feat(client): Implements a response cancellation mechanism to improve resource management

Added support for HTTP response context cancellation, including logic for registering, triggering, and cleaning up cancellation functions.

Introduced a responseCancels mapping and corresponding synchronization lock in FrameHandler to track and control the request lifecycle.

When the data frame type is closed, actively called cancelResponse to release related resources.

Simultaneously, during the response body reading process, identified context cancellation or timeout errors and prematurely terminated the processing flow to avoid invalid operations.
This commit is contained in:
Gouryella
2025-12-06 23:58:31 +08:00
parent 133a3892af
commit bbef7efb5e
3 changed files with 142 additions and 0 deletions

View File

@@ -4,6 +4,7 @@ import (
"bytes"
"context"
"crypto/tls"
"errors"
"fmt"
"io"
"net"
@@ -28,6 +29,8 @@ type FrameHandler struct {
streamMu sync.RWMutex
streamingRequests map[string]*StreamingRequest
streamingReqMu sync.RWMutex
responseCancels map[string]context.CancelFunc
responseCancelMu sync.RWMutex
tunnelType protocol.TunnelType
httpClient *http.Client
stats *TrafficStats
@@ -77,6 +80,7 @@ func NewFrameHandler(conn net.Conn, frameWriter *protocol.FrameWriter, localHost
logger: logger,
streams: make(map[string]*Stream),
streamingRequests: make(map[string]*StreamingRequest),
responseCancels: make(map[string]context.CancelFunc),
tunnelType: tunnelType,
stats: NewTrafficStats(),
isClosedCheck: isClosedCheck,
@@ -129,6 +133,11 @@ func (h *FrameHandler) HandleDataFrame(frame *protocol.Frame) error {
}
if header.Type == protocol.DataTypeClose {
cancelID := header.RequestID
if cancelID == "" {
cancelID = header.StreamID
}
h.cancelResponse(cancelID)
h.closeStream(header.StreamID)
return nil
}
@@ -386,6 +395,16 @@ func (h *FrameHandler) adaptiveHTTPResponse(streamID, requestID string, resp *ht
// Clean response headers - remove hop-by-hop headers that are invalid after proxying
cleanedHeaders := h.cleanResponseHeaders(resp.Header)
cancelID := requestID
if cancelID == "" {
cancelID = streamID
}
h.registerResponseCancel(cancelID, func() {
resp.Body.Close()
})
defer h.unregisterResponseCancel(cancelID)
// First send headers
httpHead := protocol.HTTPResponseHead{
StatusCode: resp.StatusCode,
@@ -497,6 +516,9 @@ func (h *FrameHandler) adaptiveHTTPResponse(streamID, requestID string, resp *ht
break
}
if readErr != nil {
if errors.Is(readErr, context.Canceled) || errors.Is(readErr, context.DeadlineExceeded) || errors.Is(readErr, http.ErrBodyReadAfterClose) || errors.Is(readErr, net.ErrClosed) {
return nil
}
return fmt.Errorf("read response body: %w", readErr)
}
}
@@ -529,6 +551,16 @@ func (h *FrameHandler) streamHTTPResponse(streamID, requestID string, resp *http
return nil
}
cancelID := requestID
if cancelID == "" {
cancelID = streamID
}
h.registerResponseCancel(cancelID, func() {
resp.Body.Close()
})
defer h.unregisterResponseCancel(cancelID)
// Clean response headers - remove hop-by-hop headers that are invalid after proxying
cleanedHeaders := h.cleanResponseHeaders(resp.Header)
@@ -620,6 +652,9 @@ func (h *FrameHandler) streamHTTPResponse(streamID, requestID string, resp *http
break
}
if readErr != nil {
if errors.Is(readErr, context.Canceled) || errors.Is(readErr, context.DeadlineExceeded) || errors.Is(readErr, http.ErrBodyReadAfterClose) || errors.Is(readErr, net.ErrClosed) {
return nil
}
return fmt.Errorf("read response body: %w", readErr)
}
}
@@ -1076,3 +1111,32 @@ func (h *FrameHandler) closeStreamingRequest(requestID string, streamingReq *Str
close(streamingReq.Done)
streamingReq.mu.Unlock()
}
func (h *FrameHandler) registerResponseCancel(id string, cancel context.CancelFunc) {
if cancel == nil {
return
}
h.responseCancelMu.Lock()
h.responseCancels[id] = cancel
h.responseCancelMu.Unlock()
}
func (h *FrameHandler) cancelResponse(id string) {
h.responseCancelMu.Lock()
cancel := h.responseCancels[id]
if cancel != nil {
delete(h.responseCancels, id)
}
h.responseCancelMu.Unlock()
if cancel != nil {
cancel()
}
}
func (h *FrameHandler) unregisterResponseCancel(id string) {
h.responseCancelMu.Lock()
delete(h.responseCancels, id)
h.responseCancelMu.Unlock()
}

View File

@@ -77,6 +77,32 @@ func (h *Handler) ServeHTTP(w http.ResponseWriter, r *http.Request) {
func (h *Handler) handleAdaptiveRequest(w http.ResponseWriter, r *http.Request, transport tunnel.Transport, requestID string, subdomain string) {
const streamingThreshold int64 = 1 * 1024 * 1024
if transport != nil {
h.responses.RegisterCancelFunc(requestID, func() {
header := protocol.DataHeader{
StreamID: requestID,
RequestID: requestID,
Type: protocol.DataTypeClose,
IsLast: true,
}
payload, poolBuffer, err := protocol.EncodeDataPayloadPooled(header, nil)
if err != nil {
return
}
frame := protocol.NewFramePooled(protocol.FrameTypeData, payload, poolBuffer)
if err := transport.SendFrame(frame); err != nil {
h.logger.Debug("Failed to send cancel frame to client",
zap.String("request_id", requestID),
zap.Error(err),
)
}
})
defer h.responses.CleanupCancelFunc(requestID)
}
buffer := make([]byte, 0, streamingThreshold)
tempBuf := make([]byte, 32*1024)

View File

@@ -33,6 +33,7 @@ type streamingResponseEntry struct {
type ResponseHandler struct {
channels map[string]*responseChanEntry
streamingChannels map[string]*streamingResponseEntry
cancelFuncs map[string]func()
mu sync.RWMutex
logger *zap.Logger
stopCh chan struct{}
@@ -43,6 +44,7 @@ func NewResponseHandler(logger *zap.Logger) *ResponseHandler {
h := &ResponseHandler{
channels: make(map[string]*responseChanEntry),
streamingChannels: make(map[string]*streamingResponseEntry),
cancelFuncs: make(map[string]func()),
logger: logger,
stopCh: make(chan struct{}),
}
@@ -86,6 +88,17 @@ func (h *ResponseHandler) CreateStreamingResponse(requestID string, w http.Respo
return done
}
// RegisterCancelFunc registers a callback to be invoked when the downstream disconnects.
func (h *ResponseHandler) RegisterCancelFunc(requestID string, cancel func()) {
if cancel == nil {
return
}
h.mu.Lock()
h.cancelFuncs[requestID] = cancel
h.mu.Unlock()
}
// GetResponseChan gets the response channel for a request ID
func (h *ResponseHandler) GetResponseChan(requestID string) <-chan *protocol.HTTPResponse {
h.mu.RLock()
@@ -215,6 +228,7 @@ func (h *ResponseHandler) SendStreamingChunk(requestID string, chunk []byte, isL
default:
close(entry.done)
}
h.triggerCancel(requestID)
return nil
}
select {
@@ -222,6 +236,7 @@ func (h *ResponseHandler) SendStreamingChunk(requestID string, chunk []byte, isL
default:
close(entry.done)
}
h.triggerCancel(requestID)
return nil
}
@@ -265,6 +280,22 @@ func isClientDisconnectError(err error) bool {
strings.Contains(errStr, "use of closed network connection")
}
// triggerCancel invokes and removes the cancel callback for a request.
func (h *ResponseHandler) triggerCancel(requestID string) {
h.mu.Lock()
cancel := h.cancelFuncs[requestID]
if cancel != nil {
delete(h.cancelFuncs, requestID)
}
h.mu.Unlock()
if cancel != nil {
go func() {
cancel()
}()
}
}
func (h *ResponseHandler) CleanupResponseChan(requestID string) {
h.mu.Lock()
defer h.mu.Unlock()
@@ -289,6 +320,13 @@ func (h *ResponseHandler) CleanupStreamingResponse(requestID string) {
}
}
// CleanupCancelFunc removes a registered cancel callback.
func (h *ResponseHandler) CleanupCancelFunc(requestID string) {
h.mu.Lock()
delete(h.cancelFuncs, requestID)
h.mu.Unlock()
}
func (h *ResponseHandler) GetPendingCount() int {
h.mu.RLock()
defer h.mu.RUnlock()
@@ -318,6 +356,7 @@ func (h *ResponseHandler) cleanupExpiredChannels() {
defer h.mu.Unlock()
expiredCount := 0
cancelList := make([]string, 0)
for requestID, entry := range h.channels {
if now.Sub(entry.createdAt) > timeout {
close(entry.ch)
@@ -334,10 +373,18 @@ func (h *ResponseHandler) cleanupExpiredChannels() {
close(entry.done)
}
delete(h.streamingChannels, requestID)
cancelList = append(cancelList, requestID)
expiredCount++
}
}
for _, requestID := range cancelList {
if cancel := h.cancelFuncs[requestID]; cancel != nil {
delete(h.cancelFuncs, requestID)
go cancel()
}
}
if expiredCount > 0 {
h.logger.Debug("Cleaned up expired response channels",
zap.Int("count", expiredCount),
@@ -365,4 +412,9 @@ func (h *ResponseHandler) Close() {
}
}
h.streamingChannels = make(map[string]*streamingResponseEntry)
for _, cancel := range h.cancelFuncs {
cancel()
}
h.cancelFuncs = make(map[string]func())
}