mirror of
https://github.com/Gouryella/drip.git
synced 2026-02-23 12:53:43 +00:00
feat(init): Initializes the project's basic structure and configuration files.
This commit is contained in:
68
.dockerignore
Normal file
68
.dockerignore
Normal file
@@ -0,0 +1,68 @@
|
||||
# Git
|
||||
.git
|
||||
.gitignore
|
||||
.gitattributes
|
||||
|
||||
# Documentation
|
||||
*.md
|
||||
docs/
|
||||
LICENSE
|
||||
|
||||
# Build artifacts
|
||||
bin/
|
||||
dist/
|
||||
*.exe
|
||||
*.dll
|
||||
*.so
|
||||
*.dylib
|
||||
|
||||
# Test files
|
||||
*_test.go
|
||||
**/*_test.go
|
||||
coverage.*
|
||||
|
||||
# IDE
|
||||
.idea/
|
||||
.vscode/
|
||||
*.swp
|
||||
*.swo
|
||||
*~
|
||||
|
||||
# OS
|
||||
.DS_Store
|
||||
Thumbs.db
|
||||
|
||||
# Temporary files
|
||||
tmp/
|
||||
temp/
|
||||
*.tmp
|
||||
|
||||
# Certificates (will be generated in container)
|
||||
*.pem
|
||||
*.key
|
||||
*.crt
|
||||
*.csr
|
||||
certs/
|
||||
|
||||
# Logs
|
||||
*.log
|
||||
logs/
|
||||
|
||||
# Environment
|
||||
.env
|
||||
.env.local
|
||||
|
||||
# Dependencies (will be downloaded in container)
|
||||
vendor/
|
||||
|
||||
# CI/CD
|
||||
.github/
|
||||
.gitlab-ci.yml
|
||||
.travis.yml
|
||||
|
||||
# Development tools
|
||||
Makefile
|
||||
scripts/
|
||||
|
||||
# Examples
|
||||
examples/
|
||||
49
.env.example
Normal file
49
.env.example
Normal file
@@ -0,0 +1,49 @@
|
||||
# ===========================================
|
||||
# Server Configuration (docker-compose.yml)
|
||||
# ===========================================
|
||||
|
||||
# Domain for tunnel service
|
||||
DOMAIN=tunnel.example.com
|
||||
|
||||
# Authentication token
|
||||
AUTH_TOKEN=your-secret-token-here
|
||||
|
||||
# Server port
|
||||
PORT=8080
|
||||
|
||||
# Timezone
|
||||
TZ=UTC
|
||||
|
||||
# TLS Configuration (choose one)
|
||||
# Option 1: Auto TLS with Let's Encrypt
|
||||
# AUTO_TLS=1
|
||||
|
||||
# Option 2: Manual certificates (place in ./certs/)
|
||||
# TLS_CERT=1
|
||||
# TLS_KEY=1
|
||||
|
||||
# Build version
|
||||
VERSION=latest
|
||||
# GIT_COMMIT=
|
||||
|
||||
# ===========================================
|
||||
# Client Configuration (docker-compose.client.yml)
|
||||
# ===========================================
|
||||
|
||||
# Server address
|
||||
SERVER_ADDR=tunnel.example.com:443
|
||||
|
||||
# Tunnel type: http, https, or tcp
|
||||
TUNNEL_TYPE=http
|
||||
|
||||
# Local port to expose
|
||||
LOCAL_PORT=3000
|
||||
|
||||
# Local address (default: 127.0.0.1)
|
||||
# LOCAL_ADDRESS=192.168.1.100
|
||||
|
||||
# Custom subdomain (optional)
|
||||
# SUBDOMAIN=myapp
|
||||
|
||||
# Run as daemon
|
||||
# DAEMON=1
|
||||
54
.gitignore
vendored
Normal file
54
.gitignore
vendored
Normal file
@@ -0,0 +1,54 @@
|
||||
# Binaries for programs and plugins
|
||||
*.exe
|
||||
*.exe~
|
||||
*.dll
|
||||
*.so
|
||||
*.dylib
|
||||
bin/
|
||||
dist/
|
||||
|
||||
# Test binary, built with `go test -c`
|
||||
*.test
|
||||
|
||||
# Output of the go coverage tool
|
||||
*.out
|
||||
coverage.html
|
||||
coverage.txt
|
||||
|
||||
# Dependency directories
|
||||
vendor/
|
||||
|
||||
# Go workspace file
|
||||
go.work
|
||||
|
||||
# IDEs
|
||||
.idea/
|
||||
.vscode/
|
||||
*.swp
|
||||
*.swo
|
||||
*~
|
||||
|
||||
# OS
|
||||
.DS_Store
|
||||
Thumbs.db
|
||||
|
||||
# Environment files
|
||||
.env
|
||||
.env.local
|
||||
*.pem
|
||||
*.key
|
||||
|
||||
# Logs
|
||||
*.log
|
||||
logs/
|
||||
|
||||
# Build artifacts
|
||||
*.tar.gz
|
||||
*.zip
|
||||
|
||||
# Temporary files
|
||||
tmp/
|
||||
temp/
|
||||
certs/
|
||||
.drip-server.env
|
||||
benchmark-results/
|
||||
136
Makefile
Normal file
136
Makefile
Normal file
@@ -0,0 +1,136 @@
|
||||
.PHONY: all build build-all clean test run-server run-client install deps fmt lint
|
||||
|
||||
# Variables
|
||||
BINARY=bin/drip
|
||||
VERSION?=dev
|
||||
COMMIT=$(shell git rev-parse --short=10 HEAD 2>/dev/null || echo "unknown")
|
||||
BUILD_TIME=$(shell date -u '+%Y-%m-%d_%H:%M:%S')
|
||||
LDFLAGS=-ldflags "-s -w -X main.Version=${VERSION} -X main.GitCommit=${COMMIT} -X main.BuildTime=${BUILD_TIME}"
|
||||
|
||||
# Default target
|
||||
all: clean deps test build
|
||||
|
||||
# Install dependencies
|
||||
deps:
|
||||
go mod download
|
||||
go mod tidy
|
||||
|
||||
# Build unified binary
|
||||
build:
|
||||
@echo "Building Drip..."
|
||||
@mkdir -p bin
|
||||
go build ${LDFLAGS} -o ${BINARY} ./cmd/drip
|
||||
@echo "Build complete!"
|
||||
|
||||
# Build for all platforms
|
||||
build-all: clean
|
||||
@echo "Building for multiple platforms..."
|
||||
@mkdir -p bin
|
||||
|
||||
# Linux AMD64
|
||||
GOOS=linux GOARCH=amd64 go build ${LDFLAGS} -o bin/drip-linux-amd64 ./cmd/drip
|
||||
|
||||
# Linux ARM64
|
||||
GOOS=linux GOARCH=arm64 go build ${LDFLAGS} -o bin/drip-linux-arm64 ./cmd/drip
|
||||
|
||||
# macOS AMD64
|
||||
GOOS=darwin GOARCH=amd64 go build ${LDFLAGS} -o bin/drip-darwin-amd64 ./cmd/drip
|
||||
|
||||
# macOS ARM64 (Apple Silicon)
|
||||
GOOS=darwin GOARCH=arm64 go build ${LDFLAGS} -o bin/drip-darwin-arm64 ./cmd/drip
|
||||
|
||||
# Windows AMD64
|
||||
GOOS=windows GOARCH=amd64 go build ${LDFLAGS} -o bin/drip-windows-amd64.exe ./cmd/drip
|
||||
|
||||
# Windows ARM64
|
||||
GOOS=windows GOARCH=arm64 go build ${LDFLAGS} -o bin/drip-windows-arm64.exe ./cmd/drip
|
||||
|
||||
@echo "Multi-platform build complete!"
|
||||
|
||||
# Run tests
|
||||
test:
|
||||
go test -v -race -cover ./...
|
||||
|
||||
# Run tests with coverage
|
||||
test-coverage:
|
||||
go test -v -race -coverprofile=coverage.out -covermode=atomic ./...
|
||||
go tool cover -html=coverage.out -o coverage.html
|
||||
@echo "Coverage report generated: coverage.html"
|
||||
|
||||
# Benchmark tests
|
||||
bench:
|
||||
go test -bench=. -benchmem ./...
|
||||
|
||||
# Run server locally
|
||||
run-server:
|
||||
go run ./cmd/drip server
|
||||
|
||||
# Run client locally (example)
|
||||
run-client:
|
||||
go run ./cmd/drip http 3000
|
||||
|
||||
# Install globally
|
||||
install:
|
||||
go install ${LDFLAGS} ./cmd/drip
|
||||
|
||||
# Format code
|
||||
fmt:
|
||||
go fmt ./...
|
||||
gofmt -s -w .
|
||||
|
||||
# Lint code
|
||||
lint:
|
||||
@which golangci-lint > /dev/null || (echo "Installing golangci-lint..." && go install github.com/golangci/golangci-lint/cmd/golangci-lint@latest)
|
||||
golangci-lint run ./...
|
||||
|
||||
# Clean build artifacts
|
||||
clean:
|
||||
@echo "Cleaning..."
|
||||
@rm -rf bin/
|
||||
@rm -f coverage.out coverage.html
|
||||
@echo "Clean complete!"
|
||||
|
||||
# Docker build
|
||||
docker-build:
|
||||
docker build -t drip-server:${VERSION} -f deployments/Dockerfile .
|
||||
|
||||
# Docker run
|
||||
docker-run:
|
||||
docker run -p 80:80 -p 8080:8080 drip-server:${VERSION}
|
||||
|
||||
# Generate test certificates
|
||||
gen-certs:
|
||||
@echo "Generating test TLS 1.3 certificates..."
|
||||
@mkdir -p certs
|
||||
openssl req -x509 -newkey rsa:4096 -nodes \
|
||||
-keyout certs/server-key.pem \
|
||||
-out certs/server-cert.pem \
|
||||
-days 365 \
|
||||
-subj "/CN=localhost"
|
||||
@echo "Test certificates generated in certs/"
|
||||
@echo "⚠️ Warning: These are self-signed certificates for testing only!"
|
||||
|
||||
# Help
|
||||
help:
|
||||
@echo "Drip - Available Make Targets:"
|
||||
@echo ""
|
||||
@echo " make build - Build server and client"
|
||||
@echo " make build-all - Build for all platforms"
|
||||
@echo " make test - Run tests"
|
||||
@echo " make test-coverage - Run tests with coverage report"
|
||||
@echo " make bench - Run benchmark tests"
|
||||
@echo " make run-server - Run server locally"
|
||||
@echo " make run-client - Run client locally (port 3000)"
|
||||
@echo " make gen-certs - Generate test TLS certificates"
|
||||
@echo " make install - Install client globally"
|
||||
@echo " make fmt - Format code"
|
||||
@echo " make lint - Lint code"
|
||||
@echo " make clean - Clean build artifacts"
|
||||
@echo " make deps - Install dependencies"
|
||||
@echo " make docker-build - Build Docker image"
|
||||
@echo " make docker-run - Run Docker container"
|
||||
@echo ""
|
||||
@echo "Build info:"
|
||||
@echo " VERSION=${VERSION}"
|
||||
@echo " COMMIT=${COMMIT}"
|
||||
@echo " BUILD_TIME=${BUILD_TIME}"
|
||||
244
README.md
Normal file
244
README.md
Normal file
@@ -0,0 +1,244 @@
|
||||
# Drip - Fast Tunnels to Localhost
|
||||
|
||||
Self-hosted tunneling solution. Expose your localhost to the internet securely.
|
||||
|
||||
[中文文档](README_CN.md)
|
||||
|
||||
[](https://golang.org/)
|
||||
[](LICENSE)
|
||||
[](https://tools.ietf.org/html/rfc8446)
|
||||
|
||||
## Why?
|
||||
|
||||
**Control your data.** No third-party servers means your traffic stays between your client and your server.
|
||||
|
||||
**No limits.** Run as many tunnels as you need, use as much bandwidth as your server can handle.
|
||||
|
||||
**Actually free.** Use your own domain, no paid tiers or feature restrictions.
|
||||
|
||||
| Feature | Drip | ngrok Free |
|
||||
|---------|------|------------|
|
||||
| Privacy | Your infrastructure | Third-party servers |
|
||||
| Domain | Your domain | 1 static subdomain |
|
||||
| Bandwidth | Unlimited | 1 GB/month |
|
||||
| Active Endpoints | Unlimited | 1 endpoint |
|
||||
| Tunnels per Agent | Unlimited | Up to 3 |
|
||||
| Requests | Unlimited | 20,000/month |
|
||||
| Interstitial Page | None | Yes (removable with header) |
|
||||
| Open Source | ✓ | ✗ |
|
||||
|
||||
## Quick Install
|
||||
|
||||
### Client (macOS/Linux)
|
||||
|
||||
```bash
|
||||
bash <(curl -sL https:///install.sh)
|
||||
```
|
||||
|
||||
### Server (Linux)
|
||||
|
||||
```bash
|
||||
bash <(curl -sL https:///install-server.sh)
|
||||
```
|
||||
|
||||
## Usage
|
||||
|
||||
### Basic Tunnels
|
||||
|
||||
```bash
|
||||
# Expose local HTTP server
|
||||
drip http 3000
|
||||
|
||||
# Expose local HTTPS server
|
||||
drip https 443
|
||||
|
||||
# Pick your subdomain
|
||||
drip http 3000 --subdomain myapp
|
||||
# → https://myapp.your-domain.com
|
||||
|
||||
# Expose TCP service (database, SSH, etc.)
|
||||
drip tcp 5432
|
||||
```
|
||||
|
||||
### Forward to Any Address
|
||||
|
||||
Not just localhost - forward to any device on your network:
|
||||
|
||||
```bash
|
||||
# Forward to another machine on LAN
|
||||
drip http 8080 --address 192.168.1.100
|
||||
|
||||
# Forward to Docker container
|
||||
drip http 3000 --address 172.17.0.2
|
||||
|
||||
# Forward to specific interface
|
||||
drip http 3000 --address 10.0.0.5
|
||||
```
|
||||
|
||||
### Daemon Mode
|
||||
|
||||
Run tunnels in the background:
|
||||
|
||||
```bash
|
||||
# Start tunnel as daemon
|
||||
drip daemon start http 3000
|
||||
drip daemon start https 8443 --subdomain api
|
||||
|
||||
# Manage daemons
|
||||
drip daemon list
|
||||
drip daemon stop http-3000
|
||||
drip daemon logs http-3000
|
||||
```
|
||||
|
||||
## Server Deployment
|
||||
|
||||
### Prerequisites
|
||||
|
||||
- A domain with DNS pointing to your server (A record)
|
||||
- Wildcard DNS for subdomains: `*.tunnel.example.com -> YOUR_IP`
|
||||
- SSL certificate (wildcard recommended)
|
||||
|
||||
### Option 1: Direct (Recommended)
|
||||
|
||||
Drip server handles TLS directly on port 443:
|
||||
|
||||
```bash
|
||||
# Get wildcard certificate
|
||||
sudo certbot certonly --manual --preferred-challenges dns \
|
||||
-d "*.tunnel.example.com" -d "tunnel.example.com"
|
||||
|
||||
# Start server
|
||||
drip-server \
|
||||
--port 443 \
|
||||
--domain tunnel.example.com \
|
||||
--tls-cert /etc/letsencrypt/live/tunnel.example.com/fullchain.pem \
|
||||
--tls-key /etc/letsencrypt/live/tunnel.example.com/privkey.pem \
|
||||
--token YOUR_SECRET_TOKEN
|
||||
```
|
||||
|
||||
### Option 2: Behind Nginx
|
||||
|
||||
Run Drip on port 8443, let Nginx handle SSL termination:
|
||||
|
||||
```nginx
|
||||
server {
|
||||
listen 443 ssl http2;
|
||||
server_name *.tunnel.example.com;
|
||||
|
||||
ssl_certificate /etc/letsencrypt/live/tunnel.example.com/fullchain.pem;
|
||||
ssl_certificate_key /etc/letsencrypt/live/tunnel.example.com/privkey.pem;
|
||||
|
||||
location / {
|
||||
proxy_pass https://127.0.0.1:8443;
|
||||
proxy_ssl_verify off;
|
||||
proxy_http_version 1.1;
|
||||
proxy_set_header Host $host;
|
||||
proxy_set_header X-Real-IP $remote_addr;
|
||||
proxy_set_header X-Forwarded-For $proxy_add_x_forwarded_for;
|
||||
proxy_set_header X-Forwarded-Proto $scheme;
|
||||
proxy_buffering off;
|
||||
}
|
||||
}
|
||||
```
|
||||
|
||||
### Systemd Service
|
||||
|
||||
The install script creates `/etc/systemd/system/drip-server.service` automatically. Manage with:
|
||||
|
||||
```bash
|
||||
sudo systemctl start drip-server
|
||||
sudo systemctl enable drip-server
|
||||
sudo journalctl -u drip-server -f
|
||||
```
|
||||
|
||||
## Features
|
||||
|
||||
**Security**
|
||||
- TLS 1.3 encryption for all connections
|
||||
- Token-based authentication
|
||||
- No legacy protocol support
|
||||
|
||||
**Flexibility**
|
||||
- HTTP, HTTPS, and TCP tunnels
|
||||
- Forward to localhost or any LAN address
|
||||
- Custom subdomains or auto-generated
|
||||
- Daemon mode for persistent tunnels
|
||||
|
||||
**Performance**
|
||||
- Binary protocol with msgpack encoding
|
||||
- Connection pooling and reuse
|
||||
- Minimal overhead between client and server
|
||||
|
||||
**Simplicity**
|
||||
- One-line installation
|
||||
- Save config once, use everywhere
|
||||
- Real-time connection stats
|
||||
|
||||
## Architecture
|
||||
|
||||
```
|
||||
┌─────────────┐ ┌──────────────┐ ┌─────────────┐
|
||||
│ Internet │ ──────> │ Server │ <────── │ Client │
|
||||
│ User │ HTTPS │ (Drip) │ TLS 1.3 │ localhost │
|
||||
└─────────────┘ └──────────────┘ └─────────────┘
|
||||
```
|
||||
|
||||
## Common Use Cases
|
||||
|
||||
**Development & Testing**
|
||||
```bash
|
||||
# Show local dev site to client
|
||||
drip http 3000
|
||||
|
||||
# Test webhooks from services like Stripe
|
||||
drip http 8000 --subdomain webhooks
|
||||
```
|
||||
|
||||
**Home Server Access**
|
||||
```bash
|
||||
# Access home NAS remotely
|
||||
drip http 5000 --address 192.168.1.50
|
||||
|
||||
# Remote into home network
|
||||
drip tcp 22
|
||||
```
|
||||
|
||||
**Docker & Containers**
|
||||
```bash
|
||||
# Expose containerized app
|
||||
drip http 8080 --address 172.17.0.3
|
||||
|
||||
# Database access for debugging
|
||||
drip tcp 5432 --address db-container
|
||||
```
|
||||
|
||||
## Command Reference
|
||||
|
||||
```bash
|
||||
# HTTP tunnel
|
||||
drip http <port> [flags]
|
||||
--subdomain, -n Custom subdomain
|
||||
--address, -a Target address (default: 127.0.0.1)
|
||||
--server Server address
|
||||
--token Auth token
|
||||
|
||||
# HTTPS tunnel
|
||||
drip https <port> [flags]
|
||||
|
||||
# TCP tunnel
|
||||
drip tcp <port> [flags]
|
||||
|
||||
# Daemon commands
|
||||
drip daemon start <type> <port> [flags]
|
||||
drip daemon list
|
||||
drip daemon stop <name>
|
||||
drip daemon logs <name>
|
||||
|
||||
# Configuration
|
||||
drip config init
|
||||
drip config show
|
||||
```
|
||||
|
||||
## License
|
||||
|
||||
MIT License - see [LICENSE](LICENSE) for details
|
||||
244
README_CN.md
Normal file
244
README_CN.md
Normal file
@@ -0,0 +1,244 @@
|
||||
# Drip - 快速内网穿透工具
|
||||
|
||||
自建隧道服务,让本地服务安全地暴露到公网。
|
||||
|
||||
[English](README.md)
|
||||
|
||||
[](https://golang.org/)
|
||||
[](LICENSE)
|
||||
[](https://tools.ietf.org/html/rfc8446)
|
||||
|
||||
## 为什么用它?
|
||||
|
||||
**数据掌控。** 没有第三方服务器,流量只在你的客户端和服务器之间传输。
|
||||
|
||||
**没有限制。** 想开多少隧道就开多少,带宽只受服务器性能限制。
|
||||
|
||||
**真正免费。** 用自己的域名,没有付费功能,没有阉割版。
|
||||
|
||||
| 特性 | Drip | ngrok 免费版 |
|
||||
|------|------|-------------|
|
||||
| 隐私 | 自己的基础设施 | 第三方服务器 |
|
||||
| 域名 | 你的域名 | 1 个固定子域名 |
|
||||
| 带宽 | 无限制 | 1 GB/月 |
|
||||
| 活跃端点 | 无限制 | 1 个端点 |
|
||||
| 每个代理的隧道数 | 无限制 | 最多 3 个 |
|
||||
| 请求数 | 无限制 | 20,000 次/月 |
|
||||
| 警告页面 | 无 | 有(可用请求头移除) |
|
||||
| 开源 | ✓ | ✗ |
|
||||
|
||||
## 一键安装
|
||||
|
||||
### 客户端 (macOS/Linux)
|
||||
|
||||
```bash
|
||||
bash <(curl -sL https:///install.sh)
|
||||
```
|
||||
|
||||
### 服务端 (Linux)
|
||||
|
||||
```bash
|
||||
bash <(curl -sL https:///install-server.sh)
|
||||
```
|
||||
|
||||
## 使用方法
|
||||
|
||||
### 基础隧道
|
||||
|
||||
```bash
|
||||
# 暴露本地 HTTP 服务
|
||||
drip http 3000
|
||||
|
||||
# 暴露本地 HTTPS 服务
|
||||
drip https 443
|
||||
|
||||
# 自定义子域名
|
||||
drip http 3000 --subdomain myapp
|
||||
# → https://myapp.你的域名.com
|
||||
|
||||
# 暴露 TCP 服务(数据库、SSH 等)
|
||||
drip tcp 5432
|
||||
```
|
||||
|
||||
### 转发到任意地址
|
||||
|
||||
不只是 localhost,可以转发到局域网内的任何设备:
|
||||
|
||||
```bash
|
||||
# 转发到局域网其他设备
|
||||
drip http 8080 --address 192.168.1.100
|
||||
|
||||
# 转发到 Docker 容器
|
||||
drip http 3000 --address 172.17.0.2
|
||||
|
||||
# 转发到特定网卡
|
||||
drip http 3000 --address 10.0.0.5
|
||||
```
|
||||
|
||||
### 后台守护进程
|
||||
|
||||
让隧道在后台持续运行:
|
||||
|
||||
```bash
|
||||
# 启动后台隧道
|
||||
drip daemon start http 3000
|
||||
drip daemon start https 8443 --subdomain api
|
||||
|
||||
# 管理后台隧道
|
||||
drip daemon list
|
||||
drip daemon stop http-3000
|
||||
drip daemon logs http-3000
|
||||
```
|
||||
|
||||
## 服务端部署
|
||||
|
||||
### 前置条件
|
||||
|
||||
- 域名已解析到服务器(A 记录)
|
||||
- 泛域名解析:`*.tunnel.example.com -> 服务器IP`
|
||||
- SSL 证书(推荐通配符证书)
|
||||
|
||||
### 方式一:直接部署(推荐)
|
||||
|
||||
Drip 服务端直接监听 443 端口处理 TLS:
|
||||
|
||||
```bash
|
||||
# 获取通配符证书
|
||||
sudo certbot certonly --manual --preferred-challenges dns \
|
||||
-d "*.tunnel.example.com" -d "tunnel.example.com"
|
||||
|
||||
# 启动服务
|
||||
drip-server \
|
||||
--port 443 \
|
||||
--domain tunnel.example.com \
|
||||
--tls-cert /etc/letsencrypt/live/tunnel.example.com/fullchain.pem \
|
||||
--tls-key /etc/letsencrypt/live/tunnel.example.com/privkey.pem \
|
||||
--token 你的密钥
|
||||
```
|
||||
|
||||
### 方式二:Nginx 反向代理
|
||||
|
||||
Drip 监听 8443 端口,Nginx 处理 SSL:
|
||||
|
||||
```nginx
|
||||
server {
|
||||
listen 443 ssl http2;
|
||||
server_name *.tunnel.example.com;
|
||||
|
||||
ssl_certificate /etc/letsencrypt/live/tunnel.example.com/fullchain.pem;
|
||||
ssl_certificate_key /etc/letsencrypt/live/tunnel.example.com/privkey.pem;
|
||||
|
||||
location / {
|
||||
proxy_pass https://127.0.0.1:8443;
|
||||
proxy_ssl_verify off;
|
||||
proxy_http_version 1.1;
|
||||
proxy_set_header Host $host;
|
||||
proxy_set_header X-Real-IP $remote_addr;
|
||||
proxy_set_header X-Forwarded-For $proxy_add_x_forwarded_for;
|
||||
proxy_set_header X-Forwarded-Proto $scheme;
|
||||
proxy_buffering off;
|
||||
}
|
||||
}
|
||||
```
|
||||
|
||||
### Systemd 服务
|
||||
|
||||
安装脚本会自动创建 `/etc/systemd/system/drip-server.service`,使用以下命令管理:
|
||||
|
||||
```bash
|
||||
sudo systemctl start drip-server # 启动
|
||||
sudo systemctl enable drip-server # 开机启动
|
||||
sudo journalctl -u drip-server -f # 查看日志
|
||||
```
|
||||
|
||||
## 核心特性
|
||||
|
||||
**安全性**
|
||||
- 所有连接使用 TLS 1.3 加密
|
||||
- 基于 Token 的认证
|
||||
- 不支持旧版不安全协议
|
||||
|
||||
**灵活性**
|
||||
- 支持 HTTP、HTTPS 和 TCP 隧道
|
||||
- 转发到 localhost 或局域网任意地址
|
||||
- 自定义子域名或自动生成
|
||||
- 后台守护进程模式
|
||||
|
||||
**性能**
|
||||
- 二进制协议配合 msgpack 编码
|
||||
- 连接池复用
|
||||
- 客户端与服务器之间开销极小
|
||||
|
||||
**简单易用**
|
||||
- 一行命令安装
|
||||
- 配置一次处处使用
|
||||
- 实时连接统计
|
||||
|
||||
## 架构
|
||||
|
||||
```
|
||||
┌─────────────┐ ┌──────────────┐ ┌─────────────┐
|
||||
│ 外部用户 │ ──────> │ 服务器 │ <────── │ 你的电脑 │
|
||||
│ │ HTTPS │ (Drip) │ TLS 1.3 │ localhost │
|
||||
└─────────────┘ └──────────────┘ └─────────────┘
|
||||
```
|
||||
|
||||
## 常见使用场景
|
||||
|
||||
**开发和测试**
|
||||
```bash
|
||||
# 给客户演示本地开发的网站
|
||||
drip http 3000
|
||||
|
||||
# 测试第三方 webhook(如 Stripe)
|
||||
drip http 8000 --subdomain webhooks
|
||||
```
|
||||
|
||||
**家庭服务器**
|
||||
```bash
|
||||
# 远程访问家里的 NAS
|
||||
drip http 5000 --address 192.168.1.50
|
||||
|
||||
# 远程连接家庭网络
|
||||
drip tcp 22
|
||||
```
|
||||
|
||||
**Docker 和容器**
|
||||
```bash
|
||||
# 暴露容器化应用
|
||||
drip http 8080 --address 172.17.0.3
|
||||
|
||||
# 调试数据库
|
||||
drip tcp 5432 --address db-container
|
||||
```
|
||||
|
||||
## 命令参考
|
||||
|
||||
```bash
|
||||
# HTTP 隧道
|
||||
drip http <端口> [选项]
|
||||
--subdomain, -n 自定义子域名
|
||||
--address, -a 目标地址(默认:127.0.0.1)
|
||||
--server 服务器地址
|
||||
--token 认证令牌
|
||||
|
||||
# HTTPS 隧道
|
||||
drip https <端口> [选项]
|
||||
|
||||
# TCP 隧道
|
||||
drip tcp <端口> [选项]
|
||||
|
||||
# 守护进程命令
|
||||
drip daemon start <类型> <端口> [选项]
|
||||
drip daemon list
|
||||
drip daemon stop <名称>
|
||||
drip daemon logs <名称>
|
||||
|
||||
# 配置管理
|
||||
drip config init
|
||||
drip config show
|
||||
```
|
||||
|
||||
## 开源协议
|
||||
|
||||
MIT License - 详见 [LICENSE](LICENSE)
|
||||
25
cmd/drip/main.go
Normal file
25
cmd/drip/main.go
Normal file
@@ -0,0 +1,25 @@
|
||||
package main
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"os"
|
||||
|
||||
"drip/internal/client/cli"
|
||||
)
|
||||
|
||||
var (
|
||||
Version = "dev"
|
||||
GitCommit = "unknown"
|
||||
BuildTime = "unknown"
|
||||
)
|
||||
|
||||
func main() {
|
||||
// Set version information
|
||||
cli.SetVersion(Version, GitCommit, BuildTime)
|
||||
|
||||
// Execute CLI
|
||||
if err := cli.Execute(); err != nil {
|
||||
fmt.Fprintf(os.Stderr, "Error: %v\n", err)
|
||||
os.Exit(1)
|
||||
}
|
||||
}
|
||||
38
deployments/Dockerfile
Normal file
38
deployments/Dockerfile
Normal file
@@ -0,0 +1,38 @@
|
||||
FROM golang:1.25-alpine AS builder
|
||||
|
||||
RUN apk add --no-cache git ca-certificates tzdata
|
||||
|
||||
WORKDIR /app
|
||||
|
||||
COPY go.mod go.sum ./
|
||||
RUN go mod download
|
||||
|
||||
COPY . .
|
||||
|
||||
RUN CGO_ENABLED=0 GOOS=linux GOARCH=amd64 go build \
|
||||
-ldflags="-s -w -X main.Version=${VERSION:-dev} -X main.GitCommit=${GIT_COMMIT:-unknown} -X main.BuildTime=$(date -u '+%Y-%m-%d_%H:%M:%S')" \
|
||||
-o drip \
|
||||
./cmd/drip
|
||||
|
||||
FROM alpine:latest
|
||||
|
||||
RUN apk add --no-cache ca-certificates tzdata
|
||||
|
||||
RUN addgroup -S drip && adduser -S -G drip drip
|
||||
|
||||
WORKDIR /app
|
||||
|
||||
RUN mkdir -p /app/data/certs && \
|
||||
chown -R drip:drip /app
|
||||
|
||||
COPY --from=builder /app/drip /app/drip
|
||||
|
||||
USER drip
|
||||
|
||||
EXPOSE 80 443 8080 20000-40000
|
||||
|
||||
HEALTHCHECK --interval=30s --timeout=3s --start-period=5s --retries=3 \
|
||||
CMD wget --no-verbose --tries=1 --spider http://localhost:8080/health || exit 1
|
||||
|
||||
ENTRYPOINT ["/app/drip"]
|
||||
CMD ["server", "--port", "8080"]
|
||||
33
deployments/Dockerfile.client
Normal file
33
deployments/Dockerfile.client
Normal file
@@ -0,0 +1,33 @@
|
||||
FROM golang:1.25-alpine AS builder
|
||||
|
||||
RUN apk add --no-cache git ca-certificates
|
||||
|
||||
WORKDIR /app
|
||||
|
||||
COPY go.mod go.sum ./
|
||||
RUN go mod download
|
||||
|
||||
COPY . .
|
||||
|
||||
RUN CGO_ENABLED=0 GOOS=linux GOARCH=amd64 go build \
|
||||
-ldflags="-s -w -X main.Version=${VERSION:-dev} -X main.GitCommit=${GIT_COMMIT:-unknown} -X main.BuildTime=$(date -u '+%Y-%m-%d_%H:%M:%S')" \
|
||||
-o drip \
|
||||
./cmd/drip
|
||||
|
||||
FROM alpine:latest
|
||||
|
||||
RUN apk add --no-cache ca-certificates
|
||||
|
||||
RUN addgroup -S drip && adduser -S -G drip drip
|
||||
|
||||
WORKDIR /app
|
||||
|
||||
RUN mkdir -p /app/data && \
|
||||
chown -R drip:drip /app
|
||||
|
||||
COPY --from=builder /app/drip /app/drip
|
||||
|
||||
USER drip
|
||||
|
||||
ENTRYPOINT ["/app/drip"]
|
||||
CMD ["--help"]
|
||||
220
deployments/README.md
Normal file
220
deployments/README.md
Normal file
@@ -0,0 +1,220 @@
|
||||
# Docker Deployment
|
||||
|
||||
## Quick Start
|
||||
|
||||
### Server (Production)
|
||||
|
||||
```bash
|
||||
# Copy and configure environment
|
||||
cp .env.example .env
|
||||
nano .env
|
||||
|
||||
# Edit server configuration
|
||||
DOMAIN=tunnel.example.com
|
||||
AUTH_TOKEN=your-secret-token
|
||||
TLS_CERT=1
|
||||
TLS_KEY=1
|
||||
|
||||
# Place certificates
|
||||
mkdir -p certs
|
||||
cp /path/to/fullchain.pem certs/
|
||||
cp /path/to/privkey.pem certs/
|
||||
|
||||
# Uncomment volume mount in docker-compose.yml
|
||||
# - ./certs:/app/data/certs:ro
|
||||
|
||||
# Start server
|
||||
docker compose up -d
|
||||
|
||||
# View logs
|
||||
docker compose logs -f
|
||||
```
|
||||
|
||||
### Client (Development/Testing)
|
||||
|
||||
```bash
|
||||
# Copy and configure client environment
|
||||
cp .env.example .env.client
|
||||
nano .env.client
|
||||
|
||||
# Edit client configuration
|
||||
SERVER_ADDR=tunnel.example.com:443
|
||||
AUTH_TOKEN=your-secret-token
|
||||
TUNNEL_TYPE=http
|
||||
LOCAL_PORT=3000
|
||||
|
||||
# Start client
|
||||
docker compose -f docker-compose.client.yml --env-file .env.client up -d
|
||||
|
||||
# View logs
|
||||
docker compose -f docker-compose.client.yml logs -f
|
||||
```
|
||||
|
||||
## Configuration
|
||||
|
||||
### Environment Variables
|
||||
|
||||
Create `.env` from `.env.example`:
|
||||
|
||||
```bash
|
||||
DOMAIN=tunnel.example.com
|
||||
AUTH_TOKEN=your-secret-token
|
||||
```
|
||||
|
||||
### TLS Certificates
|
||||
|
||||
**Option 1: Auto TLS (Let's Encrypt)**
|
||||
|
||||
```bash
|
||||
# Enable in .env
|
||||
AUTO_TLS=1
|
||||
|
||||
# Ensure port 80 is accessible for ACME challenges
|
||||
```
|
||||
|
||||
**Option 2: Manual Certificates**
|
||||
|
||||
```bash
|
||||
# Place certificates in ./certs/
|
||||
mkdir -p certs
|
||||
cp fullchain.pem certs/cert.pem
|
||||
cp privkey.pem certs/key.pem
|
||||
|
||||
# Uncomment in docker-compose.yml
|
||||
# - ./certs:/app/data/certs:ro
|
||||
|
||||
# Enable in .env
|
||||
TLS_CERT=1
|
||||
TLS_KEY=1
|
||||
```
|
||||
|
||||
## Data Persistence
|
||||
|
||||
All data is stored in Docker volumes:
|
||||
|
||||
- `drip-data`: Server data and certificates at `/app/data`
|
||||
- `client-data`: Client configuration at `/app/data`
|
||||
|
||||
### Backup
|
||||
|
||||
```bash
|
||||
# Backup server data
|
||||
docker run --rm -v drip-data:/data -v $(pwd):/backup alpine tar czf /backup/drip-backup.tar.gz -C /data .
|
||||
|
||||
# Restore
|
||||
docker run --rm -v drip-data:/data -v $(pwd):/backup alpine tar xzf /backup/drip-backup.tar.gz -C /data
|
||||
```
|
||||
|
||||
## Port Mapping
|
||||
|
||||
| Container Port | Host Port | Purpose |
|
||||
|---------------|-----------|---------|
|
||||
| 80 | 80 | HTTP (ACME challenges) |
|
||||
| 443 | 443 | HTTPS (main service) |
|
||||
| 8080 | 8080 | HTTP (no TLS) |
|
||||
| 20000-20100 | 20000-20100 | TCP tunnels |
|
||||
|
||||
## Management
|
||||
|
||||
### Server
|
||||
|
||||
```bash
|
||||
# Start
|
||||
docker compose up -d
|
||||
|
||||
# Stop
|
||||
docker compose down
|
||||
|
||||
# Restart
|
||||
docker compose restart
|
||||
|
||||
# View logs
|
||||
docker compose logs -f
|
||||
|
||||
# Shell access
|
||||
docker compose exec server sh
|
||||
|
||||
# Update
|
||||
docker compose pull
|
||||
docker compose up -d
|
||||
```
|
||||
|
||||
### Client
|
||||
|
||||
```bash
|
||||
# Start
|
||||
docker compose -f docker-compose.client.yml up -d
|
||||
|
||||
# Stop
|
||||
docker compose -f docker-compose.client.yml down
|
||||
|
||||
# View logs
|
||||
docker compose -f docker-compose.client.yml logs -f
|
||||
|
||||
# Different tunnel types
|
||||
TUNNEL_TYPE=http LOCAL_PORT=3000 docker compose -f docker-compose.client.yml up -d
|
||||
TUNNEL_TYPE=https LOCAL_PORT=8443 docker compose -f docker-compose.client.yml up -d
|
||||
TUNNEL_TYPE=tcp LOCAL_PORT=5432 docker compose -f docker-compose.client.yml up -d
|
||||
```
|
||||
|
||||
## Production Deployment
|
||||
|
||||
### With Reverse Proxy
|
||||
|
||||
If using Nginx/Traefik in front:
|
||||
|
||||
```yaml
|
||||
services:
|
||||
server:
|
||||
ports:
|
||||
- "127.0.0.1:8080:8080" # Only expose to localhost
|
||||
command: >
|
||||
server
|
||||
--domain tunnel.example.com
|
||||
--port 8080
|
||||
--token ${AUTH_TOKEN}
|
||||
```
|
||||
|
||||
### Resource Limits
|
||||
|
||||
Adjust in `docker-compose.yml`:
|
||||
|
||||
```yaml
|
||||
deploy:
|
||||
resources:
|
||||
limits:
|
||||
cpus: '2'
|
||||
memory: 512M
|
||||
```
|
||||
|
||||
## Troubleshooting
|
||||
|
||||
**Certificate errors**
|
||||
|
||||
```bash
|
||||
# Check certificate files
|
||||
docker compose exec server ls -la /app/data/certs
|
||||
|
||||
# Check server logs
|
||||
docker compose logs server | grep -i tls
|
||||
```
|
||||
|
||||
**Connection issues**
|
||||
|
||||
```bash
|
||||
# Verify port accessibility
|
||||
curl -I https://tunnel.example.com
|
||||
|
||||
# Check server status
|
||||
docker compose exec server /app/drip server --help
|
||||
```
|
||||
|
||||
**Reset everything**
|
||||
|
||||
```bash
|
||||
# Stop and remove everything
|
||||
docker compose down -v
|
||||
|
||||
# Start fresh
|
||||
docker compose up -d
|
||||
```
|
||||
38
docker-compose.client.yml
Normal file
38
docker-compose.client.yml
Normal file
@@ -0,0 +1,38 @@
|
||||
services:
|
||||
client:
|
||||
build:
|
||||
context: .
|
||||
dockerfile: deployments/Dockerfile.client
|
||||
args:
|
||||
VERSION: ${VERSION:-dev}
|
||||
GIT_COMMIT: ${GIT_COMMIT:-unknown}
|
||||
image: drip-client:${VERSION:-latest}
|
||||
container_name: drip-client
|
||||
restart: unless-stopped
|
||||
network_mode: host
|
||||
|
||||
volumes:
|
||||
- drip-client-data:/app/data
|
||||
# Optional: mount config file
|
||||
# - ./client-config.yaml:/app/data/config.yaml:ro
|
||||
|
||||
environment:
|
||||
TZ: ${TZ:-UTC}
|
||||
|
||||
command: >
|
||||
${TUNNEL_TYPE:-http} ${LOCAL_PORT:-3000}
|
||||
--server ${SERVER_ADDR}
|
||||
${AUTH_TOKEN:+--token ${AUTH_TOKEN}}
|
||||
${SUBDOMAIN:+--subdomain ${SUBDOMAIN}}
|
||||
${LOCAL_ADDRESS:+--address ${LOCAL_ADDRESS}}
|
||||
${DAEMON:+--daemon}
|
||||
|
||||
logging:
|
||||
driver: json-file
|
||||
options:
|
||||
max-size: 10m
|
||||
max-file: "3"
|
||||
|
||||
volumes:
|
||||
drip-client-data:
|
||||
driver: local
|
||||
67
docker-compose.yml
Normal file
67
docker-compose.yml
Normal file
@@ -0,0 +1,67 @@
|
||||
services:
|
||||
server:
|
||||
build:
|
||||
context: .
|
||||
dockerfile: deployments/Dockerfile
|
||||
args:
|
||||
VERSION: ${VERSION:-dev}
|
||||
GIT_COMMIT: ${GIT_COMMIT:-unknown}
|
||||
image: drip-server:${VERSION:-latest}
|
||||
container_name: drip-server
|
||||
restart: unless-stopped
|
||||
|
||||
ports:
|
||||
- "80:80"
|
||||
- "443:443"
|
||||
- "8080:8080"
|
||||
- "20000-20100:20000-20100"
|
||||
|
||||
volumes:
|
||||
- drip-data:/app/data
|
||||
# Mount TLS certificates if not using auto-TLS
|
||||
# - ./certs:/app/data/certs:ro
|
||||
|
||||
environment:
|
||||
TZ: ${TZ:-UTC}
|
||||
|
||||
command: >
|
||||
server
|
||||
--domain ${DOMAIN:-tunnel.localhost}
|
||||
--port ${PORT:-8080}
|
||||
${TLS_CERT:+--tls-cert /app/data/certs/fullchain.pem}
|
||||
${TLS_KEY:+--tls-key /app/data/certs/privkey.pem}
|
||||
${AUTO_TLS:+--auto-tls}
|
||||
${AUTH_TOKEN:+--token ${AUTH_TOKEN}}
|
||||
|
||||
networks:
|
||||
- drip-net
|
||||
|
||||
logging:
|
||||
driver: json-file
|
||||
options:
|
||||
max-size: 10m
|
||||
max-file: "3"
|
||||
|
||||
deploy:
|
||||
resources:
|
||||
limits:
|
||||
cpus: '2'
|
||||
memory: 512M
|
||||
reservations:
|
||||
cpus: '0.5'
|
||||
memory: 128M
|
||||
|
||||
healthcheck:
|
||||
test: ["CMD", "wget", "--no-verbose", "--tries=1", "--spider", "http://localhost:${PORT:-8080}/health"]
|
||||
interval: 30s
|
||||
timeout: 3s
|
||||
retries: 3
|
||||
start_period: 5s
|
||||
|
||||
volumes:
|
||||
drip-data:
|
||||
driver: local
|
||||
|
||||
networks:
|
||||
drip-net:
|
||||
driver: bridge
|
||||
22
go.mod
Normal file
22
go.mod
Normal file
@@ -0,0 +1,22 @@
|
||||
module drip
|
||||
|
||||
go 1.25.4
|
||||
|
||||
require (
|
||||
github.com/gorilla/websocket v1.5.3
|
||||
github.com/spf13/cobra v1.10.1
|
||||
github.com/vmihailenco/msgpack/v5 v5.4.1
|
||||
go.uber.org/zap v1.27.1
|
||||
golang.org/x/crypto v0.45.0
|
||||
gopkg.in/yaml.v3 v3.0.1
|
||||
)
|
||||
|
||||
require (
|
||||
github.com/inconshreveable/mousetrap v1.1.0 // indirect
|
||||
github.com/spf13/pflag v1.0.10 // indirect
|
||||
github.com/stretchr/testify v1.11.1 // indirect
|
||||
github.com/vmihailenco/tagparser/v2 v2.0.0 // indirect
|
||||
go.uber.org/multierr v1.11.0 // indirect
|
||||
golang.org/x/net v0.47.0 // indirect
|
||||
golang.org/x/text v0.31.0 // indirect
|
||||
)
|
||||
37
go.sum
Normal file
37
go.sum
Normal file
@@ -0,0 +1,37 @@
|
||||
github.com/cpuguy83/go-md2man/v2 v2.0.6/go.mod h1:oOW0eioCTA6cOiMLiUPZOpcVxMig6NIQQ7OS05n1F4g=
|
||||
github.com/davecgh/go-spew v1.1.1 h1:vj9j/u1bqnvCEfJOwUhtlOARqs3+rkHYY13jYWTU97c=
|
||||
github.com/davecgh/go-spew v1.1.1/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38=
|
||||
github.com/gorilla/websocket v1.5.3 h1:saDtZ6Pbx/0u+bgYQ3q96pZgCzfhKXGPqt7kZ72aNNg=
|
||||
github.com/gorilla/websocket v1.5.3/go.mod h1:YR8l580nyteQvAITg2hZ9XVh4b55+EU/adAjf1fMHhE=
|
||||
github.com/inconshreveable/mousetrap v1.1.0 h1:wN+x4NVGpMsO7ErUn/mUI3vEoE6Jt13X2s0bqwp9tc8=
|
||||
github.com/inconshreveable/mousetrap v1.1.0/go.mod h1:vpF70FUmC8bwa3OWnCshd2FqLfsEA9PFc4w1p2J65bw=
|
||||
github.com/pmezard/go-difflib v1.0.0 h1:4DBwDE0NGyQoBHbLQYPwSUPoCMWR5BEzIk/f1lZbAQM=
|
||||
github.com/pmezard/go-difflib v1.0.0/go.mod h1:iKH77koFhYxTK1pcRnkKkqfTogsbg7gZNVY4sRDYZ/4=
|
||||
github.com/russross/blackfriday/v2 v2.1.0/go.mod h1:+Rmxgy9KzJVeS9/2gXHxylqXiyQDYRxCVz55jmeOWTM=
|
||||
github.com/spf13/cobra v1.10.1 h1:lJeBwCfmrnXthfAupyUTzJ/J4Nc1RsHC/mSRU2dll/s=
|
||||
github.com/spf13/cobra v1.10.1/go.mod h1:7SmJGaTHFVBY0jW4NXGluQoLvhqFQM+6XSKD+P4XaB0=
|
||||
github.com/spf13/pflag v1.0.9/go.mod h1:McXfInJRrz4CZXVZOBLb0bTZqETkiAhM9Iw0y3An2Bg=
|
||||
github.com/spf13/pflag v1.0.10 h1:4EBh2KAYBwaONj6b2Ye1GiHfwjqyROoF4RwYO+vPwFk=
|
||||
github.com/spf13/pflag v1.0.10/go.mod h1:McXfInJRrz4CZXVZOBLb0bTZqETkiAhM9Iw0y3An2Bg=
|
||||
github.com/stretchr/testify v1.11.1 h1:7s2iGBzp5EwR7/aIZr8ao5+dra3wiQyKjjFuvgVKu7U=
|
||||
github.com/stretchr/testify v1.11.1/go.mod h1:wZwfW3scLgRK+23gO65QZefKpKQRnfz6sD981Nm4B6U=
|
||||
github.com/vmihailenco/msgpack/v5 v5.4.1 h1:cQriyiUvjTwOHg8QZaPihLWeRAAVoCpE00IUPn0Bjt8=
|
||||
github.com/vmihailenco/msgpack/v5 v5.4.1/go.mod h1:GaZTsDaehaPpQVyxrf5mtQlH+pc21PIudVV/E3rRQok=
|
||||
github.com/vmihailenco/tagparser/v2 v2.0.0 h1:y09buUbR+b5aycVFQs/g70pqKVZNBmxwAhO7/IwNM9g=
|
||||
github.com/vmihailenco/tagparser/v2 v2.0.0/go.mod h1:Wri+At7QHww0WTrCBeu4J6bNtoV6mEfg5OIWRZA9qds=
|
||||
go.uber.org/goleak v1.3.0 h1:2K3zAYmnTNqV73imy9J1T3WC+gmCePx2hEGkimedGto=
|
||||
go.uber.org/goleak v1.3.0/go.mod h1:CoHD4mav9JJNrW/WLlf7HGZPjdw8EucARQHekz1X6bE=
|
||||
go.uber.org/multierr v1.11.0 h1:blXXJkSxSSfBVBlC76pxqeO+LN3aDfLQo+309xJstO0=
|
||||
go.uber.org/multierr v1.11.0/go.mod h1:20+QtiLqy0Nd6FdQB9TLXag12DsQkrbs3htMFfDN80Y=
|
||||
go.uber.org/zap v1.27.1 h1:08RqriUEv8+ArZRYSTXy1LeBScaMpVSTBhCeaZYfMYc=
|
||||
go.uber.org/zap v1.27.1/go.mod h1:GB2qFLM7cTU87MWRP2mPIjqfIDnGu+VIO4V/SdhGo2E=
|
||||
golang.org/x/crypto v0.45.0 h1:jMBrvKuj23MTlT0bQEOBcAE0mjg8mK9RXFhRH6nyF3Q=
|
||||
golang.org/x/crypto v0.45.0/go.mod h1:XTGrrkGJve7CYK7J8PEww4aY7gM3qMCElcJQ8n8JdX4=
|
||||
golang.org/x/net v0.47.0 h1:Mx+4dIFzqraBXUugkia1OOvlD6LemFo1ALMHjrXDOhY=
|
||||
golang.org/x/net v0.47.0/go.mod h1:/jNxtkgq5yWUGYkaZGqo27cfGZ1c5Nen03aYrrKpVRU=
|
||||
golang.org/x/text v0.31.0 h1:aC8ghyu4JhP8VojJ2lEHBnochRno1sgL6nEi9WGFGMM=
|
||||
golang.org/x/text v0.31.0/go.mod h1:tKRAlv61yKIjGGHX/4tP1LTbc13YSec1pxVEWXzfoeM=
|
||||
gopkg.in/check.v1 v0.0.0-20161208181325-20d25e280405 h1:yhCVgyC4o1eVCa2tZl7eS0r+SDo693bJlVdllGtEeKM=
|
||||
gopkg.in/check.v1 v0.0.0-20161208181325-20d25e280405/go.mod h1:Co6ibVJAznAaIkqp8huTwlJQCZ016jof/cbN4VW5Yz0=
|
||||
gopkg.in/yaml.v3 v3.0.1 h1:fxVm/GzAzEWqLHuvctI91KS9hhNmmWOoWu0XTYJS7CA=
|
||||
gopkg.in/yaml.v3 v3.0.1/go.mod h1:K4uyk7z7BCEPqu6E+C64Yfv1cQ7kz7rIZviUmN+EgEM=
|
||||
244
internal/client/cli/attach.go
Normal file
244
internal/client/cli/attach.go
Normal file
@@ -0,0 +1,244 @@
|
||||
package cli
|
||||
|
||||
import (
|
||||
"bufio"
|
||||
"fmt"
|
||||
"os"
|
||||
"os/exec"
|
||||
"os/signal"
|
||||
"path/filepath"
|
||||
"strconv"
|
||||
"strings"
|
||||
"syscall"
|
||||
"time"
|
||||
|
||||
"github.com/spf13/cobra"
|
||||
)
|
||||
|
||||
var attachCmd = &cobra.Command{
|
||||
Use: "attach [type] [port]",
|
||||
Short: "Attach to a running background tunnel",
|
||||
Long: `Attach to a running background tunnel to view its logs in real-time.
|
||||
|
||||
Examples:
|
||||
drip attach List running tunnels and select one
|
||||
drip attach http 3000 Attach to HTTP tunnel on port 3000
|
||||
drip attach tcp 5432 Attach to TCP tunnel on port 5432
|
||||
|
||||
Press Ctrl+C to detach (tunnel will continue running).`,
|
||||
Aliases: []string{"logs", "tail"},
|
||||
Args: cobra.MaximumNArgs(2),
|
||||
RunE: runAttach,
|
||||
}
|
||||
|
||||
func init() {
|
||||
rootCmd.AddCommand(attachCmd)
|
||||
}
|
||||
|
||||
func runAttach(cmd *cobra.Command, args []string) error {
|
||||
// Clean up stale daemons first
|
||||
CleanupStaleDaemons()
|
||||
|
||||
// Get all running daemons
|
||||
daemons, err := ListAllDaemons()
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to list daemons: %w", err)
|
||||
}
|
||||
|
||||
if len(daemons) == 0 {
|
||||
fmt.Println("\033[90mNo running tunnels.\033[0m")
|
||||
fmt.Println()
|
||||
fmt.Println("Start a tunnel in background with:")
|
||||
fmt.Println(" \033[36mdrip http 3000 -d\033[0m")
|
||||
fmt.Println(" \033[36mdrip tcp 5432 -d\033[0m")
|
||||
return nil
|
||||
}
|
||||
|
||||
var selectedDaemon *DaemonInfo
|
||||
|
||||
// If type and port are specified, find the specific daemon
|
||||
if len(args) == 2 {
|
||||
tunnelType := args[0]
|
||||
if tunnelType != "http" && tunnelType != "tcp" {
|
||||
return fmt.Errorf("invalid tunnel type: %s (must be 'http' or 'tcp')", tunnelType)
|
||||
}
|
||||
|
||||
port, err := strconv.Atoi(args[1])
|
||||
if err != nil || port < 1 || port > 65535 {
|
||||
return fmt.Errorf("invalid port number: %s", args[1])
|
||||
}
|
||||
|
||||
// Find the daemon
|
||||
for _, d := range daemons {
|
||||
if d.Type == tunnelType && d.Port == port {
|
||||
if !IsProcessRunning(d.PID) {
|
||||
RemoveDaemonInfo(d.Type, d.Port)
|
||||
return fmt.Errorf("tunnel is not running (cleaned up stale entry)")
|
||||
}
|
||||
selectedDaemon = d
|
||||
break
|
||||
}
|
||||
}
|
||||
|
||||
if selectedDaemon == nil {
|
||||
return fmt.Errorf("no %s tunnel running on port %d", tunnelType, port)
|
||||
}
|
||||
} else if len(args) == 0 {
|
||||
// Interactive selection
|
||||
selectedDaemon, err = selectDaemonInteractive(daemons)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
if selectedDaemon == nil {
|
||||
return nil // User cancelled
|
||||
}
|
||||
} else {
|
||||
return fmt.Errorf("usage: drip attach [type port]")
|
||||
}
|
||||
|
||||
// Attach to the selected daemon
|
||||
return attachToDaemon(selectedDaemon)
|
||||
}
|
||||
|
||||
func selectDaemonInteractive(daemons []*DaemonInfo) (*DaemonInfo, error) {
|
||||
// Print header
|
||||
fmt.Println()
|
||||
fmt.Println("\033[1;37mSelect a tunnel to attach:\033[0m")
|
||||
fmt.Println("\033[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\033[0m")
|
||||
|
||||
// Filter out non-running daemons
|
||||
var runningDaemons []*DaemonInfo
|
||||
for _, d := range daemons {
|
||||
if IsProcessRunning(d.PID) {
|
||||
runningDaemons = append(runningDaemons, d)
|
||||
} else {
|
||||
RemoveDaemonInfo(d.Type, d.Port)
|
||||
}
|
||||
}
|
||||
|
||||
if len(runningDaemons) == 0 {
|
||||
fmt.Println("\033[90mNo running tunnels.\033[0m")
|
||||
return nil, nil
|
||||
}
|
||||
|
||||
// Print list
|
||||
for i, d := range runningDaemons {
|
||||
uptime := time.Since(d.StartTime)
|
||||
|
||||
// Format type with color
|
||||
var typeStr string
|
||||
if d.Type == "http" {
|
||||
typeStr = "\033[32mHTTP\033[0m"
|
||||
} else {
|
||||
typeStr = "\033[35mTCP\033[0m"
|
||||
}
|
||||
|
||||
// Truncate URL if too long
|
||||
url := d.URL
|
||||
if len(url) > 50 {
|
||||
url = url[:47] + "..."
|
||||
}
|
||||
|
||||
fmt.Printf("\033[1;36m%d.\033[0m %-15s \033[90mPort:\033[0m %-6d \033[90mURL:\033[0m %-50s \033[90mUptime:\033[0m %s\n",
|
||||
i+1, typeStr, d.Port, url, FormatDuration(uptime))
|
||||
}
|
||||
|
||||
fmt.Println("\033[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\033[0m")
|
||||
fmt.Printf("Enter number (1-%d) or 'q' to quit: ", len(runningDaemons))
|
||||
|
||||
// Read user input
|
||||
reader := bufio.NewReader(os.Stdin)
|
||||
input, err := reader.ReadString('\n')
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to read input: %w", err)
|
||||
}
|
||||
|
||||
input = strings.TrimSpace(input)
|
||||
if input == "q" || input == "Q" {
|
||||
return nil, nil
|
||||
}
|
||||
|
||||
// Parse selection
|
||||
selection, err := strconv.Atoi(input)
|
||||
if err != nil || selection < 1 || selection > len(runningDaemons) {
|
||||
return nil, fmt.Errorf("invalid selection: %s", input)
|
||||
}
|
||||
|
||||
return runningDaemons[selection-1], nil
|
||||
}
|
||||
|
||||
func attachToDaemon(daemon *DaemonInfo) error {
|
||||
// Get log file path
|
||||
logPath := filepath.Join(getDaemonDir(), fmt.Sprintf("%s_%d.log", daemon.Type, daemon.Port))
|
||||
|
||||
// Check if log file exists
|
||||
if _, err := os.Stat(logPath); os.IsNotExist(err) {
|
||||
return fmt.Errorf("log file not found: %s", logPath)
|
||||
}
|
||||
|
||||
// Print header
|
||||
fmt.Println()
|
||||
fmt.Println("\033[1;32m╔══════════════════════════════════════════════════════════════════╗\033[0m")
|
||||
fmt.Printf("\033[1;32m║\033[0m \033[1;37mAttached to %s tunnel on port %d\033[0m", strings.ToUpper(daemon.Type), daemon.Port)
|
||||
fmt.Printf("%s\033[1;32m║\033[0m\n", strings.Repeat(" ", 32-len(daemon.Type)))
|
||||
fmt.Println("\033[1;32m╠══════════════════════════════════════════════════════════════════╣\033[0m")
|
||||
fmt.Printf("\033[1;32m║\033[0m \033[90mURL:\033[0m \033[36m%-52s\033[0m \033[1;32m║\033[0m\n", daemon.URL)
|
||||
uptime := time.Since(daemon.StartTime)
|
||||
fmt.Printf("\033[1;32m║\033[0m \033[90mPID:\033[0m \033[90m%-52d\033[0m \033[1;32m║\033[0m\n", daemon.PID)
|
||||
fmt.Printf("\033[1;32m║\033[0m \033[90mUptime:\033[0m \033[90m%-52s\033[0m \033[1;32m║\033[0m\n", FormatDuration(uptime))
|
||||
fmt.Printf("\033[1;32m║\033[0m \033[90mLog:\033[0m \033[90m%-52s\033[0m \033[1;32m║\033[0m\n", truncatePath(logPath, 52))
|
||||
fmt.Println("\033[1;32m╠══════════════════════════════════════════════════════════════════╣\033[0m")
|
||||
fmt.Println("\033[1;32m║\033[0m \033[33mPress Ctrl+C to detach (tunnel will continue running)\033[0m \033[1;32m║\033[0m")
|
||||
fmt.Println("\033[1;32m╚══════════════════════════════════════════════════════════════════╝\033[0m")
|
||||
fmt.Println()
|
||||
|
||||
// Setup signal handler
|
||||
sigCh := make(chan os.Signal, 1)
|
||||
signal.Notify(sigCh, syscall.SIGINT, syscall.SIGTERM)
|
||||
|
||||
// Start tail command
|
||||
tailCmd := exec.Command("tail", "-f", logPath)
|
||||
tailCmd.Stdout = os.Stdout
|
||||
tailCmd.Stderr = os.Stderr
|
||||
|
||||
if err := tailCmd.Start(); err != nil {
|
||||
return fmt.Errorf("failed to start tail: %w", err)
|
||||
}
|
||||
|
||||
// Wait for signal or tail to exit
|
||||
done := make(chan error, 1)
|
||||
go func() {
|
||||
done <- tailCmd.Wait()
|
||||
}()
|
||||
|
||||
select {
|
||||
case <-sigCh:
|
||||
// Kill tail process
|
||||
if tailCmd.Process != nil {
|
||||
tailCmd.Process.Kill()
|
||||
}
|
||||
fmt.Println()
|
||||
fmt.Println("\033[33mDetached from tunnel (tunnel is still running)\033[0m")
|
||||
fmt.Printf("Use '\033[36mdrip attach %s %d\033[0m' to reattach\n", daemon.Type, daemon.Port)
|
||||
fmt.Printf("Use '\033[36mdrip stop %s %d\033[0m' to stop the tunnel\n", daemon.Type, daemon.Port)
|
||||
return nil
|
||||
case err := <-done:
|
||||
if err != nil {
|
||||
return fmt.Errorf("tail process exited: %w", err)
|
||||
}
|
||||
return nil
|
||||
}
|
||||
}
|
||||
|
||||
func truncatePath(path string, maxLen int) string {
|
||||
if len(path) <= maxLen {
|
||||
return path
|
||||
}
|
||||
// Try to keep filename and show ... in the middle
|
||||
filename := filepath.Base(path)
|
||||
if len(filename) >= maxLen-3 {
|
||||
return "..." + filename[len(filename)-(maxLen-3):]
|
||||
}
|
||||
dirLen := maxLen - len(filename) - 3
|
||||
return path[:dirLen] + "..." + filename
|
||||
}
|
||||
273
internal/client/cli/config.go
Normal file
273
internal/client/cli/config.go
Normal file
@@ -0,0 +1,273 @@
|
||||
package cli
|
||||
|
||||
import (
|
||||
"bufio"
|
||||
"fmt"
|
||||
"os"
|
||||
"strings"
|
||||
|
||||
"drip/pkg/config"
|
||||
"github.com/spf13/cobra"
|
||||
)
|
||||
|
||||
var configCmd = &cobra.Command{
|
||||
Use: "config",
|
||||
Short: "Manage configuration",
|
||||
Long: "Manage Drip client configuration (server, token, etc.)",
|
||||
}
|
||||
|
||||
var configInitCmd = &cobra.Command{
|
||||
Use: "init",
|
||||
Short: "Initialize configuration interactively",
|
||||
Long: "Initialize Drip configuration with interactive prompts",
|
||||
RunE: runConfigInit,
|
||||
}
|
||||
|
||||
var configShowCmd = &cobra.Command{
|
||||
Use: "show",
|
||||
Short: "Show current configuration",
|
||||
Long: "Display the current Drip configuration",
|
||||
RunE: runConfigShow,
|
||||
}
|
||||
|
||||
var configSetCmd = &cobra.Command{
|
||||
Use: "set",
|
||||
Short: "Set configuration values",
|
||||
Long: "Set specific configuration values (server, token)",
|
||||
RunE: runConfigSet,
|
||||
}
|
||||
|
||||
var configResetCmd = &cobra.Command{
|
||||
Use: "reset",
|
||||
Short: "Reset configuration",
|
||||
Long: "Delete the configuration file",
|
||||
RunE: runConfigReset,
|
||||
}
|
||||
|
||||
var configValidateCmd = &cobra.Command{
|
||||
Use: "validate",
|
||||
Short: "Validate configuration",
|
||||
Long: "Validate the configuration file",
|
||||
RunE: runConfigValidate,
|
||||
}
|
||||
|
||||
var (
|
||||
configFull bool
|
||||
configForce bool
|
||||
configServer string
|
||||
configToken string
|
||||
)
|
||||
|
||||
func init() {
|
||||
// Add subcommands
|
||||
configCmd.AddCommand(configInitCmd)
|
||||
configCmd.AddCommand(configShowCmd)
|
||||
configCmd.AddCommand(configSetCmd)
|
||||
configCmd.AddCommand(configResetCmd)
|
||||
configCmd.AddCommand(configValidateCmd)
|
||||
|
||||
// Flags for show
|
||||
configShowCmd.Flags().BoolVar(&configFull, "full", false, "Show full token (not hidden)")
|
||||
|
||||
// Flags for set
|
||||
configSetCmd.Flags().StringVar(&configServer, "server", "", "Server address (e.g., tunnel.example.com:443)")
|
||||
configSetCmd.Flags().StringVar(&configToken, "token", "", "Authentication token")
|
||||
|
||||
// Flags for reset
|
||||
configResetCmd.Flags().BoolVar(&configForce, "force", false, "Force reset without confirmation")
|
||||
|
||||
// Add to root
|
||||
rootCmd.AddCommand(configCmd)
|
||||
}
|
||||
|
||||
func runConfigInit(cmd *cobra.Command, args []string) error {
|
||||
fmt.Println("\n╔═══════════════════════════════════════╗")
|
||||
fmt.Println("║ Drip Configuration Setup ║")
|
||||
fmt.Println("╚═══════════════════════════════════════╝")
|
||||
|
||||
reader := bufio.NewReader(os.Stdin)
|
||||
|
||||
// Get server address
|
||||
fmt.Print("Server address (e.g., tunnel.example.com:443): ")
|
||||
serverAddr, _ := reader.ReadString('\n')
|
||||
serverAddr = strings.TrimSpace(serverAddr)
|
||||
|
||||
if serverAddr == "" {
|
||||
return fmt.Errorf("server address is required")
|
||||
}
|
||||
|
||||
// Get token
|
||||
fmt.Print("Authentication token (leave empty to skip): ")
|
||||
token, _ := reader.ReadString('\n')
|
||||
token = strings.TrimSpace(token)
|
||||
|
||||
// Create config
|
||||
cfg := &config.ClientConfig{
|
||||
Server: serverAddr,
|
||||
Token: token,
|
||||
TLS: true,
|
||||
}
|
||||
|
||||
// Save config
|
||||
if err := config.SaveClientConfig(cfg, ""); err != nil {
|
||||
return fmt.Errorf("failed to save configuration: %w", err)
|
||||
}
|
||||
|
||||
fmt.Println("\n✓ Configuration saved to", config.DefaultClientConfigPath())
|
||||
fmt.Println("✓ You can now use 'drip' without --server and --token")
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func runConfigShow(cmd *cobra.Command, args []string) error {
|
||||
// Load config
|
||||
cfg, err := config.LoadClientConfig("")
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
fmt.Println("\n╔═══════════════════════════════════════╗")
|
||||
fmt.Println("║ Current Configuration ║")
|
||||
fmt.Println("╚═══════════════════════════════════════╝")
|
||||
|
||||
fmt.Printf("Server: %s\n", cfg.Server)
|
||||
|
||||
// Show token (hidden or full)
|
||||
if cfg.Token != "" {
|
||||
if configFull {
|
||||
fmt.Printf("Token: %s\n", cfg.Token)
|
||||
} else {
|
||||
// Hide middle part of token
|
||||
if len(cfg.Token) > 10 {
|
||||
fmt.Printf("Token: %s***%s (hidden)\n",
|
||||
cfg.Token[:3],
|
||||
cfg.Token[len(cfg.Token)-3:],
|
||||
)
|
||||
} else {
|
||||
fmt.Printf("Token: %s (hidden)\n", cfg.Token[:3]+"***")
|
||||
}
|
||||
}
|
||||
} else {
|
||||
fmt.Println("Token: (not set)")
|
||||
}
|
||||
|
||||
fmt.Printf("TLS: %s\n", enabledDisabled(cfg.TLS))
|
||||
fmt.Printf("Config: %s\n\n", config.DefaultClientConfigPath())
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func runConfigSet(cmd *cobra.Command, args []string) error {
|
||||
// Load existing config or create new
|
||||
cfg, err := config.LoadClientConfig("")
|
||||
if err != nil {
|
||||
// Create new config if not exists
|
||||
cfg = &config.ClientConfig{
|
||||
TLS: true,
|
||||
}
|
||||
}
|
||||
|
||||
// Update fields if provided
|
||||
modified := false
|
||||
|
||||
if configServer != "" {
|
||||
cfg.Server = configServer
|
||||
modified = true
|
||||
fmt.Printf("✓ Server updated: %s\n", configServer)
|
||||
}
|
||||
|
||||
if configToken != "" {
|
||||
cfg.Token = configToken
|
||||
modified = true
|
||||
fmt.Println("✓ Token updated")
|
||||
}
|
||||
|
||||
if !modified {
|
||||
return fmt.Errorf("no changes specified. Use --server or --token")
|
||||
}
|
||||
|
||||
// Save config
|
||||
if err := config.SaveClientConfig(cfg, ""); err != nil {
|
||||
return fmt.Errorf("failed to save configuration: %w", err)
|
||||
}
|
||||
|
||||
fmt.Println("✓ Configuration saved")
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func runConfigReset(cmd *cobra.Command, args []string) error {
|
||||
configPath := config.DefaultClientConfigPath()
|
||||
|
||||
// Check if config exists
|
||||
if !config.ConfigExists("") {
|
||||
fmt.Println("No configuration file found")
|
||||
return nil
|
||||
}
|
||||
|
||||
// Confirm deletion
|
||||
if !configForce {
|
||||
fmt.Print("Are you sure you want to delete the configuration? (y/N): ")
|
||||
reader := bufio.NewReader(os.Stdin)
|
||||
response, _ := reader.ReadString('\n')
|
||||
response = strings.ToLower(strings.TrimSpace(response))
|
||||
|
||||
if response != "y" && response != "yes" {
|
||||
fmt.Println("Cancelled")
|
||||
return nil
|
||||
}
|
||||
}
|
||||
|
||||
// Delete config file
|
||||
if err := os.Remove(configPath); err != nil {
|
||||
return fmt.Errorf("failed to delete configuration: %w", err)
|
||||
}
|
||||
|
||||
fmt.Println("✓ Configuration file deleted")
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func runConfigValidate(cmd *cobra.Command, args []string) error {
|
||||
fmt.Println("\nValidating configuration...")
|
||||
fmt.Println("━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━")
|
||||
|
||||
// Load config
|
||||
cfg, err := config.LoadClientConfig("")
|
||||
if err != nil {
|
||||
fmt.Println("✗ Failed to load configuration")
|
||||
return err
|
||||
}
|
||||
|
||||
// Validate server address
|
||||
if cfg.Server == "" {
|
||||
fmt.Println("✗ Server address is not set")
|
||||
return fmt.Errorf("invalid configuration")
|
||||
}
|
||||
fmt.Println("✓ Server address is valid")
|
||||
|
||||
// Validate token
|
||||
if cfg.Token != "" {
|
||||
fmt.Println("✓ Token is set")
|
||||
} else {
|
||||
fmt.Println("⚠ Token is not set (authentication may fail)")
|
||||
}
|
||||
|
||||
// Validate TLS
|
||||
if cfg.TLS {
|
||||
fmt.Println("✓ TLS is enabled")
|
||||
} else {
|
||||
fmt.Println("⚠ TLS is disabled (not recommended for production)")
|
||||
}
|
||||
|
||||
fmt.Println("\n✓ Configuration is valid")
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func enabledDisabled(value bool) string {
|
||||
if value {
|
||||
return "enabled"
|
||||
}
|
||||
return "disabled"
|
||||
}
|
||||
263
internal/client/cli/daemon.go
Normal file
263
internal/client/cli/daemon.go
Normal file
@@ -0,0 +1,263 @@
|
||||
package cli
|
||||
|
||||
import (
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"os"
|
||||
"os/exec"
|
||||
"path/filepath"
|
||||
"strconv"
|
||||
"time"
|
||||
)
|
||||
|
||||
// DaemonInfo stores information about a running daemon process
|
||||
type DaemonInfo struct {
|
||||
PID int `json:"pid"`
|
||||
Type string `json:"type"` // "http" or "tcp"
|
||||
Port int `json:"port"` // Local port being tunneled
|
||||
Subdomain string `json:"subdomain"` // Subdomain if specified
|
||||
Server string `json:"server"` // Server address
|
||||
URL string `json:"url"` // Tunnel URL
|
||||
StartTime time.Time `json:"start_time"` // When the daemon started
|
||||
Executable string `json:"executable"` // Path to the executable
|
||||
}
|
||||
|
||||
// getDaemonDir returns the directory for storing daemon info
|
||||
func getDaemonDir() string {
|
||||
home, err := os.UserHomeDir()
|
||||
if err != nil {
|
||||
return ".drip"
|
||||
}
|
||||
return filepath.Join(home, ".drip", "daemons")
|
||||
}
|
||||
|
||||
// getDaemonFilePath returns the path to a daemon info file
|
||||
func getDaemonFilePath(tunnelType string, port int) string {
|
||||
return filepath.Join(getDaemonDir(), fmt.Sprintf("%s_%d.json", tunnelType, port))
|
||||
}
|
||||
|
||||
// SaveDaemonInfo saves daemon information to a file
|
||||
func SaveDaemonInfo(info *DaemonInfo) error {
|
||||
dir := getDaemonDir()
|
||||
if err := os.MkdirAll(dir, 0700); err != nil {
|
||||
return fmt.Errorf("failed to create daemon directory: %w", err)
|
||||
}
|
||||
|
||||
data, err := json.MarshalIndent(info, "", " ")
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to marshal daemon info: %w", err)
|
||||
}
|
||||
|
||||
path := getDaemonFilePath(info.Type, info.Port)
|
||||
if err := os.WriteFile(path, data, 0600); err != nil {
|
||||
return fmt.Errorf("failed to write daemon info: %w", err)
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// LoadDaemonInfo loads daemon information from a file
|
||||
func LoadDaemonInfo(tunnelType string, port int) (*DaemonInfo, error) {
|
||||
path := getDaemonFilePath(tunnelType, port)
|
||||
data, err := os.ReadFile(path)
|
||||
if err != nil {
|
||||
if os.IsNotExist(err) {
|
||||
return nil, nil
|
||||
}
|
||||
return nil, fmt.Errorf("failed to read daemon info: %w", err)
|
||||
}
|
||||
|
||||
var info DaemonInfo
|
||||
if err := json.Unmarshal(data, &info); err != nil {
|
||||
return nil, fmt.Errorf("failed to parse daemon info: %w", err)
|
||||
}
|
||||
|
||||
return &info, nil
|
||||
}
|
||||
|
||||
// RemoveDaemonInfo removes a daemon info file
|
||||
func RemoveDaemonInfo(tunnelType string, port int) error {
|
||||
path := getDaemonFilePath(tunnelType, port)
|
||||
if err := os.Remove(path); err != nil && !os.IsNotExist(err) {
|
||||
return fmt.Errorf("failed to remove daemon info: %w", err)
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
// ListAllDaemons returns all daemon info files
|
||||
func ListAllDaemons() ([]*DaemonInfo, error) {
|
||||
dir := getDaemonDir()
|
||||
entries, err := os.ReadDir(dir)
|
||||
if err != nil {
|
||||
if os.IsNotExist(err) {
|
||||
return nil, nil
|
||||
}
|
||||
return nil, fmt.Errorf("failed to read daemon directory: %w", err)
|
||||
}
|
||||
|
||||
var daemons []*DaemonInfo
|
||||
for _, entry := range entries {
|
||||
if entry.IsDir() || filepath.Ext(entry.Name()) != ".json" {
|
||||
continue
|
||||
}
|
||||
|
||||
data, err := os.ReadFile(filepath.Join(dir, entry.Name()))
|
||||
if err != nil {
|
||||
continue
|
||||
}
|
||||
|
||||
var info DaemonInfo
|
||||
if err := json.Unmarshal(data, &info); err != nil {
|
||||
continue
|
||||
}
|
||||
|
||||
daemons = append(daemons, &info)
|
||||
}
|
||||
|
||||
return daemons, nil
|
||||
}
|
||||
|
||||
// IsProcessRunning checks if a process with the given PID is running
|
||||
func IsProcessRunning(pid int) bool {
|
||||
process, err := os.FindProcess(pid)
|
||||
if err != nil {
|
||||
return false
|
||||
}
|
||||
|
||||
return isProcessRunningOS(process)
|
||||
}
|
||||
|
||||
// KillProcess kills a process by PID
|
||||
func KillProcess(pid int) error {
|
||||
process, err := os.FindProcess(pid)
|
||||
if err != nil {
|
||||
return fmt.Errorf("process not found: %w", err)
|
||||
}
|
||||
|
||||
if err := killProcessOS(process); err != nil {
|
||||
return fmt.Errorf("failed to kill process: %w", err)
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// StartDaemon starts the current process as a daemon
|
||||
func StartDaemon(tunnelType string, port int, args []string) error {
|
||||
// Get the executable path
|
||||
executable, err := os.Executable()
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to get executable path: %w", err)
|
||||
}
|
||||
|
||||
// Build command arguments (remove -D/--daemon flag to prevent recursion)
|
||||
var cleanArgs []string
|
||||
skipNext := false
|
||||
for i, arg := range args {
|
||||
if skipNext {
|
||||
skipNext = false
|
||||
continue
|
||||
}
|
||||
// Skip -D or --daemon flags (but NOT --daemon-child)
|
||||
if arg == "-D" || arg == "--daemon" {
|
||||
continue
|
||||
}
|
||||
// Handle -d (short form) - skip it
|
||||
if arg == "-d" {
|
||||
continue
|
||||
}
|
||||
// Skip if next arg would be a value for a removed flag (not applicable for boolean)
|
||||
_ = i
|
||||
cleanArgs = append(cleanArgs, arg)
|
||||
}
|
||||
|
||||
// Create the command
|
||||
cmd := exec.Command(executable, cleanArgs...)
|
||||
|
||||
// Detach from parent process (platform-specific)
|
||||
setupDaemonCmd(cmd)
|
||||
|
||||
// Create log file for daemon output
|
||||
logDir := getDaemonDir()
|
||||
if err := os.MkdirAll(logDir, 0700); err != nil {
|
||||
return fmt.Errorf("failed to create daemon directory: %w", err)
|
||||
}
|
||||
logPath := filepath.Join(logDir, fmt.Sprintf("%s_%d.log", tunnelType, port))
|
||||
logFile, err := os.OpenFile(logPath, os.O_CREATE|os.O_WRONLY|os.O_TRUNC, 0600)
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to create log file: %w", err)
|
||||
}
|
||||
|
||||
// Redirect stdin to /dev/null
|
||||
devNull, err := os.OpenFile(os.DevNull, os.O_RDONLY, 0)
|
||||
if err != nil {
|
||||
logFile.Close()
|
||||
return fmt.Errorf("failed to open /dev/null: %w", err)
|
||||
}
|
||||
cmd.Stdin = devNull
|
||||
cmd.Stdout = logFile
|
||||
cmd.Stderr = logFile
|
||||
|
||||
// Start the process
|
||||
if err := cmd.Start(); err != nil {
|
||||
logFile.Close()
|
||||
devNull.Close()
|
||||
return fmt.Errorf("failed to start daemon: %w", err)
|
||||
}
|
||||
|
||||
// Don't wait for the process - let it run in background
|
||||
// The child process will save its own daemon info after connecting
|
||||
|
||||
fmt.Printf("\033[32m✓\033[0m Started %s tunnel on port %d in background (PID: %d)\n", tunnelType, port, cmd.Process.Pid)
|
||||
fmt.Printf(" Use '\033[36mdrip list\033[0m' to check tunnel status\n")
|
||||
fmt.Printf(" Use '\033[36mdrip attach %s %d\033[0m' to view logs\n", tunnelType, port)
|
||||
fmt.Printf(" Use '\033[36mdrip stop %s %d\033[0m' to stop this tunnel\n", tunnelType, port)
|
||||
fmt.Printf(" Logs: \033[90m%s\033[0m\n", logPath)
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// CleanupStaleDaemons removes daemon info for processes that are no longer running
|
||||
func CleanupStaleDaemons() error {
|
||||
daemons, err := ListAllDaemons()
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
for _, info := range daemons {
|
||||
if !IsProcessRunning(info.PID) {
|
||||
RemoveDaemonInfo(info.Type, info.Port)
|
||||
}
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// FormatDuration formats a duration in a human-readable way
|
||||
func FormatDuration(d time.Duration) string {
|
||||
if d < time.Minute {
|
||||
return fmt.Sprintf("%ds", int(d.Seconds()))
|
||||
} else if d < time.Hour {
|
||||
return fmt.Sprintf("%dm %ds", int(d.Minutes()), int(d.Seconds())%60)
|
||||
} else if d < 24*time.Hour {
|
||||
return fmt.Sprintf("%dh %dm", int(d.Hours()), int(d.Minutes())%60)
|
||||
}
|
||||
days := int(d.Hours()) / 24
|
||||
hours := int(d.Hours()) % 24
|
||||
return fmt.Sprintf("%dd %dh", days, hours)
|
||||
}
|
||||
|
||||
// ParsePortFromArgs extracts the port number from command arguments
|
||||
func ParsePortFromArgs(args []string) (int, error) {
|
||||
for _, arg := range args {
|
||||
// Skip flags
|
||||
if len(arg) > 0 && arg[0] == '-' {
|
||||
continue
|
||||
}
|
||||
// Try to parse as port number
|
||||
port, err := strconv.Atoi(arg)
|
||||
if err == nil && port > 0 && port <= 65535 {
|
||||
return port, nil
|
||||
}
|
||||
}
|
||||
return 0, fmt.Errorf("port number not found in arguments")
|
||||
}
|
||||
38
internal/client/cli/daemon_unix.go
Normal file
38
internal/client/cli/daemon_unix.go
Normal file
@@ -0,0 +1,38 @@
|
||||
//go:build !windows
|
||||
|
||||
package cli
|
||||
|
||||
import (
|
||||
"os"
|
||||
"os/exec"
|
||||
"syscall"
|
||||
)
|
||||
|
||||
// getSysProcAttr returns platform-specific process attributes for daemonization
|
||||
func getSysProcAttr() *syscall.SysProcAttr {
|
||||
return &syscall.SysProcAttr{
|
||||
Setsid: true, // Create new session (Unix only)
|
||||
}
|
||||
}
|
||||
|
||||
// isProcessRunningOS checks if a process is running using OS-specific method
|
||||
func isProcessRunningOS(process *os.Process) bool {
|
||||
// On Unix, FindProcess always succeeds, so we need to send signal 0 to check
|
||||
err := process.Signal(syscall.Signal(0))
|
||||
return err == nil
|
||||
}
|
||||
|
||||
// killProcessOS kills a process using OS-specific signals
|
||||
func killProcessOS(process *os.Process) error {
|
||||
// First try SIGTERM for graceful shutdown
|
||||
if err := process.Signal(syscall.SIGTERM); err != nil {
|
||||
// If SIGTERM fails, try SIGKILL
|
||||
return process.Signal(syscall.SIGKILL)
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
// setupDaemonCmd configures the command for daemon mode
|
||||
func setupDaemonCmd(cmd *exec.Cmd) {
|
||||
cmd.SysProcAttr = getSysProcAttr()
|
||||
}
|
||||
42
internal/client/cli/daemon_windows.go
Normal file
42
internal/client/cli/daemon_windows.go
Normal file
@@ -0,0 +1,42 @@
|
||||
//go:build windows
|
||||
|
||||
package cli
|
||||
|
||||
import (
|
||||
"os"
|
||||
"os/exec"
|
||||
"syscall"
|
||||
)
|
||||
|
||||
// getSysProcAttr returns platform-specific process attributes for daemonization
|
||||
func getSysProcAttr() *syscall.SysProcAttr {
|
||||
return &syscall.SysProcAttr{
|
||||
CreationFlags: syscall.CREATE_NEW_PROCESS_GROUP,
|
||||
}
|
||||
}
|
||||
|
||||
// isProcessRunningOS checks if a process is running using OS-specific method
|
||||
func isProcessRunningOS(process *os.Process) bool {
|
||||
// On Windows, we try to open the process to check if it exists
|
||||
// FindProcess doesn't actually check if process exists on Windows
|
||||
// We can try to send signal, but Windows doesn't support signal 0
|
||||
// Instead, we'll try to kill with signal 0 which returns an error if process doesn't exist
|
||||
err := process.Signal(os.Signal(syscall.Signal(0)))
|
||||
if err != nil {
|
||||
// Try alternative: check if we can get process info
|
||||
// If the process doesn't exist, Signal will fail
|
||||
return false
|
||||
}
|
||||
return true
|
||||
}
|
||||
|
||||
// killProcessOS kills a process using OS-specific method
|
||||
func killProcessOS(process *os.Process) error {
|
||||
// On Windows, use Kill() directly
|
||||
return process.Kill()
|
||||
}
|
||||
|
||||
// setupDaemonCmd configures the command for daemon mode
|
||||
func setupDaemonCmd(cmd *exec.Cmd) {
|
||||
cmd.SysProcAttr = getSysProcAttr()
|
||||
}
|
||||
306
internal/client/cli/http.go
Normal file
306
internal/client/cli/http.go
Normal file
@@ -0,0 +1,306 @@
|
||||
package cli
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"os"
|
||||
"os/signal"
|
||||
"strconv"
|
||||
"strings"
|
||||
"syscall"
|
||||
"time"
|
||||
|
||||
"drip/internal/client/tcp"
|
||||
"drip/internal/shared/protocol"
|
||||
"drip/internal/shared/utils"
|
||||
"drip/pkg/config"
|
||||
|
||||
"github.com/spf13/cobra"
|
||||
"go.uber.org/zap"
|
||||
)
|
||||
|
||||
const (
|
||||
maxReconnectAttempts = 5
|
||||
reconnectInterval = 3 * time.Second
|
||||
)
|
||||
|
||||
var (
|
||||
subdomain string
|
||||
daemonMode bool
|
||||
daemonMarker bool
|
||||
localAddress string
|
||||
)
|
||||
|
||||
var httpCmd = &cobra.Command{
|
||||
Use: "http <port>",
|
||||
Short: "Start HTTP tunnel",
|
||||
Long: `Start an HTTP tunnel to expose a local HTTP server.
|
||||
|
||||
Example:
|
||||
drip http 3000 Tunnel localhost:3000
|
||||
drip http 8080 --subdomain myapp Use custom subdomain
|
||||
|
||||
Configuration:
|
||||
First time: Run 'drip config init' to save server and token
|
||||
Subsequent: Just run 'drip http <port>'
|
||||
|
||||
Note: Uses TCP over TLS 1.3 for secure communication`,
|
||||
Args: cobra.ExactArgs(1),
|
||||
RunE: runHTTP,
|
||||
}
|
||||
|
||||
func init() {
|
||||
httpCmd.Flags().StringVarP(&subdomain, "subdomain", "n", "", "Custom subdomain (optional)")
|
||||
httpCmd.Flags().BoolVarP(&daemonMode, "daemon", "d", false, "Run in background (daemon mode)")
|
||||
httpCmd.Flags().StringVarP(&localAddress, "address", "a", "127.0.0.1", "Local address to forward to (default: 127.0.0.1)")
|
||||
httpCmd.Flags().BoolVar(&daemonMarker, "daemon-child", false, "Internal flag for daemon child process")
|
||||
httpCmd.Flags().MarkHidden("daemon-child")
|
||||
rootCmd.AddCommand(httpCmd)
|
||||
}
|
||||
|
||||
func runHTTP(cmd *cobra.Command, args []string) error {
|
||||
port, err := strconv.Atoi(args[0])
|
||||
if err != nil || port < 1 || port > 65535 {
|
||||
return fmt.Errorf("invalid port number: %s", args[0])
|
||||
}
|
||||
|
||||
if daemonMode && !daemonMarker {
|
||||
daemonArgs := append([]string{"http"}, args...)
|
||||
daemonArgs = append(daemonArgs, "--daemon-child")
|
||||
if subdomain != "" {
|
||||
daemonArgs = append(daemonArgs, "--subdomain", subdomain)
|
||||
}
|
||||
if localAddress != "127.0.0.1" {
|
||||
daemonArgs = append(daemonArgs, "--address", localAddress)
|
||||
}
|
||||
if serverURL != "" {
|
||||
daemonArgs = append(daemonArgs, "--server", serverURL)
|
||||
}
|
||||
if authToken != "" {
|
||||
daemonArgs = append(daemonArgs, "--token", authToken)
|
||||
}
|
||||
if insecure {
|
||||
daemonArgs = append(daemonArgs, "--insecure")
|
||||
}
|
||||
if verbose {
|
||||
daemonArgs = append(daemonArgs, "--verbose")
|
||||
}
|
||||
return StartDaemon("http", port, daemonArgs)
|
||||
}
|
||||
|
||||
if err := utils.InitLogger(verbose); err != nil {
|
||||
return fmt.Errorf("failed to initialize logger: %w", err)
|
||||
}
|
||||
defer utils.Sync()
|
||||
|
||||
logger := utils.GetLogger()
|
||||
|
||||
var serverAddr, token string
|
||||
|
||||
if serverURL == "" {
|
||||
cfg, err := config.LoadClientConfig("")
|
||||
if err != nil {
|
||||
return fmt.Errorf(`configuration not found.
|
||||
|
||||
Please run 'drip config init' first, or use flags:
|
||||
drip http %d --server SERVER:PORT --token TOKEN`, port)
|
||||
}
|
||||
serverAddr = cfg.Server
|
||||
token = cfg.Token
|
||||
} else {
|
||||
serverAddr = serverURL
|
||||
token = authToken
|
||||
}
|
||||
|
||||
if serverAddr == "" {
|
||||
return fmt.Errorf("server address is required")
|
||||
}
|
||||
|
||||
connConfig := &tcp.ConnectorConfig{
|
||||
ServerAddr: serverAddr,
|
||||
Token: token,
|
||||
TunnelType: protocol.TunnelTypeHTTP,
|
||||
LocalHost: localAddress,
|
||||
LocalPort: port,
|
||||
Subdomain: subdomain,
|
||||
Insecure: insecure,
|
||||
}
|
||||
|
||||
quit := make(chan os.Signal, 1)
|
||||
signal.Notify(quit, syscall.SIGINT, syscall.SIGTERM)
|
||||
|
||||
reconnectAttempts := 0
|
||||
for {
|
||||
connector := tcp.NewConnector(connConfig, logger)
|
||||
|
||||
if reconnectAttempts == 0 {
|
||||
fmt.Printf("\033[36m🔌 Connecting to %s...\033[0m\n", serverAddr)
|
||||
} else {
|
||||
fmt.Printf("\033[33m🔄 Reconnecting to %s (attempt %d/%d)...\033[0m\n", serverAddr, reconnectAttempts, maxReconnectAttempts)
|
||||
}
|
||||
|
||||
if err := connector.Connect(); err != nil {
|
||||
if isNonRetryableError(err) {
|
||||
return fmt.Errorf("failed to connect: %w", err)
|
||||
}
|
||||
|
||||
reconnectAttempts++
|
||||
if reconnectAttempts >= maxReconnectAttempts {
|
||||
return fmt.Errorf("failed to connect after %d attempts: %w", maxReconnectAttempts, err)
|
||||
}
|
||||
fmt.Printf("\033[31m✗ Connection failed: %v\033[0m\n", err)
|
||||
fmt.Printf("\033[90m Retrying in %v...\033[0m\n", reconnectInterval)
|
||||
|
||||
select {
|
||||
case <-quit:
|
||||
fmt.Println("\n\033[33m🛑 Shutting down...\033[0m")
|
||||
return nil
|
||||
case <-time.After(reconnectInterval):
|
||||
continue
|
||||
}
|
||||
}
|
||||
|
||||
reconnectAttempts = 0
|
||||
|
||||
if daemonMarker {
|
||||
daemonInfo := &DaemonInfo{
|
||||
PID: os.Getpid(),
|
||||
Type: "http",
|
||||
Port: port,
|
||||
Subdomain: subdomain,
|
||||
Server: serverAddr,
|
||||
URL: connector.GetURL(),
|
||||
StartTime: time.Now(),
|
||||
Executable: os.Args[0],
|
||||
}
|
||||
if err := SaveDaemonInfo(daemonInfo); err != nil {
|
||||
logger.Warn("Failed to save daemon info", zap.Error(err))
|
||||
}
|
||||
}
|
||||
|
||||
fmt.Println()
|
||||
fmt.Println("\033[1;32m╔══════════════════════════════════════════════════════════════════╗\033[0m")
|
||||
fmt.Println("\033[1;32m║\033[0m \033[1;37m🚀 HTTP Tunnel Connected Successfully!\033[0m \033[1;32m║\033[0m")
|
||||
fmt.Println("\033[1;32m╠══════════════════════════════════════════════════════════════════╣\033[0m")
|
||||
fmt.Printf("\033[1;32m║\033[0m \033[1;37mTunnel URL:\033[0m \033[1;32m║\033[0m\n")
|
||||
fmt.Printf("\033[1;32m║\033[0m \033[1;36m%-60s\033[0m \033[1;32m║\033[0m\n", connector.GetURL())
|
||||
fmt.Println("\033[1;32m║\033[0m \033[1;32m║\033[0m")
|
||||
displayAddr := localAddress
|
||||
if displayAddr == "127.0.0.1" {
|
||||
displayAddr = "localhost"
|
||||
}
|
||||
fmt.Printf("\033[1;32m║\033[0m \033[90mForwarding:\033[0m \033[1m%s:%d\033[0m → \033[36m%s\033[0m%-15s\033[1;32m║\033[0m\n", displayAddr, port, "public", "")
|
||||
fmt.Printf("\033[1;32m║\033[0m \033[90mLatency:\033[0m \033[90mmeasuring...\033[0m%-40s\033[1;32m║\033[0m\n", "")
|
||||
fmt.Printf("\033[1;32m║\033[0m \033[90mTraffic:\033[0m \033[90m↓ 0 B ↑ 0 B\033[0m%-32s\033[1;32m║\033[0m\n", "")
|
||||
fmt.Printf("\033[1;32m║\033[0m \033[90mSpeed:\033[0m \033[90m↓ 0 B/s ↑ 0 B/s\033[0m%-28s\033[1;32m║\033[0m\n", "")
|
||||
fmt.Printf("\033[1;32m║\033[0m \033[90mRequests:\033[0m \033[90m0\033[0m%-43s\033[1;32m║\033[0m\n", "")
|
||||
fmt.Println("\033[1;32m╠══════════════════════════════════════════════════════════════════╣\033[0m")
|
||||
fmt.Println("\033[1;32m║\033[0m \033[90mPress Ctrl+C to stop the tunnel\033[0m \033[1;32m║\033[0m")
|
||||
fmt.Println("\033[1;32m╚══════════════════════════════════════════════════════════════════╝\033[0m")
|
||||
fmt.Println()
|
||||
|
||||
latencyCh := make(chan time.Duration, 1)
|
||||
connector.SetLatencyCallback(func(latency time.Duration) {
|
||||
select {
|
||||
case latencyCh <- latency:
|
||||
default:
|
||||
}
|
||||
})
|
||||
|
||||
stopDisplay := make(chan struct{})
|
||||
disconnected := make(chan struct{})
|
||||
|
||||
go func() {
|
||||
ticker := time.NewTicker(1 * time.Second)
|
||||
defer ticker.Stop()
|
||||
|
||||
var lastLatency time.Duration
|
||||
for {
|
||||
select {
|
||||
case latency := <-latencyCh:
|
||||
lastLatency = latency
|
||||
case <-ticker.C:
|
||||
stats := connector.GetStats()
|
||||
if stats != nil {
|
||||
stats.UpdateSpeed()
|
||||
snapshot := stats.GetSnapshot()
|
||||
|
||||
fmt.Print("\033[8A")
|
||||
|
||||
fmt.Printf("\r\033[1;32m║\033[0m \033[90mLatency:\033[0m %s%-40s\033[1;32m║\033[0m\n", formatLatency(lastLatency), "")
|
||||
|
||||
trafficStr := fmt.Sprintf("↓ %s ↑ %s", tcp.FormatBytes(snapshot.TotalBytesIn), tcp.FormatBytes(snapshot.TotalBytesOut))
|
||||
fmt.Printf("\r\033[1;32m║\033[0m \033[90mTraffic:\033[0m \033[36m%-48s\033[0m\033[1;32m║\033[0m\n", trafficStr)
|
||||
|
||||
speedStr := fmt.Sprintf("↓ %s ↑ %s", tcp.FormatSpeed(snapshot.SpeedIn), tcp.FormatSpeed(snapshot.SpeedOut))
|
||||
fmt.Printf("\r\033[1;32m║\033[0m \033[90mSpeed:\033[0m \033[33m%-48s\033[0m\033[1;32m║\033[0m\n", speedStr)
|
||||
|
||||
fmt.Printf("\r\033[1;32m║\033[0m \033[90mRequests:\033[0m \033[35m%-47d\033[0m\033[1;32m║\033[0m\n", snapshot.TotalRequests)
|
||||
|
||||
fmt.Print("\033[4B")
|
||||
}
|
||||
case <-stopDisplay:
|
||||
return
|
||||
}
|
||||
}
|
||||
}()
|
||||
|
||||
go func() {
|
||||
connector.Wait()
|
||||
close(disconnected)
|
||||
}()
|
||||
|
||||
select {
|
||||
case <-quit:
|
||||
close(stopDisplay)
|
||||
fmt.Println("\n\n\033[33m🛑 Shutting down...\033[0m")
|
||||
connector.Close()
|
||||
if daemonMarker {
|
||||
RemoveDaemonInfo("http", port)
|
||||
}
|
||||
fmt.Println("\033[32m✓\033[0m Tunnel closed")
|
||||
return nil
|
||||
case <-disconnected:
|
||||
close(stopDisplay)
|
||||
fmt.Println("\n\n\033[31m⚠ Connection lost!\033[0m")
|
||||
reconnectAttempts++
|
||||
if reconnectAttempts >= maxReconnectAttempts {
|
||||
return fmt.Errorf("connection lost after %d reconnect attempts", maxReconnectAttempts)
|
||||
}
|
||||
fmt.Printf("\033[90m Reconnecting in %v...\033[0m\n", reconnectInterval)
|
||||
|
||||
select {
|
||||
case <-quit:
|
||||
fmt.Println("\n\033[33m🛑 Shutting down...\033[0m")
|
||||
return nil
|
||||
case <-time.After(reconnectInterval):
|
||||
continue
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func formatLatency(d time.Duration) string {
|
||||
ms := d.Milliseconds()
|
||||
if ms < 50 {
|
||||
return fmt.Sprintf("\033[32m%dms\033[0m", ms)
|
||||
} else if ms < 100 {
|
||||
return fmt.Sprintf("\033[33m%dms\033[0m", ms)
|
||||
} else if ms < 200 {
|
||||
return fmt.Sprintf("\033[38;5;208m%dms\033[0m", ms)
|
||||
}
|
||||
return fmt.Sprintf("\033[31m%dms\033[0m", ms)
|
||||
}
|
||||
|
||||
func isNonRetryableError(err error) bool {
|
||||
errStr := err.Error()
|
||||
if strings.Contains(errStr, "subdomain is already taken") ||
|
||||
strings.Contains(errStr, "subdomain is reserved") ||
|
||||
strings.Contains(errStr, "invalid subdomain") {
|
||||
return true
|
||||
}
|
||||
if strings.Contains(errStr, "authentication") ||
|
||||
strings.Contains(errStr, "Invalid authentication token") {
|
||||
return true
|
||||
}
|
||||
return false
|
||||
}
|
||||
305
internal/client/cli/https.go
Normal file
305
internal/client/cli/https.go
Normal file
@@ -0,0 +1,305 @@
|
||||
package cli
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"os"
|
||||
"os/signal"
|
||||
"strconv"
|
||||
"syscall"
|
||||
"time"
|
||||
|
||||
"drip/internal/client/tcp"
|
||||
"drip/internal/shared/protocol"
|
||||
"drip/internal/shared/utils"
|
||||
"drip/pkg/config"
|
||||
|
||||
"github.com/spf13/cobra"
|
||||
"go.uber.org/zap"
|
||||
)
|
||||
|
||||
var (
|
||||
httpsSubdomain string
|
||||
httpsDaemonMode bool
|
||||
httpsDaemonMarker bool
|
||||
httpsLocalAddress string
|
||||
)
|
||||
|
||||
var httpsCmd = &cobra.Command{
|
||||
Use: "https <port>",
|
||||
Short: "Start HTTPS tunnel",
|
||||
Long: `Start an HTTPS tunnel to expose a local HTTPS server.
|
||||
|
||||
Example:
|
||||
drip https 443 Tunnel localhost:443
|
||||
drip https 8443 --subdomain myapp Use custom subdomain
|
||||
|
||||
Configuration:
|
||||
First time: Run 'drip config init' to save server and token
|
||||
Subsequent: Just run 'drip https <port>'
|
||||
|
||||
Note: Uses TCP over TLS 1.3 for secure communication`,
|
||||
Args: cobra.ExactArgs(1),
|
||||
RunE: runHTTPS,
|
||||
}
|
||||
|
||||
func init() {
|
||||
httpsCmd.Flags().StringVarP(&httpsSubdomain, "subdomain", "n", "", "Custom subdomain (optional)")
|
||||
httpsCmd.Flags().BoolVarP(&httpsDaemonMode, "daemon", "d", false, "Run in background (daemon mode)")
|
||||
httpsCmd.Flags().StringVarP(&httpsLocalAddress, "address", "a", "127.0.0.1", "Local address to forward to (default: 127.0.0.1)")
|
||||
httpsCmd.Flags().BoolVar(&httpsDaemonMarker, "daemon-child", false, "Internal flag for daemon child process")
|
||||
httpsCmd.Flags().MarkHidden("daemon-child")
|
||||
rootCmd.AddCommand(httpsCmd)
|
||||
}
|
||||
|
||||
func runHTTPS(cmd *cobra.Command, args []string) error {
|
||||
// Parse port
|
||||
port, err := strconv.Atoi(args[0])
|
||||
if err != nil || port < 1 || port > 65535 {
|
||||
return fmt.Errorf("invalid port number: %s", args[0])
|
||||
}
|
||||
|
||||
// Handle daemon mode
|
||||
if httpsDaemonMode && !httpsDaemonMarker {
|
||||
// Start as daemon
|
||||
daemonArgs := append([]string{"https"}, args...)
|
||||
daemonArgs = append(daemonArgs, "--daemon-child")
|
||||
if httpsSubdomain != "" {
|
||||
daemonArgs = append(daemonArgs, "--subdomain", httpsSubdomain)
|
||||
}
|
||||
if httpsLocalAddress != "127.0.0.1" {
|
||||
daemonArgs = append(daemonArgs, "--address", httpsLocalAddress)
|
||||
}
|
||||
if serverURL != "" {
|
||||
daemonArgs = append(daemonArgs, "--server", serverURL)
|
||||
}
|
||||
if authToken != "" {
|
||||
daemonArgs = append(daemonArgs, "--token", authToken)
|
||||
}
|
||||
if insecure {
|
||||
daemonArgs = append(daemonArgs, "--insecure")
|
||||
}
|
||||
if verbose {
|
||||
daemonArgs = append(daemonArgs, "--verbose")
|
||||
}
|
||||
return StartDaemon("https", port, daemonArgs)
|
||||
}
|
||||
|
||||
// Initialize logger
|
||||
if err := utils.InitLogger(verbose); err != nil {
|
||||
return fmt.Errorf("failed to initialize logger: %w", err)
|
||||
}
|
||||
defer utils.Sync()
|
||||
|
||||
logger := utils.GetLogger()
|
||||
|
||||
// Load configuration or use command line flags
|
||||
var serverAddr, token string
|
||||
|
||||
if serverURL == "" {
|
||||
// Try to load from config file
|
||||
cfg, err := config.LoadClientConfig("")
|
||||
if err != nil {
|
||||
return fmt.Errorf(`configuration not found.
|
||||
|
||||
Please run 'drip config init' first, or use flags:
|
||||
drip https %d --server SERVER:PORT --token TOKEN`, port)
|
||||
}
|
||||
serverAddr = cfg.Server
|
||||
token = cfg.Token
|
||||
} else {
|
||||
// Use command line flags
|
||||
serverAddr = serverURL
|
||||
token = authToken
|
||||
}
|
||||
|
||||
// Validate server address
|
||||
if serverAddr == "" {
|
||||
return fmt.Errorf("server address is required")
|
||||
}
|
||||
|
||||
// Create connector config
|
||||
connConfig := &tcp.ConnectorConfig{
|
||||
ServerAddr: serverAddr,
|
||||
Token: token,
|
||||
TunnelType: protocol.TunnelTypeHTTPS,
|
||||
LocalHost: httpsLocalAddress,
|
||||
LocalPort: port,
|
||||
Subdomain: httpsSubdomain,
|
||||
Insecure: insecure,
|
||||
}
|
||||
|
||||
// Setup signal handler for graceful shutdown
|
||||
quit := make(chan os.Signal, 1)
|
||||
signal.Notify(quit, syscall.SIGINT, syscall.SIGTERM)
|
||||
|
||||
// Connection loop with reconnect support
|
||||
reconnectAttempts := 0
|
||||
for {
|
||||
// Create connector
|
||||
connector := tcp.NewConnector(connConfig, logger)
|
||||
|
||||
// Connect to server
|
||||
if reconnectAttempts == 0 {
|
||||
fmt.Printf("\033[36m🔌 Connecting to %s...\033[0m\n", serverAddr)
|
||||
} else {
|
||||
fmt.Printf("\033[33m🔄 Reconnecting to %s (attempt %d/%d)...\033[0m\n", serverAddr, reconnectAttempts, maxReconnectAttempts)
|
||||
}
|
||||
|
||||
if err := connector.Connect(); err != nil {
|
||||
// Check if this is a non-retryable error
|
||||
if isNonRetryableError(err) {
|
||||
return fmt.Errorf("failed to connect: %w", err)
|
||||
}
|
||||
|
||||
reconnectAttempts++
|
||||
if reconnectAttempts >= maxReconnectAttempts {
|
||||
return fmt.Errorf("failed to connect after %d attempts: %w", maxReconnectAttempts, err)
|
||||
}
|
||||
fmt.Printf("\033[31m✗ Connection failed: %v\033[0m\n", err)
|
||||
fmt.Printf("\033[90m Retrying in %v...\033[0m\n", reconnectInterval)
|
||||
|
||||
// Wait before retry, but allow interrupt
|
||||
select {
|
||||
case <-quit:
|
||||
fmt.Println("\n\033[33m🛑 Shutting down...\033[0m")
|
||||
return nil
|
||||
case <-time.After(reconnectInterval):
|
||||
continue
|
||||
}
|
||||
}
|
||||
|
||||
// Reset reconnect attempts on successful connection
|
||||
reconnectAttempts = 0
|
||||
|
||||
// Save daemon info if running as daemon child
|
||||
if httpsDaemonMarker {
|
||||
daemonInfo := &DaemonInfo{
|
||||
PID: os.Getpid(),
|
||||
Type: "https",
|
||||
Port: port,
|
||||
Subdomain: httpsSubdomain,
|
||||
Server: serverAddr,
|
||||
URL: connector.GetURL(),
|
||||
StartTime: time.Now(),
|
||||
Executable: os.Args[0],
|
||||
}
|
||||
if err := SaveDaemonInfo(daemonInfo); err != nil {
|
||||
// Log but don't fail
|
||||
logger.Warn("Failed to save daemon info", zap.Error(err))
|
||||
}
|
||||
}
|
||||
|
||||
// Print tunnel information
|
||||
fmt.Println()
|
||||
fmt.Println("\033[1;32m╔══════════════════════════════════════════════════════════════════╗\033[0m")
|
||||
fmt.Println("\033[1;32m║\033[0m \033[1;37m🔒 HTTPS Tunnel Connected Successfully!\033[0m \033[1;32m║\033[0m")
|
||||
fmt.Println("\033[1;32m╠══════════════════════════════════════════════════════════════════╣\033[0m")
|
||||
fmt.Printf("\033[1;32m║\033[0m \033[1;37mTunnel URL:\033[0m \033[1;32m║\033[0m\n")
|
||||
fmt.Printf("\033[1;32m║\033[0m \033[1;36m%-60s\033[0m \033[1;32m║\033[0m\n", connector.GetURL())
|
||||
fmt.Println("\033[1;32m║\033[0m \033[1;32m║\033[0m")
|
||||
displayAddr := httpsLocalAddress
|
||||
if displayAddr == "127.0.0.1" {
|
||||
displayAddr = "localhost"
|
||||
}
|
||||
fmt.Printf("\033[1;32m║\033[0m \033[90mForwarding:\033[0m \033[1m%s:%d\033[0m → \033[36m%s\033[0m%-15s\033[1;32m║\033[0m\n", displayAddr, port, "public", "")
|
||||
fmt.Printf("\033[1;32m║\033[0m \033[90mLatency:\033[0m \033[90mmeasuring...\033[0m%-40s\033[1;32m║\033[0m\n", "")
|
||||
fmt.Printf("\033[1;32m║\033[0m \033[90mTraffic:\033[0m \033[90m↓ 0 B ↑ 0 B\033[0m%-32s\033[1;32m║\033[0m\n", "")
|
||||
fmt.Printf("\033[1;32m║\033[0m \033[90mSpeed:\033[0m \033[90m↓ 0 B/s ↑ 0 B/s\033[0m%-28s\033[1;32m║\033[0m\n", "")
|
||||
fmt.Printf("\033[1;32m║\033[0m \033[90mRequests:\033[0m \033[90m0\033[0m%-43s\033[1;32m║\033[0m\n", "")
|
||||
fmt.Println("\033[1;32m╠══════════════════════════════════════════════════════════════════╣\033[0m")
|
||||
fmt.Println("\033[1;32m║\033[0m \033[90mPress Ctrl+C to stop the tunnel\033[0m \033[1;32m║\033[0m")
|
||||
fmt.Println("\033[1;32m╚══════════════════════════════════════════════════════════════════╝\033[0m")
|
||||
fmt.Println()
|
||||
|
||||
// Setup latency display
|
||||
latencyCh := make(chan time.Duration, 1)
|
||||
connector.SetLatencyCallback(func(latency time.Duration) {
|
||||
select {
|
||||
case latencyCh <- latency:
|
||||
default:
|
||||
}
|
||||
})
|
||||
|
||||
// Start stats display updater (updates every second)
|
||||
stopDisplay := make(chan struct{})
|
||||
disconnected := make(chan struct{})
|
||||
|
||||
go func() {
|
||||
ticker := time.NewTicker(1 * time.Second)
|
||||
defer ticker.Stop()
|
||||
|
||||
var lastLatency time.Duration
|
||||
for {
|
||||
select {
|
||||
case latency := <-latencyCh:
|
||||
lastLatency = latency
|
||||
case <-ticker.C:
|
||||
// Update speed calculation
|
||||
stats := connector.GetStats()
|
||||
if stats != nil {
|
||||
stats.UpdateSpeed()
|
||||
snapshot := stats.GetSnapshot()
|
||||
|
||||
// Move cursor up 8 lines to update display
|
||||
fmt.Print("\033[8A")
|
||||
|
||||
// Update latency line
|
||||
fmt.Printf("\r\033[1;32m║\033[0m \033[90mLatency:\033[0m %s%-40s\033[1;32m║\033[0m\n", formatLatency(lastLatency), "")
|
||||
|
||||
// Update traffic line
|
||||
trafficStr := fmt.Sprintf("↓ %s ↑ %s", tcp.FormatBytes(snapshot.TotalBytesIn), tcp.FormatBytes(snapshot.TotalBytesOut))
|
||||
fmt.Printf("\r\033[1;32m║\033[0m \033[90mTraffic:\033[0m \033[36m%-48s\033[0m\033[1;32m║\033[0m\n", trafficStr)
|
||||
|
||||
// Update speed line
|
||||
speedStr := fmt.Sprintf("↓ %s ↑ %s", tcp.FormatSpeed(snapshot.SpeedIn), tcp.FormatSpeed(snapshot.SpeedOut))
|
||||
fmt.Printf("\r\033[1;32m║\033[0m \033[90mSpeed:\033[0m \033[33m%-48s\033[0m\033[1;32m║\033[0m\n", speedStr)
|
||||
|
||||
// Update requests line
|
||||
fmt.Printf("\r\033[1;32m║\033[0m \033[90mRequests:\033[0m \033[35m%-47d\033[0m\033[1;32m║\033[0m\n", snapshot.TotalRequests)
|
||||
|
||||
// Move back down 4 lines
|
||||
fmt.Print("\033[4B")
|
||||
}
|
||||
case <-stopDisplay:
|
||||
return
|
||||
}
|
||||
}
|
||||
}()
|
||||
|
||||
// Monitor connection in background
|
||||
go func() {
|
||||
connector.Wait()
|
||||
close(disconnected)
|
||||
}()
|
||||
|
||||
// Wait for signal or disconnection
|
||||
select {
|
||||
case <-quit:
|
||||
close(stopDisplay)
|
||||
fmt.Println("\n\n\033[33m🛑 Shutting down...\033[0m")
|
||||
connector.Close()
|
||||
if httpsDaemonMarker {
|
||||
RemoveDaemonInfo("https", port)
|
||||
}
|
||||
fmt.Println("\033[32m✓\033[0m Tunnel closed")
|
||||
return nil
|
||||
case <-disconnected:
|
||||
close(stopDisplay)
|
||||
fmt.Println("\n\n\033[31m⚠ Connection lost!\033[0m")
|
||||
reconnectAttempts++
|
||||
if reconnectAttempts >= maxReconnectAttempts {
|
||||
return fmt.Errorf("connection lost after %d reconnect attempts", maxReconnectAttempts)
|
||||
}
|
||||
fmt.Printf("\033[90m Reconnecting in %v...\033[0m\n", reconnectInterval)
|
||||
|
||||
// Wait before reconnect, but allow interrupt
|
||||
select {
|
||||
case <-quit:
|
||||
fmt.Println("\n\033[33m🛑 Shutting down...\033[0m")
|
||||
return nil
|
||||
case <-time.After(reconnectInterval):
|
||||
continue
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
194
internal/client/cli/list.go
Normal file
194
internal/client/cli/list.go
Normal file
@@ -0,0 +1,194 @@
|
||||
package cli
|
||||
|
||||
import (
|
||||
"bufio"
|
||||
"fmt"
|
||||
"os"
|
||||
"strconv"
|
||||
"strings"
|
||||
"time"
|
||||
|
||||
"github.com/spf13/cobra"
|
||||
)
|
||||
|
||||
var (
|
||||
interactiveMode bool
|
||||
)
|
||||
|
||||
var listCmd = &cobra.Command{
|
||||
Use: "list",
|
||||
Short: "List all running background tunnels",
|
||||
Long: `List all running background tunnels.
|
||||
|
||||
Example:
|
||||
drip list Show all running tunnels
|
||||
drip list -i Interactive mode (select to attach/stop)
|
||||
|
||||
This command shows:
|
||||
- Tunnel type (HTTP/TCP)
|
||||
- Local port being tunneled
|
||||
- Public URL
|
||||
- Process ID (PID)
|
||||
- Uptime
|
||||
|
||||
In interactive mode, you can select a tunnel to:
|
||||
- Attach: View real-time logs
|
||||
- Stop: Terminate the tunnel`,
|
||||
Aliases: []string{"ls", "ps", "status"},
|
||||
RunE: runList,
|
||||
}
|
||||
|
||||
func init() {
|
||||
listCmd.Flags().BoolVarP(&interactiveMode, "interactive", "i", false, "Interactive mode for attach/stop")
|
||||
rootCmd.AddCommand(listCmd)
|
||||
}
|
||||
|
||||
func runList(cmd *cobra.Command, args []string) error {
|
||||
// Clean up stale daemons first
|
||||
CleanupStaleDaemons()
|
||||
|
||||
// Get all running daemons
|
||||
daemons, err := ListAllDaemons()
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to list daemons: %w", err)
|
||||
}
|
||||
|
||||
if len(daemons) == 0 {
|
||||
fmt.Println("\033[90mNo running tunnels.\033[0m")
|
||||
fmt.Println()
|
||||
fmt.Println("Start a tunnel in background with:")
|
||||
fmt.Println(" \033[36mdrip http 3000 -d\033[0m")
|
||||
fmt.Println(" \033[36mdrip tcp 5432 -d\033[0m")
|
||||
return nil
|
||||
}
|
||||
|
||||
// Print header
|
||||
fmt.Println()
|
||||
fmt.Println("\033[1;37mRunning Tunnels\033[0m")
|
||||
fmt.Println("\033[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\033[0m")
|
||||
fmt.Printf("\033[1m%-4s %-6s %-6s %-40s %-8s %s\033[0m\n", "#", "TYPE", "PORT", "URL", "PID", "UPTIME")
|
||||
fmt.Println("\033[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\033[0m")
|
||||
|
||||
idx := 1
|
||||
for _, d := range daemons {
|
||||
// Check if process is still running
|
||||
if !IsProcessRunning(d.PID) {
|
||||
// Clean up stale entry
|
||||
RemoveDaemonInfo(d.Type, d.Port)
|
||||
continue
|
||||
}
|
||||
|
||||
// Calculate uptime
|
||||
uptime := time.Since(d.StartTime)
|
||||
|
||||
// Format type with color
|
||||
var typeStr string
|
||||
if d.Type == "http" {
|
||||
typeStr = "\033[32mHTTP\033[0m"
|
||||
} else {
|
||||
typeStr = "\033[35mTCP\033[0m"
|
||||
}
|
||||
|
||||
// Truncate URL if too long
|
||||
url := d.URL
|
||||
if len(url) > 40 {
|
||||
url = url[:37] + "..."
|
||||
}
|
||||
|
||||
fmt.Printf("\033[1;36m%-4d\033[0m %-15s %-6d %-40s %-8d %s\n",
|
||||
idx, typeStr, d.Port, url, d.PID, FormatDuration(uptime))
|
||||
idx++
|
||||
}
|
||||
|
||||
fmt.Println("\033[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\033[0m")
|
||||
fmt.Println()
|
||||
|
||||
// Interactive mode or show commands
|
||||
if interactiveMode || shouldPromptForAction() {
|
||||
return runInteractiveList(daemons)
|
||||
}
|
||||
|
||||
fmt.Println("Commands:")
|
||||
fmt.Println(" \033[36mdrip list -i\033[0m Interactive mode")
|
||||
fmt.Println(" \033[36mdrip attach http 3000\033[0m Attach to tunnel (view logs)")
|
||||
fmt.Println(" \033[36mdrip stop http 3000\033[0m Stop tunnel")
|
||||
fmt.Println(" \033[36mdrip stop all\033[0m Stop all tunnels")
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func shouldPromptForAction() bool {
|
||||
// Check if running in a terminal
|
||||
if fileInfo, _ := os.Stdout.Stat(); (fileInfo.Mode() & os.ModeCharDevice) == 0 {
|
||||
return false
|
||||
}
|
||||
// Always prompt when there are tunnels running
|
||||
return true
|
||||
}
|
||||
|
||||
func runInteractiveList(daemons []*DaemonInfo) error {
|
||||
// Filter out non-running daemons
|
||||
var runningDaemons []*DaemonInfo
|
||||
for _, d := range daemons {
|
||||
if IsProcessRunning(d.PID) {
|
||||
runningDaemons = append(runningDaemons, d)
|
||||
} else {
|
||||
RemoveDaemonInfo(d.Type, d.Port)
|
||||
}
|
||||
}
|
||||
|
||||
if len(runningDaemons) == 0 {
|
||||
return nil
|
||||
}
|
||||
|
||||
// Prompt for action
|
||||
fmt.Print("Select a tunnel (number) or 'q' to quit: ")
|
||||
reader := bufio.NewReader(os.Stdin)
|
||||
input, err := reader.ReadString('\n')
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to read input: %w", err)
|
||||
}
|
||||
|
||||
input = strings.TrimSpace(input)
|
||||
if input == "" || input == "q" || input == "Q" {
|
||||
return nil
|
||||
}
|
||||
|
||||
// Parse selection
|
||||
selection, err := strconv.Atoi(input)
|
||||
if err != nil || selection < 1 || selection > len(runningDaemons) {
|
||||
return fmt.Errorf("invalid selection: %s", input)
|
||||
}
|
||||
|
||||
selectedDaemon := runningDaemons[selection-1]
|
||||
|
||||
// Prompt for action
|
||||
fmt.Println()
|
||||
fmt.Printf("Selected: \033[1m%s\033[0m tunnel on port \033[1m%d\033[0m\n", strings.ToUpper(selectedDaemon.Type), selectedDaemon.Port)
|
||||
fmt.Println()
|
||||
fmt.Println("What would you like to do?")
|
||||
fmt.Println(" \033[36m1.\033[0m Attach (view logs)")
|
||||
fmt.Println(" \033[36m2.\033[0m Stop tunnel")
|
||||
fmt.Println(" \033[90mq. Cancel\033[0m")
|
||||
fmt.Println()
|
||||
fmt.Print("Choose an action: ")
|
||||
|
||||
actionInput, err := reader.ReadString('\n')
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to read input: %w", err)
|
||||
}
|
||||
|
||||
actionInput = strings.TrimSpace(actionInput)
|
||||
switch actionInput {
|
||||
case "1":
|
||||
// Attach to daemon
|
||||
return attachToDaemon(selectedDaemon)
|
||||
case "2":
|
||||
// Stop daemon
|
||||
return stopDaemon(selectedDaemon.Type, selectedDaemon.Port)
|
||||
case "q", "Q", "":
|
||||
return nil
|
||||
default:
|
||||
return fmt.Errorf("invalid action: %s", actionInput)
|
||||
}
|
||||
}
|
||||
81
internal/client/cli/root.go
Normal file
81
internal/client/cli/root.go
Normal file
@@ -0,0 +1,81 @@
|
||||
package cli
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
|
||||
"github.com/spf13/cobra"
|
||||
)
|
||||
|
||||
var (
|
||||
// Version information
|
||||
Version = "dev"
|
||||
GitCommit = "unknown"
|
||||
BuildTime = "unknown"
|
||||
|
||||
// Global flags
|
||||
serverURL string
|
||||
authToken string
|
||||
verbose bool
|
||||
insecure bool
|
||||
)
|
||||
|
||||
var rootCmd = &cobra.Command{
|
||||
Use: "drip",
|
||||
Short: "Drip - Fast and secure tunnels to localhost",
|
||||
Long: `Drip - High-performance tunneling service with TCP over TLS 1.3
|
||||
|
||||
Expose your local services to the internet securely and easily.
|
||||
|
||||
Configuration:
|
||||
First time: Run 'drip config init' to set up server and token
|
||||
Subsequent: Just run 'drip http <port>' or 'drip tcp <port>'
|
||||
|
||||
Examples:
|
||||
drip config init # Set up configuration
|
||||
drip http 3000 # HTTP tunnel
|
||||
drip tcp 5432 # PostgreSQL tunnel
|
||||
drip http 8080 --subdomain myapp # Custom subdomain
|
||||
|
||||
Features:
|
||||
✓ TCP over TLS 1.3 (secure and fast)
|
||||
✓ HTTP and TCP tunnel support
|
||||
✓ Auto-save configuration
|
||||
✓ Custom subdomains
|
||||
✓ Authentication via token`,
|
||||
}
|
||||
|
||||
func init() {
|
||||
rootCmd.PersistentFlags().StringVarP(&serverURL, "server", "s", "", "Server address (e.g., tunnel.example.com:443)")
|
||||
rootCmd.PersistentFlags().StringVarP(&authToken, "token", "t", "", "Authentication token")
|
||||
rootCmd.PersistentFlags().BoolVarP(&verbose, "verbose", "v", false, "Verbose output")
|
||||
rootCmd.PersistentFlags().BoolVarP(&insecure, "insecure", "k", false, "Skip TLS verification (testing only, NOT recommended)")
|
||||
|
||||
rootCmd.AddCommand(versionCmd)
|
||||
// http and tcp commands are added in their respective init() functions
|
||||
// config command is added in config.go init() function
|
||||
}
|
||||
|
||||
var versionCmd = &cobra.Command{
|
||||
Use: "version",
|
||||
Short: "Print version information",
|
||||
Run: func(cmd *cobra.Command, args []string) {
|
||||
fmt.Printf("Drip Client\n")
|
||||
fmt.Printf("━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\n")
|
||||
fmt.Printf("Version: %s\n", Version)
|
||||
fmt.Printf("Git Commit: %s\n", GitCommit)
|
||||
fmt.Printf("Build Time: %s\n", BuildTime)
|
||||
fmt.Printf("━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\n")
|
||||
},
|
||||
}
|
||||
|
||||
// Execute runs the root command
|
||||
func Execute() error {
|
||||
return rootCmd.Execute()
|
||||
}
|
||||
|
||||
// SetVersion sets the version information
|
||||
func SetVersion(version, commit, buildTime string) {
|
||||
Version = version
|
||||
GitCommit = commit
|
||||
BuildTime = buildTime
|
||||
}
|
||||
191
internal/client/cli/server.go
Normal file
191
internal/client/cli/server.go
Normal file
@@ -0,0 +1,191 @@
|
||||
package cli
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"net/http"
|
||||
_ "net/http/pprof"
|
||||
"os"
|
||||
"os/signal"
|
||||
"strconv"
|
||||
"syscall"
|
||||
|
||||
"drip/internal/server/proxy"
|
||||
"drip/internal/server/tcp"
|
||||
"drip/internal/server/tunnel"
|
||||
"drip/internal/shared/constants"
|
||||
"drip/internal/shared/utils"
|
||||
"drip/pkg/config"
|
||||
"github.com/spf13/cobra"
|
||||
"go.uber.org/zap"
|
||||
)
|
||||
|
||||
var (
|
||||
serverPort int
|
||||
serverPublicPort int
|
||||
serverDomain string
|
||||
serverAuthToken string
|
||||
serverDebug bool
|
||||
serverTCPPortMin int
|
||||
serverTCPPortMax int
|
||||
serverTLSCert string
|
||||
serverTLSKey string
|
||||
serverPprofPort int
|
||||
)
|
||||
|
||||
var serverCmd = &cobra.Command{
|
||||
Use: "server",
|
||||
Short: "Start Drip server",
|
||||
Long: `Start the Drip tunnel server to accept client connections`,
|
||||
RunE: runServer,
|
||||
}
|
||||
|
||||
func init() {
|
||||
rootCmd.AddCommand(serverCmd)
|
||||
|
||||
// Command line flags with environment variable defaults
|
||||
serverCmd.Flags().IntVarP(&serverPort, "port", "p", getEnvInt("DRIP_PORT", 8443), "Server port (env: DRIP_PORT)")
|
||||
serverCmd.Flags().IntVar(&serverPublicPort, "public-port", getEnvInt("DRIP_PUBLIC_PORT", 0), "Public port to display in URLs (env: DRIP_PUBLIC_PORT)")
|
||||
serverCmd.Flags().StringVarP(&serverDomain, "domain", "d", getEnvString("DRIP_DOMAIN", constants.DefaultDomain), "Server domain (env: DRIP_DOMAIN)")
|
||||
serverCmd.Flags().StringVarP(&serverAuthToken, "token", "t", getEnvString("DRIP_TOKEN", ""), "Authentication token (env: DRIP_TOKEN)")
|
||||
serverCmd.Flags().BoolVar(&serverDebug, "debug", false, "Enable debug logging")
|
||||
serverCmd.Flags().IntVar(&serverTCPPortMin, "tcp-port-min", getEnvInt("DRIP_TCP_PORT_MIN", constants.DefaultTCPPortMin), "Minimum TCP tunnel port (env: DRIP_TCP_PORT_MIN)")
|
||||
serverCmd.Flags().IntVar(&serverTCPPortMax, "tcp-port-max", getEnvInt("DRIP_TCP_PORT_MAX", constants.DefaultTCPPortMax), "Maximum TCP tunnel port (env: DRIP_TCP_PORT_MAX)")
|
||||
|
||||
// TLS options
|
||||
serverCmd.Flags().StringVar(&serverTLSCert, "tls-cert", getEnvString("DRIP_TLS_CERT", ""), "Path to TLS certificate file (env: DRIP_TLS_CERT)")
|
||||
serverCmd.Flags().StringVar(&serverTLSKey, "tls-key", getEnvString("DRIP_TLS_KEY", ""), "Path to TLS private key file (env: DRIP_TLS_KEY)")
|
||||
|
||||
// Performance profiling
|
||||
serverCmd.Flags().IntVar(&serverPprofPort, "pprof", getEnvInt("DRIP_PPROF_PORT", 0), "Enable pprof on specified port (env: DRIP_PPROF_PORT)")
|
||||
}
|
||||
|
||||
func runServer(cmd *cobra.Command, args []string) error {
|
||||
// Validate required TLS configuration
|
||||
if serverTLSCert == "" {
|
||||
return fmt.Errorf("TLS certificate path is required (use --tls-cert flag or DRIP_TLS_CERT environment variable)")
|
||||
}
|
||||
if serverTLSKey == "" {
|
||||
return fmt.Errorf("TLS private key path is required (use --tls-key flag or DRIP_TLS_KEY environment variable)")
|
||||
}
|
||||
|
||||
// Initialize logger
|
||||
if err := utils.InitServerLogger(serverDebug); err != nil {
|
||||
return fmt.Errorf("failed to initialize logger: %w", err)
|
||||
}
|
||||
defer utils.Sync()
|
||||
|
||||
logger := utils.GetLogger()
|
||||
|
||||
logger.Info("Starting Drip Server",
|
||||
zap.String("version", Version),
|
||||
zap.String("commit", GitCommit),
|
||||
)
|
||||
|
||||
// Start pprof server if enabled
|
||||
if serverPprofPort > 0 {
|
||||
go func() {
|
||||
pprofAddr := fmt.Sprintf("localhost:%d", serverPprofPort)
|
||||
logger.Info("Starting pprof server", zap.String("address", pprofAddr))
|
||||
if err := http.ListenAndServe(pprofAddr, nil); err != nil {
|
||||
logger.Error("pprof server failed", zap.Error(err))
|
||||
}
|
||||
}()
|
||||
}
|
||||
|
||||
// Create server config
|
||||
displayPort := serverPublicPort
|
||||
if displayPort == 0 {
|
||||
displayPort = serverPort
|
||||
}
|
||||
|
||||
serverConfig := &config.ServerConfig{
|
||||
Port: serverPort,
|
||||
PublicPort: displayPort,
|
||||
Domain: serverDomain,
|
||||
TCPPortMin: serverTCPPortMin,
|
||||
TCPPortMax: serverTCPPortMax,
|
||||
TLSEnabled: true,
|
||||
TLSCertFile: serverTLSCert,
|
||||
TLSKeyFile: serverTLSKey,
|
||||
AuthToken: serverAuthToken,
|
||||
Debug: serverDebug,
|
||||
}
|
||||
|
||||
// Load TLS configuration
|
||||
tlsConfig, err := serverConfig.LoadTLSConfig()
|
||||
if err != nil {
|
||||
logger.Fatal("Failed to load TLS configuration", zap.Error(err))
|
||||
}
|
||||
|
||||
logger.Info("TLS 1.3 configuration loaded",
|
||||
zap.String("cert", serverTLSCert),
|
||||
zap.String("key", serverTLSKey),
|
||||
)
|
||||
|
||||
// Create tunnel manager
|
||||
tunnelManager := tunnel.NewManager(logger)
|
||||
|
||||
// Create TCP port allocator
|
||||
portAllocator, err := tcp.NewPortAllocator(serverTCPPortMin, serverTCPPortMax)
|
||||
if err != nil {
|
||||
logger.Fatal("Invalid TCP port range", zap.Error(err))
|
||||
}
|
||||
|
||||
// Create TCP listener address
|
||||
listenAddr := fmt.Sprintf("0.0.0.0:%d", serverPort)
|
||||
|
||||
// Response handler for HTTP-over-frame responses
|
||||
responseHandler := proxy.NewResponseHandler(logger)
|
||||
|
||||
// Create HTTP proxy handler (for handling HTTP requests on TCP port)
|
||||
httpHandler := proxy.NewHandler(tunnelManager, logger, responseHandler, serverDomain, serverAuthToken)
|
||||
|
||||
// Create TCP listener (wsHandler also serves as response channel handler)
|
||||
listener := tcp.NewListener(listenAddr, tlsConfig, serverAuthToken, tunnelManager, logger, portAllocator, serverDomain, displayPort, httpHandler, responseHandler)
|
||||
|
||||
// Start listener
|
||||
if err := listener.Start(); err != nil {
|
||||
logger.Fatal("Failed to start TCP listener", zap.Error(err))
|
||||
}
|
||||
|
||||
logger.Info("Drip Server started",
|
||||
zap.String("address", listenAddr),
|
||||
zap.String("domain", serverDomain),
|
||||
zap.String("protocol", "TCP over TLS 1.3"),
|
||||
)
|
||||
|
||||
// Setup signal handler
|
||||
quit := make(chan os.Signal, 1)
|
||||
signal.Notify(quit, syscall.SIGINT, syscall.SIGTERM)
|
||||
|
||||
// Wait for signal
|
||||
<-quit
|
||||
|
||||
logger.Info("Shutting down server...")
|
||||
|
||||
// Stop listener
|
||||
if err := listener.Stop(); err != nil {
|
||||
logger.Error("Error stopping listener", zap.Error(err))
|
||||
}
|
||||
|
||||
logger.Info("Server stopped")
|
||||
return nil
|
||||
}
|
||||
|
||||
// getEnvInt returns the environment variable value as int, or defaultVal if not set
|
||||
func getEnvInt(key string, defaultVal int) int {
|
||||
if val := os.Getenv(key); val != "" {
|
||||
if i, err := strconv.Atoi(val); err == nil {
|
||||
return i
|
||||
}
|
||||
}
|
||||
return defaultVal
|
||||
}
|
||||
|
||||
// getEnvString returns the environment variable value, or defaultVal if not set
|
||||
func getEnvString(key string, defaultVal string) string {
|
||||
if val := os.Getenv(key); val != "" {
|
||||
return val
|
||||
}
|
||||
return defaultVal
|
||||
}
|
||||
125
internal/client/cli/stop.go
Normal file
125
internal/client/cli/stop.go
Normal file
@@ -0,0 +1,125 @@
|
||||
package cli
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"strconv"
|
||||
|
||||
"github.com/spf13/cobra"
|
||||
)
|
||||
|
||||
var stopCmd = &cobra.Command{
|
||||
Use: "stop <type> <port>|all",
|
||||
Short: "Stop background tunnels",
|
||||
Long: `Stop one or all background tunnels.
|
||||
|
||||
Examples:
|
||||
drip stop http 3000 Stop HTTP tunnel on port 3000
|
||||
drip stop tcp 5432 Stop TCP tunnel on port 5432
|
||||
drip stop all Stop all running tunnels
|
||||
|
||||
Use 'drip list' to see running tunnels.`,
|
||||
Aliases: []string{"kill"},
|
||||
Args: cobra.MinimumNArgs(1),
|
||||
RunE: runStop,
|
||||
}
|
||||
|
||||
func init() {
|
||||
rootCmd.AddCommand(stopCmd)
|
||||
}
|
||||
|
||||
func runStop(cmd *cobra.Command, args []string) error {
|
||||
// Handle "stop all"
|
||||
if args[0] == "all" {
|
||||
return stopAllDaemons()
|
||||
}
|
||||
|
||||
// Handle "stop <type> <port>"
|
||||
if len(args) < 2 {
|
||||
return fmt.Errorf("usage: drip stop <type> <port> or drip stop all")
|
||||
}
|
||||
|
||||
tunnelType := args[0]
|
||||
if tunnelType != "http" && tunnelType != "tcp" {
|
||||
return fmt.Errorf("invalid tunnel type: %s (must be 'http' or 'tcp')", tunnelType)
|
||||
}
|
||||
|
||||
port, err := strconv.Atoi(args[1])
|
||||
if err != nil || port < 1 || port > 65535 {
|
||||
return fmt.Errorf("invalid port number: %s", args[1])
|
||||
}
|
||||
|
||||
return stopDaemon(tunnelType, port)
|
||||
}
|
||||
|
||||
func stopDaemon(tunnelType string, port int) error {
|
||||
info, err := LoadDaemonInfo(tunnelType, port)
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to load daemon info: %w", err)
|
||||
}
|
||||
|
||||
if info == nil {
|
||||
return fmt.Errorf("no %s tunnel running on port %d", tunnelType, port)
|
||||
}
|
||||
|
||||
// Check if process is still running
|
||||
if !IsProcessRunning(info.PID) {
|
||||
// Clean up stale entry
|
||||
RemoveDaemonInfo(tunnelType, port)
|
||||
return fmt.Errorf("tunnel was not running (cleaned up stale entry)")
|
||||
}
|
||||
|
||||
// Kill the process
|
||||
if err := KillProcess(info.PID); err != nil {
|
||||
return fmt.Errorf("failed to stop tunnel: %w", err)
|
||||
}
|
||||
|
||||
// Remove daemon info
|
||||
RemoveDaemonInfo(tunnelType, port)
|
||||
|
||||
fmt.Printf("\033[32m✓\033[0m Stopped %s tunnel on port %d (PID: %d)\n", tunnelType, port, info.PID)
|
||||
return nil
|
||||
}
|
||||
|
||||
func stopAllDaemons() error {
|
||||
// Clean up stale daemons first
|
||||
CleanupStaleDaemons()
|
||||
|
||||
daemons, err := ListAllDaemons()
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to list daemons: %w", err)
|
||||
}
|
||||
|
||||
if len(daemons) == 0 {
|
||||
fmt.Println("\033[90mNo running tunnels to stop.\033[0m")
|
||||
return nil
|
||||
}
|
||||
|
||||
stopped := 0
|
||||
failed := 0
|
||||
|
||||
for _, d := range daemons {
|
||||
if !IsProcessRunning(d.PID) {
|
||||
RemoveDaemonInfo(d.Type, d.Port)
|
||||
continue
|
||||
}
|
||||
|
||||
if err := KillProcess(d.PID); err != nil {
|
||||
fmt.Printf("\033[31m✗\033[0m Failed to stop %s tunnel on port %d: %v\n", d.Type, d.Port, err)
|
||||
failed++
|
||||
continue
|
||||
}
|
||||
|
||||
RemoveDaemonInfo(d.Type, d.Port)
|
||||
fmt.Printf("\033[32m✓\033[0m Stopped %s tunnel on port %d (PID: %d)\n", d.Type, d.Port, d.PID)
|
||||
stopped++
|
||||
}
|
||||
|
||||
fmt.Println()
|
||||
if failed > 0 {
|
||||
fmt.Printf("Stopped %d tunnel(s), %d failed\n", stopped, failed)
|
||||
} else {
|
||||
fmt.Printf("Stopped %d tunnel(s)\n", stopped)
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
361
internal/client/cli/tcp.go
Normal file
361
internal/client/cli/tcp.go
Normal file
@@ -0,0 +1,361 @@
|
||||
package cli
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"os"
|
||||
"os/signal"
|
||||
"strconv"
|
||||
"strings"
|
||||
"syscall"
|
||||
"time"
|
||||
|
||||
"drip/internal/client/tcp"
|
||||
"drip/internal/shared/protocol"
|
||||
"drip/internal/shared/utils"
|
||||
"drip/pkg/config"
|
||||
|
||||
"github.com/spf13/cobra"
|
||||
"go.uber.org/zap"
|
||||
)
|
||||
|
||||
var tcpCmd = &cobra.Command{
|
||||
Use: "tcp <port>",
|
||||
Short: "Start TCP tunnel",
|
||||
Long: `Start a TCP tunnel to expose any TCP service.
|
||||
|
||||
Example:
|
||||
drip tcp 5432 Tunnel PostgreSQL
|
||||
drip tcp 3306 Tunnel MySQL
|
||||
drip tcp 22 Tunnel SSH
|
||||
drip tcp 6379 --subdomain myredis Tunnel Redis with custom subdomain
|
||||
|
||||
Supported Services:
|
||||
- Databases: PostgreSQL (5432), MySQL (3306), Redis (6379), MongoDB (27017)
|
||||
- SSH: Port 22
|
||||
- Any TCP service
|
||||
|
||||
Configuration:
|
||||
First time: Run 'drip config init' to save server and token
|
||||
Subsequent: Just run 'drip tcp <port>'
|
||||
|
||||
Note: Uses TCP over TLS 1.3 for secure communication`,
|
||||
Args: cobra.ExactArgs(1),
|
||||
RunE: runTCP,
|
||||
}
|
||||
|
||||
func init() {
|
||||
tcpCmd.Flags().StringVarP(&subdomain, "subdomain", "n", "", "Custom subdomain (optional)")
|
||||
tcpCmd.Flags().BoolVarP(&daemonMode, "daemon", "d", false, "Run in background (daemon mode)")
|
||||
tcpCmd.Flags().StringVarP(&localAddress, "address", "a", "127.0.0.1", "Local address to forward to (default: 127.0.0.1)")
|
||||
tcpCmd.Flags().BoolVar(&daemonMarker, "daemon-child", false, "Internal flag for daemon child process")
|
||||
tcpCmd.Flags().MarkHidden("daemon-child")
|
||||
rootCmd.AddCommand(tcpCmd)
|
||||
}
|
||||
|
||||
func runTCP(cmd *cobra.Command, args []string) error {
|
||||
// Parse port
|
||||
port, err := strconv.Atoi(args[0])
|
||||
if err != nil || port < 1 || port > 65535 {
|
||||
return fmt.Errorf("invalid port number: %s", args[0])
|
||||
}
|
||||
|
||||
// Handle daemon mode
|
||||
if daemonMode && !daemonMarker {
|
||||
// Start as daemon
|
||||
daemonArgs := append([]string{"tcp"}, args...)
|
||||
daemonArgs = append(daemonArgs, "--daemon-child")
|
||||
if subdomain != "" {
|
||||
daemonArgs = append(daemonArgs, "--subdomain", subdomain)
|
||||
}
|
||||
if localAddress != "127.0.0.1" {
|
||||
daemonArgs = append(daemonArgs, "--address", localAddress)
|
||||
}
|
||||
if serverURL != "" {
|
||||
daemonArgs = append(daemonArgs, "--server", serverURL)
|
||||
}
|
||||
if authToken != "" {
|
||||
daemonArgs = append(daemonArgs, "--token", authToken)
|
||||
}
|
||||
if insecure {
|
||||
daemonArgs = append(daemonArgs, "--insecure")
|
||||
}
|
||||
if verbose {
|
||||
daemonArgs = append(daemonArgs, "--verbose")
|
||||
}
|
||||
return StartDaemon("tcp", port, daemonArgs)
|
||||
}
|
||||
|
||||
// Initialize logger
|
||||
if err := utils.InitLogger(verbose); err != nil {
|
||||
return fmt.Errorf("failed to initialize logger: %w", err)
|
||||
}
|
||||
defer utils.Sync()
|
||||
|
||||
logger := utils.GetLogger()
|
||||
|
||||
// Load configuration or use command line flags
|
||||
var serverAddr, token string
|
||||
|
||||
if serverURL == "" {
|
||||
// Try to load from config file
|
||||
cfg, err := config.LoadClientConfig("")
|
||||
if err != nil {
|
||||
return fmt.Errorf(`configuration not found.
|
||||
|
||||
Please run 'drip config init' first, or use flags:
|
||||
drip tcp %d --server SERVER:PORT --token TOKEN`, port)
|
||||
}
|
||||
serverAddr = cfg.Server
|
||||
token = cfg.Token
|
||||
} else {
|
||||
// Use command line flags
|
||||
serverAddr = serverURL
|
||||
token = authToken
|
||||
}
|
||||
|
||||
// Validate server address
|
||||
if serverAddr == "" {
|
||||
return fmt.Errorf("server address is required")
|
||||
}
|
||||
|
||||
// Create connector config
|
||||
connConfig := &tcp.ConnectorConfig{
|
||||
ServerAddr: serverAddr,
|
||||
Token: token,
|
||||
TunnelType: protocol.TunnelTypeTCP,
|
||||
LocalHost: localAddress,
|
||||
LocalPort: port,
|
||||
Subdomain: subdomain,
|
||||
Insecure: insecure,
|
||||
}
|
||||
|
||||
// Setup signal handler for graceful shutdown
|
||||
quit := make(chan os.Signal, 1)
|
||||
signal.Notify(quit, syscall.SIGINT, syscall.SIGTERM)
|
||||
|
||||
// Connection loop with reconnect support
|
||||
reconnectAttempts := 0
|
||||
serviceName := getServiceName(port)
|
||||
for {
|
||||
// Create connector
|
||||
connector := tcp.NewConnector(connConfig, logger)
|
||||
|
||||
// Connect to server
|
||||
if reconnectAttempts == 0 {
|
||||
fmt.Printf("\033[36m🔌 Connecting to %s...\033[0m\n", serverAddr)
|
||||
} else {
|
||||
fmt.Printf("\033[33m🔄 Reconnecting to %s (attempt %d/%d)...\033[0m\n", serverAddr, reconnectAttempts, maxReconnectAttempts)
|
||||
}
|
||||
|
||||
if err := connector.Connect(); err != nil {
|
||||
// Check if this is a non-retryable error
|
||||
if isNonRetryableErrorTCP(err) {
|
||||
return fmt.Errorf("failed to connect: %w", err)
|
||||
}
|
||||
|
||||
reconnectAttempts++
|
||||
if reconnectAttempts >= maxReconnectAttempts {
|
||||
return fmt.Errorf("failed to connect after %d attempts: %w", maxReconnectAttempts, err)
|
||||
}
|
||||
fmt.Printf("\033[31m✗ Connection failed: %v\033[0m\n", err)
|
||||
fmt.Printf("\033[90m Retrying in %v...\033[0m\n", reconnectInterval)
|
||||
|
||||
// Wait before retry, but allow interrupt
|
||||
select {
|
||||
case <-quit:
|
||||
fmt.Println("\n\033[33m🛑 Shutting down...\033[0m")
|
||||
return nil
|
||||
case <-time.After(reconnectInterval):
|
||||
continue
|
||||
}
|
||||
}
|
||||
|
||||
// Reset reconnect attempts on successful connection
|
||||
reconnectAttempts = 0
|
||||
|
||||
// Save daemon info if running as daemon child
|
||||
if daemonMarker {
|
||||
daemonInfo := &DaemonInfo{
|
||||
PID: os.Getpid(),
|
||||
Type: "tcp",
|
||||
Port: port,
|
||||
Subdomain: subdomain,
|
||||
Server: serverAddr,
|
||||
URL: connector.GetURL(),
|
||||
StartTime: time.Now(),
|
||||
Executable: os.Args[0],
|
||||
}
|
||||
if err := SaveDaemonInfo(daemonInfo); err != nil {
|
||||
// Log but don't fail
|
||||
logger.Warn("Failed to save daemon info", zap.Error(err))
|
||||
}
|
||||
}
|
||||
|
||||
// Print tunnel information
|
||||
fmt.Println()
|
||||
fmt.Println("\033[1;35m╔══════════════════════════════════════════════════════════════════╗\033[0m")
|
||||
fmt.Println("\033[1;35m║\033[0m \033[1;37m🔌 TCP Tunnel Connected Successfully!\033[0m \033[1;35m║\033[0m")
|
||||
fmt.Println("\033[1;35m╠══════════════════════════════════════════════════════════════════╣\033[0m")
|
||||
fmt.Printf("\033[1;35m║\033[0m \033[1;37mTunnel URL:\033[0m \033[1;35m║\033[0m\n")
|
||||
fmt.Printf("\033[1;35m║\033[0m \033[1;36m%-60s\033[0m \033[1;35m║\033[0m\n", connector.GetURL())
|
||||
fmt.Println("\033[1;35m║\033[0m \033[1;35m║\033[0m")
|
||||
fmt.Printf("\033[1;35m║\033[0m \033[90mService:\033[0m \033[1;35m%-50s\033[0m \033[1;35m║\033[0m\n", serviceName)
|
||||
displayAddr := localAddress
|
||||
if displayAddr == "127.0.0.1" {
|
||||
displayAddr = "localhost"
|
||||
}
|
||||
fmt.Printf("\033[1;35m║\033[0m \033[90mForwarding:\033[0m \033[1m%s:%d\033[0m → \033[36m%s\033[0m%-15s\033[1;35m║\033[0m\n", displayAddr, port, "public", "")
|
||||
fmt.Printf("\033[1;35m║\033[0m \033[90mLatency:\033[0m \033[90mmeasuring...\033[0m%-40s\033[1;35m║\033[0m\n", "")
|
||||
fmt.Printf("\033[1;35m║\033[0m \033[90mTraffic:\033[0m \033[90m↓ 0 B ↑ 0 B\033[0m%-32s\033[1;35m║\033[0m\n", "")
|
||||
fmt.Printf("\033[1;35m║\033[0m \033[90mSpeed:\033[0m \033[90m↓ 0 B/s ↑ 0 B/s\033[0m%-28s\033[1;35m║\033[0m\n", "")
|
||||
fmt.Printf("\033[1;35m║\033[0m \033[90mRequests:\033[0m \033[90m0\033[0m%-43s\033[1;35m║\033[0m\n", "")
|
||||
fmt.Println("\033[1;35m╠══════════════════════════════════════════════════════════════════╣\033[0m")
|
||||
fmt.Println("\033[1;35m║\033[0m \033[90mPress Ctrl+C to stop the tunnel\033[0m \033[1;35m║\033[0m")
|
||||
fmt.Println("\033[1;35m╚══════════════════════════════════════════════════════════════════╝\033[0m")
|
||||
fmt.Println()
|
||||
|
||||
// Setup latency display
|
||||
latencyCh := make(chan time.Duration, 1)
|
||||
connector.SetLatencyCallback(func(latency time.Duration) {
|
||||
select {
|
||||
case latencyCh <- latency:
|
||||
default:
|
||||
}
|
||||
})
|
||||
|
||||
// Start stats display updater (updates every second)
|
||||
stopDisplay := make(chan struct{})
|
||||
disconnected := make(chan struct{})
|
||||
|
||||
go func() {
|
||||
ticker := time.NewTicker(1 * time.Second)
|
||||
defer ticker.Stop()
|
||||
|
||||
var lastLatency time.Duration
|
||||
for {
|
||||
select {
|
||||
case latency := <-latencyCh:
|
||||
lastLatency = latency
|
||||
case <-ticker.C:
|
||||
// Update speed calculation
|
||||
stats := connector.GetStats()
|
||||
if stats != nil {
|
||||
stats.UpdateSpeed()
|
||||
snapshot := stats.GetSnapshot()
|
||||
|
||||
// Move cursor up 8 lines to update display
|
||||
fmt.Print("\033[8A")
|
||||
|
||||
// Update latency line
|
||||
fmt.Printf("\r\033[1;35m║\033[0m \033[90mLatency:\033[0m %s%-40s\033[1;35m║\033[0m\n", formatLatencyTCP(lastLatency), "")
|
||||
|
||||
// Update traffic line
|
||||
trafficStr := fmt.Sprintf("↓ %s ↑ %s", tcp.FormatBytes(snapshot.TotalBytesIn), tcp.FormatBytes(snapshot.TotalBytesOut))
|
||||
fmt.Printf("\r\033[1;35m║\033[0m \033[90mTraffic:\033[0m \033[36m%-48s\033[0m\033[1;35m║\033[0m\n", trafficStr)
|
||||
|
||||
// Update speed line
|
||||
speedStr := fmt.Sprintf("↓ %s ↑ %s", tcp.FormatSpeed(snapshot.SpeedIn), tcp.FormatSpeed(snapshot.SpeedOut))
|
||||
fmt.Printf("\r\033[1;35m║\033[0m \033[90mSpeed:\033[0m \033[33m%-48s\033[0m\033[1;35m║\033[0m\n", speedStr)
|
||||
|
||||
// Update requests line
|
||||
fmt.Printf("\r\033[1;35m║\033[0m \033[90mRequests:\033[0m \033[35m%-47d\033[0m\033[1;35m║\033[0m\n", snapshot.TotalRequests)
|
||||
|
||||
// Move back down 4 lines
|
||||
fmt.Print("\033[4B")
|
||||
}
|
||||
case <-stopDisplay:
|
||||
return
|
||||
}
|
||||
}
|
||||
}()
|
||||
|
||||
// Monitor connection in background
|
||||
go func() {
|
||||
connector.Wait()
|
||||
close(disconnected)
|
||||
}()
|
||||
|
||||
// Wait for signal or disconnection
|
||||
select {
|
||||
case <-quit:
|
||||
close(stopDisplay)
|
||||
fmt.Println("\n\n\033[33m🛑 Shutting down...\033[0m")
|
||||
connector.Close()
|
||||
if daemonMarker {
|
||||
RemoveDaemonInfo("tcp", port)
|
||||
}
|
||||
fmt.Println("\033[32m✓\033[0m Tunnel closed")
|
||||
return nil
|
||||
case <-disconnected:
|
||||
close(stopDisplay)
|
||||
fmt.Println("\n\n\033[31m⚠ Connection lost!\033[0m")
|
||||
reconnectAttempts++
|
||||
if reconnectAttempts >= maxReconnectAttempts {
|
||||
return fmt.Errorf("connection lost after %d reconnect attempts", maxReconnectAttempts)
|
||||
}
|
||||
fmt.Printf("\033[90m Reconnecting in %v...\033[0m\n", reconnectInterval)
|
||||
|
||||
// Wait before reconnect, but allow interrupt
|
||||
select {
|
||||
case <-quit:
|
||||
fmt.Println("\n\033[33m🛑 Shutting down...\033[0m")
|
||||
return nil
|
||||
case <-time.After(reconnectInterval):
|
||||
continue
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// getServiceName returns a friendly name for common port numbers
|
||||
func getServiceName(port int) string {
|
||||
services := map[int]string{
|
||||
22: "SSH",
|
||||
80: "HTTP",
|
||||
443: "HTTPS",
|
||||
3306: "MySQL",
|
||||
5432: "PostgreSQL",
|
||||
6379: "Redis",
|
||||
27017: "MongoDB",
|
||||
3389: "RDP",
|
||||
5900: "VNC",
|
||||
8080: "HTTP (Alt)",
|
||||
8443: "HTTPS (Alt)",
|
||||
}
|
||||
|
||||
if name, ok := services[port]; ok {
|
||||
return fmt.Sprintf("%s (port %d)", name, port)
|
||||
}
|
||||
|
||||
return fmt.Sprintf("TCP service on port %d", port)
|
||||
}
|
||||
|
||||
// formatLatencyTCP formats latency with color based on value
|
||||
func formatLatencyTCP(d time.Duration) string {
|
||||
ms := d.Milliseconds()
|
||||
if ms < 50 {
|
||||
return fmt.Sprintf("\033[32m%dms\033[0m", ms) // Green: excellent
|
||||
} else if ms < 100 {
|
||||
return fmt.Sprintf("\033[33m%dms\033[0m", ms) // Yellow: good
|
||||
} else if ms < 200 {
|
||||
return fmt.Sprintf("\033[38;5;208m%dms\033[0m", ms) // Orange: moderate
|
||||
}
|
||||
return fmt.Sprintf("\033[31m%dms\033[0m", ms) // Red: poor
|
||||
}
|
||||
|
||||
// isNonRetryableErrorTCP checks if an error should not be retried
|
||||
func isNonRetryableErrorTCP(err error) bool {
|
||||
errStr := err.Error()
|
||||
// Subdomain conflicts - no point retrying
|
||||
if strings.Contains(errStr, "subdomain is already taken") ||
|
||||
strings.Contains(errStr, "subdomain is reserved") ||
|
||||
strings.Contains(errStr, "invalid subdomain") {
|
||||
return true
|
||||
}
|
||||
// Authentication errors - no point retrying
|
||||
if strings.Contains(errStr, "authentication") ||
|
||||
strings.Contains(errStr, "Invalid authentication token") {
|
||||
return true
|
||||
}
|
||||
return false
|
||||
}
|
||||
393
internal/client/tcp/connector.go
Normal file
393
internal/client/tcp/connector.go
Normal file
@@ -0,0 +1,393 @@
|
||||
package tcp
|
||||
|
||||
import (
|
||||
"crypto/tls"
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"net"
|
||||
"sync"
|
||||
"time"
|
||||
|
||||
"drip/internal/shared/constants"
|
||||
"drip/internal/shared/pool"
|
||||
"drip/internal/shared/protocol"
|
||||
"drip/pkg/config"
|
||||
"go.uber.org/zap"
|
||||
)
|
||||
|
||||
// LatencyCallback is called when latency is measured
|
||||
type LatencyCallback func(latency time.Duration)
|
||||
|
||||
// Connector manages the TCP connection to the server
|
||||
type Connector struct {
|
||||
serverAddr string
|
||||
tlsConfig *tls.Config
|
||||
token string
|
||||
tunnelType protocol.TunnelType
|
||||
localHost string
|
||||
localPort int
|
||||
subdomain string
|
||||
conn net.Conn
|
||||
logger *zap.Logger
|
||||
stopCh chan struct{}
|
||||
once sync.Once
|
||||
registered bool
|
||||
assignedURL string
|
||||
frameHandler *FrameHandler
|
||||
frameWriter *protocol.FrameWriter
|
||||
latencyCallback LatencyCallback
|
||||
heartbeatSentAt time.Time
|
||||
heartbeatMu sync.Mutex
|
||||
lastLatency time.Duration
|
||||
handlerWg sync.WaitGroup // Tracks active data frame handlers
|
||||
closed bool
|
||||
closedMu sync.RWMutex
|
||||
}
|
||||
|
||||
// ConnectorConfig holds connector configuration
|
||||
type ConnectorConfig struct {
|
||||
ServerAddr string
|
||||
Token string
|
||||
TunnelType protocol.TunnelType
|
||||
LocalHost string // Local host address (default: 127.0.0.1)
|
||||
LocalPort int
|
||||
Subdomain string // Optional custom subdomain
|
||||
Insecure bool // Skip TLS verification (testing only)
|
||||
}
|
||||
|
||||
// NewConnector creates a new connector
|
||||
func NewConnector(cfg *ConnectorConfig, logger *zap.Logger) *Connector {
|
||||
var tlsConfig *tls.Config
|
||||
if cfg.Insecure {
|
||||
tlsConfig = config.GetClientTLSConfigInsecure()
|
||||
} else {
|
||||
host, _, _ := net.SplitHostPort(cfg.ServerAddr)
|
||||
tlsConfig = config.GetClientTLSConfig(host)
|
||||
}
|
||||
|
||||
localHost := cfg.LocalHost
|
||||
if localHost == "" {
|
||||
localHost = "127.0.0.1"
|
||||
}
|
||||
|
||||
return &Connector{
|
||||
serverAddr: cfg.ServerAddr,
|
||||
tlsConfig: tlsConfig,
|
||||
token: cfg.Token,
|
||||
tunnelType: cfg.TunnelType,
|
||||
localHost: localHost,
|
||||
localPort: cfg.LocalPort,
|
||||
subdomain: cfg.Subdomain,
|
||||
logger: logger,
|
||||
stopCh: make(chan struct{}),
|
||||
}
|
||||
}
|
||||
|
||||
// Connect connects to the server and registers the tunnel
|
||||
func (c *Connector) Connect() error {
|
||||
c.logger.Info("Connecting to server",
|
||||
zap.String("server", c.serverAddr),
|
||||
zap.String("tunnel_type", string(c.tunnelType)),
|
||||
zap.String("local_host", c.localHost),
|
||||
zap.Int("local_port", c.localPort),
|
||||
)
|
||||
|
||||
dialer := &net.Dialer{
|
||||
Timeout: 10 * time.Second,
|
||||
}
|
||||
|
||||
conn, err := tls.DialWithDialer(dialer, "tcp", c.serverAddr, c.tlsConfig)
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to connect: %w", err)
|
||||
}
|
||||
|
||||
c.conn = conn
|
||||
|
||||
state := conn.ConnectionState()
|
||||
if state.Version != tls.VersionTLS13 {
|
||||
conn.Close()
|
||||
return fmt.Errorf("server not using TLS 1.3 (version: 0x%04x)", state.Version)
|
||||
}
|
||||
|
||||
c.logger.Info("TLS connection established",
|
||||
zap.String("cipher_suite", tls.CipherSuiteName(state.CipherSuite)),
|
||||
)
|
||||
|
||||
if err := c.register(); err != nil {
|
||||
conn.Close()
|
||||
return fmt.Errorf("registration failed: %w", err)
|
||||
}
|
||||
|
||||
c.frameWriter = protocol.NewFrameWriter(c.conn)
|
||||
bufferPool := pool.NewBufferPool()
|
||||
|
||||
c.frameHandler = NewFrameHandler(
|
||||
c.conn,
|
||||
c.frameWriter,
|
||||
c.localHost,
|
||||
c.localPort,
|
||||
c.tunnelType,
|
||||
c.logger,
|
||||
c.IsClosed,
|
||||
bufferPool,
|
||||
)
|
||||
|
||||
go c.frameHandler.WarmupConnectionPool(3)
|
||||
go c.handleFrames()
|
||||
go c.heartbeat()
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// register sends registration request and waits for acknowledgment
|
||||
func (c *Connector) register() error {
|
||||
req := protocol.RegisterRequest{
|
||||
Token: c.token,
|
||||
CustomSubdomain: c.subdomain,
|
||||
TunnelType: c.tunnelType,
|
||||
LocalPort: c.localPort,
|
||||
}
|
||||
|
||||
payload, err := json.Marshal(req)
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to marshal request: %w", err)
|
||||
}
|
||||
|
||||
regFrame := protocol.NewFrame(protocol.FrameTypeRegister, payload)
|
||||
err = protocol.WriteFrame(c.conn, regFrame)
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to send registration: %w", err)
|
||||
}
|
||||
|
||||
c.conn.SetReadDeadline(time.Now().Add(constants.RequestTimeout))
|
||||
ackFrame, err := protocol.ReadFrame(c.conn)
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to read ack: %w", err)
|
||||
}
|
||||
defer ackFrame.Release()
|
||||
|
||||
c.conn.SetReadDeadline(time.Time{})
|
||||
|
||||
if ackFrame.Type == protocol.FrameTypeError {
|
||||
var errMsg protocol.ErrorMessage
|
||||
if err := json.Unmarshal(ackFrame.Payload, &errMsg); err == nil {
|
||||
return fmt.Errorf("registration error: %s - %s", errMsg.Code, errMsg.Message)
|
||||
}
|
||||
return fmt.Errorf("registration error")
|
||||
}
|
||||
|
||||
if ackFrame.Type != protocol.FrameTypeRegisterAck {
|
||||
return fmt.Errorf("unexpected frame type: %s", ackFrame.Type)
|
||||
}
|
||||
|
||||
var resp protocol.RegisterResponse
|
||||
if err := json.Unmarshal(ackFrame.Payload, &resp); err != nil {
|
||||
return fmt.Errorf("failed to parse response: %w", err)
|
||||
}
|
||||
|
||||
c.registered = true
|
||||
c.assignedURL = resp.URL
|
||||
c.subdomain = resp.Subdomain
|
||||
|
||||
c.logger.Info("Tunnel registered successfully",
|
||||
zap.String("subdomain", resp.Subdomain),
|
||||
zap.String("url", resp.URL),
|
||||
zap.Int("remote_port", resp.Port),
|
||||
)
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// handleFrames handles incoming frames from server
|
||||
func (c *Connector) handleFrames() {
|
||||
defer c.Close()
|
||||
|
||||
for {
|
||||
select {
|
||||
case <-c.stopCh:
|
||||
return
|
||||
default:
|
||||
}
|
||||
|
||||
c.conn.SetReadDeadline(time.Now().Add(constants.RequestTimeout))
|
||||
frame, err := protocol.ReadFrame(c.conn)
|
||||
if err != nil {
|
||||
if netErr, ok := err.(net.Error); ok && netErr.Timeout() {
|
||||
c.logger.Warn("Read timeout")
|
||||
return
|
||||
}
|
||||
select {
|
||||
case <-c.stopCh:
|
||||
return
|
||||
default:
|
||||
c.logger.Error("Failed to read frame", zap.Error(err))
|
||||
return
|
||||
}
|
||||
}
|
||||
switch frame.Type {
|
||||
case protocol.FrameTypeHeartbeatAck:
|
||||
c.heartbeatMu.Lock()
|
||||
if !c.heartbeatSentAt.IsZero() {
|
||||
latency := time.Since(c.heartbeatSentAt)
|
||||
c.lastLatency = latency
|
||||
c.heartbeatMu.Unlock()
|
||||
|
||||
c.logger.Debug("Received heartbeat ack", zap.Duration("latency", latency))
|
||||
|
||||
if c.latencyCallback != nil {
|
||||
c.latencyCallback(latency)
|
||||
}
|
||||
} else {
|
||||
c.heartbeatMu.Unlock()
|
||||
c.logger.Debug("Received heartbeat ack")
|
||||
}
|
||||
frame.Release()
|
||||
|
||||
case protocol.FrameTypeData:
|
||||
c.handlerWg.Add(1)
|
||||
go func(f *protocol.Frame) {
|
||||
defer c.handlerWg.Done()
|
||||
defer f.Release()
|
||||
if err := c.frameHandler.HandleDataFrame(f); err != nil {
|
||||
c.logger.Error("Failed to handle data frame", zap.Error(err))
|
||||
}
|
||||
}(frame)
|
||||
|
||||
case protocol.FrameTypeClose:
|
||||
frame.Release()
|
||||
c.logger.Info("Server requested close")
|
||||
return
|
||||
|
||||
case protocol.FrameTypeError:
|
||||
var errMsg protocol.ErrorMessage
|
||||
if err := json.Unmarshal(frame.Payload, &errMsg); err == nil {
|
||||
c.logger.Error("Received error from server",
|
||||
zap.String("code", errMsg.Code),
|
||||
zap.String("message", errMsg.Message),
|
||||
)
|
||||
}
|
||||
frame.Release()
|
||||
return
|
||||
|
||||
default:
|
||||
frame.Release()
|
||||
c.logger.Warn("Unexpected frame type",
|
||||
zap.String("type", frame.Type.String()),
|
||||
)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// heartbeat sends periodic heartbeat frames
|
||||
func (c *Connector) heartbeat() {
|
||||
ticker := time.NewTicker(constants.HeartbeatInterval)
|
||||
defer ticker.Stop()
|
||||
|
||||
c.sendHeartbeat()
|
||||
|
||||
for {
|
||||
select {
|
||||
case <-c.stopCh:
|
||||
return
|
||||
case <-ticker.C:
|
||||
c.sendHeartbeat()
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// sendHeartbeat sends a heartbeat frame and records the time
|
||||
func (c *Connector) sendHeartbeat() {
|
||||
hbFrame := protocol.NewFrame(protocol.FrameTypeHeartbeat, nil)
|
||||
|
||||
c.heartbeatMu.Lock()
|
||||
c.heartbeatSentAt = time.Now()
|
||||
c.heartbeatMu.Unlock()
|
||||
|
||||
err := c.frameWriter.WriteFrame(hbFrame)
|
||||
if err != nil {
|
||||
c.logger.Error("Failed to send heartbeat", zap.Error(err))
|
||||
c.Close()
|
||||
return
|
||||
}
|
||||
c.logger.Debug("Heartbeat sent")
|
||||
}
|
||||
|
||||
// SendFrame sends a frame to the server
|
||||
func (c *Connector) SendFrame(frame *protocol.Frame) error {
|
||||
if !c.registered {
|
||||
return fmt.Errorf("not registered")
|
||||
}
|
||||
|
||||
return c.frameWriter.WriteFrame(frame)
|
||||
}
|
||||
|
||||
// Close closes the connection
|
||||
func (c *Connector) Close() error {
|
||||
c.once.Do(func() {
|
||||
c.closedMu.Lock()
|
||||
c.closed = true
|
||||
c.closedMu.Unlock()
|
||||
|
||||
close(c.stopCh)
|
||||
|
||||
c.logger.Debug("Waiting for active handlers to complete")
|
||||
c.handlerWg.Wait()
|
||||
|
||||
if c.conn != nil {
|
||||
closeFrame := protocol.NewFrame(protocol.FrameTypeClose, nil)
|
||||
|
||||
if c.frameWriter != nil {
|
||||
c.frameWriter.WriteFrame(closeFrame)
|
||||
c.frameWriter.Close()
|
||||
} else {
|
||||
protocol.WriteFrame(c.conn, closeFrame)
|
||||
}
|
||||
|
||||
c.conn.Close()
|
||||
}
|
||||
c.logger.Info("Connector closed")
|
||||
})
|
||||
return nil
|
||||
}
|
||||
|
||||
// Wait blocks until connection is closed
|
||||
func (c *Connector) Wait() {
|
||||
<-c.stopCh
|
||||
}
|
||||
|
||||
// GetURL returns the assigned tunnel URL
|
||||
func (c *Connector) GetURL() string {
|
||||
return c.assignedURL
|
||||
}
|
||||
|
||||
// GetSubdomain returns the assigned subdomain
|
||||
func (c *Connector) GetSubdomain() string {
|
||||
return c.subdomain
|
||||
}
|
||||
|
||||
// SetLatencyCallback sets the callback for latency updates
|
||||
func (c *Connector) SetLatencyCallback(cb LatencyCallback) {
|
||||
c.latencyCallback = cb
|
||||
}
|
||||
|
||||
// GetLatency returns the last measured latency
|
||||
func (c *Connector) GetLatency() time.Duration {
|
||||
c.heartbeatMu.Lock()
|
||||
defer c.heartbeatMu.Unlock()
|
||||
return c.lastLatency
|
||||
}
|
||||
|
||||
// GetStats returns the traffic stats from the frame handler
|
||||
func (c *Connector) GetStats() *TrafficStats {
|
||||
if c.frameHandler != nil {
|
||||
return c.frameHandler.GetStats()
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
// IsClosed returns whether the connector has been closed
|
||||
func (c *Connector) IsClosed() bool {
|
||||
c.closedMu.RLock()
|
||||
defer c.closedMu.RUnlock()
|
||||
return c.closed
|
||||
}
|
||||
440
internal/client/tcp/frame_handler.go
Normal file
440
internal/client/tcp/frame_handler.go
Normal file
@@ -0,0 +1,440 @@
|
||||
package tcp
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"crypto/tls"
|
||||
"fmt"
|
||||
"io"
|
||||
"net"
|
||||
"net/http"
|
||||
"strings"
|
||||
"sync"
|
||||
"time"
|
||||
|
||||
"drip/internal/shared/pool"
|
||||
"drip/internal/shared/protocol"
|
||||
"go.uber.org/zap"
|
||||
)
|
||||
|
||||
// FrameHandler handles data frames and forwards to local service
|
||||
type FrameHandler struct {
|
||||
conn net.Conn
|
||||
frameWriter *protocol.FrameWriter // Async batch writer (replaces writeMu)
|
||||
localHost string
|
||||
localPort int
|
||||
logger *zap.Logger
|
||||
streams map[string]*Stream
|
||||
streamMu sync.RWMutex
|
||||
tunnelType protocol.TunnelType
|
||||
httpClient *http.Client
|
||||
stats *TrafficStats
|
||||
isClosedCheck func() bool // Function to check if connection is closed
|
||||
bufferPool *pool.BufferPool
|
||||
headerPool *pool.HeaderPool // Header pool for Priority 9 optimization
|
||||
}
|
||||
|
||||
// Stream represents a single request/response stream
|
||||
type Stream struct {
|
||||
ID string
|
||||
LocalConn net.Conn
|
||||
ResponseCh chan []byte
|
||||
Done chan struct{}
|
||||
}
|
||||
|
||||
func NewFrameHandler(conn net.Conn, frameWriter *protocol.FrameWriter, localHost string, localPort int, tunnelType protocol.TunnelType, logger *zap.Logger, isClosedCheck func() bool, bufferPool *pool.BufferPool) *FrameHandler {
|
||||
var tlsConfig *tls.Config
|
||||
if tunnelType == protocol.TunnelTypeHTTPS {
|
||||
tlsConfig = &tls.Config{
|
||||
InsecureSkipVerify: true,
|
||||
}
|
||||
}
|
||||
|
||||
return &FrameHandler{
|
||||
conn: conn,
|
||||
frameWriter: frameWriter,
|
||||
localHost: localHost,
|
||||
localPort: localPort,
|
||||
logger: logger,
|
||||
streams: make(map[string]*Stream),
|
||||
tunnelType: tunnelType,
|
||||
stats: NewTrafficStats(),
|
||||
isClosedCheck: isClosedCheck,
|
||||
bufferPool: bufferPool,
|
||||
headerPool: pool.NewHeaderPool(),
|
||||
httpClient: &http.Client{
|
||||
Timeout: 30 * time.Second,
|
||||
Transport: &http.Transport{
|
||||
MaxIdleConns: 500,
|
||||
MaxIdleConnsPerHost: 200,
|
||||
MaxConnsPerHost: 0,
|
||||
IdleConnTimeout: 180 * time.Second,
|
||||
DisableCompression: true,
|
||||
DisableKeepAlives: false,
|
||||
TLSHandshakeTimeout: 10 * time.Second,
|
||||
TLSClientConfig: tlsConfig,
|
||||
ResponseHeaderTimeout: 15 * time.Second,
|
||||
ExpectContinueTimeout: 1 * time.Second,
|
||||
DialContext: (&net.Dialer{
|
||||
Timeout: 5 * time.Second,
|
||||
KeepAlive: 30 * time.Second,
|
||||
}).DialContext,
|
||||
},
|
||||
CheckRedirect: func(req *http.Request, via []*http.Request) error {
|
||||
return http.ErrUseLastResponse
|
||||
},
|
||||
},
|
||||
}
|
||||
}
|
||||
|
||||
func (h *FrameHandler) HandleDataFrame(frame *protocol.Frame) error {
|
||||
h.stats.AddBytesIn(int64(len(frame.Payload)))
|
||||
h.stats.AddRequest()
|
||||
|
||||
header, data, err := protocol.DecodeDataPayload(frame.Payload)
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to decode data payload: %w", err)
|
||||
}
|
||||
|
||||
if header.Type == "http_request" {
|
||||
return h.handleHTTPFrame(header, data)
|
||||
}
|
||||
|
||||
if header.Type == "close" {
|
||||
h.closeStream(header.StreamID)
|
||||
return nil
|
||||
}
|
||||
|
||||
stream, err := h.getOrCreateStream(header.StreamID)
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to get stream: %w", err)
|
||||
}
|
||||
|
||||
h.forwardToLocal(stream, data)
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func (h *FrameHandler) getOrCreateStream(streamID string) (*Stream, error) {
|
||||
h.streamMu.Lock()
|
||||
defer h.streamMu.Unlock()
|
||||
|
||||
if stream, ok := h.streams[streamID]; ok {
|
||||
return stream, nil
|
||||
}
|
||||
|
||||
localAddr := fmt.Sprintf("%s:%d", h.localHost, h.localPort)
|
||||
localConn, err := net.DialTimeout("tcp", localAddr, 5*time.Second)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to connect to local service: %w", err)
|
||||
}
|
||||
|
||||
stream := &Stream{
|
||||
ID: streamID,
|
||||
LocalConn: localConn,
|
||||
ResponseCh: make(chan []byte, 10),
|
||||
Done: make(chan struct{}),
|
||||
}
|
||||
|
||||
h.streams[streamID] = stream
|
||||
|
||||
go h.handleLocalResponse(stream)
|
||||
|
||||
return stream, nil
|
||||
}
|
||||
|
||||
func (h *FrameHandler) forwardToLocal(stream *Stream, data []byte) {
|
||||
if _, err := stream.LocalConn.Write(data); err != nil {
|
||||
h.logger.Error("Failed to write to local service",
|
||||
zap.String("stream_id", stream.ID),
|
||||
zap.Error(err),
|
||||
)
|
||||
h.closeStream(stream.ID)
|
||||
}
|
||||
}
|
||||
|
||||
func (h *FrameHandler) handleLocalResponse(stream *Stream) {
|
||||
defer h.closeStream(stream.ID)
|
||||
|
||||
bufPtr := h.bufferPool.Get(pool.SizeMedium)
|
||||
defer h.bufferPool.Put(bufPtr)
|
||||
buf := (*bufPtr)[:pool.SizeMedium]
|
||||
|
||||
for {
|
||||
n, err := stream.LocalConn.Read(buf)
|
||||
if err != nil {
|
||||
break
|
||||
}
|
||||
|
||||
if n > 0 {
|
||||
if h.isClosedCheck != nil && h.isClosedCheck() {
|
||||
break
|
||||
}
|
||||
|
||||
header := protocol.DataHeader{
|
||||
StreamID: stream.ID,
|
||||
Type: "response",
|
||||
IsLast: false,
|
||||
}
|
||||
|
||||
payload, err := protocol.EncodeDataPayload(header, buf[:n])
|
||||
if err != nil {
|
||||
h.logger.Error("Encode payload failed", zap.Error(err))
|
||||
break
|
||||
}
|
||||
|
||||
dataFrame := protocol.NewFrame(protocol.FrameTypeData, payload)
|
||||
err = h.frameWriter.WriteFrame(dataFrame)
|
||||
if err != nil {
|
||||
h.logger.Error("Send frame failed", zap.Error(err))
|
||||
break
|
||||
}
|
||||
|
||||
h.stats.AddBytesOut(int64(len(payload)))
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func (h *FrameHandler) handleHTTPFrame(header protocol.DataHeader, payload []byte) error {
|
||||
if h.tunnelType != protocol.TunnelTypeHTTP && h.tunnelType != protocol.TunnelTypeHTTPS {
|
||||
return nil
|
||||
}
|
||||
|
||||
httpReq, err := protocol.DecodeHTTPRequest(payload)
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to decode HTTP request: %w", err)
|
||||
}
|
||||
|
||||
targetURL := httpReq.URL
|
||||
if !strings.HasPrefix(targetURL, "http://") && !strings.HasPrefix(targetURL, "https://") {
|
||||
scheme := "http"
|
||||
if h.tunnelType == protocol.TunnelTypeHTTPS {
|
||||
scheme = "https"
|
||||
}
|
||||
targetURL = fmt.Sprintf("%s://%s:%d%s", scheme, h.localHost, h.localPort, targetURL)
|
||||
}
|
||||
|
||||
req, err := http.NewRequest(httpReq.Method, targetURL, bytes.NewReader(httpReq.Body))
|
||||
if err != nil {
|
||||
return h.sendHTTPError(header.StreamID, header.RequestID, http.StatusBadGateway, fmt.Sprintf("build request: %v", err))
|
||||
}
|
||||
|
||||
origHost := ""
|
||||
for key, values := range httpReq.Headers {
|
||||
for _, value := range values {
|
||||
req.Header.Add(key, value)
|
||||
}
|
||||
}
|
||||
if host := req.Header.Get("Host"); host != "" {
|
||||
origHost = host
|
||||
}
|
||||
|
||||
isLocalTarget := h.isLocalAddress(h.localHost)
|
||||
|
||||
if isLocalTarget {
|
||||
if origHost != "" {
|
||||
req.Host = origHost
|
||||
req.Header.Set("Host", origHost)
|
||||
} else {
|
||||
localHostPort := fmt.Sprintf("%s:%d", h.localHost, h.localPort)
|
||||
req.Host = localHostPort
|
||||
req.Header.Set("Host", localHostPort)
|
||||
}
|
||||
if origHost != "" {
|
||||
req.Header.Set("X-Forwarded-Host", origHost)
|
||||
}
|
||||
} else {
|
||||
targetHost := h.localHost
|
||||
if h.localPort != 443 && h.localPort != 80 {
|
||||
targetHost = fmt.Sprintf("%s:%d", h.localHost, h.localPort)
|
||||
}
|
||||
req.Host = targetHost
|
||||
req.Header.Set("Host", targetHost)
|
||||
if origHost != "" {
|
||||
req.Header.Set("X-Forwarded-Host", origHost)
|
||||
}
|
||||
}
|
||||
req.Header.Set("X-Forwarded-Proto", "https")
|
||||
|
||||
resp, err := h.httpClient.Do(req)
|
||||
if err != nil {
|
||||
return h.sendHTTPError(header.StreamID, header.RequestID, http.StatusBadGateway, fmt.Sprintf("local request failed: %v", err))
|
||||
}
|
||||
defer resp.Body.Close()
|
||||
|
||||
body, err := io.ReadAll(resp.Body)
|
||||
if err != nil {
|
||||
return h.sendHTTPError(header.StreamID, header.RequestID, http.StatusBadGateway, fmt.Sprintf("read response: %v", err))
|
||||
}
|
||||
|
||||
httpResp := protocol.HTTPResponse{
|
||||
StatusCode: resp.StatusCode,
|
||||
Status: resp.Status,
|
||||
Headers: resp.Header,
|
||||
Body: body,
|
||||
}
|
||||
|
||||
return h.sendHTTPResponse(header.StreamID, header.RequestID, &httpResp)
|
||||
}
|
||||
|
||||
func (h *FrameHandler) sendHTTPError(streamID, requestID string, status int, message string) error {
|
||||
headers := h.headerPool.Get()
|
||||
headers.Set("Content-Type", "text/plain")
|
||||
|
||||
httpResp := protocol.HTTPResponse{
|
||||
StatusCode: status,
|
||||
Status: http.StatusText(status),
|
||||
Headers: headers,
|
||||
Body: []byte(message),
|
||||
}
|
||||
|
||||
err := h.sendHTTPResponse(streamID, requestID, &httpResp)
|
||||
|
||||
h.headerPool.Put(headers)
|
||||
|
||||
return err
|
||||
}
|
||||
|
||||
func (h *FrameHandler) sendHTTPResponse(streamID, requestID string, resp *protocol.HTTPResponse) error {
|
||||
if h.isClosedCheck != nil && h.isClosedCheck() {
|
||||
return nil
|
||||
}
|
||||
|
||||
header := protocol.DataHeader{
|
||||
StreamID: streamID,
|
||||
RequestID: requestID,
|
||||
Type: "http_response",
|
||||
IsLast: true,
|
||||
}
|
||||
|
||||
respBytes, err := protocol.EncodeHTTPResponse(resp)
|
||||
if err != nil {
|
||||
return fmt.Errorf("encode http response: %w", err)
|
||||
}
|
||||
|
||||
payload, err := protocol.EncodeDataPayload(header, respBytes)
|
||||
if err != nil {
|
||||
return fmt.Errorf("encode payload: %w", err)
|
||||
}
|
||||
|
||||
dataFrame := protocol.NewFrame(protocol.FrameTypeData, payload)
|
||||
|
||||
h.stats.AddBytesOut(int64(len(payload)))
|
||||
|
||||
return h.frameWriter.WriteFrame(dataFrame)
|
||||
}
|
||||
|
||||
func (h *FrameHandler) closeStream(streamID string) {
|
||||
h.streamMu.Lock()
|
||||
defer h.streamMu.Unlock()
|
||||
|
||||
stream, ok := h.streams[streamID]
|
||||
if !ok {
|
||||
return
|
||||
}
|
||||
|
||||
if stream.LocalConn != nil {
|
||||
stream.LocalConn.Close()
|
||||
}
|
||||
|
||||
close(stream.Done)
|
||||
|
||||
delete(h.streams, streamID)
|
||||
|
||||
if h.isClosedCheck != nil && h.isClosedCheck() {
|
||||
return
|
||||
}
|
||||
|
||||
header := protocol.DataHeader{
|
||||
StreamID: streamID,
|
||||
RequestID: streamID,
|
||||
Type: "close",
|
||||
IsLast: true,
|
||||
}
|
||||
|
||||
payload, err := protocol.EncodeDataPayload(header, nil)
|
||||
if err != nil {
|
||||
return
|
||||
}
|
||||
|
||||
closeFrame := protocol.NewFrame(protocol.FrameTypeData, payload)
|
||||
|
||||
h.frameWriter.WriteFrame(closeFrame)
|
||||
}
|
||||
|
||||
// Close closes all streams
|
||||
func (h *FrameHandler) Close() {
|
||||
h.streamMu.Lock()
|
||||
defer h.streamMu.Unlock()
|
||||
|
||||
for streamID, stream := range h.streams {
|
||||
if stream.LocalConn != nil {
|
||||
stream.LocalConn.Close()
|
||||
}
|
||||
close(stream.Done)
|
||||
delete(h.streams, streamID)
|
||||
}
|
||||
}
|
||||
|
||||
// GetStats returns the traffic stats tracker
|
||||
func (h *FrameHandler) GetStats() *TrafficStats {
|
||||
return h.stats
|
||||
}
|
||||
|
||||
func (h *FrameHandler) WarmupConnectionPool(numConnections int) {
|
||||
if h.tunnelType != protocol.TunnelTypeHTTP {
|
||||
return
|
||||
}
|
||||
|
||||
targetURL := fmt.Sprintf("http://%s:%d/", h.localHost, h.localPort)
|
||||
|
||||
var wg sync.WaitGroup
|
||||
for i := 0; i < numConnections; i++ {
|
||||
wg.Add(1)
|
||||
go func() {
|
||||
defer wg.Done()
|
||||
|
||||
req, err := http.NewRequest("HEAD", targetURL, nil)
|
||||
if err != nil {
|
||||
return
|
||||
}
|
||||
|
||||
resp, err := h.httpClient.Do(req)
|
||||
if err != nil {
|
||||
return
|
||||
}
|
||||
defer resp.Body.Close()
|
||||
io.Copy(io.Discard, resp.Body)
|
||||
}()
|
||||
}
|
||||
|
||||
wg.Wait()
|
||||
}
|
||||
|
||||
func (h *FrameHandler) isLocalAddress(addr string) bool {
|
||||
if addr == "localhost" || addr == "127.0.0.1" || addr == "::1" {
|
||||
return true
|
||||
}
|
||||
|
||||
if strings.HasPrefix(addr, "192.168.") ||
|
||||
strings.HasPrefix(addr, "10.") ||
|
||||
strings.HasPrefix(addr, "172.16.") ||
|
||||
strings.HasPrefix(addr, "172.17.") ||
|
||||
strings.HasPrefix(addr, "172.18.") ||
|
||||
strings.HasPrefix(addr, "172.19.") ||
|
||||
strings.HasPrefix(addr, "172.20.") ||
|
||||
strings.HasPrefix(addr, "172.21.") ||
|
||||
strings.HasPrefix(addr, "172.22.") ||
|
||||
strings.HasPrefix(addr, "172.23.") ||
|
||||
strings.HasPrefix(addr, "172.24.") ||
|
||||
strings.HasPrefix(addr, "172.25.") ||
|
||||
strings.HasPrefix(addr, "172.26.") ||
|
||||
strings.HasPrefix(addr, "172.27.") ||
|
||||
strings.HasPrefix(addr, "172.28.") ||
|
||||
strings.HasPrefix(addr, "172.29.") ||
|
||||
strings.HasPrefix(addr, "172.30.") ||
|
||||
strings.HasPrefix(addr, "172.31.") {
|
||||
return true
|
||||
}
|
||||
|
||||
return false
|
||||
}
|
||||
227
internal/client/tcp/stats.go
Normal file
227
internal/client/tcp/stats.go
Normal file
@@ -0,0 +1,227 @@
|
||||
package tcp
|
||||
|
||||
import (
|
||||
"sync"
|
||||
"sync/atomic"
|
||||
"time"
|
||||
)
|
||||
|
||||
// TrafficStats tracks traffic statistics for a tunnel connection
|
||||
type TrafficStats struct {
|
||||
// Total bytes
|
||||
totalBytesIn int64
|
||||
totalBytesOut int64
|
||||
|
||||
// Request counts
|
||||
totalRequests int64
|
||||
|
||||
// For speed calculation
|
||||
lastBytesIn int64
|
||||
lastBytesOut int64
|
||||
lastTime time.Time
|
||||
speedMu sync.Mutex
|
||||
|
||||
// Current speed (bytes per second)
|
||||
speedIn int64
|
||||
speedOut int64
|
||||
|
||||
// Start time
|
||||
startTime time.Time
|
||||
}
|
||||
|
||||
// NewTrafficStats creates a new traffic stats tracker
|
||||
func NewTrafficStats() *TrafficStats {
|
||||
now := time.Now()
|
||||
return &TrafficStats{
|
||||
startTime: now,
|
||||
lastTime: now,
|
||||
}
|
||||
}
|
||||
|
||||
// AddBytesIn adds incoming bytes to the counter
|
||||
func (s *TrafficStats) AddBytesIn(n int64) {
|
||||
atomic.AddInt64(&s.totalBytesIn, n)
|
||||
}
|
||||
|
||||
// AddBytesOut adds outgoing bytes to the counter
|
||||
func (s *TrafficStats) AddBytesOut(n int64) {
|
||||
atomic.AddInt64(&s.totalBytesOut, n)
|
||||
}
|
||||
|
||||
// AddRequest increments the request counter
|
||||
func (s *TrafficStats) AddRequest() {
|
||||
atomic.AddInt64(&s.totalRequests, 1)
|
||||
}
|
||||
|
||||
// GetTotalBytesIn returns total incoming bytes
|
||||
func (s *TrafficStats) GetTotalBytesIn() int64 {
|
||||
return atomic.LoadInt64(&s.totalBytesIn)
|
||||
}
|
||||
|
||||
// GetTotalBytesOut returns total outgoing bytes
|
||||
func (s *TrafficStats) GetTotalBytesOut() int64 {
|
||||
return atomic.LoadInt64(&s.totalBytesOut)
|
||||
}
|
||||
|
||||
// GetTotalRequests returns total request count
|
||||
func (s *TrafficStats) GetTotalRequests() int64 {
|
||||
return atomic.LoadInt64(&s.totalRequests)
|
||||
}
|
||||
|
||||
// GetTotalBytes returns total bytes (in + out)
|
||||
func (s *TrafficStats) GetTotalBytes() int64 {
|
||||
return s.GetTotalBytesIn() + s.GetTotalBytesOut()
|
||||
}
|
||||
|
||||
// UpdateSpeed calculates current transfer speed
|
||||
// Should be called periodically (e.g., every second)
|
||||
func (s *TrafficStats) UpdateSpeed() {
|
||||
s.speedMu.Lock()
|
||||
defer s.speedMu.Unlock()
|
||||
|
||||
now := time.Now()
|
||||
elapsed := now.Sub(s.lastTime).Seconds()
|
||||
if elapsed < 0.1 {
|
||||
return // Avoid division by zero or too frequent updates
|
||||
}
|
||||
|
||||
currentIn := atomic.LoadInt64(&s.totalBytesIn)
|
||||
currentOut := atomic.LoadInt64(&s.totalBytesOut)
|
||||
|
||||
deltaIn := currentIn - s.lastBytesIn
|
||||
deltaOut := currentOut - s.lastBytesOut
|
||||
|
||||
s.speedIn = int64(float64(deltaIn) / elapsed)
|
||||
s.speedOut = int64(float64(deltaOut) / elapsed)
|
||||
|
||||
s.lastBytesIn = currentIn
|
||||
s.lastBytesOut = currentOut
|
||||
s.lastTime = now
|
||||
}
|
||||
|
||||
// GetSpeedIn returns current incoming speed in bytes per second
|
||||
func (s *TrafficStats) GetSpeedIn() int64 {
|
||||
s.speedMu.Lock()
|
||||
defer s.speedMu.Unlock()
|
||||
return s.speedIn
|
||||
}
|
||||
|
||||
// GetSpeedOut returns current outgoing speed in bytes per second
|
||||
func (s *TrafficStats) GetSpeedOut() int64 {
|
||||
s.speedMu.Lock()
|
||||
defer s.speedMu.Unlock()
|
||||
return s.speedOut
|
||||
}
|
||||
|
||||
// GetUptime returns how long the connection has been active
|
||||
func (s *TrafficStats) GetUptime() time.Duration {
|
||||
return time.Since(s.startTime)
|
||||
}
|
||||
|
||||
// Snapshot returns a snapshot of all stats
|
||||
type StatsSnapshot struct {
|
||||
TotalBytesIn int64
|
||||
TotalBytesOut int64
|
||||
TotalBytes int64
|
||||
TotalRequests int64
|
||||
SpeedIn int64 // bytes per second
|
||||
SpeedOut int64 // bytes per second
|
||||
Uptime time.Duration
|
||||
}
|
||||
|
||||
// GetSnapshot returns a snapshot of current stats
|
||||
func (s *TrafficStats) GetSnapshot() StatsSnapshot {
|
||||
s.speedMu.Lock()
|
||||
speedIn := s.speedIn
|
||||
speedOut := s.speedOut
|
||||
s.speedMu.Unlock()
|
||||
|
||||
totalIn := atomic.LoadInt64(&s.totalBytesIn)
|
||||
totalOut := atomic.LoadInt64(&s.totalBytesOut)
|
||||
|
||||
return StatsSnapshot{
|
||||
TotalBytesIn: totalIn,
|
||||
TotalBytesOut: totalOut,
|
||||
TotalBytes: totalIn + totalOut,
|
||||
TotalRequests: atomic.LoadInt64(&s.totalRequests),
|
||||
SpeedIn: speedIn,
|
||||
SpeedOut: speedOut,
|
||||
Uptime: time.Since(s.startTime),
|
||||
}
|
||||
}
|
||||
|
||||
// FormatBytes formats bytes to human readable string
|
||||
func FormatBytes(bytes int64) string {
|
||||
const (
|
||||
KB = 1024
|
||||
MB = KB * 1024
|
||||
GB = MB * 1024
|
||||
)
|
||||
|
||||
switch {
|
||||
case bytes >= GB:
|
||||
return formatFloat(float64(bytes)/float64(GB)) + " GB"
|
||||
case bytes >= MB:
|
||||
return formatFloat(float64(bytes)/float64(MB)) + " MB"
|
||||
case bytes >= KB:
|
||||
return formatFloat(float64(bytes)/float64(KB)) + " KB"
|
||||
default:
|
||||
return formatInt(bytes) + " B"
|
||||
}
|
||||
}
|
||||
|
||||
// FormatSpeed formats speed (bytes per second) to human readable string
|
||||
func FormatSpeed(bytesPerSec int64) string {
|
||||
if bytesPerSec == 0 {
|
||||
return "0 B/s"
|
||||
}
|
||||
return FormatBytes(bytesPerSec) + "/s"
|
||||
}
|
||||
|
||||
func formatFloat(f float64) string {
|
||||
if f >= 100 {
|
||||
return formatInt(int64(f))
|
||||
} else if f >= 10 {
|
||||
return formatOneDecimal(f)
|
||||
}
|
||||
return formatTwoDecimal(f)
|
||||
}
|
||||
|
||||
func formatInt(i int64) string {
|
||||
return intToStr(i)
|
||||
}
|
||||
|
||||
func formatOneDecimal(f float64) string {
|
||||
i := int64(f * 10)
|
||||
whole := i / 10
|
||||
frac := i % 10
|
||||
return intToStr(whole) + "." + intToStr(frac)
|
||||
}
|
||||
|
||||
func formatTwoDecimal(f float64) string {
|
||||
i := int64(f * 100)
|
||||
whole := i / 100
|
||||
frac := i % 100
|
||||
if frac < 10 {
|
||||
return intToStr(whole) + ".0" + intToStr(frac)
|
||||
}
|
||||
return intToStr(whole) + "." + intToStr(frac)
|
||||
}
|
||||
|
||||
func intToStr(i int64) string {
|
||||
if i == 0 {
|
||||
return "0"
|
||||
}
|
||||
if i < 0 {
|
||||
return "-" + intToStr(-i)
|
||||
}
|
||||
|
||||
var buf [20]byte
|
||||
pos := len(buf)
|
||||
for i > 0 {
|
||||
pos--
|
||||
buf[pos] = byte('0' + i%10)
|
||||
i /= 10
|
||||
}
|
||||
return string(buf[pos:])
|
||||
}
|
||||
331
internal/server/proxy/handler.go
Normal file
331
internal/server/proxy/handler.go
Normal file
@@ -0,0 +1,331 @@
|
||||
package proxy
|
||||
|
||||
import (
|
||||
"context"
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"io"
|
||||
"net/http"
|
||||
"net/url"
|
||||
"strings"
|
||||
"time"
|
||||
|
||||
"drip/internal/server/tunnel"
|
||||
"drip/internal/shared/constants"
|
||||
"drip/internal/shared/pool"
|
||||
"drip/internal/shared/protocol"
|
||||
"drip/internal/shared/utils"
|
||||
|
||||
"go.uber.org/zap"
|
||||
)
|
||||
|
||||
type Handler struct {
|
||||
manager *tunnel.Manager
|
||||
logger *zap.Logger
|
||||
responses *ResponseHandler
|
||||
domain string
|
||||
authToken string
|
||||
headerPool *pool.HeaderPool
|
||||
}
|
||||
|
||||
func NewHandler(manager *tunnel.Manager, logger *zap.Logger, responses *ResponseHandler, domain string, authToken string) *Handler {
|
||||
return &Handler{
|
||||
manager: manager,
|
||||
logger: logger,
|
||||
responses: responses,
|
||||
domain: domain,
|
||||
authToken: authToken,
|
||||
headerPool: pool.NewHeaderPool(),
|
||||
}
|
||||
}
|
||||
|
||||
func (h *Handler) ServeHTTP(w http.ResponseWriter, r *http.Request) {
|
||||
subdomain := h.extractSubdomain(r.Host)
|
||||
|
||||
if subdomain == "" {
|
||||
h.serveHomePage(w, r)
|
||||
return
|
||||
}
|
||||
|
||||
conn, ok := h.manager.Get(subdomain)
|
||||
if !ok {
|
||||
http.Error(w, "Tunnel not found. The tunnel may have been closed.", http.StatusNotFound)
|
||||
return
|
||||
}
|
||||
|
||||
if conn.IsClosed() {
|
||||
http.Error(w, "Tunnel connection closed", http.StatusBadGateway)
|
||||
return
|
||||
}
|
||||
|
||||
transport := conn.GetTransport()
|
||||
if transport == nil {
|
||||
http.Error(w, "Tunnel control channel not ready", http.StatusBadGateway)
|
||||
return
|
||||
}
|
||||
|
||||
tType := conn.GetTunnelType()
|
||||
if tType != "" && tType != protocol.TunnelTypeHTTP && tType != protocol.TunnelTypeHTTPS {
|
||||
http.Error(w, "Tunnel does not accept HTTP traffic", http.StatusBadGateway)
|
||||
return
|
||||
}
|
||||
|
||||
limitedReader := io.LimitReader(r.Body, constants.MaxRequestBodySize)
|
||||
body, err := io.ReadAll(limitedReader)
|
||||
if err != nil {
|
||||
h.logger.Error("Read request body failed", zap.Error(err))
|
||||
http.Error(w, "Failed to read request body", http.StatusInternalServerError)
|
||||
return
|
||||
}
|
||||
defer r.Body.Close()
|
||||
|
||||
requestID := utils.GenerateID()
|
||||
|
||||
headers := h.headerPool.Get()
|
||||
h.headerPool.CloneWithExtra(headers, r.Header, "Host", r.Host)
|
||||
|
||||
httpReq := protocol.HTTPRequest{
|
||||
Method: r.Method,
|
||||
URL: r.URL.String(),
|
||||
Headers: headers,
|
||||
Body: body,
|
||||
}
|
||||
|
||||
reqBytes, err := protocol.EncodeHTTPRequest(&httpReq)
|
||||
|
||||
h.headerPool.Put(headers)
|
||||
|
||||
if err != nil {
|
||||
h.logger.Error("Encode HTTP request failed", zap.Error(err))
|
||||
http.Error(w, "Internal server error", http.StatusInternalServerError)
|
||||
return
|
||||
}
|
||||
|
||||
header := protocol.DataHeader{
|
||||
StreamID: requestID,
|
||||
RequestID: requestID,
|
||||
Type: "http_request",
|
||||
IsLast: true,
|
||||
}
|
||||
|
||||
payload, err := protocol.EncodeDataPayload(header, reqBytes)
|
||||
if err != nil {
|
||||
h.logger.Error("Encode data payload failed", zap.Error(err))
|
||||
http.Error(w, "Internal server error", http.StatusInternalServerError)
|
||||
return
|
||||
}
|
||||
|
||||
frame := protocol.NewFrame(protocol.FrameTypeData, payload)
|
||||
|
||||
respChan := h.responses.CreateResponseChan(requestID)
|
||||
defer h.responses.CleanupResponseChan(requestID)
|
||||
|
||||
if err := transport.SendFrame(frame); err != nil {
|
||||
h.logger.Error("Send frame to tunnel failed", zap.Error(err))
|
||||
http.Error(w, "Failed to forward request to tunnel", http.StatusBadGateway)
|
||||
return
|
||||
}
|
||||
|
||||
ctx, cancel := context.WithTimeout(context.Background(), constants.RequestTimeout)
|
||||
defer cancel()
|
||||
|
||||
select {
|
||||
case respMsg := <-respChan:
|
||||
h.writeHTTPResponse(w, respMsg, subdomain, r)
|
||||
|
||||
case <-ctx.Done():
|
||||
http.Error(w, "Request timeout - the tunnel client did not respond in time", http.StatusGatewayTimeout)
|
||||
}
|
||||
}
|
||||
|
||||
func (h *Handler) writeHTTPResponse(w http.ResponseWriter, resp *protocol.HTTPResponse, subdomain string, r *http.Request) {
|
||||
if resp == nil {
|
||||
http.Error(w, "Invalid response from tunnel", http.StatusBadGateway)
|
||||
return
|
||||
}
|
||||
|
||||
for key, values := range resp.Headers {
|
||||
if key == "Connection" || key == "Keep-Alive" || key == "Transfer-Encoding" || key == "Upgrade" {
|
||||
continue
|
||||
}
|
||||
|
||||
if key == "Location" && len(values) > 0 {
|
||||
rewrittenLocation := h.rewriteLocationHeader(values[0], r.Host)
|
||||
w.Header().Set("Location", rewrittenLocation)
|
||||
continue
|
||||
}
|
||||
|
||||
for _, value := range values {
|
||||
w.Header().Add(key, value)
|
||||
}
|
||||
}
|
||||
|
||||
if w.Header().Get("Content-Length") == "" && len(resp.Body) > 0 {
|
||||
w.Header().Set("Content-Length", fmt.Sprintf("%d", len(resp.Body)))
|
||||
}
|
||||
|
||||
statusCode := resp.StatusCode
|
||||
if statusCode == 0 {
|
||||
statusCode = http.StatusOK
|
||||
}
|
||||
|
||||
w.WriteHeader(statusCode)
|
||||
if len(resp.Body) > 0 {
|
||||
w.Write(resp.Body)
|
||||
}
|
||||
}
|
||||
|
||||
func (h *Handler) rewriteLocationHeader(location, proxyHost string) string {
|
||||
if !strings.HasPrefix(location, "http://") && !strings.HasPrefix(location, "https://") {
|
||||
return location
|
||||
}
|
||||
|
||||
locationURL, err := url.Parse(location)
|
||||
if err != nil {
|
||||
return location
|
||||
}
|
||||
|
||||
if locationURL.Host == "localhost" ||
|
||||
strings.HasPrefix(locationURL.Host, "localhost:") ||
|
||||
locationURL.Host == "127.0.0.1" ||
|
||||
strings.HasPrefix(locationURL.Host, "127.0.0.1:") {
|
||||
scheme := "https"
|
||||
if strings.Contains(proxyHost, ":") && !strings.Contains(proxyHost, "https") {
|
||||
parts := strings.Split(proxyHost, ":")
|
||||
if len(parts) == 2 && parts[1] != "443" {
|
||||
scheme = "https"
|
||||
}
|
||||
}
|
||||
|
||||
rewritten := fmt.Sprintf("%s://%s%s", scheme, proxyHost, locationURL.Path)
|
||||
if locationURL.RawQuery != "" {
|
||||
rewritten += "?" + locationURL.RawQuery
|
||||
}
|
||||
if locationURL.Fragment != "" {
|
||||
rewritten += "#" + locationURL.Fragment
|
||||
}
|
||||
|
||||
return rewritten
|
||||
}
|
||||
|
||||
return location
|
||||
}
|
||||
|
||||
func (h *Handler) extractSubdomain(host string) string {
|
||||
if idx := strings.Index(host, ":"); idx != -1 {
|
||||
host = host[:idx]
|
||||
}
|
||||
|
||||
if host == h.domain {
|
||||
return ""
|
||||
}
|
||||
|
||||
suffix := "." + h.domain
|
||||
if strings.HasSuffix(host, suffix) {
|
||||
subdomain := strings.TrimSuffix(host, suffix)
|
||||
return subdomain
|
||||
}
|
||||
|
||||
return ""
|
||||
}
|
||||
|
||||
func (h *Handler) serveHomePage(w http.ResponseWriter, r *http.Request) {
|
||||
if r.URL.Path == "/health" {
|
||||
h.serveHealth(w, r)
|
||||
return
|
||||
}
|
||||
|
||||
if r.URL.Path == "/stats" {
|
||||
h.serveStats(w, r)
|
||||
return
|
||||
}
|
||||
|
||||
html := `<!DOCTYPE html>
|
||||
<html>
|
||||
<head>
|
||||
<meta charset="UTF-8" />
|
||||
<title>Drip - Tunnel to Localhost</title>
|
||||
<style>
|
||||
body { font-family: Arial, sans-serif; max-width: 800px; margin: 50px auto; padding: 20px; }
|
||||
h1 { color: #333; }
|
||||
code { background: #f4f4f4; padding: 2px 6px; border-radius: 3px; }
|
||||
.stats { background: #f9f9f9; padding: 15px; border-radius: 5px; margin: 20px 0; }
|
||||
</style>
|
||||
</head>
|
||||
<body>
|
||||
<h1>💧 Drip - Fast Tunnels to Localhost</h1>
|
||||
<p>A high-performance tunneling service.</p>
|
||||
|
||||
<h2>Quick Start</h2>
|
||||
<p>Install the client:</p>
|
||||
<code>bash <(curl -fsSL https:///install.sh)</code>
|
||||
|
||||
<p>Start a tunnel:</p>
|
||||
<code>drip http 3000</code><br><br>
|
||||
<code>drip https 443</code><br><br>
|
||||
<code>drip tcp 5432</code>
|
||||
<p><a href="/health">Health Check</a> | <a href="/stats">Statistics</a></p>
|
||||
</body>
|
||||
</html>`
|
||||
|
||||
w.Header().Set("Content-Type", "text/html")
|
||||
w.Write([]byte(html))
|
||||
}
|
||||
|
||||
func (h *Handler) serveHealth(w http.ResponseWriter, r *http.Request) {
|
||||
health := map[string]interface{}{
|
||||
"status": "ok",
|
||||
"active_tunnels": h.manager.Count(),
|
||||
"timestamp": time.Now().Unix(),
|
||||
}
|
||||
|
||||
w.Header().Set("Content-Type", "application/json")
|
||||
json.NewEncoder(w).Encode(health)
|
||||
}
|
||||
|
||||
func cloneHeadersWithHost(src http.Header, host string) http.Header {
|
||||
dst := make(http.Header, len(src)+1)
|
||||
for k, v := range src {
|
||||
copied := make([]string, len(v))
|
||||
copy(copied, v)
|
||||
dst[k] = copied
|
||||
}
|
||||
if host != "" {
|
||||
dst.Set("Host", host)
|
||||
}
|
||||
return dst
|
||||
}
|
||||
|
||||
func (h *Handler) serveStats(w http.ResponseWriter, r *http.Request) {
|
||||
if h.authToken != "" {
|
||||
token := r.URL.Query().Get("token")
|
||||
if token == "" {
|
||||
authHeader := r.Header.Get("Authorization")
|
||||
if strings.HasPrefix(authHeader, "Bearer ") {
|
||||
token = strings.TrimPrefix(authHeader, "Bearer ")
|
||||
}
|
||||
}
|
||||
|
||||
if token != h.authToken {
|
||||
http.Error(w, "Unauthorized: invalid or missing token", http.StatusUnauthorized)
|
||||
return
|
||||
}
|
||||
}
|
||||
|
||||
connections := h.manager.List()
|
||||
|
||||
stats := map[string]interface{}{
|
||||
"total_tunnels": len(connections),
|
||||
"tunnels": []map[string]interface{}{},
|
||||
}
|
||||
|
||||
for _, conn := range connections {
|
||||
stats["tunnels"] = append(stats["tunnels"].([]map[string]interface{}), map[string]interface{}{
|
||||
"subdomain": conn.Subdomain,
|
||||
"last_active": conn.LastActive.Unix(),
|
||||
})
|
||||
}
|
||||
|
||||
w.Header().Set("Content-Type", "application/json")
|
||||
json.NewEncoder(w).Encode(stats)
|
||||
}
|
||||
159
internal/server/proxy/response_handler.go
Normal file
159
internal/server/proxy/response_handler.go
Normal file
@@ -0,0 +1,159 @@
|
||||
package proxy
|
||||
|
||||
import (
|
||||
"sync"
|
||||
"time"
|
||||
|
||||
"drip/internal/shared/protocol"
|
||||
"go.uber.org/zap"
|
||||
)
|
||||
|
||||
// responseChanEntry holds a response channel and its creation time
|
||||
type responseChanEntry struct {
|
||||
ch chan *protocol.HTTPResponse
|
||||
createdAt time.Time
|
||||
}
|
||||
|
||||
// ResponseHandler manages response channels for HTTP requests over TCP/Frame protocol
|
||||
type ResponseHandler struct {
|
||||
channels map[string]*responseChanEntry
|
||||
mu sync.RWMutex
|
||||
logger *zap.Logger
|
||||
stopCh chan struct{}
|
||||
}
|
||||
|
||||
// NewResponseHandler creates a new response handler
|
||||
func NewResponseHandler(logger *zap.Logger) *ResponseHandler {
|
||||
h := &ResponseHandler{
|
||||
channels: make(map[string]*responseChanEntry),
|
||||
logger: logger,
|
||||
stopCh: make(chan struct{}),
|
||||
}
|
||||
|
||||
// Start single cleanup goroutine instead of one per request
|
||||
go h.cleanupLoop()
|
||||
|
||||
return h
|
||||
}
|
||||
|
||||
// CreateResponseChan creates a response channel for a request ID
|
||||
func (h *ResponseHandler) CreateResponseChan(requestID string) chan *protocol.HTTPResponse {
|
||||
h.mu.Lock()
|
||||
defer h.mu.Unlock()
|
||||
|
||||
ch := make(chan *protocol.HTTPResponse, 1)
|
||||
h.channels[requestID] = &responseChanEntry{
|
||||
ch: ch,
|
||||
createdAt: time.Now(),
|
||||
}
|
||||
|
||||
return ch
|
||||
}
|
||||
|
||||
// GetResponseChan gets the response channel for a request ID
|
||||
func (h *ResponseHandler) GetResponseChan(requestID string) <-chan *protocol.HTTPResponse {
|
||||
h.mu.RLock()
|
||||
defer h.mu.RUnlock()
|
||||
if entry := h.channels[requestID]; entry != nil {
|
||||
return entry.ch
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
// SendResponse sends a response to the waiting channel
|
||||
func (h *ResponseHandler) SendResponse(requestID string, resp *protocol.HTTPResponse) {
|
||||
h.mu.RLock()
|
||||
entry, exists := h.channels[requestID]
|
||||
h.mu.RUnlock()
|
||||
|
||||
if !exists || entry == nil {
|
||||
h.logger.Warn("Response channel not found",
|
||||
zap.String("request_id", requestID),
|
||||
)
|
||||
return
|
||||
}
|
||||
|
||||
select {
|
||||
case entry.ch <- resp:
|
||||
h.logger.Debug("Response sent to channel",
|
||||
zap.String("request_id", requestID),
|
||||
)
|
||||
case <-time.After(5 * time.Second):
|
||||
h.logger.Warn("Timeout sending response to channel",
|
||||
zap.String("request_id", requestID),
|
||||
)
|
||||
}
|
||||
}
|
||||
|
||||
// CleanupResponseChan removes and closes a response channel
|
||||
func (h *ResponseHandler) CleanupResponseChan(requestID string) {
|
||||
h.mu.Lock()
|
||||
defer h.mu.Unlock()
|
||||
|
||||
if entry, exists := h.channels[requestID]; exists {
|
||||
close(entry.ch)
|
||||
delete(h.channels, requestID)
|
||||
}
|
||||
}
|
||||
|
||||
// GetPendingCount returns the number of pending responses
|
||||
func (h *ResponseHandler) GetPendingCount() int {
|
||||
h.mu.RLock()
|
||||
defer h.mu.RUnlock()
|
||||
return len(h.channels)
|
||||
}
|
||||
|
||||
// cleanupLoop periodically cleans up expired response channels
|
||||
// This replaces the per-request goroutine approach with a single cleanup goroutine
|
||||
func (h *ResponseHandler) cleanupLoop() {
|
||||
ticker := time.NewTicker(5 * time.Second)
|
||||
defer ticker.Stop()
|
||||
|
||||
for {
|
||||
select {
|
||||
case <-ticker.C:
|
||||
h.cleanupExpiredChannels()
|
||||
case <-h.stopCh:
|
||||
return
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// cleanupExpiredChannels removes channels older than 30 seconds
|
||||
func (h *ResponseHandler) cleanupExpiredChannels() {
|
||||
now := time.Now()
|
||||
timeout := 30 * time.Second
|
||||
|
||||
h.mu.Lock()
|
||||
defer h.mu.Unlock()
|
||||
|
||||
expiredCount := 0
|
||||
for requestID, entry := range h.channels {
|
||||
if now.Sub(entry.createdAt) > timeout {
|
||||
close(entry.ch)
|
||||
delete(h.channels, requestID)
|
||||
expiredCount++
|
||||
}
|
||||
}
|
||||
|
||||
if expiredCount > 0 {
|
||||
h.logger.Debug("Cleaned up expired response channels",
|
||||
zap.Int("count", expiredCount),
|
||||
zap.Int("remaining", len(h.channels)),
|
||||
)
|
||||
}
|
||||
}
|
||||
|
||||
// Close stops the cleanup loop
|
||||
func (h *ResponseHandler) Close() {
|
||||
close(h.stopCh)
|
||||
|
||||
h.mu.Lock()
|
||||
defer h.mu.Unlock()
|
||||
|
||||
// Close all remaining channels
|
||||
for _, entry := range h.channels {
|
||||
close(entry.ch)
|
||||
}
|
||||
h.channels = make(map[string]*responseChanEntry)
|
||||
}
|
||||
588
internal/server/tcp/connection.go
Normal file
588
internal/server/tcp/connection.go
Normal file
@@ -0,0 +1,588 @@
|
||||
package tcp
|
||||
|
||||
import (
|
||||
"bufio"
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"io"
|
||||
"net"
|
||||
"net/http"
|
||||
"strings"
|
||||
"sync"
|
||||
"time"
|
||||
|
||||
"drip/internal/server/tunnel"
|
||||
"drip/internal/shared/constants"
|
||||
"drip/internal/shared/protocol"
|
||||
"go.uber.org/zap"
|
||||
)
|
||||
|
||||
// Connection represents a client TCP connection
|
||||
type Connection struct {
|
||||
conn net.Conn
|
||||
authToken string
|
||||
manager *tunnel.Manager
|
||||
logger *zap.Logger
|
||||
subdomain string
|
||||
port int
|
||||
domain string
|
||||
publicPort int
|
||||
portAlloc *PortAllocator
|
||||
tunnelConn *tunnel.Connection
|
||||
proxy *TunnelProxy
|
||||
stopCh chan struct{}
|
||||
once sync.Once
|
||||
lastHeartbeat time.Time
|
||||
mu sync.RWMutex
|
||||
frameWriter *protocol.FrameWriter
|
||||
httpHandler http.Handler
|
||||
responseChans HTTPResponseHandler
|
||||
tunnelType protocol.TunnelType // Track tunnel type
|
||||
}
|
||||
|
||||
// HTTPResponseHandler interface for response channel operations
|
||||
type HTTPResponseHandler interface {
|
||||
CreateResponseChan(requestID string) chan *protocol.HTTPResponse
|
||||
GetResponseChan(requestID string) <-chan *protocol.HTTPResponse
|
||||
CleanupResponseChan(requestID string)
|
||||
SendResponse(requestID string, resp *protocol.HTTPResponse)
|
||||
}
|
||||
|
||||
// NewConnection creates a new connection handler
|
||||
func NewConnection(conn net.Conn, authToken string, manager *tunnel.Manager, logger *zap.Logger, portAlloc *PortAllocator, domain string, publicPort int, httpHandler http.Handler, responseChans HTTPResponseHandler) *Connection {
|
||||
return &Connection{
|
||||
conn: conn,
|
||||
authToken: authToken,
|
||||
manager: manager,
|
||||
logger: logger,
|
||||
portAlloc: portAlloc,
|
||||
domain: domain,
|
||||
publicPort: publicPort,
|
||||
httpHandler: httpHandler,
|
||||
responseChans: responseChans,
|
||||
stopCh: make(chan struct{}),
|
||||
lastHeartbeat: time.Now(),
|
||||
}
|
||||
}
|
||||
|
||||
// Handle handles the connection lifecycle
|
||||
func (c *Connection) Handle() error {
|
||||
// Ensure cleanup of control connection, proxy, port, and registry on exit.
|
||||
defer c.Close()
|
||||
|
||||
// Set initial read timeout for protocol detection
|
||||
c.conn.SetReadDeadline(time.Now().Add(30 * time.Second))
|
||||
|
||||
// Use buffered reader to support peeking
|
||||
reader := bufio.NewReader(c.conn)
|
||||
|
||||
// Peek first 8 bytes to detect protocol
|
||||
peek, err := reader.Peek(8)
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to peek connection: %w", err)
|
||||
}
|
||||
|
||||
// Check if this is an HTTP request
|
||||
peekStr := string(peek)
|
||||
httpMethods := []string{"GET ", "POST", "PUT ", "DELE", "HEAD", "OPTI", "PATC", "CONN", "TRAC"}
|
||||
isHTTP := false
|
||||
for _, method := range httpMethods {
|
||||
if strings.HasPrefix(peekStr, method) {
|
||||
isHTTP = true
|
||||
break
|
||||
}
|
||||
}
|
||||
|
||||
if isHTTP {
|
||||
c.logger.Info("Detected HTTP request on TCP port, handling as HTTP")
|
||||
return c.handleHTTPRequest(reader)
|
||||
}
|
||||
|
||||
// Continue with drip protocol
|
||||
// Wait for registration frame
|
||||
frame, err := protocol.ReadFrame(reader)
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to read registration frame: %w", err)
|
||||
}
|
||||
defer frame.Release() // Return pool buffer when done
|
||||
|
||||
if frame.Type != protocol.FrameTypeRegister {
|
||||
return fmt.Errorf("expected register frame, got %s", frame.Type)
|
||||
}
|
||||
|
||||
// Parse registration request
|
||||
var req protocol.RegisterRequest
|
||||
if err := json.Unmarshal(frame.Payload, &req); err != nil {
|
||||
return fmt.Errorf("failed to parse registration request: %w", err)
|
||||
}
|
||||
|
||||
// Store tunnel type
|
||||
c.tunnelType = req.TunnelType
|
||||
|
||||
// Authenticate
|
||||
if c.authToken != "" && req.Token != c.authToken {
|
||||
c.sendError("authentication_failed", "Invalid authentication token")
|
||||
return fmt.Errorf("authentication failed")
|
||||
}
|
||||
|
||||
// Allocate TCP port only for TCP tunnels
|
||||
if req.TunnelType == protocol.TunnelTypeTCP {
|
||||
if c.portAlloc == nil {
|
||||
return fmt.Errorf("port allocator not configured")
|
||||
}
|
||||
|
||||
port, err := c.portAlloc.Allocate()
|
||||
if err != nil {
|
||||
c.sendError("port_allocation_failed", err.Error())
|
||||
return fmt.Errorf("failed to allocate port: %w", err)
|
||||
}
|
||||
c.port = port
|
||||
|
||||
// For TCP tunnels, prefer deterministic subdomain tied to port when not provided by client.
|
||||
if req.CustomSubdomain == "" {
|
||||
req.CustomSubdomain = fmt.Sprintf("tcp-%d", port)
|
||||
}
|
||||
}
|
||||
|
||||
// Register tunnel
|
||||
subdomain, err := c.manager.Register(nil, req.CustomSubdomain)
|
||||
if err != nil {
|
||||
c.sendError("registration_failed", err.Error())
|
||||
c.portAlloc.Release(c.port)
|
||||
c.port = 0
|
||||
return fmt.Errorf("tunnel registration failed: %w", err)
|
||||
}
|
||||
|
||||
c.subdomain = subdomain
|
||||
|
||||
// Get tunnel connection
|
||||
tunnelConn, ok := c.manager.Get(subdomain)
|
||||
if !ok {
|
||||
return fmt.Errorf("failed to get registered tunnel")
|
||||
}
|
||||
c.tunnelConn = tunnelConn
|
||||
|
||||
// Store TCP connection reference and metadata for HTTP proxy routing
|
||||
c.tunnelConn.Conn = nil // We're using TCP, not WebSocket
|
||||
c.tunnelConn.SetTransport(c, req.TunnelType)
|
||||
c.tunnelConn.SetTunnelType(req.TunnelType)
|
||||
c.tunnelType = req.TunnelType
|
||||
|
||||
c.logger.Info("Tunnel registered",
|
||||
zap.String("subdomain", subdomain),
|
||||
zap.String("tunnel_type", string(req.TunnelType)),
|
||||
zap.Int("local_port", req.LocalPort),
|
||||
zap.Int("remote_port", c.port),
|
||||
)
|
||||
|
||||
// Send registration acknowledgment
|
||||
// Generate appropriate URL based on tunnel type
|
||||
var tunnelURL string
|
||||
|
||||
if req.TunnelType == protocol.TunnelTypeHTTP || req.TunnelType == protocol.TunnelTypeHTTPS {
|
||||
// HTTP/HTTPS tunnels use HTTPS with subdomain
|
||||
// Use publicPort for URL generation (configured via --public-port flag)
|
||||
if c.publicPort == 443 {
|
||||
tunnelURL = fmt.Sprintf("https://%s.%s", subdomain, c.domain)
|
||||
} else {
|
||||
tunnelURL = fmt.Sprintf("https://%s.%s:%d", subdomain, c.domain, c.publicPort)
|
||||
}
|
||||
} else {
|
||||
// TCP tunnels use tcp:// with port
|
||||
tunnelURL = fmt.Sprintf("tcp://%s:%d", c.domain, c.port)
|
||||
}
|
||||
|
||||
resp := protocol.RegisterResponse{
|
||||
Subdomain: subdomain,
|
||||
Port: c.port,
|
||||
URL: tunnelURL,
|
||||
Message: "Tunnel registered successfully",
|
||||
}
|
||||
|
||||
respData, _ := json.Marshal(resp)
|
||||
ackFrame := protocol.NewFrame(protocol.FrameTypeRegisterAck, respData)
|
||||
|
||||
// Send registration ack (sync write before frameWriter is created)
|
||||
err = protocol.WriteFrame(c.conn, ackFrame)
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to send registration ack: %w", err)
|
||||
}
|
||||
|
||||
// Create frame writer for async writes
|
||||
c.frameWriter = protocol.NewFrameWriter(c.conn)
|
||||
|
||||
// Clear read deadline
|
||||
c.conn.SetReadDeadline(time.Time{})
|
||||
|
||||
// Start TCP proxy only for TCP tunnels
|
||||
if req.TunnelType == protocol.TunnelTypeTCP {
|
||||
c.proxy = NewTunnelProxy(c.port, subdomain, c.conn, c.logger)
|
||||
if err := c.proxy.Start(); err != nil {
|
||||
return fmt.Errorf("failed to start TCP proxy: %w", err)
|
||||
}
|
||||
}
|
||||
|
||||
// Start heartbeat checker
|
||||
go c.heartbeatChecker()
|
||||
|
||||
// Handle frames (pass reader for consistent buffering)
|
||||
return c.handleFrames(reader)
|
||||
}
|
||||
|
||||
// handleHTTPRequest handles HTTP requests that arrive on the TCP port
|
||||
func (c *Connection) handleHTTPRequest(reader *bufio.Reader) error {
|
||||
// If no HTTP handler is configured, return error
|
||||
if c.httpHandler == nil {
|
||||
c.logger.Warn("HTTP request received but no HTTP handler configured")
|
||||
response := "HTTP/1.1 503 Service Unavailable\r\n" +
|
||||
"Content-Type: text/plain\r\n" +
|
||||
"Content-Length: 47\r\n" +
|
||||
"\r\n" +
|
||||
"HTTP handler not configured for this TCP port\r\n"
|
||||
c.conn.Write([]byte(response))
|
||||
return fmt.Errorf("HTTP handler not configured")
|
||||
}
|
||||
|
||||
// Clear read deadline for HTTP processing
|
||||
c.conn.SetReadDeadline(time.Time{})
|
||||
|
||||
// Handle multiple HTTP requests on the same connection (HTTP/1.1 keep-alive)
|
||||
for {
|
||||
// Set a read deadline for each request to avoid hanging forever
|
||||
c.conn.SetReadDeadline(time.Now().Add(60 * time.Second))
|
||||
|
||||
// Parse HTTP request
|
||||
req, err := http.ReadRequest(reader)
|
||||
if err != nil {
|
||||
// EOF or timeout is normal when client closes connection or no more requests
|
||||
if err == io.EOF || err == io.ErrUnexpectedEOF {
|
||||
c.logger.Debug("Client closed HTTP connection")
|
||||
return nil
|
||||
}
|
||||
if netErr, ok := err.(net.Error); ok && netErr.Timeout() {
|
||||
c.logger.Debug("HTTP keep-alive timeout")
|
||||
return nil
|
||||
}
|
||||
// Connection reset by peer is normal - client closed connection abruptly
|
||||
errStr := err.Error()
|
||||
if strings.Contains(errStr, "connection reset by peer") ||
|
||||
strings.Contains(errStr, "broken pipe") ||
|
||||
strings.Contains(errStr, "connection refused") {
|
||||
c.logger.Debug("Client disconnected abruptly", zap.Error(err))
|
||||
return nil
|
||||
}
|
||||
c.logger.Error("Failed to parse HTTP request", zap.Error(err))
|
||||
return fmt.Errorf("failed to parse HTTP request: %w", err)
|
||||
}
|
||||
|
||||
c.logger.Info("Processing HTTP request on TCP port",
|
||||
zap.String("method", req.Method),
|
||||
zap.String("url", req.URL.String()),
|
||||
zap.String("host", req.Host),
|
||||
)
|
||||
|
||||
// Create a response writer that writes directly to the connection
|
||||
respWriter := &httpResponseWriter{
|
||||
conn: c.conn,
|
||||
header: make(http.Header),
|
||||
}
|
||||
|
||||
// Handle the request
|
||||
c.httpHandler.ServeHTTP(respWriter, req)
|
||||
|
||||
// Check if we should close the connection
|
||||
// Close if: Connection: close header, or HTTP/1.0 without Connection: keep-alive
|
||||
shouldClose := false
|
||||
if req.Close {
|
||||
shouldClose = true
|
||||
} else if req.ProtoMajor == 1 && req.ProtoMinor == 0 {
|
||||
// HTTP/1.0 defaults to close unless keep-alive is explicitly requested
|
||||
if req.Header.Get("Connection") != "keep-alive" {
|
||||
shouldClose = true
|
||||
}
|
||||
}
|
||||
|
||||
if shouldClose {
|
||||
c.logger.Debug("Closing connection as requested by client")
|
||||
return nil
|
||||
}
|
||||
|
||||
// Continue to next request on the same connection
|
||||
}
|
||||
}
|
||||
|
||||
// handleFrames handles incoming frames
|
||||
func (c *Connection) handleFrames(reader *bufio.Reader) error {
|
||||
for {
|
||||
select {
|
||||
case <-c.stopCh:
|
||||
return nil
|
||||
default:
|
||||
}
|
||||
|
||||
// Read frame with timeout
|
||||
c.conn.SetReadDeadline(time.Now().Add(constants.RequestTimeout))
|
||||
frame, err := protocol.ReadFrame(reader)
|
||||
if err != nil {
|
||||
if netErr, ok := err.(net.Error); ok && netErr.Timeout() {
|
||||
c.logger.Warn("Read timeout, connection may be dead")
|
||||
return fmt.Errorf("read timeout")
|
||||
}
|
||||
// EOF is normal when client closes connection gracefully
|
||||
if err.Error() == "failed to read frame header: EOF" || err.Error() == "EOF" {
|
||||
c.logger.Info("Client disconnected")
|
||||
return nil
|
||||
}
|
||||
// Check if connection was closed (during shutdown)
|
||||
select {
|
||||
case <-c.stopCh:
|
||||
// Connection was closed intentionally, don't log as error
|
||||
c.logger.Debug("Connection closed during shutdown")
|
||||
return nil
|
||||
default:
|
||||
return fmt.Errorf("failed to read frame: %w", err)
|
||||
}
|
||||
}
|
||||
|
||||
// Handle frame based on type
|
||||
switch frame.Type {
|
||||
case protocol.FrameTypeHeartbeat:
|
||||
c.handleHeartbeat()
|
||||
frame.Release()
|
||||
|
||||
case protocol.FrameTypeData:
|
||||
// Data frame from client (response to forwarded request)
|
||||
c.handleDataFrame(frame)
|
||||
frame.Release() // Release after processing
|
||||
|
||||
case protocol.FrameTypeClose:
|
||||
frame.Release()
|
||||
c.logger.Info("Client requested close")
|
||||
return nil
|
||||
|
||||
default:
|
||||
frame.Release()
|
||||
c.logger.Warn("Unexpected frame type",
|
||||
zap.String("type", frame.Type.String()),
|
||||
)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// handleHeartbeat handles heartbeat frame
|
||||
func (c *Connection) handleHeartbeat() {
|
||||
c.mu.Lock()
|
||||
c.lastHeartbeat = time.Now()
|
||||
c.mu.Unlock()
|
||||
|
||||
// Send heartbeat ack
|
||||
ackFrame := protocol.NewFrame(protocol.FrameTypeHeartbeatAck, nil)
|
||||
|
||||
err := c.frameWriter.WriteFrame(ackFrame)
|
||||
if err != nil {
|
||||
c.logger.Error("Failed to send heartbeat ack", zap.Error(err))
|
||||
}
|
||||
}
|
||||
|
||||
// handleDataFrame handles data frame (response from client)
|
||||
func (c *Connection) handleDataFrame(frame *protocol.Frame) {
|
||||
// Decode payload (auto-detects protocol version)
|
||||
header, data, err := protocol.DecodeDataPayload(frame.Payload)
|
||||
if err != nil {
|
||||
c.logger.Error("Failed to decode data payload",
|
||||
zap.Error(err),
|
||||
)
|
||||
return
|
||||
}
|
||||
|
||||
c.logger.Debug("Received data frame",
|
||||
zap.String("stream_id", header.StreamID),
|
||||
zap.String("type", header.Type),
|
||||
zap.Int("data_size", len(data)),
|
||||
)
|
||||
|
||||
switch header.Type {
|
||||
case "response":
|
||||
// TCP tunnel response, forward to proxy
|
||||
if c.proxy != nil {
|
||||
if err := c.proxy.HandleResponse(header.StreamID, data); err != nil {
|
||||
c.logger.Error("Failed to handle response",
|
||||
zap.String("stream_id", header.StreamID),
|
||||
zap.Error(err),
|
||||
)
|
||||
}
|
||||
}
|
||||
case "http_response":
|
||||
if c.responseChans == nil {
|
||||
c.logger.Warn("No response channel handler for HTTP response",
|
||||
zap.String("stream_id", header.StreamID),
|
||||
)
|
||||
return
|
||||
}
|
||||
|
||||
// Decode HTTP response (auto-detects JSON vs msgpack)
|
||||
httpResp, err := protocol.DecodeHTTPResponse(data)
|
||||
if err != nil {
|
||||
c.logger.Error("Failed to decode HTTP response",
|
||||
zap.String("stream_id", header.StreamID),
|
||||
zap.Error(err),
|
||||
)
|
||||
return
|
||||
}
|
||||
|
||||
// Route by request ID when provided to keep request/response aligned.
|
||||
reqID := header.RequestID
|
||||
if reqID == "" {
|
||||
reqID = header.StreamID
|
||||
}
|
||||
|
||||
c.responseChans.SendResponse(reqID, httpResp)
|
||||
|
||||
c.logger.Debug("Routed HTTP response to channel",
|
||||
zap.String("request_id", reqID),
|
||||
)
|
||||
case "close":
|
||||
// Client is closing the stream
|
||||
if c.proxy != nil {
|
||||
c.proxy.CloseStream(header.StreamID)
|
||||
}
|
||||
default:
|
||||
c.logger.Warn("Unknown data frame type",
|
||||
zap.String("type", header.Type),
|
||||
zap.String("stream_id", header.StreamID),
|
||||
)
|
||||
}
|
||||
}
|
||||
|
||||
// heartbeatChecker checks for heartbeat timeout
|
||||
func (c *Connection) heartbeatChecker() {
|
||||
ticker := time.NewTicker(constants.HeartbeatInterval)
|
||||
defer ticker.Stop()
|
||||
|
||||
for {
|
||||
select {
|
||||
case <-c.stopCh:
|
||||
return
|
||||
case <-ticker.C:
|
||||
c.mu.RLock()
|
||||
lastHB := c.lastHeartbeat
|
||||
c.mu.RUnlock()
|
||||
|
||||
if time.Since(lastHB) > constants.HeartbeatTimeout {
|
||||
c.logger.Warn("Heartbeat timeout",
|
||||
zap.String("subdomain", c.subdomain),
|
||||
zap.Duration("last_heartbeat", time.Since(lastHB)),
|
||||
)
|
||||
c.Close()
|
||||
return
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// SendFrame sends a frame to the client
|
||||
func (c *Connection) SendFrame(frame *protocol.Frame) error {
|
||||
if c.frameWriter == nil {
|
||||
return protocol.WriteFrame(c.conn, frame)
|
||||
}
|
||||
return c.frameWriter.WriteFrame(frame)
|
||||
}
|
||||
|
||||
// sendError sends an error frame to the client
|
||||
func (c *Connection) sendError(code, message string) {
|
||||
errMsg := protocol.ErrorMessage{
|
||||
Code: code,
|
||||
Message: message,
|
||||
}
|
||||
data, _ := json.Marshal(errMsg)
|
||||
errFrame := protocol.NewFrame(protocol.FrameTypeError, data)
|
||||
|
||||
if c.frameWriter == nil {
|
||||
// Fallback if frameWriter not initialized (early errors)
|
||||
protocol.WriteFrame(c.conn, errFrame)
|
||||
} else {
|
||||
c.frameWriter.WriteFrame(errFrame)
|
||||
}
|
||||
}
|
||||
|
||||
// Close closes the connection
|
||||
func (c *Connection) Close() {
|
||||
c.once.Do(func() {
|
||||
close(c.stopCh)
|
||||
|
||||
// Close frame writer
|
||||
if c.frameWriter != nil {
|
||||
c.frameWriter.Close()
|
||||
}
|
||||
|
||||
// Stop TCP proxy
|
||||
if c.proxy != nil {
|
||||
c.proxy.Stop()
|
||||
}
|
||||
|
||||
c.conn.Close()
|
||||
|
||||
// Release allocated port
|
||||
if c.port > 0 && c.portAlloc != nil {
|
||||
c.portAlloc.Release(c.port)
|
||||
}
|
||||
|
||||
// Unregister tunnel
|
||||
if c.subdomain != "" {
|
||||
c.manager.Unregister(c.subdomain)
|
||||
}
|
||||
|
||||
c.logger.Info("Connection closed",
|
||||
zap.String("subdomain", c.subdomain),
|
||||
)
|
||||
})
|
||||
}
|
||||
|
||||
// GetSubdomain returns the assigned subdomain
|
||||
func (c *Connection) GetSubdomain() string {
|
||||
return c.subdomain
|
||||
}
|
||||
|
||||
// httpResponseWriter implements http.ResponseWriter for writing to a net.Conn
|
||||
type httpResponseWriter struct {
|
||||
conn net.Conn
|
||||
header http.Header
|
||||
statusCode int
|
||||
headerWritten bool
|
||||
}
|
||||
|
||||
func (w *httpResponseWriter) Header() http.Header {
|
||||
return w.header
|
||||
}
|
||||
|
||||
func (w *httpResponseWriter) WriteHeader(statusCode int) {
|
||||
if w.headerWritten {
|
||||
return
|
||||
}
|
||||
w.statusCode = statusCode
|
||||
w.headerWritten = true
|
||||
|
||||
// Write status line
|
||||
statusText := http.StatusText(statusCode)
|
||||
if statusText == "" {
|
||||
statusText = "Unknown"
|
||||
}
|
||||
fmt.Fprintf(w.conn, "HTTP/1.1 %d %s\r\n", statusCode, statusText)
|
||||
|
||||
// Write headers
|
||||
for key, values := range w.header {
|
||||
for _, value := range values {
|
||||
fmt.Fprintf(w.conn, "%s: %s\r\n", key, value)
|
||||
}
|
||||
}
|
||||
|
||||
// Write empty line to end headers
|
||||
fmt.Fprintf(w.conn, "\r\n")
|
||||
}
|
||||
|
||||
func (w *httpResponseWriter) Write(data []byte) (int, error) {
|
||||
if !w.headerWritten {
|
||||
w.WriteHeader(http.StatusOK)
|
||||
}
|
||||
return w.conn.Write(data)
|
||||
}
|
||||
254
internal/server/tcp/listener.go
Normal file
254
internal/server/tcp/listener.go
Normal file
@@ -0,0 +1,254 @@
|
||||
package tcp
|
||||
|
||||
import (
|
||||
"crypto/tls"
|
||||
"fmt"
|
||||
"net"
|
||||
"net/http"
|
||||
"strings"
|
||||
"sync"
|
||||
"time"
|
||||
|
||||
"drip/internal/server/tunnel"
|
||||
"drip/internal/shared/pool"
|
||||
"go.uber.org/zap"
|
||||
)
|
||||
|
||||
// Listener handles TCP connections with TLS 1.3
|
||||
type Listener struct {
|
||||
address string
|
||||
tlsConfig *tls.Config
|
||||
authToken string
|
||||
manager *tunnel.Manager
|
||||
portAlloc *PortAllocator
|
||||
logger *zap.Logger
|
||||
domain string
|
||||
publicPort int
|
||||
httpHandler http.Handler
|
||||
responseChans HTTPResponseHandler
|
||||
listener net.Listener
|
||||
stopCh chan struct{}
|
||||
wg sync.WaitGroup
|
||||
connections map[string]*Connection
|
||||
connMu sync.RWMutex
|
||||
workerPool *pool.WorkerPool // Worker pool for connection handling
|
||||
}
|
||||
|
||||
// NewListener creates a new TCP listener
|
||||
func NewListener(address string, tlsConfig *tls.Config, authToken string, manager *tunnel.Manager, logger *zap.Logger, portAlloc *PortAllocator, domain string, publicPort int, httpHandler http.Handler, responseChans HTTPResponseHandler) *Listener {
|
||||
// Create worker pool with 50 workers and queue size of 1000
|
||||
// This reduces goroutine creation overhead for connection handling
|
||||
workerPool := pool.NewWorkerPool(50, 1000)
|
||||
|
||||
return &Listener{
|
||||
address: address,
|
||||
tlsConfig: tlsConfig,
|
||||
authToken: authToken,
|
||||
manager: manager,
|
||||
portAlloc: portAlloc,
|
||||
logger: logger,
|
||||
domain: domain,
|
||||
publicPort: publicPort,
|
||||
httpHandler: httpHandler,
|
||||
responseChans: responseChans,
|
||||
stopCh: make(chan struct{}),
|
||||
connections: make(map[string]*Connection),
|
||||
workerPool: workerPool,
|
||||
}
|
||||
}
|
||||
|
||||
// Start starts the TCP listener
|
||||
func (l *Listener) Start() error {
|
||||
var err error
|
||||
|
||||
// Create TLS listener
|
||||
l.listener, err = tls.Listen("tcp", l.address, l.tlsConfig)
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to start TLS listener: %w", err)
|
||||
}
|
||||
|
||||
l.logger.Info("TCP listener started",
|
||||
zap.String("address", l.address),
|
||||
zap.String("tls_version", "TLS 1.3"),
|
||||
)
|
||||
|
||||
// Accept connections in background
|
||||
l.wg.Add(1)
|
||||
go l.acceptLoop()
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// acceptLoop accepts incoming connections
|
||||
func (l *Listener) acceptLoop() {
|
||||
defer l.wg.Done()
|
||||
|
||||
for {
|
||||
select {
|
||||
case <-l.stopCh:
|
||||
return
|
||||
default:
|
||||
}
|
||||
|
||||
// Set accept deadline to allow checking stopCh
|
||||
if tcpListener, ok := l.listener.(*net.TCPListener); ok {
|
||||
tcpListener.SetDeadline(time.Now().Add(1 * time.Second))
|
||||
}
|
||||
|
||||
conn, err := l.listener.Accept()
|
||||
if err != nil {
|
||||
if netErr, ok := err.(net.Error); ok && netErr.Timeout() {
|
||||
continue // Timeout is expected due to deadline
|
||||
}
|
||||
select {
|
||||
case <-l.stopCh:
|
||||
return // Listener was stopped
|
||||
default:
|
||||
l.logger.Error("Failed to accept connection", zap.Error(err))
|
||||
continue
|
||||
}
|
||||
}
|
||||
|
||||
// Handle connection using worker pool instead of creating new goroutine
|
||||
// This reduces goroutine creation overhead and improves performance
|
||||
l.wg.Add(1)
|
||||
submitted := l.workerPool.Submit(func() {
|
||||
l.handleConnection(conn)
|
||||
})
|
||||
|
||||
// If pool is full or closed, fall back to direct goroutine
|
||||
if !submitted {
|
||||
go l.handleConnection(conn)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// handleConnection handles a single client connection
|
||||
func (l *Listener) handleConnection(netConn net.Conn) {
|
||||
defer l.wg.Done()
|
||||
defer netConn.Close()
|
||||
|
||||
// Get TLS connection info
|
||||
tlsConn, ok := netConn.(*tls.Conn)
|
||||
if !ok {
|
||||
l.logger.Error("Connection is not TLS")
|
||||
return
|
||||
}
|
||||
|
||||
// Force TLS handshake to complete
|
||||
if err := tlsConn.Handshake(); err != nil {
|
||||
// TLS handshake failures are common (HTTP clients, scanners, etc.)
|
||||
// Log as WARN instead of ERROR
|
||||
l.logger.Warn("TLS handshake failed",
|
||||
zap.String("remote_addr", netConn.RemoteAddr().String()),
|
||||
zap.Error(err),
|
||||
)
|
||||
return
|
||||
}
|
||||
|
||||
// Log connection info
|
||||
state := tlsConn.ConnectionState()
|
||||
l.logger.Info("New connection",
|
||||
zap.String("remote_addr", netConn.RemoteAddr().String()),
|
||||
zap.Uint16("tls_version", state.Version),
|
||||
zap.String("cipher_suite", tls.CipherSuiteName(state.CipherSuite)),
|
||||
)
|
||||
|
||||
// Verify TLS 1.3
|
||||
if state.Version != tls.VersionTLS13 {
|
||||
l.logger.Warn("Connection not using TLS 1.3",
|
||||
zap.Uint16("version", state.Version),
|
||||
)
|
||||
return
|
||||
}
|
||||
|
||||
// Create connection handler
|
||||
conn := NewConnection(netConn, l.authToken, l.manager, l.logger, l.portAlloc, l.domain, l.publicPort, l.httpHandler, l.responseChans)
|
||||
|
||||
// Store connection
|
||||
connID := netConn.RemoteAddr().String()
|
||||
l.connMu.Lock()
|
||||
l.connections[connID] = conn
|
||||
l.connMu.Unlock()
|
||||
|
||||
// Remove connection on exit
|
||||
defer func() {
|
||||
l.connMu.Lock()
|
||||
delete(l.connections, connID)
|
||||
l.connMu.Unlock()
|
||||
}()
|
||||
|
||||
// Handle connection (blocking)
|
||||
if err := conn.Handle(); err != nil {
|
||||
errStr := err.Error()
|
||||
|
||||
// Client disconnection errors - normal network behavior, log as DEBUG
|
||||
if strings.Contains(errStr, "connection reset by peer") ||
|
||||
strings.Contains(errStr, "broken pipe") ||
|
||||
strings.Contains(errStr, "connection refused") {
|
||||
l.logger.Debug("Client disconnected",
|
||||
zap.String("remote_addr", connID),
|
||||
zap.Error(err),
|
||||
)
|
||||
return
|
||||
}
|
||||
|
||||
// Protocol errors (invalid clients, scanners) are expected - log as WARN
|
||||
if strings.Contains(errStr, "payload too large") ||
|
||||
strings.Contains(errStr, "failed to read registration frame") ||
|
||||
strings.Contains(errStr, "expected register frame") ||
|
||||
strings.Contains(errStr, "failed to parse registration request") ||
|
||||
strings.Contains(errStr, "failed to parse HTTP request") {
|
||||
l.logger.Warn("Protocol validation failed",
|
||||
zap.String("remote_addr", connID),
|
||||
zap.Error(err),
|
||||
)
|
||||
} else {
|
||||
// Legitimate errors (auth failures, registration failures, etc.)
|
||||
l.logger.Error("Connection handling failed",
|
||||
zap.String("remote_addr", connID),
|
||||
zap.Error(err),
|
||||
)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// Stop stops the listener and closes all connections
|
||||
func (l *Listener) Stop() error {
|
||||
l.logger.Info("Stopping TCP listener")
|
||||
|
||||
// Signal stop
|
||||
close(l.stopCh)
|
||||
|
||||
// Close listener
|
||||
if l.listener != nil {
|
||||
if err := l.listener.Close(); err != nil {
|
||||
l.logger.Error("Failed to close listener", zap.Error(err))
|
||||
}
|
||||
}
|
||||
|
||||
// Close all connections
|
||||
l.connMu.Lock()
|
||||
for _, conn := range l.connections {
|
||||
conn.Close()
|
||||
}
|
||||
l.connMu.Unlock()
|
||||
|
||||
// Wait for all goroutines to finish
|
||||
l.wg.Wait()
|
||||
|
||||
// Close worker pool
|
||||
if l.workerPool != nil {
|
||||
l.workerPool.Close()
|
||||
}
|
||||
|
||||
l.logger.Info("TCP listener stopped")
|
||||
return nil
|
||||
}
|
||||
|
||||
// GetActiveConnections returns the number of active connections
|
||||
func (l *Listener) GetActiveConnections() int {
|
||||
l.connMu.RLock()
|
||||
defer l.connMu.RUnlock()
|
||||
return len(l.connections)
|
||||
}
|
||||
79
internal/server/tcp/port_allocator.go
Normal file
79
internal/server/tcp/port_allocator.go
Normal file
@@ -0,0 +1,79 @@
|
||||
package tcp
|
||||
|
||||
import (
|
||||
"crypto/rand"
|
||||
"fmt"
|
||||
"math/big"
|
||||
"net"
|
||||
"sync"
|
||||
)
|
||||
|
||||
// PortAllocator manages dynamic TCP port allocation within a configured range.
|
||||
// It keeps an in-memory reservation map; ports are held until Release is called.
|
||||
type PortAllocator struct {
|
||||
min int
|
||||
max int
|
||||
used map[int]bool
|
||||
mu sync.Mutex
|
||||
}
|
||||
|
||||
// NewPortAllocator creates a new allocator with the given inclusive range.
|
||||
func NewPortAllocator(min, max int) (*PortAllocator, error) {
|
||||
if min <= 0 || max <= 0 || min >= max || max > 65535 {
|
||||
return nil, fmt.Errorf("invalid port range %d-%d", min, max)
|
||||
}
|
||||
|
||||
return &PortAllocator{
|
||||
min: min,
|
||||
max: max,
|
||||
used: make(map[int]bool),
|
||||
}, nil
|
||||
}
|
||||
|
||||
// Allocate finds a free port, marks it as used, and ensures it's currently available.
|
||||
func (p *PortAllocator) Allocate() (int, error) {
|
||||
p.mu.Lock()
|
||||
defer p.mu.Unlock()
|
||||
|
||||
total := p.max - p.min + 1
|
||||
for attempts := 0; attempts < total; attempts++ {
|
||||
port := p.randomPort()
|
||||
if p.used[port] {
|
||||
continue
|
||||
}
|
||||
|
||||
// Probe the port to ensure it's not taken by the OS/other process.
|
||||
ln, err := net.Listen("tcp", fmt.Sprintf("0.0.0.0:%d", port))
|
||||
if err != nil {
|
||||
continue
|
||||
}
|
||||
ln.Close()
|
||||
|
||||
p.used[port] = true
|
||||
return port, nil
|
||||
}
|
||||
|
||||
return 0, fmt.Errorf("no available port in range %d-%d", p.min, p.max)
|
||||
}
|
||||
|
||||
// Release frees a previously allocated port.
|
||||
func (p *PortAllocator) Release(port int) {
|
||||
p.mu.Lock()
|
||||
defer p.mu.Unlock()
|
||||
delete(p.used, port)
|
||||
}
|
||||
|
||||
func (p *PortAllocator) randomPort() int {
|
||||
n := p.max - p.min + 1
|
||||
if n <= 0 {
|
||||
return p.min
|
||||
}
|
||||
|
||||
// crypto/rand for better distribution without needing a global seed.
|
||||
randInt, err := rand.Int(rand.Reader, big.NewInt(int64(n)))
|
||||
if err != nil {
|
||||
return p.min
|
||||
}
|
||||
|
||||
return p.min + int(randInt.Int64())
|
||||
}
|
||||
243
internal/server/tcp/proxy.go
Normal file
243
internal/server/tcp/proxy.go
Normal file
@@ -0,0 +1,243 @@
|
||||
package tcp
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"net"
|
||||
"sync"
|
||||
"time"
|
||||
|
||||
"drip/internal/shared/pool"
|
||||
"drip/internal/shared/protocol"
|
||||
"go.uber.org/zap"
|
||||
)
|
||||
|
||||
// TunnelProxy handles TCP connections for a specific tunnel
|
||||
type TunnelProxy struct {
|
||||
port int
|
||||
subdomain string
|
||||
tcpConn net.Conn // The tunnel control connection
|
||||
listener net.Listener
|
||||
logger *zap.Logger
|
||||
stopCh chan struct{}
|
||||
wg sync.WaitGroup
|
||||
clientAddr string
|
||||
streams map[string]net.Conn // streamID -> external connection
|
||||
streamMu sync.RWMutex
|
||||
frameWriter *protocol.FrameWriter
|
||||
bufferPool *pool.BufferPool
|
||||
}
|
||||
|
||||
// NewTunnelProxy creates a new TCP tunnel proxy
|
||||
func NewTunnelProxy(port int, subdomain string, tcpConn net.Conn, logger *zap.Logger) *TunnelProxy {
|
||||
return &TunnelProxy{
|
||||
port: port,
|
||||
subdomain: subdomain,
|
||||
tcpConn: tcpConn,
|
||||
logger: logger,
|
||||
stopCh: make(chan struct{}),
|
||||
clientAddr: tcpConn.RemoteAddr().String(),
|
||||
streams: make(map[string]net.Conn),
|
||||
bufferPool: pool.NewBufferPool(),
|
||||
frameWriter: protocol.NewFrameWriter(tcpConn),
|
||||
}
|
||||
}
|
||||
|
||||
// Start starts listening on the allocated port
|
||||
func (p *TunnelProxy) Start() error {
|
||||
addr := fmt.Sprintf("0.0.0.0:%d", p.port)
|
||||
|
||||
listener, err := net.Listen("tcp", addr)
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to listen on port %d: %w", p.port, err)
|
||||
}
|
||||
|
||||
p.listener = listener
|
||||
|
||||
p.logger.Info("TCP proxy started",
|
||||
zap.Int("port", p.port),
|
||||
zap.String("subdomain", p.subdomain),
|
||||
)
|
||||
|
||||
// Accept connections in background
|
||||
p.wg.Add(1)
|
||||
go p.acceptLoop()
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// acceptLoop accepts incoming TCP connections
|
||||
func (p *TunnelProxy) acceptLoop() {
|
||||
defer p.wg.Done()
|
||||
|
||||
for {
|
||||
select {
|
||||
case <-p.stopCh:
|
||||
return
|
||||
default:
|
||||
}
|
||||
|
||||
// Set accept deadline
|
||||
p.listener.(*net.TCPListener).SetDeadline(time.Now().Add(1 * time.Second))
|
||||
|
||||
conn, err := p.listener.Accept()
|
||||
if err != nil {
|
||||
if netErr, ok := err.(net.Error); ok && netErr.Timeout() {
|
||||
continue
|
||||
}
|
||||
select {
|
||||
case <-p.stopCh:
|
||||
return
|
||||
default:
|
||||
continue
|
||||
}
|
||||
}
|
||||
|
||||
// Handle connection
|
||||
p.wg.Add(1)
|
||||
go p.handleConnection(conn)
|
||||
}
|
||||
}
|
||||
|
||||
func (p *TunnelProxy) handleConnection(conn net.Conn) {
|
||||
defer p.wg.Done()
|
||||
defer conn.Close()
|
||||
|
||||
streamID := fmt.Sprintf("%d-%d", time.Now().UnixNano(), p.port)
|
||||
|
||||
p.streamMu.Lock()
|
||||
p.streams[streamID] = conn
|
||||
p.streamMu.Unlock()
|
||||
|
||||
defer func() {
|
||||
p.streamMu.Lock()
|
||||
delete(p.streams, streamID)
|
||||
p.streamMu.Unlock()
|
||||
}()
|
||||
|
||||
bufPtr := p.bufferPool.Get(pool.SizeMedium)
|
||||
defer p.bufferPool.Put(bufPtr)
|
||||
|
||||
buffer := (*bufPtr)[:pool.SizeMedium]
|
||||
|
||||
for {
|
||||
n, err := conn.Read(buffer)
|
||||
if err != nil {
|
||||
break
|
||||
}
|
||||
|
||||
if n > 0 {
|
||||
if err := p.sendDataToTunnel(streamID, buffer[:n]); err != nil {
|
||||
p.logger.Error("Send to tunnel failed", zap.Error(err))
|
||||
break
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
select {
|
||||
case <-p.stopCh:
|
||||
default:
|
||||
p.sendCloseToTunnel(streamID)
|
||||
}
|
||||
}
|
||||
|
||||
func (p *TunnelProxy) sendDataToTunnel(streamID string, data []byte) error {
|
||||
select {
|
||||
case <-p.stopCh:
|
||||
return fmt.Errorf("tunnel proxy stopped")
|
||||
default:
|
||||
}
|
||||
|
||||
header := protocol.DataHeader{
|
||||
StreamID: streamID,
|
||||
RequestID: streamID,
|
||||
Type: "data",
|
||||
IsLast: false,
|
||||
}
|
||||
|
||||
payload, err := protocol.EncodeDataPayload(header, data)
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to encode payload: %w", err)
|
||||
}
|
||||
|
||||
frame := protocol.NewFrame(protocol.FrameTypeData, payload)
|
||||
|
||||
err = p.frameWriter.WriteFrame(frame)
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to write frame: %w", err)
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func (p *TunnelProxy) sendCloseToTunnel(streamID string) {
|
||||
header := protocol.DataHeader{
|
||||
StreamID: streamID,
|
||||
RequestID: streamID,
|
||||
Type: "close",
|
||||
IsLast: true,
|
||||
}
|
||||
|
||||
payload, err := protocol.EncodeDataPayload(header, nil)
|
||||
if err != nil {
|
||||
return
|
||||
}
|
||||
|
||||
frame := protocol.NewFrame(protocol.FrameTypeData, payload)
|
||||
p.frameWriter.WriteFrame(frame)
|
||||
}
|
||||
|
||||
func (p *TunnelProxy) HandleResponse(streamID string, data []byte) error {
|
||||
p.streamMu.RLock()
|
||||
conn, ok := p.streams[streamID]
|
||||
p.streamMu.RUnlock()
|
||||
|
||||
if !ok {
|
||||
return fmt.Errorf("stream not found: %s", streamID)
|
||||
}
|
||||
|
||||
if _, err := conn.Write(data); err != nil {
|
||||
p.logger.Error("Write to client failed", zap.Error(err))
|
||||
return err
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// CloseStream closes a stream
|
||||
func (p *TunnelProxy) CloseStream(streamID string) {
|
||||
p.streamMu.RLock()
|
||||
conn, ok := p.streams[streamID]
|
||||
p.streamMu.RUnlock()
|
||||
|
||||
if ok {
|
||||
conn.Close()
|
||||
}
|
||||
}
|
||||
|
||||
func (p *TunnelProxy) Stop() {
|
||||
p.logger.Info("Stopping TCP proxy",
|
||||
zap.Int("port", p.port),
|
||||
zap.String("subdomain", p.subdomain),
|
||||
)
|
||||
|
||||
close(p.stopCh)
|
||||
|
||||
if p.listener != nil {
|
||||
p.listener.Close()
|
||||
}
|
||||
|
||||
p.streamMu.Lock()
|
||||
for _, conn := range p.streams {
|
||||
conn.Close()
|
||||
}
|
||||
p.streams = make(map[string]net.Conn)
|
||||
p.streamMu.Unlock()
|
||||
|
||||
p.wg.Wait()
|
||||
|
||||
if p.frameWriter != nil {
|
||||
p.frameWriter.Close()
|
||||
}
|
||||
|
||||
p.logger.Info("TCP proxy stopped", zap.Int("port", p.port))
|
||||
}
|
||||
73
internal/server/tls/autocert.go
Normal file
73
internal/server/tls/autocert.go
Normal file
@@ -0,0 +1,73 @@
|
||||
package tls
|
||||
|
||||
import (
|
||||
"crypto/tls"
|
||||
"net/http"
|
||||
"os"
|
||||
"path/filepath"
|
||||
|
||||
"go.uber.org/zap"
|
||||
"golang.org/x/crypto/acme/autocert"
|
||||
)
|
||||
|
||||
// AutoCertManager manages automatic certificate provisioning with Let's Encrypt
|
||||
type AutoCertManager struct {
|
||||
manager *autocert.Manager
|
||||
logger *zap.Logger
|
||||
}
|
||||
|
||||
// NewAutoCertManager creates a new AutoCert manager
|
||||
func NewAutoCertManager(domain, cacheDir string, logger *zap.Logger) *AutoCertManager {
|
||||
m := &autocert.Manager{
|
||||
Prompt: autocert.AcceptTOS,
|
||||
HostPolicy: autocert.HostWhitelist(domain, "*."+domain),
|
||||
Cache: autocert.DirCache(cacheDir),
|
||||
}
|
||||
|
||||
logger.Info("AutoTLS enabled",
|
||||
zap.String("domain", domain),
|
||||
zap.String("cache_dir", cacheDir),
|
||||
)
|
||||
|
||||
return &AutoCertManager{
|
||||
manager: m,
|
||||
logger: logger,
|
||||
}
|
||||
}
|
||||
|
||||
// GetTLSConfig returns the TLS configuration
|
||||
func (a *AutoCertManager) GetTLSConfig() *tls.Config {
|
||||
return a.manager.TLSConfig()
|
||||
}
|
||||
|
||||
// HTTPHandler returns the HTTP handler for ACME challenges
|
||||
func (a *AutoCertManager) HTTPHandler() http.Handler {
|
||||
return a.manager.HTTPHandler(nil)
|
||||
}
|
||||
|
||||
// GetCertificate gets a certificate for the given ClientHelloInfo
|
||||
func (a *AutoCertManager) GetCertificate(hello *tls.ClientHelloInfo) (*tls.Certificate, error) {
|
||||
cert, err := a.manager.GetCertificate(hello)
|
||||
if err != nil {
|
||||
a.logger.Error("Failed to get certificate",
|
||||
zap.String("server_name", hello.ServerName),
|
||||
zap.Error(err),
|
||||
)
|
||||
return nil, err
|
||||
}
|
||||
|
||||
a.logger.Debug("Certificate obtained",
|
||||
zap.String("server_name", hello.ServerName),
|
||||
)
|
||||
|
||||
return cert, nil
|
||||
}
|
||||
|
||||
// DefaultCacheDir returns the default cache directory for certificates
|
||||
func DefaultCacheDir() string {
|
||||
home := os.Getenv("HOME")
|
||||
if home == "" {
|
||||
home = "/tmp"
|
||||
}
|
||||
return filepath.Join(home, ".drip", "certs")
|
||||
}
|
||||
192
internal/server/tunnel/connection.go
Normal file
192
internal/server/tunnel/connection.go
Normal file
@@ -0,0 +1,192 @@
|
||||
package tunnel
|
||||
|
||||
import (
|
||||
"sync"
|
||||
"time"
|
||||
|
||||
"drip/internal/shared/protocol"
|
||||
"github.com/gorilla/websocket"
|
||||
"go.uber.org/zap"
|
||||
)
|
||||
|
||||
// Transport represents the control channel to the client.
|
||||
// It is implemented by the TCP control connection so the HTTP proxy
|
||||
// can push frames directly to the client without depending on WebSockets.
|
||||
type Transport interface {
|
||||
SendFrame(frame *protocol.Frame) error
|
||||
}
|
||||
|
||||
// Connection represents a tunnel connection from a client
|
||||
type Connection struct {
|
||||
Subdomain string
|
||||
Conn *websocket.Conn
|
||||
SendCh chan []byte
|
||||
CloseCh chan struct{}
|
||||
LastActive time.Time
|
||||
mu sync.RWMutex
|
||||
logger *zap.Logger
|
||||
closed bool
|
||||
transport Transport
|
||||
tunnelType protocol.TunnelType
|
||||
}
|
||||
|
||||
// NewConnection creates a new tunnel connection
|
||||
func NewConnection(subdomain string, conn *websocket.Conn, logger *zap.Logger) *Connection {
|
||||
return &Connection{
|
||||
Subdomain: subdomain,
|
||||
Conn: conn,
|
||||
SendCh: make(chan []byte, 256),
|
||||
CloseCh: make(chan struct{}),
|
||||
LastActive: time.Now(),
|
||||
logger: logger,
|
||||
closed: false,
|
||||
}
|
||||
}
|
||||
|
||||
// Send sends data through the WebSocket connection
|
||||
func (c *Connection) Send(data []byte) error {
|
||||
c.mu.RLock()
|
||||
if c.closed {
|
||||
c.mu.RUnlock()
|
||||
return ErrConnectionClosed
|
||||
}
|
||||
c.mu.RUnlock()
|
||||
|
||||
select {
|
||||
case c.SendCh <- data:
|
||||
return nil
|
||||
case <-time.After(5 * time.Second):
|
||||
return ErrSendTimeout
|
||||
}
|
||||
}
|
||||
|
||||
// UpdateActivity updates the last activity timestamp
|
||||
func (c *Connection) UpdateActivity() {
|
||||
c.mu.Lock()
|
||||
defer c.mu.Unlock()
|
||||
c.LastActive = time.Now()
|
||||
}
|
||||
|
||||
// IsAlive checks if the connection is still alive based on last activity
|
||||
func (c *Connection) IsAlive(timeout time.Duration) bool {
|
||||
c.mu.RLock()
|
||||
defer c.mu.RUnlock()
|
||||
return time.Since(c.LastActive) < timeout
|
||||
}
|
||||
|
||||
// Close closes the connection and all associated channels
|
||||
func (c *Connection) Close() {
|
||||
c.mu.Lock()
|
||||
defer c.mu.Unlock()
|
||||
|
||||
if c.closed {
|
||||
return
|
||||
}
|
||||
|
||||
c.closed = true
|
||||
close(c.CloseCh)
|
||||
close(c.SendCh)
|
||||
|
||||
if c.Conn != nil {
|
||||
// Send close message
|
||||
c.Conn.WriteMessage(websocket.CloseMessage,
|
||||
websocket.FormatCloseMessage(websocket.CloseNormalClosure, ""))
|
||||
c.Conn.Close()
|
||||
}
|
||||
|
||||
c.logger.Info("Connection closed",
|
||||
zap.String("subdomain", c.Subdomain),
|
||||
)
|
||||
}
|
||||
|
||||
// IsClosed returns whether the connection is closed
|
||||
func (c *Connection) IsClosed() bool {
|
||||
c.mu.RLock()
|
||||
defer c.mu.RUnlock()
|
||||
return c.closed
|
||||
}
|
||||
|
||||
// SetTransport attaches the control transport and tunnel type.
|
||||
func (c *Connection) SetTransport(t Transport, tType protocol.TunnelType) {
|
||||
c.mu.Lock()
|
||||
c.transport = t
|
||||
c.tunnelType = tType
|
||||
c.mu.Unlock()
|
||||
}
|
||||
|
||||
// GetTransport returns the attached transport (if any).
|
||||
func (c *Connection) GetTransport() Transport {
|
||||
c.mu.RLock()
|
||||
defer c.mu.RUnlock()
|
||||
return c.transport
|
||||
}
|
||||
|
||||
// SetTunnelType sets the tunnel type.
|
||||
func (c *Connection) SetTunnelType(tType protocol.TunnelType) {
|
||||
c.mu.Lock()
|
||||
c.tunnelType = tType
|
||||
c.mu.Unlock()
|
||||
}
|
||||
|
||||
// GetTunnelType returns the tunnel type.
|
||||
func (c *Connection) GetTunnelType() protocol.TunnelType {
|
||||
c.mu.RLock()
|
||||
defer c.mu.RUnlock()
|
||||
return c.tunnelType
|
||||
}
|
||||
|
||||
// StartWritePump starts the write pump for sending messages
|
||||
func (c *Connection) StartWritePump() {
|
||||
// Skip write pump for TCP-only connections (no WebSocket)
|
||||
if c.Conn == nil {
|
||||
c.logger.Debug("Skipping WritePump for TCP connection",
|
||||
zap.String("subdomain", c.Subdomain),
|
||||
)
|
||||
// Still need to drain SendCh to prevent blocking
|
||||
go func() {
|
||||
for {
|
||||
select {
|
||||
case <-c.SendCh:
|
||||
// Discard messages for TCP mode
|
||||
case <-c.CloseCh:
|
||||
return
|
||||
}
|
||||
}
|
||||
}()
|
||||
return
|
||||
}
|
||||
|
||||
ticker := time.NewTicker(30 * time.Second)
|
||||
defer func() {
|
||||
ticker.Stop()
|
||||
c.Close()
|
||||
}()
|
||||
|
||||
for {
|
||||
select {
|
||||
case message, ok := <-c.SendCh:
|
||||
if !ok {
|
||||
return
|
||||
}
|
||||
|
||||
c.Conn.SetWriteDeadline(time.Now().Add(10 * time.Second))
|
||||
if err := c.Conn.WriteMessage(websocket.TextMessage, message); err != nil {
|
||||
c.logger.Error("Write error",
|
||||
zap.String("subdomain", c.Subdomain),
|
||||
zap.Error(err),
|
||||
)
|
||||
return
|
||||
}
|
||||
|
||||
case <-ticker.C:
|
||||
// Send ping to keep connection alive
|
||||
c.Conn.SetWriteDeadline(time.Now().Add(10 * time.Second))
|
||||
if err := c.Conn.WriteMessage(websocket.PingMessage, nil); err != nil {
|
||||
return
|
||||
}
|
||||
|
||||
case <-c.CloseCh:
|
||||
return
|
||||
}
|
||||
}
|
||||
}
|
||||
263
internal/server/tunnel/connection_test.go
Normal file
263
internal/server/tunnel/connection_test.go
Normal file
@@ -0,0 +1,263 @@
|
||||
package tunnel
|
||||
|
||||
import (
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"go.uber.org/zap"
|
||||
)
|
||||
|
||||
func TestNewConnection(t *testing.T) {
|
||||
subdomain := "test123"
|
||||
logger := zap.NewNop()
|
||||
|
||||
// We can't create a real WebSocket connection in unit tests,
|
||||
// so we'll just test with nil
|
||||
conn := NewConnection(subdomain, nil, logger)
|
||||
|
||||
if conn == nil {
|
||||
t.Fatal("NewConnection() returned nil")
|
||||
}
|
||||
|
||||
if conn.Subdomain != subdomain {
|
||||
t.Errorf("Subdomain = %v, want %v", conn.Subdomain, subdomain)
|
||||
}
|
||||
|
||||
if conn.SendCh == nil {
|
||||
t.Error("SendCh is nil")
|
||||
}
|
||||
|
||||
if conn.CloseCh == nil {
|
||||
t.Error("CloseCh is nil")
|
||||
}
|
||||
|
||||
// Check that LastActive is recent (within last second)
|
||||
now := time.Now()
|
||||
if now.Sub(conn.LastActive) > time.Second {
|
||||
t.Error("LastActive is not recent")
|
||||
}
|
||||
}
|
||||
|
||||
func TestConnectionUpdateActivity(t *testing.T) {
|
||||
logger := zap.NewNop()
|
||||
conn := NewConnection("test", nil, logger)
|
||||
|
||||
// Get initial LastActive
|
||||
initial := conn.LastActive
|
||||
|
||||
// Wait a bit
|
||||
time.Sleep(10 * time.Millisecond)
|
||||
|
||||
// Update activity
|
||||
conn.UpdateActivity()
|
||||
|
||||
// Check that LastActive was updated
|
||||
if !conn.LastActive.After(initial) {
|
||||
t.Error("UpdateActivity() did not update LastActive")
|
||||
}
|
||||
|
||||
// Check that it's recent
|
||||
now := time.Now()
|
||||
if now.Sub(conn.LastActive) > time.Second {
|
||||
t.Error("UpdateActivity() did not set recent timestamp")
|
||||
}
|
||||
}
|
||||
|
||||
func TestConnectionIsAlive(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
lastActive time.Time
|
||||
timeout time.Duration
|
||||
want bool
|
||||
}{
|
||||
{
|
||||
name: "fresh connection is alive",
|
||||
lastActive: time.Now(),
|
||||
timeout: 90 * time.Second,
|
||||
want: true,
|
||||
},
|
||||
{
|
||||
name: "stale connection is not alive",
|
||||
lastActive: time.Now().Add(-2 * time.Minute),
|
||||
timeout: 90 * time.Second,
|
||||
want: false,
|
||||
},
|
||||
{
|
||||
name: "exactly at timeout is not alive",
|
||||
lastActive: time.Now().Add(-90 * time.Second),
|
||||
timeout: 90 * time.Second,
|
||||
want: false,
|
||||
},
|
||||
{
|
||||
name: "just before timeout is alive",
|
||||
lastActive: time.Now().Add(-89 * time.Second),
|
||||
timeout: 90 * time.Second,
|
||||
want: true,
|
||||
},
|
||||
}
|
||||
logger := zap.NewNop()
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
conn := NewConnection("test", nil, logger)
|
||||
conn.mu.Lock()
|
||||
conn.LastActive = tt.lastActive
|
||||
conn.mu.Unlock()
|
||||
|
||||
got := conn.IsAlive(tt.timeout)
|
||||
if got != tt.want {
|
||||
t.Errorf("IsAlive() = %v, want %v (age: %v, timeout: %v)",
|
||||
got, tt.want, time.Since(tt.lastActive), tt.timeout)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestConnectionSend(t *testing.T) {
|
||||
logger := zap.NewNop()
|
||||
conn := NewConnection("test", nil, logger)
|
||||
|
||||
data := []byte("test message")
|
||||
|
||||
// Test successful send
|
||||
err := conn.Send(data)
|
||||
if err != nil {
|
||||
t.Errorf("Send() error = %v, want nil", err)
|
||||
}
|
||||
|
||||
// Verify data was sent to channel
|
||||
select {
|
||||
case received := <-conn.SendCh:
|
||||
if string(received) != string(data) {
|
||||
t.Errorf("Received data = %v, want %v", string(received), string(data))
|
||||
}
|
||||
case <-time.After(100 * time.Millisecond):
|
||||
t.Error("Send() did not send data to channel")
|
||||
}
|
||||
}
|
||||
|
||||
func TestConnectionSendTimeout(t *testing.T) {
|
||||
logger := zap.NewNop()
|
||||
conn := NewConnection("test", nil, logger)
|
||||
|
||||
// Fill the channel
|
||||
for i := 0; i < 256; i++ {
|
||||
conn.SendCh <- []byte("fill")
|
||||
}
|
||||
|
||||
// Try to send when channel is full
|
||||
data := []byte("test message")
|
||||
err := conn.Send(data)
|
||||
|
||||
if err != ErrSendTimeout {
|
||||
t.Errorf("Send() on full channel error = %v, want %v", err, ErrSendTimeout)
|
||||
}
|
||||
}
|
||||
|
||||
func TestConnectionClose(t *testing.T) {
|
||||
logger := zap.NewNop()
|
||||
conn := NewConnection("test", nil, logger)
|
||||
|
||||
// Close the connection
|
||||
conn.Close()
|
||||
|
||||
// Verify CloseCh is closed
|
||||
select {
|
||||
case <-conn.CloseCh:
|
||||
// Successfully received from closed channel
|
||||
case <-time.After(100 * time.Millisecond):
|
||||
t.Error("Close() did not close CloseCh")
|
||||
}
|
||||
|
||||
// Try to close again (should not panic)
|
||||
defer func() {
|
||||
if r := recover(); r != nil {
|
||||
t.Error("Close() panicked on second call")
|
||||
}
|
||||
}()
|
||||
conn.Close()
|
||||
}
|
||||
|
||||
func TestConnectionConcurrentUpdateActivity(t *testing.T) {
|
||||
logger := zap.NewNop()
|
||||
conn := NewConnection("test", nil, logger)
|
||||
|
||||
// Update activity concurrently
|
||||
done := make(chan bool)
|
||||
for i := 0; i < 100; i++ {
|
||||
go func() {
|
||||
conn.UpdateActivity()
|
||||
done <- true
|
||||
}()
|
||||
}
|
||||
|
||||
// Wait for all goroutines
|
||||
for i := 0; i < 100; i++ {
|
||||
<-done
|
||||
}
|
||||
|
||||
// Verify LastActive is recent
|
||||
now := time.Now()
|
||||
if now.Sub(conn.LastActive) > time.Second {
|
||||
t.Error("Concurrent UpdateActivity() failed")
|
||||
}
|
||||
}
|
||||
|
||||
func TestConnectionConcurrentIsAlive(t *testing.T) {
|
||||
logger := zap.NewNop()
|
||||
conn := NewConnection("test", nil, logger)
|
||||
|
||||
// Check IsAlive concurrently
|
||||
done := make(chan bool)
|
||||
for i := 0; i < 100; i++ {
|
||||
go func() {
|
||||
conn.IsAlive(90 * time.Second)
|
||||
done <- true
|
||||
}()
|
||||
}
|
||||
|
||||
// Wait for all goroutines
|
||||
for i := 0; i < 100; i++ {
|
||||
<-done
|
||||
}
|
||||
}
|
||||
|
||||
// Benchmark tests
|
||||
func BenchmarkConnectionSend(b *testing.B) {
|
||||
logger := zap.NewNop()
|
||||
conn := NewConnection("test", nil, logger)
|
||||
|
||||
// Drain channel in background
|
||||
go func() {
|
||||
for range conn.SendCh {
|
||||
}
|
||||
}()
|
||||
|
||||
data := []byte("test message")
|
||||
b.ResetTimer()
|
||||
|
||||
for i := 0; i < b.N; i++ {
|
||||
conn.Send(data)
|
||||
}
|
||||
}
|
||||
|
||||
func BenchmarkConnectionUpdateActivity(b *testing.B) {
|
||||
logger := zap.NewNop()
|
||||
conn := NewConnection("test", nil, logger)
|
||||
|
||||
b.ResetTimer()
|
||||
for i := 0; i < b.N; i++ {
|
||||
conn.UpdateActivity()
|
||||
}
|
||||
}
|
||||
|
||||
func BenchmarkConnectionIsAlive(b *testing.B) {
|
||||
logger := zap.NewNop()
|
||||
conn := NewConnection("test", nil, logger)
|
||||
timeout := 90 * time.Second
|
||||
|
||||
b.ResetTimer()
|
||||
for i := 0; i < b.N; i++ {
|
||||
conn.IsAlive(timeout)
|
||||
}
|
||||
}
|
||||
23
internal/server/tunnel/errors.go
Normal file
23
internal/server/tunnel/errors.go
Normal file
@@ -0,0 +1,23 @@
|
||||
package tunnel
|
||||
|
||||
import "errors"
|
||||
|
||||
var (
|
||||
// ErrConnectionClosed is returned when trying to use a closed connection
|
||||
ErrConnectionClosed = errors.New("connection is closed")
|
||||
|
||||
// ErrSendTimeout is returned when send operation times out
|
||||
ErrSendTimeout = errors.New("send operation timed out")
|
||||
|
||||
// ErrTunnelNotFound is returned when a tunnel is not found
|
||||
ErrTunnelNotFound = errors.New("tunnel not found")
|
||||
|
||||
// ErrSubdomainTaken is returned when a subdomain is already in use
|
||||
ErrSubdomainTaken = errors.New("subdomain is already taken")
|
||||
|
||||
// ErrInvalidSubdomain is returned when a subdomain is invalid
|
||||
ErrInvalidSubdomain = errors.New("invalid subdomain format")
|
||||
|
||||
// ErrReservedSubdomain is returned when trying to use a reserved subdomain
|
||||
ErrReservedSubdomain = errors.New("subdomain is reserved")
|
||||
)
|
||||
185
internal/server/tunnel/manager.go
Normal file
185
internal/server/tunnel/manager.go
Normal file
@@ -0,0 +1,185 @@
|
||||
package tunnel
|
||||
|
||||
import (
|
||||
"sync"
|
||||
"time"
|
||||
|
||||
"drip/internal/shared/utils"
|
||||
"github.com/gorilla/websocket"
|
||||
"go.uber.org/zap"
|
||||
)
|
||||
|
||||
// Manager manages all active tunnel connections
|
||||
type Manager struct {
|
||||
tunnels map[string]*Connection // subdomain -> connection
|
||||
mu sync.RWMutex
|
||||
used map[string]bool // track used subdomains
|
||||
logger *zap.Logger
|
||||
}
|
||||
|
||||
// NewManager creates a new tunnel manager
|
||||
func NewManager(logger *zap.Logger) *Manager {
|
||||
return &Manager{
|
||||
tunnels: make(map[string]*Connection),
|
||||
used: make(map[string]bool),
|
||||
logger: logger,
|
||||
}
|
||||
}
|
||||
|
||||
// Register registers a new tunnel connection
|
||||
// Returns the assigned subdomain and any error
|
||||
func (m *Manager) Register(conn *websocket.Conn, customSubdomain string) (string, error) {
|
||||
m.mu.Lock()
|
||||
defer m.mu.Unlock()
|
||||
|
||||
var subdomain string
|
||||
|
||||
if customSubdomain != "" {
|
||||
// Validate custom subdomain
|
||||
if !utils.ValidateSubdomain(customSubdomain) {
|
||||
return "", ErrInvalidSubdomain
|
||||
}
|
||||
if utils.IsReserved(customSubdomain) {
|
||||
return "", ErrReservedSubdomain
|
||||
}
|
||||
if m.used[customSubdomain] {
|
||||
return "", ErrSubdomainTaken
|
||||
}
|
||||
subdomain = customSubdomain
|
||||
} else {
|
||||
// Generate unique random subdomain
|
||||
subdomain = m.generateUniqueSubdomain()
|
||||
}
|
||||
|
||||
// Create connection
|
||||
tc := NewConnection(subdomain, conn, m.logger)
|
||||
m.tunnels[subdomain] = tc
|
||||
m.used[subdomain] = true
|
||||
|
||||
// Start write pump in background
|
||||
go tc.StartWritePump()
|
||||
|
||||
m.logger.Info("Tunnel registered",
|
||||
zap.String("subdomain", subdomain),
|
||||
zap.Int("total_tunnels", len(m.tunnels)),
|
||||
)
|
||||
|
||||
return subdomain, nil
|
||||
}
|
||||
|
||||
// Unregister removes a tunnel connection
|
||||
func (m *Manager) Unregister(subdomain string) {
|
||||
m.mu.Lock()
|
||||
defer m.mu.Unlock()
|
||||
|
||||
if tc, ok := m.tunnels[subdomain]; ok {
|
||||
tc.Close()
|
||||
delete(m.tunnels, subdomain)
|
||||
delete(m.used, subdomain)
|
||||
|
||||
m.logger.Info("Tunnel unregistered",
|
||||
zap.String("subdomain", subdomain),
|
||||
zap.Int("total_tunnels", len(m.tunnels)),
|
||||
)
|
||||
}
|
||||
}
|
||||
|
||||
// Get retrieves a tunnel connection by subdomain
|
||||
func (m *Manager) Get(subdomain string) (*Connection, bool) {
|
||||
m.mu.RLock()
|
||||
defer m.mu.RUnlock()
|
||||
|
||||
tc, ok := m.tunnels[subdomain]
|
||||
return tc, ok
|
||||
}
|
||||
|
||||
// List returns all active tunnel connections
|
||||
func (m *Manager) List() []*Connection {
|
||||
m.mu.RLock()
|
||||
defer m.mu.RUnlock()
|
||||
|
||||
connections := make([]*Connection, 0, len(m.tunnels))
|
||||
for _, tc := range m.tunnels {
|
||||
connections = append(connections, tc)
|
||||
}
|
||||
return connections
|
||||
}
|
||||
|
||||
// Count returns the number of active tunnels
|
||||
func (m *Manager) Count() int {
|
||||
m.mu.RLock()
|
||||
defer m.mu.RUnlock()
|
||||
return len(m.tunnels)
|
||||
}
|
||||
|
||||
// CleanupStale removes stale connections that haven't been active
|
||||
func (m *Manager) CleanupStale(timeout time.Duration) int {
|
||||
m.mu.Lock()
|
||||
defer m.mu.Unlock()
|
||||
|
||||
staleSubdomains := []string{}
|
||||
|
||||
for subdomain, tc := range m.tunnels {
|
||||
if !tc.IsAlive(timeout) {
|
||||
staleSubdomains = append(staleSubdomains, subdomain)
|
||||
}
|
||||
}
|
||||
|
||||
for _, subdomain := range staleSubdomains {
|
||||
if tc, ok := m.tunnels[subdomain]; ok {
|
||||
tc.Close()
|
||||
delete(m.tunnels, subdomain)
|
||||
delete(m.used, subdomain)
|
||||
}
|
||||
}
|
||||
|
||||
if len(staleSubdomains) > 0 {
|
||||
m.logger.Info("Cleaned up stale tunnels",
|
||||
zap.Int("count", len(staleSubdomains)),
|
||||
)
|
||||
}
|
||||
|
||||
return len(staleSubdomains)
|
||||
}
|
||||
|
||||
// StartCleanupTask starts a background task to clean up stale connections
|
||||
func (m *Manager) StartCleanupTask(interval, timeout time.Duration) {
|
||||
ticker := time.NewTicker(interval)
|
||||
go func() {
|
||||
for range ticker.C {
|
||||
m.CleanupStale(timeout)
|
||||
}
|
||||
}()
|
||||
}
|
||||
|
||||
// generateUniqueSubdomain generates a unique random subdomain
|
||||
func (m *Manager) generateUniqueSubdomain() string {
|
||||
const maxAttempts = 10
|
||||
|
||||
for i := 0; i < maxAttempts; i++ {
|
||||
subdomain := utils.GenerateSubdomain(6)
|
||||
if !m.used[subdomain] && !utils.IsReserved(subdomain) {
|
||||
return subdomain
|
||||
}
|
||||
}
|
||||
|
||||
// Fallback: use longer subdomain if collision persists
|
||||
return utils.GenerateSubdomain(8)
|
||||
}
|
||||
|
||||
// Shutdown gracefully shuts down all tunnels
|
||||
func (m *Manager) Shutdown() {
|
||||
m.mu.Lock()
|
||||
defer m.mu.Unlock()
|
||||
|
||||
m.logger.Info("Shutting down tunnel manager",
|
||||
zap.Int("active_tunnels", len(m.tunnels)),
|
||||
)
|
||||
|
||||
for _, tc := range m.tunnels {
|
||||
tc.Close()
|
||||
}
|
||||
|
||||
m.tunnels = make(map[string]*Connection)
|
||||
m.used = make(map[string]bool)
|
||||
}
|
||||
376
internal/server/tunnel/manager_test.go
Normal file
376
internal/server/tunnel/manager_test.go
Normal file
@@ -0,0 +1,376 @@
|
||||
package tunnel
|
||||
|
||||
import (
|
||||
"sync"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"go.uber.org/zap"
|
||||
)
|
||||
|
||||
func TestNewManager(t *testing.T) {
|
||||
logger := zap.NewNop()
|
||||
manager := NewManager(logger)
|
||||
|
||||
if manager == nil {
|
||||
t.Fatal("NewManager() returned nil")
|
||||
}
|
||||
|
||||
if manager.tunnels == nil {
|
||||
t.Error("Manager tunnels map is nil")
|
||||
}
|
||||
|
||||
if manager.used == nil {
|
||||
t.Error("Manager used map is nil")
|
||||
}
|
||||
|
||||
if manager.logger == nil {
|
||||
t.Error("Manager logger is nil")
|
||||
}
|
||||
}
|
||||
|
||||
func TestManagerRegister(t *testing.T) {
|
||||
logger := zap.NewNop()
|
||||
manager := NewManager(logger)
|
||||
|
||||
// Register with empty subdomain (auto-generate)
|
||||
subdomain, err := manager.Register(nil, "")
|
||||
if err != nil {
|
||||
t.Errorf("Register() error = %v, want nil", err)
|
||||
}
|
||||
|
||||
if subdomain == "" {
|
||||
t.Error("Register() returned empty subdomain")
|
||||
}
|
||||
|
||||
if len(subdomain) != 6 {
|
||||
t.Errorf("Register() subdomain length = %d, want 6", len(subdomain))
|
||||
}
|
||||
|
||||
// Verify connection is registered
|
||||
_, ok := manager.Get(subdomain)
|
||||
if !ok {
|
||||
t.Error("Get() failed to retrieve registered connection")
|
||||
}
|
||||
}
|
||||
|
||||
func TestManagerRegisterCustomSubdomain(t *testing.T) {
|
||||
logger := zap.NewNop()
|
||||
manager := NewManager(logger)
|
||||
|
||||
customSubdomain := "mytest"
|
||||
|
||||
// Register with custom subdomain
|
||||
subdomain, err := manager.Register(nil, customSubdomain)
|
||||
if err != nil {
|
||||
t.Errorf("Register() error = %v, want nil", err)
|
||||
}
|
||||
|
||||
if subdomain != customSubdomain {
|
||||
t.Errorf("Register() subdomain = %v, want %v", subdomain, customSubdomain)
|
||||
}
|
||||
|
||||
// Verify connection is registered
|
||||
conn, ok := manager.Get(subdomain)
|
||||
if !ok {
|
||||
t.Error("Get() failed to retrieve registered connection")
|
||||
}
|
||||
|
||||
if conn.Subdomain != customSubdomain {
|
||||
t.Errorf("Connection subdomain = %v, want %v", conn.Subdomain, customSubdomain)
|
||||
}
|
||||
}
|
||||
|
||||
func TestManagerRegisterDuplicate(t *testing.T) {
|
||||
logger := zap.NewNop()
|
||||
manager := NewManager(logger)
|
||||
|
||||
customSubdomain := "test123"
|
||||
|
||||
// Register first connection
|
||||
_, err := manager.Register(nil, customSubdomain)
|
||||
if err != nil {
|
||||
t.Fatalf("First Register() error = %v, want nil", err)
|
||||
}
|
||||
|
||||
// Try to register second connection with same subdomain
|
||||
_, err = manager.Register(nil, customSubdomain)
|
||||
if err != ErrSubdomainTaken {
|
||||
t.Errorf("Register() error = %v, want %v", err, ErrSubdomainTaken)
|
||||
}
|
||||
}
|
||||
|
||||
func TestManagerRegisterInvalidSubdomain(t *testing.T) {
|
||||
logger := zap.NewNop()
|
||||
manager := NewManager(logger)
|
||||
|
||||
tests := []struct {
|
||||
name string
|
||||
subdomain string
|
||||
wantErr error
|
||||
}{
|
||||
{
|
||||
name: "invalid uppercase",
|
||||
subdomain: "TEST",
|
||||
wantErr: ErrInvalidSubdomain,
|
||||
},
|
||||
{
|
||||
name: "invalid special char",
|
||||
subdomain: "test@123",
|
||||
wantErr: ErrInvalidSubdomain,
|
||||
},
|
||||
{
|
||||
name: "reserved www",
|
||||
subdomain: "www",
|
||||
wantErr: ErrReservedSubdomain,
|
||||
},
|
||||
{
|
||||
name: "reserved api",
|
||||
subdomain: "api",
|
||||
wantErr: ErrReservedSubdomain,
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
_, err := manager.Register(nil, tt.subdomain)
|
||||
if err != tt.wantErr {
|
||||
t.Errorf("Register() error = %v, want %v", err, tt.wantErr)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestManagerUnregister(t *testing.T) {
|
||||
logger := zap.NewNop()
|
||||
manager := NewManager(logger)
|
||||
|
||||
// Register connection
|
||||
subdomain, err := manager.Register(nil, "")
|
||||
if err != nil {
|
||||
t.Fatalf("Register() error = %v", err)
|
||||
}
|
||||
|
||||
// Unregister connection
|
||||
manager.Unregister(subdomain)
|
||||
|
||||
// Verify connection is removed
|
||||
_, ok := manager.Get(subdomain)
|
||||
if ok {
|
||||
t.Error("Get() succeeded after Unregister(), want failure")
|
||||
}
|
||||
}
|
||||
|
||||
func TestManagerGet(t *testing.T) {
|
||||
logger := zap.NewNop()
|
||||
manager := NewManager(logger)
|
||||
|
||||
customSubdomain := "test123"
|
||||
|
||||
// Test Get on non-existent connection
|
||||
_, ok := manager.Get(customSubdomain)
|
||||
if ok {
|
||||
t.Error("Get() succeeded for non-existent connection")
|
||||
}
|
||||
|
||||
// Register and test Get
|
||||
subdomain, _ := manager.Register(nil, customSubdomain)
|
||||
retrieved, ok := manager.Get(subdomain)
|
||||
if !ok {
|
||||
t.Error("Get() failed for existing connection")
|
||||
}
|
||||
if retrieved.Subdomain != customSubdomain {
|
||||
t.Error("Get() returned wrong connection")
|
||||
}
|
||||
}
|
||||
|
||||
func TestManagerList(t *testing.T) {
|
||||
logger := zap.NewNop()
|
||||
manager := NewManager(logger)
|
||||
|
||||
// Test empty manager
|
||||
all := manager.List()
|
||||
if len(all) != 0 {
|
||||
t.Errorf("List() on empty manager returned %d connections, want 0", len(all))
|
||||
}
|
||||
|
||||
// Add multiple connections
|
||||
count := 5
|
||||
for i := 0; i < count; i++ {
|
||||
manager.Register(nil, "")
|
||||
}
|
||||
|
||||
// Test List
|
||||
all = manager.List()
|
||||
if len(all) != count {
|
||||
t.Errorf("List() returned %d connections, want %d", len(all), count)
|
||||
}
|
||||
}
|
||||
|
||||
func TestManagerCount(t *testing.T) {
|
||||
logger := zap.NewNop()
|
||||
manager := NewManager(logger)
|
||||
|
||||
// Test empty manager
|
||||
count := manager.Count()
|
||||
if count != 0 {
|
||||
t.Errorf("Count() on empty manager = %d, want 0", count)
|
||||
}
|
||||
|
||||
// Add connections
|
||||
numConns := 3
|
||||
for i := 0; i < numConns; i++ {
|
||||
manager.Register(nil, "")
|
||||
}
|
||||
|
||||
count = manager.Count()
|
||||
if count != numConns {
|
||||
t.Errorf("Count() = %d, want %d", count, numConns)
|
||||
}
|
||||
}
|
||||
|
||||
func TestManagerGenerateSubdomain(t *testing.T) {
|
||||
logger := zap.NewNop()
|
||||
manager := NewManager(logger)
|
||||
|
||||
// Generate subdomain via Register
|
||||
subdomain1, err := manager.Register(nil, "")
|
||||
if err != nil {
|
||||
t.Fatalf("First Register() error = %v", err)
|
||||
}
|
||||
|
||||
if subdomain1 == "" {
|
||||
t.Error("Register() returned empty subdomain")
|
||||
}
|
||||
|
||||
if len(subdomain1) != 6 {
|
||||
t.Errorf("Register() subdomain length = %d, want 6", len(subdomain1))
|
||||
}
|
||||
|
||||
// Generate another subdomain, should be different
|
||||
subdomain2, err := manager.Register(nil, "")
|
||||
if err != nil {
|
||||
t.Fatalf("Second Register() error = %v", err)
|
||||
}
|
||||
|
||||
if subdomain1 == subdomain2 {
|
||||
t.Error("Register() generated duplicate subdomain")
|
||||
}
|
||||
}
|
||||
|
||||
func TestManagerGenerateSubdomainUniqueness(t *testing.T) {
|
||||
logger := zap.NewNop()
|
||||
manager := NewManager(logger)
|
||||
|
||||
subdomains := make(map[string]bool)
|
||||
count := 100
|
||||
|
||||
for i := 0; i < count; i++ {
|
||||
subdomain, err := manager.Register(nil, "")
|
||||
if err != nil {
|
||||
t.Fatalf("Register() error = %v", err)
|
||||
}
|
||||
if subdomains[subdomain] {
|
||||
t.Errorf("Register() generated duplicate: %s", subdomain)
|
||||
}
|
||||
subdomains[subdomain] = true
|
||||
}
|
||||
}
|
||||
|
||||
func TestManagerCleanupStale(t *testing.T) {
|
||||
logger := zap.NewNop()
|
||||
manager := NewManager(logger)
|
||||
|
||||
// Create fresh connection
|
||||
freshSubdomain, _ := manager.Register(nil, "fresh")
|
||||
|
||||
// Create stale connection
|
||||
staleSubdomain, _ := manager.Register(nil, "stale")
|
||||
|
||||
// Manually set LastActive to be stale
|
||||
if staleConn, ok := manager.Get(staleSubdomain); ok {
|
||||
staleConn.mu.Lock()
|
||||
staleConn.LastActive = time.Now().Add(-2 * time.Minute)
|
||||
staleConn.mu.Unlock()
|
||||
}
|
||||
|
||||
// Run cleanup with 90 second timeout
|
||||
count := manager.CleanupStale(90 * time.Second)
|
||||
if count != 1 {
|
||||
t.Errorf("CleanupStale() returned %d, want 1", count)
|
||||
}
|
||||
|
||||
// Fresh connection should still exist
|
||||
_, ok := manager.Get(freshSubdomain)
|
||||
if !ok {
|
||||
t.Error("CleanupStale() removed fresh connection")
|
||||
}
|
||||
|
||||
// Stale connection should be removed
|
||||
_, ok = manager.Get(staleSubdomain)
|
||||
if ok {
|
||||
t.Error("CleanupStale() did not remove stale connection")
|
||||
}
|
||||
}
|
||||
|
||||
func TestManagerConcurrentAccess(t *testing.T) {
|
||||
logger := zap.NewNop()
|
||||
manager := NewManager(logger)
|
||||
|
||||
var wg sync.WaitGroup
|
||||
count := 50 // Reduced from 100 to avoid potential issues
|
||||
|
||||
// Concurrent registrations
|
||||
for i := 0; i < count; i++ {
|
||||
wg.Add(1)
|
||||
go func(idx int) {
|
||||
defer wg.Done()
|
||||
_, err := manager.Register(nil, "")
|
||||
if err != nil {
|
||||
t.Errorf("Concurrent Register() error = %v", err)
|
||||
}
|
||||
}(i)
|
||||
}
|
||||
|
||||
wg.Wait()
|
||||
|
||||
// Verify all connections are registered
|
||||
all := manager.List()
|
||||
if len(all) != count {
|
||||
t.Errorf("Expected %d connections, got %d", count, len(all))
|
||||
}
|
||||
}
|
||||
|
||||
// Benchmark tests
|
||||
func BenchmarkManagerRegister(b *testing.B) {
|
||||
logger := zap.NewNop()
|
||||
manager := NewManager(logger)
|
||||
|
||||
b.ResetTimer()
|
||||
for i := 0; i < b.N; i++ {
|
||||
manager.Register(nil, "")
|
||||
}
|
||||
}
|
||||
|
||||
func BenchmarkManagerGet(b *testing.B) {
|
||||
logger := zap.NewNop()
|
||||
manager := NewManager(logger)
|
||||
|
||||
// Setup: register a connection
|
||||
subdomain, _ := manager.Register(nil, "test123")
|
||||
|
||||
b.ResetTimer()
|
||||
for i := 0; i < b.N; i++ {
|
||||
manager.Get(subdomain)
|
||||
}
|
||||
}
|
||||
|
||||
func BenchmarkManagerGenerateSubdomain(b *testing.B) {
|
||||
logger := zap.NewNop()
|
||||
|
||||
b.ResetTimer()
|
||||
for i := 0; i < b.N; i++ {
|
||||
manager := NewManager(logger)
|
||||
manager.Register(nil, "")
|
||||
}
|
||||
}
|
||||
48
internal/shared/constants/constants.go
Normal file
48
internal/shared/constants/constants.go
Normal file
@@ -0,0 +1,48 @@
|
||||
package constants
|
||||
|
||||
import "time"
|
||||
|
||||
const (
|
||||
// DefaultServerPort is the default port for the tunnel server
|
||||
DefaultServerPort = 8080
|
||||
|
||||
// DefaultWSPort is the default WebSocket port
|
||||
DefaultWSPort = 8080
|
||||
|
||||
// HeartbeatInterval is how often clients send heartbeat messages
|
||||
HeartbeatInterval = 5 * time.Second
|
||||
|
||||
// HeartbeatTimeout is how long the server waits before considering a connection dead
|
||||
HeartbeatTimeout = 15 * time.Second
|
||||
|
||||
// RequestTimeout is the maximum time to wait for a response from the client
|
||||
RequestTimeout = 30 * time.Second
|
||||
|
||||
// MaxRequestBodySize is the maximum size of an HTTP request body (10MB)
|
||||
MaxRequestBodySize = 10 * 1024 * 1024
|
||||
|
||||
// ReconnectBaseDelay is the initial delay for reconnection attempts
|
||||
ReconnectBaseDelay = 1 * time.Second
|
||||
|
||||
// ReconnectMaxDelay is the maximum delay between reconnection attempts
|
||||
ReconnectMaxDelay = 60 * time.Second
|
||||
|
||||
// MaxReconnectAttempts is the maximum number of reconnection attempts (0 = infinite)
|
||||
MaxReconnectAttempts = 0
|
||||
|
||||
// DefaultTCPPortMin/Max define the default allocation range for TCP tunnels
|
||||
DefaultTCPPortMin = 20000
|
||||
DefaultTCPPortMax = 40000
|
||||
// DefaultDomain is the default domain for tunnels
|
||||
DefaultDomain = "tunnel.localhost"
|
||||
)
|
||||
|
||||
// Error codes
|
||||
const (
|
||||
ErrCodeTunnelNotFound = "TUNNEL_NOT_FOUND"
|
||||
ErrCodeTimeout = "TIMEOUT"
|
||||
ErrCodeConnectionFailed = "CONNECTION_FAILED"
|
||||
ErrCodeInvalidRequest = "INVALID_REQUEST"
|
||||
ErrCodeAuthFailed = "AUTH_FAILED"
|
||||
ErrCodeRateLimited = "RATE_LIMITED"
|
||||
)
|
||||
77
internal/shared/pool/buffer_pool.go
Normal file
77
internal/shared/pool/buffer_pool.go
Normal file
@@ -0,0 +1,77 @@
|
||||
package pool
|
||||
|
||||
import "sync"
|
||||
|
||||
const (
|
||||
SizeSmall = 4 * 1024 // 4KB
|
||||
SizeMedium = 32 * 1024 // 32KB
|
||||
SizeLarge = 256 * 1024 // 256KB
|
||||
)
|
||||
|
||||
type BufferPool struct {
|
||||
small sync.Pool
|
||||
medium sync.Pool
|
||||
large sync.Pool
|
||||
}
|
||||
|
||||
func NewBufferPool() *BufferPool {
|
||||
return &BufferPool{
|
||||
small: sync.Pool{
|
||||
New: func() interface{} {
|
||||
b := make([]byte, SizeSmall)
|
||||
return &b
|
||||
},
|
||||
},
|
||||
medium: sync.Pool{
|
||||
New: func() interface{} {
|
||||
b := make([]byte, SizeMedium)
|
||||
return &b
|
||||
},
|
||||
},
|
||||
large: sync.Pool{
|
||||
New: func() interface{} {
|
||||
b := make([]byte, SizeLarge)
|
||||
return &b
|
||||
},
|
||||
},
|
||||
}
|
||||
}
|
||||
|
||||
func (p *BufferPool) Get(size int) *[]byte {
|
||||
switch {
|
||||
case size <= SizeSmall:
|
||||
return p.small.Get().(*[]byte)
|
||||
case size <= SizeMedium:
|
||||
return p.medium.Get().(*[]byte)
|
||||
default:
|
||||
return p.large.Get().(*[]byte)
|
||||
}
|
||||
}
|
||||
|
||||
func (p *BufferPool) Put(buf *[]byte) {
|
||||
if buf == nil {
|
||||
return
|
||||
}
|
||||
|
||||
size := cap(*buf)
|
||||
*buf = (*buf)[:cap(*buf)]
|
||||
|
||||
switch size {
|
||||
case SizeSmall:
|
||||
p.small.Put(buf)
|
||||
case SizeMedium:
|
||||
p.medium.Put(buf)
|
||||
case SizeLarge:
|
||||
p.large.Put(buf)
|
||||
}
|
||||
}
|
||||
|
||||
var globalBufferPool = NewBufferPool()
|
||||
|
||||
func GetBuffer(size int) *[]byte {
|
||||
return globalBufferPool.Get(size)
|
||||
}
|
||||
|
||||
func PutBuffer(buf *[]byte) {
|
||||
globalBufferPool.Put(buf)
|
||||
}
|
||||
93
internal/shared/pool/header_pool.go
Normal file
93
internal/shared/pool/header_pool.go
Normal file
@@ -0,0 +1,93 @@
|
||||
package pool
|
||||
|
||||
import (
|
||||
"net/http"
|
||||
"sync"
|
||||
)
|
||||
|
||||
// HeaderPool manages a pool of http.Header objects for reuse
|
||||
// This reduces GC pressure from repeated header map allocations
|
||||
type HeaderPool struct {
|
||||
pool sync.Pool
|
||||
}
|
||||
|
||||
// NewHeaderPool creates a new header pool
|
||||
func NewHeaderPool() *HeaderPool {
|
||||
return &HeaderPool{
|
||||
pool: sync.Pool{
|
||||
New: func() interface{} {
|
||||
// Pre-allocate with capacity for common header count (8-12 headers)
|
||||
return make(http.Header, 12)
|
||||
},
|
||||
},
|
||||
}
|
||||
}
|
||||
|
||||
// Get retrieves a header from the pool
|
||||
// Returns a clean, empty header ready for use
|
||||
func (p *HeaderPool) Get() http.Header {
|
||||
h := p.pool.Get().(http.Header)
|
||||
// Clear any existing data (headers might be dirty from previous use)
|
||||
for k := range h {
|
||||
delete(h, k)
|
||||
}
|
||||
return h
|
||||
}
|
||||
|
||||
// Put returns a header to the pool
|
||||
// The header will be reused by future Get() calls
|
||||
func (p *HeaderPool) Put(h http.Header) {
|
||||
if h == nil {
|
||||
return
|
||||
}
|
||||
// Note: We don't clear here, clearing is done in Get() for better performance
|
||||
// (allows the GC to collect during idle time)
|
||||
p.pool.Put(h)
|
||||
}
|
||||
|
||||
// Clone creates a copy of src into dst, reusing dst's underlying storage
|
||||
// This is more efficient than creating a new header from scratch
|
||||
func (p *HeaderPool) Clone(dst, src http.Header) {
|
||||
// Clear dst first
|
||||
for k := range dst {
|
||||
delete(dst, k)
|
||||
}
|
||||
|
||||
// Copy all headers from src to dst
|
||||
for k, vv := range src {
|
||||
// Allocate new slice with exact capacity to avoid over-allocation
|
||||
dst[k] = make([]string, len(vv))
|
||||
copy(dst[k], vv)
|
||||
}
|
||||
}
|
||||
|
||||
// CloneWithExtra clones src into dst and adds/overwrites extra headers
|
||||
// This is optimized for the common pattern of cloning + adding Host header
|
||||
func (p *HeaderPool) CloneWithExtra(dst, src http.Header, extraKey, extraValue string) {
|
||||
// Clear dst first
|
||||
for k := range dst {
|
||||
delete(dst, k)
|
||||
}
|
||||
|
||||
// Copy all headers from src to dst
|
||||
for k, vv := range src {
|
||||
dst[k] = make([]string, len(vv))
|
||||
copy(dst[k], vv)
|
||||
}
|
||||
|
||||
// Set extra header (overwrite if exists)
|
||||
dst.Set(extraKey, extraValue)
|
||||
}
|
||||
|
||||
// globalHeaderPool is a package-level pool for convenience
|
||||
var globalHeaderPool = NewHeaderPool()
|
||||
|
||||
// GetHeader retrieves a header from the global pool
|
||||
func GetHeader() http.Header {
|
||||
return globalHeaderPool.Get()
|
||||
}
|
||||
|
||||
// PutHeader returns a header to the global pool
|
||||
func PutHeader(h http.Header) {
|
||||
globalHeaderPool.Put(h)
|
||||
}
|
||||
115
internal/shared/pool/worker_pool.go
Normal file
115
internal/shared/pool/worker_pool.go
Normal file
@@ -0,0 +1,115 @@
|
||||
package pool
|
||||
|
||||
import (
|
||||
"sync"
|
||||
)
|
||||
|
||||
// WorkerPool is a fixed-size goroutine pool for handling tasks
|
||||
type WorkerPool struct {
|
||||
workers int
|
||||
jobQueue chan func()
|
||||
wg sync.WaitGroup
|
||||
once sync.Once
|
||||
closed bool
|
||||
mu sync.RWMutex
|
||||
}
|
||||
|
||||
// NewWorkerPool creates a new worker pool with the specified number of workers
|
||||
func NewWorkerPool(workers int, queueSize int) *WorkerPool {
|
||||
if workers <= 0 {
|
||||
workers = 50 // Default worker count
|
||||
}
|
||||
if queueSize <= 0 {
|
||||
queueSize = 1000 // Default queue size
|
||||
}
|
||||
|
||||
pool := &WorkerPool{
|
||||
workers: workers,
|
||||
jobQueue: make(chan func(), queueSize),
|
||||
}
|
||||
|
||||
// Start worker goroutines
|
||||
for i := 0; i < workers; i++ {
|
||||
pool.wg.Add(1)
|
||||
go pool.worker()
|
||||
}
|
||||
|
||||
return pool
|
||||
}
|
||||
|
||||
// worker is the worker goroutine that processes jobs from the queue
|
||||
func (p *WorkerPool) worker() {
|
||||
defer p.wg.Done()
|
||||
|
||||
for job := range p.jobQueue {
|
||||
if job != nil {
|
||||
job()
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// Submit submits a job to the worker pool
|
||||
// Returns false if the pool is closed or the queue is full
|
||||
func (p *WorkerPool) Submit(job func()) bool {
|
||||
if job == nil {
|
||||
return false
|
||||
}
|
||||
|
||||
p.mu.RLock()
|
||||
if p.closed {
|
||||
p.mu.RUnlock()
|
||||
return false
|
||||
}
|
||||
p.mu.RUnlock()
|
||||
|
||||
// Non-blocking send
|
||||
select {
|
||||
case p.jobQueue <- job:
|
||||
return true
|
||||
default:
|
||||
// Queue is full, fall back to direct execution
|
||||
// This prevents blocking when pool is overloaded
|
||||
go job()
|
||||
return false
|
||||
}
|
||||
}
|
||||
|
||||
// SubmitWait submits a job and waits for it to complete
|
||||
func (p *WorkerPool) SubmitWait(job func()) {
|
||||
if job == nil {
|
||||
return
|
||||
}
|
||||
|
||||
done := make(chan struct{})
|
||||
wrappedJob := func() {
|
||||
defer close(done)
|
||||
job()
|
||||
}
|
||||
|
||||
if p.Submit(wrappedJob) {
|
||||
<-done
|
||||
} else {
|
||||
// Pool is closed or full, execute directly
|
||||
job()
|
||||
}
|
||||
}
|
||||
|
||||
// Close gracefully shuts down the worker pool
|
||||
// It waits for all pending jobs to complete
|
||||
func (p *WorkerPool) Close() {
|
||||
p.once.Do(func() {
|
||||
p.mu.Lock()
|
||||
p.closed = true
|
||||
p.mu.Unlock()
|
||||
|
||||
close(p.jobQueue)
|
||||
p.wg.Wait()
|
||||
})
|
||||
}
|
||||
|
||||
// IsClosed returns true if the pool is closed
|
||||
func (p *WorkerPool) IsClosed() bool {
|
||||
p.mu.RLock()
|
||||
defer p.mu.RUnlock()
|
||||
return p.closed
|
||||
}
|
||||
166
internal/shared/protocol/binary_header.go
Normal file
166
internal/shared/protocol/binary_header.go
Normal file
@@ -0,0 +1,166 @@
|
||||
package protocol
|
||||
|
||||
import (
|
||||
"encoding/binary"
|
||||
"errors"
|
||||
)
|
||||
|
||||
// DataHeaderV2 represents a binary-encoded data header (Protocol Version 2)
|
||||
// This replaces JSON encoding to improve performance
|
||||
type DataHeaderV2 struct {
|
||||
Type DataType
|
||||
IsLast bool
|
||||
StreamID string
|
||||
RequestID string
|
||||
}
|
||||
|
||||
// DataType represents the type of data frame
|
||||
type DataType uint8
|
||||
|
||||
const (
|
||||
DataTypeData DataType = 0x00 // 000
|
||||
DataTypeResponse DataType = 0x01 // 001
|
||||
DataTypeClose DataType = 0x02 // 010
|
||||
DataTypeHTTPRequest DataType = 0x03 // 011
|
||||
DataTypeHTTPResponse DataType = 0x04 // 100
|
||||
)
|
||||
|
||||
// String returns the string representation of DataType
|
||||
func (t DataType) String() string {
|
||||
switch t {
|
||||
case DataTypeData:
|
||||
return "data"
|
||||
case DataTypeResponse:
|
||||
return "response"
|
||||
case DataTypeClose:
|
||||
return "close"
|
||||
case DataTypeHTTPRequest:
|
||||
return "http_request"
|
||||
case DataTypeHTTPResponse:
|
||||
return "http_response"
|
||||
default:
|
||||
return "unknown"
|
||||
}
|
||||
}
|
||||
|
||||
// FromString converts a string to DataType
|
||||
func DataTypeFromString(s string) DataType {
|
||||
switch s {
|
||||
case "data":
|
||||
return DataTypeData
|
||||
case "response":
|
||||
return DataTypeResponse
|
||||
case "close":
|
||||
return DataTypeClose
|
||||
case "http_request":
|
||||
return DataTypeHTTPRequest
|
||||
case "http_response":
|
||||
return DataTypeHTTPResponse
|
||||
default:
|
||||
return DataTypeData
|
||||
}
|
||||
}
|
||||
|
||||
// Binary format:
|
||||
// +--------+--------+--------+--------+--------+
|
||||
// | Flags | StreamID Length | RequestID Len |
|
||||
// | 1 byte | 2 bytes | 2 bytes |
|
||||
// +--------+--------+--------+--------+--------+
|
||||
// | StreamID (variable) |
|
||||
// +--------+--------+--------+--------+--------+
|
||||
// | RequestID (variable) |
|
||||
// +--------+--------+--------+--------+--------+
|
||||
//
|
||||
// Flags (8 bits):
|
||||
// - Bit 0-2: Type (3 bits)
|
||||
// - Bit 3: IsLast (1 bit)
|
||||
// - Bit 4-7: Reserved (4 bits)
|
||||
|
||||
const (
|
||||
binaryHeaderMinSize = 5 // 1 byte flags + 2 bytes streamID len + 2 bytes requestID len
|
||||
)
|
||||
|
||||
// MarshalBinary encodes the header to binary format
|
||||
func (h *DataHeaderV2) MarshalBinary() []byte {
|
||||
streamIDLen := len(h.StreamID)
|
||||
requestIDLen := len(h.RequestID)
|
||||
|
||||
totalLen := binaryHeaderMinSize + streamIDLen + requestIDLen
|
||||
buf := make([]byte, totalLen)
|
||||
|
||||
// Encode flags
|
||||
flags := uint8(h.Type) & 0x07 // Type uses bits 0-2
|
||||
if h.IsLast {
|
||||
flags |= 0x08 // IsLast uses bit 3
|
||||
}
|
||||
buf[0] = flags
|
||||
|
||||
// Encode lengths (big-endian)
|
||||
binary.BigEndian.PutUint16(buf[1:3], uint16(streamIDLen))
|
||||
binary.BigEndian.PutUint16(buf[3:5], uint16(requestIDLen))
|
||||
|
||||
// Encode StreamID
|
||||
offset := binaryHeaderMinSize
|
||||
copy(buf[offset:], h.StreamID)
|
||||
offset += streamIDLen
|
||||
|
||||
// Encode RequestID
|
||||
copy(buf[offset:], h.RequestID)
|
||||
|
||||
return buf
|
||||
}
|
||||
|
||||
// UnmarshalBinary decodes the header from binary format
|
||||
func (h *DataHeaderV2) UnmarshalBinary(data []byte) error {
|
||||
if len(data) < binaryHeaderMinSize {
|
||||
return errors.New("invalid binary header: too short")
|
||||
}
|
||||
|
||||
// Decode flags
|
||||
flags := data[0]
|
||||
h.Type = DataType(flags & 0x07) // Bits 0-2
|
||||
h.IsLast = (flags & 0x08) != 0 // Bit 3
|
||||
|
||||
// Decode lengths
|
||||
streamIDLen := int(binary.BigEndian.Uint16(data[1:3]))
|
||||
requestIDLen := int(binary.BigEndian.Uint16(data[3:5]))
|
||||
|
||||
// Validate total length
|
||||
expectedLen := binaryHeaderMinSize + streamIDLen + requestIDLen
|
||||
if len(data) < expectedLen {
|
||||
return errors.New("invalid binary header: length mismatch")
|
||||
}
|
||||
|
||||
// Decode StreamID
|
||||
offset := binaryHeaderMinSize
|
||||
h.StreamID = string(data[offset : offset+streamIDLen])
|
||||
offset += streamIDLen
|
||||
|
||||
// Decode RequestID
|
||||
h.RequestID = string(data[offset : offset+requestIDLen])
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// ToDataHeader converts binary header to JSON header (for compatibility)
|
||||
func (h *DataHeaderV2) ToDataHeader() DataHeader {
|
||||
return DataHeader{
|
||||
StreamID: h.StreamID,
|
||||
RequestID: h.RequestID,
|
||||
Type: h.Type.String(),
|
||||
IsLast: h.IsLast,
|
||||
}
|
||||
}
|
||||
|
||||
// FromDataHeader converts JSON header to binary header
|
||||
func (h *DataHeaderV2) FromDataHeader(dh DataHeader) {
|
||||
h.StreamID = dh.StreamID
|
||||
h.RequestID = dh.RequestID
|
||||
h.Type = DataTypeFromString(dh.Type)
|
||||
h.IsLast = dh.IsLast
|
||||
}
|
||||
|
||||
// Size returns the size of the binary-encoded header
|
||||
func (h *DataHeaderV2) Size() int {
|
||||
return binaryHeaderMinSize + len(h.StreamID) + len(h.RequestID)
|
||||
}
|
||||
136
internal/shared/protocol/frame.go
Normal file
136
internal/shared/protocol/frame.go
Normal file
@@ -0,0 +1,136 @@
|
||||
package protocol
|
||||
|
||||
import (
|
||||
"encoding/binary"
|
||||
"fmt"
|
||||
"io"
|
||||
|
||||
"drip/internal/shared/pool"
|
||||
)
|
||||
|
||||
const (
|
||||
FrameHeaderSize = 5
|
||||
MaxFrameSize = 10 * 1024 * 1024
|
||||
)
|
||||
|
||||
// FrameType defines the type of frame
|
||||
type FrameType byte
|
||||
|
||||
const (
|
||||
FrameTypeRegister FrameType = 0x01
|
||||
FrameTypeRegisterAck FrameType = 0x02
|
||||
FrameTypeHeartbeat FrameType = 0x03
|
||||
FrameTypeHeartbeatAck FrameType = 0x04
|
||||
FrameTypeData FrameType = 0x05
|
||||
FrameTypeClose FrameType = 0x06
|
||||
FrameTypeError FrameType = 0x07
|
||||
)
|
||||
|
||||
// String returns the string representation of frame type
|
||||
func (t FrameType) String() string {
|
||||
switch t {
|
||||
case FrameTypeRegister:
|
||||
return "Register"
|
||||
case FrameTypeRegisterAck:
|
||||
return "RegisterAck"
|
||||
case FrameTypeHeartbeat:
|
||||
return "Heartbeat"
|
||||
case FrameTypeHeartbeatAck:
|
||||
return "HeartbeatAck"
|
||||
case FrameTypeData:
|
||||
return "Data"
|
||||
case FrameTypeClose:
|
||||
return "Close"
|
||||
case FrameTypeError:
|
||||
return "Error"
|
||||
default:
|
||||
return fmt.Sprintf("Unknown(%d)", t)
|
||||
}
|
||||
}
|
||||
|
||||
type Frame struct {
|
||||
Type FrameType
|
||||
Payload []byte
|
||||
poolBuffer *[]byte
|
||||
}
|
||||
|
||||
func WriteFrame(w io.Writer, frame *Frame) error {
|
||||
payloadLen := len(frame.Payload)
|
||||
if payloadLen > MaxFrameSize {
|
||||
return fmt.Errorf("payload too large: %d bytes (max %d)", payloadLen, MaxFrameSize)
|
||||
}
|
||||
|
||||
lengthBuf := make([]byte, 4)
|
||||
binary.BigEndian.PutUint32(lengthBuf, uint32(payloadLen))
|
||||
if _, err := w.Write(lengthBuf); err != nil {
|
||||
return fmt.Errorf("failed to write length: %w", err)
|
||||
}
|
||||
|
||||
if _, err := w.Write([]byte{byte(frame.Type)}); err != nil {
|
||||
return fmt.Errorf("failed to write type: %w", err)
|
||||
}
|
||||
|
||||
if payloadLen > 0 {
|
||||
if _, err := w.Write(frame.Payload); err != nil {
|
||||
return fmt.Errorf("failed to write payload: %w", err)
|
||||
}
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func ReadFrame(r io.Reader) (*Frame, error) {
|
||||
header := make([]byte, FrameHeaderSize)
|
||||
if _, err := io.ReadFull(r, header); err != nil {
|
||||
return nil, fmt.Errorf("failed to read frame header: %w", err)
|
||||
}
|
||||
|
||||
payloadLen := binary.BigEndian.Uint32(header[0:4])
|
||||
if payloadLen > MaxFrameSize {
|
||||
return nil, fmt.Errorf("payload too large: %d bytes (max %d)", payloadLen, MaxFrameSize)
|
||||
}
|
||||
|
||||
frameType := FrameType(header[4])
|
||||
|
||||
var payload []byte
|
||||
var poolBuf *[]byte
|
||||
|
||||
if payloadLen > 0 {
|
||||
if payloadLen > pool.SizeLarge {
|
||||
payload = make([]byte, payloadLen)
|
||||
if _, err := io.ReadFull(r, payload); err != nil {
|
||||
return nil, fmt.Errorf("failed to read payload: %w", err)
|
||||
}
|
||||
} else {
|
||||
poolBuf = pool.GetBuffer(int(payloadLen))
|
||||
payload = (*poolBuf)[:payloadLen]
|
||||
|
||||
if _, err := io.ReadFull(r, payload); err != nil {
|
||||
pool.PutBuffer(poolBuf)
|
||||
return nil, fmt.Errorf("failed to read payload: %w", err)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
return &Frame{
|
||||
Type: frameType,
|
||||
Payload: payload,
|
||||
poolBuffer: poolBuf,
|
||||
}, nil
|
||||
}
|
||||
|
||||
func (f *Frame) Release() {
|
||||
if f.poolBuffer != nil {
|
||||
pool.PutBuffer(f.poolBuffer)
|
||||
f.poolBuffer = nil
|
||||
f.Payload = nil
|
||||
}
|
||||
}
|
||||
|
||||
// NewFrame creates a new frame
|
||||
func NewFrame(frameType FrameType, payload []byte) *Frame {
|
||||
return &Frame{
|
||||
Type: frameType,
|
||||
Payload: payload,
|
||||
}
|
||||
}
|
||||
68
internal/shared/protocol/http_codec.go
Normal file
68
internal/shared/protocol/http_codec.go
Normal file
@@ -0,0 +1,68 @@
|
||||
package protocol
|
||||
|
||||
import (
|
||||
"encoding/json"
|
||||
"errors"
|
||||
|
||||
"github.com/vmihailenco/msgpack/v5"
|
||||
)
|
||||
|
||||
// EncodeHTTPRequest encodes HTTPRequest using msgpack encoding (optimized)
|
||||
func EncodeHTTPRequest(req *HTTPRequest) ([]byte, error) {
|
||||
return msgpack.Marshal(req)
|
||||
}
|
||||
|
||||
// DecodeHTTPRequest decodes HTTPRequest with automatic version detection
|
||||
// Detects based on first byte: '{' = JSON, else = msgpack
|
||||
func DecodeHTTPRequest(data []byte) (*HTTPRequest, error) {
|
||||
if len(data) == 0 {
|
||||
return nil, errors.New("empty data")
|
||||
}
|
||||
|
||||
var req HTTPRequest
|
||||
|
||||
// Auto-detect: JSON starts with '{', msgpack starts with 0x80-0x8f (fixmap)
|
||||
if data[0] == '{' {
|
||||
// v1: JSON
|
||||
if err := json.Unmarshal(data, &req); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
} else {
|
||||
// v2: msgpack
|
||||
if err := msgpack.Unmarshal(data, &req); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
}
|
||||
|
||||
return &req, nil
|
||||
}
|
||||
|
||||
// EncodeHTTPResponse encodes HTTPResponse using msgpack encoding (optimized)
|
||||
func EncodeHTTPResponse(resp *HTTPResponse) ([]byte, error) {
|
||||
return msgpack.Marshal(resp)
|
||||
}
|
||||
|
||||
// DecodeHTTPResponse decodes HTTPResponse with automatic version detection
|
||||
// Detects based on first byte: '{' = JSON, else = msgpack
|
||||
func DecodeHTTPResponse(data []byte) (*HTTPResponse, error) {
|
||||
if len(data) == 0 {
|
||||
return nil, errors.New("empty data")
|
||||
}
|
||||
|
||||
var resp HTTPResponse
|
||||
|
||||
// Auto-detect: JSON starts with '{', msgpack starts with 0x80-0x8f (fixmap)
|
||||
if data[0] == '{' {
|
||||
// v1: JSON
|
||||
if err := json.Unmarshal(data, &resp); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
} else {
|
||||
// v2: msgpack
|
||||
if err := msgpack.Unmarshal(data, &resp); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
}
|
||||
|
||||
return &resp, nil
|
||||
}
|
||||
55
internal/shared/protocol/message.go
Normal file
55
internal/shared/protocol/message.go
Normal file
@@ -0,0 +1,55 @@
|
||||
package protocol
|
||||
|
||||
// MessageType defines the type of tunnel message
|
||||
type MessageType string
|
||||
|
||||
const (
|
||||
// TypeRegister is sent when a client connects and gets a subdomain assigned
|
||||
TypeRegister MessageType = "register"
|
||||
// TypeRequest is sent from server to client when an HTTP request arrives
|
||||
TypeRequest MessageType = "request"
|
||||
// TypeResponse is sent from client to server with the HTTP response
|
||||
TypeResponse MessageType = "response"
|
||||
// TypeHeartbeat is sent periodically to keep the connection alive
|
||||
TypeHeartbeat MessageType = "heartbeat"
|
||||
// TypeError is sent when an error occurs
|
||||
TypeError MessageType = "error"
|
||||
)
|
||||
|
||||
// Message represents a tunnel protocol message
|
||||
type Message struct {
|
||||
Type MessageType `json:"type"`
|
||||
ID string `json:"id,omitempty"`
|
||||
Subdomain string `json:"subdomain,omitempty"`
|
||||
Data map[string]interface{} `json:"data,omitempty"`
|
||||
Error string `json:"error,omitempty"`
|
||||
}
|
||||
|
||||
// HTTPRequest represents an HTTP request to be forwarded
|
||||
type HTTPRequest struct {
|
||||
Method string `json:"method"`
|
||||
URL string `json:"url"`
|
||||
Headers map[string][]string `json:"headers"`
|
||||
Body []byte `json:"body,omitempty"`
|
||||
}
|
||||
|
||||
// HTTPResponse represents an HTTP response from the local service
|
||||
type HTTPResponse struct {
|
||||
StatusCode int `json:"status_code"`
|
||||
Status string `json:"status"`
|
||||
Headers map[string][]string `json:"headers"`
|
||||
Body []byte `json:"body,omitempty"`
|
||||
}
|
||||
|
||||
// RegisterData contains information sent when a tunnel is registered
|
||||
type RegisterData struct {
|
||||
Subdomain string `json:"subdomain"`
|
||||
URL string `json:"url"`
|
||||
Message string `json:"message"`
|
||||
}
|
||||
|
||||
// ErrorData contains error information
|
||||
type ErrorData struct {
|
||||
Code string `json:"code"`
|
||||
Message string `json:"message"`
|
||||
}
|
||||
307
internal/shared/protocol/message_test.go
Normal file
307
internal/shared/protocol/message_test.go
Normal file
@@ -0,0 +1,307 @@
|
||||
package protocol
|
||||
|
||||
import (
|
||||
"encoding/json"
|
||||
"reflect"
|
||||
"testing"
|
||||
)
|
||||
|
||||
func TestMessageType_Values(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
mt MessageType
|
||||
want string
|
||||
}{
|
||||
{
|
||||
name: "register",
|
||||
mt: TypeRegister,
|
||||
want: "register",
|
||||
},
|
||||
{
|
||||
name: "request",
|
||||
mt: TypeRequest,
|
||||
want: "request",
|
||||
},
|
||||
{
|
||||
name: "response",
|
||||
mt: TypeResponse,
|
||||
want: "response",
|
||||
},
|
||||
{
|
||||
name: "heartbeat",
|
||||
mt: TypeHeartbeat,
|
||||
want: "heartbeat",
|
||||
},
|
||||
{
|
||||
name: "error",
|
||||
mt: TypeError,
|
||||
want: "error",
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
got := string(tt.mt)
|
||||
if got != tt.want {
|
||||
t.Errorf("MessageType value = %v, want %v", got, tt.want)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestMessage_JSON(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
message *Message
|
||||
wantErr bool
|
||||
}{
|
||||
{
|
||||
name: "simple message",
|
||||
message: &Message{
|
||||
Type: TypeRegister,
|
||||
ID: "test-id-123",
|
||||
},
|
||||
wantErr: false,
|
||||
},
|
||||
{
|
||||
name: "message with subdomain",
|
||||
message: &Message{
|
||||
Type: TypeRegister,
|
||||
ID: "test-id-456",
|
||||
Subdomain: "abc123",
|
||||
},
|
||||
wantErr: false,
|
||||
},
|
||||
{
|
||||
name: "message with data",
|
||||
message: &Message{
|
||||
Type: TypeRequest,
|
||||
ID: "test-id-789",
|
||||
Data: map[string]interface{}{
|
||||
"method": "GET",
|
||||
"path": "/test",
|
||||
"headers": map[string]interface{}{
|
||||
"User-Agent": "Test",
|
||||
},
|
||||
},
|
||||
},
|
||||
wantErr: false,
|
||||
},
|
||||
{
|
||||
name: "error message",
|
||||
message: &Message{
|
||||
Type: TypeError,
|
||||
ID: "test-id-error",
|
||||
Data: map[string]interface{}{
|
||||
"error": "something went wrong",
|
||||
"code": float64(500), // JSON unmarshals numbers as float64
|
||||
},
|
||||
},
|
||||
wantErr: false,
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
// Marshal to JSON
|
||||
data, err := json.Marshal(tt.message)
|
||||
if (err != nil) != tt.wantErr {
|
||||
t.Errorf("json.Marshal() error = %v, wantErr %v", err, tt.wantErr)
|
||||
return
|
||||
}
|
||||
|
||||
if tt.wantErr {
|
||||
return
|
||||
}
|
||||
|
||||
// Unmarshal back
|
||||
var decoded Message
|
||||
err = json.Unmarshal(data, &decoded)
|
||||
if err != nil {
|
||||
t.Errorf("json.Unmarshal() error = %v", err)
|
||||
return
|
||||
}
|
||||
|
||||
// Compare
|
||||
if decoded.Type != tt.message.Type {
|
||||
t.Errorf("Type = %v, want %v", decoded.Type, tt.message.Type)
|
||||
}
|
||||
if decoded.ID != tt.message.ID {
|
||||
t.Errorf("ID = %v, want %v", decoded.ID, tt.message.ID)
|
||||
}
|
||||
if decoded.Subdomain != tt.message.Subdomain {
|
||||
t.Errorf("Subdomain = %v, want %v", decoded.Subdomain, tt.message.Subdomain)
|
||||
}
|
||||
|
||||
// Deep compare Data if present
|
||||
if tt.message.Data != nil {
|
||||
if !reflect.DeepEqual(decoded.Data, tt.message.Data) {
|
||||
t.Errorf("Data = %v, want %v", decoded.Data, tt.message.Data)
|
||||
}
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestHTTPRequest_JSON(t *testing.T) {
|
||||
req := &HTTPRequest{
|
||||
Method: "POST",
|
||||
URL: "http://localhost:3000/api/test",
|
||||
Headers: map[string][]string{
|
||||
"Content-Type": {"application/json"},
|
||||
"User-Agent": {"Test Agent"},
|
||||
},
|
||||
Body: []byte(`{"key":"value"}`),
|
||||
}
|
||||
|
||||
// Marshal
|
||||
data, err := json.Marshal(req)
|
||||
if err != nil {
|
||||
t.Fatalf("json.Marshal() error = %v", err)
|
||||
}
|
||||
|
||||
// Unmarshal
|
||||
var decoded HTTPRequest
|
||||
err = json.Unmarshal(data, &decoded)
|
||||
if err != nil {
|
||||
t.Fatalf("json.Unmarshal() error = %v", err)
|
||||
}
|
||||
|
||||
// Compare
|
||||
if decoded.Method != req.Method {
|
||||
t.Errorf("Method = %v, want %v", decoded.Method, req.Method)
|
||||
}
|
||||
if decoded.URL != req.URL {
|
||||
t.Errorf("URL = %v, want %v", decoded.URL, req.URL)
|
||||
}
|
||||
if !reflect.DeepEqual(decoded.Headers, req.Headers) {
|
||||
t.Errorf("Headers = %v, want %v", decoded.Headers, req.Headers)
|
||||
}
|
||||
if string(decoded.Body) != string(req.Body) {
|
||||
t.Errorf("Body = %v, want %v", string(decoded.Body), string(req.Body))
|
||||
}
|
||||
}
|
||||
|
||||
func TestHTTPResponse_JSON(t *testing.T) {
|
||||
resp := &HTTPResponse{
|
||||
StatusCode: 200,
|
||||
Status: "200 OK",
|
||||
Headers: map[string][]string{
|
||||
"Content-Type": {"text/html"},
|
||||
},
|
||||
Body: []byte("<html>Test</html>"),
|
||||
}
|
||||
|
||||
// Marshal
|
||||
data, err := json.Marshal(resp)
|
||||
if err != nil {
|
||||
t.Fatalf("json.Marshal() error = %v", err)
|
||||
}
|
||||
|
||||
// Unmarshal
|
||||
var decoded HTTPResponse
|
||||
err = json.Unmarshal(data, &decoded)
|
||||
if err != nil {
|
||||
t.Fatalf("json.Unmarshal() error = %v", err)
|
||||
}
|
||||
|
||||
// Compare
|
||||
if decoded.StatusCode != resp.StatusCode {
|
||||
t.Errorf("StatusCode = %v, want %v", decoded.StatusCode, resp.StatusCode)
|
||||
}
|
||||
if decoded.Status != resp.Status {
|
||||
t.Errorf("Status = %v, want %v", decoded.Status, resp.Status)
|
||||
}
|
||||
if !reflect.DeepEqual(decoded.Headers, resp.Headers) {
|
||||
t.Errorf("Headers = %v, want %v", decoded.Headers, resp.Headers)
|
||||
}
|
||||
if string(decoded.Body) != string(resp.Body) {
|
||||
t.Errorf("Body = %v, want %v", string(decoded.Body), string(resp.Body))
|
||||
}
|
||||
}
|
||||
|
||||
func TestMessage_ToMap(t *testing.T) {
|
||||
msg := &Message{
|
||||
Type: TypeRequest,
|
||||
ID: "test-123",
|
||||
Subdomain: "abc",
|
||||
Data: map[string]interface{}{
|
||||
"test": "value",
|
||||
},
|
||||
}
|
||||
|
||||
// Convert to map (simulated by marshaling and unmarshaling)
|
||||
data, err := json.Marshal(msg)
|
||||
if err != nil {
|
||||
t.Fatalf("json.Marshal() error = %v", err)
|
||||
}
|
||||
|
||||
var result map[string]interface{}
|
||||
err = json.Unmarshal(data, &result)
|
||||
if err != nil {
|
||||
t.Fatalf("json.Unmarshal() error = %v", err)
|
||||
}
|
||||
|
||||
// Verify fields exist
|
||||
if result["type"] == nil {
|
||||
t.Error("Map missing 'type' field")
|
||||
}
|
||||
if result["id"] == nil {
|
||||
t.Error("Map missing 'id' field")
|
||||
}
|
||||
}
|
||||
|
||||
func TestNewMessage(t *testing.T) {
|
||||
msgType := TypeRegister
|
||||
id := "test-id"
|
||||
|
||||
msg := &Message{
|
||||
Type: msgType,
|
||||
ID: id,
|
||||
}
|
||||
|
||||
if msg.Type != msgType {
|
||||
t.Errorf("Type = %v, want %v", msg.Type, msgType)
|
||||
}
|
||||
if msg.ID != id {
|
||||
t.Errorf("ID = %v, want %v", msg.ID, id)
|
||||
}
|
||||
}
|
||||
|
||||
// Benchmark tests
|
||||
func BenchmarkMessageMarshal(b *testing.B) {
|
||||
msg := &Message{
|
||||
Type: TypeRequest,
|
||||
ID: "test-id-123",
|
||||
Subdomain: "abc123",
|
||||
Data: map[string]interface{}{
|
||||
"method": "GET",
|
||||
"path": "/test",
|
||||
},
|
||||
}
|
||||
|
||||
b.ResetTimer()
|
||||
for i := 0; i < b.N; i++ {
|
||||
json.Marshal(msg)
|
||||
}
|
||||
}
|
||||
|
||||
func BenchmarkMessageUnmarshal(b *testing.B) {
|
||||
msg := &Message{
|
||||
Type: TypeRequest,
|
||||
ID: "test-id-123",
|
||||
Subdomain: "abc123",
|
||||
Data: map[string]interface{}{
|
||||
"method": "GET",
|
||||
"path": "/test",
|
||||
},
|
||||
}
|
||||
|
||||
data, _ := json.Marshal(msg)
|
||||
|
||||
b.ResetTimer()
|
||||
for i := 0; i < b.N; i++ {
|
||||
var decoded Message
|
||||
json.Unmarshal(data, &decoded)
|
||||
}
|
||||
}
|
||||
49
internal/shared/protocol/messages.go
Normal file
49
internal/shared/protocol/messages.go
Normal file
@@ -0,0 +1,49 @@
|
||||
package protocol
|
||||
|
||||
import "encoding/json"
|
||||
|
||||
// RegisterRequest is sent by client to register a tunnel
|
||||
type RegisterRequest struct {
|
||||
Token string `json:"token"` // Authentication token
|
||||
CustomSubdomain string `json:"custom_subdomain"` // Optional custom subdomain
|
||||
TunnelType TunnelType `json:"tunnel_type"` // http, tcp, udp
|
||||
LocalPort int `json:"local_port"` // Local port to forward to
|
||||
}
|
||||
|
||||
// RegisterResponse is sent by server after successful registration
|
||||
type RegisterResponse struct {
|
||||
Subdomain string `json:"subdomain"` // Assigned subdomain
|
||||
Port int `json:"port,omitempty"` // Assigned TCP port (for TCP tunnels)
|
||||
URL string `json:"url"` // Full tunnel URL
|
||||
Message string `json:"message"` // Success message
|
||||
}
|
||||
|
||||
// ErrorMessage represents an error
|
||||
type ErrorMessage struct {
|
||||
Code string `json:"code"` // Error code
|
||||
Message string `json:"message"` // Error message
|
||||
}
|
||||
|
||||
// DataHeader represents metadata for a data frame
|
||||
type DataHeader struct {
|
||||
StreamID string `json:"stream_id"` // Unique stream identifier
|
||||
RequestID string `json:"request_id"` // Request identifier (for HTTP)
|
||||
Type string `json:"type"` // "data", "response", "close", "http_request", "http_response"
|
||||
IsLast bool `json:"is_last"` // Is this the last frame for this stream
|
||||
}
|
||||
|
||||
// TCPData represents TCP tunnel data
|
||||
type TCPData struct {
|
||||
StreamID string `json:"stream_id"` // Stream identifier
|
||||
Data []byte `json:"data"` // Raw TCP data
|
||||
IsClose bool `json:"is_close"` // Close this stream
|
||||
}
|
||||
|
||||
// Marshal helpers
|
||||
func MarshalJSON(v interface{}) ([]byte, error) {
|
||||
return json.Marshal(v)
|
||||
}
|
||||
|
||||
func UnmarshalJSON(data []byte, v interface{}) error {
|
||||
return json.Unmarshal(data, v)
|
||||
}
|
||||
129
internal/shared/protocol/payload.go
Normal file
129
internal/shared/protocol/payload.go
Normal file
@@ -0,0 +1,129 @@
|
||||
package protocol
|
||||
|
||||
import (
|
||||
"encoding/json"
|
||||
"errors"
|
||||
)
|
||||
|
||||
// EncodeDataPayload encodes a data header and payload into a frame payload
|
||||
// Uses binary encoding (optimized format)
|
||||
func EncodeDataPayload(header DataHeader, data []byte) ([]byte, error) {
|
||||
return EncodeDataPayloadV2(header, data)
|
||||
}
|
||||
|
||||
// EncodeDataPayloadV1 encodes using JSON (legacy)
|
||||
// Format: JSON_HEADER\nDATA
|
||||
func EncodeDataPayloadV1(header DataHeader, data []byte) ([]byte, error) {
|
||||
headerBytes, err := json.Marshal(header)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
// Combine: header + newline + data
|
||||
payload := make([]byte, 0, len(headerBytes)+1+len(data))
|
||||
payload = append(payload, headerBytes...)
|
||||
payload = append(payload, '\n')
|
||||
payload = append(payload, data...)
|
||||
|
||||
return payload, nil
|
||||
}
|
||||
|
||||
// EncodeDataPayloadV2 encodes using binary format (optimized)
|
||||
// Format: BINARY_HEADER + DATA
|
||||
func EncodeDataPayloadV2(header DataHeader, data []byte) ([]byte, error) {
|
||||
// Convert to binary header
|
||||
var h2 DataHeaderV2
|
||||
h2.FromDataHeader(header)
|
||||
|
||||
// Encode header to binary
|
||||
headerBytes := h2.MarshalBinary()
|
||||
|
||||
// Combine: binary header + data
|
||||
payload := make([]byte, 0, len(headerBytes)+len(data))
|
||||
payload = append(payload, headerBytes...)
|
||||
payload = append(payload, data...)
|
||||
|
||||
return payload, nil
|
||||
}
|
||||
|
||||
// DecodeDataPayload decodes a frame payload into header and data
|
||||
// Auto-detects protocol version
|
||||
func DecodeDataPayload(payload []byte) (DataHeader, []byte, error) {
|
||||
if len(payload) == 0 {
|
||||
return DataHeader{}, nil, errors.New("empty payload")
|
||||
}
|
||||
|
||||
// Try to detect version:
|
||||
// - V1 (JSON): starts with '{'
|
||||
// - V2 (Binary): first byte is flags (0x00-0x1F typically)
|
||||
if payload[0] == '{' {
|
||||
// V1: JSON format
|
||||
return DecodeDataPayloadV1(payload)
|
||||
}
|
||||
|
||||
// V2: Binary format
|
||||
return DecodeDataPayloadV2(payload)
|
||||
}
|
||||
|
||||
// DecodeDataPayloadV1 decodes JSON format (legacy)
|
||||
// Format: JSON_HEADER\nDATA
|
||||
func DecodeDataPayloadV1(payload []byte) (DataHeader, []byte, error) {
|
||||
// Find newline separator
|
||||
sepIdx := -1
|
||||
for i, b := range payload {
|
||||
if b == '\n' {
|
||||
sepIdx = i
|
||||
break
|
||||
}
|
||||
}
|
||||
|
||||
if sepIdx == -1 {
|
||||
return DataHeader{}, nil, errors.New("invalid v1 payload: no newline separator")
|
||||
}
|
||||
|
||||
// Parse JSON header
|
||||
var header DataHeader
|
||||
if err := json.Unmarshal(payload[:sepIdx], &header); err != nil {
|
||||
return DataHeader{}, nil, err
|
||||
}
|
||||
|
||||
// Extract data (after newline)
|
||||
data := payload[sepIdx+1:]
|
||||
|
||||
return header, data, nil
|
||||
}
|
||||
|
||||
// DecodeDataPayloadV2 decodes binary format (optimized)
|
||||
// Format: BINARY_HEADER + DATA
|
||||
func DecodeDataPayloadV2(payload []byte) (DataHeader, []byte, error) {
|
||||
if len(payload) < binaryHeaderMinSize {
|
||||
return DataHeader{}, nil, errors.New("invalid v2 payload: too short")
|
||||
}
|
||||
|
||||
// Decode binary header
|
||||
var h2 DataHeaderV2
|
||||
if err := h2.UnmarshalBinary(payload); err != nil {
|
||||
return DataHeader{}, nil, err
|
||||
}
|
||||
|
||||
// Extract data (after header)
|
||||
headerSize := h2.Size()
|
||||
if len(payload) < headerSize {
|
||||
return DataHeader{}, nil, errors.New("invalid v2 payload: data missing")
|
||||
}
|
||||
|
||||
data := payload[headerSize:]
|
||||
|
||||
// Convert to DataHeader
|
||||
header := h2.ToDataHeader()
|
||||
|
||||
return header, data, nil
|
||||
}
|
||||
|
||||
// GetPayloadHeaderSize returns the size of the header in the payload
|
||||
// This is useful for pre-allocating buffers
|
||||
func GetPayloadHeaderSize(header DataHeader) int {
|
||||
var h2 DataHeaderV2
|
||||
h2.FromDataHeader(header)
|
||||
return h2.Size()
|
||||
}
|
||||
30
internal/shared/protocol/tunnel_type.go
Normal file
30
internal/shared/protocol/tunnel_type.go
Normal file
@@ -0,0 +1,30 @@
|
||||
package protocol
|
||||
|
||||
// TunnelType defines the type of tunnel
|
||||
type TunnelType string
|
||||
|
||||
const (
|
||||
// TunnelTypeHTTP is for HTTP traffic
|
||||
TunnelTypeHTTP TunnelType = "http"
|
||||
// TunnelTypeHTTPS is for HTTPS traffic
|
||||
TunnelTypeHTTPS TunnelType = "https"
|
||||
// TunnelTypeTCP is for generic TCP traffic
|
||||
TunnelTypeTCP TunnelType = "tcp"
|
||||
// TunnelTypeUDP is for UDP traffic (future support)
|
||||
TunnelTypeUDP TunnelType = "udp"
|
||||
)
|
||||
|
||||
// String returns the string representation
|
||||
func (t TunnelType) String() string {
|
||||
return string(t)
|
||||
}
|
||||
|
||||
// IsValid checks if tunnel type is valid
|
||||
func (t TunnelType) IsValid() bool {
|
||||
switch t {
|
||||
case TunnelTypeHTTP, TunnelTypeHTTPS, TunnelTypeTCP, TunnelTypeUDP:
|
||||
return true
|
||||
default:
|
||||
return false
|
||||
}
|
||||
}
|
||||
126
internal/shared/protocol/writer.go
Normal file
126
internal/shared/protocol/writer.go
Normal file
@@ -0,0 +1,126 @@
|
||||
package protocol
|
||||
|
||||
import (
|
||||
"errors"
|
||||
"io"
|
||||
"sync"
|
||||
"time"
|
||||
)
|
||||
|
||||
type FrameWriter struct {
|
||||
conn io.Writer
|
||||
queue chan *Frame
|
||||
batch []*Frame
|
||||
mu sync.Mutex
|
||||
done chan struct{}
|
||||
closed bool
|
||||
|
||||
maxBatch int
|
||||
maxBatchWait time.Duration
|
||||
}
|
||||
|
||||
func NewFrameWriter(conn io.Writer) *FrameWriter {
|
||||
return NewFrameWriterWithConfig(conn, 128, 2*time.Millisecond, 1024)
|
||||
}
|
||||
|
||||
func NewFrameWriterWithConfig(conn io.Writer, maxBatch int, maxBatchWait time.Duration, queueSize int) *FrameWriter {
|
||||
w := &FrameWriter{
|
||||
conn: conn,
|
||||
queue: make(chan *Frame, queueSize),
|
||||
batch: make([]*Frame, 0, maxBatch),
|
||||
maxBatch: maxBatch,
|
||||
maxBatchWait: maxBatchWait,
|
||||
done: make(chan struct{}),
|
||||
}
|
||||
go w.writeLoop()
|
||||
return w
|
||||
}
|
||||
|
||||
func (w *FrameWriter) writeLoop() {
|
||||
ticker := time.NewTicker(w.maxBatchWait)
|
||||
defer ticker.Stop()
|
||||
|
||||
for {
|
||||
select {
|
||||
case frame, ok := <-w.queue:
|
||||
if !ok {
|
||||
w.mu.Lock()
|
||||
w.flushBatchLocked()
|
||||
w.mu.Unlock()
|
||||
return
|
||||
}
|
||||
|
||||
w.mu.Lock()
|
||||
w.batch = append(w.batch, frame)
|
||||
|
||||
if len(w.batch) >= w.maxBatch {
|
||||
w.flushBatchLocked()
|
||||
}
|
||||
w.mu.Unlock()
|
||||
|
||||
case <-ticker.C:
|
||||
w.mu.Lock()
|
||||
if len(w.batch) > 0 {
|
||||
w.flushBatchLocked()
|
||||
}
|
||||
w.mu.Unlock()
|
||||
|
||||
case <-w.done:
|
||||
w.mu.Lock()
|
||||
w.flushBatchLocked()
|
||||
w.mu.Unlock()
|
||||
return
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func (w *FrameWriter) flushBatchLocked() {
|
||||
if len(w.batch) == 0 {
|
||||
return
|
||||
}
|
||||
|
||||
for _, frame := range w.batch {
|
||||
_ = WriteFrame(w.conn, frame)
|
||||
}
|
||||
|
||||
w.batch = w.batch[:0]
|
||||
}
|
||||
|
||||
func (w *FrameWriter) WriteFrame(frame *Frame) error {
|
||||
w.mu.Lock()
|
||||
if w.closed {
|
||||
w.mu.Unlock()
|
||||
return errors.New("writer closed")
|
||||
}
|
||||
w.mu.Unlock()
|
||||
|
||||
select {
|
||||
case w.queue <- frame:
|
||||
return nil
|
||||
case <-w.done:
|
||||
return errors.New("writer closed")
|
||||
default:
|
||||
return WriteFrame(w.conn, frame)
|
||||
}
|
||||
}
|
||||
|
||||
func (w *FrameWriter) Close() error {
|
||||
w.mu.Lock()
|
||||
if w.closed {
|
||||
w.mu.Unlock()
|
||||
return nil
|
||||
}
|
||||
w.closed = true
|
||||
w.mu.Unlock()
|
||||
|
||||
close(w.queue)
|
||||
close(w.done)
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func (w *FrameWriter) Flush() {
|
||||
w.mu.Lock()
|
||||
defer w.mu.Unlock()
|
||||
w.flushBatchLocked()
|
||||
}
|
||||
31
internal/shared/utils/id.go
Normal file
31
internal/shared/utils/id.go
Normal file
@@ -0,0 +1,31 @@
|
||||
package utils
|
||||
|
||||
import (
|
||||
"crypto/rand"
|
||||
"encoding/hex"
|
||||
"time"
|
||||
)
|
||||
|
||||
// GenerateID generates a random unique ID
|
||||
func GenerateID() string {
|
||||
b := make([]byte, 16)
|
||||
if _, err := rand.Read(b); err != nil {
|
||||
// Fallback to timestamp-based ID if crypto/rand fails
|
||||
return generateFallbackID()
|
||||
}
|
||||
return hex.EncodeToString(b)
|
||||
}
|
||||
|
||||
// GenerateShortID generates a shorter random ID (8 chars)
|
||||
func GenerateShortID() string {
|
||||
b := make([]byte, 4)
|
||||
if _, err := rand.Read(b); err != nil {
|
||||
return generateFallbackID()[:8]
|
||||
}
|
||||
return hex.EncodeToString(b)
|
||||
}
|
||||
|
||||
func generateFallbackID() string {
|
||||
// Simple fallback using timestamp
|
||||
return hex.EncodeToString([]byte(time.Now().String()))
|
||||
}
|
||||
158
internal/shared/utils/id_test.go
Normal file
158
internal/shared/utils/id_test.go
Normal file
@@ -0,0 +1,158 @@
|
||||
package utils
|
||||
|
||||
import (
|
||||
"strings"
|
||||
"testing"
|
||||
)
|
||||
|
||||
func TestGenerateID(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
wantLength int // expected minimum length
|
||||
}{
|
||||
{
|
||||
name: "generate valid ID",
|
||||
wantLength: 16, // At least 16 characters for hex-encoded random bytes
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
got := GenerateID()
|
||||
|
||||
// Check that ID is not empty
|
||||
if got == "" {
|
||||
t.Error("GenerateID() returned empty string")
|
||||
}
|
||||
|
||||
// Check minimum length
|
||||
if len(got) < tt.wantLength {
|
||||
t.Errorf("GenerateID() length = %v, want at least %v", len(got), tt.wantLength)
|
||||
}
|
||||
|
||||
// Check that it's a valid hex string
|
||||
for _, char := range got {
|
||||
if !isHexChar(char) {
|
||||
t.Errorf("GenerateID() contains non-hex character: %c", char)
|
||||
}
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestGenerateIDUniqueness(t *testing.T) {
|
||||
// Generate 10000 IDs and check for uniqueness
|
||||
ids := make(map[string]bool)
|
||||
count := 10000
|
||||
|
||||
for i := 0; i < count; i++ {
|
||||
id := GenerateID()
|
||||
if ids[id] {
|
||||
t.Errorf("GenerateID() generated duplicate: %s", id)
|
||||
}
|
||||
ids[id] = true
|
||||
}
|
||||
|
||||
if len(ids) != count {
|
||||
t.Errorf("Expected %d unique IDs, got %d", count, len(ids))
|
||||
}
|
||||
}
|
||||
|
||||
func TestGenerateIDFormat(t *testing.T) {
|
||||
id := GenerateID()
|
||||
|
||||
// Check that it's lowercase
|
||||
if id != strings.ToLower(id) {
|
||||
t.Errorf("GenerateID() is not lowercase: %s", id)
|
||||
}
|
||||
|
||||
// Check that it doesn't contain special characters
|
||||
for _, char := range id {
|
||||
if !isHexChar(char) {
|
||||
t.Errorf("GenerateID() contains invalid character: %c in %s", char, id)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func TestGenerateIDConsistency(t *testing.T) {
|
||||
// Generate multiple IDs and ensure they all follow the same format
|
||||
count := 100
|
||||
firstID := GenerateID()
|
||||
firstLen := len(firstID)
|
||||
|
||||
for i := 0; i < count; i++ {
|
||||
id := GenerateID()
|
||||
|
||||
// All IDs should have the same length
|
||||
if len(id) != firstLen {
|
||||
t.Errorf("ID length inconsistency: first=%d, current=%d", firstLen, len(id))
|
||||
}
|
||||
|
||||
// All IDs should be hex strings
|
||||
for _, char := range id {
|
||||
if !isHexChar(char) {
|
||||
t.Errorf("Invalid hex character %c in ID: %s", char, id)
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func TestGenerateIDNotEmpty(t *testing.T) {
|
||||
// Generate 1000 IDs and ensure none are empty
|
||||
for i := 0; i < 1000; i++ {
|
||||
id := GenerateID()
|
||||
if id == "" {
|
||||
t.Error("GenerateID() returned empty string")
|
||||
}
|
||||
if len(id) == 0 {
|
||||
t.Error("GenerateID() returned zero-length string")
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// Helper function to check if a character is a valid hex character
|
||||
func isHexChar(char rune) bool {
|
||||
return (char >= '0' && char <= '9') || (char >= 'a' && char <= 'f')
|
||||
}
|
||||
|
||||
// Benchmark tests
|
||||
func BenchmarkGenerateID(b *testing.B) {
|
||||
for i := 0; i < b.N; i++ {
|
||||
GenerateID()
|
||||
}
|
||||
}
|
||||
|
||||
func BenchmarkGenerateIDParallel(b *testing.B) {
|
||||
b.RunParallel(func(pb *testing.PB) {
|
||||
for pb.Next() {
|
||||
GenerateID()
|
||||
}
|
||||
})
|
||||
}
|
||||
|
||||
// Test for concurrent ID generation
|
||||
func TestGenerateIDConcurrent(t *testing.T) {
|
||||
count := 1000
|
||||
ch := make(chan string, count)
|
||||
|
||||
// Generate IDs concurrently
|
||||
for i := 0; i < count; i++ {
|
||||
go func() {
|
||||
ch <- GenerateID()
|
||||
}()
|
||||
}
|
||||
|
||||
// Collect all IDs
|
||||
ids := make(map[string]bool)
|
||||
for i := 0; i < count; i++ {
|
||||
id := <-ch
|
||||
if ids[id] {
|
||||
t.Errorf("Concurrent GenerateID() generated duplicate: %s", id)
|
||||
}
|
||||
ids[id] = true
|
||||
}
|
||||
|
||||
if len(ids) != count {
|
||||
t.Errorf("Expected %d unique IDs, got %d", count, len(ids))
|
||||
}
|
||||
}
|
||||
104
internal/shared/utils/logger.go
Normal file
104
internal/shared/utils/logger.go
Normal file
@@ -0,0 +1,104 @@
|
||||
package utils
|
||||
|
||||
import (
|
||||
"os"
|
||||
|
||||
"go.uber.org/zap"
|
||||
"go.uber.org/zap/zapcore"
|
||||
)
|
||||
|
||||
var logger *zap.Logger
|
||||
|
||||
// InitLogger initializes the global logger for client
|
||||
// verbose: if true, shows debug level logs; if false, shows error level only
|
||||
func InitLogger(verbose bool) error {
|
||||
var config zap.Config
|
||||
|
||||
if verbose {
|
||||
// Verbose mode: show debug and above
|
||||
config = zap.NewDevelopmentConfig()
|
||||
config.EncoderConfig.EncodeLevel = zapcore.CapitalColorLevelEncoder
|
||||
} else {
|
||||
// Production mode: only show errors
|
||||
config = zap.NewProductionConfig()
|
||||
config.Level = zap.NewAtomicLevelAt(zapcore.ErrorLevel)
|
||||
}
|
||||
|
||||
config.OutputPaths = []string{"stdout"}
|
||||
config.ErrorOutputPaths = []string{"stderr"}
|
||||
|
||||
var err error
|
||||
logger, err = config.Build()
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// InitServerLogger initializes logger for server with info level by default
|
||||
func InitServerLogger(debug bool) error {
|
||||
var config zap.Config
|
||||
|
||||
if debug {
|
||||
// Debug mode: show all logs
|
||||
config = zap.NewDevelopmentConfig()
|
||||
config.EncoderConfig.EncodeLevel = zapcore.CapitalColorLevelEncoder
|
||||
} else {
|
||||
// Production mode: show info and above
|
||||
config = zap.NewProductionConfig()
|
||||
}
|
||||
|
||||
config.OutputPaths = []string{"stdout"}
|
||||
config.ErrorOutputPaths = []string{"stderr"}
|
||||
|
||||
var err error
|
||||
logger, err = config.Build()
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// GetLogger returns the global logger instance
|
||||
func GetLogger() *zap.Logger {
|
||||
if logger == nil {
|
||||
// Fallback to a basic logger if not initialized
|
||||
logger, _ = zap.NewProduction()
|
||||
}
|
||||
return logger
|
||||
}
|
||||
|
||||
// Info logs an info message
|
||||
func Info(msg string, fields ...zap.Field) {
|
||||
GetLogger().Info(msg, fields...)
|
||||
}
|
||||
|
||||
// Debug logs a debug message
|
||||
func Debug(msg string, fields ...zap.Field) {
|
||||
GetLogger().Debug(msg, fields...)
|
||||
}
|
||||
|
||||
// Warn logs a warning message
|
||||
func Warn(msg string, fields ...zap.Field) {
|
||||
GetLogger().Warn(msg, fields...)
|
||||
}
|
||||
|
||||
// Error logs an error message
|
||||
func Error(msg string, fields ...zap.Field) {
|
||||
GetLogger().Error(msg, fields...)
|
||||
}
|
||||
|
||||
// Fatal logs a fatal message and exits
|
||||
func Fatal(msg string, fields ...zap.Field) {
|
||||
GetLogger().Fatal(msg, fields...)
|
||||
os.Exit(1)
|
||||
}
|
||||
|
||||
// Sync flushes any buffered log entries
|
||||
func Sync() {
|
||||
if logger != nil {
|
||||
logger.Sync()
|
||||
}
|
||||
}
|
||||
66
internal/shared/utils/subdomain.go
Normal file
66
internal/shared/utils/subdomain.go
Normal file
@@ -0,0 +1,66 @@
|
||||
package utils
|
||||
|
||||
import (
|
||||
"crypto/rand"
|
||||
"math/big"
|
||||
"regexp"
|
||||
)
|
||||
|
||||
const (
|
||||
// SubdomainChars defines the allowed characters for subdomain generation
|
||||
SubdomainChars = "abcdefghijklmnopqrstuvwxyz0123456789"
|
||||
// DefaultSubdomainLength is the default length of generated subdomains
|
||||
DefaultSubdomainLength = 6
|
||||
)
|
||||
|
||||
var subdomainRegex = regexp.MustCompile(`^[a-z0-9][a-z0-9-]{1,61}[a-z0-9]$`)
|
||||
|
||||
// GenerateSubdomain generates a random subdomain
|
||||
func GenerateSubdomain(length int) string {
|
||||
if length <= 0 {
|
||||
length = DefaultSubdomainLength
|
||||
}
|
||||
|
||||
result := make([]byte, length)
|
||||
charsLen := big.NewInt(int64(len(SubdomainChars)))
|
||||
|
||||
for i := 0; i < length; i++ {
|
||||
num, err := rand.Int(rand.Reader, charsLen)
|
||||
if err != nil {
|
||||
// Fallback to simple random if crypto/rand fails
|
||||
result[i] = SubdomainChars[i%len(SubdomainChars)]
|
||||
continue
|
||||
}
|
||||
result[i] = SubdomainChars[num.Int64()]
|
||||
}
|
||||
|
||||
return string(result)
|
||||
}
|
||||
|
||||
// ValidateSubdomain checks if a subdomain is valid
|
||||
func ValidateSubdomain(subdomain string) bool {
|
||||
if len(subdomain) < 3 || len(subdomain) > 63 {
|
||||
return false
|
||||
}
|
||||
return subdomainRegex.MatchString(subdomain)
|
||||
}
|
||||
|
||||
// IsReserved checks if a subdomain is reserved
|
||||
func IsReserved(subdomain string) bool {
|
||||
reserved := map[string]bool{
|
||||
"www": true,
|
||||
"api": true,
|
||||
"admin": true,
|
||||
"app": true,
|
||||
"mail": true,
|
||||
"ftp": true,
|
||||
"blog": true,
|
||||
"shop": true,
|
||||
"status": true,
|
||||
"health": true,
|
||||
"test": true,
|
||||
"dev": true,
|
||||
"staging": true,
|
||||
}
|
||||
return reserved[subdomain]
|
||||
}
|
||||
266
internal/shared/utils/subdomain_test.go
Normal file
266
internal/shared/utils/subdomain_test.go
Normal file
@@ -0,0 +1,266 @@
|
||||
package utils
|
||||
|
||||
import (
|
||||
"strings"
|
||||
"testing"
|
||||
)
|
||||
|
||||
func TestGenerateSubdomain(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
length int
|
||||
want int // expected length
|
||||
}{
|
||||
{
|
||||
name: "default length 6",
|
||||
length: 6,
|
||||
want: 6,
|
||||
},
|
||||
{
|
||||
name: "length 8",
|
||||
length: 8,
|
||||
want: 8,
|
||||
},
|
||||
{
|
||||
name: "length 10",
|
||||
length: 10,
|
||||
want: 10,
|
||||
},
|
||||
{
|
||||
name: "minimum length 4",
|
||||
length: 4,
|
||||
want: 4,
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
got := GenerateSubdomain(tt.length)
|
||||
|
||||
// Check length
|
||||
if len(got) != tt.want {
|
||||
t.Errorf("GenerateSubdomain() length = %v, want %v", len(got), tt.want)
|
||||
}
|
||||
|
||||
// Check that it only contains alphanumeric characters
|
||||
for _, char := range got {
|
||||
if !isAlphanumeric(char) {
|
||||
t.Errorf("GenerateSubdomain() contains non-alphanumeric character: %c", char)
|
||||
}
|
||||
}
|
||||
|
||||
// Check that it's lowercase
|
||||
if got != strings.ToLower(got) {
|
||||
t.Errorf("GenerateSubdomain() is not lowercase: %s", got)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestGenerateSubdomainUniqueness(t *testing.T) {
|
||||
// Generate 1000 subdomains and check for uniqueness
|
||||
subdomains := make(map[string]bool)
|
||||
count := 1000
|
||||
length := 6
|
||||
|
||||
for i := 0; i < count; i++ {
|
||||
subdomain := GenerateSubdomain(length)
|
||||
if subdomains[subdomain] {
|
||||
t.Errorf("GenerateSubdomain() generated duplicate: %s", subdomain)
|
||||
}
|
||||
subdomains[subdomain] = true
|
||||
}
|
||||
|
||||
if len(subdomains) != count {
|
||||
t.Errorf("Expected %d unique subdomains, got %d", count, len(subdomains))
|
||||
}
|
||||
}
|
||||
|
||||
func TestValidateSubdomain(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
subdomain string
|
||||
want bool
|
||||
}{
|
||||
{
|
||||
name: "valid lowercase",
|
||||
subdomain: "abc123",
|
||||
want: true,
|
||||
},
|
||||
{
|
||||
name: "valid all letters",
|
||||
subdomain: "abcdef",
|
||||
want: true,
|
||||
},
|
||||
{
|
||||
name: "valid all numbers",
|
||||
subdomain: "123456",
|
||||
want: true,
|
||||
},
|
||||
{
|
||||
name: "invalid uppercase",
|
||||
subdomain: "ABC123",
|
||||
want: false,
|
||||
},
|
||||
{
|
||||
name: "valid with hyphen",
|
||||
subdomain: "abc-123",
|
||||
want: true,
|
||||
},
|
||||
{
|
||||
name: "invalid starting with hyphen",
|
||||
subdomain: "-abc123",
|
||||
want: false,
|
||||
},
|
||||
{
|
||||
name: "invalid ending with hyphen",
|
||||
subdomain: "abc123-",
|
||||
want: false,
|
||||
},
|
||||
{
|
||||
name: "invalid with underscore",
|
||||
subdomain: "abc_123",
|
||||
want: false,
|
||||
},
|
||||
{
|
||||
name: "invalid with dot",
|
||||
subdomain: "abc.123",
|
||||
want: false,
|
||||
},
|
||||
{
|
||||
name: "invalid with space",
|
||||
subdomain: "abc 123",
|
||||
want: false,
|
||||
},
|
||||
{
|
||||
name: "invalid empty",
|
||||
subdomain: "",
|
||||
want: false,
|
||||
},
|
||||
{
|
||||
name: "invalid special characters",
|
||||
subdomain: "abc@123",
|
||||
want: false,
|
||||
},
|
||||
{
|
||||
name: "valid minimum length",
|
||||
subdomain: "abc",
|
||||
want: true,
|
||||
},
|
||||
{
|
||||
name: "invalid too short",
|
||||
subdomain: "ab",
|
||||
want: false,
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
got := ValidateSubdomain(tt.subdomain)
|
||||
if got != tt.want {
|
||||
t.Errorf("ValidateSubdomain(%q) = %v, want %v", tt.subdomain, got, tt.want)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestIsReserved(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
subdomain string
|
||||
want bool
|
||||
}{
|
||||
{
|
||||
name: "reserved www",
|
||||
subdomain: "www",
|
||||
want: true,
|
||||
},
|
||||
{
|
||||
name: "reserved api",
|
||||
subdomain: "api",
|
||||
want: true,
|
||||
},
|
||||
{
|
||||
name: "reserved admin",
|
||||
subdomain: "admin",
|
||||
want: true,
|
||||
},
|
||||
{
|
||||
name: "reserved mail",
|
||||
subdomain: "mail",
|
||||
want: true,
|
||||
},
|
||||
{
|
||||
name: "reserved ftp",
|
||||
subdomain: "ftp",
|
||||
want: true,
|
||||
},
|
||||
{
|
||||
name: "reserved health",
|
||||
subdomain: "health",
|
||||
want: true,
|
||||
},
|
||||
{
|
||||
name: "reserved test",
|
||||
subdomain: "test",
|
||||
want: true,
|
||||
},
|
||||
{
|
||||
name: "reserved dev",
|
||||
subdomain: "dev",
|
||||
want: true,
|
||||
},
|
||||
{
|
||||
name: "reserved staging",
|
||||
subdomain: "staging",
|
||||
want: true,
|
||||
},
|
||||
{
|
||||
name: "not reserved random",
|
||||
subdomain: "abc123",
|
||||
want: false,
|
||||
},
|
||||
{
|
||||
name: "not reserved user",
|
||||
subdomain: "myapp",
|
||||
want: false,
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
got := IsReserved(tt.subdomain)
|
||||
if got != tt.want {
|
||||
t.Errorf("IsReserved(%q) = %v, want %v", tt.subdomain, got, tt.want)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
// Helper function to check if a character is alphanumeric
|
||||
func isAlphanumeric(char rune) bool {
|
||||
return (char >= 'a' && char <= 'z') || (char >= '0' && char <= '9')
|
||||
}
|
||||
|
||||
// Benchmark tests
|
||||
func BenchmarkGenerateSubdomain(b *testing.B) {
|
||||
for i := 0; i < b.N; i++ {
|
||||
GenerateSubdomain(6)
|
||||
}
|
||||
}
|
||||
|
||||
func BenchmarkValidateSubdomain(b *testing.B) {
|
||||
subdomain := "abc123"
|
||||
b.ResetTimer()
|
||||
for i := 0; i < b.N; i++ {
|
||||
ValidateSubdomain(subdomain)
|
||||
}
|
||||
}
|
||||
|
||||
func BenchmarkIsReserved(b *testing.B) {
|
||||
subdomain := "www"
|
||||
b.ResetTimer()
|
||||
for i := 0; i < b.N; i++ {
|
||||
IsReserved(subdomain)
|
||||
}
|
||||
}
|
||||
60
nginx.example.conf
Normal file
60
nginx.example.conf
Normal file
@@ -0,0 +1,60 @@
|
||||
# Drip Tunnel Server - Nginx 配置
|
||||
#
|
||||
# 架构:外部用户 -> Nginx (443) -> Drip Server (8443) -> 客户端
|
||||
#
|
||||
# 前置条件:
|
||||
# 1. 获取通配符 SSL 证书:
|
||||
# certbot certonly --manual --preferred-challenges dns \
|
||||
# -d "*.tunnel.example.com" -d "tunnel.example.com"
|
||||
#
|
||||
# 2. DNS 配置:
|
||||
# A tunnel.example.com -> YOUR_SERVER_IP
|
||||
# A *.tunnel.example.com -> YOUR_SERVER_IP
|
||||
#
|
||||
# 3. 启动 Drip Server:
|
||||
# ./bin/drip-server --port 8443 --domain tunnel.example.com \
|
||||
# --tls-cert /etc/letsencrypt/live/tunnel.example.com/fullchain.pem \
|
||||
# --tls-key /etc/letsencrypt/live/tunnel.example.com/privkey.pem
|
||||
|
||||
# HTTP 重定向到 HTTPS
|
||||
server {
|
||||
listen 80;
|
||||
server_name tunnel.example.com *.tunnel.example.com;
|
||||
return 301 https://$host$request_uri;
|
||||
}
|
||||
|
||||
# HTTPS 代理到 Drip Server
|
||||
server {
|
||||
listen 443 ssl http2;
|
||||
server_name tunnel.example.com *.tunnel.example.com;
|
||||
|
||||
# SSL 证书
|
||||
ssl_certificate /etc/letsencrypt/live/tunnel.example.com/fullchain.pem;
|
||||
ssl_certificate_key /etc/letsencrypt/live/tunnel.example.com/privkey.pem;
|
||||
ssl_protocols TLSv1.2 TLSv1.3;
|
||||
|
||||
# 代理到 Drip Server
|
||||
location / {
|
||||
proxy_pass https://127.0.0.1:8443;
|
||||
proxy_ssl_verify off;
|
||||
proxy_http_version 1.1;
|
||||
|
||||
# 转发请求头
|
||||
proxy_set_header Host $host;
|
||||
proxy_set_header X-Real-IP $remote_addr;
|
||||
proxy_set_header X-Forwarded-For $proxy_add_x_forwarded_for;
|
||||
proxy_set_header X-Forwarded-Proto $scheme;
|
||||
|
||||
# 超时配置
|
||||
proxy_connect_timeout 60s;
|
||||
proxy_send_timeout 300s;
|
||||
proxy_read_timeout 300s;
|
||||
|
||||
# 禁用缓冲
|
||||
proxy_buffering off;
|
||||
proxy_request_buffering off;
|
||||
|
||||
# 大文件支持
|
||||
client_max_body_size 100m;
|
||||
}
|
||||
}
|
||||
87
pkg/config/client_config.go
Normal file
87
pkg/config/client_config.go
Normal file
@@ -0,0 +1,87 @@
|
||||
package config
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"os"
|
||||
"path/filepath"
|
||||
|
||||
"gopkg.in/yaml.v3"
|
||||
)
|
||||
|
||||
// ClientConfig represents the client configuration
|
||||
type ClientConfig struct {
|
||||
Server string `yaml:"server"` // Server address (e.g., tunnel.example.com:443)
|
||||
Token string `yaml:"token"` // Authentication token
|
||||
TLS bool `yaml:"tls"` // Use TLS (always true for production)
|
||||
}
|
||||
|
||||
// DefaultClientConfig returns the default configuration path
|
||||
func DefaultClientConfigPath() string {
|
||||
home, err := os.UserHomeDir()
|
||||
if err != nil {
|
||||
return ".drip/config.yaml"
|
||||
}
|
||||
return filepath.Join(home, ".drip", "config.yaml")
|
||||
}
|
||||
|
||||
// LoadClientConfig loads configuration from file
|
||||
func LoadClientConfig(path string) (*ClientConfig, error) {
|
||||
if path == "" {
|
||||
path = DefaultClientConfigPath()
|
||||
}
|
||||
|
||||
data, err := os.ReadFile(path)
|
||||
if err != nil {
|
||||
if os.IsNotExist(err) {
|
||||
return nil, fmt.Errorf("config file not found at %s, please run 'drip config init' first", path)
|
||||
}
|
||||
return nil, fmt.Errorf("failed to read config file: %w", err)
|
||||
}
|
||||
|
||||
var config ClientConfig
|
||||
if err := yaml.Unmarshal(data, &config); err != nil {
|
||||
return nil, fmt.Errorf("failed to parse config file: %w", err)
|
||||
}
|
||||
|
||||
// Validate required fields
|
||||
if config.Server == "" {
|
||||
return nil, fmt.Errorf("server address is required in config")
|
||||
}
|
||||
|
||||
return &config, nil
|
||||
}
|
||||
|
||||
// SaveClientConfig saves configuration to file
|
||||
func SaveClientConfig(config *ClientConfig, path string) error {
|
||||
if path == "" {
|
||||
path = DefaultClientConfigPath()
|
||||
}
|
||||
|
||||
// Create directory if not exists
|
||||
dir := filepath.Dir(path)
|
||||
if err := os.MkdirAll(dir, 0700); err != nil {
|
||||
return fmt.Errorf("failed to create config directory: %w", err)
|
||||
}
|
||||
|
||||
// Marshal to YAML
|
||||
data, err := yaml.Marshal(config)
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to marshal config: %w", err)
|
||||
}
|
||||
|
||||
// Write to file with secure permissions
|
||||
if err := os.WriteFile(path, data, 0600); err != nil {
|
||||
return fmt.Errorf("failed to write config file: %w", err)
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// ConfigExists checks if config file exists
|
||||
func ConfigExists(path string) bool {
|
||||
if path == "" {
|
||||
path = DefaultClientConfigPath()
|
||||
}
|
||||
_, err := os.Stat(path)
|
||||
return err == nil
|
||||
}
|
||||
130
pkg/config/config.go
Normal file
130
pkg/config/config.go
Normal file
@@ -0,0 +1,130 @@
|
||||
package config
|
||||
|
||||
import (
|
||||
"crypto/tls"
|
||||
"fmt"
|
||||
"os"
|
||||
)
|
||||
|
||||
// ServerConfig holds the server configuration
|
||||
type ServerConfig struct {
|
||||
// Server settings
|
||||
Port int
|
||||
PublicPort int // Port to display in URLs (for reverse proxy scenarios)
|
||||
Domain string
|
||||
|
||||
// TCP tunnel dynamic port allocation
|
||||
TCPPortMin int
|
||||
TCPPortMax int
|
||||
|
||||
// TLS/SSL settings
|
||||
TLSEnabled bool
|
||||
TLSCertFile string
|
||||
TLSKeyFile string
|
||||
AutoTLS bool // Automatic Let's Encrypt
|
||||
|
||||
// Security
|
||||
AuthToken string
|
||||
|
||||
// Logging
|
||||
Debug bool
|
||||
}
|
||||
|
||||
// LegacyClientConfig holds the legacy client configuration
|
||||
// Deprecated: Use config.ClientConfig from client_config.go instead
|
||||
type LegacyClientConfig struct {
|
||||
ServerURL string
|
||||
LocalTarget string
|
||||
AuthToken string
|
||||
Subdomain string
|
||||
Verbose bool
|
||||
Insecure bool // Skip TLS verification (for testing only)
|
||||
}
|
||||
|
||||
// LoadTLSConfig loads TLS configuration
|
||||
func (c *ServerConfig) LoadTLSConfig() (*tls.Config, error) {
|
||||
if !c.TLSEnabled {
|
||||
return nil, nil
|
||||
}
|
||||
|
||||
// Check if certificate files exist
|
||||
if c.TLSCertFile == "" || c.TLSKeyFile == "" {
|
||||
return nil, fmt.Errorf("TLS enabled but certificate files not specified")
|
||||
}
|
||||
|
||||
if _, err := os.Stat(c.TLSCertFile); os.IsNotExist(err) {
|
||||
return nil, fmt.Errorf("certificate file not found: %s", c.TLSCertFile)
|
||||
}
|
||||
|
||||
if _, err := os.Stat(c.TLSKeyFile); os.IsNotExist(err) {
|
||||
return nil, fmt.Errorf("key file not found: %s", c.TLSKeyFile)
|
||||
}
|
||||
|
||||
// Load certificate
|
||||
cert, err := tls.LoadX509KeyPair(c.TLSCertFile, c.TLSKeyFile)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to load certificate: %w", err)
|
||||
}
|
||||
|
||||
// Force TLS 1.3 only
|
||||
tlsConfig := &tls.Config{
|
||||
Certificates: []tls.Certificate{cert},
|
||||
MinVersion: tls.VersionTLS13, // Only TLS 1.3
|
||||
MaxVersion: tls.VersionTLS13, // Only TLS 1.3
|
||||
CipherSuites: []uint16{
|
||||
tls.TLS_AES_128_GCM_SHA256,
|
||||
tls.TLS_AES_256_GCM_SHA384,
|
||||
tls.TLS_CHACHA20_POLY1305_SHA256,
|
||||
},
|
||||
}
|
||||
|
||||
return tlsConfig, nil
|
||||
}
|
||||
|
||||
// GetClientTLSConfig returns TLS config for client connections
|
||||
func GetClientTLSConfig(serverName string) *tls.Config {
|
||||
return &tls.Config{
|
||||
ServerName: serverName,
|
||||
MinVersion: tls.VersionTLS13, // Only TLS 1.3
|
||||
MaxVersion: tls.VersionTLS13, // Only TLS 1.3
|
||||
CipherSuites: []uint16{
|
||||
tls.TLS_AES_128_GCM_SHA256,
|
||||
tls.TLS_AES_256_GCM_SHA384,
|
||||
tls.TLS_CHACHA20_POLY1305_SHA256,
|
||||
},
|
||||
}
|
||||
}
|
||||
|
||||
// GetClientTLSConfigInsecure returns TLS config for client with InsecureSkipVerify
|
||||
// WARNING: Only use for testing!
|
||||
func GetClientTLSConfigInsecure() *tls.Config {
|
||||
return &tls.Config{
|
||||
InsecureSkipVerify: true,
|
||||
MinVersion: tls.VersionTLS13, // Only TLS 1.3
|
||||
MaxVersion: tls.VersionTLS13, // Only TLS 1.3
|
||||
CipherSuites: []uint16{
|
||||
tls.TLS_AES_128_GCM_SHA256,
|
||||
tls.TLS_AES_256_GCM_SHA384,
|
||||
tls.TLS_CHACHA20_POLY1305_SHA256,
|
||||
},
|
||||
}
|
||||
}
|
||||
|
||||
// GetServerURL returns the server URL based on configuration
|
||||
func (c *ServerConfig) GetServerURL() string {
|
||||
protocol := "http"
|
||||
if c.TLSEnabled {
|
||||
protocol = "https"
|
||||
}
|
||||
|
||||
if c.Port == 80 || (c.TLSEnabled && c.Port == 443) {
|
||||
return fmt.Sprintf("%s://%s", protocol, c.Domain)
|
||||
}
|
||||
|
||||
return fmt.Sprintf("%s://%s:%d", protocol, c.Domain, c.Port)
|
||||
}
|
||||
|
||||
// GetTCPAddress returns the TCP address for tunnel connections
|
||||
func (c *ServerConfig) GetTCPAddress() string {
|
||||
return fmt.Sprintf("%s:%d", c.Domain, c.Port)
|
||||
}
|
||||
52
scripts/generate-cert.sh
Executable file
52
scripts/generate-cert.sh
Executable file
@@ -0,0 +1,52 @@
|
||||
#!/bin/bash
|
||||
|
||||
# Generate self-signed certificate for development/testing using ECDSA
|
||||
# ECDSA provides better performance and smaller key size than RSA
|
||||
# WARNING: Do NOT use self-signed certificates in production!
|
||||
|
||||
set -e
|
||||
|
||||
CERT_DIR="${1:-./certs}"
|
||||
DOMAIN="${2:-localhost}"
|
||||
DAYS=365
|
||||
|
||||
echo "🔒 Generating self-signed certificate for development..."
|
||||
echo " Domain: $DOMAIN"
|
||||
echo " Output directory: $CERT_DIR"
|
||||
echo ""
|
||||
|
||||
# Create directory if it doesn't exist
|
||||
mkdir -p "$CERT_DIR"
|
||||
|
||||
# Generate ECDSA private key (using P-256 curve)
|
||||
openssl ecparam -genkey -name prime256v1 -out "$CERT_DIR/server.key"
|
||||
|
||||
# Generate certificate signing request
|
||||
openssl req -new -key "$CERT_DIR/server.key" -out "$CERT_DIR/server.csr" \
|
||||
-subj "/C=US/ST=State/L=City/O=Organization/CN=$DOMAIN"
|
||||
|
||||
# Generate self-signed certificate
|
||||
openssl x509 -req -days $DAYS \
|
||||
-in "$CERT_DIR/server.csr" \
|
||||
-signkey "$CERT_DIR/server.key" \
|
||||
-out "$CERT_DIR/server.crt" \
|
||||
-extfile <(printf "subjectAltName=DNS:$DOMAIN,DNS:*.$DOMAIN")
|
||||
|
||||
# Clean up CSR
|
||||
rm "$CERT_DIR/server.csr"
|
||||
|
||||
echo "✅ Certificate generated successfully!"
|
||||
echo ""
|
||||
echo "Files created:"
|
||||
echo " Certificate: $CERT_DIR/server.crt"
|
||||
echo " Private Key: $CERT_DIR/server.key"
|
||||
echo ""
|
||||
echo "Usage:"
|
||||
echo " ./bin/drip server \\"
|
||||
echo " --domain $DOMAIN \\"
|
||||
echo " --tls-cert $CERT_DIR/server.crt \\"
|
||||
echo " --tls-key $CERT_DIR/server.key"
|
||||
echo ""
|
||||
echo "⚠️ WARNING: This is a self-signed certificate for development only!"
|
||||
echo " Clients will need to skip certificate verification (insecure)."
|
||||
echo " For production, use --auto-tls or get a certificate from Let's Encrypt."
|
||||
1007
scripts/install-server.sh
Executable file
1007
scripts/install-server.sh
Executable file
File diff suppressed because it is too large
Load Diff
642
scripts/install.sh
Executable file
642
scripts/install.sh
Executable file
@@ -0,0 +1,642 @@
|
||||
#!/bin/bash
|
||||
set -e
|
||||
|
||||
# ============================================================================
|
||||
# Configuration
|
||||
# ============================================================================
|
||||
DOWNLOAD_BASE_URL="https://"
|
||||
INSTALL_DIR="${INSTALL_DIR:-}"
|
||||
VERSION="${VERSION:-latest}"
|
||||
BINARY_NAME="drip"
|
||||
|
||||
# Colors
|
||||
RED='\033[0;31m'
|
||||
GREEN='\033[0;32m'
|
||||
YELLOW='\033[1;33m'
|
||||
BLUE='\033[0;34m'
|
||||
CYAN='\033[0;36m'
|
||||
BOLD='\033[1m'
|
||||
NC='\033[0m'
|
||||
|
||||
# Language (default: en)
|
||||
LANG_CODE="${LANG_CODE:-en}"
|
||||
|
||||
# ============================================================================
|
||||
# Internationalization
|
||||
# ============================================================================
|
||||
declare -A MSG_EN
|
||||
declare -A MSG_ZH
|
||||
|
||||
# English messages
|
||||
MSG_EN=(
|
||||
["banner_title"]="Drip Client - One-Click Installer"
|
||||
["select_lang"]="Select language / 选择语言"
|
||||
["lang_en"]="English"
|
||||
["lang_zh"]="中文"
|
||||
["checking_os"]="Checking operating system..."
|
||||
["detected_os"]="Detected OS"
|
||||
["unsupported_os"]="Unsupported operating system"
|
||||
["checking_arch"]="Checking system architecture..."
|
||||
["detected_arch"]="Detected architecture"
|
||||
["unsupported_arch"]="Unsupported architecture"
|
||||
["checking_deps"]="Checking dependencies..."
|
||||
["deps_ok"]="Dependencies check passed"
|
||||
["downloading"]="Downloading Drip client"
|
||||
["download_failed"]="Download failed"
|
||||
["download_ok"]="Download completed"
|
||||
["select_install_dir"]="Select installation directory"
|
||||
["option_user"]="User directory (no sudo required)"
|
||||
["option_system"]="System directory (requires sudo)"
|
||||
["option_current"]="Current directory"
|
||||
["option_custom"]="Custom path"
|
||||
["enter_custom_path"]="Enter custom path"
|
||||
["installing"]="Installing binary..."
|
||||
["install_ok"]="Installation completed"
|
||||
["updating_path"]="Updating PATH..."
|
||||
["path_updated"]="PATH updated"
|
||||
["path_note"]="Please restart your terminal or run: source ~/.bashrc"
|
||||
["config_title"]="Client Configuration"
|
||||
["configure_now"]="Configure client now?"
|
||||
["enter_server"]="Enter server address (e.g., tunnel.example.com:8443)"
|
||||
["server_required"]="Server address is required"
|
||||
["enter_token"]="Enter authentication token"
|
||||
["token_required"]="Token is required"
|
||||
["skip_verify"]="Skip TLS certificate verification? (for self-signed certs)"
|
||||
["config_saved"]="Configuration saved"
|
||||
["install_complete"]="Installation completed!"
|
||||
["usage_title"]="Usage"
|
||||
["usage_http"]="Expose HTTP service on port 3000"
|
||||
["usage_tcp"]="Expose TCP service on port 5432"
|
||||
["usage_config"]="Show/modify configuration"
|
||||
["usage_daemon"]="Run as background daemon"
|
||||
["run_test"]="Test connection now?"
|
||||
["test_running"]="Testing connection..."
|
||||
["test_success"]="Connection successful"
|
||||
["test_failed"]="Connection failed"
|
||||
["yes"]="y"
|
||||
["no"]="n"
|
||||
["press_enter"]="Press Enter to continue..."
|
||||
["windows_note"]="For Windows, please download the .exe file from GitHub Releases"
|
||||
["already_installed"]="Drip is already installed"
|
||||
["update_now"]="Update to the latest version?"
|
||||
["updating"]="Updating..."
|
||||
["update_ok"]="Update completed"
|
||||
["verify_install"]="Verifying installation..."
|
||||
["verify_ok"]="Verification passed"
|
||||
["verify_failed"]="Verification failed"
|
||||
["insecure_note"]="Only use --insecure for development/testing"
|
||||
)
|
||||
|
||||
# Chinese messages
|
||||
MSG_ZH=(
|
||||
["banner_title"]="Drip 客户端 - 一键安装脚本"
|
||||
["select_lang"]="Select language / 选择语言"
|
||||
["lang_en"]="English"
|
||||
["lang_zh"]="中文"
|
||||
["checking_os"]="检查操作系统..."
|
||||
["detected_os"]="检测到操作系统"
|
||||
["unsupported_os"]="不支持的操作系统"
|
||||
["checking_arch"]="检查系统架构..."
|
||||
["detected_arch"]="检测到架构"
|
||||
["unsupported_arch"]="不支持的架构"
|
||||
["checking_deps"]="检查依赖..."
|
||||
["deps_ok"]="依赖检查通过"
|
||||
["downloading"]="下载 Drip 客户端"
|
||||
["download_failed"]="下载失败"
|
||||
["download_ok"]="下载完成"
|
||||
["select_install_dir"]="选择安装目录"
|
||||
["option_user"]="用户目录(无需 sudo)"
|
||||
["option_system"]="系统目录(需要 sudo)"
|
||||
["option_current"]="当前目录"
|
||||
["option_custom"]="自定义路径"
|
||||
["enter_custom_path"]="输入自定义路径"
|
||||
["installing"]="安装二进制文件..."
|
||||
["install_ok"]="安装完成"
|
||||
["updating_path"]="更新 PATH..."
|
||||
["path_updated"]="PATH 已更新"
|
||||
["path_note"]="请重启终端或运行: source ~/.bashrc"
|
||||
["config_title"]="客户端配置"
|
||||
["configure_now"]="现在配置客户端?"
|
||||
["enter_server"]="输入服务器地址(例如:tunnel.example.com:8443)"
|
||||
["server_required"]="服务器地址是必填项"
|
||||
["enter_token"]="输入认证令牌"
|
||||
["token_required"]="认证令牌是必填项"
|
||||
["skip_verify"]="跳过 TLS 证书验证?(用于自签名证书)"
|
||||
["config_saved"]="配置已保存"
|
||||
["install_complete"]="安装完成!"
|
||||
["usage_title"]="使用方法"
|
||||
["usage_http"]="暴露本地 3000 端口的 HTTP 服务"
|
||||
["usage_tcp"]="暴露本地 5432 端口的 TCP 服务"
|
||||
["usage_config"]="显示/修改配置"
|
||||
["usage_daemon"]="作为后台守护进程运行"
|
||||
["run_test"]="现在测试连接?"
|
||||
["test_running"]="正在测试连接..."
|
||||
["test_success"]="连接成功"
|
||||
["test_failed"]="连接失败"
|
||||
["yes"]="y"
|
||||
["no"]="n"
|
||||
["press_enter"]="按 Enter 继续..."
|
||||
["windows_note"]="Windows 用户请从 GitHub Releases 下载 .exe 文件"
|
||||
["already_installed"]="Drip 已安装"
|
||||
["update_now"]="是否更新到最新版本?"
|
||||
["updating"]="正在更新..."
|
||||
["update_ok"]="更新完成"
|
||||
["verify_install"]="验证安装..."
|
||||
["verify_ok"]="验证通过"
|
||||
["verify_failed"]="验证失败"
|
||||
["insecure_note"]="--insecure 仅用于开发/测试环境"
|
||||
)
|
||||
|
||||
# Get message by key
|
||||
msg() {
|
||||
local key="$1"
|
||||
if [[ "$LANG_CODE" == "zh" ]]; then
|
||||
echo "${MSG_ZH[$key]:-$key}"
|
||||
else
|
||||
echo "${MSG_EN[$key]:-$key}"
|
||||
fi
|
||||
}
|
||||
|
||||
# ============================================================================
|
||||
# Output functions
|
||||
# ============================================================================
|
||||
print_info() { echo -e "${BLUE}[INFO]${NC} $1"; }
|
||||
print_success() { echo -e "${GREEN}[✓]${NC} $1"; }
|
||||
print_warning() { echo -e "${YELLOW}[!]${NC} $1"; }
|
||||
print_error() { echo -e "${RED}[✗]${NC} $1"; }
|
||||
print_step() { echo -e "${CYAN}[→]${NC} $1"; }
|
||||
|
||||
# Print banner
|
||||
print_banner() {
|
||||
echo -e "${GREEN}"
|
||||
cat << "EOF"
|
||||
____ _
|
||||
/ __ \_____(_)___
|
||||
/ / / / ___/ / __ \
|
||||
/ /_/ / / / / /_/ /
|
||||
/_____/_/ /_/ .___/
|
||||
/_/
|
||||
EOF
|
||||
echo -e "${BOLD}$(msg banner_title)${NC}"
|
||||
echo ""
|
||||
}
|
||||
|
||||
# ============================================================================
|
||||
# Language selection
|
||||
# ============================================================================
|
||||
select_language() {
|
||||
echo ""
|
||||
echo -e "${CYAN}╔════════════════════════════════════════╗${NC}"
|
||||
echo -e "${CYAN}║ $(msg select_lang) ║${NC}"
|
||||
echo -e "${CYAN}╠════════════════════════════════════════╣${NC}"
|
||||
echo -e "${CYAN}║${NC} ${GREEN}1)${NC} English ${CYAN}║${NC}"
|
||||
echo -e "${CYAN}║${NC} ${GREEN}2)${NC} 中文 ${CYAN}║${NC}"
|
||||
echo -e "${CYAN}╚════════════════════════════════════════╝${NC}"
|
||||
echo ""
|
||||
|
||||
read -p "Select [1]: " lang_choice < /dev/tty
|
||||
case "$lang_choice" in
|
||||
2)
|
||||
LANG_CODE="zh"
|
||||
;;
|
||||
*)
|
||||
LANG_CODE="en"
|
||||
;;
|
||||
esac
|
||||
echo ""
|
||||
}
|
||||
|
||||
# ============================================================================
|
||||
# System checks
|
||||
# ============================================================================
|
||||
check_os() {
|
||||
print_step "$(msg checking_os)"
|
||||
|
||||
case "$(uname -s)" in
|
||||
Linux*)
|
||||
OS="linux"
|
||||
;;
|
||||
Darwin*)
|
||||
OS="darwin"
|
||||
;;
|
||||
MINGW*|MSYS*|CYGWIN*)
|
||||
OS="windows"
|
||||
print_warning "$(msg windows_note)"
|
||||
;;
|
||||
*)
|
||||
print_error "$(msg unsupported_os): $(uname -s)"
|
||||
exit 1
|
||||
;;
|
||||
esac
|
||||
|
||||
print_success "$(msg detected_os): $OS"
|
||||
}
|
||||
|
||||
check_arch() {
|
||||
print_step "$(msg checking_arch)"
|
||||
|
||||
case "$(uname -m)" in
|
||||
x86_64|amd64)
|
||||
ARCH="amd64"
|
||||
;;
|
||||
aarch64|arm64)
|
||||
ARCH="arm64"
|
||||
;;
|
||||
armv7l)
|
||||
ARCH="arm"
|
||||
;;
|
||||
i386|i686)
|
||||
ARCH="386"
|
||||
;;
|
||||
*)
|
||||
print_error "$(msg unsupported_arch): $(uname -m)"
|
||||
exit 1
|
||||
;;
|
||||
esac
|
||||
|
||||
print_success "$(msg detected_arch): $ARCH"
|
||||
}
|
||||
|
||||
check_dependencies() {
|
||||
print_step "$(msg checking_deps)"
|
||||
|
||||
# Check for download tool
|
||||
if ! command -v curl &> /dev/null && ! command -v wget &> /dev/null; then
|
||||
print_error "curl or wget is required"
|
||||
exit 1
|
||||
fi
|
||||
|
||||
print_success "$(msg deps_ok)"
|
||||
}
|
||||
|
||||
get_remote_version() {
|
||||
# Determine binary name based on OS and ARCH
|
||||
local binary_file="drip-${OS}-${ARCH}"
|
||||
if [[ "$OS" == "windows" ]]; then
|
||||
binary_file="${binary_file}.exe"
|
||||
fi
|
||||
|
||||
local download_url="${DOWNLOAD_BASE_URL}/${VERSION}/${binary_file}"
|
||||
local tmp_file="/tmp/drip-check-$$"
|
||||
|
||||
if command -v curl &> /dev/null; then
|
||||
curl -fsSL "$download_url" -o "$tmp_file" 2>/dev/null || return 1
|
||||
else
|
||||
wget -q "$download_url" -O "$tmp_file" 2>/dev/null || return 1
|
||||
fi
|
||||
|
||||
chmod +x "$tmp_file"
|
||||
local remote_version=$("$tmp_file" version 2>/dev/null | awk '/Version:/ {print $2}' || echo "unknown")
|
||||
rm -f "$tmp_file"
|
||||
echo "$remote_version"
|
||||
}
|
||||
|
||||
check_existing_install() {
|
||||
if command -v drip &> /dev/null; then
|
||||
local current_path=$(command -v drip)
|
||||
local current_version=$(drip version 2>/dev/null | awk '/Version:/ {print $2}' || echo "unknown")
|
||||
|
||||
print_warning "$(msg already_installed): $current_path"
|
||||
print_info "$(msg current_version): $current_version"
|
||||
|
||||
# Check remote version
|
||||
print_step "Checking for updates..."
|
||||
local remote_version=$(get_remote_version)
|
||||
|
||||
if [[ "$remote_version" == "unknown" ]]; then
|
||||
print_warning "Could not check remote version"
|
||||
echo ""
|
||||
read -p "$(msg update_now) [Y/n]: " update_choice < /dev/tty
|
||||
elif [[ "$current_version" == "$remote_version" ]]; then
|
||||
print_success "Already up to date ($current_version)"
|
||||
exit 0
|
||||
else
|
||||
print_info "Latest version: $remote_version"
|
||||
echo ""
|
||||
read -p "$(msg update_now) [Y/n]: " update_choice < /dev/tty
|
||||
fi
|
||||
|
||||
if [[ "$update_choice" =~ ^[Nn]$ ]]; then
|
||||
exit 0
|
||||
fi
|
||||
|
||||
INSTALL_DIR=$(dirname "$current_path")
|
||||
IS_UPDATE=true
|
||||
fi
|
||||
}
|
||||
|
||||
# ============================================================================
|
||||
# Download and install
|
||||
# ============================================================================
|
||||
get_download_url() {
|
||||
local binary_name
|
||||
|
||||
if [[ "$OS" == "windows" ]]; then
|
||||
binary_name="drip-windows-${ARCH}.exe"
|
||||
else
|
||||
binary_name="drip-${OS}-${ARCH}"
|
||||
fi
|
||||
|
||||
echo "${DOWNLOAD_BASE_URL}/${VERSION}/${binary_name}"
|
||||
}
|
||||
|
||||
download_binary() {
|
||||
local url=$(get_download_url)
|
||||
|
||||
if [[ "$IS_UPDATE" == true ]]; then
|
||||
print_step "$(msg updating)..."
|
||||
else
|
||||
print_step "$(msg downloading)..."
|
||||
fi
|
||||
|
||||
local tmp_file="/tmp/drip-download"
|
||||
|
||||
if command -v curl &> /dev/null; then
|
||||
# Use -# for progress bar instead of -s (silent)
|
||||
if ! curl -f#L "$url" -o "$tmp_file"; then
|
||||
print_error "$(msg download_failed): $url"
|
||||
exit 1
|
||||
fi
|
||||
else
|
||||
# Use --show-progress to display download progress
|
||||
if ! wget --show-progress "$url" -O "$tmp_file" 2>&1 | grep -v "^$"; then
|
||||
print_error "$(msg download_failed): $url"
|
||||
exit 1
|
||||
fi
|
||||
fi
|
||||
|
||||
chmod +x "$tmp_file"
|
||||
print_success "$(msg download_ok)"
|
||||
}
|
||||
|
||||
select_install_dir() {
|
||||
if [[ -n "$INSTALL_DIR" ]]; then
|
||||
return
|
||||
fi
|
||||
|
||||
echo ""
|
||||
echo -e "${CYAN}╔════════════════════════════════════════╗${NC}"
|
||||
echo -e "${CYAN}║ $(msg select_install_dir) ${CYAN}║${NC}"
|
||||
echo -e "${CYAN}╠════════════════════════════════════════╣${NC}"
|
||||
echo -e "${CYAN}║${NC} ${GREEN}1)${NC} ~/.local/bin $(msg option_user) ${CYAN}║${NC}"
|
||||
echo -e "${CYAN}║${NC} ${GREEN}2)${NC} /usr/local/bin $(msg option_system) ${CYAN}║${NC}"
|
||||
echo -e "${CYAN}║${NC} ${GREEN}3)${NC} ./ $(msg option_current) ${CYAN}║${NC}"
|
||||
echo -e "${CYAN}║${NC} ${GREEN}4)${NC} $(msg option_custom) ${CYAN}║${NC}"
|
||||
echo -e "${CYAN}╚════════════════════════════════════════╝${NC}"
|
||||
echo ""
|
||||
|
||||
read -p "Select [1]: " dir_choice < /dev/tty
|
||||
|
||||
case "$dir_choice" in
|
||||
2)
|
||||
INSTALL_DIR="/usr/local/bin"
|
||||
NEED_SUDO=true
|
||||
;;
|
||||
3)
|
||||
INSTALL_DIR="."
|
||||
;;
|
||||
4)
|
||||
read -p "$(msg enter_custom_path): " INSTALL_DIR < /dev/tty
|
||||
;;
|
||||
*)
|
||||
INSTALL_DIR="$HOME/.local/bin"
|
||||
;;
|
||||
esac
|
||||
}
|
||||
|
||||
install_binary() {
|
||||
print_step "$(msg installing)"
|
||||
|
||||
# Create directory if needed
|
||||
if [[ ! -d "$INSTALL_DIR" ]]; then
|
||||
if [[ "$NEED_SUDO" == true ]]; then
|
||||
sudo mkdir -p "$INSTALL_DIR"
|
||||
else
|
||||
mkdir -p "$INSTALL_DIR"
|
||||
fi
|
||||
fi
|
||||
|
||||
local target_path="$INSTALL_DIR/$BINARY_NAME"
|
||||
if [[ "$OS" == "windows" ]]; then
|
||||
target_path="$INSTALL_DIR/$BINARY_NAME.exe"
|
||||
fi
|
||||
|
||||
# Install binary
|
||||
if [[ "$NEED_SUDO" == true ]]; then
|
||||
sudo mv /tmp/drip-download "$target_path"
|
||||
sudo chmod +x "$target_path"
|
||||
else
|
||||
mv /tmp/drip-download "$target_path"
|
||||
chmod +x "$target_path"
|
||||
fi
|
||||
|
||||
print_success "$(msg install_ok): $target_path"
|
||||
}
|
||||
|
||||
update_path() {
|
||||
# Skip if already in PATH
|
||||
if command -v drip &> /dev/null; then
|
||||
return
|
||||
fi
|
||||
|
||||
# Skip for system directories (usually already in PATH)
|
||||
if [[ "$INSTALL_DIR" == "/usr/local/bin" ]] || [[ "$INSTALL_DIR" == "/usr/bin" ]]; then
|
||||
return
|
||||
fi
|
||||
|
||||
print_step "$(msg updating_path)"
|
||||
|
||||
local shell_rc=""
|
||||
local export_line="export PATH=\"\$PATH:$INSTALL_DIR\""
|
||||
|
||||
# Determine shell config file
|
||||
if [[ -n "$ZSH_VERSION" ]] || [[ "$SHELL" == *"zsh"* ]]; then
|
||||
shell_rc="$HOME/.zshrc"
|
||||
elif [[ -n "$BASH_VERSION" ]] || [[ "$SHELL" == *"bash"* ]]; then
|
||||
if [[ "$OS" == "darwin" ]]; then
|
||||
shell_rc="$HOME/.bash_profile"
|
||||
else
|
||||
shell_rc="$HOME/.bashrc"
|
||||
fi
|
||||
elif [[ "$SHELL" == *"fish"* ]]; then
|
||||
shell_rc="$HOME/.config/fish/config.fish"
|
||||
export_line="set -gx PATH \$PATH $INSTALL_DIR"
|
||||
fi
|
||||
|
||||
if [[ -n "$shell_rc" ]]; then
|
||||
# Check if already added
|
||||
if ! grep -q "$INSTALL_DIR" "$shell_rc" 2>/dev/null; then
|
||||
echo "" >> "$shell_rc"
|
||||
echo "# Drip client" >> "$shell_rc"
|
||||
echo "$export_line" >> "$shell_rc"
|
||||
print_success "$(msg path_updated): $shell_rc"
|
||||
fi
|
||||
fi
|
||||
|
||||
print_warning "$(msg path_note)"
|
||||
}
|
||||
|
||||
verify_installation() {
|
||||
print_step "$(msg verify_install)"
|
||||
|
||||
local binary_path="$INSTALL_DIR/$BINARY_NAME"
|
||||
if [[ "$OS" == "windows" ]]; then
|
||||
binary_path="$INSTALL_DIR/$BINARY_NAME.exe"
|
||||
fi
|
||||
|
||||
if [[ -x "$binary_path" ]]; then
|
||||
local version=$("$binary_path" version 2>/dev/null | awk '/Version:/ {print $2}' || echo "installed")
|
||||
print_success "$(msg verify_ok): $version"
|
||||
else
|
||||
print_error "$(msg verify_failed)"
|
||||
exit 1
|
||||
fi
|
||||
}
|
||||
|
||||
# ============================================================================
|
||||
# Configuration
|
||||
# ============================================================================
|
||||
configure_client() {
|
||||
echo ""
|
||||
read -p "$(msg configure_now) [Y/n]: " config_choice < /dev/tty
|
||||
|
||||
if [[ "$config_choice" =~ ^[Nn]$ ]]; then
|
||||
return
|
||||
fi
|
||||
|
||||
echo ""
|
||||
echo -e "${CYAN}╔════════════════════════════════════════╗${NC}"
|
||||
echo -e "${CYAN}║ $(msg config_title) ${CYAN}║${NC}"
|
||||
echo -e "${CYAN}╚════════════════════════════════════════╝${NC}"
|
||||
echo ""
|
||||
|
||||
local binary_path="$INSTALL_DIR/$BINARY_NAME"
|
||||
|
||||
# Server address
|
||||
while true; do
|
||||
read -p "$(msg enter_server): " SERVER < /dev/tty
|
||||
if [[ -n "$SERVER" ]]; then
|
||||
break
|
||||
fi
|
||||
print_error "$(msg server_required)"
|
||||
done
|
||||
|
||||
# Token
|
||||
while true; do
|
||||
read -p "$(msg enter_token): " TOKEN < /dev/tty
|
||||
if [[ -n "$TOKEN" ]]; then
|
||||
break
|
||||
fi
|
||||
print_error "$(msg token_required)"
|
||||
done
|
||||
|
||||
# Insecure mode
|
||||
read -p "$(msg skip_verify) [y/N]: " insecure_choice < /dev/tty
|
||||
INSECURE=""
|
||||
if [[ "$insecure_choice" =~ ^[Yy]$ ]]; then
|
||||
INSECURE="--insecure"
|
||||
print_warning "$(msg insecure_note)"
|
||||
fi
|
||||
|
||||
# Save configuration
|
||||
"$binary_path" config set --server "$SERVER" --token "$TOKEN" $INSECURE 2>/dev/null || true
|
||||
|
||||
print_success "$(msg config_saved)"
|
||||
}
|
||||
|
||||
# ============================================================================
|
||||
# Test connection
|
||||
# ============================================================================
|
||||
test_connection() {
|
||||
echo ""
|
||||
read -p "$(msg run_test) [y/N]: " test_choice < /dev/tty
|
||||
|
||||
if [[ ! "$test_choice" =~ ^[Yy]$ ]]; then
|
||||
return
|
||||
fi
|
||||
|
||||
print_step "$(msg test_running)"
|
||||
|
||||
local binary_path="$INSTALL_DIR/$BINARY_NAME"
|
||||
|
||||
# Try to validate config
|
||||
if "$binary_path" config validate 2>/dev/null; then
|
||||
print_success "$(msg test_success)"
|
||||
else
|
||||
print_warning "$(msg test_failed)"
|
||||
fi
|
||||
}
|
||||
|
||||
# ============================================================================
|
||||
# Final output
|
||||
# ============================================================================
|
||||
show_completion() {
|
||||
local binary_path="$INSTALL_DIR/$BINARY_NAME"
|
||||
|
||||
echo ""
|
||||
echo -e "${GREEN}╔════════════════════════════════════════════════════════════╗${NC}"
|
||||
echo -e "${GREEN}║ $(msg install_complete) ${GREEN}║${NC}"
|
||||
echo -e "${GREEN}╚════════════════════════════════════════════════════════════╝${NC}"
|
||||
echo ""
|
||||
|
||||
echo -e "${CYAN}$(msg usage_title):${NC}"
|
||||
echo ""
|
||||
echo -e " ${GREEN}# $(msg usage_http)${NC}"
|
||||
echo -e " ${YELLOW}$BINARY_NAME http 3000${NC}"
|
||||
echo ""
|
||||
echo -e " ${GREEN}# $(msg usage_tcp)${NC}"
|
||||
echo -e " ${YELLOW}$BINARY_NAME tcp 5432${NC}"
|
||||
echo ""
|
||||
echo -e " ${GREEN}# $(msg usage_config)${NC}"
|
||||
echo -e " ${YELLOW}$BINARY_NAME config show${NC}"
|
||||
echo -e " ${YELLOW}$BINARY_NAME config init${NC}"
|
||||
echo ""
|
||||
echo -e " ${GREEN}# $(msg usage_daemon)${NC}"
|
||||
echo -e " ${YELLOW}$BINARY_NAME daemon start http 3000${NC}"
|
||||
echo -e " ${YELLOW}$BINARY_NAME daemon list${NC}"
|
||||
echo ""
|
||||
}
|
||||
|
||||
# ============================================================================
|
||||
# Main
|
||||
# ============================================================================
|
||||
main() {
|
||||
clear
|
||||
print_banner
|
||||
select_language
|
||||
|
||||
echo -e "${BOLD}────────────────────────────────────────────${NC}"
|
||||
|
||||
check_os
|
||||
check_arch
|
||||
check_dependencies
|
||||
check_existing_install
|
||||
|
||||
echo ""
|
||||
download_binary
|
||||
select_install_dir
|
||||
install_binary
|
||||
update_path
|
||||
verify_installation
|
||||
|
||||
# Skip configuration for updates
|
||||
if [[ "$IS_UPDATE" != true ]]; then
|
||||
configure_client
|
||||
test_connection
|
||||
else
|
||||
echo ""
|
||||
local new_version=$("$INSTALL_DIR/$BINARY_NAME" version 2>/dev/null | awk '/Version:/ {print $2}' || echo "installed")
|
||||
echo -e "${GREEN}╔════════════════════════════════════════════════════════════╗${NC}"
|
||||
echo -e "${GREEN}║ $(msg update_ok) ${GREEN}║${NC}"
|
||||
echo -e "${GREEN}╚════════════════════════════════════════════════════════════╝${NC}"
|
||||
echo ""
|
||||
print_info "Version: $new_version"
|
||||
echo ""
|
||||
return
|
||||
fi
|
||||
|
||||
show_completion
|
||||
}
|
||||
|
||||
# Run
|
||||
main "$@"
|
||||
503
scripts/test/one-click-test.sh
Executable file
503
scripts/test/one-click-test.sh
Executable file
@@ -0,0 +1,503 @@
|
||||
#!/bin/bash
|
||||
# Drip One-Click Performance Test Script
|
||||
# Automatically starts all services, runs tests, and generates reports
|
||||
|
||||
set -e
|
||||
|
||||
# Color definitions
|
||||
RED='\033[0;31m'
|
||||
GREEN='\033[0;32m'
|
||||
YELLOW='\033[1;33m'
|
||||
BLUE='\033[0;34m'
|
||||
NC='\033[0m' # No Color
|
||||
|
||||
# Configuration
|
||||
RESULTS_DIR="benchmark-results"
|
||||
TIMESTAMP=$(date +%Y%m%d_%H%M%S)
|
||||
LOG_DIR="/tmp/drip-test-${TIMESTAMP}"
|
||||
REPORT_FILE="${RESULTS_DIR}/test-report-${TIMESTAMP}.txt"
|
||||
|
||||
# Port configuration
|
||||
HTTP_TEST_PORT=3000
|
||||
DRIP_SERVER_PORT=8443
|
||||
PPROF_PORT=6060
|
||||
|
||||
# PID file
|
||||
PIDS_FILE="${LOG_DIR}/pids.txt"
|
||||
|
||||
# Create directories
|
||||
mkdir -p "$RESULTS_DIR"
|
||||
mkdir -p "$LOG_DIR"
|
||||
|
||||
# ============================================
|
||||
# Helper functions
|
||||
# ============================================
|
||||
|
||||
# All logs go to stderr to avoid being consumed by command substitution $(...)
|
||||
log_info() {
|
||||
echo -e "${GREEN}[INFO]${NC} $1" >&2
|
||||
}
|
||||
|
||||
log_warn() {
|
||||
echo -e "${YELLOW}[WARN]${NC} $1" >&2
|
||||
}
|
||||
|
||||
log_error() {
|
||||
echo -e "${RED}[ERROR]${NC} $1" >&2
|
||||
}
|
||||
|
||||
log_step() {
|
||||
echo -e "\n${BLUE}==>${NC} $1\n" >&2
|
||||
}
|
||||
|
||||
# Cleanup function
|
||||
cleanup() {
|
||||
log_step "Cleaning up test environment..."
|
||||
|
||||
if [ -f "$PIDS_FILE" ]; then
|
||||
log_info "Stopping all test processes..."
|
||||
while read -r pid; do
|
||||
if ps -p "$pid" > /dev/null 2>&1; then
|
||||
log_info "Stopping process $pid"
|
||||
kill "$pid" 2>/dev/null || true
|
||||
fi
|
||||
done < "$PIDS_FILE"
|
||||
rm -f "$PIDS_FILE"
|
||||
fi
|
||||
|
||||
# Extra cleanup: ensure ports are released
|
||||
pkill -f "python.*${HTTP_TEST_PORT}" 2>/dev/null || true
|
||||
pkill -f "drip server.*${DRIP_SERVER_PORT}" 2>/dev/null || true
|
||||
pkill -f "drip http ${HTTP_TEST_PORT}" 2>/dev/null || true
|
||||
|
||||
log_info "Cleanup completed"
|
||||
}
|
||||
|
||||
# Register cleanup function
|
||||
trap cleanup EXIT INT TERM
|
||||
|
||||
# Check dependencies
|
||||
check_dependencies() {
|
||||
log_step "Checking dependencies..."
|
||||
|
||||
local missing=""
|
||||
|
||||
if ! command -v wrk &> /dev/null; then
|
||||
missing="${missing}\n - wrk (brew install wrk)"
|
||||
fi
|
||||
|
||||
if ! command -v python3 &> /dev/null; then
|
||||
missing="${missing}\n - python3"
|
||||
fi
|
||||
|
||||
if ! command -v openssl &> /dev/null; then
|
||||
missing="${missing}\n - openssl"
|
||||
fi
|
||||
|
||||
if ! command -v nc &> /dev/null; then
|
||||
missing="${missing}\n - nc (netcat)"
|
||||
fi
|
||||
|
||||
if [ ! -f "./bin/drip" ]; then
|
||||
log_error "Cannot find drip executable"
|
||||
log_info "Please run: make build"
|
||||
exit 1
|
||||
fi
|
||||
|
||||
if [ -n "$missing" ]; then
|
||||
log_error "Missing dependencies:${missing}"
|
||||
exit 1
|
||||
fi
|
||||
|
||||
log_info "✓ All dependencies satisfied"
|
||||
}
|
||||
|
||||
# Generate self-signed ECDSA certificate for testing
|
||||
generate_test_certs() {
|
||||
log_step "Generating test TLS certificate (ECDSA)..."
|
||||
|
||||
local cert_dir="${LOG_DIR}/certs"
|
||||
mkdir -p "$cert_dir"
|
||||
|
||||
# Generate ECDSA private key (prime256v1 = P-256)
|
||||
openssl ecparam -name prime256v1 -genkey -noout \
|
||||
-out "${cert_dir}/server.key" >/dev/null 2>&1
|
||||
|
||||
# Generate self-signed certificate with this private key
|
||||
openssl req -new -x509 \
|
||||
-key "${cert_dir}/server.key" \
|
||||
-out "${cert_dir}/server.crt" \
|
||||
-days 1 \
|
||||
-subj "/C=US/ST=Test/L=Test/O=Test/CN=localhost" \
|
||||
>/dev/null 2>&1
|
||||
|
||||
if [ -f "${cert_dir}/server.crt" ] && [ -f "${cert_dir}/server.key" ]; then
|
||||
log_info "✓ ECDSA test certificate generated"
|
||||
# Note: this echo is the "return value", stdout only outputs this line
|
||||
echo "${cert_dir}/server.crt ${cert_dir}/server.key"
|
||||
else
|
||||
log_error "ECDSA certificate generation failed"
|
||||
exit 1
|
||||
fi
|
||||
}
|
||||
|
||||
# Wait for port to be available
|
||||
wait_for_port() {
|
||||
local port=$1
|
||||
local max_wait=${2:-30}
|
||||
local waited=0
|
||||
|
||||
while ! nc -z localhost "$port" 2>/dev/null; do
|
||||
if [ "$waited" -ge "$max_wait" ]; then
|
||||
return 1
|
||||
fi
|
||||
sleep 1
|
||||
waited=$((waited + 1))
|
||||
done
|
||||
return 0
|
||||
}
|
||||
|
||||
# Start HTTP test server
|
||||
start_http_server() {
|
||||
log_step "Starting HTTP test server (port $HTTP_TEST_PORT)..."
|
||||
|
||||
# Create simple test server
|
||||
cat > "${LOG_DIR}/test-server.py" << 'EOF'
|
||||
import http.server
|
||||
import socketserver
|
||||
import json
|
||||
from datetime import datetime
|
||||
import sys
|
||||
|
||||
PORT = int(sys.argv[1]) if len(sys.argv) > 1 else 3000
|
||||
|
||||
class TestHandler(http.server.SimpleHTTPRequestHandler):
|
||||
def do_GET(self):
|
||||
response = {
|
||||
"status": "ok",
|
||||
"timestamp": datetime.now().isoformat(),
|
||||
"message": "Test server response"
|
||||
}
|
||||
|
||||
self.send_response(200)
|
||||
self.send_header('Content-type', 'application/json')
|
||||
self.end_headers()
|
||||
self.wfile.write(json.dumps(response).encode())
|
||||
|
||||
def log_message(self, format, *args):
|
||||
pass # Silent logging
|
||||
|
||||
with socketserver.TCPServer(("", PORT), TestHandler) as httpd:
|
||||
print(f"Server started on port {PORT}", flush=True)
|
||||
httpd.serve_forever()
|
||||
EOF
|
||||
|
||||
python3 "${LOG_DIR}/test-server.py" "$HTTP_TEST_PORT" \
|
||||
> "${LOG_DIR}/http-server.log" 2>&1 &
|
||||
local pid=$!
|
||||
echo "$pid" >> "$PIDS_FILE"
|
||||
|
||||
if wait_for_port "$HTTP_TEST_PORT" 10; then
|
||||
log_info "✓ HTTP test server started (PID: $pid)"
|
||||
else
|
||||
log_error "HTTP test server failed to start"
|
||||
exit 1
|
||||
fi
|
||||
}
|
||||
|
||||
# Start Drip server
|
||||
start_drip_server() {
|
||||
log_step "Starting Drip server (port $DRIP_SERVER_PORT)..."
|
||||
|
||||
local cert_path=$1
|
||||
local key_path=$2
|
||||
|
||||
./bin/drip server \
|
||||
--port "$DRIP_SERVER_PORT" \
|
||||
--domain localhost \
|
||||
--tls-cert "$cert_path" \
|
||||
--tls-key "$key_path" \
|
||||
> "${LOG_DIR}/drip-server.log" 2>&1 &
|
||||
local pid=$!
|
||||
echo "$pid" >> "$PIDS_FILE"
|
||||
|
||||
if wait_for_port "$DRIP_SERVER_PORT" 10; then
|
||||
log_info "✓ Drip server started (PID: $pid)"
|
||||
else
|
||||
log_error "Drip server failed to start"
|
||||
cat "${LOG_DIR}/drip-server.log"
|
||||
exit 1
|
||||
fi
|
||||
}
|
||||
|
||||
# Start Drip client and extract URL
|
||||
start_drip_client() {
|
||||
log_step "Starting Drip client..."
|
||||
|
||||
./bin/drip http "$HTTP_TEST_PORT" \
|
||||
--server "localhost:${DRIP_SERVER_PORT}" \
|
||||
--insecure \
|
||||
> "${LOG_DIR}/drip-client.log" 2>&1 &
|
||||
local pid=$!
|
||||
echo "$pid" >> "$PIDS_FILE"
|
||||
|
||||
# Wait for client to start and extract URL
|
||||
log_info "Waiting for tunnel to establish..."
|
||||
sleep 3
|
||||
|
||||
# Extract tunnel URL from logs
|
||||
local tunnel_url=""
|
||||
local max_attempts=10
|
||||
local attempt=0
|
||||
|
||||
while [ "$attempt" -lt "$max_attempts" ]; do
|
||||
# Use grep to extract URL starting with https:// and remove ANSI color codes
|
||||
tunnel_url=$(grep -oE 'https://[a-zA-Z0-9.-]+:[0-9]+' "${LOG_DIR}/drip-client.log" 2>/dev/null | head -1)
|
||||
if [ -n "$tunnel_url" ]; then
|
||||
break
|
||||
fi
|
||||
sleep 1
|
||||
attempt=$((attempt + 1))
|
||||
done
|
||||
|
||||
if [ -z "$tunnel_url" ]; then
|
||||
log_error "Cannot get tunnel URL"
|
||||
log_info "Client logs:"
|
||||
cat "${LOG_DIR}/drip-client.log"
|
||||
exit 1
|
||||
fi
|
||||
|
||||
log_info "✓ Drip client started (PID: $pid)"
|
||||
log_info "✓ Tunnel URL: $tunnel_url"
|
||||
|
||||
# Return URL
|
||||
echo "$tunnel_url"
|
||||
}
|
||||
|
||||
# Verify connectivity
|
||||
verify_connectivity() {
|
||||
local url=$1
|
||||
log_step "Verifying tunnel connectivity..."
|
||||
|
||||
local max_attempts=5
|
||||
local attempt=0
|
||||
|
||||
while [ "$attempt" -lt "$max_attempts" ]; do
|
||||
if curl -sk --max-time 5 "$url" > /dev/null 2>&1; then
|
||||
log_info "✓ Tunnel connectivity normal"
|
||||
return 0
|
||||
fi
|
||||
attempt=$((attempt + 1))
|
||||
log_warn "Attempt $attempt/$max_attempts..."
|
||||
sleep 2
|
||||
done
|
||||
|
||||
log_error "Tunnel connectivity test failed"
|
||||
return 1
|
||||
}
|
||||
|
||||
# Run performance tests
|
||||
run_performance_tests() {
|
||||
local url=$1
|
||||
|
||||
log_step "Starting performance tests..."
|
||||
|
||||
# Test 1: Quick benchmark
|
||||
log_info "[1/3] Quick benchmark (10s)..."
|
||||
wrk -t 4 -c 50 -d 10s --latency "$url" \
|
||||
> "${RESULTS_DIR}/quick-benchmark-${TIMESTAMP}.txt" 2>&1
|
||||
|
||||
# Test 2: Standard load test
|
||||
log_info "[2/3] Standard load test (30s)..."
|
||||
wrk -t 8 -c 100 -d 30s --latency "$url" \
|
||||
> "${RESULTS_DIR}/standard-benchmark-${TIMESTAMP}.txt" 2>&1
|
||||
|
||||
# Test 3: High concurrency test
|
||||
log_info "[3/3] High concurrency test (30s)..."
|
||||
wrk -t 12 -c 400 -d 30s --latency "$url" \
|
||||
> "${RESULTS_DIR}/high-concurrency-${TIMESTAMP}.txt" 2>&1
|
||||
|
||||
log_info "✓ Performance tests completed"
|
||||
}
|
||||
|
||||
# Generate test report
|
||||
generate_report() {
|
||||
log_step "Generating test report..."
|
||||
|
||||
cat > "$REPORT_FILE" << EOF
|
||||
========================================
|
||||
Drip Performance Test Report
|
||||
========================================
|
||||
|
||||
Test Time: $(date)
|
||||
Test Version: $(./bin/drip version 2>/dev/null | head -1 || echo "unknown")
|
||||
|
||||
========================================
|
||||
Test Environment
|
||||
========================================
|
||||
|
||||
OS: $(uname -s)
|
||||
CPU Cores: $(sysctl -n hw.ncpu 2>/dev/null || nproc 2>/dev/null || echo "unknown")
|
||||
Memory: $(sysctl -n hw.memsize 2>/dev/null | awk '{print int($1/1024/1024/1024)"GB"}' || echo "unknown")
|
||||
|
||||
========================================
|
||||
Test Results
|
||||
========================================
|
||||
|
||||
EOF
|
||||
|
||||
# Parse and add results from each test
|
||||
if [ -f "${RESULTS_DIR}/quick-benchmark-${TIMESTAMP}.txt" ]; then
|
||||
{
|
||||
echo "--- Quick Benchmark (10s, 50 connections) ---"
|
||||
grep "Requests/sec:" "${RESULTS_DIR}/quick-benchmark-${TIMESTAMP}.txt"
|
||||
grep "Transfer/sec:" "${RESULTS_DIR}/quick-benchmark-${TIMESTAMP}.txt"
|
||||
echo ""
|
||||
grep "Latency" "${RESULTS_DIR}/quick-benchmark-${TIMESTAMP}.txt" | head -1
|
||||
grep -A 3 "Latency Distribution" "${RESULTS_DIR}/quick-benchmark-${TIMESTAMP}.txt"
|
||||
echo ""
|
||||
} >> "$REPORT_FILE"
|
||||
fi
|
||||
|
||||
if [ -f "${RESULTS_DIR}/standard-benchmark-${TIMESTAMP}.txt" ]; then
|
||||
{
|
||||
echo "--- Standard Load Test (30s, 100 connections) ---"
|
||||
grep "Requests/sec:" "${RESULTS_DIR}/standard-benchmark-${TIMESTAMP}.txt"
|
||||
grep "Transfer/sec:" "${RESULTS_DIR}/standard-benchmark-${TIMESTAMP}.txt"
|
||||
echo ""
|
||||
grep "Latency" "${RESULTS_DIR}/standard-benchmark-${TIMESTAMP}.txt" | head -1
|
||||
grep -A 3 "Latency Distribution" "${RESULTS_DIR}/standard-benchmark-${TIMESTAMP}.txt"
|
||||
echo ""
|
||||
} >> "$REPORT_FILE"
|
||||
fi
|
||||
|
||||
if [ -f "${RESULTS_DIR}/high-concurrency-${TIMESTAMP}.txt" ]; then
|
||||
{
|
||||
echo "--- High Concurrency Test (30s, 400 connections) ---"
|
||||
grep "Requests/sec:" "${RESULTS_DIR}/high-concurrency-${TIMESTAMP}.txt"
|
||||
grep "Transfer/sec:" "${RESULTS_DIR}/high-concurrency-${TIMESTAMP}.txt"
|
||||
echo ""
|
||||
grep "Latency" "${RESULTS_DIR}/high-concurrency-${TIMESTAMP}.txt" | head -1
|
||||
grep -A 3 "Latency Distribution" "${RESULTS_DIR}/high-concurrency-${TIMESTAMP}.txt"
|
||||
echo ""
|
||||
} >> "$REPORT_FILE"
|
||||
fi
|
||||
|
||||
# Add performance evaluation
|
||||
cat >> "$REPORT_FILE" << 'EOF'
|
||||
========================================
|
||||
Performance Evaluation Standards
|
||||
========================================
|
||||
|
||||
Excellent (Phase 1 optimization target):
|
||||
✓ QPS > 5000
|
||||
✓ P99 latency < 50ms
|
||||
✓ Error rate = 0%
|
||||
|
||||
Good:
|
||||
✓ QPS > 2000
|
||||
✓ P99 latency < 100ms
|
||||
✓ Error rate < 0.1%
|
||||
|
||||
Needs Optimization:
|
||||
✗ QPS < 1000
|
||||
✗ P99 latency > 200ms
|
||||
|
||||
========================================
|
||||
Detailed Result Files
|
||||
========================================
|
||||
|
||||
EOF
|
||||
|
||||
{
|
||||
echo "Quick test: ${RESULTS_DIR}/quick-benchmark-${TIMESTAMP}.txt"
|
||||
echo "Standard test: ${RESULTS_DIR}/standard-benchmark-${TIMESTAMP}.txt"
|
||||
echo "High concurrency test: ${RESULTS_DIR}/high-concurrency-${TIMESTAMP}.txt"
|
||||
echo ""
|
||||
echo "Log directory: ${LOG_DIR}/"
|
||||
} >> "$REPORT_FILE"
|
||||
|
||||
log_info "✓ Test report generated: $REPORT_FILE"
|
||||
}
|
||||
|
||||
# Show report summary
|
||||
show_summary() {
|
||||
log_step "Test Results Summary"
|
||||
|
||||
if [ -f "${RESULTS_DIR}/standard-benchmark-${TIMESTAMP}.txt" ]; then
|
||||
echo ""
|
||||
echo "========================================="
|
||||
echo " Standard Load Test Results"
|
||||
echo "========================================="
|
||||
grep "Requests/sec:" "${RESULTS_DIR}/standard-benchmark-${TIMESTAMP}.txt"
|
||||
grep "Transfer/sec:" "${RESULTS_DIR}/standard-benchmark-${TIMESTAMP}.txt"
|
||||
echo ""
|
||||
grep "50%" "${RESULTS_DIR}/standard-benchmark-${TIMESTAMP}.txt"
|
||||
grep "99%" "${RESULTS_DIR}/standard-benchmark-${TIMESTAMP}.txt"
|
||||
echo "========================================="
|
||||
echo ""
|
||||
fi
|
||||
|
||||
log_info "Full report: $REPORT_FILE"
|
||||
log_info "Detailed results: $RESULTS_DIR/"
|
||||
}
|
||||
|
||||
# ============================================
|
||||
# Main flow
|
||||
# ============================================
|
||||
|
||||
main() {
|
||||
clear
|
||||
echo "========================================="
|
||||
echo " Drip One-Click Performance Test"
|
||||
echo "========================================="
|
||||
echo ""
|
||||
|
||||
# Check dependencies
|
||||
check_dependencies
|
||||
|
||||
# Generate test certificate (ECDSA), only capture the last line of stdout with paths
|
||||
CERT_PATHS=$(generate_test_certs)
|
||||
CERT_FILE=$(echo "$CERT_PATHS" | awk '{print $1}')
|
||||
KEY_FILE=$(echo "$CERT_PATHS" | awk '{print $2}')
|
||||
|
||||
log_info "Using certificate: $CERT_FILE"
|
||||
log_info "Using private key: $KEY_FILE"
|
||||
|
||||
# Start all services
|
||||
start_http_server
|
||||
start_drip_server "$CERT_FILE" "$KEY_FILE"
|
||||
TUNNEL_URL=$(start_drip_client)
|
||||
|
||||
# Verify connectivity
|
||||
if ! verify_connectivity "$TUNNEL_URL"; then
|
||||
log_error "Test aborted: tunnel not accessible"
|
||||
exit 1
|
||||
fi
|
||||
|
||||
# Warm up
|
||||
log_info "Warming up tunnel (5s)..."
|
||||
for _ in {1..5}; do
|
||||
curl -sk "$TUNNEL_URL" > /dev/null 2>&1 || true
|
||||
sleep 1
|
||||
done
|
||||
|
||||
# Run tests
|
||||
run_performance_tests "$TUNNEL_URL"
|
||||
|
||||
# Generate report
|
||||
generate_report
|
||||
|
||||
# Show summary
|
||||
show_summary
|
||||
|
||||
log_step "Testing completed!"
|
||||
echo ""
|
||||
echo "Press any key to view full report, or Ctrl+C to exit..."
|
||||
read -n 1 -s
|
||||
|
||||
cat "$REPORT_FILE"
|
||||
}
|
||||
|
||||
# Run main flow
|
||||
main
|
||||
Reference in New Issue
Block a user