mirror of
https://github.com/Gouryella/drip.git
synced 2026-02-23 21:00:44 +00:00
feat(client): Implements a response cancellation mechanism to improve resource management
Added support for HTTP response context cancellation, including logic for registering, triggering, and cleaning up cancellation functions. Introduced a responseCancels mapping and corresponding synchronization lock in FrameHandler to track and control the request lifecycle. When the data frame type is closed, actively called cancelResponse to release related resources. Simultaneously, during the response body reading process, identified context cancellation or timeout errors and prematurely terminated the processing flow to avoid invalid operations.
This commit is contained in:
@@ -4,6 +4,7 @@ import (
|
||||
"bytes"
|
||||
"context"
|
||||
"crypto/tls"
|
||||
"errors"
|
||||
"fmt"
|
||||
"io"
|
||||
"net"
|
||||
@@ -28,6 +29,8 @@ type FrameHandler struct {
|
||||
streamMu sync.RWMutex
|
||||
streamingRequests map[string]*StreamingRequest
|
||||
streamingReqMu sync.RWMutex
|
||||
responseCancels map[string]context.CancelFunc
|
||||
responseCancelMu sync.RWMutex
|
||||
tunnelType protocol.TunnelType
|
||||
httpClient *http.Client
|
||||
stats *TrafficStats
|
||||
@@ -77,6 +80,7 @@ func NewFrameHandler(conn net.Conn, frameWriter *protocol.FrameWriter, localHost
|
||||
logger: logger,
|
||||
streams: make(map[string]*Stream),
|
||||
streamingRequests: make(map[string]*StreamingRequest),
|
||||
responseCancels: make(map[string]context.CancelFunc),
|
||||
tunnelType: tunnelType,
|
||||
stats: NewTrafficStats(),
|
||||
isClosedCheck: isClosedCheck,
|
||||
@@ -129,6 +133,11 @@ func (h *FrameHandler) HandleDataFrame(frame *protocol.Frame) error {
|
||||
}
|
||||
|
||||
if header.Type == protocol.DataTypeClose {
|
||||
cancelID := header.RequestID
|
||||
if cancelID == "" {
|
||||
cancelID = header.StreamID
|
||||
}
|
||||
h.cancelResponse(cancelID)
|
||||
h.closeStream(header.StreamID)
|
||||
return nil
|
||||
}
|
||||
@@ -386,6 +395,16 @@ func (h *FrameHandler) adaptiveHTTPResponse(streamID, requestID string, resp *ht
|
||||
// Clean response headers - remove hop-by-hop headers that are invalid after proxying
|
||||
cleanedHeaders := h.cleanResponseHeaders(resp.Header)
|
||||
|
||||
cancelID := requestID
|
||||
if cancelID == "" {
|
||||
cancelID = streamID
|
||||
}
|
||||
|
||||
h.registerResponseCancel(cancelID, func() {
|
||||
resp.Body.Close()
|
||||
})
|
||||
defer h.unregisterResponseCancel(cancelID)
|
||||
|
||||
// First send headers
|
||||
httpHead := protocol.HTTPResponseHead{
|
||||
StatusCode: resp.StatusCode,
|
||||
@@ -497,6 +516,9 @@ func (h *FrameHandler) adaptiveHTTPResponse(streamID, requestID string, resp *ht
|
||||
break
|
||||
}
|
||||
if readErr != nil {
|
||||
if errors.Is(readErr, context.Canceled) || errors.Is(readErr, context.DeadlineExceeded) || errors.Is(readErr, http.ErrBodyReadAfterClose) || errors.Is(readErr, net.ErrClosed) {
|
||||
return nil
|
||||
}
|
||||
return fmt.Errorf("read response body: %w", readErr)
|
||||
}
|
||||
}
|
||||
@@ -529,6 +551,16 @@ func (h *FrameHandler) streamHTTPResponse(streamID, requestID string, resp *http
|
||||
return nil
|
||||
}
|
||||
|
||||
cancelID := requestID
|
||||
if cancelID == "" {
|
||||
cancelID = streamID
|
||||
}
|
||||
|
||||
h.registerResponseCancel(cancelID, func() {
|
||||
resp.Body.Close()
|
||||
})
|
||||
defer h.unregisterResponseCancel(cancelID)
|
||||
|
||||
// Clean response headers - remove hop-by-hop headers that are invalid after proxying
|
||||
cleanedHeaders := h.cleanResponseHeaders(resp.Header)
|
||||
|
||||
@@ -620,6 +652,9 @@ func (h *FrameHandler) streamHTTPResponse(streamID, requestID string, resp *http
|
||||
break
|
||||
}
|
||||
if readErr != nil {
|
||||
if errors.Is(readErr, context.Canceled) || errors.Is(readErr, context.DeadlineExceeded) || errors.Is(readErr, http.ErrBodyReadAfterClose) || errors.Is(readErr, net.ErrClosed) {
|
||||
return nil
|
||||
}
|
||||
return fmt.Errorf("read response body: %w", readErr)
|
||||
}
|
||||
}
|
||||
@@ -1076,3 +1111,32 @@ func (h *FrameHandler) closeStreamingRequest(requestID string, streamingReq *Str
|
||||
close(streamingReq.Done)
|
||||
streamingReq.mu.Unlock()
|
||||
}
|
||||
|
||||
func (h *FrameHandler) registerResponseCancel(id string, cancel context.CancelFunc) {
|
||||
if cancel == nil {
|
||||
return
|
||||
}
|
||||
|
||||
h.responseCancelMu.Lock()
|
||||
h.responseCancels[id] = cancel
|
||||
h.responseCancelMu.Unlock()
|
||||
}
|
||||
|
||||
func (h *FrameHandler) cancelResponse(id string) {
|
||||
h.responseCancelMu.Lock()
|
||||
cancel := h.responseCancels[id]
|
||||
if cancel != nil {
|
||||
delete(h.responseCancels, id)
|
||||
}
|
||||
h.responseCancelMu.Unlock()
|
||||
|
||||
if cancel != nil {
|
||||
cancel()
|
||||
}
|
||||
}
|
||||
|
||||
func (h *FrameHandler) unregisterResponseCancel(id string) {
|
||||
h.responseCancelMu.Lock()
|
||||
delete(h.responseCancels, id)
|
||||
h.responseCancelMu.Unlock()
|
||||
}
|
||||
|
||||
@@ -77,6 +77,32 @@ func (h *Handler) ServeHTTP(w http.ResponseWriter, r *http.Request) {
|
||||
func (h *Handler) handleAdaptiveRequest(w http.ResponseWriter, r *http.Request, transport tunnel.Transport, requestID string, subdomain string) {
|
||||
const streamingThreshold int64 = 1 * 1024 * 1024
|
||||
|
||||
if transport != nil {
|
||||
h.responses.RegisterCancelFunc(requestID, func() {
|
||||
header := protocol.DataHeader{
|
||||
StreamID: requestID,
|
||||
RequestID: requestID,
|
||||
Type: protocol.DataTypeClose,
|
||||
IsLast: true,
|
||||
}
|
||||
|
||||
payload, poolBuffer, err := protocol.EncodeDataPayloadPooled(header, nil)
|
||||
if err != nil {
|
||||
return
|
||||
}
|
||||
|
||||
frame := protocol.NewFramePooled(protocol.FrameTypeData, payload, poolBuffer)
|
||||
if err := transport.SendFrame(frame); err != nil {
|
||||
h.logger.Debug("Failed to send cancel frame to client",
|
||||
zap.String("request_id", requestID),
|
||||
zap.Error(err),
|
||||
)
|
||||
}
|
||||
})
|
||||
|
||||
defer h.responses.CleanupCancelFunc(requestID)
|
||||
}
|
||||
|
||||
buffer := make([]byte, 0, streamingThreshold)
|
||||
tempBuf := make([]byte, 32*1024)
|
||||
|
||||
|
||||
@@ -33,6 +33,7 @@ type streamingResponseEntry struct {
|
||||
type ResponseHandler struct {
|
||||
channels map[string]*responseChanEntry
|
||||
streamingChannels map[string]*streamingResponseEntry
|
||||
cancelFuncs map[string]func()
|
||||
mu sync.RWMutex
|
||||
logger *zap.Logger
|
||||
stopCh chan struct{}
|
||||
@@ -43,6 +44,7 @@ func NewResponseHandler(logger *zap.Logger) *ResponseHandler {
|
||||
h := &ResponseHandler{
|
||||
channels: make(map[string]*responseChanEntry),
|
||||
streamingChannels: make(map[string]*streamingResponseEntry),
|
||||
cancelFuncs: make(map[string]func()),
|
||||
logger: logger,
|
||||
stopCh: make(chan struct{}),
|
||||
}
|
||||
@@ -86,6 +88,17 @@ func (h *ResponseHandler) CreateStreamingResponse(requestID string, w http.Respo
|
||||
return done
|
||||
}
|
||||
|
||||
// RegisterCancelFunc registers a callback to be invoked when the downstream disconnects.
|
||||
func (h *ResponseHandler) RegisterCancelFunc(requestID string, cancel func()) {
|
||||
if cancel == nil {
|
||||
return
|
||||
}
|
||||
|
||||
h.mu.Lock()
|
||||
h.cancelFuncs[requestID] = cancel
|
||||
h.mu.Unlock()
|
||||
}
|
||||
|
||||
// GetResponseChan gets the response channel for a request ID
|
||||
func (h *ResponseHandler) GetResponseChan(requestID string) <-chan *protocol.HTTPResponse {
|
||||
h.mu.RLock()
|
||||
@@ -215,6 +228,7 @@ func (h *ResponseHandler) SendStreamingChunk(requestID string, chunk []byte, isL
|
||||
default:
|
||||
close(entry.done)
|
||||
}
|
||||
h.triggerCancel(requestID)
|
||||
return nil
|
||||
}
|
||||
select {
|
||||
@@ -222,6 +236,7 @@ func (h *ResponseHandler) SendStreamingChunk(requestID string, chunk []byte, isL
|
||||
default:
|
||||
close(entry.done)
|
||||
}
|
||||
h.triggerCancel(requestID)
|
||||
return nil
|
||||
}
|
||||
|
||||
@@ -265,6 +280,22 @@ func isClientDisconnectError(err error) bool {
|
||||
strings.Contains(errStr, "use of closed network connection")
|
||||
}
|
||||
|
||||
// triggerCancel invokes and removes the cancel callback for a request.
|
||||
func (h *ResponseHandler) triggerCancel(requestID string) {
|
||||
h.mu.Lock()
|
||||
cancel := h.cancelFuncs[requestID]
|
||||
if cancel != nil {
|
||||
delete(h.cancelFuncs, requestID)
|
||||
}
|
||||
h.mu.Unlock()
|
||||
|
||||
if cancel != nil {
|
||||
go func() {
|
||||
cancel()
|
||||
}()
|
||||
}
|
||||
}
|
||||
|
||||
func (h *ResponseHandler) CleanupResponseChan(requestID string) {
|
||||
h.mu.Lock()
|
||||
defer h.mu.Unlock()
|
||||
@@ -289,6 +320,13 @@ func (h *ResponseHandler) CleanupStreamingResponse(requestID string) {
|
||||
}
|
||||
}
|
||||
|
||||
// CleanupCancelFunc removes a registered cancel callback.
|
||||
func (h *ResponseHandler) CleanupCancelFunc(requestID string) {
|
||||
h.mu.Lock()
|
||||
delete(h.cancelFuncs, requestID)
|
||||
h.mu.Unlock()
|
||||
}
|
||||
|
||||
func (h *ResponseHandler) GetPendingCount() int {
|
||||
h.mu.RLock()
|
||||
defer h.mu.RUnlock()
|
||||
@@ -318,6 +356,7 @@ func (h *ResponseHandler) cleanupExpiredChannels() {
|
||||
defer h.mu.Unlock()
|
||||
|
||||
expiredCount := 0
|
||||
cancelList := make([]string, 0)
|
||||
for requestID, entry := range h.channels {
|
||||
if now.Sub(entry.createdAt) > timeout {
|
||||
close(entry.ch)
|
||||
@@ -334,10 +373,18 @@ func (h *ResponseHandler) cleanupExpiredChannels() {
|
||||
close(entry.done)
|
||||
}
|
||||
delete(h.streamingChannels, requestID)
|
||||
cancelList = append(cancelList, requestID)
|
||||
expiredCount++
|
||||
}
|
||||
}
|
||||
|
||||
for _, requestID := range cancelList {
|
||||
if cancel := h.cancelFuncs[requestID]; cancel != nil {
|
||||
delete(h.cancelFuncs, requestID)
|
||||
go cancel()
|
||||
}
|
||||
}
|
||||
|
||||
if expiredCount > 0 {
|
||||
h.logger.Debug("Cleaned up expired response channels",
|
||||
zap.Int("count", expiredCount),
|
||||
@@ -365,4 +412,9 @@ func (h *ResponseHandler) Close() {
|
||||
}
|
||||
}
|
||||
h.streamingChannels = make(map[string]*streamingResponseEntry)
|
||||
|
||||
for _, cancel := range h.cancelFuncs {
|
||||
cancel()
|
||||
}
|
||||
h.cancelFuncs = make(map[string]func())
|
||||
}
|
||||
|
||||
Reference in New Issue
Block a user