perf (client): Refactored the command-line interface and enhanced user experience

- Optimized terminal output style and configuration management using libraries such as `lipgloss` and `go-json`

- Introduced the `ui` package to unify the display logic of colors, tables, and prompts

- Updated the README document structure and installation script links to improve readability and internationalization support

- Improved the interaction flow and log display effects of the daemon startup and attach commands

- Fixed some command parameter parsing issues, improving program robustness and user onboarding experience
This commit is contained in:
Gouryella
2025-12-03 10:18:52 +08:00
parent 37d1c4e005
commit dd54e79ad7
39 changed files with 1215 additions and 2326 deletions

View File

@@ -1,12 +1,37 @@
# Drip - Fast Tunnels to Localhost
<p align="center">
<img src="images/logo.png" alt="Drip Logo" width="128" />
</p>
Self-hosted tunneling solution. Expose your localhost to the internet securely.
<p align="center" style="font-size: 44px; font-weight: 600; margin: 0;">Drip</p>
<p align="center" style="font-size: 20px; font-weight: 500; margin: 8px 0 0;">
Your Tunnel, Your Domain, Anywhere
</p>
[中文文档](README_CN.md)
<p align="center">
A self-hosted tunneling solution to securely expose your services to the internet.
</p>
<p align="center ">
<a href="README.md">English</a>
<span> | </span>
<a href="README_CN.md">中文文档</a>
</p>
<p align="center">
<a href="https://golang.org/">
<img src="https://img.shields.io/badge/Go-1.21+-00ADD8?style=flat&logo=go" alt="Go Version" />
</a>
<a href="LICENSE">
<img src="https://img.shields.io/badge/License-BSD--3--Clause-blue.svg" alt="License" />
</a>
<a href="https://tools.ietf.org/html/rfc8446">
<img src="https://img.shields.io/badge/TLS-1.3-green.svg" alt="TLS" />
</a>
</p>
> Drip is a quiet, disciplined tunnel.
> You light a small lamp on your network, and it carries that light outward—through your own infrastructure, on your own terms.
[![Go Version](https://img.shields.io/badge/Go-1.21+-00ADD8?style=flat&logo=go)](https://golang.org/)
[![License](https://img.shields.io/badge/License-BSD--3--Clause-blue.svg)](LICENSE)
[![TLS](https://img.shields.io/badge/TLS-1.3-green.svg)](https://tools.ietf.org/html/rfc8446)
## Why?
@@ -32,13 +57,13 @@ Self-hosted tunneling solution. Expose your localhost to the internet securely.
### Client (macOS/Linux)
```bash
bash <(curl -sL https://raw.githubusercontent.com/Gouryella/drip/refs/heads/main/scripts/install.sh)
bash <(curl -sL https://raw.githubusercontent.com/Gouryella/drip/main/scripts/install.sh)
```
### Server (Linux)
```bash
bash <(curl -sL https://raw.githubusercontent.com/Gouryella/drip/refs/heads/main/scripts/install-server.sh)
bash <(curl -sL https://raw.githubusercontent.com/Gouryella/drip/main/scripts/install-server.sh)
```
## Usage

View File

@@ -32,13 +32,13 @@
### 客户端 (macOS/Linux)
```bash
bash <(curl -sL https://raw.githubusercontent.com/Gouryella/drip/refs/heads/main/scripts/install.sh)
bash <(curl -sL https://raw.githubusercontent.com/Gouryella/drip/main/scripts/install.sh)
```
### 服务端 (Linux)
```bash
bash <(curl -sL https://raw.githubusercontent.com/Gouryella/drip/refs/heads/main/scripts/install-server.sh)
bash <(curl -sL https://raw.githubusercontent.com/Gouryella/drip/main/scripts/install-server.sh)
```
## 使用方法

View File

@@ -14,10 +14,8 @@ var (
)
func main() {
// Set version information
cli.SetVersion(Version, GitCommit, BuildTime)
// Execute CLI
if err := cli.Execute(); err != nil {
fmt.Fprintf(os.Stderr, "Error: %v\n", err)
os.Exit(1)

14
go.mod
View File

@@ -3,6 +3,8 @@ module drip
go 1.25.4
require (
github.com/charmbracelet/lipgloss v1.1.0
github.com/goccy/go-json v0.10.5
github.com/gorilla/websocket v1.5.3
github.com/spf13/cobra v1.10.1
github.com/vmihailenco/msgpack/v5 v5.4.1
@@ -12,11 +14,23 @@ require (
)
require (
github.com/aymanbagabas/go-osc52/v2 v2.0.1 // indirect
github.com/charmbracelet/colorprofile v0.2.3-0.20250311203215-f60798e515dc // indirect
github.com/charmbracelet/x/ansi v0.8.0 // indirect
github.com/charmbracelet/x/cellbuf v0.0.13-0.20250311204145-2c3ea96c31dd // indirect
github.com/charmbracelet/x/term v0.2.1 // indirect
github.com/inconshreveable/mousetrap v1.1.0 // indirect
github.com/lucasb-eyer/go-colorful v1.2.0 // indirect
github.com/mattn/go-isatty v0.0.20 // indirect
github.com/mattn/go-runewidth v0.0.16 // indirect
github.com/muesli/termenv v0.16.0 // indirect
github.com/rivo/uniseg v0.4.7 // indirect
github.com/spf13/pflag v1.0.10 // indirect
github.com/stretchr/testify v1.11.1 // indirect
github.com/vmihailenco/tagparser/v2 v2.0.0 // indirect
github.com/xo/terminfo v0.0.0-20220910002029-abceb7e1c41e // indirect
go.uber.org/multierr v1.11.0 // indirect
golang.org/x/net v0.47.0 // indirect
golang.org/x/sys v0.38.0 // indirect
golang.org/x/text v0.31.0 // indirect
)

32
go.sum
View File

@@ -1,12 +1,37 @@
github.com/aymanbagabas/go-osc52/v2 v2.0.1 h1:HwpRHbFMcZLEVr42D4p7XBqjyuxQH5SMiErDT4WkJ2k=
github.com/aymanbagabas/go-osc52/v2 v2.0.1/go.mod h1:uYgXzlJ7ZpABp8OJ+exZzJJhRNQ2ASbcXHWsFqH8hp8=
github.com/charmbracelet/colorprofile v0.2.3-0.20250311203215-f60798e515dc h1:4pZI35227imm7yK2bGPcfpFEmuY1gc2YSTShr4iJBfs=
github.com/charmbracelet/colorprofile v0.2.3-0.20250311203215-f60798e515dc/go.mod h1:X4/0JoqgTIPSFcRA/P6INZzIuyqdFY5rm8tb41s9okk=
github.com/charmbracelet/lipgloss v1.1.0 h1:vYXsiLHVkK7fp74RkV7b2kq9+zDLoEU4MZoFqR/noCY=
github.com/charmbracelet/lipgloss v1.1.0/go.mod h1:/6Q8FR2o+kj8rz4Dq0zQc3vYf7X+B0binUUBwA0aL30=
github.com/charmbracelet/x/ansi v0.8.0 h1:9GTq3xq9caJW8ZrBTe0LIe2fvfLR/bYXKTx2llXn7xE=
github.com/charmbracelet/x/ansi v0.8.0/go.mod h1:wdYl/ONOLHLIVmQaxbIYEC/cRKOQyjTkowiI4blgS9Q=
github.com/charmbracelet/x/cellbuf v0.0.13-0.20250311204145-2c3ea96c31dd h1:vy0GVL4jeHEwG5YOXDmi86oYw2yuYUGqz6a8sLwg0X8=
github.com/charmbracelet/x/cellbuf v0.0.13-0.20250311204145-2c3ea96c31dd/go.mod h1:xe0nKWGd3eJgtqZRaN9RjMtK7xUYchjzPr7q6kcvCCs=
github.com/charmbracelet/x/term v0.2.1 h1:AQeHeLZ1OqSXhrAWpYUtZyX1T3zVxfpZuEQMIQaGIAQ=
github.com/charmbracelet/x/term v0.2.1/go.mod h1:oQ4enTYFV7QN4m0i9mzHrViD7TQKvNEEkHUMCmsxdUg=
github.com/cpuguy83/go-md2man/v2 v2.0.6/go.mod h1:oOW0eioCTA6cOiMLiUPZOpcVxMig6NIQQ7OS05n1F4g=
github.com/davecgh/go-spew v1.1.1 h1:vj9j/u1bqnvCEfJOwUhtlOARqs3+rkHYY13jYWTU97c=
github.com/davecgh/go-spew v1.1.1/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38=
github.com/goccy/go-json v0.10.5 h1:Fq85nIqj+gXn/S5ahsiTlK3TmC85qgirsdTP/+DeaC4=
github.com/goccy/go-json v0.10.5/go.mod h1:oq7eo15ShAhp70Anwd5lgX2pLfOS3QCiwU/PULtXL6M=
github.com/gorilla/websocket v1.5.3 h1:saDtZ6Pbx/0u+bgYQ3q96pZgCzfhKXGPqt7kZ72aNNg=
github.com/gorilla/websocket v1.5.3/go.mod h1:YR8l580nyteQvAITg2hZ9XVh4b55+EU/adAjf1fMHhE=
github.com/inconshreveable/mousetrap v1.1.0 h1:wN+x4NVGpMsO7ErUn/mUI3vEoE6Jt13X2s0bqwp9tc8=
github.com/inconshreveable/mousetrap v1.1.0/go.mod h1:vpF70FUmC8bwa3OWnCshd2FqLfsEA9PFc4w1p2J65bw=
github.com/lucasb-eyer/go-colorful v1.2.0 h1:1nnpGOrhyZZuNyfu1QjKiUICQ74+3FNCN69Aj6K7nkY=
github.com/lucasb-eyer/go-colorful v1.2.0/go.mod h1:R4dSotOR9KMtayYi1e77YzuveK+i7ruzyGqttikkLy0=
github.com/mattn/go-isatty v0.0.20 h1:xfD0iDuEKnDkl03q4limB+vH+GxLEtL/jb4xVJSWWEY=
github.com/mattn/go-isatty v0.0.20/go.mod h1:W+V8PltTTMOvKvAeJH7IuucS94S2C6jfK/D7dTCTo3Y=
github.com/mattn/go-runewidth v0.0.16 h1:E5ScNMtiwvlvB5paMFdw9p4kSQzbXFikJ5SQO6TULQc=
github.com/mattn/go-runewidth v0.0.16/go.mod h1:Jdepj2loyihRzMpdS35Xk/zdY8IAYHsh153qUoGf23w=
github.com/muesli/termenv v0.16.0 h1:S5AlUN9dENB57rsbnkPyfdGuWIlkmzJjbFf0Tf5FWUc=
github.com/muesli/termenv v0.16.0/go.mod h1:ZRfOIKPFDYQoDFF4Olj7/QJbW60Ol/kL1pU3VfY/Cnk=
github.com/pmezard/go-difflib v1.0.0 h1:4DBwDE0NGyQoBHbLQYPwSUPoCMWR5BEzIk/f1lZbAQM=
github.com/pmezard/go-difflib v1.0.0/go.mod h1:iKH77koFhYxTK1pcRnkKkqfTogsbg7gZNVY4sRDYZ/4=
github.com/rivo/uniseg v0.2.0/go.mod h1:J6wj4VEh+S6ZtnVlnTBMWIodfgj8LQOQFoIToxlJtxc=
github.com/rivo/uniseg v0.4.7 h1:WUdvkW8uEhrYfLC4ZzdpI2ztxP1I582+49Oc5Mq64VQ=
github.com/rivo/uniseg v0.4.7/go.mod h1:FN3SvrM+Zdj16jyLfmOkMNblXMcoc8DfTHruCPUcx88=
github.com/russross/blackfriday/v2 v2.1.0/go.mod h1:+Rmxgy9KzJVeS9/2gXHxylqXiyQDYRxCVz55jmeOWTM=
github.com/spf13/cobra v1.10.1 h1:lJeBwCfmrnXthfAupyUTzJ/J4Nc1RsHC/mSRU2dll/s=
github.com/spf13/cobra v1.10.1/go.mod h1:7SmJGaTHFVBY0jW4NXGluQoLvhqFQM+6XSKD+P4XaB0=
@@ -19,6 +44,8 @@ github.com/vmihailenco/msgpack/v5 v5.4.1 h1:cQriyiUvjTwOHg8QZaPihLWeRAAVoCpE00IU
github.com/vmihailenco/msgpack/v5 v5.4.1/go.mod h1:GaZTsDaehaPpQVyxrf5mtQlH+pc21PIudVV/E3rRQok=
github.com/vmihailenco/tagparser/v2 v2.0.0 h1:y09buUbR+b5aycVFQs/g70pqKVZNBmxwAhO7/IwNM9g=
github.com/vmihailenco/tagparser/v2 v2.0.0/go.mod h1:Wri+At7QHww0WTrCBeu4J6bNtoV6mEfg5OIWRZA9qds=
github.com/xo/terminfo v0.0.0-20220910002029-abceb7e1c41e h1:JVG44RsyaB9T2KIHavMF/ppJZNG9ZpyihvCd0w101no=
github.com/xo/terminfo v0.0.0-20220910002029-abceb7e1c41e/go.mod h1:RbqR21r5mrJuqunuUZ/Dhy/avygyECGrLceyNeo4LiM=
go.uber.org/goleak v1.3.0 h1:2K3zAYmnTNqV73imy9J1T3WC+gmCePx2hEGkimedGto=
go.uber.org/goleak v1.3.0/go.mod h1:CoHD4mav9JJNrW/WLlf7HGZPjdw8EucARQHekz1X6bE=
go.uber.org/multierr v1.11.0 h1:blXXJkSxSSfBVBlC76pxqeO+LN3aDfLQo+309xJstO0=
@@ -27,8 +54,13 @@ go.uber.org/zap v1.27.1 h1:08RqriUEv8+ArZRYSTXy1LeBScaMpVSTBhCeaZYfMYc=
go.uber.org/zap v1.27.1/go.mod h1:GB2qFLM7cTU87MWRP2mPIjqfIDnGu+VIO4V/SdhGo2E=
golang.org/x/crypto v0.45.0 h1:jMBrvKuj23MTlT0bQEOBcAE0mjg8mK9RXFhRH6nyF3Q=
golang.org/x/crypto v0.45.0/go.mod h1:XTGrrkGJve7CYK7J8PEww4aY7gM3qMCElcJQ8n8JdX4=
golang.org/x/exp v0.0.0-20220909182711-5c715a9e8561 h1:MDc5xs78ZrZr3HMQugiXOAkSZtfTpbJLDr/lwfgO53E=
golang.org/x/exp v0.0.0-20220909182711-5c715a9e8561/go.mod h1:cyybsKvd6eL0RnXn6p/Grxp8F5bW7iYuBgsNCOHpMYE=
golang.org/x/net v0.47.0 h1:Mx+4dIFzqraBXUugkia1OOvlD6LemFo1ALMHjrXDOhY=
golang.org/x/net v0.47.0/go.mod h1:/jNxtkgq5yWUGYkaZGqo27cfGZ1c5Nen03aYrrKpVRU=
golang.org/x/sys v0.6.0/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg=
golang.org/x/sys v0.38.0 h1:3yZWxaJjBmCWXqhN1qh02AkOnCQ1poK6oF+a7xWL6Gc=
golang.org/x/sys v0.38.0/go.mod h1:OgkHotnGiDImocRcuBABYBEXf8A9a87e/uXjp9XT3ks=
golang.org/x/text v0.31.0 h1:aC8ghyu4JhP8VojJ2lEHBnochRno1sgL6nEi9WGFGMM=
golang.org/x/text v0.31.0/go.mod h1:tKRAlv61yKIjGGHX/4tP1LTbc13YSec1pxVEWXzfoeM=
gopkg.in/check.v1 v0.0.0-20161208181325-20d25e280405 h1:yhCVgyC4o1eVCa2tZl7eS0r+SDo693bJlVdllGtEeKM=

BIN
images/logo.png Normal file

Binary file not shown.

After

Width:  |  Height:  |  Size: 1.4 MiB

View File

@@ -12,6 +12,7 @@ import (
"syscall"
"time"
"drip/internal/client/cli/ui"
"github.com/spf13/cobra"
)
@@ -36,27 +37,26 @@ func init() {
}
func runAttach(cmd *cobra.Command, args []string) error {
// Clean up stale daemons first
CleanupStaleDaemons()
// Get all running daemons
daemons, err := ListAllDaemons()
if err != nil {
return fmt.Errorf("failed to list daemons: %w", err)
}
if len(daemons) == 0 {
fmt.Println("\033[90mNo running tunnels.\033[0m")
fmt.Println()
fmt.Println("Start a tunnel in background with:")
fmt.Println(" \033[36mdrip http 3000 -d\033[0m")
fmt.Println(" \033[36mdrip tcp 5432 -d\033[0m")
fmt.Println(ui.Info(
"No Running Tunnels",
"",
ui.Muted("Start a tunnel in background with:"),
ui.Cyan(" drip http 3000 -d"),
ui.Cyan(" drip tcp 5432 -d"),
))
return nil
}
var selectedDaemon *DaemonInfo
// If type and port are specified, find the specific daemon
if len(args) == 2 {
tunnelType := args[0]
if tunnelType != "http" && tunnelType != "tcp" {
@@ -68,7 +68,6 @@ func runAttach(cmd *cobra.Command, args []string) error {
return fmt.Errorf("invalid port number: %s", args[1])
}
// Find the daemon
for _, d := range daemons {
if d.Type == tunnelType && d.Port == port {
if !IsProcessRunning(d.PID) {
@@ -84,29 +83,21 @@ func runAttach(cmd *cobra.Command, args []string) error {
return fmt.Errorf("no %s tunnel running on port %d", tunnelType, port)
}
} else if len(args) == 0 {
// Interactive selection
selectedDaemon, err = selectDaemonInteractive(daemons)
if err != nil {
return err
}
if selectedDaemon == nil {
return nil // User cancelled
return nil
}
} else {
return fmt.Errorf("usage: drip attach [type port]")
}
// Attach to the selected daemon
return attachToDaemon(selectedDaemon)
}
func selectDaemonInteractive(daemons []*DaemonInfo) (*DaemonInfo, error) {
// Print header
fmt.Println()
fmt.Println("\033[1;37mSelect a tunnel to attach:\033[0m")
fmt.Println("\033[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\033[0m")
// Filter out non-running daemons
var runningDaemons []*DaemonInfo
for _, d := range daemons {
if IsProcessRunning(d.PID) {
@@ -117,36 +108,36 @@ func selectDaemonInteractive(daemons []*DaemonInfo) (*DaemonInfo, error) {
}
if len(runningDaemons) == 0 {
fmt.Println("\033[90mNo running tunnels.\033[0m")
fmt.Println(ui.Muted("No running tunnels."))
return nil, nil
}
// Print list
table := ui.NewTable([]string{"#", "TYPE", "PORT", "URL", "UPTIME"}).
WithTitle("Select a tunnel to attach")
for i, d := range runningDaemons {
uptime := time.Since(d.StartTime)
// Format type with color
var typeStr string
if d.Type == "http" {
typeStr = "\033[32mHTTP\033[0m"
typeStr = ui.Success("HTTP")
} else {
typeStr = "\033[35mTCP\033[0m"
typeStr = ui.Highlight("TCP")
}
// Truncate URL if too long
url := d.URL
if len(url) > 50 {
url = url[:47] + "..."
}
fmt.Printf("\033[1;36m%d.\033[0m %-15s \033[90mPort:\033[0m %-6d \033[90mURL:\033[0m %-50s \033[90mUptime:\033[0m %s\n",
i+1, typeStr, d.Port, url, FormatDuration(uptime))
table.AddRow([]string{
ui.Highlight(fmt.Sprintf("%d", i+1)),
typeStr,
fmt.Sprintf("%d", d.Port),
ui.URL(d.URL),
FormatDuration(uptime),
})
}
fmt.Println("\033[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\033[0m")
fmt.Print(table.Render())
fmt.Printf("Enter number (1-%d) or 'q' to quit: ", len(runningDaemons))
// Read user input
reader := bufio.NewReader(os.Stdin)
input, err := reader.ReadString('\n')
if err != nil {
@@ -158,7 +149,6 @@ func selectDaemonInteractive(daemons []*DaemonInfo) (*DaemonInfo, error) {
return nil, nil
}
// Parse selection
selection, err := strconv.Atoi(input)
if err != nil || selection < 1 || selection > len(runningDaemons) {
return nil, fmt.Errorf("invalid selection: %s", input)
@@ -168,35 +158,28 @@ func selectDaemonInteractive(daemons []*DaemonInfo) (*DaemonInfo, error) {
}
func attachToDaemon(daemon *DaemonInfo) error {
// Get log file path
logPath := filepath.Join(getDaemonDir(), fmt.Sprintf("%s_%d.log", daemon.Type, daemon.Port))
// Check if log file exists
if _, err := os.Stat(logPath); os.IsNotExist(err) {
return fmt.Errorf("log file not found: %s", logPath)
}
// Print header
fmt.Println()
fmt.Println("\033[1;32m╔══════════════════════════════════════════════════════════════════╗\033[0m")
fmt.Printf("\033[1;32m║\033[0m \033[1;37mAttached to %s tunnel on port %d\033[0m", strings.ToUpper(daemon.Type), daemon.Port)
fmt.Printf("%s\033[1;32m║\033[0m\n", strings.Repeat(" ", 32-len(daemon.Type)))
fmt.Println("\033[1;32m╠══════════════════════════════════════════════════════════════════╣\033[0m")
fmt.Printf("\033[1;32m║\033[0m \033[90mURL:\033[0m \033[36m%-52s\033[0m \033[1;32m║\033[0m\n", daemon.URL)
uptime := time.Since(daemon.StartTime)
fmt.Printf("\033[1;32m║\033[0m \033[90mPID:\033[0m \033[90m%-52d\033[0m \033[1;32m║\033[0m\n", daemon.PID)
fmt.Printf("\033[1;32m║\033[0m \033[90mUptime:\033[0m \033[90m%-52s\033[0m \033[1;32m║\033[0m\n", FormatDuration(uptime))
fmt.Printf("\033[1;32m║\033[0m \033[90mLog:\033[0m \033[90m%-52s\033[0m \033[1;32m║\033[0m\n", truncatePath(logPath, 52))
fmt.Println("\033[1;32m╠══════════════════════════════════════════════════════════════════╣\033[0m")
fmt.Println("\033[1;32m║\033[0m \033[33mPress Ctrl+C to detach (tunnel will continue running)\033[0m \033[1;32m║\033[0m")
fmt.Println("\033[1;32m╚══════════════════════════════════════════════════════════════════╝\033[0m")
fmt.Println()
// Setup signal handler
fmt.Println(ui.Info(
fmt.Sprintf("Attached to %s tunnel on port %d", strings.ToUpper(daemon.Type), daemon.Port),
"",
ui.KeyValue("URL", ui.URL(daemon.URL)),
ui.KeyValue("PID", fmt.Sprintf("%d", daemon.PID)),
ui.KeyValue("Uptime", FormatDuration(uptime)),
ui.KeyValue("Log", truncatePath(logPath, 48)),
"",
ui.Warning("Press Ctrl+C to detach (tunnel will continue running)"),
))
sigCh := make(chan os.Signal, 1)
signal.Notify(sigCh, syscall.SIGINT, syscall.SIGTERM)
// Start tail command
tailCmd := exec.Command("tail", "-f", logPath)
tailCmd.Stdout = os.Stdout
tailCmd.Stderr = os.Stderr
@@ -205,7 +188,6 @@ func attachToDaemon(daemon *DaemonInfo) error {
return fmt.Errorf("failed to start tail: %w", err)
}
// Wait for signal or tail to exit
done := make(chan error, 1)
go func() {
done <- tailCmd.Wait()
@@ -213,14 +195,13 @@ func attachToDaemon(daemon *DaemonInfo) error {
select {
case <-sigCh:
// Kill tail process
if tailCmd.Process != nil {
tailCmd.Process.Kill()
}
fmt.Println()
fmt.Println("\033[33mDetached from tunnel (tunnel is still running)\033[0m")
fmt.Printf("Use '\033[36mdrip attach %s %d\033[0m' to reattach\n", daemon.Type, daemon.Port)
fmt.Printf("Use '\033[36mdrip stop %s %d\033[0m' to stop the tunnel\n", daemon.Type, daemon.Port)
fmt.Println(ui.Warning("Detached from tunnel (tunnel is still running)"))
fmt.Println(ui.Muted(fmt.Sprintf("Use '%s' to reattach", ui.Cyan(fmt.Sprintf("drip attach %s %d", daemon.Type, daemon.Port)))))
fmt.Println(ui.Muted(fmt.Sprintf("Use '%s' to stop the tunnel", ui.Cyan(fmt.Sprintf("drip stop %s %d", daemon.Type, daemon.Port)))))
return nil
case err := <-done:
if err != nil {
@@ -234,7 +215,6 @@ func truncatePath(path string, maxLen int) string {
if len(path) <= maxLen {
return path
}
// Try to keep filename and show ... in the middle
filename := filepath.Base(path)
if len(filename) >= maxLen-3 {
return "..." + filename[len(filename)-(maxLen-3):]

View File

@@ -6,6 +6,7 @@ import (
"os"
"strings"
"drip/internal/client/cli/ui"
"drip/pkg/config"
"github.com/spf13/cobra"
)
@@ -59,36 +60,28 @@ var (
)
func init() {
// Add subcommands
configCmd.AddCommand(configInitCmd)
configCmd.AddCommand(configShowCmd)
configCmd.AddCommand(configSetCmd)
configCmd.AddCommand(configResetCmd)
configCmd.AddCommand(configValidateCmd)
// Flags for show
configShowCmd.Flags().BoolVar(&configFull, "full", false, "Show full token (not hidden)")
// Flags for set
configSetCmd.Flags().StringVar(&configServer, "server", "", "Server address (e.g., tunnel.example.com:443)")
configSetCmd.Flags().StringVar(&configToken, "token", "", "Authentication token")
// Flags for reset
configResetCmd.Flags().BoolVar(&configForce, "force", false, "Force reset without confirmation")
// Add to root
rootCmd.AddCommand(configCmd)
}
func runConfigInit(cmd *cobra.Command, args []string) error {
fmt.Println("\n╔═══════════════════════════════════════╗")
fmt.Println("║ Drip Configuration Setup ║")
fmt.Println("╚═══════════════════════════════════════╝")
fmt.Print(ui.RenderConfigInit())
reader := bufio.NewReader(os.Stdin)
// Get server address
fmt.Print("Server address (e.g., tunnel.example.com:443): ")
fmt.Print(ui.Muted("Server address (e.g., tunnel.example.com:443): "))
serverAddr, _ := reader.ReadString('\n')
serverAddr = strings.TrimSpace(serverAddr)
@@ -96,102 +89,79 @@ func runConfigInit(cmd *cobra.Command, args []string) error {
return fmt.Errorf("server address is required")
}
// Get token
fmt.Print("Authentication token (leave empty to skip): ")
fmt.Print(ui.Muted("Authentication token (leave empty to skip): "))
token, _ := reader.ReadString('\n')
token = strings.TrimSpace(token)
// Create config
cfg := &config.ClientConfig{
Server: serverAddr,
Token: token,
TLS: true,
}
// Save config
if err := config.SaveClientConfig(cfg, ""); err != nil {
return fmt.Errorf("failed to save configuration: %w", err)
}
fmt.Println("\n✓ Configuration saved to", config.DefaultClientConfigPath())
fmt.Println("✓ You can now use 'drip' without --server and --token")
fmt.Println(ui.RenderConfigSaved(config.DefaultClientConfigPath()))
return nil
}
func runConfigShow(cmd *cobra.Command, args []string) error {
// Load config
cfg, err := config.LoadClientConfig("")
if err != nil {
return err
}
fmt.Println("\n╔═══════════════════════════════════════╗")
fmt.Println("║ Current Configuration ║")
fmt.Println("╚═══════════════════════════════════════╝")
fmt.Printf("Server: %s\n", cfg.Server)
// Show token (hidden or full)
var displayToken string
if cfg.Token != "" {
if configFull {
fmt.Printf("Token: %s\n", cfg.Token)
if len(cfg.Token) > 10 {
displayToken = cfg.Token[:3] + "***" + cfg.Token[len(cfg.Token)-3:]
} else {
// Hide middle part of token
if len(cfg.Token) > 10 {
fmt.Printf("Token: %s***%s (hidden)\n",
cfg.Token[:3],
cfg.Token[len(cfg.Token)-3:],
)
} else {
fmt.Printf("Token: %s (hidden)\n", cfg.Token[:3]+"***")
}
displayToken = cfg.Token[:3] + "***"
}
} else {
fmt.Println("Token: (not set)")
displayToken = ""
}
fmt.Printf("TLS: %s\n", enabledDisabled(cfg.TLS))
fmt.Printf("Config: %s\n\n", config.DefaultClientConfigPath())
fmt.Println(ui.RenderConfigShow(cfg.Server, displayToken, !configFull, cfg.TLS, config.DefaultClientConfigPath()))
return nil
}
func runConfigSet(cmd *cobra.Command, args []string) error {
// Load existing config or create new
cfg, err := config.LoadClientConfig("")
if err != nil {
// Create new config if not exists
cfg = &config.ClientConfig{
TLS: true,
}
}
// Update fields if provided
modified := false
var updates []string
if configServer != "" {
cfg.Server = configServer
modified = true
fmt.Printf("✓ Server updated: %s\n", configServer)
updates = append(updates, "Server updated: "+configServer)
}
if configToken != "" {
cfg.Token = configToken
modified = true
fmt.Println("✓ Token updated")
updates = append(updates, "Token updated")
}
if !modified {
return fmt.Errorf("no changes specified. Use --server or --token")
}
// Save config
if err := config.SaveClientConfig(cfg, ""); err != nil {
return fmt.Errorf("failed to save configuration: %w", err)
}
fmt.Println("✓ Configuration saved")
fmt.Println(ui.RenderConfigUpdated(updates))
return nil
}
@@ -199,13 +169,11 @@ func runConfigSet(cmd *cobra.Command, args []string) error {
func runConfigReset(cmd *cobra.Command, args []string) error {
configPath := config.DefaultClientConfigPath()
// Check if config exists
if !config.ConfigExists("") {
fmt.Println("No configuration file found")
return nil
}
// Confirm deletion
if !configForce {
fmt.Print("Are you sure you want to delete the configuration? (y/N): ")
reader := bufio.NewReader(os.Stdin)
@@ -218,49 +186,31 @@ func runConfigReset(cmd *cobra.Command, args []string) error {
}
}
// Delete config file
if err := os.Remove(configPath); err != nil {
return fmt.Errorf("failed to delete configuration: %w", err)
}
fmt.Println("✓ Configuration file deleted")
fmt.Println(ui.RenderConfigDeleted())
return nil
}
func runConfigValidate(cmd *cobra.Command, args []string) error {
fmt.Println("\nValidating configuration...")
fmt.Println("━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━")
// Load config
cfg, err := config.LoadClientConfig("")
if err != nil {
fmt.Println("Failed to load configuration")
fmt.Println(ui.Error("Failed to load configuration"))
return err
}
// Validate server address
if cfg.Server == "" {
fmt.Println("✗ Server address is not set")
serverValid := cfg.Server != ""
tokenSet := cfg.Token != ""
tlsEnabled := cfg.TLS
fmt.Println(ui.RenderConfigValidation(serverValid, tokenSet, tlsEnabled))
if !serverValid {
return fmt.Errorf("invalid configuration")
}
fmt.Println("✓ Server address is valid")
// Validate token
if cfg.Token != "" {
fmt.Println("✓ Token is set")
} else {
fmt.Println("⚠ Token is not set (authentication may fail)")
}
// Validate TLS
if cfg.TLS {
fmt.Println("✓ TLS is enabled")
} else {
fmt.Println("⚠ TLS is disabled (not recommended for production)")
}
fmt.Println("\n✓ Configuration is valid")
return nil
}

View File

@@ -1,13 +1,15 @@
package cli
import (
"encoding/json"
"fmt"
"os"
"os/exec"
"path/filepath"
"strconv"
"time"
"drip/internal/client/cli/ui"
json "github.com/goccy/go-json"
)
// DaemonInfo stores information about a running daemon process
@@ -170,13 +172,10 @@ func StartDaemon(tunnelType string, port int, args []string) error {
cleanArgs = append(cleanArgs, arg)
}
// Create the command
cmd := exec.Command(executable, cleanArgs...)
// Detach from parent process (platform-specific)
setupDaemonCmd(cmd)
// Create log file for daemon output
logDir := getDaemonDir()
if err := os.MkdirAll(logDir, 0700); err != nil {
return fmt.Errorf("failed to create daemon directory: %w", err)
@@ -187,7 +186,6 @@ func StartDaemon(tunnelType string, port int, args []string) error {
return fmt.Errorf("failed to create log file: %w", err)
}
// Redirect stdin to /dev/null
devNull, err := os.OpenFile(os.DevNull, os.O_RDONLY, 0)
if err != nil {
logFile.Close()
@@ -197,7 +195,6 @@ func StartDaemon(tunnelType string, port int, args []string) error {
cmd.Stdout = logFile
cmd.Stderr = logFile
// Start the process
if err := cmd.Start(); err != nil {
logFile.Close()
devNull.Close()
@@ -207,11 +204,7 @@ func StartDaemon(tunnelType string, port int, args []string) error {
// Don't wait for the process - let it run in background
// The child process will save its own daemon info after connecting
fmt.Printf("\033[32m✓\033[0m Started %s tunnel on port %d in background (PID: %d)\n", tunnelType, port, cmd.Process.Pid)
fmt.Printf(" Use '\033[36mdrip list\033[0m' to check tunnel status\n")
fmt.Printf(" Use '\033[36mdrip attach %s %d\033[0m' to view logs\n", tunnelType, port)
fmt.Printf(" Use '\033[36mdrip stop %s %d\033[0m' to stop this tunnel\n", tunnelType, port)
fmt.Printf(" Logs: \033[90m%s\033[0m\n", logPath)
fmt.Println(ui.RenderDaemonStarted(tunnelType, port, cmd.Process.Pid, logPath))
return nil
}
@@ -249,11 +242,9 @@ func FormatDuration(d time.Duration) string {
// ParsePortFromArgs extracts the port number from command arguments
func ParsePortFromArgs(args []string) (int, error) {
for _, arg := range args {
// Skip flags
if len(arg) > 0 && arg[0] == '-' {
continue
}
// Try to parse as port number
port, err := strconv.Atoi(arg)
if err == nil && port > 0 && port <= 65535 {
return port, nil

View File

@@ -3,19 +3,14 @@ package cli
import (
"fmt"
"os"
"os/signal"
"strconv"
"strings"
"syscall"
"time"
"drip/internal/client/tcp"
"drip/internal/shared/protocol"
"drip/internal/shared/utils"
"drip/pkg/config"
"github.com/spf13/cobra"
"go.uber.org/zap"
)
const (
@@ -87,13 +82,6 @@ func runHTTP(cmd *cobra.Command, args []string) error {
return StartDaemon("http", port, daemonArgs)
}
if err := utils.InitLogger(verbose); err != nil {
return fmt.Errorf("failed to initialize logger: %w", err)
}
defer utils.Sync()
logger := utils.GetLogger()
var serverAddr, token string
if serverURL == "" {
@@ -125,182 +113,18 @@ Please run 'drip config init' first, or use flags:
Insecure: insecure,
}
quit := make(chan os.Signal, 1)
signal.Notify(quit, syscall.SIGINT, syscall.SIGTERM)
reconnectAttempts := 0
for {
connector := tcp.NewConnector(connConfig, logger)
if reconnectAttempts == 0 {
fmt.Printf("\033[36m🔌 Connecting to %s...\033[0m\n", serverAddr)
} else {
fmt.Printf("\033[33m🔄 Reconnecting to %s (attempt %d/%d)...\033[0m\n", serverAddr, reconnectAttempts, maxReconnectAttempts)
}
if err := connector.Connect(); err != nil {
if isNonRetryableError(err) {
return fmt.Errorf("failed to connect: %w", err)
}
reconnectAttempts++
if reconnectAttempts >= maxReconnectAttempts {
return fmt.Errorf("failed to connect after %d attempts: %w", maxReconnectAttempts, err)
}
fmt.Printf("\033[31m✗ Connection failed: %v\033[0m\n", err)
fmt.Printf("\033[90m Retrying in %v...\033[0m\n", reconnectInterval)
select {
case <-quit:
fmt.Println("\n\033[33m🛑 Shutting down...\033[0m")
return nil
case <-time.After(reconnectInterval):
continue
}
}
reconnectAttempts = 0
if daemonMarker {
daemonInfo := &DaemonInfo{
PID: os.Getpid(),
Type: "http",
Port: port,
Subdomain: subdomain,
Server: serverAddr,
URL: connector.GetURL(),
StartTime: time.Now(),
Executable: os.Args[0],
}
if err := SaveDaemonInfo(daemonInfo); err != nil {
logger.Warn("Failed to save daemon info", zap.Error(err))
}
}
fmt.Println()
fmt.Println("\033[1;32m╔══════════════════════════════════════════════════════════════════╗\033[0m")
fmt.Println("\033[1;32m║\033[0m \033[1;37m🚀 HTTP Tunnel Connected Successfully!\033[0m \033[1;32m║\033[0m")
fmt.Println("\033[1;32m╠══════════════════════════════════════════════════════════════════╣\033[0m")
fmt.Printf("\033[1;32m║\033[0m \033[1;37mTunnel URL:\033[0m \033[1;32m║\033[0m\n")
fmt.Printf("\033[1;32m║\033[0m \033[1;36m%-60s\033[0m \033[1;32m║\033[0m\n", connector.GetURL())
fmt.Println("\033[1;32m║\033[0m \033[1;32m║\033[0m")
displayAddr := localAddress
if displayAddr == "127.0.0.1" {
displayAddr = "localhost"
}
fmt.Printf("\033[1;32m║\033[0m \033[90mForwarding:\033[0m \033[1m%s:%d\033[0m → \033[36m%s\033[0m%-15s\033[1;32m║\033[0m\n", displayAddr, port, "public", "")
fmt.Printf("\033[1;32m║\033[0m \033[90mLatency:\033[0m \033[90mmeasuring...\033[0m%-40s\033[1;32m║\033[0m\n", "")
fmt.Printf("\033[1;32m║\033[0m \033[90mTraffic:\033[0m \033[90m↓ 0 B ↑ 0 B\033[0m%-32s\033[1;32m║\033[0m\n", "")
fmt.Printf("\033[1;32m║\033[0m \033[90mSpeed:\033[0m \033[90m↓ 0 B/s ↑ 0 B/s\033[0m%-28s\033[1;32m║\033[0m\n", "")
fmt.Printf("\033[1;32m║\033[0m \033[90mRequests:\033[0m \033[90m0\033[0m%-43s\033[1;32m║\033[0m\n", "")
fmt.Println("\033[1;32m╠══════════════════════════════════════════════════════════════════╣\033[0m")
fmt.Println("\033[1;32m║\033[0m \033[90mPress Ctrl+C to stop the tunnel\033[0m \033[1;32m║\033[0m")
fmt.Println("\033[1;32m╚══════════════════════════════════════════════════════════════════╝\033[0m")
fmt.Println()
latencyCh := make(chan time.Duration, 1)
connector.SetLatencyCallback(func(latency time.Duration) {
select {
case latencyCh <- latency:
default:
}
})
stopDisplay := make(chan struct{})
disconnected := make(chan struct{})
go func() {
ticker := time.NewTicker(1 * time.Second)
defer ticker.Stop()
var lastLatency time.Duration
for {
select {
case latency := <-latencyCh:
lastLatency = latency
case <-ticker.C:
stats := connector.GetStats()
if stats != nil {
stats.UpdateSpeed()
snapshot := stats.GetSnapshot()
fmt.Print("\033[8A")
fmt.Printf("\r\033[1;32m║\033[0m \033[90mLatency:\033[0m %s%-40s\033[1;32m║\033[0m\n", formatLatency(lastLatency), "")
trafficStr := fmt.Sprintf("↓ %s ↑ %s", tcp.FormatBytes(snapshot.TotalBytesIn), tcp.FormatBytes(snapshot.TotalBytesOut))
fmt.Printf("\r\033[1;32m║\033[0m \033[90mTraffic:\033[0m \033[36m%-48s\033[0m\033[1;32m║\033[0m\n", trafficStr)
speedStr := fmt.Sprintf("↓ %s ↑ %s", tcp.FormatSpeed(snapshot.SpeedIn), tcp.FormatSpeed(snapshot.SpeedOut))
fmt.Printf("\r\033[1;32m║\033[0m \033[90mSpeed:\033[0m \033[33m%-48s\033[0m\033[1;32m║\033[0m\n", speedStr)
fmt.Printf("\r\033[1;32m║\033[0m \033[90mRequests:\033[0m \033[35m%-47d\033[0m\033[1;32m║\033[0m\n", snapshot.TotalRequests)
fmt.Print("\033[4B")
}
case <-stopDisplay:
return
}
}
}()
go func() {
connector.Wait()
close(disconnected)
}()
select {
case <-quit:
close(stopDisplay)
fmt.Println("\n\n\033[33m🛑 Shutting down...\033[0m")
connector.Close()
if daemonMarker {
RemoveDaemonInfo("http", port)
}
fmt.Println("\033[32m✓\033[0m Tunnel closed")
return nil
case <-disconnected:
close(stopDisplay)
fmt.Println("\n\n\033[31m⚠ Connection lost!\033[0m")
reconnectAttempts++
if reconnectAttempts >= maxReconnectAttempts {
return fmt.Errorf("connection lost after %d reconnect attempts", maxReconnectAttempts)
}
fmt.Printf("\033[90m Reconnecting in %v...\033[0m\n", reconnectInterval)
select {
case <-quit:
fmt.Println("\n\033[33m🛑 Shutting down...\033[0m")
return nil
case <-time.After(reconnectInterval):
continue
}
var daemon *DaemonInfo
if daemonMarker {
daemon = &DaemonInfo{
PID: os.Getpid(),
Type: "http",
Port: port,
Subdomain: subdomain,
Server: serverAddr,
StartTime: time.Now(),
Executable: os.Args[0],
}
}
}
func formatLatency(d time.Duration) string {
ms := d.Milliseconds()
if ms < 50 {
return fmt.Sprintf("\033[32m%dms\033[0m", ms)
} else if ms < 100 {
return fmt.Sprintf("\033[33m%dms\033[0m", ms)
} else if ms < 200 {
return fmt.Sprintf("\033[38;5;208m%dms\033[0m", ms)
}
return fmt.Sprintf("\033[31m%dms\033[0m", ms)
}
func isNonRetryableError(err error) bool {
errStr := err.Error()
if strings.Contains(errStr, "subdomain is already taken") ||
strings.Contains(errStr, "subdomain is reserved") ||
strings.Contains(errStr, "invalid subdomain") {
return true
}
if strings.Contains(errStr, "authentication") ||
strings.Contains(errStr, "Invalid authentication token") {
return true
}
return false
return runTunnelWithUI(connConfig, daemon)
}

View File

@@ -3,18 +3,14 @@ package cli
import (
"fmt"
"os"
"os/signal"
"strconv"
"syscall"
"time"
"drip/internal/client/tcp"
"drip/internal/shared/protocol"
"drip/internal/shared/utils"
"drip/pkg/config"
"github.com/spf13/cobra"
"go.uber.org/zap"
)
var (
@@ -52,15 +48,12 @@ func init() {
}
func runHTTPS(cmd *cobra.Command, args []string) error {
// Parse port
port, err := strconv.Atoi(args[0])
if err != nil || port < 1 || port > 65535 {
return fmt.Errorf("invalid port number: %s", args[0])
}
// Handle daemon mode
if httpsDaemonMode && !httpsDaemonMarker {
// Start as daemon
daemonArgs := append([]string{"https"}, args...)
daemonArgs = append(daemonArgs, "--daemon-child")
if httpsSubdomain != "" {
@@ -84,19 +77,9 @@ func runHTTPS(cmd *cobra.Command, args []string) error {
return StartDaemon("https", port, daemonArgs)
}
// Initialize logger
if err := utils.InitLogger(verbose); err != nil {
return fmt.Errorf("failed to initialize logger: %w", err)
}
defer utils.Sync()
logger := utils.GetLogger()
// Load configuration or use command line flags
var serverAddr, token string
if serverURL == "" {
// Try to load from config file
cfg, err := config.LoadClientConfig("")
if err != nil {
return fmt.Errorf(`configuration not found.
@@ -107,17 +90,14 @@ Please run 'drip config init' first, or use flags:
serverAddr = cfg.Server
token = cfg.Token
} else {
// Use command line flags
serverAddr = serverURL
token = authToken
}
// Validate server address
if serverAddr == "" {
return fmt.Errorf("server address is required")
}
// Create connector config
connConfig := &tcp.ConnectorConfig{
ServerAddr: serverAddr,
Token: token,
@@ -128,178 +108,18 @@ Please run 'drip config init' first, or use flags:
Insecure: insecure,
}
// Setup signal handler for graceful shutdown
quit := make(chan os.Signal, 1)
signal.Notify(quit, syscall.SIGINT, syscall.SIGTERM)
// Connection loop with reconnect support
reconnectAttempts := 0
for {
// Create connector
connector := tcp.NewConnector(connConfig, logger)
// Connect to server
if reconnectAttempts == 0 {
fmt.Printf("\033[36m🔌 Connecting to %s...\033[0m\n", serverAddr)
} else {
fmt.Printf("\033[33m🔄 Reconnecting to %s (attempt %d/%d)...\033[0m\n", serverAddr, reconnectAttempts, maxReconnectAttempts)
}
if err := connector.Connect(); err != nil {
// Check if this is a non-retryable error
if isNonRetryableError(err) {
return fmt.Errorf("failed to connect: %w", err)
}
reconnectAttempts++
if reconnectAttempts >= maxReconnectAttempts {
return fmt.Errorf("failed to connect after %d attempts: %w", maxReconnectAttempts, err)
}
fmt.Printf("\033[31m✗ Connection failed: %v\033[0m\n", err)
fmt.Printf("\033[90m Retrying in %v...\033[0m\n", reconnectInterval)
// Wait before retry, but allow interrupt
select {
case <-quit:
fmt.Println("\n\033[33m🛑 Shutting down...\033[0m")
return nil
case <-time.After(reconnectInterval):
continue
}
}
// Reset reconnect attempts on successful connection
reconnectAttempts = 0
// Save daemon info if running as daemon child
if httpsDaemonMarker {
daemonInfo := &DaemonInfo{
PID: os.Getpid(),
Type: "https",
Port: port,
Subdomain: httpsSubdomain,
Server: serverAddr,
URL: connector.GetURL(),
StartTime: time.Now(),
Executable: os.Args[0],
}
if err := SaveDaemonInfo(daemonInfo); err != nil {
// Log but don't fail
logger.Warn("Failed to save daemon info", zap.Error(err))
}
}
// Print tunnel information
fmt.Println()
fmt.Println("\033[1;32m╔══════════════════════════════════════════════════════════════════╗\033[0m")
fmt.Println("\033[1;32m║\033[0m \033[1;37m🔒 HTTPS Tunnel Connected Successfully!\033[0m \033[1;32m║\033[0m")
fmt.Println("\033[1;32m╠══════════════════════════════════════════════════════════════════╣\033[0m")
fmt.Printf("\033[1;32m║\033[0m \033[1;37mTunnel URL:\033[0m \033[1;32m║\033[0m\n")
fmt.Printf("\033[1;32m║\033[0m \033[1;36m%-60s\033[0m \033[1;32m║\033[0m\n", connector.GetURL())
fmt.Println("\033[1;32m║\033[0m \033[1;32m║\033[0m")
displayAddr := httpsLocalAddress
if displayAddr == "127.0.0.1" {
displayAddr = "localhost"
}
fmt.Printf("\033[1;32m║\033[0m \033[90mForwarding:\033[0m \033[1m%s:%d\033[0m → \033[36m%s\033[0m%-15s\033[1;32m║\033[0m\n", displayAddr, port, "public", "")
fmt.Printf("\033[1;32m║\033[0m \033[90mLatency:\033[0m \033[90mmeasuring...\033[0m%-40s\033[1;32m║\033[0m\n", "")
fmt.Printf("\033[1;32m║\033[0m \033[90mTraffic:\033[0m \033[90m↓ 0 B ↑ 0 B\033[0m%-32s\033[1;32m║\033[0m\n", "")
fmt.Printf("\033[1;32m║\033[0m \033[90mSpeed:\033[0m \033[90m↓ 0 B/s ↑ 0 B/s\033[0m%-28s\033[1;32m║\033[0m\n", "")
fmt.Printf("\033[1;32m║\033[0m \033[90mRequests:\033[0m \033[90m0\033[0m%-43s\033[1;32m║\033[0m\n", "")
fmt.Println("\033[1;32m╠══════════════════════════════════════════════════════════════════╣\033[0m")
fmt.Println("\033[1;32m║\033[0m \033[90mPress Ctrl+C to stop the tunnel\033[0m \033[1;32m║\033[0m")
fmt.Println("\033[1;32m╚══════════════════════════════════════════════════════════════════╝\033[0m")
fmt.Println()
// Setup latency display
latencyCh := make(chan time.Duration, 1)
connector.SetLatencyCallback(func(latency time.Duration) {
select {
case latencyCh <- latency:
default:
}
})
// Start stats display updater (updates every second)
stopDisplay := make(chan struct{})
disconnected := make(chan struct{})
go func() {
ticker := time.NewTicker(1 * time.Second)
defer ticker.Stop()
var lastLatency time.Duration
for {
select {
case latency := <-latencyCh:
lastLatency = latency
case <-ticker.C:
// Update speed calculation
stats := connector.GetStats()
if stats != nil {
stats.UpdateSpeed()
snapshot := stats.GetSnapshot()
// Move cursor up 8 lines to update display
fmt.Print("\033[8A")
// Update latency line
fmt.Printf("\r\033[1;32m║\033[0m \033[90mLatency:\033[0m %s%-40s\033[1;32m║\033[0m\n", formatLatency(lastLatency), "")
// Update traffic line
trafficStr := fmt.Sprintf("↓ %s ↑ %s", tcp.FormatBytes(snapshot.TotalBytesIn), tcp.FormatBytes(snapshot.TotalBytesOut))
fmt.Printf("\r\033[1;32m║\033[0m \033[90mTraffic:\033[0m \033[36m%-48s\033[0m\033[1;32m║\033[0m\n", trafficStr)
// Update speed line
speedStr := fmt.Sprintf("↓ %s ↑ %s", tcp.FormatSpeed(snapshot.SpeedIn), tcp.FormatSpeed(snapshot.SpeedOut))
fmt.Printf("\r\033[1;32m║\033[0m \033[90mSpeed:\033[0m \033[33m%-48s\033[0m\033[1;32m║\033[0m\n", speedStr)
// Update requests line
fmt.Printf("\r\033[1;32m║\033[0m \033[90mRequests:\033[0m \033[35m%-47d\033[0m\033[1;32m║\033[0m\n", snapshot.TotalRequests)
// Move back down 4 lines
fmt.Print("\033[4B")
}
case <-stopDisplay:
return
}
}
}()
// Monitor connection in background
go func() {
connector.Wait()
close(disconnected)
}()
// Wait for signal or disconnection
select {
case <-quit:
close(stopDisplay)
fmt.Println("\n\n\033[33m🛑 Shutting down...\033[0m")
connector.Close()
if httpsDaemonMarker {
RemoveDaemonInfo("https", port)
}
fmt.Println("\033[32m✓\033[0m Tunnel closed")
return nil
case <-disconnected:
close(stopDisplay)
fmt.Println("\n\n\033[31m⚠ Connection lost!\033[0m")
reconnectAttempts++
if reconnectAttempts >= maxReconnectAttempts {
return fmt.Errorf("connection lost after %d reconnect attempts", maxReconnectAttempts)
}
fmt.Printf("\033[90m Reconnecting in %v...\033[0m\n", reconnectInterval)
// Wait before reconnect, but allow interrupt
select {
case <-quit:
fmt.Println("\n\033[33m🛑 Shutting down...\033[0m")
return nil
case <-time.After(reconnectInterval):
continue
}
var daemon *DaemonInfo
if httpsDaemonMarker {
daemon = &DaemonInfo{
PID: os.Getpid(),
Type: "https",
Port: port,
Subdomain: httpsSubdomain,
Server: serverAddr,
StartTime: time.Now(),
Executable: os.Args[0],
}
}
return runTunnelWithUI(connConfig, daemon)
}

View File

@@ -8,6 +8,7 @@ import (
"strings"
"time"
"drip/internal/client/cli/ui"
"github.com/spf13/cobra"
)
@@ -44,90 +45,83 @@ func init() {
}
func runList(cmd *cobra.Command, args []string) error {
// Clean up stale daemons first
CleanupStaleDaemons()
// Get all running daemons
daemons, err := ListAllDaemons()
if err != nil {
return fmt.Errorf("failed to list daemons: %w", err)
}
if len(daemons) == 0 {
fmt.Println("\033[90mNo running tunnels.\033[0m")
fmt.Println()
fmt.Println("Start a tunnel in background with:")
fmt.Println(" \033[36mdrip http 3000 -d\033[0m")
fmt.Println(" \033[36mdrip tcp 5432 -d\033[0m")
fmt.Println(ui.Info(
"No Running Tunnels",
"",
ui.Muted("Start a tunnel in background with:"),
"",
ui.Cyan(" drip http 3000 -d"),
ui.Cyan(" drip tcp 5432 -d"),
))
return nil
}
// Print header
fmt.Println()
fmt.Println("\033[1;37mRunning Tunnels\033[0m")
fmt.Println("\033[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\033[0m")
fmt.Printf("\033[1m%-4s %-6s %-6s %-40s %-8s %s\033[0m\n", "#", "TYPE", "PORT", "URL", "PID", "UPTIME")
fmt.Println("\033[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\033[0m")
table := ui.NewTable([]string{"#", "TYPE", "PORT", "URL", "PID", "UPTIME"}).
WithTitle("Running Tunnels")
idx := 1
for _, d := range daemons {
// Check if process is still running
if !IsProcessRunning(d.PID) {
// Clean up stale entry
RemoveDaemonInfo(d.Type, d.Port)
continue
}
// Calculate uptime
uptime := time.Since(d.StartTime)
// Format type with color
var typeStr string
if d.Type == "http" {
typeStr = "\033[32mHTTP\033[0m"
typeStr = ui.Highlight("HTTP")
} else if d.Type == "https" {
typeStr = ui.Highlight("HTTPS")
} else {
typeStr = "\033[35mTCP\033[0m"
typeStr = ui.Cyan("TCP")
}
// Truncate URL if too long
url := d.URL
if len(url) > 40 {
url = url[:37] + "..."
}
fmt.Printf("\033[1;36m%-4d\033[0m %-15s %-6d %-40s %-8d %s\n",
idx, typeStr, d.Port, url, d.PID, FormatDuration(uptime))
table.AddRow([]string{
ui.Highlight(fmt.Sprintf("%d", idx)),
typeStr,
fmt.Sprintf("%d", d.Port),
ui.URL(d.URL),
ui.Muted(fmt.Sprintf("%d", d.PID)),
FormatDuration(uptime),
})
idx++
}
fmt.Println("\033[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\033[0m")
fmt.Println()
fmt.Print(table.Render())
// Interactive mode or show commands
if interactiveMode || shouldPromptForAction() {
return runInteractiveList(daemons)
}
fmt.Println("Commands:")
fmt.Println(" \033[36mdrip list -i\033[0m Interactive mode")
fmt.Println(" \033[36mdrip attach http 3000\033[0m Attach to tunnel (view logs)")
fmt.Println(" \033[36mdrip stop http 3000\033[0m Stop tunnel")
fmt.Println(" \033[36mdrip stop all\033[0m Stop all tunnels")
fmt.Println(ui.Muted("Commands:"))
fmt.Println(ui.RenderList([]string{
ui.Cyan("drip list -i") + ui.Muted(" Interactive mode"),
ui.Cyan("drip attach http 3000") + ui.Muted(" Attach to tunnel (view logs)"),
ui.Cyan("drip stop http 3000") + ui.Muted(" Stop tunnel"),
ui.Cyan("drip stop all") + ui.Muted(" Stop all tunnels"),
}))
return nil
}
func shouldPromptForAction() bool {
// Check if running in a terminal
if fileInfo, _ := os.Stdout.Stat(); (fileInfo.Mode() & os.ModeCharDevice) == 0 {
return false
}
// Always prompt when there are tunnels running
return true
return false
}
func runInteractiveList(daemons []*DaemonInfo) error {
// Filter out non-running daemons
var runningDaemons []*DaemonInfo
for _, d := range daemons {
if IsProcessRunning(d.PID) {
@@ -141,8 +135,8 @@ func runInteractiveList(daemons []*DaemonInfo) error {
return nil
}
// Prompt for action
fmt.Print("Select a tunnel (number) or 'q' to quit: ")
fmt.Println()
fmt.Print(ui.Muted("Select a tunnel (number) or 'q' to quit: "))
reader := bufio.NewReader(os.Stdin)
input, err := reader.ReadString('\n')
if err != nil {
@@ -154,7 +148,6 @@ func runInteractiveList(daemons []*DaemonInfo) error {
return nil
}
// Parse selection
selection, err := strconv.Atoi(input)
if err != nil || selection < 1 || selection > len(runningDaemons) {
return fmt.Errorf("invalid selection: %s", input)
@@ -162,16 +155,18 @@ func runInteractiveList(daemons []*DaemonInfo) error {
selectedDaemon := runningDaemons[selection-1]
// Prompt for action
fmt.Println()
fmt.Printf("Selected: \033[1m%s\033[0m tunnel on port \033[1m%d\033[0m\n", strings.ToUpper(selectedDaemon.Type), selectedDaemon.Port)
fmt.Println()
fmt.Println("What would you like to do?")
fmt.Println(" \033[36m1.\033[0m Attach (view logs)")
fmt.Println(" \033[36m2.\033[0m Stop tunnel")
fmt.Println(" \033[90mq. Cancel\033[0m")
fmt.Println()
fmt.Print("Choose an action: ")
fmt.Println(ui.Info(
fmt.Sprintf("Selected: %s tunnel on port %d", strings.ToUpper(selectedDaemon.Type), selectedDaemon.Port),
"",
ui.Muted("What would you like to do?"),
"",
ui.Cyan(" 1.") + " Attach (view logs)",
ui.Cyan(" 2.") + " Stop tunnel",
ui.Muted(" q.") + " Cancel",
))
fmt.Print(ui.Muted("Choose an action: "))
actionInput, err := reader.ReadString('\n')
if err != nil {
@@ -181,10 +176,8 @@ func runInteractiveList(daemons []*DaemonInfo) error {
actionInput = strings.TrimSpace(actionInput)
switch actionInput {
case "1":
// Attach to daemon
return attachToDaemon(selectedDaemon)
case "2":
// Stop daemon
return stopDaemon(selectedDaemon.Type, selectedDaemon.Port)
case "q", "Q", "":
return nil

View File

@@ -3,6 +3,7 @@ package cli
import (
"fmt"
"drip/internal/client/cli/ui"
"github.com/spf13/cobra"
)
@@ -59,12 +60,13 @@ var versionCmd = &cobra.Command{
Use: "version",
Short: "Print version information",
Run: func(cmd *cobra.Command, args []string) {
fmt.Printf("Drip Client\n")
fmt.Printf("━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\n")
fmt.Printf("Version: %s\n", Version)
fmt.Printf("Git Commit: %s\n", GitCommit)
fmt.Printf("Build Time: %s\n", BuildTime)
fmt.Printf("━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\n")
fmt.Println(ui.Info(
"Drip Client",
"",
ui.KeyValue("Version", Version),
ui.KeyValue("Git Commit", GitCommit),
ui.KeyValue("Build Time", BuildTime),
))
},
}

View File

@@ -60,7 +60,6 @@ func init() {
}
func runServer(cmd *cobra.Command, args []string) error {
// Validate required TLS configuration
if serverTLSCert == "" {
return fmt.Errorf("TLS certificate path is required (use --tls-cert flag or DRIP_TLS_CERT environment variable)")
}
@@ -68,7 +67,6 @@ func runServer(cmd *cobra.Command, args []string) error {
return fmt.Errorf("TLS private key path is required (use --tls-key flag or DRIP_TLS_KEY environment variable)")
}
// Initialize logger
if err := utils.InitServerLogger(serverDebug); err != nil {
return fmt.Errorf("failed to initialize logger: %w", err)
}
@@ -81,7 +79,6 @@ func runServer(cmd *cobra.Command, args []string) error {
zap.String("commit", GitCommit),
)
// Start pprof server if enabled
if serverPprofPort > 0 {
go func() {
pprofAddr := fmt.Sprintf("localhost:%d", serverPprofPort)
@@ -92,7 +89,6 @@ func runServer(cmd *cobra.Command, args []string) error {
}()
}
// Create server config
displayPort := serverPublicPort
if displayPort == 0 {
displayPort = serverPort
@@ -111,7 +107,6 @@ func runServer(cmd *cobra.Command, args []string) error {
Debug: serverDebug,
}
// Load TLS configuration
tlsConfig, err := serverConfig.LoadTLSConfig()
if err != nil {
logger.Fatal("Failed to load TLS configuration", zap.Error(err))
@@ -122,28 +117,21 @@ func runServer(cmd *cobra.Command, args []string) error {
zap.String("key", serverTLSKey),
)
// Create tunnel manager
tunnelManager := tunnel.NewManager(logger)
// Create TCP port allocator
portAllocator, err := tcp.NewPortAllocator(serverTCPPortMin, serverTCPPortMax)
if err != nil {
logger.Fatal("Invalid TCP port range", zap.Error(err))
}
// Create TCP listener address
listenAddr := fmt.Sprintf("0.0.0.0:%d", serverPort)
// Response handler for HTTP-over-frame responses
responseHandler := proxy.NewResponseHandler(logger)
// Create HTTP proxy handler (for handling HTTP requests on TCP port)
httpHandler := proxy.NewHandler(tunnelManager, logger, responseHandler, serverDomain, serverAuthToken)
// Create TCP listener (wsHandler also serves as response channel handler)
listener := tcp.NewListener(listenAddr, tlsConfig, serverAuthToken, tunnelManager, logger, portAllocator, serverDomain, displayPort, httpHandler, responseHandler)
// Start listener
if err := listener.Start(); err != nil {
logger.Fatal("Failed to start TCP listener", zap.Error(err))
}
@@ -154,16 +142,13 @@ func runServer(cmd *cobra.Command, args []string) error {
zap.String("protocol", "TCP over TLS 1.3"),
)
// Setup signal handler
quit := make(chan os.Signal, 1)
signal.Notify(quit, syscall.SIGINT, syscall.SIGTERM)
// Wait for signal
<-quit
logger.Info("Shutting down server...")
// Stop listener
if err := listener.Stop(); err != nil {
logger.Error("Error stopping listener", zap.Error(err))
}

View File

@@ -28,12 +28,10 @@ func init() {
}
func runStop(cmd *cobra.Command, args []string) error {
// Handle "stop all"
if args[0] == "all" {
return stopAllDaemons()
}
// Handle "stop <type> <port>"
if len(args) < 2 {
return fmt.Errorf("usage: drip stop <type> <port> or drip stop all")
}
@@ -61,19 +59,15 @@ func stopDaemon(tunnelType string, port int) error {
return fmt.Errorf("no %s tunnel running on port %d", tunnelType, port)
}
// Check if process is still running
if !IsProcessRunning(info.PID) {
// Clean up stale entry
RemoveDaemonInfo(tunnelType, port)
return fmt.Errorf("tunnel was not running (cleaned up stale entry)")
}
// Kill the process
if err := KillProcess(info.PID); err != nil {
return fmt.Errorf("failed to stop tunnel: %w", err)
}
// Remove daemon info
RemoveDaemonInfo(tunnelType, port)
fmt.Printf("\033[32m✓\033[0m Stopped %s tunnel on port %d (PID: %d)\n", tunnelType, port, info.PID)
@@ -81,7 +75,6 @@ func stopDaemon(tunnelType string, port int) error {
}
func stopAllDaemons() error {
// Clean up stale daemons first
CleanupStaleDaemons()
daemons, err := ListAllDaemons()

View File

@@ -3,19 +3,14 @@ package cli
import (
"fmt"
"os"
"os/signal"
"strconv"
"strings"
"syscall"
"time"
"drip/internal/client/tcp"
"drip/internal/shared/protocol"
"drip/internal/shared/utils"
"drip/pkg/config"
"github.com/spf13/cobra"
"go.uber.org/zap"
)
var tcpCmd = &cobra.Command{
@@ -53,15 +48,12 @@ func init() {
}
func runTCP(cmd *cobra.Command, args []string) error {
// Parse port
port, err := strconv.Atoi(args[0])
if err != nil || port < 1 || port > 65535 {
return fmt.Errorf("invalid port number: %s", args[0])
}
// Handle daemon mode
if daemonMode && !daemonMarker {
// Start as daemon
daemonArgs := append([]string{"tcp"}, args...)
daemonArgs = append(daemonArgs, "--daemon-child")
if subdomain != "" {
@@ -85,19 +77,9 @@ func runTCP(cmd *cobra.Command, args []string) error {
return StartDaemon("tcp", port, daemonArgs)
}
// Initialize logger
if err := utils.InitLogger(verbose); err != nil {
return fmt.Errorf("failed to initialize logger: %w", err)
}
defer utils.Sync()
logger := utils.GetLogger()
// Load configuration or use command line flags
var serverAddr, token string
if serverURL == "" {
// Try to load from config file
cfg, err := config.LoadClientConfig("")
if err != nil {
return fmt.Errorf(`configuration not found.
@@ -108,17 +90,14 @@ Please run 'drip config init' first, or use flags:
serverAddr = cfg.Server
token = cfg.Token
} else {
// Use command line flags
serverAddr = serverURL
token = authToken
}
// Validate server address
if serverAddr == "" {
return fmt.Errorf("server address is required")
}
// Create connector config
connConfig := &tcp.ConnectorConfig{
ServerAddr: serverAddr,
Token: token,
@@ -129,233 +108,18 @@ Please run 'drip config init' first, or use flags:
Insecure: insecure,
}
// Setup signal handler for graceful shutdown
quit := make(chan os.Signal, 1)
signal.Notify(quit, syscall.SIGINT, syscall.SIGTERM)
// Connection loop with reconnect support
reconnectAttempts := 0
serviceName := getServiceName(port)
for {
// Create connector
connector := tcp.NewConnector(connConfig, logger)
// Connect to server
if reconnectAttempts == 0 {
fmt.Printf("\033[36m🔌 Connecting to %s...\033[0m\n", serverAddr)
} else {
fmt.Printf("\033[33m🔄 Reconnecting to %s (attempt %d/%d)...\033[0m\n", serverAddr, reconnectAttempts, maxReconnectAttempts)
}
if err := connector.Connect(); err != nil {
// Check if this is a non-retryable error
if isNonRetryableErrorTCP(err) {
return fmt.Errorf("failed to connect: %w", err)
}
reconnectAttempts++
if reconnectAttempts >= maxReconnectAttempts {
return fmt.Errorf("failed to connect after %d attempts: %w", maxReconnectAttempts, err)
}
fmt.Printf("\033[31m✗ Connection failed: %v\033[0m\n", err)
fmt.Printf("\033[90m Retrying in %v...\033[0m\n", reconnectInterval)
// Wait before retry, but allow interrupt
select {
case <-quit:
fmt.Println("\n\033[33m🛑 Shutting down...\033[0m")
return nil
case <-time.After(reconnectInterval):
continue
}
}
// Reset reconnect attempts on successful connection
reconnectAttempts = 0
// Save daemon info if running as daemon child
if daemonMarker {
daemonInfo := &DaemonInfo{
PID: os.Getpid(),
Type: "tcp",
Port: port,
Subdomain: subdomain,
Server: serverAddr,
URL: connector.GetURL(),
StartTime: time.Now(),
Executable: os.Args[0],
}
if err := SaveDaemonInfo(daemonInfo); err != nil {
// Log but don't fail
logger.Warn("Failed to save daemon info", zap.Error(err))
}
}
// Print tunnel information
fmt.Println()
fmt.Println("\033[1;35m╔══════════════════════════════════════════════════════════════════╗\033[0m")
fmt.Println("\033[1;35m║\033[0m \033[1;37m🔌 TCP Tunnel Connected Successfully!\033[0m \033[1;35m║\033[0m")
fmt.Println("\033[1;35m╠══════════════════════════════════════════════════════════════════╣\033[0m")
fmt.Printf("\033[1;35m║\033[0m \033[1;37mTunnel URL:\033[0m \033[1;35m║\033[0m\n")
fmt.Printf("\033[1;35m║\033[0m \033[1;36m%-60s\033[0m \033[1;35m║\033[0m\n", connector.GetURL())
fmt.Println("\033[1;35m║\033[0m \033[1;35m║\033[0m")
fmt.Printf("\033[1;35m║\033[0m \033[90mService:\033[0m \033[1;35m%-50s\033[0m \033[1;35m║\033[0m\n", serviceName)
displayAddr := localAddress
if displayAddr == "127.0.0.1" {
displayAddr = "localhost"
}
fmt.Printf("\033[1;35m║\033[0m \033[90mForwarding:\033[0m \033[1m%s:%d\033[0m → \033[36m%s\033[0m%-15s\033[1;35m║\033[0m\n", displayAddr, port, "public", "")
fmt.Printf("\033[1;35m║\033[0m \033[90mLatency:\033[0m \033[90mmeasuring...\033[0m%-40s\033[1;35m║\033[0m\n", "")
fmt.Printf("\033[1;35m║\033[0m \033[90mTraffic:\033[0m \033[90m↓ 0 B ↑ 0 B\033[0m%-32s\033[1;35m║\033[0m\n", "")
fmt.Printf("\033[1;35m║\033[0m \033[90mSpeed:\033[0m \033[90m↓ 0 B/s ↑ 0 B/s\033[0m%-28s\033[1;35m║\033[0m\n", "")
fmt.Printf("\033[1;35m║\033[0m \033[90mRequests:\033[0m \033[90m0\033[0m%-43s\033[1;35m║\033[0m\n", "")
fmt.Println("\033[1;35m╠══════════════════════════════════════════════════════════════════╣\033[0m")
fmt.Println("\033[1;35m║\033[0m \033[90mPress Ctrl+C to stop the tunnel\033[0m \033[1;35m║\033[0m")
fmt.Println("\033[1;35m╚══════════════════════════════════════════════════════════════════╝\033[0m")
fmt.Println()
// Setup latency display
latencyCh := make(chan time.Duration, 1)
connector.SetLatencyCallback(func(latency time.Duration) {
select {
case latencyCh <- latency:
default:
}
})
// Start stats display updater (updates every second)
stopDisplay := make(chan struct{})
disconnected := make(chan struct{})
go func() {
ticker := time.NewTicker(1 * time.Second)
defer ticker.Stop()
var lastLatency time.Duration
for {
select {
case latency := <-latencyCh:
lastLatency = latency
case <-ticker.C:
// Update speed calculation
stats := connector.GetStats()
if stats != nil {
stats.UpdateSpeed()
snapshot := stats.GetSnapshot()
// Move cursor up 8 lines to update display
fmt.Print("\033[8A")
// Update latency line
fmt.Printf("\r\033[1;35m║\033[0m \033[90mLatency:\033[0m %s%-40s\033[1;35m║\033[0m\n", formatLatencyTCP(lastLatency), "")
// Update traffic line
trafficStr := fmt.Sprintf("↓ %s ↑ %s", tcp.FormatBytes(snapshot.TotalBytesIn), tcp.FormatBytes(snapshot.TotalBytesOut))
fmt.Printf("\r\033[1;35m║\033[0m \033[90mTraffic:\033[0m \033[36m%-48s\033[0m\033[1;35m║\033[0m\n", trafficStr)
// Update speed line
speedStr := fmt.Sprintf("↓ %s ↑ %s", tcp.FormatSpeed(snapshot.SpeedIn), tcp.FormatSpeed(snapshot.SpeedOut))
fmt.Printf("\r\033[1;35m║\033[0m \033[90mSpeed:\033[0m \033[33m%-48s\033[0m\033[1;35m║\033[0m\n", speedStr)
// Update requests line
fmt.Printf("\r\033[1;35m║\033[0m \033[90mRequests:\033[0m \033[35m%-47d\033[0m\033[1;35m║\033[0m\n", snapshot.TotalRequests)
// Move back down 4 lines
fmt.Print("\033[4B")
}
case <-stopDisplay:
return
}
}
}()
// Monitor connection in background
go func() {
connector.Wait()
close(disconnected)
}()
// Wait for signal or disconnection
select {
case <-quit:
close(stopDisplay)
fmt.Println("\n\n\033[33m🛑 Shutting down...\033[0m")
connector.Close()
if daemonMarker {
RemoveDaemonInfo("tcp", port)
}
fmt.Println("\033[32m✓\033[0m Tunnel closed")
return nil
case <-disconnected:
close(stopDisplay)
fmt.Println("\n\n\033[31m⚠ Connection lost!\033[0m")
reconnectAttempts++
if reconnectAttempts >= maxReconnectAttempts {
return fmt.Errorf("connection lost after %d reconnect attempts", maxReconnectAttempts)
}
fmt.Printf("\033[90m Reconnecting in %v...\033[0m\n", reconnectInterval)
// Wait before reconnect, but allow interrupt
select {
case <-quit:
fmt.Println("\n\033[33m🛑 Shutting down...\033[0m")
return nil
case <-time.After(reconnectInterval):
continue
}
var daemon *DaemonInfo
if daemonMarker {
daemon = &DaemonInfo{
PID: os.Getpid(),
Type: "tcp",
Port: port,
Subdomain: subdomain,
Server: serverAddr,
StartTime: time.Now(),
Executable: os.Args[0],
}
}
}
// getServiceName returns a friendly name for common port numbers
func getServiceName(port int) string {
services := map[int]string{
22: "SSH",
80: "HTTP",
443: "HTTPS",
3306: "MySQL",
5432: "PostgreSQL",
6379: "Redis",
27017: "MongoDB",
3389: "RDP",
5900: "VNC",
8080: "HTTP (Alt)",
8443: "HTTPS (Alt)",
}
if name, ok := services[port]; ok {
return fmt.Sprintf("%s (port %d)", name, port)
}
return fmt.Sprintf("TCP service on port %d", port)
}
// formatLatencyTCP formats latency with color based on value
func formatLatencyTCP(d time.Duration) string {
ms := d.Milliseconds()
if ms < 50 {
return fmt.Sprintf("\033[32m%dms\033[0m", ms) // Green: excellent
} else if ms < 100 {
return fmt.Sprintf("\033[33m%dms\033[0m", ms) // Yellow: good
} else if ms < 200 {
return fmt.Sprintf("\033[38;5;208m%dms\033[0m", ms) // Orange: moderate
}
return fmt.Sprintf("\033[31m%dms\033[0m", ms) // Red: poor
}
// isNonRetryableErrorTCP checks if an error should not be retried
func isNonRetryableErrorTCP(err error) bool {
errStr := err.Error()
// Subdomain conflicts - no point retrying
if strings.Contains(errStr, "subdomain is already taken") ||
strings.Contains(errStr, "subdomain is reserved") ||
strings.Contains(errStr, "invalid subdomain") {
return true
}
// Authentication errors - no point retrying
if strings.Contains(errStr, "authentication") ||
strings.Contains(errStr, "Invalid authentication token") {
return true
}
return false
return runTunnelWithUI(connConfig, daemon)
}

View File

@@ -0,0 +1,209 @@
package cli
import (
"fmt"
"os"
"os/signal"
"strings"
"syscall"
"time"
"drip/internal/client/cli/ui"
"drip/internal/client/tcp"
"drip/internal/shared/utils"
"go.uber.org/zap"
)
// runTunnelWithUI runs a tunnel with the new UI
func runTunnelWithUI(connConfig *tcp.ConnectorConfig, daemonInfo *DaemonInfo) error {
if err := utils.InitLogger(verbose); err != nil {
return fmt.Errorf("failed to initialize logger: %w", err)
}
defer utils.Sync()
logger := utils.GetLogger()
quit := make(chan os.Signal, 1)
signal.Notify(quit, syscall.SIGINT, syscall.SIGTERM)
reconnectAttempts := 0
for {
connector := tcp.NewConnector(connConfig, logger)
fmt.Println(ui.RenderConnecting(connConfig.ServerAddr, reconnectAttempts, maxReconnectAttempts))
if err := connector.Connect(); err != nil {
if isNonRetryableError(err) {
return fmt.Errorf("failed to connect: %w", err)
}
reconnectAttempts++
if reconnectAttempts >= maxReconnectAttempts {
return fmt.Errorf("failed to connect after %d attempts: %w", maxReconnectAttempts, err)
}
fmt.Println(ui.RenderConnectionFailed(err))
fmt.Println(ui.RenderRetrying(reconnectInterval))
select {
case <-quit:
fmt.Println(ui.RenderShuttingDown())
return nil
case <-time.After(reconnectInterval):
continue
}
}
reconnectAttempts = 0
if daemonInfo != nil {
daemonInfo.URL = connector.GetURL()
if err := SaveDaemonInfo(daemonInfo); err != nil {
logger.Warn("Failed to save daemon info", zap.Error(err))
}
}
displayAddr := connConfig.LocalHost
if displayAddr == "127.0.0.1" {
displayAddr = "localhost"
}
status := &ui.TunnelStatus{
Type: string(connConfig.TunnelType),
URL: connector.GetURL(),
LocalAddr: fmt.Sprintf("%s:%d", displayAddr, connConfig.LocalPort),
}
fmt.Print(ui.RenderTunnelConnected(status))
latencyCh := make(chan time.Duration, 1)
connector.SetLatencyCallback(func(latency time.Duration) {
select {
case latencyCh <- latency:
default:
}
})
stopDisplay := make(chan struct{})
disconnected := make(chan struct{})
go func() {
ticker := time.NewTicker(1 * time.Second)
defer ticker.Stop()
var lastLatency time.Duration
lastRenderedLines := 0
for {
select {
case latency := <-latencyCh:
lastLatency = latency
case <-ticker.C:
stats := connector.GetStats()
if stats != nil {
stats.UpdateSpeed()
snapshot := stats.GetSnapshot()
status.Latency = lastLatency
status.BytesIn = snapshot.TotalBytesIn
status.BytesOut = snapshot.TotalBytesOut
status.SpeedIn = float64(snapshot.SpeedIn)
status.SpeedOut = float64(snapshot.SpeedOut)
status.TotalRequest = snapshot.TotalRequests
statsView := ui.RenderTunnelStats(status)
if lastRenderedLines > 0 {
fmt.Print(clearLines(lastRenderedLines))
}
fmt.Print(statsView)
lastRenderedLines = countRenderedLines(statsView)
}
case <-stopDisplay:
return
}
}
}()
go func() {
connector.Wait()
close(disconnected)
}()
select {
case <-quit:
close(stopDisplay)
fmt.Println()
fmt.Println(ui.RenderShuttingDown())
// Close with timeout (wait for ongoing requests to complete)
done := make(chan struct{})
go func() {
connector.Close()
close(done)
}()
select {
case <-done:
// Closed successfully
case <-time.After(5 * time.Second):
fmt.Println(ui.Warning("Force closing (timeout)..."))
}
if daemonInfo != nil {
RemoveDaemonInfo(daemonInfo.Type, daemonInfo.Port)
}
fmt.Println(ui.Success("Tunnel closed"))
return nil
case <-disconnected:
close(stopDisplay)
fmt.Println()
fmt.Println(ui.RenderConnectionLost())
reconnectAttempts++
if reconnectAttempts >= maxReconnectAttempts {
return fmt.Errorf("connection lost after %d reconnect attempts", maxReconnectAttempts)
}
fmt.Println(ui.RenderRetrying(reconnectInterval))
select {
case <-quit:
fmt.Println(ui.RenderShuttingDown())
return nil
case <-time.After(reconnectInterval):
continue
}
}
}
}
func clearLines(lines int) string {
if lines <= 0 {
return ""
}
return fmt.Sprintf("\033[%dA\033[J", lines)
}
func countRenderedLines(block string) int {
if block == "" {
return 0
}
lines := strings.Count(block, "\n")
if !strings.HasSuffix(block, "\n") {
lines++
}
return lines
}
func isNonRetryableError(err error) bool {
errStr := err.Error()
return strings.Contains(errStr, "subdomain is already taken") ||
strings.Contains(errStr, "subdomain is reserved") ||
strings.Contains(errStr, "invalid subdomain") ||
strings.Contains(errStr, "authentication") ||
strings.Contains(errStr, "Invalid authentication token")
}
func isNonRetryableErrorTCP(err error) bool {
return isNonRetryableError(err)
}

View File

@@ -0,0 +1,117 @@
package ui
import (
"fmt"
)
// RenderConfigInit renders config initialization UI
func RenderConfigInit() string {
title := "Drip Configuration Setup"
box := boxStyle.Copy().Width(50)
return "\n" + box.Render(titleStyle.Render(title)) + "\n"
}
// RenderConfigShow renders the config display
func RenderConfigShow(server, token string, tokenHidden bool, tlsEnabled bool, configPath string) string {
lines := []string{
KeyValue("Server", server),
}
if token != "" {
if tokenHidden {
if len(token) > 10 {
displayToken := token[:3] + "***" + token[len(token)-3:]
lines = append(lines, KeyValue("Token", Muted(displayToken+" (hidden)")))
} else {
lines = append(lines, KeyValue("Token", Muted(token[:3]+"*** (hidden)")))
}
} else {
lines = append(lines, KeyValue("Token", token))
}
} else {
lines = append(lines, KeyValue("Token", Muted("(not set)")))
}
tlsStatus := "enabled"
if !tlsEnabled {
tlsStatus = "disabled"
}
lines = append(lines, KeyValue("TLS", tlsStatus))
lines = append(lines, KeyValue("Config", Muted(configPath)))
return Info("Current Configuration", lines...)
}
// RenderConfigSaved renders config saved message
func RenderConfigSaved(configPath string) string {
return SuccessBox(
"Configuration Saved",
Muted("Config saved to: ")+configPath,
"",
Muted("You can now use 'drip' without --server and --token flags"),
)
}
// RenderConfigUpdated renders config updated message
func RenderConfigUpdated(updates []string) string {
lines := make([]string, len(updates)+1)
for i, update := range updates {
lines[i] = Success(update)
}
lines[len(updates)] = ""
lines = append(lines, Muted("Configuration has been updated"))
return SuccessBox("Configuration Updated", lines...)
}
// RenderConfigDeleted renders config deleted message
func RenderConfigDeleted() string {
return SuccessBox("Configuration Deleted", Muted("Configuration file has been removed"))
}
// RenderConfigValidation renders config validation results
func RenderConfigValidation(serverValid, tokenSet, tlsEnabled bool) string {
lines := []string{}
if serverValid {
lines = append(lines, Success("Server address is valid"))
} else {
lines = append(lines, Error("Server address is not set"))
}
if tokenSet {
lines = append(lines, Success("Token is set"))
} else {
lines = append(lines, Warning("Token is not set (authentication may fail)"))
}
if tlsEnabled {
lines = append(lines, Success("TLS is enabled"))
} else {
lines = append(lines, Warning("TLS is disabled (not recommended for production)"))
}
lines = append(lines, "")
lines = append(lines, Muted("Configuration validation complete"))
if serverValid && tokenSet && tlsEnabled {
return SuccessBox("Configuration Valid", lines...)
}
return WarningBox("Configuration Validation", lines...)
}
// RenderDaemonStarted renders daemon started message
func RenderDaemonStarted(tunnelType string, port int, pid int, logPath string) string {
lines := []string{
KeyValue("Type", Highlight(tunnelType)),
KeyValue("Port", fmt.Sprintf("%d", port)),
KeyValue("PID", fmt.Sprintf("%d", pid)),
"",
Muted("Commands:"),
Cyan(" drip list") + Muted(" Check tunnel status"),
Cyan(fmt.Sprintf(" drip attach %s %d", tunnelType, port)) + Muted(" View logs"),
Cyan(fmt.Sprintf(" drip stop %s %d", tunnelType, port)) + Muted(" Stop tunnel"),
"",
Muted("Logs: ")+mutedStyle.Render(logPath),
}
return SuccessBox("Tunnel Started in Background", lines...)
}

View File

@@ -0,0 +1,204 @@
package ui
import (
"github.com/charmbracelet/lipgloss"
)
var (
// Colors inspired by Vercel CLI
primaryColor = lipgloss.Color("#0070F3")
successColor = lipgloss.Color("#0070F3")
warningColor = lipgloss.Color("#F5A623")
errorColor = lipgloss.Color("#E00")
mutedColor = lipgloss.Color("#888")
highlightColor = lipgloss.Color("#0070F3")
cyanColor = lipgloss.Color("#50E3C2")
purpleColor = lipgloss.Color("#7928CA")
// Base styles
baseStyle = lipgloss.NewStyle().
PaddingLeft(1).
PaddingRight(1)
// Box styles - Vercel-like clean box
boxStyle = lipgloss.NewStyle().
Border(lipgloss.RoundedBorder()).
BorderForeground(lipgloss.Color("#333")).
Padding(1, 2).
MarginTop(1).
MarginBottom(1)
successBoxStyle = boxStyle.Copy().
BorderForeground(successColor)
warningBoxStyle = boxStyle.Copy().
BorderForeground(warningColor)
errorBoxStyle = boxStyle.Copy().
BorderForeground(errorColor)
// Text styles
titleStyle = lipgloss.NewStyle().
Bold(true).
Foreground(lipgloss.Color("#FFF"))
subtitleStyle = lipgloss.NewStyle().
Foreground(mutedColor)
successStyle = lipgloss.NewStyle().
Foreground(successColor).
Bold(true)
errorStyle = lipgloss.NewStyle().
Foreground(errorColor).
Bold(true)
warningStyle = lipgloss.NewStyle().
Foreground(warningColor).
Bold(true)
mutedStyle = lipgloss.NewStyle().
Foreground(mutedColor)
highlightStyle = lipgloss.NewStyle().
Foreground(highlightColor).
Bold(true)
cyanStyle = lipgloss.NewStyle().
Foreground(cyanColor)
urlStyle = lipgloss.NewStyle().
Foreground(highlightColor).
Underline(true).
Bold(true)
labelStyle = lipgloss.NewStyle().
Foreground(mutedColor).
Width(12)
valueStyle = lipgloss.NewStyle().
Bold(true)
// Table styles
tableHeaderStyle = lipgloss.NewStyle().
Foreground(mutedColor).
Bold(true).
PaddingRight(2)
tableCellStyle = lipgloss.NewStyle().
PaddingRight(2)
tableRowHighlight = lipgloss.NewStyle().
Foreground(highlightColor).
Bold(true)
)
// Success returns a styled success message
func Success(text string) string {
return successStyle.Render("✓ " + text)
}
// Error returns a styled error message
func Error(text string) string {
return errorStyle.Render("✗ " + text)
}
// Warning returns a styled warning message
func Warning(text string) string {
return warningStyle.Render("⚠ " + text)
}
// Muted returns a styled muted text
func Muted(text string) string {
return mutedStyle.Render(text)
}
// Highlight returns a styled highlighted text
func Highlight(text string) string {
return highlightStyle.Render(text)
}
// Cyan returns a styled cyan text
func Cyan(text string) string {
return cyanStyle.Render(text)
}
// URL returns a styled URL
func URL(text string) string {
return urlStyle.Render(text)
}
// Title returns a styled title
func Title(text string) string {
return titleStyle.Render(text)
}
// Subtitle returns a styled subtitle
func Subtitle(text string) string {
return subtitleStyle.Render(text)
}
// KeyValue returns a styled key-value pair
func KeyValue(key, value string) string {
return labelStyle.Render(key+":") + " " + valueStyle.Render(value)
}
// Info renders an info box (Vercel-style)
func Info(title string, lines ...string) string {
content := titleStyle.Render(title)
if len(lines) > 0 {
content += "\n\n"
for i, line := range lines {
if i > 0 {
content += "\n"
}
content += line
}
}
return boxStyle.Render(content)
}
// SuccessBox renders a success box
func SuccessBox(title string, lines ...string) string {
content := successStyle.Render("✓ " + title)
if len(lines) > 0 {
content += "\n\n"
for i, line := range lines {
if i > 0 {
content += "\n"
}
content += line
}
}
return successBoxStyle.Render(content)
}
// WarningBox renders a warning box
func WarningBox(title string, lines ...string) string {
content := warningStyle.Render("⚠ " + title)
if len(lines) > 0 {
content += "\n\n"
for i, line := range lines {
if i > 0 {
content += "\n"
}
content += line
}
}
return warningBoxStyle.Render(content)
}
// ErrorBox renders an error box
func ErrorBox(title string, lines ...string) string {
content := errorStyle.Render("✗ " + title)
if len(lines) > 0 {
content += "\n\n"
for i, line := range lines {
if i > 0 {
content += "\n"
}
content += line
}
}
return errorBoxStyle.Render(content)
}

View File

@@ -0,0 +1,127 @@
package ui
import (
"fmt"
"strings"
"github.com/charmbracelet/lipgloss"
)
// Table represents a simple table for CLI output
type Table struct {
headers []string
rows [][]string
title string
}
// NewTable creates a new table
func NewTable(headers []string) *Table {
return &Table{
headers: headers,
rows: [][]string{},
}
}
// WithTitle sets the table title
func (t *Table) WithTitle(title string) *Table {
t.title = title
return t
}
// AddRow adds a row to the table
func (t *Table) AddRow(row []string) *Table {
t.rows = append(t.rows, row)
return t
}
// Render renders the table (Vercel-style)
func (t *Table) Render() string {
if len(t.rows) == 0 {
return ""
}
// Calculate column widths
colWidths := make([]int, len(t.headers))
for i, header := range t.headers {
colWidths[i] = lipgloss.Width(header)
}
for _, row := range t.rows {
for i, cell := range row {
if i < len(colWidths) {
width := lipgloss.Width(cell)
if width > colWidths[i] {
colWidths[i] = width
}
}
}
}
var output strings.Builder
// Title
if t.title != "" {
output.WriteString("\n")
output.WriteString(titleStyle.Render(t.title))
output.WriteString("\n\n")
}
// Header
headerParts := make([]string, len(t.headers))
for i, header := range t.headers {
style := tableHeaderStyle.Copy().Width(colWidths[i])
headerParts[i] = style.Render(header)
}
output.WriteString(strings.Join(headerParts, " "))
output.WriteString("\n")
// Separator line
separatorParts := make([]string, len(t.headers))
for i := range t.headers {
separatorParts[i] = mutedStyle.Render(strings.Repeat("─", colWidths[i]))
}
output.WriteString(strings.Join(separatorParts, " "))
output.WriteString("\n")
// Rows
for _, row := range t.rows {
rowParts := make([]string, len(t.headers))
for i, cell := range row {
if i < len(colWidths) {
style := tableCellStyle.Copy().Width(colWidths[i])
rowParts[i] = style.Render(cell)
}
}
output.WriteString(strings.Join(rowParts, " "))
output.WriteString("\n")
}
output.WriteString("\n")
return output.String()
}
// Print prints the table
func (t *Table) Print() {
fmt.Print(t.Render())
}
// RenderList renders a simple list with bullet points
func RenderList(items []string) string {
var output strings.Builder
for _, item := range items {
output.WriteString(mutedStyle.Render(" • "))
output.WriteString(item)
output.WriteString("\n")
}
return output.String()
}
// RenderNumberedList renders a numbered list
func RenderNumberedList(items []string) string {
var output strings.Builder
for i, item := range items {
output.WriteString(mutedStyle.Render(fmt.Sprintf(" %d. ", i+1)))
output.WriteString(item)
output.WriteString("\n")
}
return output.String()
}

View File

@@ -0,0 +1,231 @@
package ui
import (
"fmt"
"strings"
"time"
"github.com/charmbracelet/lipgloss"
)
const (
tunnelCardWidth = 76
statsColumnWidth = 32
)
var (
latencyFastColor = lipgloss.Color("#22c55e") // green
latencyYellowColor = lipgloss.Color("#eab308") // yellow
latencyOrangeColor = lipgloss.Color("#f97316") // orange
latencyRedColor = lipgloss.Color("#ef4444") // red
)
// TunnelStatus represents the status of a tunnel
type TunnelStatus struct {
Type string // "http", "https", "tcp"
URL string // Public URL
LocalAddr string // Local address
Latency time.Duration // Current latency
BytesIn int64 // Bytes received
BytesOut int64 // Bytes sent
SpeedIn float64 // Download speed
SpeedOut float64 // Upload speed
TotalRequest int64 // Total requests
}
// RenderTunnelConnected renders the tunnel connection card
func RenderTunnelConnected(status *TunnelStatus) string {
icon, typeStr, accent := tunnelVisuals(status.Type)
card := lipgloss.NewStyle().
Border(lipgloss.RoundedBorder()).
BorderForeground(accent).
Padding(1, 2).
Width(tunnelCardWidth)
typeBadge := lipgloss.NewStyle().
Background(accent).
Foreground(lipgloss.Color("#f8fafc")).
Bold(true).
Padding(0, 1).
Render(strings.ToUpper(typeStr) + " TUNNEL")
headline := lipgloss.JoinHorizontal(
lipgloss.Left,
lipgloss.NewStyle().Foreground(accent).Render(icon),
lipgloss.NewStyle().Bold(true).MarginLeft(1).Render("Tunnel Connected"),
lipgloss.NewStyle().MarginLeft(2).Render(typeBadge),
)
urlLine := urlStyle.Copy().Foreground(accent).Render(status.URL)
forwardLine := Muted("⇢ ") + valueStyle.Render(status.LocalAddr)
hint := mutedStyle.Render("Ctrl+C to stop • reconnects automatically")
content := lipgloss.JoinVertical(
lipgloss.Left,
headline,
"",
urlLine,
forwardLine,
"",
hint,
)
return "\n" + card.Render(content) + "\n"
}
// RenderTunnelStats renders real-time tunnel statistics in a card
func RenderTunnelStats(status *TunnelStatus) string {
latencyStr := formatLatency(status.Latency)
trafficStr := fmt.Sprintf("↓ %s ↑ %s", formatBytes(status.BytesIn), formatBytes(status.BytesOut))
speedStr := fmt.Sprintf("↓ %s ↑ %s", formatSpeed(status.SpeedIn), formatSpeed(status.SpeedOut))
requestsStr := fmt.Sprintf("%d", status.TotalRequest)
_, _, accent := tunnelVisuals(status.Type)
header := lipgloss.JoinHorizontal(
lipgloss.Left,
lipgloss.NewStyle().Foreground(accent).Render("◉"),
lipgloss.NewStyle().Bold(true).MarginLeft(1).Render("Live Metrics"),
)
row1 := lipgloss.JoinHorizontal(
lipgloss.Top,
statColumn("Latency", latencyStr, statsColumnWidth),
statColumn("Requests", highlightStyle.Render(requestsStr), statsColumnWidth),
)
row2 := lipgloss.JoinHorizontal(
lipgloss.Top,
statColumn("Traffic", Cyan(trafficStr), statsColumnWidth),
statColumn("Speed", warningStyle.Render(speedStr), statsColumnWidth),
)
card := lipgloss.NewStyle().
Border(lipgloss.RoundedBorder()).
BorderForeground(accent).
Padding(1, 2).
Width(tunnelCardWidth)
body := lipgloss.JoinVertical(
lipgloss.Left,
header,
"",
row1,
row2,
)
return "\n" + card.Render(body) + "\n"
}
// RenderConnecting renders the connecting message
func RenderConnecting(serverAddr string, attempt int, maxAttempts int) string {
if attempt == 0 {
return Highlight("◌") + " Connecting to " + Muted(serverAddr) + "..."
}
return Warning(fmt.Sprintf("◌ Reconnecting to %s (attempt %d/%d)...", serverAddr, attempt, maxAttempts))
}
// RenderConnectionFailed renders connection failure message
func RenderConnectionFailed(err error) string {
return Error(fmt.Sprintf("Connection failed: %v", err))
}
// RenderShuttingDown renders shutdown message
func RenderShuttingDown() string {
return Warning("⏹ Shutting down...")
}
// RenderConnectionLost renders connection lost message
func RenderConnectionLost() string {
return Error("⚠ Connection lost!")
}
// RenderRetrying renders retry message
func RenderRetrying(interval time.Duration) string {
return Muted(fmt.Sprintf(" Retrying in %v...", interval))
}
// formatLatency formats latency with color
func formatLatency(d time.Duration) string {
ms := d.Milliseconds()
var style lipgloss.Style
if ms == 0 {
return mutedStyle.Render("measuring...")
}
switch {
case ms < 50:
style = lipgloss.NewStyle().Foreground(latencyFastColor)
case ms < 150:
style = lipgloss.NewStyle().Foreground(latencyYellowColor)
case ms < 300:
style = lipgloss.NewStyle().Foreground(latencyOrangeColor)
default:
style = lipgloss.NewStyle().Foreground(latencyRedColor)
}
return style.Render(fmt.Sprintf("%dms", ms))
}
// formatBytes formats bytes to human readable format
func formatBytes(bytes int64) string {
const unit = 1024
if bytes < unit {
return fmt.Sprintf("%d B", bytes)
}
div, exp := int64(unit), 0
for n := bytes / unit; n >= unit; n /= unit {
div *= unit
exp++
}
return fmt.Sprintf("%.1f %cB", float64(bytes)/float64(div), "KMGTPE"[exp])
}
// formatSpeed formats speed to human readable format
func formatSpeed(bytesPerSec float64) string {
const unit = 1024.0
if bytesPerSec < unit {
return fmt.Sprintf("%.0f B/s", bytesPerSec)
}
div, exp := unit, 0
for n := bytesPerSec / unit; n >= unit; n /= unit {
div *= unit
exp++
}
return fmt.Sprintf("%.1f %cB/s", bytesPerSec/div, "KMGTPE"[exp])
}
func statColumn(label, value string, width int) string {
labelView := lipgloss.NewStyle().
Foreground(mutedColor).
Render(strings.ToUpper(label))
block := lipgloss.JoinHorizontal(
lipgloss.Left,
labelView,
lipgloss.NewStyle().MarginLeft(1).Render(value),
)
if width <= 0 {
return block
}
return lipgloss.NewStyle().
Width(width).
Render(block)
}
func tunnelVisuals(tunnelType string) (string, string, lipgloss.Color) {
switch tunnelType {
case "http":
return "🚀", "HTTP", lipgloss.Color("#0070F3")
case "https":
return "🔒", "HTTPS", lipgloss.Color("#2D8CFF")
case "tcp":
return "🔌", "TCP", lipgloss.Color("#50E3C2")
default:
return "🌐", strings.ToUpper(tunnelType), lipgloss.Color("#0070F3")
}
}

View File

@@ -2,9 +2,10 @@ package tcp
import (
"crypto/tls"
"encoding/json"
"fmt"
"net"
json "github.com/goccy/go-json"
"sync"
"time"
@@ -132,9 +133,10 @@ func (c *Connector) Connect() error {
bufferPool,
)
c.frameWriter.EnableHeartbeat(constants.HeartbeatInterval, c.createHeartbeatFrame)
go c.frameHandler.WarmupConnectionPool(3)
go c.handleFrames()
go c.heartbeat()
return nil
}
@@ -278,38 +280,21 @@ func (c *Connector) handleFrames() {
}
}
// heartbeat sends periodic heartbeat frames
func (c *Connector) heartbeat() {
ticker := time.NewTicker(constants.HeartbeatInterval)
defer ticker.Stop()
c.sendHeartbeat()
for {
select {
case <-c.stopCh:
return
case <-ticker.C:
c.sendHeartbeat()
}
// createHeartbeatFrame creates a heartbeat frame to be sent by the write loop.
func (c *Connector) createHeartbeatFrame() *protocol.Frame {
c.closedMu.RLock()
if c.closed {
c.closedMu.RUnlock()
return nil
}
}
// sendHeartbeat sends a heartbeat frame and records the time
func (c *Connector) sendHeartbeat() {
hbFrame := protocol.NewFrame(protocol.FrameTypeHeartbeat, nil)
c.closedMu.RUnlock()
c.heartbeatMu.Lock()
c.heartbeatSentAt = time.Now()
c.heartbeatMu.Unlock()
err := c.frameWriter.WriteFrame(hbFrame)
if err != nil {
c.logger.Error("Failed to send heartbeat", zap.Error(err))
c.Close()
return
}
c.logger.Debug("Heartbeat sent")
return protocol.NewFrame(protocol.FrameTypeHeartbeat, nil)
}
// SendFrame sends a frame to the server
@@ -330,8 +315,20 @@ func (c *Connector) Close() error {
close(c.stopCh)
// Wait for active handlers with timeout
c.logger.Debug("Waiting for active handlers to complete")
c.handlerWg.Wait()
done := make(chan struct{})
go func() {
c.handlerWg.Wait()
close(done)
}()
select {
case <-done:
c.logger.Debug("All handlers completed")
case <-time.After(3 * time.Second):
c.logger.Warn("Force closing: some handlers are still active")
}
if c.conn != nil {
closeFrame := protocol.NewFrame(protocol.FrameTypeClose, nil)

View File

@@ -19,7 +19,7 @@ import (
// FrameHandler handles data frames and forwards to local service
type FrameHandler struct {
conn net.Conn
frameWriter *protocol.FrameWriter // Async batch writer (replaces writeMu)
frameWriter *protocol.FrameWriter
localHost string
localPort int
logger *zap.Logger
@@ -28,9 +28,9 @@ type FrameHandler struct {
tunnelType protocol.TunnelType
httpClient *http.Client
stats *TrafficStats
isClosedCheck func() bool // Function to check if connection is closed
isClosedCheck func() bool
bufferPool *pool.BufferPool
headerPool *pool.HeaderPool // Header pool for Priority 9 optimization
headerPool *pool.HeaderPool
}
// Stream represents a single request/response stream

View File

@@ -2,7 +2,6 @@ package proxy
import (
"context"
"encoding/json"
"fmt"
"io"
"net/http"
@@ -10,6 +9,8 @@ import (
"strings"
"time"
json "github.com/goccy/go-json"
"drip/internal/server/tunnel"
"drip/internal/shared/constants"
"drip/internal/shared/pool"
@@ -244,7 +245,7 @@ func (h *Handler) serveHomePage(w http.ResponseWriter, r *http.Request) {
<html>
<head>
<meta charset="UTF-8" />
<title>Drip - Tunnel to Localhost</title>
<title>Drip - Your Tunnel, Your Domain, Anywhere</title>
<style>
body { font-family: Arial, sans-serif; max-width: 800px; margin: 50px auto; padding: 20px; }
h1 { color: #333; }
@@ -253,12 +254,12 @@ func (h *Handler) serveHomePage(w http.ResponseWriter, r *http.Request) {
</style>
</head>
<body>
<h1>💧 Drip - Fast Tunnels to Localhost</h1>
<p>A high-performance tunneling service.</p>
<h1>💧 Drip - Your Tunnel, Your Domain, Anywhere</h1>
<p>A self-hosted tunneling solution to securely expose your services to the internet.</p>
<h2>Quick Start</h2>
<p>Install the client:</p>
<code>bash <(curl -fsSL https://raw.githubusercontent.com/Gouryella/drip/refs/heads/main/scripts/install.sh)</code>
<code>bash <(curl -fsSL https://raw.githubusercontent.com/Gouryella/drip/main/scripts/install.sh)</code>
<p>Start a tunnel:</p>
<code>drip http 3000</code><br><br>

View File

@@ -2,7 +2,7 @@ package tcp
import (
"bufio"
"encoding/json"
json "github.com/goccy/go-json"
"fmt"
"io"
"net"
@@ -82,7 +82,6 @@ func (c *Connection) Handle() error {
return fmt.Errorf("failed to peek connection: %w", err)
}
// Check if this is an HTTP request
peekStr := string(peek)
httpMethods := []string{"GET ", "POST", "PUT ", "DELE", "HEAD", "OPTI", "PATC", "CONN", "TRAC"}
isHTTP := false
@@ -110,16 +109,13 @@ func (c *Connection) Handle() error {
return fmt.Errorf("expected register frame, got %s", frame.Type)
}
// Parse registration request
var req protocol.RegisterRequest
if err := json.Unmarshal(frame.Payload, &req); err != nil {
return fmt.Errorf("failed to parse registration request: %w", err)
}
// Store tunnel type
c.tunnelType = req.TunnelType
// Authenticate
if c.authToken != "" && req.Token != c.authToken {
c.sendError("authentication_failed", "Invalid authentication token")
return fmt.Errorf("authentication failed")
@@ -144,7 +140,6 @@ func (c *Connection) Handle() error {
}
}
// Register tunnel
subdomain, err := c.manager.Register(nil, req.CustomSubdomain)
if err != nil {
c.sendError("registration_failed", err.Error())
@@ -155,7 +150,6 @@ func (c *Connection) Handle() error {
c.subdomain = subdomain
// Get tunnel connection
tunnelConn, ok := c.manager.Get(subdomain)
if !ok {
return fmt.Errorf("failed to get registered tunnel")
@@ -211,7 +205,6 @@ func (c *Connection) Handle() error {
// Create frame writer for async writes
c.frameWriter = protocol.NewFrameWriter(c.conn)
// Clear read deadline
c.conn.SetReadDeadline(time.Time{})
// Start TCP proxy only for TCP tunnels
@@ -222,7 +215,6 @@ func (c *Connection) Handle() error {
}
}
// Start heartbeat checker
go c.heartbeatChecker()
// Handle frames (pass reader for consistent buffering)
@@ -510,24 +502,20 @@ func (c *Connection) Close() {
c.once.Do(func() {
close(c.stopCh)
// Close frame writer
if c.frameWriter != nil {
c.frameWriter.Close()
}
// Stop TCP proxy
if c.proxy != nil {
c.proxy.Stop()
}
c.conn.Close()
// Release allocated port
if c.port > 0 && c.portAlloc != nil {
c.portAlloc.Release(c.port)
}
// Unregister tunnel
if c.subdomain != "" {
c.manager.Unregister(c.subdomain)
}

View File

@@ -61,7 +61,6 @@ func NewListener(address string, tlsConfig *tls.Config, authToken string, manage
func (l *Listener) Start() error {
var err error
// Create TLS listener
l.listener, err = tls.Listen("tcp", l.address, l.tlsConfig)
if err != nil {
return fmt.Errorf("failed to start TLS listener: %w", err)
@@ -72,7 +71,6 @@ func (l *Listener) Start() error {
zap.String("tls_version", "TLS 1.3"),
)
// Accept connections in background
l.wg.Add(1)
go l.acceptLoop()
@@ -102,7 +100,7 @@ func (l *Listener) acceptLoop() {
}
select {
case <-l.stopCh:
return // Listener was stopped
return
default:
l.logger.Error("Failed to accept connection", zap.Error(err))
continue
@@ -128,14 +126,12 @@ func (l *Listener) handleConnection(netConn net.Conn) {
defer l.wg.Done()
defer netConn.Close()
// Get TLS connection info
tlsConn, ok := netConn.(*tls.Conn)
if !ok {
l.logger.Error("Connection is not TLS")
return
}
// Force TLS handshake to complete
if err := tlsConn.Handshake(); err != nil {
// TLS handshake failures are common (HTTP clients, scanners, etc.)
// Log as WARN instead of ERROR
@@ -146,7 +142,6 @@ func (l *Listener) handleConnection(netConn net.Conn) {
return
}
// Log connection info
state := tlsConn.ConnectionState()
l.logger.Info("New connection",
zap.String("remote_addr", netConn.RemoteAddr().String()),
@@ -154,7 +149,6 @@ func (l *Listener) handleConnection(netConn net.Conn) {
zap.String("cipher_suite", tls.CipherSuiteName(state.CipherSuite)),
)
// Verify TLS 1.3
if state.Version != tls.VersionTLS13 {
l.logger.Warn("Connection not using TLS 1.3",
zap.Uint16("version", state.Version),
@@ -162,23 +156,19 @@ func (l *Listener) handleConnection(netConn net.Conn) {
return
}
// Create connection handler
conn := NewConnection(netConn, l.authToken, l.manager, l.logger, l.portAlloc, l.domain, l.publicPort, l.httpHandler, l.responseChans)
// Store connection
connID := netConn.RemoteAddr().String()
l.connMu.Lock()
l.connections[connID] = conn
l.connMu.Unlock()
// Remove connection on exit
defer func() {
l.connMu.Lock()
delete(l.connections, connID)
l.connMu.Unlock()
}()
// Handle connection (blocking)
if err := conn.Handle(); err != nil {
errStr := err.Error()
@@ -217,27 +207,22 @@ func (l *Listener) handleConnection(netConn net.Conn) {
func (l *Listener) Stop() error {
l.logger.Info("Stopping TCP listener")
// Signal stop
close(l.stopCh)
// Close listener
if l.listener != nil {
if err := l.listener.Close(); err != nil {
l.logger.Error("Failed to close listener", zap.Error(err))
}
}
// Close all connections
l.connMu.Lock()
for _, conn := range l.connections {
conn.Close()
}
l.connMu.Unlock()
// Wait for all goroutines to finish
l.wg.Wait()
// Close worker pool
if l.workerPool != nil {
l.workerPool.Close()
}

View File

@@ -58,7 +58,6 @@ func (p *TunnelProxy) Start() error {
zap.String("subdomain", p.subdomain),
)
// Accept connections in background
p.wg.Add(1)
go p.acceptLoop()
@@ -76,7 +75,6 @@ func (p *TunnelProxy) acceptLoop() {
default:
}
// Set accept deadline
p.listener.(*net.TCPListener).SetDeadline(time.Now().Add(1 * time.Second))
conn, err := p.listener.Accept()
@@ -92,7 +90,6 @@ func (p *TunnelProxy) acceptLoop() {
}
}
// Handle connection
p.wg.Add(1)
go p.handleConnection(conn)
}

View File

@@ -1,263 +0,0 @@
package tunnel
import (
"testing"
"time"
"go.uber.org/zap"
)
func TestNewConnection(t *testing.T) {
subdomain := "test123"
logger := zap.NewNop()
// We can't create a real WebSocket connection in unit tests,
// so we'll just test with nil
conn := NewConnection(subdomain, nil, logger)
if conn == nil {
t.Fatal("NewConnection() returned nil")
}
if conn.Subdomain != subdomain {
t.Errorf("Subdomain = %v, want %v", conn.Subdomain, subdomain)
}
if conn.SendCh == nil {
t.Error("SendCh is nil")
}
if conn.CloseCh == nil {
t.Error("CloseCh is nil")
}
// Check that LastActive is recent (within last second)
now := time.Now()
if now.Sub(conn.LastActive) > time.Second {
t.Error("LastActive is not recent")
}
}
func TestConnectionUpdateActivity(t *testing.T) {
logger := zap.NewNop()
conn := NewConnection("test", nil, logger)
// Get initial LastActive
initial := conn.LastActive
// Wait a bit
time.Sleep(10 * time.Millisecond)
// Update activity
conn.UpdateActivity()
// Check that LastActive was updated
if !conn.LastActive.After(initial) {
t.Error("UpdateActivity() did not update LastActive")
}
// Check that it's recent
now := time.Now()
if now.Sub(conn.LastActive) > time.Second {
t.Error("UpdateActivity() did not set recent timestamp")
}
}
func TestConnectionIsAlive(t *testing.T) {
tests := []struct {
name string
lastActive time.Time
timeout time.Duration
want bool
}{
{
name: "fresh connection is alive",
lastActive: time.Now(),
timeout: 90 * time.Second,
want: true,
},
{
name: "stale connection is not alive",
lastActive: time.Now().Add(-2 * time.Minute),
timeout: 90 * time.Second,
want: false,
},
{
name: "exactly at timeout is not alive",
lastActive: time.Now().Add(-90 * time.Second),
timeout: 90 * time.Second,
want: false,
},
{
name: "just before timeout is alive",
lastActive: time.Now().Add(-89 * time.Second),
timeout: 90 * time.Second,
want: true,
},
}
logger := zap.NewNop()
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
conn := NewConnection("test", nil, logger)
conn.mu.Lock()
conn.LastActive = tt.lastActive
conn.mu.Unlock()
got := conn.IsAlive(tt.timeout)
if got != tt.want {
t.Errorf("IsAlive() = %v, want %v (age: %v, timeout: %v)",
got, tt.want, time.Since(tt.lastActive), tt.timeout)
}
})
}
}
func TestConnectionSend(t *testing.T) {
logger := zap.NewNop()
conn := NewConnection("test", nil, logger)
data := []byte("test message")
// Test successful send
err := conn.Send(data)
if err != nil {
t.Errorf("Send() error = %v, want nil", err)
}
// Verify data was sent to channel
select {
case received := <-conn.SendCh:
if string(received) != string(data) {
t.Errorf("Received data = %v, want %v", string(received), string(data))
}
case <-time.After(100 * time.Millisecond):
t.Error("Send() did not send data to channel")
}
}
func TestConnectionSendTimeout(t *testing.T) {
logger := zap.NewNop()
conn := NewConnection("test", nil, logger)
// Fill the channel
for i := 0; i < 256; i++ {
conn.SendCh <- []byte("fill")
}
// Try to send when channel is full
data := []byte("test message")
err := conn.Send(data)
if err != ErrSendTimeout {
t.Errorf("Send() on full channel error = %v, want %v", err, ErrSendTimeout)
}
}
func TestConnectionClose(t *testing.T) {
logger := zap.NewNop()
conn := NewConnection("test", nil, logger)
// Close the connection
conn.Close()
// Verify CloseCh is closed
select {
case <-conn.CloseCh:
// Successfully received from closed channel
case <-time.After(100 * time.Millisecond):
t.Error("Close() did not close CloseCh")
}
// Try to close again (should not panic)
defer func() {
if r := recover(); r != nil {
t.Error("Close() panicked on second call")
}
}()
conn.Close()
}
func TestConnectionConcurrentUpdateActivity(t *testing.T) {
logger := zap.NewNop()
conn := NewConnection("test", nil, logger)
// Update activity concurrently
done := make(chan bool)
for i := 0; i < 100; i++ {
go func() {
conn.UpdateActivity()
done <- true
}()
}
// Wait for all goroutines
for i := 0; i < 100; i++ {
<-done
}
// Verify LastActive is recent
now := time.Now()
if now.Sub(conn.LastActive) > time.Second {
t.Error("Concurrent UpdateActivity() failed")
}
}
func TestConnectionConcurrentIsAlive(t *testing.T) {
logger := zap.NewNop()
conn := NewConnection("test", nil, logger)
// Check IsAlive concurrently
done := make(chan bool)
for i := 0; i < 100; i++ {
go func() {
conn.IsAlive(90 * time.Second)
done <- true
}()
}
// Wait for all goroutines
for i := 0; i < 100; i++ {
<-done
}
}
// Benchmark tests
func BenchmarkConnectionSend(b *testing.B) {
logger := zap.NewNop()
conn := NewConnection("test", nil, logger)
// Drain channel in background
go func() {
for range conn.SendCh {
}
}()
data := []byte("test message")
b.ResetTimer()
for i := 0; i < b.N; i++ {
conn.Send(data)
}
}
func BenchmarkConnectionUpdateActivity(b *testing.B) {
logger := zap.NewNop()
conn := NewConnection("test", nil, logger)
b.ResetTimer()
for i := 0; i < b.N; i++ {
conn.UpdateActivity()
}
}
func BenchmarkConnectionIsAlive(b *testing.B) {
logger := zap.NewNop()
conn := NewConnection("test", nil, logger)
timeout := 90 * time.Second
b.ResetTimer()
for i := 0; i < b.N; i++ {
conn.IsAlive(timeout)
}
}

View File

@@ -1,376 +0,0 @@
package tunnel
import (
"sync"
"testing"
"time"
"go.uber.org/zap"
)
func TestNewManager(t *testing.T) {
logger := zap.NewNop()
manager := NewManager(logger)
if manager == nil {
t.Fatal("NewManager() returned nil")
}
if manager.tunnels == nil {
t.Error("Manager tunnels map is nil")
}
if manager.used == nil {
t.Error("Manager used map is nil")
}
if manager.logger == nil {
t.Error("Manager logger is nil")
}
}
func TestManagerRegister(t *testing.T) {
logger := zap.NewNop()
manager := NewManager(logger)
// Register with empty subdomain (auto-generate)
subdomain, err := manager.Register(nil, "")
if err != nil {
t.Errorf("Register() error = %v, want nil", err)
}
if subdomain == "" {
t.Error("Register() returned empty subdomain")
}
if len(subdomain) != 6 {
t.Errorf("Register() subdomain length = %d, want 6", len(subdomain))
}
// Verify connection is registered
_, ok := manager.Get(subdomain)
if !ok {
t.Error("Get() failed to retrieve registered connection")
}
}
func TestManagerRegisterCustomSubdomain(t *testing.T) {
logger := zap.NewNop()
manager := NewManager(logger)
customSubdomain := "mytest"
// Register with custom subdomain
subdomain, err := manager.Register(nil, customSubdomain)
if err != nil {
t.Errorf("Register() error = %v, want nil", err)
}
if subdomain != customSubdomain {
t.Errorf("Register() subdomain = %v, want %v", subdomain, customSubdomain)
}
// Verify connection is registered
conn, ok := manager.Get(subdomain)
if !ok {
t.Error("Get() failed to retrieve registered connection")
}
if conn.Subdomain != customSubdomain {
t.Errorf("Connection subdomain = %v, want %v", conn.Subdomain, customSubdomain)
}
}
func TestManagerRegisterDuplicate(t *testing.T) {
logger := zap.NewNop()
manager := NewManager(logger)
customSubdomain := "test123"
// Register first connection
_, err := manager.Register(nil, customSubdomain)
if err != nil {
t.Fatalf("First Register() error = %v, want nil", err)
}
// Try to register second connection with same subdomain
_, err = manager.Register(nil, customSubdomain)
if err != ErrSubdomainTaken {
t.Errorf("Register() error = %v, want %v", err, ErrSubdomainTaken)
}
}
func TestManagerRegisterInvalidSubdomain(t *testing.T) {
logger := zap.NewNop()
manager := NewManager(logger)
tests := []struct {
name string
subdomain string
wantErr error
}{
{
name: "invalid uppercase",
subdomain: "TEST",
wantErr: ErrInvalidSubdomain,
},
{
name: "invalid special char",
subdomain: "test@123",
wantErr: ErrInvalidSubdomain,
},
{
name: "reserved www",
subdomain: "www",
wantErr: ErrReservedSubdomain,
},
{
name: "reserved api",
subdomain: "api",
wantErr: ErrReservedSubdomain,
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
_, err := manager.Register(nil, tt.subdomain)
if err != tt.wantErr {
t.Errorf("Register() error = %v, want %v", err, tt.wantErr)
}
})
}
}
func TestManagerUnregister(t *testing.T) {
logger := zap.NewNop()
manager := NewManager(logger)
// Register connection
subdomain, err := manager.Register(nil, "")
if err != nil {
t.Fatalf("Register() error = %v", err)
}
// Unregister connection
manager.Unregister(subdomain)
// Verify connection is removed
_, ok := manager.Get(subdomain)
if ok {
t.Error("Get() succeeded after Unregister(), want failure")
}
}
func TestManagerGet(t *testing.T) {
logger := zap.NewNop()
manager := NewManager(logger)
customSubdomain := "test123"
// Test Get on non-existent connection
_, ok := manager.Get(customSubdomain)
if ok {
t.Error("Get() succeeded for non-existent connection")
}
// Register and test Get
subdomain, _ := manager.Register(nil, customSubdomain)
retrieved, ok := manager.Get(subdomain)
if !ok {
t.Error("Get() failed for existing connection")
}
if retrieved.Subdomain != customSubdomain {
t.Error("Get() returned wrong connection")
}
}
func TestManagerList(t *testing.T) {
logger := zap.NewNop()
manager := NewManager(logger)
// Test empty manager
all := manager.List()
if len(all) != 0 {
t.Errorf("List() on empty manager returned %d connections, want 0", len(all))
}
// Add multiple connections
count := 5
for i := 0; i < count; i++ {
manager.Register(nil, "")
}
// Test List
all = manager.List()
if len(all) != count {
t.Errorf("List() returned %d connections, want %d", len(all), count)
}
}
func TestManagerCount(t *testing.T) {
logger := zap.NewNop()
manager := NewManager(logger)
// Test empty manager
count := manager.Count()
if count != 0 {
t.Errorf("Count() on empty manager = %d, want 0", count)
}
// Add connections
numConns := 3
for i := 0; i < numConns; i++ {
manager.Register(nil, "")
}
count = manager.Count()
if count != numConns {
t.Errorf("Count() = %d, want %d", count, numConns)
}
}
func TestManagerGenerateSubdomain(t *testing.T) {
logger := zap.NewNop()
manager := NewManager(logger)
// Generate subdomain via Register
subdomain1, err := manager.Register(nil, "")
if err != nil {
t.Fatalf("First Register() error = %v", err)
}
if subdomain1 == "" {
t.Error("Register() returned empty subdomain")
}
if len(subdomain1) != 6 {
t.Errorf("Register() subdomain length = %d, want 6", len(subdomain1))
}
// Generate another subdomain, should be different
subdomain2, err := manager.Register(nil, "")
if err != nil {
t.Fatalf("Second Register() error = %v", err)
}
if subdomain1 == subdomain2 {
t.Error("Register() generated duplicate subdomain")
}
}
func TestManagerGenerateSubdomainUniqueness(t *testing.T) {
logger := zap.NewNop()
manager := NewManager(logger)
subdomains := make(map[string]bool)
count := 100
for i := 0; i < count; i++ {
subdomain, err := manager.Register(nil, "")
if err != nil {
t.Fatalf("Register() error = %v", err)
}
if subdomains[subdomain] {
t.Errorf("Register() generated duplicate: %s", subdomain)
}
subdomains[subdomain] = true
}
}
func TestManagerCleanupStale(t *testing.T) {
logger := zap.NewNop()
manager := NewManager(logger)
// Create fresh connection
freshSubdomain, _ := manager.Register(nil, "fresh")
// Create stale connection
staleSubdomain, _ := manager.Register(nil, "stale")
// Manually set LastActive to be stale
if staleConn, ok := manager.Get(staleSubdomain); ok {
staleConn.mu.Lock()
staleConn.LastActive = time.Now().Add(-2 * time.Minute)
staleConn.mu.Unlock()
}
// Run cleanup with 90 second timeout
count := manager.CleanupStale(90 * time.Second)
if count != 1 {
t.Errorf("CleanupStale() returned %d, want 1", count)
}
// Fresh connection should still exist
_, ok := manager.Get(freshSubdomain)
if !ok {
t.Error("CleanupStale() removed fresh connection")
}
// Stale connection should be removed
_, ok = manager.Get(staleSubdomain)
if ok {
t.Error("CleanupStale() did not remove stale connection")
}
}
func TestManagerConcurrentAccess(t *testing.T) {
logger := zap.NewNop()
manager := NewManager(logger)
var wg sync.WaitGroup
count := 50 // Reduced from 100 to avoid potential issues
// Concurrent registrations
for i := 0; i < count; i++ {
wg.Add(1)
go func(idx int) {
defer wg.Done()
_, err := manager.Register(nil, "")
if err != nil {
t.Errorf("Concurrent Register() error = %v", err)
}
}(i)
}
wg.Wait()
// Verify all connections are registered
all := manager.List()
if len(all) != count {
t.Errorf("Expected %d connections, got %d", count, len(all))
}
}
// Benchmark tests
func BenchmarkManagerRegister(b *testing.B) {
logger := zap.NewNop()
manager := NewManager(logger)
b.ResetTimer()
for i := 0; i < b.N; i++ {
manager.Register(nil, "")
}
}
func BenchmarkManagerGet(b *testing.B) {
logger := zap.NewNop()
manager := NewManager(logger)
// Setup: register a connection
subdomain, _ := manager.Register(nil, "test123")
b.ResetTimer()
for i := 0; i < b.N; i++ {
manager.Get(subdomain)
}
}
func BenchmarkManagerGenerateSubdomain(b *testing.B) {
logger := zap.NewNop()
b.ResetTimer()
for i := 0; i < b.N; i++ {
manager := NewManager(logger)
manager.Register(nil, "")
}
}

View File

@@ -5,8 +5,7 @@ import (
"sync"
)
// HeaderPool manages a pool of http.Header objects for reuse
// This reduces GC pressure from repeated header map allocations
// HeaderPool manages a pool of http.Header objects for reuse.
type HeaderPool struct {
pool sync.Pool
}
@@ -16,32 +15,26 @@ func NewHeaderPool() *HeaderPool {
return &HeaderPool{
pool: sync.Pool{
New: func() interface{} {
// Pre-allocate with capacity for common header count (8-12 headers)
return make(http.Header, 12)
},
},
}
}
// Get retrieves a header from the pool
// Returns a clean, empty header ready for use
// Get retrieves a header from the pool.
func (p *HeaderPool) Get() http.Header {
h := p.pool.Get().(http.Header)
// Clear any existing data (headers might be dirty from previous use)
for k := range h {
delete(h, k)
}
return h
}
// Put returns a header to the pool
// The header will be reused by future Get() calls
// Put returns a header to the pool.
func (p *HeaderPool) Put(h http.Header) {
if h == nil {
return
}
// Note: We don't clear here, clearing is done in Get() for better performance
// (allows the GC to collect during idle time)
p.pool.Put(h)
}

View File

@@ -1,7 +1,7 @@
package protocol
import (
"encoding/json"
json "github.com/goccy/go-json"
"errors"
"github.com/vmihailenco/msgpack/v5"

View File

@@ -1,307 +0,0 @@
package protocol
import (
"encoding/json"
"reflect"
"testing"
)
func TestMessageType_Values(t *testing.T) {
tests := []struct {
name string
mt MessageType
want string
}{
{
name: "register",
mt: TypeRegister,
want: "register",
},
{
name: "request",
mt: TypeRequest,
want: "request",
},
{
name: "response",
mt: TypeResponse,
want: "response",
},
{
name: "heartbeat",
mt: TypeHeartbeat,
want: "heartbeat",
},
{
name: "error",
mt: TypeError,
want: "error",
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
got := string(tt.mt)
if got != tt.want {
t.Errorf("MessageType value = %v, want %v", got, tt.want)
}
})
}
}
func TestMessage_JSON(t *testing.T) {
tests := []struct {
name string
message *Message
wantErr bool
}{
{
name: "simple message",
message: &Message{
Type: TypeRegister,
ID: "test-id-123",
},
wantErr: false,
},
{
name: "message with subdomain",
message: &Message{
Type: TypeRegister,
ID: "test-id-456",
Subdomain: "abc123",
},
wantErr: false,
},
{
name: "message with data",
message: &Message{
Type: TypeRequest,
ID: "test-id-789",
Data: map[string]interface{}{
"method": "GET",
"path": "/test",
"headers": map[string]interface{}{
"User-Agent": "Test",
},
},
},
wantErr: false,
},
{
name: "error message",
message: &Message{
Type: TypeError,
ID: "test-id-error",
Data: map[string]interface{}{
"error": "something went wrong",
"code": float64(500), // JSON unmarshals numbers as float64
},
},
wantErr: false,
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
// Marshal to JSON
data, err := json.Marshal(tt.message)
if (err != nil) != tt.wantErr {
t.Errorf("json.Marshal() error = %v, wantErr %v", err, tt.wantErr)
return
}
if tt.wantErr {
return
}
// Unmarshal back
var decoded Message
err = json.Unmarshal(data, &decoded)
if err != nil {
t.Errorf("json.Unmarshal() error = %v", err)
return
}
// Compare
if decoded.Type != tt.message.Type {
t.Errorf("Type = %v, want %v", decoded.Type, tt.message.Type)
}
if decoded.ID != tt.message.ID {
t.Errorf("ID = %v, want %v", decoded.ID, tt.message.ID)
}
if decoded.Subdomain != tt.message.Subdomain {
t.Errorf("Subdomain = %v, want %v", decoded.Subdomain, tt.message.Subdomain)
}
// Deep compare Data if present
if tt.message.Data != nil {
if !reflect.DeepEqual(decoded.Data, tt.message.Data) {
t.Errorf("Data = %v, want %v", decoded.Data, tt.message.Data)
}
}
})
}
}
func TestHTTPRequest_JSON(t *testing.T) {
req := &HTTPRequest{
Method: "POST",
URL: "http://localhost:3000/api/test",
Headers: map[string][]string{
"Content-Type": {"application/json"},
"User-Agent": {"Test Agent"},
},
Body: []byte(`{"key":"value"}`),
}
// Marshal
data, err := json.Marshal(req)
if err != nil {
t.Fatalf("json.Marshal() error = %v", err)
}
// Unmarshal
var decoded HTTPRequest
err = json.Unmarshal(data, &decoded)
if err != nil {
t.Fatalf("json.Unmarshal() error = %v", err)
}
// Compare
if decoded.Method != req.Method {
t.Errorf("Method = %v, want %v", decoded.Method, req.Method)
}
if decoded.URL != req.URL {
t.Errorf("URL = %v, want %v", decoded.URL, req.URL)
}
if !reflect.DeepEqual(decoded.Headers, req.Headers) {
t.Errorf("Headers = %v, want %v", decoded.Headers, req.Headers)
}
if string(decoded.Body) != string(req.Body) {
t.Errorf("Body = %v, want %v", string(decoded.Body), string(req.Body))
}
}
func TestHTTPResponse_JSON(t *testing.T) {
resp := &HTTPResponse{
StatusCode: 200,
Status: "200 OK",
Headers: map[string][]string{
"Content-Type": {"text/html"},
},
Body: []byte("<html>Test</html>"),
}
// Marshal
data, err := json.Marshal(resp)
if err != nil {
t.Fatalf("json.Marshal() error = %v", err)
}
// Unmarshal
var decoded HTTPResponse
err = json.Unmarshal(data, &decoded)
if err != nil {
t.Fatalf("json.Unmarshal() error = %v", err)
}
// Compare
if decoded.StatusCode != resp.StatusCode {
t.Errorf("StatusCode = %v, want %v", decoded.StatusCode, resp.StatusCode)
}
if decoded.Status != resp.Status {
t.Errorf("Status = %v, want %v", decoded.Status, resp.Status)
}
if !reflect.DeepEqual(decoded.Headers, resp.Headers) {
t.Errorf("Headers = %v, want %v", decoded.Headers, resp.Headers)
}
if string(decoded.Body) != string(resp.Body) {
t.Errorf("Body = %v, want %v", string(decoded.Body), string(resp.Body))
}
}
func TestMessage_ToMap(t *testing.T) {
msg := &Message{
Type: TypeRequest,
ID: "test-123",
Subdomain: "abc",
Data: map[string]interface{}{
"test": "value",
},
}
// Convert to map (simulated by marshaling and unmarshaling)
data, err := json.Marshal(msg)
if err != nil {
t.Fatalf("json.Marshal() error = %v", err)
}
var result map[string]interface{}
err = json.Unmarshal(data, &result)
if err != nil {
t.Fatalf("json.Unmarshal() error = %v", err)
}
// Verify fields exist
if result["type"] == nil {
t.Error("Map missing 'type' field")
}
if result["id"] == nil {
t.Error("Map missing 'id' field")
}
}
func TestNewMessage(t *testing.T) {
msgType := TypeRegister
id := "test-id"
msg := &Message{
Type: msgType,
ID: id,
}
if msg.Type != msgType {
t.Errorf("Type = %v, want %v", msg.Type, msgType)
}
if msg.ID != id {
t.Errorf("ID = %v, want %v", msg.ID, id)
}
}
// Benchmark tests
func BenchmarkMessageMarshal(b *testing.B) {
msg := &Message{
Type: TypeRequest,
ID: "test-id-123",
Subdomain: "abc123",
Data: map[string]interface{}{
"method": "GET",
"path": "/test",
},
}
b.ResetTimer()
for i := 0; i < b.N; i++ {
json.Marshal(msg)
}
}
func BenchmarkMessageUnmarshal(b *testing.B) {
msg := &Message{
Type: TypeRequest,
ID: "test-id-123",
Subdomain: "abc123",
Data: map[string]interface{}{
"method": "GET",
"path": "/test",
},
}
data, _ := json.Marshal(msg)
b.ResetTimer()
for i := 0; i < b.N; i++ {
var decoded Message
json.Unmarshal(data, &decoded)
}
}

View File

@@ -1,6 +1,6 @@
package protocol
import "encoding/json"
import json "github.com/goccy/go-json"
// RegisterRequest is sent by client to register a tunnel
type RegisterRequest struct {

View File

@@ -1,7 +1,7 @@
package protocol
import (
"encoding/json"
json "github.com/goccy/go-json"
"errors"
)

View File

@@ -17,6 +17,10 @@ type FrameWriter struct {
maxBatch int
maxBatchWait time.Duration
heartbeatInterval time.Duration
heartbeatCallback func() *Frame
heartbeatEnabled bool
}
func NewFrameWriter(conn io.Writer) *FrameWriter {
@@ -37,8 +41,22 @@ func NewFrameWriterWithConfig(conn io.Writer, maxBatch int, maxBatchWait time.Du
}
func (w *FrameWriter) writeLoop() {
ticker := time.NewTicker(w.maxBatchWait)
defer ticker.Stop()
batchTicker := time.NewTicker(w.maxBatchWait)
defer batchTicker.Stop()
var heartbeatTicker *time.Ticker
var heartbeatCh <-chan time.Time
w.mu.Lock()
if w.heartbeatEnabled && w.heartbeatInterval > 0 {
heartbeatTicker = time.NewTicker(w.heartbeatInterval)
heartbeatCh = heartbeatTicker.C
}
w.mu.Unlock()
if heartbeatTicker != nil {
defer heartbeatTicker.Stop()
}
for {
select {
@@ -58,13 +76,23 @@ func (w *FrameWriter) writeLoop() {
}
w.mu.Unlock()
case <-ticker.C:
case <-batchTicker.C:
w.mu.Lock()
if len(w.batch) > 0 {
w.flushBatchLocked()
}
w.mu.Unlock()
case <-heartbeatCh:
w.mu.Lock()
if w.heartbeatCallback != nil {
if frame := w.heartbeatCallback(); frame != nil {
w.batch = append(w.batch, frame)
w.flushBatchLocked()
}
}
w.mu.Unlock()
case <-w.done:
w.mu.Lock()
w.flushBatchLocked()
@@ -86,6 +114,8 @@ func (w *FrameWriter) flushBatchLocked() {
w.batch = w.batch[:0]
}
// WriteFrame queues a frame to be written by the write loop.
// Blocks if the queue is full to ensure all writes go through the single write loop.
func (w *FrameWriter) WriteFrame(frame *Frame) error {
w.mu.Lock()
if w.closed {
@@ -99,8 +129,6 @@ func (w *FrameWriter) WriteFrame(frame *Frame) error {
return nil
case <-w.done:
return errors.New("writer closed")
default:
return WriteFrame(w.conn, frame)
}
}
@@ -124,3 +152,19 @@ func (w *FrameWriter) Flush() {
defer w.mu.Unlock()
w.flushBatchLocked()
}
// EnableHeartbeat enables automatic heartbeat sending in the write loop.
func (w *FrameWriter) EnableHeartbeat(interval time.Duration, callback func() *Frame) {
w.mu.Lock()
defer w.mu.Unlock()
w.heartbeatInterval = interval
w.heartbeatCallback = callback
w.heartbeatEnabled = true
}
// DisableHeartbeat disables automatic heartbeat sending.
func (w *FrameWriter) DisableHeartbeat() {
w.mu.Lock()
defer w.mu.Unlock()
w.heartbeatEnabled = false
}

View File

@@ -1,158 +0,0 @@
package utils
import (
"strings"
"testing"
)
func TestGenerateID(t *testing.T) {
tests := []struct {
name string
wantLength int // expected minimum length
}{
{
name: "generate valid ID",
wantLength: 16, // At least 16 characters for hex-encoded random bytes
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
got := GenerateID()
// Check that ID is not empty
if got == "" {
t.Error("GenerateID() returned empty string")
}
// Check minimum length
if len(got) < tt.wantLength {
t.Errorf("GenerateID() length = %v, want at least %v", len(got), tt.wantLength)
}
// Check that it's a valid hex string
for _, char := range got {
if !isHexChar(char) {
t.Errorf("GenerateID() contains non-hex character: %c", char)
}
}
})
}
}
func TestGenerateIDUniqueness(t *testing.T) {
// Generate 10000 IDs and check for uniqueness
ids := make(map[string]bool)
count := 10000
for i := 0; i < count; i++ {
id := GenerateID()
if ids[id] {
t.Errorf("GenerateID() generated duplicate: %s", id)
}
ids[id] = true
}
if len(ids) != count {
t.Errorf("Expected %d unique IDs, got %d", count, len(ids))
}
}
func TestGenerateIDFormat(t *testing.T) {
id := GenerateID()
// Check that it's lowercase
if id != strings.ToLower(id) {
t.Errorf("GenerateID() is not lowercase: %s", id)
}
// Check that it doesn't contain special characters
for _, char := range id {
if !isHexChar(char) {
t.Errorf("GenerateID() contains invalid character: %c in %s", char, id)
}
}
}
func TestGenerateIDConsistency(t *testing.T) {
// Generate multiple IDs and ensure they all follow the same format
count := 100
firstID := GenerateID()
firstLen := len(firstID)
for i := 0; i < count; i++ {
id := GenerateID()
// All IDs should have the same length
if len(id) != firstLen {
t.Errorf("ID length inconsistency: first=%d, current=%d", firstLen, len(id))
}
// All IDs should be hex strings
for _, char := range id {
if !isHexChar(char) {
t.Errorf("Invalid hex character %c in ID: %s", char, id)
}
}
}
}
func TestGenerateIDNotEmpty(t *testing.T) {
// Generate 1000 IDs and ensure none are empty
for i := 0; i < 1000; i++ {
id := GenerateID()
if id == "" {
t.Error("GenerateID() returned empty string")
}
if len(id) == 0 {
t.Error("GenerateID() returned zero-length string")
}
}
}
// Helper function to check if a character is a valid hex character
func isHexChar(char rune) bool {
return (char >= '0' && char <= '9') || (char >= 'a' && char <= 'f')
}
// Benchmark tests
func BenchmarkGenerateID(b *testing.B) {
for i := 0; i < b.N; i++ {
GenerateID()
}
}
func BenchmarkGenerateIDParallel(b *testing.B) {
b.RunParallel(func(pb *testing.PB) {
for pb.Next() {
GenerateID()
}
})
}
// Test for concurrent ID generation
func TestGenerateIDConcurrent(t *testing.T) {
count := 1000
ch := make(chan string, count)
// Generate IDs concurrently
for i := 0; i < count; i++ {
go func() {
ch <- GenerateID()
}()
}
// Collect all IDs
ids := make(map[string]bool)
for i := 0; i < count; i++ {
id := <-ch
if ids[id] {
t.Errorf("Concurrent GenerateID() generated duplicate: %s", id)
}
ids[id] = true
}
if len(ids) != count {
t.Errorf("Expected %d unique IDs, got %d", count, len(ids))
}
}

View File

@@ -1,266 +0,0 @@
package utils
import (
"strings"
"testing"
)
func TestGenerateSubdomain(t *testing.T) {
tests := []struct {
name string
length int
want int // expected length
}{
{
name: "default length 6",
length: 6,
want: 6,
},
{
name: "length 8",
length: 8,
want: 8,
},
{
name: "length 10",
length: 10,
want: 10,
},
{
name: "minimum length 4",
length: 4,
want: 4,
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
got := GenerateSubdomain(tt.length)
// Check length
if len(got) != tt.want {
t.Errorf("GenerateSubdomain() length = %v, want %v", len(got), tt.want)
}
// Check that it only contains alphanumeric characters
for _, char := range got {
if !isAlphanumeric(char) {
t.Errorf("GenerateSubdomain() contains non-alphanumeric character: %c", char)
}
}
// Check that it's lowercase
if got != strings.ToLower(got) {
t.Errorf("GenerateSubdomain() is not lowercase: %s", got)
}
})
}
}
func TestGenerateSubdomainUniqueness(t *testing.T) {
// Generate 1000 subdomains and check for uniqueness
subdomains := make(map[string]bool)
count := 1000
length := 6
for i := 0; i < count; i++ {
subdomain := GenerateSubdomain(length)
if subdomains[subdomain] {
t.Errorf("GenerateSubdomain() generated duplicate: %s", subdomain)
}
subdomains[subdomain] = true
}
if len(subdomains) != count {
t.Errorf("Expected %d unique subdomains, got %d", count, len(subdomains))
}
}
func TestValidateSubdomain(t *testing.T) {
tests := []struct {
name string
subdomain string
want bool
}{
{
name: "valid lowercase",
subdomain: "abc123",
want: true,
},
{
name: "valid all letters",
subdomain: "abcdef",
want: true,
},
{
name: "valid all numbers",
subdomain: "123456",
want: true,
},
{
name: "invalid uppercase",
subdomain: "ABC123",
want: false,
},
{
name: "valid with hyphen",
subdomain: "abc-123",
want: true,
},
{
name: "invalid starting with hyphen",
subdomain: "-abc123",
want: false,
},
{
name: "invalid ending with hyphen",
subdomain: "abc123-",
want: false,
},
{
name: "invalid with underscore",
subdomain: "abc_123",
want: false,
},
{
name: "invalid with dot",
subdomain: "abc.123",
want: false,
},
{
name: "invalid with space",
subdomain: "abc 123",
want: false,
},
{
name: "invalid empty",
subdomain: "",
want: false,
},
{
name: "invalid special characters",
subdomain: "abc@123",
want: false,
},
{
name: "valid minimum length",
subdomain: "abc",
want: true,
},
{
name: "invalid too short",
subdomain: "ab",
want: false,
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
got := ValidateSubdomain(tt.subdomain)
if got != tt.want {
t.Errorf("ValidateSubdomain(%q) = %v, want %v", tt.subdomain, got, tt.want)
}
})
}
}
func TestIsReserved(t *testing.T) {
tests := []struct {
name string
subdomain string
want bool
}{
{
name: "reserved www",
subdomain: "www",
want: true,
},
{
name: "reserved api",
subdomain: "api",
want: true,
},
{
name: "reserved admin",
subdomain: "admin",
want: true,
},
{
name: "reserved mail",
subdomain: "mail",
want: true,
},
{
name: "reserved ftp",
subdomain: "ftp",
want: true,
},
{
name: "reserved health",
subdomain: "health",
want: true,
},
{
name: "reserved test",
subdomain: "test",
want: true,
},
{
name: "reserved dev",
subdomain: "dev",
want: true,
},
{
name: "reserved staging",
subdomain: "staging",
want: true,
},
{
name: "not reserved random",
subdomain: "abc123",
want: false,
},
{
name: "not reserved user",
subdomain: "myapp",
want: false,
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
got := IsReserved(tt.subdomain)
if got != tt.want {
t.Errorf("IsReserved(%q) = %v, want %v", tt.subdomain, got, tt.want)
}
})
}
}
// Helper function to check if a character is alphanumeric
func isAlphanumeric(char rune) bool {
return (char >= 'a' && char <= 'z') || (char >= '0' && char <= '9')
}
// Benchmark tests
func BenchmarkGenerateSubdomain(b *testing.B) {
for i := 0; i < b.N; i++ {
GenerateSubdomain(6)
}
}
func BenchmarkValidateSubdomain(b *testing.B) {
subdomain := "abc123"
b.ResetTimer()
for i := 0; i < b.N; i++ {
ValidateSubdomain(subdomain)
}
}
func BenchmarkIsReserved(b *testing.B) {
subdomain := "www"
b.ResetTimer()
for i := 0; i < b.N; i++ {
IsReserved(subdomain)
}
}

View File

@@ -43,7 +43,6 @@ func LoadClientConfig(path string) (*ClientConfig, error) {
return nil, fmt.Errorf("failed to parse config file: %w", err)
}
// Validate required fields
if config.Server == "" {
return nil, fmt.Errorf("server address is required in config")
}
@@ -57,13 +56,11 @@ func SaveClientConfig(config *ClientConfig, path string) error {
path = DefaultClientConfigPath()
}
// Create directory if not exists
dir := filepath.Dir(path)
if err := os.MkdirAll(dir, 0700); err != nil {
return fmt.Errorf("failed to create config directory: %w", err)
}
// Marshal to YAML
data, err := yaml.Marshal(config)
if err != nil {
return fmt.Errorf("failed to marshal config: %w", err)

View File

@@ -47,7 +47,6 @@ func (c *ServerConfig) LoadTLSConfig() (*tls.Config, error) {
return nil, nil
}
// Check if certificate files exist
if c.TLSCertFile == "" || c.TLSKeyFile == "" {
return nil, fmt.Errorf("TLS enabled but certificate files not specified")
}
@@ -60,7 +59,6 @@ func (c *ServerConfig) LoadTLSConfig() (*tls.Config, error) {
return nil, fmt.Errorf("key file not found: %s", c.TLSKeyFile)
}
// Load certificate
cert, err := tls.LoadX509KeyPair(c.TLSCertFile, c.TLSKeyFile)
if err != nil {
return nil, fmt.Errorf("failed to load certificate: %w", err)