mirror of
https://github.com/Gouryella/drip.git
synced 2026-02-23 21:00:44 +00:00
perf (client): Refactored the command-line interface and enhanced user experience
- Optimized terminal output style and configuration management using libraries such as `lipgloss` and `go-json` - Introduced the `ui` package to unify the display logic of colors, tables, and prompts - Updated the README document structure and installation script links to improve readability and internationalization support - Improved the interaction flow and log display effects of the daemon startup and attach commands - Fixed some command parameter parsing issues, improving program robustness and user onboarding experience
This commit is contained in:
41
README.md
41
README.md
@@ -1,12 +1,37 @@
|
||||
# Drip - Fast Tunnels to Localhost
|
||||
<p align="center">
|
||||
<img src="images/logo.png" alt="Drip Logo" width="128" />
|
||||
</p>
|
||||
|
||||
Self-hosted tunneling solution. Expose your localhost to the internet securely.
|
||||
<p align="center" style="font-size: 44px; font-weight: 600; margin: 0;">Drip</p>
|
||||
<p align="center" style="font-size: 20px; font-weight: 500; margin: 8px 0 0;">
|
||||
Your Tunnel, Your Domain, Anywhere
|
||||
</p>
|
||||
|
||||
[中文文档](README_CN.md)
|
||||
<p align="center">
|
||||
A self-hosted tunneling solution to securely expose your services to the internet.
|
||||
</p>
|
||||
|
||||
<p align="center ">
|
||||
<a href="README.md">English</a>
|
||||
<span> | </span>
|
||||
<a href="README_CN.md">中文文档</a>
|
||||
</p>
|
||||
|
||||
<p align="center">
|
||||
<a href="https://golang.org/">
|
||||
<img src="https://img.shields.io/badge/Go-1.21+-00ADD8?style=flat&logo=go" alt="Go Version" />
|
||||
</a>
|
||||
<a href="LICENSE">
|
||||
<img src="https://img.shields.io/badge/License-BSD--3--Clause-blue.svg" alt="License" />
|
||||
</a>
|
||||
<a href="https://tools.ietf.org/html/rfc8446">
|
||||
<img src="https://img.shields.io/badge/TLS-1.3-green.svg" alt="TLS" />
|
||||
</a>
|
||||
</p>
|
||||
|
||||
> Drip is a quiet, disciplined tunnel.
|
||||
> You light a small lamp on your network, and it carries that light outward—through your own infrastructure, on your own terms.
|
||||
|
||||
[](https://golang.org/)
|
||||
[](LICENSE)
|
||||
[](https://tools.ietf.org/html/rfc8446)
|
||||
|
||||
## Why?
|
||||
|
||||
@@ -32,13 +57,13 @@ Self-hosted tunneling solution. Expose your localhost to the internet securely.
|
||||
### Client (macOS/Linux)
|
||||
|
||||
```bash
|
||||
bash <(curl -sL https://raw.githubusercontent.com/Gouryella/drip/refs/heads/main/scripts/install.sh)
|
||||
bash <(curl -sL https://raw.githubusercontent.com/Gouryella/drip/main/scripts/install.sh)
|
||||
```
|
||||
|
||||
### Server (Linux)
|
||||
|
||||
```bash
|
||||
bash <(curl -sL https://raw.githubusercontent.com/Gouryella/drip/refs/heads/main/scripts/install-server.sh)
|
||||
bash <(curl -sL https://raw.githubusercontent.com/Gouryella/drip/main/scripts/install-server.sh)
|
||||
```
|
||||
|
||||
## Usage
|
||||
|
||||
@@ -32,13 +32,13 @@
|
||||
### 客户端 (macOS/Linux)
|
||||
|
||||
```bash
|
||||
bash <(curl -sL https://raw.githubusercontent.com/Gouryella/drip/refs/heads/main/scripts/install.sh)
|
||||
bash <(curl -sL https://raw.githubusercontent.com/Gouryella/drip/main/scripts/install.sh)
|
||||
```
|
||||
|
||||
### 服务端 (Linux)
|
||||
|
||||
```bash
|
||||
bash <(curl -sL https://raw.githubusercontent.com/Gouryella/drip/refs/heads/main/scripts/install-server.sh)
|
||||
bash <(curl -sL https://raw.githubusercontent.com/Gouryella/drip/main/scripts/install-server.sh)
|
||||
```
|
||||
|
||||
## 使用方法
|
||||
|
||||
@@ -14,10 +14,8 @@ var (
|
||||
)
|
||||
|
||||
func main() {
|
||||
// Set version information
|
||||
cli.SetVersion(Version, GitCommit, BuildTime)
|
||||
|
||||
// Execute CLI
|
||||
if err := cli.Execute(); err != nil {
|
||||
fmt.Fprintf(os.Stderr, "Error: %v\n", err)
|
||||
os.Exit(1)
|
||||
|
||||
14
go.mod
14
go.mod
@@ -3,6 +3,8 @@ module drip
|
||||
go 1.25.4
|
||||
|
||||
require (
|
||||
github.com/charmbracelet/lipgloss v1.1.0
|
||||
github.com/goccy/go-json v0.10.5
|
||||
github.com/gorilla/websocket v1.5.3
|
||||
github.com/spf13/cobra v1.10.1
|
||||
github.com/vmihailenco/msgpack/v5 v5.4.1
|
||||
@@ -12,11 +14,23 @@ require (
|
||||
)
|
||||
|
||||
require (
|
||||
github.com/aymanbagabas/go-osc52/v2 v2.0.1 // indirect
|
||||
github.com/charmbracelet/colorprofile v0.2.3-0.20250311203215-f60798e515dc // indirect
|
||||
github.com/charmbracelet/x/ansi v0.8.0 // indirect
|
||||
github.com/charmbracelet/x/cellbuf v0.0.13-0.20250311204145-2c3ea96c31dd // indirect
|
||||
github.com/charmbracelet/x/term v0.2.1 // indirect
|
||||
github.com/inconshreveable/mousetrap v1.1.0 // indirect
|
||||
github.com/lucasb-eyer/go-colorful v1.2.0 // indirect
|
||||
github.com/mattn/go-isatty v0.0.20 // indirect
|
||||
github.com/mattn/go-runewidth v0.0.16 // indirect
|
||||
github.com/muesli/termenv v0.16.0 // indirect
|
||||
github.com/rivo/uniseg v0.4.7 // indirect
|
||||
github.com/spf13/pflag v1.0.10 // indirect
|
||||
github.com/stretchr/testify v1.11.1 // indirect
|
||||
github.com/vmihailenco/tagparser/v2 v2.0.0 // indirect
|
||||
github.com/xo/terminfo v0.0.0-20220910002029-abceb7e1c41e // indirect
|
||||
go.uber.org/multierr v1.11.0 // indirect
|
||||
golang.org/x/net v0.47.0 // indirect
|
||||
golang.org/x/sys v0.38.0 // indirect
|
||||
golang.org/x/text v0.31.0 // indirect
|
||||
)
|
||||
|
||||
32
go.sum
32
go.sum
@@ -1,12 +1,37 @@
|
||||
github.com/aymanbagabas/go-osc52/v2 v2.0.1 h1:HwpRHbFMcZLEVr42D4p7XBqjyuxQH5SMiErDT4WkJ2k=
|
||||
github.com/aymanbagabas/go-osc52/v2 v2.0.1/go.mod h1:uYgXzlJ7ZpABp8OJ+exZzJJhRNQ2ASbcXHWsFqH8hp8=
|
||||
github.com/charmbracelet/colorprofile v0.2.3-0.20250311203215-f60798e515dc h1:4pZI35227imm7yK2bGPcfpFEmuY1gc2YSTShr4iJBfs=
|
||||
github.com/charmbracelet/colorprofile v0.2.3-0.20250311203215-f60798e515dc/go.mod h1:X4/0JoqgTIPSFcRA/P6INZzIuyqdFY5rm8tb41s9okk=
|
||||
github.com/charmbracelet/lipgloss v1.1.0 h1:vYXsiLHVkK7fp74RkV7b2kq9+zDLoEU4MZoFqR/noCY=
|
||||
github.com/charmbracelet/lipgloss v1.1.0/go.mod h1:/6Q8FR2o+kj8rz4Dq0zQc3vYf7X+B0binUUBwA0aL30=
|
||||
github.com/charmbracelet/x/ansi v0.8.0 h1:9GTq3xq9caJW8ZrBTe0LIe2fvfLR/bYXKTx2llXn7xE=
|
||||
github.com/charmbracelet/x/ansi v0.8.0/go.mod h1:wdYl/ONOLHLIVmQaxbIYEC/cRKOQyjTkowiI4blgS9Q=
|
||||
github.com/charmbracelet/x/cellbuf v0.0.13-0.20250311204145-2c3ea96c31dd h1:vy0GVL4jeHEwG5YOXDmi86oYw2yuYUGqz6a8sLwg0X8=
|
||||
github.com/charmbracelet/x/cellbuf v0.0.13-0.20250311204145-2c3ea96c31dd/go.mod h1:xe0nKWGd3eJgtqZRaN9RjMtK7xUYchjzPr7q6kcvCCs=
|
||||
github.com/charmbracelet/x/term v0.2.1 h1:AQeHeLZ1OqSXhrAWpYUtZyX1T3zVxfpZuEQMIQaGIAQ=
|
||||
github.com/charmbracelet/x/term v0.2.1/go.mod h1:oQ4enTYFV7QN4m0i9mzHrViD7TQKvNEEkHUMCmsxdUg=
|
||||
github.com/cpuguy83/go-md2man/v2 v2.0.6/go.mod h1:oOW0eioCTA6cOiMLiUPZOpcVxMig6NIQQ7OS05n1F4g=
|
||||
github.com/davecgh/go-spew v1.1.1 h1:vj9j/u1bqnvCEfJOwUhtlOARqs3+rkHYY13jYWTU97c=
|
||||
github.com/davecgh/go-spew v1.1.1/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38=
|
||||
github.com/goccy/go-json v0.10.5 h1:Fq85nIqj+gXn/S5ahsiTlK3TmC85qgirsdTP/+DeaC4=
|
||||
github.com/goccy/go-json v0.10.5/go.mod h1:oq7eo15ShAhp70Anwd5lgX2pLfOS3QCiwU/PULtXL6M=
|
||||
github.com/gorilla/websocket v1.5.3 h1:saDtZ6Pbx/0u+bgYQ3q96pZgCzfhKXGPqt7kZ72aNNg=
|
||||
github.com/gorilla/websocket v1.5.3/go.mod h1:YR8l580nyteQvAITg2hZ9XVh4b55+EU/adAjf1fMHhE=
|
||||
github.com/inconshreveable/mousetrap v1.1.0 h1:wN+x4NVGpMsO7ErUn/mUI3vEoE6Jt13X2s0bqwp9tc8=
|
||||
github.com/inconshreveable/mousetrap v1.1.0/go.mod h1:vpF70FUmC8bwa3OWnCshd2FqLfsEA9PFc4w1p2J65bw=
|
||||
github.com/lucasb-eyer/go-colorful v1.2.0 h1:1nnpGOrhyZZuNyfu1QjKiUICQ74+3FNCN69Aj6K7nkY=
|
||||
github.com/lucasb-eyer/go-colorful v1.2.0/go.mod h1:R4dSotOR9KMtayYi1e77YzuveK+i7ruzyGqttikkLy0=
|
||||
github.com/mattn/go-isatty v0.0.20 h1:xfD0iDuEKnDkl03q4limB+vH+GxLEtL/jb4xVJSWWEY=
|
||||
github.com/mattn/go-isatty v0.0.20/go.mod h1:W+V8PltTTMOvKvAeJH7IuucS94S2C6jfK/D7dTCTo3Y=
|
||||
github.com/mattn/go-runewidth v0.0.16 h1:E5ScNMtiwvlvB5paMFdw9p4kSQzbXFikJ5SQO6TULQc=
|
||||
github.com/mattn/go-runewidth v0.0.16/go.mod h1:Jdepj2loyihRzMpdS35Xk/zdY8IAYHsh153qUoGf23w=
|
||||
github.com/muesli/termenv v0.16.0 h1:S5AlUN9dENB57rsbnkPyfdGuWIlkmzJjbFf0Tf5FWUc=
|
||||
github.com/muesli/termenv v0.16.0/go.mod h1:ZRfOIKPFDYQoDFF4Olj7/QJbW60Ol/kL1pU3VfY/Cnk=
|
||||
github.com/pmezard/go-difflib v1.0.0 h1:4DBwDE0NGyQoBHbLQYPwSUPoCMWR5BEzIk/f1lZbAQM=
|
||||
github.com/pmezard/go-difflib v1.0.0/go.mod h1:iKH77koFhYxTK1pcRnkKkqfTogsbg7gZNVY4sRDYZ/4=
|
||||
github.com/rivo/uniseg v0.2.0/go.mod h1:J6wj4VEh+S6ZtnVlnTBMWIodfgj8LQOQFoIToxlJtxc=
|
||||
github.com/rivo/uniseg v0.4.7 h1:WUdvkW8uEhrYfLC4ZzdpI2ztxP1I582+49Oc5Mq64VQ=
|
||||
github.com/rivo/uniseg v0.4.7/go.mod h1:FN3SvrM+Zdj16jyLfmOkMNblXMcoc8DfTHruCPUcx88=
|
||||
github.com/russross/blackfriday/v2 v2.1.0/go.mod h1:+Rmxgy9KzJVeS9/2gXHxylqXiyQDYRxCVz55jmeOWTM=
|
||||
github.com/spf13/cobra v1.10.1 h1:lJeBwCfmrnXthfAupyUTzJ/J4Nc1RsHC/mSRU2dll/s=
|
||||
github.com/spf13/cobra v1.10.1/go.mod h1:7SmJGaTHFVBY0jW4NXGluQoLvhqFQM+6XSKD+P4XaB0=
|
||||
@@ -19,6 +44,8 @@ github.com/vmihailenco/msgpack/v5 v5.4.1 h1:cQriyiUvjTwOHg8QZaPihLWeRAAVoCpE00IU
|
||||
github.com/vmihailenco/msgpack/v5 v5.4.1/go.mod h1:GaZTsDaehaPpQVyxrf5mtQlH+pc21PIudVV/E3rRQok=
|
||||
github.com/vmihailenco/tagparser/v2 v2.0.0 h1:y09buUbR+b5aycVFQs/g70pqKVZNBmxwAhO7/IwNM9g=
|
||||
github.com/vmihailenco/tagparser/v2 v2.0.0/go.mod h1:Wri+At7QHww0WTrCBeu4J6bNtoV6mEfg5OIWRZA9qds=
|
||||
github.com/xo/terminfo v0.0.0-20220910002029-abceb7e1c41e h1:JVG44RsyaB9T2KIHavMF/ppJZNG9ZpyihvCd0w101no=
|
||||
github.com/xo/terminfo v0.0.0-20220910002029-abceb7e1c41e/go.mod h1:RbqR21r5mrJuqunuUZ/Dhy/avygyECGrLceyNeo4LiM=
|
||||
go.uber.org/goleak v1.3.0 h1:2K3zAYmnTNqV73imy9J1T3WC+gmCePx2hEGkimedGto=
|
||||
go.uber.org/goleak v1.3.0/go.mod h1:CoHD4mav9JJNrW/WLlf7HGZPjdw8EucARQHekz1X6bE=
|
||||
go.uber.org/multierr v1.11.0 h1:blXXJkSxSSfBVBlC76pxqeO+LN3aDfLQo+309xJstO0=
|
||||
@@ -27,8 +54,13 @@ go.uber.org/zap v1.27.1 h1:08RqriUEv8+ArZRYSTXy1LeBScaMpVSTBhCeaZYfMYc=
|
||||
go.uber.org/zap v1.27.1/go.mod h1:GB2qFLM7cTU87MWRP2mPIjqfIDnGu+VIO4V/SdhGo2E=
|
||||
golang.org/x/crypto v0.45.0 h1:jMBrvKuj23MTlT0bQEOBcAE0mjg8mK9RXFhRH6nyF3Q=
|
||||
golang.org/x/crypto v0.45.0/go.mod h1:XTGrrkGJve7CYK7J8PEww4aY7gM3qMCElcJQ8n8JdX4=
|
||||
golang.org/x/exp v0.0.0-20220909182711-5c715a9e8561 h1:MDc5xs78ZrZr3HMQugiXOAkSZtfTpbJLDr/lwfgO53E=
|
||||
golang.org/x/exp v0.0.0-20220909182711-5c715a9e8561/go.mod h1:cyybsKvd6eL0RnXn6p/Grxp8F5bW7iYuBgsNCOHpMYE=
|
||||
golang.org/x/net v0.47.0 h1:Mx+4dIFzqraBXUugkia1OOvlD6LemFo1ALMHjrXDOhY=
|
||||
golang.org/x/net v0.47.0/go.mod h1:/jNxtkgq5yWUGYkaZGqo27cfGZ1c5Nen03aYrrKpVRU=
|
||||
golang.org/x/sys v0.6.0/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg=
|
||||
golang.org/x/sys v0.38.0 h1:3yZWxaJjBmCWXqhN1qh02AkOnCQ1poK6oF+a7xWL6Gc=
|
||||
golang.org/x/sys v0.38.0/go.mod h1:OgkHotnGiDImocRcuBABYBEXf8A9a87e/uXjp9XT3ks=
|
||||
golang.org/x/text v0.31.0 h1:aC8ghyu4JhP8VojJ2lEHBnochRno1sgL6nEi9WGFGMM=
|
||||
golang.org/x/text v0.31.0/go.mod h1:tKRAlv61yKIjGGHX/4tP1LTbc13YSec1pxVEWXzfoeM=
|
||||
gopkg.in/check.v1 v0.0.0-20161208181325-20d25e280405 h1:yhCVgyC4o1eVCa2tZl7eS0r+SDo693bJlVdllGtEeKM=
|
||||
|
||||
BIN
images/logo.png
Normal file
BIN
images/logo.png
Normal file
Binary file not shown.
|
After Width: | Height: | Size: 1.4 MiB |
@@ -12,6 +12,7 @@ import (
|
||||
"syscall"
|
||||
"time"
|
||||
|
||||
"drip/internal/client/cli/ui"
|
||||
"github.com/spf13/cobra"
|
||||
)
|
||||
|
||||
@@ -36,27 +37,26 @@ func init() {
|
||||
}
|
||||
|
||||
func runAttach(cmd *cobra.Command, args []string) error {
|
||||
// Clean up stale daemons first
|
||||
CleanupStaleDaemons()
|
||||
|
||||
// Get all running daemons
|
||||
daemons, err := ListAllDaemons()
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to list daemons: %w", err)
|
||||
}
|
||||
|
||||
if len(daemons) == 0 {
|
||||
fmt.Println("\033[90mNo running tunnels.\033[0m")
|
||||
fmt.Println()
|
||||
fmt.Println("Start a tunnel in background with:")
|
||||
fmt.Println(" \033[36mdrip http 3000 -d\033[0m")
|
||||
fmt.Println(" \033[36mdrip tcp 5432 -d\033[0m")
|
||||
fmt.Println(ui.Info(
|
||||
"No Running Tunnels",
|
||||
"",
|
||||
ui.Muted("Start a tunnel in background with:"),
|
||||
ui.Cyan(" drip http 3000 -d"),
|
||||
ui.Cyan(" drip tcp 5432 -d"),
|
||||
))
|
||||
return nil
|
||||
}
|
||||
|
||||
var selectedDaemon *DaemonInfo
|
||||
|
||||
// If type and port are specified, find the specific daemon
|
||||
if len(args) == 2 {
|
||||
tunnelType := args[0]
|
||||
if tunnelType != "http" && tunnelType != "tcp" {
|
||||
@@ -68,7 +68,6 @@ func runAttach(cmd *cobra.Command, args []string) error {
|
||||
return fmt.Errorf("invalid port number: %s", args[1])
|
||||
}
|
||||
|
||||
// Find the daemon
|
||||
for _, d := range daemons {
|
||||
if d.Type == tunnelType && d.Port == port {
|
||||
if !IsProcessRunning(d.PID) {
|
||||
@@ -84,29 +83,21 @@ func runAttach(cmd *cobra.Command, args []string) error {
|
||||
return fmt.Errorf("no %s tunnel running on port %d", tunnelType, port)
|
||||
}
|
||||
} else if len(args) == 0 {
|
||||
// Interactive selection
|
||||
selectedDaemon, err = selectDaemonInteractive(daemons)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
if selectedDaemon == nil {
|
||||
return nil // User cancelled
|
||||
return nil
|
||||
}
|
||||
} else {
|
||||
return fmt.Errorf("usage: drip attach [type port]")
|
||||
}
|
||||
|
||||
// Attach to the selected daemon
|
||||
return attachToDaemon(selectedDaemon)
|
||||
}
|
||||
|
||||
func selectDaemonInteractive(daemons []*DaemonInfo) (*DaemonInfo, error) {
|
||||
// Print header
|
||||
fmt.Println()
|
||||
fmt.Println("\033[1;37mSelect a tunnel to attach:\033[0m")
|
||||
fmt.Println("\033[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\033[0m")
|
||||
|
||||
// Filter out non-running daemons
|
||||
var runningDaemons []*DaemonInfo
|
||||
for _, d := range daemons {
|
||||
if IsProcessRunning(d.PID) {
|
||||
@@ -117,36 +108,36 @@ func selectDaemonInteractive(daemons []*DaemonInfo) (*DaemonInfo, error) {
|
||||
}
|
||||
|
||||
if len(runningDaemons) == 0 {
|
||||
fmt.Println("\033[90mNo running tunnels.\033[0m")
|
||||
fmt.Println(ui.Muted("No running tunnels."))
|
||||
return nil, nil
|
||||
}
|
||||
|
||||
// Print list
|
||||
table := ui.NewTable([]string{"#", "TYPE", "PORT", "URL", "UPTIME"}).
|
||||
WithTitle("Select a tunnel to attach")
|
||||
|
||||
for i, d := range runningDaemons {
|
||||
uptime := time.Since(d.StartTime)
|
||||
|
||||
// Format type with color
|
||||
var typeStr string
|
||||
if d.Type == "http" {
|
||||
typeStr = "\033[32mHTTP\033[0m"
|
||||
typeStr = ui.Success("HTTP")
|
||||
} else {
|
||||
typeStr = "\033[35mTCP\033[0m"
|
||||
typeStr = ui.Highlight("TCP")
|
||||
}
|
||||
|
||||
// Truncate URL if too long
|
||||
url := d.URL
|
||||
if len(url) > 50 {
|
||||
url = url[:47] + "..."
|
||||
}
|
||||
|
||||
fmt.Printf("\033[1;36m%d.\033[0m %-15s \033[90mPort:\033[0m %-6d \033[90mURL:\033[0m %-50s \033[90mUptime:\033[0m %s\n",
|
||||
i+1, typeStr, d.Port, url, FormatDuration(uptime))
|
||||
table.AddRow([]string{
|
||||
ui.Highlight(fmt.Sprintf("%d", i+1)),
|
||||
typeStr,
|
||||
fmt.Sprintf("%d", d.Port),
|
||||
ui.URL(d.URL),
|
||||
FormatDuration(uptime),
|
||||
})
|
||||
}
|
||||
|
||||
fmt.Println("\033[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\033[0m")
|
||||
fmt.Print(table.Render())
|
||||
|
||||
fmt.Printf("Enter number (1-%d) or 'q' to quit: ", len(runningDaemons))
|
||||
|
||||
// Read user input
|
||||
reader := bufio.NewReader(os.Stdin)
|
||||
input, err := reader.ReadString('\n')
|
||||
if err != nil {
|
||||
@@ -158,7 +149,6 @@ func selectDaemonInteractive(daemons []*DaemonInfo) (*DaemonInfo, error) {
|
||||
return nil, nil
|
||||
}
|
||||
|
||||
// Parse selection
|
||||
selection, err := strconv.Atoi(input)
|
||||
if err != nil || selection < 1 || selection > len(runningDaemons) {
|
||||
return nil, fmt.Errorf("invalid selection: %s", input)
|
||||
@@ -168,35 +158,28 @@ func selectDaemonInteractive(daemons []*DaemonInfo) (*DaemonInfo, error) {
|
||||
}
|
||||
|
||||
func attachToDaemon(daemon *DaemonInfo) error {
|
||||
// Get log file path
|
||||
logPath := filepath.Join(getDaemonDir(), fmt.Sprintf("%s_%d.log", daemon.Type, daemon.Port))
|
||||
|
||||
// Check if log file exists
|
||||
if _, err := os.Stat(logPath); os.IsNotExist(err) {
|
||||
return fmt.Errorf("log file not found: %s", logPath)
|
||||
}
|
||||
|
||||
// Print header
|
||||
fmt.Println()
|
||||
fmt.Println("\033[1;32m╔══════════════════════════════════════════════════════════════════╗\033[0m")
|
||||
fmt.Printf("\033[1;32m║\033[0m \033[1;37mAttached to %s tunnel on port %d\033[0m", strings.ToUpper(daemon.Type), daemon.Port)
|
||||
fmt.Printf("%s\033[1;32m║\033[0m\n", strings.Repeat(" ", 32-len(daemon.Type)))
|
||||
fmt.Println("\033[1;32m╠══════════════════════════════════════════════════════════════════╣\033[0m")
|
||||
fmt.Printf("\033[1;32m║\033[0m \033[90mURL:\033[0m \033[36m%-52s\033[0m \033[1;32m║\033[0m\n", daemon.URL)
|
||||
uptime := time.Since(daemon.StartTime)
|
||||
fmt.Printf("\033[1;32m║\033[0m \033[90mPID:\033[0m \033[90m%-52d\033[0m \033[1;32m║\033[0m\n", daemon.PID)
|
||||
fmt.Printf("\033[1;32m║\033[0m \033[90mUptime:\033[0m \033[90m%-52s\033[0m \033[1;32m║\033[0m\n", FormatDuration(uptime))
|
||||
fmt.Printf("\033[1;32m║\033[0m \033[90mLog:\033[0m \033[90m%-52s\033[0m \033[1;32m║\033[0m\n", truncatePath(logPath, 52))
|
||||
fmt.Println("\033[1;32m╠══════════════════════════════════════════════════════════════════╣\033[0m")
|
||||
fmt.Println("\033[1;32m║\033[0m \033[33mPress Ctrl+C to detach (tunnel will continue running)\033[0m \033[1;32m║\033[0m")
|
||||
fmt.Println("\033[1;32m╚══════════════════════════════════════════════════════════════════╝\033[0m")
|
||||
fmt.Println()
|
||||
|
||||
// Setup signal handler
|
||||
fmt.Println(ui.Info(
|
||||
fmt.Sprintf("Attached to %s tunnel on port %d", strings.ToUpper(daemon.Type), daemon.Port),
|
||||
"",
|
||||
ui.KeyValue("URL", ui.URL(daemon.URL)),
|
||||
ui.KeyValue("PID", fmt.Sprintf("%d", daemon.PID)),
|
||||
ui.KeyValue("Uptime", FormatDuration(uptime)),
|
||||
ui.KeyValue("Log", truncatePath(logPath, 48)),
|
||||
"",
|
||||
ui.Warning("Press Ctrl+C to detach (tunnel will continue running)"),
|
||||
))
|
||||
|
||||
sigCh := make(chan os.Signal, 1)
|
||||
signal.Notify(sigCh, syscall.SIGINT, syscall.SIGTERM)
|
||||
|
||||
// Start tail command
|
||||
tailCmd := exec.Command("tail", "-f", logPath)
|
||||
tailCmd.Stdout = os.Stdout
|
||||
tailCmd.Stderr = os.Stderr
|
||||
@@ -205,7 +188,6 @@ func attachToDaemon(daemon *DaemonInfo) error {
|
||||
return fmt.Errorf("failed to start tail: %w", err)
|
||||
}
|
||||
|
||||
// Wait for signal or tail to exit
|
||||
done := make(chan error, 1)
|
||||
go func() {
|
||||
done <- tailCmd.Wait()
|
||||
@@ -213,14 +195,13 @@ func attachToDaemon(daemon *DaemonInfo) error {
|
||||
|
||||
select {
|
||||
case <-sigCh:
|
||||
// Kill tail process
|
||||
if tailCmd.Process != nil {
|
||||
tailCmd.Process.Kill()
|
||||
}
|
||||
fmt.Println()
|
||||
fmt.Println("\033[33mDetached from tunnel (tunnel is still running)\033[0m")
|
||||
fmt.Printf("Use '\033[36mdrip attach %s %d\033[0m' to reattach\n", daemon.Type, daemon.Port)
|
||||
fmt.Printf("Use '\033[36mdrip stop %s %d\033[0m' to stop the tunnel\n", daemon.Type, daemon.Port)
|
||||
fmt.Println(ui.Warning("Detached from tunnel (tunnel is still running)"))
|
||||
fmt.Println(ui.Muted(fmt.Sprintf("Use '%s' to reattach", ui.Cyan(fmt.Sprintf("drip attach %s %d", daemon.Type, daemon.Port)))))
|
||||
fmt.Println(ui.Muted(fmt.Sprintf("Use '%s' to stop the tunnel", ui.Cyan(fmt.Sprintf("drip stop %s %d", daemon.Type, daemon.Port)))))
|
||||
return nil
|
||||
case err := <-done:
|
||||
if err != nil {
|
||||
@@ -234,7 +215,6 @@ func truncatePath(path string, maxLen int) string {
|
||||
if len(path) <= maxLen {
|
||||
return path
|
||||
}
|
||||
// Try to keep filename and show ... in the middle
|
||||
filename := filepath.Base(path)
|
||||
if len(filename) >= maxLen-3 {
|
||||
return "..." + filename[len(filename)-(maxLen-3):]
|
||||
|
||||
@@ -6,6 +6,7 @@ import (
|
||||
"os"
|
||||
"strings"
|
||||
|
||||
"drip/internal/client/cli/ui"
|
||||
"drip/pkg/config"
|
||||
"github.com/spf13/cobra"
|
||||
)
|
||||
@@ -59,36 +60,28 @@ var (
|
||||
)
|
||||
|
||||
func init() {
|
||||
// Add subcommands
|
||||
configCmd.AddCommand(configInitCmd)
|
||||
configCmd.AddCommand(configShowCmd)
|
||||
configCmd.AddCommand(configSetCmd)
|
||||
configCmd.AddCommand(configResetCmd)
|
||||
configCmd.AddCommand(configValidateCmd)
|
||||
|
||||
// Flags for show
|
||||
configShowCmd.Flags().BoolVar(&configFull, "full", false, "Show full token (not hidden)")
|
||||
|
||||
// Flags for set
|
||||
configSetCmd.Flags().StringVar(&configServer, "server", "", "Server address (e.g., tunnel.example.com:443)")
|
||||
configSetCmd.Flags().StringVar(&configToken, "token", "", "Authentication token")
|
||||
|
||||
// Flags for reset
|
||||
configResetCmd.Flags().BoolVar(&configForce, "force", false, "Force reset without confirmation")
|
||||
|
||||
// Add to root
|
||||
rootCmd.AddCommand(configCmd)
|
||||
}
|
||||
|
||||
func runConfigInit(cmd *cobra.Command, args []string) error {
|
||||
fmt.Println("\n╔═══════════════════════════════════════╗")
|
||||
fmt.Println("║ Drip Configuration Setup ║")
|
||||
fmt.Println("╚═══════════════════════════════════════╝")
|
||||
fmt.Print(ui.RenderConfigInit())
|
||||
|
||||
reader := bufio.NewReader(os.Stdin)
|
||||
|
||||
// Get server address
|
||||
fmt.Print("Server address (e.g., tunnel.example.com:443): ")
|
||||
fmt.Print(ui.Muted("Server address (e.g., tunnel.example.com:443): "))
|
||||
serverAddr, _ := reader.ReadString('\n')
|
||||
serverAddr = strings.TrimSpace(serverAddr)
|
||||
|
||||
@@ -96,102 +89,79 @@ func runConfigInit(cmd *cobra.Command, args []string) error {
|
||||
return fmt.Errorf("server address is required")
|
||||
}
|
||||
|
||||
// Get token
|
||||
fmt.Print("Authentication token (leave empty to skip): ")
|
||||
fmt.Print(ui.Muted("Authentication token (leave empty to skip): "))
|
||||
token, _ := reader.ReadString('\n')
|
||||
token = strings.TrimSpace(token)
|
||||
|
||||
// Create config
|
||||
cfg := &config.ClientConfig{
|
||||
Server: serverAddr,
|
||||
Token: token,
|
||||
TLS: true,
|
||||
}
|
||||
|
||||
// Save config
|
||||
if err := config.SaveClientConfig(cfg, ""); err != nil {
|
||||
return fmt.Errorf("failed to save configuration: %w", err)
|
||||
}
|
||||
|
||||
fmt.Println("\n✓ Configuration saved to", config.DefaultClientConfigPath())
|
||||
fmt.Println("✓ You can now use 'drip' without --server and --token")
|
||||
fmt.Println(ui.RenderConfigSaved(config.DefaultClientConfigPath()))
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func runConfigShow(cmd *cobra.Command, args []string) error {
|
||||
// Load config
|
||||
cfg, err := config.LoadClientConfig("")
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
fmt.Println("\n╔═══════════════════════════════════════╗")
|
||||
fmt.Println("║ Current Configuration ║")
|
||||
fmt.Println("╚═══════════════════════════════════════╝")
|
||||
|
||||
fmt.Printf("Server: %s\n", cfg.Server)
|
||||
|
||||
// Show token (hidden or full)
|
||||
var displayToken string
|
||||
if cfg.Token != "" {
|
||||
if configFull {
|
||||
fmt.Printf("Token: %s\n", cfg.Token)
|
||||
if len(cfg.Token) > 10 {
|
||||
displayToken = cfg.Token[:3] + "***" + cfg.Token[len(cfg.Token)-3:]
|
||||
} else {
|
||||
// Hide middle part of token
|
||||
if len(cfg.Token) > 10 {
|
||||
fmt.Printf("Token: %s***%s (hidden)\n",
|
||||
cfg.Token[:3],
|
||||
cfg.Token[len(cfg.Token)-3:],
|
||||
)
|
||||
} else {
|
||||
fmt.Printf("Token: %s (hidden)\n", cfg.Token[:3]+"***")
|
||||
}
|
||||
displayToken = cfg.Token[:3] + "***"
|
||||
}
|
||||
} else {
|
||||
fmt.Println("Token: (not set)")
|
||||
displayToken = ""
|
||||
}
|
||||
|
||||
fmt.Printf("TLS: %s\n", enabledDisabled(cfg.TLS))
|
||||
fmt.Printf("Config: %s\n\n", config.DefaultClientConfigPath())
|
||||
fmt.Println(ui.RenderConfigShow(cfg.Server, displayToken, !configFull, cfg.TLS, config.DefaultClientConfigPath()))
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func runConfigSet(cmd *cobra.Command, args []string) error {
|
||||
// Load existing config or create new
|
||||
cfg, err := config.LoadClientConfig("")
|
||||
if err != nil {
|
||||
// Create new config if not exists
|
||||
cfg = &config.ClientConfig{
|
||||
TLS: true,
|
||||
}
|
||||
}
|
||||
|
||||
// Update fields if provided
|
||||
modified := false
|
||||
var updates []string
|
||||
|
||||
if configServer != "" {
|
||||
cfg.Server = configServer
|
||||
modified = true
|
||||
fmt.Printf("✓ Server updated: %s\n", configServer)
|
||||
updates = append(updates, "Server updated: "+configServer)
|
||||
}
|
||||
|
||||
if configToken != "" {
|
||||
cfg.Token = configToken
|
||||
modified = true
|
||||
fmt.Println("✓ Token updated")
|
||||
updates = append(updates, "Token updated")
|
||||
}
|
||||
|
||||
if !modified {
|
||||
return fmt.Errorf("no changes specified. Use --server or --token")
|
||||
}
|
||||
|
||||
// Save config
|
||||
if err := config.SaveClientConfig(cfg, ""); err != nil {
|
||||
return fmt.Errorf("failed to save configuration: %w", err)
|
||||
}
|
||||
|
||||
fmt.Println("✓ Configuration saved")
|
||||
fmt.Println(ui.RenderConfigUpdated(updates))
|
||||
|
||||
return nil
|
||||
}
|
||||
@@ -199,13 +169,11 @@ func runConfigSet(cmd *cobra.Command, args []string) error {
|
||||
func runConfigReset(cmd *cobra.Command, args []string) error {
|
||||
configPath := config.DefaultClientConfigPath()
|
||||
|
||||
// Check if config exists
|
||||
if !config.ConfigExists("") {
|
||||
fmt.Println("No configuration file found")
|
||||
return nil
|
||||
}
|
||||
|
||||
// Confirm deletion
|
||||
if !configForce {
|
||||
fmt.Print("Are you sure you want to delete the configuration? (y/N): ")
|
||||
reader := bufio.NewReader(os.Stdin)
|
||||
@@ -218,49 +186,31 @@ func runConfigReset(cmd *cobra.Command, args []string) error {
|
||||
}
|
||||
}
|
||||
|
||||
// Delete config file
|
||||
if err := os.Remove(configPath); err != nil {
|
||||
return fmt.Errorf("failed to delete configuration: %w", err)
|
||||
}
|
||||
|
||||
fmt.Println("✓ Configuration file deleted")
|
||||
fmt.Println(ui.RenderConfigDeleted())
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func runConfigValidate(cmd *cobra.Command, args []string) error {
|
||||
fmt.Println("\nValidating configuration...")
|
||||
fmt.Println("━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━")
|
||||
|
||||
// Load config
|
||||
cfg, err := config.LoadClientConfig("")
|
||||
if err != nil {
|
||||
fmt.Println("✗ Failed to load configuration")
|
||||
fmt.Println(ui.Error("Failed to load configuration"))
|
||||
return err
|
||||
}
|
||||
|
||||
// Validate server address
|
||||
if cfg.Server == "" {
|
||||
fmt.Println("✗ Server address is not set")
|
||||
serverValid := cfg.Server != ""
|
||||
tokenSet := cfg.Token != ""
|
||||
tlsEnabled := cfg.TLS
|
||||
|
||||
fmt.Println(ui.RenderConfigValidation(serverValid, tokenSet, tlsEnabled))
|
||||
|
||||
if !serverValid {
|
||||
return fmt.Errorf("invalid configuration")
|
||||
}
|
||||
fmt.Println("✓ Server address is valid")
|
||||
|
||||
// Validate token
|
||||
if cfg.Token != "" {
|
||||
fmt.Println("✓ Token is set")
|
||||
} else {
|
||||
fmt.Println("⚠ Token is not set (authentication may fail)")
|
||||
}
|
||||
|
||||
// Validate TLS
|
||||
if cfg.TLS {
|
||||
fmt.Println("✓ TLS is enabled")
|
||||
} else {
|
||||
fmt.Println("⚠ TLS is disabled (not recommended for production)")
|
||||
}
|
||||
|
||||
fmt.Println("\n✓ Configuration is valid")
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
@@ -1,13 +1,15 @@
|
||||
package cli
|
||||
|
||||
import (
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"os"
|
||||
"os/exec"
|
||||
"path/filepath"
|
||||
"strconv"
|
||||
"time"
|
||||
|
||||
"drip/internal/client/cli/ui"
|
||||
json "github.com/goccy/go-json"
|
||||
)
|
||||
|
||||
// DaemonInfo stores information about a running daemon process
|
||||
@@ -170,13 +172,10 @@ func StartDaemon(tunnelType string, port int, args []string) error {
|
||||
cleanArgs = append(cleanArgs, arg)
|
||||
}
|
||||
|
||||
// Create the command
|
||||
cmd := exec.Command(executable, cleanArgs...)
|
||||
|
||||
// Detach from parent process (platform-specific)
|
||||
setupDaemonCmd(cmd)
|
||||
|
||||
// Create log file for daemon output
|
||||
logDir := getDaemonDir()
|
||||
if err := os.MkdirAll(logDir, 0700); err != nil {
|
||||
return fmt.Errorf("failed to create daemon directory: %w", err)
|
||||
@@ -187,7 +186,6 @@ func StartDaemon(tunnelType string, port int, args []string) error {
|
||||
return fmt.Errorf("failed to create log file: %w", err)
|
||||
}
|
||||
|
||||
// Redirect stdin to /dev/null
|
||||
devNull, err := os.OpenFile(os.DevNull, os.O_RDONLY, 0)
|
||||
if err != nil {
|
||||
logFile.Close()
|
||||
@@ -197,7 +195,6 @@ func StartDaemon(tunnelType string, port int, args []string) error {
|
||||
cmd.Stdout = logFile
|
||||
cmd.Stderr = logFile
|
||||
|
||||
// Start the process
|
||||
if err := cmd.Start(); err != nil {
|
||||
logFile.Close()
|
||||
devNull.Close()
|
||||
@@ -207,11 +204,7 @@ func StartDaemon(tunnelType string, port int, args []string) error {
|
||||
// Don't wait for the process - let it run in background
|
||||
// The child process will save its own daemon info after connecting
|
||||
|
||||
fmt.Printf("\033[32m✓\033[0m Started %s tunnel on port %d in background (PID: %d)\n", tunnelType, port, cmd.Process.Pid)
|
||||
fmt.Printf(" Use '\033[36mdrip list\033[0m' to check tunnel status\n")
|
||||
fmt.Printf(" Use '\033[36mdrip attach %s %d\033[0m' to view logs\n", tunnelType, port)
|
||||
fmt.Printf(" Use '\033[36mdrip stop %s %d\033[0m' to stop this tunnel\n", tunnelType, port)
|
||||
fmt.Printf(" Logs: \033[90m%s\033[0m\n", logPath)
|
||||
fmt.Println(ui.RenderDaemonStarted(tunnelType, port, cmd.Process.Pid, logPath))
|
||||
|
||||
return nil
|
||||
}
|
||||
@@ -249,11 +242,9 @@ func FormatDuration(d time.Duration) string {
|
||||
// ParsePortFromArgs extracts the port number from command arguments
|
||||
func ParsePortFromArgs(args []string) (int, error) {
|
||||
for _, arg := range args {
|
||||
// Skip flags
|
||||
if len(arg) > 0 && arg[0] == '-' {
|
||||
continue
|
||||
}
|
||||
// Try to parse as port number
|
||||
port, err := strconv.Atoi(arg)
|
||||
if err == nil && port > 0 && port <= 65535 {
|
||||
return port, nil
|
||||
|
||||
@@ -3,19 +3,14 @@ package cli
|
||||
import (
|
||||
"fmt"
|
||||
"os"
|
||||
"os/signal"
|
||||
"strconv"
|
||||
"strings"
|
||||
"syscall"
|
||||
"time"
|
||||
|
||||
"drip/internal/client/tcp"
|
||||
"drip/internal/shared/protocol"
|
||||
"drip/internal/shared/utils"
|
||||
"drip/pkg/config"
|
||||
|
||||
"github.com/spf13/cobra"
|
||||
"go.uber.org/zap"
|
||||
)
|
||||
|
||||
const (
|
||||
@@ -87,13 +82,6 @@ func runHTTP(cmd *cobra.Command, args []string) error {
|
||||
return StartDaemon("http", port, daemonArgs)
|
||||
}
|
||||
|
||||
if err := utils.InitLogger(verbose); err != nil {
|
||||
return fmt.Errorf("failed to initialize logger: %w", err)
|
||||
}
|
||||
defer utils.Sync()
|
||||
|
||||
logger := utils.GetLogger()
|
||||
|
||||
var serverAddr, token string
|
||||
|
||||
if serverURL == "" {
|
||||
@@ -125,182 +113,18 @@ Please run 'drip config init' first, or use flags:
|
||||
Insecure: insecure,
|
||||
}
|
||||
|
||||
quit := make(chan os.Signal, 1)
|
||||
signal.Notify(quit, syscall.SIGINT, syscall.SIGTERM)
|
||||
|
||||
reconnectAttempts := 0
|
||||
for {
|
||||
connector := tcp.NewConnector(connConfig, logger)
|
||||
|
||||
if reconnectAttempts == 0 {
|
||||
fmt.Printf("\033[36m🔌 Connecting to %s...\033[0m\n", serverAddr)
|
||||
} else {
|
||||
fmt.Printf("\033[33m🔄 Reconnecting to %s (attempt %d/%d)...\033[0m\n", serverAddr, reconnectAttempts, maxReconnectAttempts)
|
||||
}
|
||||
|
||||
if err := connector.Connect(); err != nil {
|
||||
if isNonRetryableError(err) {
|
||||
return fmt.Errorf("failed to connect: %w", err)
|
||||
}
|
||||
|
||||
reconnectAttempts++
|
||||
if reconnectAttempts >= maxReconnectAttempts {
|
||||
return fmt.Errorf("failed to connect after %d attempts: %w", maxReconnectAttempts, err)
|
||||
}
|
||||
fmt.Printf("\033[31m✗ Connection failed: %v\033[0m\n", err)
|
||||
fmt.Printf("\033[90m Retrying in %v...\033[0m\n", reconnectInterval)
|
||||
|
||||
select {
|
||||
case <-quit:
|
||||
fmt.Println("\n\033[33m🛑 Shutting down...\033[0m")
|
||||
return nil
|
||||
case <-time.After(reconnectInterval):
|
||||
continue
|
||||
}
|
||||
}
|
||||
|
||||
reconnectAttempts = 0
|
||||
|
||||
if daemonMarker {
|
||||
daemonInfo := &DaemonInfo{
|
||||
PID: os.Getpid(),
|
||||
Type: "http",
|
||||
Port: port,
|
||||
Subdomain: subdomain,
|
||||
Server: serverAddr,
|
||||
URL: connector.GetURL(),
|
||||
StartTime: time.Now(),
|
||||
Executable: os.Args[0],
|
||||
}
|
||||
if err := SaveDaemonInfo(daemonInfo); err != nil {
|
||||
logger.Warn("Failed to save daemon info", zap.Error(err))
|
||||
}
|
||||
}
|
||||
|
||||
fmt.Println()
|
||||
fmt.Println("\033[1;32m╔══════════════════════════════════════════════════════════════════╗\033[0m")
|
||||
fmt.Println("\033[1;32m║\033[0m \033[1;37m🚀 HTTP Tunnel Connected Successfully!\033[0m \033[1;32m║\033[0m")
|
||||
fmt.Println("\033[1;32m╠══════════════════════════════════════════════════════════════════╣\033[0m")
|
||||
fmt.Printf("\033[1;32m║\033[0m \033[1;37mTunnel URL:\033[0m \033[1;32m║\033[0m\n")
|
||||
fmt.Printf("\033[1;32m║\033[0m \033[1;36m%-60s\033[0m \033[1;32m║\033[0m\n", connector.GetURL())
|
||||
fmt.Println("\033[1;32m║\033[0m \033[1;32m║\033[0m")
|
||||
displayAddr := localAddress
|
||||
if displayAddr == "127.0.0.1" {
|
||||
displayAddr = "localhost"
|
||||
}
|
||||
fmt.Printf("\033[1;32m║\033[0m \033[90mForwarding:\033[0m \033[1m%s:%d\033[0m → \033[36m%s\033[0m%-15s\033[1;32m║\033[0m\n", displayAddr, port, "public", "")
|
||||
fmt.Printf("\033[1;32m║\033[0m \033[90mLatency:\033[0m \033[90mmeasuring...\033[0m%-40s\033[1;32m║\033[0m\n", "")
|
||||
fmt.Printf("\033[1;32m║\033[0m \033[90mTraffic:\033[0m \033[90m↓ 0 B ↑ 0 B\033[0m%-32s\033[1;32m║\033[0m\n", "")
|
||||
fmt.Printf("\033[1;32m║\033[0m \033[90mSpeed:\033[0m \033[90m↓ 0 B/s ↑ 0 B/s\033[0m%-28s\033[1;32m║\033[0m\n", "")
|
||||
fmt.Printf("\033[1;32m║\033[0m \033[90mRequests:\033[0m \033[90m0\033[0m%-43s\033[1;32m║\033[0m\n", "")
|
||||
fmt.Println("\033[1;32m╠══════════════════════════════════════════════════════════════════╣\033[0m")
|
||||
fmt.Println("\033[1;32m║\033[0m \033[90mPress Ctrl+C to stop the tunnel\033[0m \033[1;32m║\033[0m")
|
||||
fmt.Println("\033[1;32m╚══════════════════════════════════════════════════════════════════╝\033[0m")
|
||||
fmt.Println()
|
||||
|
||||
latencyCh := make(chan time.Duration, 1)
|
||||
connector.SetLatencyCallback(func(latency time.Duration) {
|
||||
select {
|
||||
case latencyCh <- latency:
|
||||
default:
|
||||
}
|
||||
})
|
||||
|
||||
stopDisplay := make(chan struct{})
|
||||
disconnected := make(chan struct{})
|
||||
|
||||
go func() {
|
||||
ticker := time.NewTicker(1 * time.Second)
|
||||
defer ticker.Stop()
|
||||
|
||||
var lastLatency time.Duration
|
||||
for {
|
||||
select {
|
||||
case latency := <-latencyCh:
|
||||
lastLatency = latency
|
||||
case <-ticker.C:
|
||||
stats := connector.GetStats()
|
||||
if stats != nil {
|
||||
stats.UpdateSpeed()
|
||||
snapshot := stats.GetSnapshot()
|
||||
|
||||
fmt.Print("\033[8A")
|
||||
|
||||
fmt.Printf("\r\033[1;32m║\033[0m \033[90mLatency:\033[0m %s%-40s\033[1;32m║\033[0m\n", formatLatency(lastLatency), "")
|
||||
|
||||
trafficStr := fmt.Sprintf("↓ %s ↑ %s", tcp.FormatBytes(snapshot.TotalBytesIn), tcp.FormatBytes(snapshot.TotalBytesOut))
|
||||
fmt.Printf("\r\033[1;32m║\033[0m \033[90mTraffic:\033[0m \033[36m%-48s\033[0m\033[1;32m║\033[0m\n", trafficStr)
|
||||
|
||||
speedStr := fmt.Sprintf("↓ %s ↑ %s", tcp.FormatSpeed(snapshot.SpeedIn), tcp.FormatSpeed(snapshot.SpeedOut))
|
||||
fmt.Printf("\r\033[1;32m║\033[0m \033[90mSpeed:\033[0m \033[33m%-48s\033[0m\033[1;32m║\033[0m\n", speedStr)
|
||||
|
||||
fmt.Printf("\r\033[1;32m║\033[0m \033[90mRequests:\033[0m \033[35m%-47d\033[0m\033[1;32m║\033[0m\n", snapshot.TotalRequests)
|
||||
|
||||
fmt.Print("\033[4B")
|
||||
}
|
||||
case <-stopDisplay:
|
||||
return
|
||||
}
|
||||
}
|
||||
}()
|
||||
|
||||
go func() {
|
||||
connector.Wait()
|
||||
close(disconnected)
|
||||
}()
|
||||
|
||||
select {
|
||||
case <-quit:
|
||||
close(stopDisplay)
|
||||
fmt.Println("\n\n\033[33m🛑 Shutting down...\033[0m")
|
||||
connector.Close()
|
||||
if daemonMarker {
|
||||
RemoveDaemonInfo("http", port)
|
||||
}
|
||||
fmt.Println("\033[32m✓\033[0m Tunnel closed")
|
||||
return nil
|
||||
case <-disconnected:
|
||||
close(stopDisplay)
|
||||
fmt.Println("\n\n\033[31m⚠ Connection lost!\033[0m")
|
||||
reconnectAttempts++
|
||||
if reconnectAttempts >= maxReconnectAttempts {
|
||||
return fmt.Errorf("connection lost after %d reconnect attempts", maxReconnectAttempts)
|
||||
}
|
||||
fmt.Printf("\033[90m Reconnecting in %v...\033[0m\n", reconnectInterval)
|
||||
|
||||
select {
|
||||
case <-quit:
|
||||
fmt.Println("\n\033[33m🛑 Shutting down...\033[0m")
|
||||
return nil
|
||||
case <-time.After(reconnectInterval):
|
||||
continue
|
||||
}
|
||||
var daemon *DaemonInfo
|
||||
if daemonMarker {
|
||||
daemon = &DaemonInfo{
|
||||
PID: os.Getpid(),
|
||||
Type: "http",
|
||||
Port: port,
|
||||
Subdomain: subdomain,
|
||||
Server: serverAddr,
|
||||
StartTime: time.Now(),
|
||||
Executable: os.Args[0],
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func formatLatency(d time.Duration) string {
|
||||
ms := d.Milliseconds()
|
||||
if ms < 50 {
|
||||
return fmt.Sprintf("\033[32m%dms\033[0m", ms)
|
||||
} else if ms < 100 {
|
||||
return fmt.Sprintf("\033[33m%dms\033[0m", ms)
|
||||
} else if ms < 200 {
|
||||
return fmt.Sprintf("\033[38;5;208m%dms\033[0m", ms)
|
||||
}
|
||||
return fmt.Sprintf("\033[31m%dms\033[0m", ms)
|
||||
}
|
||||
|
||||
func isNonRetryableError(err error) bool {
|
||||
errStr := err.Error()
|
||||
if strings.Contains(errStr, "subdomain is already taken") ||
|
||||
strings.Contains(errStr, "subdomain is reserved") ||
|
||||
strings.Contains(errStr, "invalid subdomain") {
|
||||
return true
|
||||
}
|
||||
if strings.Contains(errStr, "authentication") ||
|
||||
strings.Contains(errStr, "Invalid authentication token") {
|
||||
return true
|
||||
}
|
||||
return false
|
||||
|
||||
return runTunnelWithUI(connConfig, daemon)
|
||||
}
|
||||
|
||||
@@ -3,18 +3,14 @@ package cli
|
||||
import (
|
||||
"fmt"
|
||||
"os"
|
||||
"os/signal"
|
||||
"strconv"
|
||||
"syscall"
|
||||
"time"
|
||||
|
||||
"drip/internal/client/tcp"
|
||||
"drip/internal/shared/protocol"
|
||||
"drip/internal/shared/utils"
|
||||
"drip/pkg/config"
|
||||
|
||||
"github.com/spf13/cobra"
|
||||
"go.uber.org/zap"
|
||||
)
|
||||
|
||||
var (
|
||||
@@ -52,15 +48,12 @@ func init() {
|
||||
}
|
||||
|
||||
func runHTTPS(cmd *cobra.Command, args []string) error {
|
||||
// Parse port
|
||||
port, err := strconv.Atoi(args[0])
|
||||
if err != nil || port < 1 || port > 65535 {
|
||||
return fmt.Errorf("invalid port number: %s", args[0])
|
||||
}
|
||||
|
||||
// Handle daemon mode
|
||||
if httpsDaemonMode && !httpsDaemonMarker {
|
||||
// Start as daemon
|
||||
daemonArgs := append([]string{"https"}, args...)
|
||||
daemonArgs = append(daemonArgs, "--daemon-child")
|
||||
if httpsSubdomain != "" {
|
||||
@@ -84,19 +77,9 @@ func runHTTPS(cmd *cobra.Command, args []string) error {
|
||||
return StartDaemon("https", port, daemonArgs)
|
||||
}
|
||||
|
||||
// Initialize logger
|
||||
if err := utils.InitLogger(verbose); err != nil {
|
||||
return fmt.Errorf("failed to initialize logger: %w", err)
|
||||
}
|
||||
defer utils.Sync()
|
||||
|
||||
logger := utils.GetLogger()
|
||||
|
||||
// Load configuration or use command line flags
|
||||
var serverAddr, token string
|
||||
|
||||
if serverURL == "" {
|
||||
// Try to load from config file
|
||||
cfg, err := config.LoadClientConfig("")
|
||||
if err != nil {
|
||||
return fmt.Errorf(`configuration not found.
|
||||
@@ -107,17 +90,14 @@ Please run 'drip config init' first, or use flags:
|
||||
serverAddr = cfg.Server
|
||||
token = cfg.Token
|
||||
} else {
|
||||
// Use command line flags
|
||||
serverAddr = serverURL
|
||||
token = authToken
|
||||
}
|
||||
|
||||
// Validate server address
|
||||
if serverAddr == "" {
|
||||
return fmt.Errorf("server address is required")
|
||||
}
|
||||
|
||||
// Create connector config
|
||||
connConfig := &tcp.ConnectorConfig{
|
||||
ServerAddr: serverAddr,
|
||||
Token: token,
|
||||
@@ -128,178 +108,18 @@ Please run 'drip config init' first, or use flags:
|
||||
Insecure: insecure,
|
||||
}
|
||||
|
||||
// Setup signal handler for graceful shutdown
|
||||
quit := make(chan os.Signal, 1)
|
||||
signal.Notify(quit, syscall.SIGINT, syscall.SIGTERM)
|
||||
|
||||
// Connection loop with reconnect support
|
||||
reconnectAttempts := 0
|
||||
for {
|
||||
// Create connector
|
||||
connector := tcp.NewConnector(connConfig, logger)
|
||||
|
||||
// Connect to server
|
||||
if reconnectAttempts == 0 {
|
||||
fmt.Printf("\033[36m🔌 Connecting to %s...\033[0m\n", serverAddr)
|
||||
} else {
|
||||
fmt.Printf("\033[33m🔄 Reconnecting to %s (attempt %d/%d)...\033[0m\n", serverAddr, reconnectAttempts, maxReconnectAttempts)
|
||||
}
|
||||
|
||||
if err := connector.Connect(); err != nil {
|
||||
// Check if this is a non-retryable error
|
||||
if isNonRetryableError(err) {
|
||||
return fmt.Errorf("failed to connect: %w", err)
|
||||
}
|
||||
|
||||
reconnectAttempts++
|
||||
if reconnectAttempts >= maxReconnectAttempts {
|
||||
return fmt.Errorf("failed to connect after %d attempts: %w", maxReconnectAttempts, err)
|
||||
}
|
||||
fmt.Printf("\033[31m✗ Connection failed: %v\033[0m\n", err)
|
||||
fmt.Printf("\033[90m Retrying in %v...\033[0m\n", reconnectInterval)
|
||||
|
||||
// Wait before retry, but allow interrupt
|
||||
select {
|
||||
case <-quit:
|
||||
fmt.Println("\n\033[33m🛑 Shutting down...\033[0m")
|
||||
return nil
|
||||
case <-time.After(reconnectInterval):
|
||||
continue
|
||||
}
|
||||
}
|
||||
|
||||
// Reset reconnect attempts on successful connection
|
||||
reconnectAttempts = 0
|
||||
|
||||
// Save daemon info if running as daemon child
|
||||
if httpsDaemonMarker {
|
||||
daemonInfo := &DaemonInfo{
|
||||
PID: os.Getpid(),
|
||||
Type: "https",
|
||||
Port: port,
|
||||
Subdomain: httpsSubdomain,
|
||||
Server: serverAddr,
|
||||
URL: connector.GetURL(),
|
||||
StartTime: time.Now(),
|
||||
Executable: os.Args[0],
|
||||
}
|
||||
if err := SaveDaemonInfo(daemonInfo); err != nil {
|
||||
// Log but don't fail
|
||||
logger.Warn("Failed to save daemon info", zap.Error(err))
|
||||
}
|
||||
}
|
||||
|
||||
// Print tunnel information
|
||||
fmt.Println()
|
||||
fmt.Println("\033[1;32m╔══════════════════════════════════════════════════════════════════╗\033[0m")
|
||||
fmt.Println("\033[1;32m║\033[0m \033[1;37m🔒 HTTPS Tunnel Connected Successfully!\033[0m \033[1;32m║\033[0m")
|
||||
fmt.Println("\033[1;32m╠══════════════════════════════════════════════════════════════════╣\033[0m")
|
||||
fmt.Printf("\033[1;32m║\033[0m \033[1;37mTunnel URL:\033[0m \033[1;32m║\033[0m\n")
|
||||
fmt.Printf("\033[1;32m║\033[0m \033[1;36m%-60s\033[0m \033[1;32m║\033[0m\n", connector.GetURL())
|
||||
fmt.Println("\033[1;32m║\033[0m \033[1;32m║\033[0m")
|
||||
displayAddr := httpsLocalAddress
|
||||
if displayAddr == "127.0.0.1" {
|
||||
displayAddr = "localhost"
|
||||
}
|
||||
fmt.Printf("\033[1;32m║\033[0m \033[90mForwarding:\033[0m \033[1m%s:%d\033[0m → \033[36m%s\033[0m%-15s\033[1;32m║\033[0m\n", displayAddr, port, "public", "")
|
||||
fmt.Printf("\033[1;32m║\033[0m \033[90mLatency:\033[0m \033[90mmeasuring...\033[0m%-40s\033[1;32m║\033[0m\n", "")
|
||||
fmt.Printf("\033[1;32m║\033[0m \033[90mTraffic:\033[0m \033[90m↓ 0 B ↑ 0 B\033[0m%-32s\033[1;32m║\033[0m\n", "")
|
||||
fmt.Printf("\033[1;32m║\033[0m \033[90mSpeed:\033[0m \033[90m↓ 0 B/s ↑ 0 B/s\033[0m%-28s\033[1;32m║\033[0m\n", "")
|
||||
fmt.Printf("\033[1;32m║\033[0m \033[90mRequests:\033[0m \033[90m0\033[0m%-43s\033[1;32m║\033[0m\n", "")
|
||||
fmt.Println("\033[1;32m╠══════════════════════════════════════════════════════════════════╣\033[0m")
|
||||
fmt.Println("\033[1;32m║\033[0m \033[90mPress Ctrl+C to stop the tunnel\033[0m \033[1;32m║\033[0m")
|
||||
fmt.Println("\033[1;32m╚══════════════════════════════════════════════════════════════════╝\033[0m")
|
||||
fmt.Println()
|
||||
|
||||
// Setup latency display
|
||||
latencyCh := make(chan time.Duration, 1)
|
||||
connector.SetLatencyCallback(func(latency time.Duration) {
|
||||
select {
|
||||
case latencyCh <- latency:
|
||||
default:
|
||||
}
|
||||
})
|
||||
|
||||
// Start stats display updater (updates every second)
|
||||
stopDisplay := make(chan struct{})
|
||||
disconnected := make(chan struct{})
|
||||
|
||||
go func() {
|
||||
ticker := time.NewTicker(1 * time.Second)
|
||||
defer ticker.Stop()
|
||||
|
||||
var lastLatency time.Duration
|
||||
for {
|
||||
select {
|
||||
case latency := <-latencyCh:
|
||||
lastLatency = latency
|
||||
case <-ticker.C:
|
||||
// Update speed calculation
|
||||
stats := connector.GetStats()
|
||||
if stats != nil {
|
||||
stats.UpdateSpeed()
|
||||
snapshot := stats.GetSnapshot()
|
||||
|
||||
// Move cursor up 8 lines to update display
|
||||
fmt.Print("\033[8A")
|
||||
|
||||
// Update latency line
|
||||
fmt.Printf("\r\033[1;32m║\033[0m \033[90mLatency:\033[0m %s%-40s\033[1;32m║\033[0m\n", formatLatency(lastLatency), "")
|
||||
|
||||
// Update traffic line
|
||||
trafficStr := fmt.Sprintf("↓ %s ↑ %s", tcp.FormatBytes(snapshot.TotalBytesIn), tcp.FormatBytes(snapshot.TotalBytesOut))
|
||||
fmt.Printf("\r\033[1;32m║\033[0m \033[90mTraffic:\033[0m \033[36m%-48s\033[0m\033[1;32m║\033[0m\n", trafficStr)
|
||||
|
||||
// Update speed line
|
||||
speedStr := fmt.Sprintf("↓ %s ↑ %s", tcp.FormatSpeed(snapshot.SpeedIn), tcp.FormatSpeed(snapshot.SpeedOut))
|
||||
fmt.Printf("\r\033[1;32m║\033[0m \033[90mSpeed:\033[0m \033[33m%-48s\033[0m\033[1;32m║\033[0m\n", speedStr)
|
||||
|
||||
// Update requests line
|
||||
fmt.Printf("\r\033[1;32m║\033[0m \033[90mRequests:\033[0m \033[35m%-47d\033[0m\033[1;32m║\033[0m\n", snapshot.TotalRequests)
|
||||
|
||||
// Move back down 4 lines
|
||||
fmt.Print("\033[4B")
|
||||
}
|
||||
case <-stopDisplay:
|
||||
return
|
||||
}
|
||||
}
|
||||
}()
|
||||
|
||||
// Monitor connection in background
|
||||
go func() {
|
||||
connector.Wait()
|
||||
close(disconnected)
|
||||
}()
|
||||
|
||||
// Wait for signal or disconnection
|
||||
select {
|
||||
case <-quit:
|
||||
close(stopDisplay)
|
||||
fmt.Println("\n\n\033[33m🛑 Shutting down...\033[0m")
|
||||
connector.Close()
|
||||
if httpsDaemonMarker {
|
||||
RemoveDaemonInfo("https", port)
|
||||
}
|
||||
fmt.Println("\033[32m✓\033[0m Tunnel closed")
|
||||
return nil
|
||||
case <-disconnected:
|
||||
close(stopDisplay)
|
||||
fmt.Println("\n\n\033[31m⚠ Connection lost!\033[0m")
|
||||
reconnectAttempts++
|
||||
if reconnectAttempts >= maxReconnectAttempts {
|
||||
return fmt.Errorf("connection lost after %d reconnect attempts", maxReconnectAttempts)
|
||||
}
|
||||
fmt.Printf("\033[90m Reconnecting in %v...\033[0m\n", reconnectInterval)
|
||||
|
||||
// Wait before reconnect, but allow interrupt
|
||||
select {
|
||||
case <-quit:
|
||||
fmt.Println("\n\033[33m🛑 Shutting down...\033[0m")
|
||||
return nil
|
||||
case <-time.After(reconnectInterval):
|
||||
continue
|
||||
}
|
||||
var daemon *DaemonInfo
|
||||
if httpsDaemonMarker {
|
||||
daemon = &DaemonInfo{
|
||||
PID: os.Getpid(),
|
||||
Type: "https",
|
||||
Port: port,
|
||||
Subdomain: httpsSubdomain,
|
||||
Server: serverAddr,
|
||||
StartTime: time.Now(),
|
||||
Executable: os.Args[0],
|
||||
}
|
||||
}
|
||||
|
||||
return runTunnelWithUI(connConfig, daemon)
|
||||
}
|
||||
|
||||
@@ -8,6 +8,7 @@ import (
|
||||
"strings"
|
||||
"time"
|
||||
|
||||
"drip/internal/client/cli/ui"
|
||||
"github.com/spf13/cobra"
|
||||
)
|
||||
|
||||
@@ -44,90 +45,83 @@ func init() {
|
||||
}
|
||||
|
||||
func runList(cmd *cobra.Command, args []string) error {
|
||||
// Clean up stale daemons first
|
||||
CleanupStaleDaemons()
|
||||
|
||||
// Get all running daemons
|
||||
daemons, err := ListAllDaemons()
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to list daemons: %w", err)
|
||||
}
|
||||
|
||||
if len(daemons) == 0 {
|
||||
fmt.Println("\033[90mNo running tunnels.\033[0m")
|
||||
fmt.Println()
|
||||
fmt.Println("Start a tunnel in background with:")
|
||||
fmt.Println(" \033[36mdrip http 3000 -d\033[0m")
|
||||
fmt.Println(" \033[36mdrip tcp 5432 -d\033[0m")
|
||||
fmt.Println(ui.Info(
|
||||
"No Running Tunnels",
|
||||
"",
|
||||
ui.Muted("Start a tunnel in background with:"),
|
||||
"",
|
||||
ui.Cyan(" drip http 3000 -d"),
|
||||
ui.Cyan(" drip tcp 5432 -d"),
|
||||
))
|
||||
return nil
|
||||
}
|
||||
|
||||
// Print header
|
||||
fmt.Println()
|
||||
fmt.Println("\033[1;37mRunning Tunnels\033[0m")
|
||||
fmt.Println("\033[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\033[0m")
|
||||
fmt.Printf("\033[1m%-4s %-6s %-6s %-40s %-8s %s\033[0m\n", "#", "TYPE", "PORT", "URL", "PID", "UPTIME")
|
||||
fmt.Println("\033[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\033[0m")
|
||||
table := ui.NewTable([]string{"#", "TYPE", "PORT", "URL", "PID", "UPTIME"}).
|
||||
WithTitle("Running Tunnels")
|
||||
|
||||
idx := 1
|
||||
for _, d := range daemons {
|
||||
// Check if process is still running
|
||||
if !IsProcessRunning(d.PID) {
|
||||
// Clean up stale entry
|
||||
RemoveDaemonInfo(d.Type, d.Port)
|
||||
continue
|
||||
}
|
||||
|
||||
// Calculate uptime
|
||||
uptime := time.Since(d.StartTime)
|
||||
|
||||
// Format type with color
|
||||
var typeStr string
|
||||
if d.Type == "http" {
|
||||
typeStr = "\033[32mHTTP\033[0m"
|
||||
typeStr = ui.Highlight("HTTP")
|
||||
} else if d.Type == "https" {
|
||||
typeStr = ui.Highlight("HTTPS")
|
||||
} else {
|
||||
typeStr = "\033[35mTCP\033[0m"
|
||||
typeStr = ui.Cyan("TCP")
|
||||
}
|
||||
|
||||
// Truncate URL if too long
|
||||
url := d.URL
|
||||
if len(url) > 40 {
|
||||
url = url[:37] + "..."
|
||||
}
|
||||
|
||||
fmt.Printf("\033[1;36m%-4d\033[0m %-15s %-6d %-40s %-8d %s\n",
|
||||
idx, typeStr, d.Port, url, d.PID, FormatDuration(uptime))
|
||||
table.AddRow([]string{
|
||||
ui.Highlight(fmt.Sprintf("%d", idx)),
|
||||
typeStr,
|
||||
fmt.Sprintf("%d", d.Port),
|
||||
ui.URL(d.URL),
|
||||
ui.Muted(fmt.Sprintf("%d", d.PID)),
|
||||
FormatDuration(uptime),
|
||||
})
|
||||
idx++
|
||||
}
|
||||
|
||||
fmt.Println("\033[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\033[0m")
|
||||
fmt.Println()
|
||||
fmt.Print(table.Render())
|
||||
|
||||
// Interactive mode or show commands
|
||||
if interactiveMode || shouldPromptForAction() {
|
||||
return runInteractiveList(daemons)
|
||||
}
|
||||
|
||||
fmt.Println("Commands:")
|
||||
fmt.Println(" \033[36mdrip list -i\033[0m Interactive mode")
|
||||
fmt.Println(" \033[36mdrip attach http 3000\033[0m Attach to tunnel (view logs)")
|
||||
fmt.Println(" \033[36mdrip stop http 3000\033[0m Stop tunnel")
|
||||
fmt.Println(" \033[36mdrip stop all\033[0m Stop all tunnels")
|
||||
fmt.Println(ui.Muted("Commands:"))
|
||||
fmt.Println(ui.RenderList([]string{
|
||||
ui.Cyan("drip list -i") + ui.Muted(" Interactive mode"),
|
||||
ui.Cyan("drip attach http 3000") + ui.Muted(" Attach to tunnel (view logs)"),
|
||||
ui.Cyan("drip stop http 3000") + ui.Muted(" Stop tunnel"),
|
||||
ui.Cyan("drip stop all") + ui.Muted(" Stop all tunnels"),
|
||||
}))
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func shouldPromptForAction() bool {
|
||||
// Check if running in a terminal
|
||||
if fileInfo, _ := os.Stdout.Stat(); (fileInfo.Mode() & os.ModeCharDevice) == 0 {
|
||||
return false
|
||||
}
|
||||
// Always prompt when there are tunnels running
|
||||
return true
|
||||
return false
|
||||
}
|
||||
|
||||
func runInteractiveList(daemons []*DaemonInfo) error {
|
||||
// Filter out non-running daemons
|
||||
var runningDaemons []*DaemonInfo
|
||||
for _, d := range daemons {
|
||||
if IsProcessRunning(d.PID) {
|
||||
@@ -141,8 +135,8 @@ func runInteractiveList(daemons []*DaemonInfo) error {
|
||||
return nil
|
||||
}
|
||||
|
||||
// Prompt for action
|
||||
fmt.Print("Select a tunnel (number) or 'q' to quit: ")
|
||||
fmt.Println()
|
||||
fmt.Print(ui.Muted("Select a tunnel (number) or 'q' to quit: "))
|
||||
reader := bufio.NewReader(os.Stdin)
|
||||
input, err := reader.ReadString('\n')
|
||||
if err != nil {
|
||||
@@ -154,7 +148,6 @@ func runInteractiveList(daemons []*DaemonInfo) error {
|
||||
return nil
|
||||
}
|
||||
|
||||
// Parse selection
|
||||
selection, err := strconv.Atoi(input)
|
||||
if err != nil || selection < 1 || selection > len(runningDaemons) {
|
||||
return fmt.Errorf("invalid selection: %s", input)
|
||||
@@ -162,16 +155,18 @@ func runInteractiveList(daemons []*DaemonInfo) error {
|
||||
|
||||
selectedDaemon := runningDaemons[selection-1]
|
||||
|
||||
// Prompt for action
|
||||
fmt.Println()
|
||||
fmt.Printf("Selected: \033[1m%s\033[0m tunnel on port \033[1m%d\033[0m\n", strings.ToUpper(selectedDaemon.Type), selectedDaemon.Port)
|
||||
fmt.Println()
|
||||
fmt.Println("What would you like to do?")
|
||||
fmt.Println(" \033[36m1.\033[0m Attach (view logs)")
|
||||
fmt.Println(" \033[36m2.\033[0m Stop tunnel")
|
||||
fmt.Println(" \033[90mq. Cancel\033[0m")
|
||||
fmt.Println()
|
||||
fmt.Print("Choose an action: ")
|
||||
fmt.Println(ui.Info(
|
||||
fmt.Sprintf("Selected: %s tunnel on port %d", strings.ToUpper(selectedDaemon.Type), selectedDaemon.Port),
|
||||
"",
|
||||
ui.Muted("What would you like to do?"),
|
||||
"",
|
||||
ui.Cyan(" 1.") + " Attach (view logs)",
|
||||
ui.Cyan(" 2.") + " Stop tunnel",
|
||||
ui.Muted(" q.") + " Cancel",
|
||||
))
|
||||
|
||||
fmt.Print(ui.Muted("Choose an action: "))
|
||||
|
||||
actionInput, err := reader.ReadString('\n')
|
||||
if err != nil {
|
||||
@@ -181,10 +176,8 @@ func runInteractiveList(daemons []*DaemonInfo) error {
|
||||
actionInput = strings.TrimSpace(actionInput)
|
||||
switch actionInput {
|
||||
case "1":
|
||||
// Attach to daemon
|
||||
return attachToDaemon(selectedDaemon)
|
||||
case "2":
|
||||
// Stop daemon
|
||||
return stopDaemon(selectedDaemon.Type, selectedDaemon.Port)
|
||||
case "q", "Q", "":
|
||||
return nil
|
||||
|
||||
@@ -3,6 +3,7 @@ package cli
|
||||
import (
|
||||
"fmt"
|
||||
|
||||
"drip/internal/client/cli/ui"
|
||||
"github.com/spf13/cobra"
|
||||
)
|
||||
|
||||
@@ -59,12 +60,13 @@ var versionCmd = &cobra.Command{
|
||||
Use: "version",
|
||||
Short: "Print version information",
|
||||
Run: func(cmd *cobra.Command, args []string) {
|
||||
fmt.Printf("Drip Client\n")
|
||||
fmt.Printf("━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\n")
|
||||
fmt.Printf("Version: %s\n", Version)
|
||||
fmt.Printf("Git Commit: %s\n", GitCommit)
|
||||
fmt.Printf("Build Time: %s\n", BuildTime)
|
||||
fmt.Printf("━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\n")
|
||||
fmt.Println(ui.Info(
|
||||
"Drip Client",
|
||||
"",
|
||||
ui.KeyValue("Version", Version),
|
||||
ui.KeyValue("Git Commit", GitCommit),
|
||||
ui.KeyValue("Build Time", BuildTime),
|
||||
))
|
||||
},
|
||||
}
|
||||
|
||||
|
||||
@@ -60,7 +60,6 @@ func init() {
|
||||
}
|
||||
|
||||
func runServer(cmd *cobra.Command, args []string) error {
|
||||
// Validate required TLS configuration
|
||||
if serverTLSCert == "" {
|
||||
return fmt.Errorf("TLS certificate path is required (use --tls-cert flag or DRIP_TLS_CERT environment variable)")
|
||||
}
|
||||
@@ -68,7 +67,6 @@ func runServer(cmd *cobra.Command, args []string) error {
|
||||
return fmt.Errorf("TLS private key path is required (use --tls-key flag or DRIP_TLS_KEY environment variable)")
|
||||
}
|
||||
|
||||
// Initialize logger
|
||||
if err := utils.InitServerLogger(serverDebug); err != nil {
|
||||
return fmt.Errorf("failed to initialize logger: %w", err)
|
||||
}
|
||||
@@ -81,7 +79,6 @@ func runServer(cmd *cobra.Command, args []string) error {
|
||||
zap.String("commit", GitCommit),
|
||||
)
|
||||
|
||||
// Start pprof server if enabled
|
||||
if serverPprofPort > 0 {
|
||||
go func() {
|
||||
pprofAddr := fmt.Sprintf("localhost:%d", serverPprofPort)
|
||||
@@ -92,7 +89,6 @@ func runServer(cmd *cobra.Command, args []string) error {
|
||||
}()
|
||||
}
|
||||
|
||||
// Create server config
|
||||
displayPort := serverPublicPort
|
||||
if displayPort == 0 {
|
||||
displayPort = serverPort
|
||||
@@ -111,7 +107,6 @@ func runServer(cmd *cobra.Command, args []string) error {
|
||||
Debug: serverDebug,
|
||||
}
|
||||
|
||||
// Load TLS configuration
|
||||
tlsConfig, err := serverConfig.LoadTLSConfig()
|
||||
if err != nil {
|
||||
logger.Fatal("Failed to load TLS configuration", zap.Error(err))
|
||||
@@ -122,28 +117,21 @@ func runServer(cmd *cobra.Command, args []string) error {
|
||||
zap.String("key", serverTLSKey),
|
||||
)
|
||||
|
||||
// Create tunnel manager
|
||||
tunnelManager := tunnel.NewManager(logger)
|
||||
|
||||
// Create TCP port allocator
|
||||
portAllocator, err := tcp.NewPortAllocator(serverTCPPortMin, serverTCPPortMax)
|
||||
if err != nil {
|
||||
logger.Fatal("Invalid TCP port range", zap.Error(err))
|
||||
}
|
||||
|
||||
// Create TCP listener address
|
||||
listenAddr := fmt.Sprintf("0.0.0.0:%d", serverPort)
|
||||
|
||||
// Response handler for HTTP-over-frame responses
|
||||
responseHandler := proxy.NewResponseHandler(logger)
|
||||
|
||||
// Create HTTP proxy handler (for handling HTTP requests on TCP port)
|
||||
httpHandler := proxy.NewHandler(tunnelManager, logger, responseHandler, serverDomain, serverAuthToken)
|
||||
|
||||
// Create TCP listener (wsHandler also serves as response channel handler)
|
||||
listener := tcp.NewListener(listenAddr, tlsConfig, serverAuthToken, tunnelManager, logger, portAllocator, serverDomain, displayPort, httpHandler, responseHandler)
|
||||
|
||||
// Start listener
|
||||
if err := listener.Start(); err != nil {
|
||||
logger.Fatal("Failed to start TCP listener", zap.Error(err))
|
||||
}
|
||||
@@ -154,16 +142,13 @@ func runServer(cmd *cobra.Command, args []string) error {
|
||||
zap.String("protocol", "TCP over TLS 1.3"),
|
||||
)
|
||||
|
||||
// Setup signal handler
|
||||
quit := make(chan os.Signal, 1)
|
||||
signal.Notify(quit, syscall.SIGINT, syscall.SIGTERM)
|
||||
|
||||
// Wait for signal
|
||||
<-quit
|
||||
|
||||
logger.Info("Shutting down server...")
|
||||
|
||||
// Stop listener
|
||||
if err := listener.Stop(); err != nil {
|
||||
logger.Error("Error stopping listener", zap.Error(err))
|
||||
}
|
||||
|
||||
@@ -28,12 +28,10 @@ func init() {
|
||||
}
|
||||
|
||||
func runStop(cmd *cobra.Command, args []string) error {
|
||||
// Handle "stop all"
|
||||
if args[0] == "all" {
|
||||
return stopAllDaemons()
|
||||
}
|
||||
|
||||
// Handle "stop <type> <port>"
|
||||
if len(args) < 2 {
|
||||
return fmt.Errorf("usage: drip stop <type> <port> or drip stop all")
|
||||
}
|
||||
@@ -61,19 +59,15 @@ func stopDaemon(tunnelType string, port int) error {
|
||||
return fmt.Errorf("no %s tunnel running on port %d", tunnelType, port)
|
||||
}
|
||||
|
||||
// Check if process is still running
|
||||
if !IsProcessRunning(info.PID) {
|
||||
// Clean up stale entry
|
||||
RemoveDaemonInfo(tunnelType, port)
|
||||
return fmt.Errorf("tunnel was not running (cleaned up stale entry)")
|
||||
}
|
||||
|
||||
// Kill the process
|
||||
if err := KillProcess(info.PID); err != nil {
|
||||
return fmt.Errorf("failed to stop tunnel: %w", err)
|
||||
}
|
||||
|
||||
// Remove daemon info
|
||||
RemoveDaemonInfo(tunnelType, port)
|
||||
|
||||
fmt.Printf("\033[32m✓\033[0m Stopped %s tunnel on port %d (PID: %d)\n", tunnelType, port, info.PID)
|
||||
@@ -81,7 +75,6 @@ func stopDaemon(tunnelType string, port int) error {
|
||||
}
|
||||
|
||||
func stopAllDaemons() error {
|
||||
// Clean up stale daemons first
|
||||
CleanupStaleDaemons()
|
||||
|
||||
daemons, err := ListAllDaemons()
|
||||
|
||||
@@ -3,19 +3,14 @@ package cli
|
||||
import (
|
||||
"fmt"
|
||||
"os"
|
||||
"os/signal"
|
||||
"strconv"
|
||||
"strings"
|
||||
"syscall"
|
||||
"time"
|
||||
|
||||
"drip/internal/client/tcp"
|
||||
"drip/internal/shared/protocol"
|
||||
"drip/internal/shared/utils"
|
||||
"drip/pkg/config"
|
||||
|
||||
"github.com/spf13/cobra"
|
||||
"go.uber.org/zap"
|
||||
)
|
||||
|
||||
var tcpCmd = &cobra.Command{
|
||||
@@ -53,15 +48,12 @@ func init() {
|
||||
}
|
||||
|
||||
func runTCP(cmd *cobra.Command, args []string) error {
|
||||
// Parse port
|
||||
port, err := strconv.Atoi(args[0])
|
||||
if err != nil || port < 1 || port > 65535 {
|
||||
return fmt.Errorf("invalid port number: %s", args[0])
|
||||
}
|
||||
|
||||
// Handle daemon mode
|
||||
if daemonMode && !daemonMarker {
|
||||
// Start as daemon
|
||||
daemonArgs := append([]string{"tcp"}, args...)
|
||||
daemonArgs = append(daemonArgs, "--daemon-child")
|
||||
if subdomain != "" {
|
||||
@@ -85,19 +77,9 @@ func runTCP(cmd *cobra.Command, args []string) error {
|
||||
return StartDaemon("tcp", port, daemonArgs)
|
||||
}
|
||||
|
||||
// Initialize logger
|
||||
if err := utils.InitLogger(verbose); err != nil {
|
||||
return fmt.Errorf("failed to initialize logger: %w", err)
|
||||
}
|
||||
defer utils.Sync()
|
||||
|
||||
logger := utils.GetLogger()
|
||||
|
||||
// Load configuration or use command line flags
|
||||
var serverAddr, token string
|
||||
|
||||
if serverURL == "" {
|
||||
// Try to load from config file
|
||||
cfg, err := config.LoadClientConfig("")
|
||||
if err != nil {
|
||||
return fmt.Errorf(`configuration not found.
|
||||
@@ -108,17 +90,14 @@ Please run 'drip config init' first, or use flags:
|
||||
serverAddr = cfg.Server
|
||||
token = cfg.Token
|
||||
} else {
|
||||
// Use command line flags
|
||||
serverAddr = serverURL
|
||||
token = authToken
|
||||
}
|
||||
|
||||
// Validate server address
|
||||
if serverAddr == "" {
|
||||
return fmt.Errorf("server address is required")
|
||||
}
|
||||
|
||||
// Create connector config
|
||||
connConfig := &tcp.ConnectorConfig{
|
||||
ServerAddr: serverAddr,
|
||||
Token: token,
|
||||
@@ -129,233 +108,18 @@ Please run 'drip config init' first, or use flags:
|
||||
Insecure: insecure,
|
||||
}
|
||||
|
||||
// Setup signal handler for graceful shutdown
|
||||
quit := make(chan os.Signal, 1)
|
||||
signal.Notify(quit, syscall.SIGINT, syscall.SIGTERM)
|
||||
|
||||
// Connection loop with reconnect support
|
||||
reconnectAttempts := 0
|
||||
serviceName := getServiceName(port)
|
||||
for {
|
||||
// Create connector
|
||||
connector := tcp.NewConnector(connConfig, logger)
|
||||
|
||||
// Connect to server
|
||||
if reconnectAttempts == 0 {
|
||||
fmt.Printf("\033[36m🔌 Connecting to %s...\033[0m\n", serverAddr)
|
||||
} else {
|
||||
fmt.Printf("\033[33m🔄 Reconnecting to %s (attempt %d/%d)...\033[0m\n", serverAddr, reconnectAttempts, maxReconnectAttempts)
|
||||
}
|
||||
|
||||
if err := connector.Connect(); err != nil {
|
||||
// Check if this is a non-retryable error
|
||||
if isNonRetryableErrorTCP(err) {
|
||||
return fmt.Errorf("failed to connect: %w", err)
|
||||
}
|
||||
|
||||
reconnectAttempts++
|
||||
if reconnectAttempts >= maxReconnectAttempts {
|
||||
return fmt.Errorf("failed to connect after %d attempts: %w", maxReconnectAttempts, err)
|
||||
}
|
||||
fmt.Printf("\033[31m✗ Connection failed: %v\033[0m\n", err)
|
||||
fmt.Printf("\033[90m Retrying in %v...\033[0m\n", reconnectInterval)
|
||||
|
||||
// Wait before retry, but allow interrupt
|
||||
select {
|
||||
case <-quit:
|
||||
fmt.Println("\n\033[33m🛑 Shutting down...\033[0m")
|
||||
return nil
|
||||
case <-time.After(reconnectInterval):
|
||||
continue
|
||||
}
|
||||
}
|
||||
|
||||
// Reset reconnect attempts on successful connection
|
||||
reconnectAttempts = 0
|
||||
|
||||
// Save daemon info if running as daemon child
|
||||
if daemonMarker {
|
||||
daemonInfo := &DaemonInfo{
|
||||
PID: os.Getpid(),
|
||||
Type: "tcp",
|
||||
Port: port,
|
||||
Subdomain: subdomain,
|
||||
Server: serverAddr,
|
||||
URL: connector.GetURL(),
|
||||
StartTime: time.Now(),
|
||||
Executable: os.Args[0],
|
||||
}
|
||||
if err := SaveDaemonInfo(daemonInfo); err != nil {
|
||||
// Log but don't fail
|
||||
logger.Warn("Failed to save daemon info", zap.Error(err))
|
||||
}
|
||||
}
|
||||
|
||||
// Print tunnel information
|
||||
fmt.Println()
|
||||
fmt.Println("\033[1;35m╔══════════════════════════════════════════════════════════════════╗\033[0m")
|
||||
fmt.Println("\033[1;35m║\033[0m \033[1;37m🔌 TCP Tunnel Connected Successfully!\033[0m \033[1;35m║\033[0m")
|
||||
fmt.Println("\033[1;35m╠══════════════════════════════════════════════════════════════════╣\033[0m")
|
||||
fmt.Printf("\033[1;35m║\033[0m \033[1;37mTunnel URL:\033[0m \033[1;35m║\033[0m\n")
|
||||
fmt.Printf("\033[1;35m║\033[0m \033[1;36m%-60s\033[0m \033[1;35m║\033[0m\n", connector.GetURL())
|
||||
fmt.Println("\033[1;35m║\033[0m \033[1;35m║\033[0m")
|
||||
fmt.Printf("\033[1;35m║\033[0m \033[90mService:\033[0m \033[1;35m%-50s\033[0m \033[1;35m║\033[0m\n", serviceName)
|
||||
displayAddr := localAddress
|
||||
if displayAddr == "127.0.0.1" {
|
||||
displayAddr = "localhost"
|
||||
}
|
||||
fmt.Printf("\033[1;35m║\033[0m \033[90mForwarding:\033[0m \033[1m%s:%d\033[0m → \033[36m%s\033[0m%-15s\033[1;35m║\033[0m\n", displayAddr, port, "public", "")
|
||||
fmt.Printf("\033[1;35m║\033[0m \033[90mLatency:\033[0m \033[90mmeasuring...\033[0m%-40s\033[1;35m║\033[0m\n", "")
|
||||
fmt.Printf("\033[1;35m║\033[0m \033[90mTraffic:\033[0m \033[90m↓ 0 B ↑ 0 B\033[0m%-32s\033[1;35m║\033[0m\n", "")
|
||||
fmt.Printf("\033[1;35m║\033[0m \033[90mSpeed:\033[0m \033[90m↓ 0 B/s ↑ 0 B/s\033[0m%-28s\033[1;35m║\033[0m\n", "")
|
||||
fmt.Printf("\033[1;35m║\033[0m \033[90mRequests:\033[0m \033[90m0\033[0m%-43s\033[1;35m║\033[0m\n", "")
|
||||
fmt.Println("\033[1;35m╠══════════════════════════════════════════════════════════════════╣\033[0m")
|
||||
fmt.Println("\033[1;35m║\033[0m \033[90mPress Ctrl+C to stop the tunnel\033[0m \033[1;35m║\033[0m")
|
||||
fmt.Println("\033[1;35m╚══════════════════════════════════════════════════════════════════╝\033[0m")
|
||||
fmt.Println()
|
||||
|
||||
// Setup latency display
|
||||
latencyCh := make(chan time.Duration, 1)
|
||||
connector.SetLatencyCallback(func(latency time.Duration) {
|
||||
select {
|
||||
case latencyCh <- latency:
|
||||
default:
|
||||
}
|
||||
})
|
||||
|
||||
// Start stats display updater (updates every second)
|
||||
stopDisplay := make(chan struct{})
|
||||
disconnected := make(chan struct{})
|
||||
|
||||
go func() {
|
||||
ticker := time.NewTicker(1 * time.Second)
|
||||
defer ticker.Stop()
|
||||
|
||||
var lastLatency time.Duration
|
||||
for {
|
||||
select {
|
||||
case latency := <-latencyCh:
|
||||
lastLatency = latency
|
||||
case <-ticker.C:
|
||||
// Update speed calculation
|
||||
stats := connector.GetStats()
|
||||
if stats != nil {
|
||||
stats.UpdateSpeed()
|
||||
snapshot := stats.GetSnapshot()
|
||||
|
||||
// Move cursor up 8 lines to update display
|
||||
fmt.Print("\033[8A")
|
||||
|
||||
// Update latency line
|
||||
fmt.Printf("\r\033[1;35m║\033[0m \033[90mLatency:\033[0m %s%-40s\033[1;35m║\033[0m\n", formatLatencyTCP(lastLatency), "")
|
||||
|
||||
// Update traffic line
|
||||
trafficStr := fmt.Sprintf("↓ %s ↑ %s", tcp.FormatBytes(snapshot.TotalBytesIn), tcp.FormatBytes(snapshot.TotalBytesOut))
|
||||
fmt.Printf("\r\033[1;35m║\033[0m \033[90mTraffic:\033[0m \033[36m%-48s\033[0m\033[1;35m║\033[0m\n", trafficStr)
|
||||
|
||||
// Update speed line
|
||||
speedStr := fmt.Sprintf("↓ %s ↑ %s", tcp.FormatSpeed(snapshot.SpeedIn), tcp.FormatSpeed(snapshot.SpeedOut))
|
||||
fmt.Printf("\r\033[1;35m║\033[0m \033[90mSpeed:\033[0m \033[33m%-48s\033[0m\033[1;35m║\033[0m\n", speedStr)
|
||||
|
||||
// Update requests line
|
||||
fmt.Printf("\r\033[1;35m║\033[0m \033[90mRequests:\033[0m \033[35m%-47d\033[0m\033[1;35m║\033[0m\n", snapshot.TotalRequests)
|
||||
|
||||
// Move back down 4 lines
|
||||
fmt.Print("\033[4B")
|
||||
}
|
||||
case <-stopDisplay:
|
||||
return
|
||||
}
|
||||
}
|
||||
}()
|
||||
|
||||
// Monitor connection in background
|
||||
go func() {
|
||||
connector.Wait()
|
||||
close(disconnected)
|
||||
}()
|
||||
|
||||
// Wait for signal or disconnection
|
||||
select {
|
||||
case <-quit:
|
||||
close(stopDisplay)
|
||||
fmt.Println("\n\n\033[33m🛑 Shutting down...\033[0m")
|
||||
connector.Close()
|
||||
if daemonMarker {
|
||||
RemoveDaemonInfo("tcp", port)
|
||||
}
|
||||
fmt.Println("\033[32m✓\033[0m Tunnel closed")
|
||||
return nil
|
||||
case <-disconnected:
|
||||
close(stopDisplay)
|
||||
fmt.Println("\n\n\033[31m⚠ Connection lost!\033[0m")
|
||||
reconnectAttempts++
|
||||
if reconnectAttempts >= maxReconnectAttempts {
|
||||
return fmt.Errorf("connection lost after %d reconnect attempts", maxReconnectAttempts)
|
||||
}
|
||||
fmt.Printf("\033[90m Reconnecting in %v...\033[0m\n", reconnectInterval)
|
||||
|
||||
// Wait before reconnect, but allow interrupt
|
||||
select {
|
||||
case <-quit:
|
||||
fmt.Println("\n\033[33m🛑 Shutting down...\033[0m")
|
||||
return nil
|
||||
case <-time.After(reconnectInterval):
|
||||
continue
|
||||
}
|
||||
var daemon *DaemonInfo
|
||||
if daemonMarker {
|
||||
daemon = &DaemonInfo{
|
||||
PID: os.Getpid(),
|
||||
Type: "tcp",
|
||||
Port: port,
|
||||
Subdomain: subdomain,
|
||||
Server: serverAddr,
|
||||
StartTime: time.Now(),
|
||||
Executable: os.Args[0],
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// getServiceName returns a friendly name for common port numbers
|
||||
func getServiceName(port int) string {
|
||||
services := map[int]string{
|
||||
22: "SSH",
|
||||
80: "HTTP",
|
||||
443: "HTTPS",
|
||||
3306: "MySQL",
|
||||
5432: "PostgreSQL",
|
||||
6379: "Redis",
|
||||
27017: "MongoDB",
|
||||
3389: "RDP",
|
||||
5900: "VNC",
|
||||
8080: "HTTP (Alt)",
|
||||
8443: "HTTPS (Alt)",
|
||||
}
|
||||
|
||||
if name, ok := services[port]; ok {
|
||||
return fmt.Sprintf("%s (port %d)", name, port)
|
||||
}
|
||||
|
||||
return fmt.Sprintf("TCP service on port %d", port)
|
||||
}
|
||||
|
||||
// formatLatencyTCP formats latency with color based on value
|
||||
func formatLatencyTCP(d time.Duration) string {
|
||||
ms := d.Milliseconds()
|
||||
if ms < 50 {
|
||||
return fmt.Sprintf("\033[32m%dms\033[0m", ms) // Green: excellent
|
||||
} else if ms < 100 {
|
||||
return fmt.Sprintf("\033[33m%dms\033[0m", ms) // Yellow: good
|
||||
} else if ms < 200 {
|
||||
return fmt.Sprintf("\033[38;5;208m%dms\033[0m", ms) // Orange: moderate
|
||||
}
|
||||
return fmt.Sprintf("\033[31m%dms\033[0m", ms) // Red: poor
|
||||
}
|
||||
|
||||
// isNonRetryableErrorTCP checks if an error should not be retried
|
||||
func isNonRetryableErrorTCP(err error) bool {
|
||||
errStr := err.Error()
|
||||
// Subdomain conflicts - no point retrying
|
||||
if strings.Contains(errStr, "subdomain is already taken") ||
|
||||
strings.Contains(errStr, "subdomain is reserved") ||
|
||||
strings.Contains(errStr, "invalid subdomain") {
|
||||
return true
|
||||
}
|
||||
// Authentication errors - no point retrying
|
||||
if strings.Contains(errStr, "authentication") ||
|
||||
strings.Contains(errStr, "Invalid authentication token") {
|
||||
return true
|
||||
}
|
||||
return false
|
||||
|
||||
return runTunnelWithUI(connConfig, daemon)
|
||||
}
|
||||
|
||||
209
internal/client/cli/tunnel_runner.go
Normal file
209
internal/client/cli/tunnel_runner.go
Normal file
@@ -0,0 +1,209 @@
|
||||
package cli
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"os"
|
||||
"os/signal"
|
||||
"strings"
|
||||
"syscall"
|
||||
"time"
|
||||
|
||||
"drip/internal/client/cli/ui"
|
||||
"drip/internal/client/tcp"
|
||||
"drip/internal/shared/utils"
|
||||
"go.uber.org/zap"
|
||||
)
|
||||
|
||||
// runTunnelWithUI runs a tunnel with the new UI
|
||||
func runTunnelWithUI(connConfig *tcp.ConnectorConfig, daemonInfo *DaemonInfo) error {
|
||||
if err := utils.InitLogger(verbose); err != nil {
|
||||
return fmt.Errorf("failed to initialize logger: %w", err)
|
||||
}
|
||||
defer utils.Sync()
|
||||
|
||||
logger := utils.GetLogger()
|
||||
|
||||
quit := make(chan os.Signal, 1)
|
||||
signal.Notify(quit, syscall.SIGINT, syscall.SIGTERM)
|
||||
|
||||
reconnectAttempts := 0
|
||||
for {
|
||||
connector := tcp.NewConnector(connConfig, logger)
|
||||
|
||||
fmt.Println(ui.RenderConnecting(connConfig.ServerAddr, reconnectAttempts, maxReconnectAttempts))
|
||||
|
||||
if err := connector.Connect(); err != nil {
|
||||
if isNonRetryableError(err) {
|
||||
return fmt.Errorf("failed to connect: %w", err)
|
||||
}
|
||||
|
||||
reconnectAttempts++
|
||||
if reconnectAttempts >= maxReconnectAttempts {
|
||||
return fmt.Errorf("failed to connect after %d attempts: %w", maxReconnectAttempts, err)
|
||||
}
|
||||
fmt.Println(ui.RenderConnectionFailed(err))
|
||||
fmt.Println(ui.RenderRetrying(reconnectInterval))
|
||||
|
||||
select {
|
||||
case <-quit:
|
||||
fmt.Println(ui.RenderShuttingDown())
|
||||
return nil
|
||||
case <-time.After(reconnectInterval):
|
||||
continue
|
||||
}
|
||||
}
|
||||
|
||||
reconnectAttempts = 0
|
||||
|
||||
if daemonInfo != nil {
|
||||
daemonInfo.URL = connector.GetURL()
|
||||
if err := SaveDaemonInfo(daemonInfo); err != nil {
|
||||
logger.Warn("Failed to save daemon info", zap.Error(err))
|
||||
}
|
||||
}
|
||||
|
||||
displayAddr := connConfig.LocalHost
|
||||
if displayAddr == "127.0.0.1" {
|
||||
displayAddr = "localhost"
|
||||
}
|
||||
|
||||
status := &ui.TunnelStatus{
|
||||
Type: string(connConfig.TunnelType),
|
||||
URL: connector.GetURL(),
|
||||
LocalAddr: fmt.Sprintf("%s:%d", displayAddr, connConfig.LocalPort),
|
||||
}
|
||||
|
||||
fmt.Print(ui.RenderTunnelConnected(status))
|
||||
|
||||
latencyCh := make(chan time.Duration, 1)
|
||||
connector.SetLatencyCallback(func(latency time.Duration) {
|
||||
select {
|
||||
case latencyCh <- latency:
|
||||
default:
|
||||
}
|
||||
})
|
||||
|
||||
stopDisplay := make(chan struct{})
|
||||
disconnected := make(chan struct{})
|
||||
|
||||
go func() {
|
||||
ticker := time.NewTicker(1 * time.Second)
|
||||
defer ticker.Stop()
|
||||
|
||||
var lastLatency time.Duration
|
||||
lastRenderedLines := 0
|
||||
|
||||
for {
|
||||
select {
|
||||
case latency := <-latencyCh:
|
||||
lastLatency = latency
|
||||
case <-ticker.C:
|
||||
stats := connector.GetStats()
|
||||
if stats != nil {
|
||||
stats.UpdateSpeed()
|
||||
snapshot := stats.GetSnapshot()
|
||||
|
||||
status.Latency = lastLatency
|
||||
status.BytesIn = snapshot.TotalBytesIn
|
||||
status.BytesOut = snapshot.TotalBytesOut
|
||||
status.SpeedIn = float64(snapshot.SpeedIn)
|
||||
status.SpeedOut = float64(snapshot.SpeedOut)
|
||||
status.TotalRequest = snapshot.TotalRequests
|
||||
|
||||
statsView := ui.RenderTunnelStats(status)
|
||||
if lastRenderedLines > 0 {
|
||||
fmt.Print(clearLines(lastRenderedLines))
|
||||
}
|
||||
|
||||
fmt.Print(statsView)
|
||||
lastRenderedLines = countRenderedLines(statsView)
|
||||
}
|
||||
case <-stopDisplay:
|
||||
return
|
||||
}
|
||||
}
|
||||
}()
|
||||
|
||||
go func() {
|
||||
connector.Wait()
|
||||
close(disconnected)
|
||||
}()
|
||||
|
||||
select {
|
||||
case <-quit:
|
||||
close(stopDisplay)
|
||||
fmt.Println()
|
||||
fmt.Println(ui.RenderShuttingDown())
|
||||
|
||||
// Close with timeout (wait for ongoing requests to complete)
|
||||
done := make(chan struct{})
|
||||
go func() {
|
||||
connector.Close()
|
||||
close(done)
|
||||
}()
|
||||
|
||||
select {
|
||||
case <-done:
|
||||
// Closed successfully
|
||||
case <-time.After(5 * time.Second):
|
||||
fmt.Println(ui.Warning("Force closing (timeout)..."))
|
||||
}
|
||||
|
||||
if daemonInfo != nil {
|
||||
RemoveDaemonInfo(daemonInfo.Type, daemonInfo.Port)
|
||||
}
|
||||
fmt.Println(ui.Success("Tunnel closed"))
|
||||
return nil
|
||||
case <-disconnected:
|
||||
close(stopDisplay)
|
||||
fmt.Println()
|
||||
fmt.Println(ui.RenderConnectionLost())
|
||||
reconnectAttempts++
|
||||
if reconnectAttempts >= maxReconnectAttempts {
|
||||
return fmt.Errorf("connection lost after %d reconnect attempts", maxReconnectAttempts)
|
||||
}
|
||||
fmt.Println(ui.RenderRetrying(reconnectInterval))
|
||||
|
||||
select {
|
||||
case <-quit:
|
||||
fmt.Println(ui.RenderShuttingDown())
|
||||
return nil
|
||||
case <-time.After(reconnectInterval):
|
||||
continue
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func clearLines(lines int) string {
|
||||
if lines <= 0 {
|
||||
return ""
|
||||
}
|
||||
return fmt.Sprintf("\033[%dA\033[J", lines)
|
||||
}
|
||||
|
||||
func countRenderedLines(block string) int {
|
||||
if block == "" {
|
||||
return 0
|
||||
}
|
||||
|
||||
lines := strings.Count(block, "\n")
|
||||
if !strings.HasSuffix(block, "\n") {
|
||||
lines++
|
||||
}
|
||||
|
||||
return lines
|
||||
}
|
||||
|
||||
func isNonRetryableError(err error) bool {
|
||||
errStr := err.Error()
|
||||
return strings.Contains(errStr, "subdomain is already taken") ||
|
||||
strings.Contains(errStr, "subdomain is reserved") ||
|
||||
strings.Contains(errStr, "invalid subdomain") ||
|
||||
strings.Contains(errStr, "authentication") ||
|
||||
strings.Contains(errStr, "Invalid authentication token")
|
||||
}
|
||||
|
||||
func isNonRetryableErrorTCP(err error) bool {
|
||||
return isNonRetryableError(err)
|
||||
}
|
||||
117
internal/client/cli/ui/config.go
Normal file
117
internal/client/cli/ui/config.go
Normal file
@@ -0,0 +1,117 @@
|
||||
package ui
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
)
|
||||
|
||||
// RenderConfigInit renders config initialization UI
|
||||
func RenderConfigInit() string {
|
||||
title := "Drip Configuration Setup"
|
||||
box := boxStyle.Copy().Width(50)
|
||||
return "\n" + box.Render(titleStyle.Render(title)) + "\n"
|
||||
}
|
||||
|
||||
// RenderConfigShow renders the config display
|
||||
func RenderConfigShow(server, token string, tokenHidden bool, tlsEnabled bool, configPath string) string {
|
||||
lines := []string{
|
||||
KeyValue("Server", server),
|
||||
}
|
||||
|
||||
if token != "" {
|
||||
if tokenHidden {
|
||||
if len(token) > 10 {
|
||||
displayToken := token[:3] + "***" + token[len(token)-3:]
|
||||
lines = append(lines, KeyValue("Token", Muted(displayToken+" (hidden)")))
|
||||
} else {
|
||||
lines = append(lines, KeyValue("Token", Muted(token[:3]+"*** (hidden)")))
|
||||
}
|
||||
} else {
|
||||
lines = append(lines, KeyValue("Token", token))
|
||||
}
|
||||
} else {
|
||||
lines = append(lines, KeyValue("Token", Muted("(not set)")))
|
||||
}
|
||||
|
||||
tlsStatus := "enabled"
|
||||
if !tlsEnabled {
|
||||
tlsStatus = "disabled"
|
||||
}
|
||||
lines = append(lines, KeyValue("TLS", tlsStatus))
|
||||
lines = append(lines, KeyValue("Config", Muted(configPath)))
|
||||
|
||||
return Info("Current Configuration", lines...)
|
||||
}
|
||||
|
||||
// RenderConfigSaved renders config saved message
|
||||
func RenderConfigSaved(configPath string) string {
|
||||
return SuccessBox(
|
||||
"Configuration Saved",
|
||||
Muted("Config saved to: ")+configPath,
|
||||
"",
|
||||
Muted("You can now use 'drip' without --server and --token flags"),
|
||||
)
|
||||
}
|
||||
|
||||
// RenderConfigUpdated renders config updated message
|
||||
func RenderConfigUpdated(updates []string) string {
|
||||
lines := make([]string, len(updates)+1)
|
||||
for i, update := range updates {
|
||||
lines[i] = Success(update)
|
||||
}
|
||||
lines[len(updates)] = ""
|
||||
lines = append(lines, Muted("Configuration has been updated"))
|
||||
return SuccessBox("Configuration Updated", lines...)
|
||||
}
|
||||
|
||||
// RenderConfigDeleted renders config deleted message
|
||||
func RenderConfigDeleted() string {
|
||||
return SuccessBox("Configuration Deleted", Muted("Configuration file has been removed"))
|
||||
}
|
||||
|
||||
// RenderConfigValidation renders config validation results
|
||||
func RenderConfigValidation(serverValid, tokenSet, tlsEnabled bool) string {
|
||||
lines := []string{}
|
||||
|
||||
if serverValid {
|
||||
lines = append(lines, Success("Server address is valid"))
|
||||
} else {
|
||||
lines = append(lines, Error("Server address is not set"))
|
||||
}
|
||||
|
||||
if tokenSet {
|
||||
lines = append(lines, Success("Token is set"))
|
||||
} else {
|
||||
lines = append(lines, Warning("Token is not set (authentication may fail)"))
|
||||
}
|
||||
|
||||
if tlsEnabled {
|
||||
lines = append(lines, Success("TLS is enabled"))
|
||||
} else {
|
||||
lines = append(lines, Warning("TLS is disabled (not recommended for production)"))
|
||||
}
|
||||
|
||||
lines = append(lines, "")
|
||||
lines = append(lines, Muted("Configuration validation complete"))
|
||||
|
||||
if serverValid && tokenSet && tlsEnabled {
|
||||
return SuccessBox("Configuration Valid", lines...)
|
||||
}
|
||||
return WarningBox("Configuration Validation", lines...)
|
||||
}
|
||||
|
||||
// RenderDaemonStarted renders daemon started message
|
||||
func RenderDaemonStarted(tunnelType string, port int, pid int, logPath string) string {
|
||||
lines := []string{
|
||||
KeyValue("Type", Highlight(tunnelType)),
|
||||
KeyValue("Port", fmt.Sprintf("%d", port)),
|
||||
KeyValue("PID", fmt.Sprintf("%d", pid)),
|
||||
"",
|
||||
Muted("Commands:"),
|
||||
Cyan(" drip list") + Muted(" Check tunnel status"),
|
||||
Cyan(fmt.Sprintf(" drip attach %s %d", tunnelType, port)) + Muted(" View logs"),
|
||||
Cyan(fmt.Sprintf(" drip stop %s %d", tunnelType, port)) + Muted(" Stop tunnel"),
|
||||
"",
|
||||
Muted("Logs: ")+mutedStyle.Render(logPath),
|
||||
}
|
||||
return SuccessBox("Tunnel Started in Background", lines...)
|
||||
}
|
||||
204
internal/client/cli/ui/styles.go
Normal file
204
internal/client/cli/ui/styles.go
Normal file
@@ -0,0 +1,204 @@
|
||||
package ui
|
||||
|
||||
import (
|
||||
"github.com/charmbracelet/lipgloss"
|
||||
)
|
||||
|
||||
var (
|
||||
// Colors inspired by Vercel CLI
|
||||
primaryColor = lipgloss.Color("#0070F3")
|
||||
successColor = lipgloss.Color("#0070F3")
|
||||
warningColor = lipgloss.Color("#F5A623")
|
||||
errorColor = lipgloss.Color("#E00")
|
||||
mutedColor = lipgloss.Color("#888")
|
||||
highlightColor = lipgloss.Color("#0070F3")
|
||||
cyanColor = lipgloss.Color("#50E3C2")
|
||||
purpleColor = lipgloss.Color("#7928CA")
|
||||
|
||||
// Base styles
|
||||
baseStyle = lipgloss.NewStyle().
|
||||
PaddingLeft(1).
|
||||
PaddingRight(1)
|
||||
|
||||
// Box styles - Vercel-like clean box
|
||||
boxStyle = lipgloss.NewStyle().
|
||||
Border(lipgloss.RoundedBorder()).
|
||||
BorderForeground(lipgloss.Color("#333")).
|
||||
Padding(1, 2).
|
||||
MarginTop(1).
|
||||
MarginBottom(1)
|
||||
|
||||
successBoxStyle = boxStyle.Copy().
|
||||
BorderForeground(successColor)
|
||||
|
||||
warningBoxStyle = boxStyle.Copy().
|
||||
BorderForeground(warningColor)
|
||||
|
||||
errorBoxStyle = boxStyle.Copy().
|
||||
BorderForeground(errorColor)
|
||||
|
||||
// Text styles
|
||||
titleStyle = lipgloss.NewStyle().
|
||||
Bold(true).
|
||||
Foreground(lipgloss.Color("#FFF"))
|
||||
|
||||
subtitleStyle = lipgloss.NewStyle().
|
||||
Foreground(mutedColor)
|
||||
|
||||
successStyle = lipgloss.NewStyle().
|
||||
Foreground(successColor).
|
||||
Bold(true)
|
||||
|
||||
errorStyle = lipgloss.NewStyle().
|
||||
Foreground(errorColor).
|
||||
Bold(true)
|
||||
|
||||
warningStyle = lipgloss.NewStyle().
|
||||
Foreground(warningColor).
|
||||
Bold(true)
|
||||
|
||||
mutedStyle = lipgloss.NewStyle().
|
||||
Foreground(mutedColor)
|
||||
|
||||
highlightStyle = lipgloss.NewStyle().
|
||||
Foreground(highlightColor).
|
||||
Bold(true)
|
||||
|
||||
cyanStyle = lipgloss.NewStyle().
|
||||
Foreground(cyanColor)
|
||||
|
||||
urlStyle = lipgloss.NewStyle().
|
||||
Foreground(highlightColor).
|
||||
Underline(true).
|
||||
Bold(true)
|
||||
|
||||
labelStyle = lipgloss.NewStyle().
|
||||
Foreground(mutedColor).
|
||||
Width(12)
|
||||
|
||||
valueStyle = lipgloss.NewStyle().
|
||||
Bold(true)
|
||||
|
||||
// Table styles
|
||||
tableHeaderStyle = lipgloss.NewStyle().
|
||||
Foreground(mutedColor).
|
||||
Bold(true).
|
||||
PaddingRight(2)
|
||||
|
||||
tableCellStyle = lipgloss.NewStyle().
|
||||
PaddingRight(2)
|
||||
|
||||
tableRowHighlight = lipgloss.NewStyle().
|
||||
Foreground(highlightColor).
|
||||
Bold(true)
|
||||
)
|
||||
|
||||
// Success returns a styled success message
|
||||
func Success(text string) string {
|
||||
return successStyle.Render("✓ " + text)
|
||||
}
|
||||
|
||||
// Error returns a styled error message
|
||||
func Error(text string) string {
|
||||
return errorStyle.Render("✗ " + text)
|
||||
}
|
||||
|
||||
// Warning returns a styled warning message
|
||||
func Warning(text string) string {
|
||||
return warningStyle.Render("⚠ " + text)
|
||||
}
|
||||
|
||||
// Muted returns a styled muted text
|
||||
func Muted(text string) string {
|
||||
return mutedStyle.Render(text)
|
||||
}
|
||||
|
||||
// Highlight returns a styled highlighted text
|
||||
func Highlight(text string) string {
|
||||
return highlightStyle.Render(text)
|
||||
}
|
||||
|
||||
// Cyan returns a styled cyan text
|
||||
func Cyan(text string) string {
|
||||
return cyanStyle.Render(text)
|
||||
}
|
||||
|
||||
// URL returns a styled URL
|
||||
func URL(text string) string {
|
||||
return urlStyle.Render(text)
|
||||
}
|
||||
|
||||
// Title returns a styled title
|
||||
func Title(text string) string {
|
||||
return titleStyle.Render(text)
|
||||
}
|
||||
|
||||
// Subtitle returns a styled subtitle
|
||||
func Subtitle(text string) string {
|
||||
return subtitleStyle.Render(text)
|
||||
}
|
||||
|
||||
// KeyValue returns a styled key-value pair
|
||||
func KeyValue(key, value string) string {
|
||||
return labelStyle.Render(key+":") + " " + valueStyle.Render(value)
|
||||
}
|
||||
|
||||
// Info renders an info box (Vercel-style)
|
||||
func Info(title string, lines ...string) string {
|
||||
content := titleStyle.Render(title)
|
||||
if len(lines) > 0 {
|
||||
content += "\n\n"
|
||||
for i, line := range lines {
|
||||
if i > 0 {
|
||||
content += "\n"
|
||||
}
|
||||
content += line
|
||||
}
|
||||
}
|
||||
return boxStyle.Render(content)
|
||||
}
|
||||
|
||||
// SuccessBox renders a success box
|
||||
func SuccessBox(title string, lines ...string) string {
|
||||
content := successStyle.Render("✓ " + title)
|
||||
if len(lines) > 0 {
|
||||
content += "\n\n"
|
||||
for i, line := range lines {
|
||||
if i > 0 {
|
||||
content += "\n"
|
||||
}
|
||||
content += line
|
||||
}
|
||||
}
|
||||
return successBoxStyle.Render(content)
|
||||
}
|
||||
|
||||
// WarningBox renders a warning box
|
||||
func WarningBox(title string, lines ...string) string {
|
||||
content := warningStyle.Render("⚠ " + title)
|
||||
if len(lines) > 0 {
|
||||
content += "\n\n"
|
||||
for i, line := range lines {
|
||||
if i > 0 {
|
||||
content += "\n"
|
||||
}
|
||||
content += line
|
||||
}
|
||||
}
|
||||
return warningBoxStyle.Render(content)
|
||||
}
|
||||
|
||||
// ErrorBox renders an error box
|
||||
func ErrorBox(title string, lines ...string) string {
|
||||
content := errorStyle.Render("✗ " + title)
|
||||
if len(lines) > 0 {
|
||||
content += "\n\n"
|
||||
for i, line := range lines {
|
||||
if i > 0 {
|
||||
content += "\n"
|
||||
}
|
||||
content += line
|
||||
}
|
||||
}
|
||||
return errorBoxStyle.Render(content)
|
||||
}
|
||||
127
internal/client/cli/ui/table.go
Normal file
127
internal/client/cli/ui/table.go
Normal file
@@ -0,0 +1,127 @@
|
||||
package ui
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"strings"
|
||||
|
||||
"github.com/charmbracelet/lipgloss"
|
||||
)
|
||||
|
||||
// Table represents a simple table for CLI output
|
||||
type Table struct {
|
||||
headers []string
|
||||
rows [][]string
|
||||
title string
|
||||
}
|
||||
|
||||
// NewTable creates a new table
|
||||
func NewTable(headers []string) *Table {
|
||||
return &Table{
|
||||
headers: headers,
|
||||
rows: [][]string{},
|
||||
}
|
||||
}
|
||||
|
||||
// WithTitle sets the table title
|
||||
func (t *Table) WithTitle(title string) *Table {
|
||||
t.title = title
|
||||
return t
|
||||
}
|
||||
|
||||
// AddRow adds a row to the table
|
||||
func (t *Table) AddRow(row []string) *Table {
|
||||
t.rows = append(t.rows, row)
|
||||
return t
|
||||
}
|
||||
|
||||
// Render renders the table (Vercel-style)
|
||||
func (t *Table) Render() string {
|
||||
if len(t.rows) == 0 {
|
||||
return ""
|
||||
}
|
||||
|
||||
// Calculate column widths
|
||||
colWidths := make([]int, len(t.headers))
|
||||
for i, header := range t.headers {
|
||||
colWidths[i] = lipgloss.Width(header)
|
||||
}
|
||||
for _, row := range t.rows {
|
||||
for i, cell := range row {
|
||||
if i < len(colWidths) {
|
||||
width := lipgloss.Width(cell)
|
||||
if width > colWidths[i] {
|
||||
colWidths[i] = width
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
var output strings.Builder
|
||||
|
||||
// Title
|
||||
if t.title != "" {
|
||||
output.WriteString("\n")
|
||||
output.WriteString(titleStyle.Render(t.title))
|
||||
output.WriteString("\n\n")
|
||||
}
|
||||
|
||||
// Header
|
||||
headerParts := make([]string, len(t.headers))
|
||||
for i, header := range t.headers {
|
||||
style := tableHeaderStyle.Copy().Width(colWidths[i])
|
||||
headerParts[i] = style.Render(header)
|
||||
}
|
||||
output.WriteString(strings.Join(headerParts, " "))
|
||||
output.WriteString("\n")
|
||||
|
||||
// Separator line
|
||||
separatorParts := make([]string, len(t.headers))
|
||||
for i := range t.headers {
|
||||
separatorParts[i] = mutedStyle.Render(strings.Repeat("─", colWidths[i]))
|
||||
}
|
||||
output.WriteString(strings.Join(separatorParts, " "))
|
||||
output.WriteString("\n")
|
||||
|
||||
// Rows
|
||||
for _, row := range t.rows {
|
||||
rowParts := make([]string, len(t.headers))
|
||||
for i, cell := range row {
|
||||
if i < len(colWidths) {
|
||||
style := tableCellStyle.Copy().Width(colWidths[i])
|
||||
rowParts[i] = style.Render(cell)
|
||||
}
|
||||
}
|
||||
output.WriteString(strings.Join(rowParts, " "))
|
||||
output.WriteString("\n")
|
||||
}
|
||||
|
||||
output.WriteString("\n")
|
||||
return output.String()
|
||||
}
|
||||
|
||||
// Print prints the table
|
||||
func (t *Table) Print() {
|
||||
fmt.Print(t.Render())
|
||||
}
|
||||
|
||||
// RenderList renders a simple list with bullet points
|
||||
func RenderList(items []string) string {
|
||||
var output strings.Builder
|
||||
for _, item := range items {
|
||||
output.WriteString(mutedStyle.Render(" • "))
|
||||
output.WriteString(item)
|
||||
output.WriteString("\n")
|
||||
}
|
||||
return output.String()
|
||||
}
|
||||
|
||||
// RenderNumberedList renders a numbered list
|
||||
func RenderNumberedList(items []string) string {
|
||||
var output strings.Builder
|
||||
for i, item := range items {
|
||||
output.WriteString(mutedStyle.Render(fmt.Sprintf(" %d. ", i+1)))
|
||||
output.WriteString(item)
|
||||
output.WriteString("\n")
|
||||
}
|
||||
return output.String()
|
||||
}
|
||||
231
internal/client/cli/ui/tunnel.go
Normal file
231
internal/client/cli/ui/tunnel.go
Normal file
@@ -0,0 +1,231 @@
|
||||
package ui
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"strings"
|
||||
"time"
|
||||
|
||||
"github.com/charmbracelet/lipgloss"
|
||||
)
|
||||
|
||||
const (
|
||||
tunnelCardWidth = 76
|
||||
statsColumnWidth = 32
|
||||
)
|
||||
|
||||
var (
|
||||
latencyFastColor = lipgloss.Color("#22c55e") // green
|
||||
latencyYellowColor = lipgloss.Color("#eab308") // yellow
|
||||
latencyOrangeColor = lipgloss.Color("#f97316") // orange
|
||||
latencyRedColor = lipgloss.Color("#ef4444") // red
|
||||
)
|
||||
|
||||
// TunnelStatus represents the status of a tunnel
|
||||
type TunnelStatus struct {
|
||||
Type string // "http", "https", "tcp"
|
||||
URL string // Public URL
|
||||
LocalAddr string // Local address
|
||||
Latency time.Duration // Current latency
|
||||
BytesIn int64 // Bytes received
|
||||
BytesOut int64 // Bytes sent
|
||||
SpeedIn float64 // Download speed
|
||||
SpeedOut float64 // Upload speed
|
||||
TotalRequest int64 // Total requests
|
||||
}
|
||||
|
||||
// RenderTunnelConnected renders the tunnel connection card
|
||||
func RenderTunnelConnected(status *TunnelStatus) string {
|
||||
icon, typeStr, accent := tunnelVisuals(status.Type)
|
||||
|
||||
card := lipgloss.NewStyle().
|
||||
Border(lipgloss.RoundedBorder()).
|
||||
BorderForeground(accent).
|
||||
Padding(1, 2).
|
||||
Width(tunnelCardWidth)
|
||||
|
||||
typeBadge := lipgloss.NewStyle().
|
||||
Background(accent).
|
||||
Foreground(lipgloss.Color("#f8fafc")).
|
||||
Bold(true).
|
||||
Padding(0, 1).
|
||||
Render(strings.ToUpper(typeStr) + " TUNNEL")
|
||||
|
||||
headline := lipgloss.JoinHorizontal(
|
||||
lipgloss.Left,
|
||||
lipgloss.NewStyle().Foreground(accent).Render(icon),
|
||||
lipgloss.NewStyle().Bold(true).MarginLeft(1).Render("Tunnel Connected"),
|
||||
lipgloss.NewStyle().MarginLeft(2).Render(typeBadge),
|
||||
)
|
||||
|
||||
urlLine := urlStyle.Copy().Foreground(accent).Render(status.URL)
|
||||
forwardLine := Muted("⇢ ") + valueStyle.Render(status.LocalAddr)
|
||||
hint := mutedStyle.Render("Ctrl+C to stop • reconnects automatically")
|
||||
|
||||
content := lipgloss.JoinVertical(
|
||||
lipgloss.Left,
|
||||
headline,
|
||||
"",
|
||||
urlLine,
|
||||
forwardLine,
|
||||
"",
|
||||
hint,
|
||||
)
|
||||
|
||||
return "\n" + card.Render(content) + "\n"
|
||||
}
|
||||
|
||||
// RenderTunnelStats renders real-time tunnel statistics in a card
|
||||
func RenderTunnelStats(status *TunnelStatus) string {
|
||||
latencyStr := formatLatency(status.Latency)
|
||||
trafficStr := fmt.Sprintf("↓ %s ↑ %s", formatBytes(status.BytesIn), formatBytes(status.BytesOut))
|
||||
speedStr := fmt.Sprintf("↓ %s ↑ %s", formatSpeed(status.SpeedIn), formatSpeed(status.SpeedOut))
|
||||
requestsStr := fmt.Sprintf("%d", status.TotalRequest)
|
||||
|
||||
_, _, accent := tunnelVisuals(status.Type)
|
||||
|
||||
header := lipgloss.JoinHorizontal(
|
||||
lipgloss.Left,
|
||||
lipgloss.NewStyle().Foreground(accent).Render("◉"),
|
||||
lipgloss.NewStyle().Bold(true).MarginLeft(1).Render("Live Metrics"),
|
||||
)
|
||||
|
||||
row1 := lipgloss.JoinHorizontal(
|
||||
lipgloss.Top,
|
||||
statColumn("Latency", latencyStr, statsColumnWidth),
|
||||
statColumn("Requests", highlightStyle.Render(requestsStr), statsColumnWidth),
|
||||
)
|
||||
|
||||
row2 := lipgloss.JoinHorizontal(
|
||||
lipgloss.Top,
|
||||
statColumn("Traffic", Cyan(trafficStr), statsColumnWidth),
|
||||
statColumn("Speed", warningStyle.Render(speedStr), statsColumnWidth),
|
||||
)
|
||||
|
||||
card := lipgloss.NewStyle().
|
||||
Border(lipgloss.RoundedBorder()).
|
||||
BorderForeground(accent).
|
||||
Padding(1, 2).
|
||||
Width(tunnelCardWidth)
|
||||
|
||||
body := lipgloss.JoinVertical(
|
||||
lipgloss.Left,
|
||||
header,
|
||||
"",
|
||||
row1,
|
||||
row2,
|
||||
)
|
||||
|
||||
return "\n" + card.Render(body) + "\n"
|
||||
}
|
||||
|
||||
// RenderConnecting renders the connecting message
|
||||
func RenderConnecting(serverAddr string, attempt int, maxAttempts int) string {
|
||||
if attempt == 0 {
|
||||
return Highlight("◌") + " Connecting to " + Muted(serverAddr) + "..."
|
||||
}
|
||||
return Warning(fmt.Sprintf("◌ Reconnecting to %s (attempt %d/%d)...", serverAddr, attempt, maxAttempts))
|
||||
}
|
||||
|
||||
// RenderConnectionFailed renders connection failure message
|
||||
func RenderConnectionFailed(err error) string {
|
||||
return Error(fmt.Sprintf("Connection failed: %v", err))
|
||||
}
|
||||
|
||||
// RenderShuttingDown renders shutdown message
|
||||
func RenderShuttingDown() string {
|
||||
return Warning("⏹ Shutting down...")
|
||||
}
|
||||
|
||||
// RenderConnectionLost renders connection lost message
|
||||
func RenderConnectionLost() string {
|
||||
return Error("⚠ Connection lost!")
|
||||
}
|
||||
|
||||
// RenderRetrying renders retry message
|
||||
func RenderRetrying(interval time.Duration) string {
|
||||
return Muted(fmt.Sprintf(" Retrying in %v...", interval))
|
||||
}
|
||||
|
||||
// formatLatency formats latency with color
|
||||
func formatLatency(d time.Duration) string {
|
||||
ms := d.Milliseconds()
|
||||
var style lipgloss.Style
|
||||
|
||||
if ms == 0 {
|
||||
return mutedStyle.Render("measuring...")
|
||||
}
|
||||
|
||||
switch {
|
||||
case ms < 50:
|
||||
style = lipgloss.NewStyle().Foreground(latencyFastColor)
|
||||
case ms < 150:
|
||||
style = lipgloss.NewStyle().Foreground(latencyYellowColor)
|
||||
case ms < 300:
|
||||
style = lipgloss.NewStyle().Foreground(latencyOrangeColor)
|
||||
default:
|
||||
style = lipgloss.NewStyle().Foreground(latencyRedColor)
|
||||
}
|
||||
|
||||
return style.Render(fmt.Sprintf("%dms", ms))
|
||||
}
|
||||
|
||||
// formatBytes formats bytes to human readable format
|
||||
func formatBytes(bytes int64) string {
|
||||
const unit = 1024
|
||||
if bytes < unit {
|
||||
return fmt.Sprintf("%d B", bytes)
|
||||
}
|
||||
div, exp := int64(unit), 0
|
||||
for n := bytes / unit; n >= unit; n /= unit {
|
||||
div *= unit
|
||||
exp++
|
||||
}
|
||||
return fmt.Sprintf("%.1f %cB", float64(bytes)/float64(div), "KMGTPE"[exp])
|
||||
}
|
||||
|
||||
// formatSpeed formats speed to human readable format
|
||||
func formatSpeed(bytesPerSec float64) string {
|
||||
const unit = 1024.0
|
||||
if bytesPerSec < unit {
|
||||
return fmt.Sprintf("%.0f B/s", bytesPerSec)
|
||||
}
|
||||
div, exp := unit, 0
|
||||
for n := bytesPerSec / unit; n >= unit; n /= unit {
|
||||
div *= unit
|
||||
exp++
|
||||
}
|
||||
return fmt.Sprintf("%.1f %cB/s", bytesPerSec/div, "KMGTPE"[exp])
|
||||
}
|
||||
|
||||
func statColumn(label, value string, width int) string {
|
||||
labelView := lipgloss.NewStyle().
|
||||
Foreground(mutedColor).
|
||||
Render(strings.ToUpper(label))
|
||||
|
||||
block := lipgloss.JoinHorizontal(
|
||||
lipgloss.Left,
|
||||
labelView,
|
||||
lipgloss.NewStyle().MarginLeft(1).Render(value),
|
||||
)
|
||||
|
||||
if width <= 0 {
|
||||
return block
|
||||
}
|
||||
|
||||
return lipgloss.NewStyle().
|
||||
Width(width).
|
||||
Render(block)
|
||||
}
|
||||
|
||||
func tunnelVisuals(tunnelType string) (string, string, lipgloss.Color) {
|
||||
switch tunnelType {
|
||||
case "http":
|
||||
return "🚀", "HTTP", lipgloss.Color("#0070F3")
|
||||
case "https":
|
||||
return "🔒", "HTTPS", lipgloss.Color("#2D8CFF")
|
||||
case "tcp":
|
||||
return "🔌", "TCP", lipgloss.Color("#50E3C2")
|
||||
default:
|
||||
return "🌐", strings.ToUpper(tunnelType), lipgloss.Color("#0070F3")
|
||||
}
|
||||
}
|
||||
@@ -2,9 +2,10 @@ package tcp
|
||||
|
||||
import (
|
||||
"crypto/tls"
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"net"
|
||||
|
||||
json "github.com/goccy/go-json"
|
||||
"sync"
|
||||
"time"
|
||||
|
||||
@@ -132,9 +133,10 @@ func (c *Connector) Connect() error {
|
||||
bufferPool,
|
||||
)
|
||||
|
||||
c.frameWriter.EnableHeartbeat(constants.HeartbeatInterval, c.createHeartbeatFrame)
|
||||
|
||||
go c.frameHandler.WarmupConnectionPool(3)
|
||||
go c.handleFrames()
|
||||
go c.heartbeat()
|
||||
|
||||
return nil
|
||||
}
|
||||
@@ -278,38 +280,21 @@ func (c *Connector) handleFrames() {
|
||||
}
|
||||
}
|
||||
|
||||
// heartbeat sends periodic heartbeat frames
|
||||
func (c *Connector) heartbeat() {
|
||||
ticker := time.NewTicker(constants.HeartbeatInterval)
|
||||
defer ticker.Stop()
|
||||
|
||||
c.sendHeartbeat()
|
||||
|
||||
for {
|
||||
select {
|
||||
case <-c.stopCh:
|
||||
return
|
||||
case <-ticker.C:
|
||||
c.sendHeartbeat()
|
||||
}
|
||||
// createHeartbeatFrame creates a heartbeat frame to be sent by the write loop.
|
||||
func (c *Connector) createHeartbeatFrame() *protocol.Frame {
|
||||
c.closedMu.RLock()
|
||||
if c.closed {
|
||||
c.closedMu.RUnlock()
|
||||
return nil
|
||||
}
|
||||
}
|
||||
|
||||
// sendHeartbeat sends a heartbeat frame and records the time
|
||||
func (c *Connector) sendHeartbeat() {
|
||||
hbFrame := protocol.NewFrame(protocol.FrameTypeHeartbeat, nil)
|
||||
c.closedMu.RUnlock()
|
||||
|
||||
c.heartbeatMu.Lock()
|
||||
c.heartbeatSentAt = time.Now()
|
||||
c.heartbeatMu.Unlock()
|
||||
|
||||
err := c.frameWriter.WriteFrame(hbFrame)
|
||||
if err != nil {
|
||||
c.logger.Error("Failed to send heartbeat", zap.Error(err))
|
||||
c.Close()
|
||||
return
|
||||
}
|
||||
c.logger.Debug("Heartbeat sent")
|
||||
return protocol.NewFrame(protocol.FrameTypeHeartbeat, nil)
|
||||
}
|
||||
|
||||
// SendFrame sends a frame to the server
|
||||
@@ -330,8 +315,20 @@ func (c *Connector) Close() error {
|
||||
|
||||
close(c.stopCh)
|
||||
|
||||
// Wait for active handlers with timeout
|
||||
c.logger.Debug("Waiting for active handlers to complete")
|
||||
c.handlerWg.Wait()
|
||||
done := make(chan struct{})
|
||||
go func() {
|
||||
c.handlerWg.Wait()
|
||||
close(done)
|
||||
}()
|
||||
|
||||
select {
|
||||
case <-done:
|
||||
c.logger.Debug("All handlers completed")
|
||||
case <-time.After(3 * time.Second):
|
||||
c.logger.Warn("Force closing: some handlers are still active")
|
||||
}
|
||||
|
||||
if c.conn != nil {
|
||||
closeFrame := protocol.NewFrame(protocol.FrameTypeClose, nil)
|
||||
|
||||
@@ -19,7 +19,7 @@ import (
|
||||
// FrameHandler handles data frames and forwards to local service
|
||||
type FrameHandler struct {
|
||||
conn net.Conn
|
||||
frameWriter *protocol.FrameWriter // Async batch writer (replaces writeMu)
|
||||
frameWriter *protocol.FrameWriter
|
||||
localHost string
|
||||
localPort int
|
||||
logger *zap.Logger
|
||||
@@ -28,9 +28,9 @@ type FrameHandler struct {
|
||||
tunnelType protocol.TunnelType
|
||||
httpClient *http.Client
|
||||
stats *TrafficStats
|
||||
isClosedCheck func() bool // Function to check if connection is closed
|
||||
isClosedCheck func() bool
|
||||
bufferPool *pool.BufferPool
|
||||
headerPool *pool.HeaderPool // Header pool for Priority 9 optimization
|
||||
headerPool *pool.HeaderPool
|
||||
}
|
||||
|
||||
// Stream represents a single request/response stream
|
||||
|
||||
@@ -2,7 +2,6 @@ package proxy
|
||||
|
||||
import (
|
||||
"context"
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"io"
|
||||
"net/http"
|
||||
@@ -10,6 +9,8 @@ import (
|
||||
"strings"
|
||||
"time"
|
||||
|
||||
json "github.com/goccy/go-json"
|
||||
|
||||
"drip/internal/server/tunnel"
|
||||
"drip/internal/shared/constants"
|
||||
"drip/internal/shared/pool"
|
||||
@@ -244,7 +245,7 @@ func (h *Handler) serveHomePage(w http.ResponseWriter, r *http.Request) {
|
||||
<html>
|
||||
<head>
|
||||
<meta charset="UTF-8" />
|
||||
<title>Drip - Tunnel to Localhost</title>
|
||||
<title>Drip - Your Tunnel, Your Domain, Anywhere</title>
|
||||
<style>
|
||||
body { font-family: Arial, sans-serif; max-width: 800px; margin: 50px auto; padding: 20px; }
|
||||
h1 { color: #333; }
|
||||
@@ -253,12 +254,12 @@ func (h *Handler) serveHomePage(w http.ResponseWriter, r *http.Request) {
|
||||
</style>
|
||||
</head>
|
||||
<body>
|
||||
<h1>💧 Drip - Fast Tunnels to Localhost</h1>
|
||||
<p>A high-performance tunneling service.</p>
|
||||
<h1>💧 Drip - Your Tunnel, Your Domain, Anywhere</h1>
|
||||
<p>A self-hosted tunneling solution to securely expose your services to the internet.</p>
|
||||
|
||||
<h2>Quick Start</h2>
|
||||
<p>Install the client:</p>
|
||||
<code>bash <(curl -fsSL https://raw.githubusercontent.com/Gouryella/drip/refs/heads/main/scripts/install.sh)</code>
|
||||
<code>bash <(curl -fsSL https://raw.githubusercontent.com/Gouryella/drip/main/scripts/install.sh)</code>
|
||||
|
||||
<p>Start a tunnel:</p>
|
||||
<code>drip http 3000</code><br><br>
|
||||
|
||||
@@ -2,7 +2,7 @@ package tcp
|
||||
|
||||
import (
|
||||
"bufio"
|
||||
"encoding/json"
|
||||
json "github.com/goccy/go-json"
|
||||
"fmt"
|
||||
"io"
|
||||
"net"
|
||||
@@ -82,7 +82,6 @@ func (c *Connection) Handle() error {
|
||||
return fmt.Errorf("failed to peek connection: %w", err)
|
||||
}
|
||||
|
||||
// Check if this is an HTTP request
|
||||
peekStr := string(peek)
|
||||
httpMethods := []string{"GET ", "POST", "PUT ", "DELE", "HEAD", "OPTI", "PATC", "CONN", "TRAC"}
|
||||
isHTTP := false
|
||||
@@ -110,16 +109,13 @@ func (c *Connection) Handle() error {
|
||||
return fmt.Errorf("expected register frame, got %s", frame.Type)
|
||||
}
|
||||
|
||||
// Parse registration request
|
||||
var req protocol.RegisterRequest
|
||||
if err := json.Unmarshal(frame.Payload, &req); err != nil {
|
||||
return fmt.Errorf("failed to parse registration request: %w", err)
|
||||
}
|
||||
|
||||
// Store tunnel type
|
||||
c.tunnelType = req.TunnelType
|
||||
|
||||
// Authenticate
|
||||
if c.authToken != "" && req.Token != c.authToken {
|
||||
c.sendError("authentication_failed", "Invalid authentication token")
|
||||
return fmt.Errorf("authentication failed")
|
||||
@@ -144,7 +140,6 @@ func (c *Connection) Handle() error {
|
||||
}
|
||||
}
|
||||
|
||||
// Register tunnel
|
||||
subdomain, err := c.manager.Register(nil, req.CustomSubdomain)
|
||||
if err != nil {
|
||||
c.sendError("registration_failed", err.Error())
|
||||
@@ -155,7 +150,6 @@ func (c *Connection) Handle() error {
|
||||
|
||||
c.subdomain = subdomain
|
||||
|
||||
// Get tunnel connection
|
||||
tunnelConn, ok := c.manager.Get(subdomain)
|
||||
if !ok {
|
||||
return fmt.Errorf("failed to get registered tunnel")
|
||||
@@ -211,7 +205,6 @@ func (c *Connection) Handle() error {
|
||||
// Create frame writer for async writes
|
||||
c.frameWriter = protocol.NewFrameWriter(c.conn)
|
||||
|
||||
// Clear read deadline
|
||||
c.conn.SetReadDeadline(time.Time{})
|
||||
|
||||
// Start TCP proxy only for TCP tunnels
|
||||
@@ -222,7 +215,6 @@ func (c *Connection) Handle() error {
|
||||
}
|
||||
}
|
||||
|
||||
// Start heartbeat checker
|
||||
go c.heartbeatChecker()
|
||||
|
||||
// Handle frames (pass reader for consistent buffering)
|
||||
@@ -510,24 +502,20 @@ func (c *Connection) Close() {
|
||||
c.once.Do(func() {
|
||||
close(c.stopCh)
|
||||
|
||||
// Close frame writer
|
||||
if c.frameWriter != nil {
|
||||
c.frameWriter.Close()
|
||||
}
|
||||
|
||||
// Stop TCP proxy
|
||||
if c.proxy != nil {
|
||||
c.proxy.Stop()
|
||||
}
|
||||
|
||||
c.conn.Close()
|
||||
|
||||
// Release allocated port
|
||||
if c.port > 0 && c.portAlloc != nil {
|
||||
c.portAlloc.Release(c.port)
|
||||
}
|
||||
|
||||
// Unregister tunnel
|
||||
if c.subdomain != "" {
|
||||
c.manager.Unregister(c.subdomain)
|
||||
}
|
||||
|
||||
@@ -61,7 +61,6 @@ func NewListener(address string, tlsConfig *tls.Config, authToken string, manage
|
||||
func (l *Listener) Start() error {
|
||||
var err error
|
||||
|
||||
// Create TLS listener
|
||||
l.listener, err = tls.Listen("tcp", l.address, l.tlsConfig)
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to start TLS listener: %w", err)
|
||||
@@ -72,7 +71,6 @@ func (l *Listener) Start() error {
|
||||
zap.String("tls_version", "TLS 1.3"),
|
||||
)
|
||||
|
||||
// Accept connections in background
|
||||
l.wg.Add(1)
|
||||
go l.acceptLoop()
|
||||
|
||||
@@ -102,7 +100,7 @@ func (l *Listener) acceptLoop() {
|
||||
}
|
||||
select {
|
||||
case <-l.stopCh:
|
||||
return // Listener was stopped
|
||||
return
|
||||
default:
|
||||
l.logger.Error("Failed to accept connection", zap.Error(err))
|
||||
continue
|
||||
@@ -128,14 +126,12 @@ func (l *Listener) handleConnection(netConn net.Conn) {
|
||||
defer l.wg.Done()
|
||||
defer netConn.Close()
|
||||
|
||||
// Get TLS connection info
|
||||
tlsConn, ok := netConn.(*tls.Conn)
|
||||
if !ok {
|
||||
l.logger.Error("Connection is not TLS")
|
||||
return
|
||||
}
|
||||
|
||||
// Force TLS handshake to complete
|
||||
if err := tlsConn.Handshake(); err != nil {
|
||||
// TLS handshake failures are common (HTTP clients, scanners, etc.)
|
||||
// Log as WARN instead of ERROR
|
||||
@@ -146,7 +142,6 @@ func (l *Listener) handleConnection(netConn net.Conn) {
|
||||
return
|
||||
}
|
||||
|
||||
// Log connection info
|
||||
state := tlsConn.ConnectionState()
|
||||
l.logger.Info("New connection",
|
||||
zap.String("remote_addr", netConn.RemoteAddr().String()),
|
||||
@@ -154,7 +149,6 @@ func (l *Listener) handleConnection(netConn net.Conn) {
|
||||
zap.String("cipher_suite", tls.CipherSuiteName(state.CipherSuite)),
|
||||
)
|
||||
|
||||
// Verify TLS 1.3
|
||||
if state.Version != tls.VersionTLS13 {
|
||||
l.logger.Warn("Connection not using TLS 1.3",
|
||||
zap.Uint16("version", state.Version),
|
||||
@@ -162,23 +156,19 @@ func (l *Listener) handleConnection(netConn net.Conn) {
|
||||
return
|
||||
}
|
||||
|
||||
// Create connection handler
|
||||
conn := NewConnection(netConn, l.authToken, l.manager, l.logger, l.portAlloc, l.domain, l.publicPort, l.httpHandler, l.responseChans)
|
||||
|
||||
// Store connection
|
||||
connID := netConn.RemoteAddr().String()
|
||||
l.connMu.Lock()
|
||||
l.connections[connID] = conn
|
||||
l.connMu.Unlock()
|
||||
|
||||
// Remove connection on exit
|
||||
defer func() {
|
||||
l.connMu.Lock()
|
||||
delete(l.connections, connID)
|
||||
l.connMu.Unlock()
|
||||
}()
|
||||
|
||||
// Handle connection (blocking)
|
||||
if err := conn.Handle(); err != nil {
|
||||
errStr := err.Error()
|
||||
|
||||
@@ -217,27 +207,22 @@ func (l *Listener) handleConnection(netConn net.Conn) {
|
||||
func (l *Listener) Stop() error {
|
||||
l.logger.Info("Stopping TCP listener")
|
||||
|
||||
// Signal stop
|
||||
close(l.stopCh)
|
||||
|
||||
// Close listener
|
||||
if l.listener != nil {
|
||||
if err := l.listener.Close(); err != nil {
|
||||
l.logger.Error("Failed to close listener", zap.Error(err))
|
||||
}
|
||||
}
|
||||
|
||||
// Close all connections
|
||||
l.connMu.Lock()
|
||||
for _, conn := range l.connections {
|
||||
conn.Close()
|
||||
}
|
||||
l.connMu.Unlock()
|
||||
|
||||
// Wait for all goroutines to finish
|
||||
l.wg.Wait()
|
||||
|
||||
// Close worker pool
|
||||
if l.workerPool != nil {
|
||||
l.workerPool.Close()
|
||||
}
|
||||
|
||||
@@ -58,7 +58,6 @@ func (p *TunnelProxy) Start() error {
|
||||
zap.String("subdomain", p.subdomain),
|
||||
)
|
||||
|
||||
// Accept connections in background
|
||||
p.wg.Add(1)
|
||||
go p.acceptLoop()
|
||||
|
||||
@@ -76,7 +75,6 @@ func (p *TunnelProxy) acceptLoop() {
|
||||
default:
|
||||
}
|
||||
|
||||
// Set accept deadline
|
||||
p.listener.(*net.TCPListener).SetDeadline(time.Now().Add(1 * time.Second))
|
||||
|
||||
conn, err := p.listener.Accept()
|
||||
@@ -92,7 +90,6 @@ func (p *TunnelProxy) acceptLoop() {
|
||||
}
|
||||
}
|
||||
|
||||
// Handle connection
|
||||
p.wg.Add(1)
|
||||
go p.handleConnection(conn)
|
||||
}
|
||||
|
||||
@@ -1,263 +0,0 @@
|
||||
package tunnel
|
||||
|
||||
import (
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"go.uber.org/zap"
|
||||
)
|
||||
|
||||
func TestNewConnection(t *testing.T) {
|
||||
subdomain := "test123"
|
||||
logger := zap.NewNop()
|
||||
|
||||
// We can't create a real WebSocket connection in unit tests,
|
||||
// so we'll just test with nil
|
||||
conn := NewConnection(subdomain, nil, logger)
|
||||
|
||||
if conn == nil {
|
||||
t.Fatal("NewConnection() returned nil")
|
||||
}
|
||||
|
||||
if conn.Subdomain != subdomain {
|
||||
t.Errorf("Subdomain = %v, want %v", conn.Subdomain, subdomain)
|
||||
}
|
||||
|
||||
if conn.SendCh == nil {
|
||||
t.Error("SendCh is nil")
|
||||
}
|
||||
|
||||
if conn.CloseCh == nil {
|
||||
t.Error("CloseCh is nil")
|
||||
}
|
||||
|
||||
// Check that LastActive is recent (within last second)
|
||||
now := time.Now()
|
||||
if now.Sub(conn.LastActive) > time.Second {
|
||||
t.Error("LastActive is not recent")
|
||||
}
|
||||
}
|
||||
|
||||
func TestConnectionUpdateActivity(t *testing.T) {
|
||||
logger := zap.NewNop()
|
||||
conn := NewConnection("test", nil, logger)
|
||||
|
||||
// Get initial LastActive
|
||||
initial := conn.LastActive
|
||||
|
||||
// Wait a bit
|
||||
time.Sleep(10 * time.Millisecond)
|
||||
|
||||
// Update activity
|
||||
conn.UpdateActivity()
|
||||
|
||||
// Check that LastActive was updated
|
||||
if !conn.LastActive.After(initial) {
|
||||
t.Error("UpdateActivity() did not update LastActive")
|
||||
}
|
||||
|
||||
// Check that it's recent
|
||||
now := time.Now()
|
||||
if now.Sub(conn.LastActive) > time.Second {
|
||||
t.Error("UpdateActivity() did not set recent timestamp")
|
||||
}
|
||||
}
|
||||
|
||||
func TestConnectionIsAlive(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
lastActive time.Time
|
||||
timeout time.Duration
|
||||
want bool
|
||||
}{
|
||||
{
|
||||
name: "fresh connection is alive",
|
||||
lastActive: time.Now(),
|
||||
timeout: 90 * time.Second,
|
||||
want: true,
|
||||
},
|
||||
{
|
||||
name: "stale connection is not alive",
|
||||
lastActive: time.Now().Add(-2 * time.Minute),
|
||||
timeout: 90 * time.Second,
|
||||
want: false,
|
||||
},
|
||||
{
|
||||
name: "exactly at timeout is not alive",
|
||||
lastActive: time.Now().Add(-90 * time.Second),
|
||||
timeout: 90 * time.Second,
|
||||
want: false,
|
||||
},
|
||||
{
|
||||
name: "just before timeout is alive",
|
||||
lastActive: time.Now().Add(-89 * time.Second),
|
||||
timeout: 90 * time.Second,
|
||||
want: true,
|
||||
},
|
||||
}
|
||||
logger := zap.NewNop()
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
conn := NewConnection("test", nil, logger)
|
||||
conn.mu.Lock()
|
||||
conn.LastActive = tt.lastActive
|
||||
conn.mu.Unlock()
|
||||
|
||||
got := conn.IsAlive(tt.timeout)
|
||||
if got != tt.want {
|
||||
t.Errorf("IsAlive() = %v, want %v (age: %v, timeout: %v)",
|
||||
got, tt.want, time.Since(tt.lastActive), tt.timeout)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestConnectionSend(t *testing.T) {
|
||||
logger := zap.NewNop()
|
||||
conn := NewConnection("test", nil, logger)
|
||||
|
||||
data := []byte("test message")
|
||||
|
||||
// Test successful send
|
||||
err := conn.Send(data)
|
||||
if err != nil {
|
||||
t.Errorf("Send() error = %v, want nil", err)
|
||||
}
|
||||
|
||||
// Verify data was sent to channel
|
||||
select {
|
||||
case received := <-conn.SendCh:
|
||||
if string(received) != string(data) {
|
||||
t.Errorf("Received data = %v, want %v", string(received), string(data))
|
||||
}
|
||||
case <-time.After(100 * time.Millisecond):
|
||||
t.Error("Send() did not send data to channel")
|
||||
}
|
||||
}
|
||||
|
||||
func TestConnectionSendTimeout(t *testing.T) {
|
||||
logger := zap.NewNop()
|
||||
conn := NewConnection("test", nil, logger)
|
||||
|
||||
// Fill the channel
|
||||
for i := 0; i < 256; i++ {
|
||||
conn.SendCh <- []byte("fill")
|
||||
}
|
||||
|
||||
// Try to send when channel is full
|
||||
data := []byte("test message")
|
||||
err := conn.Send(data)
|
||||
|
||||
if err != ErrSendTimeout {
|
||||
t.Errorf("Send() on full channel error = %v, want %v", err, ErrSendTimeout)
|
||||
}
|
||||
}
|
||||
|
||||
func TestConnectionClose(t *testing.T) {
|
||||
logger := zap.NewNop()
|
||||
conn := NewConnection("test", nil, logger)
|
||||
|
||||
// Close the connection
|
||||
conn.Close()
|
||||
|
||||
// Verify CloseCh is closed
|
||||
select {
|
||||
case <-conn.CloseCh:
|
||||
// Successfully received from closed channel
|
||||
case <-time.After(100 * time.Millisecond):
|
||||
t.Error("Close() did not close CloseCh")
|
||||
}
|
||||
|
||||
// Try to close again (should not panic)
|
||||
defer func() {
|
||||
if r := recover(); r != nil {
|
||||
t.Error("Close() panicked on second call")
|
||||
}
|
||||
}()
|
||||
conn.Close()
|
||||
}
|
||||
|
||||
func TestConnectionConcurrentUpdateActivity(t *testing.T) {
|
||||
logger := zap.NewNop()
|
||||
conn := NewConnection("test", nil, logger)
|
||||
|
||||
// Update activity concurrently
|
||||
done := make(chan bool)
|
||||
for i := 0; i < 100; i++ {
|
||||
go func() {
|
||||
conn.UpdateActivity()
|
||||
done <- true
|
||||
}()
|
||||
}
|
||||
|
||||
// Wait for all goroutines
|
||||
for i := 0; i < 100; i++ {
|
||||
<-done
|
||||
}
|
||||
|
||||
// Verify LastActive is recent
|
||||
now := time.Now()
|
||||
if now.Sub(conn.LastActive) > time.Second {
|
||||
t.Error("Concurrent UpdateActivity() failed")
|
||||
}
|
||||
}
|
||||
|
||||
func TestConnectionConcurrentIsAlive(t *testing.T) {
|
||||
logger := zap.NewNop()
|
||||
conn := NewConnection("test", nil, logger)
|
||||
|
||||
// Check IsAlive concurrently
|
||||
done := make(chan bool)
|
||||
for i := 0; i < 100; i++ {
|
||||
go func() {
|
||||
conn.IsAlive(90 * time.Second)
|
||||
done <- true
|
||||
}()
|
||||
}
|
||||
|
||||
// Wait for all goroutines
|
||||
for i := 0; i < 100; i++ {
|
||||
<-done
|
||||
}
|
||||
}
|
||||
|
||||
// Benchmark tests
|
||||
func BenchmarkConnectionSend(b *testing.B) {
|
||||
logger := zap.NewNop()
|
||||
conn := NewConnection("test", nil, logger)
|
||||
|
||||
// Drain channel in background
|
||||
go func() {
|
||||
for range conn.SendCh {
|
||||
}
|
||||
}()
|
||||
|
||||
data := []byte("test message")
|
||||
b.ResetTimer()
|
||||
|
||||
for i := 0; i < b.N; i++ {
|
||||
conn.Send(data)
|
||||
}
|
||||
}
|
||||
|
||||
func BenchmarkConnectionUpdateActivity(b *testing.B) {
|
||||
logger := zap.NewNop()
|
||||
conn := NewConnection("test", nil, logger)
|
||||
|
||||
b.ResetTimer()
|
||||
for i := 0; i < b.N; i++ {
|
||||
conn.UpdateActivity()
|
||||
}
|
||||
}
|
||||
|
||||
func BenchmarkConnectionIsAlive(b *testing.B) {
|
||||
logger := zap.NewNop()
|
||||
conn := NewConnection("test", nil, logger)
|
||||
timeout := 90 * time.Second
|
||||
|
||||
b.ResetTimer()
|
||||
for i := 0; i < b.N; i++ {
|
||||
conn.IsAlive(timeout)
|
||||
}
|
||||
}
|
||||
@@ -1,376 +0,0 @@
|
||||
package tunnel
|
||||
|
||||
import (
|
||||
"sync"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"go.uber.org/zap"
|
||||
)
|
||||
|
||||
func TestNewManager(t *testing.T) {
|
||||
logger := zap.NewNop()
|
||||
manager := NewManager(logger)
|
||||
|
||||
if manager == nil {
|
||||
t.Fatal("NewManager() returned nil")
|
||||
}
|
||||
|
||||
if manager.tunnels == nil {
|
||||
t.Error("Manager tunnels map is nil")
|
||||
}
|
||||
|
||||
if manager.used == nil {
|
||||
t.Error("Manager used map is nil")
|
||||
}
|
||||
|
||||
if manager.logger == nil {
|
||||
t.Error("Manager logger is nil")
|
||||
}
|
||||
}
|
||||
|
||||
func TestManagerRegister(t *testing.T) {
|
||||
logger := zap.NewNop()
|
||||
manager := NewManager(logger)
|
||||
|
||||
// Register with empty subdomain (auto-generate)
|
||||
subdomain, err := manager.Register(nil, "")
|
||||
if err != nil {
|
||||
t.Errorf("Register() error = %v, want nil", err)
|
||||
}
|
||||
|
||||
if subdomain == "" {
|
||||
t.Error("Register() returned empty subdomain")
|
||||
}
|
||||
|
||||
if len(subdomain) != 6 {
|
||||
t.Errorf("Register() subdomain length = %d, want 6", len(subdomain))
|
||||
}
|
||||
|
||||
// Verify connection is registered
|
||||
_, ok := manager.Get(subdomain)
|
||||
if !ok {
|
||||
t.Error("Get() failed to retrieve registered connection")
|
||||
}
|
||||
}
|
||||
|
||||
func TestManagerRegisterCustomSubdomain(t *testing.T) {
|
||||
logger := zap.NewNop()
|
||||
manager := NewManager(logger)
|
||||
|
||||
customSubdomain := "mytest"
|
||||
|
||||
// Register with custom subdomain
|
||||
subdomain, err := manager.Register(nil, customSubdomain)
|
||||
if err != nil {
|
||||
t.Errorf("Register() error = %v, want nil", err)
|
||||
}
|
||||
|
||||
if subdomain != customSubdomain {
|
||||
t.Errorf("Register() subdomain = %v, want %v", subdomain, customSubdomain)
|
||||
}
|
||||
|
||||
// Verify connection is registered
|
||||
conn, ok := manager.Get(subdomain)
|
||||
if !ok {
|
||||
t.Error("Get() failed to retrieve registered connection")
|
||||
}
|
||||
|
||||
if conn.Subdomain != customSubdomain {
|
||||
t.Errorf("Connection subdomain = %v, want %v", conn.Subdomain, customSubdomain)
|
||||
}
|
||||
}
|
||||
|
||||
func TestManagerRegisterDuplicate(t *testing.T) {
|
||||
logger := zap.NewNop()
|
||||
manager := NewManager(logger)
|
||||
|
||||
customSubdomain := "test123"
|
||||
|
||||
// Register first connection
|
||||
_, err := manager.Register(nil, customSubdomain)
|
||||
if err != nil {
|
||||
t.Fatalf("First Register() error = %v, want nil", err)
|
||||
}
|
||||
|
||||
// Try to register second connection with same subdomain
|
||||
_, err = manager.Register(nil, customSubdomain)
|
||||
if err != ErrSubdomainTaken {
|
||||
t.Errorf("Register() error = %v, want %v", err, ErrSubdomainTaken)
|
||||
}
|
||||
}
|
||||
|
||||
func TestManagerRegisterInvalidSubdomain(t *testing.T) {
|
||||
logger := zap.NewNop()
|
||||
manager := NewManager(logger)
|
||||
|
||||
tests := []struct {
|
||||
name string
|
||||
subdomain string
|
||||
wantErr error
|
||||
}{
|
||||
{
|
||||
name: "invalid uppercase",
|
||||
subdomain: "TEST",
|
||||
wantErr: ErrInvalidSubdomain,
|
||||
},
|
||||
{
|
||||
name: "invalid special char",
|
||||
subdomain: "test@123",
|
||||
wantErr: ErrInvalidSubdomain,
|
||||
},
|
||||
{
|
||||
name: "reserved www",
|
||||
subdomain: "www",
|
||||
wantErr: ErrReservedSubdomain,
|
||||
},
|
||||
{
|
||||
name: "reserved api",
|
||||
subdomain: "api",
|
||||
wantErr: ErrReservedSubdomain,
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
_, err := manager.Register(nil, tt.subdomain)
|
||||
if err != tt.wantErr {
|
||||
t.Errorf("Register() error = %v, want %v", err, tt.wantErr)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestManagerUnregister(t *testing.T) {
|
||||
logger := zap.NewNop()
|
||||
manager := NewManager(logger)
|
||||
|
||||
// Register connection
|
||||
subdomain, err := manager.Register(nil, "")
|
||||
if err != nil {
|
||||
t.Fatalf("Register() error = %v", err)
|
||||
}
|
||||
|
||||
// Unregister connection
|
||||
manager.Unregister(subdomain)
|
||||
|
||||
// Verify connection is removed
|
||||
_, ok := manager.Get(subdomain)
|
||||
if ok {
|
||||
t.Error("Get() succeeded after Unregister(), want failure")
|
||||
}
|
||||
}
|
||||
|
||||
func TestManagerGet(t *testing.T) {
|
||||
logger := zap.NewNop()
|
||||
manager := NewManager(logger)
|
||||
|
||||
customSubdomain := "test123"
|
||||
|
||||
// Test Get on non-existent connection
|
||||
_, ok := manager.Get(customSubdomain)
|
||||
if ok {
|
||||
t.Error("Get() succeeded for non-existent connection")
|
||||
}
|
||||
|
||||
// Register and test Get
|
||||
subdomain, _ := manager.Register(nil, customSubdomain)
|
||||
retrieved, ok := manager.Get(subdomain)
|
||||
if !ok {
|
||||
t.Error("Get() failed for existing connection")
|
||||
}
|
||||
if retrieved.Subdomain != customSubdomain {
|
||||
t.Error("Get() returned wrong connection")
|
||||
}
|
||||
}
|
||||
|
||||
func TestManagerList(t *testing.T) {
|
||||
logger := zap.NewNop()
|
||||
manager := NewManager(logger)
|
||||
|
||||
// Test empty manager
|
||||
all := manager.List()
|
||||
if len(all) != 0 {
|
||||
t.Errorf("List() on empty manager returned %d connections, want 0", len(all))
|
||||
}
|
||||
|
||||
// Add multiple connections
|
||||
count := 5
|
||||
for i := 0; i < count; i++ {
|
||||
manager.Register(nil, "")
|
||||
}
|
||||
|
||||
// Test List
|
||||
all = manager.List()
|
||||
if len(all) != count {
|
||||
t.Errorf("List() returned %d connections, want %d", len(all), count)
|
||||
}
|
||||
}
|
||||
|
||||
func TestManagerCount(t *testing.T) {
|
||||
logger := zap.NewNop()
|
||||
manager := NewManager(logger)
|
||||
|
||||
// Test empty manager
|
||||
count := manager.Count()
|
||||
if count != 0 {
|
||||
t.Errorf("Count() on empty manager = %d, want 0", count)
|
||||
}
|
||||
|
||||
// Add connections
|
||||
numConns := 3
|
||||
for i := 0; i < numConns; i++ {
|
||||
manager.Register(nil, "")
|
||||
}
|
||||
|
||||
count = manager.Count()
|
||||
if count != numConns {
|
||||
t.Errorf("Count() = %d, want %d", count, numConns)
|
||||
}
|
||||
}
|
||||
|
||||
func TestManagerGenerateSubdomain(t *testing.T) {
|
||||
logger := zap.NewNop()
|
||||
manager := NewManager(logger)
|
||||
|
||||
// Generate subdomain via Register
|
||||
subdomain1, err := manager.Register(nil, "")
|
||||
if err != nil {
|
||||
t.Fatalf("First Register() error = %v", err)
|
||||
}
|
||||
|
||||
if subdomain1 == "" {
|
||||
t.Error("Register() returned empty subdomain")
|
||||
}
|
||||
|
||||
if len(subdomain1) != 6 {
|
||||
t.Errorf("Register() subdomain length = %d, want 6", len(subdomain1))
|
||||
}
|
||||
|
||||
// Generate another subdomain, should be different
|
||||
subdomain2, err := manager.Register(nil, "")
|
||||
if err != nil {
|
||||
t.Fatalf("Second Register() error = %v", err)
|
||||
}
|
||||
|
||||
if subdomain1 == subdomain2 {
|
||||
t.Error("Register() generated duplicate subdomain")
|
||||
}
|
||||
}
|
||||
|
||||
func TestManagerGenerateSubdomainUniqueness(t *testing.T) {
|
||||
logger := zap.NewNop()
|
||||
manager := NewManager(logger)
|
||||
|
||||
subdomains := make(map[string]bool)
|
||||
count := 100
|
||||
|
||||
for i := 0; i < count; i++ {
|
||||
subdomain, err := manager.Register(nil, "")
|
||||
if err != nil {
|
||||
t.Fatalf("Register() error = %v", err)
|
||||
}
|
||||
if subdomains[subdomain] {
|
||||
t.Errorf("Register() generated duplicate: %s", subdomain)
|
||||
}
|
||||
subdomains[subdomain] = true
|
||||
}
|
||||
}
|
||||
|
||||
func TestManagerCleanupStale(t *testing.T) {
|
||||
logger := zap.NewNop()
|
||||
manager := NewManager(logger)
|
||||
|
||||
// Create fresh connection
|
||||
freshSubdomain, _ := manager.Register(nil, "fresh")
|
||||
|
||||
// Create stale connection
|
||||
staleSubdomain, _ := manager.Register(nil, "stale")
|
||||
|
||||
// Manually set LastActive to be stale
|
||||
if staleConn, ok := manager.Get(staleSubdomain); ok {
|
||||
staleConn.mu.Lock()
|
||||
staleConn.LastActive = time.Now().Add(-2 * time.Minute)
|
||||
staleConn.mu.Unlock()
|
||||
}
|
||||
|
||||
// Run cleanup with 90 second timeout
|
||||
count := manager.CleanupStale(90 * time.Second)
|
||||
if count != 1 {
|
||||
t.Errorf("CleanupStale() returned %d, want 1", count)
|
||||
}
|
||||
|
||||
// Fresh connection should still exist
|
||||
_, ok := manager.Get(freshSubdomain)
|
||||
if !ok {
|
||||
t.Error("CleanupStale() removed fresh connection")
|
||||
}
|
||||
|
||||
// Stale connection should be removed
|
||||
_, ok = manager.Get(staleSubdomain)
|
||||
if ok {
|
||||
t.Error("CleanupStale() did not remove stale connection")
|
||||
}
|
||||
}
|
||||
|
||||
func TestManagerConcurrentAccess(t *testing.T) {
|
||||
logger := zap.NewNop()
|
||||
manager := NewManager(logger)
|
||||
|
||||
var wg sync.WaitGroup
|
||||
count := 50 // Reduced from 100 to avoid potential issues
|
||||
|
||||
// Concurrent registrations
|
||||
for i := 0; i < count; i++ {
|
||||
wg.Add(1)
|
||||
go func(idx int) {
|
||||
defer wg.Done()
|
||||
_, err := manager.Register(nil, "")
|
||||
if err != nil {
|
||||
t.Errorf("Concurrent Register() error = %v", err)
|
||||
}
|
||||
}(i)
|
||||
}
|
||||
|
||||
wg.Wait()
|
||||
|
||||
// Verify all connections are registered
|
||||
all := manager.List()
|
||||
if len(all) != count {
|
||||
t.Errorf("Expected %d connections, got %d", count, len(all))
|
||||
}
|
||||
}
|
||||
|
||||
// Benchmark tests
|
||||
func BenchmarkManagerRegister(b *testing.B) {
|
||||
logger := zap.NewNop()
|
||||
manager := NewManager(logger)
|
||||
|
||||
b.ResetTimer()
|
||||
for i := 0; i < b.N; i++ {
|
||||
manager.Register(nil, "")
|
||||
}
|
||||
}
|
||||
|
||||
func BenchmarkManagerGet(b *testing.B) {
|
||||
logger := zap.NewNop()
|
||||
manager := NewManager(logger)
|
||||
|
||||
// Setup: register a connection
|
||||
subdomain, _ := manager.Register(nil, "test123")
|
||||
|
||||
b.ResetTimer()
|
||||
for i := 0; i < b.N; i++ {
|
||||
manager.Get(subdomain)
|
||||
}
|
||||
}
|
||||
|
||||
func BenchmarkManagerGenerateSubdomain(b *testing.B) {
|
||||
logger := zap.NewNop()
|
||||
|
||||
b.ResetTimer()
|
||||
for i := 0; i < b.N; i++ {
|
||||
manager := NewManager(logger)
|
||||
manager.Register(nil, "")
|
||||
}
|
||||
}
|
||||
@@ -5,8 +5,7 @@ import (
|
||||
"sync"
|
||||
)
|
||||
|
||||
// HeaderPool manages a pool of http.Header objects for reuse
|
||||
// This reduces GC pressure from repeated header map allocations
|
||||
// HeaderPool manages a pool of http.Header objects for reuse.
|
||||
type HeaderPool struct {
|
||||
pool sync.Pool
|
||||
}
|
||||
@@ -16,32 +15,26 @@ func NewHeaderPool() *HeaderPool {
|
||||
return &HeaderPool{
|
||||
pool: sync.Pool{
|
||||
New: func() interface{} {
|
||||
// Pre-allocate with capacity for common header count (8-12 headers)
|
||||
return make(http.Header, 12)
|
||||
},
|
||||
},
|
||||
}
|
||||
}
|
||||
|
||||
// Get retrieves a header from the pool
|
||||
// Returns a clean, empty header ready for use
|
||||
// Get retrieves a header from the pool.
|
||||
func (p *HeaderPool) Get() http.Header {
|
||||
h := p.pool.Get().(http.Header)
|
||||
// Clear any existing data (headers might be dirty from previous use)
|
||||
for k := range h {
|
||||
delete(h, k)
|
||||
}
|
||||
return h
|
||||
}
|
||||
|
||||
// Put returns a header to the pool
|
||||
// The header will be reused by future Get() calls
|
||||
// Put returns a header to the pool.
|
||||
func (p *HeaderPool) Put(h http.Header) {
|
||||
if h == nil {
|
||||
return
|
||||
}
|
||||
// Note: We don't clear here, clearing is done in Get() for better performance
|
||||
// (allows the GC to collect during idle time)
|
||||
p.pool.Put(h)
|
||||
}
|
||||
|
||||
|
||||
@@ -1,7 +1,7 @@
|
||||
package protocol
|
||||
|
||||
import (
|
||||
"encoding/json"
|
||||
json "github.com/goccy/go-json"
|
||||
"errors"
|
||||
|
||||
"github.com/vmihailenco/msgpack/v5"
|
||||
|
||||
@@ -1,307 +0,0 @@
|
||||
package protocol
|
||||
|
||||
import (
|
||||
"encoding/json"
|
||||
"reflect"
|
||||
"testing"
|
||||
)
|
||||
|
||||
func TestMessageType_Values(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
mt MessageType
|
||||
want string
|
||||
}{
|
||||
{
|
||||
name: "register",
|
||||
mt: TypeRegister,
|
||||
want: "register",
|
||||
},
|
||||
{
|
||||
name: "request",
|
||||
mt: TypeRequest,
|
||||
want: "request",
|
||||
},
|
||||
{
|
||||
name: "response",
|
||||
mt: TypeResponse,
|
||||
want: "response",
|
||||
},
|
||||
{
|
||||
name: "heartbeat",
|
||||
mt: TypeHeartbeat,
|
||||
want: "heartbeat",
|
||||
},
|
||||
{
|
||||
name: "error",
|
||||
mt: TypeError,
|
||||
want: "error",
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
got := string(tt.mt)
|
||||
if got != tt.want {
|
||||
t.Errorf("MessageType value = %v, want %v", got, tt.want)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestMessage_JSON(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
message *Message
|
||||
wantErr bool
|
||||
}{
|
||||
{
|
||||
name: "simple message",
|
||||
message: &Message{
|
||||
Type: TypeRegister,
|
||||
ID: "test-id-123",
|
||||
},
|
||||
wantErr: false,
|
||||
},
|
||||
{
|
||||
name: "message with subdomain",
|
||||
message: &Message{
|
||||
Type: TypeRegister,
|
||||
ID: "test-id-456",
|
||||
Subdomain: "abc123",
|
||||
},
|
||||
wantErr: false,
|
||||
},
|
||||
{
|
||||
name: "message with data",
|
||||
message: &Message{
|
||||
Type: TypeRequest,
|
||||
ID: "test-id-789",
|
||||
Data: map[string]interface{}{
|
||||
"method": "GET",
|
||||
"path": "/test",
|
||||
"headers": map[string]interface{}{
|
||||
"User-Agent": "Test",
|
||||
},
|
||||
},
|
||||
},
|
||||
wantErr: false,
|
||||
},
|
||||
{
|
||||
name: "error message",
|
||||
message: &Message{
|
||||
Type: TypeError,
|
||||
ID: "test-id-error",
|
||||
Data: map[string]interface{}{
|
||||
"error": "something went wrong",
|
||||
"code": float64(500), // JSON unmarshals numbers as float64
|
||||
},
|
||||
},
|
||||
wantErr: false,
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
// Marshal to JSON
|
||||
data, err := json.Marshal(tt.message)
|
||||
if (err != nil) != tt.wantErr {
|
||||
t.Errorf("json.Marshal() error = %v, wantErr %v", err, tt.wantErr)
|
||||
return
|
||||
}
|
||||
|
||||
if tt.wantErr {
|
||||
return
|
||||
}
|
||||
|
||||
// Unmarshal back
|
||||
var decoded Message
|
||||
err = json.Unmarshal(data, &decoded)
|
||||
if err != nil {
|
||||
t.Errorf("json.Unmarshal() error = %v", err)
|
||||
return
|
||||
}
|
||||
|
||||
// Compare
|
||||
if decoded.Type != tt.message.Type {
|
||||
t.Errorf("Type = %v, want %v", decoded.Type, tt.message.Type)
|
||||
}
|
||||
if decoded.ID != tt.message.ID {
|
||||
t.Errorf("ID = %v, want %v", decoded.ID, tt.message.ID)
|
||||
}
|
||||
if decoded.Subdomain != tt.message.Subdomain {
|
||||
t.Errorf("Subdomain = %v, want %v", decoded.Subdomain, tt.message.Subdomain)
|
||||
}
|
||||
|
||||
// Deep compare Data if present
|
||||
if tt.message.Data != nil {
|
||||
if !reflect.DeepEqual(decoded.Data, tt.message.Data) {
|
||||
t.Errorf("Data = %v, want %v", decoded.Data, tt.message.Data)
|
||||
}
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestHTTPRequest_JSON(t *testing.T) {
|
||||
req := &HTTPRequest{
|
||||
Method: "POST",
|
||||
URL: "http://localhost:3000/api/test",
|
||||
Headers: map[string][]string{
|
||||
"Content-Type": {"application/json"},
|
||||
"User-Agent": {"Test Agent"},
|
||||
},
|
||||
Body: []byte(`{"key":"value"}`),
|
||||
}
|
||||
|
||||
// Marshal
|
||||
data, err := json.Marshal(req)
|
||||
if err != nil {
|
||||
t.Fatalf("json.Marshal() error = %v", err)
|
||||
}
|
||||
|
||||
// Unmarshal
|
||||
var decoded HTTPRequest
|
||||
err = json.Unmarshal(data, &decoded)
|
||||
if err != nil {
|
||||
t.Fatalf("json.Unmarshal() error = %v", err)
|
||||
}
|
||||
|
||||
// Compare
|
||||
if decoded.Method != req.Method {
|
||||
t.Errorf("Method = %v, want %v", decoded.Method, req.Method)
|
||||
}
|
||||
if decoded.URL != req.URL {
|
||||
t.Errorf("URL = %v, want %v", decoded.URL, req.URL)
|
||||
}
|
||||
if !reflect.DeepEqual(decoded.Headers, req.Headers) {
|
||||
t.Errorf("Headers = %v, want %v", decoded.Headers, req.Headers)
|
||||
}
|
||||
if string(decoded.Body) != string(req.Body) {
|
||||
t.Errorf("Body = %v, want %v", string(decoded.Body), string(req.Body))
|
||||
}
|
||||
}
|
||||
|
||||
func TestHTTPResponse_JSON(t *testing.T) {
|
||||
resp := &HTTPResponse{
|
||||
StatusCode: 200,
|
||||
Status: "200 OK",
|
||||
Headers: map[string][]string{
|
||||
"Content-Type": {"text/html"},
|
||||
},
|
||||
Body: []byte("<html>Test</html>"),
|
||||
}
|
||||
|
||||
// Marshal
|
||||
data, err := json.Marshal(resp)
|
||||
if err != nil {
|
||||
t.Fatalf("json.Marshal() error = %v", err)
|
||||
}
|
||||
|
||||
// Unmarshal
|
||||
var decoded HTTPResponse
|
||||
err = json.Unmarshal(data, &decoded)
|
||||
if err != nil {
|
||||
t.Fatalf("json.Unmarshal() error = %v", err)
|
||||
}
|
||||
|
||||
// Compare
|
||||
if decoded.StatusCode != resp.StatusCode {
|
||||
t.Errorf("StatusCode = %v, want %v", decoded.StatusCode, resp.StatusCode)
|
||||
}
|
||||
if decoded.Status != resp.Status {
|
||||
t.Errorf("Status = %v, want %v", decoded.Status, resp.Status)
|
||||
}
|
||||
if !reflect.DeepEqual(decoded.Headers, resp.Headers) {
|
||||
t.Errorf("Headers = %v, want %v", decoded.Headers, resp.Headers)
|
||||
}
|
||||
if string(decoded.Body) != string(resp.Body) {
|
||||
t.Errorf("Body = %v, want %v", string(decoded.Body), string(resp.Body))
|
||||
}
|
||||
}
|
||||
|
||||
func TestMessage_ToMap(t *testing.T) {
|
||||
msg := &Message{
|
||||
Type: TypeRequest,
|
||||
ID: "test-123",
|
||||
Subdomain: "abc",
|
||||
Data: map[string]interface{}{
|
||||
"test": "value",
|
||||
},
|
||||
}
|
||||
|
||||
// Convert to map (simulated by marshaling and unmarshaling)
|
||||
data, err := json.Marshal(msg)
|
||||
if err != nil {
|
||||
t.Fatalf("json.Marshal() error = %v", err)
|
||||
}
|
||||
|
||||
var result map[string]interface{}
|
||||
err = json.Unmarshal(data, &result)
|
||||
if err != nil {
|
||||
t.Fatalf("json.Unmarshal() error = %v", err)
|
||||
}
|
||||
|
||||
// Verify fields exist
|
||||
if result["type"] == nil {
|
||||
t.Error("Map missing 'type' field")
|
||||
}
|
||||
if result["id"] == nil {
|
||||
t.Error("Map missing 'id' field")
|
||||
}
|
||||
}
|
||||
|
||||
func TestNewMessage(t *testing.T) {
|
||||
msgType := TypeRegister
|
||||
id := "test-id"
|
||||
|
||||
msg := &Message{
|
||||
Type: msgType,
|
||||
ID: id,
|
||||
}
|
||||
|
||||
if msg.Type != msgType {
|
||||
t.Errorf("Type = %v, want %v", msg.Type, msgType)
|
||||
}
|
||||
if msg.ID != id {
|
||||
t.Errorf("ID = %v, want %v", msg.ID, id)
|
||||
}
|
||||
}
|
||||
|
||||
// Benchmark tests
|
||||
func BenchmarkMessageMarshal(b *testing.B) {
|
||||
msg := &Message{
|
||||
Type: TypeRequest,
|
||||
ID: "test-id-123",
|
||||
Subdomain: "abc123",
|
||||
Data: map[string]interface{}{
|
||||
"method": "GET",
|
||||
"path": "/test",
|
||||
},
|
||||
}
|
||||
|
||||
b.ResetTimer()
|
||||
for i := 0; i < b.N; i++ {
|
||||
json.Marshal(msg)
|
||||
}
|
||||
}
|
||||
|
||||
func BenchmarkMessageUnmarshal(b *testing.B) {
|
||||
msg := &Message{
|
||||
Type: TypeRequest,
|
||||
ID: "test-id-123",
|
||||
Subdomain: "abc123",
|
||||
Data: map[string]interface{}{
|
||||
"method": "GET",
|
||||
"path": "/test",
|
||||
},
|
||||
}
|
||||
|
||||
data, _ := json.Marshal(msg)
|
||||
|
||||
b.ResetTimer()
|
||||
for i := 0; i < b.N; i++ {
|
||||
var decoded Message
|
||||
json.Unmarshal(data, &decoded)
|
||||
}
|
||||
}
|
||||
@@ -1,6 +1,6 @@
|
||||
package protocol
|
||||
|
||||
import "encoding/json"
|
||||
import json "github.com/goccy/go-json"
|
||||
|
||||
// RegisterRequest is sent by client to register a tunnel
|
||||
type RegisterRequest struct {
|
||||
|
||||
@@ -1,7 +1,7 @@
|
||||
package protocol
|
||||
|
||||
import (
|
||||
"encoding/json"
|
||||
json "github.com/goccy/go-json"
|
||||
"errors"
|
||||
)
|
||||
|
||||
|
||||
@@ -17,6 +17,10 @@ type FrameWriter struct {
|
||||
|
||||
maxBatch int
|
||||
maxBatchWait time.Duration
|
||||
|
||||
heartbeatInterval time.Duration
|
||||
heartbeatCallback func() *Frame
|
||||
heartbeatEnabled bool
|
||||
}
|
||||
|
||||
func NewFrameWriter(conn io.Writer) *FrameWriter {
|
||||
@@ -37,8 +41,22 @@ func NewFrameWriterWithConfig(conn io.Writer, maxBatch int, maxBatchWait time.Du
|
||||
}
|
||||
|
||||
func (w *FrameWriter) writeLoop() {
|
||||
ticker := time.NewTicker(w.maxBatchWait)
|
||||
defer ticker.Stop()
|
||||
batchTicker := time.NewTicker(w.maxBatchWait)
|
||||
defer batchTicker.Stop()
|
||||
|
||||
var heartbeatTicker *time.Ticker
|
||||
var heartbeatCh <-chan time.Time
|
||||
|
||||
w.mu.Lock()
|
||||
if w.heartbeatEnabled && w.heartbeatInterval > 0 {
|
||||
heartbeatTicker = time.NewTicker(w.heartbeatInterval)
|
||||
heartbeatCh = heartbeatTicker.C
|
||||
}
|
||||
w.mu.Unlock()
|
||||
|
||||
if heartbeatTicker != nil {
|
||||
defer heartbeatTicker.Stop()
|
||||
}
|
||||
|
||||
for {
|
||||
select {
|
||||
@@ -58,13 +76,23 @@ func (w *FrameWriter) writeLoop() {
|
||||
}
|
||||
w.mu.Unlock()
|
||||
|
||||
case <-ticker.C:
|
||||
case <-batchTicker.C:
|
||||
w.mu.Lock()
|
||||
if len(w.batch) > 0 {
|
||||
w.flushBatchLocked()
|
||||
}
|
||||
w.mu.Unlock()
|
||||
|
||||
case <-heartbeatCh:
|
||||
w.mu.Lock()
|
||||
if w.heartbeatCallback != nil {
|
||||
if frame := w.heartbeatCallback(); frame != nil {
|
||||
w.batch = append(w.batch, frame)
|
||||
w.flushBatchLocked()
|
||||
}
|
||||
}
|
||||
w.mu.Unlock()
|
||||
|
||||
case <-w.done:
|
||||
w.mu.Lock()
|
||||
w.flushBatchLocked()
|
||||
@@ -86,6 +114,8 @@ func (w *FrameWriter) flushBatchLocked() {
|
||||
w.batch = w.batch[:0]
|
||||
}
|
||||
|
||||
// WriteFrame queues a frame to be written by the write loop.
|
||||
// Blocks if the queue is full to ensure all writes go through the single write loop.
|
||||
func (w *FrameWriter) WriteFrame(frame *Frame) error {
|
||||
w.mu.Lock()
|
||||
if w.closed {
|
||||
@@ -99,8 +129,6 @@ func (w *FrameWriter) WriteFrame(frame *Frame) error {
|
||||
return nil
|
||||
case <-w.done:
|
||||
return errors.New("writer closed")
|
||||
default:
|
||||
return WriteFrame(w.conn, frame)
|
||||
}
|
||||
}
|
||||
|
||||
@@ -124,3 +152,19 @@ func (w *FrameWriter) Flush() {
|
||||
defer w.mu.Unlock()
|
||||
w.flushBatchLocked()
|
||||
}
|
||||
|
||||
// EnableHeartbeat enables automatic heartbeat sending in the write loop.
|
||||
func (w *FrameWriter) EnableHeartbeat(interval time.Duration, callback func() *Frame) {
|
||||
w.mu.Lock()
|
||||
defer w.mu.Unlock()
|
||||
w.heartbeatInterval = interval
|
||||
w.heartbeatCallback = callback
|
||||
w.heartbeatEnabled = true
|
||||
}
|
||||
|
||||
// DisableHeartbeat disables automatic heartbeat sending.
|
||||
func (w *FrameWriter) DisableHeartbeat() {
|
||||
w.mu.Lock()
|
||||
defer w.mu.Unlock()
|
||||
w.heartbeatEnabled = false
|
||||
}
|
||||
|
||||
@@ -1,158 +0,0 @@
|
||||
package utils
|
||||
|
||||
import (
|
||||
"strings"
|
||||
"testing"
|
||||
)
|
||||
|
||||
func TestGenerateID(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
wantLength int // expected minimum length
|
||||
}{
|
||||
{
|
||||
name: "generate valid ID",
|
||||
wantLength: 16, // At least 16 characters for hex-encoded random bytes
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
got := GenerateID()
|
||||
|
||||
// Check that ID is not empty
|
||||
if got == "" {
|
||||
t.Error("GenerateID() returned empty string")
|
||||
}
|
||||
|
||||
// Check minimum length
|
||||
if len(got) < tt.wantLength {
|
||||
t.Errorf("GenerateID() length = %v, want at least %v", len(got), tt.wantLength)
|
||||
}
|
||||
|
||||
// Check that it's a valid hex string
|
||||
for _, char := range got {
|
||||
if !isHexChar(char) {
|
||||
t.Errorf("GenerateID() contains non-hex character: %c", char)
|
||||
}
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestGenerateIDUniqueness(t *testing.T) {
|
||||
// Generate 10000 IDs and check for uniqueness
|
||||
ids := make(map[string]bool)
|
||||
count := 10000
|
||||
|
||||
for i := 0; i < count; i++ {
|
||||
id := GenerateID()
|
||||
if ids[id] {
|
||||
t.Errorf("GenerateID() generated duplicate: %s", id)
|
||||
}
|
||||
ids[id] = true
|
||||
}
|
||||
|
||||
if len(ids) != count {
|
||||
t.Errorf("Expected %d unique IDs, got %d", count, len(ids))
|
||||
}
|
||||
}
|
||||
|
||||
func TestGenerateIDFormat(t *testing.T) {
|
||||
id := GenerateID()
|
||||
|
||||
// Check that it's lowercase
|
||||
if id != strings.ToLower(id) {
|
||||
t.Errorf("GenerateID() is not lowercase: %s", id)
|
||||
}
|
||||
|
||||
// Check that it doesn't contain special characters
|
||||
for _, char := range id {
|
||||
if !isHexChar(char) {
|
||||
t.Errorf("GenerateID() contains invalid character: %c in %s", char, id)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func TestGenerateIDConsistency(t *testing.T) {
|
||||
// Generate multiple IDs and ensure they all follow the same format
|
||||
count := 100
|
||||
firstID := GenerateID()
|
||||
firstLen := len(firstID)
|
||||
|
||||
for i := 0; i < count; i++ {
|
||||
id := GenerateID()
|
||||
|
||||
// All IDs should have the same length
|
||||
if len(id) != firstLen {
|
||||
t.Errorf("ID length inconsistency: first=%d, current=%d", firstLen, len(id))
|
||||
}
|
||||
|
||||
// All IDs should be hex strings
|
||||
for _, char := range id {
|
||||
if !isHexChar(char) {
|
||||
t.Errorf("Invalid hex character %c in ID: %s", char, id)
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func TestGenerateIDNotEmpty(t *testing.T) {
|
||||
// Generate 1000 IDs and ensure none are empty
|
||||
for i := 0; i < 1000; i++ {
|
||||
id := GenerateID()
|
||||
if id == "" {
|
||||
t.Error("GenerateID() returned empty string")
|
||||
}
|
||||
if len(id) == 0 {
|
||||
t.Error("GenerateID() returned zero-length string")
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// Helper function to check if a character is a valid hex character
|
||||
func isHexChar(char rune) bool {
|
||||
return (char >= '0' && char <= '9') || (char >= 'a' && char <= 'f')
|
||||
}
|
||||
|
||||
// Benchmark tests
|
||||
func BenchmarkGenerateID(b *testing.B) {
|
||||
for i := 0; i < b.N; i++ {
|
||||
GenerateID()
|
||||
}
|
||||
}
|
||||
|
||||
func BenchmarkGenerateIDParallel(b *testing.B) {
|
||||
b.RunParallel(func(pb *testing.PB) {
|
||||
for pb.Next() {
|
||||
GenerateID()
|
||||
}
|
||||
})
|
||||
}
|
||||
|
||||
// Test for concurrent ID generation
|
||||
func TestGenerateIDConcurrent(t *testing.T) {
|
||||
count := 1000
|
||||
ch := make(chan string, count)
|
||||
|
||||
// Generate IDs concurrently
|
||||
for i := 0; i < count; i++ {
|
||||
go func() {
|
||||
ch <- GenerateID()
|
||||
}()
|
||||
}
|
||||
|
||||
// Collect all IDs
|
||||
ids := make(map[string]bool)
|
||||
for i := 0; i < count; i++ {
|
||||
id := <-ch
|
||||
if ids[id] {
|
||||
t.Errorf("Concurrent GenerateID() generated duplicate: %s", id)
|
||||
}
|
||||
ids[id] = true
|
||||
}
|
||||
|
||||
if len(ids) != count {
|
||||
t.Errorf("Expected %d unique IDs, got %d", count, len(ids))
|
||||
}
|
||||
}
|
||||
@@ -1,266 +0,0 @@
|
||||
package utils
|
||||
|
||||
import (
|
||||
"strings"
|
||||
"testing"
|
||||
)
|
||||
|
||||
func TestGenerateSubdomain(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
length int
|
||||
want int // expected length
|
||||
}{
|
||||
{
|
||||
name: "default length 6",
|
||||
length: 6,
|
||||
want: 6,
|
||||
},
|
||||
{
|
||||
name: "length 8",
|
||||
length: 8,
|
||||
want: 8,
|
||||
},
|
||||
{
|
||||
name: "length 10",
|
||||
length: 10,
|
||||
want: 10,
|
||||
},
|
||||
{
|
||||
name: "minimum length 4",
|
||||
length: 4,
|
||||
want: 4,
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
got := GenerateSubdomain(tt.length)
|
||||
|
||||
// Check length
|
||||
if len(got) != tt.want {
|
||||
t.Errorf("GenerateSubdomain() length = %v, want %v", len(got), tt.want)
|
||||
}
|
||||
|
||||
// Check that it only contains alphanumeric characters
|
||||
for _, char := range got {
|
||||
if !isAlphanumeric(char) {
|
||||
t.Errorf("GenerateSubdomain() contains non-alphanumeric character: %c", char)
|
||||
}
|
||||
}
|
||||
|
||||
// Check that it's lowercase
|
||||
if got != strings.ToLower(got) {
|
||||
t.Errorf("GenerateSubdomain() is not lowercase: %s", got)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestGenerateSubdomainUniqueness(t *testing.T) {
|
||||
// Generate 1000 subdomains and check for uniqueness
|
||||
subdomains := make(map[string]bool)
|
||||
count := 1000
|
||||
length := 6
|
||||
|
||||
for i := 0; i < count; i++ {
|
||||
subdomain := GenerateSubdomain(length)
|
||||
if subdomains[subdomain] {
|
||||
t.Errorf("GenerateSubdomain() generated duplicate: %s", subdomain)
|
||||
}
|
||||
subdomains[subdomain] = true
|
||||
}
|
||||
|
||||
if len(subdomains) != count {
|
||||
t.Errorf("Expected %d unique subdomains, got %d", count, len(subdomains))
|
||||
}
|
||||
}
|
||||
|
||||
func TestValidateSubdomain(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
subdomain string
|
||||
want bool
|
||||
}{
|
||||
{
|
||||
name: "valid lowercase",
|
||||
subdomain: "abc123",
|
||||
want: true,
|
||||
},
|
||||
{
|
||||
name: "valid all letters",
|
||||
subdomain: "abcdef",
|
||||
want: true,
|
||||
},
|
||||
{
|
||||
name: "valid all numbers",
|
||||
subdomain: "123456",
|
||||
want: true,
|
||||
},
|
||||
{
|
||||
name: "invalid uppercase",
|
||||
subdomain: "ABC123",
|
||||
want: false,
|
||||
},
|
||||
{
|
||||
name: "valid with hyphen",
|
||||
subdomain: "abc-123",
|
||||
want: true,
|
||||
},
|
||||
{
|
||||
name: "invalid starting with hyphen",
|
||||
subdomain: "-abc123",
|
||||
want: false,
|
||||
},
|
||||
{
|
||||
name: "invalid ending with hyphen",
|
||||
subdomain: "abc123-",
|
||||
want: false,
|
||||
},
|
||||
{
|
||||
name: "invalid with underscore",
|
||||
subdomain: "abc_123",
|
||||
want: false,
|
||||
},
|
||||
{
|
||||
name: "invalid with dot",
|
||||
subdomain: "abc.123",
|
||||
want: false,
|
||||
},
|
||||
{
|
||||
name: "invalid with space",
|
||||
subdomain: "abc 123",
|
||||
want: false,
|
||||
},
|
||||
{
|
||||
name: "invalid empty",
|
||||
subdomain: "",
|
||||
want: false,
|
||||
},
|
||||
{
|
||||
name: "invalid special characters",
|
||||
subdomain: "abc@123",
|
||||
want: false,
|
||||
},
|
||||
{
|
||||
name: "valid minimum length",
|
||||
subdomain: "abc",
|
||||
want: true,
|
||||
},
|
||||
{
|
||||
name: "invalid too short",
|
||||
subdomain: "ab",
|
||||
want: false,
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
got := ValidateSubdomain(tt.subdomain)
|
||||
if got != tt.want {
|
||||
t.Errorf("ValidateSubdomain(%q) = %v, want %v", tt.subdomain, got, tt.want)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestIsReserved(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
subdomain string
|
||||
want bool
|
||||
}{
|
||||
{
|
||||
name: "reserved www",
|
||||
subdomain: "www",
|
||||
want: true,
|
||||
},
|
||||
{
|
||||
name: "reserved api",
|
||||
subdomain: "api",
|
||||
want: true,
|
||||
},
|
||||
{
|
||||
name: "reserved admin",
|
||||
subdomain: "admin",
|
||||
want: true,
|
||||
},
|
||||
{
|
||||
name: "reserved mail",
|
||||
subdomain: "mail",
|
||||
want: true,
|
||||
},
|
||||
{
|
||||
name: "reserved ftp",
|
||||
subdomain: "ftp",
|
||||
want: true,
|
||||
},
|
||||
{
|
||||
name: "reserved health",
|
||||
subdomain: "health",
|
||||
want: true,
|
||||
},
|
||||
{
|
||||
name: "reserved test",
|
||||
subdomain: "test",
|
||||
want: true,
|
||||
},
|
||||
{
|
||||
name: "reserved dev",
|
||||
subdomain: "dev",
|
||||
want: true,
|
||||
},
|
||||
{
|
||||
name: "reserved staging",
|
||||
subdomain: "staging",
|
||||
want: true,
|
||||
},
|
||||
{
|
||||
name: "not reserved random",
|
||||
subdomain: "abc123",
|
||||
want: false,
|
||||
},
|
||||
{
|
||||
name: "not reserved user",
|
||||
subdomain: "myapp",
|
||||
want: false,
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
got := IsReserved(tt.subdomain)
|
||||
if got != tt.want {
|
||||
t.Errorf("IsReserved(%q) = %v, want %v", tt.subdomain, got, tt.want)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
// Helper function to check if a character is alphanumeric
|
||||
func isAlphanumeric(char rune) bool {
|
||||
return (char >= 'a' && char <= 'z') || (char >= '0' && char <= '9')
|
||||
}
|
||||
|
||||
// Benchmark tests
|
||||
func BenchmarkGenerateSubdomain(b *testing.B) {
|
||||
for i := 0; i < b.N; i++ {
|
||||
GenerateSubdomain(6)
|
||||
}
|
||||
}
|
||||
|
||||
func BenchmarkValidateSubdomain(b *testing.B) {
|
||||
subdomain := "abc123"
|
||||
b.ResetTimer()
|
||||
for i := 0; i < b.N; i++ {
|
||||
ValidateSubdomain(subdomain)
|
||||
}
|
||||
}
|
||||
|
||||
func BenchmarkIsReserved(b *testing.B) {
|
||||
subdomain := "www"
|
||||
b.ResetTimer()
|
||||
for i := 0; i < b.N; i++ {
|
||||
IsReserved(subdomain)
|
||||
}
|
||||
}
|
||||
@@ -43,7 +43,6 @@ func LoadClientConfig(path string) (*ClientConfig, error) {
|
||||
return nil, fmt.Errorf("failed to parse config file: %w", err)
|
||||
}
|
||||
|
||||
// Validate required fields
|
||||
if config.Server == "" {
|
||||
return nil, fmt.Errorf("server address is required in config")
|
||||
}
|
||||
@@ -57,13 +56,11 @@ func SaveClientConfig(config *ClientConfig, path string) error {
|
||||
path = DefaultClientConfigPath()
|
||||
}
|
||||
|
||||
// Create directory if not exists
|
||||
dir := filepath.Dir(path)
|
||||
if err := os.MkdirAll(dir, 0700); err != nil {
|
||||
return fmt.Errorf("failed to create config directory: %w", err)
|
||||
}
|
||||
|
||||
// Marshal to YAML
|
||||
data, err := yaml.Marshal(config)
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to marshal config: %w", err)
|
||||
|
||||
@@ -47,7 +47,6 @@ func (c *ServerConfig) LoadTLSConfig() (*tls.Config, error) {
|
||||
return nil, nil
|
||||
}
|
||||
|
||||
// Check if certificate files exist
|
||||
if c.TLSCertFile == "" || c.TLSKeyFile == "" {
|
||||
return nil, fmt.Errorf("TLS enabled but certificate files not specified")
|
||||
}
|
||||
@@ -60,7 +59,6 @@ func (c *ServerConfig) LoadTLSConfig() (*tls.Config, error) {
|
||||
return nil, fmt.Errorf("key file not found: %s", c.TLSKeyFile)
|
||||
}
|
||||
|
||||
// Load certificate
|
||||
cert, err := tls.LoadX509KeyPair(c.TLSCertFile, c.TLSKeyFile)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to load certificate: %w", err)
|
||||
|
||||
Reference in New Issue
Block a user