Merge pull request #3 from Gouryella/feat/major-performance-improvements

feat: Add HTTP streaming, compression support, and Docker deployment
This commit is contained in:
Gouryella
2025-12-05 22:21:46 +08:00
committed by GitHub
31 changed files with 2647 additions and 272 deletions

184
.github/workflows/docker.yml vendored Normal file
View File

@@ -0,0 +1,184 @@
name: Docker
on:
# Trigger when a release is published (after assets are uploaded)
release:
types: [published]
# Optional manual trigger
workflow_dispatch:
inputs:
version:
description: 'Release tag to use (e.g., v1.0.0 or latest)'
required: false
default: 'latest'
env:
REGISTRY: ghcr.io
IMAGE_NAME: ${{ github.repository }}
permissions:
contents: read
packages: write
jobs:
build-and-push:
name: Build and Push Docker Image
runs-on: ubuntu-latest
# For release event, only build for tags like v1.2.3
if: |
github.event_name == 'workflow_dispatch' ||
(github.event_name == 'release' && startsWith(github.event.release.tag_name, 'v'))
steps:
- name: Checkout code
uses: actions/checkout@v4
- name: Set up QEMU
uses: docker/setup-qemu-action@v3
- name: Set up Docker Buildx
uses: docker/setup-buildx-action@v3
- name: Log in to GitHub Container Registry
uses: docker/login-action@v3
with:
registry: ${{ env.REGISTRY }}
username: ${{ github.actor }}
password: ${{ secrets.GITHUB_TOKEN }}
- name: Log in to Docker Hub
uses: docker/login-action@v3
with:
username: ${{ secrets.DOCKERHUB_USERNAME }}
password: ${{ secrets.DOCKERHUB_TOKEN }}
continue-on-error: true
# Resolve VERSION:
# - release event: use release tag_name (e.g., v0.3.0)
# - workflow_dispatch: use input version (default: latest)
- name: Get version
id: version
run: |
if [ "${{ github.event_name }}" = "release" ]; then
v="${{ github.event.release.tag_name }}"
else
v="${{ github.event.inputs.version }}"
if [ -z "$v" ]; then
v="latest"
fi
fi
echo "VERSION=$v" >> "$GITHUB_OUTPUT"
echo "Resolved VERSION=$v"
# Ensure release assets exist before building
- name: Check release assets
id: check_assets
run: |
VERSION="${{ steps.version.outputs.VERSION }}"
REPO="${{ github.repository }}"
echo "Checking assets for $REPO, VERSION=$VERSION"
# For 'latest', we can only reliably ask the latest release API,
# the asset names are still versioned (drip-vX.Y.Z-linux-arch).
if [ "$VERSION" = "latest" ]; then
API_URL="https://api.github.com/repos/${REPO}/releases/latest"
echo "Using latest release API: $API_URL"
json=$(curl -fsSL "$API_URL")
# Check that assets for both amd64 and arm64 exist
echo "$json" | grep -q 'drip-.*linux-amd64' || missing_amd64=1
echo "$json" | grep -q 'drip-.*linux-arm64' || missing_arm64=1
if [ "${missing_amd64:-0}" -eq 0 ] && [ "${missing_arm64:-0}" -eq 0 ]; then
echo "assets_ready=true" >> "$GITHUB_OUTPUT"
echo "Assets found for both linux-amd64 and linux-arm64 (latest)."
else
echo "assets_ready=false" >> "$GITHUB_OUTPUT"
echo "Required assets for latest release are missing; build will be skipped."
fi
exit 0
fi
# For a specific version tag (e.g., v0.3.0) check direct download URLs
archs="amd64 arm64"
missing=0
for arch in $archs; do
url="https://github.com/${REPO}/releases/download/${VERSION}/drip-${VERSION}-linux-${arch}"
status=$(curl -o /dev/null -w "%{http_code}" -sL "$url")
echo "[$arch] HTTP $status -> $url"
if [ "$status" != "200" ]; then
missing=1
fi
done
if [ "$missing" -eq 0 ]; then
echo "assets_ready=true" >> "$GITHUB_OUTPUT"
echo "All required assets exist. Proceeding with build."
else
echo "assets_ready=false" >> "$GITHUB_OUTPUT"
echo "Required assets are missing; build will be skipped."
fi
- name: Skip build (assets not ready)
if: steps.check_assets.outputs.assets_ready != 'true'
run: |
echo "Release assets are not ready. Docker image build is skipped."
echo "You must upload all required release files (drip-<version>-linux-amd64/arm64) first."
- name: Extract metadata (tags & labels)
id: meta
if: steps.check_assets.outputs.assets_ready == 'true'
uses: docker/metadata-action@v5
with:
images: |
${{ env.REGISTRY }}/${{ env.IMAGE_NAME }}
${{ secrets.DOCKERHUB_USERNAME && format('docker.io/{0}/drip-server', secrets.DOCKERHUB_USERNAME) || '' }}
tags: |
# Main tag, e.g. v0.3.0 or latest
type=raw,value=${{ steps.version.outputs.VERSION }}
# Also tag 'latest' for convenience when using a specific version
type=raw,value=latest,enable=${{ steps.version.outputs.VERSION != 'latest' }}
- name: Build and push
if: steps.check_assets.outputs.assets_ready == 'true'
uses: docker/build-push-action@v5
with:
context: .
file: deployments/Dockerfile.release
platforms: linux/amd64,linux/arm64
push: true
tags: ${{ steps.meta.outputs.tags }}
labels: ${{ steps.meta.outputs.labels }}
build-args: |
VERSION=${{ steps.version.outputs.VERSION }}
cache-from: type=gha
cache-to: type=gha,mode=max
- name: Generate deployment summary
if: steps.check_assets.outputs.assets_ready == 'true'
run: |
echo "## 🐳 Docker Image Published" >> "$GITHUB_STEP_SUMMARY"
echo "" >> "$GITHUB_STEP_SUMMARY"
echo "**Version (GitHub Release tag or 'latest'):** \`${{ steps.version.outputs.VERSION }}\`" >> "$GITHUB_STEP_SUMMARY"
echo "" >> "$GITHUB_STEP_SUMMARY"
echo "### Pull from GHCR" >> "$GITHUB_STEP_SUMMARY"
echo "\`\`\`bash" >> "$GITHUB_STEP_SUMMARY"
echo "docker pull ${{ env.REGISTRY }}/${{ env.IMAGE_NAME }}:${{ steps.version.outputs.VERSION }}" >> "$GITHUB_STEP_SUMMARY"
echo "\`\`\`" >> "$GITHUB_STEP_SUMMARY"
echo "" >> "$GITHUB_STEP_SUMMARY"
echo "### Quick start" >> "$GITHUB_STEP_SUMMARY"
echo "\`\`\`bash" >> "$GITHUB_STEP_SUMMARY"
echo "docker run -d \\\\" >> "$GITHUB_STEP_SUMMARY"
echo " --name drip-server \\\\" >> "$GITHUB_STEP_SUMMARY"
echo " -p 443:443 \\\\" >> "$GITHUB_STEP_SUMMARY"
echo " -v /path/to/certs:/app/data/certs:ro \\\\" >> "$GITHUB_STEP_SUMMARY"
echo " ${{ env.REGISTRY }}/${{ env.IMAGE_NAME }}:${{ steps.version.outputs.VERSION }} \\\\" >> "$GITHUB_STEP_SUMMARY"
echo " server --domain your.domain.com --port 443 \\\\" >> "$GITHUB_STEP_SUMMARY"
echo " --tls-cert /app/data/certs/fullchain.pem \\\\" >> "$GITHUB_STEP_SUMMARY"
echo " --tls-key /app/data/certs/privkey.pem" >> "$GITHUB_STEP_SUMMARY"
echo "\`\`\`" >> "$GITHUB_STEP_SUMMARY"

3
.gitignore vendored
View File

@@ -51,4 +51,5 @@ tmp/
temp/
certs/
.drip-server.env
benchmark-results/
benchmark-results/
drip

View File

