mirror of
https://github.com/router-for-me/CLIProxyAPIPlus.git
synced 2026-04-13 01:34:18 +00:00
fix(codex): centralize session management with global store and add tests for executor session lifecycle
This commit is contained in:
@@ -46,10 +46,18 @@ const (
|
||||
type CodexWebsocketsExecutor struct {
|
||||
*CodexExecutor
|
||||
|
||||
sessMu sync.Mutex
|
||||
store *codexWebsocketSessionStore
|
||||
}
|
||||
|
||||
type codexWebsocketSessionStore struct {
|
||||
mu sync.Mutex
|
||||
sessions map[string]*codexWebsocketSession
|
||||
}
|
||||
|
||||
var globalCodexWebsocketSessionStore = &codexWebsocketSessionStore{
|
||||
sessions: make(map[string]*codexWebsocketSession),
|
||||
}
|
||||
|
||||
type codexWebsocketSession struct {
|
||||
sessionID string
|
||||
|
||||
@@ -73,7 +81,7 @@ type codexWebsocketSession struct {
|
||||
func NewCodexWebsocketsExecutor(cfg *config.Config) *CodexWebsocketsExecutor {
|
||||
return &CodexWebsocketsExecutor{
|
||||
CodexExecutor: NewCodexExecutor(cfg),
|
||||
sessions: make(map[string]*codexWebsocketSession),
|
||||
store: globalCodexWebsocketSessionStore,
|
||||
}
|
||||
}
|
||||
|
||||
@@ -1058,16 +1066,23 @@ func (e *CodexWebsocketsExecutor) getOrCreateSession(sessionID string) *codexWeb
|
||||
if sessionID == "" {
|
||||
return nil
|
||||
}
|
||||
e.sessMu.Lock()
|
||||
defer e.sessMu.Unlock()
|
||||
if e.sessions == nil {
|
||||
e.sessions = make(map[string]*codexWebsocketSession)
|
||||
if e == nil {
|
||||
return nil
|
||||
}
|
||||
if sess, ok := e.sessions[sessionID]; ok && sess != nil {
|
||||
store := e.store
|
||||
if store == nil {
|
||||
store = globalCodexWebsocketSessionStore
|
||||
}
|
||||
store.mu.Lock()
|
||||
defer store.mu.Unlock()
|
||||
if store.sessions == nil {
|
||||
store.sessions = make(map[string]*codexWebsocketSession)
|
||||
}
|
||||
if sess, ok := store.sessions[sessionID]; ok && sess != nil {
|
||||
return sess
|
||||
}
|
||||
sess := &codexWebsocketSession{sessionID: sessionID}
|
||||
e.sessions[sessionID] = sess
|
||||
store.sessions[sessionID] = sess
|
||||
return sess
|
||||
}
|
||||
|
||||
@@ -1213,14 +1228,20 @@ func (e *CodexWebsocketsExecutor) CloseExecutionSession(sessionID string) {
|
||||
return
|
||||
}
|
||||
if sessionID == cliproxyauth.CloseAllExecutionSessionsID {
|
||||
e.closeAllExecutionSessions("executor_replaced")
|
||||
// Executor replacement can happen during hot reload (config/credential changes).
|
||||
// Do not force-close upstream websocket sessions here, otherwise in-flight
|
||||
// downstream websocket requests get interrupted.
|
||||
return
|
||||
}
|
||||
|
||||
e.sessMu.Lock()
|
||||
sess := e.sessions[sessionID]
|
||||
delete(e.sessions, sessionID)
|
||||
e.sessMu.Unlock()
|
||||
store := e.store
|
||||
if store == nil {
|
||||
store = globalCodexWebsocketSessionStore
|
||||
}
|
||||
store.mu.Lock()
|
||||
sess := store.sessions[sessionID]
|
||||
delete(store.sessions, sessionID)
|
||||
store.mu.Unlock()
|
||||
|
||||
e.closeExecutionSession(sess, "session_closed")
|
||||
}
|
||||
@@ -1230,15 +1251,19 @@ func (e *CodexWebsocketsExecutor) closeAllExecutionSessions(reason string) {
|
||||
return
|
||||
}
|
||||
|
||||
e.sessMu.Lock()
|
||||
sessions := make([]*codexWebsocketSession, 0, len(e.sessions))
|
||||
for sessionID, sess := range e.sessions {
|
||||
delete(e.sessions, sessionID)
|
||||
store := e.store
|
||||
if store == nil {
|
||||
store = globalCodexWebsocketSessionStore
|
||||
}
|
||||
store.mu.Lock()
|
||||
sessions := make([]*codexWebsocketSession, 0, len(store.sessions))
|
||||
for sessionID, sess := range store.sessions {
|
||||
delete(store.sessions, sessionID)
|
||||
if sess != nil {
|
||||
sessions = append(sessions, sess)
|
||||
}
|
||||
}
|
||||
e.sessMu.Unlock()
|
||||
store.mu.Unlock()
|
||||
|
||||
for i := range sessions {
|
||||
e.closeExecutionSession(sessions[i], reason)
|
||||
@@ -1246,6 +1271,10 @@ func (e *CodexWebsocketsExecutor) closeAllExecutionSessions(reason string) {
|
||||
}
|
||||
|
||||
func (e *CodexWebsocketsExecutor) closeExecutionSession(sess *codexWebsocketSession, reason string) {
|
||||
closeCodexWebsocketSession(sess, reason)
|
||||
}
|
||||
|
||||
func closeCodexWebsocketSession(sess *codexWebsocketSession, reason string) {
|
||||
if sess == nil {
|
||||
return
|
||||
}
|
||||
@@ -1286,6 +1315,69 @@ func logCodexWebsocketDisconnected(sessionID string, authID string, wsURL string
|
||||
log.Infof("codex websockets: upstream disconnected session=%s auth=%s url=%s reason=%s", strings.TrimSpace(sessionID), strings.TrimSpace(authID), strings.TrimSpace(wsURL), strings.TrimSpace(reason))
|
||||
}
|
||||
|
||||
// CloseCodexWebsocketSessionsForAuthID closes all active Codex upstream websocket sessions
|
||||
// associated with the supplied auth ID.
|
||||
func CloseCodexWebsocketSessionsForAuthID(authID string, reason string) {
|
||||
authID = strings.TrimSpace(authID)
|
||||
if authID == "" {
|
||||
return
|
||||
}
|
||||
reason = strings.TrimSpace(reason)
|
||||
if reason == "" {
|
||||
reason = "auth_removed"
|
||||
}
|
||||
|
||||
store := globalCodexWebsocketSessionStore
|
||||
if store == nil {
|
||||
return
|
||||
}
|
||||
|
||||
type sessionItem struct {
|
||||
sessionID string
|
||||
sess *codexWebsocketSession
|
||||
}
|
||||
|
||||
store.mu.Lock()
|
||||
items := make([]sessionItem, 0, len(store.sessions))
|
||||
for sessionID, sess := range store.sessions {
|
||||
items = append(items, sessionItem{sessionID: sessionID, sess: sess})
|
||||
}
|
||||
store.mu.Unlock()
|
||||
|
||||
matches := make([]sessionItem, 0)
|
||||
for i := range items {
|
||||
sess := items[i].sess
|
||||
if sess == nil {
|
||||
continue
|
||||
}
|
||||
sess.connMu.Lock()
|
||||
sessAuthID := strings.TrimSpace(sess.authID)
|
||||
sess.connMu.Unlock()
|
||||
if sessAuthID == authID {
|
||||
matches = append(matches, items[i])
|
||||
}
|
||||
}
|
||||
if len(matches) == 0 {
|
||||
return
|
||||
}
|
||||
|
||||
toClose := make([]*codexWebsocketSession, 0, len(matches))
|
||||
store.mu.Lock()
|
||||
for i := range matches {
|
||||
current, ok := store.sessions[matches[i].sessionID]
|
||||
if !ok || current == nil || current != matches[i].sess {
|
||||
continue
|
||||
}
|
||||
delete(store.sessions, matches[i].sessionID)
|
||||
toClose = append(toClose, current)
|
||||
}
|
||||
store.mu.Unlock()
|
||||
|
||||
for i := range toClose {
|
||||
closeCodexWebsocketSession(toClose[i], reason)
|
||||
}
|
||||
}
|
||||
|
||||
// CodexAutoExecutor routes Codex requests to the websocket transport only when:
|
||||
// 1. The downstream transport is websocket, and
|
||||
// 2. The selected auth enables websockets.
|
||||
|
||||
@@ -0,0 +1,48 @@
|
||||
package executor
|
||||
|
||||
import (
|
||||
"testing"
|
||||
|
||||
cliproxyauth "github.com/router-for-me/CLIProxyAPI/v6/sdk/cliproxy/auth"
|
||||
)
|
||||
|
||||
func TestCodexWebsocketsExecutor_SessionStoreSurvivesExecutorReplacement(t *testing.T) {
|
||||
sessionID := "test-session-store-survives-replace"
|
||||
|
||||
globalCodexWebsocketSessionStore.mu.Lock()
|
||||
delete(globalCodexWebsocketSessionStore.sessions, sessionID)
|
||||
globalCodexWebsocketSessionStore.mu.Unlock()
|
||||
|
||||
exec1 := NewCodexWebsocketsExecutor(nil)
|
||||
sess1 := exec1.getOrCreateSession(sessionID)
|
||||
if sess1 == nil {
|
||||
t.Fatalf("expected session to be created")
|
||||
}
|
||||
|
||||
exec2 := NewCodexWebsocketsExecutor(nil)
|
||||
sess2 := exec2.getOrCreateSession(sessionID)
|
||||
if sess2 == nil {
|
||||
t.Fatalf("expected session to be available across executors")
|
||||
}
|
||||
if sess1 != sess2 {
|
||||
t.Fatalf("expected the same session instance across executors")
|
||||
}
|
||||
|
||||
exec1.CloseExecutionSession(cliproxyauth.CloseAllExecutionSessionsID)
|
||||
|
||||
globalCodexWebsocketSessionStore.mu.Lock()
|
||||
_, stillPresent := globalCodexWebsocketSessionStore.sessions[sessionID]
|
||||
globalCodexWebsocketSessionStore.mu.Unlock()
|
||||
if !stillPresent {
|
||||
t.Fatalf("expected session to remain after executor replacement close marker")
|
||||
}
|
||||
|
||||
exec2.CloseExecutionSession(sessionID)
|
||||
|
||||
globalCodexWebsocketSessionStore.mu.Lock()
|
||||
_, presentAfterClose := globalCodexWebsocketSessionStore.sessions[sessionID]
|
||||
globalCodexWebsocketSessionStore.mu.Unlock()
|
||||
if presentAfterClose {
|
||||
t.Fatalf("expected session to be removed after explicit close")
|
||||
}
|
||||
}
|
||||
@@ -335,6 +335,7 @@ func (s *Service) applyCoreAuthRemoval(ctx context.Context, id string) {
|
||||
log.Errorf("failed to disable auth %s: %v", id, err)
|
||||
}
|
||||
if strings.EqualFold(strings.TrimSpace(existing.Provider), "codex") {
|
||||
executor.CloseCodexWebsocketSessionsForAuthID(existing.ID, "auth_removed")
|
||||
s.ensureExecutorsForAuth(existing)
|
||||
}
|
||||
}
|
||||
|
||||
Reference in New Issue
Block a user