mirror of
https://github.com/router-for-me/CLIProxyAPIPlus.git
synced 2026-03-30 09:18:12 +00:00
Restore Claude continuity after the continuity refactor, keep auth-affinity keys out of upstream Codex session identifiers, and only persist affinity after successful execution so retries can still rotate to healthy credentials when the first auth fails.
1395 lines
40 KiB
Go
1395 lines
40 KiB
Go
// Package executor provides runtime execution capabilities for various AI service providers.
|
|
// This file implements a Codex executor that uses the Responses API WebSocket transport.
|
|
package executor
|
|
|
|
import (
|
|
"bytes"
|
|
"context"
|
|
"fmt"
|
|
"io"
|
|
"net"
|
|
"net/http"
|
|
"net/url"
|
|
"strconv"
|
|
"strings"
|
|
"sync"
|
|
"time"
|
|
|
|
"github.com/google/uuid"
|
|
"github.com/gorilla/websocket"
|
|
"github.com/router-for-me/CLIProxyAPI/v6/internal/config"
|
|
"github.com/router-for-me/CLIProxyAPI/v6/internal/misc"
|
|
"github.com/router-for-me/CLIProxyAPI/v6/internal/thinking"
|
|
"github.com/router-for-me/CLIProxyAPI/v6/internal/util"
|
|
cliproxyauth "github.com/router-for-me/CLIProxyAPI/v6/sdk/cliproxy/auth"
|
|
cliproxyexecutor "github.com/router-for-me/CLIProxyAPI/v6/sdk/cliproxy/executor"
|
|
"github.com/router-for-me/CLIProxyAPI/v6/sdk/proxyutil"
|
|
sdktranslator "github.com/router-for-me/CLIProxyAPI/v6/sdk/translator"
|
|
log "github.com/sirupsen/logrus"
|
|
"github.com/tidwall/gjson"
|
|
"github.com/tidwall/sjson"
|
|
"golang.org/x/net/proxy"
|
|
)
|
|
|
|
const (
|
|
codexResponsesWebsocketBetaHeaderValue = "responses_websockets=2026-02-06"
|
|
codexResponsesWebsocketIdleTimeout = 5 * time.Minute
|
|
codexResponsesWebsocketHandshakeTO = 30 * time.Second
|
|
)
|
|
|
|
// CodexWebsocketsExecutor executes Codex Responses requests using a WebSocket transport.
|
|
//
|
|
// It preserves the existing CodexExecutor HTTP implementation as a fallback for endpoints
|
|
// not available over WebSocket (e.g. /responses/compact) and for websocket upgrade failures.
|
|
type CodexWebsocketsExecutor struct {
|
|
*CodexExecutor
|
|
|
|
sessMu sync.Mutex
|
|
sessions map[string]*codexWebsocketSession
|
|
}
|
|
|
|
type codexWebsocketSession struct {
|
|
sessionID string
|
|
|
|
reqMu sync.Mutex
|
|
|
|
connMu sync.Mutex
|
|
conn *websocket.Conn
|
|
wsURL string
|
|
authID string
|
|
|
|
writeMu sync.Mutex
|
|
|
|
activeMu sync.Mutex
|
|
activeCh chan codexWebsocketRead
|
|
activeDone <-chan struct{}
|
|
activeCancel context.CancelFunc
|
|
|
|
readerConn *websocket.Conn
|
|
}
|
|
|
|
func NewCodexWebsocketsExecutor(cfg *config.Config) *CodexWebsocketsExecutor {
|
|
return &CodexWebsocketsExecutor{
|
|
CodexExecutor: NewCodexExecutor(cfg),
|
|
sessions: make(map[string]*codexWebsocketSession),
|
|
}
|
|
}
|
|
|
|
type codexWebsocketRead struct {
|
|
conn *websocket.Conn
|
|
msgType int
|
|
payload []byte
|
|
err error
|
|
}
|
|
|
|
func (s *codexWebsocketSession) setActive(ch chan codexWebsocketRead) {
|
|
if s == nil {
|
|
return
|
|
}
|
|
s.activeMu.Lock()
|
|
if s.activeCancel != nil {
|
|
s.activeCancel()
|
|
s.activeCancel = nil
|
|
s.activeDone = nil
|
|
}
|
|
s.activeCh = ch
|
|
if ch != nil {
|
|
activeCtx, activeCancel := context.WithCancel(context.Background())
|
|
s.activeDone = activeCtx.Done()
|
|
s.activeCancel = activeCancel
|
|
}
|
|
s.activeMu.Unlock()
|
|
}
|
|
|
|
func (s *codexWebsocketSession) clearActive(ch chan codexWebsocketRead) {
|
|
if s == nil {
|
|
return
|
|
}
|
|
s.activeMu.Lock()
|
|
if s.activeCh == ch {
|
|
s.activeCh = nil
|
|
if s.activeCancel != nil {
|
|
s.activeCancel()
|
|
}
|
|
s.activeCancel = nil
|
|
s.activeDone = nil
|
|
}
|
|
s.activeMu.Unlock()
|
|
}
|
|
|
|
func (s *codexWebsocketSession) writeMessage(conn *websocket.Conn, msgType int, payload []byte) error {
|
|
if s == nil {
|
|
return fmt.Errorf("codex websockets executor: session is nil")
|
|
}
|
|
if conn == nil {
|
|
return fmt.Errorf("codex websockets executor: websocket conn is nil")
|
|
}
|
|
s.writeMu.Lock()
|
|
defer s.writeMu.Unlock()
|
|
return conn.WriteMessage(msgType, payload)
|
|
}
|
|
|
|
func (s *codexWebsocketSession) configureConn(conn *websocket.Conn) {
|
|
if s == nil || conn == nil {
|
|
return
|
|
}
|
|
conn.SetPingHandler(func(appData string) error {
|
|
s.writeMu.Lock()
|
|
defer s.writeMu.Unlock()
|
|
// Reply pongs from the same write lock to avoid concurrent writes.
|
|
return conn.WriteControl(websocket.PongMessage, []byte(appData), time.Now().Add(10*time.Second))
|
|
})
|
|
}
|
|
|
|
func (e *CodexWebsocketsExecutor) Execute(ctx context.Context, auth *cliproxyauth.Auth, req cliproxyexecutor.Request, opts cliproxyexecutor.Options) (resp cliproxyexecutor.Response, err error) {
|
|
if ctx == nil {
|
|
ctx = context.Background()
|
|
}
|
|
if opts.Alt == "responses/compact" {
|
|
return e.CodexExecutor.executeCompact(ctx, auth, req, opts)
|
|
}
|
|
|
|
baseModel := thinking.ParseSuffix(req.Model).ModelName
|
|
apiKey, baseURL := codexCreds(auth)
|
|
if baseURL == "" {
|
|
baseURL = "https://chatgpt.com/backend-api/codex"
|
|
}
|
|
|
|
reporter := newUsageReporter(ctx, e.Identifier(), baseModel, auth)
|
|
defer reporter.trackFailure(ctx, &err)
|
|
|
|
from := opts.SourceFormat
|
|
to := sdktranslator.FromString("codex")
|
|
originalPayloadSource := req.Payload
|
|
if len(opts.OriginalRequest) > 0 {
|
|
originalPayloadSource = opts.OriginalRequest
|
|
}
|
|
originalPayload := originalPayloadSource
|
|
originalTranslated := sdktranslator.TranslateRequest(from, to, baseModel, originalPayload, false)
|
|
body := sdktranslator.TranslateRequest(from, to, baseModel, req.Payload, false)
|
|
|
|
body, err = thinking.ApplyThinking(body, req.Model, from.String(), to.String(), e.Identifier())
|
|
if err != nil {
|
|
return resp, err
|
|
}
|
|
|
|
requestedModel := payloadRequestedModel(opts, req.Model)
|
|
body = applyPayloadConfigWithRoot(e.cfg, baseModel, to.String(), "", body, originalTranslated, requestedModel)
|
|
body, _ = sjson.SetBytes(body, "model", baseModel)
|
|
body, _ = sjson.SetBytes(body, "stream", true)
|
|
body, _ = sjson.DeleteBytes(body, "previous_response_id")
|
|
body, _ = sjson.DeleteBytes(body, "safety_identifier")
|
|
if !gjson.GetBytes(body, "instructions").Exists() {
|
|
body, _ = sjson.SetBytes(body, "instructions", "")
|
|
}
|
|
|
|
httpURL := strings.TrimSuffix(baseURL, "/") + "/responses"
|
|
wsURL, err := buildCodexResponsesWebsocketURL(httpURL)
|
|
if err != nil {
|
|
return resp, err
|
|
}
|
|
|
|
body, wsHeaders, continuity := applyCodexPromptCacheHeaders(ctx, auth, from, req, opts, body)
|
|
wsHeaders = applyCodexWebsocketHeaders(ctx, wsHeaders, auth, apiKey, e.cfg)
|
|
|
|
var authID, authLabel, authType, authValue string
|
|
if auth != nil {
|
|
authID = auth.ID
|
|
authLabel = auth.Label
|
|
authType, authValue = auth.AccountInfo()
|
|
}
|
|
|
|
executionSessionID := executionSessionIDFromOptions(opts)
|
|
var sess *codexWebsocketSession
|
|
if executionSessionID != "" {
|
|
sess = e.getOrCreateSession(executionSessionID)
|
|
sess.reqMu.Lock()
|
|
defer sess.reqMu.Unlock()
|
|
}
|
|
|
|
wsReqBody := buildCodexWebsocketRequestBody(body)
|
|
logCodexRequestDiagnostics(ctx, auth, req, opts, wsHeaders, body, continuity)
|
|
recordAPIRequest(ctx, e.cfg, upstreamRequestLog{
|
|
URL: wsURL,
|
|
Method: "WEBSOCKET",
|
|
Headers: wsHeaders.Clone(),
|
|
Body: wsReqBody,
|
|
Provider: e.Identifier(),
|
|
AuthID: authID,
|
|
AuthLabel: authLabel,
|
|
AuthType: authType,
|
|
AuthValue: authValue,
|
|
})
|
|
|
|
conn, respHS, errDial := e.ensureUpstreamConn(ctx, auth, sess, authID, wsURL, wsHeaders)
|
|
if respHS != nil {
|
|
recordAPIResponseMetadata(ctx, e.cfg, respHS.StatusCode, respHS.Header.Clone())
|
|
}
|
|
if errDial != nil {
|
|
bodyErr := websocketHandshakeBody(respHS)
|
|
if len(bodyErr) > 0 {
|
|
appendAPIResponseChunk(ctx, e.cfg, bodyErr)
|
|
}
|
|
if respHS != nil && respHS.StatusCode == http.StatusUpgradeRequired {
|
|
return e.CodexExecutor.Execute(ctx, auth, req, opts)
|
|
}
|
|
if respHS != nil && respHS.StatusCode > 0 {
|
|
return resp, statusErr{code: respHS.StatusCode, msg: string(bodyErr)}
|
|
}
|
|
recordAPIResponseError(ctx, e.cfg, errDial)
|
|
return resp, errDial
|
|
}
|
|
closeHTTPResponseBody(respHS, "codex websockets executor: close handshake response body error")
|
|
if sess == nil {
|
|
logCodexWebsocketConnected(executionSessionID, authID, wsURL)
|
|
defer func() {
|
|
reason := "completed"
|
|
if err != nil {
|
|
reason = "error"
|
|
}
|
|
logCodexWebsocketDisconnected(executionSessionID, authID, wsURL, reason, err)
|
|
if errClose := conn.Close(); errClose != nil {
|
|
log.Errorf("codex websockets executor: close websocket error: %v", errClose)
|
|
}
|
|
}()
|
|
}
|
|
|
|
var readCh chan codexWebsocketRead
|
|
if sess != nil {
|
|
readCh = make(chan codexWebsocketRead, 4096)
|
|
sess.setActive(readCh)
|
|
defer sess.clearActive(readCh)
|
|
}
|
|
|
|
if errSend := writeCodexWebsocketMessage(sess, conn, wsReqBody); errSend != nil {
|
|
if sess != nil {
|
|
e.invalidateUpstreamConn(sess, conn, "send_error", errSend)
|
|
|
|
// Retry once with a fresh websocket connection. This is mainly to handle
|
|
// upstream closing the socket between sequential requests within the same
|
|
// execution session.
|
|
connRetry, _, errDialRetry := e.ensureUpstreamConn(ctx, auth, sess, authID, wsURL, wsHeaders)
|
|
if errDialRetry == nil && connRetry != nil {
|
|
wsReqBodyRetry := buildCodexWebsocketRequestBody(body)
|
|
recordAPIRequest(ctx, e.cfg, upstreamRequestLog{
|
|
URL: wsURL,
|
|
Method: "WEBSOCKET",
|
|
Headers: wsHeaders.Clone(),
|
|
Body: wsReqBodyRetry,
|
|
Provider: e.Identifier(),
|
|
AuthID: authID,
|
|
AuthLabel: authLabel,
|
|
AuthType: authType,
|
|
AuthValue: authValue,
|
|
})
|
|
if errSendRetry := writeCodexWebsocketMessage(sess, connRetry, wsReqBodyRetry); errSendRetry == nil {
|
|
conn = connRetry
|
|
wsReqBody = wsReqBodyRetry
|
|
} else {
|
|
e.invalidateUpstreamConn(sess, connRetry, "send_error", errSendRetry)
|
|
recordAPIResponseError(ctx, e.cfg, errSendRetry)
|
|
return resp, errSendRetry
|
|
}
|
|
} else {
|
|
recordAPIResponseError(ctx, e.cfg, errDialRetry)
|
|
return resp, errDialRetry
|
|
}
|
|
} else {
|
|
recordAPIResponseError(ctx, e.cfg, errSend)
|
|
return resp, errSend
|
|
}
|
|
}
|
|
|
|
for {
|
|
if ctx != nil && ctx.Err() != nil {
|
|
return resp, ctx.Err()
|
|
}
|
|
msgType, payload, errRead := readCodexWebsocketMessage(ctx, sess, conn, readCh)
|
|
if errRead != nil {
|
|
recordAPIResponseError(ctx, e.cfg, errRead)
|
|
return resp, errRead
|
|
}
|
|
if msgType != websocket.TextMessage {
|
|
if msgType == websocket.BinaryMessage {
|
|
err = fmt.Errorf("codex websockets executor: unexpected binary message")
|
|
if sess != nil {
|
|
e.invalidateUpstreamConn(sess, conn, "unexpected_binary", err)
|
|
}
|
|
recordAPIResponseError(ctx, e.cfg, err)
|
|
return resp, err
|
|
}
|
|
continue
|
|
}
|
|
|
|
payload = bytes.TrimSpace(payload)
|
|
if len(payload) == 0 {
|
|
continue
|
|
}
|
|
appendAPIResponseChunk(ctx, e.cfg, payload)
|
|
|
|
if wsErr, ok := parseCodexWebsocketError(payload); ok {
|
|
if sess != nil {
|
|
e.invalidateUpstreamConn(sess, conn, "upstream_error", wsErr)
|
|
}
|
|
recordAPIResponseError(ctx, e.cfg, wsErr)
|
|
return resp, wsErr
|
|
}
|
|
|
|
payload = normalizeCodexWebsocketCompletion(payload)
|
|
eventType := gjson.GetBytes(payload, "type").String()
|
|
if eventType == "response.completed" {
|
|
if detail, ok := parseCodexUsage(payload); ok {
|
|
reporter.publish(ctx, detail)
|
|
}
|
|
var param any
|
|
out := sdktranslator.TranslateNonStream(ctx, to, from, req.Model, originalPayload, body, payload, ¶m)
|
|
resp = cliproxyexecutor.Response{Payload: out}
|
|
return resp, nil
|
|
}
|
|
}
|
|
}
|
|
|
|
func (e *CodexWebsocketsExecutor) ExecuteStream(ctx context.Context, auth *cliproxyauth.Auth, req cliproxyexecutor.Request, opts cliproxyexecutor.Options) (_ *cliproxyexecutor.StreamResult, err error) {
|
|
log.Debugf("Executing Codex Websockets stream request with auth ID: %s, model: %s", auth.ID, req.Model)
|
|
if ctx == nil {
|
|
ctx = context.Background()
|
|
}
|
|
if opts.Alt == "responses/compact" {
|
|
return nil, statusErr{code: http.StatusBadRequest, msg: "streaming not supported for /responses/compact"}
|
|
}
|
|
|
|
baseModel := thinking.ParseSuffix(req.Model).ModelName
|
|
apiKey, baseURL := codexCreds(auth)
|
|
if baseURL == "" {
|
|
baseURL = "https://chatgpt.com/backend-api/codex"
|
|
}
|
|
|
|
reporter := newUsageReporter(ctx, e.Identifier(), baseModel, auth)
|
|
defer reporter.trackFailure(ctx, &err)
|
|
|
|
from := opts.SourceFormat
|
|
to := sdktranslator.FromString("codex")
|
|
body := req.Payload
|
|
|
|
body, err = thinking.ApplyThinking(body, req.Model, from.String(), to.String(), e.Identifier())
|
|
if err != nil {
|
|
return nil, err
|
|
}
|
|
|
|
requestedModel := payloadRequestedModel(opts, req.Model)
|
|
body = applyPayloadConfigWithRoot(e.cfg, baseModel, to.String(), "", body, body, requestedModel)
|
|
|
|
httpURL := strings.TrimSuffix(baseURL, "/") + "/responses"
|
|
wsURL, err := buildCodexResponsesWebsocketURL(httpURL)
|
|
if err != nil {
|
|
return nil, err
|
|
}
|
|
|
|
body, wsHeaders, continuity := applyCodexPromptCacheHeaders(ctx, auth, from, req, opts, body)
|
|
wsHeaders = applyCodexWebsocketHeaders(ctx, wsHeaders, auth, apiKey, e.cfg)
|
|
|
|
var authID, authLabel, authType, authValue string
|
|
authID = auth.ID
|
|
authLabel = auth.Label
|
|
authType, authValue = auth.AccountInfo()
|
|
|
|
executionSessionID := executionSessionIDFromOptions(opts)
|
|
var sess *codexWebsocketSession
|
|
if executionSessionID != "" {
|
|
sess = e.getOrCreateSession(executionSessionID)
|
|
if sess != nil {
|
|
sess.reqMu.Lock()
|
|
}
|
|
}
|
|
|
|
wsReqBody := buildCodexWebsocketRequestBody(body)
|
|
logCodexRequestDiagnostics(ctx, auth, req, opts, wsHeaders, body, continuity)
|
|
recordAPIRequest(ctx, e.cfg, upstreamRequestLog{
|
|
URL: wsURL,
|
|
Method: "WEBSOCKET",
|
|
Headers: wsHeaders.Clone(),
|
|
Body: wsReqBody,
|
|
Provider: e.Identifier(),
|
|
AuthID: authID,
|
|
AuthLabel: authLabel,
|
|
AuthType: authType,
|
|
AuthValue: authValue,
|
|
})
|
|
|
|
conn, respHS, errDial := e.ensureUpstreamConn(ctx, auth, sess, authID, wsURL, wsHeaders)
|
|
var upstreamHeaders http.Header
|
|
if respHS != nil {
|
|
upstreamHeaders = respHS.Header.Clone()
|
|
recordAPIResponseMetadata(ctx, e.cfg, respHS.StatusCode, respHS.Header.Clone())
|
|
}
|
|
if errDial != nil {
|
|
bodyErr := websocketHandshakeBody(respHS)
|
|
if len(bodyErr) > 0 {
|
|
appendAPIResponseChunk(ctx, e.cfg, bodyErr)
|
|
}
|
|
if respHS != nil && respHS.StatusCode == http.StatusUpgradeRequired {
|
|
return e.CodexExecutor.ExecuteStream(ctx, auth, req, opts)
|
|
}
|
|
if respHS != nil && respHS.StatusCode > 0 {
|
|
return nil, statusErr{code: respHS.StatusCode, msg: string(bodyErr)}
|
|
}
|
|
recordAPIResponseError(ctx, e.cfg, errDial)
|
|
if sess != nil {
|
|
sess.reqMu.Unlock()
|
|
}
|
|
return nil, errDial
|
|
}
|
|
closeHTTPResponseBody(respHS, "codex websockets executor: close handshake response body error")
|
|
|
|
if sess == nil {
|
|
logCodexWebsocketConnected(executionSessionID, authID, wsURL)
|
|
}
|
|
|
|
var readCh chan codexWebsocketRead
|
|
if sess != nil {
|
|
readCh = make(chan codexWebsocketRead, 4096)
|
|
sess.setActive(readCh)
|
|
}
|
|
|
|
if errSend := writeCodexWebsocketMessage(sess, conn, wsReqBody); errSend != nil {
|
|
recordAPIResponseError(ctx, e.cfg, errSend)
|
|
if sess != nil {
|
|
e.invalidateUpstreamConn(sess, conn, "send_error", errSend)
|
|
|
|
// Retry once with a new websocket connection for the same execution session.
|
|
connRetry, _, errDialRetry := e.ensureUpstreamConn(ctx, auth, sess, authID, wsURL, wsHeaders)
|
|
if errDialRetry != nil || connRetry == nil {
|
|
recordAPIResponseError(ctx, e.cfg, errDialRetry)
|
|
sess.clearActive(readCh)
|
|
sess.reqMu.Unlock()
|
|
return nil, errDialRetry
|
|
}
|
|
wsReqBodyRetry := buildCodexWebsocketRequestBody(body)
|
|
recordAPIRequest(ctx, e.cfg, upstreamRequestLog{
|
|
URL: wsURL,
|
|
Method: "WEBSOCKET",
|
|
Headers: wsHeaders.Clone(),
|
|
Body: wsReqBodyRetry,
|
|
Provider: e.Identifier(),
|
|
AuthID: authID,
|
|
AuthLabel: authLabel,
|
|
AuthType: authType,
|
|
AuthValue: authValue,
|
|
})
|
|
if errSendRetry := writeCodexWebsocketMessage(sess, connRetry, wsReqBodyRetry); errSendRetry != nil {
|
|
recordAPIResponseError(ctx, e.cfg, errSendRetry)
|
|
e.invalidateUpstreamConn(sess, connRetry, "send_error", errSendRetry)
|
|
sess.clearActive(readCh)
|
|
sess.reqMu.Unlock()
|
|
return nil, errSendRetry
|
|
}
|
|
conn = connRetry
|
|
wsReqBody = wsReqBodyRetry
|
|
} else {
|
|
logCodexWebsocketDisconnected(executionSessionID, authID, wsURL, "send_error", errSend)
|
|
if errClose := conn.Close(); errClose != nil {
|
|
log.Errorf("codex websockets executor: close websocket error: %v", errClose)
|
|
}
|
|
return nil, errSend
|
|
}
|
|
}
|
|
|
|
out := make(chan cliproxyexecutor.StreamChunk)
|
|
go func() {
|
|
terminateReason := "completed"
|
|
var terminateErr error
|
|
|
|
defer close(out)
|
|
defer func() {
|
|
if sess != nil {
|
|
sess.clearActive(readCh)
|
|
sess.reqMu.Unlock()
|
|
return
|
|
}
|
|
logCodexWebsocketDisconnected(executionSessionID, authID, wsURL, terminateReason, terminateErr)
|
|
if errClose := conn.Close(); errClose != nil {
|
|
log.Errorf("codex websockets executor: close websocket error: %v", errClose)
|
|
}
|
|
}()
|
|
|
|
send := func(chunk cliproxyexecutor.StreamChunk) bool {
|
|
if ctx == nil {
|
|
out <- chunk
|
|
return true
|
|
}
|
|
select {
|
|
case out <- chunk:
|
|
return true
|
|
case <-ctx.Done():
|
|
return false
|
|
}
|
|
}
|
|
|
|
var param any
|
|
for {
|
|
if ctx != nil && ctx.Err() != nil {
|
|
terminateReason = "context_done"
|
|
terminateErr = ctx.Err()
|
|
_ = send(cliproxyexecutor.StreamChunk{Err: ctx.Err()})
|
|
return
|
|
}
|
|
msgType, payload, errRead := readCodexWebsocketMessage(ctx, sess, conn, readCh)
|
|
if errRead != nil {
|
|
if sess != nil && ctx != nil && ctx.Err() != nil {
|
|
terminateReason = "context_done"
|
|
terminateErr = ctx.Err()
|
|
_ = send(cliproxyexecutor.StreamChunk{Err: ctx.Err()})
|
|
return
|
|
}
|
|
terminateReason = "read_error"
|
|
terminateErr = errRead
|
|
recordAPIResponseError(ctx, e.cfg, errRead)
|
|
reporter.publishFailure(ctx)
|
|
_ = send(cliproxyexecutor.StreamChunk{Err: errRead})
|
|
return
|
|
}
|
|
if msgType != websocket.TextMessage {
|
|
if msgType == websocket.BinaryMessage {
|
|
err = fmt.Errorf("codex websockets executor: unexpected binary message")
|
|
terminateReason = "unexpected_binary"
|
|
terminateErr = err
|
|
recordAPIResponseError(ctx, e.cfg, err)
|
|
reporter.publishFailure(ctx)
|
|
if sess != nil {
|
|
e.invalidateUpstreamConn(sess, conn, "unexpected_binary", err)
|
|
}
|
|
_ = send(cliproxyexecutor.StreamChunk{Err: err})
|
|
return
|
|
}
|
|
continue
|
|
}
|
|
|
|
payload = bytes.TrimSpace(payload)
|
|
if len(payload) == 0 {
|
|
continue
|
|
}
|
|
appendAPIResponseChunk(ctx, e.cfg, payload)
|
|
|
|
if wsErr, ok := parseCodexWebsocketError(payload); ok {
|
|
terminateReason = "upstream_error"
|
|
terminateErr = wsErr
|
|
recordAPIResponseError(ctx, e.cfg, wsErr)
|
|
reporter.publishFailure(ctx)
|
|
if sess != nil {
|
|
e.invalidateUpstreamConn(sess, conn, "upstream_error", wsErr)
|
|
}
|
|
_ = send(cliproxyexecutor.StreamChunk{Err: wsErr})
|
|
return
|
|
}
|
|
|
|
payload = normalizeCodexWebsocketCompletion(payload)
|
|
eventType := gjson.GetBytes(payload, "type").String()
|
|
if eventType == "response.completed" || eventType == "response.done" {
|
|
if detail, ok := parseCodexUsage(payload); ok {
|
|
reporter.publish(ctx, detail)
|
|
}
|
|
}
|
|
|
|
line := encodeCodexWebsocketAsSSE(payload)
|
|
chunks := sdktranslator.TranslateStream(ctx, to, from, req.Model, body, body, line, ¶m)
|
|
for i := range chunks {
|
|
if !send(cliproxyexecutor.StreamChunk{Payload: chunks[i]}) {
|
|
terminateReason = "context_done"
|
|
terminateErr = ctx.Err()
|
|
return
|
|
}
|
|
}
|
|
if eventType == "response.completed" || eventType == "response.done" {
|
|
return
|
|
}
|
|
}
|
|
}()
|
|
|
|
return &cliproxyexecutor.StreamResult{Headers: upstreamHeaders, Chunks: out}, nil
|
|
}
|
|
|
|
func (e *CodexWebsocketsExecutor) dialCodexWebsocket(ctx context.Context, auth *cliproxyauth.Auth, wsURL string, headers http.Header) (*websocket.Conn, *http.Response, error) {
|
|
dialer := newProxyAwareWebsocketDialer(e.cfg, auth)
|
|
dialer.HandshakeTimeout = codexResponsesWebsocketHandshakeTO
|
|
dialer.EnableCompression = true
|
|
if ctx == nil {
|
|
ctx = context.Background()
|
|
}
|
|
conn, resp, err := dialer.DialContext(ctx, wsURL, headers)
|
|
if conn != nil {
|
|
// Avoid gorilla/websocket flate tail validation issues on some upstreams/Go versions.
|
|
// Negotiating permessage-deflate is fine; we just don't compress outbound messages.
|
|
conn.EnableWriteCompression(false)
|
|
}
|
|
return conn, resp, err
|
|
}
|
|
|
|
func writeCodexWebsocketMessage(sess *codexWebsocketSession, conn *websocket.Conn, payload []byte) error {
|
|
if sess != nil {
|
|
return sess.writeMessage(conn, websocket.TextMessage, payload)
|
|
}
|
|
if conn == nil {
|
|
return fmt.Errorf("codex websockets executor: websocket conn is nil")
|
|
}
|
|
return conn.WriteMessage(websocket.TextMessage, payload)
|
|
}
|
|
|
|
func buildCodexWebsocketRequestBody(body []byte) []byte {
|
|
if len(body) == 0 {
|
|
return nil
|
|
}
|
|
|
|
// Match codex-rs websocket v2 semantics: every request is `response.create`.
|
|
// Incremental follow-up turns continue on the same websocket using
|
|
// `previous_response_id` + incremental `input`, not `response.append`.
|
|
wsReqBody, errSet := sjson.SetBytes(bytes.Clone(body), "type", "response.create")
|
|
if errSet == nil && len(wsReqBody) > 0 {
|
|
return wsReqBody
|
|
}
|
|
fallback := bytes.Clone(body)
|
|
fallback, _ = sjson.SetBytes(fallback, "type", "response.create")
|
|
return fallback
|
|
}
|
|
|
|
func readCodexWebsocketMessage(ctx context.Context, sess *codexWebsocketSession, conn *websocket.Conn, readCh chan codexWebsocketRead) (int, []byte, error) {
|
|
if sess == nil {
|
|
if conn == nil {
|
|
return 0, nil, fmt.Errorf("codex websockets executor: websocket conn is nil")
|
|
}
|
|
_ = conn.SetReadDeadline(time.Now().Add(codexResponsesWebsocketIdleTimeout))
|
|
msgType, payload, errRead := conn.ReadMessage()
|
|
return msgType, payload, errRead
|
|
}
|
|
if conn == nil {
|
|
return 0, nil, fmt.Errorf("codex websockets executor: websocket conn is nil")
|
|
}
|
|
if readCh == nil {
|
|
return 0, nil, fmt.Errorf("codex websockets executor: session read channel is nil")
|
|
}
|
|
for {
|
|
select {
|
|
case <-ctx.Done():
|
|
return 0, nil, ctx.Err()
|
|
case ev, ok := <-readCh:
|
|
if !ok {
|
|
return 0, nil, fmt.Errorf("codex websockets executor: session read channel closed")
|
|
}
|
|
if ev.conn != conn {
|
|
continue
|
|
}
|
|
if ev.err != nil {
|
|
return 0, nil, ev.err
|
|
}
|
|
return ev.msgType, ev.payload, nil
|
|
}
|
|
}
|
|
}
|
|
|
|
func newProxyAwareWebsocketDialer(cfg *config.Config, auth *cliproxyauth.Auth) *websocket.Dialer {
|
|
dialer := &websocket.Dialer{
|
|
Proxy: http.ProxyFromEnvironment,
|
|
HandshakeTimeout: codexResponsesWebsocketHandshakeTO,
|
|
EnableCompression: true,
|
|
NetDialContext: (&net.Dialer{
|
|
Timeout: 30 * time.Second,
|
|
KeepAlive: 30 * time.Second,
|
|
}).DialContext,
|
|
}
|
|
|
|
proxyURL := ""
|
|
if auth != nil {
|
|
proxyURL = strings.TrimSpace(auth.ProxyURL)
|
|
}
|
|
if proxyURL == "" && cfg != nil {
|
|
proxyURL = strings.TrimSpace(cfg.ProxyURL)
|
|
}
|
|
if proxyURL == "" {
|
|
return dialer
|
|
}
|
|
|
|
setting, errParse := proxyutil.Parse(proxyURL)
|
|
if errParse != nil {
|
|
log.Errorf("codex websockets executor: %v", errParse)
|
|
return dialer
|
|
}
|
|
|
|
switch setting.Mode {
|
|
case proxyutil.ModeDirect:
|
|
dialer.Proxy = nil
|
|
return dialer
|
|
case proxyutil.ModeProxy:
|
|
default:
|
|
return dialer
|
|
}
|
|
|
|
switch setting.URL.Scheme {
|
|
case "socks5":
|
|
var proxyAuth *proxy.Auth
|
|
if setting.URL.User != nil {
|
|
username := setting.URL.User.Username()
|
|
password, _ := setting.URL.User.Password()
|
|
proxyAuth = &proxy.Auth{User: username, Password: password}
|
|
}
|
|
socksDialer, errSOCKS5 := proxy.SOCKS5("tcp", setting.URL.Host, proxyAuth, proxy.Direct)
|
|
if errSOCKS5 != nil {
|
|
log.Errorf("codex websockets executor: create SOCKS5 dialer failed: %v", errSOCKS5)
|
|
return dialer
|
|
}
|
|
dialer.Proxy = nil
|
|
dialer.NetDialContext = func(_ context.Context, network, addr string) (net.Conn, error) {
|
|
return socksDialer.Dial(network, addr)
|
|
}
|
|
case "http", "https":
|
|
dialer.Proxy = http.ProxyURL(setting.URL)
|
|
default:
|
|
log.Errorf("codex websockets executor: unsupported proxy scheme: %s", setting.URL.Scheme)
|
|
}
|
|
|
|
return dialer
|
|
}
|
|
|
|
func buildCodexResponsesWebsocketURL(httpURL string) (string, error) {
|
|
parsed, err := url.Parse(strings.TrimSpace(httpURL))
|
|
if err != nil {
|
|
return "", err
|
|
}
|
|
switch strings.ToLower(parsed.Scheme) {
|
|
case "http":
|
|
parsed.Scheme = "ws"
|
|
case "https":
|
|
parsed.Scheme = "wss"
|
|
}
|
|
return parsed.String(), nil
|
|
}
|
|
|
|
func applyCodexPromptCacheHeaders(ctx context.Context, auth *cliproxyauth.Auth, from sdktranslator.Format, req cliproxyexecutor.Request, opts cliproxyexecutor.Options, rawJSON []byte) ([]byte, http.Header, codexContinuity) {
|
|
headers := http.Header{}
|
|
if len(rawJSON) == 0 {
|
|
return rawJSON, headers, codexContinuity{}
|
|
}
|
|
|
|
var cache codexCache
|
|
continuity := codexContinuity{}
|
|
if from == "claude" {
|
|
userIDResult := gjson.GetBytes(req.Payload, "metadata.user_id")
|
|
if userIDResult.Exists() {
|
|
key := fmt.Sprintf("%s-%s", req.Model, userIDResult.String())
|
|
if cached, ok := getCodexCache(key); ok {
|
|
cache = cached
|
|
} else {
|
|
cache = codexCache{
|
|
ID: uuid.New().String(),
|
|
Expire: time.Now().Add(1 * time.Hour),
|
|
}
|
|
setCodexCache(key, cache)
|
|
}
|
|
continuity = codexContinuity{Key: cache.ID, Source: "claude_user_cache"}
|
|
}
|
|
} else if from == "openai-response" {
|
|
if promptCacheKey := gjson.GetBytes(req.Payload, "prompt_cache_key"); promptCacheKey.Exists() {
|
|
cache.ID = promptCacheKey.String()
|
|
continuity = codexContinuity{Key: cache.ID, Source: "prompt_cache_key"}
|
|
}
|
|
} else if from == "openai" {
|
|
continuity = resolveCodexContinuity(ctx, auth, req, opts)
|
|
cache.ID = continuity.Key
|
|
}
|
|
|
|
rawJSON = applyCodexContinuityBody(rawJSON, continuity)
|
|
applyCodexContinuityHeaders(headers, continuity)
|
|
|
|
return rawJSON, headers, continuity
|
|
}
|
|
|
|
func applyCodexWebsocketHeaders(ctx context.Context, headers http.Header, auth *cliproxyauth.Auth, token string, cfg *config.Config) http.Header {
|
|
if headers == nil {
|
|
headers = http.Header{}
|
|
}
|
|
if strings.TrimSpace(token) != "" {
|
|
headers.Set("Authorization", "Bearer "+token)
|
|
}
|
|
|
|
var ginHeaders http.Header
|
|
if ginCtx := ginContextFrom(ctx); ginCtx != nil && ginCtx.Request != nil {
|
|
ginHeaders = ginCtx.Request.Header
|
|
}
|
|
|
|
cfgUserAgent, cfgBetaFeatures := codexHeaderDefaults(cfg, auth)
|
|
ensureHeaderWithPriority(headers, ginHeaders, "x-codex-beta-features", cfgBetaFeatures, "")
|
|
misc.EnsureHeader(headers, ginHeaders, "x-codex-turn-state", "")
|
|
misc.EnsureHeader(headers, ginHeaders, "x-codex-turn-metadata", "")
|
|
misc.EnsureHeader(headers, ginHeaders, "x-client-request-id", "")
|
|
misc.EnsureHeader(headers, ginHeaders, "x-responsesapi-include-timing-metrics", "")
|
|
misc.EnsureHeader(headers, ginHeaders, "Version", "")
|
|
|
|
betaHeader := strings.TrimSpace(headers.Get("OpenAI-Beta"))
|
|
if betaHeader == "" && ginHeaders != nil {
|
|
betaHeader = strings.TrimSpace(ginHeaders.Get("OpenAI-Beta"))
|
|
}
|
|
if betaHeader == "" || !strings.Contains(betaHeader, "responses_websockets=") {
|
|
betaHeader = codexResponsesWebsocketBetaHeaderValue
|
|
}
|
|
headers.Set("OpenAI-Beta", betaHeader)
|
|
misc.EnsureHeader(headers, ginHeaders, "session_id", uuid.NewString())
|
|
ensureHeaderWithConfigPrecedence(headers, ginHeaders, "User-Agent", cfgUserAgent, codexUserAgent)
|
|
|
|
isAPIKey := false
|
|
if auth != nil && auth.Attributes != nil {
|
|
if v := strings.TrimSpace(auth.Attributes["api_key"]); v != "" {
|
|
isAPIKey = true
|
|
}
|
|
}
|
|
if originator := strings.TrimSpace(ginHeaders.Get("Originator")); originator != "" {
|
|
headers.Set("Originator", originator)
|
|
} else if !isAPIKey {
|
|
headers.Set("Originator", codexOriginator)
|
|
}
|
|
if !isAPIKey {
|
|
if auth != nil && auth.Metadata != nil {
|
|
if accountID, ok := auth.Metadata["account_id"].(string); ok {
|
|
if trimmed := strings.TrimSpace(accountID); trimmed != "" {
|
|
headers.Set("Chatgpt-Account-Id", trimmed)
|
|
}
|
|
}
|
|
}
|
|
}
|
|
|
|
var attrs map[string]string
|
|
if auth != nil {
|
|
attrs = auth.Attributes
|
|
}
|
|
util.ApplyCustomHeadersFromAttrs(&http.Request{Header: headers}, attrs)
|
|
|
|
return headers
|
|
}
|
|
|
|
func codexHeaderDefaults(cfg *config.Config, auth *cliproxyauth.Auth) (string, string) {
|
|
if cfg == nil || auth == nil {
|
|
return "", ""
|
|
}
|
|
if auth.Attributes != nil {
|
|
if v := strings.TrimSpace(auth.Attributes["api_key"]); v != "" {
|
|
return "", ""
|
|
}
|
|
}
|
|
return strings.TrimSpace(cfg.CodexHeaderDefaults.UserAgent), strings.TrimSpace(cfg.CodexHeaderDefaults.BetaFeatures)
|
|
}
|
|
|
|
func ensureHeaderWithPriority(target http.Header, source http.Header, key, configValue, fallbackValue string) {
|
|
if target == nil {
|
|
return
|
|
}
|
|
if strings.TrimSpace(target.Get(key)) != "" {
|
|
return
|
|
}
|
|
if source != nil {
|
|
if val := strings.TrimSpace(source.Get(key)); val != "" {
|
|
target.Set(key, val)
|
|
return
|
|
}
|
|
}
|
|
if val := strings.TrimSpace(configValue); val != "" {
|
|
target.Set(key, val)
|
|
return
|
|
}
|
|
if val := strings.TrimSpace(fallbackValue); val != "" {
|
|
target.Set(key, val)
|
|
}
|
|
}
|
|
|
|
func ensureHeaderWithConfigPrecedence(target http.Header, source http.Header, key, configValue, fallbackValue string) {
|
|
if target == nil {
|
|
return
|
|
}
|
|
if strings.TrimSpace(target.Get(key)) != "" {
|
|
return
|
|
}
|
|
if val := strings.TrimSpace(configValue); val != "" {
|
|
target.Set(key, val)
|
|
return
|
|
}
|
|
if source != nil {
|
|
if val := strings.TrimSpace(source.Get(key)); val != "" {
|
|
target.Set(key, val)
|
|
return
|
|
}
|
|
}
|
|
if val := strings.TrimSpace(fallbackValue); val != "" {
|
|
target.Set(key, val)
|
|
}
|
|
}
|
|
|
|
type statusErrWithHeaders struct {
|
|
statusErr
|
|
headers http.Header
|
|
}
|
|
|
|
func (e statusErrWithHeaders) Headers() http.Header {
|
|
if e.headers == nil {
|
|
return nil
|
|
}
|
|
return e.headers.Clone()
|
|
}
|
|
|
|
func parseCodexWebsocketError(payload []byte) (error, bool) {
|
|
if len(payload) == 0 {
|
|
return nil, false
|
|
}
|
|
if strings.TrimSpace(gjson.GetBytes(payload, "type").String()) != "error" {
|
|
return nil, false
|
|
}
|
|
status := int(gjson.GetBytes(payload, "status").Int())
|
|
if status == 0 {
|
|
status = int(gjson.GetBytes(payload, "status_code").Int())
|
|
}
|
|
if status <= 0 {
|
|
return nil, false
|
|
}
|
|
|
|
out := []byte(`{}`)
|
|
if errNode := gjson.GetBytes(payload, "error"); errNode.Exists() {
|
|
raw := errNode.Raw
|
|
if errNode.Type == gjson.String {
|
|
raw = errNode.Raw
|
|
}
|
|
out, _ = sjson.SetRawBytes(out, "error", []byte(raw))
|
|
} else {
|
|
out, _ = sjson.SetBytes(out, "error.type", "server_error")
|
|
out, _ = sjson.SetBytes(out, "error.message", http.StatusText(status))
|
|
}
|
|
|
|
headers := parseCodexWebsocketErrorHeaders(payload)
|
|
return statusErrWithHeaders{
|
|
statusErr: statusErr{code: status, msg: string(out)},
|
|
headers: headers,
|
|
}, true
|
|
}
|
|
|
|
func parseCodexWebsocketErrorHeaders(payload []byte) http.Header {
|
|
headersNode := gjson.GetBytes(payload, "headers")
|
|
if !headersNode.Exists() || !headersNode.IsObject() {
|
|
return nil
|
|
}
|
|
mapped := make(http.Header)
|
|
headersNode.ForEach(func(key, value gjson.Result) bool {
|
|
name := strings.TrimSpace(key.String())
|
|
if name == "" {
|
|
return true
|
|
}
|
|
switch value.Type {
|
|
case gjson.String:
|
|
if v := strings.TrimSpace(value.String()); v != "" {
|
|
mapped.Set(name, v)
|
|
}
|
|
case gjson.Number, gjson.True, gjson.False:
|
|
if v := strings.TrimSpace(value.Raw); v != "" {
|
|
mapped.Set(name, v)
|
|
}
|
|
default:
|
|
}
|
|
return true
|
|
})
|
|
if len(mapped) == 0 {
|
|
return nil
|
|
}
|
|
return mapped
|
|
}
|
|
|
|
func normalizeCodexWebsocketCompletion(payload []byte) []byte {
|
|
if strings.TrimSpace(gjson.GetBytes(payload, "type").String()) == "response.done" {
|
|
updated, err := sjson.SetBytes(payload, "type", "response.completed")
|
|
if err == nil && len(updated) > 0 {
|
|
return updated
|
|
}
|
|
}
|
|
return payload
|
|
}
|
|
|
|
func encodeCodexWebsocketAsSSE(payload []byte) []byte {
|
|
if len(payload) == 0 {
|
|
return nil
|
|
}
|
|
line := make([]byte, 0, len("data: ")+len(payload))
|
|
line = append(line, []byte("data: ")...)
|
|
line = append(line, payload...)
|
|
return line
|
|
}
|
|
|
|
func websocketHandshakeBody(resp *http.Response) []byte {
|
|
if resp == nil || resp.Body == nil {
|
|
return nil
|
|
}
|
|
body, _ := io.ReadAll(resp.Body)
|
|
closeHTTPResponseBody(resp, "codex websockets executor: close handshake response body error")
|
|
if len(body) == 0 {
|
|
return nil
|
|
}
|
|
return body
|
|
}
|
|
|
|
func closeHTTPResponseBody(resp *http.Response, logPrefix string) {
|
|
if resp == nil || resp.Body == nil {
|
|
return
|
|
}
|
|
if errClose := resp.Body.Close(); errClose != nil {
|
|
log.Errorf("%s: %v", logPrefix, errClose)
|
|
}
|
|
}
|
|
|
|
func executionSessionIDFromOptions(opts cliproxyexecutor.Options) string {
|
|
if len(opts.Metadata) == 0 {
|
|
return ""
|
|
}
|
|
raw, ok := opts.Metadata[cliproxyexecutor.ExecutionSessionMetadataKey]
|
|
if !ok || raw == nil {
|
|
return ""
|
|
}
|
|
switch v := raw.(type) {
|
|
case string:
|
|
return strings.TrimSpace(v)
|
|
case []byte:
|
|
return strings.TrimSpace(string(v))
|
|
default:
|
|
return ""
|
|
}
|
|
}
|
|
|
|
func (e *CodexWebsocketsExecutor) getOrCreateSession(sessionID string) *codexWebsocketSession {
|
|
sessionID = strings.TrimSpace(sessionID)
|
|
if sessionID == "" {
|
|
return nil
|
|
}
|
|
e.sessMu.Lock()
|
|
defer e.sessMu.Unlock()
|
|
if e.sessions == nil {
|
|
e.sessions = make(map[string]*codexWebsocketSession)
|
|
}
|
|
if sess, ok := e.sessions[sessionID]; ok && sess != nil {
|
|
return sess
|
|
}
|
|
sess := &codexWebsocketSession{sessionID: sessionID}
|
|
e.sessions[sessionID] = sess
|
|
return sess
|
|
}
|
|
|
|
func (e *CodexWebsocketsExecutor) ensureUpstreamConn(ctx context.Context, auth *cliproxyauth.Auth, sess *codexWebsocketSession, authID string, wsURL string, headers http.Header) (*websocket.Conn, *http.Response, error) {
|
|
if sess == nil {
|
|
return e.dialCodexWebsocket(ctx, auth, wsURL, headers)
|
|
}
|
|
|
|
sess.connMu.Lock()
|
|
conn := sess.conn
|
|
readerConn := sess.readerConn
|
|
sess.connMu.Unlock()
|
|
if conn != nil {
|
|
if readerConn != conn {
|
|
sess.connMu.Lock()
|
|
sess.readerConn = conn
|
|
sess.connMu.Unlock()
|
|
sess.configureConn(conn)
|
|
go e.readUpstreamLoop(sess, conn)
|
|
}
|
|
return conn, nil, nil
|
|
}
|
|
|
|
conn, resp, errDial := e.dialCodexWebsocket(ctx, auth, wsURL, headers)
|
|
if errDial != nil {
|
|
return nil, resp, errDial
|
|
}
|
|
|
|
sess.connMu.Lock()
|
|
if sess.conn != nil {
|
|
previous := sess.conn
|
|
sess.connMu.Unlock()
|
|
if errClose := conn.Close(); errClose != nil {
|
|
log.Errorf("codex websockets executor: close websocket error: %v", errClose)
|
|
}
|
|
return previous, nil, nil
|
|
}
|
|
sess.conn = conn
|
|
sess.wsURL = wsURL
|
|
sess.authID = authID
|
|
sess.readerConn = conn
|
|
sess.connMu.Unlock()
|
|
|
|
sess.configureConn(conn)
|
|
go e.readUpstreamLoop(sess, conn)
|
|
logCodexWebsocketConnected(sess.sessionID, authID, wsURL)
|
|
return conn, resp, nil
|
|
}
|
|
|
|
func (e *CodexWebsocketsExecutor) readUpstreamLoop(sess *codexWebsocketSession, conn *websocket.Conn) {
|
|
if e == nil || sess == nil || conn == nil {
|
|
return
|
|
}
|
|
for {
|
|
_ = conn.SetReadDeadline(time.Now().Add(codexResponsesWebsocketIdleTimeout))
|
|
msgType, payload, errRead := conn.ReadMessage()
|
|
if errRead != nil {
|
|
sess.activeMu.Lock()
|
|
ch := sess.activeCh
|
|
done := sess.activeDone
|
|
sess.activeMu.Unlock()
|
|
if ch != nil {
|
|
select {
|
|
case ch <- codexWebsocketRead{conn: conn, err: errRead}:
|
|
case <-done:
|
|
default:
|
|
}
|
|
sess.clearActive(ch)
|
|
close(ch)
|
|
}
|
|
e.invalidateUpstreamConn(sess, conn, "upstream_disconnected", errRead)
|
|
return
|
|
}
|
|
|
|
if msgType != websocket.TextMessage {
|
|
if msgType == websocket.BinaryMessage {
|
|
errBinary := fmt.Errorf("codex websockets executor: unexpected binary message")
|
|
sess.activeMu.Lock()
|
|
ch := sess.activeCh
|
|
done := sess.activeDone
|
|
sess.activeMu.Unlock()
|
|
if ch != nil {
|
|
select {
|
|
case ch <- codexWebsocketRead{conn: conn, err: errBinary}:
|
|
case <-done:
|
|
default:
|
|
}
|
|
sess.clearActive(ch)
|
|
close(ch)
|
|
}
|
|
e.invalidateUpstreamConn(sess, conn, "unexpected_binary", errBinary)
|
|
return
|
|
}
|
|
continue
|
|
}
|
|
|
|
sess.activeMu.Lock()
|
|
ch := sess.activeCh
|
|
done := sess.activeDone
|
|
sess.activeMu.Unlock()
|
|
if ch == nil {
|
|
continue
|
|
}
|
|
select {
|
|
case ch <- codexWebsocketRead{conn: conn, msgType: msgType, payload: payload}:
|
|
case <-done:
|
|
}
|
|
}
|
|
}
|
|
|
|
func (e *CodexWebsocketsExecutor) invalidateUpstreamConn(sess *codexWebsocketSession, conn *websocket.Conn, reason string, err error) {
|
|
if sess == nil || conn == nil {
|
|
return
|
|
}
|
|
|
|
sess.connMu.Lock()
|
|
current := sess.conn
|
|
authID := sess.authID
|
|
wsURL := sess.wsURL
|
|
sessionID := sess.sessionID
|
|
if current == nil || current != conn {
|
|
sess.connMu.Unlock()
|
|
return
|
|
}
|
|
sess.conn = nil
|
|
if sess.readerConn == conn {
|
|
sess.readerConn = nil
|
|
}
|
|
sess.connMu.Unlock()
|
|
|
|
logCodexWebsocketDisconnected(sessionID, authID, wsURL, reason, err)
|
|
if errClose := conn.Close(); errClose != nil {
|
|
log.Errorf("codex websockets executor: close websocket error: %v", errClose)
|
|
}
|
|
}
|
|
|
|
func (e *CodexWebsocketsExecutor) CloseExecutionSession(sessionID string) {
|
|
sessionID = strings.TrimSpace(sessionID)
|
|
if e == nil {
|
|
return
|
|
}
|
|
if sessionID == "" {
|
|
return
|
|
}
|
|
if sessionID == cliproxyauth.CloseAllExecutionSessionsID {
|
|
e.closeAllExecutionSessions("executor_replaced")
|
|
return
|
|
}
|
|
|
|
e.sessMu.Lock()
|
|
sess := e.sessions[sessionID]
|
|
delete(e.sessions, sessionID)
|
|
e.sessMu.Unlock()
|
|
|
|
e.closeExecutionSession(sess, "session_closed")
|
|
}
|
|
|
|
func (e *CodexWebsocketsExecutor) closeAllExecutionSessions(reason string) {
|
|
if e == nil {
|
|
return
|
|
}
|
|
|
|
e.sessMu.Lock()
|
|
sessions := make([]*codexWebsocketSession, 0, len(e.sessions))
|
|
for sessionID, sess := range e.sessions {
|
|
delete(e.sessions, sessionID)
|
|
if sess != nil {
|
|
sessions = append(sessions, sess)
|
|
}
|
|
}
|
|
e.sessMu.Unlock()
|
|
|
|
for i := range sessions {
|
|
e.closeExecutionSession(sessions[i], reason)
|
|
}
|
|
}
|
|
|
|
func (e *CodexWebsocketsExecutor) closeExecutionSession(sess *codexWebsocketSession, reason string) {
|
|
if sess == nil {
|
|
return
|
|
}
|
|
reason = strings.TrimSpace(reason)
|
|
if reason == "" {
|
|
reason = "session_closed"
|
|
}
|
|
|
|
sess.connMu.Lock()
|
|
conn := sess.conn
|
|
authID := sess.authID
|
|
wsURL := sess.wsURL
|
|
sess.conn = nil
|
|
if sess.readerConn == conn {
|
|
sess.readerConn = nil
|
|
}
|
|
sessionID := sess.sessionID
|
|
sess.connMu.Unlock()
|
|
|
|
if conn == nil {
|
|
return
|
|
}
|
|
logCodexWebsocketDisconnected(sessionID, authID, wsURL, reason, nil)
|
|
if errClose := conn.Close(); errClose != nil {
|
|
log.Errorf("codex websockets executor: close websocket error: %v", errClose)
|
|
}
|
|
}
|
|
|
|
func logCodexWebsocketConnected(sessionID string, authID string, wsURL string) {
|
|
log.Infof("codex websockets: upstream connected session=%s auth=%s url=%s", strings.TrimSpace(sessionID), strings.TrimSpace(authID), strings.TrimSpace(wsURL))
|
|
}
|
|
|
|
func logCodexWebsocketDisconnected(sessionID string, authID string, wsURL string, reason string, err error) {
|
|
if err != nil {
|
|
log.Infof("codex websockets: upstream disconnected session=%s auth=%s url=%s reason=%s err=%v", strings.TrimSpace(sessionID), strings.TrimSpace(authID), strings.TrimSpace(wsURL), strings.TrimSpace(reason), err)
|
|
return
|
|
}
|
|
log.Infof("codex websockets: upstream disconnected session=%s auth=%s url=%s reason=%s", strings.TrimSpace(sessionID), strings.TrimSpace(authID), strings.TrimSpace(wsURL), strings.TrimSpace(reason))
|
|
}
|
|
|
|
// CodexAutoExecutor routes Codex requests to the websocket transport only when:
|
|
// 1. The downstream transport is websocket, and
|
|
// 2. The selected auth enables websockets.
|
|
//
|
|
// For non-websocket downstream requests, it always uses the legacy HTTP implementation.
|
|
type CodexAutoExecutor struct {
|
|
httpExec *CodexExecutor
|
|
wsExec *CodexWebsocketsExecutor
|
|
}
|
|
|
|
func NewCodexAutoExecutor(cfg *config.Config) *CodexAutoExecutor {
|
|
return &CodexAutoExecutor{
|
|
httpExec: NewCodexExecutor(cfg),
|
|
wsExec: NewCodexWebsocketsExecutor(cfg),
|
|
}
|
|
}
|
|
|
|
func (e *CodexAutoExecutor) Identifier() string { return "codex" }
|
|
|
|
func (e *CodexAutoExecutor) PrepareRequest(req *http.Request, auth *cliproxyauth.Auth) error {
|
|
if e == nil || e.httpExec == nil {
|
|
return nil
|
|
}
|
|
return e.httpExec.PrepareRequest(req, auth)
|
|
}
|
|
|
|
func (e *CodexAutoExecutor) HttpRequest(ctx context.Context, auth *cliproxyauth.Auth, req *http.Request) (*http.Response, error) {
|
|
if e == nil || e.httpExec == nil {
|
|
return nil, fmt.Errorf("codex auto executor: http executor is nil")
|
|
}
|
|
return e.httpExec.HttpRequest(ctx, auth, req)
|
|
}
|
|
|
|
func (e *CodexAutoExecutor) Execute(ctx context.Context, auth *cliproxyauth.Auth, req cliproxyexecutor.Request, opts cliproxyexecutor.Options) (cliproxyexecutor.Response, error) {
|
|
if e == nil || e.httpExec == nil || e.wsExec == nil {
|
|
return cliproxyexecutor.Response{}, fmt.Errorf("codex auto executor: executor is nil")
|
|
}
|
|
if cliproxyexecutor.DownstreamWebsocket(ctx) && codexWebsocketsEnabled(auth) {
|
|
return e.wsExec.Execute(ctx, auth, req, opts)
|
|
}
|
|
return e.httpExec.Execute(ctx, auth, req, opts)
|
|
}
|
|
|
|
func (e *CodexAutoExecutor) ExecuteStream(ctx context.Context, auth *cliproxyauth.Auth, req cliproxyexecutor.Request, opts cliproxyexecutor.Options) (*cliproxyexecutor.StreamResult, error) {
|
|
if e == nil || e.httpExec == nil || e.wsExec == nil {
|
|
return nil, fmt.Errorf("codex auto executor: executor is nil")
|
|
}
|
|
if cliproxyexecutor.DownstreamWebsocket(ctx) && codexWebsocketsEnabled(auth) {
|
|
return e.wsExec.ExecuteStream(ctx, auth, req, opts)
|
|
}
|
|
return e.httpExec.ExecuteStream(ctx, auth, req, opts)
|
|
}
|
|
|
|
func (e *CodexAutoExecutor) Refresh(ctx context.Context, auth *cliproxyauth.Auth) (*cliproxyauth.Auth, error) {
|
|
if e == nil || e.httpExec == nil {
|
|
return nil, fmt.Errorf("codex auto executor: http executor is nil")
|
|
}
|
|
return e.httpExec.Refresh(ctx, auth)
|
|
}
|
|
|
|
func (e *CodexAutoExecutor) CountTokens(ctx context.Context, auth *cliproxyauth.Auth, req cliproxyexecutor.Request, opts cliproxyexecutor.Options) (cliproxyexecutor.Response, error) {
|
|
if e == nil || e.httpExec == nil {
|
|
return cliproxyexecutor.Response{}, fmt.Errorf("codex auto executor: http executor is nil")
|
|
}
|
|
return e.httpExec.CountTokens(ctx, auth, req, opts)
|
|
}
|
|
|
|
func (e *CodexAutoExecutor) CloseExecutionSession(sessionID string) {
|
|
if e == nil || e.wsExec == nil {
|
|
return
|
|
}
|
|
e.wsExec.CloseExecutionSession(sessionID)
|
|
}
|
|
|
|
func codexWebsocketsEnabled(auth *cliproxyauth.Auth) bool {
|
|
if auth == nil {
|
|
return false
|
|
}
|
|
if len(auth.Attributes) > 0 {
|
|
if raw := strings.TrimSpace(auth.Attributes["websockets"]); raw != "" {
|
|
parsed, errParse := strconv.ParseBool(raw)
|
|
if errParse == nil {
|
|
return parsed
|
|
}
|
|
}
|
|
}
|
|
if len(auth.Metadata) == 0 {
|
|
return false
|
|
}
|
|
raw, ok := auth.Metadata["websockets"]
|
|
if !ok || raw == nil {
|
|
return false
|
|
}
|
|
switch v := raw.(type) {
|
|
case bool:
|
|
return v
|
|
case string:
|
|
parsed, errParse := strconv.ParseBool(strings.TrimSpace(v))
|
|
if errParse == nil {
|
|
return parsed
|
|
}
|
|
default:
|
|
}
|
|
return false
|
|
}
|