@@ -66,6 +66,13 @@ bash <(curl -sL https://raw.githubusercontent.com/Gouryella/drip/main/scripts/in
## Usage
### First Time Setup
```bash
# Configure server and token (only needed once)
drip config init
```
### Basic Tunnels
```bash
@@ -76,7 +83,7 @@ drip http 3000
drip https 443
# Pick your subdomain
drip http 3000 --subdomain myapp
drip http 3000 -n myapp
# → https://myapp.your-domain.com
# Expose TCP service (database, SSH, etc.)
@@ -89,28 +96,33 @@ Not just localhost - forward to any device on your network:
```bash
# Forward to another machine on LAN
drip http 8080 --address 192.168.1.100
drip http 8080 -a 192.168.1.100
# Forward to Docker container
drip http 3000 --address 172.17.0.2
drip http 3000 -a 172.17.0.2
# Forward to specific interface
drip http 3000 --address 10.0.0.5
drip http 3000 -a 10.0.0.5
```
### Daemon Mode
### Background Mode
Run tunnels in the background:
Run tunnels in the background with `-d`:
```bash
# Start tunnel as daemon
drip daemon start http 3000
drip daemon start https 8443 --subdomain api
# Start tunnel in background
drip http 3000 -d
drip https 8443 -n api -d
# Manage daemons
drip daemon list
drip daemon stop http-3000
drip daemon logs http-3000
# List running tunnels
drip list
# View tunnel logs
drip attach http 3000
# Stop tunnels
drip stop http 3000
drip stop all
```
## Server Deployment
@@ -214,25 +226,25 @@ sudo journalctl -u drip-server -f
drip http 3000
# Test webhooks from services like Stripe
drip http 8000 --subdomain webhooks
drip http 8000 -n webhooks
```
**Home Server Access**
```bash
# Access home NAS remotely
drip http 5000 --address 192.168.1.50
drip http 5000 -a 192.168.1.50
# Remote into home network
# Remote into home network via SSH
drip tcp 22
```
**Docker & Containers**
```bash
# Expose containerized app
drip http 8080 --address 172.17.0.3
drip http 8080 -a 172.17.0.3
# Database access for debugging
drip tcp 5432 --address db-container
drip tcp 5432 -a db-container
```
## Command Reference
@@ -240,26 +252,29 @@ drip tcp 5432 --address db-container
```bash
# HTTP tunnel
drip http <port> [flags]
--subdomain, -n Custom subdomain
--address, -a Target address (default: 127.0.0.1)
--server Server address
--token Auth token
-n, --subdomain Custom subdomain
-a, --address Target address (default: 127.0.0.1)
-d, --daemon Run in background
-s, --server Server address
-t, --token Auth token
# HTTPS tunnel
# HTTPS tunnel (same flags as http)
drip https <port> [flags]
# TCP tunnel
# TCP tunnel (same flags as http)
drip tcp <port> [flags]
# Daemon commands
drip daemon start <type> <port> [flags]
drip daemon list
drip daemon stop <name>
drip daemon logs <name>
# Background tunnel management
drip list List running tunnels
drip list -i Interactive mode
drip attach [type] [port] View logs
drip stop <type> <port> Stop tunnel
drip stop all Stop all tunnels
# Configuration
drip config init
drip config show
drip config init Set up server and token
drip config show Show current config
drip config set <key> <value>
```
## License

View File

@@ -66,6 +66,13 @@ bash <(curl -sL https://raw.githubusercontent.com/Gouryella/drip/main/scripts/in
## 使用
### 首次配置
```bash
# 配置服务器地址和 token只需一次
drip config init
```
### 基础隧道
```bash
@@ -76,7 +83,7 @@ drip http 3000
drip https 443
# 选择你的子域名
drip http 3000 --subdomain myapp
drip http 3000 -n myapp
# → https://myapp.your-domain.com
# 暴露 TCP 服务数据库、SSH 等)
@@ -89,28 +96,33 @@ drip tcp 5432
```bash
# 转发到局域网其他机器
drip http 8080 --address 192.168.1.100
drip http 8080 -a 192.168.1.100
# 转发到 Docker 容器
drip http 3000 --address 172.17.0.2
drip http 3000 -a 172.17.0.2
# 转发到特定网卡
drip http 3000 --address 10.0.0.5
drip http 3000 -a 10.0.0.5
```
### 守护模式
### 后台模式
让隧道在后台运行:
使用 `-d` 让隧道在后台运行:
```bash
# 以守护进程启动隧道
drip daemon start http 3000
drip daemon start https 8443 --subdomain api
# 后台启动隧道
drip http 3000 -d
drip https 8443 -n api -d
# 管理守护进程
drip daemon list
drip daemon stop http-3000
drip daemon logs http-3000
# 列出运行中的隧道
drip list
# 查看隧道日志
drip attach http 3000
# 停止隧道
drip stop http 3000
drip stop all
```
## 服务端部署
@@ -214,25 +226,25 @@ sudo journalctl -u drip-server -f
drip http 3000
# 测试第三方 webhook如 Stripe
drip http 8000 --subdomain webhooks
drip http 8000 -n webhooks
```
**家庭服务器访问**
```bash
# 远程访问家里的 NAS
drip http 5000 --address 192.168.1.50
drip http 5000 -a 192.168.1.50
# 远程进入家庭网络
# 通过 SSH 远程进入家庭网络
drip tcp 22
```
**Docker 与容器**
```bash
# 暴露容器化应用
drip http 8080 --address 172.17.0.3
drip http 8080 -a 172.17.0.3
# 数据库调试
drip tcp 5432 --address db-container
drip tcp 5432 -a db-container
```
## 命令参考
@@ -240,26 +252,29 @@ drip tcp 5432 --address db-container
```bash
# HTTP 隧道
drip http <端口> [参数]
--subdomain, -n 自定义子域名
--address, -a 目标地址默认127.0.0.1
--server 服务器地址
--token 认证 token
-n, --subdomain 自定义子域名
-a, --address 目标地址默认127.0.0.1
-d, --daemon 后台运行
-s, --server 服务器地址
-t, --token 认证 token
# HTTPS 隧道
# HTTPS 隧道(参数同 http
drip https <端口> [参数]
# TCP 隧道
# TCP 隧道(参数同 http
drip tcp <端口> [参数]
# 守护进程命令
drip daemon start <类型> <端口> [参数]
drip daemon list
drip daemon stop <名称>
drip daemon logs <名称>
# 后台隧道管理
drip list 列出运行中的隧道
drip list -i 交互模式
drip attach [类型] [端口] 查看日志
drip stop <类型> <端口> 停止隧道
drip stop all 停止所有隧道
# 配置
drip config init
drip config show
drip config init 设置服务器和 token
drip config show 显示当前配置
drip config set <键> <值>
```
## 协议

View File

@@ -1,4 +1,4 @@
FROM golang:1.25-alpine AS builder
FROM golang:1.23-alpine AS builder
RUN apk add --no-cache git ca-certificates tzdata

View File

@@ -0,0 +1,43 @@
# Dockerfile for deploying drip-server from GitHub Release
# Usage:
# docker build -f deployments/Dockerfile.release -t drip-server .
# docker build -f deployments/Dockerfile.release --build-arg VERSION=v1.0.0 -t drip-server:v1.0.0 .
FROM alpine:latest
# VERSION is passed from GitHub Actions (release tag or "latest")
ARG VERSION=latest
# TARGETARCH is automatically set by Docker Buildx (amd64, arm64, etc.)
ARG TARGETARCH=amd64
RUN apk add --no-cache ca-certificates tzdata curl && update-ca-certificates
RUN addgroup -S drip && adduser -S -G drip drip
WORKDIR /app
# Download binary from GitHub Releases
RUN set -ex; \
# Resolve "latest" to an actual tag
if [ "$VERSION" = "latest" ]; then \
VERSION=$(curl -sL https://api.github.com/repos/Gouryella/drip/releases/latest \
| grep '"tag_name"' | cut -d'"' -f4); \
fi; \
echo "Resolved VERSION=${VERSION}"; \
echo "Downloading drip ${VERSION} for linux-${TARGETARCH}"; \
curl -fsSL "https://github.com/Gouryella/drip/releases/download/${VERSION}/drip-${VERSION}-linux-${TARGETARCH}" -o /app/drip; \
chmod +x /app/drip; \
/app/drip version --short
RUN mkdir -p /app/data/certs && \
chown -R drip:drip /app
USER drip
EXPOSE 80 443 8080 20000-20100
HEALTHCHECK --interval=30s --timeout=3s --start-period=5s --retries=3 \
CMD curl -fsS "http://localhost:${PORT:-8080}/health" >/dev/null || exit 1
ENTRYPOINT ["/app/drip"]
CMD ["server", "--port", "8080"]

View File

@@ -1,6 +1,35 @@
# Docker Deployment
## Quick Start
## Quick Start (Recommended)
Deploy drip-server using pre-built images from GitHub Container Registry:
```bash
# Pull the latest image
docker pull ghcr.io/gouryella/drip:latest
# Or use docker compose
curl -fsSL https://raw.githubusercontent.com/Gouryella/drip/main/docker-compose.release.yml -o docker-compose.yml
# Create .env file
cat > .env << EOF
DOMAIN=tunnel.example.com
AUTH_TOKEN=your-secret-token
VERSION=latest
EOF
# Place your TLS certificates
mkdir -p certs
cp /path/to/fullchain.pem certs/
cp /path/to/privkey.pem certs/
# Start server
docker compose up -d
```
## Build from Source
If you prefer to build locally:
### Server (Production)

View File

@@ -0,0 +1,72 @@
# Docker Compose for deploying drip-server from GitHub Release
#
# Usage:
# 1. Copy this file to your server
# 2. Create .env file with your settings (see .env.example below)
# 3. Run: docker compose -f docker-compose.release.yml up -d
#
# Environment variables (.env.example):
# DOMAIN=tunnel.example.com
# AUTH_TOKEN=your-secret-token
# VERSION=latest
# TZ=UTC
services:
drip-server:
image: ghcr.io/gouryella/drip:${VERSION:-latest}
container_name: drip-server
restart: unless-stopped
ports:
- "443:443"
- "20000-20100:20000-20100" # TCP tunnel ports
volumes:
- ./certs:/app/data/certs:ro
- drip-data:/app/data
environment:
TZ: ${TZ:-UTC}
command: >
server
--domain ${DOMAIN:-tunnel.localhost}
--port 443
--tls-cert /app/data/certs/fullchain.pem
--tls-key /app/data/certs/privkey.pem
--token ${AUTH_TOKEN:-}
--tcp-port-min 20000
--tcp-port-max 20100
networks:
- drip-net
logging:
driver: json-file
options:
max-size: 10m
max-file: "3"
deploy:
resources:
limits:
cpus: '2'
memory: 512M
reservations:
cpus: '0.25'
memory: 64M
healthcheck:
test: ["CMD", "wget", "--no-verbose", "--tries=1", "--spider", "http://localhost:443/health"]
interval: 30s
timeout: 3s
retries: 3
start_period: 10s
volumes:
drip-data:
driver: local
networks:
drip-net:
driver: bridge

View File

@@ -221,13 +221,6 @@ func runConfigValidate(cmd *cobra.Command, args []string) error {
return nil
}
func enabledDisabled(value bool) string {
if value {
return "enabled"
}
return "disabled"
}
func validateServerAddress(addr string) (bool, string) {
addr = strings.TrimSpace(addr)
if addr == "" {

View File

@@ -153,12 +153,7 @@ func StartDaemon(tunnelType string, port int, args []string) error {
// Build command arguments (remove -D/--daemon flag to prevent recursion)
var cleanArgs []string
skipNext := false
for i, arg := range args {
if skipNext {
skipNext = false
continue
}
for _, arg := range args {
// Skip -D or --daemon flags (but NOT --daemon-child)
if arg == "-D" || arg == "--daemon" {
continue
@@ -167,8 +162,6 @@ func StartDaemon(tunnelType string, port int, args []string) error {
if arg == "-d" {
continue
}
// Skip if next arg would be a value for a removed flag (not applicable for boolean)
_ = i
cleanArgs = append(cleanArgs, arg)
}

View File

@@ -203,7 +203,3 @@ func isNonRetryableError(err error) bool {
strings.Contains(errStr, "authentication") ||
strings.Contains(errStr, "Invalid authentication token")
}
func isNonRetryableErrorTCP(err error) bool {
return isNonRetryableError(err)
}

View File

@@ -7,7 +7,7 @@ import (
// RenderConfigInit renders config initialization UI
func RenderConfigInit() string {
title := "Drip Configuration Setup"
box := boxStyle.Copy().Width(50)
box := boxStyle.Width(50)
return "\n" + box.Render(titleStyle.Render(title)) + "\n"
}

View File

@@ -6,19 +6,12 @@ import (
var (
// Colors inspired by Vercel CLI
primaryColor = lipgloss.Color("#0070F3")
successColor = lipgloss.Color("#0070F3")
warningColor = lipgloss.Color("#F5A623")
errorColor = lipgloss.Color("#E00")
mutedColor = lipgloss.Color("#888")
highlightColor = lipgloss.Color("#0070F3")
cyanColor = lipgloss.Color("#50E3C2")
purpleColor = lipgloss.Color("#7928CA")
// Base styles
baseStyle = lipgloss.NewStyle().
PaddingLeft(1).
PaddingRight(1)
// Box styles - Vercel-like clean box
boxStyle = lipgloss.NewStyle().
@@ -27,14 +20,11 @@ var (
MarginTop(1).
MarginBottom(1)
successBoxStyle = boxStyle.Copy().
BorderForeground(successColor)
successBoxStyle = boxStyle.BorderForeground(successColor)
warningBoxStyle = boxStyle.Copy().
BorderForeground(warningColor)
warningBoxStyle = boxStyle.BorderForeground(warningColor)
errorBoxStyle = boxStyle.Copy().
BorderForeground(errorColor)
errorBoxStyle = boxStyle.BorderForeground(errorColor)
// Text styles
titleStyle = lipgloss.NewStyle().
@@ -85,10 +75,6 @@ var (
tableCellStyle = lipgloss.NewStyle().
PaddingRight(2)
tableRowHighlight = lipgloss.NewStyle().
Foreground(highlightColor).
Bold(true)
)
// Success returns a styled success message

View File

@@ -68,7 +68,7 @@ func (t *Table) Render() string {
// Header
headerParts := make([]string, len(t.headers))
for i, header := range t.headers {
style := tableHeaderStyle.Copy().Width(colWidths[i])
style := tableHeaderStyle.Width(colWidths[i])
headerParts[i] = style.Render(header)
}
output.WriteString(strings.Join(headerParts, " "))
@@ -87,7 +87,7 @@ func (t *Table) Render() string {
rowParts := make([]string, len(t.headers))
for i, cell := range row {
if i < len(colWidths) {
style := tableCellStyle.Copy().Width(colWidths[i])
style := tableCellStyle.Width(colWidths[i])
rowParts[i] = style.Render(cell)
}
}

View File

@@ -59,7 +59,7 @@ func RenderTunnelConnected(status *TunnelStatus) string {
urlLine := lipgloss.JoinHorizontal(
lipgloss.Left,
urlStyle.Copy().Foreground(accent).Render(status.URL),
urlStyle.Foreground(accent).Render(status.URL),
lipgloss.NewStyle().MarginLeft(1).Foreground(mutedColor).Render("(forwarded address)"),
)

View File

@@ -2,6 +2,7 @@ package tcp
import (
"bytes"
"context"
"crypto/tls"
"fmt"
"io"
@@ -18,19 +19,21 @@ import (
// FrameHandler handles data frames and forwards to local service
type FrameHandler struct {
conn net.Conn
frameWriter *protocol.FrameWriter
localHost string
localPort int
logger *zap.Logger
streams map[string]*Stream
streamMu sync.RWMutex
tunnelType protocol.TunnelType
httpClient *http.Client
stats *TrafficStats
isClosedCheck func() bool
bufferPool *pool.BufferPool
headerPool *pool.HeaderPool
conn net.Conn
frameWriter *protocol.FrameWriter
localHost string
localPort int
logger *zap.Logger
streams map[string]*Stream
streamMu sync.RWMutex
streamingRequests map[string]*StreamingRequest
streamingReqMu sync.RWMutex
tunnelType protocol.TunnelType
httpClient *http.Client
stats *TrafficStats
isClosedCheck func() bool
bufferPool *pool.BufferPool
headerPool *pool.HeaderPool
}
// Stream represents a single request/response stream
@@ -39,6 +42,23 @@ type Stream struct {
LocalConn net.Conn
ResponseCh chan []byte
Done chan struct{}
closed bool
mu sync.Mutex
}
// StreamingRequest represents a streaming upload request in progress
type StreamingRequest struct {
RequestID string
Writer *io.PipeWriter
Done chan struct{}
chunkQueue chan *chunkData
closed bool
mu sync.Mutex
}
type chunkData struct {
data []byte
isLast bool
}
func NewFrameHandler(conn net.Conn, frameWriter *protocol.FrameWriter, localHost string, localPort int, tunnelType protocol.TunnelType, logger *zap.Logger, isClosedCheck func() bool, bufferPool *pool.BufferPool) *FrameHandler {
@@ -50,29 +70,30 @@ func NewFrameHandler(conn net.Conn, frameWriter *protocol.FrameWriter, localHost
}
return &FrameHandler{
conn: conn,
frameWriter: frameWriter,
localHost: localHost,
localPort: localPort,
logger: logger,
streams: make(map[string]*Stream),
tunnelType: tunnelType,
stats: NewTrafficStats(),
isClosedCheck: isClosedCheck,
bufferPool: bufferPool,
headerPool: pool.NewHeaderPool(),
conn: conn,
frameWriter: frameWriter,
localHost: localHost,
localPort: localPort,
logger: logger,
streams: make(map[string]*Stream),
streamingRequests: make(map[string]*StreamingRequest),
tunnelType: tunnelType,
stats: NewTrafficStats(),
isClosedCheck: isClosedCheck,
bufferPool: bufferPool,
headerPool: pool.NewHeaderPool(),
httpClient: &http.Client{
Timeout: 30 * time.Second,
// No overall timeout - streaming responses can take arbitrary time
Transport: &http.Transport{
MaxIdleConns: 1000, // Optimized for both mid and high load scenarios
MaxIdleConnsPerHost: 500, // Sufficient for 400+ concurrent connections
MaxConnsPerHost: 0, // Unlimited
MaxIdleConns: 1000,
MaxIdleConnsPerHost: 500,
MaxConnsPerHost: 0,
IdleConnTimeout: 180 * time.Second,
DisableCompression: true,
DisableKeepAlives: false,
TLSHandshakeTimeout: 10 * time.Second,
TLSClientConfig: tlsConfig,
ResponseHeaderTimeout: 15 * time.Second,
ResponseHeaderTimeout: 30 * time.Second,
ExpectContinueTimeout: 1 * time.Second,
DialContext: (&net.Dialer{
Timeout: 5 * time.Second,
@@ -99,6 +120,14 @@ func (h *FrameHandler) HandleDataFrame(frame *protocol.Frame) error {
return h.handleHTTPFrame(header, data)
}
if header.Type == protocol.DataTypeHTTPRequestHead || header.Type == protocol.DataTypeHTTPHead {
return h.handleHTTPRequestHead(header, data)
}
if header.Type == protocol.DataTypeHTTPRequestBodyChunk || header.Type == protocol.DataTypeHTTPBodyChunk {
return h.handleHTTPRequestBodyChunk(header, data)
}
if header.Type == protocol.DataTypeClose {
h.closeStream(header.StreamID)
return nil
@@ -122,7 +151,7 @@ func (h *FrameHandler) getOrCreateStream(streamID string) (*Stream, error) {
return stream, nil
}
localAddr := fmt.Sprintf("%s:%d", h.localHost, h.localPort)
localAddr := net.JoinHostPort(h.localHost, fmt.Sprintf("%d", h.localPort))
localConn, err := net.DialTimeout("tcp", localAddr, 5*time.Second)
if err != nil {
return nil, fmt.Errorf("failed to connect to local service: %w", err)
@@ -143,8 +172,25 @@ func (h *FrameHandler) getOrCreateStream(streamID string) (*Stream, error) {
}
func (h *FrameHandler) forwardToLocal(stream *Stream, data []byte) {
// Check if stream is closed using mutex
stream.mu.Lock()
if stream.closed {
stream.mu.Unlock()
return
}
stream.mu.Unlock()
// Double check with Done channel
select {
case <-stream.Done:
// Stream already closed, ignore data
return
default:
}
if _, err := stream.LocalConn.Write(data); err != nil {
h.logger.Error("Failed to write to local service",
// Only log at debug level since connection close is often expected
h.logger.Debug("Failed to write to local service",
zap.String("stream_id", stream.ID),
zap.Error(err),
)
@@ -160,6 +206,14 @@ func (h *FrameHandler) handleLocalResponse(stream *Stream) {
buf := (*bufPtr)[:pool.SizeMedium]
for {
// Check if stream is closed before reading
stream.mu.Lock()
closed := stream.closed
stream.mu.Unlock()
if closed {
break
}
n, err := stream.LocalConn.Read(buf)
if err != nil {
break
@@ -170,6 +224,14 @@ func (h *FrameHandler) handleLocalResponse(stream *Stream) {
break
}
// Check again after read
stream.mu.Lock()
closed = stream.closed
stream.mu.Unlock()
if closed {
break
}
header := protocol.DataHeader{
StreamID: stream.ID,
Type: protocol.DataTypeResponse,
@@ -178,14 +240,14 @@ func (h *FrameHandler) handleLocalResponse(stream *Stream) {
payload, poolBuffer, err := protocol.EncodeDataPayloadPooled(header, buf[:n])
if err != nil {
h.logger.Error("Encode payload failed", zap.Error(err))
h.logger.Debug("Encode payload failed", zap.Error(err))
break
}
dataFrame := protocol.NewFramePooled(protocol.FrameTypeData, payload, poolBuffer)
err = h.frameWriter.WriteFrame(dataFrame)
if err != nil {
h.logger.Error("Send frame failed", zap.Error(err))
h.logger.Debug("Send frame failed", zap.Error(err))
break
}
@@ -261,19 +323,185 @@ func (h *FrameHandler) handleHTTPFrame(header protocol.DataHeader, payload []byt
}
defer resp.Body.Close()
body, err := io.ReadAll(resp.Body)
// Threshold for switching from buffered to streaming mode
const bufferThreshold int64 = 1 * 1024 * 1024 // 1MB
// If Content-Length is known and large, use streaming directly
if resp.ContentLength > bufferThreshold {
return h.streamHTTPResponse(header.StreamID, header.RequestID, resp)
}
// For small or unknown size: try buffered first, switch to streaming if too large
return h.adaptiveHTTPResponse(header.StreamID, header.RequestID, resp, bufferThreshold)
}
// adaptiveHTTPResponse tries buffered mode first, switches to streaming if data exceeds threshold
func (h *FrameHandler) adaptiveHTTPResponse(streamID, requestID string, resp *http.Response, threshold int64) error {
if h.isClosedCheck != nil && h.isClosedCheck() {
return nil
}
// Buffer for initial read
buffer := make([]byte, 0, threshold)
tempBuf := make([]byte, 32*1024) // 32KB read chunks
var totalRead int64
var hitThreshold bool
// Try to read up to threshold
for totalRead < threshold {
n, err := resp.Body.Read(tempBuf)
if n > 0 {
buffer = append(buffer, tempBuf[:n]...)
totalRead += int64(n)
}
if err == io.EOF {
// Response completed within threshold - use buffered mode
break
}
if err != nil {
return h.sendHTTPError(streamID, requestID, http.StatusBadGateway, fmt.Sprintf("read response: %v", err))
}
if totalRead >= threshold {
hitThreshold = true
break
}
}
if !hitThreshold {
// Small response - send as buffered
// Clean response headers - remove hop-by-hop headers that are invalid after proxying
cleanedHeaders := h.cleanResponseHeaders(resp.Header)
httpResp := protocol.HTTPResponse{
StatusCode: resp.StatusCode,
Status: resp.Status,
Headers: cleanedHeaders,
Body: buffer,
}
return h.sendHTTPResponse(streamID, requestID, &httpResp)
}
// Large response - switch to streaming mode
// Clean response headers - remove hop-by-hop headers that are invalid after proxying
cleanedHeaders := h.cleanResponseHeaders(resp.Header)
// First send headers
httpHead := protocol.HTTPResponseHead{
StatusCode: resp.StatusCode,
Status: resp.Status,
Headers: cleanedHeaders,
ContentLength: resp.ContentLength, // -1 if unknown
}
headBytes, err := protocol.EncodeHTTPResponseHead(&httpHead)
if err != nil {
return h.sendHTTPError(header.StreamID, header.RequestID, http.StatusBadGateway, fmt.Sprintf("read response: %v", err))
return fmt.Errorf("encode http head: %w", err)
}
httpResp := protocol.HTTPResponse{
StatusCode: resp.StatusCode,
Status: resp.Status,
Headers: resp.Header,
Body: body,
headHeader := protocol.DataHeader{
StreamID: streamID,
RequestID: requestID,
Type: protocol.DataTypeHTTPHead,
IsLast: false,
}
return h.sendHTTPResponse(header.StreamID, header.RequestID, &httpResp)
headPayload, headPoolBuffer, err := protocol.EncodeDataPayloadPooled(headHeader, headBytes)
if err != nil {
return fmt.Errorf("encode head payload: %w", err)
}
headFrame := protocol.NewFramePooled(protocol.FrameTypeData, headPayload, headPoolBuffer)
if err := h.frameWriter.WriteFrame(headFrame); err != nil {
return err
}
h.frameWriter.Flush()
h.stats.AddBytesOut(int64(len(headPayload)))
// Send buffered data as first chunk
if len(buffer) > 0 {
chunkHeader := protocol.DataHeader{
StreamID: streamID,
RequestID: requestID,
Type: protocol.DataTypeHTTPBodyChunk,
IsLast: false,
}
chunkPayload, chunkPoolBuffer, err := protocol.EncodeDataPayloadPooled(chunkHeader, buffer)
if err != nil {
return fmt.Errorf("encode chunk payload: %w", err)
}
chunkFrame := protocol.NewFramePooled(protocol.FrameTypeData, chunkPayload, chunkPoolBuffer)
if err := h.frameWriter.WriteFrame(chunkFrame); err != nil {
return err
}
h.stats.AddBytesOut(int64(len(chunkPayload)))
}
// Clear buffer to free memory
buffer = nil
// Continue streaming remaining data
bufPtr := h.bufferPool.Get(pool.SizeMedium)
defer h.bufferPool.Put(bufPtr)
buf := (*bufPtr)[:pool.SizeMedium]
for {
if h.isClosedCheck != nil && h.isClosedCheck() {
return nil
}
n, readErr := resp.Body.Read(buf)
if n > 0 {
isLast := readErr == io.EOF
chunkHeader := protocol.DataHeader{
StreamID: streamID,
RequestID: requestID,
Type: protocol.DataTypeHTTPBodyChunk,
IsLast: isLast,
}
chunkPayload, chunkPoolBuffer, err := protocol.EncodeDataPayloadPooled(chunkHeader, buf[:n])
if err != nil {
return fmt.Errorf("encode chunk payload: %w", err)
}
chunkFrame := protocol.NewFramePooled(protocol.FrameTypeData, chunkPayload, chunkPoolBuffer)
if err := h.frameWriter.WriteFrame(chunkFrame); err != nil {
return err
}
h.stats.AddBytesOut(int64(len(chunkPayload)))
}
if readErr == io.EOF {
if n == 0 {
// Send final empty chunk
finalHeader := protocol.DataHeader{
StreamID: streamID,
RequestID: requestID,
Type: protocol.DataTypeHTTPBodyChunk,
IsLast: true,
}
finalPayload, finalPoolBuffer, err := protocol.EncodeDataPayloadPooled(finalHeader, nil)
if err != nil {
return fmt.Errorf("encode final payload: %w", err)
}
finalFrame := protocol.NewFramePooled(protocol.FrameTypeData, finalPayload, finalPoolBuffer)
if err := h.frameWriter.WriteFrame(finalFrame); err != nil {
return err
}
}
h.frameWriter.Flush()
break
}
if readErr != nil {
return fmt.Errorf("read response body: %w", readErr)
}
}
return nil
}
func (h *FrameHandler) sendHTTPError(streamID, requestID string, status int, message string) error {
@@ -294,6 +522,111 @@ func (h *FrameHandler) sendHTTPError(streamID, requestID string, status int, mes
return err
}
// streamHTTPResponse streams HTTP response using zero-copy approach
// First sends headers, then streams body chunks
func (h *FrameHandler) streamHTTPResponse(streamID, requestID string, resp *http.Response) error {
if h.isClosedCheck != nil && h.isClosedCheck() {
return nil
}
// Clean response headers - remove hop-by-hop headers that are invalid after proxying
cleanedHeaders := h.cleanResponseHeaders(resp.Header)
// Send HTTP headers first
contentLength := resp.ContentLength // -1 if unknown
httpHead := protocol.HTTPResponseHead{
StatusCode: resp.StatusCode,
Status: resp.Status,
Headers: cleanedHeaders,
ContentLength: contentLength,
}
headBytes, err := protocol.EncodeHTTPResponseHead(&httpHead)
if err != nil {
return fmt.Errorf("encode http head: %w", err)
}
headHeader := protocol.DataHeader{
StreamID: streamID,
RequestID: requestID,
Type: protocol.DataTypeHTTPHead,
IsLast: false,
}
headPayload, headPoolBuffer, err := protocol.EncodeDataPayloadPooled(headHeader, headBytes)
if err != nil {
return fmt.Errorf("encode head payload: %w", err)
}
headFrame := protocol.NewFramePooled(protocol.FrameTypeData, headPayload, headPoolBuffer)
if err := h.frameWriter.WriteFrame(headFrame); err != nil {
return err
}
h.frameWriter.Flush()
h.stats.AddBytesOut(int64(len(headPayload)))
// Stream body chunks - zero copy using buffer pool
bufPtr := h.bufferPool.Get(pool.SizeMedium)
defer h.bufferPool.Put(bufPtr)
buf := (*bufPtr)[:pool.SizeMedium]
for {
if h.isClosedCheck != nil && h.isClosedCheck() {
return nil
}
n, readErr := resp.Body.Read(buf)
if n > 0 {
isLast := readErr == io.EOF
chunkHeader := protocol.DataHeader{
StreamID: streamID,
RequestID: requestID,
Type: protocol.DataTypeHTTPBodyChunk,
IsLast: isLast,
}
chunkPayload, chunkPoolBuffer, err := protocol.EncodeDataPayloadPooled(chunkHeader, buf[:n])
if err != nil {
return fmt.Errorf("encode chunk payload: %w", err)
}
chunkFrame := protocol.NewFramePooled(protocol.FrameTypeData, chunkPayload, chunkPoolBuffer)
if err := h.frameWriter.WriteFrame(chunkFrame); err != nil {
return err
}
h.stats.AddBytesOut(int64(len(chunkPayload)))
}
if readErr == io.EOF {
// Send final empty chunk with IsLast=true if we haven't already
if n == 0 {
finalHeader := protocol.DataHeader{
StreamID: streamID,
RequestID: requestID,
Type: protocol.DataTypeHTTPBodyChunk,
IsLast: true,
}
finalPayload, finalPoolBuffer, err := protocol.EncodeDataPayloadPooled(finalHeader, nil)
if err != nil {
return fmt.Errorf("encode final payload: %w", err)
}
finalFrame := protocol.NewFramePooled(protocol.FrameTypeData, finalPayload, finalPoolBuffer)
if err := h.frameWriter.WriteFrame(finalFrame); err != nil {
return err
}
}
h.frameWriter.Flush()
break
}
if readErr != nil {
return fmt.Errorf("read response body: %w", readErr)
}
}
return nil
}
func (h *FrameHandler) sendHTTPResponse(streamID, requestID string, resp *protocol.HTTPResponse) error {
if h.isClosedCheck != nil && h.isClosedCheck() {
return nil
@@ -320,26 +653,44 @@ func (h *FrameHandler) sendHTTPResponse(streamID, requestID string, resp *protoc
h.stats.AddBytesOut(int64(len(payload)))
return h.frameWriter.WriteFrame(dataFrame)
if err := h.frameWriter.WriteFrame(dataFrame); err != nil {
return err
}
// Flush immediately to ensure the response is sent without batching delay
h.frameWriter.Flush()
return nil
}
func (h *FrameHandler) closeStream(streamID string) {
h.streamMu.Lock()
defer h.streamMu.Unlock()
stream, ok := h.streams[streamID]
if !ok {
h.streamMu.Unlock()
return
}
// Use stream-level mutex to prevent race conditions
stream.mu.Lock()
if stream.closed {
stream.mu.Unlock()
h.streamMu.Unlock()
return
}
stream.closed = true
stream.mu.Unlock()
// Remove from map first to prevent concurrent access
delete(h.streams, streamID)
h.streamMu.Unlock()
// Now safe to close resources without holding the main lock
if stream.LocalConn != nil {
stream.LocalConn.Close()
}
close(stream.Done)
delete(h.streams, streamID)
if h.isClosedCheck != nil && h.isClosedCheck() {
return
}
@@ -364,15 +715,29 @@ func (h *FrameHandler) closeStream(streamID string) {
// Close closes all streams
func (h *FrameHandler) Close() {
h.streamMu.Lock()
defer h.streamMu.Unlock()
for streamID, stream := range h.streams {
if stream.LocalConn != nil {
stream.LocalConn.Close()
stream.mu.Lock()
if !stream.closed {
stream.closed = true
if stream.LocalConn != nil {
stream.LocalConn.Close()
}
close(stream.Done)
}
close(stream.Done)
stream.mu.Unlock()
delete(h.streams, streamID)
}
h.streamMu.Unlock()
h.streamingReqMu.Lock()
for requestID, streamingReq := range h.streamingRequests {
h.closeStreamingRequest(requestID, streamingReq)
if streamingReq.Writer != nil {
streamingReq.Writer.CloseWithError(fmt.Errorf("tunnel connection closed"))
}
delete(h.streamingRequests, requestID)
}
h.streamingReqMu.Unlock()
}
// GetStats returns the traffic stats tracker
@@ -438,3 +803,276 @@ func (h *FrameHandler) isLocalAddress(addr string) bool {
return false
}
// cleanResponseHeaders removes hop-by-hop headers that should not be forwarded
// Go's http.Client automatically handles chunked encoding, so we need to remove
// the Transfer-Encoding header to avoid sending decoded body with chunked header
func (h *FrameHandler) cleanResponseHeaders(headers http.Header) http.Header {
cleaned := make(http.Header)
// List of hop-by-hop headers to remove (RFC 2616)
hopByHopHeaders := map[string]bool{
"Connection": true,
"Keep-Alive": true,
"Proxy-Authenticate": true,
"Proxy-Authorization": true,
"Te": true,
"Trailers": true,
"Transfer-Encoding": true,
"Upgrade": true,
"Proxy-Connection": true,
}
for key, values := range headers {
canonicalKey := http.CanonicalHeaderKey(key)
if hopByHopHeaders[canonicalKey] {
continue
}
// Also check if this header is listed in Connection header
connectionHeaders := headers.Get("Connection")
if connectionHeaders != "" {
tokens := strings.Split(connectionHeaders, ",")
skip := false
for _, token := range tokens {
if strings.TrimSpace(token) == key {
skip = true
break
}
}
if skip {
continue
}
}
for _, value := range values {
cleaned.Add(key, value)
}
}
return cleaned
}
func (h *FrameHandler) handleHTTPRequestHead(header protocol.DataHeader, payload []byte) error {
httpReqHead, err := protocol.DecodeHTTPRequestHead(payload)
if err != nil {
return fmt.Errorf("failed to decode HTTP request head: %w", err)
}
requestID := header.RequestID
if requestID == "" {
requestID = header.StreamID
}
targetURL := httpReqHead.URL
if !strings.HasPrefix(targetURL, "http://") && !strings.HasPrefix(targetURL, "https://") {
scheme := "http"
if h.tunnelType == protocol.TunnelTypeHTTPS {
scheme = "https"
}
targetURL = fmt.Sprintf("%s://%s:%d%s", scheme, h.localHost, h.localPort, targetURL)
}
pipeReader, pipeWriter := io.Pipe()
req, err := http.NewRequest(httpReqHead.Method, targetURL, pipeReader)
if err != nil {
pipeWriter.Close()
return h.sendHTTPError(header.StreamID, requestID, http.StatusBadGateway, fmt.Sprintf("build request: %v", err))
}
origHost := ""
for key, values := range httpReqHead.Headers {
if key == "Content-Length" {
continue
}
for _, value := range values {
req.Header.Add(key, value)
}
}
if host := req.Header.Get("Host"); host != "" {
origHost = host
}
req.ContentLength = -1
isLocalTarget := h.isLocalAddress(h.localHost)
if isLocalTarget {
if origHost != "" {
req.Host = origHost
req.Header.Set("Host", origHost)
} else {
localHostPort := fmt.Sprintf("%s:%d", h.localHost, h.localPort)
req.Host = localHostPort
req.Header.Set("Host", localHostPort)
}
if origHost != "" {
req.Header.Set("X-Forwarded-Host", origHost)
}
} else {
targetHost := h.localHost
if h.localPort != 443 && h.localPort != 80 {
targetHost = fmt.Sprintf("%s:%d", h.localHost, h.localPort)
}
req.Host = targetHost
req.Header.Set("Host", targetHost)
if origHost != "" {
req.Header.Set("X-Forwarded-Host", origHost)
}
}
req.Header.Set("X-Forwarded-Proto", "https")
streamingReq := &StreamingRequest{
RequestID: requestID,
Writer: pipeWriter,
Done: make(chan struct{}),
chunkQueue: make(chan *chunkData, 512), // deeper buffer for bursty body chunks
}
h.streamingReqMu.Lock()
h.streamingRequests[requestID] = streamingReq
h.streamingReqMu.Unlock()
go func() {
defer pipeWriter.Close()
timeout := time.NewTimer(5 * time.Minute) // Timeout for receiving body chunks
defer timeout.Stop()
for {
select {
case chunk, ok := <-streamingReq.chunkQueue:
if !ok || chunk == nil {
return
}
// Reset timeout on each chunk
if !timeout.Stop() {
select {
case <-timeout.C:
default:
}
}
timeout.Reset(5 * time.Minute)
if len(chunk.data) > 0 {
if _, err := pipeWriter.Write(chunk.data); err != nil {
h.logger.Error("Failed to write to pipe",
zap.String("request_id", requestID),
zap.Error(err),
)
pipeWriter.CloseWithError(err)
return
}
}
if chunk.isLast {
return
}
case <-streamingReq.Done:
return
case <-timeout.C:
h.logger.Warn("Timeout waiting for request body chunks",
zap.String("request_id", requestID),
)
pipeWriter.CloseWithError(fmt.Errorf("timeout waiting for body chunks"))
return
}
}
}()
go func() {
defer func() {
h.closeStreamingRequest(requestID, streamingReq)
h.streamingReqMu.Lock()
delete(h.streamingRequests, requestID)
h.streamingReqMu.Unlock()
}()
ctx, cancel := context.WithTimeout(context.Background(), 10*time.Minute)
defer cancel()
reqWithCtx := req.WithContext(ctx)
resp, err := h.httpClient.Do(reqWithCtx)
if err != nil {
h.sendHTTPError(header.StreamID, requestID, http.StatusBadGateway, fmt.Sprintf("local request failed: %v", err))
return
}
defer resp.Body.Close()
const bufferThreshold int64 = 1 * 1024 * 1024
if resp.ContentLength > bufferThreshold {
h.streamHTTPResponse(header.StreamID, requestID, resp)
} else {
h.adaptiveHTTPResponse(header.StreamID, requestID, resp, bufferThreshold)
}
}()
return nil
}
func (h *FrameHandler) handleHTTPRequestBodyChunk(header protocol.DataHeader, data []byte) error {
requestID := header.RequestID
if requestID == "" {
requestID = header.StreamID
}
h.streamingReqMu.RLock()
streamingReq, exists := h.streamingRequests[requestID]
h.streamingReqMu.RUnlock()
if !exists {
h.logger.Warn("Streaming request not found for body chunk",
zap.String("request_id", requestID),
)
return nil
}
streamingReq.mu.Lock()
if streamingReq.closed {
streamingReq.mu.Unlock()
h.logger.Debug("Streaming request already closed",
zap.String("request_id", requestID),
)
return nil
}
streamingReq.mu.Unlock()
chunk := &chunkData{
data: make([]byte, len(data)),
isLast: header.IsLast,
}
copy(chunk.data, data)
select {
case streamingReq.chunkQueue <- chunk:
case <-streamingReq.Done:
h.logger.Debug("Streaming request already closed",
zap.String("request_id", requestID),
)
return nil
}
if header.IsLast {
h.closeStreamingRequest(requestID, streamingReq)
h.streamingReqMu.Lock()
delete(h.streamingRequests, requestID)
h.streamingReqMu.Unlock()
}
return nil
}
// closeStreamingRequest marks a streaming request closed and signals goroutines.
func (h *FrameHandler) closeStreamingRequest(requestID string, streamingReq *StreamingRequest) {
streamingReq.mu.Lock()
if streamingReq.closed {
streamingReq.mu.Unlock()
return
}
streamingReq.closed = true
close(streamingReq.Done)
streamingReq.mu.Unlock()
}

View File

@@ -1,7 +1,6 @@
package proxy
import (
"context"
"fmt"
"io"
"net/http"
@@ -12,7 +11,6 @@ import (
json "github.com/goccy/go-json"
"drip/internal/server/tunnel"
"drip/internal/shared/constants"
"drip/internal/shared/pool"
"drip/internal/shared/protocol"
"drip/internal/shared/utils"
@@ -71,17 +69,53 @@ func (h *Handler) ServeHTTP(w http.ResponseWriter, r *http.Request) {
return
}
limitedReader := io.LimitReader(r.Body, constants.MaxRequestBodySize)
body, err := io.ReadAll(limitedReader)
if err != nil {
h.logger.Error("Read request body failed", zap.Error(err))
http.Error(w, "Failed to read request body", http.StatusInternalServerError)
return
}
defer r.Body.Close()
requestID := utils.GenerateID()
h.handleAdaptiveRequest(w, r, transport, requestID, subdomain)
}
func (h *Handler) handleAdaptiveRequest(w http.ResponseWriter, r *http.Request, transport tunnel.Transport, requestID string, subdomain string) {
const streamingThreshold int64 = 1 * 1024 * 1024
buffer := make([]byte, 0, streamingThreshold)
tempBuf := make([]byte, 32*1024)
var totalRead int64
var hitThreshold bool
for totalRead < streamingThreshold {
n, err := r.Body.Read(tempBuf)
if n > 0 {
buffer = append(buffer, tempBuf[:n]...)
totalRead += int64(n)
}
if err == io.EOF {
r.Body.Close()
h.sendBufferedRequest(w, r, transport, requestID, subdomain, buffer)
return
}
if err != nil {
r.Body.Close()
h.logger.Error("Read request body failed", zap.Error(err))
http.Error(w, "Failed to read request body", http.StatusInternalServerError)
return
}
if totalRead >= streamingThreshold {
hitThreshold = true
break
}
}
if !hitThreshold {
r.Body.Close()
h.sendBufferedRequest(w, r, transport, requestID, subdomain, buffer)
return
}
h.streamLargeRequest(w, r, transport, requestID, subdomain, buffer)
}
func (h *Handler) sendBufferedRequest(w http.ResponseWriter, r *http.Request, transport tunnel.Transport, requestID string, subdomain string, body []byte) {
headers := h.headerPool.Get()
h.headerPool.CloneWithExtra(headers, r.Header, "Host", r.Host)
@@ -93,7 +127,6 @@ func (h *Handler) ServeHTTP(w http.ResponseWriter, r *http.Request) {
}
reqBytes, err := protocol.EncodeHTTPRequest(&httpReq)
h.headerPool.Put(headers)
if err != nil {
@@ -119,7 +152,11 @@ func (h *Handler) ServeHTTP(w http.ResponseWriter, r *http.Request) {
frame := protocol.NewFramePooled(protocol.FrameTypeData, payload, poolBuffer)
respChan := h.responses.CreateResponseChan(requestID)
defer h.responses.CleanupResponseChan(requestID)
streamingDone := h.responses.CreateStreamingResponse(requestID, w)
defer func() {
h.responses.CleanupResponseChan(requestID)
h.responses.CleanupStreamingResponse(requestID)
}()
if err := transport.SendFrame(frame); err != nil {
h.logger.Error("Send frame to tunnel failed", zap.Error(err))
@@ -127,14 +164,220 @@ func (h *Handler) ServeHTTP(w http.ResponseWriter, r *http.Request) {
return
}
ctx, cancel := context.WithTimeout(context.Background(), constants.RequestTimeout)
defer cancel()
select {
case respMsg := <-respChan:
if respMsg == nil {
http.Error(w, "Internal server error: nil response", http.StatusInternalServerError)
return
}
h.writeHTTPResponse(w, respMsg, subdomain, r)
case <-streamingDone:
// Streaming response has been fully written by SendStreamingChunk
case <-time.After(5 * time.Minute):
h.logger.Error("Request timeout",
zap.String("request_id", requestID),
zap.String("url", r.URL.String()),
)
http.Error(w, "Request timeout - the tunnel client did not respond in time", http.StatusGatewayTimeout)
}
}
func (h *Handler) streamLargeRequest(w http.ResponseWriter, r *http.Request, transport tunnel.Transport, requestID string, subdomain string, bufferedData []byte) {
headers := h.headerPool.Get()
h.headerPool.CloneWithExtra(headers, r.Header, "Host", r.Host)
httpReqHead := protocol.HTTPRequestHead{
Method: r.Method,
URL: r.URL.String(),
Headers: headers,
ContentLength: r.ContentLength,
}
headBytes, err := protocol.EncodeHTTPRequestHead(&httpReqHead)
h.headerPool.Put(headers)
if err != nil {
h.logger.Error("Encode HTTP request head failed", zap.Error(err))
http.Error(w, "Internal server error", http.StatusInternalServerError)
return
}
headHeader := protocol.DataHeader{
StreamID: requestID,
RequestID: requestID,
Type: protocol.DataTypeHTTPHead, // shared streaming head type
IsLast: false,
}
headPayload, headPoolBuffer, err := protocol.EncodeDataPayloadPooled(headHeader, headBytes)
if err != nil {
h.logger.Error("Encode head payload failed", zap.Error(err))
http.Error(w, "Internal server error", http.StatusInternalServerError)
return
}
headFrame := protocol.NewFramePooled(protocol.FrameTypeData, headPayload, headPoolBuffer)
respChan := h.responses.CreateResponseChan(requestID)
streamingDone := h.responses.CreateStreamingResponse(requestID, w)
defer func() {
h.responses.CleanupResponseChan(requestID)
h.responses.CleanupStreamingResponse(requestID)
}()
if err := transport.SendFrame(headFrame); err != nil {
h.logger.Error("Send head frame failed", zap.Error(err))
http.Error(w, "Failed to forward request to tunnel", http.StatusBadGateway)
return
}
if len(bufferedData) > 0 {
chunkHeader := protocol.DataHeader{
StreamID: requestID,
RequestID: requestID,
Type: protocol.DataTypeHTTPBodyChunk, // shared streaming body type
IsLast: false,
}
chunkPayload, chunkPoolBuffer, err := protocol.EncodeDataPayloadPooled(chunkHeader, bufferedData)
if err != nil {
h.logger.Error("Encode buffered chunk failed", zap.Error(err))
finalHeader := protocol.DataHeader{
StreamID: requestID,
RequestID: requestID,
Type: protocol.DataTypeHTTPRequestBodyChunk,
IsLast: true,
}
finalPayload, finalPoolBuffer, ferr := protocol.EncodeDataPayloadPooled(finalHeader, nil)
if ferr == nil {
finalFrame := protocol.NewFramePooled(protocol.FrameTypeData, finalPayload, finalPoolBuffer)
transport.SendFrame(finalFrame)
}
return
}
chunkFrame := protocol.NewFramePooled(protocol.FrameTypeData, chunkPayload, chunkPoolBuffer)
if err := transport.SendFrame(chunkFrame); err != nil {
h.logger.Error("Send buffered chunk failed", zap.Error(err))
finalHeader := protocol.DataHeader{
StreamID: requestID,
RequestID: requestID,
Type: protocol.DataTypeHTTPRequestBodyChunk,
IsLast: true,
}
finalPayload, finalPoolBuffer, ferr := protocol.EncodeDataPayloadPooled(finalHeader, nil)
if ferr == nil {
finalFrame := protocol.NewFramePooled(protocol.FrameTypeData, finalPayload, finalPoolBuffer)
transport.SendFrame(finalFrame)
}
return
}
}
buffer := make([]byte, 32*1024)
for {
n, readErr := r.Body.Read(buffer)
if n > 0 {
isLast := readErr == io.EOF
chunkHeader := protocol.DataHeader{
StreamID: requestID,
RequestID: requestID,
Type: protocol.DataTypeHTTPBodyChunk, // shared streaming body type
IsLast: isLast,
}
chunkPayload, chunkPoolBuffer, err := protocol.EncodeDataPayloadPooled(chunkHeader, buffer[:n])
if err != nil {
h.logger.Error("Encode chunk payload failed", zap.Error(err))
finalHeader := protocol.DataHeader{
StreamID: requestID,
RequestID: requestID,
Type: protocol.DataTypeHTTPRequestBodyChunk,
IsLast: true,
}
finalPayload, finalPoolBuffer, ferr := protocol.EncodeDataPayloadPooled(finalHeader, nil)
if ferr == nil {
finalFrame := protocol.NewFramePooled(protocol.FrameTypeData, finalPayload, finalPoolBuffer)
transport.SendFrame(finalFrame)
}
return
}
chunkFrame := protocol.NewFramePooled(protocol.FrameTypeData, chunkPayload, chunkPoolBuffer)
if err := transport.SendFrame(chunkFrame); err != nil {
h.logger.Error("Send chunk frame failed", zap.Error(err))
finalHeader := protocol.DataHeader{
StreamID: requestID,
RequestID: requestID,
Type: protocol.DataTypeHTTPRequestBodyChunk,
IsLast: true,
}
finalPayload, finalPoolBuffer, ferr := protocol.EncodeDataPayloadPooled(finalHeader, nil)
if ferr == nil {
finalFrame := protocol.NewFramePooled(protocol.FrameTypeData, finalPayload, finalPoolBuffer)
transport.SendFrame(finalFrame)
}
return
}
}
if readErr == io.EOF {
if n == 0 {
finalHeader := protocol.DataHeader{
StreamID: requestID,
RequestID: requestID,
Type: protocol.DataTypeHTTPRequestBodyChunk,
IsLast: true,
}
finalPayload, finalPoolBuffer, err := protocol.EncodeDataPayloadPooled(finalHeader, nil)
if err == nil {
finalFrame := protocol.NewFramePooled(protocol.FrameTypeData, finalPayload, finalPoolBuffer)
transport.SendFrame(finalFrame)
}
}
break
}
if readErr != nil {
h.logger.Error("Read request body failed", zap.Error(readErr))
finalHeader := protocol.DataHeader{
StreamID: requestID,
RequestID: requestID,
Type: protocol.DataTypeHTTPRequestBodyChunk,
IsLast: true,
}
finalPayload, finalPoolBuffer, err := protocol.EncodeDataPayloadPooled(finalHeader, nil)
if err == nil {
finalFrame := protocol.NewFramePooled(protocol.FrameTypeData, finalPayload, finalPoolBuffer)
transport.SendFrame(finalFrame)
}
http.Error(w, "Failed to read request body", http.StatusInternalServerError)
return
}
}
r.Body.Close()
select {
case respMsg := <-respChan:
if respMsg == nil {
http.Error(w, "Internal server error: nil response", http.StatusInternalServerError)
return
}
h.writeHTTPResponse(w, respMsg, subdomain, r)
case <-ctx.Done():
case <-streamingDone:
// Streaming response has been fully written by SendStreamingChunk
case <-time.After(5 * time.Minute):
h.logger.Error("Streaming request timeout",
zap.String("request_id", requestID),
zap.String("url", r.URL.String()),
)
http.Error(w, "Request timeout - the tunnel client did not respond in time", http.StatusGatewayTimeout)
}
}
@@ -145,12 +388,23 @@ func (h *Handler) writeHTTPResponse(w http.ResponseWriter, resp *protocol.HTTPRe
return
}
// For buffered responses, we have the complete body, so we can set Content-Length
// Skip ALL hop-by-hop headers - client should have already cleaned them
for key, values := range resp.Headers {
if key == "Connection" || key == "Keep-Alive" || key == "Transfer-Encoding" || key == "Upgrade" {
canonicalKey := http.CanonicalHeaderKey(key)
// Skip hop-by-hop headers completely using canonical key comparison
if canonicalKey == "Connection" ||
canonicalKey == "Keep-Alive" ||
canonicalKey == "Transfer-Encoding" ||
canonicalKey == "Upgrade" ||
canonicalKey == "Proxy-Connection" ||
canonicalKey == "Te" ||
canonicalKey == "Trailer" {
continue
}
if key == "Location" && len(values) > 0 {
if canonicalKey == "Location" && len(values) > 0 {
rewrittenLocation := h.rewriteLocationHeader(values[0], r.Host)
w.Header().Set("Location", rewrittenLocation)
continue
@@ -161,9 +415,8 @@ func (h *Handler) writeHTTPResponse(w http.ResponseWriter, resp *protocol.HTTPRe
}
}
if w.Header().Get("Content-Length") == "" && len(resp.Body) > 0 {
w.Header().Set("Content-Length", fmt.Sprintf("%d", len(resp.Body)))
}
// For buffered mode, always set Content-Length with the actual body size
w.Header().Set("Content-Length", fmt.Sprintf("%d", len(resp.Body)))
statusCode := resp.StatusCode
if statusCode == 0 {
@@ -171,6 +424,7 @@ func (h *Handler) writeHTTPResponse(w http.ResponseWriter, resp *protocol.HTTPRe
}
w.WriteHeader(statusCode)
if len(resp.Body) > 0 {
w.Write(resp.Body)
}
@@ -284,19 +538,6 @@ func (h *Handler) serveHealth(w http.ResponseWriter, r *http.Request) {
json.NewEncoder(w).Encode(health)
}
func cloneHeadersWithHost(src http.Header, host string) http.Header {
dst := make(http.Header, len(src)+1)
for k, v := range src {
copied := make([]string, len(v))
copy(copied, v)
dst[k] = copied
}
if host != "" {
dst.Set("Host", host)
}
return dst
}
func (h *Handler) serveStats(w http.ResponseWriter, r *http.Request) {
if h.authToken != "" {
token := r.URL.Query().Get("token")

View File

@@ -1,6 +1,10 @@
package proxy
import (
"fmt"
"net"
"net/http"
"strings"
"sync"
"time"
@@ -14,20 +18,33 @@ type responseChanEntry struct {
createdAt time.Time
}
// streamingResponseEntry holds a streaming response writer
type streamingResponseEntry struct {
w http.ResponseWriter
flusher http.Flusher
createdAt time.Time
lastActivityAt time.Time
headersSent bool
done chan struct{}
mu sync.Mutex
}
// ResponseHandler manages response channels for HTTP requests over TCP/Frame protocol
type ResponseHandler struct {
channels map[string]*responseChanEntry
mu sync.RWMutex
logger *zap.Logger
stopCh chan struct{}
channels map[string]*responseChanEntry
streamingChannels map[string]*streamingResponseEntry
mu sync.RWMutex
logger *zap.Logger
stopCh chan struct{}
}
// NewResponseHandler creates a new response handler
func NewResponseHandler(logger *zap.Logger) *ResponseHandler {
h := &ResponseHandler{
channels: make(map[string]*responseChanEntry),
logger: logger,
stopCh: make(chan struct{}),
channels: make(map[string]*responseChanEntry),
streamingChannels: make(map[string]*streamingResponseEntry),
logger: logger,
stopCh: make(chan struct{}),
}
// Start single cleanup goroutine instead of one per request
@@ -50,6 +67,25 @@ func (h *ResponseHandler) CreateResponseChan(requestID string) chan *protocol.HT
return ch
}
// CreateStreamingResponse creates a streaming response entry for a request ID
func (h *ResponseHandler) CreateStreamingResponse(requestID string, w http.ResponseWriter) chan struct{} {
h.mu.Lock()
defer h.mu.Unlock()
flusher, _ := w.(http.Flusher)
done := make(chan struct{})
now := time.Now()
h.streamingChannels[requestID] = &streamingResponseEntry{
w: w,
flusher: flusher,
createdAt: now,
lastActivityAt: now,
done: done,
}
return done
}
// GetResponseChan gets the response channel for a request ID
func (h *ResponseHandler) GetResponseChan(requestID string) <-chan *protocol.HTTPResponse {
h.mu.RLock()
@@ -67,25 +103,168 @@ func (h *ResponseHandler) SendResponse(requestID string, resp *protocol.HTTPResp
h.mu.RUnlock()
if !exists || entry == nil {
h.logger.Warn("Response channel not found",
zap.String("request_id", requestID),
)
return
}
select {
case entry.ch <- resp:
h.logger.Debug("Response sent to channel",
zap.String("request_id", requestID),
)
case <-time.After(5 * time.Second):
h.logger.Warn("Timeout sending response to channel",
case <-time.After(30 * time.Second):
h.logger.Error("Timeout sending response to channel - handler may have abandoned",
zap.String("request_id", requestID),
zap.Int("status_code", resp.StatusCode),
zap.Int("body_size", len(resp.Body)),
)
}
}
// CleanupResponseChan removes and closes a response channel
func (h *ResponseHandler) SendStreamingHead(requestID string, head *protocol.HTTPResponseHead) error {
h.mu.RLock()
entry, exists := h.streamingChannels[requestID]
h.mu.RUnlock()
if !exists || entry == nil {
return nil
}
entry.mu.Lock()
defer entry.mu.Unlock()
select {
case <-entry.done:
return nil
default:
}
if entry.headersSent {
return nil
}
// Copy headers, removing hop-by-hop headers that were already handled by client
// Client's cleanResponseHeaders already removed Transfer-Encoding, Connection, etc.
// But we need to check again in case they slipped through
hasContentLength := false
for key, values := range head.Headers {
canonicalKey := http.CanonicalHeaderKey(key)
// Skip ALL hop-by-hop headers
if canonicalKey == "Connection" ||
canonicalKey == "Keep-Alive" ||
canonicalKey == "Transfer-Encoding" ||
canonicalKey == "Upgrade" ||
canonicalKey == "Proxy-Connection" ||
canonicalKey == "Te" ||
canonicalKey == "Trailer" {
continue
}
if canonicalKey == "Content-Length" {
hasContentLength = true
}
for _, value := range values {
entry.w.Header().Add(key, value)
}
}
// For streaming responses, decide how to indicate message length
if head.ContentLength >= 0 && !hasContentLength {
entry.w.Header().Set("Content-Length", fmt.Sprintf("%d", head.ContentLength))
}
statusCode := head.StatusCode
if statusCode == 0 {
statusCode = http.StatusOK
}
entry.w.WriteHeader(statusCode)
entry.headersSent = true
entry.lastActivityAt = time.Now()
if entry.flusher != nil {
entry.flusher.Flush()
}
return nil
}
func (h *ResponseHandler) SendStreamingChunk(requestID string, chunk []byte, isLast bool) error {
h.mu.RLock()
entry, exists := h.streamingChannels[requestID]
h.mu.RUnlock()
if !exists || entry == nil {
return nil
}
entry.mu.Lock()
defer entry.mu.Unlock()
select {
case <-entry.done:
return nil
default:
}
if len(chunk) > 0 {
_, err := entry.w.Write(chunk)
if err != nil {
if isClientDisconnectError(err) {
select {
case <-entry.done:
default:
close(entry.done)
}
return nil
}
select {
case <-entry.done:
default:
close(entry.done)
}
return nil
}
entry.lastActivityAt = time.Now()
if entry.flusher != nil {
entry.flusher.Flush()
}
}
if isLast {
select {
case <-entry.done:
default:
close(entry.done)
}
}
return nil
}
func isClientDisconnectError(err error) bool {
if err == nil {
return false
}
if netErr, ok := err.(*net.OpError); ok {
if netErr.Err != nil {
errStr := netErr.Err.Error()
if strings.Contains(errStr, "broken pipe") ||
strings.Contains(errStr, "connection reset") ||
strings.Contains(errStr, "connection refused") {
return true
}
}
}
errStr := err.Error()
return strings.Contains(errStr, "broken pipe") ||
strings.Contains(errStr, "connection reset") ||
strings.Contains(errStr, "use of closed network connection")
}
func (h *ResponseHandler) CleanupResponseChan(requestID string) {
h.mu.Lock()
defer h.mu.Unlock()
@@ -96,15 +275,26 @@ func (h *ResponseHandler) CleanupResponseChan(requestID string) {
}
}
// GetPendingCount returns the number of pending responses
func (h *ResponseHandler) CleanupStreamingResponse(requestID string) {
h.mu.Lock()
defer h.mu.Unlock()
if entry, exists := h.streamingChannels[requestID]; exists {
select {
case <-entry.done:
default:
close(entry.done)
}
delete(h.streamingChannels, requestID)
}
}
func (h *ResponseHandler) GetPendingCount() int {
h.mu.RLock()
defer h.mu.RUnlock()
return len(h.channels)
return len(h.channels) + len(h.streamingChannels)
}
// cleanupLoop periodically cleans up expired response channels
// This replaces the per-request goroutine approach with a single cleanup goroutine
func (h *ResponseHandler) cleanupLoop() {
ticker := time.NewTicker(5 * time.Second)
defer ticker.Stop()
@@ -119,10 +309,10 @@ func (h *ResponseHandler) cleanupLoop() {
}
}
// cleanupExpiredChannels removes channels older than 30 seconds
func (h *ResponseHandler) cleanupExpiredChannels() {
now := time.Now()
timeout := 30 * time.Second
streamingTimeout := 5 * time.Minute
h.mu.Lock()
defer h.mu.Unlock()
@@ -136,24 +326,43 @@ func (h *ResponseHandler) cleanupExpiredChannels() {
}
}
for requestID, entry := range h.streamingChannels {
if now.Sub(entry.lastActivityAt) > streamingTimeout {
select {
case <-entry.done:
default:
close(entry.done)
}
delete(h.streamingChannels, requestID)
expiredCount++
}
}
if expiredCount > 0 {
h.logger.Debug("Cleaned up expired response channels",
zap.Int("count", expiredCount),
zap.Int("remaining", len(h.channels)),
zap.Int("remaining", len(h.channels)+len(h.streamingChannels)),
)
}
}
// Close stops the cleanup loop
func (h *ResponseHandler) Close() {
close(h.stopCh)
h.mu.Lock()
defer h.mu.Unlock()
// Close all remaining channels
for _, entry := range h.channels {
close(entry.ch)
}
h.channels = make(map[string]*responseChanEntry)
for _, entry := range h.streamingChannels {
select {
case <-entry.done:
default:
close(entry.done)
}
}
h.streamingChannels = make(map[string]*streamingResponseEntry)
}

View File

@@ -49,6 +49,9 @@ type HTTPResponseHandler interface {
GetResponseChan(requestID string) <-chan *protocol.HTTPResponse
CleanupResponseChan(requestID string)
SendResponse(requestID string, resp *protocol.HTTPResponse)
// Streaming response methods
SendStreamingHead(requestID string, head *protocol.HTTPResponseHead) error
SendStreamingChunk(requestID string, chunk []byte, isLast bool) error
}
// NewConnection creates a new connection handler
@@ -273,6 +276,15 @@ func (c *Connection) handleHTTPRequest(reader *bufio.Reader) error {
c.logger.Debug("Client disconnected abruptly", zap.Error(err))
return nil
}
// Check if it looks like garbage data (not a valid HTTP request)
if strings.Contains(errStr, "malformed HTTP") {
c.logger.Warn("Received malformed HTTP request, possibly due to pipelined requests or protocol error",
zap.Error(err),
zap.String("error_snippet", errStr[:min(len(errStr), 100)]),
)
// Close connection on malformed request to prevent further errors
return nil
}
c.logger.Error("Failed to parse HTTP request", zap.Error(err))
return fmt.Errorf("failed to parse HTTP request: %w", err)
}
@@ -289,9 +301,21 @@ func (c *Connection) handleHTTPRequest(reader *bufio.Reader) error {
header: make(http.Header),
}
// Handle the request
// Handle the request - this blocks until response is complete
c.httpHandler.ServeHTTP(respWriter, req)
// Ensure response is flushed to client
if tcpConn, ok := c.conn.(*net.TCPConn); ok {
// Force flush TCP buffers
tcpConn.SetNoDelay(true)
tcpConn.SetNoDelay(false)
}
c.logger.Debug("HTTP request processing completed",
zap.String("method", req.Method),
zap.String("url", req.URL.String()),
)
// Check if we should close the connection
// Close if: Connection: close header, or HTTP/1.0 without Connection: keep-alive
shouldClose := false
@@ -304,8 +328,13 @@ func (c *Connection) handleHTTPRequest(reader *bufio.Reader) error {
}
}
// Also check if response indicated connection should close
if respWriter.headerWritten && respWriter.header.Get("Connection") == "close" {
shouldClose = true
}
if shouldClose {
c.logger.Debug("Closing connection as requested by client")
c.logger.Debug("Closing connection as requested by client or server")
return nil
}
@@ -313,6 +342,13 @@ func (c *Connection) handleHTTPRequest(reader *bufio.Reader) error {
}
}
func min(a, b int) int {
if a < b {
return a
}
return b
}
// handleFrames handles incoming frames
func (c *Connection) handleFrames(reader *bufio.Reader) error {
for {
@@ -439,10 +475,55 @@ func (c *Connection) handleDataFrame(frame *protocol.Frame) {
}
c.responseChans.SendResponse(reqID, httpResp)
case protocol.DataTypeHTTPHead:
// Streaming HTTP response headers
if c.responseChans == nil {
c.logger.Warn("No response handler for streaming HTTP head",
zap.String("stream_id", header.StreamID),
)
return
}
c.logger.Debug("Routed HTTP response to channel",
zap.String("request_id", reqID),
)
httpHead, err := protocol.DecodeHTTPResponseHead(data)
if err != nil {
c.logger.Error("Failed to decode HTTP response head",
zap.String("stream_id", header.StreamID),
zap.Error(err),
)
return
}
reqID := header.RequestID
if reqID == "" {
reqID = header.StreamID
}
if err := c.responseChans.SendStreamingHead(reqID, httpHead); err != nil {
c.logger.Error("Failed to send streaming head",
zap.String("request_id", reqID),
zap.Error(err),
)
}
case protocol.DataTypeHTTPBodyChunk:
// Streaming HTTP response body chunk
if c.responseChans == nil {
c.logger.Warn("No response handler for streaming HTTP chunk",
zap.String("stream_id", header.StreamID),
)
return
}
reqID := header.RequestID
if reqID == "" {
reqID = header.StreamID
}
if err := c.responseChans.SendStreamingChunk(reqID, data, header.IsLast); err != nil {
c.logger.Error("Failed to send streaming chunk",
zap.String("request_id", reqID),
zap.Error(err),
)
}
case protocol.DataTypeClose:
// Client is closing the stream
if c.proxy != nil {
@@ -487,7 +568,12 @@ func (c *Connection) SendFrame(frame *protocol.Frame) error {
if c.frameWriter == nil {
return protocol.WriteFrame(c.conn, frame)
}
return c.frameWriter.WriteFrame(frame)
if err := c.frameWriter.WriteFrame(frame); err != nil {
return err
}
// Flush immediately to ensure the frame is sent without batching delay
c.frameWriter.Flush()
return nil
}
// sendError sends an error frame to the client

View File

@@ -21,12 +21,19 @@ type TunnelProxy struct {
stopCh chan struct{}
wg sync.WaitGroup
clientAddr string
streams map[string]net.Conn // streamID -> external connection
streams map[string]*proxyStream // streamID -> stream info
streamMu sync.RWMutex
frameWriter *protocol.FrameWriter
bufferPool *pool.BufferPool
}
// proxyStream holds connection info with close state
type proxyStream struct {
conn net.Conn
closed bool
mu sync.Mutex
}
// NewTunnelProxy creates a new TCP tunnel proxy
func NewTunnelProxy(port int, subdomain string, tcpConn net.Conn, logger *zap.Logger) *TunnelProxy {
return &TunnelProxy{
@@ -36,7 +43,7 @@ func NewTunnelProxy(port int, subdomain string, tcpConn net.Conn, logger *zap.Lo
logger: logger,
stopCh: make(chan struct{}),
clientAddr: tcpConn.RemoteAddr().String(),
streams: make(map[string]net.Conn),
streams: make(map[string]*proxyStream),
bufferPool: pool.NewBufferPool(),
frameWriter: protocol.NewFrameWriter(tcpConn),
}
@@ -101,8 +108,13 @@ func (p *TunnelProxy) handleConnection(conn net.Conn) {
streamID := fmt.Sprintf("%d-%d", time.Now().UnixNano(), p.port)
stream := &proxyStream{
conn: conn,
closed: false,
}
p.streamMu.Lock()
p.streams[streamID] = conn
p.streams[streamID] = stream
p.streamMu.Unlock()
defer func() {
@@ -117,6 +129,14 @@ func (p *TunnelProxy) handleConnection(conn net.Conn) {
buffer := (*bufPtr)[:pool.SizeMedium]
for {
// Check if stream is closed
stream.mu.Lock()
closed := stream.closed
stream.mu.Unlock()
if closed {
break
}
n, err := conn.Read(buffer)
if err != nil {
break
@@ -124,7 +144,7 @@ func (p *TunnelProxy) handleConnection(conn net.Conn) {
if n > 0 {
if err := p.sendDataToTunnel(streamID, buffer[:n]); err != nil {
p.logger.Error("Send to tunnel failed", zap.Error(err))
p.logger.Debug("Send to tunnel failed", zap.Error(err))
break
}
}
@@ -185,15 +205,24 @@ func (p *TunnelProxy) sendCloseToTunnel(streamID string) {
func (p *TunnelProxy) HandleResponse(streamID string, data []byte) error {
p.streamMu.RLock()
conn, ok := p.streams[streamID]
stream, ok := p.streams[streamID]
p.streamMu.RUnlock()
if !ok {
return fmt.Errorf("stream not found: %s", streamID)
// Stream may have been closed by client, this is normal
return nil
}
if _, err := conn.Write(data); err != nil {
p.logger.Error("Write to client failed", zap.Error(err))
// Check if stream is closed
stream.mu.Lock()
if stream.closed {
stream.mu.Unlock()
return nil
}
stream.mu.Unlock()
if _, err := stream.conn.Write(data); err != nil {
p.logger.Debug("Write to client failed", zap.Error(err))
return err
}
@@ -203,12 +232,24 @@ func (p *TunnelProxy) HandleResponse(streamID string, data []byte) error {
// CloseStream closes a stream
func (p *TunnelProxy) CloseStream(streamID string) {
p.streamMu.RLock()
conn, ok := p.streams[streamID]
stream, ok := p.streams[streamID]
p.streamMu.RUnlock()
if ok {
conn.Close()
if !ok {
return
}
// Mark as closed first
stream.mu.Lock()
if stream.closed {
stream.mu.Unlock()
return
}
stream.closed = true
stream.mu.Unlock()
// Now close the connection
stream.conn.Close()
}
func (p *TunnelProxy) Stop() {
@@ -224,10 +265,13 @@ func (p *TunnelProxy) Stop() {
}
p.streamMu.Lock()
for _, conn := range p.streams {
conn.Close()
for _, stream := range p.streams {
stream.mu.Lock()
stream.closed = true
stream.mu.Unlock()
stream.conn.Close()
}
p.streams = make(map[string]net.Conn)
p.streams = make(map[string]*proxyStream)
p.streamMu.Unlock()
p.wg.Wait()

View File

@@ -0,0 +1,280 @@
package hpack
import (
"bytes"
"errors"
"fmt"
"net/http"
"sync"
)
// Decoder decompresses HPACK-encoded headers
// Each connection MUST have its own decoder instance to maintain correct state
type Decoder struct {
mu sync.Mutex
dynamicTable *DynamicTable
staticTable *StaticTable
maxTableSize uint32
}
// NewDecoder creates a new HPACK decoder
func NewDecoder(maxTableSize uint32) *Decoder {
if maxTableSize == 0 {
maxTableSize = DefaultDynamicTableSize
}
return &Decoder{
dynamicTable: NewDynamicTable(maxTableSize),
staticTable: GetStaticTable(),
maxTableSize: maxTableSize,
}
}
// Decode decodes HPACK-encoded headers
func (d *Decoder) Decode(data []byte) (http.Header, error) {
d.mu.Lock()
defer d.mu.Unlock()
if len(data) == 0 {
return http.Header{}, nil
}
headers := make(http.Header)
buf := bytes.NewReader(data)
for buf.Len() > 0 {
b, err := buf.ReadByte()
if err != nil {
return nil, fmt.Errorf("read header byte: %w", err)
}
// Unread the byte so we can process it properly
if err := buf.UnreadByte(); err != nil {
return nil, err
}
var name, value string
if b&0x80 != 0 {
// Indexed header field (10xxxxxx)
name, value, err = d.decodeIndexedHeader(buf)
} else if b&0x40 != 0 {
// Literal with incremental indexing (01xxxxxx)
name, value, err = d.decodeLiteralWithIndexing(buf)
} else {
// Literal without indexing (0000xxxx)
name, value, err = d.decodeLiteralWithoutIndexing(buf)
}
if err != nil {
return nil, err
}
headers.Add(name, value)
}
return headers, nil
}
// decodeIndexedHeader decodes an indexed header field
func (d *Decoder) decodeIndexedHeader(buf *bytes.Reader) (string, string, error) {
index, err := d.readInteger(buf, 7)
if err != nil {
return "", "", fmt.Errorf("read index: %w", err)
}
if index == 0 {
return "", "", errors.New("invalid index: 0")
}
staticSize := uint32(d.staticTable.Size())
if index <= staticSize {
// Static table
return d.staticTable.Get(index - 1)
}
// Dynamic table (indices start after static table)
dynamicIndex := index - staticSize - 1
return d.dynamicTable.Get(dynamicIndex)
}
// decodeLiteralWithIndexing decodes a literal header with incremental indexing
func (d *Decoder) decodeLiteralWithIndexing(buf *bytes.Reader) (string, string, error) {
nameIndex, err := d.readInteger(buf, 6)
if err != nil {
return "", "", err
}
var name string
if nameIndex == 0 {
// Name is literal
name, err = d.readString(buf)
if err != nil {
return "", "", fmt.Errorf("read name: %w", err)
}
} else {
// Name is indexed
staticSize := uint32(d.staticTable.Size())
if nameIndex <= staticSize {
name, _, err = d.staticTable.Get(nameIndex - 1)
} else {
dynamicIndex := nameIndex - staticSize - 1
name, _, err = d.dynamicTable.Get(dynamicIndex)
}
if err != nil {
return "", "", fmt.Errorf("get indexed name: %w", err)
}
}
// Value is always literal
value, err := d.readString(buf)
if err != nil {
return "", "", fmt.Errorf("read value: %w", err)
}
// Add to dynamic table
d.dynamicTable.Add(name, value)
return name, value, nil
}
// decodeLiteralWithoutIndexing decodes a literal header without indexing
func (d *Decoder) decodeLiteralWithoutIndexing(buf *bytes.Reader) (string, string, error) {
nameIndex, err := d.readInteger(buf, 4)
if err != nil {
return "", "", err
}
var name string
if nameIndex == 0 {
// Name is literal
name, err = d.readString(buf)
if err != nil {
return "", "", fmt.Errorf("read name: %w", err)
}
} else {
// Name is indexed
staticSize := uint32(d.staticTable.Size())
if nameIndex <= staticSize {
name, _, err = d.staticTable.Get(nameIndex - 1)
} else {
dynamicIndex := nameIndex - staticSize - 1
name, _, err = d.dynamicTable.Get(dynamicIndex)
}
if err != nil {
return "", "", fmt.Errorf("get indexed name: %w", err)
}
}
// Value is always literal
value, err := d.readString(buf)
if err != nil {
return "", "", fmt.Errorf("read value: %w", err)
}
// Do NOT add to dynamic table
return name, value, nil
}
// readInteger reads an HPACK integer
func (d *Decoder) readInteger(buf *bytes.Reader, prefixBits int) (uint32, error) {
if prefixBits < 1 || prefixBits > 8 {
return 0, fmt.Errorf("invalid prefix bits: %d", prefixBits)
}
b, err := buf.ReadByte()
if err != nil {
return 0, err
}
maxPrefix := uint32((1 << prefixBits) - 1)
mask := byte(maxPrefix)
value := uint32(b & mask)
if value < maxPrefix {
return value, nil
}
// Multi-byte integer
m := uint32(0)
for {
b, err := buf.ReadByte()
if err != nil {
return 0, err
}
value += uint32(b&0x7f) << m
m += 7
if b&0x80 == 0 {
break
}
if m > 28 {
return 0, errors.New("integer overflow")
}
}
return value, nil
}
// readString reads an HPACK string
func (d *Decoder) readString(buf *bytes.Reader) (string, error) {
b, err := buf.ReadByte()
if err != nil {
return "", err
}
if err := buf.UnreadByte(); err != nil {
return "", err
}
huffmanEncoded := (b & 0x80) != 0
length, err := d.readInteger(buf, 7)
if err != nil {
return "", fmt.Errorf("read string length: %w", err)
}
if length == 0 {
return "", nil
}
if length > uint32(buf.Len()) {
return "", fmt.Errorf("string length %d exceeds buffer size %d", length, buf.Len())
}
data := make([]byte, length)
n, err := buf.Read(data)
if err != nil {
return "", err
}
if n != int(length) {
return "", fmt.Errorf("expected %d bytes, read %d", length, n)
}
if huffmanEncoded {
// TODO: Implement Huffman decoding if needed
return "", errors.New("huffman decoding not implemented")
}
return string(data), nil
}
// SetMaxTableSize updates the dynamic table size
func (d *Decoder) SetMaxTableSize(size uint32) {
d.mu.Lock()
defer d.mu.Unlock()
d.maxTableSize = size
d.dynamicTable.SetMaxSize(size)
}
// Reset clears the dynamic table
func (d *Decoder) Reset() {
d.mu.Lock()
defer d.mu.Unlock()
d.dynamicTable = NewDynamicTable(d.maxTableSize)
}

View File

@@ -0,0 +1,124 @@
package hpack
import (
"fmt"
)
// DynamicTable implements the HPACK dynamic table (RFC 7541 Section 2.3.2)
// The dynamic table is a FIFO queue where new entries are added at the beginning
// and old entries are evicted when the table size exceeds the maximum
type DynamicTable struct {
entries []HeaderField
size uint32 // Current size in bytes
maxSize uint32 // Maximum size in bytes
}
// HeaderField represents a header name-value pair
type HeaderField struct {
Name string
Value string
}
// Size returns the size of this header field in bytes
// RFC 7541: size = len(name) + len(value) + 32
func (h *HeaderField) Size() uint32 {
return uint32(len(h.Name) + len(h.Value) + 32)
}
// NewDynamicTable creates a new dynamic table with the specified maximum size
func NewDynamicTable(maxSize uint32) *DynamicTable {
return &DynamicTable{
entries: make([]HeaderField, 0, 32),
size: 0,
maxSize: maxSize,
}
}
// Add adds a header field to the dynamic table
// New entries are added at the beginning (index 0)
func (dt *DynamicTable) Add(name, value string) {
field := HeaderField{Name: name, Value: value}
fieldSize := field.Size()
// If the field is larger than maxSize, don't add it
if fieldSize > dt.maxSize {
dt.evictAll()
return
}
// Evict entries if necessary to make room
for dt.size+fieldSize > dt.maxSize && len(dt.entries) > 0 {
dt.evictOldest()
}
// Add new entry at the beginning
dt.entries = append([]HeaderField{field}, dt.entries...)
dt.size += fieldSize
}
// Get retrieves a header field by index (0-based)
// Index 0 is the most recently added entry
func (dt *DynamicTable) Get(index uint32) (string, string, error) {
if index >= uint32(len(dt.entries)) {
return "", "", fmt.Errorf("index %d out of range (table size: %d)", index, len(dt.entries))
}
field := dt.entries[index]
return field.Name, field.Value, nil
}
// FindExact searches for an exact match (name and value)
// Returns the index (0-based) and true if found
func (dt *DynamicTable) FindExact(name, value string) (uint32, bool) {
for i, field := range dt.entries {
if field.Name == name && field.Value == value {
return uint32(i), true
}
}
return 0, false
}
// FindName searches for a name match
// Returns the index (0-based) and true if found
func (dt *DynamicTable) FindName(name string) (uint32, bool) {
for i, field := range dt.entries {
if field.Name == name {
return uint32(i), true
}
}
return 0, false
}
// SetMaxSize updates the maximum table size
// If the new size is smaller, entries are evicted
func (dt *DynamicTable) SetMaxSize(maxSize uint32) {
dt.maxSize = maxSize
// Evict entries if current size exceeds new max
for dt.size > dt.maxSize && len(dt.entries) > 0 {
dt.evictOldest()
}
}
// CurrentSize returns the current size of the table in bytes
func (dt *DynamicTable) CurrentSize() uint32 {
return dt.size
}
// evictOldest removes the oldest entry (last in the slice)
func (dt *DynamicTable) evictOldest() {
if len(dt.entries) == 0 {
return
}
lastIndex := len(dt.entries) - 1
evicted := dt.entries[lastIndex]
dt.entries = dt.entries[:lastIndex]
dt.size -= evicted.Size()
}
// evictAll removes all entries
func (dt *DynamicTable) evictAll() {
dt.entries = dt.entries[:0]
dt.size = 0
}

View File

@@ -0,0 +1,200 @@
package hpack
import (
"bytes"
"errors"
"fmt"
"net/http"
"strings"
"sync"
)
const (
// DefaultDynamicTableSize is the default size of the dynamic table (4KB)
DefaultDynamicTableSize = 4096
// IndexedHeaderField represents a fully indexed header field
indexedHeaderField = 0x80 // 10xxxxxx
// LiteralHeaderFieldWithIndexing represents a literal with incremental indexing
literalHeaderFieldWithIndexing = 0x40 // 01xxxxxx
)
// Encoder compresses HTTP headers using HPACK
// Each connection MUST have its own encoder instance to avoid state corruption
type Encoder struct {
mu sync.Mutex
dynamicTable *DynamicTable
staticTable *StaticTable
maxTableSize uint32
}
// NewEncoder creates a new HPACK encoder with the specified dynamic table size
// This encoder is NOT thread-safe and should be used by a single connection
func NewEncoder(maxTableSize uint32) *Encoder {
if maxTableSize == 0 {
maxTableSize = DefaultDynamicTableSize
}
return &Encoder{
dynamicTable: NewDynamicTable(maxTableSize),
staticTable: GetStaticTable(),
maxTableSize: maxTableSize,
}
}
// Encode encodes HTTP headers into HPACK binary format
// This method is safe to call concurrently within the same encoder instance
func (e *Encoder) Encode(headers http.Header) ([]byte, error) {
e.mu.Lock()
defer e.mu.Unlock()
if headers == nil {
return nil, errors.New("headers cannot be nil")
}
buf := &bytes.Buffer{}
for name, values := range headers {
for _, value := range values {
if err := e.encodeHeaderField(buf, name, value); err != nil {
return nil, fmt.Errorf("encode header %s: %w", name, err)
}
}
}
return buf.Bytes(), nil
}
// encodeHeaderField encodes a single header field
func (e *Encoder) encodeHeaderField(buf *bytes.Buffer, name, value string) error {
// HTTP/2 requires header names to be lowercase (RFC 7540 Section 8.1.2)
// Convert to lowercase for table lookups and storage
nameLower := strings.ToLower(name)
// Try to find in static table first
if index, found := e.staticTable.FindExact(nameLower, value); found {
return e.writeIndexedHeader(buf, index+1)
}
// Check if name exists in static table (for literal with name reference)
if index, found := e.staticTable.FindName(nameLower); found {
return e.writeLiteralWithIndexing(buf, index+1, value, true)
}
// Try dynamic table
if index, found := e.dynamicTable.FindExact(nameLower, value); found {
// Dynamic table indices start after static table
dynamicIndex := uint32(e.staticTable.Size()) + index + 1
return e.writeIndexedHeader(buf, dynamicIndex)
}
if index, found := e.dynamicTable.FindName(nameLower); found {
dynamicIndex := uint32(e.staticTable.Size()) + index + 1
return e.writeLiteralWithIndexing(buf, dynamicIndex, value, true)
}
// Not found anywhere - literal with indexing and new name
// Write literal flag
buf.WriteByte(literalHeaderFieldWithIndexing)
// Write name as literal string (must come before value)
// Use lowercase name for consistency
if err := e.writeString(buf, nameLower, false); err != nil {
return err
}
// Write value as literal string
if err := e.writeString(buf, value, false); err != nil {
return err
}
// Add to dynamic table with lowercase name
e.dynamicTable.Add(nameLower, value)
return nil
}
// writeIndexedHeader writes an indexed header field (10xxxxxx)
func (e *Encoder) writeIndexedHeader(buf *bytes.Buffer, index uint32) error {
return e.writeInteger(buf, index, 7, indexedHeaderField)
}
// writeLiteralWithIndexing writes a literal header with incremental indexing (01xxxxxx)
func (e *Encoder) writeLiteralWithIndexing(buf *bytes.Buffer, nameIndex uint32, value string, hasIndex bool) error {
if hasIndex {
// Write name as index
if err := e.writeInteger(buf, nameIndex, 6, literalHeaderFieldWithIndexing); err != nil {
return err
}
} else {
// Write literal flag
buf.WriteByte(literalHeaderFieldWithIndexing)
}
// Write value as literal string
return e.writeString(buf, value, false)
}
// writeInteger writes an integer using HPACK integer representation
func (e *Encoder) writeInteger(buf *bytes.Buffer, value uint32, prefixBits int, prefix byte) error {
if prefixBits < 1 || prefixBits > 8 {
return fmt.Errorf("invalid prefix bits: %d", prefixBits)
}
maxPrefix := uint32((1 << prefixBits) - 1)
if value < maxPrefix {
buf.WriteByte(prefix | byte(value))
return nil
}
// Value >= maxPrefix, need multiple bytes
buf.WriteByte(prefix | byte(maxPrefix))
value -= maxPrefix
for value >= 128 {
buf.WriteByte(byte(value%128) | 0x80)
value /= 128
}
buf.WriteByte(byte(value))
return nil
}
// writeString writes a string using HPACK string representation
func (e *Encoder) writeString(buf *bytes.Buffer, str string, huffmanEncode bool) error {
// For simplicity, we don't use Huffman encoding in this implementation
// Huffman flag is bit 7, followed by length in remaining 7 bits
length := uint32(len(str))
if huffmanEncode {
// TODO: Implement Huffman encoding if needed
return errors.New("huffman encoding not implemented")
}
// Write length with H=0 (no Huffman)
if err := e.writeInteger(buf, length, 7, 0x00); err != nil {
return err
}
// Write string bytes
buf.WriteString(str)
return nil
}
// SetMaxTableSize updates the dynamic table size
func (e *Encoder) SetMaxTableSize(size uint32) {
e.mu.Lock()
defer e.mu.Unlock()
e.maxTableSize = size
e.dynamicTable.SetMaxSize(size)
}
// Reset clears the dynamic table
func (e *Encoder) Reset() {
e.mu.Lock()
defer e.mu.Unlock()
e.dynamicTable = NewDynamicTable(e.maxTableSize)
}

View File

@@ -0,0 +1,150 @@
package hpack
import (
"fmt"
"sync"
)
// StaticTable implements the HPACK static table (RFC 7541 Appendix A)
// The static table is predefined and never changes
type StaticTable struct {
entries []HeaderField
nameMap map[string][]uint32 // Maps name to list of indices
}
var (
staticTableInstance *StaticTable
staticTableOnce sync.Once
)
// GetStaticTable returns the singleton static table instance
func GetStaticTable() *StaticTable {
staticTableOnce.Do(func() {
staticTableInstance = newStaticTable()
})
return staticTableInstance
}
// newStaticTable creates and initializes the static table
func newStaticTable() *StaticTable {
// RFC 7541 Appendix A - Static Table Definition
// We include the most common headers for HTTP
entries := []HeaderField{
{Name: ":authority", Value: ""},
{Name: ":method", Value: "GET"},
{Name: ":method", Value: "POST"},
{Name: ":path", Value: "/"},
{Name: ":path", Value: "/index.html"},
{Name: ":scheme", Value: "http"},
{Name: ":scheme", Value: "https"},
{Name: ":status", Value: "200"},
{Name: ":status", Value: "204"},
{Name: ":status", Value: "206"},
{Name: ":status", Value: "304"},
{Name: ":status", Value: "400"},
{Name: ":status", Value: "404"},
{Name: ":status", Value: "500"},
{Name: "accept-charset", Value: ""},
{Name: "accept-encoding", Value: "gzip, deflate"},
{Name: "accept-language", Value: ""},
{Name: "accept-ranges", Value: ""},
{Name: "accept", Value: ""},
{Name: "access-control-allow-origin", Value: ""},
{Name: "age", Value: ""},
{Name: "allow", Value: ""},
{Name: "authorization", Value: ""},
{Name: "cache-control", Value: ""},
{Name: "content-disposition", Value: ""},
{Name: "content-encoding", Value: ""},
{Name: "content-language", Value: ""},
{Name: "content-length", Value: ""},
{Name: "content-location", Value: ""},
{Name: "content-range", Value: ""},
{Name: "content-type", Value: ""},
{Name: "cookie", Value: ""},
{Name: "date", Value: ""},
{Name: "etag", Value: ""},
{Name: "expect", Value: ""},
{Name: "expires", Value: ""},
{Name: "from", Value: ""},
{Name: "host", Value: ""},
{Name: "if-match", Value: ""},
{Name: "if-modified-since", Value: ""},
{Name: "if-none-match", Value: ""},
{Name: "if-range", Value: ""},
{Name: "if-unmodified-since", Value: ""},
{Name: "last-modified", Value: ""},
{Name: "link", Value: ""},
{Name: "location", Value: ""},
{Name: "max-forwards", Value: ""},
{Name: "proxy-authenticate", Value: ""},
{Name: "proxy-authorization", Value: ""},
{Name: "range", Value: ""},
{Name: "referer", Value: ""},
{Name: "refresh", Value: ""},
{Name: "retry-after", Value: ""},
{Name: "server", Value: ""},
{Name: "set-cookie", Value: ""},
{Name: "strict-transport-security", Value: ""},
{Name: "transfer-encoding", Value: ""},
{Name: "user-agent", Value: ""},
{Name: "vary", Value: ""},
{Name: "via", Value: ""},
{Name: "www-authenticate", Value: ""},
}
// Build name index map
nameMap := make(map[string][]uint32)
for i, entry := range entries {
nameMap[entry.Name] = append(nameMap[entry.Name], uint32(i))
}
return &StaticTable{
entries: entries,
nameMap: nameMap,
}
}
// Get retrieves a header field by index (0-based)
func (st *StaticTable) Get(index uint32) (string, string, error) {
if index >= uint32(len(st.entries)) {
return "", "", fmt.Errorf("index %d out of range (static table size: %d)", index, len(st.entries))
}
field := st.entries[index]
return field.Name, field.Value, nil
}
// FindExact searches for an exact match (name and value)
// Returns the index (0-based) and true if found
func (st *StaticTable) FindExact(name, value string) (uint32, bool) {
indices, exists := st.nameMap[name]
if !exists {
return 0, false
}
for _, index := range indices {
field := st.entries[index]
if field.Value == value {
return index, true
}
}
return 0, false
}
// FindName searches for a name match
// Returns the first matching index (0-based) and true if found
func (st *StaticTable) FindName(name string) (uint32, bool) {
indices, exists := st.nameMap[name]
if !exists || len(indices) == 0 {
return 0, false
}
return indices[0], true
}
// Size returns the number of entries in the static table
func (st *StaticTable) Size() int {
return len(st.entries)
}

View File

@@ -18,9 +18,6 @@ const (
// RequestTimeout is the maximum time to wait for a response from the client
RequestTimeout = 30 * time.Second
// MaxRequestBodySize is the maximum size of an HTTP request body (10MB)
MaxRequestBodySize = 10 * 1024 * 1024
// ReconnectBaseDelay is the initial delay for reconnection attempts
ReconnectBaseDelay = 1 * time.Second

View File

@@ -18,11 +18,17 @@ type DataHeader struct {
type DataType uint8
const (
DataTypeData DataType = 0x00 // 000
DataTypeResponse DataType = 0x01 // 001
DataTypeClose DataType = 0x02 // 010
DataTypeHTTPRequest DataType = 0x03 // 011
DataTypeHTTPResponse DataType = 0x04 // 100
DataTypeData DataType = 0x00 // 000
DataTypeResponse DataType = 0x01 // 001
DataTypeClose DataType = 0x02 // 010
DataTypeHTTPRequest DataType = 0x03 // 011
DataTypeHTTPResponse DataType = 0x04 // 100
DataTypeHTTPHead DataType = 0x05 // 101 - streaming headers (shared)
DataTypeHTTPBodyChunk DataType = 0x06 // 110 - streaming body chunks (shared)
// Reuse the same type codes for request streaming to stay within 3 bits.
DataTypeHTTPRequestHead DataType = DataTypeHTTPHead
DataTypeHTTPRequestBodyChunk DataType = DataTypeHTTPBodyChunk
)
// String returns the string representation of DataType
@@ -38,6 +44,10 @@ func (t DataType) String() string {
return "http_request"
case DataTypeHTTPResponse:
return "http_response"
case DataTypeHTTPHead:
return "http_head"
case DataTypeHTTPBodyChunk:
return "http_body_chunk"
default:
return "unknown"
}
@@ -56,6 +66,10 @@ func DataTypeFromString(s string) DataType {
return DataTypeHTTPRequest
case "http_response":
return DataTypeHTTPResponse
case "http_head":
return DataTypeHTTPHead
case "http_body_chunk":
return DataTypeHTTPBodyChunk
default:
return DataTypeData
}
@@ -118,8 +132,8 @@ func (h *DataHeader) UnmarshalBinary(data []byte) error {
// Decode flags
flags := data[0]
h.Type = DataType(flags & 0x07) // Bits 0-2
h.IsLast = (flags & 0x08) != 0 // Bit 3
h.Type = DataType(flags & 0x07) // Bits 0-2
h.IsLast = (flags & 0x08) != 0 // Bit 3
// Decode lengths
streamIDLen := int(binary.BigEndian.Uint16(data[1:3]))

View File

@@ -1,9 +1,10 @@
package protocol
import (
json "github.com/goccy/go-json"
"errors"
json "github.com/goccy/go-json"
"github.com/vmihailenco/msgpack/v5"
)
@@ -37,6 +38,31 @@ func DecodeHTTPRequest(data []byte) (*HTTPRequest, error) {
return &req, nil
}
// EncodeHTTPRequestHead encodes HTTP request headers for streaming
func EncodeHTTPRequestHead(head *HTTPRequestHead) ([]byte, error) {
return msgpack.Marshal(head)
}
// DecodeHTTPRequestHead decodes HTTP request headers for streaming
func DecodeHTTPRequestHead(data []byte) (*HTTPRequestHead, error) {
if len(data) == 0 {
return nil, errors.New("empty data")
}
var head HTTPRequestHead
if data[0] == '{' {
if err := json.Unmarshal(data, &head); err != nil {
return nil, err
}
} else {
if err := msgpack.Unmarshal(data, &head); err != nil {
return nil, err
}
}
return &head, nil
}
// EncodeHTTPResponse encodes HTTPResponse using msgpack encoding (optimized)
func EncodeHTTPResponse(resp *HTTPResponse) ([]byte, error) {
return msgpack.Marshal(resp)
@@ -66,3 +92,28 @@ func DecodeHTTPResponse(data []byte) (*HTTPResponse, error) {
return &resp, nil
}
// EncodeHTTPResponseHead encodes HTTP response headers for streaming
func EncodeHTTPResponseHead(head *HTTPResponseHead) ([]byte, error) {
return msgpack.Marshal(head)
}
// DecodeHTTPResponseHead decodes HTTP response headers for streaming
func DecodeHTTPResponseHead(data []byte) (*HTTPResponseHead, error) {
if len(data) == 0 {
return nil, errors.New("empty data")
}
var head HTTPResponseHead
if data[0] == '{' {
if err := json.Unmarshal(data, &head); err != nil {
return nil, err
}
} else {
if err := msgpack.Unmarshal(data, &head); err != nil {
return nil, err
}
}
return &head, nil
}

View File

@@ -33,6 +33,14 @@ type HTTPRequest struct {
Body []byte `json:"body,omitempty"`
}
// HTTPRequestHead represents HTTP request headers for streaming (no body)
type HTTPRequestHead struct {
Method string `json:"method"`
URL string `json:"url"`
Headers map[string][]string `json:"headers"`
ContentLength int64 `json:"content_length"` // -1 for unknown/chunked
}
// HTTPResponse represents an HTTP response from the local service
type HTTPResponse struct {
StatusCode int `json:"status_code"`
@@ -41,6 +49,14 @@ type HTTPResponse struct {
Body []byte `json:"body,omitempty"`
}
// HTTPResponseHead represents HTTP response headers for streaming (no body)
type HTTPResponseHead struct {
StatusCode int `json:"status_code"`
Status string `json:"status"`
Headers map[string][]string `json:"headers"`
ContentLength int64 `json:"content_length"` // -1 for unknown/chunked
}
// RegisterData contains information sent when a tunnel is registered
type RegisterData struct {
Subdomain string `json:"subdomain"`

View File

@@ -7,9 +7,8 @@ import (
"drip/internal/shared/pool"
)
// EncodeDataPayload encodes a data header and payload into a frame payload.
// Deprecated: Use EncodeDataPayloadPooled for better performance.
func EncodeDataPayload(header DataHeader, data []byte) ([]byte, error) {
// encodeDataPayload encodes a data header and payload into a frame payload.
func encodeDataPayload(header DataHeader, data []byte) ([]byte, error) {
streamIDLen := len(header.StreamID)
requestIDLen := len(header.RequestID)
@@ -37,11 +36,6 @@ func EncodeDataPayload(header DataHeader, data []byte) ([]byte, error) {
// EncodeDataPayloadPooled encodes with adaptive allocation based on load.
// Returns payload slice and pool buffer pointer (may be nil).
//
// Adaptive strategy:
// - Mid-load (<150 conn): 256KB threshold, pool disabled → max QPS
// - High-load (≥300 conn): 32KB threshold, pool enabled → stable latency
// - Transition (150-300): Hysteresis to prevent flapping
func EncodeDataPayloadPooled(header DataHeader, data []byte) (payload []byte, poolBuffer *[]byte, err error) {
streamIDLen := len(header.StreamID)
requestIDLen := len(header.RequestID)
@@ -50,12 +44,12 @@ func EncodeDataPayloadPooled(header DataHeader, data []byte) (payload []byte, po
dynamicThreshold := GetAdaptiveThreshold()
if totalLen < dynamicThreshold {
regularPayload, err := EncodeDataPayload(header, data)
regularPayload, err := encodeDataPayload(header, data)
return regularPayload, nil, err
}
if totalLen > pool.SizeLarge {
regularPayload, err := EncodeDataPayload(header, data)
regularPayload, err := encodeDataPayload(header, data)
return regularPayload, nil, err
}
@@ -100,7 +94,3 @@ func DecodeDataPayload(payload []byte) (DataHeader, []byte, error) {
data := payload[headerSize:]
return header, data, nil
}
func GetPayloadHeaderSize(header DataHeader) int {
return header.Size()
}

View File

@@ -172,8 +172,27 @@ func (w *FrameWriter) Close() error {
func (w *FrameWriter) Flush() {
w.mu.Lock()
defer w.mu.Unlock()
if w.closed {
w.mu.Unlock()
return
}
// First, drain the queue into batch
for {
select {
case frame, ok := <-w.queue:
if !ok {
break
}
w.batch = append(w.batch, frame)
default:
goto done
}
}
done:
// Then flush the batch
w.flushBatchLocked()
w.mu.Unlock()
}
func (w *FrameWriter) EnableHeartbeat(interval time.Duration, callback func() *Frame) {

View File

@@ -30,17 +30,6 @@ type ServerConfig struct {
Debug bool
}
// LegacyClientConfig holds the legacy client configuration
// Deprecated: Use config.ClientConfig from client_config.go instead
type LegacyClientConfig struct {
ServerURL string
LocalTarget string
AuthToken string
Subdomain string
Verbose bool
Insecure bool // Skip TLS verification (for testing only)
}
// LoadTLSConfig loads TLS configuration
func (c *ServerConfig) LoadTLSConfig() (*tls.Config, error) {
if !c.TLSEnabled {