diff --git a/internal/runtime/executor/codex_continuity.go b/internal/runtime/executor/codex_continuity.go index 3ebb721f..e2fa8de0 100644 --- a/internal/runtime/executor/codex_continuity.go +++ b/internal/runtime/executor/codex_continuity.go @@ -57,9 +57,6 @@ func resolveCodexContinuity(ctx context.Context, auth *cliproxyauth.Auth, req cl if executionSession := metadataString(opts.Metadata, cliproxyexecutor.ExecutionSessionMetadataKey); executionSession != "" { return codexContinuity{Key: executionSession, Source: "execution_session"} } - if affinityKey := metadataString(opts.Metadata, codexAuthAffinityMetadataKey); affinityKey != "" { - return codexContinuity{Key: affinityKey, Source: "auth_affinity"} - } if ginCtx := ginContextFrom(ctx); ginCtx != nil { if ginCtx.Request != nil { if v := strings.TrimSpace(ginCtx.GetHeader("Idempotency-Key")); v != "" { diff --git a/internal/runtime/executor/codex_executor.go b/internal/runtime/executor/codex_executor.go index 766a081a..5f06ace2 100644 --- a/internal/runtime/executor/codex_executor.go +++ b/internal/runtime/executor/codex_executor.go @@ -612,6 +612,7 @@ func (e *CodexExecutor) cacheHelper(ctx context.Context, auth *cliproxyauth.Auth } setCodexCache(key, cache) } + continuity = codexContinuity{Key: cache.ID, Source: "claude_user_cache"} } } else if from == "openai-response" { promptCacheKey := gjson.GetBytes(req.Payload, "prompt_cache_key") diff --git a/internal/runtime/executor/codex_executor_cache_test.go b/internal/runtime/executor/codex_executor_cache_test.go index 116b06ff..8c61a22e 100644 --- a/internal/runtime/executor/codex_executor_cache_test.go +++ b/internal/runtime/executor/codex_executor_cache_test.go @@ -151,3 +151,45 @@ func TestCodexExecutorCacheHelper_OpenAIChatCompletions_FallsBackToStableAuthID( t.Fatalf("session_id = %q, want %q", got, expected) } } + +func TestCodexExecutorCacheHelper_ClaudePreservesCacheContinuity(t *testing.T) { + executor := &CodexExecutor{} + req := cliproxyexecutor.Request{ + Model: "claude-3-7-sonnet", + Payload: []byte(`{"metadata":{"user_id":"user-1"}}`), + } + rawJSON := []byte(`{"model":"gpt-5.4","stream":true}`) + + httpReq, continuity, err := executor.cacheHelper(context.Background(), nil, sdktranslator.FromString("claude"), "https://example.com/responses", req, cliproxyexecutor.Options{}, rawJSON) + if err != nil { + t.Fatalf("cacheHelper error: %v", err) + } + if continuity.Key == "" { + t.Fatal("continuity.Key = empty, want non-empty") + } + body, err := io.ReadAll(httpReq.Body) + if err != nil { + t.Fatalf("read request body: %v", err) + } + if got := gjson.GetBytes(body, "prompt_cache_key").String(); got != continuity.Key { + t.Fatalf("prompt_cache_key = %q, want %q", got, continuity.Key) + } + if got := httpReq.Header.Get("session_id"); got != continuity.Key { + t.Fatalf("session_id = %q, want %q", got, continuity.Key) + } +} + +func TestResolveCodexContinuity_DoesNotForwardAuthAffinityKey(t *testing.T) { + req := cliproxyexecutor.Request{Payload: []byte(`{"model":"gpt-5.4"}`)} + opts := cliproxyexecutor.Options{Metadata: map[string]any{"auth_affinity_key": "principal:raw-client-secret"}} + auth := &cliproxyauth.Auth{ID: "codex-auth-1", Provider: "codex"} + + continuity := resolveCodexContinuity(context.Background(), auth, req, opts) + + if continuity.Source != "auth_id" { + t.Fatalf("continuity.Source = %q, want %q", continuity.Source, "auth_id") + } + if continuity.Key == "principal:raw-client-secret" { + t.Fatal("continuity.Key leaked raw auth affinity key") + } +} diff --git a/internal/runtime/executor/codex_websockets_executor.go b/internal/runtime/executor/codex_websockets_executor.go index d0dd22c3..50cc736d 100644 --- a/internal/runtime/executor/codex_websockets_executor.go +++ b/internal/runtime/executor/codex_websockets_executor.go @@ -783,6 +783,7 @@ func applyCodexPromptCacheHeaders(ctx context.Context, auth *cliproxyauth.Auth, } setCodexCache(key, cache) } + continuity = codexContinuity{Key: cache.ID, Source: "claude_user_cache"} } } else if from == "openai-response" { if promptCacheKey := gjson.GetBytes(req.Payload, "prompt_cache_key"); promptCacheKey.Exists() { diff --git a/internal/runtime/executor/codex_websockets_executor_test.go b/internal/runtime/executor/codex_websockets_executor_test.go index e86036bc..0a06982f 100644 --- a/internal/runtime/executor/codex_websockets_executor_test.go +++ b/internal/runtime/executor/codex_websockets_executor_test.go @@ -57,6 +57,26 @@ func TestApplyCodexPromptCacheHeaders_PreservesPromptCacheRetention(t *testing.T } } +func TestApplyCodexPromptCacheHeaders_ClaudePreservesContinuity(t *testing.T) { + req := cliproxyexecutor.Request{ + Model: "claude-3-7-sonnet", + Payload: []byte(`{"metadata":{"user_id":"user-1"}}`), + } + body := []byte(`{"model":"gpt-5.4","stream":true}`) + + updatedBody, headers, continuity := applyCodexPromptCacheHeaders(context.Background(), nil, sdktranslator.FromString("claude"), req, cliproxyexecutor.Options{}, body) + + if continuity.Key == "" { + t.Fatal("continuity.Key = empty, want non-empty") + } + if got := gjson.GetBytes(updatedBody, "prompt_cache_key").String(); got != continuity.Key { + t.Fatalf("prompt_cache_key = %q, want %q", got, continuity.Key) + } + if got := headers.Get("session_id"); got != continuity.Key { + t.Fatalf("session_id = %q, want %q", got, continuity.Key) + } +} + func TestApplyCodexWebsocketHeadersDefaultsToCurrentResponsesBeta(t *testing.T) { headers := applyCodexWebsocketHeaders(context.Background(), http.Header{}, nil, "", nil) diff --git a/sdk/cliproxy/auth/conductor.go b/sdk/cliproxy/auth/conductor.go index d7736cf4..6ef13baa 100644 --- a/sdk/cliproxy/auth/conductor.go +++ b/sdk/cliproxy/auth/conductor.go @@ -1093,12 +1093,6 @@ func (m *Manager) executeMixedOnce(ctx context.Context, providers []string, req entry := logEntryWithRequestID(ctx) debugLogAuthSelection(entry, auth, provider, req.Model) publishSelectedAuthMetadata(opts.Metadata, auth.ID) - if affinityKey := authAffinityKeyFromMetadata(opts.Metadata); affinityKey != "" { - m.SetAuthAffinity(affinityKey, auth.ID) - if log.IsLevelEnabled(log.DebugLevel) { - entry.Debugf("auth affinity pinned key=%s auth_id=%s provider=%s model=%s", affinityKey, auth.ID, provider, req.Model) - } - } tried[auth.ID] = struct{}{} execCtx := ctx @@ -1138,6 +1132,7 @@ func (m *Manager) executeMixedOnce(ctx context.Context, providers []string, req continue } m.MarkResult(execCtx, result) + m.persistAuthAffinity(entry, opts, auth.ID, provider, req.Model) return resp, nil } if authErr != nil { @@ -1177,12 +1172,6 @@ func (m *Manager) executeCountMixedOnce(ctx context.Context, providers []string, entry := logEntryWithRequestID(ctx) debugLogAuthSelection(entry, auth, provider, req.Model) publishSelectedAuthMetadata(opts.Metadata, auth.ID) - if affinityKey := authAffinityKeyFromMetadata(opts.Metadata); affinityKey != "" { - m.SetAuthAffinity(affinityKey, auth.ID) - if log.IsLevelEnabled(log.DebugLevel) { - entry.Debugf("auth affinity pinned key=%s auth_id=%s provider=%s model=%s", affinityKey, auth.ID, provider, req.Model) - } - } tried[auth.ID] = struct{}{} execCtx := ctx @@ -1222,6 +1211,7 @@ func (m *Manager) executeCountMixedOnce(ctx context.Context, providers []string, continue } m.MarkResult(execCtx, result) + m.persistAuthAffinity(entry, opts, auth.ID, provider, req.Model) return resp, nil } if authErr != nil { @@ -1269,12 +1259,6 @@ func (m *Manager) executeStreamMixedOnce(ctx context.Context, providers []string entry := logEntryWithRequestID(ctx) debugLogAuthSelection(entry, auth, provider, req.Model) publishSelectedAuthMetadata(opts.Metadata, auth.ID) - if affinityKey := authAffinityKeyFromMetadata(opts.Metadata); affinityKey != "" { - m.SetAuthAffinity(affinityKey, auth.ID) - if log.IsLevelEnabled(log.DebugLevel) { - entry.Debugf("auth affinity pinned key=%s auth_id=%s provider=%s model=%s", affinityKey, auth.ID, provider, req.Model) - } - } tried[auth.ID] = struct{}{} execCtx := ctx @@ -1298,6 +1282,7 @@ func (m *Manager) executeStreamMixedOnce(ctx context.Context, providers []string lastErr = errStream continue } + m.persistAuthAffinity(entry, opts, auth.ID, provider, req.Model) return streamResult, nil } } @@ -2285,6 +2270,18 @@ func (m *Manager) applyAuthAffinity(opts *cliproxyexecutor.Options) { } } +func (m *Manager) persistAuthAffinity(entry *log.Entry, opts cliproxyexecutor.Options, authID, provider, model string) { + if m == nil { + return + } + if affinityKey := authAffinityKeyFromMetadata(opts.Metadata); affinityKey != "" { + m.SetAuthAffinity(affinityKey, authID) + if entry != nil && log.IsLevelEnabled(log.DebugLevel) { + entry.Debugf("auth affinity pinned key=%s auth_id=%s provider=%s model=%s", affinityKey, authID, provider, model) + } + } +} + func (m *Manager) SetAuthAffinity(key, authID string) { key = strings.TrimSpace(key) authID = strings.TrimSpace(authID)