feat(cli): Supports stopping HTTPS tunnels and optimizes configuration display logic.

- Added support for HTTPS tunnel types to the `drip stop` command and updated the example documentation.
- Optimized token display logic to adapt to token formats of different lengths.
- Adjust the alignment of FrameHandler buffer read/write and timeout configuration formats.
- Move the error handling logic location to ensure data read integrity.
- Introducing context to control request lifecycle and supporting cancel transfer in proxy handlers
- The hop-by-hop header judgment format in the unified response header filtering rules
- Add a context-aware streaming request cancellation mechanism and extend the channel cleanup timeout.
- Add a context control field to the TCP connection structure to support connection lifecycle management.
- Format the httpResponseWriter field comments
This commit is contained in:
Gouryella
2025-12-08 16:57:10 +08:00
parent 3bc7978999
commit d21bb4897f
6 changed files with 97 additions and 31 deletions

View File

@@ -117,9 +117,15 @@ func runConfigShow(cmd *cobra.Command, args []string) error {
var displayToken string
if cfg.Token != "" {
if len(cfg.Token) > 10 {
displayToken = cfg.Token[:3] + "***" + cfg.Token[len(cfg.Token)-3:]
tokenLen := len(cfg.Token)
if tokenLen <= 3 {
// For very short tokens, just show asterisks
displayToken = "***"
} else if tokenLen > 10 {
// For long tokens, show first 3 and last 3 characters
displayToken = cfg.Token[:3] + "***" + cfg.Token[tokenLen-3:]
} else {
// For medium tokens (4-10 chars), show first 3 characters
displayToken = cfg.Token[:3] + "***"
}
} else {

View File

@@ -14,6 +14,7 @@ var stopCmd = &cobra.Command{
Examples:
drip stop http 3000 Stop HTTP tunnel on port 3000
drip stop https 8443 Stop HTTPS tunnel on port 8443
drip stop tcp 5432 Stop TCP tunnel on port 5432
drip stop all Stop all running tunnels
@@ -37,8 +38,8 @@ func runStop(cmd *cobra.Command, args []string) error {
}
tunnelType := args[0]
if tunnelType != "http" && tunnelType != "tcp" {
return fmt.Errorf("invalid tunnel type: %s (must be 'http' or 'tcp')", tunnelType)
if tunnelType != "http" && tunnelType != "https" && tunnelType != "tcp" {
return fmt.Errorf("invalid tunnel type: %s (must be 'http', 'https', or 'tcp')", tunnelType)
}
port, err := strconv.Atoi(args[1])

View File

@@ -15,6 +15,7 @@ import (
"drip/internal/shared/pool"
"drip/internal/shared/protocol"
"go.uber.org/zap"
)
@@ -97,10 +98,10 @@ func NewFrameHandler(conn net.Conn, frameWriter *protocol.FrameWriter, localHost
DisableKeepAlives: false, // Enable keep-alive for connection reuse
TLSHandshakeTimeout: 5 * time.Second, // Reduced from 10s for faster failure detection
TLSClientConfig: tlsConfig,
ResponseHeaderTimeout: 15 * time.Second, // Reduced from 30s for faster timeout
ResponseHeaderTimeout: 15 * time.Second, // Reduced from 30s for faster timeout
ExpectContinueTimeout: 500 * time.Millisecond, // Reduced from 1s for better responsiveness
WriteBufferSize: 32 * 1024, // 32KB write buffer
ReadBufferSize: 32 * 1024, // 32KB read buffer
WriteBufferSize: 32 * 1024, // 32KB write buffer
ReadBufferSize: 32 * 1024, // 32KB read buffer
DialContext: (&net.Dialer{
Timeout: 3 * time.Second, // Reduced from 5s for faster connection attempts
KeepAlive: 30 * time.Second, // Keep TCP keepalive
@@ -226,9 +227,6 @@ func (h *FrameHandler) handleLocalResponse(stream *Stream) {
}
n, err := stream.LocalConn.Read(buf)
if err != nil {
break
}
if n > 0 {
if h.isClosedCheck != nil && h.isClosedCheck() {
@@ -264,6 +262,10 @@ func (h *FrameHandler) handleLocalResponse(stream *Stream) {
h.stats.AddBytesOut(int64(len(payload)))
}
if err != nil {
break
}
}
}

View File

@@ -1,11 +1,13 @@
package proxy
import (
"context"
"fmt"
"io"
"net/http"
"net/url"
"strings"
"sync"
"time"
json "github.com/goccy/go-json"
@@ -90,8 +92,12 @@ func (h *Handler) ServeHTTP(w http.ResponseWriter, r *http.Request) {
func (h *Handler) handleAdaptiveRequest(w http.ResponseWriter, r *http.Request, transport tunnel.Transport, requestID string, subdomain string) {
const streamingThreshold int64 = 1 * 1024 * 1024
ctx := r.Context()
var cancelTransport func()
if transport != nil {
h.responses.RegisterCancelFunc(requestID, func() {
cancelOnce := sync.Once{}
cancelFunc := func() {
header := protocol.DataHeader{
StreamID: requestID,
RequestID: requestID,
@@ -111,8 +117,13 @@ func (h *Handler) handleAdaptiveRequest(w http.ResponseWriter, r *http.Request,
zap.Error(err),
)
}
})
}
cancelTransport = func() {
cancelOnce.Do(cancelFunc)
}
h.responses.RegisterCancelFunc(requestID, cancelTransport)
defer h.responses.CleanupCancelFunc(requestID)
}
@@ -138,7 +149,7 @@ func (h *Handler) handleAdaptiveRequest(w http.ResponseWriter, r *http.Request,
}
if err == io.EOF {
r.Body.Close()
h.sendBufferedRequest(w, r, transport, requestID, subdomain, buffer)
h.sendBufferedRequest(ctx, w, r, transport, requestID, subdomain, cancelTransport, buffer)
return
}
if err != nil {
@@ -155,14 +166,14 @@ func (h *Handler) handleAdaptiveRequest(w http.ResponseWriter, r *http.Request,
if !hitThreshold {
r.Body.Close()
h.sendBufferedRequest(w, r, transport, requestID, subdomain, buffer)
h.sendBufferedRequest(ctx, w, r, transport, requestID, subdomain, cancelTransport, buffer)
return
}
h.streamLargeRequest(w, r, transport, requestID, subdomain, buffer)
h.streamLargeRequest(ctx, w, r, transport, requestID, subdomain, cancelTransport, buffer)
}
func (h *Handler) sendBufferedRequest(w http.ResponseWriter, r *http.Request, transport tunnel.Transport, requestID string, subdomain string, body []byte) {
func (h *Handler) sendBufferedRequest(ctx context.Context, w http.ResponseWriter, r *http.Request, transport tunnel.Transport, requestID string, subdomain string, cancelTransport func(), body []byte) {
headers := h.headerPool.Get()
h.headerPool.CloneWithExtra(headers, r.Header, "Host", r.Host)
@@ -220,6 +231,15 @@ func (h *Handler) sendBufferedRequest(w http.ResponseWriter, r *http.Request, tr
h.writeHTTPResponse(w, respMsg, subdomain, r)
case <-streamingDone:
// Streaming response has been fully written by SendStreamingChunk
case <-ctx.Done():
if cancelTransport != nil {
cancelTransport()
}
h.logger.Debug("HTTP request context cancelled",
zap.String("request_id", requestID),
zap.String("subdomain", subdomain),
)
return
case <-time.After(5 * time.Minute):
h.logger.Error("Request timeout",
zap.String("request_id", requestID),
@@ -229,7 +249,7 @@ func (h *Handler) sendBufferedRequest(w http.ResponseWriter, r *http.Request, tr
}
}
func (h *Handler) streamLargeRequest(w http.ResponseWriter, r *http.Request, transport tunnel.Transport, requestID string, subdomain string, bufferedData []byte) {
func (h *Handler) streamLargeRequest(ctx context.Context, w http.ResponseWriter, r *http.Request, transport tunnel.Transport, requestID string, subdomain string, cancelTransport func(), bufferedData []byte) {
headers := h.headerPool.Get()
h.headerPool.CloneWithExtra(headers, r.Header, "Host", r.Host)
@@ -327,6 +347,19 @@ func (h *Handler) streamLargeRequest(w http.ResponseWriter, r *http.Request, tra
defer h.bufferPool.PutMedium(streamBufPtr)
buffer := (*streamBufPtr)[:pool.MediumBufferSize]
for {
select {
case <-ctx.Done():
if cancelTransport != nil {
cancelTransport()
}
h.logger.Debug("Streaming request cancelled via context",
zap.String("request_id", requestID),
zap.String("subdomain", subdomain),
)
return
default:
}
n, readErr := r.Body.Read(buffer)
if n > 0 {
isLast := readErr == io.EOF
@@ -422,6 +455,15 @@ func (h *Handler) streamLargeRequest(w http.ResponseWriter, r *http.Request, tra
h.writeHTTPResponse(w, respMsg, subdomain, r)
case <-streamingDone:
// Streaming response has been fully written by SendStreamingChunk
case <-ctx.Done():
if cancelTransport != nil {
cancelTransport()
}
h.logger.Debug("Streaming HTTP request context cancelled",
zap.String("request_id", requestID),
zap.String("subdomain", subdomain),
)
return
case <-time.After(5 * time.Minute):
h.logger.Error("Streaming request timeout",
zap.String("request_id", requestID),
@@ -444,12 +486,12 @@ func (h *Handler) writeHTTPResponse(w http.ResponseWriter, resp *protocol.HTTPRe
// Skip hop-by-hop headers completely using canonical key comparison
if canonicalKey == "Connection" ||
canonicalKey == "Keep-Alive" ||
canonicalKey == "Transfer-Encoding" ||
canonicalKey == "Upgrade" ||
canonicalKey == "Proxy-Connection" ||
canonicalKey == "Te" ||
canonicalKey == "Trailer" {
canonicalKey == "Keep-Alive" ||
canonicalKey == "Transfer-Encoding" ||
canonicalKey == "Upgrade" ||
canonicalKey == "Proxy-Connection" ||
canonicalKey == "Te" ||
canonicalKey == "Trailer" {
continue
}

View File

@@ -9,6 +9,7 @@ import (
"time"
"drip/internal/shared/protocol"
"go.uber.org/zap"
)
@@ -162,12 +163,12 @@ func (h *ResponseHandler) SendStreamingHead(requestID string, head *protocol.HTT
// Skip ALL hop-by-hop headers
if canonicalKey == "Connection" ||
canonicalKey == "Keep-Alive" ||
canonicalKey == "Transfer-Encoding" ||
canonicalKey == "Upgrade" ||
canonicalKey == "Proxy-Connection" ||
canonicalKey == "Te" ||
canonicalKey == "Trailer" {
canonicalKey == "Keep-Alive" ||
canonicalKey == "Transfer-Encoding" ||
canonicalKey == "Upgrade" ||
canonicalKey == "Proxy-Connection" ||
canonicalKey == "Te" ||
canonicalKey == "Trailer" {
continue
}
@@ -349,7 +350,7 @@ func (h *ResponseHandler) cleanupLoop() {
func (h *ResponseHandler) cleanupExpiredChannels() {
now := time.Now()
timeout := 30 * time.Second
timeout := 5 * time.Minute
streamingTimeout := 5 * time.Minute
h.mu.Lock()

View File

@@ -2,6 +2,7 @@ package tcp
import (
"bufio"
"context"
"errors"
"fmt"
"io"
@@ -41,6 +42,8 @@ type Connection struct {
httpHandler http.Handler
responseChans HTTPResponseHandler
tunnelType protocol.TunnelType // Track tunnel type
ctx context.Context
cancel context.CancelFunc
}
// HTTPResponseHandler interface for response channel operations
@@ -56,6 +59,7 @@ type HTTPResponseHandler interface {
// NewConnection creates a new connection handler
func NewConnection(conn net.Conn, authToken string, manager *tunnel.Manager, logger *zap.Logger, portAlloc *PortAllocator, domain string, publicPort int, httpHandler http.Handler, responseChans HTTPResponseHandler) *Connection {
ctx, cancel := context.WithCancel(context.Background())
return &Connection{
conn: conn,
authToken: authToken,
@@ -68,6 +72,8 @@ func NewConnection(conn net.Conn, authToken string, manager *tunnel.Manager, log
responseChans: responseChans,
stopCh: make(chan struct{}),
lastHeartbeat: time.Now(),
ctx: ctx,
cancel: cancel,
}
}
@@ -289,6 +295,10 @@ func (c *Connection) handleHTTPRequest(reader *bufio.Reader) error {
return fmt.Errorf("failed to parse HTTP request: %w", err)
}
if c.ctx != nil {
req = req.WithContext(c.ctx)
}
c.logger.Info("Processing HTTP request on TCP port",
zap.String("method", req.Method),
zap.String("url", req.URL.String()),
@@ -588,6 +598,10 @@ func (c *Connection) Close() {
close(c.stopCh)
if c.cancel != nil {
c.cancel()
}
if c.frameWriter != nil {
c.frameWriter.Flush()
c.frameWriter.Close()
@@ -621,7 +635,7 @@ func (c *Connection) GetSubdomain() string {
// httpResponseWriter implements http.ResponseWriter for writing to a net.Conn
type httpResponseWriter struct {
conn net.Conn
writer *bufio.Writer // Buffered writer for efficient I/O
writer *bufio.Writer // Buffered writer for efficient I/O
header http.Header
statusCode int
headerWritten bool