diff --git a/internal/runtime/executor/codex_websockets_executor.go b/internal/runtime/executor/codex_websockets_executor.go index 7c887221..1f340050 100644 --- a/internal/runtime/executor/codex_websockets_executor.go +++ b/internal/runtime/executor/codex_websockets_executor.go @@ -31,7 +31,7 @@ import ( ) const ( - codexResponsesWebsocketBetaHeaderValue = "responses_websockets=2026-02-04" + codexResponsesWebsocketBetaHeaderValue = "responses_websockets=2026-02-06" codexResponsesWebsocketIdleTimeout = 5 * time.Minute codexResponsesWebsocketHandshakeTO = 30 * time.Second ) @@ -57,11 +57,6 @@ type codexWebsocketSession struct { wsURL string authID string - // connCreateSent tracks whether a `response.create` message has been successfully sent - // on the current websocket connection. The upstream expects the first message on each - // connection to be `response.create`. - connCreateSent bool - writeMu sync.Mutex activeMu sync.Mutex @@ -212,13 +207,7 @@ func (e *CodexWebsocketsExecutor) Execute(ctx context.Context, auth *cliproxyaut defer sess.reqMu.Unlock() } - allowAppend := true - if sess != nil { - sess.connMu.Lock() - allowAppend = sess.connCreateSent - sess.connMu.Unlock() - } - wsReqBody := buildCodexWebsocketRequestBody(body, allowAppend) + wsReqBody := buildCodexWebsocketRequestBody(body) recordAPIRequest(ctx, e.cfg, upstreamRequestLog{ URL: wsURL, Method: "WEBSOCKET", @@ -280,10 +269,7 @@ func (e *CodexWebsocketsExecutor) Execute(ctx context.Context, auth *cliproxyaut // execution session. connRetry, _, errDialRetry := e.ensureUpstreamConn(ctx, auth, sess, authID, wsURL, wsHeaders) if errDialRetry == nil && connRetry != nil { - sess.connMu.Lock() - allowAppend = sess.connCreateSent - sess.connMu.Unlock() - wsReqBodyRetry := buildCodexWebsocketRequestBody(body, allowAppend) + wsReqBodyRetry := buildCodexWebsocketRequestBody(body) recordAPIRequest(ctx, e.cfg, upstreamRequestLog{ URL: wsURL, Method: "WEBSOCKET", @@ -312,7 +298,6 @@ func (e *CodexWebsocketsExecutor) Execute(ctx context.Context, auth *cliproxyaut return resp, errSend } } - markCodexWebsocketCreateSent(sess, conn, wsReqBody) for { if ctx != nil && ctx.Err() != nil { @@ -403,26 +388,20 @@ func (e *CodexWebsocketsExecutor) ExecuteStream(ctx context.Context, auth *clipr wsHeaders = applyCodexWebsocketHeaders(ctx, wsHeaders, auth, apiKey) var authID, authLabel, authType, authValue string - if auth != nil { - authID = auth.ID - authLabel = auth.Label - authType, authValue = auth.AccountInfo() - } + authID = auth.ID + authLabel = auth.Label + authType, authValue = auth.AccountInfo() executionSessionID := executionSessionIDFromOptions(opts) var sess *codexWebsocketSession if executionSessionID != "" { sess = e.getOrCreateSession(executionSessionID) - sess.reqMu.Lock() + if sess != nil { + sess.reqMu.Lock() + } } - allowAppend := true - if sess != nil { - sess.connMu.Lock() - allowAppend = sess.connCreateSent - sess.connMu.Unlock() - } - wsReqBody := buildCodexWebsocketRequestBody(body, allowAppend) + wsReqBody := buildCodexWebsocketRequestBody(body) recordAPIRequest(ctx, e.cfg, upstreamRequestLog{ URL: wsURL, Method: "WEBSOCKET", @@ -483,10 +462,7 @@ func (e *CodexWebsocketsExecutor) ExecuteStream(ctx context.Context, auth *clipr sess.reqMu.Unlock() return nil, errDialRetry } - sess.connMu.Lock() - allowAppend = sess.connCreateSent - sess.connMu.Unlock() - wsReqBodyRetry := buildCodexWebsocketRequestBody(body, allowAppend) + wsReqBodyRetry := buildCodexWebsocketRequestBody(body) recordAPIRequest(ctx, e.cfg, upstreamRequestLog{ URL: wsURL, Method: "WEBSOCKET", @@ -515,7 +491,6 @@ func (e *CodexWebsocketsExecutor) ExecuteStream(ctx context.Context, auth *clipr return nil, errSend } } - markCodexWebsocketCreateSent(sess, conn, wsReqBody) out := make(chan cliproxyexecutor.StreamChunk) go func() { @@ -657,31 +632,14 @@ func writeCodexWebsocketMessage(sess *codexWebsocketSession, conn *websocket.Con return conn.WriteMessage(websocket.TextMessage, payload) } -func buildCodexWebsocketRequestBody(body []byte, allowAppend bool) []byte { +func buildCodexWebsocketRequestBody(body []byte) []byte { if len(body) == 0 { return nil } - // Codex CLI websocket v2 uses `response.create` with `previous_response_id` for incremental turns. - // The upstream ChatGPT Codex websocket currently rejects that with close 1008 (policy violation). - // Fall back to v1 `response.append` semantics on the same websocket connection to keep the session alive. - // - // NOTE: The upstream expects the first websocket event on each connection to be `response.create`, - // so we only use `response.append` after we have initialized the current connection. - if allowAppend { - if prev := strings.TrimSpace(gjson.GetBytes(body, "previous_response_id").String()); prev != "" { - inputNode := gjson.GetBytes(body, "input") - wsReqBody := []byte(`{}`) - wsReqBody, _ = sjson.SetBytes(wsReqBody, "type", "response.append") - if inputNode.Exists() && inputNode.IsArray() && strings.TrimSpace(inputNode.Raw) != "" { - wsReqBody, _ = sjson.SetRawBytes(wsReqBody, "input", []byte(inputNode.Raw)) - return wsReqBody - } - wsReqBody, _ = sjson.SetRawBytes(wsReqBody, "input", []byte("[]")) - return wsReqBody - } - } - + // Match codex-rs websocket v2 semantics: every request is `response.create`. + // Incremental follow-up turns continue on the same websocket using + // `previous_response_id` + incremental `input`, not `response.append`. wsReqBody, errSet := sjson.SetBytes(bytes.Clone(body), "type", "response.create") if errSet == nil && len(wsReqBody) > 0 { return wsReqBody @@ -725,21 +683,6 @@ func readCodexWebsocketMessage(ctx context.Context, sess *codexWebsocketSession, } } -func markCodexWebsocketCreateSent(sess *codexWebsocketSession, conn *websocket.Conn, payload []byte) { - if sess == nil || conn == nil || len(payload) == 0 { - return - } - if strings.TrimSpace(gjson.GetBytes(payload, "type").String()) != "response.create" { - return - } - - sess.connMu.Lock() - if sess.conn == conn { - sess.connCreateSent = true - } - sess.connMu.Unlock() -} - func newProxyAwareWebsocketDialer(cfg *config.Config, auth *cliproxyauth.Auth) *websocket.Dialer { dialer := &websocket.Dialer{ Proxy: http.ProxyFromEnvironment, @@ -1017,36 +960,6 @@ func closeHTTPResponseBody(resp *http.Response, logPrefix string) { } } -func closeOnContextDone(ctx context.Context, conn *websocket.Conn) chan struct{} { - done := make(chan struct{}) - if ctx == nil || conn == nil { - return done - } - go func() { - select { - case <-done: - case <-ctx.Done(): - _ = conn.Close() - } - }() - return done -} - -func cancelReadOnContextDone(ctx context.Context, conn *websocket.Conn) chan struct{} { - done := make(chan struct{}) - if ctx == nil || conn == nil { - return done - } - go func() { - select { - case <-done: - case <-ctx.Done(): - _ = conn.SetReadDeadline(time.Now()) - } - }() - return done -} - func executionSessionIDFromOptions(opts cliproxyexecutor.Options) string { if len(opts.Metadata) == 0 { return "" @@ -1120,7 +1033,6 @@ func (e *CodexWebsocketsExecutor) ensureUpstreamConn(ctx context.Context, auth * sess.conn = conn sess.wsURL = wsURL sess.authID = authID - sess.connCreateSent = false sess.readerConn = conn sess.connMu.Unlock() @@ -1206,7 +1118,6 @@ func (e *CodexWebsocketsExecutor) invalidateUpstreamConn(sess *codexWebsocketSes return } sess.conn = nil - sess.connCreateSent = false if sess.readerConn == conn { sess.readerConn = nil } @@ -1273,7 +1184,6 @@ func (e *CodexWebsocketsExecutor) closeExecutionSession(sess *codexWebsocketSess authID := sess.authID wsURL := sess.wsURL sess.conn = nil - sess.connCreateSent = false if sess.readerConn == conn { sess.readerConn = nil } diff --git a/internal/runtime/executor/codex_websockets_executor_test.go b/internal/runtime/executor/codex_websockets_executor_test.go new file mode 100644 index 00000000..1fd68513 --- /dev/null +++ b/internal/runtime/executor/codex_websockets_executor_test.go @@ -0,0 +1,36 @@ +package executor + +import ( + "context" + "net/http" + "testing" + + "github.com/tidwall/gjson" +) + +func TestBuildCodexWebsocketRequestBodyPreservesPreviousResponseID(t *testing.T) { + body := []byte(`{"model":"gpt-5-codex","previous_response_id":"resp-1","input":[{"type":"message","id":"msg-1"}]}`) + + wsReqBody := buildCodexWebsocketRequestBody(body) + + if got := gjson.GetBytes(wsReqBody, "type").String(); got != "response.create" { + t.Fatalf("type = %s, want response.create", got) + } + if got := gjson.GetBytes(wsReqBody, "previous_response_id").String(); got != "resp-1" { + t.Fatalf("previous_response_id = %s, want resp-1", got) + } + if gjson.GetBytes(wsReqBody, "input.0.id").String() != "msg-1" { + t.Fatalf("input item id mismatch") + } + if got := gjson.GetBytes(wsReqBody, "type").String(); got == "response.append" { + t.Fatalf("unexpected websocket request type: %s", got) + } +} + +func TestApplyCodexWebsocketHeadersDefaultsToCurrentResponsesBeta(t *testing.T) { + headers := applyCodexWebsocketHeaders(context.Background(), http.Header{}, nil, "") + + if got := headers.Get("OpenAI-Beta"); got != codexResponsesWebsocketBetaHeaderValue { + t.Fatalf("OpenAI-Beta = %s, want %s", got, codexResponsesWebsocketBetaHeaderValue) + } +} diff --git a/sdk/api/handlers/openai/openai_responses_websocket.go b/sdk/api/handlers/openai/openai_responses_websocket.go index f2d44f05..5e2beb94 100644 --- a/sdk/api/handlers/openai/openai_responses_websocket.go +++ b/sdk/api/handlers/openai/openai_responses_websocket.go @@ -26,7 +26,6 @@ const ( wsRequestTypeAppend = "response.append" wsEventTypeError = "error" wsEventTypeCompleted = "response.completed" - wsEventTypeDone = "response.done" wsDoneMarker = "[DONE]" wsTurnStateHeader = "x-codex-turn-state" wsRequestBodyKey = "REQUEST_BODY_OVERRIDE" @@ -469,9 +468,6 @@ func (h *OpenAIResponsesAPIHandler) forwardResponsesWebsocket( for i := range payloads { eventType := gjson.GetBytes(payloads[i], "type").String() if eventType == wsEventTypeCompleted { - // log.Infof("replace %s with %s", wsEventTypeCompleted, wsEventTypeDone) - payloads[i], _ = sjson.SetBytes(payloads[i], "type", wsEventTypeDone) - completed = true completedOutput = responseCompletedOutputFromPayload(payloads[i]) } diff --git a/sdk/api/handlers/openai/openai_responses_websocket_test.go b/sdk/api/handlers/openai/openai_responses_websocket_test.go index 9b6cec78..a04bb18c 100644 --- a/sdk/api/handlers/openai/openai_responses_websocket_test.go +++ b/sdk/api/handlers/openai/openai_responses_websocket_test.go @@ -2,12 +2,15 @@ package openai import ( "bytes" + "errors" "net/http" "net/http/httptest" "strings" "testing" "github.com/gin-gonic/gin" + "github.com/gorilla/websocket" + "github.com/router-for-me/CLIProxyAPI/v6/internal/interfaces" "github.com/tidwall/gjson" ) @@ -247,3 +250,79 @@ func TestSetWebsocketRequestBody(t *testing.T) { t.Fatalf("request body = %q, want %q", string(bodyBytes), "event body") } } + +func TestForwardResponsesWebsocketPreservesCompletedEvent(t *testing.T) { + gin.SetMode(gin.TestMode) + + serverErrCh := make(chan error, 1) + server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + conn, err := responsesWebsocketUpgrader.Upgrade(w, r, nil) + if err != nil { + serverErrCh <- err + return + } + defer func() { + errClose := conn.Close() + if errClose != nil { + serverErrCh <- errClose + } + }() + + ctx, _ := gin.CreateTestContext(httptest.NewRecorder()) + ctx.Request = r + + data := make(chan []byte, 1) + errCh := make(chan *interfaces.ErrorMessage) + data <- []byte("data: {\"type\":\"response.completed\",\"response\":{\"id\":\"resp-1\",\"output\":[{\"type\":\"message\",\"id\":\"out-1\"}]}}\n\n") + close(data) + close(errCh) + + var bodyLog strings.Builder + completedOutput, err := (*OpenAIResponsesAPIHandler)(nil).forwardResponsesWebsocket( + ctx, + conn, + func(...interface{}) {}, + data, + errCh, + &bodyLog, + "session-1", + ) + if err != nil { + serverErrCh <- err + return + } + if gjson.GetBytes(completedOutput, "0.id").String() != "out-1" { + serverErrCh <- errors.New("completed output not captured") + return + } + serverErrCh <- nil + })) + defer server.Close() + + wsURL := "ws" + strings.TrimPrefix(server.URL, "http") + conn, _, err := websocket.DefaultDialer.Dial(wsURL, nil) + if err != nil { + t.Fatalf("dial websocket: %v", err) + } + defer func() { + errClose := conn.Close() + if errClose != nil { + t.Fatalf("close websocket: %v", errClose) + } + }() + + _, payload, errReadMessage := conn.ReadMessage() + if errReadMessage != nil { + t.Fatalf("read websocket message: %v", errReadMessage) + } + if gjson.GetBytes(payload, "type").String() != wsEventTypeCompleted { + t.Fatalf("payload type = %s, want %s", gjson.GetBytes(payload, "type").String(), wsEventTypeCompleted) + } + if strings.Contains(string(payload), "response.done") { + t.Fatalf("payload unexpectedly rewrote completed event: %s", payload) + } + + if errServer := <-serverErrCh; errServer != nil { + t.Fatalf("server error: %v", errServer) + } +}