mirror of
https://github.com/router-for-me/CLIProxyAPIPlus.git
synced 2026-03-08 06:43:41 +00:00
refactor(executor): remove legacy connCreateSent logic and standardize response.create usage for all websocket events
- Simplified connection logic by removing `connCreateSent` and related state handling. - Updated `buildCodexWebsocketRequestBody` to always use `response.create`. - Added unit tests to validate `response.create` behavior and beta header preservation. - Dropped unsupported `response.append` and outdated `response.done` event types.
This commit is contained in:
@@ -31,7 +31,7 @@ import (
|
||||
)
|
||||
|
||||
const (
|
||||
codexResponsesWebsocketBetaHeaderValue = "responses_websockets=2026-02-04"
|
||||
codexResponsesWebsocketBetaHeaderValue = "responses_websockets=2026-02-06"
|
||||
codexResponsesWebsocketIdleTimeout = 5 * time.Minute
|
||||
codexResponsesWebsocketHandshakeTO = 30 * time.Second
|
||||
)
|
||||
@@ -57,11 +57,6 @@ type codexWebsocketSession struct {
|
||||
wsURL string
|
||||
authID string
|
||||
|
||||
// connCreateSent tracks whether a `response.create` message has been successfully sent
|
||||
// on the current websocket connection. The upstream expects the first message on each
|
||||
// connection to be `response.create`.
|
||||
connCreateSent bool
|
||||
|
||||
writeMu sync.Mutex
|
||||
|
||||
activeMu sync.Mutex
|
||||
@@ -212,13 +207,7 @@ func (e *CodexWebsocketsExecutor) Execute(ctx context.Context, auth *cliproxyaut
|
||||
defer sess.reqMu.Unlock()
|
||||
}
|
||||
|
||||
allowAppend := true
|
||||
if sess != nil {
|
||||
sess.connMu.Lock()
|
||||
allowAppend = sess.connCreateSent
|
||||
sess.connMu.Unlock()
|
||||
}
|
||||
wsReqBody := buildCodexWebsocketRequestBody(body, allowAppend)
|
||||
wsReqBody := buildCodexWebsocketRequestBody(body)
|
||||
recordAPIRequest(ctx, e.cfg, upstreamRequestLog{
|
||||
URL: wsURL,
|
||||
Method: "WEBSOCKET",
|
||||
@@ -280,10 +269,7 @@ func (e *CodexWebsocketsExecutor) Execute(ctx context.Context, auth *cliproxyaut
|
||||
// execution session.
|
||||
connRetry, _, errDialRetry := e.ensureUpstreamConn(ctx, auth, sess, authID, wsURL, wsHeaders)
|
||||
if errDialRetry == nil && connRetry != nil {
|
||||
sess.connMu.Lock()
|
||||
allowAppend = sess.connCreateSent
|
||||
sess.connMu.Unlock()
|
||||
wsReqBodyRetry := buildCodexWebsocketRequestBody(body, allowAppend)
|
||||
wsReqBodyRetry := buildCodexWebsocketRequestBody(body)
|
||||
recordAPIRequest(ctx, e.cfg, upstreamRequestLog{
|
||||
URL: wsURL,
|
||||
Method: "WEBSOCKET",
|
||||
@@ -312,7 +298,6 @@ func (e *CodexWebsocketsExecutor) Execute(ctx context.Context, auth *cliproxyaut
|
||||
return resp, errSend
|
||||
}
|
||||
}
|
||||
markCodexWebsocketCreateSent(sess, conn, wsReqBody)
|
||||
|
||||
for {
|
||||
if ctx != nil && ctx.Err() != nil {
|
||||
@@ -403,26 +388,20 @@ func (e *CodexWebsocketsExecutor) ExecuteStream(ctx context.Context, auth *clipr
|
||||
wsHeaders = applyCodexWebsocketHeaders(ctx, wsHeaders, auth, apiKey)
|
||||
|
||||
var authID, authLabel, authType, authValue string
|
||||
if auth != nil {
|
||||
authID = auth.ID
|
||||
authLabel = auth.Label
|
||||
authType, authValue = auth.AccountInfo()
|
||||
}
|
||||
authID = auth.ID
|
||||
authLabel = auth.Label
|
||||
authType, authValue = auth.AccountInfo()
|
||||
|
||||
executionSessionID := executionSessionIDFromOptions(opts)
|
||||
var sess *codexWebsocketSession
|
||||
if executionSessionID != "" {
|
||||
sess = e.getOrCreateSession(executionSessionID)
|
||||
sess.reqMu.Lock()
|
||||
if sess != nil {
|
||||
sess.reqMu.Lock()
|
||||
}
|
||||
}
|
||||
|
||||
allowAppend := true
|
||||
if sess != nil {
|
||||
sess.connMu.Lock()
|
||||
allowAppend = sess.connCreateSent
|
||||
sess.connMu.Unlock()
|
||||
}
|
||||
wsReqBody := buildCodexWebsocketRequestBody(body, allowAppend)
|
||||
wsReqBody := buildCodexWebsocketRequestBody(body)
|
||||
recordAPIRequest(ctx, e.cfg, upstreamRequestLog{
|
||||
URL: wsURL,
|
||||
Method: "WEBSOCKET",
|
||||
@@ -483,10 +462,7 @@ func (e *CodexWebsocketsExecutor) ExecuteStream(ctx context.Context, auth *clipr
|
||||
sess.reqMu.Unlock()
|
||||
return nil, errDialRetry
|
||||
}
|
||||
sess.connMu.Lock()
|
||||
allowAppend = sess.connCreateSent
|
||||
sess.connMu.Unlock()
|
||||
wsReqBodyRetry := buildCodexWebsocketRequestBody(body, allowAppend)
|
||||
wsReqBodyRetry := buildCodexWebsocketRequestBody(body)
|
||||
recordAPIRequest(ctx, e.cfg, upstreamRequestLog{
|
||||
URL: wsURL,
|
||||
Method: "WEBSOCKET",
|
||||
@@ -515,7 +491,6 @@ func (e *CodexWebsocketsExecutor) ExecuteStream(ctx context.Context, auth *clipr
|
||||
return nil, errSend
|
||||
}
|
||||
}
|
||||
markCodexWebsocketCreateSent(sess, conn, wsReqBody)
|
||||
|
||||
out := make(chan cliproxyexecutor.StreamChunk)
|
||||
go func() {
|
||||
@@ -657,31 +632,14 @@ func writeCodexWebsocketMessage(sess *codexWebsocketSession, conn *websocket.Con
|
||||
return conn.WriteMessage(websocket.TextMessage, payload)
|
||||
}
|
||||
|
||||
func buildCodexWebsocketRequestBody(body []byte, allowAppend bool) []byte {
|
||||
func buildCodexWebsocketRequestBody(body []byte) []byte {
|
||||
if len(body) == 0 {
|
||||
return nil
|
||||
}
|
||||
|
||||
// Codex CLI websocket v2 uses `response.create` with `previous_response_id` for incremental turns.
|
||||
// The upstream ChatGPT Codex websocket currently rejects that with close 1008 (policy violation).
|
||||
// Fall back to v1 `response.append` semantics on the same websocket connection to keep the session alive.
|
||||
//
|
||||
// NOTE: The upstream expects the first websocket event on each connection to be `response.create`,
|
||||
// so we only use `response.append` after we have initialized the current connection.
|
||||
if allowAppend {
|
||||
if prev := strings.TrimSpace(gjson.GetBytes(body, "previous_response_id").String()); prev != "" {
|
||||
inputNode := gjson.GetBytes(body, "input")
|
||||
wsReqBody := []byte(`{}`)
|
||||
wsReqBody, _ = sjson.SetBytes(wsReqBody, "type", "response.append")
|
||||
if inputNode.Exists() && inputNode.IsArray() && strings.TrimSpace(inputNode.Raw) != "" {
|
||||
wsReqBody, _ = sjson.SetRawBytes(wsReqBody, "input", []byte(inputNode.Raw))
|
||||
return wsReqBody
|
||||
}
|
||||
wsReqBody, _ = sjson.SetRawBytes(wsReqBody, "input", []byte("[]"))
|
||||
return wsReqBody
|
||||
}
|
||||
}
|
||||
|
||||
// Match codex-rs websocket v2 semantics: every request is `response.create`.
|
||||
// Incremental follow-up turns continue on the same websocket using
|
||||
// `previous_response_id` + incremental `input`, not `response.append`.
|
||||
wsReqBody, errSet := sjson.SetBytes(bytes.Clone(body), "type", "response.create")
|
||||
if errSet == nil && len(wsReqBody) > 0 {
|
||||
return wsReqBody
|
||||
@@ -725,21 +683,6 @@ func readCodexWebsocketMessage(ctx context.Context, sess *codexWebsocketSession,
|
||||
}
|
||||
}
|
||||
|
||||
func markCodexWebsocketCreateSent(sess *codexWebsocketSession, conn *websocket.Conn, payload []byte) {
|
||||
if sess == nil || conn == nil || len(payload) == 0 {
|
||||
return
|
||||
}
|
||||
if strings.TrimSpace(gjson.GetBytes(payload, "type").String()) != "response.create" {
|
||||
return
|
||||
}
|
||||
|
||||
sess.connMu.Lock()
|
||||
if sess.conn == conn {
|
||||
sess.connCreateSent = true
|
||||
}
|
||||
sess.connMu.Unlock()
|
||||
}
|
||||
|
||||
func newProxyAwareWebsocketDialer(cfg *config.Config, auth *cliproxyauth.Auth) *websocket.Dialer {
|
||||
dialer := &websocket.Dialer{
|
||||
Proxy: http.ProxyFromEnvironment,
|
||||
@@ -1017,36 +960,6 @@ func closeHTTPResponseBody(resp *http.Response, logPrefix string) {
|
||||
}
|
||||
}
|
||||
|
||||
func closeOnContextDone(ctx context.Context, conn *websocket.Conn) chan struct{} {
|
||||
done := make(chan struct{})
|
||||
if ctx == nil || conn == nil {
|
||||
return done
|
||||
}
|
||||
go func() {
|
||||
select {
|
||||
case <-done:
|
||||
case <-ctx.Done():
|
||||
_ = conn.Close()
|
||||
}
|
||||
}()
|
||||
return done
|
||||
}
|
||||
|
||||
func cancelReadOnContextDone(ctx context.Context, conn *websocket.Conn) chan struct{} {
|
||||
done := make(chan struct{})
|
||||
if ctx == nil || conn == nil {
|
||||
return done
|
||||
}
|
||||
go func() {
|
||||
select {
|
||||
case <-done:
|
||||
case <-ctx.Done():
|
||||
_ = conn.SetReadDeadline(time.Now())
|
||||
}
|
||||
}()
|
||||
return done
|
||||
}
|
||||
|
||||
func executionSessionIDFromOptions(opts cliproxyexecutor.Options) string {
|
||||
if len(opts.Metadata) == 0 {
|
||||
return ""
|
||||
@@ -1120,7 +1033,6 @@ func (e *CodexWebsocketsExecutor) ensureUpstreamConn(ctx context.Context, auth *
|
||||
sess.conn = conn
|
||||
sess.wsURL = wsURL
|
||||
sess.authID = authID
|
||||
sess.connCreateSent = false
|
||||
sess.readerConn = conn
|
||||
sess.connMu.Unlock()
|
||||
|
||||
@@ -1206,7 +1118,6 @@ func (e *CodexWebsocketsExecutor) invalidateUpstreamConn(sess *codexWebsocketSes
|
||||
return
|
||||
}
|
||||
sess.conn = nil
|
||||
sess.connCreateSent = false
|
||||
if sess.readerConn == conn {
|
||||
sess.readerConn = nil
|
||||
}
|
||||
@@ -1273,7 +1184,6 @@ func (e *CodexWebsocketsExecutor) closeExecutionSession(sess *codexWebsocketSess
|
||||
authID := sess.authID
|
||||
wsURL := sess.wsURL
|
||||
sess.conn = nil
|
||||
sess.connCreateSent = false
|
||||
if sess.readerConn == conn {
|
||||
sess.readerConn = nil
|
||||
}
|
||||
|
||||
36
internal/runtime/executor/codex_websockets_executor_test.go
Normal file
36
internal/runtime/executor/codex_websockets_executor_test.go
Normal file
@@ -0,0 +1,36 @@
|
||||
package executor
|
||||
|
||||
import (
|
||||
"context"
|
||||
"net/http"
|
||||
"testing"
|
||||
|
||||
"github.com/tidwall/gjson"
|
||||
)
|
||||
|
||||
func TestBuildCodexWebsocketRequestBodyPreservesPreviousResponseID(t *testing.T) {
|
||||
body := []byte(`{"model":"gpt-5-codex","previous_response_id":"resp-1","input":[{"type":"message","id":"msg-1"}]}`)
|
||||
|
||||
wsReqBody := buildCodexWebsocketRequestBody(body)
|
||||
|
||||
if got := gjson.GetBytes(wsReqBody, "type").String(); got != "response.create" {
|
||||
t.Fatalf("type = %s, want response.create", got)
|
||||
}
|
||||
if got := gjson.GetBytes(wsReqBody, "previous_response_id").String(); got != "resp-1" {
|
||||
t.Fatalf("previous_response_id = %s, want resp-1", got)
|
||||
}
|
||||
if gjson.GetBytes(wsReqBody, "input.0.id").String() != "msg-1" {
|
||||
t.Fatalf("input item id mismatch")
|
||||
}
|
||||
if got := gjson.GetBytes(wsReqBody, "type").String(); got == "response.append" {
|
||||
t.Fatalf("unexpected websocket request type: %s", got)
|
||||
}
|
||||
}
|
||||
|
||||
func TestApplyCodexWebsocketHeadersDefaultsToCurrentResponsesBeta(t *testing.T) {
|
||||
headers := applyCodexWebsocketHeaders(context.Background(), http.Header{}, nil, "")
|
||||
|
||||
if got := headers.Get("OpenAI-Beta"); got != codexResponsesWebsocketBetaHeaderValue {
|
||||
t.Fatalf("OpenAI-Beta = %s, want %s", got, codexResponsesWebsocketBetaHeaderValue)
|
||||
}
|
||||
}
|
||||
@@ -26,7 +26,6 @@ const (
|
||||
wsRequestTypeAppend = "response.append"
|
||||
wsEventTypeError = "error"
|
||||
wsEventTypeCompleted = "response.completed"
|
||||
wsEventTypeDone = "response.done"
|
||||
wsDoneMarker = "[DONE]"
|
||||
wsTurnStateHeader = "x-codex-turn-state"
|
||||
wsRequestBodyKey = "REQUEST_BODY_OVERRIDE"
|
||||
@@ -469,9 +468,6 @@ func (h *OpenAIResponsesAPIHandler) forwardResponsesWebsocket(
|
||||
for i := range payloads {
|
||||
eventType := gjson.GetBytes(payloads[i], "type").String()
|
||||
if eventType == wsEventTypeCompleted {
|
||||
// log.Infof("replace %s with %s", wsEventTypeCompleted, wsEventTypeDone)
|
||||
payloads[i], _ = sjson.SetBytes(payloads[i], "type", wsEventTypeDone)
|
||||
|
||||
completed = true
|
||||
completedOutput = responseCompletedOutputFromPayload(payloads[i])
|
||||
}
|
||||
|
||||
@@ -2,12 +2,15 @@ package openai
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"errors"
|
||||
"net/http"
|
||||
"net/http/httptest"
|
||||
"strings"
|
||||
"testing"
|
||||
|
||||
"github.com/gin-gonic/gin"
|
||||
"github.com/gorilla/websocket"
|
||||
"github.com/router-for-me/CLIProxyAPI/v6/internal/interfaces"
|
||||
"github.com/tidwall/gjson"
|
||||
)
|
||||
|
||||
@@ -247,3 +250,79 @@ func TestSetWebsocketRequestBody(t *testing.T) {
|
||||
t.Fatalf("request body = %q, want %q", string(bodyBytes), "event body")
|
||||
}
|
||||
}
|
||||
|
||||
func TestForwardResponsesWebsocketPreservesCompletedEvent(t *testing.T) {
|
||||
gin.SetMode(gin.TestMode)
|
||||
|
||||
serverErrCh := make(chan error, 1)
|
||||
server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
conn, err := responsesWebsocketUpgrader.Upgrade(w, r, nil)
|
||||
if err != nil {
|
||||
serverErrCh <- err
|
||||
return
|
||||
}
|
||||
defer func() {
|
||||
errClose := conn.Close()
|
||||
if errClose != nil {
|
||||
serverErrCh <- errClose
|
||||
}
|
||||
}()
|
||||
|
||||
ctx, _ := gin.CreateTestContext(httptest.NewRecorder())
|
||||
ctx.Request = r
|
||||
|
||||
data := make(chan []byte, 1)
|
||||
errCh := make(chan *interfaces.ErrorMessage)
|
||||
data <- []byte("data: {\"type\":\"response.completed\",\"response\":{\"id\":\"resp-1\",\"output\":[{\"type\":\"message\",\"id\":\"out-1\"}]}}\n\n")
|
||||
close(data)
|
||||
close(errCh)
|
||||
|
||||
var bodyLog strings.Builder
|
||||
completedOutput, err := (*OpenAIResponsesAPIHandler)(nil).forwardResponsesWebsocket(
|
||||
ctx,
|
||||
conn,
|
||||
func(...interface{}) {},
|
||||
data,
|
||||
errCh,
|
||||
&bodyLog,
|
||||
"session-1",
|
||||
)
|
||||
if err != nil {
|
||||
serverErrCh <- err
|
||||
return
|
||||
}
|
||||
if gjson.GetBytes(completedOutput, "0.id").String() != "out-1" {
|
||||
serverErrCh <- errors.New("completed output not captured")
|
||||
return
|
||||
}
|
||||
serverErrCh <- nil
|
||||
}))
|
||||
defer server.Close()
|
||||
|
||||
wsURL := "ws" + strings.TrimPrefix(server.URL, "http")
|
||||
conn, _, err := websocket.DefaultDialer.Dial(wsURL, nil)
|
||||
if err != nil {
|
||||
t.Fatalf("dial websocket: %v", err)
|
||||
}
|
||||
defer func() {
|
||||
errClose := conn.Close()
|
||||
if errClose != nil {
|
||||
t.Fatalf("close websocket: %v", errClose)
|
||||
}
|
||||
}()
|
||||
|
||||
_, payload, errReadMessage := conn.ReadMessage()
|
||||
if errReadMessage != nil {
|
||||
t.Fatalf("read websocket message: %v", errReadMessage)
|
||||
}
|
||||
if gjson.GetBytes(payload, "type").String() != wsEventTypeCompleted {
|
||||
t.Fatalf("payload type = %s, want %s", gjson.GetBytes(payload, "type").String(), wsEventTypeCompleted)
|
||||
}
|
||||
if strings.Contains(string(payload), "response.done") {
|
||||
t.Fatalf("payload unexpectedly rewrote completed event: %s", payload)
|
||||
}
|
||||
|
||||
if errServer := <-serverErrCh; errServer != nil {
|
||||
t.Fatalf("server error: %v", errServer)
|
||||
}
|
||||
}
|
||||
|
||||
Reference in New Issue
Block a user