mirror of
https://github.com/Gouryella/drip.git
synced 2026-02-23 21:00:44 +00:00
feat(tunnel): switch to yamux stream proxying and connection pooling
- Introduce pooled tunnel sessions (TunnelID/DataConnect) on client/server - Proxy HTTP/HTTPS via raw HTTP over yamux streams; pipe TCP streams directly - Move UI/stats into internal/shared; refactor CLI tunnel helpers; drop msgpack/hpack legacy
This commit is contained in:
@@ -1,490 +1,251 @@
|
||||
package proxy
|
||||
|
||||
import (
|
||||
"context"
|
||||
"bufio"
|
||||
"fmt"
|
||||
"io"
|
||||
"net"
|
||||
"net/http"
|
||||
"net/url"
|
||||
"strings"
|
||||
"sync"
|
||||
"time"
|
||||
|
||||
json "github.com/goccy/go-json"
|
||||
|
||||
"drip/internal/server/tunnel"
|
||||
"drip/internal/shared/pool"
|
||||
"drip/internal/shared/httputil"
|
||||
"drip/internal/shared/netutil"
|
||||
"drip/internal/shared/protocol"
|
||||
"drip/internal/shared/utils"
|
||||
|
||||
"go.uber.org/zap"
|
||||
)
|
||||
|
||||
const openStreamTimeout = 10 * time.Second
|
||||
|
||||
type Handler struct {
|
||||
manager *tunnel.Manager
|
||||
logger *zap.Logger
|
||||
responses *ResponseHandler
|
||||
domain string
|
||||
authToken string
|
||||
headerPool *pool.HeaderPool
|
||||
bufferPool *pool.AdaptiveBufferPool
|
||||
manager *tunnel.Manager
|
||||
logger *zap.Logger
|
||||
domain string
|
||||
authToken string
|
||||
}
|
||||
|
||||
func NewHandler(manager *tunnel.Manager, logger *zap.Logger, responses *ResponseHandler, domain string, authToken string) *Handler {
|
||||
func NewHandler(manager *tunnel.Manager, logger *zap.Logger, domain string, authToken string) *Handler {
|
||||
return &Handler{
|
||||
manager: manager,
|
||||
logger: logger,
|
||||
responses: responses,
|
||||
domain: domain,
|
||||
authToken: authToken,
|
||||
headerPool: pool.NewHeaderPool(),
|
||||
bufferPool: pool.NewAdaptiveBufferPool(),
|
||||
manager: manager,
|
||||
logger: logger,
|
||||
domain: domain,
|
||||
authToken: authToken,
|
||||
}
|
||||
}
|
||||
|
||||
func (h *Handler) ServeHTTP(w http.ResponseWriter, r *http.Request) {
|
||||
// Always handle /health and /stats directly, regardless of subdomain
|
||||
// Always handle /health and /stats directly, regardless of subdomain.
|
||||
if r.URL.Path == "/health" {
|
||||
h.serveHealth(w, r)
|
||||
return
|
||||
}
|
||||
|
||||
if r.URL.Path == "/stats" {
|
||||
h.serveStats(w, r)
|
||||
return
|
||||
}
|
||||
|
||||
subdomain := h.extractSubdomain(r.Host)
|
||||
|
||||
if subdomain == "" {
|
||||
h.serveHomePage(w, r)
|
||||
return
|
||||
}
|
||||
|
||||
conn, ok := h.manager.Get(subdomain)
|
||||
if !ok {
|
||||
tconn, ok := h.manager.Get(subdomain)
|
||||
if !ok || tconn == nil {
|
||||
http.Error(w, "Tunnel not found. The tunnel may have been closed.", http.StatusNotFound)
|
||||
return
|
||||
}
|
||||
|
||||
if conn.IsClosed() {
|
||||
if tconn.IsClosed() {
|
||||
http.Error(w, "Tunnel connection closed", http.StatusBadGateway)
|
||||
return
|
||||
}
|
||||
|
||||
transport := conn.GetTransport()
|
||||
if transport == nil {
|
||||
http.Error(w, "Tunnel control channel not ready", http.StatusBadGateway)
|
||||
return
|
||||
}
|
||||
|
||||
tType := conn.GetTunnelType()
|
||||
tType := tconn.GetTunnelType()
|
||||
if tType != "" && tType != protocol.TunnelTypeHTTP && tType != protocol.TunnelTypeHTTPS {
|
||||
http.Error(w, "Tunnel does not accept HTTP traffic", http.StatusBadGateway)
|
||||
return
|
||||
}
|
||||
|
||||
requestID := utils.GenerateID()
|
||||
// Check for WebSocket upgrade
|
||||
if httputil.IsWebSocketUpgrade(r) {
|
||||
h.handleWebSocket(w, r, tconn)
|
||||
return
|
||||
}
|
||||
|
||||
h.handleAdaptiveRequest(w, r, transport, requestID, subdomain)
|
||||
}
|
||||
// Open stream with timeout
|
||||
stream, err := h.openStreamWithTimeout(tconn)
|
||||
if err != nil {
|
||||
w.Header().Set("Connection", "close")
|
||||
http.Error(w, "Tunnel unavailable", http.StatusBadGateway)
|
||||
return
|
||||
}
|
||||
defer stream.Close()
|
||||
|
||||
func (h *Handler) handleAdaptiveRequest(w http.ResponseWriter, r *http.Request, transport tunnel.Transport, requestID string, subdomain string) {
|
||||
const streamingThreshold int64 = 1 * 1024 * 1024
|
||||
// Track active connections
|
||||
tconn.IncActiveConnections()
|
||||
defer tconn.DecActiveConnections()
|
||||
|
||||
// Wrap stream with counting for traffic stats
|
||||
countingStream := netutil.NewCountingConn(stream,
|
||||
tconn.AddBytesOut, // Data read from stream = bytes out to client
|
||||
tconn.AddBytesIn, // Data written to stream = bytes in from client
|
||||
)
|
||||
|
||||
// 1) Write request over the stream (net/http handles large bodies correctly).
|
||||
if err := r.Write(countingStream); err != nil {
|
||||
w.Header().Set("Connection", "close")
|
||||
_ = r.Body.Close()
|
||||
http.Error(w, "Forward failed", http.StatusBadGateway)
|
||||
return
|
||||
}
|
||||
|
||||
// 2) Read response from stream.
|
||||
resp, err := http.ReadResponse(bufio.NewReaderSize(countingStream, 32*1024), r)
|
||||
if err != nil {
|
||||
w.Header().Set("Connection", "close")
|
||||
http.Error(w, "Read response failed", http.StatusBadGateway)
|
||||
return
|
||||
}
|
||||
defer resp.Body.Close()
|
||||
|
||||
// 3) Copy headers (strip hop-by-hop).
|
||||
h.copyResponseHeaders(w.Header(), resp.Header, r.Host)
|
||||
|
||||
statusCode := resp.StatusCode
|
||||
if statusCode == 0 {
|
||||
statusCode = http.StatusOK
|
||||
}
|
||||
|
||||
// Ensure message delimiting works with our custom ResponseWriter:
|
||||
// - If Content-Length is known, send it.
|
||||
// - Otherwise, re-chunk the decoded body ourselves.
|
||||
if r.Method == http.MethodHead || statusCode == http.StatusNoContent || statusCode == http.StatusNotModified {
|
||||
if resp.ContentLength >= 0 {
|
||||
w.Header().Set("Content-Length", fmt.Sprintf("%d", resp.ContentLength))
|
||||
} else {
|
||||
w.Header().Del("Content-Length")
|
||||
}
|
||||
w.WriteHeader(statusCode)
|
||||
return
|
||||
}
|
||||
|
||||
if resp.ContentLength >= 0 {
|
||||
w.Header().Set("Content-Length", fmt.Sprintf("%d", resp.ContentLength))
|
||||
w.WriteHeader(statusCode)
|
||||
|
||||
ctx := r.Context()
|
||||
done := make(chan struct{})
|
||||
go func() {
|
||||
select {
|
||||
case <-ctx.Done():
|
||||
stream.Close()
|
||||
case <-done:
|
||||
}
|
||||
}()
|
||||
_, _ = io.Copy(w, resp.Body)
|
||||
close(done)
|
||||
stream.Close()
|
||||
return
|
||||
}
|
||||
|
||||
w.Header().Del("Content-Length")
|
||||
w.Header().Set("Transfer-Encoding", "chunked")
|
||||
if len(resp.Trailer) > 0 {
|
||||
w.Header().Set("Trailer", trailerKeys(resp.Trailer))
|
||||
}
|
||||
w.WriteHeader(statusCode)
|
||||
|
||||
ctx := r.Context()
|
||||
|
||||
var cancelTransport func()
|
||||
if transport != nil {
|
||||
cancelOnce := sync.Once{}
|
||||
cancelFunc := func() {
|
||||
header := protocol.DataHeader{
|
||||
StreamID: requestID,
|
||||
RequestID: requestID,
|
||||
Type: protocol.DataTypeClose,
|
||||
IsLast: true,
|
||||
}
|
||||
|
||||
payload, poolBuffer, err := protocol.EncodeDataPayloadPooled(header, nil)
|
||||
if err != nil {
|
||||
return
|
||||
}
|
||||
|
||||
frame := protocol.NewFramePooled(protocol.FrameTypeData, payload, poolBuffer)
|
||||
if err := transport.SendFrame(frame); err != nil {
|
||||
h.logger.Debug("Failed to send cancel frame to client",
|
||||
zap.String("request_id", requestID),
|
||||
zap.Error(err),
|
||||
)
|
||||
}
|
||||
}
|
||||
|
||||
cancelTransport = func() {
|
||||
cancelOnce.Do(cancelFunc)
|
||||
}
|
||||
|
||||
h.responses.RegisterCancelFunc(requestID, cancelTransport)
|
||||
defer h.responses.CleanupCancelFunc(requestID)
|
||||
}
|
||||
|
||||
largeBufferPtr := h.bufferPool.GetLarge()
|
||||
tempBufPtr := h.bufferPool.GetMedium()
|
||||
|
||||
defer func() {
|
||||
h.bufferPool.PutLarge(largeBufferPtr)
|
||||
h.bufferPool.PutMedium(tempBufPtr)
|
||||
}()
|
||||
|
||||
buffer := (*largeBufferPtr)[:0]
|
||||
tempBuf := (*tempBufPtr)[:pool.MediumBufferSize]
|
||||
|
||||
var totalRead int64
|
||||
var hitThreshold bool
|
||||
|
||||
for totalRead < streamingThreshold {
|
||||
n, err := r.Body.Read(tempBuf)
|
||||
if n > 0 {
|
||||
buffer = append(buffer, tempBuf[:n]...)
|
||||
totalRead += int64(n)
|
||||
}
|
||||
if err == io.EOF {
|
||||
r.Body.Close()
|
||||
h.sendBufferedRequest(ctx, w, r, transport, requestID, subdomain, cancelTransport, buffer)
|
||||
return
|
||||
}
|
||||
if err != nil {
|
||||
r.Body.Close()
|
||||
h.logger.Error("Read request body failed", zap.Error(err))
|
||||
http.Error(w, "Failed to read request body", http.StatusInternalServerError)
|
||||
return
|
||||
}
|
||||
if totalRead >= streamingThreshold {
|
||||
hitThreshold = true
|
||||
break
|
||||
}
|
||||
}
|
||||
|
||||
if !hitThreshold {
|
||||
r.Body.Close()
|
||||
h.sendBufferedRequest(ctx, w, r, transport, requestID, subdomain, cancelTransport, buffer)
|
||||
return
|
||||
}
|
||||
|
||||
h.streamLargeRequest(ctx, w, r, transport, requestID, subdomain, cancelTransport, buffer)
|
||||
}
|
||||
|
||||
func (h *Handler) sendBufferedRequest(ctx context.Context, w http.ResponseWriter, r *http.Request, transport tunnel.Transport, requestID string, subdomain string, cancelTransport func(), body []byte) {
|
||||
headers := h.headerPool.Get()
|
||||
h.headerPool.CloneWithExtra(headers, r.Header, "Host", r.Host)
|
||||
|
||||
httpReq := protocol.HTTPRequest{
|
||||
Method: r.Method,
|
||||
URL: r.URL.String(),
|
||||
Headers: headers,
|
||||
Body: body,
|
||||
}
|
||||
|
||||
reqBytes, err := protocol.EncodeHTTPRequest(&httpReq)
|
||||
h.headerPool.Put(headers)
|
||||
|
||||
if err != nil {
|
||||
h.logger.Error("Encode HTTP request failed", zap.Error(err))
|
||||
http.Error(w, "Internal server error", http.StatusInternalServerError)
|
||||
return
|
||||
}
|
||||
|
||||
header := protocol.DataHeader{
|
||||
StreamID: requestID,
|
||||
RequestID: requestID,
|
||||
Type: protocol.DataTypeHTTPRequest,
|
||||
IsLast: true,
|
||||
}
|
||||
|
||||
payload, poolBuffer, err := protocol.EncodeDataPayloadPooled(header, reqBytes)
|
||||
if err != nil {
|
||||
h.logger.Error("Encode data payload failed", zap.Error(err))
|
||||
http.Error(w, "Internal server error", http.StatusInternalServerError)
|
||||
return
|
||||
}
|
||||
|
||||
frame := protocol.NewFramePooled(protocol.FrameTypeData, payload, poolBuffer)
|
||||
|
||||
respChan := h.responses.CreateResponseChan(requestID)
|
||||
streamingDone := h.responses.CreateStreamingResponse(requestID, w)
|
||||
defer func() {
|
||||
h.responses.CleanupResponseChan(requestID)
|
||||
h.responses.CleanupStreamingResponse(requestID)
|
||||
}()
|
||||
|
||||
if err := transport.SendFrame(frame); err != nil {
|
||||
h.logger.Error("Send frame to tunnel failed", zap.Error(err))
|
||||
http.Error(w, "Failed to forward request to tunnel", http.StatusBadGateway)
|
||||
return
|
||||
}
|
||||
|
||||
select {
|
||||
case respMsg := <-respChan:
|
||||
if respMsg == nil {
|
||||
http.Error(w, "Internal server error: nil response", http.StatusInternalServerError)
|
||||
return
|
||||
}
|
||||
h.writeHTTPResponse(w, respMsg, subdomain, r)
|
||||
case <-streamingDone:
|
||||
// Streaming response has been fully written by SendStreamingChunk
|
||||
case <-ctx.Done():
|
||||
if cancelTransport != nil {
|
||||
cancelTransport()
|
||||
}
|
||||
h.logger.Debug("HTTP request context cancelled",
|
||||
zap.String("request_id", requestID),
|
||||
zap.String("subdomain", subdomain),
|
||||
)
|
||||
return
|
||||
case <-time.After(5 * time.Minute):
|
||||
h.logger.Error("Request timeout",
|
||||
zap.String("request_id", requestID),
|
||||
zap.String("url", r.URL.String()),
|
||||
)
|
||||
http.Error(w, "Request timeout - the tunnel client did not respond in time", http.StatusGatewayTimeout)
|
||||
}
|
||||
}
|
||||
|
||||
func (h *Handler) streamLargeRequest(ctx context.Context, w http.ResponseWriter, r *http.Request, transport tunnel.Transport, requestID string, subdomain string, cancelTransport func(), bufferedData []byte) {
|
||||
headers := h.headerPool.Get()
|
||||
h.headerPool.CloneWithExtra(headers, r.Header, "Host", r.Host)
|
||||
|
||||
httpReqHead := protocol.HTTPRequestHead{
|
||||
Method: r.Method,
|
||||
URL: r.URL.String(),
|
||||
Headers: headers,
|
||||
ContentLength: r.ContentLength,
|
||||
}
|
||||
|
||||
headBytes, err := protocol.EncodeHTTPRequestHead(&httpReqHead)
|
||||
h.headerPool.Put(headers)
|
||||
|
||||
if err != nil {
|
||||
h.logger.Error("Encode HTTP request head failed", zap.Error(err))
|
||||
http.Error(w, "Internal server error", http.StatusInternalServerError)
|
||||
return
|
||||
}
|
||||
|
||||
headHeader := protocol.DataHeader{
|
||||
StreamID: requestID,
|
||||
RequestID: requestID,
|
||||
Type: protocol.DataTypeHTTPHead, // shared streaming head type
|
||||
IsLast: false,
|
||||
}
|
||||
|
||||
headPayload, headPoolBuffer, err := protocol.EncodeDataPayloadPooled(headHeader, headBytes)
|
||||
if err != nil {
|
||||
h.logger.Error("Encode head payload failed", zap.Error(err))
|
||||
http.Error(w, "Internal server error", http.StatusInternalServerError)
|
||||
return
|
||||
}
|
||||
|
||||
headFrame := protocol.NewFramePooled(protocol.FrameTypeData, headPayload, headPoolBuffer)
|
||||
|
||||
respChan := h.responses.CreateResponseChan(requestID)
|
||||
streamingDone := h.responses.CreateStreamingResponse(requestID, w)
|
||||
defer func() {
|
||||
h.responses.CleanupResponseChan(requestID)
|
||||
h.responses.CleanupStreamingResponse(requestID)
|
||||
}()
|
||||
|
||||
if err := transport.SendFrame(headFrame); err != nil {
|
||||
h.logger.Error("Send head frame failed", zap.Error(err))
|
||||
http.Error(w, "Failed to forward request to tunnel", http.StatusBadGateway)
|
||||
return
|
||||
}
|
||||
|
||||
if len(bufferedData) > 0 {
|
||||
chunkHeader := protocol.DataHeader{
|
||||
StreamID: requestID,
|
||||
RequestID: requestID,
|
||||
Type: protocol.DataTypeHTTPBodyChunk, // shared streaming body type
|
||||
IsLast: false,
|
||||
}
|
||||
|
||||
chunkPayload, chunkPoolBuffer, err := protocol.EncodeDataPayloadPooled(chunkHeader, bufferedData)
|
||||
if err != nil {
|
||||
h.logger.Error("Encode buffered chunk failed", zap.Error(err))
|
||||
|
||||
finalHeader := protocol.DataHeader{
|
||||
StreamID: requestID,
|
||||
RequestID: requestID,
|
||||
Type: protocol.DataTypeHTTPRequestBodyChunk,
|
||||
IsLast: true,
|
||||
}
|
||||
finalPayload, finalPoolBuffer, ferr := protocol.EncodeDataPayloadPooled(finalHeader, nil)
|
||||
if ferr == nil {
|
||||
finalFrame := protocol.NewFramePooled(protocol.FrameTypeData, finalPayload, finalPoolBuffer)
|
||||
transport.SendFrame(finalFrame)
|
||||
}
|
||||
return
|
||||
}
|
||||
|
||||
chunkFrame := protocol.NewFramePooled(protocol.FrameTypeData, chunkPayload, chunkPoolBuffer)
|
||||
if err := transport.SendFrame(chunkFrame); err != nil {
|
||||
h.logger.Error("Send buffered chunk failed", zap.Error(err))
|
||||
|
||||
finalHeader := protocol.DataHeader{
|
||||
StreamID: requestID,
|
||||
RequestID: requestID,
|
||||
Type: protocol.DataTypeHTTPRequestBodyChunk,
|
||||
IsLast: true,
|
||||
}
|
||||
finalPayload, finalPoolBuffer, ferr := protocol.EncodeDataPayloadPooled(finalHeader, nil)
|
||||
if ferr == nil {
|
||||
finalFrame := protocol.NewFramePooled(protocol.FrameTypeData, finalPayload, finalPoolBuffer)
|
||||
transport.SendFrame(finalFrame)
|
||||
}
|
||||
return
|
||||
}
|
||||
}
|
||||
|
||||
streamBufPtr := h.bufferPool.GetMedium()
|
||||
defer h.bufferPool.PutMedium(streamBufPtr)
|
||||
buffer := (*streamBufPtr)[:pool.MediumBufferSize]
|
||||
for {
|
||||
done := make(chan struct{})
|
||||
go func() {
|
||||
select {
|
||||
case <-ctx.Done():
|
||||
if cancelTransport != nil {
|
||||
cancelTransport()
|
||||
}
|
||||
h.logger.Debug("Streaming request cancelled via context",
|
||||
zap.String("request_id", requestID),
|
||||
zap.String("subdomain", subdomain),
|
||||
)
|
||||
return
|
||||
default:
|
||||
stream.Close()
|
||||
case <-done:
|
||||
}
|
||||
}()
|
||||
|
||||
n, readErr := r.Body.Read(buffer)
|
||||
if n > 0 {
|
||||
isLast := readErr == io.EOF
|
||||
|
||||
chunkHeader := protocol.DataHeader{
|
||||
StreamID: requestID,
|
||||
RequestID: requestID,
|
||||
Type: protocol.DataTypeHTTPBodyChunk, // shared streaming body type
|
||||
IsLast: isLast,
|
||||
}
|
||||
|
||||
chunkPayload, chunkPoolBuffer, err := protocol.EncodeDataPayloadPooled(chunkHeader, buffer[:n])
|
||||
if err != nil {
|
||||
h.logger.Error("Encode chunk payload failed", zap.Error(err))
|
||||
|
||||
finalHeader := protocol.DataHeader{
|
||||
StreamID: requestID,
|
||||
RequestID: requestID,
|
||||
Type: protocol.DataTypeHTTPRequestBodyChunk,
|
||||
IsLast: true,
|
||||
}
|
||||
finalPayload, finalPoolBuffer, ferr := protocol.EncodeDataPayloadPooled(finalHeader, nil)
|
||||
if ferr == nil {
|
||||
finalFrame := protocol.NewFramePooled(protocol.FrameTypeData, finalPayload, finalPoolBuffer)
|
||||
transport.SendFrame(finalFrame)
|
||||
}
|
||||
return
|
||||
}
|
||||
|
||||
chunkFrame := protocol.NewFramePooled(protocol.FrameTypeData, chunkPayload, chunkPoolBuffer)
|
||||
if err := transport.SendFrame(chunkFrame); err != nil {
|
||||
h.logger.Error("Send chunk frame failed", zap.Error(err))
|
||||
|
||||
finalHeader := protocol.DataHeader{
|
||||
StreamID: requestID,
|
||||
RequestID: requestID,
|
||||
Type: protocol.DataTypeHTTPRequestBodyChunk,
|
||||
IsLast: true,
|
||||
}
|
||||
finalPayload, finalPoolBuffer, ferr := protocol.EncodeDataPayloadPooled(finalHeader, nil)
|
||||
if ferr == nil {
|
||||
finalFrame := protocol.NewFramePooled(protocol.FrameTypeData, finalPayload, finalPoolBuffer)
|
||||
transport.SendFrame(finalFrame)
|
||||
}
|
||||
return
|
||||
}
|
||||
}
|
||||
|
||||
if readErr == io.EOF {
|
||||
if n == 0 {
|
||||
finalHeader := protocol.DataHeader{
|
||||
StreamID: requestID,
|
||||
RequestID: requestID,
|
||||
Type: protocol.DataTypeHTTPRequestBodyChunk,
|
||||
IsLast: true,
|
||||
}
|
||||
finalPayload, finalPoolBuffer, err := protocol.EncodeDataPayloadPooled(finalHeader, nil)
|
||||
if err == nil {
|
||||
finalFrame := protocol.NewFramePooled(protocol.FrameTypeData, finalPayload, finalPoolBuffer)
|
||||
transport.SendFrame(finalFrame)
|
||||
}
|
||||
}
|
||||
break
|
||||
}
|
||||
if readErr != nil {
|
||||
h.logger.Error("Read request body failed", zap.Error(readErr))
|
||||
|
||||
finalHeader := protocol.DataHeader{
|
||||
StreamID: requestID,
|
||||
RequestID: requestID,
|
||||
Type: protocol.DataTypeHTTPRequestBodyChunk,
|
||||
IsLast: true,
|
||||
}
|
||||
finalPayload, finalPoolBuffer, err := protocol.EncodeDataPayloadPooled(finalHeader, nil)
|
||||
if err == nil {
|
||||
finalFrame := protocol.NewFramePooled(protocol.FrameTypeData, finalPayload, finalPoolBuffer)
|
||||
transport.SendFrame(finalFrame)
|
||||
}
|
||||
|
||||
http.Error(w, "Failed to read request body", http.StatusInternalServerError)
|
||||
return
|
||||
}
|
||||
if err := writeChunked(w, resp.Body, resp.Trailer); err != nil {
|
||||
h.logger.Debug("Write chunked response failed", zap.Error(err))
|
||||
}
|
||||
close(done)
|
||||
stream.Close()
|
||||
}
|
||||
|
||||
r.Body.Close()
|
||||
func (h *Handler) openStreamWithTimeout(tconn *tunnel.Connection) (net.Conn, error) {
|
||||
type result struct {
|
||||
stream net.Conn
|
||||
err error
|
||||
}
|
||||
ch := make(chan result, 1)
|
||||
|
||||
go func() {
|
||||
s, err := tconn.OpenStream()
|
||||
ch <- result{s, err}
|
||||
}()
|
||||
|
||||
select {
|
||||
case respMsg := <-respChan:
|
||||
if respMsg == nil {
|
||||
http.Error(w, "Internal server error: nil response", http.StatusInternalServerError)
|
||||
return
|
||||
}
|
||||
h.writeHTTPResponse(w, respMsg, subdomain, r)
|
||||
case <-streamingDone:
|
||||
// Streaming response has been fully written by SendStreamingChunk
|
||||
case <-ctx.Done():
|
||||
if cancelTransport != nil {
|
||||
cancelTransport()
|
||||
}
|
||||
h.logger.Debug("Streaming HTTP request context cancelled",
|
||||
zap.String("request_id", requestID),
|
||||
zap.String("subdomain", subdomain),
|
||||
)
|
||||
return
|
||||
case <-time.After(5 * time.Minute):
|
||||
h.logger.Error("Streaming request timeout",
|
||||
zap.String("request_id", requestID),
|
||||
zap.String("url", r.URL.String()),
|
||||
)
|
||||
http.Error(w, "Request timeout - the tunnel client did not respond in time", http.StatusGatewayTimeout)
|
||||
case r := <-ch:
|
||||
return r.stream, r.err
|
||||
case <-time.After(openStreamTimeout):
|
||||
return nil, fmt.Errorf("open stream timeout")
|
||||
}
|
||||
}
|
||||
|
||||
func (h *Handler) writeHTTPResponse(w http.ResponseWriter, resp *protocol.HTTPResponse, subdomain string, r *http.Request) {
|
||||
if resp == nil {
|
||||
http.Error(w, "Invalid response from tunnel", http.StatusBadGateway)
|
||||
func (h *Handler) handleWebSocket(w http.ResponseWriter, r *http.Request, tconn *tunnel.Connection) {
|
||||
stream, err := h.openStreamWithTimeout(tconn)
|
||||
if err != nil {
|
||||
http.Error(w, "Tunnel unavailable", http.StatusBadGateway)
|
||||
return
|
||||
}
|
||||
|
||||
// For buffered responses, we have the complete body, so we can set Content-Length
|
||||
// Skip ALL hop-by-hop headers - client should have already cleaned them
|
||||
for key, values := range resp.Headers {
|
||||
tconn.IncActiveConnections()
|
||||
|
||||
hj, ok := w.(http.Hijacker)
|
||||
if !ok {
|
||||
stream.Close()
|
||||
tconn.DecActiveConnections()
|
||||
http.Error(w, "WebSocket not supported", http.StatusInternalServerError)
|
||||
return
|
||||
}
|
||||
|
||||
clientConn, _, err := hj.Hijack()
|
||||
if err != nil {
|
||||
stream.Close()
|
||||
tconn.DecActiveConnections()
|
||||
http.Error(w, "Failed to hijack connection", http.StatusInternalServerError)
|
||||
return
|
||||
}
|
||||
|
||||
if err := r.Write(stream); err != nil {
|
||||
stream.Close()
|
||||
clientConn.Close()
|
||||
tconn.DecActiveConnections()
|
||||
return
|
||||
}
|
||||
|
||||
go func() {
|
||||
defer stream.Close()
|
||||
defer clientConn.Close()
|
||||
defer tconn.DecActiveConnections()
|
||||
|
||||
_ = netutil.PipeWithCallbacks(r.Context(), stream, clientConn,
|
||||
func(n int64) { tconn.AddBytesOut(n) },
|
||||
func(n int64) { tconn.AddBytesIn(n) },
|
||||
)
|
||||
}()
|
||||
}
|
||||
|
||||
func (h *Handler) copyResponseHeaders(dst http.Header, src http.Header, proxyHost string) {
|
||||
for key, values := range src {
|
||||
canonicalKey := http.CanonicalHeaderKey(key)
|
||||
|
||||
// Skip hop-by-hop headers completely using canonical key comparison
|
||||
// Hop-by-hop headers must not be forwarded.
|
||||
if canonicalKey == "Connection" ||
|
||||
canonicalKey == "Keep-Alive" ||
|
||||
canonicalKey == "Transfer-Encoding" ||
|
||||
@@ -496,29 +257,61 @@ func (h *Handler) writeHTTPResponse(w http.ResponseWriter, resp *protocol.HTTPRe
|
||||
}
|
||||
|
||||
if canonicalKey == "Location" && len(values) > 0 {
|
||||
rewrittenLocation := h.rewriteLocationHeader(values[0], r.Host)
|
||||
w.Header().Set("Location", rewrittenLocation)
|
||||
dst.Set("Location", h.rewriteLocationHeader(values[0], proxyHost))
|
||||
continue
|
||||
}
|
||||
|
||||
for _, value := range values {
|
||||
w.Header().Add(key, value)
|
||||
dst.Add(key, value)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func trailerKeys(hdr http.Header) string {
|
||||
keys := make([]string, 0, len(hdr))
|
||||
for k := range hdr {
|
||||
keys = append(keys, k)
|
||||
}
|
||||
// Deterministic order is nicer for debugging; no semantic impact.
|
||||
sortStrings(keys)
|
||||
return strings.Join(keys, ", ")
|
||||
}
|
||||
|
||||
func writeChunked(w io.Writer, r io.Reader, trailer http.Header) error {
|
||||
buf := make([]byte, 32*1024)
|
||||
for {
|
||||
n, err := r.Read(buf)
|
||||
if n > 0 {
|
||||
if _, werr := fmt.Fprintf(w, "%x\r\n", n); werr != nil {
|
||||
return werr
|
||||
}
|
||||
if _, werr := w.Write(buf[:n]); werr != nil {
|
||||
return werr
|
||||
}
|
||||
if _, werr := io.WriteString(w, "\r\n"); werr != nil {
|
||||
return werr
|
||||
}
|
||||
}
|
||||
if err == io.EOF {
|
||||
break
|
||||
}
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
}
|
||||
|
||||
// For buffered mode, always set Content-Length with the actual body size
|
||||
w.Header().Set("Content-Length", fmt.Sprintf("%d", len(resp.Body)))
|
||||
|
||||
statusCode := resp.StatusCode
|
||||
if statusCode == 0 {
|
||||
statusCode = http.StatusOK
|
||||
if _, err := io.WriteString(w, "0\r\n"); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
w.WriteHeader(statusCode)
|
||||
|
||||
if len(resp.Body) > 0 {
|
||||
w.Write(resp.Body)
|
||||
for k, vv := range trailer {
|
||||
for _, v := range vv {
|
||||
if _, err := io.WriteString(w, fmt.Sprintf("%s: %s\r\n", k, v)); err != nil {
|
||||
return err
|
||||
}
|
||||
}
|
||||
}
|
||||
_, err := io.WriteString(w, "\r\n")
|
||||
return err
|
||||
}
|
||||
|
||||
func (h *Handler) rewriteLocationHeader(location, proxyHost string) string {
|
||||
@@ -535,22 +328,13 @@ func (h *Handler) rewriteLocationHeader(location, proxyHost string) string {
|
||||
strings.HasPrefix(locationURL.Host, "localhost:") ||
|
||||
locationURL.Host == "127.0.0.1" ||
|
||||
strings.HasPrefix(locationURL.Host, "127.0.0.1:") {
|
||||
scheme := "https"
|
||||
if strings.Contains(proxyHost, ":") && !strings.Contains(proxyHost, "https") {
|
||||
parts := strings.Split(proxyHost, ":")
|
||||
if len(parts) == 2 && parts[1] != "443" {
|
||||
scheme = "https"
|
||||
}
|
||||
}
|
||||
|
||||
rewritten := fmt.Sprintf("%s://%s%s", scheme, proxyHost, locationURL.Path)
|
||||
rewritten := fmt.Sprintf("https://%s%s", proxyHost, locationURL.Path)
|
||||
if locationURL.RawQuery != "" {
|
||||
rewritten += "?" + locationURL.RawQuery
|
||||
}
|
||||
if locationURL.Fragment != "" {
|
||||
rewritten += "#" + locationURL.Fragment
|
||||
}
|
||||
|
||||
return rewritten
|
||||
}
|
||||
|
||||
@@ -568,8 +352,7 @@ func (h *Handler) extractSubdomain(host string) string {
|
||||
|
||||
suffix := "." + h.domain
|
||||
if strings.HasSuffix(host, suffix) {
|
||||
subdomain := strings.TrimSuffix(host, suffix)
|
||||
return subdomain
|
||||
return strings.TrimSuffix(host, suffix)
|
||||
}
|
||||
|
||||
return ""
|
||||
@@ -652,9 +435,17 @@ func (h *Handler) serveStats(w http.ResponseWriter, r *http.Request) {
|
||||
}
|
||||
|
||||
for _, conn := range connections {
|
||||
if conn == nil {
|
||||
continue
|
||||
}
|
||||
stats["tunnels"] = append(stats["tunnels"].([]map[string]interface{}), map[string]interface{}{
|
||||
"subdomain": conn.Subdomain,
|
||||
"last_active": conn.LastActive.Unix(),
|
||||
"subdomain": conn.Subdomain,
|
||||
"tunnel_type": string(conn.GetTunnelType()),
|
||||
"last_active": conn.LastActive.Unix(),
|
||||
"bytes_in": conn.GetBytesIn(),
|
||||
"bytes_out": conn.GetBytesOut(),
|
||||
"active_connections": conn.GetActiveConnections(),
|
||||
"total_bytes": conn.GetBytesIn() + conn.GetBytesOut(),
|
||||
})
|
||||
}
|
||||
|
||||
@@ -668,3 +459,13 @@ func (h *Handler) serveStats(w http.ResponseWriter, r *http.Request) {
|
||||
w.Header().Set("Content-Length", fmt.Sprintf("%d", len(data)))
|
||||
w.Write(data)
|
||||
}
|
||||
|
||||
func sortStrings(s []string) {
|
||||
for i := 0; i < len(s); i++ {
|
||||
for j := i + 1; j < len(s); j++ {
|
||||
if s[j] < s[i] {
|
||||
s[i], s[j] = s[j], s[i]
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
Reference in New Issue
Block a user