package proxy import ( "bufio" "fmt" "io" "net" "net/http" "net/url" "strings" "time" json "github.com/goccy/go-json" "drip/internal/server/tunnel" "drip/internal/shared/httputil" "drip/internal/shared/netutil" "drip/internal/shared/protocol" "go.uber.org/zap" ) const openStreamTimeout = 10 * time.Second type Handler struct { manager *tunnel.Manager logger *zap.Logger domain string authToken string } func NewHandler(manager *tunnel.Manager, logger *zap.Logger, domain string, authToken string) *Handler { return &Handler{ manager: manager, logger: logger, domain: domain, authToken: authToken, } } func (h *Handler) ServeHTTP(w http.ResponseWriter, r *http.Request) { // Always handle /health and /stats directly, regardless of subdomain. if r.URL.Path == "/health" { h.serveHealth(w, r) return } if r.URL.Path == "/stats" { h.serveStats(w, r) return } subdomain := h.extractSubdomain(r.Host) if subdomain == "" { h.serveHomePage(w, r) return } tconn, ok := h.manager.Get(subdomain) if !ok || tconn == nil { http.Error(w, "Tunnel not found. The tunnel may have been closed.", http.StatusNotFound) return } if tconn.IsClosed() { http.Error(w, "Tunnel connection closed", http.StatusBadGateway) return } tType := tconn.GetTunnelType() if tType != "" && tType != protocol.TunnelTypeHTTP && tType != protocol.TunnelTypeHTTPS { http.Error(w, "Tunnel does not accept HTTP traffic", http.StatusBadGateway) return } // Check for WebSocket upgrade if httputil.IsWebSocketUpgrade(r) { h.handleWebSocket(w, r, tconn) return } // Open stream with timeout stream, err := h.openStreamWithTimeout(tconn) if err != nil { w.Header().Set("Connection", "close") http.Error(w, "Tunnel unavailable", http.StatusBadGateway) return } defer stream.Close() // Track active connections tconn.IncActiveConnections() defer tconn.DecActiveConnections() // Wrap stream with counting for traffic stats countingStream := netutil.NewCountingConn(stream, tconn.AddBytesOut, // Data read from stream = bytes out to client tconn.AddBytesIn, // Data written to stream = bytes in from client ) // 1) Write request over the stream (net/http handles large bodies correctly). if err := r.Write(countingStream); err != nil { w.Header().Set("Connection", "close") _ = r.Body.Close() http.Error(w, "Forward failed", http.StatusBadGateway) return } // 2) Read response from stream. resp, err := http.ReadResponse(bufio.NewReaderSize(countingStream, 32*1024), r) if err != nil { w.Header().Set("Connection", "close") http.Error(w, "Read response failed", http.StatusBadGateway) return } defer resp.Body.Close() // 3) Copy headers (strip hop-by-hop). h.copyResponseHeaders(w.Header(), resp.Header, r.Host) statusCode := resp.StatusCode if statusCode == 0 { statusCode = http.StatusOK } // Ensure message delimiting works with our custom ResponseWriter: // - If Content-Length is known, send it. // - Otherwise, re-chunk the decoded body ourselves. if r.Method == http.MethodHead || statusCode == http.StatusNoContent || statusCode == http.StatusNotModified { if resp.ContentLength >= 0 { w.Header().Set("Content-Length", fmt.Sprintf("%d", resp.ContentLength)) } else { w.Header().Del("Content-Length") } w.WriteHeader(statusCode) return } if resp.ContentLength >= 0 { w.Header().Set("Content-Length", fmt.Sprintf("%d", resp.ContentLength)) w.WriteHeader(statusCode) ctx := r.Context() done := make(chan struct{}) go func() { select { case <-ctx.Done(): stream.Close() case <-done: } }() _, _ = io.Copy(w, resp.Body) close(done) stream.Close() return } w.Header().Del("Content-Length") w.Header().Set("Transfer-Encoding", "chunked") if len(resp.Trailer) > 0 { w.Header().Set("Trailer", trailerKeys(resp.Trailer)) } w.WriteHeader(statusCode) ctx := r.Context() done := make(chan struct{}) go func() { select { case <-ctx.Done(): stream.Close() case <-done: } }() if err := writeChunked(w, resp.Body, resp.Trailer); err != nil { h.logger.Debug("Write chunked response failed", zap.Error(err)) } close(done) stream.Close() } func (h *Handler) openStreamWithTimeout(tconn *tunnel.Connection) (net.Conn, error) { type result struct { stream net.Conn err error } ch := make(chan result, 1) go func() { s, err := tconn.OpenStream() ch <- result{s, err} }() select { case r := <-ch: return r.stream, r.err case <-time.After(openStreamTimeout): return nil, fmt.Errorf("open stream timeout") } } func (h *Handler) handleWebSocket(w http.ResponseWriter, r *http.Request, tconn *tunnel.Connection) { stream, err := h.openStreamWithTimeout(tconn) if err != nil { http.Error(w, "Tunnel unavailable", http.StatusBadGateway) return } tconn.IncActiveConnections() hj, ok := w.(http.Hijacker) if !ok { stream.Close() tconn.DecActiveConnections() http.Error(w, "WebSocket not supported", http.StatusInternalServerError) return } clientConn, _, err := hj.Hijack() if err != nil { stream.Close() tconn.DecActiveConnections() http.Error(w, "Failed to hijack connection", http.StatusInternalServerError) return } if err := r.Write(stream); err != nil { stream.Close() clientConn.Close() tconn.DecActiveConnections() return } go func() { defer stream.Close() defer clientConn.Close() defer tconn.DecActiveConnections() _ = netutil.PipeWithCallbacks(r.Context(), stream, clientConn, func(n int64) { tconn.AddBytesOut(n) }, func(n int64) { tconn.AddBytesIn(n) }, ) }() } func (h *Handler) copyResponseHeaders(dst http.Header, src http.Header, proxyHost string) { for key, values := range src { canonicalKey := http.CanonicalHeaderKey(key) // Hop-by-hop headers must not be forwarded. if canonicalKey == "Connection" || canonicalKey == "Keep-Alive" || canonicalKey == "Transfer-Encoding" || canonicalKey == "Upgrade" || canonicalKey == "Proxy-Connection" || canonicalKey == "Te" || canonicalKey == "Trailer" { continue } if canonicalKey == "Location" && len(values) > 0 { dst.Set("Location", h.rewriteLocationHeader(values[0], proxyHost)) continue } for _, value := range values { dst.Add(key, value) } } } func trailerKeys(hdr http.Header) string { keys := make([]string, 0, len(hdr)) for k := range hdr { keys = append(keys, k) } // Deterministic order is nicer for debugging; no semantic impact. sortStrings(keys) return strings.Join(keys, ", ") } func writeChunked(w io.Writer, r io.Reader, trailer http.Header) error { buf := make([]byte, 32*1024) for { n, err := r.Read(buf) if n > 0 { if _, werr := fmt.Fprintf(w, "%x\r\n", n); werr != nil { return werr } if _, werr := w.Write(buf[:n]); werr != nil { return werr } if _, werr := io.WriteString(w, "\r\n"); werr != nil { return werr } } if err == io.EOF { break } if err != nil { return err } } if _, err := io.WriteString(w, "0\r\n"); err != nil { return err } for k, vv := range trailer { for _, v := range vv { if _, err := io.WriteString(w, fmt.Sprintf("%s: %s\r\n", k, v)); err != nil { return err } } } _, err := io.WriteString(w, "\r\n") return err } func (h *Handler) rewriteLocationHeader(location, proxyHost string) string { if !strings.HasPrefix(location, "http://") && !strings.HasPrefix(location, "https://") { return location } locationURL, err := url.Parse(location) if err != nil { return location } if locationURL.Host == "localhost" || strings.HasPrefix(locationURL.Host, "localhost:") || locationURL.Host == "127.0.0.1" || strings.HasPrefix(locationURL.Host, "127.0.0.1:") { rewritten := fmt.Sprintf("https://%s%s", proxyHost, locationURL.Path) if locationURL.RawQuery != "" { rewritten += "?" + locationURL.RawQuery } if locationURL.Fragment != "" { rewritten += "#" + locationURL.Fragment } return rewritten } return location } func (h *Handler) extractSubdomain(host string) string { if idx := strings.Index(host, ":"); idx != -1 { host = host[:idx] } if host == h.domain { return "" } suffix := "." + h.domain if strings.HasSuffix(host, suffix) { return strings.TrimSuffix(host, suffix) } return "" } func (h *Handler) serveHomePage(w http.ResponseWriter, r *http.Request) { html := `
A self-hosted tunneling solution to securely expose your services to the internet.
Install the client:
bash <(curl -fsSL https://raw.githubusercontent.com/Gouryella/drip/main/scripts/install.sh)
Start a tunnel:
drip http 3000drip https 443drip tcp 5432
`
data := []byte(html)
w.Header().Set("Content-Type", "text/html")
w.Header().Set("Content-Length", fmt.Sprintf("%d", len(data)))
w.Write(data)
}
func (h *Handler) serveHealth(w http.ResponseWriter, r *http.Request) {
health := map[string]interface{}{
"status": "ok",
"active_tunnels": h.manager.Count(),
"timestamp": time.Now().Unix(),
}
data, err := json.Marshal(health)
if err != nil {
http.Error(w, "Failed to encode response", http.StatusInternalServerError)
return
}
w.Header().Set("Content-Type", "application/json")
w.Header().Set("Content-Length", fmt.Sprintf("%d", len(data)))
w.Write(data)
}
func (h *Handler) serveStats(w http.ResponseWriter, r *http.Request) {
if h.authToken != "" {
token := r.URL.Query().Get("token")
if token == "" {
authHeader := r.Header.Get("Authorization")
if strings.HasPrefix(authHeader, "Bearer ") {
token = strings.TrimPrefix(authHeader, "Bearer ")
}
}
if token != h.authToken {
http.Error(w, "Unauthorized: invalid or missing token", http.StatusUnauthorized)
return
}
}
connections := h.manager.List()
stats := map[string]interface{}{
"total_tunnels": len(connections),
"tunnels": []map[string]interface{}{},
}
for _, conn := range connections {
if conn == nil {
continue
}
stats["tunnels"] = append(stats["tunnels"].([]map[string]interface{}), map[string]interface{}{
"subdomain": conn.Subdomain,
"tunnel_type": string(conn.GetTunnelType()),
"last_active": conn.LastActive.Unix(),
"bytes_in": conn.GetBytesIn(),
"bytes_out": conn.GetBytesOut(),
"active_connections": conn.GetActiveConnections(),
"total_bytes": conn.GetBytesIn() + conn.GetBytesOut(),
})
}
data, err := json.Marshal(stats)
if err != nil {
http.Error(w, "Failed to encode response", http.StatusInternalServerError)
return
}
w.Header().Set("Content-Type", "application/json")
w.Header().Set("Content-Length", fmt.Sprintf("%d", len(data)))
w.Write(data)
}
func sortStrings(s []string) {
for i := 0; i < len(s); i++ {
for j := i + 1; j < len(s); j++ {
if s[j] < s[i] {
s[i], s[j] = s[j], s[i]
}
}
}
}