mirror of
https://github.com/Gouryella/drip.git
synced 2026-02-28 23:36:00 +00:00
feat(cli): Supports stopping HTTPS tunnels and optimizes configuration display logic.
- Added support for HTTPS tunnel types to the `drip stop` command and updated the example documentation. - Optimized token display logic to adapt to token formats of different lengths. - Adjust the alignment of FrameHandler buffer read/write and timeout configuration formats. - Move the error handling logic location to ensure data read integrity. - Introducing context to control request lifecycle and supporting cancel transfer in proxy handlers - The hop-by-hop header judgment format in the unified response header filtering rules - Add a context-aware streaming request cancellation mechanism and extend the channel cleanup timeout. - Add a context control field to the TCP connection structure to support connection lifecycle management. - Format the httpResponseWriter field comments
This commit is contained in:
@@ -117,9 +117,15 @@ func runConfigShow(cmd *cobra.Command, args []string) error {
|
||||
|
||||
var displayToken string
|
||||
if cfg.Token != "" {
|
||||
if len(cfg.Token) > 10 {
|
||||
displayToken = cfg.Token[:3] + "***" + cfg.Token[len(cfg.Token)-3:]
|
||||
tokenLen := len(cfg.Token)
|
||||
if tokenLen <= 3 {
|
||||
// For very short tokens, just show asterisks
|
||||
displayToken = "***"
|
||||
} else if tokenLen > 10 {
|
||||
// For long tokens, show first 3 and last 3 characters
|
||||
displayToken = cfg.Token[:3] + "***" + cfg.Token[tokenLen-3:]
|
||||
} else {
|
||||
// For medium tokens (4-10 chars), show first 3 characters
|
||||
displayToken = cfg.Token[:3] + "***"
|
||||
}
|
||||
} else {
|
||||
|
||||
@@ -14,6 +14,7 @@ var stopCmd = &cobra.Command{
|
||||
|
||||
Examples:
|
||||
drip stop http 3000 Stop HTTP tunnel on port 3000
|
||||
drip stop https 8443 Stop HTTPS tunnel on port 8443
|
||||
drip stop tcp 5432 Stop TCP tunnel on port 5432
|
||||
drip stop all Stop all running tunnels
|
||||
|
||||
@@ -37,8 +38,8 @@ func runStop(cmd *cobra.Command, args []string) error {
|
||||
}
|
||||
|
||||
tunnelType := args[0]
|
||||
if tunnelType != "http" && tunnelType != "tcp" {
|
||||
return fmt.Errorf("invalid tunnel type: %s (must be 'http' or 'tcp')", tunnelType)
|
||||
if tunnelType != "http" && tunnelType != "https" && tunnelType != "tcp" {
|
||||
return fmt.Errorf("invalid tunnel type: %s (must be 'http', 'https', or 'tcp')", tunnelType)
|
||||
}
|
||||
|
||||
port, err := strconv.Atoi(args[1])
|
||||
|
||||
@@ -15,6 +15,7 @@ import (
|
||||
|
||||
"drip/internal/shared/pool"
|
||||
"drip/internal/shared/protocol"
|
||||
|
||||
"go.uber.org/zap"
|
||||
)
|
||||
|
||||
@@ -97,10 +98,10 @@ func NewFrameHandler(conn net.Conn, frameWriter *protocol.FrameWriter, localHost
|
||||
DisableKeepAlives: false, // Enable keep-alive for connection reuse
|
||||
TLSHandshakeTimeout: 5 * time.Second, // Reduced from 10s for faster failure detection
|
||||
TLSClientConfig: tlsConfig,
|
||||
ResponseHeaderTimeout: 15 * time.Second, // Reduced from 30s for faster timeout
|
||||
ResponseHeaderTimeout: 15 * time.Second, // Reduced from 30s for faster timeout
|
||||
ExpectContinueTimeout: 500 * time.Millisecond, // Reduced from 1s for better responsiveness
|
||||
WriteBufferSize: 32 * 1024, // 32KB write buffer
|
||||
ReadBufferSize: 32 * 1024, // 32KB read buffer
|
||||
WriteBufferSize: 32 * 1024, // 32KB write buffer
|
||||
ReadBufferSize: 32 * 1024, // 32KB read buffer
|
||||
DialContext: (&net.Dialer{
|
||||
Timeout: 3 * time.Second, // Reduced from 5s for faster connection attempts
|
||||
KeepAlive: 30 * time.Second, // Keep TCP keepalive
|
||||
@@ -226,9 +227,6 @@ func (h *FrameHandler) handleLocalResponse(stream *Stream) {
|
||||
}
|
||||
|
||||
n, err := stream.LocalConn.Read(buf)
|
||||
if err != nil {
|
||||
break
|
||||
}
|
||||
|
||||
if n > 0 {
|
||||
if h.isClosedCheck != nil && h.isClosedCheck() {
|
||||
@@ -264,6 +262,10 @@ func (h *FrameHandler) handleLocalResponse(stream *Stream) {
|
||||
|
||||
h.stats.AddBytesOut(int64(len(payload)))
|
||||
}
|
||||
|
||||
if err != nil {
|
||||
break
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
@@ -1,11 +1,13 @@
|
||||
package proxy
|
||||
|
||||
import (
|
||||
"context"
|
||||
"fmt"
|
||||
"io"
|
||||
"net/http"
|
||||
"net/url"
|
||||
"strings"
|
||||
"sync"
|
||||
"time"
|
||||
|
||||
json "github.com/goccy/go-json"
|
||||
@@ -90,8 +92,12 @@ func (h *Handler) ServeHTTP(w http.ResponseWriter, r *http.Request) {
|
||||
func (h *Handler) handleAdaptiveRequest(w http.ResponseWriter, r *http.Request, transport tunnel.Transport, requestID string, subdomain string) {
|
||||
const streamingThreshold int64 = 1 * 1024 * 1024
|
||||
|
||||
ctx := r.Context()
|
||||
|
||||
var cancelTransport func()
|
||||
if transport != nil {
|
||||
h.responses.RegisterCancelFunc(requestID, func() {
|
||||
cancelOnce := sync.Once{}
|
||||
cancelFunc := func() {
|
||||
header := protocol.DataHeader{
|
||||
StreamID: requestID,
|
||||
RequestID: requestID,
|
||||
@@ -111,8 +117,13 @@ func (h *Handler) handleAdaptiveRequest(w http.ResponseWriter, r *http.Request,
|
||||
zap.Error(err),
|
||||
)
|
||||
}
|
||||
})
|
||||
}
|
||||
|
||||
cancelTransport = func() {
|
||||
cancelOnce.Do(cancelFunc)
|
||||
}
|
||||
|
||||
h.responses.RegisterCancelFunc(requestID, cancelTransport)
|
||||
defer h.responses.CleanupCancelFunc(requestID)
|
||||
}
|
||||
|
||||
@@ -138,7 +149,7 @@ func (h *Handler) handleAdaptiveRequest(w http.ResponseWriter, r *http.Request,
|
||||
}
|
||||
if err == io.EOF {
|
||||
r.Body.Close()
|
||||
h.sendBufferedRequest(w, r, transport, requestID, subdomain, buffer)
|
||||
h.sendBufferedRequest(ctx, w, r, transport, requestID, subdomain, cancelTransport, buffer)
|
||||
return
|
||||
}
|
||||
if err != nil {
|
||||
@@ -155,14 +166,14 @@ func (h *Handler) handleAdaptiveRequest(w http.ResponseWriter, r *http.Request,
|
||||
|
||||
if !hitThreshold {
|
||||
r.Body.Close()
|
||||
h.sendBufferedRequest(w, r, transport, requestID, subdomain, buffer)
|
||||
h.sendBufferedRequest(ctx, w, r, transport, requestID, subdomain, cancelTransport, buffer)
|
||||
return
|
||||
}
|
||||
|
||||
h.streamLargeRequest(w, r, transport, requestID, subdomain, buffer)
|
||||
h.streamLargeRequest(ctx, w, r, transport, requestID, subdomain, cancelTransport, buffer)
|
||||
}
|
||||
|
||||
func (h *Handler) sendBufferedRequest(w http.ResponseWriter, r *http.Request, transport tunnel.Transport, requestID string, subdomain string, body []byte) {
|
||||
func (h *Handler) sendBufferedRequest(ctx context.Context, w http.ResponseWriter, r *http.Request, transport tunnel.Transport, requestID string, subdomain string, cancelTransport func(), body []byte) {
|
||||
headers := h.headerPool.Get()
|
||||
h.headerPool.CloneWithExtra(headers, r.Header, "Host", r.Host)
|
||||
|
||||
@@ -220,6 +231,15 @@ func (h *Handler) sendBufferedRequest(w http.ResponseWriter, r *http.Request, tr
|
||||
h.writeHTTPResponse(w, respMsg, subdomain, r)
|
||||
case <-streamingDone:
|
||||
// Streaming response has been fully written by SendStreamingChunk
|
||||
case <-ctx.Done():
|
||||
if cancelTransport != nil {
|
||||
cancelTransport()
|
||||
}
|
||||
h.logger.Debug("HTTP request context cancelled",
|
||||
zap.String("request_id", requestID),
|
||||
zap.String("subdomain", subdomain),
|
||||
)
|
||||
return
|
||||
case <-time.After(5 * time.Minute):
|
||||
h.logger.Error("Request timeout",
|
||||
zap.String("request_id", requestID),
|
||||
@@ -229,7 +249,7 @@ func (h *Handler) sendBufferedRequest(w http.ResponseWriter, r *http.Request, tr
|
||||
}
|
||||
}
|
||||
|
||||
func (h *Handler) streamLargeRequest(w http.ResponseWriter, r *http.Request, transport tunnel.Transport, requestID string, subdomain string, bufferedData []byte) {
|
||||
func (h *Handler) streamLargeRequest(ctx context.Context, w http.ResponseWriter, r *http.Request, transport tunnel.Transport, requestID string, subdomain string, cancelTransport func(), bufferedData []byte) {
|
||||
headers := h.headerPool.Get()
|
||||
h.headerPool.CloneWithExtra(headers, r.Header, "Host", r.Host)
|
||||
|
||||
@@ -327,6 +347,19 @@ func (h *Handler) streamLargeRequest(w http.ResponseWriter, r *http.Request, tra
|
||||
defer h.bufferPool.PutMedium(streamBufPtr)
|
||||
buffer := (*streamBufPtr)[:pool.MediumBufferSize]
|
||||
for {
|
||||
select {
|
||||
case <-ctx.Done():
|
||||
if cancelTransport != nil {
|
||||
cancelTransport()
|
||||
}
|
||||
h.logger.Debug("Streaming request cancelled via context",
|
||||
zap.String("request_id", requestID),
|
||||
zap.String("subdomain", subdomain),
|
||||
)
|
||||
return
|
||||
default:
|
||||
}
|
||||
|
||||
n, readErr := r.Body.Read(buffer)
|
||||
if n > 0 {
|
||||
isLast := readErr == io.EOF
|
||||
@@ -422,6 +455,15 @@ func (h *Handler) streamLargeRequest(w http.ResponseWriter, r *http.Request, tra
|
||||
h.writeHTTPResponse(w, respMsg, subdomain, r)
|
||||
case <-streamingDone:
|
||||
// Streaming response has been fully written by SendStreamingChunk
|
||||
case <-ctx.Done():
|
||||
if cancelTransport != nil {
|
||||
cancelTransport()
|
||||
}
|
||||
h.logger.Debug("Streaming HTTP request context cancelled",
|
||||
zap.String("request_id", requestID),
|
||||
zap.String("subdomain", subdomain),
|
||||
)
|
||||
return
|
||||
case <-time.After(5 * time.Minute):
|
||||
h.logger.Error("Streaming request timeout",
|
||||
zap.String("request_id", requestID),
|
||||
@@ -444,12 +486,12 @@ func (h *Handler) writeHTTPResponse(w http.ResponseWriter, resp *protocol.HTTPRe
|
||||
|
||||
// Skip hop-by-hop headers completely using canonical key comparison
|
||||
if canonicalKey == "Connection" ||
|
||||
canonicalKey == "Keep-Alive" ||
|
||||
canonicalKey == "Transfer-Encoding" ||
|
||||
canonicalKey == "Upgrade" ||
|
||||
canonicalKey == "Proxy-Connection" ||
|
||||
canonicalKey == "Te" ||
|
||||
canonicalKey == "Trailer" {
|
||||
canonicalKey == "Keep-Alive" ||
|
||||
canonicalKey == "Transfer-Encoding" ||
|
||||
canonicalKey == "Upgrade" ||
|
||||
canonicalKey == "Proxy-Connection" ||
|
||||
canonicalKey == "Te" ||
|
||||
canonicalKey == "Trailer" {
|
||||
continue
|
||||
}
|
||||
|
||||
|
||||
@@ -9,6 +9,7 @@ import (
|
||||
"time"
|
||||
|
||||
"drip/internal/shared/protocol"
|
||||
|
||||
"go.uber.org/zap"
|
||||
)
|
||||
|
||||
@@ -162,12 +163,12 @@ func (h *ResponseHandler) SendStreamingHead(requestID string, head *protocol.HTT
|
||||
|
||||
// Skip ALL hop-by-hop headers
|
||||
if canonicalKey == "Connection" ||
|
||||
canonicalKey == "Keep-Alive" ||
|
||||
canonicalKey == "Transfer-Encoding" ||
|
||||
canonicalKey == "Upgrade" ||
|
||||
canonicalKey == "Proxy-Connection" ||
|
||||
canonicalKey == "Te" ||
|
||||
canonicalKey == "Trailer" {
|
||||
canonicalKey == "Keep-Alive" ||
|
||||
canonicalKey == "Transfer-Encoding" ||
|
||||
canonicalKey == "Upgrade" ||
|
||||
canonicalKey == "Proxy-Connection" ||
|
||||
canonicalKey == "Te" ||
|
||||
canonicalKey == "Trailer" {
|
||||
continue
|
||||
}
|
||||
|
||||
@@ -349,7 +350,7 @@ func (h *ResponseHandler) cleanupLoop() {
|
||||
|
||||
func (h *ResponseHandler) cleanupExpiredChannels() {
|
||||
now := time.Now()
|
||||
timeout := 30 * time.Second
|
||||
timeout := 5 * time.Minute
|
||||
streamingTimeout := 5 * time.Minute
|
||||
|
||||
h.mu.Lock()
|
||||
|
||||
@@ -2,6 +2,7 @@ package tcp
|
||||
|
||||
import (
|
||||
"bufio"
|
||||
"context"
|
||||
"errors"
|
||||
"fmt"
|
||||
"io"
|
||||
@@ -41,6 +42,8 @@ type Connection struct {
|
||||
httpHandler http.Handler
|
||||
responseChans HTTPResponseHandler
|
||||
tunnelType protocol.TunnelType // Track tunnel type
|
||||
ctx context.Context
|
||||
cancel context.CancelFunc
|
||||
}
|
||||
|
||||
// HTTPResponseHandler interface for response channel operations
|
||||
@@ -56,6 +59,7 @@ type HTTPResponseHandler interface {
|
||||
|
||||
// NewConnection creates a new connection handler
|
||||
func NewConnection(conn net.Conn, authToken string, manager *tunnel.Manager, logger *zap.Logger, portAlloc *PortAllocator, domain string, publicPort int, httpHandler http.Handler, responseChans HTTPResponseHandler) *Connection {
|
||||
ctx, cancel := context.WithCancel(context.Background())
|
||||
return &Connection{
|
||||
conn: conn,
|
||||
authToken: authToken,
|
||||
@@ -68,6 +72,8 @@ func NewConnection(conn net.Conn, authToken string, manager *tunnel.Manager, log
|
||||
responseChans: responseChans,
|
||||
stopCh: make(chan struct{}),
|
||||
lastHeartbeat: time.Now(),
|
||||
ctx: ctx,
|
||||
cancel: cancel,
|
||||
}
|
||||
}
|
||||
|
||||
@@ -289,6 +295,10 @@ func (c *Connection) handleHTTPRequest(reader *bufio.Reader) error {
|
||||
return fmt.Errorf("failed to parse HTTP request: %w", err)
|
||||
}
|
||||
|
||||
if c.ctx != nil {
|
||||
req = req.WithContext(c.ctx)
|
||||
}
|
||||
|
||||
c.logger.Info("Processing HTTP request on TCP port",
|
||||
zap.String("method", req.Method),
|
||||
zap.String("url", req.URL.String()),
|
||||
@@ -588,6 +598,10 @@ func (c *Connection) Close() {
|
||||
|
||||
close(c.stopCh)
|
||||
|
||||
if c.cancel != nil {
|
||||
c.cancel()
|
||||
}
|
||||
|
||||
if c.frameWriter != nil {
|
||||
c.frameWriter.Flush()
|
||||
c.frameWriter.Close()
|
||||
@@ -621,7 +635,7 @@ func (c *Connection) GetSubdomain() string {
|
||||
// httpResponseWriter implements http.ResponseWriter for writing to a net.Conn
|
||||
type httpResponseWriter struct {
|
||||
conn net.Conn
|
||||
writer *bufio.Writer // Buffered writer for efficient I/O
|
||||
writer *bufio.Writer // Buffered writer for efficient I/O
|
||||
header http.Header
|
||||
statusCode int
|
||||
headerWritten bool
|
||||
|
||||
Reference in New Issue
Block a user