feat(init): Initializes the project's basic structure and configuration files.

This commit is contained in:
Gouryella
2025-12-02 16:12:18 +08:00
commit 07eea862d5
66 changed files with 12029 additions and 0 deletions

68
.dockerignore Normal file
View File

@@ -0,0 +1,68 @@
# Git
.git
.gitignore
.gitattributes
# Documentation
*.md
docs/
LICENSE
# Build artifacts
bin/
dist/
*.exe
*.dll
*.so
*.dylib
# Test files
*_test.go
**/*_test.go
coverage.*
# IDE
.idea/
.vscode/
*.swp
*.swo
*~
# OS
.DS_Store
Thumbs.db
# Temporary files
tmp/
temp/
*.tmp
# Certificates (will be generated in container)
*.pem
*.key
*.crt
*.csr
certs/
# Logs
*.log
logs/
# Environment
.env
.env.local
# Dependencies (will be downloaded in container)
vendor/
# CI/CD
.github/
.gitlab-ci.yml
.travis.yml
# Development tools
Makefile
scripts/
# Examples
examples/

49
.env.example Normal file
View File

@@ -0,0 +1,49 @@
# ===========================================
# Server Configuration (docker-compose.yml)
# ===========================================
# Domain for tunnel service
DOMAIN=tunnel.example.com
# Authentication token
AUTH_TOKEN=your-secret-token-here
# Server port
PORT=8080
# Timezone
TZ=UTC
# TLS Configuration (choose one)
# Option 1: Auto TLS with Let's Encrypt
# AUTO_TLS=1
# Option 2: Manual certificates (place in ./certs/)
# TLS_CERT=1
# TLS_KEY=1
# Build version
VERSION=latest
# GIT_COMMIT=
# ===========================================
# Client Configuration (docker-compose.client.yml)
# ===========================================
# Server address
SERVER_ADDR=tunnel.example.com:443
# Tunnel type: http, https, or tcp
TUNNEL_TYPE=http
# Local port to expose
LOCAL_PORT=3000
# Local address (default: 127.0.0.1)
# LOCAL_ADDRESS=192.168.1.100
# Custom subdomain (optional)
# SUBDOMAIN=myapp
# Run as daemon
# DAEMON=1

54
.gitignore vendored Normal file
View File

@@ -0,0 +1,54 @@
# Binaries for programs and plugins
*.exe
*.exe~
*.dll
*.so
*.dylib
bin/
dist/
# Test binary, built with `go test -c`
*.test
# Output of the go coverage tool
*.out
coverage.html
coverage.txt
# Dependency directories
vendor/
# Go workspace file
go.work
# IDEs
.idea/
.vscode/
*.swp
*.swo
*~
# OS
.DS_Store
Thumbs.db
# Environment files
.env
.env.local
*.pem
*.key
# Logs
*.log
logs/
# Build artifacts
*.tar.gz
*.zip
# Temporary files
tmp/
temp/
certs/
.drip-server.env
benchmark-results/

136
Makefile Normal file
View File

@@ -0,0 +1,136 @@
.PHONY: all build build-all clean test run-server run-client install deps fmt lint
# Variables
BINARY=bin/drip
VERSION?=dev
COMMIT=$(shell git rev-parse --short=10 HEAD 2>/dev/null || echo "unknown")
BUILD_TIME=$(shell date -u '+%Y-%m-%d_%H:%M:%S')
LDFLAGS=-ldflags "-s -w -X main.Version=${VERSION} -X main.GitCommit=${COMMIT} -X main.BuildTime=${BUILD_TIME}"
# Default target
all: clean deps test build
# Install dependencies
deps:
go mod download
go mod tidy
# Build unified binary
build:
@echo "Building Drip..."
@mkdir -p bin
go build ${LDFLAGS} -o ${BINARY} ./cmd/drip
@echo "Build complete!"
# Build for all platforms
build-all: clean
@echo "Building for multiple platforms..."
@mkdir -p bin
# Linux AMD64
GOOS=linux GOARCH=amd64 go build ${LDFLAGS} -o bin/drip-linux-amd64 ./cmd/drip
# Linux ARM64
GOOS=linux GOARCH=arm64 go build ${LDFLAGS} -o bin/drip-linux-arm64 ./cmd/drip
# macOS AMD64
GOOS=darwin GOARCH=amd64 go build ${LDFLAGS} -o bin/drip-darwin-amd64 ./cmd/drip
# macOS ARM64 (Apple Silicon)
GOOS=darwin GOARCH=arm64 go build ${LDFLAGS} -o bin/drip-darwin-arm64 ./cmd/drip
# Windows AMD64
GOOS=windows GOARCH=amd64 go build ${LDFLAGS} -o bin/drip-windows-amd64.exe ./cmd/drip
# Windows ARM64
GOOS=windows GOARCH=arm64 go build ${LDFLAGS} -o bin/drip-windows-arm64.exe ./cmd/drip
@echo "Multi-platform build complete!"
# Run tests
test:
go test -v -race -cover ./...
# Run tests with coverage
test-coverage:
go test -v -race -coverprofile=coverage.out -covermode=atomic ./...
go tool cover -html=coverage.out -o coverage.html
@echo "Coverage report generated: coverage.html"
# Benchmark tests
bench:
go test -bench=. -benchmem ./...
# Run server locally
run-server:
go run ./cmd/drip server
# Run client locally (example)
run-client:
go run ./cmd/drip http 3000
# Install globally
install:
go install ${LDFLAGS} ./cmd/drip
# Format code
fmt:
go fmt ./...
gofmt -s -w .
# Lint code
lint:
@which golangci-lint > /dev/null || (echo "Installing golangci-lint..." && go install github.com/golangci/golangci-lint/cmd/golangci-lint@latest)
golangci-lint run ./...
# Clean build artifacts
clean:
@echo "Cleaning..."
@rm -rf bin/
@rm -f coverage.out coverage.html
@echo "Clean complete!"
# Docker build
docker-build:
docker build -t drip-server:${VERSION} -f deployments/Dockerfile .
# Docker run
docker-run:
docker run -p 80:80 -p 8080:8080 drip-server:${VERSION}
# Generate test certificates
gen-certs:
@echo "Generating test TLS 1.3 certificates..."
@mkdir -p certs
openssl req -x509 -newkey rsa:4096 -nodes \
-keyout certs/server-key.pem \
-out certs/server-cert.pem \
-days 365 \
-subj "/CN=localhost"
@echo "Test certificates generated in certs/"
@echo "⚠️ Warning: These are self-signed certificates for testing only!"
# Help
help:
@echo "Drip - Available Make Targets:"
@echo ""
@echo " make build - Build server and client"
@echo " make build-all - Build for all platforms"
@echo " make test - Run tests"
@echo " make test-coverage - Run tests with coverage report"
@echo " make bench - Run benchmark tests"
@echo " make run-server - Run server locally"
@echo " make run-client - Run client locally (port 3000)"
@echo " make gen-certs - Generate test TLS certificates"
@echo " make install - Install client globally"
@echo " make fmt - Format code"
@echo " make lint - Lint code"
@echo " make clean - Clean build artifacts"
@echo " make deps - Install dependencies"
@echo " make docker-build - Build Docker image"
@echo " make docker-run - Run Docker container"
@echo ""
@echo "Build info:"
@echo " VERSION=${VERSION}"
@echo " COMMIT=${COMMIT}"
@echo " BUILD_TIME=${BUILD_TIME}"

244
README.md Normal file
View File

@@ -0,0 +1,244 @@
# Drip - Fast Tunnels to Localhost
Self-hosted tunneling solution. Expose your localhost to the internet securely.
[中文文档](README_CN.md)
[![Go Version](https://img.shields.io/badge/Go-1.21+-00ADD8?style=flat&logo=go)](https://golang.org/)
[![License](https://img.shields.io/badge/License-MIT-blue.svg)](LICENSE)
[![TLS](https://img.shields.io/badge/TLS-1.3-green.svg)](https://tools.ietf.org/html/rfc8446)
## Why?
**Control your data.** No third-party servers means your traffic stays between your client and your server.
**No limits.** Run as many tunnels as you need, use as much bandwidth as your server can handle.
**Actually free.** Use your own domain, no paid tiers or feature restrictions.
| Feature | Drip | ngrok Free |
|---------|------|------------|
| Privacy | Your infrastructure | Third-party servers |
| Domain | Your domain | 1 static subdomain |
| Bandwidth | Unlimited | 1 GB/month |
| Active Endpoints | Unlimited | 1 endpoint |
| Tunnels per Agent | Unlimited | Up to 3 |
| Requests | Unlimited | 20,000/month |
| Interstitial Page | None | Yes (removable with header) |
| Open Source | ✓ | ✗ |
## Quick Install
### Client (macOS/Linux)
```bash
bash <(curl -sL https:///install.sh)
```
### Server (Linux)
```bash
bash <(curl -sL https:///install-server.sh)
```
## Usage
### Basic Tunnels
```bash
# Expose local HTTP server
drip http 3000
# Expose local HTTPS server
drip https 443
# Pick your subdomain
drip http 3000 --subdomain myapp
# → https://myapp.your-domain.com
# Expose TCP service (database, SSH, etc.)
drip tcp 5432
```
### Forward to Any Address
Not just localhost - forward to any device on your network:
```bash
# Forward to another machine on LAN
drip http 8080 --address 192.168.1.100
# Forward to Docker container
drip http 3000 --address 172.17.0.2
# Forward to specific interface
drip http 3000 --address 10.0.0.5
```
### Daemon Mode
Run tunnels in the background:
```bash
# Start tunnel as daemon
drip daemon start http 3000
drip daemon start https 8443 --subdomain api
# Manage daemons
drip daemon list
drip daemon stop http-3000
drip daemon logs http-3000
```
## Server Deployment
### Prerequisites
- A domain with DNS pointing to your server (A record)
- Wildcard DNS for subdomains: `*.tunnel.example.com -> YOUR_IP`
- SSL certificate (wildcard recommended)
### Option 1: Direct (Recommended)
Drip server handles TLS directly on port 443:
```bash
# Get wildcard certificate
sudo certbot certonly --manual --preferred-challenges dns \
-d "*.tunnel.example.com" -d "tunnel.example.com"
# Start server
drip-server \
--port 443 \
--domain tunnel.example.com \
--tls-cert /etc/letsencrypt/live/tunnel.example.com/fullchain.pem \
--tls-key /etc/letsencrypt/live/tunnel.example.com/privkey.pem \
--token YOUR_SECRET_TOKEN
```
### Option 2: Behind Nginx
Run Drip on port 8443, let Nginx handle SSL termination:
```nginx
server {
listen 443 ssl http2;
server_name *.tunnel.example.com;
ssl_certificate /etc/letsencrypt/live/tunnel.example.com/fullchain.pem;
ssl_certificate_key /etc/letsencrypt/live/tunnel.example.com/privkey.pem;
location / {
proxy_pass https://127.0.0.1:8443;
proxy_ssl_verify off;
proxy_http_version 1.1;
proxy_set_header Host $host;
proxy_set_header X-Real-IP $remote_addr;
proxy_set_header X-Forwarded-For $proxy_add_x_forwarded_for;
proxy_set_header X-Forwarded-Proto $scheme;
proxy_buffering off;
}
}
```
### Systemd Service
The install script creates `/etc/systemd/system/drip-server.service` automatically. Manage with:
```bash
sudo systemctl start drip-server
sudo systemctl enable drip-server
sudo journalctl -u drip-server -f
```
## Features
**Security**
- TLS 1.3 encryption for all connections
- Token-based authentication
- No legacy protocol support
**Flexibility**
- HTTP, HTTPS, and TCP tunnels
- Forward to localhost or any LAN address
- Custom subdomains or auto-generated
- Daemon mode for persistent tunnels
**Performance**
- Binary protocol with msgpack encoding
- Connection pooling and reuse
- Minimal overhead between client and server
**Simplicity**
- One-line installation
- Save config once, use everywhere
- Real-time connection stats
## Architecture
```
┌─────────────┐ ┌──────────────┐ ┌─────────────┐
│ Internet │ ──────> │ Server │ <────── │ Client │
│ User │ HTTPS │ (Drip) │ TLS 1.3 │ localhost │
└─────────────┘ └──────────────┘ └─────────────┘
```
## Common Use Cases
**Development & Testing**
```bash
# Show local dev site to client
drip http 3000
# Test webhooks from services like Stripe
drip http 8000 --subdomain webhooks
```
**Home Server Access**
```bash
# Access home NAS remotely
drip http 5000 --address 192.168.1.50
# Remote into home network
drip tcp 22
```
**Docker & Containers**
```bash
# Expose containerized app
drip http 8080 --address 172.17.0.3
# Database access for debugging
drip tcp 5432 --address db-container
```
## Command Reference
```bash
# HTTP tunnel
drip http <port> [flags]
--subdomain, -n Custom subdomain
--address, -a Target address (default: 127.0.0.1)
--server Server address
--token Auth token
# HTTPS tunnel
drip https <port> [flags]
# TCP tunnel
drip tcp <port> [flags]
# Daemon commands
drip daemon start <type> <port> [flags]
drip daemon list
drip daemon stop <name>
drip daemon logs <name>
# Configuration
drip config init
drip config show
```
## License
MIT License - see [LICENSE](LICENSE) for details

244
README_CN.md Normal file
View File

@@ -0,0 +1,244 @@
# Drip - 快速内网穿透工具
自建隧道服务,让本地服务安全地暴露到公网。
[English](README.md)
[![Go Version](https://img.shields.io/badge/Go-1.21+-00ADD8?style=flat&logo=go)](https://golang.org/)
[![License](https://img.shields.io/badge/License-MIT-blue.svg)](LICENSE)
[![TLS](https://img.shields.io/badge/TLS-1.3-green.svg)](https://tools.ietf.org/html/rfc8446)
## 为什么用它?
**数据掌控。** 没有第三方服务器,流量只在你的客户端和服务器之间传输。
**没有限制。** 想开多少隧道就开多少,带宽只受服务器性能限制。
**真正免费。** 用自己的域名,没有付费功能,没有阉割版。
| 特性 | Drip | ngrok 免费版 |
|------|------|-------------|
| 隐私 | 自己的基础设施 | 第三方服务器 |
| 域名 | 你的域名 | 1 个固定子域名 |
| 带宽 | 无限制 | 1 GB/月 |
| 活跃端点 | 无限制 | 1 个端点 |
| 每个代理的隧道数 | 无限制 | 最多 3 个 |
| 请求数 | 无限制 | 20,000 次/月 |
| 警告页面 | 无 | 有(可用请求头移除) |
| 开源 | ✓ | ✗ |
## 一键安装
### 客户端 (macOS/Linux)
```bash
bash <(curl -sL https:///install.sh)
```
### 服务端 (Linux)
```bash
bash <(curl -sL https:///install-server.sh)
```
## 使用方法
### 基础隧道
```bash
# 暴露本地 HTTP 服务
drip http 3000
# 暴露本地 HTTPS 服务
drip https 443
# 自定义子域名
drip http 3000 --subdomain myapp
# → https://myapp.你的域名.com
# 暴露 TCP 服务数据库、SSH 等)
drip tcp 5432
```
### 转发到任意地址
不只是 localhost可以转发到局域网内的任何设备
```bash
# 转发到局域网其他设备
drip http 8080 --address 192.168.1.100
# 转发到 Docker 容器
drip http 3000 --address 172.17.0.2
# 转发到特定网卡
drip http 3000 --address 10.0.0.5
```
### 后台守护进程
让隧道在后台持续运行:
```bash
# 启动后台隧道
drip daemon start http 3000
drip daemon start https 8443 --subdomain api
# 管理后台隧道
drip daemon list
drip daemon stop http-3000
drip daemon logs http-3000
```
## 服务端部署
### 前置条件
- 域名已解析到服务器A 记录)
- 泛域名解析:`*.tunnel.example.com -> 服务器IP`
- SSL 证书(推荐通配符证书)
### 方式一:直接部署(推荐)
Drip 服务端直接监听 443 端口处理 TLS
```bash
# 获取通配符证书
sudo certbot certonly --manual --preferred-challenges dns \
-d "*.tunnel.example.com" -d "tunnel.example.com"
# 启动服务
drip-server \
--port 443 \
--domain tunnel.example.com \
--tls-cert /etc/letsencrypt/live/tunnel.example.com/fullchain.pem \
--tls-key /etc/letsencrypt/live/tunnel.example.com/privkey.pem \
--token 你的密钥
```
### 方式二Nginx 反向代理
Drip 监听 8443 端口Nginx 处理 SSL
```nginx
server {
listen 443 ssl http2;
server_name *.tunnel.example.com;
ssl_certificate /etc/letsencrypt/live/tunnel.example.com/fullchain.pem;
ssl_certificate_key /etc/letsencrypt/live/tunnel.example.com/privkey.pem;
location / {
proxy_pass https://127.0.0.1:8443;
proxy_ssl_verify off;
proxy_http_version 1.1;
proxy_set_header Host $host;
proxy_set_header X-Real-IP $remote_addr;
proxy_set_header X-Forwarded-For $proxy_add_x_forwarded_for;
proxy_set_header X-Forwarded-Proto $scheme;
proxy_buffering off;
}
}
```
### Systemd 服务
安装脚本会自动创建 `/etc/systemd/system/drip-server.service`,使用以下命令管理:
```bash
sudo systemctl start drip-server # 启动
sudo systemctl enable drip-server # 开机启动
sudo journalctl -u drip-server -f # 查看日志
```
## 核心特性
**安全性**
- 所有连接使用 TLS 1.3 加密
- 基于 Token 的认证
- 不支持旧版不安全协议
**灵活性**
- 支持 HTTP、HTTPS 和 TCP 隧道
- 转发到 localhost 或局域网任意地址
- 自定义子域名或自动生成
- 后台守护进程模式
**性能**
- 二进制协议配合 msgpack 编码
- 连接池复用
- 客户端与服务器之间开销极小
**简单易用**
- 一行命令安装
- 配置一次处处使用
- 实时连接统计
## 架构
```
┌─────────────┐ ┌──────────────┐ ┌─────────────┐
│ 外部用户 │ ──────> │ 服务器 │ <────── │ 你的电脑 │
│ │ HTTPS │ (Drip) │ TLS 1.3 │ localhost │
└─────────────┘ └──────────────┘ └─────────────┘
```
## 常见使用场景
**开发和测试**
```bash
# 给客户演示本地开发的网站
drip http 3000
# 测试第三方 webhook如 Stripe
drip http 8000 --subdomain webhooks
```
**家庭服务器**
```bash
# 远程访问家里的 NAS
drip http 5000 --address 192.168.1.50
# 远程连接家庭网络
drip tcp 22
```
**Docker 和容器**
```bash
# 暴露容器化应用
drip http 8080 --address 172.17.0.3
# 调试数据库
drip tcp 5432 --address db-container
```
## 命令参考
```bash
# HTTP 隧道
drip http <端口> [选项]
--subdomain, -n 自定义子域名
--address, -a 目标地址默认127.0.0.1
--server 服务器地址
--token 认证令牌
# HTTPS 隧道
drip https <端口> [选项]
# TCP 隧道
drip tcp <端口> [选项]
# 守护进程命令
drip daemon start <类型> <端口> [选项]
drip daemon list
drip daemon stop <名称>
drip daemon logs <名称>
# 配置管理
drip config init
drip config show
```
## 开源协议
MIT License - 详见 [LICENSE](LICENSE)

25
cmd/drip/main.go Normal file
View File

@@ -0,0 +1,25 @@
package main
import (
"fmt"
"os"
"drip/internal/client/cli"
)
var (
Version = "dev"
GitCommit = "unknown"
BuildTime = "unknown"
)
func main() {
// Set version information
cli.SetVersion(Version, GitCommit, BuildTime)
// Execute CLI
if err := cli.Execute(); err != nil {
fmt.Fprintf(os.Stderr, "Error: %v\n", err)
os.Exit(1)
}
}

38
deployments/Dockerfile Normal file
View File

@@ -0,0 +1,38 @@
FROM golang:1.25-alpine AS builder
RUN apk add --no-cache git ca-certificates tzdata
WORKDIR /app
COPY go.mod go.sum ./
RUN go mod download
COPY . .
RUN CGO_ENABLED=0 GOOS=linux GOARCH=amd64 go build \
-ldflags="-s -w -X main.Version=${VERSION:-dev} -X main.GitCommit=${GIT_COMMIT:-unknown} -X main.BuildTime=$(date -u '+%Y-%m-%d_%H:%M:%S')" \
-o drip \
./cmd/drip
FROM alpine:latest
RUN apk add --no-cache ca-certificates tzdata
RUN addgroup -S drip && adduser -S -G drip drip
WORKDIR /app
RUN mkdir -p /app/data/certs && \
chown -R drip:drip /app
COPY --from=builder /app/drip /app/drip
USER drip
EXPOSE 80 443 8080 20000-40000
HEALTHCHECK --interval=30s --timeout=3s --start-period=5s --retries=3 \
CMD wget --no-verbose --tries=1 --spider http://localhost:8080/health || exit 1
ENTRYPOINT ["/app/drip"]
CMD ["server", "--port", "8080"]

View File

@@ -0,0 +1,33 @@
FROM golang:1.25-alpine AS builder
RUN apk add --no-cache git ca-certificates
WORKDIR /app
COPY go.mod go.sum ./
RUN go mod download
COPY . .
RUN CGO_ENABLED=0 GOOS=linux GOARCH=amd64 go build \
-ldflags="-s -w -X main.Version=${VERSION:-dev} -X main.GitCommit=${GIT_COMMIT:-unknown} -X main.BuildTime=$(date -u '+%Y-%m-%d_%H:%M:%S')" \
-o drip \
./cmd/drip
FROM alpine:latest
RUN apk add --no-cache ca-certificates
RUN addgroup -S drip && adduser -S -G drip drip
WORKDIR /app
RUN mkdir -p /app/data && \
chown -R drip:drip /app
COPY --from=builder /app/drip /app/drip
USER drip
ENTRYPOINT ["/app/drip"]
CMD ["--help"]

220
deployments/README.md Normal file
View File

@@ -0,0 +1,220 @@
# Docker Deployment
## Quick Start
### Server (Production)
```bash
# Copy and configure environment
cp .env.example .env
nano .env
# Edit server configuration
DOMAIN=tunnel.example.com
AUTH_TOKEN=your-secret-token
TLS_CERT=1
TLS_KEY=1
# Place certificates
mkdir -p certs
cp /path/to/fullchain.pem certs/
cp /path/to/privkey.pem certs/
# Uncomment volume mount in docker-compose.yml
# - ./certs:/app/data/certs:ro
# Start server
docker compose up -d
# View logs
docker compose logs -f
```
### Client (Development/Testing)
```bash
# Copy and configure client environment
cp .env.example .env.client
nano .env.client
# Edit client configuration
SERVER_ADDR=tunnel.example.com:443
AUTH_TOKEN=your-secret-token
TUNNEL_TYPE=http
LOCAL_PORT=3000
# Start client
docker compose -f docker-compose.client.yml --env-file .env.client up -d
# View logs
docker compose -f docker-compose.client.yml logs -f
```
## Configuration
### Environment Variables
Create `.env` from `.env.example`:
```bash
DOMAIN=tunnel.example.com
AUTH_TOKEN=your-secret-token
```
### TLS Certificates
**Option 1: Auto TLS (Let's Encrypt)**
```bash
# Enable in .env
AUTO_TLS=1
# Ensure port 80 is accessible for ACME challenges
```
**Option 2: Manual Certificates**
```bash
# Place certificates in ./certs/
mkdir -p certs
cp fullchain.pem certs/cert.pem
cp privkey.pem certs/key.pem
# Uncomment in docker-compose.yml
# - ./certs:/app/data/certs:ro
# Enable in .env
TLS_CERT=1
TLS_KEY=1
```
## Data Persistence
All data is stored in Docker volumes:
- `drip-data`: Server data and certificates at `/app/data`
- `client-data`: Client configuration at `/app/data`
### Backup
```bash
# Backup server data
docker run --rm -v drip-data:/data -v $(pwd):/backup alpine tar czf /backup/drip-backup.tar.gz -C /data .
# Restore
docker run --rm -v drip-data:/data -v $(pwd):/backup alpine tar xzf /backup/drip-backup.tar.gz -C /data
```
## Port Mapping
| Container Port | Host Port | Purpose |
|---------------|-----------|---------|
| 80 | 80 | HTTP (ACME challenges) |
| 443 | 443 | HTTPS (main service) |
| 8080 | 8080 | HTTP (no TLS) |
| 20000-20100 | 20000-20100 | TCP tunnels |
## Management
### Server
```bash
# Start
docker compose up -d
# Stop
docker compose down
# Restart
docker compose restart
# View logs
docker compose logs -f
# Shell access
docker compose exec server sh
# Update
docker compose pull
docker compose up -d
```
### Client
```bash
# Start
docker compose -f docker-compose.client.yml up -d
# Stop
docker compose -f docker-compose.client.yml down
# View logs
docker compose -f docker-compose.client.yml logs -f
# Different tunnel types
TUNNEL_TYPE=http LOCAL_PORT=3000 docker compose -f docker-compose.client.yml up -d
TUNNEL_TYPE=https LOCAL_PORT=8443 docker compose -f docker-compose.client.yml up -d
TUNNEL_TYPE=tcp LOCAL_PORT=5432 docker compose -f docker-compose.client.yml up -d
```
## Production Deployment
### With Reverse Proxy
If using Nginx/Traefik in front:
```yaml
services:
server:
ports:
- "127.0.0.1:8080:8080" # Only expose to localhost
command: >
server
--domain tunnel.example.com
--port 8080
--token ${AUTH_TOKEN}
```
### Resource Limits
Adjust in `docker-compose.yml`:
```yaml
deploy:
resources:
limits:
cpus: '2'
memory: 512M
```
## Troubleshooting
**Certificate errors**
```bash
# Check certificate files
docker compose exec server ls -la /app/data/certs
# Check server logs
docker compose logs server | grep -i tls
```
**Connection issues**
```bash
# Verify port accessibility
curl -I https://tunnel.example.com
# Check server status
docker compose exec server /app/drip server --help
```
**Reset everything**
```bash
# Stop and remove everything
docker compose down -v
# Start fresh
docker compose up -d
```

38
docker-compose.client.yml Normal file
View File

@@ -0,0 +1,38 @@
services:
client:
build:
context: .
dockerfile: deployments/Dockerfile.client
args:
VERSION: ${VERSION:-dev}
GIT_COMMIT: ${GIT_COMMIT:-unknown}
image: drip-client:${VERSION:-latest}
container_name: drip-client
restart: unless-stopped
network_mode: host
volumes:
- drip-client-data:/app/data
# Optional: mount config file
# - ./client-config.yaml:/app/data/config.yaml:ro
environment:
TZ: ${TZ:-UTC}
command: >
${TUNNEL_TYPE:-http} ${LOCAL_PORT:-3000}
--server ${SERVER_ADDR}
${AUTH_TOKEN:+--token ${AUTH_TOKEN}}
${SUBDOMAIN:+--subdomain ${SUBDOMAIN}}
${LOCAL_ADDRESS:+--address ${LOCAL_ADDRESS}}
${DAEMON:+--daemon}
logging:
driver: json-file
options:
max-size: 10m
max-file: "3"
volumes:
drip-client-data:
driver: local

67
docker-compose.yml Normal file
View File

@@ -0,0 +1,67 @@
services:
server:
build:
context: .
dockerfile: deployments/Dockerfile
args:
VERSION: ${VERSION:-dev}
GIT_COMMIT: ${GIT_COMMIT:-unknown}
image: drip-server:${VERSION:-latest}
container_name: drip-server
restart: unless-stopped
ports:
- "80:80"
- "443:443"
- "8080:8080"
- "20000-20100:20000-20100"
volumes:
- drip-data:/app/data
# Mount TLS certificates if not using auto-TLS
# - ./certs:/app/data/certs:ro
environment:
TZ: ${TZ:-UTC}
command: >
server
--domain ${DOMAIN:-tunnel.localhost}
--port ${PORT:-8080}
${TLS_CERT:+--tls-cert /app/data/certs/fullchain.pem}
${TLS_KEY:+--tls-key /app/data/certs/privkey.pem}
${AUTO_TLS:+--auto-tls}
${AUTH_TOKEN:+--token ${AUTH_TOKEN}}
networks:
- drip-net
logging:
driver: json-file
options:
max-size: 10m
max-file: "3"
deploy:
resources:
limits:
cpus: '2'
memory: 512M
reservations:
cpus: '0.5'
memory: 128M
healthcheck:
test: ["CMD", "wget", "--no-verbose", "--tries=1", "--spider", "http://localhost:${PORT:-8080}/health"]
interval: 30s
timeout: 3s
retries: 3
start_period: 5s
volumes:
drip-data:
driver: local
networks:
drip-net:
driver: bridge

22
go.mod Normal file
View File

@@ -0,0 +1,22 @@
module drip
go 1.25.4
require (
github.com/gorilla/websocket v1.5.3
github.com/spf13/cobra v1.10.1
github.com/vmihailenco/msgpack/v5 v5.4.1
go.uber.org/zap v1.27.1
golang.org/x/crypto v0.45.0
gopkg.in/yaml.v3 v3.0.1
)
require (
github.com/inconshreveable/mousetrap v1.1.0 // indirect
github.com/spf13/pflag v1.0.10 // indirect
github.com/stretchr/testify v1.11.1 // indirect
github.com/vmihailenco/tagparser/v2 v2.0.0 // indirect
go.uber.org/multierr v1.11.0 // indirect
golang.org/x/net v0.47.0 // indirect
golang.org/x/text v0.31.0 // indirect
)

37
go.sum Normal file
View File

@@ -0,0 +1,37 @@
github.com/cpuguy83/go-md2man/v2 v2.0.6/go.mod h1:oOW0eioCTA6cOiMLiUPZOpcVxMig6NIQQ7OS05n1F4g=
github.com/davecgh/go-spew v1.1.1 h1:vj9j/u1bqnvCEfJOwUhtlOARqs3+rkHYY13jYWTU97c=
github.com/davecgh/go-spew v1.1.1/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38=
github.com/gorilla/websocket v1.5.3 h1:saDtZ6Pbx/0u+bgYQ3q96pZgCzfhKXGPqt7kZ72aNNg=
github.com/gorilla/websocket v1.5.3/go.mod h1:YR8l580nyteQvAITg2hZ9XVh4b55+EU/adAjf1fMHhE=
github.com/inconshreveable/mousetrap v1.1.0 h1:wN+x4NVGpMsO7ErUn/mUI3vEoE6Jt13X2s0bqwp9tc8=
github.com/inconshreveable/mousetrap v1.1.0/go.mod h1:vpF70FUmC8bwa3OWnCshd2FqLfsEA9PFc4w1p2J65bw=
github.com/pmezard/go-difflib v1.0.0 h1:4DBwDE0NGyQoBHbLQYPwSUPoCMWR5BEzIk/f1lZbAQM=
github.com/pmezard/go-difflib v1.0.0/go.mod h1:iKH77koFhYxTK1pcRnkKkqfTogsbg7gZNVY4sRDYZ/4=
github.com/russross/blackfriday/v2 v2.1.0/go.mod h1:+Rmxgy9KzJVeS9/2gXHxylqXiyQDYRxCVz55jmeOWTM=
github.com/spf13/cobra v1.10.1 h1:lJeBwCfmrnXthfAupyUTzJ/J4Nc1RsHC/mSRU2dll/s=
github.com/spf13/cobra v1.10.1/go.mod h1:7SmJGaTHFVBY0jW4NXGluQoLvhqFQM+6XSKD+P4XaB0=
github.com/spf13/pflag v1.0.9/go.mod h1:McXfInJRrz4CZXVZOBLb0bTZqETkiAhM9Iw0y3An2Bg=
github.com/spf13/pflag v1.0.10 h1:4EBh2KAYBwaONj6b2Ye1GiHfwjqyROoF4RwYO+vPwFk=
github.com/spf13/pflag v1.0.10/go.mod h1:McXfInJRrz4CZXVZOBLb0bTZqETkiAhM9Iw0y3An2Bg=
github.com/stretchr/testify v1.11.1 h1:7s2iGBzp5EwR7/aIZr8ao5+dra3wiQyKjjFuvgVKu7U=
github.com/stretchr/testify v1.11.1/go.mod h1:wZwfW3scLgRK+23gO65QZefKpKQRnfz6sD981Nm4B6U=
github.com/vmihailenco/msgpack/v5 v5.4.1 h1:cQriyiUvjTwOHg8QZaPihLWeRAAVoCpE00IUPn0Bjt8=
github.com/vmihailenco/msgpack/v5 v5.4.1/go.mod h1:GaZTsDaehaPpQVyxrf5mtQlH+pc21PIudVV/E3rRQok=
github.com/vmihailenco/tagparser/v2 v2.0.0 h1:y09buUbR+b5aycVFQs/g70pqKVZNBmxwAhO7/IwNM9g=
github.com/vmihailenco/tagparser/v2 v2.0.0/go.mod h1:Wri+At7QHww0WTrCBeu4J6bNtoV6mEfg5OIWRZA9qds=
go.uber.org/goleak v1.3.0 h1:2K3zAYmnTNqV73imy9J1T3WC+gmCePx2hEGkimedGto=
go.uber.org/goleak v1.3.0/go.mod h1:CoHD4mav9JJNrW/WLlf7HGZPjdw8EucARQHekz1X6bE=
go.uber.org/multierr v1.11.0 h1:blXXJkSxSSfBVBlC76pxqeO+LN3aDfLQo+309xJstO0=
go.uber.org/multierr v1.11.0/go.mod h1:20+QtiLqy0Nd6FdQB9TLXag12DsQkrbs3htMFfDN80Y=
go.uber.org/zap v1.27.1 h1:08RqriUEv8+ArZRYSTXy1LeBScaMpVSTBhCeaZYfMYc=
go.uber.org/zap v1.27.1/go.mod h1:GB2qFLM7cTU87MWRP2mPIjqfIDnGu+VIO4V/SdhGo2E=
golang.org/x/crypto v0.45.0 h1:jMBrvKuj23MTlT0bQEOBcAE0mjg8mK9RXFhRH6nyF3Q=
golang.org/x/crypto v0.45.0/go.mod h1:XTGrrkGJve7CYK7J8PEww4aY7gM3qMCElcJQ8n8JdX4=
golang.org/x/net v0.47.0 h1:Mx+4dIFzqraBXUugkia1OOvlD6LemFo1ALMHjrXDOhY=
golang.org/x/net v0.47.0/go.mod h1:/jNxtkgq5yWUGYkaZGqo27cfGZ1c5Nen03aYrrKpVRU=
golang.org/x/text v0.31.0 h1:aC8ghyu4JhP8VojJ2lEHBnochRno1sgL6nEi9WGFGMM=
golang.org/x/text v0.31.0/go.mod h1:tKRAlv61yKIjGGHX/4tP1LTbc13YSec1pxVEWXzfoeM=
gopkg.in/check.v1 v0.0.0-20161208181325-20d25e280405 h1:yhCVgyC4o1eVCa2tZl7eS0r+SDo693bJlVdllGtEeKM=
gopkg.in/check.v1 v0.0.0-20161208181325-20d25e280405/go.mod h1:Co6ibVJAznAaIkqp8huTwlJQCZ016jof/cbN4VW5Yz0=
gopkg.in/yaml.v3 v3.0.1 h1:fxVm/GzAzEWqLHuvctI91KS9hhNmmWOoWu0XTYJS7CA=
gopkg.in/yaml.v3 v3.0.1/go.mod h1:K4uyk7z7BCEPqu6E+C64Yfv1cQ7kz7rIZviUmN+EgEM=

View File

@@ -0,0 +1,244 @@
package cli
import (
"bufio"
"fmt"
"os"
"os/exec"
"os/signal"
"path/filepath"
"strconv"
"strings"
"syscall"
"time"
"github.com/spf13/cobra"
)
var attachCmd = &cobra.Command{
Use: "attach [type] [port]",
Short: "Attach to a running background tunnel",
Long: `Attach to a running background tunnel to view its logs in real-time.
Examples:
drip attach List running tunnels and select one
drip attach http 3000 Attach to HTTP tunnel on port 3000
drip attach tcp 5432 Attach to TCP tunnel on port 5432
Press Ctrl+C to detach (tunnel will continue running).`,
Aliases: []string{"logs", "tail"},
Args: cobra.MaximumNArgs(2),
RunE: runAttach,
}
func init() {
rootCmd.AddCommand(attachCmd)
}
func runAttach(cmd *cobra.Command, args []string) error {
// Clean up stale daemons first
CleanupStaleDaemons()
// Get all running daemons
daemons, err := ListAllDaemons()
if err != nil {
return fmt.Errorf("failed to list daemons: %w", err)
}
if len(daemons) == 0 {
fmt.Println("\033[90mNo running tunnels.\033[0m")
fmt.Println()
fmt.Println("Start a tunnel in background with:")
fmt.Println(" \033[36mdrip http 3000 -d\033[0m")
fmt.Println(" \033[36mdrip tcp 5432 -d\033[0m")
return nil
}
var selectedDaemon *DaemonInfo
// If type and port are specified, find the specific daemon
if len(args) == 2 {
tunnelType := args[0]
if tunnelType != "http" && tunnelType != "tcp" {
return fmt.Errorf("invalid tunnel type: %s (must be 'http' or 'tcp')", tunnelType)
}
port, err := strconv.Atoi(args[1])
if err != nil || port < 1 || port > 65535 {
return fmt.Errorf("invalid port number: %s", args[1])
}
// Find the daemon
for _, d := range daemons {
if d.Type == tunnelType && d.Port == port {
if !IsProcessRunning(d.PID) {
RemoveDaemonInfo(d.Type, d.Port)
return fmt.Errorf("tunnel is not running (cleaned up stale entry)")
}
selectedDaemon = d
break
}
}
if selectedDaemon == nil {
return fmt.Errorf("no %s tunnel running on port %d", tunnelType, port)
}
} else if len(args) == 0 {
// Interactive selection
selectedDaemon, err = selectDaemonInteractive(daemons)
if err != nil {
return err
}
if selectedDaemon == nil {
return nil // User cancelled
}
} else {
return fmt.Errorf("usage: drip attach [type port]")
}
// Attach to the selected daemon
return attachToDaemon(selectedDaemon)
}
func selectDaemonInteractive(daemons []*DaemonInfo) (*DaemonInfo, error) {
// Print header
fmt.Println()
fmt.Println("\033[1;37mSelect a tunnel to attach:\033[0m")
fmt.Println("\033[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\033[0m")
// Filter out non-running daemons
var runningDaemons []*DaemonInfo
for _, d := range daemons {
if IsProcessRunning(d.PID) {
runningDaemons = append(runningDaemons, d)
} else {
RemoveDaemonInfo(d.Type, d.Port)
}
}
if len(runningDaemons) == 0 {
fmt.Println("\033[90mNo running tunnels.\033[0m")
return nil, nil
}
// Print list
for i, d := range runningDaemons {
uptime := time.Since(d.StartTime)
// Format type with color
var typeStr string
if d.Type == "http" {
typeStr = "\033[32mHTTP\033[0m"
} else {
typeStr = "\033[35mTCP\033[0m"
}
// Truncate URL if too long
url := d.URL
if len(url) > 50 {
url = url[:47] + "..."
}
fmt.Printf("\033[1;36m%d.\033[0m %-15s \033[90mPort:\033[0m %-6d \033[90mURL:\033[0m %-50s \033[90mUptime:\033[0m %s\n",
i+1, typeStr, d.Port, url, FormatDuration(uptime))
}
fmt.Println("\033[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\033[0m")
fmt.Printf("Enter number (1-%d) or 'q' to quit: ", len(runningDaemons))
// Read user input
reader := bufio.NewReader(os.Stdin)
input, err := reader.ReadString('\n')
if err != nil {
return nil, fmt.Errorf("failed to read input: %w", err)
}
input = strings.TrimSpace(input)
if input == "q" || input == "Q" {
return nil, nil
}
// Parse selection
selection, err := strconv.Atoi(input)
if err != nil || selection < 1 || selection > len(runningDaemons) {
return nil, fmt.Errorf("invalid selection: %s", input)
}
return runningDaemons[selection-1], nil
}
func attachToDaemon(daemon *DaemonInfo) error {
// Get log file path
logPath := filepath.Join(getDaemonDir(), fmt.Sprintf("%s_%d.log", daemon.Type, daemon.Port))
// Check if log file exists
if _, err := os.Stat(logPath); os.IsNotExist(err) {
return fmt.Errorf("log file not found: %s", logPath)
}
// Print header
fmt.Println()
fmt.Println("\033[1;32m╔══════════════════════════════════════════════════════════════════╗\033[0m")
fmt.Printf("\033[1;32m║\033[0m \033[1;37mAttached to %s tunnel on port %d\033[0m", strings.ToUpper(daemon.Type), daemon.Port)
fmt.Printf("%s\033[1;32m║\033[0m\n", strings.Repeat(" ", 32-len(daemon.Type)))
fmt.Println("\033[1;32m╠══════════════════════════════════════════════════════════════════╣\033[0m")
fmt.Printf("\033[1;32m║\033[0m \033[90mURL:\033[0m \033[36m%-52s\033[0m \033[1;32m║\033[0m\n", daemon.URL)
uptime := time.Since(daemon.StartTime)
fmt.Printf("\033[1;32m║\033[0m \033[90mPID:\033[0m \033[90m%-52d\033[0m \033[1;32m║\033[0m\n", daemon.PID)
fmt.Printf("\033[1;32m║\033[0m \033[90mUptime:\033[0m \033[90m%-52s\033[0m \033[1;32m║\033[0m\n", FormatDuration(uptime))
fmt.Printf("\033[1;32m║\033[0m \033[90mLog:\033[0m \033[90m%-52s\033[0m \033[1;32m║\033[0m\n", truncatePath(logPath, 52))
fmt.Println("\033[1;32m╠══════════════════════════════════════════════════════════════════╣\033[0m")
fmt.Println("\033[1;32m║\033[0m \033[33mPress Ctrl+C to detach (tunnel will continue running)\033[0m \033[1;32m║\033[0m")
fmt.Println("\033[1;32m╚══════════════════════════════════════════════════════════════════╝\033[0m")
fmt.Println()
// Setup signal handler
sigCh := make(chan os.Signal, 1)
signal.Notify(sigCh, syscall.SIGINT, syscall.SIGTERM)
// Start tail command
tailCmd := exec.Command("tail", "-f", logPath)
tailCmd.Stdout = os.Stdout
tailCmd.Stderr = os.Stderr
if err := tailCmd.Start(); err != nil {
return fmt.Errorf("failed to start tail: %w", err)
}
// Wait for signal or tail to exit
done := make(chan error, 1)
go func() {
done <- tailCmd.Wait()
}()
select {
case <-sigCh:
// Kill tail process
if tailCmd.Process != nil {
tailCmd.Process.Kill()
}
fmt.Println()
fmt.Println("\033[33mDetached from tunnel (tunnel is still running)\033[0m")
fmt.Printf("Use '\033[36mdrip attach %s %d\033[0m' to reattach\n", daemon.Type, daemon.Port)
fmt.Printf("Use '\033[36mdrip stop %s %d\033[0m' to stop the tunnel\n", daemon.Type, daemon.Port)
return nil
case err := <-done:
if err != nil {
return fmt.Errorf("tail process exited: %w", err)
}
return nil
}
}
func truncatePath(path string, maxLen int) string {
if len(path) <= maxLen {
return path
}
// Try to keep filename and show ... in the middle
filename := filepath.Base(path)
if len(filename) >= maxLen-3 {
return "..." + filename[len(filename)-(maxLen-3):]
}
dirLen := maxLen - len(filename) - 3
return path[:dirLen] + "..." + filename
}

View File

@@ -0,0 +1,273 @@
package cli
import (
"bufio"
"fmt"
"os"
"strings"
"drip/pkg/config"
"github.com/spf13/cobra"
)
var configCmd = &cobra.Command{
Use: "config",
Short: "Manage configuration",
Long: "Manage Drip client configuration (server, token, etc.)",
}
var configInitCmd = &cobra.Command{
Use: "init",
Short: "Initialize configuration interactively",
Long: "Initialize Drip configuration with interactive prompts",
RunE: runConfigInit,
}
var configShowCmd = &cobra.Command{
Use: "show",
Short: "Show current configuration",
Long: "Display the current Drip configuration",
RunE: runConfigShow,
}
var configSetCmd = &cobra.Command{
Use: "set",
Short: "Set configuration values",
Long: "Set specific configuration values (server, token)",
RunE: runConfigSet,
}
var configResetCmd = &cobra.Command{
Use: "reset",
Short: "Reset configuration",
Long: "Delete the configuration file",
RunE: runConfigReset,
}
var configValidateCmd = &cobra.Command{
Use: "validate",
Short: "Validate configuration",
Long: "Validate the configuration file",
RunE: runConfigValidate,
}
var (
configFull bool
configForce bool
configServer string
configToken string
)
func init() {
// Add subcommands
configCmd.AddCommand(configInitCmd)
configCmd.AddCommand(configShowCmd)
configCmd.AddCommand(configSetCmd)
configCmd.AddCommand(configResetCmd)
configCmd.AddCommand(configValidateCmd)
// Flags for show
configShowCmd.Flags().BoolVar(&configFull, "full", false, "Show full token (not hidden)")
// Flags for set
configSetCmd.Flags().StringVar(&configServer, "server", "", "Server address (e.g., tunnel.example.com:443)")
configSetCmd.Flags().StringVar(&configToken, "token", "", "Authentication token")
// Flags for reset
configResetCmd.Flags().BoolVar(&configForce, "force", false, "Force reset without confirmation")
// Add to root
rootCmd.AddCommand(configCmd)
}
func runConfigInit(cmd *cobra.Command, args []string) error {
fmt.Println("\n╔═══════════════════════════════════════╗")
fmt.Println("║ Drip Configuration Setup ║")
fmt.Println("╚═══════════════════════════════════════╝")
reader := bufio.NewReader(os.Stdin)
// Get server address
fmt.Print("Server address (e.g., tunnel.example.com:443): ")
serverAddr, _ := reader.ReadString('\n')
serverAddr = strings.TrimSpace(serverAddr)
if serverAddr == "" {
return fmt.Errorf("server address is required")
}
// Get token
fmt.Print("Authentication token (leave empty to skip): ")
token, _ := reader.ReadString('\n')
token = strings.TrimSpace(token)
// Create config
cfg := &config.ClientConfig{
Server: serverAddr,
Token: token,
TLS: true,
}
// Save config
if err := config.SaveClientConfig(cfg, ""); err != nil {
return fmt.Errorf("failed to save configuration: %w", err)
}
fmt.Println("\n✓ Configuration saved to", config.DefaultClientConfigPath())
fmt.Println("✓ You can now use 'drip' without --server and --token")
return nil
}
func runConfigShow(cmd *cobra.Command, args []string) error {
// Load config
cfg, err := config.LoadClientConfig("")
if err != nil {
return err
}
fmt.Println("\n╔═══════════════════════════════════════╗")
fmt.Println("║ Current Configuration ║")
fmt.Println("╚═══════════════════════════════════════╝")
fmt.Printf("Server: %s\n", cfg.Server)
// Show token (hidden or full)
if cfg.Token != "" {
if configFull {
fmt.Printf("Token: %s\n", cfg.Token)
} else {
// Hide middle part of token
if len(cfg.Token) > 10 {
fmt.Printf("Token: %s***%s (hidden)\n",
cfg.Token[:3],
cfg.Token[len(cfg.Token)-3:],
)
} else {
fmt.Printf("Token: %s (hidden)\n", cfg.Token[:3]+"***")
}
}
} else {
fmt.Println("Token: (not set)")
}
fmt.Printf("TLS: %s\n", enabledDisabled(cfg.TLS))
fmt.Printf("Config: %s\n\n", config.DefaultClientConfigPath())
return nil
}
func runConfigSet(cmd *cobra.Command, args []string) error {
// Load existing config or create new
cfg, err := config.LoadClientConfig("")
if err != nil {
// Create new config if not exists
cfg = &config.ClientConfig{
TLS: true,
}
}
// Update fields if provided
modified := false
if configServer != "" {
cfg.Server = configServer
modified = true
fmt.Printf("✓ Server updated: %s\n", configServer)
}
if configToken != "" {
cfg.Token = configToken
modified = true
fmt.Println("✓ Token updated")
}
if !modified {
return fmt.Errorf("no changes specified. Use --server or --token")
}
// Save config
if err := config.SaveClientConfig(cfg, ""); err != nil {
return fmt.Errorf("failed to save configuration: %w", err)
}
fmt.Println("✓ Configuration saved")
return nil
}
func runConfigReset(cmd *cobra.Command, args []string) error {
configPath := config.DefaultClientConfigPath()
// Check if config exists
if !config.ConfigExists("") {
fmt.Println("No configuration file found")
return nil
}
// Confirm deletion
if !configForce {
fmt.Print("Are you sure you want to delete the configuration? (y/N): ")
reader := bufio.NewReader(os.Stdin)
response, _ := reader.ReadString('\n')
response = strings.ToLower(strings.TrimSpace(response))
if response != "y" && response != "yes" {
fmt.Println("Cancelled")
return nil
}
}
// Delete config file
if err := os.Remove(configPath); err != nil {
return fmt.Errorf("failed to delete configuration: %w", err)
}
fmt.Println("✓ Configuration file deleted")
return nil
}
func runConfigValidate(cmd *cobra.Command, args []string) error {
fmt.Println("\nValidating configuration...")
fmt.Println("━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━")
// Load config
cfg, err := config.LoadClientConfig("")
if err != nil {
fmt.Println("✗ Failed to load configuration")
return err
}
// Validate server address
if cfg.Server == "" {
fmt.Println("✗ Server address is not set")
return fmt.Errorf("invalid configuration")
}
fmt.Println("✓ Server address is valid")
// Validate token
if cfg.Token != "" {
fmt.Println("✓ Token is set")
} else {
fmt.Println("⚠ Token is not set (authentication may fail)")
}
// Validate TLS
if cfg.TLS {
fmt.Println("✓ TLS is enabled")
} else {
fmt.Println("⚠ TLS is disabled (not recommended for production)")
}
fmt.Println("\n✓ Configuration is valid")
return nil
}
func enabledDisabled(value bool) string {
if value {
return "enabled"
}
return "disabled"
}

View File

@@ -0,0 +1,263 @@
package cli
import (
"encoding/json"
"fmt"
"os"
"os/exec"
"path/filepath"
"strconv"
"time"
)
// DaemonInfo stores information about a running daemon process
type DaemonInfo struct {
PID int `json:"pid"`
Type string `json:"type"` // "http" or "tcp"
Port int `json:"port"` // Local port being tunneled
Subdomain string `json:"subdomain"` // Subdomain if specified
Server string `json:"server"` // Server address
URL string `json:"url"` // Tunnel URL
StartTime time.Time `json:"start_time"` // When the daemon started
Executable string `json:"executable"` // Path to the executable
}
// getDaemonDir returns the directory for storing daemon info
func getDaemonDir() string {
home, err := os.UserHomeDir()
if err != nil {
return ".drip"
}
return filepath.Join(home, ".drip", "daemons")
}
// getDaemonFilePath returns the path to a daemon info file
func getDaemonFilePath(tunnelType string, port int) string {
return filepath.Join(getDaemonDir(), fmt.Sprintf("%s_%d.json", tunnelType, port))
}
// SaveDaemonInfo saves daemon information to a file
func SaveDaemonInfo(info *DaemonInfo) error {
dir := getDaemonDir()
if err := os.MkdirAll(dir, 0700); err != nil {
return fmt.Errorf("failed to create daemon directory: %w", err)
}
data, err := json.MarshalIndent(info, "", " ")
if err != nil {
return fmt.Errorf("failed to marshal daemon info: %w", err)
}
path := getDaemonFilePath(info.Type, info.Port)
if err := os.WriteFile(path, data, 0600); err != nil {
return fmt.Errorf("failed to write daemon info: %w", err)
}
return nil
}
// LoadDaemonInfo loads daemon information from a file
func LoadDaemonInfo(tunnelType string, port int) (*DaemonInfo, error) {
path := getDaemonFilePath(tunnelType, port)
data, err := os.ReadFile(path)
if err != nil {
if os.IsNotExist(err) {
return nil, nil
}
return nil, fmt.Errorf("failed to read daemon info: %w", err)
}
var info DaemonInfo
if err := json.Unmarshal(data, &info); err != nil {
return nil, fmt.Errorf("failed to parse daemon info: %w", err)
}
return &info, nil
}
// RemoveDaemonInfo removes a daemon info file
func RemoveDaemonInfo(tunnelType string, port int) error {
path := getDaemonFilePath(tunnelType, port)
if err := os.Remove(path); err != nil && !os.IsNotExist(err) {
return fmt.Errorf("failed to remove daemon info: %w", err)
}
return nil
}
// ListAllDaemons returns all daemon info files
func ListAllDaemons() ([]*DaemonInfo, error) {
dir := getDaemonDir()
entries, err := os.ReadDir(dir)
if err != nil {
if os.IsNotExist(err) {
return nil, nil
}
return nil, fmt.Errorf("failed to read daemon directory: %w", err)
}
var daemons []*DaemonInfo
for _, entry := range entries {
if entry.IsDir() || filepath.Ext(entry.Name()) != ".json" {
continue
}
data, err := os.ReadFile(filepath.Join(dir, entry.Name()))
if err != nil {
continue
}
var info DaemonInfo
if err := json.Unmarshal(data, &info); err != nil {
continue
}
daemons = append(daemons, &info)
}
return daemons, nil
}
// IsProcessRunning checks if a process with the given PID is running
func IsProcessRunning(pid int) bool {
process, err := os.FindProcess(pid)
if err != nil {
return false
}
return isProcessRunningOS(process)
}
// KillProcess kills a process by PID
func KillProcess(pid int) error {
process, err := os.FindProcess(pid)
if err != nil {
return fmt.Errorf("process not found: %w", err)
}
if err := killProcessOS(process); err != nil {
return fmt.Errorf("failed to kill process: %w", err)
}
return nil
}
// StartDaemon starts the current process as a daemon
func StartDaemon(tunnelType string, port int, args []string) error {
// Get the executable path
executable, err := os.Executable()
if err != nil {
return fmt.Errorf("failed to get executable path: %w", err)
}
// Build command arguments (remove -D/--daemon flag to prevent recursion)
var cleanArgs []string
skipNext := false
for i, arg := range args {
if skipNext {
skipNext = false
continue
}
// Skip -D or --daemon flags (but NOT --daemon-child)
if arg == "-D" || arg == "--daemon" {
continue
}
// Handle -d (short form) - skip it
if arg == "-d" {
continue
}
// Skip if next arg would be a value for a removed flag (not applicable for boolean)
_ = i
cleanArgs = append(cleanArgs, arg)
}
// Create the command
cmd := exec.Command(executable, cleanArgs...)
// Detach from parent process (platform-specific)
setupDaemonCmd(cmd)
// Create log file for daemon output
logDir := getDaemonDir()
if err := os.MkdirAll(logDir, 0700); err != nil {
return fmt.Errorf("failed to create daemon directory: %w", err)
}
logPath := filepath.Join(logDir, fmt.Sprintf("%s_%d.log", tunnelType, port))
logFile, err := os.OpenFile(logPath, os.O_CREATE|os.O_WRONLY|os.O_TRUNC, 0600)
if err != nil {
return fmt.Errorf("failed to create log file: %w", err)
}
// Redirect stdin to /dev/null
devNull, err := os.OpenFile(os.DevNull, os.O_RDONLY, 0)
if err != nil {
logFile.Close()
return fmt.Errorf("failed to open /dev/null: %w", err)
}
cmd.Stdin = devNull
cmd.Stdout = logFile
cmd.Stderr = logFile
// Start the process
if err := cmd.Start(); err != nil {
logFile.Close()
devNull.Close()
return fmt.Errorf("failed to start daemon: %w", err)
}
// Don't wait for the process - let it run in background
// The child process will save its own daemon info after connecting
fmt.Printf("\033[32m✓\033[0m Started %s tunnel on port %d in background (PID: %d)\n", tunnelType, port, cmd.Process.Pid)
fmt.Printf(" Use '\033[36mdrip list\033[0m' to check tunnel status\n")
fmt.Printf(" Use '\033[36mdrip attach %s %d\033[0m' to view logs\n", tunnelType, port)
fmt.Printf(" Use '\033[36mdrip stop %s %d\033[0m' to stop this tunnel\n", tunnelType, port)
fmt.Printf(" Logs: \033[90m%s\033[0m\n", logPath)
return nil
}
// CleanupStaleDaemons removes daemon info for processes that are no longer running
func CleanupStaleDaemons() error {
daemons, err := ListAllDaemons()
if err != nil {
return err
}
for _, info := range daemons {
if !IsProcessRunning(info.PID) {
RemoveDaemonInfo(info.Type, info.Port)
}
}
return nil
}
// FormatDuration formats a duration in a human-readable way
func FormatDuration(d time.Duration) string {
if d < time.Minute {
return fmt.Sprintf("%ds", int(d.Seconds()))
} else if d < time.Hour {
return fmt.Sprintf("%dm %ds", int(d.Minutes()), int(d.Seconds())%60)
} else if d < 24*time.Hour {
return fmt.Sprintf("%dh %dm", int(d.Hours()), int(d.Minutes())%60)
}
days := int(d.Hours()) / 24
hours := int(d.Hours()) % 24
return fmt.Sprintf("%dd %dh", days, hours)
}
// ParsePortFromArgs extracts the port number from command arguments
func ParsePortFromArgs(args []string) (int, error) {
for _, arg := range args {
// Skip flags
if len(arg) > 0 && arg[0] == '-' {
continue
}
// Try to parse as port number
port, err := strconv.Atoi(arg)
if err == nil && port > 0 && port <= 65535 {
return port, nil
}
}
return 0, fmt.Errorf("port number not found in arguments")
}

View File

@@ -0,0 +1,38 @@
//go:build !windows
package cli
import (
"os"
"os/exec"
"syscall"
)
// getSysProcAttr returns platform-specific process attributes for daemonization
func getSysProcAttr() *syscall.SysProcAttr {
return &syscall.SysProcAttr{
Setsid: true, // Create new session (Unix only)
}
}
// isProcessRunningOS checks if a process is running using OS-specific method
func isProcessRunningOS(process *os.Process) bool {
// On Unix, FindProcess always succeeds, so we need to send signal 0 to check
err := process.Signal(syscall.Signal(0))
return err == nil
}
// killProcessOS kills a process using OS-specific signals
func killProcessOS(process *os.Process) error {
// First try SIGTERM for graceful shutdown
if err := process.Signal(syscall.SIGTERM); err != nil {
// If SIGTERM fails, try SIGKILL
return process.Signal(syscall.SIGKILL)
}
return nil
}
// setupDaemonCmd configures the command for daemon mode
func setupDaemonCmd(cmd *exec.Cmd) {
cmd.SysProcAttr = getSysProcAttr()
}

View File

@@ -0,0 +1,42 @@
//go:build windows
package cli
import (
"os"
"os/exec"
"syscall"
)
// getSysProcAttr returns platform-specific process attributes for daemonization
func getSysProcAttr() *syscall.SysProcAttr {
return &syscall.SysProcAttr{
CreationFlags: syscall.CREATE_NEW_PROCESS_GROUP,
}
}
// isProcessRunningOS checks if a process is running using OS-specific method
func isProcessRunningOS(process *os.Process) bool {
// On Windows, we try to open the process to check if it exists
// FindProcess doesn't actually check if process exists on Windows
// We can try to send signal, but Windows doesn't support signal 0
// Instead, we'll try to kill with signal 0 which returns an error if process doesn't exist
err := process.Signal(os.Signal(syscall.Signal(0)))
if err != nil {
// Try alternative: check if we can get process info
// If the process doesn't exist, Signal will fail
return false
}
return true
}
// killProcessOS kills a process using OS-specific method
func killProcessOS(process *os.Process) error {
// On Windows, use Kill() directly
return process.Kill()
}
// setupDaemonCmd configures the command for daemon mode
func setupDaemonCmd(cmd *exec.Cmd) {
cmd.SysProcAttr = getSysProcAttr()
}

306
internal/client/cli/http.go Normal file
View File

@@ -0,0 +1,306 @@
package cli
import (
"fmt"
"os"
"os/signal"
"strconv"
"strings"
"syscall"
"time"
"drip/internal/client/tcp"
"drip/internal/shared/protocol"
"drip/internal/shared/utils"
"drip/pkg/config"
"github.com/spf13/cobra"
"go.uber.org/zap"
)
const (
maxReconnectAttempts = 5
reconnectInterval = 3 * time.Second
)
var (
subdomain string
daemonMode bool
daemonMarker bool
localAddress string
)
var httpCmd = &cobra.Command{
Use: "http <port>",
Short: "Start HTTP tunnel",
Long: `Start an HTTP tunnel to expose a local HTTP server.
Example:
drip http 3000 Tunnel localhost:3000
drip http 8080 --subdomain myapp Use custom subdomain
Configuration:
First time: Run 'drip config init' to save server and token
Subsequent: Just run 'drip http <port>'
Note: Uses TCP over TLS 1.3 for secure communication`,
Args: cobra.ExactArgs(1),
RunE: runHTTP,
}
func init() {
httpCmd.Flags().StringVarP(&subdomain, "subdomain", "n", "", "Custom subdomain (optional)")
httpCmd.Flags().BoolVarP(&daemonMode, "daemon", "d", false, "Run in background (daemon mode)")
httpCmd.Flags().StringVarP(&localAddress, "address", "a", "127.0.0.1", "Local address to forward to (default: 127.0.0.1)")
httpCmd.Flags().BoolVar(&daemonMarker, "daemon-child", false, "Internal flag for daemon child process")
httpCmd.Flags().MarkHidden("daemon-child")
rootCmd.AddCommand(httpCmd)
}
func runHTTP(cmd *cobra.Command, args []string) error {
port, err := strconv.Atoi(args[0])
if err != nil || port < 1 || port > 65535 {
return fmt.Errorf("invalid port number: %s", args[0])
}
if daemonMode && !daemonMarker {
daemonArgs := append([]string{"http"}, args...)
daemonArgs = append(daemonArgs, "--daemon-child")
if subdomain != "" {
daemonArgs = append(daemonArgs, "--subdomain", subdomain)
}
if localAddress != "127.0.0.1" {
daemonArgs = append(daemonArgs, "--address", localAddress)
}
if serverURL != "" {
daemonArgs = append(daemonArgs, "--server", serverURL)
}
if authToken != "" {
daemonArgs = append(daemonArgs, "--token", authToken)
}
if insecure {
daemonArgs = append(daemonArgs, "--insecure")
}
if verbose {
daemonArgs = append(daemonArgs, "--verbose")
}
return StartDaemon("http", port, daemonArgs)
}
if err := utils.InitLogger(verbose); err != nil {
return fmt.Errorf("failed to initialize logger: %w", err)
}
defer utils.Sync()
logger := utils.GetLogger()
var serverAddr, token string
if serverURL == "" {
cfg, err := config.LoadClientConfig("")
if err != nil {
return fmt.Errorf(`configuration not found.
Please run 'drip config init' first, or use flags:
drip http %d --server SERVER:PORT --token TOKEN`, port)
}
serverAddr = cfg.Server
token = cfg.Token
} else {
serverAddr = serverURL
token = authToken
}
if serverAddr == "" {
return fmt.Errorf("server address is required")
}
connConfig := &tcp.ConnectorConfig{
ServerAddr: serverAddr,
Token: token,
TunnelType: protocol.TunnelTypeHTTP,
LocalHost: localAddress,
LocalPort: port,
Subdomain: subdomain,
Insecure: insecure,
}
quit := make(chan os.Signal, 1)
signal.Notify(quit, syscall.SIGINT, syscall.SIGTERM)
reconnectAttempts := 0
for {
connector := tcp.NewConnector(connConfig, logger)
if reconnectAttempts == 0 {
fmt.Printf("\033[36m🔌 Connecting to %s...\033[0m\n", serverAddr)
} else {
fmt.Printf("\033[33m🔄 Reconnecting to %s (attempt %d/%d)...\033[0m\n", serverAddr, reconnectAttempts, maxReconnectAttempts)
}
if err := connector.Connect(); err != nil {
if isNonRetryableError(err) {
return fmt.Errorf("failed to connect: %w", err)
}
reconnectAttempts++
if reconnectAttempts >= maxReconnectAttempts {
return fmt.Errorf("failed to connect after %d attempts: %w", maxReconnectAttempts, err)
}
fmt.Printf("\033[31m✗ Connection failed: %v\033[0m\n", err)
fmt.Printf("\033[90m Retrying in %v...\033[0m\n", reconnectInterval)
select {
case <-quit:
fmt.Println("\n\033[33m🛑 Shutting down...\033[0m")
return nil
case <-time.After(reconnectInterval):
continue
}
}
reconnectAttempts = 0
if daemonMarker {
daemonInfo := &DaemonInfo{
PID: os.Getpid(),
Type: "http",
Port: port,
Subdomain: subdomain,
Server: serverAddr,
URL: connector.GetURL(),
StartTime: time.Now(),
Executable: os.Args[0],
}
if err := SaveDaemonInfo(daemonInfo); err != nil {
logger.Warn("Failed to save daemon info", zap.Error(err))
}
}
fmt.Println()
fmt.Println("\033[1;32m╔══════════════════════════════════════════════════════════════════╗\033[0m")
fmt.Println("\033[1;32m║\033[0m \033[1;37m🚀 HTTP Tunnel Connected Successfully!\033[0m \033[1;32m║\033[0m")
fmt.Println("\033[1;32m╠══════════════════════════════════════════════════════════════════╣\033[0m")
fmt.Printf("\033[1;32m║\033[0m \033[1;37mTunnel URL:\033[0m \033[1;32m║\033[0m\n")
fmt.Printf("\033[1;32m║\033[0m \033[1;36m%-60s\033[0m \033[1;32m║\033[0m\n", connector.GetURL())
fmt.Println("\033[1;32m║\033[0m \033[1;32m║\033[0m")
displayAddr := localAddress
if displayAddr == "127.0.0.1" {
displayAddr = "localhost"
}
fmt.Printf("\033[1;32m║\033[0m \033[90mForwarding:\033[0m \033[1m%s:%d\033[0m → \033[36m%s\033[0m%-15s\033[1;32m║\033[0m\n", displayAddr, port, "public", "")
fmt.Printf("\033[1;32m║\033[0m \033[90mLatency:\033[0m \033[90mmeasuring...\033[0m%-40s\033[1;32m║\033[0m\n", "")
fmt.Printf("\033[1;32m║\033[0m \033[90mTraffic:\033[0m \033[90m↓ 0 B ↑ 0 B\033[0m%-32s\033[1;32m║\033[0m\n", "")
fmt.Printf("\033[1;32m║\033[0m \033[90mSpeed:\033[0m \033[90m↓ 0 B/s ↑ 0 B/s\033[0m%-28s\033[1;32m║\033[0m\n", "")
fmt.Printf("\033[1;32m║\033[0m \033[90mRequests:\033[0m \033[90m0\033[0m%-43s\033[1;32m║\033[0m\n", "")
fmt.Println("\033[1;32m╠══════════════════════════════════════════════════════════════════╣\033[0m")
fmt.Println("\033[1;32m║\033[0m \033[90mPress Ctrl+C to stop the tunnel\033[0m \033[1;32m║\033[0m")
fmt.Println("\033[1;32m╚══════════════════════════════════════════════════════════════════╝\033[0m")
fmt.Println()
latencyCh := make(chan time.Duration, 1)
connector.SetLatencyCallback(func(latency time.Duration) {
select {
case latencyCh <- latency:
default:
}
})
stopDisplay := make(chan struct{})
disconnected := make(chan struct{})
go func() {
ticker := time.NewTicker(1 * time.Second)
defer ticker.Stop()
var lastLatency time.Duration
for {
select {
case latency := <-latencyCh:
lastLatency = latency
case <-ticker.C:
stats := connector.GetStats()
if stats != nil {
stats.UpdateSpeed()
snapshot := stats.GetSnapshot()
fmt.Print("\033[8A")
fmt.Printf("\r\033[1;32m║\033[0m \033[90mLatency:\033[0m %s%-40s\033[1;32m║\033[0m\n", formatLatency(lastLatency), "")
trafficStr := fmt.Sprintf("↓ %s ↑ %s", tcp.FormatBytes(snapshot.TotalBytesIn), tcp.FormatBytes(snapshot.TotalBytesOut))
fmt.Printf("\r\033[1;32m║\033[0m \033[90mTraffic:\033[0m \033[36m%-48s\033[0m\033[1;32m║\033[0m\n", trafficStr)
speedStr := fmt.Sprintf("↓ %s ↑ %s", tcp.FormatSpeed(snapshot.SpeedIn), tcp.FormatSpeed(snapshot.SpeedOut))
fmt.Printf("\r\033[1;32m║\033[0m \033[90mSpeed:\033[0m \033[33m%-48s\033[0m\033[1;32m║\033[0m\n", speedStr)
fmt.Printf("\r\033[1;32m║\033[0m \033[90mRequests:\033[0m \033[35m%-47d\033[0m\033[1;32m║\033[0m\n", snapshot.TotalRequests)
fmt.Print("\033[4B")
}
case <-stopDisplay:
return
}
}
}()
go func() {
connector.Wait()
close(disconnected)
}()
select {
case <-quit:
close(stopDisplay)
fmt.Println("\n\n\033[33m🛑 Shutting down...\033[0m")
connector.Close()
if daemonMarker {
RemoveDaemonInfo("http", port)
}
fmt.Println("\033[32m✓\033[0m Tunnel closed")
return nil
case <-disconnected:
close(stopDisplay)
fmt.Println("\n\n\033[31m⚠ Connection lost!\033[0m")
reconnectAttempts++
if reconnectAttempts >= maxReconnectAttempts {
return fmt.Errorf("connection lost after %d reconnect attempts", maxReconnectAttempts)
}
fmt.Printf("\033[90m Reconnecting in %v...\033[0m\n", reconnectInterval)
select {
case <-quit:
fmt.Println("\n\033[33m🛑 Shutting down...\033[0m")
return nil
case <-time.After(reconnectInterval):
continue
}
}
}
}
func formatLatency(d time.Duration) string {
ms := d.Milliseconds()
if ms < 50 {
return fmt.Sprintf("\033[32m%dms\033[0m", ms)
} else if ms < 100 {
return fmt.Sprintf("\033[33m%dms\033[0m", ms)
} else if ms < 200 {
return fmt.Sprintf("\033[38;5;208m%dms\033[0m", ms)
}
return fmt.Sprintf("\033[31m%dms\033[0m", ms)
}
func isNonRetryableError(err error) bool {
errStr := err.Error()
if strings.Contains(errStr, "subdomain is already taken") ||
strings.Contains(errStr, "subdomain is reserved") ||
strings.Contains(errStr, "invalid subdomain") {
return true
}
if strings.Contains(errStr, "authentication") ||
strings.Contains(errStr, "Invalid authentication token") {
return true
}
return false
}

View File

@@ -0,0 +1,305 @@
package cli
import (
"fmt"
"os"
"os/signal"
"strconv"
"syscall"
"time"
"drip/internal/client/tcp"
"drip/internal/shared/protocol"
"drip/internal/shared/utils"
"drip/pkg/config"
"github.com/spf13/cobra"
"go.uber.org/zap"
)
var (
httpsSubdomain string
httpsDaemonMode bool
httpsDaemonMarker bool
httpsLocalAddress string
)
var httpsCmd = &cobra.Command{
Use: "https <port>",
Short: "Start HTTPS tunnel",
Long: `Start an HTTPS tunnel to expose a local HTTPS server.
Example:
drip https 443 Tunnel localhost:443
drip https 8443 --subdomain myapp Use custom subdomain
Configuration:
First time: Run 'drip config init' to save server and token
Subsequent: Just run 'drip https <port>'
Note: Uses TCP over TLS 1.3 for secure communication`,
Args: cobra.ExactArgs(1),
RunE: runHTTPS,
}
func init() {
httpsCmd.Flags().StringVarP(&httpsSubdomain, "subdomain", "n", "", "Custom subdomain (optional)")
httpsCmd.Flags().BoolVarP(&httpsDaemonMode, "daemon", "d", false, "Run in background (daemon mode)")
httpsCmd.Flags().StringVarP(&httpsLocalAddress, "address", "a", "127.0.0.1", "Local address to forward to (default: 127.0.0.1)")
httpsCmd.Flags().BoolVar(&httpsDaemonMarker, "daemon-child", false, "Internal flag for daemon child process")
httpsCmd.Flags().MarkHidden("daemon-child")
rootCmd.AddCommand(httpsCmd)
}
func runHTTPS(cmd *cobra.Command, args []string) error {
// Parse port
port, err := strconv.Atoi(args[0])
if err != nil || port < 1 || port > 65535 {
return fmt.Errorf("invalid port number: %s", args[0])
}
// Handle daemon mode
if httpsDaemonMode && !httpsDaemonMarker {
// Start as daemon
daemonArgs := append([]string{"https"}, args...)
daemonArgs = append(daemonArgs, "--daemon-child")
if httpsSubdomain != "" {
daemonArgs = append(daemonArgs, "--subdomain", httpsSubdomain)
}
if httpsLocalAddress != "127.0.0.1" {
daemonArgs = append(daemonArgs, "--address", httpsLocalAddress)
}
if serverURL != "" {
daemonArgs = append(daemonArgs, "--server", serverURL)
}
if authToken != "" {
daemonArgs = append(daemonArgs, "--token", authToken)
}
if insecure {
daemonArgs = append(daemonArgs, "--insecure")
}
if verbose {
daemonArgs = append(daemonArgs, "--verbose")
}
return StartDaemon("https", port, daemonArgs)
}
// Initialize logger
if err := utils.InitLogger(verbose); err != nil {
return fmt.Errorf("failed to initialize logger: %w", err)
}
defer utils.Sync()
logger := utils.GetLogger()
// Load configuration or use command line flags
var serverAddr, token string
if serverURL == "" {
// Try to load from config file
cfg, err := config.LoadClientConfig("")
if err != nil {
return fmt.Errorf(`configuration not found.
Please run 'drip config init' first, or use flags:
drip https %d --server SERVER:PORT --token TOKEN`, port)
}
serverAddr = cfg.Server
token = cfg.Token
} else {
// Use command line flags
serverAddr = serverURL
token = authToken
}
// Validate server address
if serverAddr == "" {
return fmt.Errorf("server address is required")
}
// Create connector config
connConfig := &tcp.ConnectorConfig{
ServerAddr: serverAddr,
Token: token,
TunnelType: protocol.TunnelTypeHTTPS,
LocalHost: httpsLocalAddress,
LocalPort: port,
Subdomain: httpsSubdomain,
Insecure: insecure,
}
// Setup signal handler for graceful shutdown
quit := make(chan os.Signal, 1)
signal.Notify(quit, syscall.SIGINT, syscall.SIGTERM)
// Connection loop with reconnect support
reconnectAttempts := 0
for {
// Create connector
connector := tcp.NewConnector(connConfig, logger)
// Connect to server
if reconnectAttempts == 0 {
fmt.Printf("\033[36m🔌 Connecting to %s...\033[0m\n", serverAddr)
} else {
fmt.Printf("\033[33m🔄 Reconnecting to %s (attempt %d/%d)...\033[0m\n", serverAddr, reconnectAttempts, maxReconnectAttempts)
}
if err := connector.Connect(); err != nil {
// Check if this is a non-retryable error
if isNonRetryableError(err) {
return fmt.Errorf("failed to connect: %w", err)
}
reconnectAttempts++
if reconnectAttempts >= maxReconnectAttempts {
return fmt.Errorf("failed to connect after %d attempts: %w", maxReconnectAttempts, err)
}
fmt.Printf("\033[31m✗ Connection failed: %v\033[0m\n", err)
fmt.Printf("\033[90m Retrying in %v...\033[0m\n", reconnectInterval)
// Wait before retry, but allow interrupt
select {
case <-quit:
fmt.Println("\n\033[33m🛑 Shutting down...\033[0m")
return nil
case <-time.After(reconnectInterval):
continue
}
}
// Reset reconnect attempts on successful connection
reconnectAttempts = 0
// Save daemon info if running as daemon child
if httpsDaemonMarker {
daemonInfo := &DaemonInfo{
PID: os.Getpid(),
Type: "https",
Port: port,
Subdomain: httpsSubdomain,
Server: serverAddr,
URL: connector.GetURL(),
StartTime: time.Now(),
Executable: os.Args[0],
}
if err := SaveDaemonInfo(daemonInfo); err != nil {
// Log but don't fail
logger.Warn("Failed to save daemon info", zap.Error(err))
}
}
// Print tunnel information
fmt.Println()
fmt.Println("\033[1;32m╔══════════════════════════════════════════════════════════════════╗\033[0m")
fmt.Println("\033[1;32m║\033[0m \033[1;37m🔒 HTTPS Tunnel Connected Successfully!\033[0m \033[1;32m║\033[0m")
fmt.Println("\033[1;32m╠══════════════════════════════════════════════════════════════════╣\033[0m")
fmt.Printf("\033[1;32m║\033[0m \033[1;37mTunnel URL:\033[0m \033[1;32m║\033[0m\n")
fmt.Printf("\033[1;32m║\033[0m \033[1;36m%-60s\033[0m \033[1;32m║\033[0m\n", connector.GetURL())
fmt.Println("\033[1;32m║\033[0m \033[1;32m║\033[0m")
displayAddr := httpsLocalAddress
if displayAddr == "127.0.0.1" {
displayAddr = "localhost"
}
fmt.Printf("\033[1;32m║\033[0m \033[90mForwarding:\033[0m \033[1m%s:%d\033[0m → \033[36m%s\033[0m%-15s\033[1;32m║\033[0m\n", displayAddr, port, "public", "")
fmt.Printf("\033[1;32m║\033[0m \033[90mLatency:\033[0m \033[90mmeasuring...\033[0m%-40s\033[1;32m║\033[0m\n", "")
fmt.Printf("\033[1;32m║\033[0m \033[90mTraffic:\033[0m \033[90m↓ 0 B ↑ 0 B\033[0m%-32s\033[1;32m║\033[0m\n", "")
fmt.Printf("\033[1;32m║\033[0m \033[90mSpeed:\033[0m \033[90m↓ 0 B/s ↑ 0 B/s\033[0m%-28s\033[1;32m║\033[0m\n", "")
fmt.Printf("\033[1;32m║\033[0m \033[90mRequests:\033[0m \033[90m0\033[0m%-43s\033[1;32m║\033[0m\n", "")
fmt.Println("\033[1;32m╠══════════════════════════════════════════════════════════════════╣\033[0m")
fmt.Println("\033[1;32m║\033[0m \033[90mPress Ctrl+C to stop the tunnel\033[0m \033[1;32m║\033[0m")
fmt.Println("\033[1;32m╚══════════════════════════════════════════════════════════════════╝\033[0m")
fmt.Println()
// Setup latency display
latencyCh := make(chan time.Duration, 1)
connector.SetLatencyCallback(func(latency time.Duration) {
select {
case latencyCh <- latency:
default:
}
})
// Start stats display updater (updates every second)
stopDisplay := make(chan struct{})
disconnected := make(chan struct{})
go func() {
ticker := time.NewTicker(1 * time.Second)
defer ticker.Stop()
var lastLatency time.Duration
for {
select {
case latency := <-latencyCh:
lastLatency = latency
case <-ticker.C:
// Update speed calculation
stats := connector.GetStats()
if stats != nil {
stats.UpdateSpeed()
snapshot := stats.GetSnapshot()
// Move cursor up 8 lines to update display
fmt.Print("\033[8A")
// Update latency line
fmt.Printf("\r\033[1;32m║\033[0m \033[90mLatency:\033[0m %s%-40s\033[1;32m║\033[0m\n", formatLatency(lastLatency), "")
// Update traffic line
trafficStr := fmt.Sprintf("↓ %s ↑ %s", tcp.FormatBytes(snapshot.TotalBytesIn), tcp.FormatBytes(snapshot.TotalBytesOut))
fmt.Printf("\r\033[1;32m║\033[0m \033[90mTraffic:\033[0m \033[36m%-48s\033[0m\033[1;32m║\033[0m\n", trafficStr)
// Update speed line
speedStr := fmt.Sprintf("↓ %s ↑ %s", tcp.FormatSpeed(snapshot.SpeedIn), tcp.FormatSpeed(snapshot.SpeedOut))
fmt.Printf("\r\033[1;32m║\033[0m \033[90mSpeed:\033[0m \033[33m%-48s\033[0m\033[1;32m║\033[0m\n", speedStr)
// Update requests line
fmt.Printf("\r\033[1;32m║\033[0m \033[90mRequests:\033[0m \033[35m%-47d\033[0m\033[1;32m║\033[0m\n", snapshot.TotalRequests)
// Move back down 4 lines
fmt.Print("\033[4B")
}
case <-stopDisplay:
return
}
}
}()
// Monitor connection in background
go func() {
connector.Wait()
close(disconnected)
}()
// Wait for signal or disconnection
select {
case <-quit:
close(stopDisplay)
fmt.Println("\n\n\033[33m🛑 Shutting down...\033[0m")
connector.Close()
if httpsDaemonMarker {
RemoveDaemonInfo("https", port)
}
fmt.Println("\033[32m✓\033[0m Tunnel closed")
return nil
case <-disconnected:
close(stopDisplay)
fmt.Println("\n\n\033[31m⚠ Connection lost!\033[0m")
reconnectAttempts++
if reconnectAttempts >= maxReconnectAttempts {
return fmt.Errorf("connection lost after %d reconnect attempts", maxReconnectAttempts)
}
fmt.Printf("\033[90m Reconnecting in %v...\033[0m\n", reconnectInterval)
// Wait before reconnect, but allow interrupt
select {
case <-quit:
fmt.Println("\n\033[33m🛑 Shutting down...\033[0m")
return nil
case <-time.After(reconnectInterval):
continue
}
}
}
}

194
internal/client/cli/list.go Normal file
View File

@@ -0,0 +1,194 @@
package cli
import (
"bufio"
"fmt"
"os"
"strconv"
"strings"
"time"
"github.com/spf13/cobra"
)
var (
interactiveMode bool
)
var listCmd = &cobra.Command{
Use: "list",
Short: "List all running background tunnels",
Long: `List all running background tunnels.
Example:
drip list Show all running tunnels
drip list -i Interactive mode (select to attach/stop)
This command shows:
- Tunnel type (HTTP/TCP)
- Local port being tunneled
- Public URL
- Process ID (PID)
- Uptime
In interactive mode, you can select a tunnel to:
- Attach: View real-time logs
- Stop: Terminate the tunnel`,
Aliases: []string{"ls", "ps", "status"},
RunE: runList,
}
func init() {
listCmd.Flags().BoolVarP(&interactiveMode, "interactive", "i", false, "Interactive mode for attach/stop")
rootCmd.AddCommand(listCmd)
}
func runList(cmd *cobra.Command, args []string) error {
// Clean up stale daemons first
CleanupStaleDaemons()
// Get all running daemons
daemons, err := ListAllDaemons()
if err != nil {
return fmt.Errorf("failed to list daemons: %w", err)
}
if len(daemons) == 0 {
fmt.Println("\033[90mNo running tunnels.\033[0m")
fmt.Println()
fmt.Println("Start a tunnel in background with:")
fmt.Println(" \033[36mdrip http 3000 -d\033[0m")
fmt.Println(" \033[36mdrip tcp 5432 -d\033[0m")
return nil
}
// Print header
fmt.Println()
fmt.Println("\033[1;37mRunning Tunnels\033[0m")
fmt.Println("\033[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\033[0m")
fmt.Printf("\033[1m%-4s %-6s %-6s %-40s %-8s %s\033[0m\n", "#", "TYPE", "PORT", "URL", "PID", "UPTIME")
fmt.Println("\033[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\033[0m")
idx := 1
for _, d := range daemons {
// Check if process is still running
if !IsProcessRunning(d.PID) {
// Clean up stale entry
RemoveDaemonInfo(d.Type, d.Port)
continue
}
// Calculate uptime
uptime := time.Since(d.StartTime)
// Format type with color
var typeStr string
if d.Type == "http" {
typeStr = "\033[32mHTTP\033[0m"
} else {
typeStr = "\033[35mTCP\033[0m"
}
// Truncate URL if too long
url := d.URL
if len(url) > 40 {
url = url[:37] + "..."
}
fmt.Printf("\033[1;36m%-4d\033[0m %-15s %-6d %-40s %-8d %s\n",
idx, typeStr, d.Port, url, d.PID, FormatDuration(uptime))
idx++
}
fmt.Println("\033[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\033[0m")
fmt.Println()
// Interactive mode or show commands
if interactiveMode || shouldPromptForAction() {
return runInteractiveList(daemons)
}
fmt.Println("Commands:")
fmt.Println(" \033[36mdrip list -i\033[0m Interactive mode")
fmt.Println(" \033[36mdrip attach http 3000\033[0m Attach to tunnel (view logs)")
fmt.Println(" \033[36mdrip stop http 3000\033[0m Stop tunnel")
fmt.Println(" \033[36mdrip stop all\033[0m Stop all tunnels")
return nil
}
func shouldPromptForAction() bool {
// Check if running in a terminal
if fileInfo, _ := os.Stdout.Stat(); (fileInfo.Mode() & os.ModeCharDevice) == 0 {
return false
}
// Always prompt when there are tunnels running
return true
}
func runInteractiveList(daemons []*DaemonInfo) error {
// Filter out non-running daemons
var runningDaemons []*DaemonInfo
for _, d := range daemons {
if IsProcessRunning(d.PID) {
runningDaemons = append(runningDaemons, d)
} else {
RemoveDaemonInfo(d.Type, d.Port)
}
}
if len(runningDaemons) == 0 {
return nil
}
// Prompt for action
fmt.Print("Select a tunnel (number) or 'q' to quit: ")
reader := bufio.NewReader(os.Stdin)
input, err := reader.ReadString('\n')
if err != nil {
return fmt.Errorf("failed to read input: %w", err)
}
input = strings.TrimSpace(input)
if input == "" || input == "q" || input == "Q" {
return nil
}
// Parse selection
selection, err := strconv.Atoi(input)
if err != nil || selection < 1 || selection > len(runningDaemons) {
return fmt.Errorf("invalid selection: %s", input)
}
selectedDaemon := runningDaemons[selection-1]
// Prompt for action
fmt.Println()
fmt.Printf("Selected: \033[1m%s\033[0m tunnel on port \033[1m%d\033[0m\n", strings.ToUpper(selectedDaemon.Type), selectedDaemon.Port)
fmt.Println()
fmt.Println("What would you like to do?")
fmt.Println(" \033[36m1.\033[0m Attach (view logs)")
fmt.Println(" \033[36m2.\033[0m Stop tunnel")
fmt.Println(" \033[90mq. Cancel\033[0m")
fmt.Println()
fmt.Print("Choose an action: ")
actionInput, err := reader.ReadString('\n')
if err != nil {
return fmt.Errorf("failed to read input: %w", err)
}
actionInput = strings.TrimSpace(actionInput)
switch actionInput {
case "1":
// Attach to daemon
return attachToDaemon(selectedDaemon)
case "2":
// Stop daemon
return stopDaemon(selectedDaemon.Type, selectedDaemon.Port)
case "q", "Q", "":
return nil
default:
return fmt.Errorf("invalid action: %s", actionInput)
}
}

View File

@@ -0,0 +1,81 @@
package cli
import (
"fmt"
"github.com/spf13/cobra"
)
var (
// Version information
Version = "dev"
GitCommit = "unknown"
BuildTime = "unknown"
// Global flags
serverURL string
authToken string
verbose bool
insecure bool
)
var rootCmd = &cobra.Command{
Use: "drip",
Short: "Drip - Fast and secure tunnels to localhost",
Long: `Drip - High-performance tunneling service with TCP over TLS 1.3
Expose your local services to the internet securely and easily.
Configuration:
First time: Run 'drip config init' to set up server and token
Subsequent: Just run 'drip http <port>' or 'drip tcp <port>'
Examples:
drip config init # Set up configuration
drip http 3000 # HTTP tunnel
drip tcp 5432 # PostgreSQL tunnel
drip http 8080 --subdomain myapp # Custom subdomain
Features:
✓ TCP over TLS 1.3 (secure and fast)
✓ HTTP and TCP tunnel support
✓ Auto-save configuration
✓ Custom subdomains
✓ Authentication via token`,
}
func init() {
rootCmd.PersistentFlags().StringVarP(&serverURL, "server", "s", "", "Server address (e.g., tunnel.example.com:443)")
rootCmd.PersistentFlags().StringVarP(&authToken, "token", "t", "", "Authentication token")
rootCmd.PersistentFlags().BoolVarP(&verbose, "verbose", "v", false, "Verbose output")
rootCmd.PersistentFlags().BoolVarP(&insecure, "insecure", "k", false, "Skip TLS verification (testing only, NOT recommended)")
rootCmd.AddCommand(versionCmd)
// http and tcp commands are added in their respective init() functions
// config command is added in config.go init() function
}
var versionCmd = &cobra.Command{
Use: "version",
Short: "Print version information",
Run: func(cmd *cobra.Command, args []string) {
fmt.Printf("Drip Client\n")
fmt.Printf("━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\n")
fmt.Printf("Version: %s\n", Version)
fmt.Printf("Git Commit: %s\n", GitCommit)
fmt.Printf("Build Time: %s\n", BuildTime)
fmt.Printf("━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\n")
},
}
// Execute runs the root command
func Execute() error {
return rootCmd.Execute()
}
// SetVersion sets the version information
func SetVersion(version, commit, buildTime string) {
Version = version
GitCommit = commit
BuildTime = buildTime
}

View File

@@ -0,0 +1,191 @@
package cli
import (
"fmt"
"net/http"
_ "net/http/pprof"
"os"
"os/signal"
"strconv"
"syscall"
"drip/internal/server/proxy"
"drip/internal/server/tcp"
"drip/internal/server/tunnel"
"drip/internal/shared/constants"
"drip/internal/shared/utils"
"drip/pkg/config"
"github.com/spf13/cobra"
"go.uber.org/zap"
)
var (
serverPort int
serverPublicPort int
serverDomain string
serverAuthToken string
serverDebug bool
serverTCPPortMin int
serverTCPPortMax int
serverTLSCert string
serverTLSKey string
serverPprofPort int
)
var serverCmd = &cobra.Command{
Use: "server",
Short: "Start Drip server",
Long: `Start the Drip tunnel server to accept client connections`,
RunE: runServer,
}
func init() {
rootCmd.AddCommand(serverCmd)
// Command line flags with environment variable defaults
serverCmd.Flags().IntVarP(&serverPort, "port", "p", getEnvInt("DRIP_PORT", 8443), "Server port (env: DRIP_PORT)")
serverCmd.Flags().IntVar(&serverPublicPort, "public-port", getEnvInt("DRIP_PUBLIC_PORT", 0), "Public port to display in URLs (env: DRIP_PUBLIC_PORT)")
serverCmd.Flags().StringVarP(&serverDomain, "domain", "d", getEnvString("DRIP_DOMAIN", constants.DefaultDomain), "Server domain (env: DRIP_DOMAIN)")
serverCmd.Flags().StringVarP(&serverAuthToken, "token", "t", getEnvString("DRIP_TOKEN", ""), "Authentication token (env: DRIP_TOKEN)")
serverCmd.Flags().BoolVar(&serverDebug, "debug", false, "Enable debug logging")
serverCmd.Flags().IntVar(&serverTCPPortMin, "tcp-port-min", getEnvInt("DRIP_TCP_PORT_MIN", constants.DefaultTCPPortMin), "Minimum TCP tunnel port (env: DRIP_TCP_PORT_MIN)")
serverCmd.Flags().IntVar(&serverTCPPortMax, "tcp-port-max", getEnvInt("DRIP_TCP_PORT_MAX", constants.DefaultTCPPortMax), "Maximum TCP tunnel port (env: DRIP_TCP_PORT_MAX)")
// TLS options
serverCmd.Flags().StringVar(&serverTLSCert, "tls-cert", getEnvString("DRIP_TLS_CERT", ""), "Path to TLS certificate file (env: DRIP_TLS_CERT)")
serverCmd.Flags().StringVar(&serverTLSKey, "tls-key", getEnvString("DRIP_TLS_KEY", ""), "Path to TLS private key file (env: DRIP_TLS_KEY)")
// Performance profiling
serverCmd.Flags().IntVar(&serverPprofPort, "pprof", getEnvInt("DRIP_PPROF_PORT", 0), "Enable pprof on specified port (env: DRIP_PPROF_PORT)")
}
func runServer(cmd *cobra.Command, args []string) error {
// Validate required TLS configuration
if serverTLSCert == "" {
return fmt.Errorf("TLS certificate path is required (use --tls-cert flag or DRIP_TLS_CERT environment variable)")
}
if serverTLSKey == "" {
return fmt.Errorf("TLS private key path is required (use --tls-key flag or DRIP_TLS_KEY environment variable)")
}
// Initialize logger
if err := utils.InitServerLogger(serverDebug); err != nil {
return fmt.Errorf("failed to initialize logger: %w", err)
}
defer utils.Sync()
logger := utils.GetLogger()
logger.Info("Starting Drip Server",
zap.String("version", Version),
zap.String("commit", GitCommit),
)
// Start pprof server if enabled
if serverPprofPort > 0 {
go func() {
pprofAddr := fmt.Sprintf("localhost:%d", serverPprofPort)
logger.Info("Starting pprof server", zap.String("address", pprofAddr))
if err := http.ListenAndServe(pprofAddr, nil); err != nil {
logger.Error("pprof server failed", zap.Error(err))
}
}()
}
// Create server config
displayPort := serverPublicPort
if displayPort == 0 {
displayPort = serverPort
}
serverConfig := &config.ServerConfig{
Port: serverPort,
PublicPort: displayPort,
Domain: serverDomain,
TCPPortMin: serverTCPPortMin,
TCPPortMax: serverTCPPortMax,
TLSEnabled: true,
TLSCertFile: serverTLSCert,
TLSKeyFile: serverTLSKey,
AuthToken: serverAuthToken,
Debug: serverDebug,
}
// Load TLS configuration
tlsConfig, err := serverConfig.LoadTLSConfig()
if err != nil {
logger.Fatal("Failed to load TLS configuration", zap.Error(err))
}
logger.Info("TLS 1.3 configuration loaded",
zap.String("cert", serverTLSCert),
zap.String("key", serverTLSKey),
)
// Create tunnel manager
tunnelManager := tunnel.NewManager(logger)
// Create TCP port allocator
portAllocator, err := tcp.NewPortAllocator(serverTCPPortMin, serverTCPPortMax)
if err != nil {
logger.Fatal("Invalid TCP port range", zap.Error(err))
}
// Create TCP listener address
listenAddr := fmt.Sprintf("0.0.0.0:%d", serverPort)
// Response handler for HTTP-over-frame responses
responseHandler := proxy.NewResponseHandler(logger)
// Create HTTP proxy handler (for handling HTTP requests on TCP port)
httpHandler := proxy.NewHandler(tunnelManager, logger, responseHandler, serverDomain, serverAuthToken)
// Create TCP listener (wsHandler also serves as response channel handler)
listener := tcp.NewListener(listenAddr, tlsConfig, serverAuthToken, tunnelManager, logger, portAllocator, serverDomain, displayPort, httpHandler, responseHandler)
// Start listener
if err := listener.Start(); err != nil {
logger.Fatal("Failed to start TCP listener", zap.Error(err))
}
logger.Info("Drip Server started",
zap.String("address", listenAddr),
zap.String("domain", serverDomain),
zap.String("protocol", "TCP over TLS 1.3"),
)
// Setup signal handler
quit := make(chan os.Signal, 1)
signal.Notify(quit, syscall.SIGINT, syscall.SIGTERM)
// Wait for signal
<-quit
logger.Info("Shutting down server...")
// Stop listener
if err := listener.Stop(); err != nil {
logger.Error("Error stopping listener", zap.Error(err))
}
logger.Info("Server stopped")
return nil
}
// getEnvInt returns the environment variable value as int, or defaultVal if not set
func getEnvInt(key string, defaultVal int) int {
if val := os.Getenv(key); val != "" {
if i, err := strconv.Atoi(val); err == nil {
return i
}
}
return defaultVal
}
// getEnvString returns the environment variable value, or defaultVal if not set
func getEnvString(key string, defaultVal string) string {
if val := os.Getenv(key); val != "" {
return val
}
return defaultVal
}

125
internal/client/cli/stop.go Normal file
View File

@@ -0,0 +1,125 @@
package cli
import (
"fmt"
"strconv"
"github.com/spf13/cobra"
)
var stopCmd = &cobra.Command{
Use: "stop <type> <port>|all",
Short: "Stop background tunnels",
Long: `Stop one or all background tunnels.
Examples:
drip stop http 3000 Stop HTTP tunnel on port 3000
drip stop tcp 5432 Stop TCP tunnel on port 5432
drip stop all Stop all running tunnels
Use 'drip list' to see running tunnels.`,
Aliases: []string{"kill"},
Args: cobra.MinimumNArgs(1),
RunE: runStop,
}
func init() {
rootCmd.AddCommand(stopCmd)
}
func runStop(cmd *cobra.Command, args []string) error {
// Handle "stop all"
if args[0] == "all" {
return stopAllDaemons()
}
// Handle "stop <type> <port>"
if len(args) < 2 {
return fmt.Errorf("usage: drip stop <type> <port> or drip stop all")
}
tunnelType := args[0]
if tunnelType != "http" && tunnelType != "tcp" {
return fmt.Errorf("invalid tunnel type: %s (must be 'http' or 'tcp')", tunnelType)
}
port, err := strconv.Atoi(args[1])
if err != nil || port < 1 || port > 65535 {
return fmt.Errorf("invalid port number: %s", args[1])
}
return stopDaemon(tunnelType, port)
}
func stopDaemon(tunnelType string, port int) error {
info, err := LoadDaemonInfo(tunnelType, port)
if err != nil {
return fmt.Errorf("failed to load daemon info: %w", err)
}
if info == nil {
return fmt.Errorf("no %s tunnel running on port %d", tunnelType, port)
}
// Check if process is still running
if !IsProcessRunning(info.PID) {
// Clean up stale entry
RemoveDaemonInfo(tunnelType, port)
return fmt.Errorf("tunnel was not running (cleaned up stale entry)")
}
// Kill the process
if err := KillProcess(info.PID); err != nil {
return fmt.Errorf("failed to stop tunnel: %w", err)
}
// Remove daemon info
RemoveDaemonInfo(tunnelType, port)
fmt.Printf("\033[32m✓\033[0m Stopped %s tunnel on port %d (PID: %d)\n", tunnelType, port, info.PID)
return nil
}
func stopAllDaemons() error {
// Clean up stale daemons first
CleanupStaleDaemons()
daemons, err := ListAllDaemons()
if err != nil {
return fmt.Errorf("failed to list daemons: %w", err)
}
if len(daemons) == 0 {
fmt.Println("\033[90mNo running tunnels to stop.\033[0m")
return nil
}
stopped := 0
failed := 0
for _, d := range daemons {
if !IsProcessRunning(d.PID) {
RemoveDaemonInfo(d.Type, d.Port)
continue
}
if err := KillProcess(d.PID); err != nil {
fmt.Printf("\033[31m✗\033[0m Failed to stop %s tunnel on port %d: %v\n", d.Type, d.Port, err)
failed++
continue
}
RemoveDaemonInfo(d.Type, d.Port)
fmt.Printf("\033[32m✓\033[0m Stopped %s tunnel on port %d (PID: %d)\n", d.Type, d.Port, d.PID)
stopped++
}
fmt.Println()
if failed > 0 {
fmt.Printf("Stopped %d tunnel(s), %d failed\n", stopped, failed)
} else {
fmt.Printf("Stopped %d tunnel(s)\n", stopped)
}
return nil
}

361
internal/client/cli/tcp.go Normal file
View File

@@ -0,0 +1,361 @@
package cli
import (
"fmt"
"os"
"os/signal"
"strconv"
"strings"
"syscall"
"time"
"drip/internal/client/tcp"
"drip/internal/shared/protocol"
"drip/internal/shared/utils"
"drip/pkg/config"
"github.com/spf13/cobra"
"go.uber.org/zap"
)
var tcpCmd = &cobra.Command{
Use: "tcp <port>",
Short: "Start TCP tunnel",
Long: `Start a TCP tunnel to expose any TCP service.
Example:
drip tcp 5432 Tunnel PostgreSQL
drip tcp 3306 Tunnel MySQL
drip tcp 22 Tunnel SSH
drip tcp 6379 --subdomain myredis Tunnel Redis with custom subdomain
Supported Services:
- Databases: PostgreSQL (5432), MySQL (3306), Redis (6379), MongoDB (27017)
- SSH: Port 22
- Any TCP service
Configuration:
First time: Run 'drip config init' to save server and token
Subsequent: Just run 'drip tcp <port>'
Note: Uses TCP over TLS 1.3 for secure communication`,
Args: cobra.ExactArgs(1),
RunE: runTCP,
}
func init() {
tcpCmd.Flags().StringVarP(&subdomain, "subdomain", "n", "", "Custom subdomain (optional)")
tcpCmd.Flags().BoolVarP(&daemonMode, "daemon", "d", false, "Run in background (daemon mode)")
tcpCmd.Flags().StringVarP(&localAddress, "address", "a", "127.0.0.1", "Local address to forward to (default: 127.0.0.1)")
tcpCmd.Flags().BoolVar(&daemonMarker, "daemon-child", false, "Internal flag for daemon child process")
tcpCmd.Flags().MarkHidden("daemon-child")
rootCmd.AddCommand(tcpCmd)
}
func runTCP(cmd *cobra.Command, args []string) error {
// Parse port
port, err := strconv.Atoi(args[0])
if err != nil || port < 1 || port > 65535 {
return fmt.Errorf("invalid port number: %s", args[0])
}
// Handle daemon mode
if daemonMode && !daemonMarker {
// Start as daemon
daemonArgs := append([]string{"tcp"}, args...)
daemonArgs = append(daemonArgs, "--daemon-child")
if subdomain != "" {
daemonArgs = append(daemonArgs, "--subdomain", subdomain)
}
if localAddress != "127.0.0.1" {
daemonArgs = append(daemonArgs, "--address", localAddress)
}
if serverURL != "" {
daemonArgs = append(daemonArgs, "--server", serverURL)
}
if authToken != "" {
daemonArgs = append(daemonArgs, "--token", authToken)
}
if insecure {
daemonArgs = append(daemonArgs, "--insecure")
}
if verbose {
daemonArgs = append(daemonArgs, "--verbose")
}
return StartDaemon("tcp", port, daemonArgs)
}
// Initialize logger
if err := utils.InitLogger(verbose); err != nil {
return fmt.Errorf("failed to initialize logger: %w", err)
}
defer utils.Sync()
logger := utils.GetLogger()
// Load configuration or use command line flags
var serverAddr, token string
if serverURL == "" {
// Try to load from config file
cfg, err := config.LoadClientConfig("")
if err != nil {
return fmt.Errorf(`configuration not found.
Please run 'drip config init' first, or use flags:
drip tcp %d --server SERVER:PORT --token TOKEN`, port)
}
serverAddr = cfg.Server
token = cfg.Token
} else {
// Use command line flags
serverAddr = serverURL
token = authToken
}
// Validate server address
if serverAddr == "" {
return fmt.Errorf("server address is required")
}
// Create connector config
connConfig := &tcp.ConnectorConfig{
ServerAddr: serverAddr,
Token: token,
TunnelType: protocol.TunnelTypeTCP,
LocalHost: localAddress,
LocalPort: port,
Subdomain: subdomain,
Insecure: insecure,
}
// Setup signal handler for graceful shutdown
quit := make(chan os.Signal, 1)
signal.Notify(quit, syscall.SIGINT, syscall.SIGTERM)
// Connection loop with reconnect support
reconnectAttempts := 0
serviceName := getServiceName(port)
for {
// Create connector
connector := tcp.NewConnector(connConfig, logger)
// Connect to server
if reconnectAttempts == 0 {
fmt.Printf("\033[36m🔌 Connecting to %s...\033[0m\n", serverAddr)
} else {
fmt.Printf("\033[33m🔄 Reconnecting to %s (attempt %d/%d)...\033[0m\n", serverAddr, reconnectAttempts, maxReconnectAttempts)
}
if err := connector.Connect(); err != nil {
// Check if this is a non-retryable error
if isNonRetryableErrorTCP(err) {
return fmt.Errorf("failed to connect: %w", err)
}
reconnectAttempts++
if reconnectAttempts >= maxReconnectAttempts {
return fmt.Errorf("failed to connect after %d attempts: %w", maxReconnectAttempts, err)
}
fmt.Printf("\033[31m✗ Connection failed: %v\033[0m\n", err)
fmt.Printf("\033[90m Retrying in %v...\033[0m\n", reconnectInterval)
// Wait before retry, but allow interrupt
select {
case <-quit:
fmt.Println("\n\033[33m🛑 Shutting down...\033[0m")
return nil
case <-time.After(reconnectInterval):
continue
}
}
// Reset reconnect attempts on successful connection
reconnectAttempts = 0
// Save daemon info if running as daemon child
if daemonMarker {
daemonInfo := &DaemonInfo{
PID: os.Getpid(),
Type: "tcp",
Port: port,
Subdomain: subdomain,
Server: serverAddr,
URL: connector.GetURL(),
StartTime: time.Now(),
Executable: os.Args[0],
}
if err := SaveDaemonInfo(daemonInfo); err != nil {
// Log but don't fail
logger.Warn("Failed to save daemon info", zap.Error(err))
}
}
// Print tunnel information
fmt.Println()
fmt.Println("\033[1;35m╔══════════════════════════════════════════════════════════════════╗\033[0m")
fmt.Println("\033[1;35m║\033[0m \033[1;37m🔌 TCP Tunnel Connected Successfully!\033[0m \033[1;35m║\033[0m")
fmt.Println("\033[1;35m╠══════════════════════════════════════════════════════════════════╣\033[0m")
fmt.Printf("\033[1;35m║\033[0m \033[1;37mTunnel URL:\033[0m \033[1;35m║\033[0m\n")
fmt.Printf("\033[1;35m║\033[0m \033[1;36m%-60s\033[0m \033[1;35m║\033[0m\n", connector.GetURL())
fmt.Println("\033[1;35m║\033[0m \033[1;35m║\033[0m")
fmt.Printf("\033[1;35m║\033[0m \033[90mService:\033[0m \033[1;35m%-50s\033[0m \033[1;35m║\033[0m\n", serviceName)
displayAddr := localAddress
if displayAddr == "127.0.0.1" {
displayAddr = "localhost"
}
fmt.Printf("\033[1;35m║\033[0m \033[90mForwarding:\033[0m \033[1m%s:%d\033[0m → \033[36m%s\033[0m%-15s\033[1;35m║\033[0m\n", displayAddr, port, "public", "")
fmt.Printf("\033[1;35m║\033[0m \033[90mLatency:\033[0m \033[90mmeasuring...\033[0m%-40s\033[1;35m║\033[0m\n", "")
fmt.Printf("\033[1;35m║\033[0m \033[90mTraffic:\033[0m \033[90m↓ 0 B ↑ 0 B\033[0m%-32s\033[1;35m║\033[0m\n", "")
fmt.Printf("\033[1;35m║\033[0m \033[90mSpeed:\033[0m \033[90m↓ 0 B/s ↑ 0 B/s\033[0m%-28s\033[1;35m║\033[0m\n", "")
fmt.Printf("\033[1;35m║\033[0m \033[90mRequests:\033[0m \033[90m0\033[0m%-43s\033[1;35m║\033[0m\n", "")
fmt.Println("\033[1;35m╠══════════════════════════════════════════════════════════════════╣\033[0m")
fmt.Println("\033[1;35m║\033[0m \033[90mPress Ctrl+C to stop the tunnel\033[0m \033[1;35m║\033[0m")
fmt.Println("\033[1;35m╚══════════════════════════════════════════════════════════════════╝\033[0m")
fmt.Println()
// Setup latency display
latencyCh := make(chan time.Duration, 1)
connector.SetLatencyCallback(func(latency time.Duration) {
select {
case latencyCh <- latency:
default:
}
})
// Start stats display updater (updates every second)
stopDisplay := make(chan struct{})
disconnected := make(chan struct{})
go func() {
ticker := time.NewTicker(1 * time.Second)
defer ticker.Stop()
var lastLatency time.Duration
for {
select {
case latency := <-latencyCh:
lastLatency = latency
case <-ticker.C:
// Update speed calculation
stats := connector.GetStats()
if stats != nil {
stats.UpdateSpeed()
snapshot := stats.GetSnapshot()
// Move cursor up 8 lines to update display
fmt.Print("\033[8A")
// Update latency line
fmt.Printf("\r\033[1;35m║\033[0m \033[90mLatency:\033[0m %s%-40s\033[1;35m║\033[0m\n", formatLatencyTCP(lastLatency), "")
// Update traffic line
trafficStr := fmt.Sprintf("↓ %s ↑ %s", tcp.FormatBytes(snapshot.TotalBytesIn), tcp.FormatBytes(snapshot.TotalBytesOut))
fmt.Printf("\r\033[1;35m║\033[0m \033[90mTraffic:\033[0m \033[36m%-48s\033[0m\033[1;35m║\033[0m\n", trafficStr)
// Update speed line
speedStr := fmt.Sprintf("↓ %s ↑ %s", tcp.FormatSpeed(snapshot.SpeedIn), tcp.FormatSpeed(snapshot.SpeedOut))
fmt.Printf("\r\033[1;35m║\033[0m \033[90mSpeed:\033[0m \033[33m%-48s\033[0m\033[1;35m║\033[0m\n", speedStr)
// Update requests line
fmt.Printf("\r\033[1;35m║\033[0m \033[90mRequests:\033[0m \033[35m%-47d\033[0m\033[1;35m║\033[0m\n", snapshot.TotalRequests)
// Move back down 4 lines
fmt.Print("\033[4B")
}
case <-stopDisplay:
return
}
}
}()
// Monitor connection in background
go func() {
connector.Wait()
close(disconnected)
}()
// Wait for signal or disconnection
select {
case <-quit:
close(stopDisplay)
fmt.Println("\n\n\033[33m🛑 Shutting down...\033[0m")
connector.Close()
if daemonMarker {
RemoveDaemonInfo("tcp", port)
}
fmt.Println("\033[32m✓\033[0m Tunnel closed")
return nil
case <-disconnected:
close(stopDisplay)
fmt.Println("\n\n\033[31m⚠ Connection lost!\033[0m")
reconnectAttempts++
if reconnectAttempts >= maxReconnectAttempts {
return fmt.Errorf("connection lost after %d reconnect attempts", maxReconnectAttempts)
}
fmt.Printf("\033[90m Reconnecting in %v...\033[0m\n", reconnectInterval)
// Wait before reconnect, but allow interrupt
select {
case <-quit:
fmt.Println("\n\033[33m🛑 Shutting down...\033[0m")
return nil
case <-time.After(reconnectInterval):
continue
}
}
}
}
// getServiceName returns a friendly name for common port numbers
func getServiceName(port int) string {
services := map[int]string{
22: "SSH",
80: "HTTP",
443: "HTTPS",
3306: "MySQL",
5432: "PostgreSQL",
6379: "Redis",
27017: "MongoDB",
3389: "RDP",
5900: "VNC",
8080: "HTTP (Alt)",
8443: "HTTPS (Alt)",
}
if name, ok := services[port]; ok {
return fmt.Sprintf("%s (port %d)", name, port)
}
return fmt.Sprintf("TCP service on port %d", port)
}
// formatLatencyTCP formats latency with color based on value
func formatLatencyTCP(d time.Duration) string {
ms := d.Milliseconds()
if ms < 50 {
return fmt.Sprintf("\033[32m%dms\033[0m", ms) // Green: excellent
} else if ms < 100 {
return fmt.Sprintf("\033[33m%dms\033[0m", ms) // Yellow: good
} else if ms < 200 {
return fmt.Sprintf("\033[38;5;208m%dms\033[0m", ms) // Orange: moderate
}
return fmt.Sprintf("\033[31m%dms\033[0m", ms) // Red: poor
}
// isNonRetryableErrorTCP checks if an error should not be retried
func isNonRetryableErrorTCP(err error) bool {
errStr := err.Error()
// Subdomain conflicts - no point retrying
if strings.Contains(errStr, "subdomain is already taken") ||
strings.Contains(errStr, "subdomain is reserved") ||
strings.Contains(errStr, "invalid subdomain") {
return true
}
// Authentication errors - no point retrying
if strings.Contains(errStr, "authentication") ||
strings.Contains(errStr, "Invalid authentication token") {
return true
}
return false
}

View File

@@ -0,0 +1,393 @@
package tcp
import (
"crypto/tls"
"encoding/json"
"fmt"
"net"
"sync"
"time"
"drip/internal/shared/constants"
"drip/internal/shared/pool"
"drip/internal/shared/protocol"
"drip/pkg/config"
"go.uber.org/zap"
)
// LatencyCallback is called when latency is measured
type LatencyCallback func(latency time.Duration)
// Connector manages the TCP connection to the server
type Connector struct {
serverAddr string
tlsConfig *tls.Config
token string
tunnelType protocol.TunnelType
localHost string
localPort int
subdomain string
conn net.Conn
logger *zap.Logger
stopCh chan struct{}
once sync.Once
registered bool
assignedURL string
frameHandler *FrameHandler
frameWriter *protocol.FrameWriter
latencyCallback LatencyCallback
heartbeatSentAt time.Time
heartbeatMu sync.Mutex
lastLatency time.Duration
handlerWg sync.WaitGroup // Tracks active data frame handlers
closed bool
closedMu sync.RWMutex
}
// ConnectorConfig holds connector configuration
type ConnectorConfig struct {
ServerAddr string
Token string
TunnelType protocol.TunnelType
LocalHost string // Local host address (default: 127.0.0.1)
LocalPort int
Subdomain string // Optional custom subdomain
Insecure bool // Skip TLS verification (testing only)
}
// NewConnector creates a new connector
func NewConnector(cfg *ConnectorConfig, logger *zap.Logger) *Connector {
var tlsConfig *tls.Config
if cfg.Insecure {
tlsConfig = config.GetClientTLSConfigInsecure()
} else {
host, _, _ := net.SplitHostPort(cfg.ServerAddr)
tlsConfig = config.GetClientTLSConfig(host)
}
localHost := cfg.LocalHost
if localHost == "" {
localHost = "127.0.0.1"
}
return &Connector{
serverAddr: cfg.ServerAddr,
tlsConfig: tlsConfig,
token: cfg.Token,
tunnelType: cfg.TunnelType,
localHost: localHost,
localPort: cfg.LocalPort,
subdomain: cfg.Subdomain,
logger: logger,
stopCh: make(chan struct{}),
}
}
// Connect connects to the server and registers the tunnel
func (c *Connector) Connect() error {
c.logger.Info("Connecting to server",
zap.String("server", c.serverAddr),
zap.String("tunnel_type", string(c.tunnelType)),
zap.String("local_host", c.localHost),
zap.Int("local_port", c.localPort),
)
dialer := &net.Dialer{
Timeout: 10 * time.Second,
}
conn, err := tls.DialWithDialer(dialer, "tcp", c.serverAddr, c.tlsConfig)
if err != nil {
return fmt.Errorf("failed to connect: %w", err)
}
c.conn = conn
state := conn.ConnectionState()
if state.Version != tls.VersionTLS13 {
conn.Close()
return fmt.Errorf("server not using TLS 1.3 (version: 0x%04x)", state.Version)
}
c.logger.Info("TLS connection established",
zap.String("cipher_suite", tls.CipherSuiteName(state.CipherSuite)),
)
if err := c.register(); err != nil {
conn.Close()
return fmt.Errorf("registration failed: %w", err)
}
c.frameWriter = protocol.NewFrameWriter(c.conn)
bufferPool := pool.NewBufferPool()
c.frameHandler = NewFrameHandler(
c.conn,
c.frameWriter,
c.localHost,
c.localPort,
c.tunnelType,
c.logger,
c.IsClosed,
bufferPool,
)
go c.frameHandler.WarmupConnectionPool(3)
go c.handleFrames()
go c.heartbeat()
return nil
}
// register sends registration request and waits for acknowledgment
func (c *Connector) register() error {
req := protocol.RegisterRequest{
Token: c.token,
CustomSubdomain: c.subdomain,
TunnelType: c.tunnelType,
LocalPort: c.localPort,
}
payload, err := json.Marshal(req)
if err != nil {
return fmt.Errorf("failed to marshal request: %w", err)
}
regFrame := protocol.NewFrame(protocol.FrameTypeRegister, payload)
err = protocol.WriteFrame(c.conn, regFrame)
if err != nil {
return fmt.Errorf("failed to send registration: %w", err)
}
c.conn.SetReadDeadline(time.Now().Add(constants.RequestTimeout))
ackFrame, err := protocol.ReadFrame(c.conn)
if err != nil {
return fmt.Errorf("failed to read ack: %w", err)
}
defer ackFrame.Release()
c.conn.SetReadDeadline(time.Time{})
if ackFrame.Type == protocol.FrameTypeError {
var errMsg protocol.ErrorMessage
if err := json.Unmarshal(ackFrame.Payload, &errMsg); err == nil {
return fmt.Errorf("registration error: %s - %s", errMsg.Code, errMsg.Message)
}
return fmt.Errorf("registration error")
}
if ackFrame.Type != protocol.FrameTypeRegisterAck {
return fmt.Errorf("unexpected frame type: %s", ackFrame.Type)
}
var resp protocol.RegisterResponse
if err := json.Unmarshal(ackFrame.Payload, &resp); err != nil {
return fmt.Errorf("failed to parse response: %w", err)
}
c.registered = true
c.assignedURL = resp.URL
c.subdomain = resp.Subdomain
c.logger.Info("Tunnel registered successfully",
zap.String("subdomain", resp.Subdomain),
zap.String("url", resp.URL),
zap.Int("remote_port", resp.Port),
)
return nil
}
// handleFrames handles incoming frames from server
func (c *Connector) handleFrames() {
defer c.Close()
for {
select {
case <-c.stopCh:
return
default:
}
c.conn.SetReadDeadline(time.Now().Add(constants.RequestTimeout))
frame, err := protocol.ReadFrame(c.conn)
if err != nil {
if netErr, ok := err.(net.Error); ok && netErr.Timeout() {
c.logger.Warn("Read timeout")
return
}
select {
case <-c.stopCh:
return
default:
c.logger.Error("Failed to read frame", zap.Error(err))
return
}
}
switch frame.Type {
case protocol.FrameTypeHeartbeatAck:
c.heartbeatMu.Lock()
if !c.heartbeatSentAt.IsZero() {
latency := time.Since(c.heartbeatSentAt)
c.lastLatency = latency
c.heartbeatMu.Unlock()
c.logger.Debug("Received heartbeat ack", zap.Duration("latency", latency))
if c.latencyCallback != nil {
c.latencyCallback(latency)
}
} else {
c.heartbeatMu.Unlock()
c.logger.Debug("Received heartbeat ack")
}
frame.Release()
case protocol.FrameTypeData:
c.handlerWg.Add(1)
go func(f *protocol.Frame) {
defer c.handlerWg.Done()
defer f.Release()
if err := c.frameHandler.HandleDataFrame(f); err != nil {
c.logger.Error("Failed to handle data frame", zap.Error(err))
}
}(frame)
case protocol.FrameTypeClose:
frame.Release()
c.logger.Info("Server requested close")
return
case protocol.FrameTypeError:
var errMsg protocol.ErrorMessage
if err := json.Unmarshal(frame.Payload, &errMsg); err == nil {
c.logger.Error("Received error from server",
zap.String("code", errMsg.Code),
zap.String("message", errMsg.Message),
)
}
frame.Release()
return
default:
frame.Release()
c.logger.Warn("Unexpected frame type",
zap.String("type", frame.Type.String()),
)
}
}
}
// heartbeat sends periodic heartbeat frames
func (c *Connector) heartbeat() {
ticker := time.NewTicker(constants.HeartbeatInterval)
defer ticker.Stop()
c.sendHeartbeat()
for {
select {
case <-c.stopCh:
return
case <-ticker.C:
c.sendHeartbeat()
}
}
}
// sendHeartbeat sends a heartbeat frame and records the time
func (c *Connector) sendHeartbeat() {
hbFrame := protocol.NewFrame(protocol.FrameTypeHeartbeat, nil)
c.heartbeatMu.Lock()
c.heartbeatSentAt = time.Now()
c.heartbeatMu.Unlock()
err := c.frameWriter.WriteFrame(hbFrame)
if err != nil {
c.logger.Error("Failed to send heartbeat", zap.Error(err))
c.Close()
return
}
c.logger.Debug("Heartbeat sent")
}
// SendFrame sends a frame to the server
func (c *Connector) SendFrame(frame *protocol.Frame) error {
if !c.registered {
return fmt.Errorf("not registered")
}
return c.frameWriter.WriteFrame(frame)
}
// Close closes the connection
func (c *Connector) Close() error {
c.once.Do(func() {
c.closedMu.Lock()
c.closed = true
c.closedMu.Unlock()
close(c.stopCh)
c.logger.Debug("Waiting for active handlers to complete")
c.handlerWg.Wait()
if c.conn != nil {
closeFrame := protocol.NewFrame(protocol.FrameTypeClose, nil)
if c.frameWriter != nil {
c.frameWriter.WriteFrame(closeFrame)
c.frameWriter.Close()
} else {
protocol.WriteFrame(c.conn, closeFrame)
}
c.conn.Close()
}
c.logger.Info("Connector closed")
})
return nil
}
// Wait blocks until connection is closed
func (c *Connector) Wait() {
<-c.stopCh
}
// GetURL returns the assigned tunnel URL
func (c *Connector) GetURL() string {
return c.assignedURL
}
// GetSubdomain returns the assigned subdomain
func (c *Connector) GetSubdomain() string {
return c.subdomain
}
// SetLatencyCallback sets the callback for latency updates
func (c *Connector) SetLatencyCallback(cb LatencyCallback) {
c.latencyCallback = cb
}
// GetLatency returns the last measured latency
func (c *Connector) GetLatency() time.Duration {
c.heartbeatMu.Lock()
defer c.heartbeatMu.Unlock()
return c.lastLatency
}
// GetStats returns the traffic stats from the frame handler
func (c *Connector) GetStats() *TrafficStats {
if c.frameHandler != nil {
return c.frameHandler.GetStats()
}
return nil
}
// IsClosed returns whether the connector has been closed
func (c *Connector) IsClosed() bool {
c.closedMu.RLock()
defer c.closedMu.RUnlock()
return c.closed
}

View File

@@ -0,0 +1,440 @@
package tcp
import (
"bytes"
"crypto/tls"
"fmt"
"io"
"net"
"net/http"
"strings"
"sync"
"time"
"drip/internal/shared/pool"
"drip/internal/shared/protocol"
"go.uber.org/zap"
)
// FrameHandler handles data frames and forwards to local service
type FrameHandler struct {
conn net.Conn
frameWriter *protocol.FrameWriter // Async batch writer (replaces writeMu)
localHost string
localPort int
logger *zap.Logger
streams map[string]*Stream
streamMu sync.RWMutex
tunnelType protocol.TunnelType
httpClient *http.Client
stats *TrafficStats
isClosedCheck func() bool // Function to check if connection is closed
bufferPool *pool.BufferPool
headerPool *pool.HeaderPool // Header pool for Priority 9 optimization
}
// Stream represents a single request/response stream
type Stream struct {
ID string
LocalConn net.Conn
ResponseCh chan []byte
Done chan struct{}
}
func NewFrameHandler(conn net.Conn, frameWriter *protocol.FrameWriter, localHost string, localPort int, tunnelType protocol.TunnelType, logger *zap.Logger, isClosedCheck func() bool, bufferPool *pool.BufferPool) *FrameHandler {
var tlsConfig *tls.Config
if tunnelType == protocol.TunnelTypeHTTPS {
tlsConfig = &tls.Config{
InsecureSkipVerify: true,
}
}
return &FrameHandler{
conn: conn,
frameWriter: frameWriter,
localHost: localHost,
localPort: localPort,
logger: logger,
streams: make(map[string]*Stream),
tunnelType: tunnelType,
stats: NewTrafficStats(),
isClosedCheck: isClosedCheck,
bufferPool: bufferPool,
headerPool: pool.NewHeaderPool(),
httpClient: &http.Client{
Timeout: 30 * time.Second,
Transport: &http.Transport{
MaxIdleConns: 500,
MaxIdleConnsPerHost: 200,
MaxConnsPerHost: 0,
IdleConnTimeout: 180 * time.Second,
DisableCompression: true,
DisableKeepAlives: false,
TLSHandshakeTimeout: 10 * time.Second,
TLSClientConfig: tlsConfig,
ResponseHeaderTimeout: 15 * time.Second,
ExpectContinueTimeout: 1 * time.Second,
DialContext: (&net.Dialer{
Timeout: 5 * time.Second,
KeepAlive: 30 * time.Second,
}).DialContext,
},
CheckRedirect: func(req *http.Request, via []*http.Request) error {
return http.ErrUseLastResponse
},
},
}
}
func (h *FrameHandler) HandleDataFrame(frame *protocol.Frame) error {
h.stats.AddBytesIn(int64(len(frame.Payload)))
h.stats.AddRequest()
header, data, err := protocol.DecodeDataPayload(frame.Payload)
if err != nil {
return fmt.Errorf("failed to decode data payload: %w", err)
}
if header.Type == "http_request" {
return h.handleHTTPFrame(header, data)
}
if header.Type == "close" {
h.closeStream(header.StreamID)
return nil
}
stream, err := h.getOrCreateStream(header.StreamID)
if err != nil {
return fmt.Errorf("failed to get stream: %w", err)
}
h.forwardToLocal(stream, data)
return nil
}
func (h *FrameHandler) getOrCreateStream(streamID string) (*Stream, error) {
h.streamMu.Lock()
defer h.streamMu.Unlock()
if stream, ok := h.streams[streamID]; ok {
return stream, nil
}
localAddr := fmt.Sprintf("%s:%d", h.localHost, h.localPort)
localConn, err := net.DialTimeout("tcp", localAddr, 5*time.Second)
if err != nil {
return nil, fmt.Errorf("failed to connect to local service: %w", err)
}
stream := &Stream{
ID: streamID,
LocalConn: localConn,
ResponseCh: make(chan []byte, 10),
Done: make(chan struct{}),
}
h.streams[streamID] = stream
go h.handleLocalResponse(stream)
return stream, nil
}
func (h *FrameHandler) forwardToLocal(stream *Stream, data []byte) {
if _, err := stream.LocalConn.Write(data); err != nil {
h.logger.Error("Failed to write to local service",
zap.String("stream_id", stream.ID),
zap.Error(err),
)
h.closeStream(stream.ID)
}
}
func (h *FrameHandler) handleLocalResponse(stream *Stream) {
defer h.closeStream(stream.ID)
bufPtr := h.bufferPool.Get(pool.SizeMedium)
defer h.bufferPool.Put(bufPtr)
buf := (*bufPtr)[:pool.SizeMedium]
for {
n, err := stream.LocalConn.Read(buf)
if err != nil {
break
}
if n > 0 {
if h.isClosedCheck != nil && h.isClosedCheck() {
break
}
header := protocol.DataHeader{
StreamID: stream.ID,
Type: "response",
IsLast: false,
}
payload, err := protocol.EncodeDataPayload(header, buf[:n])
if err != nil {
h.logger.Error("Encode payload failed", zap.Error(err))
break
}
dataFrame := protocol.NewFrame(protocol.FrameTypeData, payload)
err = h.frameWriter.WriteFrame(dataFrame)
if err != nil {
h.logger.Error("Send frame failed", zap.Error(err))
break
}
h.stats.AddBytesOut(int64(len(payload)))
}
}
}
func (h *FrameHandler) handleHTTPFrame(header protocol.DataHeader, payload []byte) error {
if h.tunnelType != protocol.TunnelTypeHTTP && h.tunnelType != protocol.TunnelTypeHTTPS {
return nil
}
httpReq, err := protocol.DecodeHTTPRequest(payload)
if err != nil {
return fmt.Errorf("failed to decode HTTP request: %w", err)
}
targetURL := httpReq.URL
if !strings.HasPrefix(targetURL, "http://") && !strings.HasPrefix(targetURL, "https://") {
scheme := "http"
if h.tunnelType == protocol.TunnelTypeHTTPS {
scheme = "https"
}
targetURL = fmt.Sprintf("%s://%s:%d%s", scheme, h.localHost, h.localPort, targetURL)
}
req, err := http.NewRequest(httpReq.Method, targetURL, bytes.NewReader(httpReq.Body))
if err != nil {
return h.sendHTTPError(header.StreamID, header.RequestID, http.StatusBadGateway, fmt.Sprintf("build request: %v", err))
}
origHost := ""
for key, values := range httpReq.Headers {
for _, value := range values {
req.Header.Add(key, value)
}
}
if host := req.Header.Get("Host"); host != "" {
origHost = host
}
isLocalTarget := h.isLocalAddress(h.localHost)
if isLocalTarget {
if origHost != "" {
req.Host = origHost
req.Header.Set("Host", origHost)
} else {
localHostPort := fmt.Sprintf("%s:%d", h.localHost, h.localPort)
req.Host = localHostPort
req.Header.Set("Host", localHostPort)
}
if origHost != "" {
req.Header.Set("X-Forwarded-Host", origHost)
}
} else {
targetHost := h.localHost
if h.localPort != 443 && h.localPort != 80 {
targetHost = fmt.Sprintf("%s:%d", h.localHost, h.localPort)
}
req.Host = targetHost
req.Header.Set("Host", targetHost)
if origHost != "" {
req.Header.Set("X-Forwarded-Host", origHost)
}
}
req.Header.Set("X-Forwarded-Proto", "https")
resp, err := h.httpClient.Do(req)
if err != nil {
return h.sendHTTPError(header.StreamID, header.RequestID, http.StatusBadGateway, fmt.Sprintf("local request failed: %v", err))
}
defer resp.Body.Close()
body, err := io.ReadAll(resp.Body)
if err != nil {
return h.sendHTTPError(header.StreamID, header.RequestID, http.StatusBadGateway, fmt.Sprintf("read response: %v", err))
}
httpResp := protocol.HTTPResponse{
StatusCode: resp.StatusCode,
Status: resp.Status,
Headers: resp.Header,
Body: body,
}
return h.sendHTTPResponse(header.StreamID, header.RequestID, &httpResp)
}
func (h *FrameHandler) sendHTTPError(streamID, requestID string, status int, message string) error {
headers := h.headerPool.Get()
headers.Set("Content-Type", "text/plain")
httpResp := protocol.HTTPResponse{
StatusCode: status,
Status: http.StatusText(status),
Headers: headers,
Body: []byte(message),
}
err := h.sendHTTPResponse(streamID, requestID, &httpResp)
h.headerPool.Put(headers)
return err
}
func (h *FrameHandler) sendHTTPResponse(streamID, requestID string, resp *protocol.HTTPResponse) error {
if h.isClosedCheck != nil && h.isClosedCheck() {
return nil
}
header := protocol.DataHeader{
StreamID: streamID,
RequestID: requestID,
Type: "http_response",
IsLast: true,
}
respBytes, err := protocol.EncodeHTTPResponse(resp)
if err != nil {
return fmt.Errorf("encode http response: %w", err)
}
payload, err := protocol.EncodeDataPayload(header, respBytes)
if err != nil {
return fmt.Errorf("encode payload: %w", err)
}
dataFrame := protocol.NewFrame(protocol.FrameTypeData, payload)
h.stats.AddBytesOut(int64(len(payload)))
return h.frameWriter.WriteFrame(dataFrame)
}
func (h *FrameHandler) closeStream(streamID string) {
h.streamMu.Lock()
defer h.streamMu.Unlock()
stream, ok := h.streams[streamID]
if !ok {
return
}
if stream.LocalConn != nil {
stream.LocalConn.Close()
}
close(stream.Done)
delete(h.streams, streamID)
if h.isClosedCheck != nil && h.isClosedCheck() {
return
}
header := protocol.DataHeader{
StreamID: streamID,
RequestID: streamID,
Type: "close",
IsLast: true,
}
payload, err := protocol.EncodeDataPayload(header, nil)
if err != nil {
return
}
closeFrame := protocol.NewFrame(protocol.FrameTypeData, payload)
h.frameWriter.WriteFrame(closeFrame)
}
// Close closes all streams
func (h *FrameHandler) Close() {
h.streamMu.Lock()
defer h.streamMu.Unlock()
for streamID, stream := range h.streams {
if stream.LocalConn != nil {
stream.LocalConn.Close()
}
close(stream.Done)
delete(h.streams, streamID)
}
}
// GetStats returns the traffic stats tracker
func (h *FrameHandler) GetStats() *TrafficStats {
return h.stats
}
func (h *FrameHandler) WarmupConnectionPool(numConnections int) {
if h.tunnelType != protocol.TunnelTypeHTTP {
return
}
targetURL := fmt.Sprintf("http://%s:%d/", h.localHost, h.localPort)
var wg sync.WaitGroup
for i := 0; i < numConnections; i++ {
wg.Add(1)
go func() {
defer wg.Done()
req, err := http.NewRequest("HEAD", targetURL, nil)
if err != nil {
return
}
resp, err := h.httpClient.Do(req)
if err != nil {
return
}
defer resp.Body.Close()
io.Copy(io.Discard, resp.Body)
}()
}
wg.Wait()
}
func (h *FrameHandler) isLocalAddress(addr string) bool {
if addr == "localhost" || addr == "127.0.0.1" || addr == "::1" {
return true
}
if strings.HasPrefix(addr, "192.168.") ||
strings.HasPrefix(addr, "10.") ||
strings.HasPrefix(addr, "172.16.") ||
strings.HasPrefix(addr, "172.17.") ||
strings.HasPrefix(addr, "172.18.") ||
strings.HasPrefix(addr, "172.19.") ||
strings.HasPrefix(addr, "172.20.") ||
strings.HasPrefix(addr, "172.21.") ||
strings.HasPrefix(addr, "172.22.") ||
strings.HasPrefix(addr, "172.23.") ||
strings.HasPrefix(addr, "172.24.") ||
strings.HasPrefix(addr, "172.25.") ||
strings.HasPrefix(addr, "172.26.") ||
strings.HasPrefix(addr, "172.27.") ||
strings.HasPrefix(addr, "172.28.") ||
strings.HasPrefix(addr, "172.29.") ||
strings.HasPrefix(addr, "172.30.") ||
strings.HasPrefix(addr, "172.31.") {
return true
}
return false
}

View File

@@ -0,0 +1,227 @@
package tcp
import (
"sync"
"sync/atomic"
"time"
)
// TrafficStats tracks traffic statistics for a tunnel connection
type TrafficStats struct {
// Total bytes
totalBytesIn int64
totalBytesOut int64
// Request counts
totalRequests int64
// For speed calculation
lastBytesIn int64
lastBytesOut int64
lastTime time.Time
speedMu sync.Mutex
// Current speed (bytes per second)
speedIn int64
speedOut int64
// Start time
startTime time.Time
}
// NewTrafficStats creates a new traffic stats tracker
func NewTrafficStats() *TrafficStats {
now := time.Now()
return &TrafficStats{
startTime: now,
lastTime: now,
}
}
// AddBytesIn adds incoming bytes to the counter
func (s *TrafficStats) AddBytesIn(n int64) {
atomic.AddInt64(&s.totalBytesIn, n)
}
// AddBytesOut adds outgoing bytes to the counter
func (s *TrafficStats) AddBytesOut(n int64) {
atomic.AddInt64(&s.totalBytesOut, n)
}
// AddRequest increments the request counter
func (s *TrafficStats) AddRequest() {
atomic.AddInt64(&s.totalRequests, 1)
}
// GetTotalBytesIn returns total incoming bytes
func (s *TrafficStats) GetTotalBytesIn() int64 {
return atomic.LoadInt64(&s.totalBytesIn)
}
// GetTotalBytesOut returns total outgoing bytes
func (s *TrafficStats) GetTotalBytesOut() int64 {
return atomic.LoadInt64(&s.totalBytesOut)
}
// GetTotalRequests returns total request count
func (s *TrafficStats) GetTotalRequests() int64 {
return atomic.LoadInt64(&s.totalRequests)
}
// GetTotalBytes returns total bytes (in + out)
func (s *TrafficStats) GetTotalBytes() int64 {
return s.GetTotalBytesIn() + s.GetTotalBytesOut()
}
// UpdateSpeed calculates current transfer speed
// Should be called periodically (e.g., every second)
func (s *TrafficStats) UpdateSpeed() {
s.speedMu.Lock()
defer s.speedMu.Unlock()
now := time.Now()
elapsed := now.Sub(s.lastTime).Seconds()
if elapsed < 0.1 {
return // Avoid division by zero or too frequent updates
}
currentIn := atomic.LoadInt64(&s.totalBytesIn)
currentOut := atomic.LoadInt64(&s.totalBytesOut)
deltaIn := currentIn - s.lastBytesIn
deltaOut := currentOut - s.lastBytesOut
s.speedIn = int64(float64(deltaIn) / elapsed)
s.speedOut = int64(float64(deltaOut) / elapsed)
s.lastBytesIn = currentIn
s.lastBytesOut = currentOut
s.lastTime = now
}
// GetSpeedIn returns current incoming speed in bytes per second
func (s *TrafficStats) GetSpeedIn() int64 {
s.speedMu.Lock()
defer s.speedMu.Unlock()
return s.speedIn
}
// GetSpeedOut returns current outgoing speed in bytes per second
func (s *TrafficStats) GetSpeedOut() int64 {
s.speedMu.Lock()
defer s.speedMu.Unlock()
return s.speedOut
}
// GetUptime returns how long the connection has been active
func (s *TrafficStats) GetUptime() time.Duration {
return time.Since(s.startTime)
}
// Snapshot returns a snapshot of all stats
type StatsSnapshot struct {
TotalBytesIn int64
TotalBytesOut int64
TotalBytes int64
TotalRequests int64
SpeedIn int64 // bytes per second
SpeedOut int64 // bytes per second
Uptime time.Duration
}
// GetSnapshot returns a snapshot of current stats
func (s *TrafficStats) GetSnapshot() StatsSnapshot {
s.speedMu.Lock()
speedIn := s.speedIn
speedOut := s.speedOut
s.speedMu.Unlock()
totalIn := atomic.LoadInt64(&s.totalBytesIn)
totalOut := atomic.LoadInt64(&s.totalBytesOut)
return StatsSnapshot{
TotalBytesIn: totalIn,
TotalBytesOut: totalOut,
TotalBytes: totalIn + totalOut,
TotalRequests: atomic.LoadInt64(&s.totalRequests),
SpeedIn: speedIn,
SpeedOut: speedOut,
Uptime: time.Since(s.startTime),
}
}
// FormatBytes formats bytes to human readable string
func FormatBytes(bytes int64) string {
const (
KB = 1024
MB = KB * 1024
GB = MB * 1024
)
switch {
case bytes >= GB:
return formatFloat(float64(bytes)/float64(GB)) + " GB"
case bytes >= MB:
return formatFloat(float64(bytes)/float64(MB)) + " MB"
case bytes >= KB:
return formatFloat(float64(bytes)/float64(KB)) + " KB"
default:
return formatInt(bytes) + " B"
}
}
// FormatSpeed formats speed (bytes per second) to human readable string
func FormatSpeed(bytesPerSec int64) string {
if bytesPerSec == 0 {
return "0 B/s"
}
return FormatBytes(bytesPerSec) + "/s"
}
func formatFloat(f float64) string {
if f >= 100 {
return formatInt(int64(f))
} else if f >= 10 {
return formatOneDecimal(f)
}
return formatTwoDecimal(f)
}
func formatInt(i int64) string {
return intToStr(i)
}
func formatOneDecimal(f float64) string {
i := int64(f * 10)
whole := i / 10
frac := i % 10
return intToStr(whole) + "." + intToStr(frac)
}
func formatTwoDecimal(f float64) string {
i := int64(f * 100)
whole := i / 100
frac := i % 100
if frac < 10 {
return intToStr(whole) + ".0" + intToStr(frac)
}
return intToStr(whole) + "." + intToStr(frac)
}
func intToStr(i int64) string {
if i == 0 {
return "0"
}
if i < 0 {
return "-" + intToStr(-i)
}
var buf [20]byte
pos := len(buf)
for i > 0 {
pos--
buf[pos] = byte('0' + i%10)
i /= 10
}
return string(buf[pos:])
}

View File

@@ -0,0 +1,331 @@
package proxy
import (
"context"
"encoding/json"
"fmt"
"io"
"net/http"
"net/url"
"strings"
"time"
"drip/internal/server/tunnel"
"drip/internal/shared/constants"
"drip/internal/shared/pool"
"drip/internal/shared/protocol"
"drip/internal/shared/utils"
"go.uber.org/zap"
)
type Handler struct {
manager *tunnel.Manager
logger *zap.Logger
responses *ResponseHandler
domain string
authToken string
headerPool *pool.HeaderPool
}
func NewHandler(manager *tunnel.Manager, logger *zap.Logger, responses *ResponseHandler, domain string, authToken string) *Handler {
return &Handler{
manager: manager,
logger: logger,
responses: responses,
domain: domain,
authToken: authToken,
headerPool: pool.NewHeaderPool(),
}
}
func (h *Handler) ServeHTTP(w http.ResponseWriter, r *http.Request) {
subdomain := h.extractSubdomain(r.Host)
if subdomain == "" {
h.serveHomePage(w, r)
return
}
conn, ok := h.manager.Get(subdomain)
if !ok {
http.Error(w, "Tunnel not found. The tunnel may have been closed.", http.StatusNotFound)
return
}
if conn.IsClosed() {
http.Error(w, "Tunnel connection closed", http.StatusBadGateway)
return
}
transport := conn.GetTransport()
if transport == nil {
http.Error(w, "Tunnel control channel not ready", http.StatusBadGateway)
return
}
tType := conn.GetTunnelType()
if tType != "" && tType != protocol.TunnelTypeHTTP && tType != protocol.TunnelTypeHTTPS {
http.Error(w, "Tunnel does not accept HTTP traffic", http.StatusBadGateway)
return
}
limitedReader := io.LimitReader(r.Body, constants.MaxRequestBodySize)
body, err := io.ReadAll(limitedReader)
if err != nil {
h.logger.Error("Read request body failed", zap.Error(err))
http.Error(w, "Failed to read request body", http.StatusInternalServerError)
return
}
defer r.Body.Close()
requestID := utils.GenerateID()
headers := h.headerPool.Get()
h.headerPool.CloneWithExtra(headers, r.Header, "Host", r.Host)
httpReq := protocol.HTTPRequest{
Method: r.Method,
URL: r.URL.String(),
Headers: headers,
Body: body,
}
reqBytes, err := protocol.EncodeHTTPRequest(&httpReq)
h.headerPool.Put(headers)
if err != nil {
h.logger.Error("Encode HTTP request failed", zap.Error(err))
http.Error(w, "Internal server error", http.StatusInternalServerError)
return
}
header := protocol.DataHeader{
StreamID: requestID,
RequestID: requestID,
Type: "http_request",
IsLast: true,
}
payload, err := protocol.EncodeDataPayload(header, reqBytes)
if err != nil {
h.logger.Error("Encode data payload failed", zap.Error(err))
http.Error(w, "Internal server error", http.StatusInternalServerError)
return
}
frame := protocol.NewFrame(protocol.FrameTypeData, payload)
respChan := h.responses.CreateResponseChan(requestID)
defer h.responses.CleanupResponseChan(requestID)
if err := transport.SendFrame(frame); err != nil {
h.logger.Error("Send frame to tunnel failed", zap.Error(err))
http.Error(w, "Failed to forward request to tunnel", http.StatusBadGateway)
return
}
ctx, cancel := context.WithTimeout(context.Background(), constants.RequestTimeout)
defer cancel()
select {
case respMsg := <-respChan:
h.writeHTTPResponse(w, respMsg, subdomain, r)
case <-ctx.Done():
http.Error(w, "Request timeout - the tunnel client did not respond in time", http.StatusGatewayTimeout)
}
}
func (h *Handler) writeHTTPResponse(w http.ResponseWriter, resp *protocol.HTTPResponse, subdomain string, r *http.Request) {
if resp == nil {
http.Error(w, "Invalid response from tunnel", http.StatusBadGateway)
return
}
for key, values := range resp.Headers {
if key == "Connection" || key == "Keep-Alive" || key == "Transfer-Encoding" || key == "Upgrade" {
continue
}
if key == "Location" && len(values) > 0 {
rewrittenLocation := h.rewriteLocationHeader(values[0], r.Host)
w.Header().Set("Location", rewrittenLocation)
continue
}
for _, value := range values {
w.Header().Add(key, value)
}
}
if w.Header().Get("Content-Length") == "" && len(resp.Body) > 0 {
w.Header().Set("Content-Length", fmt.Sprintf("%d", len(resp.Body)))
}
statusCode := resp.StatusCode
if statusCode == 0 {
statusCode = http.StatusOK
}
w.WriteHeader(statusCode)
if len(resp.Body) > 0 {
w.Write(resp.Body)
}
}
func (h *Handler) rewriteLocationHeader(location, proxyHost string) string {
if !strings.HasPrefix(location, "http://") && !strings.HasPrefix(location, "https://") {
return location
}
locationURL, err := url.Parse(location)
if err != nil {
return location
}
if locationURL.Host == "localhost" ||
strings.HasPrefix(locationURL.Host, "localhost:") ||
locationURL.Host == "127.0.0.1" ||
strings.HasPrefix(locationURL.Host, "127.0.0.1:") {
scheme := "https"
if strings.Contains(proxyHost, ":") && !strings.Contains(proxyHost, "https") {
parts := strings.Split(proxyHost, ":")
if len(parts) == 2 && parts[1] != "443" {
scheme = "https"
}
}
rewritten := fmt.Sprintf("%s://%s%s", scheme, proxyHost, locationURL.Path)
if locationURL.RawQuery != "" {
rewritten += "?" + locationURL.RawQuery
}
if locationURL.Fragment != "" {
rewritten += "#" + locationURL.Fragment
}
return rewritten
}
return location
}
func (h *Handler) extractSubdomain(host string) string {
if idx := strings.Index(host, ":"); idx != -1 {
host = host[:idx]
}
if host == h.domain {
return ""
}
suffix := "." + h.domain
if strings.HasSuffix(host, suffix) {
subdomain := strings.TrimSuffix(host, suffix)
return subdomain
}
return ""
}
func (h *Handler) serveHomePage(w http.ResponseWriter, r *http.Request) {
if r.URL.Path == "/health" {
h.serveHealth(w, r)
return
}
if r.URL.Path == "/stats" {
h.serveStats(w, r)
return
}
html := `<!DOCTYPE html>
<html>
<head>
<meta charset="UTF-8" />
<title>Drip - Tunnel to Localhost</title>
<style>
body { font-family: Arial, sans-serif; max-width: 800px; margin: 50px auto; padding: 20px; }
h1 { color: #333; }
code { background: #f4f4f4; padding: 2px 6px; border-radius: 3px; }
.stats { background: #f9f9f9; padding: 15px; border-radius: 5px; margin: 20px 0; }
</style>
</head>
<body>
<h1>💧 Drip - Fast Tunnels to Localhost</h1>
<p>A high-performance tunneling service.</p>
<h2>Quick Start</h2>
<p>Install the client:</p>
<code>bash <(curl -fsSL https:///install.sh)</code>
<p>Start a tunnel:</p>
<code>drip http 3000</code><br><br>
<code>drip https 443</code><br><br>
<code>drip tcp 5432</code>
<p><a href="/health">Health Check</a> | <a href="/stats">Statistics</a></p>
</body>
</html>`
w.Header().Set("Content-Type", "text/html")
w.Write([]byte(html))
}
func (h *Handler) serveHealth(w http.ResponseWriter, r *http.Request) {
health := map[string]interface{}{
"status": "ok",
"active_tunnels": h.manager.Count(),
"timestamp": time.Now().Unix(),
}
w.Header().Set("Content-Type", "application/json")
json.NewEncoder(w).Encode(health)
}
func cloneHeadersWithHost(src http.Header, host string) http.Header {
dst := make(http.Header, len(src)+1)
for k, v := range src {
copied := make([]string, len(v))
copy(copied, v)
dst[k] = copied
}
if host != "" {
dst.Set("Host", host)
}
return dst
}
func (h *Handler) serveStats(w http.ResponseWriter, r *http.Request) {
if h.authToken != "" {
token := r.URL.Query().Get("token")
if token == "" {
authHeader := r.Header.Get("Authorization")
if strings.HasPrefix(authHeader, "Bearer ") {
token = strings.TrimPrefix(authHeader, "Bearer ")
}
}
if token != h.authToken {
http.Error(w, "Unauthorized: invalid or missing token", http.StatusUnauthorized)
return
}
}
connections := h.manager.List()
stats := map[string]interface{}{
"total_tunnels": len(connections),
"tunnels": []map[string]interface{}{},
}
for _, conn := range connections {
stats["tunnels"] = append(stats["tunnels"].([]map[string]interface{}), map[string]interface{}{
"subdomain": conn.Subdomain,
"last_active": conn.LastActive.Unix(),
})
}
w.Header().Set("Content-Type", "application/json")
json.NewEncoder(w).Encode(stats)
}

View File

@@ -0,0 +1,159 @@
package proxy
import (
"sync"
"time"
"drip/internal/shared/protocol"
"go.uber.org/zap"
)
// responseChanEntry holds a response channel and its creation time
type responseChanEntry struct {
ch chan *protocol.HTTPResponse
createdAt time.Time
}
// ResponseHandler manages response channels for HTTP requests over TCP/Frame protocol
type ResponseHandler struct {
channels map[string]*responseChanEntry
mu sync.RWMutex
logger *zap.Logger
stopCh chan struct{}
}
// NewResponseHandler creates a new response handler
func NewResponseHandler(logger *zap.Logger) *ResponseHandler {
h := &ResponseHandler{
channels: make(map[string]*responseChanEntry),
logger: logger,
stopCh: make(chan struct{}),
}
// Start single cleanup goroutine instead of one per request
go h.cleanupLoop()
return h
}
// CreateResponseChan creates a response channel for a request ID
func (h *ResponseHandler) CreateResponseChan(requestID string) chan *protocol.HTTPResponse {
h.mu.Lock()
defer h.mu.Unlock()
ch := make(chan *protocol.HTTPResponse, 1)
h.channels[requestID] = &responseChanEntry{
ch: ch,
createdAt: time.Now(),
}
return ch
}
// GetResponseChan gets the response channel for a request ID
func (h *ResponseHandler) GetResponseChan(requestID string) <-chan *protocol.HTTPResponse {
h.mu.RLock()
defer h.mu.RUnlock()
if entry := h.channels[requestID]; entry != nil {
return entry.ch
}
return nil
}
// SendResponse sends a response to the waiting channel
func (h *ResponseHandler) SendResponse(requestID string, resp *protocol.HTTPResponse) {
h.mu.RLock()
entry, exists := h.channels[requestID]
h.mu.RUnlock()
if !exists || entry == nil {
h.logger.Warn("Response channel not found",
zap.String("request_id", requestID),
)
return
}
select {
case entry.ch <- resp:
h.logger.Debug("Response sent to channel",
zap.String("request_id", requestID),
)
case <-time.After(5 * time.Second):
h.logger.Warn("Timeout sending response to channel",
zap.String("request_id", requestID),
)
}
}
// CleanupResponseChan removes and closes a response channel
func (h *ResponseHandler) CleanupResponseChan(requestID string) {
h.mu.Lock()
defer h.mu.Unlock()
if entry, exists := h.channels[requestID]; exists {
close(entry.ch)
delete(h.channels, requestID)
}
}
// GetPendingCount returns the number of pending responses
func (h *ResponseHandler) GetPendingCount() int {
h.mu.RLock()
defer h.mu.RUnlock()
return len(h.channels)
}
// cleanupLoop periodically cleans up expired response channels
// This replaces the per-request goroutine approach with a single cleanup goroutine
func (h *ResponseHandler) cleanupLoop() {
ticker := time.NewTicker(5 * time.Second)
defer ticker.Stop()
for {
select {
case <-ticker.C:
h.cleanupExpiredChannels()
case <-h.stopCh:
return
}
}
}
// cleanupExpiredChannels removes channels older than 30 seconds
func (h *ResponseHandler) cleanupExpiredChannels() {
now := time.Now()
timeout := 30 * time.Second
h.mu.Lock()
defer h.mu.Unlock()
expiredCount := 0
for requestID, entry := range h.channels {
if now.Sub(entry.createdAt) > timeout {
close(entry.ch)
delete(h.channels, requestID)
expiredCount++
}
}
if expiredCount > 0 {
h.logger.Debug("Cleaned up expired response channels",
zap.Int("count", expiredCount),
zap.Int("remaining", len(h.channels)),
)
}
}
// Close stops the cleanup loop
func (h *ResponseHandler) Close() {
close(h.stopCh)
h.mu.Lock()
defer h.mu.Unlock()
// Close all remaining channels
for _, entry := range h.channels {
close(entry.ch)
}
h.channels = make(map[string]*responseChanEntry)
}

View File

@@ -0,0 +1,588 @@
package tcp
import (
"bufio"
"encoding/json"
"fmt"
"io"
"net"
"net/http"
"strings"
"sync"
"time"
"drip/internal/server/tunnel"
"drip/internal/shared/constants"
"drip/internal/shared/protocol"
"go.uber.org/zap"
)
// Connection represents a client TCP connection
type Connection struct {
conn net.Conn
authToken string
manager *tunnel.Manager
logger *zap.Logger
subdomain string
port int
domain string
publicPort int
portAlloc *PortAllocator
tunnelConn *tunnel.Connection
proxy *TunnelProxy
stopCh chan struct{}
once sync.Once
lastHeartbeat time.Time
mu sync.RWMutex
frameWriter *protocol.FrameWriter
httpHandler http.Handler
responseChans HTTPResponseHandler
tunnelType protocol.TunnelType // Track tunnel type
}
// HTTPResponseHandler interface for response channel operations
type HTTPResponseHandler interface {
CreateResponseChan(requestID string) chan *protocol.HTTPResponse
GetResponseChan(requestID string) <-chan *protocol.HTTPResponse
CleanupResponseChan(requestID string)
SendResponse(requestID string, resp *protocol.HTTPResponse)
}
// NewConnection creates a new connection handler
func NewConnection(conn net.Conn, authToken string, manager *tunnel.Manager, logger *zap.Logger, portAlloc *PortAllocator, domain string, publicPort int, httpHandler http.Handler, responseChans HTTPResponseHandler) *Connection {
return &Connection{
conn: conn,
authToken: authToken,
manager: manager,
logger: logger,
portAlloc: portAlloc,
domain: domain,
publicPort: publicPort,
httpHandler: httpHandler,
responseChans: responseChans,
stopCh: make(chan struct{}),
lastHeartbeat: time.Now(),
}
}
// Handle handles the connection lifecycle
func (c *Connection) Handle() error {
// Ensure cleanup of control connection, proxy, port, and registry on exit.
defer c.Close()
// Set initial read timeout for protocol detection
c.conn.SetReadDeadline(time.Now().Add(30 * time.Second))
// Use buffered reader to support peeking
reader := bufio.NewReader(c.conn)
// Peek first 8 bytes to detect protocol
peek, err := reader.Peek(8)
if err != nil {
return fmt.Errorf("failed to peek connection: %w", err)
}
// Check if this is an HTTP request
peekStr := string(peek)
httpMethods := []string{"GET ", "POST", "PUT ", "DELE", "HEAD", "OPTI", "PATC", "CONN", "TRAC"}
isHTTP := false
for _, method := range httpMethods {
if strings.HasPrefix(peekStr, method) {
isHTTP = true
break
}
}
if isHTTP {
c.logger.Info("Detected HTTP request on TCP port, handling as HTTP")
return c.handleHTTPRequest(reader)
}
// Continue with drip protocol
// Wait for registration frame
frame, err := protocol.ReadFrame(reader)
if err != nil {
return fmt.Errorf("failed to read registration frame: %w", err)
}
defer frame.Release() // Return pool buffer when done
if frame.Type != protocol.FrameTypeRegister {
return fmt.Errorf("expected register frame, got %s", frame.Type)
}
// Parse registration request
var req protocol.RegisterRequest
if err := json.Unmarshal(frame.Payload, &req); err != nil {
return fmt.Errorf("failed to parse registration request: %w", err)
}
// Store tunnel type
c.tunnelType = req.TunnelType
// Authenticate
if c.authToken != "" && req.Token != c.authToken {
c.sendError("authentication_failed", "Invalid authentication token")
return fmt.Errorf("authentication failed")
}
// Allocate TCP port only for TCP tunnels
if req.TunnelType == protocol.TunnelTypeTCP {
if c.portAlloc == nil {
return fmt.Errorf("port allocator not configured")
}
port, err := c.portAlloc.Allocate()
if err != nil {
c.sendError("port_allocation_failed", err.Error())
return fmt.Errorf("failed to allocate port: %w", err)
}
c.port = port
// For TCP tunnels, prefer deterministic subdomain tied to port when not provided by client.
if req.CustomSubdomain == "" {
req.CustomSubdomain = fmt.Sprintf("tcp-%d", port)
}
}
// Register tunnel
subdomain, err := c.manager.Register(nil, req.CustomSubdomain)
if err != nil {
c.sendError("registration_failed", err.Error())
c.portAlloc.Release(c.port)
c.port = 0
return fmt.Errorf("tunnel registration failed: %w", err)
}
c.subdomain = subdomain
// Get tunnel connection
tunnelConn, ok := c.manager.Get(subdomain)
if !ok {
return fmt.Errorf("failed to get registered tunnel")
}
c.tunnelConn = tunnelConn
// Store TCP connection reference and metadata for HTTP proxy routing
c.tunnelConn.Conn = nil // We're using TCP, not WebSocket
c.tunnelConn.SetTransport(c, req.TunnelType)
c.tunnelConn.SetTunnelType(req.TunnelType)
c.tunnelType = req.TunnelType
c.logger.Info("Tunnel registered",
zap.String("subdomain", subdomain),
zap.String("tunnel_type", string(req.TunnelType)),
zap.Int("local_port", req.LocalPort),
zap.Int("remote_port", c.port),
)
// Send registration acknowledgment
// Generate appropriate URL based on tunnel type
var tunnelURL string
if req.TunnelType == protocol.TunnelTypeHTTP || req.TunnelType == protocol.TunnelTypeHTTPS {
// HTTP/HTTPS tunnels use HTTPS with subdomain
// Use publicPort for URL generation (configured via --public-port flag)
if c.publicPort == 443 {
tunnelURL = fmt.Sprintf("https://%s.%s", subdomain, c.domain)
} else {
tunnelURL = fmt.Sprintf("https://%s.%s:%d", subdomain, c.domain, c.publicPort)
}
} else {
// TCP tunnels use tcp:// with port
tunnelURL = fmt.Sprintf("tcp://%s:%d", c.domain, c.port)
}
resp := protocol.RegisterResponse{
Subdomain: subdomain,
Port: c.port,
URL: tunnelURL,
Message: "Tunnel registered successfully",
}
respData, _ := json.Marshal(resp)
ackFrame := protocol.NewFrame(protocol.FrameTypeRegisterAck, respData)
// Send registration ack (sync write before frameWriter is created)
err = protocol.WriteFrame(c.conn, ackFrame)
if err != nil {
return fmt.Errorf("failed to send registration ack: %w", err)
}
// Create frame writer for async writes
c.frameWriter = protocol.NewFrameWriter(c.conn)
// Clear read deadline
c.conn.SetReadDeadline(time.Time{})
// Start TCP proxy only for TCP tunnels
if req.TunnelType == protocol.TunnelTypeTCP {
c.proxy = NewTunnelProxy(c.port, subdomain, c.conn, c.logger)
if err := c.proxy.Start(); err != nil {
return fmt.Errorf("failed to start TCP proxy: %w", err)
}
}
// Start heartbeat checker
go c.heartbeatChecker()
// Handle frames (pass reader for consistent buffering)
return c.handleFrames(reader)
}
// handleHTTPRequest handles HTTP requests that arrive on the TCP port
func (c *Connection) handleHTTPRequest(reader *bufio.Reader) error {
// If no HTTP handler is configured, return error
if c.httpHandler == nil {
c.logger.Warn("HTTP request received but no HTTP handler configured")
response := "HTTP/1.1 503 Service Unavailable\r\n" +
"Content-Type: text/plain\r\n" +
"Content-Length: 47\r\n" +
"\r\n" +
"HTTP handler not configured for this TCP port\r\n"
c.conn.Write([]byte(response))
return fmt.Errorf("HTTP handler not configured")
}
// Clear read deadline for HTTP processing
c.conn.SetReadDeadline(time.Time{})
// Handle multiple HTTP requests on the same connection (HTTP/1.1 keep-alive)
for {
// Set a read deadline for each request to avoid hanging forever
c.conn.SetReadDeadline(time.Now().Add(60 * time.Second))
// Parse HTTP request
req, err := http.ReadRequest(reader)
if err != nil {
// EOF or timeout is normal when client closes connection or no more requests
if err == io.EOF || err == io.ErrUnexpectedEOF {
c.logger.Debug("Client closed HTTP connection")
return nil
}
if netErr, ok := err.(net.Error); ok && netErr.Timeout() {
c.logger.Debug("HTTP keep-alive timeout")
return nil
}
// Connection reset by peer is normal - client closed connection abruptly
errStr := err.Error()
if strings.Contains(errStr, "connection reset by peer") ||
strings.Contains(errStr, "broken pipe") ||
strings.Contains(errStr, "connection refused") {
c.logger.Debug("Client disconnected abruptly", zap.Error(err))
return nil
}
c.logger.Error("Failed to parse HTTP request", zap.Error(err))
return fmt.Errorf("failed to parse HTTP request: %w", err)
}
c.logger.Info("Processing HTTP request on TCP port",
zap.String("method", req.Method),
zap.String("url", req.URL.String()),
zap.String("host", req.Host),
)
// Create a response writer that writes directly to the connection
respWriter := &httpResponseWriter{
conn: c.conn,
header: make(http.Header),
}
// Handle the request
c.httpHandler.ServeHTTP(respWriter, req)
// Check if we should close the connection
// Close if: Connection: close header, or HTTP/1.0 without Connection: keep-alive
shouldClose := false
if req.Close {
shouldClose = true
} else if req.ProtoMajor == 1 && req.ProtoMinor == 0 {
// HTTP/1.0 defaults to close unless keep-alive is explicitly requested
if req.Header.Get("Connection") != "keep-alive" {
shouldClose = true
}
}
if shouldClose {
c.logger.Debug("Closing connection as requested by client")
return nil
}
// Continue to next request on the same connection
}
}
// handleFrames handles incoming frames
func (c *Connection) handleFrames(reader *bufio.Reader) error {
for {
select {
case <-c.stopCh:
return nil
default:
}
// Read frame with timeout
c.conn.SetReadDeadline(time.Now().Add(constants.RequestTimeout))
frame, err := protocol.ReadFrame(reader)
if err != nil {
if netErr, ok := err.(net.Error); ok && netErr.Timeout() {
c.logger.Warn("Read timeout, connection may be dead")
return fmt.Errorf("read timeout")
}
// EOF is normal when client closes connection gracefully
if err.Error() == "failed to read frame header: EOF" || err.Error() == "EOF" {
c.logger.Info("Client disconnected")
return nil
}
// Check if connection was closed (during shutdown)
select {
case <-c.stopCh:
// Connection was closed intentionally, don't log as error
c.logger.Debug("Connection closed during shutdown")
return nil
default:
return fmt.Errorf("failed to read frame: %w", err)
}
}
// Handle frame based on type
switch frame.Type {
case protocol.FrameTypeHeartbeat:
c.handleHeartbeat()
frame.Release()
case protocol.FrameTypeData:
// Data frame from client (response to forwarded request)
c.handleDataFrame(frame)
frame.Release() // Release after processing
case protocol.FrameTypeClose:
frame.Release()
c.logger.Info("Client requested close")
return nil
default:
frame.Release()
c.logger.Warn("Unexpected frame type",
zap.String("type", frame.Type.String()),
)
}
}
}
// handleHeartbeat handles heartbeat frame
func (c *Connection) handleHeartbeat() {
c.mu.Lock()
c.lastHeartbeat = time.Now()
c.mu.Unlock()
// Send heartbeat ack
ackFrame := protocol.NewFrame(protocol.FrameTypeHeartbeatAck, nil)
err := c.frameWriter.WriteFrame(ackFrame)
if err != nil {
c.logger.Error("Failed to send heartbeat ack", zap.Error(err))
}
}
// handleDataFrame handles data frame (response from client)
func (c *Connection) handleDataFrame(frame *protocol.Frame) {
// Decode payload (auto-detects protocol version)
header, data, err := protocol.DecodeDataPayload(frame.Payload)
if err != nil {
c.logger.Error("Failed to decode data payload",
zap.Error(err),
)
return
}
c.logger.Debug("Received data frame",
zap.String("stream_id", header.StreamID),
zap.String("type", header.Type),
zap.Int("data_size", len(data)),
)
switch header.Type {
case "response":
// TCP tunnel response, forward to proxy
if c.proxy != nil {
if err := c.proxy.HandleResponse(header.StreamID, data); err != nil {
c.logger.Error("Failed to handle response",
zap.String("stream_id", header.StreamID),
zap.Error(err),
)
}
}
case "http_response":
if c.responseChans == nil {
c.logger.Warn("No response channel handler for HTTP response",
zap.String("stream_id", header.StreamID),
)
return
}
// Decode HTTP response (auto-detects JSON vs msgpack)
httpResp, err := protocol.DecodeHTTPResponse(data)
if err != nil {
c.logger.Error("Failed to decode HTTP response",
zap.String("stream_id", header.StreamID),
zap.Error(err),
)
return
}
// Route by request ID when provided to keep request/response aligned.
reqID := header.RequestID
if reqID == "" {
reqID = header.StreamID
}
c.responseChans.SendResponse(reqID, httpResp)
c.logger.Debug("Routed HTTP response to channel",
zap.String("request_id", reqID),
)
case "close":
// Client is closing the stream
if c.proxy != nil {
c.proxy.CloseStream(header.StreamID)
}
default:
c.logger.Warn("Unknown data frame type",
zap.String("type", header.Type),
zap.String("stream_id", header.StreamID),
)
}
}
// heartbeatChecker checks for heartbeat timeout
func (c *Connection) heartbeatChecker() {
ticker := time.NewTicker(constants.HeartbeatInterval)
defer ticker.Stop()
for {
select {
case <-c.stopCh:
return
case <-ticker.C:
c.mu.RLock()
lastHB := c.lastHeartbeat
c.mu.RUnlock()
if time.Since(lastHB) > constants.HeartbeatTimeout {
c.logger.Warn("Heartbeat timeout",
zap.String("subdomain", c.subdomain),
zap.Duration("last_heartbeat", time.Since(lastHB)),
)
c.Close()
return
}
}
}
}
// SendFrame sends a frame to the client
func (c *Connection) SendFrame(frame *protocol.Frame) error {
if c.frameWriter == nil {
return protocol.WriteFrame(c.conn, frame)
}
return c.frameWriter.WriteFrame(frame)
}
// sendError sends an error frame to the client
func (c *Connection) sendError(code, message string) {
errMsg := protocol.ErrorMessage{
Code: code,
Message: message,
}
data, _ := json.Marshal(errMsg)
errFrame := protocol.NewFrame(protocol.FrameTypeError, data)
if c.frameWriter == nil {
// Fallback if frameWriter not initialized (early errors)
protocol.WriteFrame(c.conn, errFrame)
} else {
c.frameWriter.WriteFrame(errFrame)
}
}
// Close closes the connection
func (c *Connection) Close() {
c.once.Do(func() {
close(c.stopCh)
// Close frame writer
if c.frameWriter != nil {
c.frameWriter.Close()
}
// Stop TCP proxy
if c.proxy != nil {
c.proxy.Stop()
}
c.conn.Close()
// Release allocated port
if c.port > 0 && c.portAlloc != nil {
c.portAlloc.Release(c.port)
}
// Unregister tunnel
if c.subdomain != "" {
c.manager.Unregister(c.subdomain)
}
c.logger.Info("Connection closed",
zap.String("subdomain", c.subdomain),
)
})
}
// GetSubdomain returns the assigned subdomain
func (c *Connection) GetSubdomain() string {
return c.subdomain
}
// httpResponseWriter implements http.ResponseWriter for writing to a net.Conn
type httpResponseWriter struct {
conn net.Conn
header http.Header
statusCode int
headerWritten bool
}
func (w *httpResponseWriter) Header() http.Header {
return w.header
}
func (w *httpResponseWriter) WriteHeader(statusCode int) {
if w.headerWritten {
return
}
w.statusCode = statusCode
w.headerWritten = true
// Write status line
statusText := http.StatusText(statusCode)
if statusText == "" {
statusText = "Unknown"
}
fmt.Fprintf(w.conn, "HTTP/1.1 %d %s\r\n", statusCode, statusText)
// Write headers
for key, values := range w.header {
for _, value := range values {
fmt.Fprintf(w.conn, "%s: %s\r\n", key, value)
}
}
// Write empty line to end headers
fmt.Fprintf(w.conn, "\r\n")
}
func (w *httpResponseWriter) Write(data []byte) (int, error) {
if !w.headerWritten {
w.WriteHeader(http.StatusOK)
}
return w.conn.Write(data)
}

View File

@@ -0,0 +1,254 @@
package tcp
import (
"crypto/tls"
"fmt"
"net"
"net/http"
"strings"
"sync"
"time"
"drip/internal/server/tunnel"
"drip/internal/shared/pool"
"go.uber.org/zap"
)
// Listener handles TCP connections with TLS 1.3
type Listener struct {
address string
tlsConfig *tls.Config
authToken string
manager *tunnel.Manager
portAlloc *PortAllocator
logger *zap.Logger
domain string
publicPort int
httpHandler http.Handler
responseChans HTTPResponseHandler
listener net.Listener
stopCh chan struct{}
wg sync.WaitGroup
connections map[string]*Connection
connMu sync.RWMutex
workerPool *pool.WorkerPool // Worker pool for connection handling
}
// NewListener creates a new TCP listener
func NewListener(address string, tlsConfig *tls.Config, authToken string, manager *tunnel.Manager, logger *zap.Logger, portAlloc *PortAllocator, domain string, publicPort int, httpHandler http.Handler, responseChans HTTPResponseHandler) *Listener {
// Create worker pool with 50 workers and queue size of 1000
// This reduces goroutine creation overhead for connection handling
workerPool := pool.NewWorkerPool(50, 1000)
return &Listener{
address: address,
tlsConfig: tlsConfig,
authToken: authToken,
manager: manager,
portAlloc: portAlloc,
logger: logger,
domain: domain,
publicPort: publicPort,
httpHandler: httpHandler,
responseChans: responseChans,
stopCh: make(chan struct{}),
connections: make(map[string]*Connection),
workerPool: workerPool,
}
}
// Start starts the TCP listener
func (l *Listener) Start() error {
var err error
// Create TLS listener
l.listener, err = tls.Listen("tcp", l.address, l.tlsConfig)
if err != nil {
return fmt.Errorf("failed to start TLS listener: %w", err)
}
l.logger.Info("TCP listener started",
zap.String("address", l.address),
zap.String("tls_version", "TLS 1.3"),
)
// Accept connections in background
l.wg.Add(1)
go l.acceptLoop()
return nil
}
// acceptLoop accepts incoming connections
func (l *Listener) acceptLoop() {
defer l.wg.Done()
for {
select {
case <-l.stopCh:
return
default:
}
// Set accept deadline to allow checking stopCh
if tcpListener, ok := l.listener.(*net.TCPListener); ok {
tcpListener.SetDeadline(time.Now().Add(1 * time.Second))
}
conn, err := l.listener.Accept()
if err != nil {
if netErr, ok := err.(net.Error); ok && netErr.Timeout() {
continue // Timeout is expected due to deadline
}
select {
case <-l.stopCh:
return // Listener was stopped
default:
l.logger.Error("Failed to accept connection", zap.Error(err))
continue
}
}
// Handle connection using worker pool instead of creating new goroutine
// This reduces goroutine creation overhead and improves performance
l.wg.Add(1)
submitted := l.workerPool.Submit(func() {
l.handleConnection(conn)
})
// If pool is full or closed, fall back to direct goroutine
if !submitted {
go l.handleConnection(conn)
}
}
}
// handleConnection handles a single client connection
func (l *Listener) handleConnection(netConn net.Conn) {
defer l.wg.Done()
defer netConn.Close()
// Get TLS connection info
tlsConn, ok := netConn.(*tls.Conn)
if !ok {
l.logger.Error("Connection is not TLS")
return
}
// Force TLS handshake to complete
if err := tlsConn.Handshake(); err != nil {
// TLS handshake failures are common (HTTP clients, scanners, etc.)
// Log as WARN instead of ERROR
l.logger.Warn("TLS handshake failed",
zap.String("remote_addr", netConn.RemoteAddr().String()),
zap.Error(err),
)
return
}
// Log connection info
state := tlsConn.ConnectionState()
l.logger.Info("New connection",
zap.String("remote_addr", netConn.RemoteAddr().String()),
zap.Uint16("tls_version", state.Version),
zap.String("cipher_suite", tls.CipherSuiteName(state.CipherSuite)),
)
// Verify TLS 1.3
if state.Version != tls.VersionTLS13 {
l.logger.Warn("Connection not using TLS 1.3",
zap.Uint16("version", state.Version),
)
return
}
// Create connection handler
conn := NewConnection(netConn, l.authToken, l.manager, l.logger, l.portAlloc, l.domain, l.publicPort, l.httpHandler, l.responseChans)
// Store connection
connID := netConn.RemoteAddr().String()
l.connMu.Lock()
l.connections[connID] = conn
l.connMu.Unlock()
// Remove connection on exit
defer func() {
l.connMu.Lock()
delete(l.connections, connID)
l.connMu.Unlock()
}()
// Handle connection (blocking)
if err := conn.Handle(); err != nil {
errStr := err.Error()
// Client disconnection errors - normal network behavior, log as DEBUG
if strings.Contains(errStr, "connection reset by peer") ||
strings.Contains(errStr, "broken pipe") ||
strings.Contains(errStr, "connection refused") {
l.logger.Debug("Client disconnected",
zap.String("remote_addr", connID),
zap.Error(err),
)
return
}
// Protocol errors (invalid clients, scanners) are expected - log as WARN
if strings.Contains(errStr, "payload too large") ||
strings.Contains(errStr, "failed to read registration frame") ||
strings.Contains(errStr, "expected register frame") ||
strings.Contains(errStr, "failed to parse registration request") ||
strings.Contains(errStr, "failed to parse HTTP request") {
l.logger.Warn("Protocol validation failed",
zap.String("remote_addr", connID),
zap.Error(err),
)
} else {
// Legitimate errors (auth failures, registration failures, etc.)
l.logger.Error("Connection handling failed",
zap.String("remote_addr", connID),
zap.Error(err),
)
}
}
}
// Stop stops the listener and closes all connections
func (l *Listener) Stop() error {
l.logger.Info("Stopping TCP listener")
// Signal stop
close(l.stopCh)
// Close listener
if l.listener != nil {
if err := l.listener.Close(); err != nil {
l.logger.Error("Failed to close listener", zap.Error(err))
}
}
// Close all connections
l.connMu.Lock()
for _, conn := range l.connections {
conn.Close()
}
l.connMu.Unlock()
// Wait for all goroutines to finish
l.wg.Wait()
// Close worker pool
if l.workerPool != nil {
l.workerPool.Close()
}
l.logger.Info("TCP listener stopped")
return nil
}
// GetActiveConnections returns the number of active connections
func (l *Listener) GetActiveConnections() int {
l.connMu.RLock()
defer l.connMu.RUnlock()
return len(l.connections)
}

View File

@@ -0,0 +1,79 @@
package tcp
import (
"crypto/rand"
"fmt"
"math/big"
"net"
"sync"
)
// PortAllocator manages dynamic TCP port allocation within a configured range.
// It keeps an in-memory reservation map; ports are held until Release is called.
type PortAllocator struct {
min int
max int
used map[int]bool
mu sync.Mutex
}
// NewPortAllocator creates a new allocator with the given inclusive range.
func NewPortAllocator(min, max int) (*PortAllocator, error) {
if min <= 0 || max <= 0 || min >= max || max > 65535 {
return nil, fmt.Errorf("invalid port range %d-%d", min, max)
}
return &PortAllocator{
min: min,
max: max,
used: make(map[int]bool),
}, nil
}
// Allocate finds a free port, marks it as used, and ensures it's currently available.
func (p *PortAllocator) Allocate() (int, error) {
p.mu.Lock()
defer p.mu.Unlock()
total := p.max - p.min + 1
for attempts := 0; attempts < total; attempts++ {
port := p.randomPort()
if p.used[port] {
continue
}
// Probe the port to ensure it's not taken by the OS/other process.
ln, err := net.Listen("tcp", fmt.Sprintf("0.0.0.0:%d", port))
if err != nil {
continue
}
ln.Close()
p.used[port] = true
return port, nil
}
return 0, fmt.Errorf("no available port in range %d-%d", p.min, p.max)
}
// Release frees a previously allocated port.
func (p *PortAllocator) Release(port int) {
p.mu.Lock()
defer p.mu.Unlock()
delete(p.used, port)
}
func (p *PortAllocator) randomPort() int {
n := p.max - p.min + 1
if n <= 0 {
return p.min
}
// crypto/rand for better distribution without needing a global seed.
randInt, err := rand.Int(rand.Reader, big.NewInt(int64(n)))
if err != nil {
return p.min
}
return p.min + int(randInt.Int64())
}

View File

@@ -0,0 +1,243 @@
package tcp
import (
"fmt"
"net"
"sync"
"time"
"drip/internal/shared/pool"
"drip/internal/shared/protocol"
"go.uber.org/zap"
)
// TunnelProxy handles TCP connections for a specific tunnel
type TunnelProxy struct {
port int
subdomain string
tcpConn net.Conn // The tunnel control connection
listener net.Listener
logger *zap.Logger
stopCh chan struct{}
wg sync.WaitGroup
clientAddr string
streams map[string]net.Conn // streamID -> external connection
streamMu sync.RWMutex
frameWriter *protocol.FrameWriter
bufferPool *pool.BufferPool
}
// NewTunnelProxy creates a new TCP tunnel proxy
func NewTunnelProxy(port int, subdomain string, tcpConn net.Conn, logger *zap.Logger) *TunnelProxy {
return &TunnelProxy{
port: port,
subdomain: subdomain,
tcpConn: tcpConn,
logger: logger,
stopCh: make(chan struct{}),
clientAddr: tcpConn.RemoteAddr().String(),
streams: make(map[string]net.Conn),
bufferPool: pool.NewBufferPool(),
frameWriter: protocol.NewFrameWriter(tcpConn),
}
}
// Start starts listening on the allocated port
func (p *TunnelProxy) Start() error {
addr := fmt.Sprintf("0.0.0.0:%d", p.port)
listener, err := net.Listen("tcp", addr)
if err != nil {
return fmt.Errorf("failed to listen on port %d: %w", p.port, err)
}
p.listener = listener
p.logger.Info("TCP proxy started",
zap.Int("port", p.port),
zap.String("subdomain", p.subdomain),
)
// Accept connections in background
p.wg.Add(1)
go p.acceptLoop()
return nil
}
// acceptLoop accepts incoming TCP connections
func (p *TunnelProxy) acceptLoop() {
defer p.wg.Done()
for {
select {
case <-p.stopCh:
return
default:
}
// Set accept deadline
p.listener.(*net.TCPListener).SetDeadline(time.Now().Add(1 * time.Second))
conn, err := p.listener.Accept()
if err != nil {
if netErr, ok := err.(net.Error); ok && netErr.Timeout() {
continue
}
select {
case <-p.stopCh:
return
default:
continue
}
}
// Handle connection
p.wg.Add(1)
go p.handleConnection(conn)
}
}
func (p *TunnelProxy) handleConnection(conn net.Conn) {
defer p.wg.Done()
defer conn.Close()
streamID := fmt.Sprintf("%d-%d", time.Now().UnixNano(), p.port)
p.streamMu.Lock()
p.streams[streamID] = conn
p.streamMu.Unlock()
defer func() {
p.streamMu.Lock()
delete(p.streams, streamID)
p.streamMu.Unlock()
}()
bufPtr := p.bufferPool.Get(pool.SizeMedium)
defer p.bufferPool.Put(bufPtr)
buffer := (*bufPtr)[:pool.SizeMedium]
for {
n, err := conn.Read(buffer)
if err != nil {
break
}
if n > 0 {
if err := p.sendDataToTunnel(streamID, buffer[:n]); err != nil {
p.logger.Error("Send to tunnel failed", zap.Error(err))
break
}
}
}
select {
case <-p.stopCh:
default:
p.sendCloseToTunnel(streamID)
}
}
func (p *TunnelProxy) sendDataToTunnel(streamID string, data []byte) error {
select {
case <-p.stopCh:
return fmt.Errorf("tunnel proxy stopped")
default:
}
header := protocol.DataHeader{
StreamID: streamID,
RequestID: streamID,
Type: "data",
IsLast: false,
}
payload, err := protocol.EncodeDataPayload(header, data)
if err != nil {
return fmt.Errorf("failed to encode payload: %w", err)
}
frame := protocol.NewFrame(protocol.FrameTypeData, payload)
err = p.frameWriter.WriteFrame(frame)
if err != nil {
return fmt.Errorf("failed to write frame: %w", err)
}
return nil
}
func (p *TunnelProxy) sendCloseToTunnel(streamID string) {
header := protocol.DataHeader{
StreamID: streamID,
RequestID: streamID,
Type: "close",
IsLast: true,
}
payload, err := protocol.EncodeDataPayload(header, nil)
if err != nil {
return
}
frame := protocol.NewFrame(protocol.FrameTypeData, payload)
p.frameWriter.WriteFrame(frame)
}
func (p *TunnelProxy) HandleResponse(streamID string, data []byte) error {
p.streamMu.RLock()
conn, ok := p.streams[streamID]
p.streamMu.RUnlock()
if !ok {
return fmt.Errorf("stream not found: %s", streamID)
}
if _, err := conn.Write(data); err != nil {
p.logger.Error("Write to client failed", zap.Error(err))
return err
}
return nil
}
// CloseStream closes a stream
func (p *TunnelProxy) CloseStream(streamID string) {
p.streamMu.RLock()
conn, ok := p.streams[streamID]
p.streamMu.RUnlock()
if ok {
conn.Close()
}
}
func (p *TunnelProxy) Stop() {
p.logger.Info("Stopping TCP proxy",
zap.Int("port", p.port),
zap.String("subdomain", p.subdomain),
)
close(p.stopCh)
if p.listener != nil {
p.listener.Close()
}
p.streamMu.Lock()
for _, conn := range p.streams {
conn.Close()
}
p.streams = make(map[string]net.Conn)
p.streamMu.Unlock()
p.wg.Wait()
if p.frameWriter != nil {
p.frameWriter.Close()
}
p.logger.Info("TCP proxy stopped", zap.Int("port", p.port))
}

View File

@@ -0,0 +1,73 @@
package tls
import (
"crypto/tls"
"net/http"
"os"
"path/filepath"
"go.uber.org/zap"
"golang.org/x/crypto/acme/autocert"
)
// AutoCertManager manages automatic certificate provisioning with Let's Encrypt
type AutoCertManager struct {
manager *autocert.Manager
logger *zap.Logger
}
// NewAutoCertManager creates a new AutoCert manager
func NewAutoCertManager(domain, cacheDir string, logger *zap.Logger) *AutoCertManager {
m := &autocert.Manager{
Prompt: autocert.AcceptTOS,
HostPolicy: autocert.HostWhitelist(domain, "*."+domain),
Cache: autocert.DirCache(cacheDir),
}
logger.Info("AutoTLS enabled",
zap.String("domain", domain),
zap.String("cache_dir", cacheDir),
)
return &AutoCertManager{
manager: m,
logger: logger,
}
}
// GetTLSConfig returns the TLS configuration
func (a *AutoCertManager) GetTLSConfig() *tls.Config {
return a.manager.TLSConfig()
}
// HTTPHandler returns the HTTP handler for ACME challenges
func (a *AutoCertManager) HTTPHandler() http.Handler {
return a.manager.HTTPHandler(nil)
}
// GetCertificate gets a certificate for the given ClientHelloInfo
func (a *AutoCertManager) GetCertificate(hello *tls.ClientHelloInfo) (*tls.Certificate, error) {
cert, err := a.manager.GetCertificate(hello)
if err != nil {
a.logger.Error("Failed to get certificate",
zap.String("server_name", hello.ServerName),
zap.Error(err),
)
return nil, err
}
a.logger.Debug("Certificate obtained",
zap.String("server_name", hello.ServerName),
)
return cert, nil
}
// DefaultCacheDir returns the default cache directory for certificates
func DefaultCacheDir() string {
home := os.Getenv("HOME")
if home == "" {
home = "/tmp"
}
return filepath.Join(home, ".drip", "certs")
}

View File

@@ -0,0 +1,192 @@
package tunnel
import (
"sync"
"time"
"drip/internal/shared/protocol"
"github.com/gorilla/websocket"
"go.uber.org/zap"
)
// Transport represents the control channel to the client.
// It is implemented by the TCP control connection so the HTTP proxy
// can push frames directly to the client without depending on WebSockets.
type Transport interface {
SendFrame(frame *protocol.Frame) error
}
// Connection represents a tunnel connection from a client
type Connection struct {
Subdomain string
Conn *websocket.Conn
SendCh chan []byte
CloseCh chan struct{}
LastActive time.Time
mu sync.RWMutex
logger *zap.Logger
closed bool
transport Transport
tunnelType protocol.TunnelType
}
// NewConnection creates a new tunnel connection
func NewConnection(subdomain string, conn *websocket.Conn, logger *zap.Logger) *Connection {
return &Connection{
Subdomain: subdomain,
Conn: conn,
SendCh: make(chan []byte, 256),
CloseCh: make(chan struct{}),
LastActive: time.Now(),
logger: logger,
closed: false,
}
}
// Send sends data through the WebSocket connection
func (c *Connection) Send(data []byte) error {
c.mu.RLock()
if c.closed {
c.mu.RUnlock()
return ErrConnectionClosed
}
c.mu.RUnlock()
select {
case c.SendCh <- data:
return nil
case <-time.After(5 * time.Second):
return ErrSendTimeout
}
}
// UpdateActivity updates the last activity timestamp
func (c *Connection) UpdateActivity() {
c.mu.Lock()
defer c.mu.Unlock()
c.LastActive = time.Now()
}
// IsAlive checks if the connection is still alive based on last activity
func (c *Connection) IsAlive(timeout time.Duration) bool {
c.mu.RLock()
defer c.mu.RUnlock()
return time.Since(c.LastActive) < timeout
}
// Close closes the connection and all associated channels
func (c *Connection) Close() {
c.mu.Lock()
defer c.mu.Unlock()
if c.closed {
return
}
c.closed = true
close(c.CloseCh)
close(c.SendCh)
if c.Conn != nil {
// Send close message
c.Conn.WriteMessage(websocket.CloseMessage,
websocket.FormatCloseMessage(websocket.CloseNormalClosure, ""))
c.Conn.Close()
}
c.logger.Info("Connection closed",
zap.String("subdomain", c.Subdomain),
)
}
// IsClosed returns whether the connection is closed
func (c *Connection) IsClosed() bool {
c.mu.RLock()
defer c.mu.RUnlock()
return c.closed
}
// SetTransport attaches the control transport and tunnel type.
func (c *Connection) SetTransport(t Transport, tType protocol.TunnelType) {
c.mu.Lock()
c.transport = t
c.tunnelType = tType
c.mu.Unlock()
}
// GetTransport returns the attached transport (if any).
func (c *Connection) GetTransport() Transport {
c.mu.RLock()
defer c.mu.RUnlock()
return c.transport
}
// SetTunnelType sets the tunnel type.
func (c *Connection) SetTunnelType(tType protocol.TunnelType) {
c.mu.Lock()
c.tunnelType = tType
c.mu.Unlock()
}
// GetTunnelType returns the tunnel type.
func (c *Connection) GetTunnelType() protocol.TunnelType {
c.mu.RLock()
defer c.mu.RUnlock()
return c.tunnelType
}
// StartWritePump starts the write pump for sending messages
func (c *Connection) StartWritePump() {
// Skip write pump for TCP-only connections (no WebSocket)
if c.Conn == nil {
c.logger.Debug("Skipping WritePump for TCP connection",
zap.String("subdomain", c.Subdomain),
)
// Still need to drain SendCh to prevent blocking
go func() {
for {
select {
case <-c.SendCh:
// Discard messages for TCP mode
case <-c.CloseCh:
return
}
}
}()
return
}
ticker := time.NewTicker(30 * time.Second)
defer func() {
ticker.Stop()
c.Close()
}()
for {
select {
case message, ok := <-c.SendCh:
if !ok {
return
}
c.Conn.SetWriteDeadline(time.Now().Add(10 * time.Second))
if err := c.Conn.WriteMessage(websocket.TextMessage, message); err != nil {
c.logger.Error("Write error",
zap.String("subdomain", c.Subdomain),
zap.Error(err),
)
return
}
case <-ticker.C:
// Send ping to keep connection alive
c.Conn.SetWriteDeadline(time.Now().Add(10 * time.Second))
if err := c.Conn.WriteMessage(websocket.PingMessage, nil); err != nil {
return
}
case <-c.CloseCh:
return
}
}
}

View File

@@ -0,0 +1,263 @@
package tunnel
import (
"testing"
"time"
"go.uber.org/zap"
)
func TestNewConnection(t *testing.T) {
subdomain := "test123"
logger := zap.NewNop()
// We can't create a real WebSocket connection in unit tests,
// so we'll just test with nil
conn := NewConnection(subdomain, nil, logger)
if conn == nil {
t.Fatal("NewConnection() returned nil")
}
if conn.Subdomain != subdomain {
t.Errorf("Subdomain = %v, want %v", conn.Subdomain, subdomain)
}
if conn.SendCh == nil {
t.Error("SendCh is nil")
}
if conn.CloseCh == nil {
t.Error("CloseCh is nil")
}
// Check that LastActive is recent (within last second)
now := time.Now()
if now.Sub(conn.LastActive) > time.Second {
t.Error("LastActive is not recent")
}
}
func TestConnectionUpdateActivity(t *testing.T) {
logger := zap.NewNop()
conn := NewConnection("test", nil, logger)
// Get initial LastActive
initial := conn.LastActive
// Wait a bit
time.Sleep(10 * time.Millisecond)
// Update activity
conn.UpdateActivity()
// Check that LastActive was updated
if !conn.LastActive.After(initial) {
t.Error("UpdateActivity() did not update LastActive")
}
// Check that it's recent
now := time.Now()
if now.Sub(conn.LastActive) > time.Second {
t.Error("UpdateActivity() did not set recent timestamp")
}
}
func TestConnectionIsAlive(t *testing.T) {
tests := []struct {
name string
lastActive time.Time
timeout time.Duration
want bool
}{
{
name: "fresh connection is alive",
lastActive: time.Now(),
timeout: 90 * time.Second,
want: true,
},
{
name: "stale connection is not alive",
lastActive: time.Now().Add(-2 * time.Minute),
timeout: 90 * time.Second,
want: false,
},
{
name: "exactly at timeout is not alive",
lastActive: time.Now().Add(-90 * time.Second),
timeout: 90 * time.Second,
want: false,
},
{
name: "just before timeout is alive",
lastActive: time.Now().Add(-89 * time.Second),
timeout: 90 * time.Second,
want: true,
},
}
logger := zap.NewNop()
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
conn := NewConnection("test", nil, logger)
conn.mu.Lock()
conn.LastActive = tt.lastActive
conn.mu.Unlock()
got := conn.IsAlive(tt.timeout)
if got != tt.want {
t.Errorf("IsAlive() = %v, want %v (age: %v, timeout: %v)",
got, tt.want, time.Since(tt.lastActive), tt.timeout)
}
})
}
}
func TestConnectionSend(t *testing.T) {
logger := zap.NewNop()
conn := NewConnection("test", nil, logger)
data := []byte("test message")
// Test successful send
err := conn.Send(data)
if err != nil {
t.Errorf("Send() error = %v, want nil", err)
}
// Verify data was sent to channel
select {
case received := <-conn.SendCh:
if string(received) != string(data) {
t.Errorf("Received data = %v, want %v", string(received), string(data))
}
case <-time.After(100 * time.Millisecond):
t.Error("Send() did not send data to channel")
}
}
func TestConnectionSendTimeout(t *testing.T) {
logger := zap.NewNop()
conn := NewConnection("test", nil, logger)
// Fill the channel
for i := 0; i < 256; i++ {
conn.SendCh <- []byte("fill")
}
// Try to send when channel is full
data := []byte("test message")
err := conn.Send(data)
if err != ErrSendTimeout {
t.Errorf("Send() on full channel error = %v, want %v", err, ErrSendTimeout)
}
}
func TestConnectionClose(t *testing.T) {
logger := zap.NewNop()
conn := NewConnection("test", nil, logger)
// Close the connection
conn.Close()
// Verify CloseCh is closed
select {
case <-conn.CloseCh:
// Successfully received from closed channel
case <-time.After(100 * time.Millisecond):
t.Error("Close() did not close CloseCh")
}
// Try to close again (should not panic)
defer func() {
if r := recover(); r != nil {
t.Error("Close() panicked on second call")
}
}()
conn.Close()
}
func TestConnectionConcurrentUpdateActivity(t *testing.T) {
logger := zap.NewNop()
conn := NewConnection("test", nil, logger)
// Update activity concurrently
done := make(chan bool)
for i := 0; i < 100; i++ {
go func() {
conn.UpdateActivity()
done <- true
}()
}
// Wait for all goroutines
for i := 0; i < 100; i++ {
<-done
}
// Verify LastActive is recent
now := time.Now()
if now.Sub(conn.LastActive) > time.Second {
t.Error("Concurrent UpdateActivity() failed")
}
}
func TestConnectionConcurrentIsAlive(t *testing.T) {
logger := zap.NewNop()
conn := NewConnection("test", nil, logger)
// Check IsAlive concurrently
done := make(chan bool)
for i := 0; i < 100; i++ {
go func() {
conn.IsAlive(90 * time.Second)
done <- true
}()
}
// Wait for all goroutines
for i := 0; i < 100; i++ {
<-done
}
}
// Benchmark tests
func BenchmarkConnectionSend(b *testing.B) {
logger := zap.NewNop()
conn := NewConnection("test", nil, logger)
// Drain channel in background
go func() {
for range conn.SendCh {
}
}()
data := []byte("test message")
b.ResetTimer()
for i := 0; i < b.N; i++ {
conn.Send(data)
}
}
func BenchmarkConnectionUpdateActivity(b *testing.B) {
logger := zap.NewNop()
conn := NewConnection("test", nil, logger)
b.ResetTimer()
for i := 0; i < b.N; i++ {
conn.UpdateActivity()
}
}
func BenchmarkConnectionIsAlive(b *testing.B) {
logger := zap.NewNop()
conn := NewConnection("test", nil, logger)
timeout := 90 * time.Second
b.ResetTimer()
for i := 0; i < b.N; i++ {
conn.IsAlive(timeout)
}
}

View File

@@ -0,0 +1,23 @@
package tunnel
import "errors"
var (
// ErrConnectionClosed is returned when trying to use a closed connection
ErrConnectionClosed = errors.New("connection is closed")
// ErrSendTimeout is returned when send operation times out
ErrSendTimeout = errors.New("send operation timed out")
// ErrTunnelNotFound is returned when a tunnel is not found
ErrTunnelNotFound = errors.New("tunnel not found")
// ErrSubdomainTaken is returned when a subdomain is already in use
ErrSubdomainTaken = errors.New("subdomain is already taken")
// ErrInvalidSubdomain is returned when a subdomain is invalid
ErrInvalidSubdomain = errors.New("invalid subdomain format")
// ErrReservedSubdomain is returned when trying to use a reserved subdomain
ErrReservedSubdomain = errors.New("subdomain is reserved")
)

View File

@@ -0,0 +1,185 @@
package tunnel
import (
"sync"
"time"
"drip/internal/shared/utils"
"github.com/gorilla/websocket"
"go.uber.org/zap"
)
// Manager manages all active tunnel connections
type Manager struct {
tunnels map[string]*Connection // subdomain -> connection
mu sync.RWMutex
used map[string]bool // track used subdomains
logger *zap.Logger
}
// NewManager creates a new tunnel manager
func NewManager(logger *zap.Logger) *Manager {
return &Manager{
tunnels: make(map[string]*Connection),
used: make(map[string]bool),
logger: logger,
}
}
// Register registers a new tunnel connection
// Returns the assigned subdomain and any error
func (m *Manager) Register(conn *websocket.Conn, customSubdomain string) (string, error) {
m.mu.Lock()
defer m.mu.Unlock()
var subdomain string
if customSubdomain != "" {
// Validate custom subdomain
if !utils.ValidateSubdomain(customSubdomain) {
return "", ErrInvalidSubdomain
}
if utils.IsReserved(customSubdomain) {
return "", ErrReservedSubdomain
}
if m.used[customSubdomain] {
return "", ErrSubdomainTaken
}
subdomain = customSubdomain
} else {
// Generate unique random subdomain
subdomain = m.generateUniqueSubdomain()
}
// Create connection
tc := NewConnection(subdomain, conn, m.logger)
m.tunnels[subdomain] = tc
m.used[subdomain] = true
// Start write pump in background
go tc.StartWritePump()
m.logger.Info("Tunnel registered",
zap.String("subdomain", subdomain),
zap.Int("total_tunnels", len(m.tunnels)),
)
return subdomain, nil
}
// Unregister removes a tunnel connection
func (m *Manager) Unregister(subdomain string) {
m.mu.Lock()
defer m.mu.Unlock()
if tc, ok := m.tunnels[subdomain]; ok {
tc.Close()
delete(m.tunnels, subdomain)
delete(m.used, subdomain)
m.logger.Info("Tunnel unregistered",
zap.String("subdomain", subdomain),
zap.Int("total_tunnels", len(m.tunnels)),
)
}
}
// Get retrieves a tunnel connection by subdomain
func (m *Manager) Get(subdomain string) (*Connection, bool) {
m.mu.RLock()
defer m.mu.RUnlock()
tc, ok := m.tunnels[subdomain]
return tc, ok
}
// List returns all active tunnel connections
func (m *Manager) List() []*Connection {
m.mu.RLock()
defer m.mu.RUnlock()
connections := make([]*Connection, 0, len(m.tunnels))
for _, tc := range m.tunnels {
connections = append(connections, tc)
}
return connections
}
// Count returns the number of active tunnels
func (m *Manager) Count() int {
m.mu.RLock()
defer m.mu.RUnlock()
return len(m.tunnels)
}
// CleanupStale removes stale connections that haven't been active
func (m *Manager) CleanupStale(timeout time.Duration) int {
m.mu.Lock()
defer m.mu.Unlock()
staleSubdomains := []string{}
for subdomain, tc := range m.tunnels {
if !tc.IsAlive(timeout) {
staleSubdomains = append(staleSubdomains, subdomain)
}
}
for _, subdomain := range staleSubdomains {
if tc, ok := m.tunnels[subdomain]; ok {
tc.Close()
delete(m.tunnels, subdomain)
delete(m.used, subdomain)
}
}
if len(staleSubdomains) > 0 {
m.logger.Info("Cleaned up stale tunnels",
zap.Int("count", len(staleSubdomains)),
)
}
return len(staleSubdomains)
}
// StartCleanupTask starts a background task to clean up stale connections
func (m *Manager) StartCleanupTask(interval, timeout time.Duration) {
ticker := time.NewTicker(interval)
go func() {
for range ticker.C {
m.CleanupStale(timeout)
}
}()
}
// generateUniqueSubdomain generates a unique random subdomain
func (m *Manager) generateUniqueSubdomain() string {
const maxAttempts = 10
for i := 0; i < maxAttempts; i++ {
subdomain := utils.GenerateSubdomain(6)
if !m.used[subdomain] && !utils.IsReserved(subdomain) {
return subdomain
}
}
// Fallback: use longer subdomain if collision persists
return utils.GenerateSubdomain(8)
}
// Shutdown gracefully shuts down all tunnels
func (m *Manager) Shutdown() {
m.mu.Lock()
defer m.mu.Unlock()
m.logger.Info("Shutting down tunnel manager",
zap.Int("active_tunnels", len(m.tunnels)),
)
for _, tc := range m.tunnels {
tc.Close()
}
m.tunnels = make(map[string]*Connection)
m.used = make(map[string]bool)
}

View File

@@ -0,0 +1,376 @@
package tunnel
import (
"sync"
"testing"
"time"
"go.uber.org/zap"
)
func TestNewManager(t *testing.T) {
logger := zap.NewNop()
manager := NewManager(logger)
if manager == nil {
t.Fatal("NewManager() returned nil")
}
if manager.tunnels == nil {
t.Error("Manager tunnels map is nil")
}
if manager.used == nil {
t.Error("Manager used map is nil")
}
if manager.logger == nil {
t.Error("Manager logger is nil")
}
}
func TestManagerRegister(t *testing.T) {
logger := zap.NewNop()
manager := NewManager(logger)
// Register with empty subdomain (auto-generate)
subdomain, err := manager.Register(nil, "")
if err != nil {
t.Errorf("Register() error = %v, want nil", err)
}
if subdomain == "" {
t.Error("Register() returned empty subdomain")
}
if len(subdomain) != 6 {
t.Errorf("Register() subdomain length = %d, want 6", len(subdomain))
}
// Verify connection is registered
_, ok := manager.Get(subdomain)
if !ok {
t.Error("Get() failed to retrieve registered connection")
}
}
func TestManagerRegisterCustomSubdomain(t *testing.T) {
logger := zap.NewNop()
manager := NewManager(logger)
customSubdomain := "mytest"
// Register with custom subdomain
subdomain, err := manager.Register(nil, customSubdomain)
if err != nil {
t.Errorf("Register() error = %v, want nil", err)
}
if subdomain != customSubdomain {
t.Errorf("Register() subdomain = %v, want %v", subdomain, customSubdomain)
}
// Verify connection is registered
conn, ok := manager.Get(subdomain)
if !ok {
t.Error("Get() failed to retrieve registered connection")
}
if conn.Subdomain != customSubdomain {
t.Errorf("Connection subdomain = %v, want %v", conn.Subdomain, customSubdomain)
}
}
func TestManagerRegisterDuplicate(t *testing.T) {
logger := zap.NewNop()
manager := NewManager(logger)
customSubdomain := "test123"
// Register first connection
_, err := manager.Register(nil, customSubdomain)
if err != nil {
t.Fatalf("First Register() error = %v, want nil", err)
}
// Try to register second connection with same subdomain
_, err = manager.Register(nil, customSubdomain)
if err != ErrSubdomainTaken {
t.Errorf("Register() error = %v, want %v", err, ErrSubdomainTaken)
}
}
func TestManagerRegisterInvalidSubdomain(t *testing.T) {
logger := zap.NewNop()
manager := NewManager(logger)
tests := []struct {
name string
subdomain string
wantErr error
}{
{
name: "invalid uppercase",
subdomain: "TEST",
wantErr: ErrInvalidSubdomain,
},
{
name: "invalid special char",
subdomain: "test@123",
wantErr: ErrInvalidSubdomain,
},
{
name: "reserved www",
subdomain: "www",
wantErr: ErrReservedSubdomain,
},
{
name: "reserved api",
subdomain: "api",
wantErr: ErrReservedSubdomain,
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
_, err := manager.Register(nil, tt.subdomain)
if err != tt.wantErr {
t.Errorf("Register() error = %v, want %v", err, tt.wantErr)
}
})
}
}
func TestManagerUnregister(t *testing.T) {
logger := zap.NewNop()
manager := NewManager(logger)
// Register connection
subdomain, err := manager.Register(nil, "")
if err != nil {
t.Fatalf("Register() error = %v", err)
}
// Unregister connection
manager.Unregister(subdomain)
// Verify connection is removed
_, ok := manager.Get(subdomain)
if ok {
t.Error("Get() succeeded after Unregister(), want failure")
}
}
func TestManagerGet(t *testing.T) {
logger := zap.NewNop()
manager := NewManager(logger)
customSubdomain := "test123"
// Test Get on non-existent connection
_, ok := manager.Get(customSubdomain)
if ok {
t.Error("Get() succeeded for non-existent connection")
}
// Register and test Get
subdomain, _ := manager.Register(nil, customSubdomain)
retrieved, ok := manager.Get(subdomain)
if !ok {
t.Error("Get() failed for existing connection")
}
if retrieved.Subdomain != customSubdomain {
t.Error("Get() returned wrong connection")
}
}
func TestManagerList(t *testing.T) {
logger := zap.NewNop()
manager := NewManager(logger)
// Test empty manager
all := manager.List()
if len(all) != 0 {
t.Errorf("List() on empty manager returned %d connections, want 0", len(all))
}
// Add multiple connections
count := 5
for i := 0; i < count; i++ {
manager.Register(nil, "")
}
// Test List
all = manager.List()
if len(all) != count {
t.Errorf("List() returned %d connections, want %d", len(all), count)
}
}
func TestManagerCount(t *testing.T) {
logger := zap.NewNop()
manager := NewManager(logger)
// Test empty manager
count := manager.Count()
if count != 0 {
t.Errorf("Count() on empty manager = %d, want 0", count)
}
// Add connections
numConns := 3
for i := 0; i < numConns; i++ {
manager.Register(nil, "")
}
count = manager.Count()
if count != numConns {
t.Errorf("Count() = %d, want %d", count, numConns)
}
}
func TestManagerGenerateSubdomain(t *testing.T) {
logger := zap.NewNop()
manager := NewManager(logger)
// Generate subdomain via Register
subdomain1, err := manager.Register(nil, "")
if err != nil {
t.Fatalf("First Register() error = %v", err)
}
if subdomain1 == "" {
t.Error("Register() returned empty subdomain")
}
if len(subdomain1) != 6 {
t.Errorf("Register() subdomain length = %d, want 6", len(subdomain1))
}
// Generate another subdomain, should be different
subdomain2, err := manager.Register(nil, "")
if err != nil {
t.Fatalf("Second Register() error = %v", err)
}
if subdomain1 == subdomain2 {
t.Error("Register() generated duplicate subdomain")
}
}
func TestManagerGenerateSubdomainUniqueness(t *testing.T) {
logger := zap.NewNop()
manager := NewManager(logger)
subdomains := make(map[string]bool)
count := 100
for i := 0; i < count; i++ {
subdomain, err := manager.Register(nil, "")
if err != nil {
t.Fatalf("Register() error = %v", err)
}
if subdomains[subdomain] {
t.Errorf("Register() generated duplicate: %s", subdomain)
}
subdomains[subdomain] = true
}
}
func TestManagerCleanupStale(t *testing.T) {
logger := zap.NewNop()
manager := NewManager(logger)
// Create fresh connection
freshSubdomain, _ := manager.Register(nil, "fresh")
// Create stale connection
staleSubdomain, _ := manager.Register(nil, "stale")
// Manually set LastActive to be stale
if staleConn, ok := manager.Get(staleSubdomain); ok {
staleConn.mu.Lock()
staleConn.LastActive = time.Now().Add(-2 * time.Minute)
staleConn.mu.Unlock()
}
// Run cleanup with 90 second timeout
count := manager.CleanupStale(90 * time.Second)
if count != 1 {
t.Errorf("CleanupStale() returned %d, want 1", count)
}
// Fresh connection should still exist
_, ok := manager.Get(freshSubdomain)
if !ok {
t.Error("CleanupStale() removed fresh connection")
}
// Stale connection should be removed
_, ok = manager.Get(staleSubdomain)
if ok {
t.Error("CleanupStale() did not remove stale connection")
}
}
func TestManagerConcurrentAccess(t *testing.T) {
logger := zap.NewNop()
manager := NewManager(logger)
var wg sync.WaitGroup
count := 50 // Reduced from 100 to avoid potential issues
// Concurrent registrations
for i := 0; i < count; i++ {
wg.Add(1)
go func(idx int) {
defer wg.Done()
_, err := manager.Register(nil, "")
if err != nil {
t.Errorf("Concurrent Register() error = %v", err)
}
}(i)
}
wg.Wait()
// Verify all connections are registered
all := manager.List()
if len(all) != count {
t.Errorf("Expected %d connections, got %d", count, len(all))
}
}
// Benchmark tests
func BenchmarkManagerRegister(b *testing.B) {
logger := zap.NewNop()
manager := NewManager(logger)
b.ResetTimer()
for i := 0; i < b.N; i++ {
manager.Register(nil, "")
}
}
func BenchmarkManagerGet(b *testing.B) {
logger := zap.NewNop()
manager := NewManager(logger)
// Setup: register a connection
subdomain, _ := manager.Register(nil, "test123")
b.ResetTimer()
for i := 0; i < b.N; i++ {
manager.Get(subdomain)
}
}
func BenchmarkManagerGenerateSubdomain(b *testing.B) {
logger := zap.NewNop()
b.ResetTimer()
for i := 0; i < b.N; i++ {
manager := NewManager(logger)
manager.Register(nil, "")
}
}

View File

@@ -0,0 +1,48 @@
package constants
import "time"
const (
// DefaultServerPort is the default port for the tunnel server
DefaultServerPort = 8080
// DefaultWSPort is the default WebSocket port
DefaultWSPort = 8080
// HeartbeatInterval is how often clients send heartbeat messages
HeartbeatInterval = 5 * time.Second
// HeartbeatTimeout is how long the server waits before considering a connection dead
HeartbeatTimeout = 15 * time.Second
// RequestTimeout is the maximum time to wait for a response from the client
RequestTimeout = 30 * time.Second
// MaxRequestBodySize is the maximum size of an HTTP request body (10MB)
MaxRequestBodySize = 10 * 1024 * 1024
// ReconnectBaseDelay is the initial delay for reconnection attempts
ReconnectBaseDelay = 1 * time.Second
// ReconnectMaxDelay is the maximum delay between reconnection attempts
ReconnectMaxDelay = 60 * time.Second
// MaxReconnectAttempts is the maximum number of reconnection attempts (0 = infinite)
MaxReconnectAttempts = 0
// DefaultTCPPortMin/Max define the default allocation range for TCP tunnels
DefaultTCPPortMin = 20000
DefaultTCPPortMax = 40000
// DefaultDomain is the default domain for tunnels
DefaultDomain = "tunnel.localhost"
)
// Error codes
const (
ErrCodeTunnelNotFound = "TUNNEL_NOT_FOUND"
ErrCodeTimeout = "TIMEOUT"
ErrCodeConnectionFailed = "CONNECTION_FAILED"
ErrCodeInvalidRequest = "INVALID_REQUEST"
ErrCodeAuthFailed = "AUTH_FAILED"
ErrCodeRateLimited = "RATE_LIMITED"
)

View File

@@ -0,0 +1,77 @@
package pool
import "sync"
const (
SizeSmall = 4 * 1024 // 4KB
SizeMedium = 32 * 1024 // 32KB
SizeLarge = 256 * 1024 // 256KB
)
type BufferPool struct {
small sync.Pool
medium sync.Pool
large sync.Pool
}
func NewBufferPool() *BufferPool {
return &BufferPool{
small: sync.Pool{
New: func() interface{} {
b := make([]byte, SizeSmall)
return &b
},
},
medium: sync.Pool{
New: func() interface{} {
b := make([]byte, SizeMedium)
return &b
},
},
large: sync.Pool{
New: func() interface{} {
b := make([]byte, SizeLarge)
return &b
},
},
}
}
func (p *BufferPool) Get(size int) *[]byte {
switch {
case size <= SizeSmall:
return p.small.Get().(*[]byte)
case size <= SizeMedium:
return p.medium.Get().(*[]byte)
default:
return p.large.Get().(*[]byte)
}
}
func (p *BufferPool) Put(buf *[]byte) {
if buf == nil {
return
}
size := cap(*buf)
*buf = (*buf)[:cap(*buf)]
switch size {
case SizeSmall:
p.small.Put(buf)
case SizeMedium:
p.medium.Put(buf)
case SizeLarge:
p.large.Put(buf)
}
}
var globalBufferPool = NewBufferPool()
func GetBuffer(size int) *[]byte {
return globalBufferPool.Get(size)
}
func PutBuffer(buf *[]byte) {
globalBufferPool.Put(buf)
}

View File

@@ -0,0 +1,93 @@
package pool
import (
"net/http"
"sync"
)
// HeaderPool manages a pool of http.Header objects for reuse
// This reduces GC pressure from repeated header map allocations
type HeaderPool struct {
pool sync.Pool
}
// NewHeaderPool creates a new header pool
func NewHeaderPool() *HeaderPool {
return &HeaderPool{
pool: sync.Pool{
New: func() interface{} {
// Pre-allocate with capacity for common header count (8-12 headers)
return make(http.Header, 12)
},
},
}
}
// Get retrieves a header from the pool
// Returns a clean, empty header ready for use
func (p *HeaderPool) Get() http.Header {
h := p.pool.Get().(http.Header)
// Clear any existing data (headers might be dirty from previous use)
for k := range h {
delete(h, k)
}
return h
}
// Put returns a header to the pool
// The header will be reused by future Get() calls
func (p *HeaderPool) Put(h http.Header) {
if h == nil {
return
}
// Note: We don't clear here, clearing is done in Get() for better performance
// (allows the GC to collect during idle time)
p.pool.Put(h)
}
// Clone creates a copy of src into dst, reusing dst's underlying storage
// This is more efficient than creating a new header from scratch
func (p *HeaderPool) Clone(dst, src http.Header) {
// Clear dst first
for k := range dst {
delete(dst, k)
}
// Copy all headers from src to dst
for k, vv := range src {
// Allocate new slice with exact capacity to avoid over-allocation
dst[k] = make([]string, len(vv))
copy(dst[k], vv)
}
}
// CloneWithExtra clones src into dst and adds/overwrites extra headers
// This is optimized for the common pattern of cloning + adding Host header
func (p *HeaderPool) CloneWithExtra(dst, src http.Header, extraKey, extraValue string) {
// Clear dst first
for k := range dst {
delete(dst, k)
}
// Copy all headers from src to dst
for k, vv := range src {
dst[k] = make([]string, len(vv))
copy(dst[k], vv)
}
// Set extra header (overwrite if exists)
dst.Set(extraKey, extraValue)
}
// globalHeaderPool is a package-level pool for convenience
var globalHeaderPool = NewHeaderPool()
// GetHeader retrieves a header from the global pool
func GetHeader() http.Header {
return globalHeaderPool.Get()
}
// PutHeader returns a header to the global pool
func PutHeader(h http.Header) {
globalHeaderPool.Put(h)
}

View File

@@ -0,0 +1,115 @@
package pool
import (
"sync"
)
// WorkerPool is a fixed-size goroutine pool for handling tasks
type WorkerPool struct {
workers int
jobQueue chan func()
wg sync.WaitGroup
once sync.Once
closed bool
mu sync.RWMutex
}
// NewWorkerPool creates a new worker pool with the specified number of workers
func NewWorkerPool(workers int, queueSize int) *WorkerPool {
if workers <= 0 {
workers = 50 // Default worker count
}
if queueSize <= 0 {
queueSize = 1000 // Default queue size
}
pool := &WorkerPool{
workers: workers,
jobQueue: make(chan func(), queueSize),
}
// Start worker goroutines
for i := 0; i < workers; i++ {
pool.wg.Add(1)
go pool.worker()
}
return pool
}
// worker is the worker goroutine that processes jobs from the queue
func (p *WorkerPool) worker() {
defer p.wg.Done()
for job := range p.jobQueue {
if job != nil {
job()
}
}
}
// Submit submits a job to the worker pool
// Returns false if the pool is closed or the queue is full
func (p *WorkerPool) Submit(job func()) bool {
if job == nil {
return false
}
p.mu.RLock()
if p.closed {
p.mu.RUnlock()
return false
}
p.mu.RUnlock()
// Non-blocking send
select {
case p.jobQueue <- job:
return true
default:
// Queue is full, fall back to direct execution
// This prevents blocking when pool is overloaded
go job()
return false
}
}
// SubmitWait submits a job and waits for it to complete
func (p *WorkerPool) SubmitWait(job func()) {
if job == nil {
return
}
done := make(chan struct{})
wrappedJob := func() {
defer close(done)
job()
}
if p.Submit(wrappedJob) {
<-done
} else {
// Pool is closed or full, execute directly
job()
}
}
// Close gracefully shuts down the worker pool
// It waits for all pending jobs to complete
func (p *WorkerPool) Close() {
p.once.Do(func() {
p.mu.Lock()
p.closed = true
p.mu.Unlock()
close(p.jobQueue)
p.wg.Wait()
})
}
// IsClosed returns true if the pool is closed
func (p *WorkerPool) IsClosed() bool {
p.mu.RLock()
defer p.mu.RUnlock()
return p.closed
}

View File

@@ -0,0 +1,166 @@
package protocol
import (
"encoding/binary"
"errors"
)
// DataHeaderV2 represents a binary-encoded data header (Protocol Version 2)
// This replaces JSON encoding to improve performance
type DataHeaderV2 struct {
Type DataType
IsLast bool
StreamID string
RequestID string
}
// DataType represents the type of data frame
type DataType uint8
const (
DataTypeData DataType = 0x00 // 000
DataTypeResponse DataType = 0x01 // 001
DataTypeClose DataType = 0x02 // 010
DataTypeHTTPRequest DataType = 0x03 // 011
DataTypeHTTPResponse DataType = 0x04 // 100
)
// String returns the string representation of DataType
func (t DataType) String() string {
switch t {
case DataTypeData:
return "data"
case DataTypeResponse:
return "response"
case DataTypeClose:
return "close"
case DataTypeHTTPRequest:
return "http_request"
case DataTypeHTTPResponse:
return "http_response"
default:
return "unknown"
}
}
// FromString converts a string to DataType
func DataTypeFromString(s string) DataType {
switch s {
case "data":
return DataTypeData
case "response":
return DataTypeResponse
case "close":
return DataTypeClose
case "http_request":
return DataTypeHTTPRequest
case "http_response":
return DataTypeHTTPResponse
default:
return DataTypeData
}
}
// Binary format:
// +--------+--------+--------+--------+--------+
// | Flags | StreamID Length | RequestID Len |
// | 1 byte | 2 bytes | 2 bytes |
// +--------+--------+--------+--------+--------+
// | StreamID (variable) |
// +--------+--------+--------+--------+--------+
// | RequestID (variable) |
// +--------+--------+--------+--------+--------+
//
// Flags (8 bits):
// - Bit 0-2: Type (3 bits)
// - Bit 3: IsLast (1 bit)
// - Bit 4-7: Reserved (4 bits)
const (
binaryHeaderMinSize = 5 // 1 byte flags + 2 bytes streamID len + 2 bytes requestID len
)
// MarshalBinary encodes the header to binary format
func (h *DataHeaderV2) MarshalBinary() []byte {
streamIDLen := len(h.StreamID)
requestIDLen := len(h.RequestID)
totalLen := binaryHeaderMinSize + streamIDLen + requestIDLen
buf := make([]byte, totalLen)
// Encode flags
flags := uint8(h.Type) & 0x07 // Type uses bits 0-2
if h.IsLast {
flags |= 0x08 // IsLast uses bit 3
}
buf[0] = flags
// Encode lengths (big-endian)
binary.BigEndian.PutUint16(buf[1:3], uint16(streamIDLen))
binary.BigEndian.PutUint16(buf[3:5], uint16(requestIDLen))
// Encode StreamID
offset := binaryHeaderMinSize
copy(buf[offset:], h.StreamID)
offset += streamIDLen
// Encode RequestID
copy(buf[offset:], h.RequestID)
return buf
}
// UnmarshalBinary decodes the header from binary format
func (h *DataHeaderV2) UnmarshalBinary(data []byte) error {
if len(data) < binaryHeaderMinSize {
return errors.New("invalid binary header: too short")
}
// Decode flags
flags := data[0]
h.Type = DataType(flags & 0x07) // Bits 0-2
h.IsLast = (flags & 0x08) != 0 // Bit 3
// Decode lengths
streamIDLen := int(binary.BigEndian.Uint16(data[1:3]))
requestIDLen := int(binary.BigEndian.Uint16(data[3:5]))
// Validate total length
expectedLen := binaryHeaderMinSize + streamIDLen + requestIDLen
if len(data) < expectedLen {
return errors.New("invalid binary header: length mismatch")
}
// Decode StreamID
offset := binaryHeaderMinSize
h.StreamID = string(data[offset : offset+streamIDLen])
offset += streamIDLen
// Decode RequestID
h.RequestID = string(data[offset : offset+requestIDLen])
return nil
}
// ToDataHeader converts binary header to JSON header (for compatibility)
func (h *DataHeaderV2) ToDataHeader() DataHeader {
return DataHeader{
StreamID: h.StreamID,
RequestID: h.RequestID,
Type: h.Type.String(),
IsLast: h.IsLast,
}
}
// FromDataHeader converts JSON header to binary header
func (h *DataHeaderV2) FromDataHeader(dh DataHeader) {
h.StreamID = dh.StreamID
h.RequestID = dh.RequestID
h.Type = DataTypeFromString(dh.Type)
h.IsLast = dh.IsLast
}
// Size returns the size of the binary-encoded header
func (h *DataHeaderV2) Size() int {
return binaryHeaderMinSize + len(h.StreamID) + len(h.RequestID)
}

View File

@@ -0,0 +1,136 @@
package protocol
import (
"encoding/binary"
"fmt"
"io"
"drip/internal/shared/pool"
)
const (
FrameHeaderSize = 5
MaxFrameSize = 10 * 1024 * 1024
)
// FrameType defines the type of frame
type FrameType byte
const (
FrameTypeRegister FrameType = 0x01
FrameTypeRegisterAck FrameType = 0x02
FrameTypeHeartbeat FrameType = 0x03
FrameTypeHeartbeatAck FrameType = 0x04
FrameTypeData FrameType = 0x05
FrameTypeClose FrameType = 0x06
FrameTypeError FrameType = 0x07
)
// String returns the string representation of frame type
func (t FrameType) String() string {
switch t {
case FrameTypeRegister:
return "Register"
case FrameTypeRegisterAck:
return "RegisterAck"
case FrameTypeHeartbeat:
return "Heartbeat"
case FrameTypeHeartbeatAck:
return "HeartbeatAck"
case FrameTypeData:
return "Data"
case FrameTypeClose:
return "Close"
case FrameTypeError:
return "Error"
default:
return fmt.Sprintf("Unknown(%d)", t)
}
}
type Frame struct {
Type FrameType
Payload []byte
poolBuffer *[]byte
}
func WriteFrame(w io.Writer, frame *Frame) error {
payloadLen := len(frame.Payload)
if payloadLen > MaxFrameSize {
return fmt.Errorf("payload too large: %d bytes (max %d)", payloadLen, MaxFrameSize)
}
lengthBuf := make([]byte, 4)
binary.BigEndian.PutUint32(lengthBuf, uint32(payloadLen))
if _, err := w.Write(lengthBuf); err != nil {
return fmt.Errorf("failed to write length: %w", err)
}
if _, err := w.Write([]byte{byte(frame.Type)}); err != nil {
return fmt.Errorf("failed to write type: %w", err)
}
if payloadLen > 0 {
if _, err := w.Write(frame.Payload); err != nil {
return fmt.Errorf("failed to write payload: %w", err)
}
}
return nil
}
func ReadFrame(r io.Reader) (*Frame, error) {
header := make([]byte, FrameHeaderSize)
if _, err := io.ReadFull(r, header); err != nil {
return nil, fmt.Errorf("failed to read frame header: %w", err)
}
payloadLen := binary.BigEndian.Uint32(header[0:4])
if payloadLen > MaxFrameSize {
return nil, fmt.Errorf("payload too large: %d bytes (max %d)", payloadLen, MaxFrameSize)
}
frameType := FrameType(header[4])
var payload []byte
var poolBuf *[]byte
if payloadLen > 0 {
if payloadLen > pool.SizeLarge {
payload = make([]byte, payloadLen)
if _, err := io.ReadFull(r, payload); err != nil {
return nil, fmt.Errorf("failed to read payload: %w", err)
}
} else {
poolBuf = pool.GetBuffer(int(payloadLen))
payload = (*poolBuf)[:payloadLen]
if _, err := io.ReadFull(r, payload); err != nil {
pool.PutBuffer(poolBuf)
return nil, fmt.Errorf("failed to read payload: %w", err)
}
}
}
return &Frame{
Type: frameType,
Payload: payload,
poolBuffer: poolBuf,
}, nil
}
func (f *Frame) Release() {
if f.poolBuffer != nil {
pool.PutBuffer(f.poolBuffer)
f.poolBuffer = nil
f.Payload = nil
}
}
// NewFrame creates a new frame
func NewFrame(frameType FrameType, payload []byte) *Frame {
return &Frame{
Type: frameType,
Payload: payload,
}
}

View File

@@ -0,0 +1,68 @@
package protocol
import (
"encoding/json"
"errors"
"github.com/vmihailenco/msgpack/v5"
)
// EncodeHTTPRequest encodes HTTPRequest using msgpack encoding (optimized)
func EncodeHTTPRequest(req *HTTPRequest) ([]byte, error) {
return msgpack.Marshal(req)
}
// DecodeHTTPRequest decodes HTTPRequest with automatic version detection
// Detects based on first byte: '{' = JSON, else = msgpack
func DecodeHTTPRequest(data []byte) (*HTTPRequest, error) {
if len(data) == 0 {
return nil, errors.New("empty data")
}
var req HTTPRequest
// Auto-detect: JSON starts with '{', msgpack starts with 0x80-0x8f (fixmap)
if data[0] == '{' {
// v1: JSON
if err := json.Unmarshal(data, &req); err != nil {
return nil, err
}
} else {
// v2: msgpack
if err := msgpack.Unmarshal(data, &req); err != nil {
return nil, err
}
}
return &req, nil
}
// EncodeHTTPResponse encodes HTTPResponse using msgpack encoding (optimized)
func EncodeHTTPResponse(resp *HTTPResponse) ([]byte, error) {
return msgpack.Marshal(resp)
}
// DecodeHTTPResponse decodes HTTPResponse with automatic version detection
// Detects based on first byte: '{' = JSON, else = msgpack
func DecodeHTTPResponse(data []byte) (*HTTPResponse, error) {
if len(data) == 0 {
return nil, errors.New("empty data")
}
var resp HTTPResponse
// Auto-detect: JSON starts with '{', msgpack starts with 0x80-0x8f (fixmap)
if data[0] == '{' {
// v1: JSON
if err := json.Unmarshal(data, &resp); err != nil {
return nil, err
}
} else {
// v2: msgpack
if err := msgpack.Unmarshal(data, &resp); err != nil {
return nil, err
}
}
return &resp, nil
}

View File

@@ -0,0 +1,55 @@
package protocol
// MessageType defines the type of tunnel message
type MessageType string
const (
// TypeRegister is sent when a client connects and gets a subdomain assigned
TypeRegister MessageType = "register"
// TypeRequest is sent from server to client when an HTTP request arrives
TypeRequest MessageType = "request"
// TypeResponse is sent from client to server with the HTTP response
TypeResponse MessageType = "response"
// TypeHeartbeat is sent periodically to keep the connection alive
TypeHeartbeat MessageType = "heartbeat"
// TypeError is sent when an error occurs
TypeError MessageType = "error"
)
// Message represents a tunnel protocol message
type Message struct {
Type MessageType `json:"type"`
ID string `json:"id,omitempty"`
Subdomain string `json:"subdomain,omitempty"`
Data map[string]interface{} `json:"data,omitempty"`
Error string `json:"error,omitempty"`
}
// HTTPRequest represents an HTTP request to be forwarded
type HTTPRequest struct {
Method string `json:"method"`
URL string `json:"url"`
Headers map[string][]string `json:"headers"`
Body []byte `json:"body,omitempty"`
}
// HTTPResponse represents an HTTP response from the local service
type HTTPResponse struct {
StatusCode int `json:"status_code"`
Status string `json:"status"`
Headers map[string][]string `json:"headers"`
Body []byte `json:"body,omitempty"`
}
// RegisterData contains information sent when a tunnel is registered
type RegisterData struct {
Subdomain string `json:"subdomain"`
URL string `json:"url"`
Message string `json:"message"`
}
// ErrorData contains error information
type ErrorData struct {
Code string `json:"code"`
Message string `json:"message"`
}

View File

@@ -0,0 +1,307 @@
package protocol
import (
"encoding/json"
"reflect"
"testing"
)
func TestMessageType_Values(t *testing.T) {
tests := []struct {
name string
mt MessageType
want string
}{
{
name: "register",
mt: TypeRegister,
want: "register",
},
{
name: "request",
mt: TypeRequest,
want: "request",
},
{
name: "response",
mt: TypeResponse,
want: "response",
},
{
name: "heartbeat",
mt: TypeHeartbeat,
want: "heartbeat",
},
{
name: "error",
mt: TypeError,
want: "error",
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
got := string(tt.mt)
if got != tt.want {
t.Errorf("MessageType value = %v, want %v", got, tt.want)
}
})
}
}
func TestMessage_JSON(t *testing.T) {
tests := []struct {
name string
message *Message
wantErr bool
}{
{
name: "simple message",
message: &Message{
Type: TypeRegister,
ID: "test-id-123",
},
wantErr: false,
},
{
name: "message with subdomain",
message: &Message{
Type: TypeRegister,
ID: "test-id-456",
Subdomain: "abc123",
},
wantErr: false,
},
{
name: "message with data",
message: &Message{
Type: TypeRequest,
ID: "test-id-789",
Data: map[string]interface{}{
"method": "GET",
"path": "/test",
"headers": map[string]interface{}{
"User-Agent": "Test",
},
},
},
wantErr: false,
},
{
name: "error message",
message: &Message{
Type: TypeError,
ID: "test-id-error",
Data: map[string]interface{}{
"error": "something went wrong",
"code": float64(500), // JSON unmarshals numbers as float64
},
},
wantErr: false,
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
// Marshal to JSON
data, err := json.Marshal(tt.message)
if (err != nil) != tt.wantErr {
t.Errorf("json.Marshal() error = %v, wantErr %v", err, tt.wantErr)
return
}
if tt.wantErr {
return
}
// Unmarshal back
var decoded Message
err = json.Unmarshal(data, &decoded)
if err != nil {
t.Errorf("json.Unmarshal() error = %v", err)
return
}
// Compare
if decoded.Type != tt.message.Type {
t.Errorf("Type = %v, want %v", decoded.Type, tt.message.Type)
}
if decoded.ID != tt.message.ID {
t.Errorf("ID = %v, want %v", decoded.ID, tt.message.ID)
}
if decoded.Subdomain != tt.message.Subdomain {
t.Errorf("Subdomain = %v, want %v", decoded.Subdomain, tt.message.Subdomain)
}
// Deep compare Data if present
if tt.message.Data != nil {
if !reflect.DeepEqual(decoded.Data, tt.message.Data) {
t.Errorf("Data = %v, want %v", decoded.Data, tt.message.Data)
}
}
})
}
}
func TestHTTPRequest_JSON(t *testing.T) {
req := &HTTPRequest{
Method: "POST",
URL: "http://localhost:3000/api/test",
Headers: map[string][]string{
"Content-Type": {"application/json"},
"User-Agent": {"Test Agent"},
},
Body: []byte(`{"key":"value"}`),
}
// Marshal
data, err := json.Marshal(req)
if err != nil {
t.Fatalf("json.Marshal() error = %v", err)
}
// Unmarshal
var decoded HTTPRequest
err = json.Unmarshal(data, &decoded)
if err != nil {
t.Fatalf("json.Unmarshal() error = %v", err)
}
// Compare
if decoded.Method != req.Method {
t.Errorf("Method = %v, want %v", decoded.Method, req.Method)
}
if decoded.URL != req.URL {
t.Errorf("URL = %v, want %v", decoded.URL, req.URL)
}
if !reflect.DeepEqual(decoded.Headers, req.Headers) {
t.Errorf("Headers = %v, want %v", decoded.Headers, req.Headers)
}
if string(decoded.Body) != string(req.Body) {
t.Errorf("Body = %v, want %v", string(decoded.Body), string(req.Body))
}
}
func TestHTTPResponse_JSON(t *testing.T) {
resp := &HTTPResponse{
StatusCode: 200,
Status: "200 OK",
Headers: map[string][]string{
"Content-Type": {"text/html"},
},
Body: []byte("<html>Test</html>"),
}
// Marshal
data, err := json.Marshal(resp)
if err != nil {
t.Fatalf("json.Marshal() error = %v", err)
}
// Unmarshal
var decoded HTTPResponse
err = json.Unmarshal(data, &decoded)
if err != nil {
t.Fatalf("json.Unmarshal() error = %v", err)
}
// Compare
if decoded.StatusCode != resp.StatusCode {
t.Errorf("StatusCode = %v, want %v", decoded.StatusCode, resp.StatusCode)
}
if decoded.Status != resp.Status {
t.Errorf("Status = %v, want %v", decoded.Status, resp.Status)
}
if !reflect.DeepEqual(decoded.Headers, resp.Headers) {
t.Errorf("Headers = %v, want %v", decoded.Headers, resp.Headers)
}
if string(decoded.Body) != string(resp.Body) {
t.Errorf("Body = %v, want %v", string(decoded.Body), string(resp.Body))
}
}
func TestMessage_ToMap(t *testing.T) {
msg := &Message{
Type: TypeRequest,
ID: "test-123",
Subdomain: "abc",
Data: map[string]interface{}{
"test": "value",
},
}
// Convert to map (simulated by marshaling and unmarshaling)
data, err := json.Marshal(msg)
if err != nil {
t.Fatalf("json.Marshal() error = %v", err)
}
var result map[string]interface{}
err = json.Unmarshal(data, &result)
if err != nil {
t.Fatalf("json.Unmarshal() error = %v", err)
}
// Verify fields exist
if result["type"] == nil {
t.Error("Map missing 'type' field")
}
if result["id"] == nil {
t.Error("Map missing 'id' field")
}
}
func TestNewMessage(t *testing.T) {
msgType := TypeRegister
id := "test-id"
msg := &Message{
Type: msgType,
ID: id,
}
if msg.Type != msgType {
t.Errorf("Type = %v, want %v", msg.Type, msgType)
}
if msg.ID != id {
t.Errorf("ID = %v, want %v", msg.ID, id)
}
}
// Benchmark tests
func BenchmarkMessageMarshal(b *testing.B) {
msg := &Message{
Type: TypeRequest,
ID: "test-id-123",
Subdomain: "abc123",
Data: map[string]interface{}{
"method": "GET",
"path": "/test",
},
}
b.ResetTimer()
for i := 0; i < b.N; i++ {
json.Marshal(msg)
}
}
func BenchmarkMessageUnmarshal(b *testing.B) {
msg := &Message{
Type: TypeRequest,
ID: "test-id-123",
Subdomain: "abc123",
Data: map[string]interface{}{
"method": "GET",
"path": "/test",
},
}
data, _ := json.Marshal(msg)
b.ResetTimer()
for i := 0; i < b.N; i++ {
var decoded Message
json.Unmarshal(data, &decoded)
}
}

View File

@@ -0,0 +1,49 @@
package protocol
import "encoding/json"
// RegisterRequest is sent by client to register a tunnel
type RegisterRequest struct {
Token string `json:"token"` // Authentication token
CustomSubdomain string `json:"custom_subdomain"` // Optional custom subdomain
TunnelType TunnelType `json:"tunnel_type"` // http, tcp, udp
LocalPort int `json:"local_port"` // Local port to forward to
}
// RegisterResponse is sent by server after successful registration
type RegisterResponse struct {
Subdomain string `json:"subdomain"` // Assigned subdomain
Port int `json:"port,omitempty"` // Assigned TCP port (for TCP tunnels)
URL string `json:"url"` // Full tunnel URL
Message string `json:"message"` // Success message
}
// ErrorMessage represents an error
type ErrorMessage struct {
Code string `json:"code"` // Error code
Message string `json:"message"` // Error message
}
// DataHeader represents metadata for a data frame
type DataHeader struct {
StreamID string `json:"stream_id"` // Unique stream identifier
RequestID string `json:"request_id"` // Request identifier (for HTTP)
Type string `json:"type"` // "data", "response", "close", "http_request", "http_response"
IsLast bool `json:"is_last"` // Is this the last frame for this stream
}
// TCPData represents TCP tunnel data
type TCPData struct {
StreamID string `json:"stream_id"` // Stream identifier
Data []byte `json:"data"` // Raw TCP data
IsClose bool `json:"is_close"` // Close this stream
}
// Marshal helpers
func MarshalJSON(v interface{}) ([]byte, error) {
return json.Marshal(v)
}
func UnmarshalJSON(data []byte, v interface{}) error {
return json.Unmarshal(data, v)
}

View File

@@ -0,0 +1,129 @@
package protocol
import (
"encoding/json"
"errors"
)
// EncodeDataPayload encodes a data header and payload into a frame payload
// Uses binary encoding (optimized format)
func EncodeDataPayload(header DataHeader, data []byte) ([]byte, error) {
return EncodeDataPayloadV2(header, data)
}
// EncodeDataPayloadV1 encodes using JSON (legacy)
// Format: JSON_HEADER\nDATA
func EncodeDataPayloadV1(header DataHeader, data []byte) ([]byte, error) {
headerBytes, err := json.Marshal(header)
if err != nil {
return nil, err
}
// Combine: header + newline + data
payload := make([]byte, 0, len(headerBytes)+1+len(data))
payload = append(payload, headerBytes...)
payload = append(payload, '\n')
payload = append(payload, data...)
return payload, nil
}
// EncodeDataPayloadV2 encodes using binary format (optimized)
// Format: BINARY_HEADER + DATA
func EncodeDataPayloadV2(header DataHeader, data []byte) ([]byte, error) {
// Convert to binary header
var h2 DataHeaderV2
h2.FromDataHeader(header)
// Encode header to binary
headerBytes := h2.MarshalBinary()
// Combine: binary header + data
payload := make([]byte, 0, len(headerBytes)+len(data))
payload = append(payload, headerBytes...)
payload = append(payload, data...)
return payload, nil
}
// DecodeDataPayload decodes a frame payload into header and data
// Auto-detects protocol version
func DecodeDataPayload(payload []byte) (DataHeader, []byte, error) {
if len(payload) == 0 {
return DataHeader{}, nil, errors.New("empty payload")
}
// Try to detect version:
// - V1 (JSON): starts with '{'
// - V2 (Binary): first byte is flags (0x00-0x1F typically)
if payload[0] == '{' {
// V1: JSON format
return DecodeDataPayloadV1(payload)
}
// V2: Binary format
return DecodeDataPayloadV2(payload)
}
// DecodeDataPayloadV1 decodes JSON format (legacy)
// Format: JSON_HEADER\nDATA
func DecodeDataPayloadV1(payload []byte) (DataHeader, []byte, error) {
// Find newline separator
sepIdx := -1
for i, b := range payload {
if b == '\n' {
sepIdx = i
break
}
}
if sepIdx == -1 {
return DataHeader{}, nil, errors.New("invalid v1 payload: no newline separator")
}
// Parse JSON header
var header DataHeader
if err := json.Unmarshal(payload[:sepIdx], &header); err != nil {
return DataHeader{}, nil, err
}
// Extract data (after newline)
data := payload[sepIdx+1:]
return header, data, nil
}
// DecodeDataPayloadV2 decodes binary format (optimized)
// Format: BINARY_HEADER + DATA
func DecodeDataPayloadV2(payload []byte) (DataHeader, []byte, error) {
if len(payload) < binaryHeaderMinSize {
return DataHeader{}, nil, errors.New("invalid v2 payload: too short")
}
// Decode binary header
var h2 DataHeaderV2
if err := h2.UnmarshalBinary(payload); err != nil {
return DataHeader{}, nil, err
}
// Extract data (after header)
headerSize := h2.Size()
if len(payload) < headerSize {
return DataHeader{}, nil, errors.New("invalid v2 payload: data missing")
}
data := payload[headerSize:]
// Convert to DataHeader
header := h2.ToDataHeader()
return header, data, nil
}
// GetPayloadHeaderSize returns the size of the header in the payload
// This is useful for pre-allocating buffers
func GetPayloadHeaderSize(header DataHeader) int {
var h2 DataHeaderV2
h2.FromDataHeader(header)
return h2.Size()
}

View File

@@ -0,0 +1,30 @@
package protocol
// TunnelType defines the type of tunnel
type TunnelType string
const (
// TunnelTypeHTTP is for HTTP traffic
TunnelTypeHTTP TunnelType = "http"
// TunnelTypeHTTPS is for HTTPS traffic
TunnelTypeHTTPS TunnelType = "https"
// TunnelTypeTCP is for generic TCP traffic
TunnelTypeTCP TunnelType = "tcp"
// TunnelTypeUDP is for UDP traffic (future support)
TunnelTypeUDP TunnelType = "udp"
)
// String returns the string representation
func (t TunnelType) String() string {
return string(t)
}
// IsValid checks if tunnel type is valid
func (t TunnelType) IsValid() bool {
switch t {
case TunnelTypeHTTP, TunnelTypeHTTPS, TunnelTypeTCP, TunnelTypeUDP:
return true
default:
return false
}
}

View File

@@ -0,0 +1,126 @@
package protocol
import (
"errors"
"io"
"sync"
"time"
)
type FrameWriter struct {
conn io.Writer
queue chan *Frame
batch []*Frame
mu sync.Mutex
done chan struct{}
closed bool
maxBatch int
maxBatchWait time.Duration
}
func NewFrameWriter(conn io.Writer) *FrameWriter {
return NewFrameWriterWithConfig(conn, 128, 2*time.Millisecond, 1024)
}
func NewFrameWriterWithConfig(conn io.Writer, maxBatch int, maxBatchWait time.Duration, queueSize int) *FrameWriter {
w := &FrameWriter{
conn: conn,
queue: make(chan *Frame, queueSize),
batch: make([]*Frame, 0, maxBatch),
maxBatch: maxBatch,
maxBatchWait: maxBatchWait,
done: make(chan struct{}),
}
go w.writeLoop()
return w
}
func (w *FrameWriter) writeLoop() {
ticker := time.NewTicker(w.maxBatchWait)
defer ticker.Stop()
for {
select {
case frame, ok := <-w.queue:
if !ok {
w.mu.Lock()
w.flushBatchLocked()
w.mu.Unlock()
return
}
w.mu.Lock()
w.batch = append(w.batch, frame)
if len(w.batch) >= w.maxBatch {
w.flushBatchLocked()
}
w.mu.Unlock()
case <-ticker.C:
w.mu.Lock()
if len(w.batch) > 0 {
w.flushBatchLocked()
}
w.mu.Unlock()
case <-w.done:
w.mu.Lock()
w.flushBatchLocked()
w.mu.Unlock()
return
}
}
}
func (w *FrameWriter) flushBatchLocked() {
if len(w.batch) == 0 {
return
}
for _, frame := range w.batch {
_ = WriteFrame(w.conn, frame)
}
w.batch = w.batch[:0]
}
func (w *FrameWriter) WriteFrame(frame *Frame) error {
w.mu.Lock()
if w.closed {
w.mu.Unlock()
return errors.New("writer closed")
}
w.mu.Unlock()
select {
case w.queue <- frame:
return nil
case <-w.done:
return errors.New("writer closed")
default:
return WriteFrame(w.conn, frame)
}
}
func (w *FrameWriter) Close() error {
w.mu.Lock()
if w.closed {
w.mu.Unlock()
return nil
}
w.closed = true
w.mu.Unlock()
close(w.queue)
close(w.done)
return nil
}
func (w *FrameWriter) Flush() {
w.mu.Lock()
defer w.mu.Unlock()
w.flushBatchLocked()
}

View File

@@ -0,0 +1,31 @@
package utils
import (
"crypto/rand"
"encoding/hex"
"time"
)
// GenerateID generates a random unique ID
func GenerateID() string {
b := make([]byte, 16)
if _, err := rand.Read(b); err != nil {
// Fallback to timestamp-based ID if crypto/rand fails
return generateFallbackID()
}
return hex.EncodeToString(b)
}
// GenerateShortID generates a shorter random ID (8 chars)
func GenerateShortID() string {
b := make([]byte, 4)
if _, err := rand.Read(b); err != nil {
return generateFallbackID()[:8]
}
return hex.EncodeToString(b)
}
func generateFallbackID() string {
// Simple fallback using timestamp
return hex.EncodeToString([]byte(time.Now().String()))
}

View File

@@ -0,0 +1,158 @@
package utils
import (
"strings"
"testing"
)
func TestGenerateID(t *testing.T) {
tests := []struct {
name string
wantLength int // expected minimum length
}{
{
name: "generate valid ID",
wantLength: 16, // At least 16 characters for hex-encoded random bytes
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
got := GenerateID()
// Check that ID is not empty
if got == "" {
t.Error("GenerateID() returned empty string")
}
// Check minimum length
if len(got) < tt.wantLength {
t.Errorf("GenerateID() length = %v, want at least %v", len(got), tt.wantLength)
}
// Check that it's a valid hex string
for _, char := range got {
if !isHexChar(char) {
t.Errorf("GenerateID() contains non-hex character: %c", char)
}
}
})
}
}
func TestGenerateIDUniqueness(t *testing.T) {
// Generate 10000 IDs and check for uniqueness
ids := make(map[string]bool)
count := 10000
for i := 0; i < count; i++ {
id := GenerateID()
if ids[id] {
t.Errorf("GenerateID() generated duplicate: %s", id)
}
ids[id] = true
}
if len(ids) != count {
t.Errorf("Expected %d unique IDs, got %d", count, len(ids))
}
}
func TestGenerateIDFormat(t *testing.T) {
id := GenerateID()
// Check that it's lowercase
if id != strings.ToLower(id) {
t.Errorf("GenerateID() is not lowercase: %s", id)
}
// Check that it doesn't contain special characters
for _, char := range id {
if !isHexChar(char) {
t.Errorf("GenerateID() contains invalid character: %c in %s", char, id)
}
}
}
func TestGenerateIDConsistency(t *testing.T) {
// Generate multiple IDs and ensure they all follow the same format
count := 100
firstID := GenerateID()
firstLen := len(firstID)
for i := 0; i < count; i++ {
id := GenerateID()
// All IDs should have the same length
if len(id) != firstLen {
t.Errorf("ID length inconsistency: first=%d, current=%d", firstLen, len(id))
}
// All IDs should be hex strings
for _, char := range id {
if !isHexChar(char) {
t.Errorf("Invalid hex character %c in ID: %s", char, id)
}
}
}
}
func TestGenerateIDNotEmpty(t *testing.T) {
// Generate 1000 IDs and ensure none are empty
for i := 0; i < 1000; i++ {
id := GenerateID()
if id == "" {
t.Error("GenerateID() returned empty string")
}
if len(id) == 0 {
t.Error("GenerateID() returned zero-length string")
}
}
}
// Helper function to check if a character is a valid hex character
func isHexChar(char rune) bool {
return (char >= '0' && char <= '9') || (char >= 'a' && char <= 'f')
}
// Benchmark tests
func BenchmarkGenerateID(b *testing.B) {
for i := 0; i < b.N; i++ {
GenerateID()
}
}
func BenchmarkGenerateIDParallel(b *testing.B) {
b.RunParallel(func(pb *testing.PB) {
for pb.Next() {
GenerateID()
}
})
}
// Test for concurrent ID generation
func TestGenerateIDConcurrent(t *testing.T) {
count := 1000
ch := make(chan string, count)
// Generate IDs concurrently
for i := 0; i < count; i++ {
go func() {
ch <- GenerateID()
}()
}
// Collect all IDs
ids := make(map[string]bool)
for i := 0; i < count; i++ {
id := <-ch
if ids[id] {
t.Errorf("Concurrent GenerateID() generated duplicate: %s", id)
}
ids[id] = true
}
if len(ids) != count {
t.Errorf("Expected %d unique IDs, got %d", count, len(ids))
}
}

View File

@@ -0,0 +1,104 @@
package utils
import (
"os"
"go.uber.org/zap"
"go.uber.org/zap/zapcore"
)
var logger *zap.Logger
// InitLogger initializes the global logger for client
// verbose: if true, shows debug level logs; if false, shows error level only
func InitLogger(verbose bool) error {
var config zap.Config
if verbose {
// Verbose mode: show debug and above
config = zap.NewDevelopmentConfig()
config.EncoderConfig.EncodeLevel = zapcore.CapitalColorLevelEncoder
} else {
// Production mode: only show errors
config = zap.NewProductionConfig()
config.Level = zap.NewAtomicLevelAt(zapcore.ErrorLevel)
}
config.OutputPaths = []string{"stdout"}
config.ErrorOutputPaths = []string{"stderr"}
var err error
logger, err = config.Build()
if err != nil {
return err
}
return nil
}
// InitServerLogger initializes logger for server with info level by default
func InitServerLogger(debug bool) error {
var config zap.Config
if debug {
// Debug mode: show all logs
config = zap.NewDevelopmentConfig()
config.EncoderConfig.EncodeLevel = zapcore.CapitalColorLevelEncoder
} else {
// Production mode: show info and above
config = zap.NewProductionConfig()
}
config.OutputPaths = []string{"stdout"}
config.ErrorOutputPaths = []string{"stderr"}
var err error
logger, err = config.Build()
if err != nil {
return err
}
return nil
}
// GetLogger returns the global logger instance
func GetLogger() *zap.Logger {
if logger == nil {
// Fallback to a basic logger if not initialized
logger, _ = zap.NewProduction()
}
return logger
}
// Info logs an info message
func Info(msg string, fields ...zap.Field) {
GetLogger().Info(msg, fields...)
}
// Debug logs a debug message
func Debug(msg string, fields ...zap.Field) {
GetLogger().Debug(msg, fields...)
}
// Warn logs a warning message
func Warn(msg string, fields ...zap.Field) {
GetLogger().Warn(msg, fields...)
}
// Error logs an error message
func Error(msg string, fields ...zap.Field) {
GetLogger().Error(msg, fields...)
}
// Fatal logs a fatal message and exits
func Fatal(msg string, fields ...zap.Field) {
GetLogger().Fatal(msg, fields...)
os.Exit(1)
}
// Sync flushes any buffered log entries
func Sync() {
if logger != nil {
logger.Sync()
}
}

View File

@@ -0,0 +1,66 @@
package utils
import (
"crypto/rand"
"math/big"
"regexp"
)
const (
// SubdomainChars defines the allowed characters for subdomain generation
SubdomainChars = "abcdefghijklmnopqrstuvwxyz0123456789"
// DefaultSubdomainLength is the default length of generated subdomains
DefaultSubdomainLength = 6
)
var subdomainRegex = regexp.MustCompile(`^[a-z0-9][a-z0-9-]{1,61}[a-z0-9]$`)
// GenerateSubdomain generates a random subdomain
func GenerateSubdomain(length int) string {
if length <= 0 {
length = DefaultSubdomainLength
}
result := make([]byte, length)
charsLen := big.NewInt(int64(len(SubdomainChars)))
for i := 0; i < length; i++ {
num, err := rand.Int(rand.Reader, charsLen)
if err != nil {
// Fallback to simple random if crypto/rand fails
result[i] = SubdomainChars[i%len(SubdomainChars)]
continue
}
result[i] = SubdomainChars[num.Int64()]
}
return string(result)
}
// ValidateSubdomain checks if a subdomain is valid
func ValidateSubdomain(subdomain string) bool {
if len(subdomain) < 3 || len(subdomain) > 63 {
return false
}
return subdomainRegex.MatchString(subdomain)
}
// IsReserved checks if a subdomain is reserved
func IsReserved(subdomain string) bool {
reserved := map[string]bool{
"www": true,
"api": true,
"admin": true,
"app": true,
"mail": true,
"ftp": true,
"blog": true,
"shop": true,
"status": true,
"health": true,
"test": true,
"dev": true,
"staging": true,
}
return reserved[subdomain]
}

View File

@@ -0,0 +1,266 @@
package utils
import (
"strings"
"testing"
)
func TestGenerateSubdomain(t *testing.T) {
tests := []struct {
name string
length int
want int // expected length
}{
{
name: "default length 6",
length: 6,
want: 6,
},
{
name: "length 8",
length: 8,
want: 8,
},
{
name: "length 10",
length: 10,
want: 10,
},
{
name: "minimum length 4",
length: 4,
want: 4,
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
got := GenerateSubdomain(tt.length)
// Check length
if len(got) != tt.want {
t.Errorf("GenerateSubdomain() length = %v, want %v", len(got), tt.want)
}
// Check that it only contains alphanumeric characters
for _, char := range got {
if !isAlphanumeric(char) {
t.Errorf("GenerateSubdomain() contains non-alphanumeric character: %c", char)
}
}
// Check that it's lowercase
if got != strings.ToLower(got) {
t.Errorf("GenerateSubdomain() is not lowercase: %s", got)
}
})
}
}
func TestGenerateSubdomainUniqueness(t *testing.T) {
// Generate 1000 subdomains and check for uniqueness
subdomains := make(map[string]bool)
count := 1000
length := 6
for i := 0; i < count; i++ {
subdomain := GenerateSubdomain(length)
if subdomains[subdomain] {
t.Errorf("GenerateSubdomain() generated duplicate: %s", subdomain)
}
subdomains[subdomain] = true
}
if len(subdomains) != count {
t.Errorf("Expected %d unique subdomains, got %d", count, len(subdomains))
}
}
func TestValidateSubdomain(t *testing.T) {
tests := []struct {
name string
subdomain string
want bool
}{
{
name: "valid lowercase",
subdomain: "abc123",
want: true,
},
{
name: "valid all letters",
subdomain: "abcdef",
want: true,
},
{
name: "valid all numbers",
subdomain: "123456",
want: true,
},
{
name: "invalid uppercase",
subdomain: "ABC123",
want: false,
},
{
name: "valid with hyphen",
subdomain: "abc-123",
want: true,
},
{
name: "invalid starting with hyphen",
subdomain: "-abc123",
want: false,
},
{
name: "invalid ending with hyphen",
subdomain: "abc123-",
want: false,
},
{
name: "invalid with underscore",
subdomain: "abc_123",
want: false,
},
{
name: "invalid with dot",
subdomain: "abc.123",
want: false,
},
{
name: "invalid with space",
subdomain: "abc 123",
want: false,
},
{
name: "invalid empty",
subdomain: "",
want: false,
},
{
name: "invalid special characters",
subdomain: "abc@123",
want: false,
},
{
name: "valid minimum length",
subdomain: "abc",
want: true,
},
{
name: "invalid too short",
subdomain: "ab",
want: false,
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
got := ValidateSubdomain(tt.subdomain)
if got != tt.want {
t.Errorf("ValidateSubdomain(%q) = %v, want %v", tt.subdomain, got, tt.want)
}
})
}
}
func TestIsReserved(t *testing.T) {
tests := []struct {
name string
subdomain string
want bool
}{
{
name: "reserved www",
subdomain: "www",
want: true,
},
{
name: "reserved api",
subdomain: "api",
want: true,
},
{
name: "reserved admin",
subdomain: "admin",
want: true,
},
{
name: "reserved mail",
subdomain: "mail",
want: true,
},
{
name: "reserved ftp",
subdomain: "ftp",
want: true,
},
{
name: "reserved health",
subdomain: "health",
want: true,
},
{
name: "reserved test",
subdomain: "test",
want: true,
},
{
name: "reserved dev",
subdomain: "dev",
want: true,
},
{
name: "reserved staging",
subdomain: "staging",
want: true,
},
{
name: "not reserved random",
subdomain: "abc123",
want: false,
},
{
name: "not reserved user",
subdomain: "myapp",
want: false,
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
got := IsReserved(tt.subdomain)
if got != tt.want {
t.Errorf("IsReserved(%q) = %v, want %v", tt.subdomain, got, tt.want)
}
})
}
}
// Helper function to check if a character is alphanumeric
func isAlphanumeric(char rune) bool {
return (char >= 'a' && char <= 'z') || (char >= '0' && char <= '9')
}
// Benchmark tests
func BenchmarkGenerateSubdomain(b *testing.B) {
for i := 0; i < b.N; i++ {
GenerateSubdomain(6)
}
}
func BenchmarkValidateSubdomain(b *testing.B) {
subdomain := "abc123"
b.ResetTimer()
for i := 0; i < b.N; i++ {
ValidateSubdomain(subdomain)
}
}
func BenchmarkIsReserved(b *testing.B) {
subdomain := "www"
b.ResetTimer()
for i := 0; i < b.N; i++ {
IsReserved(subdomain)
}
}

60
nginx.example.conf Normal file
View File

@@ -0,0 +1,60 @@
# Drip Tunnel Server - Nginx 配置
#
# 架构:外部用户 -> Nginx (443) -> Drip Server (8443) -> 客户端
#
# 前置条件:
# 1. 获取通配符 SSL 证书:
# certbot certonly --manual --preferred-challenges dns \
# -d "*.tunnel.example.com" -d "tunnel.example.com"
#
# 2. DNS 配置:
# A tunnel.example.com -> YOUR_SERVER_IP
# A *.tunnel.example.com -> YOUR_SERVER_IP
#
# 3. 启动 Drip Server
# ./bin/drip-server --port 8443 --domain tunnel.example.com \
# --tls-cert /etc/letsencrypt/live/tunnel.example.com/fullchain.pem \
# --tls-key /etc/letsencrypt/live/tunnel.example.com/privkey.pem
# HTTP 重定向到 HTTPS
server {
listen 80;
server_name tunnel.example.com *.tunnel.example.com;
return 301 https://$host$request_uri;
}
# HTTPS 代理到 Drip Server
server {
listen 443 ssl http2;
server_name tunnel.example.com *.tunnel.example.com;
# SSL 证书
ssl_certificate /etc/letsencrypt/live/tunnel.example.com/fullchain.pem;
ssl_certificate_key /etc/letsencrypt/live/tunnel.example.com/privkey.pem;
ssl_protocols TLSv1.2 TLSv1.3;
# 代理到 Drip Server
location / {
proxy_pass https://127.0.0.1:8443;
proxy_ssl_verify off;
proxy_http_version 1.1;
# 转发请求头
proxy_set_header Host $host;
proxy_set_header X-Real-IP $remote_addr;
proxy_set_header X-Forwarded-For $proxy_add_x_forwarded_for;
proxy_set_header X-Forwarded-Proto $scheme;
# 超时配置
proxy_connect_timeout 60s;
proxy_send_timeout 300s;
proxy_read_timeout 300s;
# 禁用缓冲
proxy_buffering off;
proxy_request_buffering off;
# 大文件支持
client_max_body_size 100m;
}
}

View File

@@ -0,0 +1,87 @@
package config
import (
"fmt"
"os"
"path/filepath"
"gopkg.in/yaml.v3"
)
// ClientConfig represents the client configuration
type ClientConfig struct {
Server string `yaml:"server"` // Server address (e.g., tunnel.example.com:443)
Token string `yaml:"token"` // Authentication token
TLS bool `yaml:"tls"` // Use TLS (always true for production)
}
// DefaultClientConfig returns the default configuration path
func DefaultClientConfigPath() string {
home, err := os.UserHomeDir()
if err != nil {
return ".drip/config.yaml"
}
return filepath.Join(home, ".drip", "config.yaml")
}
// LoadClientConfig loads configuration from file
func LoadClientConfig(path string) (*ClientConfig, error) {
if path == "" {
path = DefaultClientConfigPath()
}
data, err := os.ReadFile(path)
if err != nil {
if os.IsNotExist(err) {
return nil, fmt.Errorf("config file not found at %s, please run 'drip config init' first", path)
}
return nil, fmt.Errorf("failed to read config file: %w", err)
}
var config ClientConfig
if err := yaml.Unmarshal(data, &config); err != nil {
return nil, fmt.Errorf("failed to parse config file: %w", err)
}
// Validate required fields
if config.Server == "" {
return nil, fmt.Errorf("server address is required in config")
}
return &config, nil
}
// SaveClientConfig saves configuration to file
func SaveClientConfig(config *ClientConfig, path string) error {
if path == "" {
path = DefaultClientConfigPath()
}
// Create directory if not exists
dir := filepath.Dir(path)
if err := os.MkdirAll(dir, 0700); err != nil {
return fmt.Errorf("failed to create config directory: %w", err)
}
// Marshal to YAML
data, err := yaml.Marshal(config)
if err != nil {
return fmt.Errorf("failed to marshal config: %w", err)
}
// Write to file with secure permissions
if err := os.WriteFile(path, data, 0600); err != nil {
return fmt.Errorf("failed to write config file: %w", err)
}
return nil
}
// ConfigExists checks if config file exists
func ConfigExists(path string) bool {
if path == "" {
path = DefaultClientConfigPath()
}
_, err := os.Stat(path)
return err == nil
}

130
pkg/config/config.go Normal file
View File

@@ -0,0 +1,130 @@
package config
import (
"crypto/tls"
"fmt"
"os"
)
// ServerConfig holds the server configuration
type ServerConfig struct {
// Server settings
Port int
PublicPort int // Port to display in URLs (for reverse proxy scenarios)
Domain string
// TCP tunnel dynamic port allocation
TCPPortMin int
TCPPortMax int
// TLS/SSL settings
TLSEnabled bool
TLSCertFile string
TLSKeyFile string
AutoTLS bool // Automatic Let's Encrypt
// Security
AuthToken string
// Logging
Debug bool
}
// LegacyClientConfig holds the legacy client configuration
// Deprecated: Use config.ClientConfig from client_config.go instead
type LegacyClientConfig struct {
ServerURL string
LocalTarget string
AuthToken string
Subdomain string
Verbose bool
Insecure bool // Skip TLS verification (for testing only)
}
// LoadTLSConfig loads TLS configuration
func (c *ServerConfig) LoadTLSConfig() (*tls.Config, error) {
if !c.TLSEnabled {
return nil, nil
}
// Check if certificate files exist
if c.TLSCertFile == "" || c.TLSKeyFile == "" {
return nil, fmt.Errorf("TLS enabled but certificate files not specified")
}
if _, err := os.Stat(c.TLSCertFile); os.IsNotExist(err) {
return nil, fmt.Errorf("certificate file not found: %s", c.TLSCertFile)
}
if _, err := os.Stat(c.TLSKeyFile); os.IsNotExist(err) {
return nil, fmt.Errorf("key file not found: %s", c.TLSKeyFile)
}
// Load certificate
cert, err := tls.LoadX509KeyPair(c.TLSCertFile, c.TLSKeyFile)
if err != nil {
return nil, fmt.Errorf("failed to load certificate: %w", err)
}
// Force TLS 1.3 only
tlsConfig := &tls.Config{
Certificates: []tls.Certificate{cert},
MinVersion: tls.VersionTLS13, // Only TLS 1.3
MaxVersion: tls.VersionTLS13, // Only TLS 1.3
CipherSuites: []uint16{
tls.TLS_AES_128_GCM_SHA256,
tls.TLS_AES_256_GCM_SHA384,
tls.TLS_CHACHA20_POLY1305_SHA256,
},
}
return tlsConfig, nil
}
// GetClientTLSConfig returns TLS config for client connections
func GetClientTLSConfig(serverName string) *tls.Config {
return &tls.Config{
ServerName: serverName,
MinVersion: tls.VersionTLS13, // Only TLS 1.3
MaxVersion: tls.VersionTLS13, // Only TLS 1.3
CipherSuites: []uint16{
tls.TLS_AES_128_GCM_SHA256,
tls.TLS_AES_256_GCM_SHA384,
tls.TLS_CHACHA20_POLY1305_SHA256,
},
}
}
// GetClientTLSConfigInsecure returns TLS config for client with InsecureSkipVerify
// WARNING: Only use for testing!
func GetClientTLSConfigInsecure() *tls.Config {
return &tls.Config{
InsecureSkipVerify: true,
MinVersion: tls.VersionTLS13, // Only TLS 1.3
MaxVersion: tls.VersionTLS13, // Only TLS 1.3
CipherSuites: []uint16{
tls.TLS_AES_128_GCM_SHA256,
tls.TLS_AES_256_GCM_SHA384,
tls.TLS_CHACHA20_POLY1305_SHA256,
},
}
}
// GetServerURL returns the server URL based on configuration
func (c *ServerConfig) GetServerURL() string {
protocol := "http"
if c.TLSEnabled {
protocol = "https"
}
if c.Port == 80 || (c.TLSEnabled && c.Port == 443) {
return fmt.Sprintf("%s://%s", protocol, c.Domain)
}
return fmt.Sprintf("%s://%s:%d", protocol, c.Domain, c.Port)
}
// GetTCPAddress returns the TCP address for tunnel connections
func (c *ServerConfig) GetTCPAddress() string {
return fmt.Sprintf("%s:%d", c.Domain, c.Port)
}

52
scripts/generate-cert.sh Executable file
View File

@@ -0,0 +1,52 @@
#!/bin/bash
# Generate self-signed certificate for development/testing using ECDSA
# ECDSA provides better performance and smaller key size than RSA
# WARNING: Do NOT use self-signed certificates in production!
set -e
CERT_DIR="${1:-./certs}"
DOMAIN="${2:-localhost}"
DAYS=365
echo "🔒 Generating self-signed certificate for development..."
echo " Domain: $DOMAIN"
echo " Output directory: $CERT_DIR"
echo ""
# Create directory if it doesn't exist
mkdir -p "$CERT_DIR"
# Generate ECDSA private key (using P-256 curve)
openssl ecparam -genkey -name prime256v1 -out "$CERT_DIR/server.key"
# Generate certificate signing request
openssl req -new -key "$CERT_DIR/server.key" -out "$CERT_DIR/server.csr" \
-subj "/C=US/ST=State/L=City/O=Organization/CN=$DOMAIN"
# Generate self-signed certificate
openssl x509 -req -days $DAYS \
-in "$CERT_DIR/server.csr" \
-signkey "$CERT_DIR/server.key" \
-out "$CERT_DIR/server.crt" \
-extfile <(printf "subjectAltName=DNS:$DOMAIN,DNS:*.$DOMAIN")
# Clean up CSR
rm "$CERT_DIR/server.csr"
echo "✅ Certificate generated successfully!"
echo ""
echo "Files created:"
echo " Certificate: $CERT_DIR/server.crt"
echo " Private Key: $CERT_DIR/server.key"
echo ""
echo "Usage:"
echo " ./bin/drip server \\"
echo " --domain $DOMAIN \\"
echo " --tls-cert $CERT_DIR/server.crt \\"
echo " --tls-key $CERT_DIR/server.key"
echo ""
echo "⚠️ WARNING: This is a self-signed certificate for development only!"
echo " Clients will need to skip certificate verification (insecure)."
echo " For production, use --auto-tls or get a certificate from Let's Encrypt."

1007
scripts/install-server.sh Executable file

File diff suppressed because it is too large Load Diff

642
scripts/install.sh Executable file
View File

@@ -0,0 +1,642 @@
#!/bin/bash
set -e
# ============================================================================
# Configuration
# ============================================================================
DOWNLOAD_BASE_URL="https://"
INSTALL_DIR="${INSTALL_DIR:-}"
VERSION="${VERSION:-latest}"
BINARY_NAME="drip"
# Colors
RED='\033[0;31m'
GREEN='\033[0;32m'
YELLOW='\033[1;33m'
BLUE='\033[0;34m'
CYAN='\033[0;36m'
BOLD='\033[1m'
NC='\033[0m'
# Language (default: en)
LANG_CODE="${LANG_CODE:-en}"
# ============================================================================
# Internationalization
# ============================================================================
declare -A MSG_EN
declare -A MSG_ZH
# English messages
MSG_EN=(
["banner_title"]="Drip Client - One-Click Installer"
["select_lang"]="Select language / 选择语言"
["lang_en"]="English"
["lang_zh"]="中文"
["checking_os"]="Checking operating system..."
["detected_os"]="Detected OS"
["unsupported_os"]="Unsupported operating system"
["checking_arch"]="Checking system architecture..."
["detected_arch"]="Detected architecture"
["unsupported_arch"]="Unsupported architecture"
["checking_deps"]="Checking dependencies..."
["deps_ok"]="Dependencies check passed"
["downloading"]="Downloading Drip client"
["download_failed"]="Download failed"
["download_ok"]="Download completed"
["select_install_dir"]="Select installation directory"
["option_user"]="User directory (no sudo required)"
["option_system"]="System directory (requires sudo)"
["option_current"]="Current directory"
["option_custom"]="Custom path"
["enter_custom_path"]="Enter custom path"
["installing"]="Installing binary..."
["install_ok"]="Installation completed"
["updating_path"]="Updating PATH..."
["path_updated"]="PATH updated"
["path_note"]="Please restart your terminal or run: source ~/.bashrc"
["config_title"]="Client Configuration"
["configure_now"]="Configure client now?"
["enter_server"]="Enter server address (e.g., tunnel.example.com:8443)"
["server_required"]="Server address is required"
["enter_token"]="Enter authentication token"
["token_required"]="Token is required"
["skip_verify"]="Skip TLS certificate verification? (for self-signed certs)"
["config_saved"]="Configuration saved"
["install_complete"]="Installation completed!"
["usage_title"]="Usage"
["usage_http"]="Expose HTTP service on port 3000"
["usage_tcp"]="Expose TCP service on port 5432"
["usage_config"]="Show/modify configuration"
["usage_daemon"]="Run as background daemon"
["run_test"]="Test connection now?"
["test_running"]="Testing connection..."
["test_success"]="Connection successful"
["test_failed"]="Connection failed"
["yes"]="y"
["no"]="n"
["press_enter"]="Press Enter to continue..."
["windows_note"]="For Windows, please download the .exe file from GitHub Releases"
["already_installed"]="Drip is already installed"
["update_now"]="Update to the latest version?"
["updating"]="Updating..."
["update_ok"]="Update completed"
["verify_install"]="Verifying installation..."
["verify_ok"]="Verification passed"
["verify_failed"]="Verification failed"
["insecure_note"]="Only use --insecure for development/testing"
)
# Chinese messages
MSG_ZH=(
["banner_title"]="Drip 客户端 - 一键安装脚本"
["select_lang"]="Select language / 选择语言"
["lang_en"]="English"
["lang_zh"]="中文"
["checking_os"]="检查操作系统..."
["detected_os"]="检测到操作系统"
["unsupported_os"]="不支持的操作系统"
["checking_arch"]="检查系统架构..."
["detected_arch"]="检测到架构"
["unsupported_arch"]="不支持的架构"
["checking_deps"]="检查依赖..."
["deps_ok"]="依赖检查通过"
["downloading"]="下载 Drip 客户端"
["download_failed"]="下载失败"
["download_ok"]="下载完成"
["select_install_dir"]="选择安装目录"
["option_user"]="用户目录(无需 sudo"
["option_system"]="系统目录(需要 sudo"
["option_current"]="当前目录"
["option_custom"]="自定义路径"
["enter_custom_path"]="输入自定义路径"
["installing"]="安装二进制文件..."
["install_ok"]="安装完成"
["updating_path"]="更新 PATH..."
["path_updated"]="PATH 已更新"
["path_note"]="请重启终端或运行: source ~/.bashrc"
["config_title"]="客户端配置"
["configure_now"]="现在配置客户端?"
["enter_server"]="输入服务器地址例如tunnel.example.com:8443"
["server_required"]="服务器地址是必填项"
["enter_token"]="输入认证令牌"
["token_required"]="认证令牌是必填项"
["skip_verify"]="跳过 TLS 证书验证?(用于自签名证书)"
["config_saved"]="配置已保存"
["install_complete"]="安装完成!"
["usage_title"]="使用方法"
["usage_http"]="暴露本地 3000 端口的 HTTP 服务"
["usage_tcp"]="暴露本地 5432 端口的 TCP 服务"
["usage_config"]="显示/修改配置"
["usage_daemon"]="作为后台守护进程运行"
["run_test"]="现在测试连接?"
["test_running"]="正在测试连接..."
["test_success"]="连接成功"
["test_failed"]="连接失败"
["yes"]="y"
["no"]="n"
["press_enter"]="按 Enter 继续..."
["windows_note"]="Windows 用户请从 GitHub Releases 下载 .exe 文件"
["already_installed"]="Drip 已安装"
["update_now"]="是否更新到最新版本?"
["updating"]="正在更新..."
["update_ok"]="更新完成"
["verify_install"]="验证安装..."
["verify_ok"]="验证通过"
["verify_failed"]="验证失败"
["insecure_note"]="--insecure 仅用于开发/测试环境"
)
# Get message by key
msg() {
local key="$1"
if [[ "$LANG_CODE" == "zh" ]]; then
echo "${MSG_ZH[$key]:-$key}"
else
echo "${MSG_EN[$key]:-$key}"
fi
}
# ============================================================================
# Output functions
# ============================================================================
print_info() { echo -e "${BLUE}[INFO]${NC} $1"; }
print_success() { echo -e "${GREEN}[✓]${NC} $1"; }
print_warning() { echo -e "${YELLOW}[!]${NC} $1"; }
print_error() { echo -e "${RED}[✗]${NC} $1"; }
print_step() { echo -e "${CYAN}[→]${NC} $1"; }
# Print banner
print_banner() {
echo -e "${GREEN}"
cat << "EOF"
____ _
/ __ \_____(_)___
/ / / / ___/ / __ \
/ /_/ / / / / /_/ /
/_____/_/ /_/ .___/
/_/
EOF
echo -e "${BOLD}$(msg banner_title)${NC}"
echo ""
}
# ============================================================================
# Language selection
# ============================================================================
select_language() {
echo ""
echo -e "${CYAN}╔════════════════════════════════════════╗${NC}"
echo -e "${CYAN}$(msg select_lang)${NC}"
echo -e "${CYAN}╠════════════════════════════════════════╣${NC}"
echo -e "${CYAN}${NC} ${GREEN}1)${NC} English ${CYAN}${NC}"
echo -e "${CYAN}${NC} ${GREEN}2)${NC} 中文 ${CYAN}${NC}"
echo -e "${CYAN}╚════════════════════════════════════════╝${NC}"
echo ""
read -p "Select [1]: " lang_choice < /dev/tty
case "$lang_choice" in
2)
LANG_CODE="zh"
;;
*)
LANG_CODE="en"
;;
esac
echo ""
}
# ============================================================================
# System checks
# ============================================================================
check_os() {
print_step "$(msg checking_os)"
case "$(uname -s)" in
Linux*)
OS="linux"
;;
Darwin*)
OS="darwin"
;;
MINGW*|MSYS*|CYGWIN*)
OS="windows"
print_warning "$(msg windows_note)"
;;
*)
print_error "$(msg unsupported_os): $(uname -s)"
exit 1
;;
esac
print_success "$(msg detected_os): $OS"
}
check_arch() {
print_step "$(msg checking_arch)"
case "$(uname -m)" in
x86_64|amd64)
ARCH="amd64"
;;
aarch64|arm64)
ARCH="arm64"
;;
armv7l)
ARCH="arm"
;;
i386|i686)
ARCH="386"
;;
*)
print_error "$(msg unsupported_arch): $(uname -m)"
exit 1
;;
esac
print_success "$(msg detected_arch): $ARCH"
}
check_dependencies() {
print_step "$(msg checking_deps)"
# Check for download tool
if ! command -v curl &> /dev/null && ! command -v wget &> /dev/null; then
print_error "curl or wget is required"
exit 1
fi
print_success "$(msg deps_ok)"
}
get_remote_version() {
# Determine binary name based on OS and ARCH
local binary_file="drip-${OS}-${ARCH}"
if [[ "$OS" == "windows" ]]; then
binary_file="${binary_file}.exe"
fi
local download_url="${DOWNLOAD_BASE_URL}/${VERSION}/${binary_file}"
local tmp_file="/tmp/drip-check-$$"
if command -v curl &> /dev/null; then
curl -fsSL "$download_url" -o "$tmp_file" 2>/dev/null || return 1
else
wget -q "$download_url" -O "$tmp_file" 2>/dev/null || return 1
fi
chmod +x "$tmp_file"
local remote_version=$("$tmp_file" version 2>/dev/null | awk '/Version:/ {print $2}' || echo "unknown")
rm -f "$tmp_file"
echo "$remote_version"
}
check_existing_install() {
if command -v drip &> /dev/null; then
local current_path=$(command -v drip)
local current_version=$(drip version 2>/dev/null | awk '/Version:/ {print $2}' || echo "unknown")
print_warning "$(msg already_installed): $current_path"
print_info "$(msg current_version): $current_version"
# Check remote version
print_step "Checking for updates..."
local remote_version=$(get_remote_version)
if [[ "$remote_version" == "unknown" ]]; then
print_warning "Could not check remote version"
echo ""
read -p "$(msg update_now) [Y/n]: " update_choice < /dev/tty
elif [[ "$current_version" == "$remote_version" ]]; then
print_success "Already up to date ($current_version)"
exit 0
else
print_info "Latest version: $remote_version"
echo ""
read -p "$(msg update_now) [Y/n]: " update_choice < /dev/tty
fi
if [[ "$update_choice" =~ ^[Nn]$ ]]; then
exit 0
fi
INSTALL_DIR=$(dirname "$current_path")
IS_UPDATE=true
fi
}
# ============================================================================
# Download and install
# ============================================================================
get_download_url() {
local binary_name
if [[ "$OS" == "windows" ]]; then
binary_name="drip-windows-${ARCH}.exe"
else
binary_name="drip-${OS}-${ARCH}"
fi
echo "${DOWNLOAD_BASE_URL}/${VERSION}/${binary_name}"
}
download_binary() {
local url=$(get_download_url)
if [[ "$IS_UPDATE" == true ]]; then
print_step "$(msg updating)..."
else
print_step "$(msg downloading)..."
fi
local tmp_file="/tmp/drip-download"
if command -v curl &> /dev/null; then
# Use -# for progress bar instead of -s (silent)
if ! curl -f#L "$url" -o "$tmp_file"; then
print_error "$(msg download_failed): $url"
exit 1
fi
else
# Use --show-progress to display download progress
if ! wget --show-progress "$url" -O "$tmp_file" 2>&1 | grep -v "^$"; then
print_error "$(msg download_failed): $url"
exit 1
fi
fi
chmod +x "$tmp_file"
print_success "$(msg download_ok)"
}
select_install_dir() {
if [[ -n "$INSTALL_DIR" ]]; then
return
fi
echo ""
echo -e "${CYAN}╔════════════════════════════════════════╗${NC}"
echo -e "${CYAN}$(msg select_install_dir) ${CYAN}${NC}"
echo -e "${CYAN}╠════════════════════════════════════════╣${NC}"
echo -e "${CYAN}${NC} ${GREEN}1)${NC} ~/.local/bin $(msg option_user) ${CYAN}${NC}"
echo -e "${CYAN}${NC} ${GREEN}2)${NC} /usr/local/bin $(msg option_system) ${CYAN}${NC}"
echo -e "${CYAN}${NC} ${GREEN}3)${NC} ./ $(msg option_current) ${CYAN}${NC}"
echo -e "${CYAN}${NC} ${GREEN}4)${NC} $(msg option_custom) ${CYAN}${NC}"
echo -e "${CYAN}╚════════════════════════════════════════╝${NC}"
echo ""
read -p "Select [1]: " dir_choice < /dev/tty
case "$dir_choice" in
2)
INSTALL_DIR="/usr/local/bin"
NEED_SUDO=true
;;
3)
INSTALL_DIR="."
;;
4)
read -p "$(msg enter_custom_path): " INSTALL_DIR < /dev/tty
;;
*)
INSTALL_DIR="$HOME/.local/bin"
;;
esac
}
install_binary() {
print_step "$(msg installing)"
# Create directory if needed
if [[ ! -d "$INSTALL_DIR" ]]; then
if [[ "$NEED_SUDO" == true ]]; then
sudo mkdir -p "$INSTALL_DIR"
else
mkdir -p "$INSTALL_DIR"
fi
fi
local target_path="$INSTALL_DIR/$BINARY_NAME"
if [[ "$OS" == "windows" ]]; then
target_path="$INSTALL_DIR/$BINARY_NAME.exe"
fi
# Install binary
if [[ "$NEED_SUDO" == true ]]; then
sudo mv /tmp/drip-download "$target_path"
sudo chmod +x "$target_path"
else
mv /tmp/drip-download "$target_path"
chmod +x "$target_path"
fi
print_success "$(msg install_ok): $target_path"
}
update_path() {
# Skip if already in PATH
if command -v drip &> /dev/null; then
return
fi
# Skip for system directories (usually already in PATH)
if [[ "$INSTALL_DIR" == "/usr/local/bin" ]] || [[ "$INSTALL_DIR" == "/usr/bin" ]]; then
return
fi
print_step "$(msg updating_path)"
local shell_rc=""
local export_line="export PATH=\"\$PATH:$INSTALL_DIR\""
# Determine shell config file
if [[ -n "$ZSH_VERSION" ]] || [[ "$SHELL" == *"zsh"* ]]; then
shell_rc="$HOME/.zshrc"
elif [[ -n "$BASH_VERSION" ]] || [[ "$SHELL" == *"bash"* ]]; then
if [[ "$OS" == "darwin" ]]; then
shell_rc="$HOME/.bash_profile"
else
shell_rc="$HOME/.bashrc"
fi
elif [[ "$SHELL" == *"fish"* ]]; then
shell_rc="$HOME/.config/fish/config.fish"
export_line="set -gx PATH \$PATH $INSTALL_DIR"
fi
if [[ -n "$shell_rc" ]]; then
# Check if already added
if ! grep -q "$INSTALL_DIR" "$shell_rc" 2>/dev/null; then
echo "" >> "$shell_rc"
echo "# Drip client" >> "$shell_rc"
echo "$export_line" >> "$shell_rc"
print_success "$(msg path_updated): $shell_rc"
fi
fi
print_warning "$(msg path_note)"
}
verify_installation() {
print_step "$(msg verify_install)"
local binary_path="$INSTALL_DIR/$BINARY_NAME"
if [[ "$OS" == "windows" ]]; then
binary_path="$INSTALL_DIR/$BINARY_NAME.exe"
fi
if [[ -x "$binary_path" ]]; then
local version=$("$binary_path" version 2>/dev/null | awk '/Version:/ {print $2}' || echo "installed")
print_success "$(msg verify_ok): $version"
else
print_error "$(msg verify_failed)"
exit 1
fi
}
# ============================================================================
# Configuration
# ============================================================================
configure_client() {
echo ""
read -p "$(msg configure_now) [Y/n]: " config_choice < /dev/tty
if [[ "$config_choice" =~ ^[Nn]$ ]]; then
return
fi
echo ""
echo -e "${CYAN}╔════════════════════════════════════════╗${NC}"
echo -e "${CYAN}$(msg config_title) ${CYAN}${NC}"
echo -e "${CYAN}╚════════════════════════════════════════╝${NC}"
echo ""
local binary_path="$INSTALL_DIR/$BINARY_NAME"
# Server address
while true; do
read -p "$(msg enter_server): " SERVER < /dev/tty
if [[ -n "$SERVER" ]]; then
break
fi
print_error "$(msg server_required)"
done
# Token
while true; do
read -p "$(msg enter_token): " TOKEN < /dev/tty
if [[ -n "$TOKEN" ]]; then
break
fi
print_error "$(msg token_required)"
done
# Insecure mode
read -p "$(msg skip_verify) [y/N]: " insecure_choice < /dev/tty
INSECURE=""
if [[ "$insecure_choice" =~ ^[Yy]$ ]]; then
INSECURE="--insecure"
print_warning "$(msg insecure_note)"
fi
# Save configuration
"$binary_path" config set --server "$SERVER" --token "$TOKEN" $INSECURE 2>/dev/null || true
print_success "$(msg config_saved)"
}
# ============================================================================
# Test connection
# ============================================================================
test_connection() {
echo ""
read -p "$(msg run_test) [y/N]: " test_choice < /dev/tty
if [[ ! "$test_choice" =~ ^[Yy]$ ]]; then
return
fi
print_step "$(msg test_running)"
local binary_path="$INSTALL_DIR/$BINARY_NAME"
# Try to validate config
if "$binary_path" config validate 2>/dev/null; then
print_success "$(msg test_success)"
else
print_warning "$(msg test_failed)"
fi
}
# ============================================================================
# Final output
# ============================================================================
show_completion() {
local binary_path="$INSTALL_DIR/$BINARY_NAME"
echo ""
echo -e "${GREEN}╔════════════════════════════════════════════════════════════╗${NC}"
echo -e "${GREEN}$(msg install_complete) ${GREEN}${NC}"
echo -e "${GREEN}╚════════════════════════════════════════════════════════════╝${NC}"
echo ""
echo -e "${CYAN}$(msg usage_title):${NC}"
echo ""
echo -e " ${GREEN}# $(msg usage_http)${NC}"
echo -e " ${YELLOW}$BINARY_NAME http 3000${NC}"
echo ""
echo -e " ${GREEN}# $(msg usage_tcp)${NC}"
echo -e " ${YELLOW}$BINARY_NAME tcp 5432${NC}"
echo ""
echo -e " ${GREEN}# $(msg usage_config)${NC}"
echo -e " ${YELLOW}$BINARY_NAME config show${NC}"
echo -e " ${YELLOW}$BINARY_NAME config init${NC}"
echo ""
echo -e " ${GREEN}# $(msg usage_daemon)${NC}"
echo -e " ${YELLOW}$BINARY_NAME daemon start http 3000${NC}"
echo -e " ${YELLOW}$BINARY_NAME daemon list${NC}"
echo ""
}
# ============================================================================
# Main
# ============================================================================
main() {
clear
print_banner
select_language
echo -e "${BOLD}────────────────────────────────────────────${NC}"
check_os
check_arch
check_dependencies
check_existing_install
echo ""
download_binary
select_install_dir
install_binary
update_path
verify_installation
# Skip configuration for updates
if [[ "$IS_UPDATE" != true ]]; then
configure_client
test_connection
else
echo ""
local new_version=$("$INSTALL_DIR/$BINARY_NAME" version 2>/dev/null | awk '/Version:/ {print $2}' || echo "installed")
echo -e "${GREEN}╔════════════════════════════════════════════════════════════╗${NC}"
echo -e "${GREEN}$(msg update_ok) ${GREEN}${NC}"
echo -e "${GREEN}╚════════════════════════════════════════════════════════════╝${NC}"
echo ""
print_info "Version: $new_version"
echo ""
return
fi
show_completion
}
# Run
main "$@"

503
scripts/test/one-click-test.sh Executable file
View File

@@ -0,0 +1,503 @@
#!/bin/bash
# Drip One-Click Performance Test Script
# Automatically starts all services, runs tests, and generates reports
set -e
# Color definitions
RED='\033[0;31m'
GREEN='\033[0;32m'
YELLOW='\033[1;33m'
BLUE='\033[0;34m'
NC='\033[0m' # No Color
# Configuration
RESULTS_DIR="benchmark-results"
TIMESTAMP=$(date +%Y%m%d_%H%M%S)
LOG_DIR="/tmp/drip-test-${TIMESTAMP}"
REPORT_FILE="${RESULTS_DIR}/test-report-${TIMESTAMP}.txt"
# Port configuration
HTTP_TEST_PORT=3000
DRIP_SERVER_PORT=8443
PPROF_PORT=6060
# PID file
PIDS_FILE="${LOG_DIR}/pids.txt"
# Create directories
mkdir -p "$RESULTS_DIR"
mkdir -p "$LOG_DIR"
# ============================================
# Helper functions
# ============================================
# All logs go to stderr to avoid being consumed by command substitution $(...)
log_info() {
echo -e "${GREEN}[INFO]${NC} $1" >&2
}
log_warn() {
echo -e "${YELLOW}[WARN]${NC} $1" >&2
}
log_error() {
echo -e "${RED}[ERROR]${NC} $1" >&2
}
log_step() {
echo -e "\n${BLUE}==>${NC} $1\n" >&2
}
# Cleanup function
cleanup() {
log_step "Cleaning up test environment..."
if [ -f "$PIDS_FILE" ]; then
log_info "Stopping all test processes..."
while read -r pid; do
if ps -p "$pid" > /dev/null 2>&1; then
log_info "Stopping process $pid"
kill "$pid" 2>/dev/null || true
fi
done < "$PIDS_FILE"
rm -f "$PIDS_FILE"
fi
# Extra cleanup: ensure ports are released
pkill -f "python.*${HTTP_TEST_PORT}" 2>/dev/null || true
pkill -f "drip server.*${DRIP_SERVER_PORT}" 2>/dev/null || true
pkill -f "drip http ${HTTP_TEST_PORT}" 2>/dev/null || true
log_info "Cleanup completed"
}
# Register cleanup function
trap cleanup EXIT INT TERM
# Check dependencies
check_dependencies() {
log_step "Checking dependencies..."
local missing=""
if ! command -v wrk &> /dev/null; then
missing="${missing}\n - wrk (brew install wrk)"
fi
if ! command -v python3 &> /dev/null; then
missing="${missing}\n - python3"
fi
if ! command -v openssl &> /dev/null; then
missing="${missing}\n - openssl"
fi
if ! command -v nc &> /dev/null; then
missing="${missing}\n - nc (netcat)"
fi
if [ ! -f "./bin/drip" ]; then
log_error "Cannot find drip executable"
log_info "Please run: make build"
exit 1
fi
if [ -n "$missing" ]; then
log_error "Missing dependencies:${missing}"
exit 1
fi
log_info "✓ All dependencies satisfied"
}
# Generate self-signed ECDSA certificate for testing
generate_test_certs() {
log_step "Generating test TLS certificate (ECDSA)..."
local cert_dir="${LOG_DIR}/certs"
mkdir -p "$cert_dir"
# Generate ECDSA private key (prime256v1 = P-256)
openssl ecparam -name prime256v1 -genkey -noout \
-out "${cert_dir}/server.key" >/dev/null 2>&1
# Generate self-signed certificate with this private key
openssl req -new -x509 \
-key "${cert_dir}/server.key" \
-out "${cert_dir}/server.crt" \
-days 1 \
-subj "/C=US/ST=Test/L=Test/O=Test/CN=localhost" \
>/dev/null 2>&1
if [ -f "${cert_dir}/server.crt" ] && [ -f "${cert_dir}/server.key" ]; then
log_info "✓ ECDSA test certificate generated"
# Note: this echo is the "return value", stdout only outputs this line
echo "${cert_dir}/server.crt ${cert_dir}/server.key"
else
log_error "ECDSA certificate generation failed"
exit 1
fi
}
# Wait for port to be available
wait_for_port() {
local port=$1
local max_wait=${2:-30}
local waited=0
while ! nc -z localhost "$port" 2>/dev/null; do
if [ "$waited" -ge "$max_wait" ]; then
return 1
fi
sleep 1
waited=$((waited + 1))
done
return 0
}
# Start HTTP test server
start_http_server() {
log_step "Starting HTTP test server (port $HTTP_TEST_PORT)..."
# Create simple test server
cat > "${LOG_DIR}/test-server.py" << 'EOF'
import http.server
import socketserver
import json
from datetime import datetime
import sys
PORT = int(sys.argv[1]) if len(sys.argv) > 1 else 3000
class TestHandler(http.server.SimpleHTTPRequestHandler):
def do_GET(self):
response = {
"status": "ok",
"timestamp": datetime.now().isoformat(),
"message": "Test server response"
}
self.send_response(200)
self.send_header('Content-type', 'application/json')
self.end_headers()
self.wfile.write(json.dumps(response).encode())
def log_message(self, format, *args):
pass # Silent logging
with socketserver.TCPServer(("", PORT), TestHandler) as httpd:
print(f"Server started on port {PORT}", flush=True)
httpd.serve_forever()
EOF
python3 "${LOG_DIR}/test-server.py" "$HTTP_TEST_PORT" \
> "${LOG_DIR}/http-server.log" 2>&1 &
local pid=$!
echo "$pid" >> "$PIDS_FILE"
if wait_for_port "$HTTP_TEST_PORT" 10; then
log_info "✓ HTTP test server started (PID: $pid)"
else
log_error "HTTP test server failed to start"
exit 1
fi
}
# Start Drip server
start_drip_server() {
log_step "Starting Drip server (port $DRIP_SERVER_PORT)..."
local cert_path=$1
local key_path=$2
./bin/drip server \
--port "$DRIP_SERVER_PORT" \
--domain localhost \
--tls-cert "$cert_path" \
--tls-key "$key_path" \
> "${LOG_DIR}/drip-server.log" 2>&1 &
local pid=$!
echo "$pid" >> "$PIDS_FILE"
if wait_for_port "$DRIP_SERVER_PORT" 10; then
log_info "✓ Drip server started (PID: $pid)"
else
log_error "Drip server failed to start"
cat "${LOG_DIR}/drip-server.log"
exit 1
fi
}
# Start Drip client and extract URL
start_drip_client() {
log_step "Starting Drip client..."
./bin/drip http "$HTTP_TEST_PORT" \
--server "localhost:${DRIP_SERVER_PORT}" \
--insecure \
> "${LOG_DIR}/drip-client.log" 2>&1 &
local pid=$!
echo "$pid" >> "$PIDS_FILE"
# Wait for client to start and extract URL
log_info "Waiting for tunnel to establish..."
sleep 3
# Extract tunnel URL from logs
local tunnel_url=""
local max_attempts=10
local attempt=0
while [ "$attempt" -lt "$max_attempts" ]; do
# Use grep to extract URL starting with https:// and remove ANSI color codes
tunnel_url=$(grep -oE 'https://[a-zA-Z0-9.-]+:[0-9]+' "${LOG_DIR}/drip-client.log" 2>/dev/null | head -1)
if [ -n "$tunnel_url" ]; then
break
fi
sleep 1
attempt=$((attempt + 1))
done
if [ -z "$tunnel_url" ]; then
log_error "Cannot get tunnel URL"
log_info "Client logs:"
cat "${LOG_DIR}/drip-client.log"
exit 1
fi
log_info "✓ Drip client started (PID: $pid)"
log_info "✓ Tunnel URL: $tunnel_url"
# Return URL
echo "$tunnel_url"
}
# Verify connectivity
verify_connectivity() {
local url=$1
log_step "Verifying tunnel connectivity..."
local max_attempts=5
local attempt=0
while [ "$attempt" -lt "$max_attempts" ]; do
if curl -sk --max-time 5 "$url" > /dev/null 2>&1; then
log_info "✓ Tunnel connectivity normal"
return 0
fi
attempt=$((attempt + 1))
log_warn "Attempt $attempt/$max_attempts..."
sleep 2
done
log_error "Tunnel connectivity test failed"
return 1
}
# Run performance tests
run_performance_tests() {
local url=$1
log_step "Starting performance tests..."
# Test 1: Quick benchmark
log_info "[1/3] Quick benchmark (10s)..."
wrk -t 4 -c 50 -d 10s --latency "$url" \
> "${RESULTS_DIR}/quick-benchmark-${TIMESTAMP}.txt" 2>&1
# Test 2: Standard load test
log_info "[2/3] Standard load test (30s)..."
wrk -t 8 -c 100 -d 30s --latency "$url" \
> "${RESULTS_DIR}/standard-benchmark-${TIMESTAMP}.txt" 2>&1
# Test 3: High concurrency test
log_info "[3/3] High concurrency test (30s)..."
wrk -t 12 -c 400 -d 30s --latency "$url" \
> "${RESULTS_DIR}/high-concurrency-${TIMESTAMP}.txt" 2>&1
log_info "✓ Performance tests completed"
}
# Generate test report
generate_report() {
log_step "Generating test report..."
cat > "$REPORT_FILE" << EOF
========================================
Drip Performance Test Report
========================================
Test Time: $(date)
Test Version: $(./bin/drip version 2>/dev/null | head -1 || echo "unknown")
========================================
Test Environment
========================================
OS: $(uname -s)
CPU Cores: $(sysctl -n hw.ncpu 2>/dev/null || nproc 2>/dev/null || echo "unknown")
Memory: $(sysctl -n hw.memsize 2>/dev/null | awk '{print int($1/1024/1024/1024)"GB"}' || echo "unknown")
========================================
Test Results
========================================
EOF
# Parse and add results from each test
if [ -f "${RESULTS_DIR}/quick-benchmark-${TIMESTAMP}.txt" ]; then
{
echo "--- Quick Benchmark (10s, 50 connections) ---"
grep "Requests/sec:" "${RESULTS_DIR}/quick-benchmark-${TIMESTAMP}.txt"
grep "Transfer/sec:" "${RESULTS_DIR}/quick-benchmark-${TIMESTAMP}.txt"
echo ""
grep "Latency" "${RESULTS_DIR}/quick-benchmark-${TIMESTAMP}.txt" | head -1
grep -A 3 "Latency Distribution" "${RESULTS_DIR}/quick-benchmark-${TIMESTAMP}.txt"
echo ""
} >> "$REPORT_FILE"
fi
if [ -f "${RESULTS_DIR}/standard-benchmark-${TIMESTAMP}.txt" ]; then
{
echo "--- Standard Load Test (30s, 100 connections) ---"
grep "Requests/sec:" "${RESULTS_DIR}/standard-benchmark-${TIMESTAMP}.txt"
grep "Transfer/sec:" "${RESULTS_DIR}/standard-benchmark-${TIMESTAMP}.txt"
echo ""
grep "Latency" "${RESULTS_DIR}/standard-benchmark-${TIMESTAMP}.txt" | head -1
grep -A 3 "Latency Distribution" "${RESULTS_DIR}/standard-benchmark-${TIMESTAMP}.txt"
echo ""
} >> "$REPORT_FILE"
fi
if [ -f "${RESULTS_DIR}/high-concurrency-${TIMESTAMP}.txt" ]; then
{
echo "--- High Concurrency Test (30s, 400 connections) ---"
grep "Requests/sec:" "${RESULTS_DIR}/high-concurrency-${TIMESTAMP}.txt"
grep "Transfer/sec:" "${RESULTS_DIR}/high-concurrency-${TIMESTAMP}.txt"
echo ""
grep "Latency" "${RESULTS_DIR}/high-concurrency-${TIMESTAMP}.txt" | head -1
grep -A 3 "Latency Distribution" "${RESULTS_DIR}/high-concurrency-${TIMESTAMP}.txt"
echo ""
} >> "$REPORT_FILE"
fi
# Add performance evaluation
cat >> "$REPORT_FILE" << 'EOF'
========================================
Performance Evaluation Standards
========================================
Excellent (Phase 1 optimization target):
✓ QPS > 5000
✓ P99 latency < 50ms
✓ Error rate = 0%
Good:
✓ QPS > 2000
✓ P99 latency < 100ms
✓ Error rate < 0.1%
Needs Optimization:
✗ QPS < 1000
✗ P99 latency > 200ms
========================================
Detailed Result Files
========================================
EOF
{
echo "Quick test: ${RESULTS_DIR}/quick-benchmark-${TIMESTAMP}.txt"
echo "Standard test: ${RESULTS_DIR}/standard-benchmark-${TIMESTAMP}.txt"
echo "High concurrency test: ${RESULTS_DIR}/high-concurrency-${TIMESTAMP}.txt"
echo ""
echo "Log directory: ${LOG_DIR}/"
} >> "$REPORT_FILE"
log_info "✓ Test report generated: $REPORT_FILE"
}
# Show report summary
show_summary() {
log_step "Test Results Summary"
if [ -f "${RESULTS_DIR}/standard-benchmark-${TIMESTAMP}.txt" ]; then
echo ""
echo "========================================="
echo " Standard Load Test Results"
echo "========================================="
grep "Requests/sec:" "${RESULTS_DIR}/standard-benchmark-${TIMESTAMP}.txt"
grep "Transfer/sec:" "${RESULTS_DIR}/standard-benchmark-${TIMESTAMP}.txt"
echo ""
grep "50%" "${RESULTS_DIR}/standard-benchmark-${TIMESTAMP}.txt"
grep "99%" "${RESULTS_DIR}/standard-benchmark-${TIMESTAMP}.txt"
echo "========================================="
echo ""
fi
log_info "Full report: $REPORT_FILE"
log_info "Detailed results: $RESULTS_DIR/"
}
# ============================================
# Main flow
# ============================================
main() {
clear
echo "========================================="
echo " Drip One-Click Performance Test"
echo "========================================="
echo ""
# Check dependencies
check_dependencies
# Generate test certificate (ECDSA), only capture the last line of stdout with paths
CERT_PATHS=$(generate_test_certs)
CERT_FILE=$(echo "$CERT_PATHS" | awk '{print $1}')
KEY_FILE=$(echo "$CERT_PATHS" | awk '{print $2}')
log_info "Using certificate: $CERT_FILE"
log_info "Using private key: $KEY_FILE"
# Start all services
start_http_server
start_drip_server "$CERT_FILE" "$KEY_FILE"
TUNNEL_URL=$(start_drip_client)
# Verify connectivity
if ! verify_connectivity "$TUNNEL_URL"; then
log_error "Test aborted: tunnel not accessible"
exit 1
fi
# Warm up
log_info "Warming up tunnel (5s)..."
for _ in {1..5}; do
curl -sk "$TUNNEL_URL" > /dev/null 2>&1 || true
sleep 1
done
# Run tests
run_performance_tests "$TUNNEL_URL"
# Generate report
generate_report
# Show summary
show_summary
log_step "Testing completed!"
echo ""
echo "Press any key to view full report, or Ctrl+C to exit..."
read -n 1 -s
cat "$REPORT_FILE"
}
# Run main flow
main