diff --git a/internal/wsrelay/http.go b/internal/wsrelay/http.go index 52ea2a1d..abdb277c 100644 --- a/internal/wsrelay/http.go +++ b/internal/wsrelay/http.go @@ -124,32 +124,47 @@ func (m *Manager) Stream(ctx context.Context, provider string, req *HTTPRequest) out := make(chan StreamEvent) go func() { defer close(out) + send := func(ev StreamEvent) bool { + if ctx == nil { + out <- ev + return true + } + select { + case <-ctx.Done(): + return false + case out <- ev: + return true + } + } for { select { case <-ctx.Done(): - out <- StreamEvent{Err: ctx.Err()} return case msg, ok := <-respCh: if !ok { - out <- StreamEvent{Err: errors.New("wsrelay: stream closed")} + _ = send(StreamEvent{Err: errors.New("wsrelay: stream closed")}) return } switch msg.Type { case MessageTypeStreamStart: resp := decodeResponse(msg.Payload) - out <- StreamEvent{Type: MessageTypeStreamStart, Status: resp.Status, Headers: resp.Headers} + if okSend := send(StreamEvent{Type: MessageTypeStreamStart, Status: resp.Status, Headers: resp.Headers}); !okSend { + return + } case MessageTypeStreamChunk: chunk := decodeChunk(msg.Payload) - out <- StreamEvent{Type: MessageTypeStreamChunk, Payload: chunk} + if okSend := send(StreamEvent{Type: MessageTypeStreamChunk, Payload: chunk}); !okSend { + return + } case MessageTypeStreamEnd: - out <- StreamEvent{Type: MessageTypeStreamEnd} + _ = send(StreamEvent{Type: MessageTypeStreamEnd}) return case MessageTypeError: - out <- StreamEvent{Type: MessageTypeError, Err: decodeError(msg.Payload)} + _ = send(StreamEvent{Type: MessageTypeError, Err: decodeError(msg.Payload)}) return case MessageTypeHTTPResp: resp := decodeResponse(msg.Payload) - out <- StreamEvent{Type: MessageTypeHTTPResp, Status: resp.Status, Headers: resp.Headers, Payload: resp.Body} + _ = send(StreamEvent{Type: MessageTypeHTTPResp, Status: resp.Status, Headers: resp.Headers, Payload: resp.Body}) return default: } diff --git a/sdk/api/handlers/handlers.go b/sdk/api/handlers/handlers.go index 7108749d..b1da9664 100644 --- a/sdk/api/handlers/handlers.go +++ b/sdk/api/handlers/handlers.go @@ -506,6 +506,32 @@ func (h *BaseAPIHandler) ExecuteStreamWithAuthManager(ctx context.Context, handl bootstrapRetries := 0 maxBootstrapRetries := StreamingBootstrapRetries(h.Cfg) + sendErr := func(msg *interfaces.ErrorMessage) bool { + if ctx == nil { + errChan <- msg + return true + } + select { + case <-ctx.Done(): + return false + case errChan <- msg: + return true + } + } + + sendData := func(chunk []byte) bool { + if ctx == nil { + dataChan <- chunk + return true + } + select { + case <-ctx.Done(): + return false + case dataChan <- chunk: + return true + } + } + bootstrapEligible := func(err error) bool { status := statusFromError(err) if status == 0 { @@ -565,12 +591,14 @@ func (h *BaseAPIHandler) ExecuteStreamWithAuthManager(ctx context.Context, handl addon = hdr.Clone() } } - errChan <- &interfaces.ErrorMessage{StatusCode: status, Error: streamErr, Addon: addon} + _ = sendErr(&interfaces.ErrorMessage{StatusCode: status, Error: streamErr, Addon: addon}) return } if len(chunk.Payload) > 0 { sentPayload = true - dataChan <- cloneBytes(chunk.Payload) + if okSendData := sendData(cloneBytes(chunk.Payload)); !okSendData { + return + } } } } diff --git a/sdk/cliproxy/auth/conductor.go b/sdk/cliproxy/auth/conductor.go index fd7543b4..3a64c8c3 100644 --- a/sdk/cliproxy/auth/conductor.go +++ b/sdk/cliproxy/auth/conductor.go @@ -718,6 +718,7 @@ func (m *Manager) executeStreamMixedOnce(ctx context.Context, providers []string go func(streamCtx context.Context, streamAuth *Auth, streamProvider string, streamChunks <-chan cliproxyexecutor.StreamChunk) { defer close(out) var failed bool + forward := true for chunk := range streamChunks { if chunk.Err != nil && !failed { failed = true @@ -728,7 +729,18 @@ func (m *Manager) executeStreamMixedOnce(ctx context.Context, providers []string } m.MarkResult(streamCtx, Result{AuthID: streamAuth.ID, Provider: streamProvider, Model: routeModel, Success: false, Error: rerr}) } - out <- chunk + if !forward { + continue + } + if streamCtx == nil { + out <- chunk + continue + } + select { + case <-streamCtx.Done(): + forward = false + case out <- chunk: + } } if !failed { m.MarkResult(streamCtx, Result{AuthID: streamAuth.ID, Provider: streamProvider, Model: routeModel, Success: true})