mirror of
https://github.com/Gouryella/drip.git
synced 2026-02-23 21:00:44 +00:00
enhancements - Add adaptive HTTP response handling with automatic streaming for large responses (>1MB) - Implement zero-copy streaming using buffer pools for better performance - Add compression module for reduced bandwidth usage - Add GitHub Container Registry workflow for automated Docker builds - Add production-optimized Dockerfile and docker-compose configuration - Simplify background mode with -d flag and improved daemon management - Update documentation with new command syntax and deployment guides - Clean up unused code and improve error handling - Fix lipgloss style usage (remove unnecessary .Copy() calls)
574 lines
16 KiB
Go
574 lines
16 KiB
Go
package proxy
|
|
|
|
import (
|
|
"fmt"
|
|
"io"
|
|
"net/http"
|
|
"net/url"
|
|
"strings"
|
|
"time"
|
|
|
|
json "github.com/goccy/go-json"
|
|
|
|
"drip/internal/server/tunnel"
|
|
"drip/internal/shared/pool"
|
|
"drip/internal/shared/protocol"
|
|
"drip/internal/shared/utils"
|
|
|
|
"go.uber.org/zap"
|
|
)
|
|
|
|
type Handler struct {
|
|
manager *tunnel.Manager
|
|
logger *zap.Logger
|
|
responses *ResponseHandler
|
|
domain string
|
|
authToken string
|
|
headerPool *pool.HeaderPool
|
|
}
|
|
|
|
func NewHandler(manager *tunnel.Manager, logger *zap.Logger, responses *ResponseHandler, domain string, authToken string) *Handler {
|
|
return &Handler{
|
|
manager: manager,
|
|
logger: logger,
|
|
responses: responses,
|
|
domain: domain,
|
|
authToken: authToken,
|
|
headerPool: pool.NewHeaderPool(),
|
|
}
|
|
}
|
|
|
|
func (h *Handler) ServeHTTP(w http.ResponseWriter, r *http.Request) {
|
|
subdomain := h.extractSubdomain(r.Host)
|
|
|
|
if subdomain == "" {
|
|
h.serveHomePage(w, r)
|
|
return
|
|
}
|
|
|
|
conn, ok := h.manager.Get(subdomain)
|
|
if !ok {
|
|
http.Error(w, "Tunnel not found. The tunnel may have been closed.", http.StatusNotFound)
|
|
return
|
|
}
|
|
|
|
if conn.IsClosed() {
|
|
http.Error(w, "Tunnel connection closed", http.StatusBadGateway)
|
|
return
|
|
}
|
|
|
|
transport := conn.GetTransport()
|
|
if transport == nil {
|
|
http.Error(w, "Tunnel control channel not ready", http.StatusBadGateway)
|
|
return
|
|
}
|
|
|
|
tType := conn.GetTunnelType()
|
|
if tType != "" && tType != protocol.TunnelTypeHTTP && tType != protocol.TunnelTypeHTTPS {
|
|
http.Error(w, "Tunnel does not accept HTTP traffic", http.StatusBadGateway)
|
|
return
|
|
}
|
|
|
|
requestID := utils.GenerateID()
|
|
|
|
h.handleAdaptiveRequest(w, r, transport, requestID, subdomain)
|
|
}
|
|
|
|
func (h *Handler) handleAdaptiveRequest(w http.ResponseWriter, r *http.Request, transport tunnel.Transport, requestID string, subdomain string) {
|
|
const streamingThreshold int64 = 1 * 1024 * 1024
|
|
|
|
buffer := make([]byte, 0, streamingThreshold)
|
|
tempBuf := make([]byte, 32*1024)
|
|
|
|
var totalRead int64
|
|
var hitThreshold bool
|
|
|
|
for totalRead < streamingThreshold {
|
|
n, err := r.Body.Read(tempBuf)
|
|
if n > 0 {
|
|
buffer = append(buffer, tempBuf[:n]...)
|
|
totalRead += int64(n)
|
|
}
|
|
if err == io.EOF {
|
|
r.Body.Close()
|
|
h.sendBufferedRequest(w, r, transport, requestID, subdomain, buffer)
|
|
return
|
|
}
|
|
if err != nil {
|
|
r.Body.Close()
|
|
h.logger.Error("Read request body failed", zap.Error(err))
|
|
http.Error(w, "Failed to read request body", http.StatusInternalServerError)
|
|
return
|
|
}
|
|
if totalRead >= streamingThreshold {
|
|
hitThreshold = true
|
|
break
|
|
}
|
|
}
|
|
|
|
if !hitThreshold {
|
|
r.Body.Close()
|
|
h.sendBufferedRequest(w, r, transport, requestID, subdomain, buffer)
|
|
return
|
|
}
|
|
|
|
h.streamLargeRequest(w, r, transport, requestID, subdomain, buffer)
|
|
}
|
|
|
|
func (h *Handler) sendBufferedRequest(w http.ResponseWriter, r *http.Request, transport tunnel.Transport, requestID string, subdomain string, body []byte) {
|
|
headers := h.headerPool.Get()
|
|
h.headerPool.CloneWithExtra(headers, r.Header, "Host", r.Host)
|
|
|
|
httpReq := protocol.HTTPRequest{
|
|
Method: r.Method,
|
|
URL: r.URL.String(),
|
|
Headers: headers,
|
|
Body: body,
|
|
}
|
|
|
|
reqBytes, err := protocol.EncodeHTTPRequest(&httpReq)
|
|
h.headerPool.Put(headers)
|
|
|
|
if err != nil {
|
|
h.logger.Error("Encode HTTP request failed", zap.Error(err))
|
|
http.Error(w, "Internal server error", http.StatusInternalServerError)
|
|
return
|
|
}
|
|
|
|
header := protocol.DataHeader{
|
|
StreamID: requestID,
|
|
RequestID: requestID,
|
|
Type: protocol.DataTypeHTTPRequest,
|
|
IsLast: true,
|
|
}
|
|
|
|
payload, poolBuffer, err := protocol.EncodeDataPayloadPooled(header, reqBytes)
|
|
if err != nil {
|
|
h.logger.Error("Encode data payload failed", zap.Error(err))
|
|
http.Error(w, "Internal server error", http.StatusInternalServerError)
|
|
return
|
|
}
|
|
|
|
frame := protocol.NewFramePooled(protocol.FrameTypeData, payload, poolBuffer)
|
|
|
|
respChan := h.responses.CreateResponseChan(requestID)
|
|
streamingDone := h.responses.CreateStreamingResponse(requestID, w)
|
|
defer func() {
|
|
h.responses.CleanupResponseChan(requestID)
|
|
h.responses.CleanupStreamingResponse(requestID)
|
|
}()
|
|
|
|
if err := transport.SendFrame(frame); err != nil {
|
|
h.logger.Error("Send frame to tunnel failed", zap.Error(err))
|
|
http.Error(w, "Failed to forward request to tunnel", http.StatusBadGateway)
|
|
return
|
|
}
|
|
|
|
select {
|
|
case respMsg := <-respChan:
|
|
if respMsg == nil {
|
|
http.Error(w, "Internal server error: nil response", http.StatusInternalServerError)
|
|
return
|
|
}
|
|
h.writeHTTPResponse(w, respMsg, subdomain, r)
|
|
case <-streamingDone:
|
|
// Streaming response has been fully written by SendStreamingChunk
|
|
case <-time.After(5 * time.Minute):
|
|
h.logger.Error("Request timeout",
|
|
zap.String("request_id", requestID),
|
|
zap.String("url", r.URL.String()),
|
|
)
|
|
http.Error(w, "Request timeout - the tunnel client did not respond in time", http.StatusGatewayTimeout)
|
|
}
|
|
}
|
|
|
|
func (h *Handler) streamLargeRequest(w http.ResponseWriter, r *http.Request, transport tunnel.Transport, requestID string, subdomain string, bufferedData []byte) {
|
|
headers := h.headerPool.Get()
|
|
h.headerPool.CloneWithExtra(headers, r.Header, "Host", r.Host)
|
|
|
|
httpReqHead := protocol.HTTPRequestHead{
|
|
Method: r.Method,
|
|
URL: r.URL.String(),
|
|
Headers: headers,
|
|
ContentLength: r.ContentLength,
|
|
}
|
|
|
|
headBytes, err := protocol.EncodeHTTPRequestHead(&httpReqHead)
|
|
h.headerPool.Put(headers)
|
|
|
|
if err != nil {
|
|
h.logger.Error("Encode HTTP request head failed", zap.Error(err))
|
|
http.Error(w, "Internal server error", http.StatusInternalServerError)
|
|
return
|
|
}
|
|
|
|
headHeader := protocol.DataHeader{
|
|
StreamID: requestID,
|
|
RequestID: requestID,
|
|
Type: protocol.DataTypeHTTPHead, // shared streaming head type
|
|
IsLast: false,
|
|
}
|
|
|
|
headPayload, headPoolBuffer, err := protocol.EncodeDataPayloadPooled(headHeader, headBytes)
|
|
if err != nil {
|
|
h.logger.Error("Encode head payload failed", zap.Error(err))
|
|
http.Error(w, "Internal server error", http.StatusInternalServerError)
|
|
return
|
|
}
|
|
|
|
headFrame := protocol.NewFramePooled(protocol.FrameTypeData, headPayload, headPoolBuffer)
|
|
|
|
respChan := h.responses.CreateResponseChan(requestID)
|
|
streamingDone := h.responses.CreateStreamingResponse(requestID, w)
|
|
defer func() {
|
|
h.responses.CleanupResponseChan(requestID)
|
|
h.responses.CleanupStreamingResponse(requestID)
|
|
}()
|
|
|
|
if err := transport.SendFrame(headFrame); err != nil {
|
|
h.logger.Error("Send head frame failed", zap.Error(err))
|
|
http.Error(w, "Failed to forward request to tunnel", http.StatusBadGateway)
|
|
return
|
|
}
|
|
|
|
if len(bufferedData) > 0 {
|
|
chunkHeader := protocol.DataHeader{
|
|
StreamID: requestID,
|
|
RequestID: requestID,
|
|
Type: protocol.DataTypeHTTPBodyChunk, // shared streaming body type
|
|
IsLast: false,
|
|
}
|
|
|
|
chunkPayload, chunkPoolBuffer, err := protocol.EncodeDataPayloadPooled(chunkHeader, bufferedData)
|
|
if err != nil {
|
|
h.logger.Error("Encode buffered chunk failed", zap.Error(err))
|
|
|
|
finalHeader := protocol.DataHeader{
|
|
StreamID: requestID,
|
|
RequestID: requestID,
|
|
Type: protocol.DataTypeHTTPRequestBodyChunk,
|
|
IsLast: true,
|
|
}
|
|
finalPayload, finalPoolBuffer, ferr := protocol.EncodeDataPayloadPooled(finalHeader, nil)
|
|
if ferr == nil {
|
|
finalFrame := protocol.NewFramePooled(protocol.FrameTypeData, finalPayload, finalPoolBuffer)
|
|
transport.SendFrame(finalFrame)
|
|
}
|
|
return
|
|
}
|
|
|
|
chunkFrame := protocol.NewFramePooled(protocol.FrameTypeData, chunkPayload, chunkPoolBuffer)
|
|
if err := transport.SendFrame(chunkFrame); err != nil {
|
|
h.logger.Error("Send buffered chunk failed", zap.Error(err))
|
|
|
|
finalHeader := protocol.DataHeader{
|
|
StreamID: requestID,
|
|
RequestID: requestID,
|
|
Type: protocol.DataTypeHTTPRequestBodyChunk,
|
|
IsLast: true,
|
|
}
|
|
finalPayload, finalPoolBuffer, ferr := protocol.EncodeDataPayloadPooled(finalHeader, nil)
|
|
if ferr == nil {
|
|
finalFrame := protocol.NewFramePooled(protocol.FrameTypeData, finalPayload, finalPoolBuffer)
|
|
transport.SendFrame(finalFrame)
|
|
}
|
|
return
|
|
}
|
|
}
|
|
|
|
buffer := make([]byte, 32*1024)
|
|
for {
|
|
n, readErr := r.Body.Read(buffer)
|
|
if n > 0 {
|
|
isLast := readErr == io.EOF
|
|
|
|
chunkHeader := protocol.DataHeader{
|
|
StreamID: requestID,
|
|
RequestID: requestID,
|
|
Type: protocol.DataTypeHTTPBodyChunk, // shared streaming body type
|
|
IsLast: isLast,
|
|
}
|
|
|
|
chunkPayload, chunkPoolBuffer, err := protocol.EncodeDataPayloadPooled(chunkHeader, buffer[:n])
|
|
if err != nil {
|
|
h.logger.Error("Encode chunk payload failed", zap.Error(err))
|
|
|
|
finalHeader := protocol.DataHeader{
|
|
StreamID: requestID,
|
|
RequestID: requestID,
|
|
Type: protocol.DataTypeHTTPRequestBodyChunk,
|
|
IsLast: true,
|
|
}
|
|
finalPayload, finalPoolBuffer, ferr := protocol.EncodeDataPayloadPooled(finalHeader, nil)
|
|
if ferr == nil {
|
|
finalFrame := protocol.NewFramePooled(protocol.FrameTypeData, finalPayload, finalPoolBuffer)
|
|
transport.SendFrame(finalFrame)
|
|
}
|
|
return
|
|
}
|
|
|
|
chunkFrame := protocol.NewFramePooled(protocol.FrameTypeData, chunkPayload, chunkPoolBuffer)
|
|
if err := transport.SendFrame(chunkFrame); err != nil {
|
|
h.logger.Error("Send chunk frame failed", zap.Error(err))
|
|
|
|
finalHeader := protocol.DataHeader{
|
|
StreamID: requestID,
|
|
RequestID: requestID,
|
|
Type: protocol.DataTypeHTTPRequestBodyChunk,
|
|
IsLast: true,
|
|
}
|
|
finalPayload, finalPoolBuffer, ferr := protocol.EncodeDataPayloadPooled(finalHeader, nil)
|
|
if ferr == nil {
|
|
finalFrame := protocol.NewFramePooled(protocol.FrameTypeData, finalPayload, finalPoolBuffer)
|
|
transport.SendFrame(finalFrame)
|
|
}
|
|
return
|
|
}
|
|
}
|
|
|
|
if readErr == io.EOF {
|
|
if n == 0 {
|
|
finalHeader := protocol.DataHeader{
|
|
StreamID: requestID,
|
|
RequestID: requestID,
|
|
Type: protocol.DataTypeHTTPRequestBodyChunk,
|
|
IsLast: true,
|
|
}
|
|
finalPayload, finalPoolBuffer, err := protocol.EncodeDataPayloadPooled(finalHeader, nil)
|
|
if err == nil {
|
|
finalFrame := protocol.NewFramePooled(protocol.FrameTypeData, finalPayload, finalPoolBuffer)
|
|
transport.SendFrame(finalFrame)
|
|
}
|
|
}
|
|
break
|
|
}
|
|
if readErr != nil {
|
|
h.logger.Error("Read request body failed", zap.Error(readErr))
|
|
|
|
finalHeader := protocol.DataHeader{
|
|
StreamID: requestID,
|
|
RequestID: requestID,
|
|
Type: protocol.DataTypeHTTPRequestBodyChunk,
|
|
IsLast: true,
|
|
}
|
|
finalPayload, finalPoolBuffer, err := protocol.EncodeDataPayloadPooled(finalHeader, nil)
|
|
if err == nil {
|
|
finalFrame := protocol.NewFramePooled(protocol.FrameTypeData, finalPayload, finalPoolBuffer)
|
|
transport.SendFrame(finalFrame)
|
|
}
|
|
|
|
http.Error(w, "Failed to read request body", http.StatusInternalServerError)
|
|
return
|
|
}
|
|
}
|
|
|
|
r.Body.Close()
|
|
|
|
select {
|
|
case respMsg := <-respChan:
|
|
if respMsg == nil {
|
|
http.Error(w, "Internal server error: nil response", http.StatusInternalServerError)
|
|
return
|
|
}
|
|
h.writeHTTPResponse(w, respMsg, subdomain, r)
|
|
case <-streamingDone:
|
|
// Streaming response has been fully written by SendStreamingChunk
|
|
case <-time.After(5 * time.Minute):
|
|
h.logger.Error("Streaming request timeout",
|
|
zap.String("request_id", requestID),
|
|
zap.String("url", r.URL.String()),
|
|
)
|
|
http.Error(w, "Request timeout - the tunnel client did not respond in time", http.StatusGatewayTimeout)
|
|
}
|
|
}
|
|
|
|
func (h *Handler) writeHTTPResponse(w http.ResponseWriter, resp *protocol.HTTPResponse, subdomain string, r *http.Request) {
|
|
if resp == nil {
|
|
http.Error(w, "Invalid response from tunnel", http.StatusBadGateway)
|
|
return
|
|
}
|
|
|
|
// For buffered responses, we have the complete body, so we can set Content-Length
|
|
// Skip ALL hop-by-hop headers - client should have already cleaned them
|
|
for key, values := range resp.Headers {
|
|
canonicalKey := http.CanonicalHeaderKey(key)
|
|
|
|
// Skip hop-by-hop headers completely using canonical key comparison
|
|
if canonicalKey == "Connection" ||
|
|
canonicalKey == "Keep-Alive" ||
|
|
canonicalKey == "Transfer-Encoding" ||
|
|
canonicalKey == "Upgrade" ||
|
|
canonicalKey == "Proxy-Connection" ||
|
|
canonicalKey == "Te" ||
|
|
canonicalKey == "Trailer" {
|
|
continue
|
|
}
|
|
|
|
if canonicalKey == "Location" && len(values) > 0 {
|
|
rewrittenLocation := h.rewriteLocationHeader(values[0], r.Host)
|
|
w.Header().Set("Location", rewrittenLocation)
|
|
continue
|
|
}
|
|
|
|
for _, value := range values {
|
|
w.Header().Add(key, value)
|
|
}
|
|
}
|
|
|
|
// For buffered mode, always set Content-Length with the actual body size
|
|
w.Header().Set("Content-Length", fmt.Sprintf("%d", len(resp.Body)))
|
|
|
|
statusCode := resp.StatusCode
|
|
if statusCode == 0 {
|
|
statusCode = http.StatusOK
|
|
}
|
|
|
|
w.WriteHeader(statusCode)
|
|
|
|
if len(resp.Body) > 0 {
|
|
w.Write(resp.Body)
|
|
}
|
|
}
|
|
|
|
func (h *Handler) rewriteLocationHeader(location, proxyHost string) string {
|
|
if !strings.HasPrefix(location, "http://") && !strings.HasPrefix(location, "https://") {
|
|
return location
|
|
}
|
|
|
|
locationURL, err := url.Parse(location)
|
|
if err != nil {
|
|
return location
|
|
}
|
|
|
|
if locationURL.Host == "localhost" ||
|
|
strings.HasPrefix(locationURL.Host, "localhost:") ||
|
|
locationURL.Host == "127.0.0.1" ||
|
|
strings.HasPrefix(locationURL.Host, "127.0.0.1:") {
|
|
scheme := "https"
|
|
if strings.Contains(proxyHost, ":") && !strings.Contains(proxyHost, "https") {
|
|
parts := strings.Split(proxyHost, ":")
|
|
if len(parts) == 2 && parts[1] != "443" {
|
|
scheme = "https"
|
|
}
|
|
}
|
|
|
|
rewritten := fmt.Sprintf("%s://%s%s", scheme, proxyHost, locationURL.Path)
|
|
if locationURL.RawQuery != "" {
|
|
rewritten += "?" + locationURL.RawQuery
|
|
}
|
|
if locationURL.Fragment != "" {
|
|
rewritten += "#" + locationURL.Fragment
|
|
}
|
|
|
|
return rewritten
|
|
}
|
|
|
|
return location
|
|
}
|
|
|
|
func (h *Handler) extractSubdomain(host string) string {
|
|
if idx := strings.Index(host, ":"); idx != -1 {
|
|
host = host[:idx]
|
|
}
|
|
|
|
if host == h.domain {
|
|
return ""
|
|
}
|
|
|
|
suffix := "." + h.domain
|
|
if strings.HasSuffix(host, suffix) {
|
|
subdomain := strings.TrimSuffix(host, suffix)
|
|
return subdomain
|
|
}
|
|
|
|
return ""
|
|
}
|
|
|
|
func (h *Handler) serveHomePage(w http.ResponseWriter, r *http.Request) {
|
|
if r.URL.Path == "/health" {
|
|
h.serveHealth(w, r)
|
|
return
|
|
}
|
|
|
|
if r.URL.Path == "/stats" {
|
|
h.serveStats(w, r)
|
|
return
|
|
}
|
|
|
|
html := `<!DOCTYPE html>
|
|
<html>
|
|
<head>
|
|
<meta charset="UTF-8" />
|
|
<title>Drip - Your Tunnel, Your Domain, Anywhere</title>
|
|
<style>
|
|
body { font-family: Arial, sans-serif; max-width: 800px; margin: 50px auto; padding: 20px; }
|
|
h1 { color: #333; }
|
|
code { background: #f4f4f4; padding: 2px 6px; border-radius: 3px; }
|
|
.stats { background: #f9f9f9; padding: 15px; border-radius: 5px; margin: 20px 0; }
|
|
</style>
|
|
</head>
|
|
<body>
|
|
<h1>💧 Drip - Your Tunnel, Your Domain, Anywhere</h1>
|
|
<p>A self-hosted tunneling solution to securely expose your services to the internet.</p>
|
|
|
|
<h2>Quick Start</h2>
|
|
<p>Install the client:</p>
|
|
<code>bash <(curl -fsSL https://raw.githubusercontent.com/Gouryella/drip/main/scripts/install.sh)</code>
|
|
|
|
<p>Start a tunnel:</p>
|
|
<code>drip http 3000</code><br><br>
|
|
<code>drip https 443</code><br><br>
|
|
<code>drip tcp 5432</code>
|
|
<p><a href="/health">Health Check</a> | <a href="/stats">Statistics</a></p>
|
|
</body>
|
|
</html>`
|
|
|
|
w.Header().Set("Content-Type", "text/html")
|
|
w.Write([]byte(html))
|
|
}
|
|
|
|
func (h *Handler) serveHealth(w http.ResponseWriter, r *http.Request) {
|
|
health := map[string]interface{}{
|
|
"status": "ok",
|
|
"active_tunnels": h.manager.Count(),
|
|
"timestamp": time.Now().Unix(),
|
|
}
|
|
|
|
w.Header().Set("Content-Type", "application/json")
|
|
json.NewEncoder(w).Encode(health)
|
|
}
|
|
|
|
func (h *Handler) serveStats(w http.ResponseWriter, r *http.Request) {
|
|
if h.authToken != "" {
|
|
token := r.URL.Query().Get("token")
|
|
if token == "" {
|
|
authHeader := r.Header.Get("Authorization")
|
|
if strings.HasPrefix(authHeader, "Bearer ") {
|
|
token = strings.TrimPrefix(authHeader, "Bearer ")
|
|
}
|
|
}
|
|
|
|
if token != h.authToken {
|
|
http.Error(w, "Unauthorized: invalid or missing token", http.StatusUnauthorized)
|
|
return
|
|
}
|
|
}
|
|
|
|
connections := h.manager.List()
|
|
|
|
stats := map[string]interface{}{
|
|
"total_tunnels": len(connections),
|
|
"tunnels": []map[string]interface{}{},
|
|
}
|
|
|
|
for _, conn := range connections {
|
|
stats["tunnels"] = append(stats["tunnels"].([]map[string]interface{}), map[string]interface{}{
|
|
"subdomain": conn.Subdomain,
|
|
"last_active": conn.LastActive.Unix(),
|
|
})
|
|
}
|
|
|
|
w.Header().Set("Content-Type", "application/json")
|
|
json.NewEncoder(w).Encode(stats)
|
|
}
|