Files
drip/internal/server/proxy/handler.go
Gouryella 89f67ab145 feat(client): Add bandwidth limit function support
- Implement client bandwidth limitation parameter --bandwidth, supporting 1M, 1MB, 1G and other formats
- Added parseBandwidth function to parse bandwidth values and verify them
- Added bandwidth limit option in HTTP, HTTPS, TCP commands
- Pass bandwidth configuration to the server through protocol
- Add relevant test cases to verify the bandwidth analysis function

feat(server): implements server-side bandwidth limitation function

- Add bandwidth limitation logic in connection processing, using token bucket algorithm
- Implement an effective rate limiting strategy that minimizes the bandwidth of the client and server
- Added QoS limiter and restricted connection wrapper
- Integrated bandwidth throttling in HTTP and WebSocket proxies
- Added global bandwidth limit and burst multiplier settings in server configuration

docs: Updated documentation to describe bandwidth limiting functionality

- Add 2025-02-14 version update instructions in README and README_CN
- Add bandwidth limit function description and usage examples
- Provide client and server configuration examples and parameter descriptions
2026-02-15 02:39:50 +08:00

452 lines
10 KiB
Go

package proxy
import (
"bufio"
"crypto/subtle"
"fmt"
"io"
"net"
"net/http"
"net/url"
"strings"
"sync"
"time"
"github.com/gorilla/websocket"
"go.uber.org/zap"
"drip/internal/server/tunnel"
"drip/internal/shared/httputil"
"drip/internal/shared/netutil"
"drip/internal/shared/pool"
"drip/internal/shared/protocol"
"drip/internal/shared/qos"
)
// bufio.Reader pool to reduce allocations on hot path
var bufioReaderPool = sync.Pool{
New: func() interface{} {
return bufio.NewReaderSize(nil, 32*1024)
},
}
const openStreamTimeout = 3 * time.Second
type HandlerConfig struct {
Manager *tunnel.Manager
Logger *zap.Logger
ServerDomain string
TunnelDomain string
AuthToken string
MetricsToken string
}
type Handler struct {
manager *tunnel.Manager
logger *zap.Logger
serverDomain string
tunnelDomain string
authToken string
metricsToken string
publicPort int
// WebSocket tunnel support
wsUpgrader websocket.Upgrader
wsConnHandler WSConnectionHandler
// Server capabilities
allowedTransports []string
allowedTunnelTypes []string
}
// WSConnectionHandler handles WebSocket tunnel connections
type WSConnectionHandler interface {
HandleWSConnection(conn net.Conn, remoteAddr string)
}
func NewHandler(cfg HandlerConfig) *Handler {
return &Handler{
manager: cfg.Manager,
logger: cfg.Logger,
serverDomain: cfg.ServerDomain,
tunnelDomain: cfg.TunnelDomain,
authToken: cfg.AuthToken,
metricsToken: cfg.MetricsToken,
wsUpgrader: websocket.Upgrader{
ReadBufferSize: 256 * 1024,
WriteBufferSize: 256 * 1024,
CheckOrigin: func(r *http.Request) bool {
return true // Allow all origins for tunnel connections
},
},
}
}
// SetWSConnectionHandler sets the handler for WebSocket tunnel connections
func (h *Handler) SetWSConnectionHandler(handler WSConnectionHandler) {
h.wsConnHandler = handler
}
// SetPublicPort sets the public port for URL generation
func (h *Handler) SetPublicPort(port int) {
h.publicPort = port
}
// SetAllowedTransports sets the allowed transport protocols
func (h *Handler) SetAllowedTransports(transports []string) {
h.allowedTransports = transports
}
// SetAllowedTunnelTypes sets the allowed tunnel types
func (h *Handler) SetAllowedTunnelTypes(types []string) {
h.allowedTunnelTypes = types
}
// IsTransportAllowed checks if a transport is allowed
func (h *Handler) IsTransportAllowed(transport string) bool {
if len(h.allowedTransports) == 0 {
return true
}
for _, t := range h.allowedTransports {
if strings.EqualFold(t, transport) {
return true
}
}
return false
}
// IsTunnelTypeAllowed checks if a tunnel type is allowed
func (h *Handler) IsTunnelTypeAllowed(tunnelType string) bool {
if len(h.allowedTunnelTypes) == 0 {
return true
}
for _, t := range h.allowedTunnelTypes {
if strings.EqualFold(t, tunnelType) {
return true
}
}
return false
}
// GetPreferredTransport returns the preferred transport for auto-detection
func (h *Handler) GetPreferredTransport() string {
if len(h.allowedTransports) == 0 {
return "tcp"
}
if len(h.allowedTransports) == 1 {
return h.allowedTransports[0]
}
return "tcp"
}
func (h *Handler) ServeHTTP(w http.ResponseWriter, r *http.Request) {
// Discovery endpoint for client auto-detection
if r.URL.Path == "/_drip/discover" {
h.serveDiscovery(w, r)
return
}
// WebSocket tunnel endpoint - must be checked before other routes
if r.URL.Path == "/_drip/ws" {
h.handleTunnelWebSocket(w, r)
return
}
if r.URL.Path == "/health" {
h.serveHealth(w, r)
return
}
if r.URL.Path == "/stats" {
h.serveStats(w, r)
return
}
if r.URL.Path == "/metrics" {
h.serveMetrics(w, r)
return
}
subdomain, result := h.extractSubdomain(r.Host)
switch result {
case subdomainHome:
h.serveHomePage(w, r)
return
case subdomainNotFound:
h.serveTunnelNotFound(w, r)
return
}
tconn, ok := h.manager.Get(subdomain)
if !ok || tconn == nil {
h.serveTunnelNotFound(w, r)
return
}
if tconn.IsClosed() {
http.Error(w, "Tunnel connection closed", http.StatusBadGateway)
return
}
if tconn.HasIPAccessControl() {
clientIP := netutil.ExtractClientIP(r)
if !tconn.IsIPAllowed(clientIP) {
http.Error(w, "Access denied: your IP is not allowed", http.StatusForbidden)
return
}
}
if auth := tconn.GetProxyAuth(); auth != nil && auth.Enabled {
clientIP := netutil.ExtractClientIP(r)
if authLimiter.isRateLimited(clientIP) {
w.Header().Set("Retry-After", "60")
http.Error(w, "Too many failed authentication attempts. Please try again later.", http.StatusTooManyRequests)
return
}
if isBearerProxyAuth(auth) {
if !h.isBearerAuthenticated(r, auth) {
authLimiter.recordFailure(clientIP)
h.serveBearerAuthRequired(w, "drip")
return
}
authLimiter.resetFailures(clientIP)
} else {
if r.URL.Path == "/_drip/login" {
h.handleProxyLoginWithRateLimit(w, r, tconn, subdomain, clientIP)
return
}
if !h.isProxyAuthenticated(r, subdomain) {
h.serveLoginPage(w, r, subdomain, "")
return
}
}
}
tType := tconn.GetTunnelType()
if tType != "" && tType != protocol.TunnelTypeHTTP && tType != protocol.TunnelTypeHTTPS {
http.Error(w, "Tunnel does not accept HTTP traffic", http.StatusBadGateway)
return
}
if r.Method == http.MethodConnect {
http.Error(w, "CONNECT not supported for HTTP tunnels", http.StatusMethodNotAllowed)
return
}
if h.isWebSocketUpgrade(r) {
h.handleWebSocket(w, r, tconn)
return
}
stream, err := h.openStreamWithTimeout(tconn)
if err != nil {
httputil.SetCloseConnection(w)
http.Error(w, "Tunnel unavailable", http.StatusBadGateway)
return
}
defer stream.Close()
tconn.IncActiveConnections()
defer tconn.DecActiveConnections()
var limitedStream net.Conn = stream
if limiter := tconn.GetLimiter(); limiter != nil && limiter.IsLimited() {
if l, ok := limiter.(*qos.Limiter); ok {
limitedStream = qos.NewLimitedConn(r.Context(), stream, l)
}
}
countingStream := netutil.NewCountingConn(limitedStream,
tconn.AddBytesOut,
tconn.AddBytesIn,
)
if err := r.Write(countingStream); err != nil {
httputil.SetCloseConnection(w)
_ = r.Body.Close()
http.Error(w, "Forward failed", http.StatusBadGateway)
return
}
reader := bufioReaderPool.Get().(*bufio.Reader)
reader.Reset(countingStream)
resp, err := http.ReadResponse(reader, r)
if err != nil {
bufioReaderPool.Put(reader)
httputil.SetCloseConnection(w)
http.Error(w, "Read response failed", http.StatusBadGateway)
return
}
defer func() {
resp.Body.Close()
bufioReaderPool.Put(reader)
}()
h.copyResponseHeaders(w.Header(), resp.Header, r.Host)
statusCode := resp.StatusCode
if statusCode == 0 {
statusCode = http.StatusOK
}
if r.Method == http.MethodHead || statusCode == http.StatusNoContent || statusCode == http.StatusNotModified {
if resp.ContentLength >= 0 {
httputil.SetContentLength(w, resp.ContentLength)
} else {
w.Header().Del("Content-Length")
}
w.WriteHeader(statusCode)
return
}
if resp.ContentLength >= 0 {
httputil.SetContentLength(w, resp.ContentLength)
} else {
w.Header().Del("Content-Length")
}
w.WriteHeader(statusCode)
// Use pooled buffer for zero-copy optimization
buf := pool.GetBuffer(pool.SizeLarge)
defer pool.PutBuffer(buf)
// Copy with context cancellation support
ctx := r.Context()
copyDone := make(chan struct{})
go func() {
select {
case <-ctx.Done():
stream.Close()
case <-copyDone:
}
}()
_, _ = io.CopyBuffer(w, resp.Body, (*buf)[:])
close(copyDone)
}
func (h *Handler) openStreamWithTimeout(tconn *tunnel.Connection) (net.Conn, error) {
type result struct {
stream net.Conn
err error
}
ch := make(chan result, 1)
go func() {
s, err := tconn.OpenStream()
ch <- result{s, err}
}()
select {
case r := <-ch:
return r.stream, r.err
case <-time.After(openStreamTimeout):
// Goroutine will eventually complete and send to buffered channel
// which will be garbage collected. If stream was opened, it needs cleanup.
go func() {
if r := <-ch; r.stream != nil {
r.stream.Close()
}
}()
return nil, fmt.Errorf("open stream timeout")
}
}
func (h *Handler) copyResponseHeaders(dst http.Header, src http.Header, proxyHost string) {
for key, values := range src {
canonicalKey := http.CanonicalHeaderKey(key)
// Hop-by-hop headers must not be forwarded.
if canonicalKey == "Connection" ||
canonicalKey == "Keep-Alive" ||
canonicalKey == "Transfer-Encoding" ||
canonicalKey == "Upgrade" ||
canonicalKey == "Proxy-Connection" ||
canonicalKey == "Te" ||
canonicalKey == "Trailer" {
continue
}
if canonicalKey == "Location" && len(values) > 0 {
dst.Set("Location", h.rewriteLocationHeader(values[0], proxyHost))
continue
}
for _, value := range values {
dst.Add(key, value)
}
}
}
func (h *Handler) rewriteLocationHeader(location, proxyHost string) string {
if !strings.HasPrefix(location, "http://") && !strings.HasPrefix(location, "https://") {
return location
}
locationURL, err := url.Parse(location)
if err != nil {
return location
}
if locationURL.Host == "localhost" ||
strings.HasPrefix(locationURL.Host, "localhost:") ||
locationURL.Host == "127.0.0.1" ||
strings.HasPrefix(locationURL.Host, "127.0.0.1:") {
rewritten := fmt.Sprintf("https://%s%s", proxyHost, locationURL.Path)
if locationURL.RawQuery != "" {
rewritten += "?" + locationURL.RawQuery
}
if locationURL.Fragment != "" {
rewritten += "#" + locationURL.Fragment
}
return rewritten
}
return location
}
type subdomainResult int
const (
subdomainHome subdomainResult = iota
subdomainFound
subdomainNotFound
)
func (h *Handler) extractSubdomain(host string) (string, subdomainResult) {
if idx := strings.Index(host, ":"); idx != -1 {
host = host[:idx]
}
if host == h.serverDomain {
return "", subdomainHome
}
suffix := "." + h.tunnelDomain
if strings.HasSuffix(host, suffix) {
return strings.TrimSuffix(host, suffix), subdomainFound
}
if host == h.tunnelDomain {
return "", subdomainNotFound
}
return "", subdomainNotFound
}
func (h *Handler) validateMetricsAuth(w http.ResponseWriter, r *http.Request, realm string) bool {
if h.metricsToken == "" {
return true
}
token := extractBearerToken(r.Header.Get("Authorization"))
if subtle.ConstantTimeCompare([]byte(token), []byte(h.metricsToken)) != 1 {
w.Header().Set("WWW-Authenticate", fmt.Sprintf(`Bearer realm="%s"`, realm))
http.Error(w, "Unauthorized: provide metrics token via 'Authorization: Bearer <token>' header", http.StatusUnauthorized)
return false
}
return true
}