From 07d6689d87545f34666f1ba491a2ed9d968cd7ba Mon Sep 17 00:00:00 2001 From: Blue-B Date: Sat, 7 Mar 2026 21:31:10 +0900 Subject: [PATCH 01/21] fix(claude): add interleaved-thinking beta header, AMP gzip error decoding, normalizeClaudeBudget max_tokens MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit 1. Always include interleaved-thinking-2025-05-14 beta header so that thinking blocks are returned correctly for all Claude models. 2. Remove status-code guard in AMP reverse proxy ModifyResponse so that error responses (4xx/5xx) with hidden gzip encoding are decoded properly — prevents garbled error messages reaching the client. 3. In normalizeClaudeBudget, when the adjusted budget falls below the model minimum, set max_tokens = budgetTokens+1 instead of leaving the request unchanged (which causes a 400 from the API). --- internal/api/modules/amp/proxy.go | 5 ----- internal/runtime/executor/claude_executor.go | 3 +++ internal/thinking/provider/claude/apply.go | 4 +++- 3 files changed, 6 insertions(+), 6 deletions(-) diff --git a/internal/api/modules/amp/proxy.go b/internal/api/modules/amp/proxy.go index ecc9da77..c8010854 100644 --- a/internal/api/modules/amp/proxy.go +++ b/internal/api/modules/amp/proxy.go @@ -108,11 +108,6 @@ func createReverseProxy(upstreamURL string, secretSource SecretSource) (*httputi // Modify incoming responses to handle gzip without Content-Encoding // This addresses the same issue as inline handler gzip handling, but at the proxy level proxy.ModifyResponse = func(resp *http.Response) error { - // Only process successful responses - if resp.StatusCode < 200 || resp.StatusCode >= 300 { - return nil - } - // Skip if already marked as gzip (Content-Encoding set) if resp.Header.Get("Content-Encoding") != "" { return nil diff --git a/internal/runtime/executor/claude_executor.go b/internal/runtime/executor/claude_executor.go index 7d0ddcf2..8cdbbf4f 100644 --- a/internal/runtime/executor/claude_executor.go +++ b/internal/runtime/executor/claude_executor.go @@ -832,6 +832,9 @@ func applyClaudeHeaders(r *http.Request, auth *cliproxyauth.Auth, apiKey string, baseBetas += ",oauth-2025-04-20" } } + if !strings.Contains(baseBetas, "interleaved-thinking") { + baseBetas += ",interleaved-thinking-2025-05-14" + } hasClaude1MHeader := false if ginHeaders != nil { diff --git a/internal/thinking/provider/claude/apply.go b/internal/thinking/provider/claude/apply.go index 275be469..af031907 100644 --- a/internal/thinking/provider/claude/apply.go +++ b/internal/thinking/provider/claude/apply.go @@ -194,7 +194,9 @@ func (a *Applier) normalizeClaudeBudget(body []byte, budgetTokens int, modelInfo } if minBudget > 0 && adjustedBudget > 0 && adjustedBudget < minBudget { // If enforcing the max_tokens constraint would push the budget below the model minimum, - // leave the request unchanged. + // increase max_tokens to accommodate the original budget instead of leaving the + // request unchanged (which would cause a 400 error from the API). + body, _ = sjson.SetBytes(body, "max_tokens", budgetTokens+1) return body } From 5f58248016c33c7a3cc01691abf2452e4810f7e7 Mon Sep 17 00:00:00 2001 From: Blue-B Date: Mon, 9 Mar 2026 22:10:30 +0900 Subject: [PATCH 02/21] fix(claude): clamp max_tokens to model limit in normalizeClaudeBudget When adjustedBudget < minBudget, the previous fix blindly set max_tokens = budgetTokens+1 which could exceed MaxCompletionTokens. Now: cap max_tokens at MaxCompletionTokens, recalculate budget, and disable thinking entirely if constraints are unsatisfiable. Add unit tests covering raise, clamp, disable, and no-op scenarios. --- internal/thinking/provider/claude/apply.go | 29 +++++- .../thinking/provider/claude/apply_test.go | 99 +++++++++++++++++++ 2 files changed, 123 insertions(+), 5 deletions(-) create mode 100644 internal/thinking/provider/claude/apply_test.go diff --git a/internal/thinking/provider/claude/apply.go b/internal/thinking/provider/claude/apply.go index af031907..c92f539e 100644 --- a/internal/thinking/provider/claude/apply.go +++ b/internal/thinking/provider/claude/apply.go @@ -174,7 +174,8 @@ func (a *Applier) normalizeClaudeBudget(body []byte, budgetTokens int, modelInfo // Ensure the request satisfies Claude constraints: // 1) Determine effective max_tokens (request overrides model default) // 2) If budget_tokens >= max_tokens, reduce budget_tokens to max_tokens-1 - // 3) If the adjusted budget falls below the model minimum, leave the request unchanged + // 3) If the adjusted budget falls below the model minimum, try raising max_tokens + // (clamped to MaxCompletionTokens); disable thinking if constraints are unsatisfiable // 4) If max_tokens came from model default, write it back into the request effectiveMax, setDefaultMax := a.effectiveMaxTokens(body, modelInfo) @@ -193,10 +194,28 @@ func (a *Applier) normalizeClaudeBudget(body []byte, budgetTokens int, modelInfo minBudget = modelInfo.Thinking.Min } if minBudget > 0 && adjustedBudget > 0 && adjustedBudget < minBudget { - // If enforcing the max_tokens constraint would push the budget below the model minimum, - // increase max_tokens to accommodate the original budget instead of leaving the - // request unchanged (which would cause a 400 error from the API). - body, _ = sjson.SetBytes(body, "max_tokens", budgetTokens+1) + // Enforcing budget_tokens < max_tokens pushed the budget below the model minimum. + // Try raising max_tokens to fit the original budget. + needed := budgetTokens + 1 + maxAllowed := 0 + if modelInfo != nil { + maxAllowed = modelInfo.MaxCompletionTokens + } + if maxAllowed > 0 && needed > maxAllowed { + // Cannot use original budget; cap max_tokens at model limit. + needed = maxAllowed + } + cappedBudget := needed - 1 + if cappedBudget < minBudget { + // Impossible to satisfy both budget >= minBudget and budget < max_tokens + // within the model's completion limit. Disable thinking entirely. + body, _ = sjson.DeleteBytes(body, "thinking") + return body + } + body, _ = sjson.SetBytes(body, "max_tokens", needed) + if cappedBudget != budgetTokens { + body, _ = sjson.SetBytes(body, "thinking.budget_tokens", cappedBudget) + } return body } diff --git a/internal/thinking/provider/claude/apply_test.go b/internal/thinking/provider/claude/apply_test.go new file mode 100644 index 00000000..46b3f3b7 --- /dev/null +++ b/internal/thinking/provider/claude/apply_test.go @@ -0,0 +1,99 @@ +package claude + +import ( + "testing" + + "github.com/router-for-me/CLIProxyAPI/v6/internal/registry" + "github.com/tidwall/gjson" +) + +func TestNormalizeClaudeBudget_RaisesMaxTokens(t *testing.T) { + a := &Applier{} + modelInfo := ®istry.ModelInfo{ + MaxCompletionTokens: 64000, + Thinking: ®istry.ThinkingSupport{Min: 1024, Max: 128000}, + } + body := []byte(`{"max_tokens":1000,"thinking":{"type":"enabled","budget_tokens":5000}}`) + + out := a.normalizeClaudeBudget(body, 5000, modelInfo) + + maxTok := gjson.GetBytes(out, "max_tokens").Int() + if maxTok != 5001 { + t.Fatalf("max_tokens = %d, want 5001, body=%s", maxTok, string(out)) + } +} + +func TestNormalizeClaudeBudget_ClampsToModelMax(t *testing.T) { + a := &Applier{} + modelInfo := ®istry.ModelInfo{ + MaxCompletionTokens: 64000, + Thinking: ®istry.ThinkingSupport{Min: 1024, Max: 128000}, + } + body := []byte(`{"max_tokens":500,"thinking":{"type":"enabled","budget_tokens":200000}}`) + + out := a.normalizeClaudeBudget(body, 200000, modelInfo) + + maxTok := gjson.GetBytes(out, "max_tokens").Int() + if maxTok != 64000 { + t.Fatalf("max_tokens = %d, want 64000 (capped to model limit), body=%s", maxTok, string(out)) + } + budget := gjson.GetBytes(out, "thinking.budget_tokens").Int() + if budget != 63999 { + t.Fatalf("budget_tokens = %d, want 63999 (max_tokens-1), body=%s", budget, string(out)) + } +} + +func TestNormalizeClaudeBudget_DisablesThinkingWhenUnsatisfiable(t *testing.T) { + a := &Applier{} + modelInfo := ®istry.ModelInfo{ + MaxCompletionTokens: 1000, + Thinking: ®istry.ThinkingSupport{Min: 1024, Max: 128000}, + } + body := []byte(`{"max_tokens":500,"thinking":{"type":"enabled","budget_tokens":2000}}`) + + out := a.normalizeClaudeBudget(body, 2000, modelInfo) + + if gjson.GetBytes(out, "thinking").Exists() { + t.Fatalf("thinking should be removed when constraints are unsatisfiable, body=%s", string(out)) + } +} + +func TestNormalizeClaudeBudget_NoClamping(t *testing.T) { + a := &Applier{} + modelInfo := ®istry.ModelInfo{ + MaxCompletionTokens: 64000, + Thinking: ®istry.ThinkingSupport{Min: 1024, Max: 128000}, + } + body := []byte(`{"max_tokens":32000,"thinking":{"type":"enabled","budget_tokens":16000}}`) + + out := a.normalizeClaudeBudget(body, 16000, modelInfo) + + maxTok := gjson.GetBytes(out, "max_tokens").Int() + if maxTok != 32000 { + t.Fatalf("max_tokens should remain 32000, got %d, body=%s", maxTok, string(out)) + } + budget := gjson.GetBytes(out, "thinking.budget_tokens").Int() + if budget != 16000 { + t.Fatalf("budget_tokens should remain 16000, got %d, body=%s", budget, string(out)) + } +} + +func TestNormalizeClaudeBudget_AdjustsBudgetToMaxMinus1(t *testing.T) { + a := &Applier{} + modelInfo := ®istry.ModelInfo{ + MaxCompletionTokens: 8192, + Thinking: ®istry.ThinkingSupport{Min: 1024, Max: 128000}, + } + body := []byte(`{"max_tokens":8192,"thinking":{"type":"enabled","budget_tokens":10000}}`) + + out := a.normalizeClaudeBudget(body, 10000, modelInfo) + + maxTok := gjson.GetBytes(out, "max_tokens").Int() + if maxTok != 8192 { + t.Fatalf("max_tokens = %d, want 8192 (unchanged), body=%s", maxTok, string(out)) + } + budget := gjson.GetBytes(out, "thinking.budget_tokens").Int() + if budget != 8191 { + t.Fatalf("budget_tokens = %d, want 8191 (max_tokens-1), body=%s", budget, string(out)) + } +} From e166e56249f7d42d9d05823bd2e6cf20a1667fa5 Mon Sep 17 00:00:00 2001 From: destinoantagonista-wq Date: Fri, 13 Mar 2026 19:41:49 +0000 Subject: [PATCH 03/21] Reconcile registry model states on auth changes Add Manager.ReconcileRegistryModelStates to clear stale per-model runtime failures for models currently registered in the global model registry. The method finds models supported for an auth, resets non-clean ModelState entries, updates aggregated availability, persists changes, and pushes a snapshot to the scheduler. Introduce modelStateIsClean helper to determine when a model state needs resetting. Call ReconcileRegistryModelStates from Service paths that register/refresh models (applyCoreAuthAddOrUpdate and refreshModelRegistrationForAuth) to keep the scheduler and global registry aligned after model re-registration. --- sdk/cliproxy/auth/conductor.go | 91 ++++++++++++++++++++++++++++++++++ sdk/cliproxy/service.go | 3 ++ 2 files changed, 94 insertions(+) diff --git a/sdk/cliproxy/auth/conductor.go b/sdk/cliproxy/auth/conductor.go index b29e04db..9fc65274 100644 --- a/sdk/cliproxy/auth/conductor.go +++ b/sdk/cliproxy/auth/conductor.go @@ -233,6 +233,81 @@ func (m *Manager) RefreshSchedulerEntry(authID string) { m.scheduler.upsertAuth(snapshot) } +// ReconcileRegistryModelStates clears stale per-model runtime failures for +// models that are currently registered for the auth in the global model registry. +// +// This keeps the scheduler and the global registry aligned after model +// re-registration. Without this reconciliation, a model can reappear in +// /v1/models after registry refresh while the scheduler still blocks it because +// auth.ModelStates retained an older failure such as not_found or quota. +func (m *Manager) ReconcileRegistryModelStates(ctx context.Context, authID string) { + if m == nil || authID == "" { + return + } + + supportedModels := registry.GetGlobalRegistry().GetModelsForClient(authID) + if len(supportedModels) == 0 { + return + } + + supported := make(map[string]struct{}, len(supportedModels)) + for _, model := range supportedModels { + if model == nil { + continue + } + modelKey := canonicalModelKey(model.ID) + if modelKey == "" { + continue + } + supported[modelKey] = struct{}{} + } + if len(supported) == 0 { + return + } + + var snapshot *Auth + now := time.Now() + + m.mu.Lock() + auth, ok := m.auths[authID] + if ok && auth != nil && len(auth.ModelStates) > 0 { + changed := false + for modelKey, state := range auth.ModelStates { + if state == nil { + continue + } + baseModel := canonicalModelKey(modelKey) + if baseModel == "" { + baseModel = strings.TrimSpace(modelKey) + } + if _, supportedModel := supported[baseModel]; !supportedModel { + continue + } + if modelStateIsClean(state) { + continue + } + resetModelState(state, now) + changed = true + } + if changed { + updateAggregatedAvailability(auth, now) + if !hasModelError(auth, now) { + auth.LastError = nil + auth.StatusMessage = "" + auth.Status = StatusActive + } + auth.UpdatedAt = now + _ = m.persist(ctx, auth) + snapshot = auth.Clone() + } + } + m.mu.Unlock() + + if m.scheduler != nil && snapshot != nil { + m.scheduler.upsertAuth(snapshot) + } +} + func (m *Manager) SetSelector(selector Selector) { if m == nil { return @@ -1735,6 +1810,22 @@ func resetModelState(state *ModelState, now time.Time) { state.UpdatedAt = now } +func modelStateIsClean(state *ModelState) bool { + if state == nil { + return true + } + if state.Status != StatusActive { + return false + } + if state.Unavailable || state.StatusMessage != "" || !state.NextRetryAfter.IsZero() || state.LastError != nil { + return false + } + if state.Quota.Exceeded || state.Quota.Reason != "" || !state.Quota.NextRecoverAt.IsZero() || state.Quota.BackoffLevel != 0 { + return false + } + return true +} + func updateAggregatedAvailability(auth *Auth, now time.Time) { if auth == nil || len(auth.ModelStates) == 0 { return diff --git a/sdk/cliproxy/service.go b/sdk/cliproxy/service.go index abe1deed..a562cfb3 100644 --- a/sdk/cliproxy/service.go +++ b/sdk/cliproxy/service.go @@ -310,6 +310,7 @@ func (s *Service) applyCoreAuthAddOrUpdate(ctx context.Context, auth *coreauth.A // This operation may block on network calls, but the auth configuration // is already effective at this point. s.registerModelsForAuth(auth) + s.coreManager.ReconcileRegistryModelStates(ctx, auth.ID) // Refresh the scheduler entry so that the auth's supportedModelSet is rebuilt // from the now-populated global model registry. Without this, newly added auths @@ -1019,6 +1020,7 @@ func (s *Service) refreshModelRegistrationForAuth(current *coreauth.Auth) bool { s.ensureExecutorsForAuth(current) } s.registerModelsForAuth(current) + s.coreManager.ReconcileRegistryModelStates(context.Background(), current.ID) latest, ok := s.latestAuthForModelRegistration(current.ID) if !ok || latest.Disabled { @@ -1032,6 +1034,7 @@ func (s *Service) refreshModelRegistrationForAuth(current *coreauth.Auth) bool { // no auth fields changed, but keeps the refresh path simple and correct. s.ensureExecutorsForAuth(latest) s.registerModelsForAuth(latest) + s.coreManager.ReconcileRegistryModelStates(context.Background(), latest.ID) s.coreManager.RefreshSchedulerEntry(current.ID) return true } From f09ed25fd365a37e4469414f0d2754c19789ef60 Mon Sep 17 00:00:00 2001 From: destinoantagonista-wq Date: Sat, 14 Mar 2026 14:40:06 +0000 Subject: [PATCH 04/21] fix(auth): tighten registry model reconciliation --- sdk/cliproxy/auth/conductor.go | 58 ++++-- .../auth/conductor_registry_reconcile_test.go | 182 ++++++++++++++++++ 2 files changed, 222 insertions(+), 18 deletions(-) create mode 100644 sdk/cliproxy/auth/conductor_registry_reconcile_test.go diff --git a/sdk/cliproxy/auth/conductor.go b/sdk/cliproxy/auth/conductor.go index 9fc65274..1152bca0 100644 --- a/sdk/cliproxy/auth/conductor.go +++ b/sdk/cliproxy/auth/conductor.go @@ -233,23 +233,19 @@ func (m *Manager) RefreshSchedulerEntry(authID string) { m.scheduler.upsertAuth(snapshot) } -// ReconcileRegistryModelStates clears stale per-model runtime failures for -// models that are currently registered for the auth in the global model registry. +// ReconcileRegistryModelStates aligns per-model runtime state with the current +// registry snapshot for one auth. // -// This keeps the scheduler and the global registry aligned after model -// re-registration. Without this reconciliation, a model can reappear in -// /v1/models after registry refresh while the scheduler still blocks it because -// auth.ModelStates retained an older failure such as not_found or quota. +// Supported models are reset to a clean state because re-registration already +// cleared the registry-side cooldown/suspension snapshot. ModelStates for +// models that are no longer present in the registry are pruned entirely so +// renamed/removed models cannot keep auth-level status stale. func (m *Manager) ReconcileRegistryModelStates(ctx context.Context, authID string) { if m == nil || authID == "" { return } supportedModels := registry.GetGlobalRegistry().GetModelsForClient(authID) - if len(supportedModels) == 0 { - return - } - supported := make(map[string]struct{}, len(supportedModels)) for _, model := range supportedModels { if model == nil { @@ -261,9 +257,6 @@ func (m *Manager) ReconcileRegistryModelStates(ctx context.Context, authID strin } supported[modelKey] = struct{}{} } - if len(supported) == 0 { - return - } var snapshot *Auth now := time.Now() @@ -273,14 +266,19 @@ func (m *Manager) ReconcileRegistryModelStates(ctx context.Context, authID strin if ok && auth != nil && len(auth.ModelStates) > 0 { changed := false for modelKey, state := range auth.ModelStates { - if state == nil { - continue - } baseModel := canonicalModelKey(modelKey) if baseModel == "" { baseModel = strings.TrimSpace(modelKey) } if _, supportedModel := supported[baseModel]; !supportedModel { + // Drop state for models that disappeared from the current registry + // snapshot. Keeping them around leaks stale errors into auth-level + // status, management output, and websocket fallback checks. + delete(auth.ModelStates, modelKey) + changed = true + continue + } + if state == nil { continue } if modelStateIsClean(state) { @@ -289,6 +287,9 @@ func (m *Manager) ReconcileRegistryModelStates(ctx context.Context, authID strin resetModelState(state, now) changed = true } + if len(auth.ModelStates) == 0 { + auth.ModelStates = nil + } if changed { updateAggregatedAvailability(auth, now) if !hasModelError(auth, now) { @@ -297,7 +298,9 @@ func (m *Manager) ReconcileRegistryModelStates(ctx context.Context, authID strin auth.Status = StatusActive } auth.UpdatedAt = now - _ = m.persist(ctx, auth) + if errPersist := m.persist(ctx, auth); errPersist != nil { + logEntryWithRequestID(ctx).WithField("auth_id", auth.ID).Warnf("failed to persist auth changes during model state reconciliation: %v", errPersist) + } snapshot = auth.Clone() } } @@ -1827,7 +1830,11 @@ func modelStateIsClean(state *ModelState) bool { } func updateAggregatedAvailability(auth *Auth, now time.Time) { - if auth == nil || len(auth.ModelStates) == 0 { + if auth == nil { + return + } + if len(auth.ModelStates) == 0 { + clearAggregatedAvailability(auth) return } allUnavailable := true @@ -1835,10 +1842,12 @@ func updateAggregatedAvailability(auth *Auth, now time.Time) { quotaExceeded := false quotaRecover := time.Time{} maxBackoffLevel := 0 + hasState := false for _, state := range auth.ModelStates { if state == nil { continue } + hasState = true stateUnavailable := false if state.Status == StatusDisabled { stateUnavailable = true @@ -1868,6 +1877,10 @@ func updateAggregatedAvailability(auth *Auth, now time.Time) { } } } + if !hasState { + clearAggregatedAvailability(auth) + return + } auth.Unavailable = allUnavailable if allUnavailable { auth.NextRetryAfter = earliestRetry @@ -1887,6 +1900,15 @@ func updateAggregatedAvailability(auth *Auth, now time.Time) { } } +func clearAggregatedAvailability(auth *Auth) { + if auth == nil { + return + } + auth.Unavailable = false + auth.NextRetryAfter = time.Time{} + auth.Quota = QuotaState{} +} + func hasModelError(auth *Auth, now time.Time) bool { if auth == nil || len(auth.ModelStates) == 0 { return false diff --git a/sdk/cliproxy/auth/conductor_registry_reconcile_test.go b/sdk/cliproxy/auth/conductor_registry_reconcile_test.go new file mode 100644 index 00000000..dc4b95a9 --- /dev/null +++ b/sdk/cliproxy/auth/conductor_registry_reconcile_test.go @@ -0,0 +1,182 @@ +package auth + +import ( + "context" + "errors" + "net/http" + "testing" + "time" + + cliproxyexecutor "github.com/router-for-me/CLIProxyAPI/v6/sdk/cliproxy/executor" +) + +func TestManager_ReconcileRegistryModelStates_ClearsStaleSupportedModelErrors(t *testing.T) { + ctx := context.Background() + manager := NewManager(nil, &RoundRobinSelector{}, nil) + + auth := &Auth{ + ID: "reconcile-auth", + Provider: "codex", + ModelStates: map[string]*ModelState{ + "gpt-5.4": { + Status: StatusError, + StatusMessage: "not_found", + Unavailable: true, + NextRetryAfter: time.Now().Add(12 * time.Hour), + LastError: &Error{HTTPStatus: http.StatusNotFound, Message: "not_found"}, + }, + }, + } + if _, errRegister := manager.Register(ctx, auth); errRegister != nil { + t.Fatalf("register auth: %v", errRegister) + } + + registerSchedulerModels(t, "codex", "gpt-5.4", auth.ID) + manager.RefreshSchedulerEntry(auth.ID) + + got, errPick := manager.scheduler.pickSingle(ctx, "codex", "gpt-5.4", cliproxyexecutor.Options{}, nil) + var authErr *Error + if !errors.As(errPick, &authErr) || authErr == nil { + t.Fatalf("pickSingle() before reconcile error = %v, want auth_unavailable", errPick) + } + if authErr.Code != "auth_unavailable" { + t.Fatalf("pickSingle() before reconcile code = %q, want %q", authErr.Code, "auth_unavailable") + } + if got != nil { + t.Fatalf("pickSingle() before reconcile auth = %v, want nil", got) + } + + manager.ReconcileRegistryModelStates(ctx, auth.ID) + + got, errPick = manager.scheduler.pickSingle(ctx, "codex", "gpt-5.4", cliproxyexecutor.Options{}, nil) + if errPick != nil { + t.Fatalf("pickSingle() after reconcile error = %v", errPick) + } + if got == nil || got.ID != auth.ID { + t.Fatalf("pickSingle() after reconcile auth = %v, want %q", got, auth.ID) + } + + reconciled, ok := manager.GetByID(auth.ID) + if !ok || reconciled == nil { + t.Fatalf("expected auth to still exist") + } + state := reconciled.ModelStates["gpt-5.4"] + if state == nil { + t.Fatalf("expected reconciled model state to exist") + } + if state.Unavailable { + t.Fatalf("state.Unavailable = true, want false") + } + if state.Status != StatusActive { + t.Fatalf("state.Status = %q, want %q", state.Status, StatusActive) + } + if !state.NextRetryAfter.IsZero() { + t.Fatalf("state.NextRetryAfter = %v, want zero", state.NextRetryAfter) + } + if state.LastError != nil { + t.Fatalf("state.LastError = %v, want nil", state.LastError) + } +} + +func TestManager_ReconcileRegistryModelStates_PrunesUnsupportedModelStates(t *testing.T) { + ctx := context.Background() + manager := NewManager(nil, &RoundRobinSelector{}, nil) + + nextRetry := time.Now().Add(30 * time.Minute) + auth := &Auth{ + ID: "reconcile-unsupported-auth", + Provider: "codex", + Status: StatusError, + Unavailable: true, + StatusMessage: "payment_required", + LastError: &Error{HTTPStatus: http.StatusPaymentRequired, Message: "payment_required"}, + ModelStates: map[string]*ModelState{ + "gpt-5.4": { + Status: StatusError, + StatusMessage: "payment_required", + Unavailable: true, + NextRetryAfter: nextRetry, + }, + }, + } + if _, errRegister := manager.Register(ctx, auth); errRegister != nil { + t.Fatalf("register auth: %v", errRegister) + } + + registerSchedulerModels(t, "codex", "gpt-5.5", auth.ID) + manager.ReconcileRegistryModelStates(ctx, auth.ID) + + reconciled, ok := manager.GetByID(auth.ID) + if !ok || reconciled == nil { + t.Fatalf("expected auth to still exist") + } + if len(reconciled.ModelStates) != 0 { + t.Fatalf("expected stale unsupported model state to be pruned, got %+v", reconciled.ModelStates) + } + if reconciled.Unavailable { + t.Fatalf("auth.Unavailable = true, want false") + } + if reconciled.Status != StatusActive { + t.Fatalf("auth.Status = %q, want %q", reconciled.Status, StatusActive) + } + if reconciled.StatusMessage != "" { + t.Fatalf("auth.StatusMessage = %q, want empty", reconciled.StatusMessage) + } + if reconciled.LastError != nil { + t.Fatalf("auth.LastError = %v, want nil", reconciled.LastError) + } + if !reconciled.NextRetryAfter.IsZero() { + t.Fatalf("auth.NextRetryAfter = %v, want zero", reconciled.NextRetryAfter) + } +} + +func TestManager_ReconcileRegistryModelStates_ClearsRemovedModelStateWhenRegistryIsEmpty(t *testing.T) { + ctx := context.Background() + manager := NewManager(nil, &RoundRobinSelector{}, nil) + + auth := &Auth{ + ID: "reconcile-empty-registry-auth", + Provider: "codex", + Status: StatusError, + Unavailable: true, + StatusMessage: "not_found", + LastError: &Error{HTTPStatus: http.StatusNotFound, Message: "not_found"}, + ModelStates: map[string]*ModelState{ + "gpt-5.4": { + Status: StatusError, + StatusMessage: "not_found", + Unavailable: true, + NextRetryAfter: time.Now().Add(12 * time.Hour), + LastError: &Error{HTTPStatus: http.StatusNotFound, Message: "not_found"}, + }, + }, + } + if _, errRegister := manager.Register(ctx, auth); errRegister != nil { + t.Fatalf("register auth: %v", errRegister) + } + + manager.ReconcileRegistryModelStates(ctx, auth.ID) + + reconciled, ok := manager.GetByID(auth.ID) + if !ok || reconciled == nil { + t.Fatalf("expected auth to still exist") + } + if len(reconciled.ModelStates) != 0 { + t.Fatalf("expected stale model state to be pruned when registry is empty, got %+v", reconciled.ModelStates) + } + if reconciled.Unavailable { + t.Fatalf("auth.Unavailable = true, want false") + } + if reconciled.Status != StatusActive { + t.Fatalf("auth.Status = %q, want %q", reconciled.Status, StatusActive) + } + if reconciled.StatusMessage != "" { + t.Fatalf("auth.StatusMessage = %q, want empty", reconciled.StatusMessage) + } + if reconciled.LastError != nil { + t.Fatalf("auth.LastError = %v, want nil", reconciled.LastError) + } + if !reconciled.NextRetryAfter.IsZero() { + t.Fatalf("auth.NextRetryAfter = %v, want zero", reconciled.NextRetryAfter) + } +} From e08f68ed7c7bafe4e0291a570d92b7e53b6e1352 Mon Sep 17 00:00:00 2001 From: destinoantagonista-wq Date: Sat, 14 Mar 2026 14:41:26 +0000 Subject: [PATCH 05/21] chore(auth): drop reconcile test file from pr --- .../auth/conductor_registry_reconcile_test.go | 182 ------------------ 1 file changed, 182 deletions(-) delete mode 100644 sdk/cliproxy/auth/conductor_registry_reconcile_test.go diff --git a/sdk/cliproxy/auth/conductor_registry_reconcile_test.go b/sdk/cliproxy/auth/conductor_registry_reconcile_test.go deleted file mode 100644 index dc4b95a9..00000000 --- a/sdk/cliproxy/auth/conductor_registry_reconcile_test.go +++ /dev/null @@ -1,182 +0,0 @@ -package auth - -import ( - "context" - "errors" - "net/http" - "testing" - "time" - - cliproxyexecutor "github.com/router-for-me/CLIProxyAPI/v6/sdk/cliproxy/executor" -) - -func TestManager_ReconcileRegistryModelStates_ClearsStaleSupportedModelErrors(t *testing.T) { - ctx := context.Background() - manager := NewManager(nil, &RoundRobinSelector{}, nil) - - auth := &Auth{ - ID: "reconcile-auth", - Provider: "codex", - ModelStates: map[string]*ModelState{ - "gpt-5.4": { - Status: StatusError, - StatusMessage: "not_found", - Unavailable: true, - NextRetryAfter: time.Now().Add(12 * time.Hour), - LastError: &Error{HTTPStatus: http.StatusNotFound, Message: "not_found"}, - }, - }, - } - if _, errRegister := manager.Register(ctx, auth); errRegister != nil { - t.Fatalf("register auth: %v", errRegister) - } - - registerSchedulerModels(t, "codex", "gpt-5.4", auth.ID) - manager.RefreshSchedulerEntry(auth.ID) - - got, errPick := manager.scheduler.pickSingle(ctx, "codex", "gpt-5.4", cliproxyexecutor.Options{}, nil) - var authErr *Error - if !errors.As(errPick, &authErr) || authErr == nil { - t.Fatalf("pickSingle() before reconcile error = %v, want auth_unavailable", errPick) - } - if authErr.Code != "auth_unavailable" { - t.Fatalf("pickSingle() before reconcile code = %q, want %q", authErr.Code, "auth_unavailable") - } - if got != nil { - t.Fatalf("pickSingle() before reconcile auth = %v, want nil", got) - } - - manager.ReconcileRegistryModelStates(ctx, auth.ID) - - got, errPick = manager.scheduler.pickSingle(ctx, "codex", "gpt-5.4", cliproxyexecutor.Options{}, nil) - if errPick != nil { - t.Fatalf("pickSingle() after reconcile error = %v", errPick) - } - if got == nil || got.ID != auth.ID { - t.Fatalf("pickSingle() after reconcile auth = %v, want %q", got, auth.ID) - } - - reconciled, ok := manager.GetByID(auth.ID) - if !ok || reconciled == nil { - t.Fatalf("expected auth to still exist") - } - state := reconciled.ModelStates["gpt-5.4"] - if state == nil { - t.Fatalf("expected reconciled model state to exist") - } - if state.Unavailable { - t.Fatalf("state.Unavailable = true, want false") - } - if state.Status != StatusActive { - t.Fatalf("state.Status = %q, want %q", state.Status, StatusActive) - } - if !state.NextRetryAfter.IsZero() { - t.Fatalf("state.NextRetryAfter = %v, want zero", state.NextRetryAfter) - } - if state.LastError != nil { - t.Fatalf("state.LastError = %v, want nil", state.LastError) - } -} - -func TestManager_ReconcileRegistryModelStates_PrunesUnsupportedModelStates(t *testing.T) { - ctx := context.Background() - manager := NewManager(nil, &RoundRobinSelector{}, nil) - - nextRetry := time.Now().Add(30 * time.Minute) - auth := &Auth{ - ID: "reconcile-unsupported-auth", - Provider: "codex", - Status: StatusError, - Unavailable: true, - StatusMessage: "payment_required", - LastError: &Error{HTTPStatus: http.StatusPaymentRequired, Message: "payment_required"}, - ModelStates: map[string]*ModelState{ - "gpt-5.4": { - Status: StatusError, - StatusMessage: "payment_required", - Unavailable: true, - NextRetryAfter: nextRetry, - }, - }, - } - if _, errRegister := manager.Register(ctx, auth); errRegister != nil { - t.Fatalf("register auth: %v", errRegister) - } - - registerSchedulerModels(t, "codex", "gpt-5.5", auth.ID) - manager.ReconcileRegistryModelStates(ctx, auth.ID) - - reconciled, ok := manager.GetByID(auth.ID) - if !ok || reconciled == nil { - t.Fatalf("expected auth to still exist") - } - if len(reconciled.ModelStates) != 0 { - t.Fatalf("expected stale unsupported model state to be pruned, got %+v", reconciled.ModelStates) - } - if reconciled.Unavailable { - t.Fatalf("auth.Unavailable = true, want false") - } - if reconciled.Status != StatusActive { - t.Fatalf("auth.Status = %q, want %q", reconciled.Status, StatusActive) - } - if reconciled.StatusMessage != "" { - t.Fatalf("auth.StatusMessage = %q, want empty", reconciled.StatusMessage) - } - if reconciled.LastError != nil { - t.Fatalf("auth.LastError = %v, want nil", reconciled.LastError) - } - if !reconciled.NextRetryAfter.IsZero() { - t.Fatalf("auth.NextRetryAfter = %v, want zero", reconciled.NextRetryAfter) - } -} - -func TestManager_ReconcileRegistryModelStates_ClearsRemovedModelStateWhenRegistryIsEmpty(t *testing.T) { - ctx := context.Background() - manager := NewManager(nil, &RoundRobinSelector{}, nil) - - auth := &Auth{ - ID: "reconcile-empty-registry-auth", - Provider: "codex", - Status: StatusError, - Unavailable: true, - StatusMessage: "not_found", - LastError: &Error{HTTPStatus: http.StatusNotFound, Message: "not_found"}, - ModelStates: map[string]*ModelState{ - "gpt-5.4": { - Status: StatusError, - StatusMessage: "not_found", - Unavailable: true, - NextRetryAfter: time.Now().Add(12 * time.Hour), - LastError: &Error{HTTPStatus: http.StatusNotFound, Message: "not_found"}, - }, - }, - } - if _, errRegister := manager.Register(ctx, auth); errRegister != nil { - t.Fatalf("register auth: %v", errRegister) - } - - manager.ReconcileRegistryModelStates(ctx, auth.ID) - - reconciled, ok := manager.GetByID(auth.ID) - if !ok || reconciled == nil { - t.Fatalf("expected auth to still exist") - } - if len(reconciled.ModelStates) != 0 { - t.Fatalf("expected stale model state to be pruned when registry is empty, got %+v", reconciled.ModelStates) - } - if reconciled.Unavailable { - t.Fatalf("auth.Unavailable = true, want false") - } - if reconciled.Status != StatusActive { - t.Fatalf("auth.Status = %q, want %q", reconciled.Status, StatusActive) - } - if reconciled.StatusMessage != "" { - t.Fatalf("auth.StatusMessage = %q, want empty", reconciled.StatusMessage) - } - if reconciled.LastError != nil { - t.Fatalf("auth.LastError = %v, want nil", reconciled.LastError) - } - if !reconciled.NextRetryAfter.IsZero() { - t.Fatalf("auth.NextRetryAfter = %v, want zero", reconciled.NextRetryAfter) - } -} From a34dfed3780ec5ab654183148460020a635450b6 Mon Sep 17 00:00:00 2001 From: Ravi Tharuma Date: Tue, 24 Mar 2026 19:12:52 +0100 Subject: [PATCH 06/21] fix: preserve Claude thinking signatures in Codex translator --- .../codex/claude/codex_claude_response.go | 121 ++++++++----- .../claude/codex_claude_response_test.go | 160 ++++++++++++++++++ 2 files changed, 237 insertions(+), 44 deletions(-) create mode 100644 internal/translator/codex/claude/codex_claude_response_test.go diff --git a/internal/translator/codex/claude/codex_claude_response.go b/internal/translator/codex/claude/codex_claude_response.go index b436cd3f..0ddd0845 100644 --- a/internal/translator/codex/claude/codex_claude_response.go +++ b/internal/translator/codex/claude/codex_claude_response.go @@ -26,6 +26,9 @@ type ConvertCodexResponseToClaudeParams struct { HasToolCall bool BlockIndex int HasReceivedArgumentsDelta bool + ThinkingBlockOpen bool + ThinkingStopPending bool + ThinkingSignature string } // ConvertCodexResponseToClaude performs sophisticated streaming response format conversion. @@ -44,7 +47,7 @@ type ConvertCodexResponseToClaudeParams struct { // // Returns: // - [][]byte: A slice of Claude Code-compatible JSON responses -func ConvertCodexResponseToClaude(_ context.Context, _ string, originalRequestRawJSON, requestRawJSON, rawJSON []byte, param *any) [][]byte { +func ConvertCodexResponseToClaude(_ context.Context, _ string, originalRequestRawJSON, _ []byte, rawJSON []byte, param *any) [][]byte { if *param == nil { *param = &ConvertCodexResponseToClaudeParams{ HasToolCall: false, @@ -52,7 +55,6 @@ func ConvertCodexResponseToClaude(_ context.Context, _ string, originalRequestRa } } - // log.Debugf("rawJSON: %s", string(rawJSON)) if !bytes.HasPrefix(rawJSON, dataTag) { return [][]byte{} } @@ -60,9 +62,18 @@ func ConvertCodexResponseToClaude(_ context.Context, _ string, originalRequestRa output := make([]byte, 0, 512) rootResult := gjson.ParseBytes(rawJSON) + params := (*param).(*ConvertCodexResponseToClaudeParams) + if params.ThinkingBlockOpen && params.ThinkingStopPending { + switch rootResult.Get("type").String() { + case "response.content_part.added", "response.completed": + output = append(output, finalizeCodexThinkingBlock(params)...) + } + } + typeResult := rootResult.Get("type") typeStr := typeResult.String() var template []byte + if typeStr == "response.created" { template = []byte(`{"type":"message_start","message":{"id":"","type":"message","role":"assistant","model":"claude-opus-4-1-20250805","stop_sequence":null,"usage":{"input_tokens":0,"output_tokens":0},"content":[],"stop_reason":null}}`) template, _ = sjson.SetBytes(template, "message.model", rootResult.Get("response.model").String()) @@ -70,43 +81,44 @@ func ConvertCodexResponseToClaude(_ context.Context, _ string, originalRequestRa output = translatorcommon.AppendSSEEventBytes(output, "message_start", template, 2) } else if typeStr == "response.reasoning_summary_part.added" { - template = []byte(`{"type":"content_block_start","index":0,"content_block":{"type":"thinking","thinking":""}}`) - template, _ = sjson.SetBytes(template, "index", (*param).(*ConvertCodexResponseToClaudeParams).BlockIndex) + template = []byte(`{"type":"content_block_start","index":0,"content_block":{"type":"thinking","thinking":"","signature":""}}`) + template, _ = sjson.SetBytes(template, "index", params.BlockIndex) + params.ThinkingBlockOpen = true + params.ThinkingStopPending = false + params.ThinkingSignature = "" output = translatorcommon.AppendSSEEventBytes(output, "content_block_start", template, 2) } else if typeStr == "response.reasoning_summary_text.delta" { template = []byte(`{"type":"content_block_delta","index":0,"delta":{"type":"thinking_delta","thinking":""}}`) - template, _ = sjson.SetBytes(template, "index", (*param).(*ConvertCodexResponseToClaudeParams).BlockIndex) + template, _ = sjson.SetBytes(template, "index", params.BlockIndex) template, _ = sjson.SetBytes(template, "delta.thinking", rootResult.Get("delta").String()) output = translatorcommon.AppendSSEEventBytes(output, "content_block_delta", template, 2) } else if typeStr == "response.reasoning_summary_part.done" { - template = []byte(`{"type":"content_block_stop","index":0}`) - template, _ = sjson.SetBytes(template, "index", (*param).(*ConvertCodexResponseToClaudeParams).BlockIndex) - (*param).(*ConvertCodexResponseToClaudeParams).BlockIndex++ - - output = translatorcommon.AppendSSEEventBytes(output, "content_block_stop", template, 2) - + params.ThinkingStopPending = true + if params.ThinkingSignature != "" { + output = append(output, finalizeCodexThinkingBlock(params)...) + } } else if typeStr == "response.content_part.added" { template = []byte(`{"type":"content_block_start","index":0,"content_block":{"type":"text","text":""}}`) - template, _ = sjson.SetBytes(template, "index", (*param).(*ConvertCodexResponseToClaudeParams).BlockIndex) + template, _ = sjson.SetBytes(template, "index", params.BlockIndex) output = translatorcommon.AppendSSEEventBytes(output, "content_block_start", template, 2) } else if typeStr == "response.output_text.delta" { template = []byte(`{"type":"content_block_delta","index":0,"delta":{"type":"text_delta","text":""}}`) - template, _ = sjson.SetBytes(template, "index", (*param).(*ConvertCodexResponseToClaudeParams).BlockIndex) + template, _ = sjson.SetBytes(template, "index", params.BlockIndex) template, _ = sjson.SetBytes(template, "delta.text", rootResult.Get("delta").String()) output = translatorcommon.AppendSSEEventBytes(output, "content_block_delta", template, 2) } else if typeStr == "response.content_part.done" { template = []byte(`{"type":"content_block_stop","index":0}`) - template, _ = sjson.SetBytes(template, "index", (*param).(*ConvertCodexResponseToClaudeParams).BlockIndex) - (*param).(*ConvertCodexResponseToClaudeParams).BlockIndex++ + template, _ = sjson.SetBytes(template, "index", params.BlockIndex) + params.BlockIndex++ output = translatorcommon.AppendSSEEventBytes(output, "content_block_stop", template, 2) } else if typeStr == "response.completed" { template = []byte(`{"type":"message_delta","delta":{"stop_reason":"tool_use","stop_sequence":null},"usage":{"input_tokens":0,"output_tokens":0}}`) - p := (*param).(*ConvertCodexResponseToClaudeParams).HasToolCall + p := params.HasToolCall stopReason := rootResult.Get("response.stop_reason").String() if p { template, _ = sjson.SetBytes(template, "delta.stop_reason", "tool_use") @@ -128,13 +140,13 @@ func ConvertCodexResponseToClaude(_ context.Context, _ string, originalRequestRa itemResult := rootResult.Get("item") itemType := itemResult.Get("type").String() if itemType == "function_call" { - (*param).(*ConvertCodexResponseToClaudeParams).HasToolCall = true - (*param).(*ConvertCodexResponseToClaudeParams).HasReceivedArgumentsDelta = false + output = append(output, finalizeCodexThinkingBlock(params)...) + params.HasToolCall = true + params.HasReceivedArgumentsDelta = false template = []byte(`{"type":"content_block_start","index":0,"content_block":{"type":"tool_use","id":"","name":"","input":{}}}`) - template, _ = sjson.SetBytes(template, "index", (*param).(*ConvertCodexResponseToClaudeParams).BlockIndex) + template, _ = sjson.SetBytes(template, "index", params.BlockIndex) template, _ = sjson.SetBytes(template, "content_block.id", util.SanitizeClaudeToolID(itemResult.Get("call_id").String())) { - // Restore original tool name if shortened name := itemResult.Get("name").String() rev := buildReverseMapFromClaudeOriginalShortToOriginal(originalRequestRawJSON) if orig, ok := rev[name]; ok { @@ -146,37 +158,40 @@ func ConvertCodexResponseToClaude(_ context.Context, _ string, originalRequestRa output = translatorcommon.AppendSSEEventBytes(output, "content_block_start", template, 2) template = []byte(`{"type":"content_block_delta","index":0,"delta":{"type":"input_json_delta","partial_json":""}}`) - template, _ = sjson.SetBytes(template, "index", (*param).(*ConvertCodexResponseToClaudeParams).BlockIndex) + template, _ = sjson.SetBytes(template, "index", params.BlockIndex) output = translatorcommon.AppendSSEEventBytes(output, "content_block_delta", template, 2) + } else if itemType == "reasoning" { + params.ThinkingSignature = itemResult.Get("encrypted_content").String() + if params.ThinkingStopPending { + output = append(output, finalizeCodexThinkingBlock(params)...) + } } } else if typeStr == "response.output_item.done" { itemResult := rootResult.Get("item") itemType := itemResult.Get("type").String() if itemType == "function_call" { template = []byte(`{"type":"content_block_stop","index":0}`) - template, _ = sjson.SetBytes(template, "index", (*param).(*ConvertCodexResponseToClaudeParams).BlockIndex) - (*param).(*ConvertCodexResponseToClaudeParams).BlockIndex++ + template, _ = sjson.SetBytes(template, "index", params.BlockIndex) + params.BlockIndex++ output = translatorcommon.AppendSSEEventBytes(output, "content_block_stop", template, 2) + } else if itemType == "reasoning" { + params.ThinkingSignature = itemResult.Get("encrypted_content").String() + output = append(output, finalizeCodexThinkingBlock(params)...) } } else if typeStr == "response.function_call_arguments.delta" { - (*param).(*ConvertCodexResponseToClaudeParams).HasReceivedArgumentsDelta = true + params.HasReceivedArgumentsDelta = true template = []byte(`{"type":"content_block_delta","index":0,"delta":{"type":"input_json_delta","partial_json":""}}`) - template, _ = sjson.SetBytes(template, "index", (*param).(*ConvertCodexResponseToClaudeParams).BlockIndex) + template, _ = sjson.SetBytes(template, "index", params.BlockIndex) template, _ = sjson.SetBytes(template, "delta.partial_json", rootResult.Get("delta").String()) output = translatorcommon.AppendSSEEventBytes(output, "content_block_delta", template, 2) } else if typeStr == "response.function_call_arguments.done" { - // Some models (e.g. gpt-5.3-codex-spark) send function call arguments - // in a single "done" event without preceding "delta" events. - // Emit the full arguments as a single input_json_delta so the - // downstream Claude client receives the complete tool input. - // When delta events were already received, skip to avoid duplicating arguments. - if !(*param).(*ConvertCodexResponseToClaudeParams).HasReceivedArgumentsDelta { + if !params.HasReceivedArgumentsDelta { if args := rootResult.Get("arguments").String(); args != "" { template = []byte(`{"type":"content_block_delta","index":0,"delta":{"type":"input_json_delta","partial_json":""}}`) - template, _ = sjson.SetBytes(template, "index", (*param).(*ConvertCodexResponseToClaudeParams).BlockIndex) + template, _ = sjson.SetBytes(template, "index", params.BlockIndex) template, _ = sjson.SetBytes(template, "delta.partial_json", args) output = translatorcommon.AppendSSEEventBytes(output, "content_block_delta", template, 2) @@ -191,15 +206,6 @@ func ConvertCodexResponseToClaude(_ context.Context, _ string, originalRequestRa // This function processes the complete Codex response and transforms it into a single Claude Code-compatible // JSON response. It handles message content, tool calls, reasoning content, and usage metadata, combining all // the information into a single response that matches the Claude Code API format. -// -// Parameters: -// - ctx: The context for the request, used for cancellation and timeout handling -// - modelName: The name of the model being used for the response (unused in current implementation) -// - rawJSON: The raw JSON response from the Codex API -// - param: A pointer to a parameter object for the conversion (unused in current implementation) -// -// Returns: -// - []byte: A Claude Code-compatible JSON response containing all message content and metadata func ConvertCodexResponseToClaudeNonStream(_ context.Context, _ string, originalRequestRawJSON, _ []byte, rawJSON []byte, _ *any) []byte { revNames := buildReverseMapFromClaudeOriginalShortToOriginal(originalRequestRawJSON) @@ -230,6 +236,7 @@ func ConvertCodexResponseToClaudeNonStream(_ context.Context, _ string, original switch item.Get("type").String() { case "reasoning": thinkingBuilder := strings.Builder{} + signature := item.Get("encrypted_content").String() if summary := item.Get("summary"); summary.Exists() { if summary.IsArray() { summary.ForEach(func(_, part gjson.Result) bool { @@ -260,9 +267,10 @@ func ConvertCodexResponseToClaudeNonStream(_ context.Context, _ string, original } } } - if thinkingBuilder.Len() > 0 { - block := []byte(`{"type":"thinking","thinking":""}`) + if thinkingBuilder.Len() > 0 || signature != "" { + block := []byte(`{"type":"thinking","thinking":"","signature":""}`) block, _ = sjson.SetBytes(block, "thinking", thinkingBuilder.String()) + block, _ = sjson.SetBytes(block, "signature", signature) out, _ = sjson.SetRawBytes(out, "content.-1", block) } case "message": @@ -371,6 +379,31 @@ func buildReverseMapFromClaudeOriginalShortToOriginal(original []byte) map[strin return rev } -func ClaudeTokenCount(ctx context.Context, count int64) []byte { +func ClaudeTokenCount(_ context.Context, count int64) []byte { return translatorcommon.ClaudeInputTokensJSON(count) } + +func finalizeCodexThinkingBlock(params *ConvertCodexResponseToClaudeParams) []byte { + if !params.ThinkingBlockOpen { + return nil + } + + output := make([]byte, 0, 256) + if params.ThinkingSignature != "" { + signatureDelta := []byte(`{"type":"content_block_delta","index":0,"delta":{"type":"signature_delta","signature":""}}`) + signatureDelta, _ = sjson.SetBytes(signatureDelta, "index", params.BlockIndex) + signatureDelta, _ = sjson.SetBytes(signatureDelta, "delta.signature", params.ThinkingSignature) + output = translatorcommon.AppendSSEEventBytes(output, "content_block_delta", signatureDelta, 2) + } + + contentBlockStop := []byte(`{"type":"content_block_stop","index":0}`) + contentBlockStop, _ = sjson.SetBytes(contentBlockStop, "index", params.BlockIndex) + output = translatorcommon.AppendSSEEventBytes(output, "content_block_stop", contentBlockStop, 2) + + params.BlockIndex++ + params.ThinkingBlockOpen = false + params.ThinkingStopPending = false + params.ThinkingSignature = "" + + return output +} diff --git a/internal/translator/codex/claude/codex_claude_response_test.go b/internal/translator/codex/claude/codex_claude_response_test.go new file mode 100644 index 00000000..d903dcf7 --- /dev/null +++ b/internal/translator/codex/claude/codex_claude_response_test.go @@ -0,0 +1,160 @@ +package claude + +import ( + "context" + "strings" + "testing" + + "github.com/tidwall/gjson" +) + +func TestConvertCodexResponseToClaude_StreamThinkingIncludesSignature(t *testing.T) { + ctx := context.Background() + originalRequest := []byte(`{"messages":[]}`) + var param any + + chunks := [][]byte{ + []byte("data: {\"type\":\"response.created\",\"response\":{\"id\":\"resp_123\",\"model\":\"gpt-5\"}}"), + []byte("data: {\"type\":\"response.reasoning_summary_part.added\"}"), + []byte("data: {\"type\":\"response.reasoning_summary_text.delta\",\"delta\":\"Let me think\"}"), + []byte("data: {\"type\":\"response.reasoning_summary_part.done\"}"), + []byte("data: {\"type\":\"response.output_item.done\",\"item\":{\"type\":\"reasoning\",\"encrypted_content\":\"enc_sig_123\"}}"), + } + + var outputs [][]byte + for _, chunk := range chunks { + outputs = append(outputs, ConvertCodexResponseToClaude(ctx, "", originalRequest, nil, chunk, ¶m)...) + } + + startFound := false + signatureDeltaFound := false + stopFound := false + + for _, out := range outputs { + for _, line := range strings.Split(string(out), "\n") { + if !strings.HasPrefix(line, "data: ") { + continue + } + data := gjson.Parse(strings.TrimPrefix(line, "data: ")) + switch data.Get("type").String() { + case "content_block_start": + if data.Get("content_block.type").String() == "thinking" { + startFound = true + if !data.Get("content_block.signature").Exists() { + t.Fatalf("thinking start block missing signature field: %s", line) + } + } + case "content_block_delta": + if data.Get("delta.type").String() == "signature_delta" { + signatureDeltaFound = true + if got := data.Get("delta.signature").String(); got != "enc_sig_123" { + t.Fatalf("unexpected signature delta: %q", got) + } + } + case "content_block_stop": + stopFound = true + } + } + } + + if !startFound { + t.Fatal("expected thinking content_block_start event") + } + if !signatureDeltaFound { + t.Fatal("expected signature_delta event for thinking block") + } + if !stopFound { + t.Fatal("expected content_block_stop event for thinking block") + } +} + +func TestConvertCodexResponseToClaude_StreamThinkingWithoutReasoningItemStillIncludesSignatureField(t *testing.T) { + ctx := context.Background() + originalRequest := []byte(`{"messages":[]}`) + var param any + + chunks := [][]byte{ + []byte("data: {\"type\":\"response.reasoning_summary_part.added\"}"), + []byte("data: {\"type\":\"response.reasoning_summary_text.delta\",\"delta\":\"Let me think\"}"), + []byte("data: {\"type\":\"response.reasoning_summary_part.done\"}"), + []byte("data: {\"type\":\"response.completed\",\"response\":{\"usage\":{\"input_tokens\":1,\"output_tokens\":1}}}"), + } + + var outputs [][]byte + for _, chunk := range chunks { + outputs = append(outputs, ConvertCodexResponseToClaude(ctx, "", originalRequest, nil, chunk, ¶m)...) + } + + thinkingStartFound := false + thinkingStopFound := false + signatureDeltaFound := false + + for _, out := range outputs { + for _, line := range strings.Split(string(out), "\n") { + if !strings.HasPrefix(line, "data: ") { + continue + } + data := gjson.Parse(strings.TrimPrefix(line, "data: ")) + if data.Get("type").String() == "content_block_start" && data.Get("content_block.type").String() == "thinking" { + thinkingStartFound = true + if !data.Get("content_block.signature").Exists() { + t.Fatalf("thinking start block missing signature field: %s", line) + } + } + if data.Get("type").String() == "content_block_stop" && data.Get("index").Int() == 0 { + thinkingStopFound = true + } + if data.Get("type").String() == "content_block_delta" && data.Get("delta.type").String() == "signature_delta" { + signatureDeltaFound = true + } + } + } + + if !thinkingStartFound { + t.Fatal("expected thinking content_block_start event") + } + if !thinkingStopFound { + t.Fatal("expected thinking content_block_stop event") + } + if signatureDeltaFound { + t.Fatal("did not expect signature_delta without encrypted_content") + } +} + +func TestConvertCodexResponseToClaudeNonStream_ThinkingIncludesSignature(t *testing.T) { + ctx := context.Background() + originalRequest := []byte(`{"messages":[]}`) + response := []byte(`{ + "type":"response.completed", + "response":{ + "id":"resp_123", + "model":"gpt-5", + "usage":{"input_tokens":10,"output_tokens":20}, + "output":[ + { + "type":"reasoning", + "encrypted_content":"enc_sig_nonstream", + "summary":[{"type":"summary_text","text":"internal reasoning"}] + }, + { + "type":"message", + "content":[{"type":"output_text","text":"final answer"}] + } + ] + } + }`) + + out := ConvertCodexResponseToClaudeNonStream(ctx, "", originalRequest, nil, response, nil) + parsed := gjson.ParseBytes(out) + + thinking := parsed.Get("content.0") + if thinking.Get("type").String() != "thinking" { + t.Fatalf("expected first content block to be thinking, got %s", thinking.Raw) + } + if got := thinking.Get("signature").String(); got != "enc_sig_nonstream" { + t.Fatalf("expected signature to be preserved, got %q", got) + } + if got := thinking.Get("thinking").String(); got != "internal reasoning" { + t.Fatalf("unexpected thinking text: %q", got) + } +} From 76b53d6b5b6c7cc48b174d2cfcf611b4f5ccefce Mon Sep 17 00:00:00 2001 From: Ravi Tharuma Date: Tue, 24 Mar 2026 19:34:11 +0100 Subject: [PATCH 07/21] fix: finalize pending thinking block before next summary part --- .../codex/claude/codex_claude_response.go | 3 ++ .../claude/codex_claude_response_test.go | 42 +++++++++++++++++++ 2 files changed, 45 insertions(+) diff --git a/internal/translator/codex/claude/codex_claude_response.go b/internal/translator/codex/claude/codex_claude_response.go index 0ddd0845..4f027543 100644 --- a/internal/translator/codex/claude/codex_claude_response.go +++ b/internal/translator/codex/claude/codex_claude_response.go @@ -81,6 +81,9 @@ func ConvertCodexResponseToClaude(_ context.Context, _ string, originalRequestRa output = translatorcommon.AppendSSEEventBytes(output, "message_start", template, 2) } else if typeStr == "response.reasoning_summary_part.added" { + if params.ThinkingBlockOpen && params.ThinkingStopPending { + output = append(output, finalizeCodexThinkingBlock(params)...) + } template = []byte(`{"type":"content_block_start","index":0,"content_block":{"type":"thinking","thinking":"","signature":""}}`) template, _ = sjson.SetBytes(template, "index", params.BlockIndex) params.ThinkingBlockOpen = true diff --git a/internal/translator/codex/claude/codex_claude_response_test.go b/internal/translator/codex/claude/codex_claude_response_test.go index d903dcf7..5a25057c 100644 --- a/internal/translator/codex/claude/codex_claude_response_test.go +++ b/internal/translator/codex/claude/codex_claude_response_test.go @@ -121,6 +121,48 @@ func TestConvertCodexResponseToClaude_StreamThinkingWithoutReasoningItemStillInc } } +func TestConvertCodexResponseToClaude_StreamThinkingFinalizesPendingBlockBeforeNextSummaryPart(t *testing.T) { + ctx := context.Background() + originalRequest := []byte(`{"messages":[]}`) + var param any + + chunks := [][]byte{ + []byte("data: {\"type\":\"response.reasoning_summary_part.added\"}"), + []byte("data: {\"type\":\"response.reasoning_summary_text.delta\",\"delta\":\"First part\"}"), + []byte("data: {\"type\":\"response.reasoning_summary_part.done\"}"), + []byte("data: {\"type\":\"response.reasoning_summary_part.added\"}"), + } + + var outputs [][]byte + for _, chunk := range chunks { + outputs = append(outputs, ConvertCodexResponseToClaude(ctx, "", originalRequest, nil, chunk, ¶m)...) + } + + startCount := 0 + stopCount := 0 + for _, out := range outputs { + for _, line := range strings.Split(string(out), "\n") { + if !strings.HasPrefix(line, "data: ") { + continue + } + data := gjson.Parse(strings.TrimPrefix(line, "data: ")) + if data.Get("type").String() == "content_block_start" && data.Get("content_block.type").String() == "thinking" { + startCount++ + } + if data.Get("type").String() == "content_block_stop" { + stopCount++ + } + } + } + + if startCount != 2 { + t.Fatalf("expected 2 thinking block starts, got %d", startCount) + } + if stopCount != 1 { + t.Fatalf("expected pending thinking block to be finalized before second start, got %d stops", stopCount) + } +} + func TestConvertCodexResponseToClaudeNonStream_ThinkingIncludesSignature(t *testing.T) { ctx := context.Background() originalRequest := []byte(`{"messages":[]}`) From c31ae2f3b598e228ca4058984504cde278d513e5 Mon Sep 17 00:00:00 2001 From: Ravi Tharuma Date: Tue, 24 Mar 2026 20:08:23 +0100 Subject: [PATCH 08/21] fix: retain previously captured thinking signature on new summary part --- internal/translator/codex/claude/codex_claude_response.go | 1 - 1 file changed, 1 deletion(-) diff --git a/internal/translator/codex/claude/codex_claude_response.go b/internal/translator/codex/claude/codex_claude_response.go index 4f027543..798089d0 100644 --- a/internal/translator/codex/claude/codex_claude_response.go +++ b/internal/translator/codex/claude/codex_claude_response.go @@ -88,7 +88,6 @@ func ConvertCodexResponseToClaude(_ context.Context, _ string, originalRequestRa template, _ = sjson.SetBytes(template, "index", params.BlockIndex) params.ThinkingBlockOpen = true params.ThinkingStopPending = false - params.ThinkingSignature = "" output = translatorcommon.AppendSSEEventBytes(output, "content_block_start", template, 2) } else if typeStr == "response.reasoning_summary_text.delta" { From 73b22ec29b26e6fbb682df0d873e4e538ecfbd1f Mon Sep 17 00:00:00 2001 From: Ravi Tharuma Date: Wed, 25 Mar 2026 07:44:21 +0100 Subject: [PATCH 09/21] fix: omit empty signature field from thinking blocks Emit signature only when non-empty in both streaming content_block_start and non-streaming thinking blocks. Avoids turning 'missing signature' into 'empty/invalid signature' which Claude clients may reject. --- internal/translator/codex/claude/codex_claude_response.go | 8 +++++--- .../translator/codex/claude/codex_claude_response_test.go | 8 ++++---- 2 files changed, 9 insertions(+), 7 deletions(-) diff --git a/internal/translator/codex/claude/codex_claude_response.go b/internal/translator/codex/claude/codex_claude_response.go index 798089d0..4557606f 100644 --- a/internal/translator/codex/claude/codex_claude_response.go +++ b/internal/translator/codex/claude/codex_claude_response.go @@ -84,7 +84,7 @@ func ConvertCodexResponseToClaude(_ context.Context, _ string, originalRequestRa if params.ThinkingBlockOpen && params.ThinkingStopPending { output = append(output, finalizeCodexThinkingBlock(params)...) } - template = []byte(`{"type":"content_block_start","index":0,"content_block":{"type":"thinking","thinking":"","signature":""}}`) + template = []byte(`{"type":"content_block_start","index":0,"content_block":{"type":"thinking","thinking":""}}`) template, _ = sjson.SetBytes(template, "index", params.BlockIndex) params.ThinkingBlockOpen = true params.ThinkingStopPending = false @@ -270,9 +270,11 @@ func ConvertCodexResponseToClaudeNonStream(_ context.Context, _ string, original } } if thinkingBuilder.Len() > 0 || signature != "" { - block := []byte(`{"type":"thinking","thinking":"","signature":""}`) + block := []byte(`{"type":"thinking","thinking":""}`) block, _ = sjson.SetBytes(block, "thinking", thinkingBuilder.String()) - block, _ = sjson.SetBytes(block, "signature", signature) + if signature != "" { + block, _ = sjson.SetBytes(block, "signature", signature) + } out, _ = sjson.SetRawBytes(out, "content.-1", block) } case "message": diff --git a/internal/translator/codex/claude/codex_claude_response_test.go b/internal/translator/codex/claude/codex_claude_response_test.go index 5a25057c..f436711e 100644 --- a/internal/translator/codex/claude/codex_claude_response_test.go +++ b/internal/translator/codex/claude/codex_claude_response_test.go @@ -40,8 +40,8 @@ func TestConvertCodexResponseToClaude_StreamThinkingIncludesSignature(t *testing case "content_block_start": if data.Get("content_block.type").String() == "thinking" { startFound = true - if !data.Get("content_block.signature").Exists() { - t.Fatalf("thinking start block missing signature field: %s", line) + if data.Get("content_block.signature").Exists() { + t.Fatalf("thinking start block should NOT have signature field when signature is unknown: %s", line) } } case "content_block_delta": @@ -97,8 +97,8 @@ func TestConvertCodexResponseToClaude_StreamThinkingWithoutReasoningItemStillInc data := gjson.Parse(strings.TrimPrefix(line, "data: ")) if data.Get("type").String() == "content_block_start" && data.Get("content_block.type").String() == "thinking" { thinkingStartFound = true - if !data.Get("content_block.signature").Exists() { - t.Fatalf("thinking start block missing signature field: %s", line) + if data.Get("content_block.signature").Exists() { + t.Fatalf("thinking start block should NOT have signature field without encrypted_content: %s", line) } } if data.Get("type").String() == "content_block_stop" && data.Get("index").Int() == 0 { From 66eb12294a5f96269a1d478149d5fc0ba1e44234 Mon Sep 17 00:00:00 2001 From: Ravi Tharuma Date: Wed, 25 Mar 2026 07:52:32 +0100 Subject: [PATCH 10/21] fix: clear stale thinking signature when no block is open --- internal/translator/codex/claude/codex_claude_response.go | 1 + 1 file changed, 1 insertion(+) diff --git a/internal/translator/codex/claude/codex_claude_response.go b/internal/translator/codex/claude/codex_claude_response.go index 4557606f..4db4c9fc 100644 --- a/internal/translator/codex/claude/codex_claude_response.go +++ b/internal/translator/codex/claude/codex_claude_response.go @@ -389,6 +389,7 @@ func ClaudeTokenCount(_ context.Context, count int64) []byte { func finalizeCodexThinkingBlock(params *ConvertCodexResponseToClaudeParams) []byte { if !params.ThinkingBlockOpen { + params.ThinkingSignature = "" return nil } From 5fc2bd393eb9a360476d9db1762f2b69441bb7e4 Mon Sep 17 00:00:00 2001 From: Ravi Tharuma Date: Sat, 28 Mar 2026 14:41:25 +0100 Subject: [PATCH 11/21] fix: retain codex thinking signature until item done --- .../codex/claude/codex_claude_response.go | 7 +- .../claude/codex_claude_response_test.go | 80 +++++++++++++++++++ 2 files changed, 84 insertions(+), 3 deletions(-) diff --git a/internal/translator/codex/claude/codex_claude_response.go b/internal/translator/codex/claude/codex_claude_response.go index 4db4c9fc..708194e6 100644 --- a/internal/translator/codex/claude/codex_claude_response.go +++ b/internal/translator/codex/claude/codex_claude_response.go @@ -179,8 +179,11 @@ func ConvertCodexResponseToClaude(_ context.Context, _ string, originalRequestRa output = translatorcommon.AppendSSEEventBytes(output, "content_block_stop", template, 2) } else if itemType == "reasoning" { - params.ThinkingSignature = itemResult.Get("encrypted_content").String() + if signature := itemResult.Get("encrypted_content").String(); signature != "" { + params.ThinkingSignature = signature + } output = append(output, finalizeCodexThinkingBlock(params)...) + params.ThinkingSignature = "" } } else if typeStr == "response.function_call_arguments.delta" { params.HasReceivedArgumentsDelta = true @@ -389,7 +392,6 @@ func ClaudeTokenCount(_ context.Context, count int64) []byte { func finalizeCodexThinkingBlock(params *ConvertCodexResponseToClaudeParams) []byte { if !params.ThinkingBlockOpen { - params.ThinkingSignature = "" return nil } @@ -408,7 +410,6 @@ func finalizeCodexThinkingBlock(params *ConvertCodexResponseToClaudeParams) []by params.BlockIndex++ params.ThinkingBlockOpen = false params.ThinkingStopPending = false - params.ThinkingSignature = "" return output } diff --git a/internal/translator/codex/claude/codex_claude_response_test.go b/internal/translator/codex/claude/codex_claude_response_test.go index f436711e..a8d4d189 100644 --- a/internal/translator/codex/claude/codex_claude_response_test.go +++ b/internal/translator/codex/claude/codex_claude_response_test.go @@ -163,6 +163,86 @@ func TestConvertCodexResponseToClaude_StreamThinkingFinalizesPendingBlockBeforeN } } +func TestConvertCodexResponseToClaude_StreamThinkingRetainsSignatureAcrossMultipartReasoning(t *testing.T) { + ctx := context.Background() + originalRequest := []byte(`{"messages":[]}`) + var param any + + chunks := [][]byte{ + []byte("data: {\"type\":\"response.output_item.added\",\"item\":{\"type\":\"reasoning\",\"encrypted_content\":\"enc_sig_multipart\"}}"), + []byte("data: {\"type\":\"response.reasoning_summary_part.added\"}"), + []byte("data: {\"type\":\"response.reasoning_summary_text.delta\",\"delta\":\"First part\"}"), + []byte("data: {\"type\":\"response.reasoning_summary_part.done\"}"), + []byte("data: {\"type\":\"response.reasoning_summary_part.added\"}"), + []byte("data: {\"type\":\"response.reasoning_summary_text.delta\",\"delta\":\"Second part\"}"), + []byte("data: {\"type\":\"response.reasoning_summary_part.done\"}"), + []byte("data: {\"type\":\"response.output_item.done\",\"item\":{\"type\":\"reasoning\"}}"), + } + + var outputs [][]byte + for _, chunk := range chunks { + outputs = append(outputs, ConvertCodexResponseToClaude(ctx, "", originalRequest, nil, chunk, ¶m)...) + } + + signatureDeltaCount := 0 + for _, out := range outputs { + for _, line := range strings.Split(string(out), "\n") { + if !strings.HasPrefix(line, "data: ") { + continue + } + data := gjson.Parse(strings.TrimPrefix(line, "data: ")) + if data.Get("type").String() == "content_block_delta" && data.Get("delta.type").String() == "signature_delta" { + signatureDeltaCount++ + if got := data.Get("delta.signature").String(); got != "enc_sig_multipart" { + t.Fatalf("unexpected signature delta: %q", got) + } + } + } + } + + if signatureDeltaCount != 2 { + t.Fatalf("expected signature_delta for both multipart thinking blocks, got %d", signatureDeltaCount) + } +} + +func TestConvertCodexResponseToClaude_StreamThinkingUsesEarlyCapturedSignatureWhenDoneOmitsIt(t *testing.T) { + ctx := context.Background() + originalRequest := []byte(`{"messages":[]}`) + var param any + + chunks := [][]byte{ + []byte("data: {\"type\":\"response.output_item.added\",\"item\":{\"type\":\"reasoning\",\"encrypted_content\":\"enc_sig_early\"}}"), + []byte("data: {\"type\":\"response.reasoning_summary_part.added\"}"), + []byte("data: {\"type\":\"response.reasoning_summary_text.delta\",\"delta\":\"Let me think\"}"), + []byte("data: {\"type\":\"response.output_item.done\",\"item\":{\"type\":\"reasoning\"}}"), + } + + var outputs [][]byte + for _, chunk := range chunks { + outputs = append(outputs, ConvertCodexResponseToClaude(ctx, "", originalRequest, nil, chunk, ¶m)...) + } + + signatureDeltaCount := 0 + for _, out := range outputs { + for _, line := range strings.Split(string(out), "\n") { + if !strings.HasPrefix(line, "data: ") { + continue + } + data := gjson.Parse(strings.TrimPrefix(line, "data: ")) + if data.Get("type").String() == "content_block_delta" && data.Get("delta.type").String() == "signature_delta" { + signatureDeltaCount++ + if got := data.Get("delta.signature").String(); got != "enc_sig_early" { + t.Fatalf("unexpected signature delta: %q", got) + } + } + } + } + + if signatureDeltaCount != 1 { + t.Fatalf("expected signature_delta from early-captured signature, got %d", signatureDeltaCount) + } +} + func TestConvertCodexResponseToClaudeNonStream_ThinkingIncludesSignature(t *testing.T) { ctx := context.Background() originalRequest := []byte(`{"messages":[]}`) From da3a498a28819a59b2b9cca0cdcc1d051c3cb983 Mon Sep 17 00:00:00 2001 From: mpfo0106 Date: Thu, 2 Apr 2026 20:35:39 +0900 Subject: [PATCH 12/21] Keep Claude Code compatibility work low-risk and reviewable This change stops short of broader Claude Code runtime alignment and instead hardens two safe edges: builtin tool prefix handling and source-informed sentinel coverage for future drift checks. Constraint: Must preserve existing default behavior for current users Rejected: Implement control-plane/session alignment now | too much runtime risk for a first slice Confidence: high Scope-risk: narrow Reversibility: clean Directive: Treat the new fixtures as compatibility sentinels, not a full Claude Code schema contract Tested: go test ./test/...; go test ./sdk/translator/...; go test ./internal/runtime/executor -run 'Claude|Builtin|Tool'; go test ./... Not-tested: End-to-end Claude Code direct-connect/session runtime behavior --- .../runtime/executor/claude_builtin_tools.go | 38 +++++++ .../executor/claude_builtin_tools_test.go | 46 ++++++++ internal/runtime/executor/claude_executor.go | 9 +- ...claude_code_compatibility_sentinel_test.go | 106 ++++++++++++++++++ .../control_request_can_use_tool.json | 11 ++ .../session_state_changed.json | 7 ++ .../claude_code_sentinels/tool_progress.json | 10 ++ .../tool_use_summary.json | 7 ++ 8 files changed, 228 insertions(+), 6 deletions(-) create mode 100644 internal/runtime/executor/claude_builtin_tools.go create mode 100644 internal/runtime/executor/claude_builtin_tools_test.go create mode 100644 test/claude_code_compatibility_sentinel_test.go create mode 100644 test/testdata/claude_code_sentinels/control_request_can_use_tool.json create mode 100644 test/testdata/claude_code_sentinels/session_state_changed.json create mode 100644 test/testdata/claude_code_sentinels/tool_progress.json create mode 100644 test/testdata/claude_code_sentinels/tool_use_summary.json diff --git a/internal/runtime/executor/claude_builtin_tools.go b/internal/runtime/executor/claude_builtin_tools.go new file mode 100644 index 00000000..8c3592f7 --- /dev/null +++ b/internal/runtime/executor/claude_builtin_tools.go @@ -0,0 +1,38 @@ +package executor + +import "github.com/tidwall/gjson" + +var defaultClaudeBuiltinToolNames = []string{ + "web_search", + "code_execution", + "text_editor", + "computer", +} + +func newClaudeBuiltinToolRegistry() map[string]bool { + registry := make(map[string]bool, len(defaultClaudeBuiltinToolNames)) + for _, name := range defaultClaudeBuiltinToolNames { + registry[name] = true + } + return registry +} + +func augmentClaudeBuiltinToolRegistry(body []byte, registry map[string]bool) map[string]bool { + if registry == nil { + registry = newClaudeBuiltinToolRegistry() + } + tools := gjson.GetBytes(body, "tools") + if !tools.Exists() || !tools.IsArray() { + return registry + } + tools.ForEach(func(_, tool gjson.Result) bool { + if tool.Get("type").String() == "" { + return true + } + if name := tool.Get("name").String(); name != "" { + registry[name] = true + } + return true + }) + return registry +} diff --git a/internal/runtime/executor/claude_builtin_tools_test.go b/internal/runtime/executor/claude_builtin_tools_test.go new file mode 100644 index 00000000..34036fa0 --- /dev/null +++ b/internal/runtime/executor/claude_builtin_tools_test.go @@ -0,0 +1,46 @@ +package executor + +import ( + "fmt" + "testing" + + "github.com/tidwall/gjson" +) + +func TestClaudeBuiltinToolRegistry_DefaultSeedFallback(t *testing.T) { + registry := augmentClaudeBuiltinToolRegistry(nil, nil) + for _, name := range defaultClaudeBuiltinToolNames { + if !registry[name] { + t.Fatalf("default builtin %q missing from fallback registry", name) + } + } +} + +func TestApplyClaudeToolPrefix_KnownFallbackBuiltinsRemainUnprefixed(t *testing.T) { + for _, builtin := range defaultClaudeBuiltinToolNames { + t.Run(builtin, func(t *testing.T) { + input := []byte(fmt.Sprintf(`{ + "tools":[{"name":"Read"}], + "tool_choice":{"type":"tool","name":%q}, + "messages":[{"role":"assistant","content":[{"type":"tool_use","name":%q,"id":"toolu_1","input":{}},{"type":"tool_reference","tool_name":%q},{"type":"tool_result","tool_use_id":"toolu_1","content":[{"type":"tool_reference","tool_name":%q}]}]}] + }`, builtin, builtin, builtin, builtin)) + out := applyClaudeToolPrefix(input, "proxy_") + + if got := gjson.GetBytes(out, "tool_choice.name").String(); got != builtin { + t.Fatalf("tool_choice.name = %q, want %q", got, builtin) + } + if got := gjson.GetBytes(out, "messages.0.content.0.name").String(); got != builtin { + t.Fatalf("messages.0.content.0.name = %q, want %q", got, builtin) + } + if got := gjson.GetBytes(out, "messages.0.content.1.tool_name").String(); got != builtin { + t.Fatalf("messages.0.content.1.tool_name = %q, want %q", got, builtin) + } + if got := gjson.GetBytes(out, "messages.0.content.2.content.0.tool_name").String(); got != builtin { + t.Fatalf("messages.0.content.2.content.0.tool_name = %q, want %q", got, builtin) + } + if got := gjson.GetBytes(out, "tools.0.name").String(); got != "proxy_Read" { + t.Fatalf("tools.0.name = %q, want %q", got, "proxy_Read") + } + }) + } +} diff --git a/internal/runtime/executor/claude_executor.go b/internal/runtime/executor/claude_executor.go index f5e7e409..d1d2e136 100644 --- a/internal/runtime/executor/claude_executor.go +++ b/internal/runtime/executor/claude_executor.go @@ -919,12 +919,9 @@ func applyClaudeToolPrefix(body []byte, prefix string) []byte { return body } - // Collect built-in tool names (those with a non-empty "type" field) so we can - // skip them consistently in both tools and message history. - builtinTools := map[string]bool{} - for _, name := range []string{"web_search", "code_execution", "text_editor", "computer"} { - builtinTools[name] = true - } + // Collect built-in tool names from the authoritative fallback seed list and + // augment it with any typed built-ins present in the current request body. + builtinTools := augmentClaudeBuiltinToolRegistry(body, nil) if tools := gjson.GetBytes(body, "tools"); tools.Exists() && tools.IsArray() { tools.ForEach(func(index, tool gjson.Result) bool { diff --git a/test/claude_code_compatibility_sentinel_test.go b/test/claude_code_compatibility_sentinel_test.go new file mode 100644 index 00000000..793b3c6a --- /dev/null +++ b/test/claude_code_compatibility_sentinel_test.go @@ -0,0 +1,106 @@ +package test + +import ( + "encoding/json" + "os" + "path/filepath" + "testing" +) + +type jsonObject = map[string]any + +func loadClaudeCodeSentinelFixture(t *testing.T, name string) jsonObject { + t.Helper() + path := filepath.Join("testdata", "claude_code_sentinels", name) + data := mustReadFile(t, path) + var payload jsonObject + if err := json.Unmarshal(data, &payload); err != nil { + t.Fatalf("unmarshal %s: %v", name, err) + } + return payload +} + +func mustReadFile(t *testing.T, path string) []byte { + t.Helper() + data, err := os.ReadFile(path) + if err != nil { + t.Fatalf("read %s: %v", path, err) + } + return data +} + +func requireStringField(t *testing.T, obj jsonObject, key string) string { + t.Helper() + value, ok := obj[key].(string) + if !ok || value == "" { + t.Fatalf("field %q missing or empty: %#v", key, obj[key]) + } + return value +} + +func TestClaudeCodeSentinel_ToolProgressShape(t *testing.T) { + payload := loadClaudeCodeSentinelFixture(t, "tool_progress.json") + if got := requireStringField(t, payload, "type"); got != "tool_progress" { + t.Fatalf("type = %q, want tool_progress", got) + } + requireStringField(t, payload, "tool_use_id") + requireStringField(t, payload, "tool_name") + requireStringField(t, payload, "session_id") + if _, ok := payload["elapsed_time_seconds"].(float64); !ok { + t.Fatalf("elapsed_time_seconds missing or non-number: %#v", payload["elapsed_time_seconds"]) + } +} + +func TestClaudeCodeSentinel_SessionStateShape(t *testing.T) { + payload := loadClaudeCodeSentinelFixture(t, "session_state_changed.json") + if got := requireStringField(t, payload, "type"); got != "system" { + t.Fatalf("type = %q, want system", got) + } + if got := requireStringField(t, payload, "subtype"); got != "session_state_changed" { + t.Fatalf("subtype = %q, want session_state_changed", got) + } + state := requireStringField(t, payload, "state") + switch state { + case "idle", "running", "requires_action": + default: + t.Fatalf("unexpected session state %q", state) + } + requireStringField(t, payload, "session_id") +} + +func TestClaudeCodeSentinel_ToolUseSummaryShape(t *testing.T) { + payload := loadClaudeCodeSentinelFixture(t, "tool_use_summary.json") + if got := requireStringField(t, payload, "type"); got != "tool_use_summary" { + t.Fatalf("type = %q, want tool_use_summary", got) + } + requireStringField(t, payload, "summary") + rawIDs, ok := payload["preceding_tool_use_ids"].([]any) + if !ok || len(rawIDs) == 0 { + t.Fatalf("preceding_tool_use_ids missing or empty: %#v", payload["preceding_tool_use_ids"]) + } + for i, raw := range rawIDs { + if id, ok := raw.(string); !ok || id == "" { + t.Fatalf("preceding_tool_use_ids[%d] invalid: %#v", i, raw) + } + } +} + +func TestClaudeCodeSentinel_ControlRequestCanUseToolShape(t *testing.T) { + payload := loadClaudeCodeSentinelFixture(t, "control_request_can_use_tool.json") + if got := requireStringField(t, payload, "type"); got != "control_request" { + t.Fatalf("type = %q, want control_request", got) + } + requireStringField(t, payload, "request_id") + request, ok := payload["request"].(map[string]any) + if !ok { + t.Fatalf("request missing or invalid: %#v", payload["request"]) + } + if got := requireStringField(t, request, "subtype"); got != "can_use_tool" { + t.Fatalf("request.subtype = %q, want can_use_tool", got) + } + requireStringField(t, request, "tool_name") + requireStringField(t, request, "tool_use_id") + if input, ok := request["input"].(map[string]any); !ok || len(input) == 0 { + t.Fatalf("request.input missing or empty: %#v", request["input"]) + } +} diff --git a/test/testdata/claude_code_sentinels/control_request_can_use_tool.json b/test/testdata/claude_code_sentinels/control_request_can_use_tool.json new file mode 100644 index 00000000..cafdb00a --- /dev/null +++ b/test/testdata/claude_code_sentinels/control_request_can_use_tool.json @@ -0,0 +1,11 @@ +{ + "type": "control_request", + "request_id": "req_123", + "request": { + "subtype": "can_use_tool", + "tool_name": "Bash", + "input": {"command": "npm test"}, + "tool_use_id": "toolu_123", + "description": "Running npm test" + } +} diff --git a/test/testdata/claude_code_sentinels/session_state_changed.json b/test/testdata/claude_code_sentinels/session_state_changed.json new file mode 100644 index 00000000..db411ace --- /dev/null +++ b/test/testdata/claude_code_sentinels/session_state_changed.json @@ -0,0 +1,7 @@ +{ + "type": "system", + "subtype": "session_state_changed", + "state": "requires_action", + "uuid": "22222222-2222-4222-8222-222222222222", + "session_id": "sess_123" +} diff --git a/test/testdata/claude_code_sentinels/tool_progress.json b/test/testdata/claude_code_sentinels/tool_progress.json new file mode 100644 index 00000000..45a3a22e --- /dev/null +++ b/test/testdata/claude_code_sentinels/tool_progress.json @@ -0,0 +1,10 @@ +{ + "type": "tool_progress", + "tool_use_id": "toolu_123", + "tool_name": "Bash", + "parent_tool_use_id": null, + "elapsed_time_seconds": 2.5, + "task_id": "task_123", + "uuid": "11111111-1111-4111-8111-111111111111", + "session_id": "sess_123" +} diff --git a/test/testdata/claude_code_sentinels/tool_use_summary.json b/test/testdata/claude_code_sentinels/tool_use_summary.json new file mode 100644 index 00000000..da3c4c3e --- /dev/null +++ b/test/testdata/claude_code_sentinels/tool_use_summary.json @@ -0,0 +1,7 @@ +{ + "type": "tool_use_summary", + "summary": "Searched in auth/", + "preceding_tool_use_ids": ["toolu_1", "toolu_2"], + "uuid": "33333333-3333-4333-8333-333333333333", + "session_id": "sess_123" +} From 058793c73a1de810ff5aaf1fb90c268228f58d5e Mon Sep 17 00:00:00 2001 From: "Duong M. CUONG" <> Date: Thu, 2 Apr 2026 14:44:44 +0000 Subject: [PATCH 13/21] feat(gitstore): honor configured branch and follow live remote default --- README.md | 2 + README_CN.md | 2 + README_JA.md | 2 + cmd/server/main.go | 6 +- internal/store/gitstore.go | 247 +++++++++++++- internal/store/gitstore_test.go | 585 ++++++++++++++++++++++++++++++++ 6 files changed, 838 insertions(+), 6 deletions(-) create mode 100644 internal/store/gitstore_test.go diff --git a/README.md b/README.md index 25e0090e..aeb6964e 100644 --- a/README.md +++ b/README.md @@ -60,6 +60,8 @@ Get 10% OFF GLM CODING PLAN:https://z.ai/subscribe?ic=8JVLJQFSKB CLIProxyAPI Guides: [https://help.router-for.me/](https://help.router-for.me/) +For the optional git-backed config store, `GITSTORE_GIT_BRANCH` is optional. Leave it empty or unset to follow the remote repository's default branch, and set it only when you want to force a specific branch. + ## Management API see [MANAGEMENT_API.md](https://help.router-for.me/management/api) diff --git a/README_CN.md b/README_CN.md index 671bd992..db1b4115 100644 --- a/README_CN.md +++ b/README_CN.md @@ -60,6 +60,8 @@ GLM CODING PLAN 是专为AI编码打造的订阅套餐,每月最低仅需20元 CLIProxyAPI 用户手册: [https://help.router-for.me/](https://help.router-for.me/cn/) +对于可选的 git 存储配置,`GITSTORE_GIT_BRANCH` 是可选项。留空或不设置时会跟随远程仓库的默认分支,只有在你需要强制指定分支时才设置它。 + ## 管理 API 文档 请参见 [MANAGEMENT_API_CN.md](https://help.router-for.me/cn/management/api) diff --git a/README_JA.md b/README_JA.md index cb0ae1de..d9b250e6 100644 --- a/README_JA.md +++ b/README_JA.md @@ -60,6 +60,8 @@ GLM CODING PLANを10%割引で取得:https://z.ai/subscribe?ic=8JVLJQFSKB CLIProxyAPIガイド:[https://help.router-for.me/ja/](https://help.router-for.me/ja/) +オプションのgitバックアップ設定ストアでは、`GITSTORE_GIT_BRANCH` は任意です。空のままにするか未設定にすると、リモートリポジトリのデフォルトブランチに従います。特定のブランチを強制したい場合のみ設定してください。 + ## 管理API [MANAGEMENT_API.md](https://help.router-for.me/ja/management/api)を参照 diff --git a/cmd/server/main.go b/cmd/server/main.go index e12e5261..1e3f88b1 100644 --- a/cmd/server/main.go +++ b/cmd/server/main.go @@ -140,6 +140,7 @@ func main() { gitStoreRemoteURL string gitStoreUser string gitStorePassword string + gitStoreBranch string gitStoreLocalPath string gitStoreInst *store.GitTokenStore gitStoreRoot string @@ -209,6 +210,9 @@ func main() { if value, ok := lookupEnv("GITSTORE_LOCAL_PATH", "gitstore_local_path"); ok { gitStoreLocalPath = value } + if value, ok := lookupEnv("GITSTORE_GIT_BRANCH", "gitstore_git_branch"); ok { + gitStoreBranch = value + } if value, ok := lookupEnv("OBJECTSTORE_ENDPOINT", "objectstore_endpoint"); ok { useObjectStore = true objectStoreEndpoint = value @@ -343,7 +347,7 @@ func main() { } gitStoreRoot = filepath.Join(gitStoreLocalPath, "gitstore") authDir := filepath.Join(gitStoreRoot, "auths") - gitStoreInst = store.NewGitTokenStore(gitStoreRemoteURL, gitStoreUser, gitStorePassword) + gitStoreInst = store.NewGitTokenStore(gitStoreRemoteURL, gitStoreUser, gitStorePassword, gitStoreBranch) gitStoreInst.SetBaseDir(authDir) if errRepo := gitStoreInst.EnsureRepository(); errRepo != nil { log.Errorf("failed to prepare git token store: %v", errRepo) diff --git a/internal/store/gitstore.go b/internal/store/gitstore.go index c8db660c..12e49794 100644 --- a/internal/store/gitstore.go +++ b/internal/store/gitstore.go @@ -32,16 +32,24 @@ type GitTokenStore struct { repoDir string configDir string remote string + branch string username string password string lastGC time.Time } +type resolvedRemoteBranch struct { + name plumbing.ReferenceName + hash plumbing.Hash +} + // NewGitTokenStore creates a token store that saves credentials to disk through the // TokenStorage implementation embedded in the token record. -func NewGitTokenStore(remote, username, password string) *GitTokenStore { +// When branch is non-empty, clone/pull/push operations target that branch instead of the remote default. +func NewGitTokenStore(remote, username, password, branch string) *GitTokenStore { return &GitTokenStore{ remote: remote, + branch: strings.TrimSpace(branch), username: username, password: password, } @@ -120,7 +128,11 @@ func (s *GitTokenStore) EnsureRepository() error { s.dirLock.Unlock() return fmt.Errorf("git token store: create repo dir: %w", errMk) } - if _, errClone := git.PlainClone(repoDir, &git.CloneOptions{Auth: authMethod, URL: s.remote}); errClone != nil { + cloneOpts := &git.CloneOptions{Auth: authMethod, URL: s.remote} + if s.branch != "" { + cloneOpts.ReferenceName = plumbing.NewBranchReferenceName(s.branch) + } + if _, errClone := git.PlainClone(repoDir, cloneOpts); errClone != nil { if errors.Is(errClone, transport.ErrEmptyRemoteRepository) { _ = os.RemoveAll(gitDir) repo, errInit := git.PlainInit(repoDir, false) @@ -128,6 +140,13 @@ func (s *GitTokenStore) EnsureRepository() error { s.dirLock.Unlock() return fmt.Errorf("git token store: init empty repo: %w", errInit) } + if s.branch != "" { + headRef := plumbing.NewSymbolicReference(plumbing.HEAD, plumbing.NewBranchReferenceName(s.branch)) + if errHead := repo.Storer.SetReference(headRef); errHead != nil { + s.dirLock.Unlock() + return fmt.Errorf("git token store: set head to branch %s: %w", s.branch, errHead) + } + } if _, errRemote := repo.Remote("origin"); errRemote != nil { if _, errCreate := repo.CreateRemote(&config.RemoteConfig{ Name: "origin", @@ -176,16 +195,39 @@ func (s *GitTokenStore) EnsureRepository() error { s.dirLock.Unlock() return fmt.Errorf("git token store: worktree: %w", errWorktree) } - if errPull := worktree.Pull(&git.PullOptions{Auth: authMethod, RemoteName: "origin"}); errPull != nil { + if s.branch != "" { + if errCheckout := s.checkoutConfiguredBranch(repo, worktree, authMethod); errCheckout != nil { + s.dirLock.Unlock() + return errCheckout + } + } else { + // When branch is unset, ensure the working tree follows the remote default branch + if err := checkoutRemoteDefaultBranch(repo, worktree, authMethod); err != nil { + if !shouldFallbackToCurrentBranch(repo, err) { + s.dirLock.Unlock() + return fmt.Errorf("git token store: checkout remote default: %w", err) + } + } + } + pullOpts := &git.PullOptions{Auth: authMethod, RemoteName: "origin"} + if s.branch != "" { + pullOpts.ReferenceName = plumbing.NewBranchReferenceName(s.branch) + } + if errPull := worktree.Pull(pullOpts); errPull != nil { switch { case errors.Is(errPull, git.NoErrAlreadyUpToDate), errors.Is(errPull, git.ErrUnstagedChanges), errors.Is(errPull, git.ErrNonFastForwardUpdate): // Ignore clean syncs, local edits, and remote divergence—local changes win. case errors.Is(errPull, transport.ErrAuthenticationRequired), - errors.Is(errPull, plumbing.ErrReferenceNotFound), errors.Is(errPull, transport.ErrEmptyRemoteRepository): // Ignore authentication prompts and empty remote references on initial sync. + case errors.Is(errPull, plumbing.ErrReferenceNotFound): + if s.branch != "" { + s.dirLock.Unlock() + return fmt.Errorf("git token store: pull: %w", errPull) + } + // Ignore missing references only when following the remote default branch. default: s.dirLock.Unlock() return fmt.Errorf("git token store: pull: %w", errPull) @@ -553,6 +595,192 @@ func (s *GitTokenStore) relativeToRepo(path string) (string, error) { return rel, nil } +func (s *GitTokenStore) checkoutConfiguredBranch(repo *git.Repository, worktree *git.Worktree, authMethod transport.AuthMethod) error { + branchRefName := plumbing.NewBranchReferenceName(s.branch) + headRef, errHead := repo.Head() + switch { + case errHead == nil && headRef.Name() == branchRefName: + return nil + case errHead != nil && !errors.Is(errHead, plumbing.ErrReferenceNotFound): + return fmt.Errorf("git token store: get head: %w", errHead) + } + + if err := worktree.Checkout(&git.CheckoutOptions{Branch: branchRefName}); err == nil { + return nil + } else if _, errRef := repo.Reference(branchRefName, true); errRef == nil { + return fmt.Errorf("git token store: checkout branch %s: %w", s.branch, err) + } else if !errors.Is(errRef, plumbing.ErrReferenceNotFound) { + return fmt.Errorf("git token store: inspect branch %s: %w", s.branch, errRef) + } else if err := s.checkoutConfiguredRemoteTrackingBranch(repo, worktree, branchRefName, authMethod); err != nil { + return fmt.Errorf("git token store: checkout branch %s: %w", s.branch, err) + } + + return nil +} + +func (s *GitTokenStore) checkoutConfiguredRemoteTrackingBranch(repo *git.Repository, worktree *git.Worktree, branchRefName plumbing.ReferenceName, authMethod transport.AuthMethod) error { + remoteRefName := plumbing.ReferenceName("refs/remotes/origin/" + s.branch) + remoteRef, err := repo.Reference(remoteRefName, true) + if errors.Is(err, plumbing.ErrReferenceNotFound) { + if errSync := syncRemoteReferences(repo, authMethod); errSync != nil { + return fmt.Errorf("sync remote refs: %w", errSync) + } + remoteRef, err = repo.Reference(remoteRefName, true) + } + if err != nil { + return err + } + if err := worktree.Checkout(&git.CheckoutOptions{Branch: branchRefName, Create: true, Hash: remoteRef.Hash()}); err != nil { + return err + } + + cfg, err := repo.Config() + if err != nil { + return fmt.Errorf("git token store: repo config: %w", err) + } + if _, ok := cfg.Branches[s.branch]; !ok { + cfg.Branches[s.branch] = &config.Branch{Name: s.branch} + } + cfg.Branches[s.branch].Remote = "origin" + cfg.Branches[s.branch].Merge = branchRefName + if err := repo.SetConfig(cfg); err != nil { + return fmt.Errorf("git token store: set branch config: %w", err) + } + return nil +} + +func syncRemoteReferences(repo *git.Repository, authMethod transport.AuthMethod) error { + if err := repo.Fetch(&git.FetchOptions{Auth: authMethod, RemoteName: "origin"}); err != nil && !errors.Is(err, git.NoErrAlreadyUpToDate) { + return err + } + return nil +} + +// resolveRemoteDefaultBranch queries the origin remote to determine the remote's default branch +// (the target of HEAD) and returns the corresponding local branch reference name (e.g. refs/heads/master). +func resolveRemoteDefaultBranch(repo *git.Repository, authMethod transport.AuthMethod) (resolvedRemoteBranch, error) { + if err := syncRemoteReferences(repo, authMethod); err != nil { + return resolvedRemoteBranch{}, fmt.Errorf("resolve remote default: sync remote refs: %w", err) + } + remote, err := repo.Remote("origin") + if err != nil { + return resolvedRemoteBranch{}, fmt.Errorf("resolve remote default: get remote: %w", err) + } + refs, err := remote.List(&git.ListOptions{Auth: authMethod}) + if err != nil { + if resolved, ok := resolveRemoteDefaultBranchFromLocal(repo); ok { + return resolved, nil + } + return resolvedRemoteBranch{}, fmt.Errorf("resolve remote default: list remote refs: %w", err) + } + for _, r := range refs { + if r.Name() == plumbing.HEAD { + if r.Type() == plumbing.SymbolicReference { + if target, ok := normalizeRemoteBranchReference(r.Target()); ok { + return resolvedRemoteBranch{name: target}, nil + } + } + s := r.String() + if idx := strings.Index(s, "->"); idx != -1 { + if target, ok := normalizeRemoteBranchReference(plumbing.ReferenceName(strings.TrimSpace(s[idx+2:]))); ok { + return resolvedRemoteBranch{name: target}, nil + } + } + } + } + if resolved, ok := resolveRemoteDefaultBranchFromLocal(repo); ok { + return resolved, nil + } + for _, r := range refs { + if normalized, ok := normalizeRemoteBranchReference(r.Name()); ok { + return resolvedRemoteBranch{name: normalized, hash: r.Hash()}, nil + } + } + return resolvedRemoteBranch{}, fmt.Errorf("resolve remote default: remote default branch not found") +} + +func resolveRemoteDefaultBranchFromLocal(repo *git.Repository) (resolvedRemoteBranch, bool) { + ref, err := repo.Reference(plumbing.ReferenceName("refs/remotes/origin/HEAD"), true) + if err != nil || ref.Type() != plumbing.SymbolicReference { + return resolvedRemoteBranch{}, false + } + target, ok := normalizeRemoteBranchReference(ref.Target()) + if !ok { + return resolvedRemoteBranch{}, false + } + return resolvedRemoteBranch{name: target}, true +} + +func normalizeRemoteBranchReference(name plumbing.ReferenceName) (plumbing.ReferenceName, bool) { + switch { + case strings.HasPrefix(name.String(), "refs/heads/"): + return name, true + case strings.HasPrefix(name.String(), "refs/remotes/origin/"): + return plumbing.NewBranchReferenceName(strings.TrimPrefix(name.String(), "refs/remotes/origin/")), true + default: + return "", false + } +} + +func shouldFallbackToCurrentBranch(repo *git.Repository, err error) bool { + if !errors.Is(err, transport.ErrAuthenticationRequired) && !errors.Is(err, transport.ErrEmptyRemoteRepository) { + return false + } + _, headErr := repo.Head() + return headErr == nil +} + +// checkoutRemoteDefaultBranch ensures the working tree is checked out to the remote's default branch +// (the branch target of origin/HEAD). If the local branch does not exist it will be created to track +// the remote branch. +func checkoutRemoteDefaultBranch(repo *git.Repository, worktree *git.Worktree, authMethod transport.AuthMethod) error { + resolved, err := resolveRemoteDefaultBranch(repo, authMethod) + if err != nil { + return err + } + branchRefName := resolved.name + // If HEAD already points to the desired branch, nothing to do. + headRef, errHead := repo.Head() + if errHead == nil && headRef.Name() == branchRefName { + return nil + } + // If local branch exists, attempt a checkout + if _, err := repo.Reference(branchRefName, true); err == nil { + if err := worktree.Checkout(&git.CheckoutOptions{Branch: branchRefName}); err != nil { + return fmt.Errorf("checkout branch %s: %w", branchRefName.String(), err) + } + return nil + } + // Try to find the corresponding remote tracking ref (refs/remotes/origin/) + branchShort := strings.TrimPrefix(branchRefName.String(), "refs/heads/") + remoteRefName := plumbing.ReferenceName("refs/remotes/origin/" + branchShort) + hash := resolved.hash + if remoteRef, err := repo.Reference(remoteRefName, true); err == nil { + hash = remoteRef.Hash() + } else if err != nil && !errors.Is(err, plumbing.ErrReferenceNotFound) { + return fmt.Errorf("checkout remote default: remote ref %s: %w", remoteRefName.String(), err) + } + if hash == plumbing.ZeroHash { + return fmt.Errorf("checkout remote default: remote ref %s not found", remoteRefName.String()) + } + if err := worktree.Checkout(&git.CheckoutOptions{Branch: branchRefName, Create: true, Hash: hash}); err != nil { + return fmt.Errorf("checkout create branch %s: %w", branchRefName.String(), err) + } + cfg, err := repo.Config() + if err != nil { + return fmt.Errorf("git token store: repo config: %w", err) + } + if _, ok := cfg.Branches[branchShort]; !ok { + cfg.Branches[branchShort] = &config.Branch{Name: branchShort} + } + cfg.Branches[branchShort].Remote = "origin" + cfg.Branches[branchShort].Merge = branchRefName + if err := repo.SetConfig(cfg); err != nil { + return fmt.Errorf("git token store: set branch config: %w", err) + } + return nil +} + func (s *GitTokenStore) commitAndPushLocked(message string, relPaths ...string) error { repoDir := s.repoDirSnapshot() if repoDir == "" { @@ -618,7 +846,16 @@ func (s *GitTokenStore) commitAndPushLocked(message string, relPaths ...string) return errRewrite } s.maybeRunGC(repo) - if err = repo.Push(&git.PushOptions{Auth: s.gitAuth(), Force: true}); err != nil { + pushOpts := &git.PushOptions{Auth: s.gitAuth(), Force: true} + if s.branch != "" { + pushOpts.RefSpecs = []config.RefSpec{config.RefSpec("refs/heads/" + s.branch + ":refs/heads/" + s.branch)} + } else { + // When branch is unset, pin push to the currently checked-out branch. + if headRef, err := repo.Head(); err == nil { + pushOpts.RefSpecs = []config.RefSpec{config.RefSpec(headRef.Name().String() + ":" + headRef.Name().String())} + } + } + if err = repo.Push(pushOpts); err != nil { if errors.Is(err, git.NoErrAlreadyUpToDate) { return nil } diff --git a/internal/store/gitstore_test.go b/internal/store/gitstore_test.go new file mode 100644 index 00000000..c5e99039 --- /dev/null +++ b/internal/store/gitstore_test.go @@ -0,0 +1,585 @@ +package store + +import ( + "net/http" + "net/http/httptest" + "os" + "path/filepath" + "testing" + "time" + + "github.com/go-git/go-git/v6" + gitconfig "github.com/go-git/go-git/v6/config" + "github.com/go-git/go-git/v6/plumbing" + "github.com/go-git/go-git/v6/plumbing/object" +) + +type testBranchSpec struct { + name string + contents string +} + +func TestEnsureRepositoryUsesRemoteDefaultBranchWhenBranchNotConfigured(t *testing.T) { + root := t.TempDir() + remoteDir := setupGitRemoteRepository(t, root, "trunk", + testBranchSpec{name: "trunk", contents: "remote default branch\n"}, + testBranchSpec{name: "release/2026", contents: "release branch\n"}, + ) + + store := NewGitTokenStore(remoteDir, "", "", "") + store.SetBaseDir(filepath.Join(root, "workspace", "auths")) + + if err := store.EnsureRepository(); err != nil { + t.Fatalf("EnsureRepository: %v", err) + } + + assertRepositoryBranchAndContents(t, filepath.Join(root, "workspace"), "trunk", "remote default branch\n") + advanceRemoteBranch(t, filepath.Join(root, "seed"), remoteDir, "trunk", "remote default branch updated\n", "advance trunk") + advanceRemoteBranch(t, filepath.Join(root, "seed"), remoteDir, "release/2026", "release branch updated\n", "advance release") + + if err := store.EnsureRepository(); err != nil { + t.Fatalf("EnsureRepository second call: %v", err) + } + + assertRepositoryBranchAndContents(t, filepath.Join(root, "workspace"), "trunk", "remote default branch updated\n") + assertRemoteHeadBranch(t, remoteDir, "trunk") +} + +func TestEnsureRepositoryUsesConfiguredBranchWhenExplicitlySet(t *testing.T) { + root := t.TempDir() + remoteDir := setupGitRemoteRepository(t, root, "trunk", + testBranchSpec{name: "trunk", contents: "remote default branch\n"}, + testBranchSpec{name: "release/2026", contents: "release branch\n"}, + ) + + store := NewGitTokenStore(remoteDir, "", "", "release/2026") + store.SetBaseDir(filepath.Join(root, "workspace", "auths")) + + if err := store.EnsureRepository(); err != nil { + t.Fatalf("EnsureRepository: %v", err) + } + + assertRepositoryBranchAndContents(t, filepath.Join(root, "workspace"), "release/2026", "release branch\n") + advanceRemoteBranch(t, filepath.Join(root, "seed"), remoteDir, "trunk", "remote default branch updated\n", "advance trunk") + advanceRemoteBranch(t, filepath.Join(root, "seed"), remoteDir, "release/2026", "release branch updated\n", "advance release") + + if err := store.EnsureRepository(); err != nil { + t.Fatalf("EnsureRepository second call: %v", err) + } + + assertRepositoryBranchAndContents(t, filepath.Join(root, "workspace"), "release/2026", "release branch updated\n") + assertRemoteHeadBranch(t, remoteDir, "trunk") +} + +func TestEnsureRepositoryReturnsErrorForMissingConfiguredBranch(t *testing.T) { + root := t.TempDir() + remoteDir := setupGitRemoteRepository(t, root, "trunk", + testBranchSpec{name: "trunk", contents: "remote default branch\n"}, + ) + + store := NewGitTokenStore(remoteDir, "", "", "missing-branch") + store.SetBaseDir(filepath.Join(root, "workspace", "auths")) + + err := store.EnsureRepository() + if err == nil { + t.Fatal("EnsureRepository succeeded, want error for nonexistent configured branch") + } + assertRemoteHeadBranch(t, remoteDir, "trunk") +} + +func TestEnsureRepositoryReturnsErrorForMissingConfiguredBranchOnExistingRepositoryPull(t *testing.T) { + root := t.TempDir() + remoteDir := setupGitRemoteRepository(t, root, "trunk", + testBranchSpec{name: "trunk", contents: "remote default branch\n"}, + ) + + baseDir := filepath.Join(root, "workspace", "auths") + store := NewGitTokenStore(remoteDir, "", "", "") + store.SetBaseDir(baseDir) + + if err := store.EnsureRepository(); err != nil { + t.Fatalf("EnsureRepository initial clone: %v", err) + } + + reopened := NewGitTokenStore(remoteDir, "", "", "missing-branch") + reopened.SetBaseDir(baseDir) + + err := reopened.EnsureRepository() + if err == nil { + t.Fatal("EnsureRepository succeeded on reopen, want error for nonexistent configured branch") + } + assertRepositoryHeadBranch(t, filepath.Join(root, "workspace"), "trunk") + assertRemoteHeadBranch(t, remoteDir, "trunk") +} + +func TestEnsureRepositoryInitializesEmptyRemoteUsingConfiguredBranch(t *testing.T) { + root := t.TempDir() + remoteDir := filepath.Join(root, "remote.git") + if _, err := git.PlainInit(remoteDir, true); err != nil { + t.Fatalf("init bare remote: %v", err) + } + + branch := "feature/gemini-fix" + store := NewGitTokenStore(remoteDir, "", "", branch) + store.SetBaseDir(filepath.Join(root, "workspace", "auths")) + + if err := store.EnsureRepository(); err != nil { + t.Fatalf("EnsureRepository: %v", err) + } + + assertRepositoryHeadBranch(t, filepath.Join(root, "workspace"), branch) + assertRemoteBranchExistsWithCommit(t, remoteDir, branch) + assertRemoteBranchDoesNotExist(t, remoteDir, "master") +} + +func TestEnsureRepositoryExistingRepoSwitchesToConfiguredBranch(t *testing.T) { + root := t.TempDir() + remoteDir := setupGitRemoteRepository(t, root, "master", + testBranchSpec{name: "master", contents: "remote master branch\n"}, + testBranchSpec{name: "develop", contents: "remote develop branch\n"}, + ) + + baseDir := filepath.Join(root, "workspace", "auths") + store := NewGitTokenStore(remoteDir, "", "", "") + store.SetBaseDir(baseDir) + + if err := store.EnsureRepository(); err != nil { + t.Fatalf("EnsureRepository initial clone: %v", err) + } + assertRepositoryBranchAndContents(t, filepath.Join(root, "workspace"), "master", "remote master branch\n") + + reopened := NewGitTokenStore(remoteDir, "", "", "develop") + reopened.SetBaseDir(baseDir) + + if err := reopened.EnsureRepository(); err != nil { + t.Fatalf("EnsureRepository reopen: %v", err) + } + assertRepositoryBranchAndContents(t, filepath.Join(root, "workspace"), "develop", "remote develop branch\n") + + workspaceDir := filepath.Join(root, "workspace") + if err := os.WriteFile(filepath.Join(workspaceDir, "branch.txt"), []byte("local develop update\n"), 0o600); err != nil { + t.Fatalf("write local branch marker: %v", err) + } + + reopened.mu.Lock() + err := reopened.commitAndPushLocked("Update develop branch marker", "branch.txt") + reopened.mu.Unlock() + if err != nil { + t.Fatalf("commitAndPushLocked: %v", err) + } + + assertRepositoryHeadBranch(t, workspaceDir, "develop") + assertRemoteBranchContents(t, remoteDir, "develop", "local develop update\n") + assertRemoteBranchContents(t, remoteDir, "master", "remote master branch\n") +} + +func TestEnsureRepositoryExistingRepoSwitchesToConfiguredBranchCreatedAfterClone(t *testing.T) { + root := t.TempDir() + remoteDir := setupGitRemoteRepository(t, root, "master", + testBranchSpec{name: "master", contents: "remote master branch\n"}, + ) + + baseDir := filepath.Join(root, "workspace", "auths") + store := NewGitTokenStore(remoteDir, "", "", "") + store.SetBaseDir(baseDir) + + if err := store.EnsureRepository(); err != nil { + t.Fatalf("EnsureRepository initial clone: %v", err) + } + assertRepositoryBranchAndContents(t, filepath.Join(root, "workspace"), "master", "remote master branch\n") + + advanceRemoteBranchFromNewBranch(t, filepath.Join(root, "seed"), remoteDir, "release/2026", "release branch\n", "create release") + + reopened := NewGitTokenStore(remoteDir, "", "", "release/2026") + reopened.SetBaseDir(baseDir) + + if err := reopened.EnsureRepository(); err != nil { + t.Fatalf("EnsureRepository reopen: %v", err) + } + assertRepositoryBranchAndContents(t, filepath.Join(root, "workspace"), "release/2026", "release branch\n") +} + +func TestEnsureRepositoryResetsToRemoteDefaultWhenBranchUnset(t *testing.T) { + root := t.TempDir() + remoteDir := setupGitRemoteRepository(t, root, "master", + testBranchSpec{name: "master", contents: "remote master branch\n"}, + testBranchSpec{name: "develop", contents: "remote develop branch\n"}, + ) + + baseDir := filepath.Join(root, "workspace", "auths") + // First store pins to develop and prepares local workspace + storePinned := NewGitTokenStore(remoteDir, "", "", "develop") + storePinned.SetBaseDir(baseDir) + if err := storePinned.EnsureRepository(); err != nil { + t.Fatalf("EnsureRepository pinned: %v", err) + } + assertRepositoryBranchAndContents(t, filepath.Join(root, "workspace"), "develop", "remote develop branch\n") + + // Second store has branch unset and should reset local workspace to remote default (master) + storeDefault := NewGitTokenStore(remoteDir, "", "", "") + storeDefault.SetBaseDir(baseDir) + if err := storeDefault.EnsureRepository(); err != nil { + t.Fatalf("EnsureRepository default: %v", err) + } + // Local HEAD should now follow remote default (master) + assertRepositoryHeadBranch(t, filepath.Join(root, "workspace"), "master") + + // Make a local change and push using the store with branch unset; push should update remote master + workspaceDir := filepath.Join(root, "workspace") + if err := os.WriteFile(filepath.Join(workspaceDir, "branch.txt"), []byte("local master update\n"), 0o600); err != nil { + t.Fatalf("write local master marker: %v", err) + } + storeDefault.mu.Lock() + if err := storeDefault.commitAndPushLocked("Update master marker", "branch.txt"); err != nil { + storeDefault.mu.Unlock() + t.Fatalf("commitAndPushLocked: %v", err) + } + storeDefault.mu.Unlock() + + assertRemoteBranchContents(t, remoteDir, "master", "local master update\n") +} + +func TestEnsureRepositoryFollowsRenamedRemoteDefaultBranchWhenAvailable(t *testing.T) { + root := t.TempDir() + remoteDir := setupGitRemoteRepository(t, root, "master", + testBranchSpec{name: "master", contents: "remote master branch\n"}, + testBranchSpec{name: "main", contents: "remote main branch\n"}, + ) + + baseDir := filepath.Join(root, "workspace", "auths") + store := NewGitTokenStore(remoteDir, "", "", "") + store.SetBaseDir(baseDir) + + if err := store.EnsureRepository(); err != nil { + t.Fatalf("EnsureRepository initial clone: %v", err) + } + assertRepositoryBranchAndContents(t, filepath.Join(root, "workspace"), "master", "remote master branch\n") + + setRemoteHeadBranch(t, remoteDir, "main") + advanceRemoteBranch(t, filepath.Join(root, "seed"), remoteDir, "main", "remote main branch updated\n", "advance main") + + reopened := NewGitTokenStore(remoteDir, "", "", "") + reopened.SetBaseDir(baseDir) + + if err := reopened.EnsureRepository(); err != nil { + t.Fatalf("EnsureRepository after remote default rename: %v", err) + } + assertRepositoryBranchAndContents(t, filepath.Join(root, "workspace"), "main", "remote main branch updated\n") + assertRemoteHeadBranch(t, remoteDir, "main") +} + +func TestEnsureRepositoryKeepsCurrentBranchWhenRemoteDefaultCannotBeResolved(t *testing.T) { + root := t.TempDir() + remoteDir := setupGitRemoteRepository(t, root, "master", + testBranchSpec{name: "master", contents: "remote master branch\n"}, + testBranchSpec{name: "develop", contents: "remote develop branch\n"}, + ) + + baseDir := filepath.Join(root, "workspace", "auths") + pinned := NewGitTokenStore(remoteDir, "", "", "develop") + pinned.SetBaseDir(baseDir) + if err := pinned.EnsureRepository(); err != nil { + t.Fatalf("EnsureRepository pinned: %v", err) + } + assertRepositoryBranchAndContents(t, filepath.Join(root, "workspace"), "develop", "remote develop branch\n") + + authServer := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + w.Header().Set("WWW-Authenticate", `Basic realm="git"`) + http.Error(w, "auth required", http.StatusUnauthorized) + })) + defer authServer.Close() + + repo, err := git.PlainOpen(filepath.Join(root, "workspace")) + if err != nil { + t.Fatalf("open workspace repo: %v", err) + } + cfg, err := repo.Config() + if err != nil { + t.Fatalf("read repo config: %v", err) + } + cfg.Remotes["origin"].URLs = []string{authServer.URL} + if err := repo.SetConfig(cfg); err != nil { + t.Fatalf("set repo config: %v", err) + } + + reopened := NewGitTokenStore(remoteDir, "", "", "") + reopened.SetBaseDir(baseDir) + + if err := reopened.EnsureRepository(); err != nil { + t.Fatalf("EnsureRepository default branch fallback: %v", err) + } + assertRepositoryHeadBranch(t, filepath.Join(root, "workspace"), "develop") +} + +func setupGitRemoteRepository(t *testing.T, root, defaultBranch string, branches ...testBranchSpec) string { + t.Helper() + + remoteDir := filepath.Join(root, "remote.git") + if _, err := git.PlainInit(remoteDir, true); err != nil { + t.Fatalf("init bare remote: %v", err) + } + + seedDir := filepath.Join(root, "seed") + seedRepo, err := git.PlainInit(seedDir, false) + if err != nil { + t.Fatalf("init seed repo: %v", err) + } + if err := seedRepo.Storer.SetReference(plumbing.NewSymbolicReference(plumbing.HEAD, plumbing.NewBranchReferenceName(defaultBranch))); err != nil { + t.Fatalf("set seed HEAD: %v", err) + } + + worktree, err := seedRepo.Worktree() + if err != nil { + t.Fatalf("open seed worktree: %v", err) + } + + defaultSpec, ok := findBranchSpec(branches, defaultBranch) + if !ok { + t.Fatalf("missing default branch spec for %q", defaultBranch) + } + commitBranchMarker(t, seedDir, worktree, defaultSpec, "seed default branch") + + for _, branch := range branches { + if branch.name == defaultBranch { + continue + } + if err := worktree.Checkout(&git.CheckoutOptions{Branch: plumbing.NewBranchReferenceName(defaultBranch)}); err != nil { + t.Fatalf("checkout default branch %s: %v", defaultBranch, err) + } + if err := worktree.Checkout(&git.CheckoutOptions{Branch: plumbing.NewBranchReferenceName(branch.name), Create: true}); err != nil { + t.Fatalf("create branch %s: %v", branch.name, err) + } + commitBranchMarker(t, seedDir, worktree, branch, "seed branch "+branch.name) + } + + if _, err := seedRepo.CreateRemote(&gitconfig.RemoteConfig{Name: "origin", URLs: []string{remoteDir}}); err != nil { + t.Fatalf("create origin remote: %v", err) + } + if err := seedRepo.Push(&git.PushOptions{ + RemoteName: "origin", + RefSpecs: []gitconfig.RefSpec{gitconfig.RefSpec("refs/heads/*:refs/heads/*")}, + }); err != nil { + t.Fatalf("push seed branches: %v", err) + } + + remoteRepo, err := git.PlainOpen(remoteDir) + if err != nil { + t.Fatalf("open remote repo: %v", err) + } + if err := remoteRepo.Storer.SetReference(plumbing.NewSymbolicReference(plumbing.HEAD, plumbing.NewBranchReferenceName(defaultBranch))); err != nil { + t.Fatalf("set remote HEAD: %v", err) + } + + return remoteDir +} + +func commitBranchMarker(t *testing.T, seedDir string, worktree *git.Worktree, branch testBranchSpec, message string) { + t.Helper() + + if err := os.WriteFile(filepath.Join(seedDir, "branch.txt"), []byte(branch.contents), 0o600); err != nil { + t.Fatalf("write branch marker for %s: %v", branch.name, err) + } + if _, err := worktree.Add("branch.txt"); err != nil { + t.Fatalf("add branch marker for %s: %v", branch.name, err) + } + if _, err := worktree.Commit(message, &git.CommitOptions{ + Author: &object.Signature{ + Name: "CLIProxyAPI", + Email: "cliproxy@local", + When: time.Unix(1711929600, 0), + }, + }); err != nil { + t.Fatalf("commit branch marker for %s: %v", branch.name, err) + } +} + +func advanceRemoteBranch(t *testing.T, seedDir, remoteDir, branch, contents, message string) { + t.Helper() + + seedRepo, err := git.PlainOpen(seedDir) + if err != nil { + t.Fatalf("open seed repo: %v", err) + } + worktree, err := seedRepo.Worktree() + if err != nil { + t.Fatalf("open seed worktree: %v", err) + } + if err := worktree.Checkout(&git.CheckoutOptions{Branch: plumbing.NewBranchReferenceName(branch)}); err != nil { + t.Fatalf("checkout branch %s: %v", branch, err) + } + commitBranchMarker(t, seedDir, worktree, testBranchSpec{name: branch, contents: contents}, message) + if err := seedRepo.Push(&git.PushOptions{ + RemoteName: "origin", + RefSpecs: []gitconfig.RefSpec{ + gitconfig.RefSpec(plumbing.NewBranchReferenceName(branch).String() + ":" + plumbing.NewBranchReferenceName(branch).String()), + }, + }); err != nil { + t.Fatalf("push branch %s update to %s: %v", branch, remoteDir, err) + } +} + +func advanceRemoteBranchFromNewBranch(t *testing.T, seedDir, remoteDir, branch, contents, message string) { + t.Helper() + + seedRepo, err := git.PlainOpen(seedDir) + if err != nil { + t.Fatalf("open seed repo: %v", err) + } + worktree, err := seedRepo.Worktree() + if err != nil { + t.Fatalf("open seed worktree: %v", err) + } + if err := worktree.Checkout(&git.CheckoutOptions{Branch: plumbing.NewBranchReferenceName("master")}); err != nil { + t.Fatalf("checkout master before creating %s: %v", branch, err) + } + if err := worktree.Checkout(&git.CheckoutOptions{Branch: plumbing.NewBranchReferenceName(branch), Create: true}); err != nil { + t.Fatalf("create branch %s: %v", branch, err) + } + commitBranchMarker(t, seedDir, worktree, testBranchSpec{name: branch, contents: contents}, message) + if err := seedRepo.Push(&git.PushOptions{ + RemoteName: "origin", + RefSpecs: []gitconfig.RefSpec{ + gitconfig.RefSpec(plumbing.NewBranchReferenceName(branch).String() + ":" + plumbing.NewBranchReferenceName(branch).String()), + }, + }); err != nil { + t.Fatalf("push new branch %s update to %s: %v", branch, remoteDir, err) + } +} + +func findBranchSpec(branches []testBranchSpec, name string) (testBranchSpec, bool) { + for _, branch := range branches { + if branch.name == name { + return branch, true + } + } + return testBranchSpec{}, false +} + +func assertRepositoryBranchAndContents(t *testing.T, repoDir, branch, wantContents string) { + t.Helper() + + repo, err := git.PlainOpen(repoDir) + if err != nil { + t.Fatalf("open local repo: %v", err) + } + head, err := repo.Head() + if err != nil { + t.Fatalf("local repo head: %v", err) + } + if got, want := head.Name(), plumbing.NewBranchReferenceName(branch); got != want { + t.Fatalf("local head branch = %s, want %s", got, want) + } + contents, err := os.ReadFile(filepath.Join(repoDir, "branch.txt")) + if err != nil { + t.Fatalf("read branch marker: %v", err) + } + if got := string(contents); got != wantContents { + t.Fatalf("branch marker contents = %q, want %q", got, wantContents) + } +} + +func assertRepositoryHeadBranch(t *testing.T, repoDir, branch string) { + t.Helper() + + repo, err := git.PlainOpen(repoDir) + if err != nil { + t.Fatalf("open local repo: %v", err) + } + head, err := repo.Head() + if err != nil { + t.Fatalf("local repo head: %v", err) + } + if got, want := head.Name(), plumbing.NewBranchReferenceName(branch); got != want { + t.Fatalf("local head branch = %s, want %s", got, want) + } +} + +func assertRemoteHeadBranch(t *testing.T, remoteDir, branch string) { + t.Helper() + + remoteRepo, err := git.PlainOpen(remoteDir) + if err != nil { + t.Fatalf("open remote repo: %v", err) + } + head, err := remoteRepo.Reference(plumbing.HEAD, false) + if err != nil { + t.Fatalf("read remote HEAD: %v", err) + } + if got, want := head.Target(), plumbing.NewBranchReferenceName(branch); got != want { + t.Fatalf("remote HEAD target = %s, want %s", got, want) + } +} + +func setRemoteHeadBranch(t *testing.T, remoteDir, branch string) { + t.Helper() + + remoteRepo, err := git.PlainOpen(remoteDir) + if err != nil { + t.Fatalf("open remote repo: %v", err) + } + if err := remoteRepo.Storer.SetReference(plumbing.NewSymbolicReference(plumbing.HEAD, plumbing.NewBranchReferenceName(branch))); err != nil { + t.Fatalf("set remote HEAD to %s: %v", branch, err) + } +} + +func assertRemoteBranchExistsWithCommit(t *testing.T, remoteDir, branch string) { + t.Helper() + + remoteRepo, err := git.PlainOpen(remoteDir) + if err != nil { + t.Fatalf("open remote repo: %v", err) + } + ref, err := remoteRepo.Reference(plumbing.NewBranchReferenceName(branch), false) + if err != nil { + t.Fatalf("read remote branch %s: %v", branch, err) + } + if got := ref.Hash(); got == plumbing.ZeroHash { + t.Fatalf("remote branch %s hash = %s, want non-zero hash", branch, got) + } +} + +func assertRemoteBranchDoesNotExist(t *testing.T, remoteDir, branch string) { + t.Helper() + + remoteRepo, err := git.PlainOpen(remoteDir) + if err != nil { + t.Fatalf("open remote repo: %v", err) + } + if _, err := remoteRepo.Reference(plumbing.NewBranchReferenceName(branch), false); err == nil { + t.Fatalf("remote branch %s exists, want missing", branch) + } else if err != plumbing.ErrReferenceNotFound { + t.Fatalf("read remote branch %s: %v", branch, err) + } +} + +func assertRemoteBranchContents(t *testing.T, remoteDir, branch, wantContents string) { + t.Helper() + + remoteRepo, err := git.PlainOpen(remoteDir) + if err != nil { + t.Fatalf("open remote repo: %v", err) + } + ref, err := remoteRepo.Reference(plumbing.NewBranchReferenceName(branch), false) + if err != nil { + t.Fatalf("read remote branch %s: %v", branch, err) + } + commit, err := remoteRepo.CommitObject(ref.Hash()) + if err != nil { + t.Fatalf("read remote branch %s commit: %v", branch, err) + } + tree, err := commit.Tree() + if err != nil { + t.Fatalf("read remote branch %s tree: %v", branch, err) + } + file, err := tree.File("branch.txt") + if err != nil { + t.Fatalf("read remote branch %s file: %v", branch, err) + } + contents, err := file.Contents() + if err != nil { + t.Fatalf("read remote branch %s contents: %v", branch, err) + } + if contents != wantContents { + t.Fatalf("remote branch %s contents = %q, want %q", branch, contents, wantContents) + } +} From 9b5ce8c64f91cb6af2b34bb9c95eac7ca931c9b2 Mon Sep 17 00:00:00 2001 From: mpfo0106 Date: Fri, 3 Apr 2026 00:13:02 +0900 Subject: [PATCH 14/21] Keep Claude builtin helpers aligned with the shared helper layout The review asked for the builtin tool registry helper to live with the rest of executor support utilities. This moves the registry code into the helps package, exports the minimal surface executor needs, and keeps behavior tests with the executor while leaving registry-focused checks with the helper. Constraint: Requested layout keeps executor helper utilities centralized under internal/runtime/executor/helps Rejected: Keep the files in executor and reply with rationale | conflicts with requested package organization Confidence: high Scope-risk: narrow Reversibility: clean Directive: Keep executor behavior tests near applyClaudeToolPrefix and keep pure registry tests in helps Tested: go test ./internal/runtime/executor/helps ./internal/runtime/executor -run 'Claude|Builtin|Tool'; go test ./test/...; go test ./... Not-tested: End-to-end Claude Code direct-connect/session runtime behavior --- .../executor/claude_builtin_tools_test.go | 46 ------------------- internal/runtime/executor/claude_executor.go | 2 +- .../runtime/executor/claude_executor_test.go | 29 ++++++++++++ .../{ => helps}/claude_builtin_tools.go | 4 +- .../helps/claude_builtin_tools_test.go | 32 +++++++++++++ 5 files changed, 64 insertions(+), 49 deletions(-) delete mode 100644 internal/runtime/executor/claude_builtin_tools_test.go rename internal/runtime/executor/{ => helps}/claude_builtin_tools.go (90%) create mode 100644 internal/runtime/executor/helps/claude_builtin_tools_test.go diff --git a/internal/runtime/executor/claude_builtin_tools_test.go b/internal/runtime/executor/claude_builtin_tools_test.go deleted file mode 100644 index 34036fa0..00000000 --- a/internal/runtime/executor/claude_builtin_tools_test.go +++ /dev/null @@ -1,46 +0,0 @@ -package executor - -import ( - "fmt" - "testing" - - "github.com/tidwall/gjson" -) - -func TestClaudeBuiltinToolRegistry_DefaultSeedFallback(t *testing.T) { - registry := augmentClaudeBuiltinToolRegistry(nil, nil) - for _, name := range defaultClaudeBuiltinToolNames { - if !registry[name] { - t.Fatalf("default builtin %q missing from fallback registry", name) - } - } -} - -func TestApplyClaudeToolPrefix_KnownFallbackBuiltinsRemainUnprefixed(t *testing.T) { - for _, builtin := range defaultClaudeBuiltinToolNames { - t.Run(builtin, func(t *testing.T) { - input := []byte(fmt.Sprintf(`{ - "tools":[{"name":"Read"}], - "tool_choice":{"type":"tool","name":%q}, - "messages":[{"role":"assistant","content":[{"type":"tool_use","name":%q,"id":"toolu_1","input":{}},{"type":"tool_reference","tool_name":%q},{"type":"tool_result","tool_use_id":"toolu_1","content":[{"type":"tool_reference","tool_name":%q}]}]}] - }`, builtin, builtin, builtin, builtin)) - out := applyClaudeToolPrefix(input, "proxy_") - - if got := gjson.GetBytes(out, "tool_choice.name").String(); got != builtin { - t.Fatalf("tool_choice.name = %q, want %q", got, builtin) - } - if got := gjson.GetBytes(out, "messages.0.content.0.name").String(); got != builtin { - t.Fatalf("messages.0.content.0.name = %q, want %q", got, builtin) - } - if got := gjson.GetBytes(out, "messages.0.content.1.tool_name").String(); got != builtin { - t.Fatalf("messages.0.content.1.tool_name = %q, want %q", got, builtin) - } - if got := gjson.GetBytes(out, "messages.0.content.2.content.0.tool_name").String(); got != builtin { - t.Fatalf("messages.0.content.2.content.0.tool_name = %q, want %q", got, builtin) - } - if got := gjson.GetBytes(out, "tools.0.name").String(); got != "proxy_Read" { - t.Fatalf("tools.0.name = %q, want %q", got, "proxy_Read") - } - }) - } -} diff --git a/internal/runtime/executor/claude_executor.go b/internal/runtime/executor/claude_executor.go index d1d2e136..120b1f35 100644 --- a/internal/runtime/executor/claude_executor.go +++ b/internal/runtime/executor/claude_executor.go @@ -921,7 +921,7 @@ func applyClaudeToolPrefix(body []byte, prefix string) []byte { // Collect built-in tool names from the authoritative fallback seed list and // augment it with any typed built-ins present in the current request body. - builtinTools := augmentClaudeBuiltinToolRegistry(body, nil) + builtinTools := helps.AugmentClaudeBuiltinToolRegistry(body, nil) if tools := gjson.GetBytes(body, "tools"); tools.Exists() && tools.IsArray() { tools.ForEach(func(index, tool gjson.Result) bool { diff --git a/internal/runtime/executor/claude_executor_test.go b/internal/runtime/executor/claude_executor_test.go index 8e8173dd..e5f907b7 100644 --- a/internal/runtime/executor/claude_executor_test.go +++ b/internal/runtime/executor/claude_executor_test.go @@ -739,6 +739,35 @@ func TestApplyClaudeToolPrefix_ToolChoiceBuiltin(t *testing.T) { } } +func TestApplyClaudeToolPrefix_KnownFallbackBuiltinsRemainUnprefixed(t *testing.T) { + for _, builtin := range []string{"web_search", "code_execution", "text_editor", "computer"} { + t.Run(builtin, func(t *testing.T) { + input := []byte(fmt.Sprintf(`{ + "tools":[{"name":"Read"}], + "tool_choice":{"type":"tool","name":%q}, + "messages":[{"role":"assistant","content":[{"type":"tool_use","name":%q,"id":"toolu_1","input":{}},{"type":"tool_reference","tool_name":%q},{"type":"tool_result","tool_use_id":"toolu_1","content":[{"type":"tool_reference","tool_name":%q}]}]}] + }`, builtin, builtin, builtin, builtin)) + out := applyClaudeToolPrefix(input, "proxy_") + + if got := gjson.GetBytes(out, "tool_choice.name").String(); got != builtin { + t.Fatalf("tool_choice.name = %q, want %q", got, builtin) + } + if got := gjson.GetBytes(out, "messages.0.content.0.name").String(); got != builtin { + t.Fatalf("messages.0.content.0.name = %q, want %q", got, builtin) + } + if got := gjson.GetBytes(out, "messages.0.content.1.tool_name").String(); got != builtin { + t.Fatalf("messages.0.content.1.tool_name = %q, want %q", got, builtin) + } + if got := gjson.GetBytes(out, "messages.0.content.2.content.0.tool_name").String(); got != builtin { + t.Fatalf("messages.0.content.2.content.0.tool_name = %q, want %q", got, builtin) + } + if got := gjson.GetBytes(out, "tools.0.name").String(); got != "proxy_Read" { + t.Fatalf("tools.0.name = %q, want %q", got, "proxy_Read") + } + }) + } +} + func TestStripClaudeToolPrefixFromResponse(t *testing.T) { input := []byte(`{"content":[{"type":"tool_use","name":"proxy_alpha","id":"t1","input":{}},{"type":"tool_use","name":"bravo","id":"t2","input":{}}]}`) out := stripClaudeToolPrefixFromResponse(input, "proxy_") diff --git a/internal/runtime/executor/claude_builtin_tools.go b/internal/runtime/executor/helps/claude_builtin_tools.go similarity index 90% rename from internal/runtime/executor/claude_builtin_tools.go rename to internal/runtime/executor/helps/claude_builtin_tools.go index 8c3592f7..5ee2b08d 100644 --- a/internal/runtime/executor/claude_builtin_tools.go +++ b/internal/runtime/executor/helps/claude_builtin_tools.go @@ -1,4 +1,4 @@ -package executor +package helps import "github.com/tidwall/gjson" @@ -17,7 +17,7 @@ func newClaudeBuiltinToolRegistry() map[string]bool { return registry } -func augmentClaudeBuiltinToolRegistry(body []byte, registry map[string]bool) map[string]bool { +func AugmentClaudeBuiltinToolRegistry(body []byte, registry map[string]bool) map[string]bool { if registry == nil { registry = newClaudeBuiltinToolRegistry() } diff --git a/internal/runtime/executor/helps/claude_builtin_tools_test.go b/internal/runtime/executor/helps/claude_builtin_tools_test.go new file mode 100644 index 00000000..d7badd19 --- /dev/null +++ b/internal/runtime/executor/helps/claude_builtin_tools_test.go @@ -0,0 +1,32 @@ +package helps + +import "testing" + +func TestClaudeBuiltinToolRegistry_DefaultSeedFallback(t *testing.T) { + registry := AugmentClaudeBuiltinToolRegistry(nil, nil) + for _, name := range defaultClaudeBuiltinToolNames { + if !registry[name] { + t.Fatalf("default builtin %q missing from fallback registry", name) + } + } +} + +func TestClaudeBuiltinToolRegistry_AugmentsTypedBuiltinsFromBody(t *testing.T) { + registry := AugmentClaudeBuiltinToolRegistry([]byte(`{ + "tools": [ + {"type": "web_search_20250305", "name": "web_search"}, + {"type": "custom_builtin_20250401", "name": "special_builtin"}, + {"name": "Read"} + ] + }`), nil) + + if !registry["web_search"] { + t.Fatal("expected default typed builtin web_search in registry") + } + if !registry["special_builtin"] { + t.Fatal("expected typed builtin from body to be added to registry") + } + if registry["Read"] { + t.Fatal("expected untyped custom tool to stay out of builtin registry") + } +} From 65e9e892a4cff2dd1d68e17a23a7b7b405b767ba Mon Sep 17 00:00:00 2001 From: James Date: Sat, 4 Apr 2026 04:44:01 +0000 Subject: [PATCH 15/21] Fix missing `response.completed.usage` for late-usage OpenAI-compatible streams --- .../executor/openai_compat_executor.go | 8 + .../openai_openai-responses_response.go | 293 +++++++++--------- .../openai_openai-responses_response_test.go | 118 +++++++ 3 files changed, 279 insertions(+), 140 deletions(-) diff --git a/internal/runtime/executor/openai_compat_executor.go b/internal/runtime/executor/openai_compat_executor.go index a03e4987..7f202055 100644 --- a/internal/runtime/executor/openai_compat_executor.go +++ b/internal/runtime/executor/openai_compat_executor.go @@ -298,6 +298,14 @@ func (e *OpenAICompatExecutor) ExecuteStream(ctx context.Context, auth *cliproxy helps.RecordAPIResponseError(ctx, e.cfg, errScan) reporter.PublishFailure(ctx) out <- cliproxyexecutor.StreamChunk{Err: errScan} + } else { + // In case the upstream close the stream without a terminal [DONE] marker. + // Feed a synthetic done marker through the translator so pending + // response.completed events are still emitted exactly once. + chunks := sdktranslator.TranslateStream(ctx, to, from, req.Model, opts.OriginalRequest, translated, []byte("data: [DONE]"), ¶m) + for i := range chunks { + out <- cliproxyexecutor.StreamChunk{Payload: chunks[i]} + } } // Ensure we record the request if no usage chunk was ever seen reporter.EnsurePublished(ctx) diff --git a/internal/translator/openai/openai/responses/openai_openai-responses_response.go b/internal/translator/openai/openai/responses/openai_openai-responses_response.go index a34a6ff4..8a44aede 100644 --- a/internal/translator/openai/openai/responses/openai_openai-responses_response.go +++ b/internal/translator/openai/openai/responses/openai_openai-responses_response.go @@ -20,12 +20,14 @@ type oaiToResponsesStateReasoning struct { OutputIndex int } type oaiToResponsesState struct { - Seq int - ResponseID string - Created int64 - Started bool - ReasoningID string - ReasoningIndex int + Seq int + ResponseID string + Created int64 + Started bool + CompletionPending bool + CompletedEmitted bool + ReasoningID string + ReasoningIndex int // aggregation buffers for response.output // Per-output message text buffers by index MsgTextBuf map[int]*strings.Builder @@ -60,6 +62,141 @@ func emitRespEvent(event string, payload []byte) []byte { return translatorcommon.SSEEventData(event, payload) } +func buildResponsesCompletedEvent(st *oaiToResponsesState, requestRawJSON []byte, nextSeq func() int) []byte { + completed := []byte(`{"type":"response.completed","sequence_number":0,"response":{"id":"","object":"response","created_at":0,"status":"completed","background":false,"error":null}}`) + completed, _ = sjson.SetBytes(completed, "sequence_number", nextSeq()) + completed, _ = sjson.SetBytes(completed, "response.id", st.ResponseID) + completed, _ = sjson.SetBytes(completed, "response.created_at", st.Created) + // Inject original request fields into response as per docs/response.completed.json + if requestRawJSON != nil { + req := gjson.ParseBytes(requestRawJSON) + if v := req.Get("instructions"); v.Exists() { + completed, _ = sjson.SetBytes(completed, "response.instructions", v.String()) + } + if v := req.Get("max_output_tokens"); v.Exists() { + completed, _ = sjson.SetBytes(completed, "response.max_output_tokens", v.Int()) + } + if v := req.Get("max_tool_calls"); v.Exists() { + completed, _ = sjson.SetBytes(completed, "response.max_tool_calls", v.Int()) + } + if v := req.Get("model"); v.Exists() { + completed, _ = sjson.SetBytes(completed, "response.model", v.String()) + } + if v := req.Get("parallel_tool_calls"); v.Exists() { + completed, _ = sjson.SetBytes(completed, "response.parallel_tool_calls", v.Bool()) + } + if v := req.Get("previous_response_id"); v.Exists() { + completed, _ = sjson.SetBytes(completed, "response.previous_response_id", v.String()) + } + if v := req.Get("prompt_cache_key"); v.Exists() { + completed, _ = sjson.SetBytes(completed, "response.prompt_cache_key", v.String()) + } + if v := req.Get("reasoning"); v.Exists() { + completed, _ = sjson.SetBytes(completed, "response.reasoning", v.Value()) + } + if v := req.Get("safety_identifier"); v.Exists() { + completed, _ = sjson.SetBytes(completed, "response.safety_identifier", v.String()) + } + if v := req.Get("service_tier"); v.Exists() { + completed, _ = sjson.SetBytes(completed, "response.service_tier", v.String()) + } + if v := req.Get("store"); v.Exists() { + completed, _ = sjson.SetBytes(completed, "response.store", v.Bool()) + } + if v := req.Get("temperature"); v.Exists() { + completed, _ = sjson.SetBytes(completed, "response.temperature", v.Float()) + } + if v := req.Get("text"); v.Exists() { + completed, _ = sjson.SetBytes(completed, "response.text", v.Value()) + } + if v := req.Get("tool_choice"); v.Exists() { + completed, _ = sjson.SetBytes(completed, "response.tool_choice", v.Value()) + } + if v := req.Get("tools"); v.Exists() { + completed, _ = sjson.SetBytes(completed, "response.tools", v.Value()) + } + if v := req.Get("top_logprobs"); v.Exists() { + completed, _ = sjson.SetBytes(completed, "response.top_logprobs", v.Int()) + } + if v := req.Get("top_p"); v.Exists() { + completed, _ = sjson.SetBytes(completed, "response.top_p", v.Float()) + } + if v := req.Get("truncation"); v.Exists() { + completed, _ = sjson.SetBytes(completed, "response.truncation", v.String()) + } + if v := req.Get("user"); v.Exists() { + completed, _ = sjson.SetBytes(completed, "response.user", v.Value()) + } + if v := req.Get("metadata"); v.Exists() { + completed, _ = sjson.SetBytes(completed, "response.metadata", v.Value()) + } + } + + outputsWrapper := []byte(`{"arr":[]}`) + type completedOutputItem struct { + index int + raw []byte + } + outputItems := make([]completedOutputItem, 0, len(st.Reasonings)+len(st.MsgItemAdded)+len(st.FuncArgsBuf)) + if len(st.Reasonings) > 0 { + for _, r := range st.Reasonings { + item := []byte(`{"id":"","type":"reasoning","summary":[{"type":"summary_text","text":""}]}`) + item, _ = sjson.SetBytes(item, "id", r.ReasoningID) + item, _ = sjson.SetBytes(item, "summary.0.text", r.ReasoningData) + outputItems = append(outputItems, completedOutputItem{index: r.OutputIndex, raw: item}) + } + } + if len(st.MsgItemAdded) > 0 { + for i := range st.MsgItemAdded { + txt := "" + if b := st.MsgTextBuf[i]; b != nil { + txt = b.String() + } + item := []byte(`{"id":"","type":"message","status":"completed","content":[{"type":"output_text","annotations":[],"logprobs":[],"text":""}],"role":"assistant"}`) + item, _ = sjson.SetBytes(item, "id", fmt.Sprintf("msg_%s_%d", st.ResponseID, i)) + item, _ = sjson.SetBytes(item, "content.0.text", txt) + outputItems = append(outputItems, completedOutputItem{index: st.MsgOutputIx[i], raw: item}) + } + } + if len(st.FuncArgsBuf) > 0 { + for key := range st.FuncArgsBuf { + args := "" + if b := st.FuncArgsBuf[key]; b != nil { + args = b.String() + } + callID := st.FuncCallIDs[key] + name := st.FuncNames[key] + item := []byte(`{"id":"","type":"function_call","status":"completed","arguments":"","call_id":"","name":""}`) + item, _ = sjson.SetBytes(item, "id", fmt.Sprintf("fc_%s", callID)) + item, _ = sjson.SetBytes(item, "arguments", args) + item, _ = sjson.SetBytes(item, "call_id", callID) + item, _ = sjson.SetBytes(item, "name", name) + outputItems = append(outputItems, completedOutputItem{index: st.FuncOutputIx[key], raw: item}) + } + } + sort.Slice(outputItems, func(i, j int) bool { return outputItems[i].index < outputItems[j].index }) + for _, item := range outputItems { + outputsWrapper, _ = sjson.SetRawBytes(outputsWrapper, "arr.-1", item.raw) + } + if gjson.GetBytes(outputsWrapper, "arr.#").Int() > 0 { + completed, _ = sjson.SetRawBytes(completed, "response.output", []byte(gjson.GetBytes(outputsWrapper, "arr").Raw)) + } + if st.UsageSeen { + completed, _ = sjson.SetBytes(completed, "response.usage.input_tokens", st.PromptTokens) + completed, _ = sjson.SetBytes(completed, "response.usage.input_tokens_details.cached_tokens", st.CachedTokens) + completed, _ = sjson.SetBytes(completed, "response.usage.output_tokens", st.CompletionTokens) + if st.ReasoningTokens > 0 { + completed, _ = sjson.SetBytes(completed, "response.usage.output_tokens_details.reasoning_tokens", st.ReasoningTokens) + } + total := st.TotalTokens + if total == 0 { + total = st.PromptTokens + st.CompletionTokens + } + completed, _ = sjson.SetBytes(completed, "response.usage.total_tokens", total) + } + return emitRespEvent("response.completed", completed) +} + // ConvertOpenAIChatCompletionsResponseToOpenAIResponses converts OpenAI Chat Completions streaming chunks // to OpenAI Responses SSE events (response.*). func ConvertOpenAIChatCompletionsResponseToOpenAIResponses(ctx context.Context, modelName string, originalRequestRawJSON, requestRawJSON, rawJSON []byte, param *any) [][]byte { @@ -90,6 +227,10 @@ func ConvertOpenAIChatCompletionsResponseToOpenAIResponses(ctx context.Context, return [][]byte{} } if bytes.Equal(rawJSON, []byte("[DONE]")) { + if st.CompletionPending && !st.CompletedEmitted { + st.CompletedEmitted = true + return [][]byte{buildResponsesCompletedEvent(st, requestRawJSON, func() int { st.Seq++; return st.Seq })} + } return [][]byte{} } @@ -165,6 +306,8 @@ func ConvertOpenAIChatCompletionsResponseToOpenAIResponses(ctx context.Context, st.TotalTokens = 0 st.ReasoningTokens = 0 st.UsageSeen = false + st.CompletionPending = false + st.CompletedEmitted = false // response.created created := []byte(`{"type":"response.created","sequence_number":0,"response":{"id":"","object":"response","created_at":0,"status":"in_progress","background":false,"error":null,"output":[]}}`) created, _ = sjson.SetBytes(created, "sequence_number", nextSeq()) @@ -374,8 +517,9 @@ func ConvertOpenAIChatCompletionsResponseToOpenAIResponses(ctx context.Context, } } - // finish_reason triggers finalization, including text done/content done/item done, - // reasoning done/part.done, function args done/item done, and completed + // finish_reason triggers item-level finalization. response.completed is + // deferred until the terminal [DONE] marker so late usage-only chunks can + // still populate response.usage. if fr := choice.Get("finish_reason"); fr.Exists() && fr.String() != "" { // Emit message done events for all indices that started a message if len(st.MsgItemAdded) > 0 { @@ -464,138 +608,7 @@ func ConvertOpenAIChatCompletionsResponseToOpenAIResponses(ctx context.Context, st.FuncArgsDone[key] = true } } - completed := []byte(`{"type":"response.completed","sequence_number":0,"response":{"id":"","object":"response","created_at":0,"status":"completed","background":false,"error":null}}`) - completed, _ = sjson.SetBytes(completed, "sequence_number", nextSeq()) - completed, _ = sjson.SetBytes(completed, "response.id", st.ResponseID) - completed, _ = sjson.SetBytes(completed, "response.created_at", st.Created) - // Inject original request fields into response as per docs/response.completed.json - if requestRawJSON != nil { - req := gjson.ParseBytes(requestRawJSON) - if v := req.Get("instructions"); v.Exists() { - completed, _ = sjson.SetBytes(completed, "response.instructions", v.String()) - } - if v := req.Get("max_output_tokens"); v.Exists() { - completed, _ = sjson.SetBytes(completed, "response.max_output_tokens", v.Int()) - } - if v := req.Get("max_tool_calls"); v.Exists() { - completed, _ = sjson.SetBytes(completed, "response.max_tool_calls", v.Int()) - } - if v := req.Get("model"); v.Exists() { - completed, _ = sjson.SetBytes(completed, "response.model", v.String()) - } - if v := req.Get("parallel_tool_calls"); v.Exists() { - completed, _ = sjson.SetBytes(completed, "response.parallel_tool_calls", v.Bool()) - } - if v := req.Get("previous_response_id"); v.Exists() { - completed, _ = sjson.SetBytes(completed, "response.previous_response_id", v.String()) - } - if v := req.Get("prompt_cache_key"); v.Exists() { - completed, _ = sjson.SetBytes(completed, "response.prompt_cache_key", v.String()) - } - if v := req.Get("reasoning"); v.Exists() { - completed, _ = sjson.SetBytes(completed, "response.reasoning", v.Value()) - } - if v := req.Get("safety_identifier"); v.Exists() { - completed, _ = sjson.SetBytes(completed, "response.safety_identifier", v.String()) - } - if v := req.Get("service_tier"); v.Exists() { - completed, _ = sjson.SetBytes(completed, "response.service_tier", v.String()) - } - if v := req.Get("store"); v.Exists() { - completed, _ = sjson.SetBytes(completed, "response.store", v.Bool()) - } - if v := req.Get("temperature"); v.Exists() { - completed, _ = sjson.SetBytes(completed, "response.temperature", v.Float()) - } - if v := req.Get("text"); v.Exists() { - completed, _ = sjson.SetBytes(completed, "response.text", v.Value()) - } - if v := req.Get("tool_choice"); v.Exists() { - completed, _ = sjson.SetBytes(completed, "response.tool_choice", v.Value()) - } - if v := req.Get("tools"); v.Exists() { - completed, _ = sjson.SetBytes(completed, "response.tools", v.Value()) - } - if v := req.Get("top_logprobs"); v.Exists() { - completed, _ = sjson.SetBytes(completed, "response.top_logprobs", v.Int()) - } - if v := req.Get("top_p"); v.Exists() { - completed, _ = sjson.SetBytes(completed, "response.top_p", v.Float()) - } - if v := req.Get("truncation"); v.Exists() { - completed, _ = sjson.SetBytes(completed, "response.truncation", v.String()) - } - if v := req.Get("user"); v.Exists() { - completed, _ = sjson.SetBytes(completed, "response.user", v.Value()) - } - if v := req.Get("metadata"); v.Exists() { - completed, _ = sjson.SetBytes(completed, "response.metadata", v.Value()) - } - } - // Build response.output using aggregated buffers - outputsWrapper := []byte(`{"arr":[]}`) - type completedOutputItem struct { - index int - raw []byte - } - outputItems := make([]completedOutputItem, 0, len(st.Reasonings)+len(st.MsgItemAdded)+len(st.FuncArgsBuf)) - if len(st.Reasonings) > 0 { - for _, r := range st.Reasonings { - item := []byte(`{"id":"","type":"reasoning","summary":[{"type":"summary_text","text":""}]}`) - item, _ = sjson.SetBytes(item, "id", r.ReasoningID) - item, _ = sjson.SetBytes(item, "summary.0.text", r.ReasoningData) - outputItems = append(outputItems, completedOutputItem{index: r.OutputIndex, raw: item}) - } - } - if len(st.MsgItemAdded) > 0 { - for i := range st.MsgItemAdded { - txt := "" - if b := st.MsgTextBuf[i]; b != nil { - txt = b.String() - } - item := []byte(`{"id":"","type":"message","status":"completed","content":[{"type":"output_text","annotations":[],"logprobs":[],"text":""}],"role":"assistant"}`) - item, _ = sjson.SetBytes(item, "id", fmt.Sprintf("msg_%s_%d", st.ResponseID, i)) - item, _ = sjson.SetBytes(item, "content.0.text", txt) - outputItems = append(outputItems, completedOutputItem{index: st.MsgOutputIx[i], raw: item}) - } - } - if len(st.FuncArgsBuf) > 0 { - for key := range st.FuncArgsBuf { - args := "" - if b := st.FuncArgsBuf[key]; b != nil { - args = b.String() - } - callID := st.FuncCallIDs[key] - name := st.FuncNames[key] - item := []byte(`{"id":"","type":"function_call","status":"completed","arguments":"","call_id":"","name":""}`) - item, _ = sjson.SetBytes(item, "id", fmt.Sprintf("fc_%s", callID)) - item, _ = sjson.SetBytes(item, "arguments", args) - item, _ = sjson.SetBytes(item, "call_id", callID) - item, _ = sjson.SetBytes(item, "name", name) - outputItems = append(outputItems, completedOutputItem{index: st.FuncOutputIx[key], raw: item}) - } - } - sort.Slice(outputItems, func(i, j int) bool { return outputItems[i].index < outputItems[j].index }) - for _, item := range outputItems { - outputsWrapper, _ = sjson.SetRawBytes(outputsWrapper, "arr.-1", item.raw) - } - if gjson.GetBytes(outputsWrapper, "arr.#").Int() > 0 { - completed, _ = sjson.SetRawBytes(completed, "response.output", []byte(gjson.GetBytes(outputsWrapper, "arr").Raw)) - } - if st.UsageSeen { - completed, _ = sjson.SetBytes(completed, "response.usage.input_tokens", st.PromptTokens) - completed, _ = sjson.SetBytes(completed, "response.usage.input_tokens_details.cached_tokens", st.CachedTokens) - completed, _ = sjson.SetBytes(completed, "response.usage.output_tokens", st.CompletionTokens) - if st.ReasoningTokens > 0 { - completed, _ = sjson.SetBytes(completed, "response.usage.output_tokens_details.reasoning_tokens", st.ReasoningTokens) - } - total := st.TotalTokens - if total == 0 { - total = st.PromptTokens + st.CompletionTokens - } - completed, _ = sjson.SetBytes(completed, "response.usage.total_tokens", total) - } - out = append(out, emitRespEvent("response.completed", completed)) + st.CompletionPending = true } return true diff --git a/internal/translator/openai/openai/responses/openai_openai-responses_response_test.go b/internal/translator/openai/openai/responses/openai_openai-responses_response_test.go index 9f3ed3f4..cafcacb7 100644 --- a/internal/translator/openai/openai/responses/openai_openai-responses_response_test.go +++ b/internal/translator/openai/openai/responses/openai_openai-responses_response_test.go @@ -24,6 +24,120 @@ func parseOpenAIResponsesSSEEvent(t *testing.T, chunk []byte) (string, gjson.Res return event, gjson.Parse(dataLine) } +func TestConvertOpenAIChatCompletionsResponseToOpenAIResponses_ResponseCompletedWaitsForDone(t *testing.T) { + t.Parallel() + + request := []byte(`{"model":"gpt-5.4","tool_choice":"auto","parallel_tool_calls":true}`) + + tests := []struct { + name string + in []string + doneInputIndex int // Index in tt.in where the terminal [DONE] chunk arrives and response.completed must be emitted. + hasUsage bool + inputTokens int64 + outputTokens int64 + totalTokens int64 + }{ + { + // A provider may send finish_reason first and only attach usage in a later chunk (e.g. Vertex AI), + // so response.completed must wait for [DONE] to include that usage. + name: "late usage after finish reason", + in: []string{ + `data: {"id":"resp_late_usage","object":"chat.completion.chunk","created":1773896263,"model":"model","choices":[{"index":0,"delta":{"role":"assistant","content":null,"reasoning_content":null,"tool_calls":[{"index":0,"id":"call_late_usage","type":"function","function":{"name":"read","arguments":""}}]},"finish_reason":null}]}`, + `data: {"id":"resp_late_usage","object":"chat.completion.chunk","created":1773896263,"model":"model","choices":[{"index":0,"delta":{"role":null,"content":null,"reasoning_content":null,"tool_calls":[{"index":0,"function":{"arguments":"{\"filePath\":\"C:\\\\repo\\\\README.md\"}"}}]},"finish_reason":"tool_calls"}]}`, + `data: {"id":"resp_late_usage","object":"chat.completion.chunk","created":1773896263,"model":"model","choices":[],"usage":{"prompt_tokens":11,"completion_tokens":7,"total_tokens":18}}`, + `data: [DONE]`, + }, + doneInputIndex: 3, + hasUsage: true, + inputTokens: 11, + outputTokens: 7, + totalTokens: 18, + }, + { + // When usage arrives on the same chunk as finish_reason, we still expect a + // single response.completed event and it should remain deferred until [DONE]. + name: "usage on finish reason chunk", + in: []string{ + `data: {"id":"resp_usage_same_chunk","object":"chat.completion.chunk","created":1773896263,"model":"model","choices":[{"index":0,"delta":{"role":"assistant","content":null,"reasoning_content":null,"tool_calls":[{"index":0,"id":"call_usage_same_chunk","type":"function","function":{"name":"read","arguments":""}}]},"finish_reason":null}]}`, + `data: {"id":"resp_usage_same_chunk","object":"chat.completion.chunk","created":1773896263,"model":"model","choices":[{"index":0,"delta":{"role":null,"content":null,"reasoning_content":null,"tool_calls":[{"index":0,"function":{"arguments":"{\"filePath\":\"C:\\\\repo\\\\README.md\"}"}}]},"finish_reason":"tool_calls"}],"usage":{"prompt_tokens":13,"completion_tokens":5,"total_tokens":18}}`, + `data: [DONE]`, + }, + doneInputIndex: 2, + hasUsage: true, + inputTokens: 13, + outputTokens: 5, + totalTokens: 18, + }, + { + // An OpenAI-compatible streams from a buggy server might never send usage, so response.completed should + // still wait for [DONE] but omit the usage object entirely. + name: "no usage chunk", + in: []string{ + `data: {"id":"resp_no_usage","object":"chat.completion.chunk","created":1773896263,"model":"model","choices":[{"index":0,"delta":{"role":"assistant","content":null,"reasoning_content":null,"tool_calls":[{"index":0,"id":"call_no_usage","type":"function","function":{"name":"read","arguments":""}}]},"finish_reason":null}]}`, + `data: {"id":"resp_no_usage","object":"chat.completion.chunk","created":1773896263,"model":"model","choices":[{"index":0,"delta":{"role":null,"content":null,"reasoning_content":null,"tool_calls":[{"index":0,"function":{"arguments":"{\"filePath\":\"C:\\\\repo\\\\README.md\"}"}}]},"finish_reason":"tool_calls"}]}`, + `data: [DONE]`, + }, + doneInputIndex: 2, + hasUsage: false, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + completedCount := 0 + completedInputIndex := -1 + var completedData gjson.Result + + // Reuse converter state across input lines to simulate one streaming response. + var param any + + for i, line := range tt.in { + // One upstream chunk can emit multiple downstream SSE events. + for _, chunk := range ConvertOpenAIChatCompletionsResponseToOpenAIResponses(context.Background(), "model", request, request, []byte(line), ¶m) { + event, data := parseOpenAIResponsesSSEEvent(t, chunk) + if event != "response.completed" { + continue + } + + completedCount++ + completedInputIndex = i + completedData = data + if i < tt.doneInputIndex { + t.Fatalf("unexpected early response.completed on input index %d", i) + } + } + } + + if completedCount != 1 { + t.Fatalf("expected exactly 1 response.completed event, got %d", completedCount) + } + if completedInputIndex != tt.doneInputIndex { + t.Fatalf("expected response.completed on terminal [DONE] chunk at input index %d, got %d", tt.doneInputIndex, completedInputIndex) + } + + // Missing upstream usage should stay omitted in the final completed event. + if !tt.hasUsage { + if completedData.Get("response.usage").Exists() { + t.Fatalf("expected response.completed to omit usage when none was provided, got %s", completedData.Get("response.usage").Raw) + } + return + } + + // When usage is present, the final response.completed event must preserve the usage values. + if got := completedData.Get("response.usage.input_tokens").Int(); got != tt.inputTokens { + t.Fatalf("unexpected response.usage.input_tokens: got %d want %d", got, tt.inputTokens) + } + if got := completedData.Get("response.usage.output_tokens").Int(); got != tt.outputTokens { + t.Fatalf("unexpected response.usage.output_tokens: got %d want %d", got, tt.outputTokens) + } + if got := completedData.Get("response.usage.total_tokens").Int(); got != tt.totalTokens { + t.Fatalf("unexpected response.usage.total_tokens: got %d want %d", got, tt.totalTokens) + } + }) + } +} + func TestConvertOpenAIChatCompletionsResponseToOpenAIResponses_MultipleToolCallsRemainSeparate(t *testing.T) { in := []string{ `data: {"id":"resp_test","object":"chat.completion.chunk","created":1773896263,"model":"model","choices":[{"index":0,"delta":{"role":"assistant","content":null,"reasoning_content":null,"tool_calls":[{"index":0,"id":"call_read","type":"function","function":{"name":"read","arguments":""}}]},"finish_reason":null}]}`, @@ -31,6 +145,7 @@ func TestConvertOpenAIChatCompletionsResponseToOpenAIResponses_MultipleToolCalls `data: {"id":"resp_test","object":"chat.completion.chunk","created":1773896263,"model":"model","choices":[{"index":0,"delta":{"role":"assistant","content":null,"reasoning_content":null,"tool_calls":[{"index":1,"id":"call_glob","type":"function","function":{"name":"glob","arguments":""}}]},"finish_reason":null}]}`, `data: {"id":"resp_test","object":"chat.completion.chunk","created":1773896263,"model":"model","choices":[{"index":0,"delta":{"role":null,"content":null,"reasoning_content":null,"tool_calls":[{"index":1,"function":{"arguments":"{\"path\":\"C:\\\\repo\",\"pattern\":\"*.{yml,yaml}\"}"}}]},"finish_reason":null}]}`, `data: {"id":"resp_test","object":"chat.completion.chunk","created":1773896263,"model":"model","choices":[{"index":0,"delta":{"role":null,"content":null,"reasoning_content":null,"tool_calls":null},"finish_reason":"tool_calls"}],"usage":{"completion_tokens":10,"total_tokens":20,"prompt_tokens":10}}`, + `data: [DONE]`, } request := []byte(`{"model":"gpt-5.4","tool_choice":"auto","parallel_tool_calls":true}`) @@ -131,6 +246,7 @@ func TestConvertOpenAIChatCompletionsResponseToOpenAIResponses_MultiChoiceToolCa `data: {"id":"resp_multi_choice","object":"chat.completion.chunk","created":1773896263,"model":"model","choices":[{"index":0,"delta":{"role":"assistant","content":null,"reasoning_content":null,"tool_calls":[{"index":0,"id":"call_choice0","type":"function","function":{"name":"glob","arguments":""}}]},"finish_reason":null},{"index":1,"delta":{"role":"assistant","content":null,"reasoning_content":null,"tool_calls":[{"index":0,"id":"call_choice1","type":"function","function":{"name":"read","arguments":""}}]},"finish_reason":null}]}`, `data: {"id":"resp_multi_choice","object":"chat.completion.chunk","created":1773896263,"model":"model","choices":[{"index":0,"delta":{"role":null,"content":null,"reasoning_content":null,"tool_calls":[{"index":0,"function":{"arguments":"{\"path\":\"C:\\\\repo\",\"pattern\":\"*.go\"}"}}]},"finish_reason":null},{"index":1,"delta":{"role":null,"content":null,"reasoning_content":null,"tool_calls":[{"index":0,"function":{"arguments":"{\"filePath\":\"C:\\\\repo\\\\README.md\",\"limit\":20,\"offset\":1}"}}]},"finish_reason":null}]}`, `data: {"id":"resp_multi_choice","object":"chat.completion.chunk","created":1773896263,"model":"model","choices":[{"index":0,"delta":{"role":null,"content":null,"reasoning_content":null,"tool_calls":null},"finish_reason":"tool_calls"},{"index":1,"delta":{"role":null,"content":null,"reasoning_content":null,"tool_calls":null},"finish_reason":"tool_calls"}],"usage":{"completion_tokens":10,"total_tokens":20,"prompt_tokens":10}}`, + `data: [DONE]`, } request := []byte(`{"model":"gpt-5.4","tool_choice":"auto","parallel_tool_calls":true}`) @@ -213,6 +329,7 @@ func TestConvertOpenAIChatCompletionsResponseToOpenAIResponses_MixedMessageAndTo in := []string{ `data: {"id":"resp_mixed","object":"chat.completion.chunk","created":1773896263,"model":"model","choices":[{"index":0,"delta":{"role":"assistant","content":"hello","reasoning_content":null,"tool_calls":null},"finish_reason":null},{"index":1,"delta":{"role":"assistant","content":null,"reasoning_content":null,"tool_calls":[{"index":0,"id":"call_choice1","type":"function","function":{"name":"read","arguments":""}}]},"finish_reason":null}]}`, `data: {"id":"resp_mixed","object":"chat.completion.chunk","created":1773896263,"model":"model","choices":[{"index":0,"delta":{"role":null,"content":null,"reasoning_content":null,"tool_calls":null},"finish_reason":"stop"},{"index":1,"delta":{"role":null,"content":null,"reasoning_content":null,"tool_calls":[{"index":0,"function":{"arguments":"{\"filePath\":\"C:\\\\repo\\\\README.md\",\"limit\":20,\"offset\":1}"}}]},"finish_reason":"tool_calls"}],"usage":{"completion_tokens":10,"total_tokens":20,"prompt_tokens":10}}`, + `data: [DONE]`, } request := []byte(`{"model":"gpt-5.4","tool_choice":"auto","parallel_tool_calls":true}`) @@ -261,6 +378,7 @@ func TestConvertOpenAIChatCompletionsResponseToOpenAIResponses_FunctionCallDoneA `data: {"id":"resp_order","object":"chat.completion.chunk","created":1773896263,"model":"model","choices":[{"index":0,"delta":{"role":"assistant","content":null,"reasoning_content":null,"tool_calls":[{"index":1,"id":"call_read","type":"function","function":{"name":"read","arguments":""}}]},"finish_reason":null}]}`, `data: {"id":"resp_order","object":"chat.completion.chunk","created":1773896263,"model":"model","choices":[{"index":0,"delta":{"role":null,"content":null,"reasoning_content":null,"tool_calls":[{"index":1,"function":{"arguments":"{\"filePath\":\"C:\\\\repo\\\\README.md\",\"limit\":20,\"offset\":1}"}}]},"finish_reason":null}]}`, `data: {"id":"resp_order","object":"chat.completion.chunk","created":1773896263,"model":"model","choices":[{"index":0,"delta":{"role":null,"content":null,"reasoning_content":null,"tool_calls":null},"finish_reason":"tool_calls"}],"usage":{"completion_tokens":10,"total_tokens":20,"prompt_tokens":10}}`, + `data: [DONE]`, } request := []byte(`{"model":"gpt-5.4","tool_choice":"auto","parallel_tool_calls":true}`) From 22a1a24cf58bb3965267b56a21404f20df8b77af Mon Sep 17 00:00:00 2001 From: Luis Pater Date: Sun, 5 Apr 2026 17:58:13 +0800 Subject: [PATCH 16/21] feat(executor): add tests for preserving key order in cache control functions Added comprehensive tests to ensure key order is maintained when modifying payloads in `normalizeCacheControlTTL` and `enforceCacheControlLimit` functions. Removed unused helper functions and refactored implementations for better readability and efficiency. --- internal/runtime/executor/claude_executor.go | 430 ++++++++---------- .../runtime/executor/claude_executor_test.go | 47 ++ 2 files changed, 233 insertions(+), 244 deletions(-) diff --git a/internal/runtime/executor/claude_executor.go b/internal/runtime/executor/claude_executor.go index 7b2e5d8d..131ac3ea 100644 --- a/internal/runtime/executor/claude_executor.go +++ b/internal/runtime/executor/claude_executor.go @@ -8,7 +8,6 @@ import ( "context" "crypto/sha256" "encoding/hex" - "encoding/json" "fmt" "io" "net/http" @@ -1463,182 +1462,6 @@ func countCacheControls(payload []byte) int { return count } -func parsePayloadObject(payload []byte) (map[string]any, bool) { - if len(payload) == 0 { - return nil, false - } - var root map[string]any - if err := json.Unmarshal(payload, &root); err != nil { - return nil, false - } - return root, true -} - -func marshalPayloadObject(original []byte, root map[string]any) []byte { - if root == nil { - return original - } - out, err := json.Marshal(root) - if err != nil { - return original - } - return out -} - -func asObject(v any) (map[string]any, bool) { - obj, ok := v.(map[string]any) - return obj, ok -} - -func asArray(v any) ([]any, bool) { - arr, ok := v.([]any) - return arr, ok -} - -func countCacheControlsMap(root map[string]any) int { - count := 0 - - if system, ok := asArray(root["system"]); ok { - for _, item := range system { - if obj, ok := asObject(item); ok { - if _, exists := obj["cache_control"]; exists { - count++ - } - } - } - } - - if tools, ok := asArray(root["tools"]); ok { - for _, item := range tools { - if obj, ok := asObject(item); ok { - if _, exists := obj["cache_control"]; exists { - count++ - } - } - } - } - - if messages, ok := asArray(root["messages"]); ok { - for _, msg := range messages { - msgObj, ok := asObject(msg) - if !ok { - continue - } - content, ok := asArray(msgObj["content"]) - if !ok { - continue - } - for _, item := range content { - if obj, ok := asObject(item); ok { - if _, exists := obj["cache_control"]; exists { - count++ - } - } - } - } - } - - return count -} - -func normalizeTTLForBlock(obj map[string]any, seen5m *bool) bool { - ccRaw, exists := obj["cache_control"] - if !exists { - return false - } - cc, ok := asObject(ccRaw) - if !ok { - *seen5m = true - return false - } - ttlRaw, ttlExists := cc["ttl"] - ttl, ttlIsString := ttlRaw.(string) - if !ttlExists || !ttlIsString || ttl != "1h" { - *seen5m = true - return false - } - if *seen5m { - delete(cc, "ttl") - return true - } - return false -} - -func findLastCacheControlIndex(arr []any) int { - last := -1 - for idx, item := range arr { - obj, ok := asObject(item) - if !ok { - continue - } - if _, exists := obj["cache_control"]; exists { - last = idx - } - } - return last -} - -func stripCacheControlExceptIndex(arr []any, preserveIdx int, excess *int) { - for idx, item := range arr { - if *excess <= 0 { - return - } - obj, ok := asObject(item) - if !ok { - continue - } - if _, exists := obj["cache_control"]; exists && idx != preserveIdx { - delete(obj, "cache_control") - *excess-- - } - } -} - -func stripAllCacheControl(arr []any, excess *int) { - for _, item := range arr { - if *excess <= 0 { - return - } - obj, ok := asObject(item) - if !ok { - continue - } - if _, exists := obj["cache_control"]; exists { - delete(obj, "cache_control") - *excess-- - } - } -} - -func stripMessageCacheControl(messages []any, excess *int) { - for _, msg := range messages { - if *excess <= 0 { - return - } - msgObj, ok := asObject(msg) - if !ok { - continue - } - content, ok := asArray(msgObj["content"]) - if !ok { - continue - } - for _, item := range content { - if *excess <= 0 { - return - } - obj, ok := asObject(item) - if !ok { - continue - } - if _, exists := obj["cache_control"]; exists { - delete(obj, "cache_control") - *excess-- - } - } - } -} - // normalizeCacheControlTTL ensures cache_control TTL values don't violate the // prompt-caching-scope-2026-01-05 ordering constraint: a 1h-TTL block must not // appear after a 5m-TTL block anywhere in the evaluation order. @@ -1651,58 +1474,75 @@ func stripMessageCacheControl(messages []any, excess *int) { // Strategy: walk all cache_control blocks in evaluation order. Once a 5m block // is seen, strip ttl from ALL subsequent 1h blocks (downgrading them to 5m). func normalizeCacheControlTTL(payload []byte) []byte { - root, ok := parsePayloadObject(payload) - if !ok { + if len(payload) == 0 || !gjson.ValidBytes(payload) { return payload } + original := payload seen5m := false modified := false - if tools, ok := asArray(root["tools"]); ok { - for _, tool := range tools { - if obj, ok := asObject(tool); ok { - if normalizeTTLForBlock(obj, &seen5m) { - modified = true - } - } + processBlock := func(path string, obj gjson.Result) { + cc := obj.Get("cache_control") + if !cc.Exists() { + return } + if !cc.IsObject() { + seen5m = true + return + } + ttl := cc.Get("ttl") + if ttl.Type != gjson.String || ttl.String() != "1h" { + seen5m = true + return + } + if !seen5m { + return + } + ttlPath := path + ".cache_control.ttl" + updated, errDel := sjson.DeleteBytes(payload, ttlPath) + if errDel != nil { + return + } + payload = updated + modified = true } - if system, ok := asArray(root["system"]); ok { - for _, item := range system { - if obj, ok := asObject(item); ok { - if normalizeTTLForBlock(obj, &seen5m) { - modified = true - } - } - } + tools := gjson.GetBytes(payload, "tools") + if tools.IsArray() { + tools.ForEach(func(idx, item gjson.Result) bool { + processBlock(fmt.Sprintf("tools.%d", int(idx.Int())), item) + return true + }) } - if messages, ok := asArray(root["messages"]); ok { - for _, msg := range messages { - msgObj, ok := asObject(msg) - if !ok { - continue + system := gjson.GetBytes(payload, "system") + if system.IsArray() { + system.ForEach(func(idx, item gjson.Result) bool { + processBlock(fmt.Sprintf("system.%d", int(idx.Int())), item) + return true + }) + } + + messages := gjson.GetBytes(payload, "messages") + if messages.IsArray() { + messages.ForEach(func(msgIdx, msg gjson.Result) bool { + content := msg.Get("content") + if !content.IsArray() { + return true } - content, ok := asArray(msgObj["content"]) - if !ok { - continue - } - for _, item := range content { - if obj, ok := asObject(item); ok { - if normalizeTTLForBlock(obj, &seen5m) { - modified = true - } - } - } - } + content.ForEach(func(itemIdx, item gjson.Result) bool { + processBlock(fmt.Sprintf("messages.%d.content.%d", int(msgIdx.Int()), int(itemIdx.Int())), item) + return true + }) + return true + }) } if !modified { - return payload + return original } - return marshalPayloadObject(payload, root) + return payload } // enforceCacheControlLimit removes excess cache_control blocks from a payload @@ -1722,64 +1562,166 @@ func normalizeCacheControlTTL(payload []byte) []byte { // Phase 4: remaining system blocks (last system). // Phase 5: remaining tool blocks (last tool). func enforceCacheControlLimit(payload []byte, maxBlocks int) []byte { - root, ok := parsePayloadObject(payload) - if !ok { + if len(payload) == 0 || !gjson.ValidBytes(payload) { return payload } - total := countCacheControlsMap(root) + total := countCacheControls(payload) if total <= maxBlocks { return payload } excess := total - maxBlocks - var system []any - if arr, ok := asArray(root["system"]); ok { - system = arr - } - var tools []any - if arr, ok := asArray(root["tools"]); ok { - tools = arr - } - var messages []any - if arr, ok := asArray(root["messages"]); ok { - messages = arr - } - - if len(system) > 0 { - stripCacheControlExceptIndex(system, findLastCacheControlIndex(system), &excess) + system := gjson.GetBytes(payload, "system") + if system.IsArray() { + lastIdx := -1 + system.ForEach(func(idx, item gjson.Result) bool { + if item.Get("cache_control").Exists() { + lastIdx = int(idx.Int()) + } + return true + }) + if lastIdx >= 0 { + system.ForEach(func(idx, item gjson.Result) bool { + if excess <= 0 { + return false + } + i := int(idx.Int()) + if i == lastIdx { + return true + } + if !item.Get("cache_control").Exists() { + return true + } + path := fmt.Sprintf("system.%d.cache_control", i) + updated, errDel := sjson.DeleteBytes(payload, path) + if errDel != nil { + return true + } + payload = updated + excess-- + return true + }) + } } if excess <= 0 { - return marshalPayloadObject(payload, root) + return payload } - if len(tools) > 0 { - stripCacheControlExceptIndex(tools, findLastCacheControlIndex(tools), &excess) + tools := gjson.GetBytes(payload, "tools") + if tools.IsArray() { + lastIdx := -1 + tools.ForEach(func(idx, item gjson.Result) bool { + if item.Get("cache_control").Exists() { + lastIdx = int(idx.Int()) + } + return true + }) + if lastIdx >= 0 { + tools.ForEach(func(idx, item gjson.Result) bool { + if excess <= 0 { + return false + } + i := int(idx.Int()) + if i == lastIdx { + return true + } + if !item.Get("cache_control").Exists() { + return true + } + path := fmt.Sprintf("tools.%d.cache_control", i) + updated, errDel := sjson.DeleteBytes(payload, path) + if errDel != nil { + return true + } + payload = updated + excess-- + return true + }) + } } if excess <= 0 { - return marshalPayloadObject(payload, root) + return payload } - if len(messages) > 0 { - stripMessageCacheControl(messages, &excess) + messages := gjson.GetBytes(payload, "messages") + if messages.IsArray() { + messages.ForEach(func(msgIdx, msg gjson.Result) bool { + if excess <= 0 { + return false + } + content := msg.Get("content") + if !content.IsArray() { + return true + } + content.ForEach(func(itemIdx, item gjson.Result) bool { + if excess <= 0 { + return false + } + if !item.Get("cache_control").Exists() { + return true + } + path := fmt.Sprintf("messages.%d.content.%d.cache_control", int(msgIdx.Int()), int(itemIdx.Int())) + updated, errDel := sjson.DeleteBytes(payload, path) + if errDel != nil { + return true + } + payload = updated + excess-- + return true + }) + return true + }) } if excess <= 0 { - return marshalPayloadObject(payload, root) + return payload } - if len(system) > 0 { - stripAllCacheControl(system, &excess) + system = gjson.GetBytes(payload, "system") + if system.IsArray() { + system.ForEach(func(idx, item gjson.Result) bool { + if excess <= 0 { + return false + } + if !item.Get("cache_control").Exists() { + return true + } + path := fmt.Sprintf("system.%d.cache_control", int(idx.Int())) + updated, errDel := sjson.DeleteBytes(payload, path) + if errDel != nil { + return true + } + payload = updated + excess-- + return true + }) } if excess <= 0 { - return marshalPayloadObject(payload, root) + return payload } - if len(tools) > 0 { - stripAllCacheControl(tools, &excess) + tools = gjson.GetBytes(payload, "tools") + if tools.IsArray() { + tools.ForEach(func(idx, item gjson.Result) bool { + if excess <= 0 { + return false + } + if !item.Get("cache_control").Exists() { + return true + } + path := fmt.Sprintf("tools.%d.cache_control", int(idx.Int())) + updated, errDel := sjson.DeleteBytes(payload, path) + if errDel != nil { + return true + } + payload = updated + excess-- + return true + }) } - return marshalPayloadObject(payload, root) + return payload } // injectMessagesCacheControl adds cache_control to the second-to-last user turn for multi-turn caching. diff --git a/internal/runtime/executor/claude_executor_test.go b/internal/runtime/executor/claude_executor_test.go index 74cec0a3..c6220fe9 100644 --- a/internal/runtime/executor/claude_executor_test.go +++ b/internal/runtime/executor/claude_executor_test.go @@ -965,6 +965,28 @@ func TestNormalizeCacheControlTTL_PreservesOriginalBytesWhenNoChange(t *testing. } } +func TestNormalizeCacheControlTTL_PreservesKeyOrderWhenModified(t *testing.T) { + payload := []byte(`{"model":"m","messages":[{"role":"user","content":[{"type":"text","text":"u1","cache_control":{"type":"ephemeral","ttl":"1h"}}]}],"tools":[{"name":"t1","cache_control":{"type":"ephemeral"}}],"system":[{"type":"text","text":"s1","cache_control":{"type":"ephemeral"}}]}`) + + out := normalizeCacheControlTTL(payload) + + if gjson.GetBytes(out, "messages.0.content.0.cache_control.ttl").Exists() { + t.Fatalf("messages.0.content.0.cache_control.ttl should be removed after a default-5m block") + } + + outStr := string(out) + idxModel := strings.Index(outStr, `"model"`) + idxMessages := strings.Index(outStr, `"messages"`) + idxTools := strings.Index(outStr, `"tools"`) + idxSystem := strings.Index(outStr, `"system"`) + if idxModel == -1 || idxMessages == -1 || idxTools == -1 || idxSystem == -1 { + t.Fatalf("failed to locate top-level keys in output: %s", outStr) + } + if !(idxModel < idxMessages && idxMessages < idxTools && idxTools < idxSystem) { + t.Fatalf("top-level key order changed:\noriginal: %s\ngot: %s", payload, out) + } +} + func TestEnforceCacheControlLimit_StripsNonLastToolBeforeMessages(t *testing.T) { payload := []byte(`{ "tools": [ @@ -994,6 +1016,31 @@ func TestEnforceCacheControlLimit_StripsNonLastToolBeforeMessages(t *testing.T) } } +func TestEnforceCacheControlLimit_PreservesKeyOrderWhenModified(t *testing.T) { + payload := []byte(`{"model":"m","messages":[{"role":"user","content":[{"type":"text","text":"u1","cache_control":{"type":"ephemeral"}},{"type":"text","text":"u2","cache_control":{"type":"ephemeral"}}]}],"tools":[{"name":"t1","cache_control":{"type":"ephemeral"}},{"name":"t2","cache_control":{"type":"ephemeral"}}],"system":[{"type":"text","text":"s1","cache_control":{"type":"ephemeral"}}]}`) + + out := enforceCacheControlLimit(payload, 4) + + if got := countCacheControls(out); got != 4 { + t.Fatalf("cache_control count = %d, want 4", got) + } + if gjson.GetBytes(out, "tools.0.cache_control").Exists() { + t.Fatalf("tools.0.cache_control should be removed first (non-last tool)") + } + + outStr := string(out) + idxModel := strings.Index(outStr, `"model"`) + idxMessages := strings.Index(outStr, `"messages"`) + idxTools := strings.Index(outStr, `"tools"`) + idxSystem := strings.Index(outStr, `"system"`) + if idxModel == -1 || idxMessages == -1 || idxTools == -1 || idxSystem == -1 { + t.Fatalf("failed to locate top-level keys in output: %s", outStr) + } + if !(idxModel < idxMessages && idxMessages < idxTools && idxTools < idxSystem) { + t.Fatalf("top-level key order changed:\noriginal: %s\ngot: %s", payload, out) + } +} + func TestEnforceCacheControlLimit_ToolOnlyPayloadStillRespectsLimit(t *testing.T) { payload := []byte(`{ "tools": [ From b0653cec7b4942aa844777c69932e996c437aad2 Mon Sep 17 00:00:00 2001 From: Aikins Laryea Date: Fri, 3 Apr 2026 17:20:43 +0000 Subject: [PATCH 17/21] fix(amp): strip signature from tool_use blocks before forwarding to Claude ensureAmpSignature injects signature:"" into tool_use blocks so the Amp TUI does not crash on P.signature.length. when Amp sends the conversation back, Claude rejects the extra field with 400: tool_use.signature: Extra inputs are not permitted strip the proxy-injected signature from tool_use blocks in SanitizeAmpRequestBody before forwarding to the upstream API. --- internal/api/modules/amp/response_rewriter.go | 27 ++++++++++++----- .../api/modules/amp/response_rewriter_test.go | 30 +++++++++++++++++++ 2 files changed, 49 insertions(+), 8 deletions(-) diff --git a/internal/api/modules/amp/response_rewriter.go b/internal/api/modules/amp/response_rewriter.go index 390f301d..707fe576 100644 --- a/internal/api/modules/amp/response_rewriter.go +++ b/internal/api/modules/amp/response_rewriter.go @@ -2,6 +2,7 @@ package amp import ( "bytes" + "encoding/json" "fmt" "net/http" "strings" @@ -290,8 +291,10 @@ func (rw *ResponseRewriter) rewriteStreamEvent(data []byte) []byte { } // SanitizeAmpRequestBody removes thinking blocks with empty/missing/invalid signatures -// from the messages array in a request body before forwarding to the upstream API. -// This prevents 400 errors from the API which requires valid signatures on thinking blocks. +// and strips the proxy-injected "signature" field from tool_use blocks in the messages +// array before forwarding to the upstream API. +// This prevents 400 errors from the API which requires valid signatures on thinking +// blocks and does not accept a signature field on tool_use blocks. func SanitizeAmpRequestBody(body []byte) []byte { messages := gjson.GetBytes(body, "messages") if !messages.Exists() || !messages.IsArray() { @@ -309,21 +312,30 @@ func SanitizeAmpRequestBody(body []byte) []byte { } var keepBlocks []interface{} - removedCount := 0 + contentModified := false for _, block := range content.Array() { blockType := block.Get("type").String() if blockType == "thinking" { sig := block.Get("signature") if !sig.Exists() || sig.Type != gjson.String || strings.TrimSpace(sig.String()) == "" { - removedCount++ + contentModified = true continue } } - keepBlocks = append(keepBlocks, block.Value()) + + // Use raw JSON to prevent float64 rounding of large integers in tool_use inputs + blockRaw := []byte(block.Raw) + if blockType == "tool_use" && block.Get("signature").Exists() { + blockRaw, _ = sjson.DeleteBytes(blockRaw, "signature") + contentModified = true + } + + // sjson.SetBytes supports raw JSON strings if wrapped in gjson.Raw + keepBlocks = append(keepBlocks, json.RawMessage(blockRaw)) } - if removedCount > 0 { + if contentModified { contentPath := fmt.Sprintf("messages.%d.content", msgIdx) var err error if len(keepBlocks) == 0 { @@ -332,11 +344,10 @@ func SanitizeAmpRequestBody(body []byte) []byte { body, err = sjson.SetBytes(body, contentPath, keepBlocks) } if err != nil { - log.Warnf("Amp RequestSanitizer: failed to remove thinking blocks from message %d: %v", msgIdx, err) + log.Warnf("Amp RequestSanitizer: failed to sanitize message %d: %v", msgIdx, err) continue } modified = true - log.Debugf("Amp RequestSanitizer: removed %d thinking blocks with invalid signatures from message %d", removedCount, msgIdx) } } diff --git a/internal/api/modules/amp/response_rewriter_test.go b/internal/api/modules/amp/response_rewriter_test.go index 31ba56bd..ac95dfc6 100644 --- a/internal/api/modules/amp/response_rewriter_test.go +++ b/internal/api/modules/amp/response_rewriter_test.go @@ -145,6 +145,36 @@ func TestSanitizeAmpRequestBody_RemovesWhitespaceAndNonStringSignatures(t *testi } } +func TestSanitizeAmpRequestBody_StripsSignatureFromToolUseBlocks(t *testing.T) { + input := []byte(`{"messages":[{"role":"assistant","content":[{"type":"thinking","thinking":"thought","signature":"valid-sig"},{"type":"tool_use","id":"toolu_01","name":"Bash","input":{"cmd":"ls"},"signature":""}]}]}`) + result := SanitizeAmpRequestBody(input) + + if contains(result, []byte(`"signature":""`)) { + t.Fatalf("expected signature to be stripped from tool_use block, got %s", string(result)) + } + if !contains(result, []byte(`"valid-sig"`)) { + t.Fatalf("expected thinking signature to remain, got %s", string(result)) + } + if !contains(result, []byte(`"tool_use"`)) { + t.Fatalf("expected tool_use block to remain, got %s", string(result)) + } +} + +func TestSanitizeAmpRequestBody_MixedInvalidThinkingAndToolUseSignature(t *testing.T) { + input := []byte(`{"messages":[{"role":"assistant","content":[{"type":"thinking","thinking":"drop-me","signature":""},{"type":"tool_use","id":"toolu_01","name":"Bash","input":{"cmd":"ls"},"signature":""}]}]}`) + result := SanitizeAmpRequestBody(input) + + if contains(result, []byte("drop-me")) { + t.Fatalf("expected invalid thinking block to be removed, got %s", string(result)) + } + if contains(result, []byte(`"signature"`)) { + t.Fatalf("expected signature to be stripped from tool_use block, got %s", string(result)) + } + if !contains(result, []byte(`"tool_use"`)) { + t.Fatalf("expected tool_use block to remain, got %s", string(result)) + } +} + func contains(data, substr []byte) bool { for i := 0; i <= len(data)-len(substr); i++ { if string(data[i:i+len(substr)]) == string(substr) { From 6f58518c6912a26b46875653af6b26e023aa3a00 Mon Sep 17 00:00:00 2001 From: Luis Pater Date: Mon, 6 Apr 2026 09:23:04 +0800 Subject: [PATCH 18/21] docs(readme): remove redundant GITSTORE_GIT_BRANCH description in README files --- README.md | 2 -- README_CN.md | 2 -- README_JA.md | 2 -- 3 files changed, 6 deletions(-) diff --git a/README.md b/README.md index 63548b3f..c027be19 100644 --- a/README.md +++ b/README.md @@ -72,8 +72,6 @@ Get 10% OFF GLM CODING PLAN:https://z.ai/subscribe?ic=8JVLJQFSKB CLIProxyAPI Guides: [https://help.router-for.me/](https://help.router-for.me/) -For the optional git-backed config store, `GITSTORE_GIT_BRANCH` is optional. Leave it empty or unset to follow the remote repository's default branch, and set it only when you want to force a specific branch. - ## Management API see [MANAGEMENT_API.md](https://help.router-for.me/management/api) diff --git a/README_CN.md b/README_CN.md index 08ad6919..3e71528d 100644 --- a/README_CN.md +++ b/README_CN.md @@ -72,8 +72,6 @@ GLM CODING PLAN 是专为AI编码打造的订阅套餐,每月最低仅需20元 CLIProxyAPI 用户手册: [https://help.router-for.me/](https://help.router-for.me/cn/) -对于可选的 git 存储配置,`GITSTORE_GIT_BRANCH` 是可选项。留空或不设置时会跟随远程仓库的默认分支,只有在你需要强制指定分支时才设置它。 - ## 管理 API 文档 请参见 [MANAGEMENT_API_CN.md](https://help.router-for.me/cn/management/api) diff --git a/README_JA.md b/README_JA.md index de37690e..d3f06949 100644 --- a/README_JA.md +++ b/README_JA.md @@ -72,8 +72,6 @@ GLM CODING PLANを10%割引で取得:https://z.ai/subscribe?ic=8JVLJQFSKB CLIProxyAPIガイド:[https://help.router-for.me/](https://help.router-for.me/) -オプションのgitバックアップ設定ストアでは、`GITSTORE_GIT_BRANCH` は任意です。空のままにするか未設定にすると、リモートリポジトリのデフォルトブランチに従います。特定のブランチを強制したい場合のみ設定してください。 - ## 管理API [MANAGEMENT_API.md](https://help.router-for.me/management/api)を参照 From 0ea768011b93de3d45fbc05442e451df55069280 Mon Sep 17 00:00:00 2001 From: zilianpn Date: Tue, 7 Apr 2026 00:24:08 +0800 Subject: [PATCH 19/21] fix(auth): honor disable-cooling and enrich no-auth errors --- config.example.yaml | 3 + sdk/api/handlers/handlers.go | 54 ++++- .../handlers/handlers_error_response_test.go | 45 ++++ .../handlers_stream_bootstrap_test.go | 73 +++++++ sdk/cliproxy/auth/conductor.go | 100 ++++++--- sdk/cliproxy/auth/conductor_overrides_test.go | 196 ++++++++++++++++++ 6 files changed, 436 insertions(+), 35 deletions(-) diff --git a/config.example.yaml b/config.example.yaml index 5dd872ea..ce2d0a5a 100644 --- a/config.example.yaml +++ b/config.example.yaml @@ -87,6 +87,9 @@ max-retry-credentials: 0 # Maximum wait time in seconds for a cooled-down credential before triggering a retry. max-retry-interval: 30 +# When true, disable auth/model cooldown scheduling globally (prevents blackout windows after failure states). +disable-cooling: false + # Quota exceeded behavior quota-exceeded: switch-project: true # Whether to automatically switch to another project when a quota is exceeded diff --git a/sdk/api/handlers/handlers.go b/sdk/api/handlers/handlers.go index 28ab970d..1f7996c0 100644 --- a/sdk/api/handlers/handlers.go +++ b/sdk/api/handlers/handlers.go @@ -6,6 +6,7 @@ package handlers import ( "bytes" "encoding/json" + "errors" "fmt" "net/http" "strings" @@ -492,6 +493,7 @@ func (h *BaseAPIHandler) ExecuteWithAuthManager(ctx context.Context, handlerType opts.Metadata = reqMeta resp, err := h.AuthManager.Execute(ctx, providers, req, opts) if err != nil { + err = enrichAuthSelectionError(err, providers, normalizedModel) status := http.StatusInternalServerError if se, ok := err.(interface{ StatusCode() int }); ok && se != nil { if code := se.StatusCode(); code > 0 { @@ -538,6 +540,7 @@ func (h *BaseAPIHandler) ExecuteCountWithAuthManager(ctx context.Context, handle opts.Metadata = reqMeta resp, err := h.AuthManager.ExecuteCount(ctx, providers, req, opts) if err != nil { + err = enrichAuthSelectionError(err, providers, normalizedModel) status := http.StatusInternalServerError if se, ok := err.(interface{ StatusCode() int }); ok && se != nil { if code := se.StatusCode(); code > 0 { @@ -588,6 +591,7 @@ func (h *BaseAPIHandler) ExecuteStreamWithAuthManager(ctx context.Context, handl opts.Metadata = reqMeta streamResult, err := h.AuthManager.ExecuteStream(ctx, providers, req, opts) if err != nil { + err = enrichAuthSelectionError(err, providers, normalizedModel) errChan := make(chan *interfaces.ErrorMessage, 1) status := http.StatusInternalServerError if se, ok := err.(interface{ StatusCode() int }); ok && se != nil { @@ -697,7 +701,7 @@ func (h *BaseAPIHandler) ExecuteStreamWithAuthManager(ctx context.Context, handl chunks = retryResult.Chunks continue outer } - streamErr = retryErr + streamErr = enrichAuthSelectionError(retryErr, providers, normalizedModel) } } @@ -840,6 +844,54 @@ func replaceHeader(dst http.Header, src http.Header) { } } +func enrichAuthSelectionError(err error, providers []string, model string) error { + if err == nil { + return nil + } + + var authErr *coreauth.Error + if !errors.As(err, &authErr) || authErr == nil { + return err + } + + code := strings.TrimSpace(authErr.Code) + if code != "auth_not_found" && code != "auth_unavailable" { + return err + } + + providerText := strings.Join(providers, ",") + if providerText == "" { + providerText = "unknown" + } + modelText := strings.TrimSpace(model) + if modelText == "" { + modelText = "unknown" + } + + baseMessage := strings.TrimSpace(authErr.Message) + if baseMessage == "" { + baseMessage = "no auth available" + } + detail := fmt.Sprintf("%s (providers=%s, model=%s)", baseMessage, providerText, modelText) + + // Clarify the most common alias confusion between Anthropic route names and internal provider keys. + if strings.Contains(","+providerText+",", ",claude,") { + detail += "; check Claude auth/key session and cooldown state via /v0/management/auth-files" + } + + status := authErr.HTTPStatus + if status <= 0 { + status = http.StatusServiceUnavailable + } + + return &coreauth.Error{ + Code: authErr.Code, + Message: detail, + Retryable: authErr.Retryable, + HTTPStatus: status, + } +} + // WriteErrorResponse writes an error message to the response writer using the HTTP status embedded in the message. func (h *BaseAPIHandler) WriteErrorResponse(c *gin.Context, msg *interfaces.ErrorMessage) { status := http.StatusInternalServerError diff --git a/sdk/api/handlers/handlers_error_response_test.go b/sdk/api/handlers/handlers_error_response_test.go index cde4547f..917971c2 100644 --- a/sdk/api/handlers/handlers_error_response_test.go +++ b/sdk/api/handlers/handlers_error_response_test.go @@ -5,10 +5,12 @@ import ( "net/http" "net/http/httptest" "reflect" + "strings" "testing" "github.com/gin-gonic/gin" "github.com/router-for-me/CLIProxyAPI/v6/internal/interfaces" + coreauth "github.com/router-for-me/CLIProxyAPI/v6/sdk/cliproxy/auth" sdkconfig "github.com/router-for-me/CLIProxyAPI/v6/sdk/config" ) @@ -66,3 +68,46 @@ func TestWriteErrorResponse_AddonHeadersEnabled(t *testing.T) { t.Fatalf("X-Request-Id = %#v, want %#v", got, []string{"new-1", "new-2"}) } } + +func TestEnrichAuthSelectionError_DefaultsTo503WithContext(t *testing.T) { + in := &coreauth.Error{Code: "auth_not_found", Message: "no auth available"} + out := enrichAuthSelectionError(in, []string{"claude"}, "claude-sonnet-4-6") + + var got *coreauth.Error + if !errors.As(out, &got) || got == nil { + t.Fatalf("expected coreauth.Error, got %T", out) + } + if got.StatusCode() != http.StatusServiceUnavailable { + t.Fatalf("status = %d, want %d", got.StatusCode(), http.StatusServiceUnavailable) + } + if !strings.Contains(got.Message, "providers=claude") { + t.Fatalf("message missing provider context: %q", got.Message) + } + if !strings.Contains(got.Message, "model=claude-sonnet-4-6") { + t.Fatalf("message missing model context: %q", got.Message) + } + if !strings.Contains(got.Message, "/v0/management/auth-files") { + t.Fatalf("message missing management hint: %q", got.Message) + } +} + +func TestEnrichAuthSelectionError_PreservesExplicitStatus(t *testing.T) { + in := &coreauth.Error{Code: "auth_unavailable", Message: "no auth available", HTTPStatus: http.StatusTooManyRequests} + out := enrichAuthSelectionError(in, []string{"gemini"}, "gemini-2.5-pro") + + var got *coreauth.Error + if !errors.As(out, &got) || got == nil { + t.Fatalf("expected coreauth.Error, got %T", out) + } + if got.StatusCode() != http.StatusTooManyRequests { + t.Fatalf("status = %d, want %d", got.StatusCode(), http.StatusTooManyRequests) + } +} + +func TestEnrichAuthSelectionError_IgnoresOtherErrors(t *testing.T) { + in := errors.New("boom") + out := enrichAuthSelectionError(in, []string{"claude"}, "claude-sonnet-4-6") + if out != in { + t.Fatalf("expected original error to be returned unchanged") + } +} diff --git a/sdk/api/handlers/handlers_stream_bootstrap_test.go b/sdk/api/handlers/handlers_stream_bootstrap_test.go index 61c03332..f357962f 100644 --- a/sdk/api/handlers/handlers_stream_bootstrap_test.go +++ b/sdk/api/handlers/handlers_stream_bootstrap_test.go @@ -2,10 +2,13 @@ package handlers import ( "context" + "errors" "net/http" + "strings" "sync" "testing" + "github.com/router-for-me/CLIProxyAPI/v6/internal/interfaces" "github.com/router-for-me/CLIProxyAPI/v6/internal/registry" coreauth "github.com/router-for-me/CLIProxyAPI/v6/sdk/cliproxy/auth" coreexecutor "github.com/router-for-me/CLIProxyAPI/v6/sdk/cliproxy/executor" @@ -463,6 +466,76 @@ func TestExecuteStreamWithAuthManager_DoesNotRetryAfterFirstByte(t *testing.T) { } } +func TestExecuteStreamWithAuthManager_EnrichesBootstrapRetryAuthUnavailableError(t *testing.T) { + executor := &failOnceStreamExecutor{} + manager := coreauth.NewManager(nil, nil, nil) + manager.RegisterExecutor(executor) + + auth1 := &coreauth.Auth{ + ID: "auth1", + Provider: "codex", + Status: coreauth.StatusActive, + Metadata: map[string]any{"email": "test1@example.com"}, + } + if _, err := manager.Register(context.Background(), auth1); err != nil { + t.Fatalf("manager.Register(auth1): %v", err) + } + + registry.GetGlobalRegistry().RegisterClient(auth1.ID, auth1.Provider, []*registry.ModelInfo{{ID: "test-model"}}) + t.Cleanup(func() { + registry.GetGlobalRegistry().UnregisterClient(auth1.ID) + }) + + handler := NewBaseAPIHandlers(&sdkconfig.SDKConfig{ + Streaming: sdkconfig.StreamingConfig{ + BootstrapRetries: 1, + }, + }, manager) + dataChan, _, errChan := handler.ExecuteStreamWithAuthManager(context.Background(), "openai", "test-model", []byte(`{"model":"test-model"}`), "") + if dataChan == nil || errChan == nil { + t.Fatalf("expected non-nil channels") + } + + var got []byte + for chunk := range dataChan { + got = append(got, chunk...) + } + if len(got) != 0 { + t.Fatalf("expected empty payload, got %q", string(got)) + } + + var gotErr *interfaces.ErrorMessage + for msg := range errChan { + if msg != nil { + gotErr = msg + } + } + if gotErr == nil { + t.Fatalf("expected terminal error") + } + if gotErr.StatusCode != http.StatusServiceUnavailable { + t.Fatalf("status = %d, want %d", gotErr.StatusCode, http.StatusServiceUnavailable) + } + + var authErr *coreauth.Error + if !errors.As(gotErr.Error, &authErr) || authErr == nil { + t.Fatalf("expected coreauth.Error, got %T", gotErr.Error) + } + if authErr.Code != "auth_unavailable" { + t.Fatalf("code = %q, want %q", authErr.Code, "auth_unavailable") + } + if !strings.Contains(authErr.Message, "providers=codex") { + t.Fatalf("message missing provider context: %q", authErr.Message) + } + if !strings.Contains(authErr.Message, "model=test-model") { + t.Fatalf("message missing model context: %q", authErr.Message) + } + + if executor.Calls() != 1 { + t.Fatalf("expected exactly one upstream call before retry path selection failure, got %d", executor.Calls()) + } +} + func TestExecuteStreamWithAuthManager_PinnedAuthKeepsSameUpstream(t *testing.T) { executor := &authAwareStreamExecutor{} manager := coreauth.NewManager(nil, nil, nil) diff --git a/sdk/cliproxy/auth/conductor.go b/sdk/cliproxy/auth/conductor.go index 478c7921..f5f7a60a 100644 --- a/sdk/cliproxy/auth/conductor.go +++ b/sdk/cliproxy/auth/conductor.go @@ -1838,6 +1838,7 @@ func (m *Manager) MarkResult(ctx context.Context, result Result) { } else { if result.Model != "" { if !isRequestScopedNotFoundResultError(result.Error) { + disableCooling := quotaCooldownDisabledForAuth(auth) state := ensureModelState(auth, result.Model) state.Unavailable = true state.Status = StatusError @@ -1858,31 +1859,45 @@ func (m *Manager) MarkResult(ctx context.Context, result Result) { } else { switch statusCode { case 401: - next := now.Add(30 * time.Minute) - state.NextRetryAfter = next - suspendReason = "unauthorized" - shouldSuspendModel = true + if disableCooling { + state.NextRetryAfter = time.Time{} + } else { + next := now.Add(30 * time.Minute) + state.NextRetryAfter = next + suspendReason = "unauthorized" + shouldSuspendModel = true + } case 402, 403: - next := now.Add(30 * time.Minute) - state.NextRetryAfter = next - suspendReason = "payment_required" - shouldSuspendModel = true + if disableCooling { + state.NextRetryAfter = time.Time{} + } else { + next := now.Add(30 * time.Minute) + state.NextRetryAfter = next + suspendReason = "payment_required" + shouldSuspendModel = true + } case 404: - next := now.Add(12 * time.Hour) - state.NextRetryAfter = next - suspendReason = "not_found" - shouldSuspendModel = true + if disableCooling { + state.NextRetryAfter = time.Time{} + } else { + next := now.Add(12 * time.Hour) + state.NextRetryAfter = next + suspendReason = "not_found" + shouldSuspendModel = true + } case 429: var next time.Time backoffLevel := state.Quota.BackoffLevel - if result.RetryAfter != nil { - next = now.Add(*result.RetryAfter) - } else { - cooldown, nextLevel := nextQuotaCooldown(backoffLevel, quotaCooldownDisabledForAuth(auth)) - if cooldown > 0 { - next = now.Add(cooldown) + if !disableCooling { + if result.RetryAfter != nil { + next = now.Add(*result.RetryAfter) + } else { + cooldown, nextLevel := nextQuotaCooldown(backoffLevel, disableCooling) + if cooldown > 0 { + next = now.Add(cooldown) + } + backoffLevel = nextLevel } - backoffLevel = nextLevel } state.NextRetryAfter = next state.Quota = QuotaState{ @@ -1891,11 +1906,13 @@ func (m *Manager) MarkResult(ctx context.Context, result Result) { NextRecoverAt: next, BackoffLevel: backoffLevel, } - suspendReason = "quota" - shouldSuspendModel = true - setModelQuota = true + if !disableCooling { + suspendReason = "quota" + shouldSuspendModel = true + setModelQuota = true + } case 408, 500, 502, 503, 504: - if quotaCooldownDisabledForAuth(auth) { + if disableCooling { state.NextRetryAfter = time.Time{} } else { next := now.Add(1 * time.Minute) @@ -2211,6 +2228,7 @@ func applyAuthFailureState(auth *Auth, resultErr *Error, retryAfter *time.Durati if isRequestScopedNotFoundResultError(resultErr) { return } + disableCooling := quotaCooldownDisabledForAuth(auth) auth.Unavailable = true auth.Status = StatusError auth.UpdatedAt = now @@ -2224,32 +2242,46 @@ func applyAuthFailureState(auth *Auth, resultErr *Error, retryAfter *time.Durati switch statusCode { case 401: auth.StatusMessage = "unauthorized" - auth.NextRetryAfter = now.Add(30 * time.Minute) + if disableCooling { + auth.NextRetryAfter = time.Time{} + } else { + auth.NextRetryAfter = now.Add(30 * time.Minute) + } case 402, 403: auth.StatusMessage = "payment_required" - auth.NextRetryAfter = now.Add(30 * time.Minute) + if disableCooling { + auth.NextRetryAfter = time.Time{} + } else { + auth.NextRetryAfter = now.Add(30 * time.Minute) + } case 404: auth.StatusMessage = "not_found" - auth.NextRetryAfter = now.Add(12 * time.Hour) + if disableCooling { + auth.NextRetryAfter = time.Time{} + } else { + auth.NextRetryAfter = now.Add(12 * time.Hour) + } case 429: auth.StatusMessage = "quota exhausted" auth.Quota.Exceeded = true auth.Quota.Reason = "quota" var next time.Time - if retryAfter != nil { - next = now.Add(*retryAfter) - } else { - cooldown, nextLevel := nextQuotaCooldown(auth.Quota.BackoffLevel, quotaCooldownDisabledForAuth(auth)) - if cooldown > 0 { - next = now.Add(cooldown) + if !disableCooling { + if retryAfter != nil { + next = now.Add(*retryAfter) + } else { + cooldown, nextLevel := nextQuotaCooldown(auth.Quota.BackoffLevel, disableCooling) + if cooldown > 0 { + next = now.Add(cooldown) + } + auth.Quota.BackoffLevel = nextLevel } - auth.Quota.BackoffLevel = nextLevel } auth.Quota.NextRecoverAt = next auth.NextRetryAfter = next case 408, 500, 502, 503, 504: auth.StatusMessage = "transient upstream error" - if quotaCooldownDisabledForAuth(auth) { + if disableCooling { auth.NextRetryAfter = time.Time{} } else { auth.NextRetryAfter = now.Add(1 * time.Minute) diff --git a/sdk/cliproxy/auth/conductor_overrides_test.go b/sdk/cliproxy/auth/conductor_overrides_test.go index 50915ce0..0c72c833 100644 --- a/sdk/cliproxy/auth/conductor_overrides_test.go +++ b/sdk/cliproxy/auth/conductor_overrides_test.go @@ -180,6 +180,34 @@ func (e *authFallbackExecutor) StreamCalls() []string { return out } +type retryAfterStatusError struct { + status int + message string + retryAfter time.Duration +} + +func (e *retryAfterStatusError) Error() string { + if e == nil { + return "" + } + return e.message +} + +func (e *retryAfterStatusError) StatusCode() int { + if e == nil { + return 0 + } + return e.status +} + +func (e *retryAfterStatusError) RetryAfter() *time.Duration { + if e == nil { + return nil + } + d := e.retryAfter + return &d +} + func newCredentialRetryLimitTestManager(t *testing.T, maxRetryCredentials int) (*Manager, *credentialRetryLimitExecutor) { t.Helper() @@ -450,6 +478,174 @@ func TestManager_MarkResult_RespectsAuthDisableCoolingOverride(t *testing.T) { } } +func TestManager_MarkResult_RespectsAuthDisableCoolingOverride_On403(t *testing.T) { + prev := quotaCooldownDisabled.Load() + quotaCooldownDisabled.Store(false) + t.Cleanup(func() { quotaCooldownDisabled.Store(prev) }) + + m := NewManager(nil, nil, nil) + + auth := &Auth{ + ID: "auth-403", + Provider: "claude", + Metadata: map[string]any{ + "disable_cooling": true, + }, + } + if _, errRegister := m.Register(context.Background(), auth); errRegister != nil { + t.Fatalf("register auth: %v", errRegister) + } + + model := "test-model-403" + reg := registry.GetGlobalRegistry() + reg.RegisterClient(auth.ID, "claude", []*registry.ModelInfo{{ID: model}}) + t.Cleanup(func() { reg.UnregisterClient(auth.ID) }) + + m.MarkResult(context.Background(), Result{ + AuthID: auth.ID, + Provider: "claude", + Model: model, + Success: false, + Error: &Error{HTTPStatus: http.StatusForbidden, Message: "forbidden"}, + }) + + updated, ok := m.GetByID(auth.ID) + if !ok || updated == nil { + t.Fatalf("expected auth to be present") + } + state := updated.ModelStates[model] + if state == nil { + t.Fatalf("expected model state to be present") + } + if !state.NextRetryAfter.IsZero() { + t.Fatalf("expected NextRetryAfter to be zero when disable_cooling=true, got %v", state.NextRetryAfter) + } + + if count := reg.GetModelCount(model); count <= 0 { + t.Fatalf("expected model count > 0 when disable_cooling=true, got %d", count) + } +} + +func TestManager_Execute_DisableCooling_DoesNotBlackoutAfter403(t *testing.T) { + prev := quotaCooldownDisabled.Load() + quotaCooldownDisabled.Store(false) + t.Cleanup(func() { quotaCooldownDisabled.Store(prev) }) + + m := NewManager(nil, nil, nil) + executor := &authFallbackExecutor{ + id: "claude", + executeErrors: map[string]error{ + "auth-403-exec": &Error{ + HTTPStatus: http.StatusForbidden, + Message: "forbidden", + }, + }, + } + m.RegisterExecutor(executor) + + auth := &Auth{ + ID: "auth-403-exec", + Provider: "claude", + Metadata: map[string]any{ + "disable_cooling": true, + }, + } + if _, errRegister := m.Register(context.Background(), auth); errRegister != nil { + t.Fatalf("register auth: %v", errRegister) + } + + model := "test-model-403-exec" + reg := registry.GetGlobalRegistry() + reg.RegisterClient(auth.ID, "claude", []*registry.ModelInfo{{ID: model}}) + t.Cleanup(func() { reg.UnregisterClient(auth.ID) }) + + req := cliproxyexecutor.Request{Model: model} + _, errExecute1 := m.Execute(context.Background(), []string{"claude"}, req, cliproxyexecutor.Options{}) + if errExecute1 == nil { + t.Fatal("expected first execute error") + } + if statusCodeFromError(errExecute1) != http.StatusForbidden { + t.Fatalf("first execute status = %d, want %d", statusCodeFromError(errExecute1), http.StatusForbidden) + } + + _, errExecute2 := m.Execute(context.Background(), []string{"claude"}, req, cliproxyexecutor.Options{}) + if errExecute2 == nil { + t.Fatal("expected second execute error") + } + if statusCodeFromError(errExecute2) != http.StatusForbidden { + t.Fatalf("second execute status = %d, want %d", statusCodeFromError(errExecute2), http.StatusForbidden) + } +} + +func TestManager_Execute_DisableCooling_DoesNotBlackoutAfter429RetryAfter(t *testing.T) { + prev := quotaCooldownDisabled.Load() + quotaCooldownDisabled.Store(false) + t.Cleanup(func() { quotaCooldownDisabled.Store(prev) }) + + m := NewManager(nil, nil, nil) + executor := &authFallbackExecutor{ + id: "claude", + executeErrors: map[string]error{ + "auth-429-exec": &retryAfterStatusError{ + status: http.StatusTooManyRequests, + message: "quota exhausted", + retryAfter: 2 * time.Minute, + }, + }, + } + m.RegisterExecutor(executor) + + auth := &Auth{ + ID: "auth-429-exec", + Provider: "claude", + Metadata: map[string]any{ + "disable_cooling": true, + }, + } + if _, errRegister := m.Register(context.Background(), auth); errRegister != nil { + t.Fatalf("register auth: %v", errRegister) + } + + model := "test-model-429-exec" + reg := registry.GetGlobalRegistry() + reg.RegisterClient(auth.ID, "claude", []*registry.ModelInfo{{ID: model}}) + t.Cleanup(func() { reg.UnregisterClient(auth.ID) }) + + req := cliproxyexecutor.Request{Model: model} + _, errExecute1 := m.Execute(context.Background(), []string{"claude"}, req, cliproxyexecutor.Options{}) + if errExecute1 == nil { + t.Fatal("expected first execute error") + } + if statusCodeFromError(errExecute1) != http.StatusTooManyRequests { + t.Fatalf("first execute status = %d, want %d", statusCodeFromError(errExecute1), http.StatusTooManyRequests) + } + + _, errExecute2 := m.Execute(context.Background(), []string{"claude"}, req, cliproxyexecutor.Options{}) + if errExecute2 == nil { + t.Fatal("expected second execute error") + } + if statusCodeFromError(errExecute2) != http.StatusTooManyRequests { + t.Fatalf("second execute status = %d, want %d", statusCodeFromError(errExecute2), http.StatusTooManyRequests) + } + + calls := executor.ExecuteCalls() + if len(calls) != 2 { + t.Fatalf("execute calls = %d, want 2", len(calls)) + } + + updated, ok := m.GetByID(auth.ID) + if !ok || updated == nil { + t.Fatalf("expected auth to be present") + } + state := updated.ModelStates[model] + if state == nil { + t.Fatalf("expected model state to be present") + } + if !state.NextRetryAfter.IsZero() { + t.Fatalf("expected NextRetryAfter to be zero when disable_cooling=true, got %v", state.NextRetryAfter) + } +} + func TestManager_MarkResult_RequestScopedNotFoundDoesNotCooldownAuth(t *testing.T) { m := NewManager(nil, nil, nil) From 163d68318f8f844fe1aa6423806b6f0273357b07 Mon Sep 17 00:00:00 2001 From: Lemon Date: Tue, 7 Apr 2026 07:46:11 +0800 Subject: [PATCH 20/21] feat: support socks5h scheme for proxy settings --- .../executor/codex_websockets_executor.go | 2 +- sdk/proxyutil/proxy.go | 4 ++-- sdk/proxyutil/proxy_test.go | 22 +++++++++++++++++++ 3 files changed, 25 insertions(+), 3 deletions(-) diff --git a/internal/runtime/executor/codex_websockets_executor.go b/internal/runtime/executor/codex_websockets_executor.go index 2041cebc..94c9b262 100644 --- a/internal/runtime/executor/codex_websockets_executor.go +++ b/internal/runtime/executor/codex_websockets_executor.go @@ -734,7 +734,7 @@ func newProxyAwareWebsocketDialer(cfg *config.Config, auth *cliproxyauth.Auth) * } switch setting.URL.Scheme { - case "socks5": + case "socks5", "socks5h": var proxyAuth *proxy.Auth if setting.URL.User != nil { username := setting.URL.User.Username() diff --git a/sdk/proxyutil/proxy.go b/sdk/proxyutil/proxy.go index 029efeb7..c0d8b328 100644 --- a/sdk/proxyutil/proxy.go +++ b/sdk/proxyutil/proxy.go @@ -58,7 +58,7 @@ func Parse(raw string) (Setting, error) { } switch parsedURL.Scheme { - case "socks5", "http", "https": + case "socks5", "socks5h", "http", "https": setting.Mode = ModeProxy setting.URL = parsedURL return setting, nil @@ -95,7 +95,7 @@ func BuildHTTPTransport(raw string) (*http.Transport, Mode, error) { case ModeDirect: return NewDirectTransport(), setting.Mode, nil case ModeProxy: - if setting.URL.Scheme == "socks5" { + if setting.URL.Scheme == "socks5" || setting.URL.Scheme == "socks5h" { var proxyAuth *proxy.Auth if setting.URL.User != nil { username := setting.URL.User.Username() diff --git a/sdk/proxyutil/proxy_test.go b/sdk/proxyutil/proxy_test.go index 5b250117..f214bf6d 100644 --- a/sdk/proxyutil/proxy_test.go +++ b/sdk/proxyutil/proxy_test.go @@ -30,6 +30,7 @@ func TestParse(t *testing.T) { {name: "http", input: "http://proxy.example.com:8080", want: ModeProxy}, {name: "https", input: "https://proxy.example.com:8443", want: ModeProxy}, {name: "socks5", input: "socks5://proxy.example.com:1080", want: ModeProxy}, + {name: "socks5h", input: "socks5h://proxy.example.com:1080", want: ModeProxy}, {name: "invalid", input: "bad-value", want: ModeInvalid, wantErr: true}, } @@ -137,3 +138,24 @@ func TestBuildHTTPTransportSOCKS5ProxyInheritsDefaultTransportSettings(t *testin t.Fatalf("TLSHandshakeTimeout = %v, want %v", transport.TLSHandshakeTimeout, defaultTransport.TLSHandshakeTimeout) } } + +func TestBuildHTTPTransportSOCKS5HProxy(t *testing.T) { + t.Parallel() + + transport, mode, errBuild := BuildHTTPTransport("socks5h://proxy.example.com:1080") + if errBuild != nil { + t.Fatalf("BuildHTTPTransport returned error: %v", errBuild) + } + if mode != ModeProxy { + t.Fatalf("mode = %d, want %d", mode, ModeProxy) + } + if transport == nil { + t.Fatal("expected transport, got nil") + } + if transport.Proxy != nil { + t.Fatal("expected SOCKS5H transport to bypass http proxy function") + } + if transport.DialContext == nil { + t.Fatal("expected SOCKS5H transport to have custom DialContext") + } +} From c8b7e2b8d6f24462b724925dfe4f984ae6b9e302 Mon Sep 17 00:00:00 2001 From: Luis Pater Date: Tue, 7 Apr 2026 18:21:12 +0800 Subject: [PATCH 21/21] fix(executor): ensure empty stream completions use output_item.done as fallback Fixed: #2583 --- internal/runtime/executor/codex_executor.go | 50 +++++++++++++++++-- .../codex_executor_stream_output_test.go | 46 +++++++++++++++++ 2 files changed, 92 insertions(+), 4 deletions(-) create mode 100644 internal/runtime/executor/codex_executor_stream_output_test.go diff --git a/internal/runtime/executor/codex_executor.go b/internal/runtime/executor/codex_executor.go index e48a4ac3..acca590a 100644 --- a/internal/runtime/executor/codex_executor.go +++ b/internal/runtime/executor/codex_executor.go @@ -7,6 +7,7 @@ import ( "fmt" "io" "net/http" + "sort" "strings" "time" @@ -167,22 +168,63 @@ func (e *CodexExecutor) Execute(ctx context.Context, auth *cliproxyauth.Auth, re helps.AppendAPIResponseChunk(ctx, e.cfg, data) lines := bytes.Split(data, []byte("\n")) + outputItemsByIndex := make(map[int64][]byte) + var outputItemsFallback [][]byte for _, line := range lines { if !bytes.HasPrefix(line, dataTag) { continue } - line = bytes.TrimSpace(line[5:]) - if gjson.GetBytes(line, "type").String() != "response.completed" { + eventData := bytes.TrimSpace(line[5:]) + eventType := gjson.GetBytes(eventData, "type").String() + + if eventType == "response.output_item.done" { + itemResult := gjson.GetBytes(eventData, "item") + if !itemResult.Exists() || itemResult.Type != gjson.JSON { + continue + } + outputIndexResult := gjson.GetBytes(eventData, "output_index") + if outputIndexResult.Exists() { + outputItemsByIndex[outputIndexResult.Int()] = []byte(itemResult.Raw) + } else { + outputItemsFallback = append(outputItemsFallback, []byte(itemResult.Raw)) + } continue } - if detail, ok := helps.ParseCodexUsage(line); ok { + if eventType != "response.completed" { + continue + } + + if detail, ok := helps.ParseCodexUsage(eventData); ok { reporter.Publish(ctx, detail) } + completedData := eventData + outputResult := gjson.GetBytes(completedData, "response.output") + shouldPatchOutput := (!outputResult.Exists() || !outputResult.IsArray() || len(outputResult.Array()) == 0) && (len(outputItemsByIndex) > 0 || len(outputItemsFallback) > 0) + if shouldPatchOutput { + completedDataPatched := completedData + completedDataPatched, _ = sjson.SetRawBytes(completedDataPatched, "response.output", []byte(`[]`)) + + indexes := make([]int64, 0, len(outputItemsByIndex)) + for idx := range outputItemsByIndex { + indexes = append(indexes, idx) + } + sort.Slice(indexes, func(i, j int) bool { + return indexes[i] < indexes[j] + }) + for _, idx := range indexes { + completedDataPatched, _ = sjson.SetRawBytes(completedDataPatched, "response.output.-1", outputItemsByIndex[idx]) + } + for _, item := range outputItemsFallback { + completedDataPatched, _ = sjson.SetRawBytes(completedDataPatched, "response.output.-1", item) + } + completedData = completedDataPatched + } + var param any - out := sdktranslator.TranslateNonStream(ctx, to, from, req.Model, originalPayload, body, line, ¶m) + out := sdktranslator.TranslateNonStream(ctx, to, from, req.Model, originalPayload, body, completedData, ¶m) resp = cliproxyexecutor.Response{Payload: out, Headers: httpResp.Header.Clone()} return resp, nil } diff --git a/internal/runtime/executor/codex_executor_stream_output_test.go b/internal/runtime/executor/codex_executor_stream_output_test.go new file mode 100644 index 00000000..91d9b076 --- /dev/null +++ b/internal/runtime/executor/codex_executor_stream_output_test.go @@ -0,0 +1,46 @@ +package executor + +import ( + "context" + "net/http" + "net/http/httptest" + "testing" + + "github.com/router-for-me/CLIProxyAPI/v6/internal/config" + _ "github.com/router-for-me/CLIProxyAPI/v6/internal/translator" + cliproxyauth "github.com/router-for-me/CLIProxyAPI/v6/sdk/cliproxy/auth" + cliproxyexecutor "github.com/router-for-me/CLIProxyAPI/v6/sdk/cliproxy/executor" + sdktranslator "github.com/router-for-me/CLIProxyAPI/v6/sdk/translator" + "github.com/tidwall/gjson" +) + +func TestCodexExecutorExecute_EmptyStreamCompletionOutputUsesOutputItemDone(t *testing.T) { + server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + w.Header().Set("Content-Type", "text/event-stream") + _, _ = w.Write([]byte("data: {\"type\":\"response.output_item.done\",\"item\":{\"type\":\"message\",\"role\":\"assistant\",\"content\":[{\"type\":\"output_text\",\"text\":\"ok\"}]},\"output_index\":0}\n")) + _, _ = w.Write([]byte("data: {\"type\":\"response.completed\",\"response\":{\"id\":\"resp_1\",\"object\":\"response\",\"created_at\":1775555723,\"status\":\"completed\",\"model\":\"gpt-5.4-mini-2026-03-17\",\"output\":[],\"usage\":{\"input_tokens\":8,\"output_tokens\":28,\"total_tokens\":36}}}\n\n")) + })) + defer server.Close() + + executor := NewCodexExecutor(&config.Config{}) + auth := &cliproxyauth.Auth{Attributes: map[string]string{ + "base_url": server.URL, + "api_key": "test", + }} + + resp, err := executor.Execute(context.Background(), auth, cliproxyexecutor.Request{ + Model: "gpt-5.4-mini", + Payload: []byte(`{"model":"gpt-5.4-mini","messages":[{"role":"user","content":"Say ok"}]}`), + }, cliproxyexecutor.Options{ + SourceFormat: sdktranslator.FromString("openai"), + Stream: false, + }) + if err != nil { + t.Fatalf("Execute error: %v", err) + } + + gotContent := gjson.GetBytes(resp.Payload, "choices.0.message.content").String() + if gotContent != "ok" { + t.Fatalf("choices.0.message.content = %q, want %q; payload=%s", gotContent, "ok", string(resp.Payload)) + } +}