mirror of
https://github.com/Gouryella/drip.git
synced 2026-02-23 21:00:44 +00:00
Merge pull request #3 from Gouryella/feat/major-performance-improvements
feat: Add HTTP streaming, compression support, and Docker deployment
This commit is contained in:
184
.github/workflows/docker.yml
vendored
Normal file
184
.github/workflows/docker.yml
vendored
Normal file
@@ -0,0 +1,184 @@
|
||||
name: Docker
|
||||
|
||||
on:
|
||||
# Trigger when a release is published (after assets are uploaded)
|
||||
release:
|
||||
types: [published]
|
||||
|
||||
# Optional manual trigger
|
||||
workflow_dispatch:
|
||||
inputs:
|
||||
version:
|
||||
description: 'Release tag to use (e.g., v1.0.0 or latest)'
|
||||
required: false
|
||||
default: 'latest'
|
||||
|
||||
env:
|
||||
REGISTRY: ghcr.io
|
||||
IMAGE_NAME: ${{ github.repository }}
|
||||
|
||||
permissions:
|
||||
contents: read
|
||||
packages: write
|
||||
|
||||
jobs:
|
||||
build-and-push:
|
||||
name: Build and Push Docker Image
|
||||
runs-on: ubuntu-latest
|
||||
|
||||
# For release event, only build for tags like v1.2.3
|
||||
if: |
|
||||
github.event_name == 'workflow_dispatch' ||
|
||||
(github.event_name == 'release' && startsWith(github.event.release.tag_name, 'v'))
|
||||
|
||||
steps:
|
||||
- name: Checkout code
|
||||
uses: actions/checkout@v4
|
||||
|
||||
- name: Set up QEMU
|
||||
uses: docker/setup-qemu-action@v3
|
||||
|
||||
- name: Set up Docker Buildx
|
||||
uses: docker/setup-buildx-action@v3
|
||||
|
||||
- name: Log in to GitHub Container Registry
|
||||
uses: docker/login-action@v3
|
||||
with:
|
||||
registry: ${{ env.REGISTRY }}
|
||||
username: ${{ github.actor }}
|
||||
password: ${{ secrets.GITHUB_TOKEN }}
|
||||
|
||||
- name: Log in to Docker Hub
|
||||
uses: docker/login-action@v3
|
||||
with:
|
||||
username: ${{ secrets.DOCKERHUB_USERNAME }}
|
||||
password: ${{ secrets.DOCKERHUB_TOKEN }}
|
||||
continue-on-error: true
|
||||
|
||||
# Resolve VERSION:
|
||||
# - release event: use release tag_name (e.g., v0.3.0)
|
||||
# - workflow_dispatch: use input version (default: latest)
|
||||
- name: Get version
|
||||
id: version
|
||||
run: |
|
||||
if [ "${{ github.event_name }}" = "release" ]; then
|
||||
v="${{ github.event.release.tag_name }}"
|
||||
else
|
||||
v="${{ github.event.inputs.version }}"
|
||||
if [ -z "$v" ]; then
|
||||
v="latest"
|
||||
fi
|
||||
fi
|
||||
|
||||
echo "VERSION=$v" >> "$GITHUB_OUTPUT"
|
||||
echo "Resolved VERSION=$v"
|
||||
|
||||
# Ensure release assets exist before building
|
||||
- name: Check release assets
|
||||
id: check_assets
|
||||
run: |
|
||||
VERSION="${{ steps.version.outputs.VERSION }}"
|
||||
REPO="${{ github.repository }}"
|
||||
|
||||
echo "Checking assets for $REPO, VERSION=$VERSION"
|
||||
|
||||
# For 'latest', we can only reliably ask the latest release API,
|
||||
# the asset names are still versioned (drip-vX.Y.Z-linux-arch).
|
||||
if [ "$VERSION" = "latest" ]; then
|
||||
API_URL="https://api.github.com/repos/${REPO}/releases/latest"
|
||||
echo "Using latest release API: $API_URL"
|
||||
json=$(curl -fsSL "$API_URL")
|
||||
|
||||
# Check that assets for both amd64 and arm64 exist
|
||||
echo "$json" | grep -q 'drip-.*linux-amd64' || missing_amd64=1
|
||||
echo "$json" | grep -q 'drip-.*linux-arm64' || missing_arm64=1
|
||||
|
||||
if [ "${missing_amd64:-0}" -eq 0 ] && [ "${missing_arm64:-0}" -eq 0 ]; then
|
||||
echo "assets_ready=true" >> "$GITHUB_OUTPUT"
|
||||
echo "Assets found for both linux-amd64 and linux-arm64 (latest)."
|
||||
else
|
||||
echo "assets_ready=false" >> "$GITHUB_OUTPUT"
|
||||
echo "Required assets for latest release are missing; build will be skipped."
|
||||
fi
|
||||
exit 0
|
||||
fi
|
||||
|
||||
# For a specific version tag (e.g., v0.3.0) check direct download URLs
|
||||
archs="amd64 arm64"
|
||||
missing=0
|
||||
|
||||
for arch in $archs; do
|
||||
url="https://github.com/${REPO}/releases/download/${VERSION}/drip-${VERSION}-linux-${arch}"
|
||||
status=$(curl -o /dev/null -w "%{http_code}" -sL "$url")
|
||||
echo "[$arch] HTTP $status -> $url"
|
||||
if [ "$status" != "200" ]; then
|
||||
missing=1
|
||||
fi
|
||||
done
|
||||
|
||||
if [ "$missing" -eq 0 ]; then
|
||||
echo "assets_ready=true" >> "$GITHUB_OUTPUT"
|
||||
echo "All required assets exist. Proceeding with build."
|
||||
else
|
||||
echo "assets_ready=false" >> "$GITHUB_OUTPUT"
|
||||
echo "Required assets are missing; build will be skipped."
|
||||
fi
|
||||
|
||||
- name: Skip build (assets not ready)
|
||||
if: steps.check_assets.outputs.assets_ready != 'true'
|
||||
run: |
|
||||
echo "Release assets are not ready. Docker image build is skipped."
|
||||
echo "You must upload all required release files (drip-<version>-linux-amd64/arm64) first."
|
||||
|
||||
- name: Extract metadata (tags & labels)
|
||||
id: meta
|
||||
if: steps.check_assets.outputs.assets_ready == 'true'
|
||||
uses: docker/metadata-action@v5
|
||||
with:
|
||||
images: |
|
||||
${{ env.REGISTRY }}/${{ env.IMAGE_NAME }}
|
||||
${{ secrets.DOCKERHUB_USERNAME && format('docker.io/{0}/drip-server', secrets.DOCKERHUB_USERNAME) || '' }}
|
||||
tags: |
|
||||
# Main tag, e.g. v0.3.0 or latest
|
||||
type=raw,value=${{ steps.version.outputs.VERSION }}
|
||||
# Also tag 'latest' for convenience when using a specific version
|
||||
type=raw,value=latest,enable=${{ steps.version.outputs.VERSION != 'latest' }}
|
||||
|
||||
- name: Build and push
|
||||
if: steps.check_assets.outputs.assets_ready == 'true'
|
||||
uses: docker/build-push-action@v5
|
||||
with:
|
||||
context: .
|
||||
file: deployments/Dockerfile.release
|
||||
platforms: linux/amd64,linux/arm64
|
||||
push: true
|
||||
tags: ${{ steps.meta.outputs.tags }}
|
||||
labels: ${{ steps.meta.outputs.labels }}
|
||||
build-args: |
|
||||
VERSION=${{ steps.version.outputs.VERSION }}
|
||||
cache-from: type=gha
|
||||
cache-to: type=gha,mode=max
|
||||
|
||||
- name: Generate deployment summary
|
||||
if: steps.check_assets.outputs.assets_ready == 'true'
|
||||
run: |
|
||||
echo "## 🐳 Docker Image Published" >> "$GITHUB_STEP_SUMMARY"
|
||||
echo "" >> "$GITHUB_STEP_SUMMARY"
|
||||
echo "**Version (GitHub Release tag or 'latest'):** \`${{ steps.version.outputs.VERSION }}\`" >> "$GITHUB_STEP_SUMMARY"
|
||||
echo "" >> "$GITHUB_STEP_SUMMARY"
|
||||
echo "### Pull from GHCR" >> "$GITHUB_STEP_SUMMARY"
|
||||
echo "\`\`\`bash" >> "$GITHUB_STEP_SUMMARY"
|
||||
echo "docker pull ${{ env.REGISTRY }}/${{ env.IMAGE_NAME }}:${{ steps.version.outputs.VERSION }}" >> "$GITHUB_STEP_SUMMARY"
|
||||
echo "\`\`\`" >> "$GITHUB_STEP_SUMMARY"
|
||||
echo "" >> "$GITHUB_STEP_SUMMARY"
|
||||
echo "### Quick start" >> "$GITHUB_STEP_SUMMARY"
|
||||
echo "\`\`\`bash" >> "$GITHUB_STEP_SUMMARY"
|
||||
echo "docker run -d \\\\" >> "$GITHUB_STEP_SUMMARY"
|
||||
echo " --name drip-server \\\\" >> "$GITHUB_STEP_SUMMARY"
|
||||
echo " -p 443:443 \\\\" >> "$GITHUB_STEP_SUMMARY"
|
||||
echo " -v /path/to/certs:/app/data/certs:ro \\\\" >> "$GITHUB_STEP_SUMMARY"
|
||||
echo " ${{ env.REGISTRY }}/${{ env.IMAGE_NAME }}:${{ steps.version.outputs.VERSION }} \\\\" >> "$GITHUB_STEP_SUMMARY"
|
||||
echo " server --domain your.domain.com --port 443 \\\\" >> "$GITHUB_STEP_SUMMARY"
|
||||
echo " --tls-cert /app/data/certs/fullchain.pem \\\\" >> "$GITHUB_STEP_SUMMARY"
|
||||
echo " --tls-key /app/data/certs/privkey.pem" >> "$GITHUB_STEP_SUMMARY"
|
||||
echo "\`\`\`" >> "$GITHUB_STEP_SUMMARY"
|
||||
3
.gitignore
vendored
3
.gitignore
vendored
@@ -51,4 +51,5 @@ tmp/
|
||||
temp/
|
||||
certs/
|
||||
.drip-server.env
|
||||
benchmark-results/
|
||||
benchmark-results/
|
||||
drip
|
||||
|
||||
77
README.md
77
README.md
@@ -66,6 +66,13 @@ bash <(curl -sL https://raw.githubusercontent.com/Gouryella/drip/main/scripts/in
|
||||
|
||||
## Usage
|
||||
|
||||
### First Time Setup
|
||||
|
||||
```bash
|
||||
# Configure server and token (only needed once)
|
||||
drip config init
|
||||
```
|
||||
|
||||
### Basic Tunnels
|
||||
|
||||
```bash
|
||||
@@ -76,7 +83,7 @@ drip http 3000
|
||||
drip https 443
|
||||
|
||||
# Pick your subdomain
|
||||
drip http 3000 --subdomain myapp
|
||||
drip http 3000 -n myapp
|
||||
# → https://myapp.your-domain.com
|
||||
|
||||
# Expose TCP service (database, SSH, etc.)
|
||||
@@ -89,28 +96,33 @@ Not just localhost - forward to any device on your network:
|
||||
|
||||
```bash
|
||||
# Forward to another machine on LAN
|
||||
drip http 8080 --address 192.168.1.100
|
||||
drip http 8080 -a 192.168.1.100
|
||||
|
||||
# Forward to Docker container
|
||||
drip http 3000 --address 172.17.0.2
|
||||
drip http 3000 -a 172.17.0.2
|
||||
|
||||
# Forward to specific interface
|
||||
drip http 3000 --address 10.0.0.5
|
||||
drip http 3000 -a 10.0.0.5
|
||||
```
|
||||
|
||||
### Daemon Mode
|
||||
### Background Mode
|
||||
|
||||
Run tunnels in the background:
|
||||
Run tunnels in the background with `-d`:
|
||||
|
||||
```bash
|
||||
# Start tunnel as daemon
|
||||
drip daemon start http 3000
|
||||
drip daemon start https 8443 --subdomain api
|
||||
# Start tunnel in background
|
||||
drip http 3000 -d
|
||||
drip https 8443 -n api -d
|
||||
|
||||
# Manage daemons
|
||||
drip daemon list
|
||||
drip daemon stop http-3000
|
||||
drip daemon logs http-3000
|
||||
# List running tunnels
|
||||
drip list
|
||||
|
||||
# View tunnel logs
|
||||
drip attach http 3000
|
||||
|
||||
# Stop tunnels
|
||||
drip stop http 3000
|
||||
drip stop all
|
||||
```
|
||||
|
||||
## Server Deployment
|
||||
@@ -214,25 +226,25 @@ sudo journalctl -u drip-server -f
|
||||
drip http 3000
|
||||
|
||||
# Test webhooks from services like Stripe
|
||||
drip http 8000 --subdomain webhooks
|
||||
drip http 8000 -n webhooks
|
||||
```
|
||||
|
||||
**Home Server Access**
|
||||
```bash
|
||||
# Access home NAS remotely
|
||||
drip http 5000 --address 192.168.1.50
|
||||
drip http 5000 -a 192.168.1.50
|
||||
|
||||
# Remote into home network
|
||||
# Remote into home network via SSH
|
||||
drip tcp 22
|
||||
```
|
||||
|
||||
**Docker & Containers**
|
||||
```bash
|
||||
# Expose containerized app
|
||||
drip http 8080 --address 172.17.0.3
|
||||
drip http 8080 -a 172.17.0.3
|
||||
|
||||
# Database access for debugging
|
||||
drip tcp 5432 --address db-container
|
||||
drip tcp 5432 -a db-container
|
||||
```
|
||||
|
||||
## Command Reference
|
||||
@@ -240,26 +252,29 @@ drip tcp 5432 --address db-container
|
||||
```bash
|
||||
# HTTP tunnel
|
||||
drip http <port> [flags]
|
||||
--subdomain, -n Custom subdomain
|
||||
--address, -a Target address (default: 127.0.0.1)
|
||||
--server Server address
|
||||
--token Auth token
|
||||
-n, --subdomain Custom subdomain
|
||||
-a, --address Target address (default: 127.0.0.1)
|
||||
-d, --daemon Run in background
|
||||
-s, --server Server address
|
||||
-t, --token Auth token
|
||||
|
||||
# HTTPS tunnel
|
||||
# HTTPS tunnel (same flags as http)
|
||||
drip https <port> [flags]
|
||||
|
||||
# TCP tunnel
|
||||
# TCP tunnel (same flags as http)
|
||||
drip tcp <port> [flags]
|
||||
|
||||
# Daemon commands
|
||||
drip daemon start <type> <port> [flags]
|
||||
drip daemon list
|
||||
drip daemon stop <name>
|
||||
drip daemon logs <name>
|
||||
# Background tunnel management
|
||||
drip list List running tunnels
|
||||
drip list -i Interactive mode
|
||||
drip attach [type] [port] View logs
|
||||
drip stop <type> <port> Stop tunnel
|
||||
drip stop all Stop all tunnels
|
||||
|
||||
# Configuration
|
||||
drip config init
|
||||
drip config show
|
||||
drip config init Set up server and token
|
||||
drip config show Show current config
|
||||
drip config set <key> <value>
|
||||
```
|
||||
|
||||
## License
|
||||
|
||||
77
README_CN.md
77
README_CN.md
@@ -66,6 +66,13 @@ bash <(curl -sL https://raw.githubusercontent.com/Gouryella/drip/main/scripts/in
|
||||
|
||||
## 使用
|
||||
|
||||
### 首次配置
|
||||
|
||||
```bash
|
||||
# 配置服务器地址和 token(只需一次)
|
||||
drip config init
|
||||
```
|
||||
|
||||
### 基础隧道
|
||||
|
||||
```bash
|
||||
@@ -76,7 +83,7 @@ drip http 3000
|
||||
drip https 443
|
||||
|
||||
# 选择你的子域名
|
||||
drip http 3000 --subdomain myapp
|
||||
drip http 3000 -n myapp
|
||||
# → https://myapp.your-domain.com
|
||||
|
||||
# 暴露 TCP 服务(数据库、SSH 等)
|
||||
@@ -89,28 +96,33 @@ drip tcp 5432
|
||||
|
||||
```bash
|
||||
# 转发到局域网其他机器
|
||||
drip http 8080 --address 192.168.1.100
|
||||
drip http 8080 -a 192.168.1.100
|
||||
|
||||
# 转发到 Docker 容器
|
||||
drip http 3000 --address 172.17.0.2
|
||||
drip http 3000 -a 172.17.0.2
|
||||
|
||||
# 转发到特定网卡
|
||||
drip http 3000 --address 10.0.0.5
|
||||
drip http 3000 -a 10.0.0.5
|
||||
```
|
||||
|
||||
### 守护模式
|
||||
### 后台模式
|
||||
|
||||
让隧道在后台运行:
|
||||
使用 `-d` 让隧道在后台运行:
|
||||
|
||||
```bash
|
||||
# 以守护进程启动隧道
|
||||
drip daemon start http 3000
|
||||
drip daemon start https 8443 --subdomain api
|
||||
# 后台启动隧道
|
||||
drip http 3000 -d
|
||||
drip https 8443 -n api -d
|
||||
|
||||
# 管理守护进程
|
||||
drip daemon list
|
||||
drip daemon stop http-3000
|
||||
drip daemon logs http-3000
|
||||
# 列出运行中的隧道
|
||||
drip list
|
||||
|
||||
# 查看隧道日志
|
||||
drip attach http 3000
|
||||
|
||||
# 停止隧道
|
||||
drip stop http 3000
|
||||
drip stop all
|
||||
```
|
||||
|
||||
## 服务端部署
|
||||
@@ -214,25 +226,25 @@ sudo journalctl -u drip-server -f
|
||||
drip http 3000
|
||||
|
||||
# 测试第三方 webhook(如 Stripe)
|
||||
drip http 8000 --subdomain webhooks
|
||||
drip http 8000 -n webhooks
|
||||
```
|
||||
|
||||
**家庭服务器访问**
|
||||
```bash
|
||||
# 远程访问家里的 NAS
|
||||
drip http 5000 --address 192.168.1.50
|
||||
drip http 5000 -a 192.168.1.50
|
||||
|
||||
# 远程进入家庭网络
|
||||
# 通过 SSH 远程进入家庭网络
|
||||
drip tcp 22
|
||||
```
|
||||
|
||||
**Docker 与容器**
|
||||
```bash
|
||||
# 暴露容器化应用
|
||||
drip http 8080 --address 172.17.0.3
|
||||
drip http 8080 -a 172.17.0.3
|
||||
|
||||
# 数据库调试
|
||||
drip tcp 5432 --address db-container
|
||||
drip tcp 5432 -a db-container
|
||||
```
|
||||
|
||||
## 命令参考
|
||||
@@ -240,26 +252,29 @@ drip tcp 5432 --address db-container
|
||||
```bash
|
||||
# HTTP 隧道
|
||||
drip http <端口> [参数]
|
||||
--subdomain, -n 自定义子域名
|
||||
--address, -a 目标地址(默认:127.0.0.1)
|
||||
--server 服务器地址
|
||||
--token 认证 token
|
||||
-n, --subdomain 自定义子域名
|
||||
-a, --address 目标地址(默认:127.0.0.1)
|
||||
-d, --daemon 后台运行
|
||||
-s, --server 服务器地址
|
||||
-t, --token 认证 token
|
||||
|
||||
# HTTPS 隧道
|
||||
# HTTPS 隧道(参数同 http)
|
||||
drip https <端口> [参数]
|
||||
|
||||
# TCP 隧道
|
||||
# TCP 隧道(参数同 http)
|
||||
drip tcp <端口> [参数]
|
||||
|
||||
# 守护进程命令
|
||||
drip daemon start <类型> <端口> [参数]
|
||||
drip daemon list
|
||||
drip daemon stop <名称>
|
||||
drip daemon logs <名称>
|
||||
# 后台隧道管理
|
||||
drip list 列出运行中的隧道
|
||||
drip list -i 交互模式
|
||||
drip attach [类型] [端口] 查看日志
|
||||
drip stop <类型> <端口> 停止隧道
|
||||
drip stop all 停止所有隧道
|
||||
|
||||
# 配置
|
||||
drip config init
|
||||
drip config show
|
||||
drip config init 设置服务器和 token
|
||||
drip config show 显示当前配置
|
||||
drip config set <键> <值>
|
||||
```
|
||||
|
||||
## 协议
|
||||
|
||||
@@ -1,4 +1,4 @@
|
||||
FROM golang:1.25-alpine AS builder
|
||||
FROM golang:1.23-alpine AS builder
|
||||
|
||||
RUN apk add --no-cache git ca-certificates tzdata
|
||||
|
||||
|
||||
43
deployments/Dockerfile.release
Normal file
43
deployments/Dockerfile.release
Normal file
@@ -0,0 +1,43 @@
|
||||
# Dockerfile for deploying drip-server from GitHub Release
|
||||
# Usage:
|
||||
# docker build -f deployments/Dockerfile.release -t drip-server .
|
||||
# docker build -f deployments/Dockerfile.release --build-arg VERSION=v1.0.0 -t drip-server:v1.0.0 .
|
||||
|
||||
FROM alpine:latest
|
||||
|
||||
# VERSION is passed from GitHub Actions (release tag or "latest")
|
||||
ARG VERSION=latest
|
||||
# TARGETARCH is automatically set by Docker Buildx (amd64, arm64, etc.)
|
||||
ARG TARGETARCH=amd64
|
||||
|
||||
RUN apk add --no-cache ca-certificates tzdata curl && update-ca-certificates
|
||||
|
||||
RUN addgroup -S drip && adduser -S -G drip drip
|
||||
|
||||
WORKDIR /app
|
||||
|
||||
# Download binary from GitHub Releases
|
||||
RUN set -ex; \
|
||||
# Resolve "latest" to an actual tag
|
||||
if [ "$VERSION" = "latest" ]; then \
|
||||
VERSION=$(curl -sL https://api.github.com/repos/Gouryella/drip/releases/latest \
|
||||
| grep '"tag_name"' | cut -d'"' -f4); \
|
||||
fi; \
|
||||
echo "Resolved VERSION=${VERSION}"; \
|
||||
echo "Downloading drip ${VERSION} for linux-${TARGETARCH}"; \
|
||||
curl -fsSL "https://github.com/Gouryella/drip/releases/download/${VERSION}/drip-${VERSION}-linux-${TARGETARCH}" -o /app/drip; \
|
||||
chmod +x /app/drip; \
|
||||
/app/drip version --short
|
||||
|
||||
RUN mkdir -p /app/data/certs && \
|
||||
chown -R drip:drip /app
|
||||
|
||||
USER drip
|
||||
|
||||
EXPOSE 80 443 8080 20000-20100
|
||||
|
||||
HEALTHCHECK --interval=30s --timeout=3s --start-period=5s --retries=3 \
|
||||
CMD curl -fsS "http://localhost:${PORT:-8080}/health" >/dev/null || exit 1
|
||||
|
||||
ENTRYPOINT ["/app/drip"]
|
||||
CMD ["server", "--port", "8080"]
|
||||
@@ -1,6 +1,35 @@
|
||||
# Docker Deployment
|
||||
|
||||
## Quick Start
|
||||
## Quick Start (Recommended)
|
||||
|
||||
Deploy drip-server using pre-built images from GitHub Container Registry:
|
||||
|
||||
```bash
|
||||
# Pull the latest image
|
||||
docker pull ghcr.io/gouryella/drip:latest
|
||||
|
||||
# Or use docker compose
|
||||
curl -fsSL https://raw.githubusercontent.com/Gouryella/drip/main/docker-compose.release.yml -o docker-compose.yml
|
||||
|
||||
# Create .env file
|
||||
cat > .env << EOF
|
||||
DOMAIN=tunnel.example.com
|
||||
AUTH_TOKEN=your-secret-token
|
||||
VERSION=latest
|
||||
EOF
|
||||
|
||||
# Place your TLS certificates
|
||||
mkdir -p certs
|
||||
cp /path/to/fullchain.pem certs/
|
||||
cp /path/to/privkey.pem certs/
|
||||
|
||||
# Start server
|
||||
docker compose up -d
|
||||
```
|
||||
|
||||
## Build from Source
|
||||
|
||||
If you prefer to build locally:
|
||||
|
||||
### Server (Production)
|
||||
|
||||
|
||||
72
docker-compose.release.yml
Normal file
72
docker-compose.release.yml
Normal file
@@ -0,0 +1,72 @@
|
||||
# Docker Compose for deploying drip-server from GitHub Release
|
||||
#
|
||||
# Usage:
|
||||
# 1. Copy this file to your server
|
||||
# 2. Create .env file with your settings (see .env.example below)
|
||||
# 3. Run: docker compose -f docker-compose.release.yml up -d
|
||||
#
|
||||
# Environment variables (.env.example):
|
||||
# DOMAIN=tunnel.example.com
|
||||
# AUTH_TOKEN=your-secret-token
|
||||
# VERSION=latest
|
||||
# TZ=UTC
|
||||
|
||||
services:
|
||||
drip-server:
|
||||
image: ghcr.io/gouryella/drip:${VERSION:-latest}
|
||||
container_name: drip-server
|
||||
restart: unless-stopped
|
||||
|
||||
ports:
|
||||
- "443:443"
|
||||
- "20000-20100:20000-20100" # TCP tunnel ports
|
||||
|
||||
volumes:
|
||||
- ./certs:/app/data/certs:ro
|
||||
- drip-data:/app/data
|
||||
|
||||
environment:
|
||||
TZ: ${TZ:-UTC}
|
||||
|
||||
command: >
|
||||
server
|
||||
--domain ${DOMAIN:-tunnel.localhost}
|
||||
--port 443
|
||||
--tls-cert /app/data/certs/fullchain.pem
|
||||
--tls-key /app/data/certs/privkey.pem
|
||||
--token ${AUTH_TOKEN:-}
|
||||
--tcp-port-min 20000
|
||||
--tcp-port-max 20100
|
||||
|
||||
networks:
|
||||
- drip-net
|
||||
|
||||
logging:
|
||||
driver: json-file
|
||||
options:
|
||||
max-size: 10m
|
||||
max-file: "3"
|
||||
|
||||
deploy:
|
||||
resources:
|
||||
limits:
|
||||
cpus: '2'
|
||||
memory: 512M
|
||||
reservations:
|
||||
cpus: '0.25'
|
||||
memory: 64M
|
||||
|
||||
healthcheck:
|
||||
test: ["CMD", "wget", "--no-verbose", "--tries=1", "--spider", "http://localhost:443/health"]
|
||||
interval: 30s
|
||||
timeout: 3s
|
||||
retries: 3
|
||||
start_period: 10s
|
||||
|
||||
volumes:
|
||||
drip-data:
|
||||
driver: local
|
||||
|
||||
networks:
|
||||
drip-net:
|
||||
driver: bridge
|
||||
@@ -221,13 +221,6 @@ func runConfigValidate(cmd *cobra.Command, args []string) error {
|
||||
return nil
|
||||
}
|
||||
|
||||
func enabledDisabled(value bool) string {
|
||||
if value {
|
||||
return "enabled"
|
||||
}
|
||||
return "disabled"
|
||||
}
|
||||
|
||||
func validateServerAddress(addr string) (bool, string) {
|
||||
addr = strings.TrimSpace(addr)
|
||||
if addr == "" {
|
||||
|
||||
@@ -153,12 +153,7 @@ func StartDaemon(tunnelType string, port int, args []string) error {
|
||||
|
||||
// Build command arguments (remove -D/--daemon flag to prevent recursion)
|
||||
var cleanArgs []string
|
||||
skipNext := false
|
||||
for i, arg := range args {
|
||||
if skipNext {
|
||||
skipNext = false
|
||||
continue
|
||||
}
|
||||
for _, arg := range args {
|
||||
// Skip -D or --daemon flags (but NOT --daemon-child)
|
||||
if arg == "-D" || arg == "--daemon" {
|
||||
continue
|
||||
@@ -167,8 +162,6 @@ func StartDaemon(tunnelType string, port int, args []string) error {
|
||||
if arg == "-d" {
|
||||
continue
|
||||
}
|
||||
// Skip if next arg would be a value for a removed flag (not applicable for boolean)
|
||||
_ = i
|
||||
cleanArgs = append(cleanArgs, arg)
|
||||
}
|
||||
|
||||
|
||||
@@ -203,7 +203,3 @@ func isNonRetryableError(err error) bool {
|
||||
strings.Contains(errStr, "authentication") ||
|
||||
strings.Contains(errStr, "Invalid authentication token")
|
||||
}
|
||||
|
||||
func isNonRetryableErrorTCP(err error) bool {
|
||||
return isNonRetryableError(err)
|
||||
}
|
||||
|
||||
@@ -7,7 +7,7 @@ import (
|
||||
// RenderConfigInit renders config initialization UI
|
||||
func RenderConfigInit() string {
|
||||
title := "Drip Configuration Setup"
|
||||
box := boxStyle.Copy().Width(50)
|
||||
box := boxStyle.Width(50)
|
||||
return "\n" + box.Render(titleStyle.Render(title)) + "\n"
|
||||
}
|
||||
|
||||
|
||||
@@ -6,19 +6,12 @@ import (
|
||||
|
||||
var (
|
||||
// Colors inspired by Vercel CLI
|
||||
primaryColor = lipgloss.Color("#0070F3")
|
||||
successColor = lipgloss.Color("#0070F3")
|
||||
warningColor = lipgloss.Color("#F5A623")
|
||||
errorColor = lipgloss.Color("#E00")
|
||||
mutedColor = lipgloss.Color("#888")
|
||||
highlightColor = lipgloss.Color("#0070F3")
|
||||
cyanColor = lipgloss.Color("#50E3C2")
|
||||
purpleColor = lipgloss.Color("#7928CA")
|
||||
|
||||
// Base styles
|
||||
baseStyle = lipgloss.NewStyle().
|
||||
PaddingLeft(1).
|
||||
PaddingRight(1)
|
||||
|
||||
// Box styles - Vercel-like clean box
|
||||
boxStyle = lipgloss.NewStyle().
|
||||
@@ -27,14 +20,11 @@ var (
|
||||
MarginTop(1).
|
||||
MarginBottom(1)
|
||||
|
||||
successBoxStyle = boxStyle.Copy().
|
||||
BorderForeground(successColor)
|
||||
successBoxStyle = boxStyle.BorderForeground(successColor)
|
||||
|
||||
warningBoxStyle = boxStyle.Copy().
|
||||
BorderForeground(warningColor)
|
||||
warningBoxStyle = boxStyle.BorderForeground(warningColor)
|
||||
|
||||
errorBoxStyle = boxStyle.Copy().
|
||||
BorderForeground(errorColor)
|
||||
errorBoxStyle = boxStyle.BorderForeground(errorColor)
|
||||
|
||||
// Text styles
|
||||
titleStyle = lipgloss.NewStyle().
|
||||
@@ -85,10 +75,6 @@ var (
|
||||
|
||||
tableCellStyle = lipgloss.NewStyle().
|
||||
PaddingRight(2)
|
||||
|
||||
tableRowHighlight = lipgloss.NewStyle().
|
||||
Foreground(highlightColor).
|
||||
Bold(true)
|
||||
)
|
||||
|
||||
// Success returns a styled success message
|
||||
|
||||
@@ -68,7 +68,7 @@ func (t *Table) Render() string {
|
||||
// Header
|
||||
headerParts := make([]string, len(t.headers))
|
||||
for i, header := range t.headers {
|
||||
style := tableHeaderStyle.Copy().Width(colWidths[i])
|
||||
style := tableHeaderStyle.Width(colWidths[i])
|
||||
headerParts[i] = style.Render(header)
|
||||
}
|
||||
output.WriteString(strings.Join(headerParts, " "))
|
||||
@@ -87,7 +87,7 @@ func (t *Table) Render() string {
|
||||
rowParts := make([]string, len(t.headers))
|
||||
for i, cell := range row {
|
||||
if i < len(colWidths) {
|
||||
style := tableCellStyle.Copy().Width(colWidths[i])
|
||||
style := tableCellStyle.Width(colWidths[i])
|
||||
rowParts[i] = style.Render(cell)
|
||||
}
|
||||
}
|
||||
|
||||
@@ -59,7 +59,7 @@ func RenderTunnelConnected(status *TunnelStatus) string {
|
||||
|
||||
urlLine := lipgloss.JoinHorizontal(
|
||||
lipgloss.Left,
|
||||
urlStyle.Copy().Foreground(accent).Render(status.URL),
|
||||
urlStyle.Foreground(accent).Render(status.URL),
|
||||
lipgloss.NewStyle().MarginLeft(1).Foreground(mutedColor).Render("(forwarded address)"),
|
||||
)
|
||||
|
||||
|
||||
@@ -2,6 +2,7 @@ package tcp
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"context"
|
||||
"crypto/tls"
|
||||
"fmt"
|
||||
"io"
|
||||
@@ -18,19 +19,21 @@ import (
|
||||
|
||||
// FrameHandler handles data frames and forwards to local service
|
||||
type FrameHandler struct {
|
||||
conn net.Conn
|
||||
frameWriter *protocol.FrameWriter
|
||||
localHost string
|
||||
localPort int
|
||||
logger *zap.Logger
|
||||
streams map[string]*Stream
|
||||
streamMu sync.RWMutex
|
||||
tunnelType protocol.TunnelType
|
||||
httpClient *http.Client
|
||||
stats *TrafficStats
|
||||
isClosedCheck func() bool
|
||||
bufferPool *pool.BufferPool
|
||||
headerPool *pool.HeaderPool
|
||||
conn net.Conn
|
||||
frameWriter *protocol.FrameWriter
|
||||
localHost string
|
||||
localPort int
|
||||
logger *zap.Logger
|
||||
streams map[string]*Stream
|
||||
streamMu sync.RWMutex
|
||||
streamingRequests map[string]*StreamingRequest
|
||||
streamingReqMu sync.RWMutex
|
||||
tunnelType protocol.TunnelType
|
||||
httpClient *http.Client
|
||||
stats *TrafficStats
|
||||
isClosedCheck func() bool
|
||||
bufferPool *pool.BufferPool
|
||||
headerPool *pool.HeaderPool
|
||||
}
|
||||
|
||||
// Stream represents a single request/response stream
|
||||
@@ -39,6 +42,23 @@ type Stream struct {
|
||||
LocalConn net.Conn
|
||||
ResponseCh chan []byte
|
||||
Done chan struct{}
|
||||
closed bool
|
||||
mu sync.Mutex
|
||||
}
|
||||
|
||||
// StreamingRequest represents a streaming upload request in progress
|
||||
type StreamingRequest struct {
|
||||
RequestID string
|
||||
Writer *io.PipeWriter
|
||||
Done chan struct{}
|
||||
chunkQueue chan *chunkData
|
||||
closed bool
|
||||
mu sync.Mutex
|
||||
}
|
||||
|
||||
type chunkData struct {
|
||||
data []byte
|
||||
isLast bool
|
||||
}
|
||||
|
||||
func NewFrameHandler(conn net.Conn, frameWriter *protocol.FrameWriter, localHost string, localPort int, tunnelType protocol.TunnelType, logger *zap.Logger, isClosedCheck func() bool, bufferPool *pool.BufferPool) *FrameHandler {
|
||||
@@ -50,29 +70,30 @@ func NewFrameHandler(conn net.Conn, frameWriter *protocol.FrameWriter, localHost
|
||||
}
|
||||
|
||||
return &FrameHandler{
|
||||
conn: conn,
|
||||
frameWriter: frameWriter,
|
||||
localHost: localHost,
|
||||
localPort: localPort,
|
||||
logger: logger,
|
||||
streams: make(map[string]*Stream),
|
||||
tunnelType: tunnelType,
|
||||
stats: NewTrafficStats(),
|
||||
isClosedCheck: isClosedCheck,
|
||||
bufferPool: bufferPool,
|
||||
headerPool: pool.NewHeaderPool(),
|
||||
conn: conn,
|
||||
frameWriter: frameWriter,
|
||||
localHost: localHost,
|
||||
localPort: localPort,
|
||||
logger: logger,
|
||||
streams: make(map[string]*Stream),
|
||||
streamingRequests: make(map[string]*StreamingRequest),
|
||||
tunnelType: tunnelType,
|
||||
stats: NewTrafficStats(),
|
||||
isClosedCheck: isClosedCheck,
|
||||
bufferPool: bufferPool,
|
||||
headerPool: pool.NewHeaderPool(),
|
||||
httpClient: &http.Client{
|
||||
Timeout: 30 * time.Second,
|
||||
// No overall timeout - streaming responses can take arbitrary time
|
||||
Transport: &http.Transport{
|
||||
MaxIdleConns: 1000, // Optimized for both mid and high load scenarios
|
||||
MaxIdleConnsPerHost: 500, // Sufficient for 400+ concurrent connections
|
||||
MaxConnsPerHost: 0, // Unlimited
|
||||
MaxIdleConns: 1000,
|
||||
MaxIdleConnsPerHost: 500,
|
||||
MaxConnsPerHost: 0,
|
||||
IdleConnTimeout: 180 * time.Second,
|
||||
DisableCompression: true,
|
||||
DisableKeepAlives: false,
|
||||
TLSHandshakeTimeout: 10 * time.Second,
|
||||
TLSClientConfig: tlsConfig,
|
||||
ResponseHeaderTimeout: 15 * time.Second,
|
||||
ResponseHeaderTimeout: 30 * time.Second,
|
||||
ExpectContinueTimeout: 1 * time.Second,
|
||||
DialContext: (&net.Dialer{
|
||||
Timeout: 5 * time.Second,
|
||||
@@ -99,6 +120,14 @@ func (h *FrameHandler) HandleDataFrame(frame *protocol.Frame) error {
|
||||
return h.handleHTTPFrame(header, data)
|
||||
}
|
||||
|
||||
if header.Type == protocol.DataTypeHTTPRequestHead || header.Type == protocol.DataTypeHTTPHead {
|
||||
return h.handleHTTPRequestHead(header, data)
|
||||
}
|
||||
|
||||
if header.Type == protocol.DataTypeHTTPRequestBodyChunk || header.Type == protocol.DataTypeHTTPBodyChunk {
|
||||
return h.handleHTTPRequestBodyChunk(header, data)
|
||||
}
|
||||
|
||||
if header.Type == protocol.DataTypeClose {
|
||||
h.closeStream(header.StreamID)
|
||||
return nil
|
||||
@@ -122,7 +151,7 @@ func (h *FrameHandler) getOrCreateStream(streamID string) (*Stream, error) {
|
||||
return stream, nil
|
||||
}
|
||||
|
||||
localAddr := fmt.Sprintf("%s:%d", h.localHost, h.localPort)
|
||||
localAddr := net.JoinHostPort(h.localHost, fmt.Sprintf("%d", h.localPort))
|
||||
localConn, err := net.DialTimeout("tcp", localAddr, 5*time.Second)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to connect to local service: %w", err)
|
||||
@@ -143,8 +172,25 @@ func (h *FrameHandler) getOrCreateStream(streamID string) (*Stream, error) {
|
||||
}
|
||||
|
||||
func (h *FrameHandler) forwardToLocal(stream *Stream, data []byte) {
|
||||
// Check if stream is closed using mutex
|
||||
stream.mu.Lock()
|
||||
if stream.closed {
|
||||
stream.mu.Unlock()
|
||||
return
|
||||
}
|
||||
stream.mu.Unlock()
|
||||
|
||||
// Double check with Done channel
|
||||
select {
|
||||
case <-stream.Done:
|
||||
// Stream already closed, ignore data
|
||||
return
|
||||
default:
|
||||
}
|
||||
|
||||
if _, err := stream.LocalConn.Write(data); err != nil {
|
||||
h.logger.Error("Failed to write to local service",
|
||||
// Only log at debug level since connection close is often expected
|
||||
h.logger.Debug("Failed to write to local service",
|
||||
zap.String("stream_id", stream.ID),
|
||||
zap.Error(err),
|
||||
)
|
||||
@@ -160,6 +206,14 @@ func (h *FrameHandler) handleLocalResponse(stream *Stream) {
|
||||
buf := (*bufPtr)[:pool.SizeMedium]
|
||||
|
||||
for {
|
||||
// Check if stream is closed before reading
|
||||
stream.mu.Lock()
|
||||
closed := stream.closed
|
||||
stream.mu.Unlock()
|
||||
if closed {
|
||||
break
|
||||
}
|
||||
|
||||
n, err := stream.LocalConn.Read(buf)
|
||||
if err != nil {
|
||||
break
|
||||
@@ -170,6 +224,14 @@ func (h *FrameHandler) handleLocalResponse(stream *Stream) {
|
||||
break
|
||||
}
|
||||
|
||||
// Check again after read
|
||||
stream.mu.Lock()
|
||||
closed = stream.closed
|
||||
stream.mu.Unlock()
|
||||
if closed {
|
||||
break
|
||||
}
|
||||
|
||||
header := protocol.DataHeader{
|
||||
StreamID: stream.ID,
|
||||
Type: protocol.DataTypeResponse,
|
||||
@@ -178,14 +240,14 @@ func (h *FrameHandler) handleLocalResponse(stream *Stream) {
|
||||
|
||||
payload, poolBuffer, err := protocol.EncodeDataPayloadPooled(header, buf[:n])
|
||||
if err != nil {
|
||||
h.logger.Error("Encode payload failed", zap.Error(err))
|
||||
h.logger.Debug("Encode payload failed", zap.Error(err))
|
||||
break
|
||||
}
|
||||
|
||||
dataFrame := protocol.NewFramePooled(protocol.FrameTypeData, payload, poolBuffer)
|
||||
err = h.frameWriter.WriteFrame(dataFrame)
|
||||
if err != nil {
|
||||
h.logger.Error("Send frame failed", zap.Error(err))
|
||||
h.logger.Debug("Send frame failed", zap.Error(err))
|
||||
break
|
||||
}
|
||||
|
||||
@@ -261,19 +323,185 @@ func (h *FrameHandler) handleHTTPFrame(header protocol.DataHeader, payload []byt
|
||||
}
|
||||
defer resp.Body.Close()
|
||||
|
||||
body, err := io.ReadAll(resp.Body)
|
||||
// Threshold for switching from buffered to streaming mode
|
||||
const bufferThreshold int64 = 1 * 1024 * 1024 // 1MB
|
||||
|
||||
// If Content-Length is known and large, use streaming directly
|
||||
if resp.ContentLength > bufferThreshold {
|
||||
return h.streamHTTPResponse(header.StreamID, header.RequestID, resp)
|
||||
}
|
||||
|
||||
// For small or unknown size: try buffered first, switch to streaming if too large
|
||||
return h.adaptiveHTTPResponse(header.StreamID, header.RequestID, resp, bufferThreshold)
|
||||
}
|
||||
|
||||
// adaptiveHTTPResponse tries buffered mode first, switches to streaming if data exceeds threshold
|
||||
func (h *FrameHandler) adaptiveHTTPResponse(streamID, requestID string, resp *http.Response, threshold int64) error {
|
||||
if h.isClosedCheck != nil && h.isClosedCheck() {
|
||||
return nil
|
||||
}
|
||||
|
||||
// Buffer for initial read
|
||||
buffer := make([]byte, 0, threshold)
|
||||
tempBuf := make([]byte, 32*1024) // 32KB read chunks
|
||||
|
||||
var totalRead int64
|
||||
var hitThreshold bool
|
||||
|
||||
// Try to read up to threshold
|
||||
for totalRead < threshold {
|
||||
n, err := resp.Body.Read(tempBuf)
|
||||
if n > 0 {
|
||||
buffer = append(buffer, tempBuf[:n]...)
|
||||
totalRead += int64(n)
|
||||
}
|
||||
if err == io.EOF {
|
||||
// Response completed within threshold - use buffered mode
|
||||
break
|
||||
}
|
||||
if err != nil {
|
||||
return h.sendHTTPError(streamID, requestID, http.StatusBadGateway, fmt.Sprintf("read response: %v", err))
|
||||
}
|
||||
if totalRead >= threshold {
|
||||
hitThreshold = true
|
||||
break
|
||||
}
|
||||
}
|
||||
|
||||
if !hitThreshold {
|
||||
// Small response - send as buffered
|
||||
// Clean response headers - remove hop-by-hop headers that are invalid after proxying
|
||||
cleanedHeaders := h.cleanResponseHeaders(resp.Header)
|
||||
|
||||
httpResp := protocol.HTTPResponse{
|
||||
StatusCode: resp.StatusCode,
|
||||
Status: resp.Status,
|
||||
Headers: cleanedHeaders,
|
||||
Body: buffer,
|
||||
}
|
||||
return h.sendHTTPResponse(streamID, requestID, &httpResp)
|
||||
}
|
||||
|
||||
// Large response - switch to streaming mode
|
||||
// Clean response headers - remove hop-by-hop headers that are invalid after proxying
|
||||
cleanedHeaders := h.cleanResponseHeaders(resp.Header)
|
||||
|
||||
// First send headers
|
||||
httpHead := protocol.HTTPResponseHead{
|
||||
StatusCode: resp.StatusCode,
|
||||
Status: resp.Status,
|
||||
Headers: cleanedHeaders,
|
||||
ContentLength: resp.ContentLength, // -1 if unknown
|
||||
}
|
||||
|
||||
headBytes, err := protocol.EncodeHTTPResponseHead(&httpHead)
|
||||
if err != nil {
|
||||
return h.sendHTTPError(header.StreamID, header.RequestID, http.StatusBadGateway, fmt.Sprintf("read response: %v", err))
|
||||
return fmt.Errorf("encode http head: %w", err)
|
||||
}
|
||||
|
||||
httpResp := protocol.HTTPResponse{
|
||||
StatusCode: resp.StatusCode,
|
||||
Status: resp.Status,
|
||||
Headers: resp.Header,
|
||||
Body: body,
|
||||
headHeader := protocol.DataHeader{
|
||||
StreamID: streamID,
|
||||
RequestID: requestID,
|
||||
Type: protocol.DataTypeHTTPHead,
|
||||
IsLast: false,
|
||||
}
|
||||
|
||||
return h.sendHTTPResponse(header.StreamID, header.RequestID, &httpResp)
|
||||
headPayload, headPoolBuffer, err := protocol.EncodeDataPayloadPooled(headHeader, headBytes)
|
||||
if err != nil {
|
||||
return fmt.Errorf("encode head payload: %w", err)
|
||||
}
|
||||
|
||||
headFrame := protocol.NewFramePooled(protocol.FrameTypeData, headPayload, headPoolBuffer)
|
||||
if err := h.frameWriter.WriteFrame(headFrame); err != nil {
|
||||
return err
|
||||
}
|
||||
h.frameWriter.Flush()
|
||||
h.stats.AddBytesOut(int64(len(headPayload)))
|
||||
|
||||
// Send buffered data as first chunk
|
||||
if len(buffer) > 0 {
|
||||
chunkHeader := protocol.DataHeader{
|
||||
StreamID: streamID,
|
||||
RequestID: requestID,
|
||||
Type: protocol.DataTypeHTTPBodyChunk,
|
||||
IsLast: false,
|
||||
}
|
||||
|
||||
chunkPayload, chunkPoolBuffer, err := protocol.EncodeDataPayloadPooled(chunkHeader, buffer)
|
||||
if err != nil {
|
||||
return fmt.Errorf("encode chunk payload: %w", err)
|
||||
}
|
||||
|
||||
chunkFrame := protocol.NewFramePooled(protocol.FrameTypeData, chunkPayload, chunkPoolBuffer)
|
||||
if err := h.frameWriter.WriteFrame(chunkFrame); err != nil {
|
||||
return err
|
||||
}
|
||||
h.stats.AddBytesOut(int64(len(chunkPayload)))
|
||||
}
|
||||
|
||||
// Clear buffer to free memory
|
||||
buffer = nil
|
||||
|
||||
// Continue streaming remaining data
|
||||
bufPtr := h.bufferPool.Get(pool.SizeMedium)
|
||||
defer h.bufferPool.Put(bufPtr)
|
||||
buf := (*bufPtr)[:pool.SizeMedium]
|
||||
|
||||
for {
|
||||
if h.isClosedCheck != nil && h.isClosedCheck() {
|
||||
return nil
|
||||
}
|
||||
|
||||
n, readErr := resp.Body.Read(buf)
|
||||
if n > 0 {
|
||||
isLast := readErr == io.EOF
|
||||
|
||||
chunkHeader := protocol.DataHeader{
|
||||
StreamID: streamID,
|
||||
RequestID: requestID,
|
||||
Type: protocol.DataTypeHTTPBodyChunk,
|
||||
IsLast: isLast,
|
||||
}
|
||||
|
||||
chunkPayload, chunkPoolBuffer, err := protocol.EncodeDataPayloadPooled(chunkHeader, buf[:n])
|
||||
if err != nil {
|
||||
return fmt.Errorf("encode chunk payload: %w", err)
|
||||
}
|
||||
|
||||
chunkFrame := protocol.NewFramePooled(protocol.FrameTypeData, chunkPayload, chunkPoolBuffer)
|
||||
if err := h.frameWriter.WriteFrame(chunkFrame); err != nil {
|
||||
return err
|
||||
}
|
||||
h.stats.AddBytesOut(int64(len(chunkPayload)))
|
||||
}
|
||||
|
||||
if readErr == io.EOF {
|
||||
if n == 0 {
|
||||
// Send final empty chunk
|
||||
finalHeader := protocol.DataHeader{
|
||||
StreamID: streamID,
|
||||
RequestID: requestID,
|
||||
Type: protocol.DataTypeHTTPBodyChunk,
|
||||
IsLast: true,
|
||||
}
|
||||
finalPayload, finalPoolBuffer, err := protocol.EncodeDataPayloadPooled(finalHeader, nil)
|
||||
if err != nil {
|
||||
return fmt.Errorf("encode final payload: %w", err)
|
||||
}
|
||||
finalFrame := protocol.NewFramePooled(protocol.FrameTypeData, finalPayload, finalPoolBuffer)
|
||||
if err := h.frameWriter.WriteFrame(finalFrame); err != nil {
|
||||
return err
|
||||
}
|
||||
}
|
||||
h.frameWriter.Flush()
|
||||
break
|
||||
}
|
||||
if readErr != nil {
|
||||
return fmt.Errorf("read response body: %w", readErr)
|
||||
}
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func (h *FrameHandler) sendHTTPError(streamID, requestID string, status int, message string) error {
|
||||
@@ -294,6 +522,111 @@ func (h *FrameHandler) sendHTTPError(streamID, requestID string, status int, mes
|
||||
return err
|
||||
}
|
||||
|
||||
// streamHTTPResponse streams HTTP response using zero-copy approach
|
||||
// First sends headers, then streams body chunks
|
||||
func (h *FrameHandler) streamHTTPResponse(streamID, requestID string, resp *http.Response) error {
|
||||
if h.isClosedCheck != nil && h.isClosedCheck() {
|
||||
return nil
|
||||
}
|
||||
|
||||
// Clean response headers - remove hop-by-hop headers that are invalid after proxying
|
||||
cleanedHeaders := h.cleanResponseHeaders(resp.Header)
|
||||
|
||||
// Send HTTP headers first
|
||||
contentLength := resp.ContentLength // -1 if unknown
|
||||
httpHead := protocol.HTTPResponseHead{
|
||||
StatusCode: resp.StatusCode,
|
||||
Status: resp.Status,
|
||||
Headers: cleanedHeaders,
|
||||
ContentLength: contentLength,
|
||||
}
|
||||
|
||||
headBytes, err := protocol.EncodeHTTPResponseHead(&httpHead)
|
||||
if err != nil {
|
||||
return fmt.Errorf("encode http head: %w", err)
|
||||
}
|
||||
|
||||
headHeader := protocol.DataHeader{
|
||||
StreamID: streamID,
|
||||
RequestID: requestID,
|
||||
Type: protocol.DataTypeHTTPHead,
|
||||
IsLast: false,
|
||||
}
|
||||
|
||||
headPayload, headPoolBuffer, err := protocol.EncodeDataPayloadPooled(headHeader, headBytes)
|
||||
if err != nil {
|
||||
return fmt.Errorf("encode head payload: %w", err)
|
||||
}
|
||||
|
||||
headFrame := protocol.NewFramePooled(protocol.FrameTypeData, headPayload, headPoolBuffer)
|
||||
if err := h.frameWriter.WriteFrame(headFrame); err != nil {
|
||||
return err
|
||||
}
|
||||
h.frameWriter.Flush()
|
||||
h.stats.AddBytesOut(int64(len(headPayload)))
|
||||
|
||||
// Stream body chunks - zero copy using buffer pool
|
||||
bufPtr := h.bufferPool.Get(pool.SizeMedium)
|
||||
defer h.bufferPool.Put(bufPtr)
|
||||
buf := (*bufPtr)[:pool.SizeMedium]
|
||||
|
||||
for {
|
||||
if h.isClosedCheck != nil && h.isClosedCheck() {
|
||||
return nil
|
||||
}
|
||||
|
||||
n, readErr := resp.Body.Read(buf)
|
||||
if n > 0 {
|
||||
isLast := readErr == io.EOF
|
||||
|
||||
chunkHeader := protocol.DataHeader{
|
||||
StreamID: streamID,
|
||||
RequestID: requestID,
|
||||
Type: protocol.DataTypeHTTPBodyChunk,
|
||||
IsLast: isLast,
|
||||
}
|
||||
|
||||
chunkPayload, chunkPoolBuffer, err := protocol.EncodeDataPayloadPooled(chunkHeader, buf[:n])
|
||||
if err != nil {
|
||||
return fmt.Errorf("encode chunk payload: %w", err)
|
||||
}
|
||||
|
||||
chunkFrame := protocol.NewFramePooled(protocol.FrameTypeData, chunkPayload, chunkPoolBuffer)
|
||||
if err := h.frameWriter.WriteFrame(chunkFrame); err != nil {
|
||||
return err
|
||||
}
|
||||
h.stats.AddBytesOut(int64(len(chunkPayload)))
|
||||
}
|
||||
|
||||
if readErr == io.EOF {
|
||||
// Send final empty chunk with IsLast=true if we haven't already
|
||||
if n == 0 {
|
||||
finalHeader := protocol.DataHeader{
|
||||
StreamID: streamID,
|
||||
RequestID: requestID,
|
||||
Type: protocol.DataTypeHTTPBodyChunk,
|
||||
IsLast: true,
|
||||
}
|
||||
finalPayload, finalPoolBuffer, err := protocol.EncodeDataPayloadPooled(finalHeader, nil)
|
||||
if err != nil {
|
||||
return fmt.Errorf("encode final payload: %w", err)
|
||||
}
|
||||
finalFrame := protocol.NewFramePooled(protocol.FrameTypeData, finalPayload, finalPoolBuffer)
|
||||
if err := h.frameWriter.WriteFrame(finalFrame); err != nil {
|
||||
return err
|
||||
}
|
||||
}
|
||||
h.frameWriter.Flush()
|
||||
break
|
||||
}
|
||||
if readErr != nil {
|
||||
return fmt.Errorf("read response body: %w", readErr)
|
||||
}
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func (h *FrameHandler) sendHTTPResponse(streamID, requestID string, resp *protocol.HTTPResponse) error {
|
||||
if h.isClosedCheck != nil && h.isClosedCheck() {
|
||||
return nil
|
||||
@@ -320,26 +653,44 @@ func (h *FrameHandler) sendHTTPResponse(streamID, requestID string, resp *protoc
|
||||
|
||||
h.stats.AddBytesOut(int64(len(payload)))
|
||||
|
||||
return h.frameWriter.WriteFrame(dataFrame)
|
||||
if err := h.frameWriter.WriteFrame(dataFrame); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
// Flush immediately to ensure the response is sent without batching delay
|
||||
h.frameWriter.Flush()
|
||||
return nil
|
||||
}
|
||||
|
||||
func (h *FrameHandler) closeStream(streamID string) {
|
||||
h.streamMu.Lock()
|
||||
defer h.streamMu.Unlock()
|
||||
|
||||
stream, ok := h.streams[streamID]
|
||||
if !ok {
|
||||
h.streamMu.Unlock()
|
||||
return
|
||||
}
|
||||
|
||||
// Use stream-level mutex to prevent race conditions
|
||||
stream.mu.Lock()
|
||||
if stream.closed {
|
||||
stream.mu.Unlock()
|
||||
h.streamMu.Unlock()
|
||||
return
|
||||
}
|
||||
stream.closed = true
|
||||
stream.mu.Unlock()
|
||||
|
||||
// Remove from map first to prevent concurrent access
|
||||
delete(h.streams, streamID)
|
||||
h.streamMu.Unlock()
|
||||
|
||||
// Now safe to close resources without holding the main lock
|
||||
if stream.LocalConn != nil {
|
||||
stream.LocalConn.Close()
|
||||
}
|
||||
|
||||
close(stream.Done)
|
||||
|
||||
delete(h.streams, streamID)
|
||||
|
||||
if h.isClosedCheck != nil && h.isClosedCheck() {
|
||||
return
|
||||
}
|
||||
@@ -364,15 +715,29 @@ func (h *FrameHandler) closeStream(streamID string) {
|
||||
// Close closes all streams
|
||||
func (h *FrameHandler) Close() {
|
||||
h.streamMu.Lock()
|
||||
defer h.streamMu.Unlock()
|
||||
|
||||
for streamID, stream := range h.streams {
|
||||
if stream.LocalConn != nil {
|
||||
stream.LocalConn.Close()
|
||||
stream.mu.Lock()
|
||||
if !stream.closed {
|
||||
stream.closed = true
|
||||
if stream.LocalConn != nil {
|
||||
stream.LocalConn.Close()
|
||||
}
|
||||
close(stream.Done)
|
||||
}
|
||||
close(stream.Done)
|
||||
stream.mu.Unlock()
|
||||
delete(h.streams, streamID)
|
||||
}
|
||||
h.streamMu.Unlock()
|
||||
|
||||
h.streamingReqMu.Lock()
|
||||
for requestID, streamingReq := range h.streamingRequests {
|
||||
h.closeStreamingRequest(requestID, streamingReq)
|
||||
if streamingReq.Writer != nil {
|
||||
streamingReq.Writer.CloseWithError(fmt.Errorf("tunnel connection closed"))
|
||||
}
|
||||
delete(h.streamingRequests, requestID)
|
||||
}
|
||||
h.streamingReqMu.Unlock()
|
||||
}
|
||||
|
||||
// GetStats returns the traffic stats tracker
|
||||
@@ -438,3 +803,276 @@ func (h *FrameHandler) isLocalAddress(addr string) bool {
|
||||
|
||||
return false
|
||||
}
|
||||
|
||||
// cleanResponseHeaders removes hop-by-hop headers that should not be forwarded
|
||||
// Go's http.Client automatically handles chunked encoding, so we need to remove
|
||||
// the Transfer-Encoding header to avoid sending decoded body with chunked header
|
||||
func (h *FrameHandler) cleanResponseHeaders(headers http.Header) http.Header {
|
||||
cleaned := make(http.Header)
|
||||
|
||||
// List of hop-by-hop headers to remove (RFC 2616)
|
||||
hopByHopHeaders := map[string]bool{
|
||||
"Connection": true,
|
||||
"Keep-Alive": true,
|
||||
"Proxy-Authenticate": true,
|
||||
"Proxy-Authorization": true,
|
||||
"Te": true,
|
||||
"Trailers": true,
|
||||
"Transfer-Encoding": true,
|
||||
"Upgrade": true,
|
||||
"Proxy-Connection": true,
|
||||
}
|
||||
|
||||
for key, values := range headers {
|
||||
canonicalKey := http.CanonicalHeaderKey(key)
|
||||
|
||||
if hopByHopHeaders[canonicalKey] {
|
||||
continue
|
||||
}
|
||||
|
||||
// Also check if this header is listed in Connection header
|
||||
connectionHeaders := headers.Get("Connection")
|
||||
if connectionHeaders != "" {
|
||||
tokens := strings.Split(connectionHeaders, ",")
|
||||
skip := false
|
||||
for _, token := range tokens {
|
||||
if strings.TrimSpace(token) == key {
|
||||
skip = true
|
||||
break
|
||||
}
|
||||
}
|
||||
if skip {
|
||||
continue
|
||||
}
|
||||
}
|
||||
|
||||
for _, value := range values {
|
||||
cleaned.Add(key, value)
|
||||
}
|
||||
}
|
||||
|
||||
return cleaned
|
||||
}
|
||||
|
||||
func (h *FrameHandler) handleHTTPRequestHead(header protocol.DataHeader, payload []byte) error {
|
||||
httpReqHead, err := protocol.DecodeHTTPRequestHead(payload)
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to decode HTTP request head: %w", err)
|
||||
}
|
||||
|
||||
requestID := header.RequestID
|
||||
if requestID == "" {
|
||||
requestID = header.StreamID
|
||||
}
|
||||
|
||||
targetURL := httpReqHead.URL
|
||||
if !strings.HasPrefix(targetURL, "http://") && !strings.HasPrefix(targetURL, "https://") {
|
||||
scheme := "http"
|
||||
if h.tunnelType == protocol.TunnelTypeHTTPS {
|
||||
scheme = "https"
|
||||
}
|
||||
targetURL = fmt.Sprintf("%s://%s:%d%s", scheme, h.localHost, h.localPort, targetURL)
|
||||
}
|
||||
|
||||
pipeReader, pipeWriter := io.Pipe()
|
||||
|
||||
req, err := http.NewRequest(httpReqHead.Method, targetURL, pipeReader)
|
||||
if err != nil {
|
||||
pipeWriter.Close()
|
||||
return h.sendHTTPError(header.StreamID, requestID, http.StatusBadGateway, fmt.Sprintf("build request: %v", err))
|
||||
}
|
||||
|
||||
origHost := ""
|
||||
for key, values := range httpReqHead.Headers {
|
||||
if key == "Content-Length" {
|
||||
continue
|
||||
}
|
||||
for _, value := range values {
|
||||
req.Header.Add(key, value)
|
||||
}
|
||||
}
|
||||
if host := req.Header.Get("Host"); host != "" {
|
||||
origHost = host
|
||||
}
|
||||
|
||||
req.ContentLength = -1
|
||||
|
||||
isLocalTarget := h.isLocalAddress(h.localHost)
|
||||
if isLocalTarget {
|
||||
if origHost != "" {
|
||||
req.Host = origHost
|
||||
req.Header.Set("Host", origHost)
|
||||
} else {
|
||||
localHostPort := fmt.Sprintf("%s:%d", h.localHost, h.localPort)
|
||||
req.Host = localHostPort
|
||||
req.Header.Set("Host", localHostPort)
|
||||
}
|
||||
if origHost != "" {
|
||||
req.Header.Set("X-Forwarded-Host", origHost)
|
||||
}
|
||||
} else {
|
||||
targetHost := h.localHost
|
||||
if h.localPort != 443 && h.localPort != 80 {
|
||||
targetHost = fmt.Sprintf("%s:%d", h.localHost, h.localPort)
|
||||
}
|
||||
req.Host = targetHost
|
||||
req.Header.Set("Host", targetHost)
|
||||
if origHost != "" {
|
||||
req.Header.Set("X-Forwarded-Host", origHost)
|
||||
}
|
||||
}
|
||||
req.Header.Set("X-Forwarded-Proto", "https")
|
||||
|
||||
streamingReq := &StreamingRequest{
|
||||
RequestID: requestID,
|
||||
Writer: pipeWriter,
|
||||
Done: make(chan struct{}),
|
||||
chunkQueue: make(chan *chunkData, 512), // deeper buffer for bursty body chunks
|
||||
}
|
||||
|
||||
h.streamingReqMu.Lock()
|
||||
h.streamingRequests[requestID] = streamingReq
|
||||
h.streamingReqMu.Unlock()
|
||||
|
||||
go func() {
|
||||
defer pipeWriter.Close()
|
||||
|
||||
timeout := time.NewTimer(5 * time.Minute) // Timeout for receiving body chunks
|
||||
defer timeout.Stop()
|
||||
|
||||
for {
|
||||
select {
|
||||
case chunk, ok := <-streamingReq.chunkQueue:
|
||||
if !ok || chunk == nil {
|
||||
return
|
||||
}
|
||||
|
||||
// Reset timeout on each chunk
|
||||
if !timeout.Stop() {
|
||||
select {
|
||||
case <-timeout.C:
|
||||
default:
|
||||
}
|
||||
}
|
||||
timeout.Reset(5 * time.Minute)
|
||||
|
||||
if len(chunk.data) > 0 {
|
||||
if _, err := pipeWriter.Write(chunk.data); err != nil {
|
||||
h.logger.Error("Failed to write to pipe",
|
||||
zap.String("request_id", requestID),
|
||||
zap.Error(err),
|
||||
)
|
||||
pipeWriter.CloseWithError(err)
|
||||
return
|
||||
}
|
||||
}
|
||||
|
||||
if chunk.isLast {
|
||||
return
|
||||
}
|
||||
case <-streamingReq.Done:
|
||||
return
|
||||
case <-timeout.C:
|
||||
h.logger.Warn("Timeout waiting for request body chunks",
|
||||
zap.String("request_id", requestID),
|
||||
)
|
||||
pipeWriter.CloseWithError(fmt.Errorf("timeout waiting for body chunks"))
|
||||
return
|
||||
}
|
||||
}
|
||||
}()
|
||||
|
||||
go func() {
|
||||
defer func() {
|
||||
h.closeStreamingRequest(requestID, streamingReq)
|
||||
h.streamingReqMu.Lock()
|
||||
delete(h.streamingRequests, requestID)
|
||||
h.streamingReqMu.Unlock()
|
||||
}()
|
||||
|
||||
ctx, cancel := context.WithTimeout(context.Background(), 10*time.Minute)
|
||||
defer cancel()
|
||||
|
||||
reqWithCtx := req.WithContext(ctx)
|
||||
|
||||
resp, err := h.httpClient.Do(reqWithCtx)
|
||||
if err != nil {
|
||||
h.sendHTTPError(header.StreamID, requestID, http.StatusBadGateway, fmt.Sprintf("local request failed: %v", err))
|
||||
return
|
||||
}
|
||||
defer resp.Body.Close()
|
||||
|
||||
const bufferThreshold int64 = 1 * 1024 * 1024
|
||||
if resp.ContentLength > bufferThreshold {
|
||||
h.streamHTTPResponse(header.StreamID, requestID, resp)
|
||||
} else {
|
||||
h.adaptiveHTTPResponse(header.StreamID, requestID, resp, bufferThreshold)
|
||||
}
|
||||
}()
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func (h *FrameHandler) handleHTTPRequestBodyChunk(header protocol.DataHeader, data []byte) error {
|
||||
requestID := header.RequestID
|
||||
if requestID == "" {
|
||||
requestID = header.StreamID
|
||||
}
|
||||
|
||||
h.streamingReqMu.RLock()
|
||||
streamingReq, exists := h.streamingRequests[requestID]
|
||||
h.streamingReqMu.RUnlock()
|
||||
|
||||
if !exists {
|
||||
h.logger.Warn("Streaming request not found for body chunk",
|
||||
zap.String("request_id", requestID),
|
||||
)
|
||||
return nil
|
||||
}
|
||||
|
||||
streamingReq.mu.Lock()
|
||||
if streamingReq.closed {
|
||||
streamingReq.mu.Unlock()
|
||||
h.logger.Debug("Streaming request already closed",
|
||||
zap.String("request_id", requestID),
|
||||
)
|
||||
return nil
|
||||
}
|
||||
streamingReq.mu.Unlock()
|
||||
|
||||
chunk := &chunkData{
|
||||
data: make([]byte, len(data)),
|
||||
isLast: header.IsLast,
|
||||
}
|
||||
copy(chunk.data, data)
|
||||
|
||||
select {
|
||||
case streamingReq.chunkQueue <- chunk:
|
||||
case <-streamingReq.Done:
|
||||
h.logger.Debug("Streaming request already closed",
|
||||
zap.String("request_id", requestID),
|
||||
)
|
||||
return nil
|
||||
}
|
||||
|
||||
if header.IsLast {
|
||||
h.closeStreamingRequest(requestID, streamingReq)
|
||||
h.streamingReqMu.Lock()
|
||||
delete(h.streamingRequests, requestID)
|
||||
h.streamingReqMu.Unlock()
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// closeStreamingRequest marks a streaming request closed and signals goroutines.
|
||||
func (h *FrameHandler) closeStreamingRequest(requestID string, streamingReq *StreamingRequest) {
|
||||
streamingReq.mu.Lock()
|
||||
if streamingReq.closed {
|
||||
streamingReq.mu.Unlock()
|
||||
return
|
||||
}
|
||||
streamingReq.closed = true
|
||||
close(streamingReq.Done)
|
||||
streamingReq.mu.Unlock()
|
||||
}
|
||||
|
||||
@@ -1,7 +1,6 @@
|
||||
package proxy
|
||||
|
||||
import (
|
||||
"context"
|
||||
"fmt"
|
||||
"io"
|
||||
"net/http"
|
||||
@@ -12,7 +11,6 @@ import (
|
||||
json "github.com/goccy/go-json"
|
||||
|
||||
"drip/internal/server/tunnel"
|
||||
"drip/internal/shared/constants"
|
||||
"drip/internal/shared/pool"
|
||||
"drip/internal/shared/protocol"
|
||||
"drip/internal/shared/utils"
|
||||
@@ -71,17 +69,53 @@ func (h *Handler) ServeHTTP(w http.ResponseWriter, r *http.Request) {
|
||||
return
|
||||
}
|
||||
|
||||
limitedReader := io.LimitReader(r.Body, constants.MaxRequestBodySize)
|
||||
body, err := io.ReadAll(limitedReader)
|
||||
if err != nil {
|
||||
h.logger.Error("Read request body failed", zap.Error(err))
|
||||
http.Error(w, "Failed to read request body", http.StatusInternalServerError)
|
||||
return
|
||||
}
|
||||
defer r.Body.Close()
|
||||
|
||||
requestID := utils.GenerateID()
|
||||
|
||||
h.handleAdaptiveRequest(w, r, transport, requestID, subdomain)
|
||||
}
|
||||
|
||||
func (h *Handler) handleAdaptiveRequest(w http.ResponseWriter, r *http.Request, transport tunnel.Transport, requestID string, subdomain string) {
|
||||
const streamingThreshold int64 = 1 * 1024 * 1024
|
||||
|
||||
buffer := make([]byte, 0, streamingThreshold)
|
||||
tempBuf := make([]byte, 32*1024)
|
||||
|
||||
var totalRead int64
|
||||
var hitThreshold bool
|
||||
|
||||
for totalRead < streamingThreshold {
|
||||
n, err := r.Body.Read(tempBuf)
|
||||
if n > 0 {
|
||||
buffer = append(buffer, tempBuf[:n]...)
|
||||
totalRead += int64(n)
|
||||
}
|
||||
if err == io.EOF {
|
||||
r.Body.Close()
|
||||
h.sendBufferedRequest(w, r, transport, requestID, subdomain, buffer)
|
||||
return
|
||||
}
|
||||
if err != nil {
|
||||
r.Body.Close()
|
||||
h.logger.Error("Read request body failed", zap.Error(err))
|
||||
http.Error(w, "Failed to read request body", http.StatusInternalServerError)
|
||||
return
|
||||
}
|
||||
if totalRead >= streamingThreshold {
|
||||
hitThreshold = true
|
||||
break
|
||||
}
|
||||
}
|
||||
|
||||
if !hitThreshold {
|
||||
r.Body.Close()
|
||||
h.sendBufferedRequest(w, r, transport, requestID, subdomain, buffer)
|
||||
return
|
||||
}
|
||||
|
||||
h.streamLargeRequest(w, r, transport, requestID, subdomain, buffer)
|
||||
}
|
||||
|
||||
func (h *Handler) sendBufferedRequest(w http.ResponseWriter, r *http.Request, transport tunnel.Transport, requestID string, subdomain string, body []byte) {
|
||||
headers := h.headerPool.Get()
|
||||
h.headerPool.CloneWithExtra(headers, r.Header, "Host", r.Host)
|
||||
|
||||
@@ -93,7 +127,6 @@ func (h *Handler) ServeHTTP(w http.ResponseWriter, r *http.Request) {
|
||||
}
|
||||
|
||||
reqBytes, err := protocol.EncodeHTTPRequest(&httpReq)
|
||||
|
||||
h.headerPool.Put(headers)
|
||||
|
||||
if err != nil {
|
||||
@@ -119,7 +152,11 @@ func (h *Handler) ServeHTTP(w http.ResponseWriter, r *http.Request) {
|
||||
frame := protocol.NewFramePooled(protocol.FrameTypeData, payload, poolBuffer)
|
||||
|
||||
respChan := h.responses.CreateResponseChan(requestID)
|
||||
defer h.responses.CleanupResponseChan(requestID)
|
||||
streamingDone := h.responses.CreateStreamingResponse(requestID, w)
|
||||
defer func() {
|
||||
h.responses.CleanupResponseChan(requestID)
|
||||
h.responses.CleanupStreamingResponse(requestID)
|
||||
}()
|
||||
|
||||
if err := transport.SendFrame(frame); err != nil {
|
||||
h.logger.Error("Send frame to tunnel failed", zap.Error(err))
|
||||
@@ -127,14 +164,220 @@ func (h *Handler) ServeHTTP(w http.ResponseWriter, r *http.Request) {
|
||||
return
|
||||
}
|
||||
|
||||
ctx, cancel := context.WithTimeout(context.Background(), constants.RequestTimeout)
|
||||
defer cancel()
|
||||
select {
|
||||
case respMsg := <-respChan:
|
||||
if respMsg == nil {
|
||||
http.Error(w, "Internal server error: nil response", http.StatusInternalServerError)
|
||||
return
|
||||
}
|
||||
h.writeHTTPResponse(w, respMsg, subdomain, r)
|
||||
case <-streamingDone:
|
||||
// Streaming response has been fully written by SendStreamingChunk
|
||||
case <-time.After(5 * time.Minute):
|
||||
h.logger.Error("Request timeout",
|
||||
zap.String("request_id", requestID),
|
||||
zap.String("url", r.URL.String()),
|
||||
)
|
||||
http.Error(w, "Request timeout - the tunnel client did not respond in time", http.StatusGatewayTimeout)
|
||||
}
|
||||
}
|
||||
|
||||
func (h *Handler) streamLargeRequest(w http.ResponseWriter, r *http.Request, transport tunnel.Transport, requestID string, subdomain string, bufferedData []byte) {
|
||||
headers := h.headerPool.Get()
|
||||
h.headerPool.CloneWithExtra(headers, r.Header, "Host", r.Host)
|
||||
|
||||
httpReqHead := protocol.HTTPRequestHead{
|
||||
Method: r.Method,
|
||||
URL: r.URL.String(),
|
||||
Headers: headers,
|
||||
ContentLength: r.ContentLength,
|
||||
}
|
||||
|
||||
headBytes, err := protocol.EncodeHTTPRequestHead(&httpReqHead)
|
||||
h.headerPool.Put(headers)
|
||||
|
||||
if err != nil {
|
||||
h.logger.Error("Encode HTTP request head failed", zap.Error(err))
|
||||
http.Error(w, "Internal server error", http.StatusInternalServerError)
|
||||
return
|
||||
}
|
||||
|
||||
headHeader := protocol.DataHeader{
|
||||
StreamID: requestID,
|
||||
RequestID: requestID,
|
||||
Type: protocol.DataTypeHTTPHead, // shared streaming head type
|
||||
IsLast: false,
|
||||
}
|
||||
|
||||
headPayload, headPoolBuffer, err := protocol.EncodeDataPayloadPooled(headHeader, headBytes)
|
||||
if err != nil {
|
||||
h.logger.Error("Encode head payload failed", zap.Error(err))
|
||||
http.Error(w, "Internal server error", http.StatusInternalServerError)
|
||||
return
|
||||
}
|
||||
|
||||
headFrame := protocol.NewFramePooled(protocol.FrameTypeData, headPayload, headPoolBuffer)
|
||||
|
||||
respChan := h.responses.CreateResponseChan(requestID)
|
||||
streamingDone := h.responses.CreateStreamingResponse(requestID, w)
|
||||
defer func() {
|
||||
h.responses.CleanupResponseChan(requestID)
|
||||
h.responses.CleanupStreamingResponse(requestID)
|
||||
}()
|
||||
|
||||
if err := transport.SendFrame(headFrame); err != nil {
|
||||
h.logger.Error("Send head frame failed", zap.Error(err))
|
||||
http.Error(w, "Failed to forward request to tunnel", http.StatusBadGateway)
|
||||
return
|
||||
}
|
||||
|
||||
if len(bufferedData) > 0 {
|
||||
chunkHeader := protocol.DataHeader{
|
||||
StreamID: requestID,
|
||||
RequestID: requestID,
|
||||
Type: protocol.DataTypeHTTPBodyChunk, // shared streaming body type
|
||||
IsLast: false,
|
||||
}
|
||||
|
||||
chunkPayload, chunkPoolBuffer, err := protocol.EncodeDataPayloadPooled(chunkHeader, bufferedData)
|
||||
if err != nil {
|
||||
h.logger.Error("Encode buffered chunk failed", zap.Error(err))
|
||||
|
||||
finalHeader := protocol.DataHeader{
|
||||
StreamID: requestID,
|
||||
RequestID: requestID,
|
||||
Type: protocol.DataTypeHTTPRequestBodyChunk,
|
||||
IsLast: true,
|
||||
}
|
||||
finalPayload, finalPoolBuffer, ferr := protocol.EncodeDataPayloadPooled(finalHeader, nil)
|
||||
if ferr == nil {
|
||||
finalFrame := protocol.NewFramePooled(protocol.FrameTypeData, finalPayload, finalPoolBuffer)
|
||||
transport.SendFrame(finalFrame)
|
||||
}
|
||||
return
|
||||
}
|
||||
|
||||
chunkFrame := protocol.NewFramePooled(protocol.FrameTypeData, chunkPayload, chunkPoolBuffer)
|
||||
if err := transport.SendFrame(chunkFrame); err != nil {
|
||||
h.logger.Error("Send buffered chunk failed", zap.Error(err))
|
||||
|
||||
finalHeader := protocol.DataHeader{
|
||||
StreamID: requestID,
|
||||
RequestID: requestID,
|
||||
Type: protocol.DataTypeHTTPRequestBodyChunk,
|
||||
IsLast: true,
|
||||
}
|
||||
finalPayload, finalPoolBuffer, ferr := protocol.EncodeDataPayloadPooled(finalHeader, nil)
|
||||
if ferr == nil {
|
||||
finalFrame := protocol.NewFramePooled(protocol.FrameTypeData, finalPayload, finalPoolBuffer)
|
||||
transport.SendFrame(finalFrame)
|
||||
}
|
||||
return
|
||||
}
|
||||
}
|
||||
|
||||
buffer := make([]byte, 32*1024)
|
||||
for {
|
||||
n, readErr := r.Body.Read(buffer)
|
||||
if n > 0 {
|
||||
isLast := readErr == io.EOF
|
||||
|
||||
chunkHeader := protocol.DataHeader{
|
||||
StreamID: requestID,
|
||||
RequestID: requestID,
|
||||
Type: protocol.DataTypeHTTPBodyChunk, // shared streaming body type
|
||||
IsLast: isLast,
|
||||
}
|
||||
|
||||
chunkPayload, chunkPoolBuffer, err := protocol.EncodeDataPayloadPooled(chunkHeader, buffer[:n])
|
||||
if err != nil {
|
||||
h.logger.Error("Encode chunk payload failed", zap.Error(err))
|
||||
|
||||
finalHeader := protocol.DataHeader{
|
||||
StreamID: requestID,
|
||||
RequestID: requestID,
|
||||
Type: protocol.DataTypeHTTPRequestBodyChunk,
|
||||
IsLast: true,
|
||||
}
|
||||
finalPayload, finalPoolBuffer, ferr := protocol.EncodeDataPayloadPooled(finalHeader, nil)
|
||||
if ferr == nil {
|
||||
finalFrame := protocol.NewFramePooled(protocol.FrameTypeData, finalPayload, finalPoolBuffer)
|
||||
transport.SendFrame(finalFrame)
|
||||
}
|
||||
return
|
||||
}
|
||||
|
||||
chunkFrame := protocol.NewFramePooled(protocol.FrameTypeData, chunkPayload, chunkPoolBuffer)
|
||||
if err := transport.SendFrame(chunkFrame); err != nil {
|
||||
h.logger.Error("Send chunk frame failed", zap.Error(err))
|
||||
|
||||
finalHeader := protocol.DataHeader{
|
||||
StreamID: requestID,
|
||||
RequestID: requestID,
|
||||
Type: protocol.DataTypeHTTPRequestBodyChunk,
|
||||
IsLast: true,
|
||||
}
|
||||
finalPayload, finalPoolBuffer, ferr := protocol.EncodeDataPayloadPooled(finalHeader, nil)
|
||||
if ferr == nil {
|
||||
finalFrame := protocol.NewFramePooled(protocol.FrameTypeData, finalPayload, finalPoolBuffer)
|
||||
transport.SendFrame(finalFrame)
|
||||
}
|
||||
return
|
||||
}
|
||||
}
|
||||
|
||||
if readErr == io.EOF {
|
||||
if n == 0 {
|
||||
finalHeader := protocol.DataHeader{
|
||||
StreamID: requestID,
|
||||
RequestID: requestID,
|
||||
Type: protocol.DataTypeHTTPRequestBodyChunk,
|
||||
IsLast: true,
|
||||
}
|
||||
finalPayload, finalPoolBuffer, err := protocol.EncodeDataPayloadPooled(finalHeader, nil)
|
||||
if err == nil {
|
||||
finalFrame := protocol.NewFramePooled(protocol.FrameTypeData, finalPayload, finalPoolBuffer)
|
||||
transport.SendFrame(finalFrame)
|
||||
}
|
||||
}
|
||||
break
|
||||
}
|
||||
if readErr != nil {
|
||||
h.logger.Error("Read request body failed", zap.Error(readErr))
|
||||
|
||||
finalHeader := protocol.DataHeader{
|
||||
StreamID: requestID,
|
||||
RequestID: requestID,
|
||||
Type: protocol.DataTypeHTTPRequestBodyChunk,
|
||||
IsLast: true,
|
||||
}
|
||||
finalPayload, finalPoolBuffer, err := protocol.EncodeDataPayloadPooled(finalHeader, nil)
|
||||
if err == nil {
|
||||
finalFrame := protocol.NewFramePooled(protocol.FrameTypeData, finalPayload, finalPoolBuffer)
|
||||
transport.SendFrame(finalFrame)
|
||||
}
|
||||
|
||||
http.Error(w, "Failed to read request body", http.StatusInternalServerError)
|
||||
return
|
||||
}
|
||||
}
|
||||
|
||||
r.Body.Close()
|
||||
|
||||
select {
|
||||
case respMsg := <-respChan:
|
||||
if respMsg == nil {
|
||||
http.Error(w, "Internal server error: nil response", http.StatusInternalServerError)
|
||||
return
|
||||
}
|
||||
h.writeHTTPResponse(w, respMsg, subdomain, r)
|
||||
|
||||
case <-ctx.Done():
|
||||
case <-streamingDone:
|
||||
// Streaming response has been fully written by SendStreamingChunk
|
||||
case <-time.After(5 * time.Minute):
|
||||
h.logger.Error("Streaming request timeout",
|
||||
zap.String("request_id", requestID),
|
||||
zap.String("url", r.URL.String()),
|
||||
)
|
||||
http.Error(w, "Request timeout - the tunnel client did not respond in time", http.StatusGatewayTimeout)
|
||||
}
|
||||
}
|
||||
@@ -145,12 +388,23 @@ func (h *Handler) writeHTTPResponse(w http.ResponseWriter, resp *protocol.HTTPRe
|
||||
return
|
||||
}
|
||||
|
||||
// For buffered responses, we have the complete body, so we can set Content-Length
|
||||
// Skip ALL hop-by-hop headers - client should have already cleaned them
|
||||
for key, values := range resp.Headers {
|
||||
if key == "Connection" || key == "Keep-Alive" || key == "Transfer-Encoding" || key == "Upgrade" {
|
||||
canonicalKey := http.CanonicalHeaderKey(key)
|
||||
|
||||
// Skip hop-by-hop headers completely using canonical key comparison
|
||||
if canonicalKey == "Connection" ||
|
||||
canonicalKey == "Keep-Alive" ||
|
||||
canonicalKey == "Transfer-Encoding" ||
|
||||
canonicalKey == "Upgrade" ||
|
||||
canonicalKey == "Proxy-Connection" ||
|
||||
canonicalKey == "Te" ||
|
||||
canonicalKey == "Trailer" {
|
||||
continue
|
||||
}
|
||||
|
||||
if key == "Location" && len(values) > 0 {
|
||||
if canonicalKey == "Location" && len(values) > 0 {
|
||||
rewrittenLocation := h.rewriteLocationHeader(values[0], r.Host)
|
||||
w.Header().Set("Location", rewrittenLocation)
|
||||
continue
|
||||
@@ -161,9 +415,8 @@ func (h *Handler) writeHTTPResponse(w http.ResponseWriter, resp *protocol.HTTPRe
|
||||
}
|
||||
}
|
||||
|
||||
if w.Header().Get("Content-Length") == "" && len(resp.Body) > 0 {
|
||||
w.Header().Set("Content-Length", fmt.Sprintf("%d", len(resp.Body)))
|
||||
}
|
||||
// For buffered mode, always set Content-Length with the actual body size
|
||||
w.Header().Set("Content-Length", fmt.Sprintf("%d", len(resp.Body)))
|
||||
|
||||
statusCode := resp.StatusCode
|
||||
if statusCode == 0 {
|
||||
@@ -171,6 +424,7 @@ func (h *Handler) writeHTTPResponse(w http.ResponseWriter, resp *protocol.HTTPRe
|
||||
}
|
||||
|
||||
w.WriteHeader(statusCode)
|
||||
|
||||
if len(resp.Body) > 0 {
|
||||
w.Write(resp.Body)
|
||||
}
|
||||
@@ -284,19 +538,6 @@ func (h *Handler) serveHealth(w http.ResponseWriter, r *http.Request) {
|
||||
json.NewEncoder(w).Encode(health)
|
||||
}
|
||||
|
||||
func cloneHeadersWithHost(src http.Header, host string) http.Header {
|
||||
dst := make(http.Header, len(src)+1)
|
||||
for k, v := range src {
|
||||
copied := make([]string, len(v))
|
||||
copy(copied, v)
|
||||
dst[k] = copied
|
||||
}
|
||||
if host != "" {
|
||||
dst.Set("Host", host)
|
||||
}
|
||||
return dst
|
||||
}
|
||||
|
||||
func (h *Handler) serveStats(w http.ResponseWriter, r *http.Request) {
|
||||
if h.authToken != "" {
|
||||
token := r.URL.Query().Get("token")
|
||||
|
||||
@@ -1,6 +1,10 @@
|
||||
package proxy
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"net"
|
||||
"net/http"
|
||||
"strings"
|
||||
"sync"
|
||||
"time"
|
||||
|
||||
@@ -14,20 +18,33 @@ type responseChanEntry struct {
|
||||
createdAt time.Time
|
||||
}
|
||||
|
||||
// streamingResponseEntry holds a streaming response writer
|
||||
type streamingResponseEntry struct {
|
||||
w http.ResponseWriter
|
||||
flusher http.Flusher
|
||||
createdAt time.Time
|
||||
lastActivityAt time.Time
|
||||
headersSent bool
|
||||
done chan struct{}
|
||||
mu sync.Mutex
|
||||
}
|
||||
|
||||
// ResponseHandler manages response channels for HTTP requests over TCP/Frame protocol
|
||||
type ResponseHandler struct {
|
||||
channels map[string]*responseChanEntry
|
||||
mu sync.RWMutex
|
||||
logger *zap.Logger
|
||||
stopCh chan struct{}
|
||||
channels map[string]*responseChanEntry
|
||||
streamingChannels map[string]*streamingResponseEntry
|
||||
mu sync.RWMutex
|
||||
logger *zap.Logger
|
||||
stopCh chan struct{}
|
||||
}
|
||||
|
||||
// NewResponseHandler creates a new response handler
|
||||
func NewResponseHandler(logger *zap.Logger) *ResponseHandler {
|
||||
h := &ResponseHandler{
|
||||
channels: make(map[string]*responseChanEntry),
|
||||
logger: logger,
|
||||
stopCh: make(chan struct{}),
|
||||
channels: make(map[string]*responseChanEntry),
|
||||
streamingChannels: make(map[string]*streamingResponseEntry),
|
||||
logger: logger,
|
||||
stopCh: make(chan struct{}),
|
||||
}
|
||||
|
||||
// Start single cleanup goroutine instead of one per request
|
||||
@@ -50,6 +67,25 @@ func (h *ResponseHandler) CreateResponseChan(requestID string) chan *protocol.HT
|
||||
return ch
|
||||
}
|
||||
|
||||
// CreateStreamingResponse creates a streaming response entry for a request ID
|
||||
func (h *ResponseHandler) CreateStreamingResponse(requestID string, w http.ResponseWriter) chan struct{} {
|
||||
h.mu.Lock()
|
||||
defer h.mu.Unlock()
|
||||
|
||||
flusher, _ := w.(http.Flusher)
|
||||
done := make(chan struct{})
|
||||
now := time.Now()
|
||||
h.streamingChannels[requestID] = &streamingResponseEntry{
|
||||
w: w,
|
||||
flusher: flusher,
|
||||
createdAt: now,
|
||||
lastActivityAt: now,
|
||||
done: done,
|
||||
}
|
||||
|
||||
return done
|
||||
}
|
||||
|
||||
// GetResponseChan gets the response channel for a request ID
|
||||
func (h *ResponseHandler) GetResponseChan(requestID string) <-chan *protocol.HTTPResponse {
|
||||
h.mu.RLock()
|
||||
@@ -67,25 +103,168 @@ func (h *ResponseHandler) SendResponse(requestID string, resp *protocol.HTTPResp
|
||||
h.mu.RUnlock()
|
||||
|
||||
if !exists || entry == nil {
|
||||
h.logger.Warn("Response channel not found",
|
||||
zap.String("request_id", requestID),
|
||||
)
|
||||
return
|
||||
}
|
||||
|
||||
select {
|
||||
case entry.ch <- resp:
|
||||
h.logger.Debug("Response sent to channel",
|
||||
zap.String("request_id", requestID),
|
||||
)
|
||||
case <-time.After(5 * time.Second):
|
||||
h.logger.Warn("Timeout sending response to channel",
|
||||
case <-time.After(30 * time.Second):
|
||||
h.logger.Error("Timeout sending response to channel - handler may have abandoned",
|
||||
zap.String("request_id", requestID),
|
||||
zap.Int("status_code", resp.StatusCode),
|
||||
zap.Int("body_size", len(resp.Body)),
|
||||
)
|
||||
}
|
||||
}
|
||||
|
||||
// CleanupResponseChan removes and closes a response channel
|
||||
func (h *ResponseHandler) SendStreamingHead(requestID string, head *protocol.HTTPResponseHead) error {
|
||||
h.mu.RLock()
|
||||
entry, exists := h.streamingChannels[requestID]
|
||||
h.mu.RUnlock()
|
||||
|
||||
if !exists || entry == nil {
|
||||
return nil
|
||||
}
|
||||
|
||||
entry.mu.Lock()
|
||||
defer entry.mu.Unlock()
|
||||
|
||||
select {
|
||||
case <-entry.done:
|
||||
return nil
|
||||
default:
|
||||
}
|
||||
|
||||
if entry.headersSent {
|
||||
return nil
|
||||
}
|
||||
|
||||
// Copy headers, removing hop-by-hop headers that were already handled by client
|
||||
// Client's cleanResponseHeaders already removed Transfer-Encoding, Connection, etc.
|
||||
// But we need to check again in case they slipped through
|
||||
hasContentLength := false
|
||||
|
||||
for key, values := range head.Headers {
|
||||
canonicalKey := http.CanonicalHeaderKey(key)
|
||||
|
||||
// Skip ALL hop-by-hop headers
|
||||
if canonicalKey == "Connection" ||
|
||||
canonicalKey == "Keep-Alive" ||
|
||||
canonicalKey == "Transfer-Encoding" ||
|
||||
canonicalKey == "Upgrade" ||
|
||||
canonicalKey == "Proxy-Connection" ||
|
||||
canonicalKey == "Te" ||
|
||||
canonicalKey == "Trailer" {
|
||||
continue
|
||||
}
|
||||
|
||||
if canonicalKey == "Content-Length" {
|
||||
hasContentLength = true
|
||||
}
|
||||
|
||||
for _, value := range values {
|
||||
entry.w.Header().Add(key, value)
|
||||
}
|
||||
}
|
||||
|
||||
// For streaming responses, decide how to indicate message length
|
||||
if head.ContentLength >= 0 && !hasContentLength {
|
||||
entry.w.Header().Set("Content-Length", fmt.Sprintf("%d", head.ContentLength))
|
||||
}
|
||||
|
||||
statusCode := head.StatusCode
|
||||
if statusCode == 0 {
|
||||
statusCode = http.StatusOK
|
||||
}
|
||||
|
||||
entry.w.WriteHeader(statusCode)
|
||||
entry.headersSent = true
|
||||
entry.lastActivityAt = time.Now()
|
||||
|
||||
if entry.flusher != nil {
|
||||
entry.flusher.Flush()
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func (h *ResponseHandler) SendStreamingChunk(requestID string, chunk []byte, isLast bool) error {
|
||||
h.mu.RLock()
|
||||
entry, exists := h.streamingChannels[requestID]
|
||||
h.mu.RUnlock()
|
||||
|
||||
if !exists || entry == nil {
|
||||
return nil
|
||||
}
|
||||
|
||||
entry.mu.Lock()
|
||||
defer entry.mu.Unlock()
|
||||
|
||||
select {
|
||||
case <-entry.done:
|
||||
return nil
|
||||
default:
|
||||
}
|
||||
|
||||
if len(chunk) > 0 {
|
||||
_, err := entry.w.Write(chunk)
|
||||
if err != nil {
|
||||
if isClientDisconnectError(err) {
|
||||
select {
|
||||
case <-entry.done:
|
||||
default:
|
||||
close(entry.done)
|
||||
}
|
||||
return nil
|
||||
}
|
||||
select {
|
||||
case <-entry.done:
|
||||
default:
|
||||
close(entry.done)
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
entry.lastActivityAt = time.Now()
|
||||
|
||||
if entry.flusher != nil {
|
||||
entry.flusher.Flush()
|
||||
}
|
||||
}
|
||||
|
||||
if isLast {
|
||||
select {
|
||||
case <-entry.done:
|
||||
default:
|
||||
close(entry.done)
|
||||
}
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func isClientDisconnectError(err error) bool {
|
||||
if err == nil {
|
||||
return false
|
||||
}
|
||||
|
||||
if netErr, ok := err.(*net.OpError); ok {
|
||||
if netErr.Err != nil {
|
||||
errStr := netErr.Err.Error()
|
||||
if strings.Contains(errStr, "broken pipe") ||
|
||||
strings.Contains(errStr, "connection reset") ||
|
||||
strings.Contains(errStr, "connection refused") {
|
||||
return true
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
errStr := err.Error()
|
||||
return strings.Contains(errStr, "broken pipe") ||
|
||||
strings.Contains(errStr, "connection reset") ||
|
||||
strings.Contains(errStr, "use of closed network connection")
|
||||
}
|
||||
|
||||
func (h *ResponseHandler) CleanupResponseChan(requestID string) {
|
||||
h.mu.Lock()
|
||||
defer h.mu.Unlock()
|
||||
@@ -96,15 +275,26 @@ func (h *ResponseHandler) CleanupResponseChan(requestID string) {
|
||||
}
|
||||
}
|
||||
|
||||
// GetPendingCount returns the number of pending responses
|
||||
func (h *ResponseHandler) CleanupStreamingResponse(requestID string) {
|
||||
h.mu.Lock()
|
||||
defer h.mu.Unlock()
|
||||
|
||||
if entry, exists := h.streamingChannels[requestID]; exists {
|
||||
select {
|
||||
case <-entry.done:
|
||||
default:
|
||||
close(entry.done)
|
||||
}
|
||||
delete(h.streamingChannels, requestID)
|
||||
}
|
||||
}
|
||||
|
||||
func (h *ResponseHandler) GetPendingCount() int {
|
||||
h.mu.RLock()
|
||||
defer h.mu.RUnlock()
|
||||
return len(h.channels)
|
||||
return len(h.channels) + len(h.streamingChannels)
|
||||
}
|
||||
|
||||
// cleanupLoop periodically cleans up expired response channels
|
||||
// This replaces the per-request goroutine approach with a single cleanup goroutine
|
||||
func (h *ResponseHandler) cleanupLoop() {
|
||||
ticker := time.NewTicker(5 * time.Second)
|
||||
defer ticker.Stop()
|
||||
@@ -119,10 +309,10 @@ func (h *ResponseHandler) cleanupLoop() {
|
||||
}
|
||||
}
|
||||
|
||||
// cleanupExpiredChannels removes channels older than 30 seconds
|
||||
func (h *ResponseHandler) cleanupExpiredChannels() {
|
||||
now := time.Now()
|
||||
timeout := 30 * time.Second
|
||||
streamingTimeout := 5 * time.Minute
|
||||
|
||||
h.mu.Lock()
|
||||
defer h.mu.Unlock()
|
||||
@@ -136,24 +326,43 @@ func (h *ResponseHandler) cleanupExpiredChannels() {
|
||||
}
|
||||
}
|
||||
|
||||
for requestID, entry := range h.streamingChannels {
|
||||
if now.Sub(entry.lastActivityAt) > streamingTimeout {
|
||||
select {
|
||||
case <-entry.done:
|
||||
default:
|
||||
close(entry.done)
|
||||
}
|
||||
delete(h.streamingChannels, requestID)
|
||||
expiredCount++
|
||||
}
|
||||
}
|
||||
|
||||
if expiredCount > 0 {
|
||||
h.logger.Debug("Cleaned up expired response channels",
|
||||
zap.Int("count", expiredCount),
|
||||
zap.Int("remaining", len(h.channels)),
|
||||
zap.Int("remaining", len(h.channels)+len(h.streamingChannels)),
|
||||
)
|
||||
}
|
||||
}
|
||||
|
||||
// Close stops the cleanup loop
|
||||
func (h *ResponseHandler) Close() {
|
||||
close(h.stopCh)
|
||||
|
||||
h.mu.Lock()
|
||||
defer h.mu.Unlock()
|
||||
|
||||
// Close all remaining channels
|
||||
for _, entry := range h.channels {
|
||||
close(entry.ch)
|
||||
}
|
||||
h.channels = make(map[string]*responseChanEntry)
|
||||
|
||||
for _, entry := range h.streamingChannels {
|
||||
select {
|
||||
case <-entry.done:
|
||||
default:
|
||||
close(entry.done)
|
||||
}
|
||||
}
|
||||
h.streamingChannels = make(map[string]*streamingResponseEntry)
|
||||
}
|
||||
|
||||
@@ -49,6 +49,9 @@ type HTTPResponseHandler interface {
|
||||
GetResponseChan(requestID string) <-chan *protocol.HTTPResponse
|
||||
CleanupResponseChan(requestID string)
|
||||
SendResponse(requestID string, resp *protocol.HTTPResponse)
|
||||
// Streaming response methods
|
||||
SendStreamingHead(requestID string, head *protocol.HTTPResponseHead) error
|
||||
SendStreamingChunk(requestID string, chunk []byte, isLast bool) error
|
||||
}
|
||||
|
||||
// NewConnection creates a new connection handler
|
||||
@@ -273,6 +276,15 @@ func (c *Connection) handleHTTPRequest(reader *bufio.Reader) error {
|
||||
c.logger.Debug("Client disconnected abruptly", zap.Error(err))
|
||||
return nil
|
||||
}
|
||||
// Check if it looks like garbage data (not a valid HTTP request)
|
||||
if strings.Contains(errStr, "malformed HTTP") {
|
||||
c.logger.Warn("Received malformed HTTP request, possibly due to pipelined requests or protocol error",
|
||||
zap.Error(err),
|
||||
zap.String("error_snippet", errStr[:min(len(errStr), 100)]),
|
||||
)
|
||||
// Close connection on malformed request to prevent further errors
|
||||
return nil
|
||||
}
|
||||
c.logger.Error("Failed to parse HTTP request", zap.Error(err))
|
||||
return fmt.Errorf("failed to parse HTTP request: %w", err)
|
||||
}
|
||||
@@ -289,9 +301,21 @@ func (c *Connection) handleHTTPRequest(reader *bufio.Reader) error {
|
||||
header: make(http.Header),
|
||||
}
|
||||
|
||||
// Handle the request
|
||||
// Handle the request - this blocks until response is complete
|
||||
c.httpHandler.ServeHTTP(respWriter, req)
|
||||
|
||||
// Ensure response is flushed to client
|
||||
if tcpConn, ok := c.conn.(*net.TCPConn); ok {
|
||||
// Force flush TCP buffers
|
||||
tcpConn.SetNoDelay(true)
|
||||
tcpConn.SetNoDelay(false)
|
||||
}
|
||||
|
||||
c.logger.Debug("HTTP request processing completed",
|
||||
zap.String("method", req.Method),
|
||||
zap.String("url", req.URL.String()),
|
||||
)
|
||||
|
||||
// Check if we should close the connection
|
||||
// Close if: Connection: close header, or HTTP/1.0 without Connection: keep-alive
|
||||
shouldClose := false
|
||||
@@ -304,8 +328,13 @@ func (c *Connection) handleHTTPRequest(reader *bufio.Reader) error {
|
||||
}
|
||||
}
|
||||
|
||||
// Also check if response indicated connection should close
|
||||
if respWriter.headerWritten && respWriter.header.Get("Connection") == "close" {
|
||||
shouldClose = true
|
||||
}
|
||||
|
||||
if shouldClose {
|
||||
c.logger.Debug("Closing connection as requested by client")
|
||||
c.logger.Debug("Closing connection as requested by client or server")
|
||||
return nil
|
||||
}
|
||||
|
||||
@@ -313,6 +342,13 @@ func (c *Connection) handleHTTPRequest(reader *bufio.Reader) error {
|
||||
}
|
||||
}
|
||||
|
||||
func min(a, b int) int {
|
||||
if a < b {
|
||||
return a
|
||||
}
|
||||
return b
|
||||
}
|
||||
|
||||
// handleFrames handles incoming frames
|
||||
func (c *Connection) handleFrames(reader *bufio.Reader) error {
|
||||
for {
|
||||
@@ -439,10 +475,55 @@ func (c *Connection) handleDataFrame(frame *protocol.Frame) {
|
||||
}
|
||||
|
||||
c.responseChans.SendResponse(reqID, httpResp)
|
||||
case protocol.DataTypeHTTPHead:
|
||||
// Streaming HTTP response headers
|
||||
if c.responseChans == nil {
|
||||
c.logger.Warn("No response handler for streaming HTTP head",
|
||||
zap.String("stream_id", header.StreamID),
|
||||
)
|
||||
return
|
||||
}
|
||||
|
||||
c.logger.Debug("Routed HTTP response to channel",
|
||||
zap.String("request_id", reqID),
|
||||
)
|
||||
httpHead, err := protocol.DecodeHTTPResponseHead(data)
|
||||
if err != nil {
|
||||
c.logger.Error("Failed to decode HTTP response head",
|
||||
zap.String("stream_id", header.StreamID),
|
||||
zap.Error(err),
|
||||
)
|
||||
return
|
||||
}
|
||||
|
||||
reqID := header.RequestID
|
||||
if reqID == "" {
|
||||
reqID = header.StreamID
|
||||
}
|
||||
|
||||
if err := c.responseChans.SendStreamingHead(reqID, httpHead); err != nil {
|
||||
c.logger.Error("Failed to send streaming head",
|
||||
zap.String("request_id", reqID),
|
||||
zap.Error(err),
|
||||
)
|
||||
}
|
||||
case protocol.DataTypeHTTPBodyChunk:
|
||||
// Streaming HTTP response body chunk
|
||||
if c.responseChans == nil {
|
||||
c.logger.Warn("No response handler for streaming HTTP chunk",
|
||||
zap.String("stream_id", header.StreamID),
|
||||
)
|
||||
return
|
||||
}
|
||||
|
||||
reqID := header.RequestID
|
||||
if reqID == "" {
|
||||
reqID = header.StreamID
|
||||
}
|
||||
|
||||
if err := c.responseChans.SendStreamingChunk(reqID, data, header.IsLast); err != nil {
|
||||
c.logger.Error("Failed to send streaming chunk",
|
||||
zap.String("request_id", reqID),
|
||||
zap.Error(err),
|
||||
)
|
||||
}
|
||||
case protocol.DataTypeClose:
|
||||
// Client is closing the stream
|
||||
if c.proxy != nil {
|
||||
@@ -487,7 +568,12 @@ func (c *Connection) SendFrame(frame *protocol.Frame) error {
|
||||
if c.frameWriter == nil {
|
||||
return protocol.WriteFrame(c.conn, frame)
|
||||
}
|
||||
return c.frameWriter.WriteFrame(frame)
|
||||
if err := c.frameWriter.WriteFrame(frame); err != nil {
|
||||
return err
|
||||
}
|
||||
// Flush immediately to ensure the frame is sent without batching delay
|
||||
c.frameWriter.Flush()
|
||||
return nil
|
||||
}
|
||||
|
||||
// sendError sends an error frame to the client
|
||||
|
||||
@@ -21,12 +21,19 @@ type TunnelProxy struct {
|
||||
stopCh chan struct{}
|
||||
wg sync.WaitGroup
|
||||
clientAddr string
|
||||
streams map[string]net.Conn // streamID -> external connection
|
||||
streams map[string]*proxyStream // streamID -> stream info
|
||||
streamMu sync.RWMutex
|
||||
frameWriter *protocol.FrameWriter
|
||||
bufferPool *pool.BufferPool
|
||||
}
|
||||
|
||||
// proxyStream holds connection info with close state
|
||||
type proxyStream struct {
|
||||
conn net.Conn
|
||||
closed bool
|
||||
mu sync.Mutex
|
||||
}
|
||||
|
||||
// NewTunnelProxy creates a new TCP tunnel proxy
|
||||
func NewTunnelProxy(port int, subdomain string, tcpConn net.Conn, logger *zap.Logger) *TunnelProxy {
|
||||
return &TunnelProxy{
|
||||
@@ -36,7 +43,7 @@ func NewTunnelProxy(port int, subdomain string, tcpConn net.Conn, logger *zap.Lo
|
||||
logger: logger,
|
||||
stopCh: make(chan struct{}),
|
||||
clientAddr: tcpConn.RemoteAddr().String(),
|
||||
streams: make(map[string]net.Conn),
|
||||
streams: make(map[string]*proxyStream),
|
||||
bufferPool: pool.NewBufferPool(),
|
||||
frameWriter: protocol.NewFrameWriter(tcpConn),
|
||||
}
|
||||
@@ -101,8 +108,13 @@ func (p *TunnelProxy) handleConnection(conn net.Conn) {
|
||||
|
||||
streamID := fmt.Sprintf("%d-%d", time.Now().UnixNano(), p.port)
|
||||
|
||||
stream := &proxyStream{
|
||||
conn: conn,
|
||||
closed: false,
|
||||
}
|
||||
|
||||
p.streamMu.Lock()
|
||||
p.streams[streamID] = conn
|
||||
p.streams[streamID] = stream
|
||||
p.streamMu.Unlock()
|
||||
|
||||
defer func() {
|
||||
@@ -117,6 +129,14 @@ func (p *TunnelProxy) handleConnection(conn net.Conn) {
|
||||
buffer := (*bufPtr)[:pool.SizeMedium]
|
||||
|
||||
for {
|
||||
// Check if stream is closed
|
||||
stream.mu.Lock()
|
||||
closed := stream.closed
|
||||
stream.mu.Unlock()
|
||||
if closed {
|
||||
break
|
||||
}
|
||||
|
||||
n, err := conn.Read(buffer)
|
||||
if err != nil {
|
||||
break
|
||||
@@ -124,7 +144,7 @@ func (p *TunnelProxy) handleConnection(conn net.Conn) {
|
||||
|
||||
if n > 0 {
|
||||
if err := p.sendDataToTunnel(streamID, buffer[:n]); err != nil {
|
||||
p.logger.Error("Send to tunnel failed", zap.Error(err))
|
||||
p.logger.Debug("Send to tunnel failed", zap.Error(err))
|
||||
break
|
||||
}
|
||||
}
|
||||
@@ -185,15 +205,24 @@ func (p *TunnelProxy) sendCloseToTunnel(streamID string) {
|
||||
|
||||
func (p *TunnelProxy) HandleResponse(streamID string, data []byte) error {
|
||||
p.streamMu.RLock()
|
||||
conn, ok := p.streams[streamID]
|
||||
stream, ok := p.streams[streamID]
|
||||
p.streamMu.RUnlock()
|
||||
|
||||
if !ok {
|
||||
return fmt.Errorf("stream not found: %s", streamID)
|
||||
// Stream may have been closed by client, this is normal
|
||||
return nil
|
||||
}
|
||||
|
||||
if _, err := conn.Write(data); err != nil {
|
||||
p.logger.Error("Write to client failed", zap.Error(err))
|
||||
// Check if stream is closed
|
||||
stream.mu.Lock()
|
||||
if stream.closed {
|
||||
stream.mu.Unlock()
|
||||
return nil
|
||||
}
|
||||
stream.mu.Unlock()
|
||||
|
||||
if _, err := stream.conn.Write(data); err != nil {
|
||||
p.logger.Debug("Write to client failed", zap.Error(err))
|
||||
return err
|
||||
}
|
||||
|
||||
@@ -203,12 +232,24 @@ func (p *TunnelProxy) HandleResponse(streamID string, data []byte) error {
|
||||
// CloseStream closes a stream
|
||||
func (p *TunnelProxy) CloseStream(streamID string) {
|
||||
p.streamMu.RLock()
|
||||
conn, ok := p.streams[streamID]
|
||||
stream, ok := p.streams[streamID]
|
||||
p.streamMu.RUnlock()
|
||||
|
||||
if ok {
|
||||
conn.Close()
|
||||
if !ok {
|
||||
return
|
||||
}
|
||||
|
||||
// Mark as closed first
|
||||
stream.mu.Lock()
|
||||
if stream.closed {
|
||||
stream.mu.Unlock()
|
||||
return
|
||||
}
|
||||
stream.closed = true
|
||||
stream.mu.Unlock()
|
||||
|
||||
// Now close the connection
|
||||
stream.conn.Close()
|
||||
}
|
||||
|
||||
func (p *TunnelProxy) Stop() {
|
||||
@@ -224,10 +265,13 @@ func (p *TunnelProxy) Stop() {
|
||||
}
|
||||
|
||||
p.streamMu.Lock()
|
||||
for _, conn := range p.streams {
|
||||
conn.Close()
|
||||
for _, stream := range p.streams {
|
||||
stream.mu.Lock()
|
||||
stream.closed = true
|
||||
stream.mu.Unlock()
|
||||
stream.conn.Close()
|
||||
}
|
||||
p.streams = make(map[string]net.Conn)
|
||||
p.streams = make(map[string]*proxyStream)
|
||||
p.streamMu.Unlock()
|
||||
|
||||
p.wg.Wait()
|
||||
|
||||
280
internal/shared/compression/hpack/decoder.go
Normal file
280
internal/shared/compression/hpack/decoder.go
Normal file
@@ -0,0 +1,280 @@
|
||||
package hpack
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"errors"
|
||||
"fmt"
|
||||
"net/http"
|
||||
"sync"
|
||||
)
|
||||
|
||||
// Decoder decompresses HPACK-encoded headers
|
||||
// Each connection MUST have its own decoder instance to maintain correct state
|
||||
type Decoder struct {
|
||||
mu sync.Mutex
|
||||
dynamicTable *DynamicTable
|
||||
staticTable *StaticTable
|
||||
maxTableSize uint32
|
||||
}
|
||||
|
||||
// NewDecoder creates a new HPACK decoder
|
||||
func NewDecoder(maxTableSize uint32) *Decoder {
|
||||
if maxTableSize == 0 {
|
||||
maxTableSize = DefaultDynamicTableSize
|
||||
}
|
||||
|
||||
return &Decoder{
|
||||
dynamicTable: NewDynamicTable(maxTableSize),
|
||||
staticTable: GetStaticTable(),
|
||||
maxTableSize: maxTableSize,
|
||||
}
|
||||
}
|
||||
|
||||
// Decode decodes HPACK-encoded headers
|
||||
func (d *Decoder) Decode(data []byte) (http.Header, error) {
|
||||
d.mu.Lock()
|
||||
defer d.mu.Unlock()
|
||||
|
||||
if len(data) == 0 {
|
||||
return http.Header{}, nil
|
||||
}
|
||||
|
||||
headers := make(http.Header)
|
||||
buf := bytes.NewReader(data)
|
||||
|
||||
for buf.Len() > 0 {
|
||||
b, err := buf.ReadByte()
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("read header byte: %w", err)
|
||||
}
|
||||
|
||||
// Unread the byte so we can process it properly
|
||||
if err := buf.UnreadByte(); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
var name, value string
|
||||
|
||||
if b&0x80 != 0 {
|
||||
// Indexed header field (10xxxxxx)
|
||||
name, value, err = d.decodeIndexedHeader(buf)
|
||||
} else if b&0x40 != 0 {
|
||||
// Literal with incremental indexing (01xxxxxx)
|
||||
name, value, err = d.decodeLiteralWithIndexing(buf)
|
||||
} else {
|
||||
// Literal without indexing (0000xxxx)
|
||||
name, value, err = d.decodeLiteralWithoutIndexing(buf)
|
||||
}
|
||||
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
headers.Add(name, value)
|
||||
}
|
||||
|
||||
return headers, nil
|
||||
}
|
||||
|
||||
// decodeIndexedHeader decodes an indexed header field
|
||||
func (d *Decoder) decodeIndexedHeader(buf *bytes.Reader) (string, string, error) {
|
||||
index, err := d.readInteger(buf, 7)
|
||||
if err != nil {
|
||||
return "", "", fmt.Errorf("read index: %w", err)
|
||||
}
|
||||
|
||||
if index == 0 {
|
||||
return "", "", errors.New("invalid index: 0")
|
||||
}
|
||||
|
||||
staticSize := uint32(d.staticTable.Size())
|
||||
|
||||
if index <= staticSize {
|
||||
// Static table
|
||||
return d.staticTable.Get(index - 1)
|
||||
}
|
||||
|
||||
// Dynamic table (indices start after static table)
|
||||
dynamicIndex := index - staticSize - 1
|
||||
return d.dynamicTable.Get(dynamicIndex)
|
||||
}
|
||||
|
||||
// decodeLiteralWithIndexing decodes a literal header with incremental indexing
|
||||
func (d *Decoder) decodeLiteralWithIndexing(buf *bytes.Reader) (string, string, error) {
|
||||
nameIndex, err := d.readInteger(buf, 6)
|
||||
if err != nil {
|
||||
return "", "", err
|
||||
}
|
||||
|
||||
var name string
|
||||
if nameIndex == 0 {
|
||||
// Name is literal
|
||||
name, err = d.readString(buf)
|
||||
if err != nil {
|
||||
return "", "", fmt.Errorf("read name: %w", err)
|
||||
}
|
||||
} else {
|
||||
// Name is indexed
|
||||
staticSize := uint32(d.staticTable.Size())
|
||||
if nameIndex <= staticSize {
|
||||
name, _, err = d.staticTable.Get(nameIndex - 1)
|
||||
} else {
|
||||
dynamicIndex := nameIndex - staticSize - 1
|
||||
name, _, err = d.dynamicTable.Get(dynamicIndex)
|
||||
}
|
||||
if err != nil {
|
||||
return "", "", fmt.Errorf("get indexed name: %w", err)
|
||||
}
|
||||
}
|
||||
|
||||
// Value is always literal
|
||||
value, err := d.readString(buf)
|
||||
if err != nil {
|
||||
return "", "", fmt.Errorf("read value: %w", err)
|
||||
}
|
||||
|
||||
// Add to dynamic table
|
||||
d.dynamicTable.Add(name, value)
|
||||
|
||||
return name, value, nil
|
||||
}
|
||||
|
||||
// decodeLiteralWithoutIndexing decodes a literal header without indexing
|
||||
func (d *Decoder) decodeLiteralWithoutIndexing(buf *bytes.Reader) (string, string, error) {
|
||||
nameIndex, err := d.readInteger(buf, 4)
|
||||
if err != nil {
|
||||
return "", "", err
|
||||
}
|
||||
|
||||
var name string
|
||||
if nameIndex == 0 {
|
||||
// Name is literal
|
||||
name, err = d.readString(buf)
|
||||
if err != nil {
|
||||
return "", "", fmt.Errorf("read name: %w", err)
|
||||
}
|
||||
} else {
|
||||
// Name is indexed
|
||||
staticSize := uint32(d.staticTable.Size())
|
||||
if nameIndex <= staticSize {
|
||||
name, _, err = d.staticTable.Get(nameIndex - 1)
|
||||
} else {
|
||||
dynamicIndex := nameIndex - staticSize - 1
|
||||
name, _, err = d.dynamicTable.Get(dynamicIndex)
|
||||
}
|
||||
if err != nil {
|
||||
return "", "", fmt.Errorf("get indexed name: %w", err)
|
||||
}
|
||||
}
|
||||
|
||||
// Value is always literal
|
||||
value, err := d.readString(buf)
|
||||
if err != nil {
|
||||
return "", "", fmt.Errorf("read value: %w", err)
|
||||
}
|
||||
|
||||
// Do NOT add to dynamic table
|
||||
|
||||
return name, value, nil
|
||||
}
|
||||
|
||||
// readInteger reads an HPACK integer
|
||||
func (d *Decoder) readInteger(buf *bytes.Reader, prefixBits int) (uint32, error) {
|
||||
if prefixBits < 1 || prefixBits > 8 {
|
||||
return 0, fmt.Errorf("invalid prefix bits: %d", prefixBits)
|
||||
}
|
||||
|
||||
b, err := buf.ReadByte()
|
||||
if err != nil {
|
||||
return 0, err
|
||||
}
|
||||
|
||||
maxPrefix := uint32((1 << prefixBits) - 1)
|
||||
mask := byte(maxPrefix)
|
||||
|
||||
value := uint32(b & mask)
|
||||
if value < maxPrefix {
|
||||
return value, nil
|
||||
}
|
||||
|
||||
// Multi-byte integer
|
||||
m := uint32(0)
|
||||
for {
|
||||
b, err := buf.ReadByte()
|
||||
if err != nil {
|
||||
return 0, err
|
||||
}
|
||||
|
||||
value += uint32(b&0x7f) << m
|
||||
m += 7
|
||||
|
||||
if b&0x80 == 0 {
|
||||
break
|
||||
}
|
||||
|
||||
if m > 28 {
|
||||
return 0, errors.New("integer overflow")
|
||||
}
|
||||
}
|
||||
|
||||
return value, nil
|
||||
}
|
||||
|
||||
// readString reads an HPACK string
|
||||
func (d *Decoder) readString(buf *bytes.Reader) (string, error) {
|
||||
b, err := buf.ReadByte()
|
||||
if err != nil {
|
||||
return "", err
|
||||
}
|
||||
|
||||
if err := buf.UnreadByte(); err != nil {
|
||||
return "", err
|
||||
}
|
||||
|
||||
huffmanEncoded := (b & 0x80) != 0
|
||||
|
||||
length, err := d.readInteger(buf, 7)
|
||||
if err != nil {
|
||||
return "", fmt.Errorf("read string length: %w", err)
|
||||
}
|
||||
|
||||
if length == 0 {
|
||||
return "", nil
|
||||
}
|
||||
|
||||
if length > uint32(buf.Len()) {
|
||||
return "", fmt.Errorf("string length %d exceeds buffer size %d", length, buf.Len())
|
||||
}
|
||||
|
||||
data := make([]byte, length)
|
||||
n, err := buf.Read(data)
|
||||
if err != nil {
|
||||
return "", err
|
||||
}
|
||||
if n != int(length) {
|
||||
return "", fmt.Errorf("expected %d bytes, read %d", length, n)
|
||||
}
|
||||
|
||||
if huffmanEncoded {
|
||||
// TODO: Implement Huffman decoding if needed
|
||||
return "", errors.New("huffman decoding not implemented")
|
||||
}
|
||||
|
||||
return string(data), nil
|
||||
}
|
||||
|
||||
// SetMaxTableSize updates the dynamic table size
|
||||
func (d *Decoder) SetMaxTableSize(size uint32) {
|
||||
d.mu.Lock()
|
||||
defer d.mu.Unlock()
|
||||
|
||||
d.maxTableSize = size
|
||||
d.dynamicTable.SetMaxSize(size)
|
||||
}
|
||||
|
||||
// Reset clears the dynamic table
|
||||
func (d *Decoder) Reset() {
|
||||
d.mu.Lock()
|
||||
defer d.mu.Unlock()
|
||||
d.dynamicTable = NewDynamicTable(d.maxTableSize)
|
||||
}
|
||||
124
internal/shared/compression/hpack/dynamic_table.go
Normal file
124
internal/shared/compression/hpack/dynamic_table.go
Normal file
@@ -0,0 +1,124 @@
|
||||
package hpack
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
)
|
||||
|
||||
// DynamicTable implements the HPACK dynamic table (RFC 7541 Section 2.3.2)
|
||||
// The dynamic table is a FIFO queue where new entries are added at the beginning
|
||||
// and old entries are evicted when the table size exceeds the maximum
|
||||
type DynamicTable struct {
|
||||
entries []HeaderField
|
||||
size uint32 // Current size in bytes
|
||||
maxSize uint32 // Maximum size in bytes
|
||||
}
|
||||
|
||||
// HeaderField represents a header name-value pair
|
||||
type HeaderField struct {
|
||||
Name string
|
||||
Value string
|
||||
}
|
||||
|
||||
// Size returns the size of this header field in bytes
|
||||
// RFC 7541: size = len(name) + len(value) + 32
|
||||
func (h *HeaderField) Size() uint32 {
|
||||
return uint32(len(h.Name) + len(h.Value) + 32)
|
||||
}
|
||||
|
||||
// NewDynamicTable creates a new dynamic table with the specified maximum size
|
||||
func NewDynamicTable(maxSize uint32) *DynamicTable {
|
||||
return &DynamicTable{
|
||||
entries: make([]HeaderField, 0, 32),
|
||||
size: 0,
|
||||
maxSize: maxSize,
|
||||
}
|
||||
}
|
||||
|
||||
// Add adds a header field to the dynamic table
|
||||
// New entries are added at the beginning (index 0)
|
||||
func (dt *DynamicTable) Add(name, value string) {
|
||||
field := HeaderField{Name: name, Value: value}
|
||||
fieldSize := field.Size()
|
||||
|
||||
// If the field is larger than maxSize, don't add it
|
||||
if fieldSize > dt.maxSize {
|
||||
dt.evictAll()
|
||||
return
|
||||
}
|
||||
|
||||
// Evict entries if necessary to make room
|
||||
for dt.size+fieldSize > dt.maxSize && len(dt.entries) > 0 {
|
||||
dt.evictOldest()
|
||||
}
|
||||
|
||||
// Add new entry at the beginning
|
||||
dt.entries = append([]HeaderField{field}, dt.entries...)
|
||||
dt.size += fieldSize
|
||||
}
|
||||
|
||||
// Get retrieves a header field by index (0-based)
|
||||
// Index 0 is the most recently added entry
|
||||
func (dt *DynamicTable) Get(index uint32) (string, string, error) {
|
||||
if index >= uint32(len(dt.entries)) {
|
||||
return "", "", fmt.Errorf("index %d out of range (table size: %d)", index, len(dt.entries))
|
||||
}
|
||||
|
||||
field := dt.entries[index]
|
||||
return field.Name, field.Value, nil
|
||||
}
|
||||
|
||||
// FindExact searches for an exact match (name and value)
|
||||
// Returns the index (0-based) and true if found
|
||||
func (dt *DynamicTable) FindExact(name, value string) (uint32, bool) {
|
||||
for i, field := range dt.entries {
|
||||
if field.Name == name && field.Value == value {
|
||||
return uint32(i), true
|
||||
}
|
||||
}
|
||||
return 0, false
|
||||
}
|
||||
|
||||
// FindName searches for a name match
|
||||
// Returns the index (0-based) and true if found
|
||||
func (dt *DynamicTable) FindName(name string) (uint32, bool) {
|
||||
for i, field := range dt.entries {
|
||||
if field.Name == name {
|
||||
return uint32(i), true
|
||||
}
|
||||
}
|
||||
return 0, false
|
||||
}
|
||||
|
||||
// SetMaxSize updates the maximum table size
|
||||
// If the new size is smaller, entries are evicted
|
||||
func (dt *DynamicTable) SetMaxSize(maxSize uint32) {
|
||||
dt.maxSize = maxSize
|
||||
|
||||
// Evict entries if current size exceeds new max
|
||||
for dt.size > dt.maxSize && len(dt.entries) > 0 {
|
||||
dt.evictOldest()
|
||||
}
|
||||
}
|
||||
|
||||
// CurrentSize returns the current size of the table in bytes
|
||||
func (dt *DynamicTable) CurrentSize() uint32 {
|
||||
return dt.size
|
||||
}
|
||||
|
||||
// evictOldest removes the oldest entry (last in the slice)
|
||||
func (dt *DynamicTable) evictOldest() {
|
||||
if len(dt.entries) == 0 {
|
||||
return
|
||||
}
|
||||
|
||||
lastIndex := len(dt.entries) - 1
|
||||
evicted := dt.entries[lastIndex]
|
||||
dt.entries = dt.entries[:lastIndex]
|
||||
dt.size -= evicted.Size()
|
||||
}
|
||||
|
||||
// evictAll removes all entries
|
||||
func (dt *DynamicTable) evictAll() {
|
||||
dt.entries = dt.entries[:0]
|
||||
dt.size = 0
|
||||
}
|
||||
200
internal/shared/compression/hpack/encoder.go
Normal file
200
internal/shared/compression/hpack/encoder.go
Normal file
@@ -0,0 +1,200 @@
|
||||
package hpack
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"errors"
|
||||
"fmt"
|
||||
"net/http"
|
||||
"strings"
|
||||
"sync"
|
||||
)
|
||||
|
||||
const (
|
||||
// DefaultDynamicTableSize is the default size of the dynamic table (4KB)
|
||||
DefaultDynamicTableSize = 4096
|
||||
|
||||
// IndexedHeaderField represents a fully indexed header field
|
||||
indexedHeaderField = 0x80 // 10xxxxxx
|
||||
|
||||
// LiteralHeaderFieldWithIndexing represents a literal with incremental indexing
|
||||
literalHeaderFieldWithIndexing = 0x40 // 01xxxxxx
|
||||
)
|
||||
|
||||
// Encoder compresses HTTP headers using HPACK
|
||||
// Each connection MUST have its own encoder instance to avoid state corruption
|
||||
type Encoder struct {
|
||||
mu sync.Mutex
|
||||
dynamicTable *DynamicTable
|
||||
staticTable *StaticTable
|
||||
maxTableSize uint32
|
||||
}
|
||||
|
||||
// NewEncoder creates a new HPACK encoder with the specified dynamic table size
|
||||
// This encoder is NOT thread-safe and should be used by a single connection
|
||||
func NewEncoder(maxTableSize uint32) *Encoder {
|
||||
if maxTableSize == 0 {
|
||||
maxTableSize = DefaultDynamicTableSize
|
||||
}
|
||||
|
||||
return &Encoder{
|
||||
dynamicTable: NewDynamicTable(maxTableSize),
|
||||
staticTable: GetStaticTable(),
|
||||
maxTableSize: maxTableSize,
|
||||
}
|
||||
}
|
||||
|
||||
// Encode encodes HTTP headers into HPACK binary format
|
||||
// This method is safe to call concurrently within the same encoder instance
|
||||
func (e *Encoder) Encode(headers http.Header) ([]byte, error) {
|
||||
e.mu.Lock()
|
||||
defer e.mu.Unlock()
|
||||
|
||||
if headers == nil {
|
||||
return nil, errors.New("headers cannot be nil")
|
||||
}
|
||||
|
||||
buf := &bytes.Buffer{}
|
||||
|
||||
for name, values := range headers {
|
||||
for _, value := range values {
|
||||
if err := e.encodeHeaderField(buf, name, value); err != nil {
|
||||
return nil, fmt.Errorf("encode header %s: %w", name, err)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
return buf.Bytes(), nil
|
||||
}
|
||||
|
||||
// encodeHeaderField encodes a single header field
|
||||
func (e *Encoder) encodeHeaderField(buf *bytes.Buffer, name, value string) error {
|
||||
// HTTP/2 requires header names to be lowercase (RFC 7540 Section 8.1.2)
|
||||
// Convert to lowercase for table lookups and storage
|
||||
nameLower := strings.ToLower(name)
|
||||
|
||||
// Try to find in static table first
|
||||
if index, found := e.staticTable.FindExact(nameLower, value); found {
|
||||
return e.writeIndexedHeader(buf, index+1)
|
||||
}
|
||||
|
||||
// Check if name exists in static table (for literal with name reference)
|
||||
if index, found := e.staticTable.FindName(nameLower); found {
|
||||
return e.writeLiteralWithIndexing(buf, index+1, value, true)
|
||||
}
|
||||
|
||||
// Try dynamic table
|
||||
if index, found := e.dynamicTable.FindExact(nameLower, value); found {
|
||||
// Dynamic table indices start after static table
|
||||
dynamicIndex := uint32(e.staticTable.Size()) + index + 1
|
||||
return e.writeIndexedHeader(buf, dynamicIndex)
|
||||
}
|
||||
|
||||
if index, found := e.dynamicTable.FindName(nameLower); found {
|
||||
dynamicIndex := uint32(e.staticTable.Size()) + index + 1
|
||||
return e.writeLiteralWithIndexing(buf, dynamicIndex, value, true)
|
||||
}
|
||||
|
||||
// Not found anywhere - literal with indexing and new name
|
||||
// Write literal flag
|
||||
buf.WriteByte(literalHeaderFieldWithIndexing)
|
||||
|
||||
// Write name as literal string (must come before value)
|
||||
// Use lowercase name for consistency
|
||||
if err := e.writeString(buf, nameLower, false); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
// Write value as literal string
|
||||
if err := e.writeString(buf, value, false); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
// Add to dynamic table with lowercase name
|
||||
e.dynamicTable.Add(nameLower, value)
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// writeIndexedHeader writes an indexed header field (10xxxxxx)
|
||||
func (e *Encoder) writeIndexedHeader(buf *bytes.Buffer, index uint32) error {
|
||||
return e.writeInteger(buf, index, 7, indexedHeaderField)
|
||||
}
|
||||
|
||||
// writeLiteralWithIndexing writes a literal header with incremental indexing (01xxxxxx)
|
||||
func (e *Encoder) writeLiteralWithIndexing(buf *bytes.Buffer, nameIndex uint32, value string, hasIndex bool) error {
|
||||
if hasIndex {
|
||||
// Write name as index
|
||||
if err := e.writeInteger(buf, nameIndex, 6, literalHeaderFieldWithIndexing); err != nil {
|
||||
return err
|
||||
}
|
||||
} else {
|
||||
// Write literal flag
|
||||
buf.WriteByte(literalHeaderFieldWithIndexing)
|
||||
}
|
||||
|
||||
// Write value as literal string
|
||||
return e.writeString(buf, value, false)
|
||||
}
|
||||
|
||||
// writeInteger writes an integer using HPACK integer representation
|
||||
func (e *Encoder) writeInteger(buf *bytes.Buffer, value uint32, prefixBits int, prefix byte) error {
|
||||
if prefixBits < 1 || prefixBits > 8 {
|
||||
return fmt.Errorf("invalid prefix bits: %d", prefixBits)
|
||||
}
|
||||
|
||||
maxPrefix := uint32((1 << prefixBits) - 1)
|
||||
|
||||
if value < maxPrefix {
|
||||
buf.WriteByte(prefix | byte(value))
|
||||
return nil
|
||||
}
|
||||
|
||||
// Value >= maxPrefix, need multiple bytes
|
||||
buf.WriteByte(prefix | byte(maxPrefix))
|
||||
value -= maxPrefix
|
||||
|
||||
for value >= 128 {
|
||||
buf.WriteByte(byte(value%128) | 0x80)
|
||||
value /= 128
|
||||
}
|
||||
buf.WriteByte(byte(value))
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// writeString writes a string using HPACK string representation
|
||||
func (e *Encoder) writeString(buf *bytes.Buffer, str string, huffmanEncode bool) error {
|
||||
// For simplicity, we don't use Huffman encoding in this implementation
|
||||
// Huffman flag is bit 7, followed by length in remaining 7 bits
|
||||
|
||||
length := uint32(len(str))
|
||||
if huffmanEncode {
|
||||
// TODO: Implement Huffman encoding if needed
|
||||
return errors.New("huffman encoding not implemented")
|
||||
}
|
||||
|
||||
// Write length with H=0 (no Huffman)
|
||||
if err := e.writeInteger(buf, length, 7, 0x00); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
// Write string bytes
|
||||
buf.WriteString(str)
|
||||
return nil
|
||||
}
|
||||
|
||||
// SetMaxTableSize updates the dynamic table size
|
||||
func (e *Encoder) SetMaxTableSize(size uint32) {
|
||||
e.mu.Lock()
|
||||
defer e.mu.Unlock()
|
||||
|
||||
e.maxTableSize = size
|
||||
e.dynamicTable.SetMaxSize(size)
|
||||
}
|
||||
|
||||
// Reset clears the dynamic table
|
||||
func (e *Encoder) Reset() {
|
||||
e.mu.Lock()
|
||||
defer e.mu.Unlock()
|
||||
e.dynamicTable = NewDynamicTable(e.maxTableSize)
|
||||
}
|
||||
150
internal/shared/compression/hpack/static_table.go
Normal file
150
internal/shared/compression/hpack/static_table.go
Normal file
@@ -0,0 +1,150 @@
|
||||
package hpack
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"sync"
|
||||
)
|
||||
|
||||
// StaticTable implements the HPACK static table (RFC 7541 Appendix A)
|
||||
// The static table is predefined and never changes
|
||||
type StaticTable struct {
|
||||
entries []HeaderField
|
||||
nameMap map[string][]uint32 // Maps name to list of indices
|
||||
}
|
||||
|
||||
var (
|
||||
staticTableInstance *StaticTable
|
||||
staticTableOnce sync.Once
|
||||
)
|
||||
|
||||
// GetStaticTable returns the singleton static table instance
|
||||
func GetStaticTable() *StaticTable {
|
||||
staticTableOnce.Do(func() {
|
||||
staticTableInstance = newStaticTable()
|
||||
})
|
||||
return staticTableInstance
|
||||
}
|
||||
|
||||
// newStaticTable creates and initializes the static table
|
||||
func newStaticTable() *StaticTable {
|
||||
// RFC 7541 Appendix A - Static Table Definition
|
||||
// We include the most common headers for HTTP
|
||||
entries := []HeaderField{
|
||||
{Name: ":authority", Value: ""},
|
||||
{Name: ":method", Value: "GET"},
|
||||
{Name: ":method", Value: "POST"},
|
||||
{Name: ":path", Value: "/"},
|
||||
{Name: ":path", Value: "/index.html"},
|
||||
{Name: ":scheme", Value: "http"},
|
||||
{Name: ":scheme", Value: "https"},
|
||||
{Name: ":status", Value: "200"},
|
||||
{Name: ":status", Value: "204"},
|
||||
{Name: ":status", Value: "206"},
|
||||
{Name: ":status", Value: "304"},
|
||||
{Name: ":status", Value: "400"},
|
||||
{Name: ":status", Value: "404"},
|
||||
{Name: ":status", Value: "500"},
|
||||
{Name: "accept-charset", Value: ""},
|
||||
{Name: "accept-encoding", Value: "gzip, deflate"},
|
||||
{Name: "accept-language", Value: ""},
|
||||
{Name: "accept-ranges", Value: ""},
|
||||
{Name: "accept", Value: ""},
|
||||
{Name: "access-control-allow-origin", Value: ""},
|
||||
{Name: "age", Value: ""},
|
||||
{Name: "allow", Value: ""},
|
||||
{Name: "authorization", Value: ""},
|
||||
{Name: "cache-control", Value: ""},
|
||||
{Name: "content-disposition", Value: ""},
|
||||
{Name: "content-encoding", Value: ""},
|
||||
{Name: "content-language", Value: ""},
|
||||
{Name: "content-length", Value: ""},
|
||||
{Name: "content-location", Value: ""},
|
||||
{Name: "content-range", Value: ""},
|
||||
{Name: "content-type", Value: ""},
|
||||
{Name: "cookie", Value: ""},
|
||||
{Name: "date", Value: ""},
|
||||
{Name: "etag", Value: ""},
|
||||
{Name: "expect", Value: ""},
|
||||
{Name: "expires", Value: ""},
|
||||
{Name: "from", Value: ""},
|
||||
{Name: "host", Value: ""},
|
||||
{Name: "if-match", Value: ""},
|
||||
{Name: "if-modified-since", Value: ""},
|
||||
{Name: "if-none-match", Value: ""},
|
||||
{Name: "if-range", Value: ""},
|
||||
{Name: "if-unmodified-since", Value: ""},
|
||||
{Name: "last-modified", Value: ""},
|
||||
{Name: "link", Value: ""},
|
||||
{Name: "location", Value: ""},
|
||||
{Name: "max-forwards", Value: ""},
|
||||
{Name: "proxy-authenticate", Value: ""},
|
||||
{Name: "proxy-authorization", Value: ""},
|
||||
{Name: "range", Value: ""},
|
||||
{Name: "referer", Value: ""},
|
||||
{Name: "refresh", Value: ""},
|
||||
{Name: "retry-after", Value: ""},
|
||||
{Name: "server", Value: ""},
|
||||
{Name: "set-cookie", Value: ""},
|
||||
{Name: "strict-transport-security", Value: ""},
|
||||
{Name: "transfer-encoding", Value: ""},
|
||||
{Name: "user-agent", Value: ""},
|
||||
{Name: "vary", Value: ""},
|
||||
{Name: "via", Value: ""},
|
||||
{Name: "www-authenticate", Value: ""},
|
||||
}
|
||||
|
||||
// Build name index map
|
||||
nameMap := make(map[string][]uint32)
|
||||
for i, entry := range entries {
|
||||
nameMap[entry.Name] = append(nameMap[entry.Name], uint32(i))
|
||||
}
|
||||
|
||||
return &StaticTable{
|
||||
entries: entries,
|
||||
nameMap: nameMap,
|
||||
}
|
||||
}
|
||||
|
||||
// Get retrieves a header field by index (0-based)
|
||||
func (st *StaticTable) Get(index uint32) (string, string, error) {
|
||||
if index >= uint32(len(st.entries)) {
|
||||
return "", "", fmt.Errorf("index %d out of range (static table size: %d)", index, len(st.entries))
|
||||
}
|
||||
|
||||
field := st.entries[index]
|
||||
return field.Name, field.Value, nil
|
||||
}
|
||||
|
||||
// FindExact searches for an exact match (name and value)
|
||||
// Returns the index (0-based) and true if found
|
||||
func (st *StaticTable) FindExact(name, value string) (uint32, bool) {
|
||||
indices, exists := st.nameMap[name]
|
||||
if !exists {
|
||||
return 0, false
|
||||
}
|
||||
|
||||
for _, index := range indices {
|
||||
field := st.entries[index]
|
||||
if field.Value == value {
|
||||
return index, true
|
||||
}
|
||||
}
|
||||
|
||||
return 0, false
|
||||
}
|
||||
|
||||
// FindName searches for a name match
|
||||
// Returns the first matching index (0-based) and true if found
|
||||
func (st *StaticTable) FindName(name string) (uint32, bool) {
|
||||
indices, exists := st.nameMap[name]
|
||||
if !exists || len(indices) == 0 {
|
||||
return 0, false
|
||||
}
|
||||
|
||||
return indices[0], true
|
||||
}
|
||||
|
||||
// Size returns the number of entries in the static table
|
||||
func (st *StaticTable) Size() int {
|
||||
return len(st.entries)
|
||||
}
|
||||
@@ -18,9 +18,6 @@ const (
|
||||
// RequestTimeout is the maximum time to wait for a response from the client
|
||||
RequestTimeout = 30 * time.Second
|
||||
|
||||
// MaxRequestBodySize is the maximum size of an HTTP request body (10MB)
|
||||
MaxRequestBodySize = 10 * 1024 * 1024
|
||||
|
||||
// ReconnectBaseDelay is the initial delay for reconnection attempts
|
||||
ReconnectBaseDelay = 1 * time.Second
|
||||
|
||||
|
||||
@@ -18,11 +18,17 @@ type DataHeader struct {
|
||||
type DataType uint8
|
||||
|
||||
const (
|
||||
DataTypeData DataType = 0x00 // 000
|
||||
DataTypeResponse DataType = 0x01 // 001
|
||||
DataTypeClose DataType = 0x02 // 010
|
||||
DataTypeHTTPRequest DataType = 0x03 // 011
|
||||
DataTypeHTTPResponse DataType = 0x04 // 100
|
||||
DataTypeData DataType = 0x00 // 000
|
||||
DataTypeResponse DataType = 0x01 // 001
|
||||
DataTypeClose DataType = 0x02 // 010
|
||||
DataTypeHTTPRequest DataType = 0x03 // 011
|
||||
DataTypeHTTPResponse DataType = 0x04 // 100
|
||||
DataTypeHTTPHead DataType = 0x05 // 101 - streaming headers (shared)
|
||||
DataTypeHTTPBodyChunk DataType = 0x06 // 110 - streaming body chunks (shared)
|
||||
|
||||
// Reuse the same type codes for request streaming to stay within 3 bits.
|
||||
DataTypeHTTPRequestHead DataType = DataTypeHTTPHead
|
||||
DataTypeHTTPRequestBodyChunk DataType = DataTypeHTTPBodyChunk
|
||||
)
|
||||
|
||||
// String returns the string representation of DataType
|
||||
@@ -38,6 +44,10 @@ func (t DataType) String() string {
|
||||
return "http_request"
|
||||
case DataTypeHTTPResponse:
|
||||
return "http_response"
|
||||
case DataTypeHTTPHead:
|
||||
return "http_head"
|
||||
case DataTypeHTTPBodyChunk:
|
||||
return "http_body_chunk"
|
||||
default:
|
||||
return "unknown"
|
||||
}
|
||||
@@ -56,6 +66,10 @@ func DataTypeFromString(s string) DataType {
|
||||
return DataTypeHTTPRequest
|
||||
case "http_response":
|
||||
return DataTypeHTTPResponse
|
||||
case "http_head":
|
||||
return DataTypeHTTPHead
|
||||
case "http_body_chunk":
|
||||
return DataTypeHTTPBodyChunk
|
||||
default:
|
||||
return DataTypeData
|
||||
}
|
||||
@@ -118,8 +132,8 @@ func (h *DataHeader) UnmarshalBinary(data []byte) error {
|
||||
|
||||
// Decode flags
|
||||
flags := data[0]
|
||||
h.Type = DataType(flags & 0x07) // Bits 0-2
|
||||
h.IsLast = (flags & 0x08) != 0 // Bit 3
|
||||
h.Type = DataType(flags & 0x07) // Bits 0-2
|
||||
h.IsLast = (flags & 0x08) != 0 // Bit 3
|
||||
|
||||
// Decode lengths
|
||||
streamIDLen := int(binary.BigEndian.Uint16(data[1:3]))
|
||||
|
||||
@@ -1,9 +1,10 @@
|
||||
package protocol
|
||||
|
||||
import (
|
||||
json "github.com/goccy/go-json"
|
||||
"errors"
|
||||
|
||||
json "github.com/goccy/go-json"
|
||||
|
||||
"github.com/vmihailenco/msgpack/v5"
|
||||
)
|
||||
|
||||
@@ -37,6 +38,31 @@ func DecodeHTTPRequest(data []byte) (*HTTPRequest, error) {
|
||||
return &req, nil
|
||||
}
|
||||
|
||||
// EncodeHTTPRequestHead encodes HTTP request headers for streaming
|
||||
func EncodeHTTPRequestHead(head *HTTPRequestHead) ([]byte, error) {
|
||||
return msgpack.Marshal(head)
|
||||
}
|
||||
|
||||
// DecodeHTTPRequestHead decodes HTTP request headers for streaming
|
||||
func DecodeHTTPRequestHead(data []byte) (*HTTPRequestHead, error) {
|
||||
if len(data) == 0 {
|
||||
return nil, errors.New("empty data")
|
||||
}
|
||||
|
||||
var head HTTPRequestHead
|
||||
if data[0] == '{' {
|
||||
if err := json.Unmarshal(data, &head); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
} else {
|
||||
if err := msgpack.Unmarshal(data, &head); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
}
|
||||
|
||||
return &head, nil
|
||||
}
|
||||
|
||||
// EncodeHTTPResponse encodes HTTPResponse using msgpack encoding (optimized)
|
||||
func EncodeHTTPResponse(resp *HTTPResponse) ([]byte, error) {
|
||||
return msgpack.Marshal(resp)
|
||||
@@ -66,3 +92,28 @@ func DecodeHTTPResponse(data []byte) (*HTTPResponse, error) {
|
||||
|
||||
return &resp, nil
|
||||
}
|
||||
|
||||
// EncodeHTTPResponseHead encodes HTTP response headers for streaming
|
||||
func EncodeHTTPResponseHead(head *HTTPResponseHead) ([]byte, error) {
|
||||
return msgpack.Marshal(head)
|
||||
}
|
||||
|
||||
// DecodeHTTPResponseHead decodes HTTP response headers for streaming
|
||||
func DecodeHTTPResponseHead(data []byte) (*HTTPResponseHead, error) {
|
||||
if len(data) == 0 {
|
||||
return nil, errors.New("empty data")
|
||||
}
|
||||
|
||||
var head HTTPResponseHead
|
||||
if data[0] == '{' {
|
||||
if err := json.Unmarshal(data, &head); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
} else {
|
||||
if err := msgpack.Unmarshal(data, &head); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
}
|
||||
|
||||
return &head, nil
|
||||
}
|
||||
|
||||
@@ -33,6 +33,14 @@ type HTTPRequest struct {
|
||||
Body []byte `json:"body,omitempty"`
|
||||
}
|
||||
|
||||
// HTTPRequestHead represents HTTP request headers for streaming (no body)
|
||||
type HTTPRequestHead struct {
|
||||
Method string `json:"method"`
|
||||
URL string `json:"url"`
|
||||
Headers map[string][]string `json:"headers"`
|
||||
ContentLength int64 `json:"content_length"` // -1 for unknown/chunked
|
||||
}
|
||||
|
||||
// HTTPResponse represents an HTTP response from the local service
|
||||
type HTTPResponse struct {
|
||||
StatusCode int `json:"status_code"`
|
||||
@@ -41,6 +49,14 @@ type HTTPResponse struct {
|
||||
Body []byte `json:"body,omitempty"`
|
||||
}
|
||||
|
||||
// HTTPResponseHead represents HTTP response headers for streaming (no body)
|
||||
type HTTPResponseHead struct {
|
||||
StatusCode int `json:"status_code"`
|
||||
Status string `json:"status"`
|
||||
Headers map[string][]string `json:"headers"`
|
||||
ContentLength int64 `json:"content_length"` // -1 for unknown/chunked
|
||||
}
|
||||
|
||||
// RegisterData contains information sent when a tunnel is registered
|
||||
type RegisterData struct {
|
||||
Subdomain string `json:"subdomain"`
|
||||
|
||||
@@ -7,9 +7,8 @@ import (
|
||||
"drip/internal/shared/pool"
|
||||
)
|
||||
|
||||
// EncodeDataPayload encodes a data header and payload into a frame payload.
|
||||
// Deprecated: Use EncodeDataPayloadPooled for better performance.
|
||||
func EncodeDataPayload(header DataHeader, data []byte) ([]byte, error) {
|
||||
// encodeDataPayload encodes a data header and payload into a frame payload.
|
||||
func encodeDataPayload(header DataHeader, data []byte) ([]byte, error) {
|
||||
streamIDLen := len(header.StreamID)
|
||||
requestIDLen := len(header.RequestID)
|
||||
|
||||
@@ -37,11 +36,6 @@ func EncodeDataPayload(header DataHeader, data []byte) ([]byte, error) {
|
||||
|
||||
// EncodeDataPayloadPooled encodes with adaptive allocation based on load.
|
||||
// Returns payload slice and pool buffer pointer (may be nil).
|
||||
//
|
||||
// Adaptive strategy:
|
||||
// - Mid-load (<150 conn): 256KB threshold, pool disabled → max QPS
|
||||
// - High-load (≥300 conn): 32KB threshold, pool enabled → stable latency
|
||||
// - Transition (150-300): Hysteresis to prevent flapping
|
||||
func EncodeDataPayloadPooled(header DataHeader, data []byte) (payload []byte, poolBuffer *[]byte, err error) {
|
||||
streamIDLen := len(header.StreamID)
|
||||
requestIDLen := len(header.RequestID)
|
||||
@@ -50,12 +44,12 @@ func EncodeDataPayloadPooled(header DataHeader, data []byte) (payload []byte, po
|
||||
dynamicThreshold := GetAdaptiveThreshold()
|
||||
|
||||
if totalLen < dynamicThreshold {
|
||||
regularPayload, err := EncodeDataPayload(header, data)
|
||||
regularPayload, err := encodeDataPayload(header, data)
|
||||
return regularPayload, nil, err
|
||||
}
|
||||
|
||||
if totalLen > pool.SizeLarge {
|
||||
regularPayload, err := EncodeDataPayload(header, data)
|
||||
regularPayload, err := encodeDataPayload(header, data)
|
||||
return regularPayload, nil, err
|
||||
}
|
||||
|
||||
@@ -100,7 +94,3 @@ func DecodeDataPayload(payload []byte) (DataHeader, []byte, error) {
|
||||
data := payload[headerSize:]
|
||||
return header, data, nil
|
||||
}
|
||||
|
||||
func GetPayloadHeaderSize(header DataHeader) int {
|
||||
return header.Size()
|
||||
}
|
||||
|
||||
@@ -172,8 +172,27 @@ func (w *FrameWriter) Close() error {
|
||||
|
||||
func (w *FrameWriter) Flush() {
|
||||
w.mu.Lock()
|
||||
defer w.mu.Unlock()
|
||||
if w.closed {
|
||||
w.mu.Unlock()
|
||||
return
|
||||
}
|
||||
|
||||
// First, drain the queue into batch
|
||||
for {
|
||||
select {
|
||||
case frame, ok := <-w.queue:
|
||||
if !ok {
|
||||
break
|
||||
}
|
||||
w.batch = append(w.batch, frame)
|
||||
default:
|
||||
goto done
|
||||
}
|
||||
}
|
||||
done:
|
||||
// Then flush the batch
|
||||
w.flushBatchLocked()
|
||||
w.mu.Unlock()
|
||||
}
|
||||
|
||||
func (w *FrameWriter) EnableHeartbeat(interval time.Duration, callback func() *Frame) {
|
||||
|
||||
@@ -30,17 +30,6 @@ type ServerConfig struct {
|
||||
Debug bool
|
||||
}
|
||||
|
||||
// LegacyClientConfig holds the legacy client configuration
|
||||
// Deprecated: Use config.ClientConfig from client_config.go instead
|
||||
type LegacyClientConfig struct {
|
||||
ServerURL string
|
||||
LocalTarget string
|
||||
AuthToken string
|
||||
Subdomain string
|
||||
Verbose bool
|
||||
Insecure bool // Skip TLS verification (for testing only)
|
||||
}
|
||||
|
||||
// LoadTLSConfig loads TLS configuration
|
||||
func (c *ServerConfig) LoadTLSConfig() (*tls.Config, error) {
|
||||
if !c.TLSEnabled {
|
||||
|
||||
Reference in New Issue
Block a user