Merge branch 'main' into dev

This commit is contained in:
Luis Pater
2026-04-02 21:21:26 +08:00
4 changed files with 339 additions and 9 deletions

View File

@@ -136,6 +136,8 @@ type authAwareStreamExecutor struct {
type invalidJSONStreamExecutor struct{}
type splitResponsesEventStreamExecutor struct{}
func (e *invalidJSONStreamExecutor) Identifier() string { return "codex" }
func (e *invalidJSONStreamExecutor) Execute(context.Context, *coreauth.Auth, coreexecutor.Request, coreexecutor.Options) (coreexecutor.Response, error) {
@@ -165,6 +167,36 @@ func (e *invalidJSONStreamExecutor) HttpRequest(ctx context.Context, auth *corea
}
}
func (e *splitResponsesEventStreamExecutor) Identifier() string { return "split-sse" }
func (e *splitResponsesEventStreamExecutor) Execute(context.Context, *coreauth.Auth, coreexecutor.Request, coreexecutor.Options) (coreexecutor.Response, error) {
return coreexecutor.Response{}, &coreauth.Error{Code: "not_implemented", Message: "Execute not implemented"}
}
func (e *splitResponsesEventStreamExecutor) ExecuteStream(context.Context, *coreauth.Auth, coreexecutor.Request, coreexecutor.Options) (*coreexecutor.StreamResult, error) {
ch := make(chan coreexecutor.StreamChunk, 2)
ch <- coreexecutor.StreamChunk{Payload: []byte("event: response.completed")}
ch <- coreexecutor.StreamChunk{Payload: []byte("data: {\"type\":\"response.completed\",\"response\":{\"id\":\"resp-1\",\"output\":[]}}")}
close(ch)
return &coreexecutor.StreamResult{Chunks: ch}, nil
}
func (e *splitResponsesEventStreamExecutor) Refresh(ctx context.Context, auth *coreauth.Auth) (*coreauth.Auth, error) {
return auth, nil
}
func (e *splitResponsesEventStreamExecutor) CountTokens(context.Context, *coreauth.Auth, coreexecutor.Request, coreexecutor.Options) (coreexecutor.Response, error) {
return coreexecutor.Response{}, &coreauth.Error{Code: "not_implemented", Message: "CountTokens not implemented"}
}
func (e *splitResponsesEventStreamExecutor) HttpRequest(ctx context.Context, auth *coreauth.Auth, req *http.Request) (*http.Response, error) {
return nil, &coreauth.Error{
Code: "not_implemented",
Message: "HttpRequest not implemented",
HTTPStatus: http.StatusNotImplemented,
}
}
func (e *authAwareStreamExecutor) Identifier() string { return "codex" }
func (e *authAwareStreamExecutor) Execute(context.Context, *coreauth.Auth, coreexecutor.Request, coreexecutor.Options) (coreexecutor.Response, error) {
@@ -607,3 +639,52 @@ func TestExecuteStreamWithAuthManager_ValidatesOpenAIResponsesStreamDataJSON(t *
t.Fatalf("expected terminal error")
}
}
func TestExecuteStreamWithAuthManager_AllowsSplitOpenAIResponsesSSEEventLines(t *testing.T) {
executor := &splitResponsesEventStreamExecutor{}
manager := coreauth.NewManager(nil, nil, nil)
manager.RegisterExecutor(executor)
auth1 := &coreauth.Auth{
ID: "auth1",
Provider: "split-sse",
Status: coreauth.StatusActive,
Metadata: map[string]any{"email": "test1@example.com"},
}
if _, err := manager.Register(context.Background(), auth1); err != nil {
t.Fatalf("manager.Register(auth1): %v", err)
}
registry.GetGlobalRegistry().RegisterClient(auth1.ID, auth1.Provider, []*registry.ModelInfo{{ID: "test-model"}})
t.Cleanup(func() {
registry.GetGlobalRegistry().UnregisterClient(auth1.ID)
})
handler := NewBaseAPIHandlers(&sdkconfig.SDKConfig{}, manager)
dataChan, _, errChan := handler.ExecuteStreamWithAuthManager(context.Background(), "openai-response", "test-model", []byte(`{"model":"test-model"}`), "")
if dataChan == nil || errChan == nil {
t.Fatalf("expected non-nil channels")
}
var got []string
for chunk := range dataChan {
got = append(got, string(chunk))
}
for msg := range errChan {
if msg != nil {
t.Fatalf("unexpected error: %+v", msg)
}
}
if len(got) != 2 {
t.Fatalf("expected 2 forwarded chunks, got %d: %#v", len(got), got)
}
if got[0] != "event: response.completed" {
t.Fatalf("unexpected first chunk: %q", got[0])
}
expectedData := "data: {\"type\":\"response.completed\",\"response\":{\"id\":\"resp-1\",\"output\":[]}}"
if got[1] != expectedData {
t.Fatalf("unexpected second chunk.\nGot: %q\nWant: %q", got[1], expectedData)
}
}

View File

@@ -9,6 +9,7 @@ package openai
import (
"bytes"
"context"
"encoding/json"
"fmt"
"io"
"net/http"
@@ -29,11 +30,13 @@ func writeResponsesSSEChunk(w io.Writer, chunk []byte) {
if _, err := w.Write(chunk); err != nil {
return
}
if bytes.HasSuffix(chunk, []byte("\n\n")) {
if bytes.HasSuffix(chunk, []byte("\n\n")) || bytes.HasSuffix(chunk, []byte("\r\n\r\n")) {
return
}
suffix := []byte("\n\n")
if bytes.HasSuffix(chunk, []byte("\n")) {
if bytes.HasSuffix(chunk, []byte("\r\n")) {
suffix = []byte("\r\n")
} else if bytes.HasSuffix(chunk, []byte("\n")) {
suffix = []byte("\n")
}
if _, err := w.Write(suffix); err != nil {
@@ -41,6 +44,156 @@ func writeResponsesSSEChunk(w io.Writer, chunk []byte) {
}
}
type responsesSSEFramer struct {
pending []byte
}
func (f *responsesSSEFramer) WriteChunk(w io.Writer, chunk []byte) {
if len(chunk) == 0 {
return
}
if responsesSSENeedsLineBreak(f.pending, chunk) {
f.pending = append(f.pending, '\n')
}
f.pending = append(f.pending, chunk...)
for {
frameLen := responsesSSEFrameLen(f.pending)
if frameLen == 0 {
break
}
writeResponsesSSEChunk(w, f.pending[:frameLen])
copy(f.pending, f.pending[frameLen:])
f.pending = f.pending[:len(f.pending)-frameLen]
}
if len(bytes.TrimSpace(f.pending)) == 0 {
f.pending = f.pending[:0]
return
}
if len(f.pending) == 0 || !responsesSSECanEmitWithoutDelimiter(f.pending) {
return
}
writeResponsesSSEChunk(w, f.pending)
f.pending = f.pending[:0]
}
func (f *responsesSSEFramer) Flush(w io.Writer) {
if len(f.pending) == 0 {
return
}
if len(bytes.TrimSpace(f.pending)) == 0 {
f.pending = f.pending[:0]
return
}
if !responsesSSECanEmitWithoutDelimiter(f.pending) {
f.pending = f.pending[:0]
return
}
writeResponsesSSEChunk(w, f.pending)
f.pending = f.pending[:0]
}
func responsesSSEFrameLen(chunk []byte) int {
if len(chunk) == 0 {
return 0
}
lf := bytes.Index(chunk, []byte("\n\n"))
crlf := bytes.Index(chunk, []byte("\r\n\r\n"))
switch {
case lf < 0:
if crlf < 0 {
return 0
}
return crlf + 4
case crlf < 0:
return lf + 2
case lf < crlf:
return lf + 2
default:
return crlf + 4
}
}
func responsesSSENeedsMoreData(chunk []byte) bool {
trimmed := bytes.TrimSpace(chunk)
if len(trimmed) == 0 {
return false
}
return responsesSSEHasField(trimmed, []byte("event:")) && !responsesSSEHasField(trimmed, []byte("data:"))
}
func responsesSSEHasField(chunk []byte, prefix []byte) bool {
s := chunk
for len(s) > 0 {
line := s
if i := bytes.IndexByte(s, '\n'); i >= 0 {
line = s[:i]
s = s[i+1:]
} else {
s = nil
}
line = bytes.TrimSpace(line)
if bytes.HasPrefix(line, prefix) {
return true
}
}
return false
}
func responsesSSECanEmitWithoutDelimiter(chunk []byte) bool {
trimmed := bytes.TrimSpace(chunk)
if len(trimmed) == 0 || responsesSSENeedsMoreData(trimmed) || !responsesSSEHasField(trimmed, []byte("data:")) {
return false
}
return responsesSSEDataLinesValid(trimmed)
}
func responsesSSEDataLinesValid(chunk []byte) bool {
s := chunk
for len(s) > 0 {
line := s
if i := bytes.IndexByte(s, '\n'); i >= 0 {
line = s[:i]
s = s[i+1:]
} else {
s = nil
}
line = bytes.TrimSpace(line)
if len(line) == 0 || !bytes.HasPrefix(line, []byte("data:")) {
continue
}
data := bytes.TrimSpace(line[len("data:"):])
if len(data) == 0 || bytes.Equal(data, []byte("[DONE]")) {
continue
}
if !json.Valid(data) {
return false
}
}
return true
}
func responsesSSENeedsLineBreak(pending, chunk []byte) bool {
if len(pending) == 0 || len(chunk) == 0 {
return false
}
if bytes.HasSuffix(pending, []byte("\n")) || bytes.HasSuffix(pending, []byte("\r")) {
return false
}
if chunk[0] == '\n' || chunk[0] == '\r' {
return false
}
trimmed := bytes.TrimLeft(chunk, " \t")
if len(trimmed) == 0 {
return false
}
for _, prefix := range [][]byte{[]byte("data:"), []byte("event:"), []byte("id:"), []byte("retry:"), []byte(":")} {
if bytes.HasPrefix(trimmed, prefix) {
return true
}
}
return false
}
// OpenAIResponsesAPIHandler contains the handlers for OpenAIResponses API endpoints.
// It holds a pool of clients to interact with the backend service.
type OpenAIResponsesAPIHandler struct {
@@ -213,6 +366,7 @@ func (h *OpenAIResponsesAPIHandler) handleStreamingResponse(c *gin.Context, rawJ
c.Header("Connection", "keep-alive")
c.Header("Access-Control-Allow-Origin", "*")
}
framer := &responsesSSEFramer{}
// Peek at the first chunk
for {
@@ -250,22 +404,26 @@ func (h *OpenAIResponsesAPIHandler) handleStreamingResponse(c *gin.Context, rawJ
handlers.WriteUpstreamHeaders(c.Writer.Header(), upstreamHeaders)
// Write first chunk logic (matching forwardResponsesStream)
writeResponsesSSEChunk(c.Writer, chunk)
framer.WriteChunk(c.Writer, chunk)
flusher.Flush()
// Continue
h.forwardResponsesStream(c, flusher, func(err error) { cliCancel(err) }, dataChan, errChan)
h.forwardResponsesStream(c, flusher, func(err error) { cliCancel(err) }, dataChan, errChan, framer)
return
}
}
}
func (h *OpenAIResponsesAPIHandler) forwardResponsesStream(c *gin.Context, flusher http.Flusher, cancel func(error), data <-chan []byte, errs <-chan *interfaces.ErrorMessage) {
func (h *OpenAIResponsesAPIHandler) forwardResponsesStream(c *gin.Context, flusher http.Flusher, cancel func(error), data <-chan []byte, errs <-chan *interfaces.ErrorMessage, framer *responsesSSEFramer) {
if framer == nil {
framer = &responsesSSEFramer{}
}
h.ForwardStream(c, flusher, cancel, data, errs, handlers.StreamForwardOptions{
WriteChunk: func(chunk []byte) {
writeResponsesSSEChunk(c.Writer, chunk)
framer.WriteChunk(c.Writer, chunk)
},
WriteTerminalError: func(errMsg *interfaces.ErrorMessage) {
framer.Flush(c.Writer)
if errMsg == nil {
return
}
@@ -281,6 +439,7 @@ func (h *OpenAIResponsesAPIHandler) forwardResponsesStream(c *gin.Context, flush
_, _ = fmt.Fprintf(c.Writer, "\nevent: error\ndata: %s\n\n", string(chunk))
},
WriteDone: func() {
framer.Flush(c.Writer)
_, _ = c.Writer.Write([]byte("\n"))
},
})

View File

@@ -32,7 +32,7 @@ func TestForwardResponsesStreamTerminalErrorUsesResponsesErrorChunk(t *testing.T
errs <- &interfaces.ErrorMessage{StatusCode: http.StatusInternalServerError, Error: errors.New("unexpected EOF")}
close(errs)
h.forwardResponsesStream(c, flusher, func(error) {}, data, errs)
h.forwardResponsesStream(c, flusher, func(error) {}, data, errs, nil)
body := recorder.Body.String()
if !strings.Contains(body, `"type":"error"`) {
t.Fatalf("expected responses error chunk, got: %q", body)

View File

@@ -12,7 +12,9 @@ import (
sdkconfig "github.com/router-for-me/CLIProxyAPI/v6/sdk/config"
)
func TestForwardResponsesStreamSeparatesDataOnlySSEChunks(t *testing.T) {
func newResponsesStreamTestHandler(t *testing.T) (*OpenAIResponsesAPIHandler, *httptest.ResponseRecorder, *gin.Context, http.Flusher) {
t.Helper()
gin.SetMode(gin.TestMode)
base := handlers.NewBaseAPIHandlers(&sdkconfig.SDKConfig{}, nil)
h := NewOpenAIResponsesAPIHandler(base)
@@ -26,6 +28,12 @@ func TestForwardResponsesStreamSeparatesDataOnlySSEChunks(t *testing.T) {
t.Fatalf("expected gin writer to implement http.Flusher")
}
return h, recorder, c, flusher
}
func TestForwardResponsesStreamSeparatesDataOnlySSEChunks(t *testing.T) {
h, recorder, c, flusher := newResponsesStreamTestHandler(t)
data := make(chan []byte, 2)
errs := make(chan *interfaces.ErrorMessage)
data <- []byte("data: {\"type\":\"response.output_item.done\",\"item\":{\"type\":\"function_call\",\"arguments\":\"{}\"}}")
@@ -33,7 +41,7 @@ func TestForwardResponsesStreamSeparatesDataOnlySSEChunks(t *testing.T) {
close(data)
close(errs)
h.forwardResponsesStream(c, flusher, func(error) {}, data, errs)
h.forwardResponsesStream(c, flusher, func(error) {}, data, errs, nil)
body := recorder.Body.String()
parts := strings.Split(strings.TrimSpace(body), "\n\n")
if len(parts) != 2 {
@@ -50,3 +58,85 @@ func TestForwardResponsesStreamSeparatesDataOnlySSEChunks(t *testing.T) {
t.Errorf("unexpected second event.\nGot: %q\nWant: %q", parts[1], expectedPart2)
}
}
func TestForwardResponsesStreamReassemblesSplitSSEEventChunks(t *testing.T) {
h, recorder, c, flusher := newResponsesStreamTestHandler(t)
data := make(chan []byte, 3)
errs := make(chan *interfaces.ErrorMessage)
data <- []byte("event: response.created")
data <- []byte("data: {\"type\":\"response.created\",\"response\":{\"id\":\"resp-1\"}}")
data <- []byte("\n")
close(data)
close(errs)
h.forwardResponsesStream(c, flusher, func(error) {}, data, errs, nil)
got := strings.TrimSuffix(recorder.Body.String(), "\n")
want := "event: response.created\ndata: {\"type\":\"response.created\",\"response\":{\"id\":\"resp-1\"}}\n\n"
if got != want {
t.Fatalf("unexpected split-event framing.\nGot: %q\nWant: %q", got, want)
}
}
func TestForwardResponsesStreamPreservesValidFullSSEEventChunks(t *testing.T) {
h, recorder, c, flusher := newResponsesStreamTestHandler(t)
data := make(chan []byte, 1)
errs := make(chan *interfaces.ErrorMessage)
chunk := []byte("event: response.created\ndata: {\"type\":\"response.created\",\"response\":{\"id\":\"resp-1\"}}\n\n")
data <- chunk
close(data)
close(errs)
h.forwardResponsesStream(c, flusher, func(error) {}, data, errs, nil)
got := strings.TrimSuffix(recorder.Body.String(), "\n")
if got != string(chunk) {
t.Fatalf("unexpected full-event framing.\nGot: %q\nWant: %q", got, string(chunk))
}
}
func TestForwardResponsesStreamBuffersSplitDataPayloadChunks(t *testing.T) {
h, recorder, c, flusher := newResponsesStreamTestHandler(t)
data := make(chan []byte, 2)
errs := make(chan *interfaces.ErrorMessage)
data <- []byte("data: {\"type\":\"response.created\"")
data <- []byte(",\"response\":{\"id\":\"resp-1\"}}")
close(data)
close(errs)
h.forwardResponsesStream(c, flusher, func(error) {}, data, errs, nil)
got := recorder.Body.String()
want := "data: {\"type\":\"response.created\",\"response\":{\"id\":\"resp-1\"}}\n\n\n"
if got != want {
t.Fatalf("unexpected split-data framing.\nGot: %q\nWant: %q", got, want)
}
}
func TestResponsesSSENeedsLineBreakSkipsChunksThatAlreadyStartWithNewline(t *testing.T) {
if responsesSSENeedsLineBreak([]byte("event: response.created"), []byte("\n")) {
t.Fatal("expected no injected newline before newline-only chunk")
}
if responsesSSENeedsLineBreak([]byte("event: response.created"), []byte("\r\n")) {
t.Fatal("expected no injected newline before CRLF chunk")
}
}
func TestForwardResponsesStreamDropsIncompleteTrailingDataChunkOnFlush(t *testing.T) {
h, recorder, c, flusher := newResponsesStreamTestHandler(t)
data := make(chan []byte, 1)
errs := make(chan *interfaces.ErrorMessage)
data <- []byte("data: {\"type\":\"response.created\"")
close(data)
close(errs)
h.forwardResponsesStream(c, flusher, func(error) {}, data, errs, nil)
if got := recorder.Body.String(); got != "\n" {
t.Fatalf("expected incomplete trailing data to be dropped on flush.\nGot: %q", got)
}
}