Merge pull request #494 from router-for-me/plus

v6.9.16
This commit is contained in:
Luis Pater
2026-04-07 18:23:51 +08:00
committed by GitHub
33 changed files with 2795 additions and 492 deletions

View File

@@ -190,6 +190,7 @@ func main() {
gitStoreRemoteURL string
gitStoreUser string
gitStorePassword string
gitStoreBranch string
gitStoreLocalPath string
gitStoreInst *store.GitTokenStore
gitStoreRoot string
@@ -259,6 +260,9 @@ func main() {
if value, ok := lookupEnv("GITSTORE_LOCAL_PATH", "gitstore_local_path"); ok {
gitStoreLocalPath = value
}
if value, ok := lookupEnv("GITSTORE_GIT_BRANCH", "gitstore_git_branch"); ok {
gitStoreBranch = value
}
if value, ok := lookupEnv("OBJECTSTORE_ENDPOINT", "objectstore_endpoint"); ok {
useObjectStore = true
objectStoreEndpoint = value
@@ -393,7 +397,7 @@ func main() {
}
gitStoreRoot = filepath.Join(gitStoreLocalPath, "gitstore")
authDir := filepath.Join(gitStoreRoot, "auths")
gitStoreInst = store.NewGitTokenStore(gitStoreRemoteURL, gitStoreUser, gitStorePassword)
gitStoreInst = store.NewGitTokenStore(gitStoreRemoteURL, gitStoreUser, gitStorePassword, gitStoreBranch)
gitStoreInst.SetBaseDir(authDir)
if errRepo := gitStoreInst.EnsureRepository(); errRepo != nil {
log.Errorf("failed to prepare git token store: %v", errRepo)

View File

@@ -92,6 +92,9 @@ max-retry-credentials: 0
# Maximum wait time in seconds for a cooled-down credential before triggering a retry.
max-retry-interval: 30
# When true, disable auth/model cooldown scheduling globally (prevents blackout windows after failure states).
disable-cooling: false
# Quota exceeded behavior
quota-exceeded:
switch-project: true # Whether to automatically switch to another project when a quota is exceeded

View File

@@ -2,6 +2,7 @@ package amp
import (
"bytes"
"encoding/json"
"fmt"
"net/http"
"strings"
@@ -298,8 +299,10 @@ func (rw *ResponseRewriter) rewriteStreamEvent(data []byte) []byte {
}
// SanitizeAmpRequestBody removes thinking blocks with empty/missing/invalid signatures
// from the messages array in a request body before forwarding to the upstream API.
// This prevents 400 errors from the API which requires valid signatures on thinking blocks.
// and strips the proxy-injected "signature" field from tool_use blocks in the messages
// array before forwarding to the upstream API.
// This prevents 400 errors from the API which requires valid signatures on thinking
// blocks and does not accept a signature field on tool_use blocks.
func SanitizeAmpRequestBody(body []byte) []byte {
messages := gjson.GetBytes(body, "messages")
if !messages.Exists() || !messages.IsArray() {
@@ -317,21 +320,30 @@ func SanitizeAmpRequestBody(body []byte) []byte {
}
var keepBlocks []interface{}
removedCount := 0
contentModified := false
for _, block := range content.Array() {
blockType := block.Get("type").String()
if blockType == "thinking" {
sig := block.Get("signature")
if !sig.Exists() || sig.Type != gjson.String || strings.TrimSpace(sig.String()) == "" {
removedCount++
contentModified = true
continue
}
}
keepBlocks = append(keepBlocks, block.Value())
// Use raw JSON to prevent float64 rounding of large integers in tool_use inputs
blockRaw := []byte(block.Raw)
if blockType == "tool_use" && block.Get("signature").Exists() {
blockRaw, _ = sjson.DeleteBytes(blockRaw, "signature")
contentModified = true
}
// sjson.SetBytes supports raw JSON strings if wrapped in gjson.Raw
keepBlocks = append(keepBlocks, json.RawMessage(blockRaw))
}
if removedCount > 0 {
if contentModified {
contentPath := fmt.Sprintf("messages.%d.content", msgIdx)
var err error
if len(keepBlocks) == 0 {
@@ -340,11 +352,10 @@ func SanitizeAmpRequestBody(body []byte) []byte {
body, err = sjson.SetBytes(body, contentPath, keepBlocks)
}
if err != nil {
log.Warnf("Amp RequestSanitizer: failed to remove thinking blocks from message %d: %v", msgIdx, err)
log.Warnf("Amp RequestSanitizer: failed to sanitize message %d: %v", msgIdx, err)
continue
}
modified = true
log.Debugf("Amp RequestSanitizer: removed %d thinking blocks with invalid signatures from message %d", removedCount, msgIdx)
}
}

View File

@@ -145,6 +145,36 @@ func TestSanitizeAmpRequestBody_RemovesWhitespaceAndNonStringSignatures(t *testi
}
}
func TestSanitizeAmpRequestBody_StripsSignatureFromToolUseBlocks(t *testing.T) {
input := []byte(`{"messages":[{"role":"assistant","content":[{"type":"thinking","thinking":"thought","signature":"valid-sig"},{"type":"tool_use","id":"toolu_01","name":"Bash","input":{"cmd":"ls"},"signature":""}]}]}`)
result := SanitizeAmpRequestBody(input)
if contains(result, []byte(`"signature":""`)) {
t.Fatalf("expected signature to be stripped from tool_use block, got %s", string(result))
}
if !contains(result, []byte(`"valid-sig"`)) {
t.Fatalf("expected thinking signature to remain, got %s", string(result))
}
if !contains(result, []byte(`"tool_use"`)) {
t.Fatalf("expected tool_use block to remain, got %s", string(result))
}
}
func TestSanitizeAmpRequestBody_MixedInvalidThinkingAndToolUseSignature(t *testing.T) {
input := []byte(`{"messages":[{"role":"assistant","content":[{"type":"thinking","thinking":"drop-me","signature":""},{"type":"tool_use","id":"toolu_01","name":"Bash","input":{"cmd":"ls"},"signature":""}]}]}`)
result := SanitizeAmpRequestBody(input)
if contains(result, []byte("drop-me")) {
t.Fatalf("expected invalid thinking block to be removed, got %s", string(result))
}
if contains(result, []byte(`"signature"`)) {
t.Fatalf("expected signature to be stripped from tool_use block, got %s", string(result))
}
if !contains(result, []byte(`"tool_use"`)) {
t.Fatalf("expected tool_use block to remain, got %s", string(result))
}
}
func contains(data, substr []byte) bool {
for i := 0; i <= len(data)-len(substr); i++ {
if string(data[i:i+len(substr)]) == string(substr) {

View File

@@ -8,7 +8,6 @@ import (
"context"
"crypto/sha256"
"encoding/hex"
"encoding/json"
"fmt"
"io"
"net/http"
@@ -841,6 +840,9 @@ func applyClaudeHeaders(r *http.Request, auth *cliproxyauth.Auth, apiKey string,
baseBetas += ",oauth-2025-04-20"
}
}
if !strings.Contains(baseBetas, "interleaved-thinking") {
baseBetas += ",interleaved-thinking-2025-05-14"
}
hasClaude1MHeader := false
if ginHeaders != nil {
@@ -957,12 +959,9 @@ func applyClaudeToolPrefix(body []byte, prefix string) []byte {
return body
}
// Collect built-in tool names (those with a non-empty "type" field) so we can
// skip them consistently in both tools and message history.
builtinTools := map[string]bool{}
for _, name := range []string{"web_search", "code_execution", "text_editor", "computer"} {
builtinTools[name] = true
}
// Collect built-in tool names from the authoritative fallback seed list and
// augment it with any typed built-ins present in the current request body.
builtinTools := helps.AugmentClaudeBuiltinToolRegistry(body, nil)
if tools := gjson.GetBytes(body, "tools"); tools.Exists() && tools.IsArray() {
tools.ForEach(func(index, tool gjson.Result) bool {
@@ -1471,182 +1470,6 @@ func countCacheControls(payload []byte) int {
return count
}
func parsePayloadObject(payload []byte) (map[string]any, bool) {
if len(payload) == 0 {
return nil, false
}
var root map[string]any
if err := json.Unmarshal(payload, &root); err != nil {
return nil, false
}
return root, true
}
func marshalPayloadObject(original []byte, root map[string]any) []byte {
if root == nil {
return original
}
out, err := json.Marshal(root)
if err != nil {
return original
}
return out
}
func asObject(v any) (map[string]any, bool) {
obj, ok := v.(map[string]any)
return obj, ok
}
func asArray(v any) ([]any, bool) {
arr, ok := v.([]any)
return arr, ok
}
func countCacheControlsMap(root map[string]any) int {
count := 0
if system, ok := asArray(root["system"]); ok {
for _, item := range system {
if obj, ok := asObject(item); ok {
if _, exists := obj["cache_control"]; exists {
count++
}
}
}
}
if tools, ok := asArray(root["tools"]); ok {
for _, item := range tools {
if obj, ok := asObject(item); ok {
if _, exists := obj["cache_control"]; exists {
count++
}
}
}
}
if messages, ok := asArray(root["messages"]); ok {
for _, msg := range messages {
msgObj, ok := asObject(msg)
if !ok {
continue
}
content, ok := asArray(msgObj["content"])
if !ok {
continue
}
for _, item := range content {
if obj, ok := asObject(item); ok {
if _, exists := obj["cache_control"]; exists {
count++
}
}
}
}
}
return count
}
func normalizeTTLForBlock(obj map[string]any, seen5m *bool) bool {
ccRaw, exists := obj["cache_control"]
if !exists {
return false
}
cc, ok := asObject(ccRaw)
if !ok {
*seen5m = true
return false
}
ttlRaw, ttlExists := cc["ttl"]
ttl, ttlIsString := ttlRaw.(string)
if !ttlExists || !ttlIsString || ttl != "1h" {
*seen5m = true
return false
}
if *seen5m {
delete(cc, "ttl")
return true
}
return false
}
func findLastCacheControlIndex(arr []any) int {
last := -1
for idx, item := range arr {
obj, ok := asObject(item)
if !ok {
continue
}
if _, exists := obj["cache_control"]; exists {
last = idx
}
}
return last
}
func stripCacheControlExceptIndex(arr []any, preserveIdx int, excess *int) {
for idx, item := range arr {
if *excess <= 0 {
return
}
obj, ok := asObject(item)
if !ok {
continue
}
if _, exists := obj["cache_control"]; exists && idx != preserveIdx {
delete(obj, "cache_control")
*excess--
}
}
}
func stripAllCacheControl(arr []any, excess *int) {
for _, item := range arr {
if *excess <= 0 {
return
}
obj, ok := asObject(item)
if !ok {
continue
}
if _, exists := obj["cache_control"]; exists {
delete(obj, "cache_control")
*excess--
}
}
}
func stripMessageCacheControl(messages []any, excess *int) {
for _, msg := range messages {
if *excess <= 0 {
return
}
msgObj, ok := asObject(msg)
if !ok {
continue
}
content, ok := asArray(msgObj["content"])
if !ok {
continue
}
for _, item := range content {
if *excess <= 0 {
return
}
obj, ok := asObject(item)
if !ok {
continue
}
if _, exists := obj["cache_control"]; exists {
delete(obj, "cache_control")
*excess--
}
}
}
}
// normalizeCacheControlTTL ensures cache_control TTL values don't violate the
// prompt-caching-scope-2026-01-05 ordering constraint: a 1h-TTL block must not
// appear after a 5m-TTL block anywhere in the evaluation order.
@@ -1659,58 +1482,75 @@ func stripMessageCacheControl(messages []any, excess *int) {
// Strategy: walk all cache_control blocks in evaluation order. Once a 5m block
// is seen, strip ttl from ALL subsequent 1h blocks (downgrading them to 5m).
func normalizeCacheControlTTL(payload []byte) []byte {
root, ok := parsePayloadObject(payload)
if !ok {
if len(payload) == 0 || !gjson.ValidBytes(payload) {
return payload
}
original := payload
seen5m := false
modified := false
if tools, ok := asArray(root["tools"]); ok {
for _, tool := range tools {
if obj, ok := asObject(tool); ok {
if normalizeTTLForBlock(obj, &seen5m) {
modified = true
}
}
processBlock := func(path string, obj gjson.Result) {
cc := obj.Get("cache_control")
if !cc.Exists() {
return
}
if !cc.IsObject() {
seen5m = true
return
}
ttl := cc.Get("ttl")
if ttl.Type != gjson.String || ttl.String() != "1h" {
seen5m = true
return
}
if !seen5m {
return
}
ttlPath := path + ".cache_control.ttl"
updated, errDel := sjson.DeleteBytes(payload, ttlPath)
if errDel != nil {
return
}
payload = updated
modified = true
}
if system, ok := asArray(root["system"]); ok {
for _, item := range system {
if obj, ok := asObject(item); ok {
if normalizeTTLForBlock(obj, &seen5m) {
modified = true
}
}
}
tools := gjson.GetBytes(payload, "tools")
if tools.IsArray() {
tools.ForEach(func(idx, item gjson.Result) bool {
processBlock(fmt.Sprintf("tools.%d", int(idx.Int())), item)
return true
})
}
if messages, ok := asArray(root["messages"]); ok {
for _, msg := range messages {
msgObj, ok := asObject(msg)
if !ok {
continue
system := gjson.GetBytes(payload, "system")
if system.IsArray() {
system.ForEach(func(idx, item gjson.Result) bool {
processBlock(fmt.Sprintf("system.%d", int(idx.Int())), item)
return true
})
}
messages := gjson.GetBytes(payload, "messages")
if messages.IsArray() {
messages.ForEach(func(msgIdx, msg gjson.Result) bool {
content := msg.Get("content")
if !content.IsArray() {
return true
}
content, ok := asArray(msgObj["content"])
if !ok {
continue
}
for _, item := range content {
if obj, ok := asObject(item); ok {
if normalizeTTLForBlock(obj, &seen5m) {
modified = true
}
}
}
}
content.ForEach(func(itemIdx, item gjson.Result) bool {
processBlock(fmt.Sprintf("messages.%d.content.%d", int(msgIdx.Int()), int(itemIdx.Int())), item)
return true
})
return true
})
}
if !modified {
return payload
return original
}
return marshalPayloadObject(payload, root)
return payload
}
// enforceCacheControlLimit removes excess cache_control blocks from a payload
@@ -1730,64 +1570,166 @@ func normalizeCacheControlTTL(payload []byte) []byte {
// Phase 4: remaining system blocks (last system).
// Phase 5: remaining tool blocks (last tool).
func enforceCacheControlLimit(payload []byte, maxBlocks int) []byte {
root, ok := parsePayloadObject(payload)
if !ok {
if len(payload) == 0 || !gjson.ValidBytes(payload) {
return payload
}
total := countCacheControlsMap(root)
total := countCacheControls(payload)
if total <= maxBlocks {
return payload
}
excess := total - maxBlocks
var system []any
if arr, ok := asArray(root["system"]); ok {
system = arr
}
var tools []any
if arr, ok := asArray(root["tools"]); ok {
tools = arr
}
var messages []any
if arr, ok := asArray(root["messages"]); ok {
messages = arr
}
if len(system) > 0 {
stripCacheControlExceptIndex(system, findLastCacheControlIndex(system), &excess)
system := gjson.GetBytes(payload, "system")
if system.IsArray() {
lastIdx := -1
system.ForEach(func(idx, item gjson.Result) bool {
if item.Get("cache_control").Exists() {
lastIdx = int(idx.Int())
}
return true
})
if lastIdx >= 0 {
system.ForEach(func(idx, item gjson.Result) bool {
if excess <= 0 {
return false
}
i := int(idx.Int())
if i == lastIdx {
return true
}
if !item.Get("cache_control").Exists() {
return true
}
path := fmt.Sprintf("system.%d.cache_control", i)
updated, errDel := sjson.DeleteBytes(payload, path)
if errDel != nil {
return true
}
payload = updated
excess--
return true
})
}
}
if excess <= 0 {
return marshalPayloadObject(payload, root)
return payload
}
if len(tools) > 0 {
stripCacheControlExceptIndex(tools, findLastCacheControlIndex(tools), &excess)
tools := gjson.GetBytes(payload, "tools")
if tools.IsArray() {
lastIdx := -1
tools.ForEach(func(idx, item gjson.Result) bool {
if item.Get("cache_control").Exists() {
lastIdx = int(idx.Int())
}
return true
})
if lastIdx >= 0 {
tools.ForEach(func(idx, item gjson.Result) bool {
if excess <= 0 {
return false
}
i := int(idx.Int())
if i == lastIdx {
return true
}
if !item.Get("cache_control").Exists() {
return true
}
path := fmt.Sprintf("tools.%d.cache_control", i)
updated, errDel := sjson.DeleteBytes(payload, path)
if errDel != nil {
return true
}
payload = updated
excess--
return true
})
}
}
if excess <= 0 {
return marshalPayloadObject(payload, root)
return payload
}
if len(messages) > 0 {
stripMessageCacheControl(messages, &excess)
messages := gjson.GetBytes(payload, "messages")
if messages.IsArray() {
messages.ForEach(func(msgIdx, msg gjson.Result) bool {
if excess <= 0 {
return false
}
content := msg.Get("content")
if !content.IsArray() {
return true
}
content.ForEach(func(itemIdx, item gjson.Result) bool {
if excess <= 0 {
return false
}
if !item.Get("cache_control").Exists() {
return true
}
path := fmt.Sprintf("messages.%d.content.%d.cache_control", int(msgIdx.Int()), int(itemIdx.Int()))
updated, errDel := sjson.DeleteBytes(payload, path)
if errDel != nil {
return true
}
payload = updated
excess--
return true
})
return true
})
}
if excess <= 0 {
return marshalPayloadObject(payload, root)
return payload
}
if len(system) > 0 {
stripAllCacheControl(system, &excess)
system = gjson.GetBytes(payload, "system")
if system.IsArray() {
system.ForEach(func(idx, item gjson.Result) bool {
if excess <= 0 {
return false
}
if !item.Get("cache_control").Exists() {
return true
}
path := fmt.Sprintf("system.%d.cache_control", int(idx.Int()))
updated, errDel := sjson.DeleteBytes(payload, path)
if errDel != nil {
return true
}
payload = updated
excess--
return true
})
}
if excess <= 0 {
return marshalPayloadObject(payload, root)
return payload
}
if len(tools) > 0 {
stripAllCacheControl(tools, &excess)
tools = gjson.GetBytes(payload, "tools")
if tools.IsArray() {
tools.ForEach(func(idx, item gjson.Result) bool {
if excess <= 0 {
return false
}
if !item.Get("cache_control").Exists() {
return true
}
path := fmt.Sprintf("tools.%d.cache_control", int(idx.Int()))
updated, errDel := sjson.DeleteBytes(payload, path)
if errDel != nil {
return true
}
payload = updated
excess--
return true
})
}
return marshalPayloadObject(payload, root)
return payload
}
// injectMessagesCacheControl adds cache_control to the second-to-last user turn for multi-turn caching.

View File

@@ -739,6 +739,35 @@ func TestApplyClaudeToolPrefix_ToolChoiceBuiltin(t *testing.T) {
}
}
func TestApplyClaudeToolPrefix_KnownFallbackBuiltinsRemainUnprefixed(t *testing.T) {
for _, builtin := range []string{"web_search", "code_execution", "text_editor", "computer"} {
t.Run(builtin, func(t *testing.T) {
input := []byte(fmt.Sprintf(`{
"tools":[{"name":"Read"}],
"tool_choice":{"type":"tool","name":%q},
"messages":[{"role":"assistant","content":[{"type":"tool_use","name":%q,"id":"toolu_1","input":{}},{"type":"tool_reference","tool_name":%q},{"type":"tool_result","tool_use_id":"toolu_1","content":[{"type":"tool_reference","tool_name":%q}]}]}]
}`, builtin, builtin, builtin, builtin))
out := applyClaudeToolPrefix(input, "proxy_")
if got := gjson.GetBytes(out, "tool_choice.name").String(); got != builtin {
t.Fatalf("tool_choice.name = %q, want %q", got, builtin)
}
if got := gjson.GetBytes(out, "messages.0.content.0.name").String(); got != builtin {
t.Fatalf("messages.0.content.0.name = %q, want %q", got, builtin)
}
if got := gjson.GetBytes(out, "messages.0.content.1.tool_name").String(); got != builtin {
t.Fatalf("messages.0.content.1.tool_name = %q, want %q", got, builtin)
}
if got := gjson.GetBytes(out, "messages.0.content.2.content.0.tool_name").String(); got != builtin {
t.Fatalf("messages.0.content.2.content.0.tool_name = %q, want %q", got, builtin)
}
if got := gjson.GetBytes(out, "tools.0.name").String(); got != "proxy_Read" {
t.Fatalf("tools.0.name = %q, want %q", got, "proxy_Read")
}
})
}
}
func TestStripClaudeToolPrefixFromResponse(t *testing.T) {
input := []byte(`{"content":[{"type":"tool_use","name":"proxy_alpha","id":"t1","input":{}},{"type":"tool_use","name":"bravo","id":"t2","input":{}}]}`)
out := stripClaudeToolPrefixFromResponse(input, "proxy_")
@@ -965,6 +994,28 @@ func TestNormalizeCacheControlTTL_PreservesOriginalBytesWhenNoChange(t *testing.
}
}
func TestNormalizeCacheControlTTL_PreservesKeyOrderWhenModified(t *testing.T) {
payload := []byte(`{"model":"m","messages":[{"role":"user","content":[{"type":"text","text":"u1","cache_control":{"type":"ephemeral","ttl":"1h"}}]}],"tools":[{"name":"t1","cache_control":{"type":"ephemeral"}}],"system":[{"type":"text","text":"s1","cache_control":{"type":"ephemeral"}}]}`)
out := normalizeCacheControlTTL(payload)
if gjson.GetBytes(out, "messages.0.content.0.cache_control.ttl").Exists() {
t.Fatalf("messages.0.content.0.cache_control.ttl should be removed after a default-5m block")
}
outStr := string(out)
idxModel := strings.Index(outStr, `"model"`)
idxMessages := strings.Index(outStr, `"messages"`)
idxTools := strings.Index(outStr, `"tools"`)
idxSystem := strings.Index(outStr, `"system"`)
if idxModel == -1 || idxMessages == -1 || idxTools == -1 || idxSystem == -1 {
t.Fatalf("failed to locate top-level keys in output: %s", outStr)
}
if !(idxModel < idxMessages && idxMessages < idxTools && idxTools < idxSystem) {
t.Fatalf("top-level key order changed:\noriginal: %s\ngot: %s", payload, out)
}
}
func TestEnforceCacheControlLimit_StripsNonLastToolBeforeMessages(t *testing.T) {
payload := []byte(`{
"tools": [
@@ -994,6 +1045,31 @@ func TestEnforceCacheControlLimit_StripsNonLastToolBeforeMessages(t *testing.T)
}
}
func TestEnforceCacheControlLimit_PreservesKeyOrderWhenModified(t *testing.T) {
payload := []byte(`{"model":"m","messages":[{"role":"user","content":[{"type":"text","text":"u1","cache_control":{"type":"ephemeral"}},{"type":"text","text":"u2","cache_control":{"type":"ephemeral"}}]}],"tools":[{"name":"t1","cache_control":{"type":"ephemeral"}},{"name":"t2","cache_control":{"type":"ephemeral"}}],"system":[{"type":"text","text":"s1","cache_control":{"type":"ephemeral"}}]}`)
out := enforceCacheControlLimit(payload, 4)
if got := countCacheControls(out); got != 4 {
t.Fatalf("cache_control count = %d, want 4", got)
}
if gjson.GetBytes(out, "tools.0.cache_control").Exists() {
t.Fatalf("tools.0.cache_control should be removed first (non-last tool)")
}
outStr := string(out)
idxModel := strings.Index(outStr, `"model"`)
idxMessages := strings.Index(outStr, `"messages"`)
idxTools := strings.Index(outStr, `"tools"`)
idxSystem := strings.Index(outStr, `"system"`)
if idxModel == -1 || idxMessages == -1 || idxTools == -1 || idxSystem == -1 {
t.Fatalf("failed to locate top-level keys in output: %s", outStr)
}
if !(idxModel < idxMessages && idxMessages < idxTools && idxTools < idxSystem) {
t.Fatalf("top-level key order changed:\noriginal: %s\ngot: %s", payload, out)
}
}
func TestEnforceCacheControlLimit_ToolOnlyPayloadStillRespectsLimit(t *testing.T) {
payload := []byte(`{
"tools": [

View File

@@ -7,6 +7,7 @@ import (
"fmt"
"io"
"net/http"
"sort"
"strings"
"time"
@@ -167,22 +168,63 @@ func (e *CodexExecutor) Execute(ctx context.Context, auth *cliproxyauth.Auth, re
helps.AppendAPIResponseChunk(ctx, e.cfg, data)
lines := bytes.Split(data, []byte("\n"))
outputItemsByIndex := make(map[int64][]byte)
var outputItemsFallback [][]byte
for _, line := range lines {
if !bytes.HasPrefix(line, dataTag) {
continue
}
line = bytes.TrimSpace(line[5:])
if gjson.GetBytes(line, "type").String() != "response.completed" {
eventData := bytes.TrimSpace(line[5:])
eventType := gjson.GetBytes(eventData, "type").String()
if eventType == "response.output_item.done" {
itemResult := gjson.GetBytes(eventData, "item")
if !itemResult.Exists() || itemResult.Type != gjson.JSON {
continue
}
outputIndexResult := gjson.GetBytes(eventData, "output_index")
if outputIndexResult.Exists() {
outputItemsByIndex[outputIndexResult.Int()] = []byte(itemResult.Raw)
} else {
outputItemsFallback = append(outputItemsFallback, []byte(itemResult.Raw))
}
continue
}
if detail, ok := helps.ParseCodexUsage(line); ok {
if eventType != "response.completed" {
continue
}
if detail, ok := helps.ParseCodexUsage(eventData); ok {
reporter.Publish(ctx, detail)
}
completedData := eventData
outputResult := gjson.GetBytes(completedData, "response.output")
shouldPatchOutput := (!outputResult.Exists() || !outputResult.IsArray() || len(outputResult.Array()) == 0) && (len(outputItemsByIndex) > 0 || len(outputItemsFallback) > 0)
if shouldPatchOutput {
completedDataPatched := completedData
completedDataPatched, _ = sjson.SetRawBytes(completedDataPatched, "response.output", []byte(`[]`))
indexes := make([]int64, 0, len(outputItemsByIndex))
for idx := range outputItemsByIndex {
indexes = append(indexes, idx)
}
sort.Slice(indexes, func(i, j int) bool {
return indexes[i] < indexes[j]
})
for _, idx := range indexes {
completedDataPatched, _ = sjson.SetRawBytes(completedDataPatched, "response.output.-1", outputItemsByIndex[idx])
}
for _, item := range outputItemsFallback {
completedDataPatched, _ = sjson.SetRawBytes(completedDataPatched, "response.output.-1", item)
}
completedData = completedDataPatched
}
var param any
out := sdktranslator.TranslateNonStream(ctx, to, from, req.Model, originalPayload, body, line, &param)
out := sdktranslator.TranslateNonStream(ctx, to, from, req.Model, originalPayload, body, completedData, &param)
resp = cliproxyexecutor.Response{Payload: out, Headers: httpResp.Header.Clone()}
return resp, nil
}

View File

@@ -0,0 +1,46 @@
package executor
import (
"context"
"net/http"
"net/http/httptest"
"testing"
"github.com/router-for-me/CLIProxyAPI/v6/internal/config"
_ "github.com/router-for-me/CLIProxyAPI/v6/internal/translator"
cliproxyauth "github.com/router-for-me/CLIProxyAPI/v6/sdk/cliproxy/auth"
cliproxyexecutor "github.com/router-for-me/CLIProxyAPI/v6/sdk/cliproxy/executor"
sdktranslator "github.com/router-for-me/CLIProxyAPI/v6/sdk/translator"
"github.com/tidwall/gjson"
)
func TestCodexExecutorExecute_EmptyStreamCompletionOutputUsesOutputItemDone(t *testing.T) {
server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
w.Header().Set("Content-Type", "text/event-stream")
_, _ = w.Write([]byte("data: {\"type\":\"response.output_item.done\",\"item\":{\"type\":\"message\",\"role\":\"assistant\",\"content\":[{\"type\":\"output_text\",\"text\":\"ok\"}]},\"output_index\":0}\n"))
_, _ = w.Write([]byte("data: {\"type\":\"response.completed\",\"response\":{\"id\":\"resp_1\",\"object\":\"response\",\"created_at\":1775555723,\"status\":\"completed\",\"model\":\"gpt-5.4-mini-2026-03-17\",\"output\":[],\"usage\":{\"input_tokens\":8,\"output_tokens\":28,\"total_tokens\":36}}}\n\n"))
}))
defer server.Close()
executor := NewCodexExecutor(&config.Config{})
auth := &cliproxyauth.Auth{Attributes: map[string]string{
"base_url": server.URL,
"api_key": "test",
}}
resp, err := executor.Execute(context.Background(), auth, cliproxyexecutor.Request{
Model: "gpt-5.4-mini",
Payload: []byte(`{"model":"gpt-5.4-mini","messages":[{"role":"user","content":"Say ok"}]}`),
}, cliproxyexecutor.Options{
SourceFormat: sdktranslator.FromString("openai"),
Stream: false,
})
if err != nil {
t.Fatalf("Execute error: %v", err)
}
gotContent := gjson.GetBytes(resp.Payload, "choices.0.message.content").String()
if gotContent != "ok" {
t.Fatalf("choices.0.message.content = %q, want %q; payload=%s", gotContent, "ok", string(resp.Payload))
}
}

View File

@@ -734,7 +734,7 @@ func newProxyAwareWebsocketDialer(cfg *config.Config, auth *cliproxyauth.Auth) *
}
switch setting.URL.Scheme {
case "socks5":
case "socks5", "socks5h":
var proxyAuth *proxy.Auth
if setting.URL.User != nil {
username := setting.URL.User.Username()

View File

@@ -0,0 +1,38 @@
package helps
import "github.com/tidwall/gjson"
var defaultClaudeBuiltinToolNames = []string{
"web_search",
"code_execution",
"text_editor",
"computer",
}
func newClaudeBuiltinToolRegistry() map[string]bool {
registry := make(map[string]bool, len(defaultClaudeBuiltinToolNames))
for _, name := range defaultClaudeBuiltinToolNames {
registry[name] = true
}
return registry
}
func AugmentClaudeBuiltinToolRegistry(body []byte, registry map[string]bool) map[string]bool {
if registry == nil {
registry = newClaudeBuiltinToolRegistry()
}
tools := gjson.GetBytes(body, "tools")
if !tools.Exists() || !tools.IsArray() {
return registry
}
tools.ForEach(func(_, tool gjson.Result) bool {
if tool.Get("type").String() == "" {
return true
}
if name := tool.Get("name").String(); name != "" {
registry[name] = true
}
return true
})
return registry
}

View File

@@ -0,0 +1,32 @@
package helps
import "testing"
func TestClaudeBuiltinToolRegistry_DefaultSeedFallback(t *testing.T) {
registry := AugmentClaudeBuiltinToolRegistry(nil, nil)
for _, name := range defaultClaudeBuiltinToolNames {
if !registry[name] {
t.Fatalf("default builtin %q missing from fallback registry", name)
}
}
}
func TestClaudeBuiltinToolRegistry_AugmentsTypedBuiltinsFromBody(t *testing.T) {
registry := AugmentClaudeBuiltinToolRegistry([]byte(`{
"tools": [
{"type": "web_search_20250305", "name": "web_search"},
{"type": "custom_builtin_20250401", "name": "special_builtin"},
{"name": "Read"}
]
}`), nil)
if !registry["web_search"] {
t.Fatal("expected default typed builtin web_search in registry")
}
if !registry["special_builtin"] {
t.Fatal("expected typed builtin from body to be added to registry")
}
if registry["Read"] {
t.Fatal("expected untyped custom tool to stay out of builtin registry")
}
}

View File

@@ -298,6 +298,14 @@ func (e *OpenAICompatExecutor) ExecuteStream(ctx context.Context, auth *cliproxy
helps.RecordAPIResponseError(ctx, e.cfg, errScan)
reporter.PublishFailure(ctx)
out <- cliproxyexecutor.StreamChunk{Err: errScan}
} else {
// In case the upstream close the stream without a terminal [DONE] marker.
// Feed a synthetic done marker through the translator so pending
// response.completed events are still emitted exactly once.
chunks := sdktranslator.TranslateStream(ctx, to, from, req.Model, opts.OriginalRequest, translated, []byte("data: [DONE]"), &param)
for i := range chunks {
out <- cliproxyexecutor.StreamChunk{Payload: chunks[i]}
}
}
// Ensure we record the request if no usage chunk was ever seen
reporter.EnsurePublished(ctx)

View File

@@ -32,16 +32,24 @@ type GitTokenStore struct {
repoDir string
configDir string
remote string
branch string
username string
password string
lastGC time.Time
}
type resolvedRemoteBranch struct {
name plumbing.ReferenceName
hash plumbing.Hash
}
// NewGitTokenStore creates a token store that saves credentials to disk through the
// TokenStorage implementation embedded in the token record.
func NewGitTokenStore(remote, username, password string) *GitTokenStore {
// When branch is non-empty, clone/pull/push operations target that branch instead of the remote default.
func NewGitTokenStore(remote, username, password, branch string) *GitTokenStore {
return &GitTokenStore{
remote: remote,
branch: strings.TrimSpace(branch),
username: username,
password: password,
}
@@ -120,7 +128,11 @@ func (s *GitTokenStore) EnsureRepository() error {
s.dirLock.Unlock()
return fmt.Errorf("git token store: create repo dir: %w", errMk)
}
if _, errClone := git.PlainClone(repoDir, &git.CloneOptions{Auth: authMethod, URL: s.remote}); errClone != nil {
cloneOpts := &git.CloneOptions{Auth: authMethod, URL: s.remote}
if s.branch != "" {
cloneOpts.ReferenceName = plumbing.NewBranchReferenceName(s.branch)
}
if _, errClone := git.PlainClone(repoDir, cloneOpts); errClone != nil {
if errors.Is(errClone, transport.ErrEmptyRemoteRepository) {
_ = os.RemoveAll(gitDir)
repo, errInit := git.PlainInit(repoDir, false)
@@ -128,6 +140,13 @@ func (s *GitTokenStore) EnsureRepository() error {
s.dirLock.Unlock()
return fmt.Errorf("git token store: init empty repo: %w", errInit)
}
if s.branch != "" {
headRef := plumbing.NewSymbolicReference(plumbing.HEAD, plumbing.NewBranchReferenceName(s.branch))
if errHead := repo.Storer.SetReference(headRef); errHead != nil {
s.dirLock.Unlock()
return fmt.Errorf("git token store: set head to branch %s: %w", s.branch, errHead)
}
}
if _, errRemote := repo.Remote("origin"); errRemote != nil {
if _, errCreate := repo.CreateRemote(&config.RemoteConfig{
Name: "origin",
@@ -176,16 +195,39 @@ func (s *GitTokenStore) EnsureRepository() error {
s.dirLock.Unlock()
return fmt.Errorf("git token store: worktree: %w", errWorktree)
}
if errPull := worktree.Pull(&git.PullOptions{Auth: authMethod, RemoteName: "origin"}); errPull != nil {
if s.branch != "" {
if errCheckout := s.checkoutConfiguredBranch(repo, worktree, authMethod); errCheckout != nil {
s.dirLock.Unlock()
return errCheckout
}
} else {
// When branch is unset, ensure the working tree follows the remote default branch
if err := checkoutRemoteDefaultBranch(repo, worktree, authMethod); err != nil {
if !shouldFallbackToCurrentBranch(repo, err) {
s.dirLock.Unlock()
return fmt.Errorf("git token store: checkout remote default: %w", err)
}
}
}
pullOpts := &git.PullOptions{Auth: authMethod, RemoteName: "origin"}
if s.branch != "" {
pullOpts.ReferenceName = plumbing.NewBranchReferenceName(s.branch)
}
if errPull := worktree.Pull(pullOpts); errPull != nil {
switch {
case errors.Is(errPull, git.NoErrAlreadyUpToDate),
errors.Is(errPull, git.ErrUnstagedChanges),
errors.Is(errPull, git.ErrNonFastForwardUpdate):
// Ignore clean syncs, local edits, and remote divergence—local changes win.
case errors.Is(errPull, transport.ErrAuthenticationRequired),
errors.Is(errPull, plumbing.ErrReferenceNotFound),
errors.Is(errPull, transport.ErrEmptyRemoteRepository):
// Ignore authentication prompts and empty remote references on initial sync.
case errors.Is(errPull, plumbing.ErrReferenceNotFound):
if s.branch != "" {
s.dirLock.Unlock()
return fmt.Errorf("git token store: pull: %w", errPull)
}
// Ignore missing references only when following the remote default branch.
default:
s.dirLock.Unlock()
return fmt.Errorf("git token store: pull: %w", errPull)
@@ -554,6 +596,192 @@ func (s *GitTokenStore) relativeToRepo(path string) (string, error) {
return rel, nil
}
func (s *GitTokenStore) checkoutConfiguredBranch(repo *git.Repository, worktree *git.Worktree, authMethod transport.AuthMethod) error {
branchRefName := plumbing.NewBranchReferenceName(s.branch)
headRef, errHead := repo.Head()
switch {
case errHead == nil && headRef.Name() == branchRefName:
return nil
case errHead != nil && !errors.Is(errHead, plumbing.ErrReferenceNotFound):
return fmt.Errorf("git token store: get head: %w", errHead)
}
if err := worktree.Checkout(&git.CheckoutOptions{Branch: branchRefName}); err == nil {
return nil
} else if _, errRef := repo.Reference(branchRefName, true); errRef == nil {
return fmt.Errorf("git token store: checkout branch %s: %w", s.branch, err)
} else if !errors.Is(errRef, plumbing.ErrReferenceNotFound) {
return fmt.Errorf("git token store: inspect branch %s: %w", s.branch, errRef)
} else if err := s.checkoutConfiguredRemoteTrackingBranch(repo, worktree, branchRefName, authMethod); err != nil {
return fmt.Errorf("git token store: checkout branch %s: %w", s.branch, err)
}
return nil
}
func (s *GitTokenStore) checkoutConfiguredRemoteTrackingBranch(repo *git.Repository, worktree *git.Worktree, branchRefName plumbing.ReferenceName, authMethod transport.AuthMethod) error {
remoteRefName := plumbing.ReferenceName("refs/remotes/origin/" + s.branch)
remoteRef, err := repo.Reference(remoteRefName, true)
if errors.Is(err, plumbing.ErrReferenceNotFound) {
if errSync := syncRemoteReferences(repo, authMethod); errSync != nil {
return fmt.Errorf("sync remote refs: %w", errSync)
}
remoteRef, err = repo.Reference(remoteRefName, true)
}
if err != nil {
return err
}
if err := worktree.Checkout(&git.CheckoutOptions{Branch: branchRefName, Create: true, Hash: remoteRef.Hash()}); err != nil {
return err
}
cfg, err := repo.Config()
if err != nil {
return fmt.Errorf("git token store: repo config: %w", err)
}
if _, ok := cfg.Branches[s.branch]; !ok {
cfg.Branches[s.branch] = &config.Branch{Name: s.branch}
}
cfg.Branches[s.branch].Remote = "origin"
cfg.Branches[s.branch].Merge = branchRefName
if err := repo.SetConfig(cfg); err != nil {
return fmt.Errorf("git token store: set branch config: %w", err)
}
return nil
}
func syncRemoteReferences(repo *git.Repository, authMethod transport.AuthMethod) error {
if err := repo.Fetch(&git.FetchOptions{Auth: authMethod, RemoteName: "origin"}); err != nil && !errors.Is(err, git.NoErrAlreadyUpToDate) {
return err
}
return nil
}
// resolveRemoteDefaultBranch queries the origin remote to determine the remote's default branch
// (the target of HEAD) and returns the corresponding local branch reference name (e.g. refs/heads/master).
func resolveRemoteDefaultBranch(repo *git.Repository, authMethod transport.AuthMethod) (resolvedRemoteBranch, error) {
if err := syncRemoteReferences(repo, authMethod); err != nil {
return resolvedRemoteBranch{}, fmt.Errorf("resolve remote default: sync remote refs: %w", err)
}
remote, err := repo.Remote("origin")
if err != nil {
return resolvedRemoteBranch{}, fmt.Errorf("resolve remote default: get remote: %w", err)
}
refs, err := remote.List(&git.ListOptions{Auth: authMethod})
if err != nil {
if resolved, ok := resolveRemoteDefaultBranchFromLocal(repo); ok {
return resolved, nil
}
return resolvedRemoteBranch{}, fmt.Errorf("resolve remote default: list remote refs: %w", err)
}
for _, r := range refs {
if r.Name() == plumbing.HEAD {
if r.Type() == plumbing.SymbolicReference {
if target, ok := normalizeRemoteBranchReference(r.Target()); ok {
return resolvedRemoteBranch{name: target}, nil
}
}
s := r.String()
if idx := strings.Index(s, "->"); idx != -1 {
if target, ok := normalizeRemoteBranchReference(plumbing.ReferenceName(strings.TrimSpace(s[idx+2:]))); ok {
return resolvedRemoteBranch{name: target}, nil
}
}
}
}
if resolved, ok := resolveRemoteDefaultBranchFromLocal(repo); ok {
return resolved, nil
}
for _, r := range refs {
if normalized, ok := normalizeRemoteBranchReference(r.Name()); ok {
return resolvedRemoteBranch{name: normalized, hash: r.Hash()}, nil
}
}
return resolvedRemoteBranch{}, fmt.Errorf("resolve remote default: remote default branch not found")
}
func resolveRemoteDefaultBranchFromLocal(repo *git.Repository) (resolvedRemoteBranch, bool) {
ref, err := repo.Reference(plumbing.ReferenceName("refs/remotes/origin/HEAD"), true)
if err != nil || ref.Type() != plumbing.SymbolicReference {
return resolvedRemoteBranch{}, false
}
target, ok := normalizeRemoteBranchReference(ref.Target())
if !ok {
return resolvedRemoteBranch{}, false
}
return resolvedRemoteBranch{name: target}, true
}
func normalizeRemoteBranchReference(name plumbing.ReferenceName) (plumbing.ReferenceName, bool) {
switch {
case strings.HasPrefix(name.String(), "refs/heads/"):
return name, true
case strings.HasPrefix(name.String(), "refs/remotes/origin/"):
return plumbing.NewBranchReferenceName(strings.TrimPrefix(name.String(), "refs/remotes/origin/")), true
default:
return "", false
}
}
func shouldFallbackToCurrentBranch(repo *git.Repository, err error) bool {
if !errors.Is(err, transport.ErrAuthenticationRequired) && !errors.Is(err, transport.ErrEmptyRemoteRepository) {
return false
}
_, headErr := repo.Head()
return headErr == nil
}
// checkoutRemoteDefaultBranch ensures the working tree is checked out to the remote's default branch
// (the branch target of origin/HEAD). If the local branch does not exist it will be created to track
// the remote branch.
func checkoutRemoteDefaultBranch(repo *git.Repository, worktree *git.Worktree, authMethod transport.AuthMethod) error {
resolved, err := resolveRemoteDefaultBranch(repo, authMethod)
if err != nil {
return err
}
branchRefName := resolved.name
// If HEAD already points to the desired branch, nothing to do.
headRef, errHead := repo.Head()
if errHead == nil && headRef.Name() == branchRefName {
return nil
}
// If local branch exists, attempt a checkout
if _, err := repo.Reference(branchRefName, true); err == nil {
if err := worktree.Checkout(&git.CheckoutOptions{Branch: branchRefName}); err != nil {
return fmt.Errorf("checkout branch %s: %w", branchRefName.String(), err)
}
return nil
}
// Try to find the corresponding remote tracking ref (refs/remotes/origin/<name>)
branchShort := strings.TrimPrefix(branchRefName.String(), "refs/heads/")
remoteRefName := plumbing.ReferenceName("refs/remotes/origin/" + branchShort)
hash := resolved.hash
if remoteRef, err := repo.Reference(remoteRefName, true); err == nil {
hash = remoteRef.Hash()
} else if err != nil && !errors.Is(err, plumbing.ErrReferenceNotFound) {
return fmt.Errorf("checkout remote default: remote ref %s: %w", remoteRefName.String(), err)
}
if hash == plumbing.ZeroHash {
return fmt.Errorf("checkout remote default: remote ref %s not found", remoteRefName.String())
}
if err := worktree.Checkout(&git.CheckoutOptions{Branch: branchRefName, Create: true, Hash: hash}); err != nil {
return fmt.Errorf("checkout create branch %s: %w", branchRefName.String(), err)
}
cfg, err := repo.Config()
if err != nil {
return fmt.Errorf("git token store: repo config: %w", err)
}
if _, ok := cfg.Branches[branchShort]; !ok {
cfg.Branches[branchShort] = &config.Branch{Name: branchShort}
}
cfg.Branches[branchShort].Remote = "origin"
cfg.Branches[branchShort].Merge = branchRefName
if err := repo.SetConfig(cfg); err != nil {
return fmt.Errorf("git token store: set branch config: %w", err)
}
return nil
}
func (s *GitTokenStore) commitAndPushLocked(message string, relPaths ...string) error {
repoDir := s.repoDirSnapshot()
if repoDir == "" {
@@ -619,7 +847,16 @@ func (s *GitTokenStore) commitAndPushLocked(message string, relPaths ...string)
return errRewrite
}
s.maybeRunGC(repo)
if err = repo.Push(&git.PushOptions{Auth: s.gitAuth(), Force: true}); err != nil {
pushOpts := &git.PushOptions{Auth: s.gitAuth(), Force: true}
if s.branch != "" {
pushOpts.RefSpecs = []config.RefSpec{config.RefSpec("refs/heads/" + s.branch + ":refs/heads/" + s.branch)}
} else {
// When branch is unset, pin push to the currently checked-out branch.
if headRef, err := repo.Head(); err == nil {
pushOpts.RefSpecs = []config.RefSpec{config.RefSpec(headRef.Name().String() + ":" + headRef.Name().String())}
}
}
if err = repo.Push(pushOpts); err != nil {
if errors.Is(err, git.NoErrAlreadyUpToDate) {
return nil
}

View File

@@ -0,0 +1,585 @@
package store
import (
"net/http"
"net/http/httptest"
"os"
"path/filepath"
"testing"
"time"
"github.com/go-git/go-git/v6"
gitconfig "github.com/go-git/go-git/v6/config"
"github.com/go-git/go-git/v6/plumbing"
"github.com/go-git/go-git/v6/plumbing/object"
)
type testBranchSpec struct {
name string
contents string
}
func TestEnsureRepositoryUsesRemoteDefaultBranchWhenBranchNotConfigured(t *testing.T) {
root := t.TempDir()
remoteDir := setupGitRemoteRepository(t, root, "trunk",
testBranchSpec{name: "trunk", contents: "remote default branch\n"},
testBranchSpec{name: "release/2026", contents: "release branch\n"},
)
store := NewGitTokenStore(remoteDir, "", "", "")
store.SetBaseDir(filepath.Join(root, "workspace", "auths"))
if err := store.EnsureRepository(); err != nil {
t.Fatalf("EnsureRepository: %v", err)
}
assertRepositoryBranchAndContents(t, filepath.Join(root, "workspace"), "trunk", "remote default branch\n")
advanceRemoteBranch(t, filepath.Join(root, "seed"), remoteDir, "trunk", "remote default branch updated\n", "advance trunk")
advanceRemoteBranch(t, filepath.Join(root, "seed"), remoteDir, "release/2026", "release branch updated\n", "advance release")
if err := store.EnsureRepository(); err != nil {
t.Fatalf("EnsureRepository second call: %v", err)
}
assertRepositoryBranchAndContents(t, filepath.Join(root, "workspace"), "trunk", "remote default branch updated\n")
assertRemoteHeadBranch(t, remoteDir, "trunk")
}
func TestEnsureRepositoryUsesConfiguredBranchWhenExplicitlySet(t *testing.T) {
root := t.TempDir()
remoteDir := setupGitRemoteRepository(t, root, "trunk",
testBranchSpec{name: "trunk", contents: "remote default branch\n"},
testBranchSpec{name: "release/2026", contents: "release branch\n"},
)
store := NewGitTokenStore(remoteDir, "", "", "release/2026")
store.SetBaseDir(filepath.Join(root, "workspace", "auths"))
if err := store.EnsureRepository(); err != nil {
t.Fatalf("EnsureRepository: %v", err)
}
assertRepositoryBranchAndContents(t, filepath.Join(root, "workspace"), "release/2026", "release branch\n")
advanceRemoteBranch(t, filepath.Join(root, "seed"), remoteDir, "trunk", "remote default branch updated\n", "advance trunk")
advanceRemoteBranch(t, filepath.Join(root, "seed"), remoteDir, "release/2026", "release branch updated\n", "advance release")
if err := store.EnsureRepository(); err != nil {
t.Fatalf("EnsureRepository second call: %v", err)
}
assertRepositoryBranchAndContents(t, filepath.Join(root, "workspace"), "release/2026", "release branch updated\n")
assertRemoteHeadBranch(t, remoteDir, "trunk")
}
func TestEnsureRepositoryReturnsErrorForMissingConfiguredBranch(t *testing.T) {
root := t.TempDir()
remoteDir := setupGitRemoteRepository(t, root, "trunk",
testBranchSpec{name: "trunk", contents: "remote default branch\n"},
)
store := NewGitTokenStore(remoteDir, "", "", "missing-branch")
store.SetBaseDir(filepath.Join(root, "workspace", "auths"))
err := store.EnsureRepository()
if err == nil {
t.Fatal("EnsureRepository succeeded, want error for nonexistent configured branch")
}
assertRemoteHeadBranch(t, remoteDir, "trunk")
}
func TestEnsureRepositoryReturnsErrorForMissingConfiguredBranchOnExistingRepositoryPull(t *testing.T) {
root := t.TempDir()
remoteDir := setupGitRemoteRepository(t, root, "trunk",
testBranchSpec{name: "trunk", contents: "remote default branch\n"},
)
baseDir := filepath.Join(root, "workspace", "auths")
store := NewGitTokenStore(remoteDir, "", "", "")
store.SetBaseDir(baseDir)
if err := store.EnsureRepository(); err != nil {
t.Fatalf("EnsureRepository initial clone: %v", err)
}
reopened := NewGitTokenStore(remoteDir, "", "", "missing-branch")
reopened.SetBaseDir(baseDir)
err := reopened.EnsureRepository()
if err == nil {
t.Fatal("EnsureRepository succeeded on reopen, want error for nonexistent configured branch")
}
assertRepositoryHeadBranch(t, filepath.Join(root, "workspace"), "trunk")
assertRemoteHeadBranch(t, remoteDir, "trunk")
}
func TestEnsureRepositoryInitializesEmptyRemoteUsingConfiguredBranch(t *testing.T) {
root := t.TempDir()
remoteDir := filepath.Join(root, "remote.git")
if _, err := git.PlainInit(remoteDir, true); err != nil {
t.Fatalf("init bare remote: %v", err)
}
branch := "feature/gemini-fix"
store := NewGitTokenStore(remoteDir, "", "", branch)
store.SetBaseDir(filepath.Join(root, "workspace", "auths"))
if err := store.EnsureRepository(); err != nil {
t.Fatalf("EnsureRepository: %v", err)
}
assertRepositoryHeadBranch(t, filepath.Join(root, "workspace"), branch)
assertRemoteBranchExistsWithCommit(t, remoteDir, branch)
assertRemoteBranchDoesNotExist(t, remoteDir, "master")
}
func TestEnsureRepositoryExistingRepoSwitchesToConfiguredBranch(t *testing.T) {
root := t.TempDir()
remoteDir := setupGitRemoteRepository(t, root, "master",
testBranchSpec{name: "master", contents: "remote master branch\n"},
testBranchSpec{name: "develop", contents: "remote develop branch\n"},
)
baseDir := filepath.Join(root, "workspace", "auths")
store := NewGitTokenStore(remoteDir, "", "", "")
store.SetBaseDir(baseDir)
if err := store.EnsureRepository(); err != nil {
t.Fatalf("EnsureRepository initial clone: %v", err)
}
assertRepositoryBranchAndContents(t, filepath.Join(root, "workspace"), "master", "remote master branch\n")
reopened := NewGitTokenStore(remoteDir, "", "", "develop")
reopened.SetBaseDir(baseDir)
if err := reopened.EnsureRepository(); err != nil {
t.Fatalf("EnsureRepository reopen: %v", err)
}
assertRepositoryBranchAndContents(t, filepath.Join(root, "workspace"), "develop", "remote develop branch\n")
workspaceDir := filepath.Join(root, "workspace")
if err := os.WriteFile(filepath.Join(workspaceDir, "branch.txt"), []byte("local develop update\n"), 0o600); err != nil {
t.Fatalf("write local branch marker: %v", err)
}
reopened.mu.Lock()
err := reopened.commitAndPushLocked("Update develop branch marker", "branch.txt")
reopened.mu.Unlock()
if err != nil {
t.Fatalf("commitAndPushLocked: %v", err)
}
assertRepositoryHeadBranch(t, workspaceDir, "develop")
assertRemoteBranchContents(t, remoteDir, "develop", "local develop update\n")
assertRemoteBranchContents(t, remoteDir, "master", "remote master branch\n")
}
func TestEnsureRepositoryExistingRepoSwitchesToConfiguredBranchCreatedAfterClone(t *testing.T) {
root := t.TempDir()
remoteDir := setupGitRemoteRepository(t, root, "master",
testBranchSpec{name: "master", contents: "remote master branch\n"},
)
baseDir := filepath.Join(root, "workspace", "auths")
store := NewGitTokenStore(remoteDir, "", "", "")
store.SetBaseDir(baseDir)
if err := store.EnsureRepository(); err != nil {
t.Fatalf("EnsureRepository initial clone: %v", err)
}
assertRepositoryBranchAndContents(t, filepath.Join(root, "workspace"), "master", "remote master branch\n")
advanceRemoteBranchFromNewBranch(t, filepath.Join(root, "seed"), remoteDir, "release/2026", "release branch\n", "create release")
reopened := NewGitTokenStore(remoteDir, "", "", "release/2026")
reopened.SetBaseDir(baseDir)
if err := reopened.EnsureRepository(); err != nil {
t.Fatalf("EnsureRepository reopen: %v", err)
}
assertRepositoryBranchAndContents(t, filepath.Join(root, "workspace"), "release/2026", "release branch\n")
}
func TestEnsureRepositoryResetsToRemoteDefaultWhenBranchUnset(t *testing.T) {
root := t.TempDir()
remoteDir := setupGitRemoteRepository(t, root, "master",
testBranchSpec{name: "master", contents: "remote master branch\n"},
testBranchSpec{name: "develop", contents: "remote develop branch\n"},
)
baseDir := filepath.Join(root, "workspace", "auths")
// First store pins to develop and prepares local workspace
storePinned := NewGitTokenStore(remoteDir, "", "", "develop")
storePinned.SetBaseDir(baseDir)
if err := storePinned.EnsureRepository(); err != nil {
t.Fatalf("EnsureRepository pinned: %v", err)
}
assertRepositoryBranchAndContents(t, filepath.Join(root, "workspace"), "develop", "remote develop branch\n")
// Second store has branch unset and should reset local workspace to remote default (master)
storeDefault := NewGitTokenStore(remoteDir, "", "", "")
storeDefault.SetBaseDir(baseDir)
if err := storeDefault.EnsureRepository(); err != nil {
t.Fatalf("EnsureRepository default: %v", err)
}
// Local HEAD should now follow remote default (master)
assertRepositoryHeadBranch(t, filepath.Join(root, "workspace"), "master")
// Make a local change and push using the store with branch unset; push should update remote master
workspaceDir := filepath.Join(root, "workspace")
if err := os.WriteFile(filepath.Join(workspaceDir, "branch.txt"), []byte("local master update\n"), 0o600); err != nil {
t.Fatalf("write local master marker: %v", err)
}
storeDefault.mu.Lock()
if err := storeDefault.commitAndPushLocked("Update master marker", "branch.txt"); err != nil {
storeDefault.mu.Unlock()
t.Fatalf("commitAndPushLocked: %v", err)
}
storeDefault.mu.Unlock()
assertRemoteBranchContents(t, remoteDir, "master", "local master update\n")
}
func TestEnsureRepositoryFollowsRenamedRemoteDefaultBranchWhenAvailable(t *testing.T) {
root := t.TempDir()
remoteDir := setupGitRemoteRepository(t, root, "master",
testBranchSpec{name: "master", contents: "remote master branch\n"},
testBranchSpec{name: "main", contents: "remote main branch\n"},
)
baseDir := filepath.Join(root, "workspace", "auths")
store := NewGitTokenStore(remoteDir, "", "", "")
store.SetBaseDir(baseDir)
if err := store.EnsureRepository(); err != nil {
t.Fatalf("EnsureRepository initial clone: %v", err)
}
assertRepositoryBranchAndContents(t, filepath.Join(root, "workspace"), "master", "remote master branch\n")
setRemoteHeadBranch(t, remoteDir, "main")
advanceRemoteBranch(t, filepath.Join(root, "seed"), remoteDir, "main", "remote main branch updated\n", "advance main")
reopened := NewGitTokenStore(remoteDir, "", "", "")
reopened.SetBaseDir(baseDir)
if err := reopened.EnsureRepository(); err != nil {
t.Fatalf("EnsureRepository after remote default rename: %v", err)
}
assertRepositoryBranchAndContents(t, filepath.Join(root, "workspace"), "main", "remote main branch updated\n")
assertRemoteHeadBranch(t, remoteDir, "main")
}
func TestEnsureRepositoryKeepsCurrentBranchWhenRemoteDefaultCannotBeResolved(t *testing.T) {
root := t.TempDir()
remoteDir := setupGitRemoteRepository(t, root, "master",
testBranchSpec{name: "master", contents: "remote master branch\n"},
testBranchSpec{name: "develop", contents: "remote develop branch\n"},
)
baseDir := filepath.Join(root, "workspace", "auths")
pinned := NewGitTokenStore(remoteDir, "", "", "develop")
pinned.SetBaseDir(baseDir)
if err := pinned.EnsureRepository(); err != nil {
t.Fatalf("EnsureRepository pinned: %v", err)
}
assertRepositoryBranchAndContents(t, filepath.Join(root, "workspace"), "develop", "remote develop branch\n")
authServer := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
w.Header().Set("WWW-Authenticate", `Basic realm="git"`)
http.Error(w, "auth required", http.StatusUnauthorized)
}))
defer authServer.Close()
repo, err := git.PlainOpen(filepath.Join(root, "workspace"))
if err != nil {
t.Fatalf("open workspace repo: %v", err)
}
cfg, err := repo.Config()
if err != nil {
t.Fatalf("read repo config: %v", err)
}
cfg.Remotes["origin"].URLs = []string{authServer.URL}
if err := repo.SetConfig(cfg); err != nil {
t.Fatalf("set repo config: %v", err)
}
reopened := NewGitTokenStore(remoteDir, "", "", "")
reopened.SetBaseDir(baseDir)
if err := reopened.EnsureRepository(); err != nil {
t.Fatalf("EnsureRepository default branch fallback: %v", err)
}
assertRepositoryHeadBranch(t, filepath.Join(root, "workspace"), "develop")
}
func setupGitRemoteRepository(t *testing.T, root, defaultBranch string, branches ...testBranchSpec) string {
t.Helper()
remoteDir := filepath.Join(root, "remote.git")
if _, err := git.PlainInit(remoteDir, true); err != nil {
t.Fatalf("init bare remote: %v", err)
}
seedDir := filepath.Join(root, "seed")
seedRepo, err := git.PlainInit(seedDir, false)
if err != nil {
t.Fatalf("init seed repo: %v", err)
}
if err := seedRepo.Storer.SetReference(plumbing.NewSymbolicReference(plumbing.HEAD, plumbing.NewBranchReferenceName(defaultBranch))); err != nil {
t.Fatalf("set seed HEAD: %v", err)
}
worktree, err := seedRepo.Worktree()
if err != nil {
t.Fatalf("open seed worktree: %v", err)
}
defaultSpec, ok := findBranchSpec(branches, defaultBranch)
if !ok {
t.Fatalf("missing default branch spec for %q", defaultBranch)
}
commitBranchMarker(t, seedDir, worktree, defaultSpec, "seed default branch")
for _, branch := range branches {
if branch.name == defaultBranch {
continue
}
if err := worktree.Checkout(&git.CheckoutOptions{Branch: plumbing.NewBranchReferenceName(defaultBranch)}); err != nil {
t.Fatalf("checkout default branch %s: %v", defaultBranch, err)
}
if err := worktree.Checkout(&git.CheckoutOptions{Branch: plumbing.NewBranchReferenceName(branch.name), Create: true}); err != nil {
t.Fatalf("create branch %s: %v", branch.name, err)
}
commitBranchMarker(t, seedDir, worktree, branch, "seed branch "+branch.name)
}
if _, err := seedRepo.CreateRemote(&gitconfig.RemoteConfig{Name: "origin", URLs: []string{remoteDir}}); err != nil {
t.Fatalf("create origin remote: %v", err)
}
if err := seedRepo.Push(&git.PushOptions{
RemoteName: "origin",
RefSpecs: []gitconfig.RefSpec{gitconfig.RefSpec("refs/heads/*:refs/heads/*")},
}); err != nil {
t.Fatalf("push seed branches: %v", err)
}
remoteRepo, err := git.PlainOpen(remoteDir)
if err != nil {
t.Fatalf("open remote repo: %v", err)
}
if err := remoteRepo.Storer.SetReference(plumbing.NewSymbolicReference(plumbing.HEAD, plumbing.NewBranchReferenceName(defaultBranch))); err != nil {
t.Fatalf("set remote HEAD: %v", err)
}
return remoteDir
}
func commitBranchMarker(t *testing.T, seedDir string, worktree *git.Worktree, branch testBranchSpec, message string) {
t.Helper()
if err := os.WriteFile(filepath.Join(seedDir, "branch.txt"), []byte(branch.contents), 0o600); err != nil {
t.Fatalf("write branch marker for %s: %v", branch.name, err)
}
if _, err := worktree.Add("branch.txt"); err != nil {
t.Fatalf("add branch marker for %s: %v", branch.name, err)
}
if _, err := worktree.Commit(message, &git.CommitOptions{
Author: &object.Signature{
Name: "CLIProxyAPI",
Email: "cliproxy@local",
When: time.Unix(1711929600, 0),
},
}); err != nil {
t.Fatalf("commit branch marker for %s: %v", branch.name, err)
}
}
func advanceRemoteBranch(t *testing.T, seedDir, remoteDir, branch, contents, message string) {
t.Helper()
seedRepo, err := git.PlainOpen(seedDir)
if err != nil {
t.Fatalf("open seed repo: %v", err)
}
worktree, err := seedRepo.Worktree()
if err != nil {
t.Fatalf("open seed worktree: %v", err)
}
if err := worktree.Checkout(&git.CheckoutOptions{Branch: plumbing.NewBranchReferenceName(branch)}); err != nil {
t.Fatalf("checkout branch %s: %v", branch, err)
}
commitBranchMarker(t, seedDir, worktree, testBranchSpec{name: branch, contents: contents}, message)
if err := seedRepo.Push(&git.PushOptions{
RemoteName: "origin",
RefSpecs: []gitconfig.RefSpec{
gitconfig.RefSpec(plumbing.NewBranchReferenceName(branch).String() + ":" + plumbing.NewBranchReferenceName(branch).String()),
},
}); err != nil {
t.Fatalf("push branch %s update to %s: %v", branch, remoteDir, err)
}
}
func advanceRemoteBranchFromNewBranch(t *testing.T, seedDir, remoteDir, branch, contents, message string) {
t.Helper()
seedRepo, err := git.PlainOpen(seedDir)
if err != nil {
t.Fatalf("open seed repo: %v", err)
}
worktree, err := seedRepo.Worktree()
if err != nil {
t.Fatalf("open seed worktree: %v", err)
}
if err := worktree.Checkout(&git.CheckoutOptions{Branch: plumbing.NewBranchReferenceName("master")}); err != nil {
t.Fatalf("checkout master before creating %s: %v", branch, err)
}
if err := worktree.Checkout(&git.CheckoutOptions{Branch: plumbing.NewBranchReferenceName(branch), Create: true}); err != nil {
t.Fatalf("create branch %s: %v", branch, err)
}
commitBranchMarker(t, seedDir, worktree, testBranchSpec{name: branch, contents: contents}, message)
if err := seedRepo.Push(&git.PushOptions{
RemoteName: "origin",
RefSpecs: []gitconfig.RefSpec{
gitconfig.RefSpec(plumbing.NewBranchReferenceName(branch).String() + ":" + plumbing.NewBranchReferenceName(branch).String()),
},
}); err != nil {
t.Fatalf("push new branch %s update to %s: %v", branch, remoteDir, err)
}
}
func findBranchSpec(branches []testBranchSpec, name string) (testBranchSpec, bool) {
for _, branch := range branches {
if branch.name == name {
return branch, true
}
}
return testBranchSpec{}, false
}
func assertRepositoryBranchAndContents(t *testing.T, repoDir, branch, wantContents string) {
t.Helper()
repo, err := git.PlainOpen(repoDir)
if err != nil {
t.Fatalf("open local repo: %v", err)
}
head, err := repo.Head()
if err != nil {
t.Fatalf("local repo head: %v", err)
}
if got, want := head.Name(), plumbing.NewBranchReferenceName(branch); got != want {
t.Fatalf("local head branch = %s, want %s", got, want)
}
contents, err := os.ReadFile(filepath.Join(repoDir, "branch.txt"))
if err != nil {
t.Fatalf("read branch marker: %v", err)
}
if got := string(contents); got != wantContents {
t.Fatalf("branch marker contents = %q, want %q", got, wantContents)
}
}
func assertRepositoryHeadBranch(t *testing.T, repoDir, branch string) {
t.Helper()
repo, err := git.PlainOpen(repoDir)
if err != nil {
t.Fatalf("open local repo: %v", err)
}
head, err := repo.Head()
if err != nil {
t.Fatalf("local repo head: %v", err)
}
if got, want := head.Name(), plumbing.NewBranchReferenceName(branch); got != want {
t.Fatalf("local head branch = %s, want %s", got, want)
}
}
func assertRemoteHeadBranch(t *testing.T, remoteDir, branch string) {
t.Helper()
remoteRepo, err := git.PlainOpen(remoteDir)
if err != nil {
t.Fatalf("open remote repo: %v", err)
}
head, err := remoteRepo.Reference(plumbing.HEAD, false)
if err != nil {
t.Fatalf("read remote HEAD: %v", err)
}
if got, want := head.Target(), plumbing.NewBranchReferenceName(branch); got != want {
t.Fatalf("remote HEAD target = %s, want %s", got, want)
}
}
func setRemoteHeadBranch(t *testing.T, remoteDir, branch string) {
t.Helper()
remoteRepo, err := git.PlainOpen(remoteDir)
if err != nil {
t.Fatalf("open remote repo: %v", err)
}
if err := remoteRepo.Storer.SetReference(plumbing.NewSymbolicReference(plumbing.HEAD, plumbing.NewBranchReferenceName(branch))); err != nil {
t.Fatalf("set remote HEAD to %s: %v", branch, err)
}
}
func assertRemoteBranchExistsWithCommit(t *testing.T, remoteDir, branch string) {
t.Helper()
remoteRepo, err := git.PlainOpen(remoteDir)
if err != nil {
t.Fatalf("open remote repo: %v", err)
}
ref, err := remoteRepo.Reference(plumbing.NewBranchReferenceName(branch), false)
if err != nil {
t.Fatalf("read remote branch %s: %v", branch, err)
}
if got := ref.Hash(); got == plumbing.ZeroHash {
t.Fatalf("remote branch %s hash = %s, want non-zero hash", branch, got)
}
}
func assertRemoteBranchDoesNotExist(t *testing.T, remoteDir, branch string) {
t.Helper()
remoteRepo, err := git.PlainOpen(remoteDir)
if err != nil {
t.Fatalf("open remote repo: %v", err)
}
if _, err := remoteRepo.Reference(plumbing.NewBranchReferenceName(branch), false); err == nil {
t.Fatalf("remote branch %s exists, want missing", branch)
} else if err != plumbing.ErrReferenceNotFound {
t.Fatalf("read remote branch %s: %v", branch, err)
}
}
func assertRemoteBranchContents(t *testing.T, remoteDir, branch, wantContents string) {
t.Helper()
remoteRepo, err := git.PlainOpen(remoteDir)
if err != nil {
t.Fatalf("open remote repo: %v", err)
}
ref, err := remoteRepo.Reference(plumbing.NewBranchReferenceName(branch), false)
if err != nil {
t.Fatalf("read remote branch %s: %v", branch, err)
}
commit, err := remoteRepo.CommitObject(ref.Hash())
if err != nil {
t.Fatalf("read remote branch %s commit: %v", branch, err)
}
tree, err := commit.Tree()
if err != nil {
t.Fatalf("read remote branch %s tree: %v", branch, err)
}
file, err := tree.File("branch.txt")
if err != nil {
t.Fatalf("read remote branch %s file: %v", branch, err)
}
contents, err := file.Contents()
if err != nil {
t.Fatalf("read remote branch %s contents: %v", branch, err)
}
if contents != wantContents {
t.Fatalf("remote branch %s contents = %q, want %q", branch, contents, wantContents)
}
}

View File

@@ -174,7 +174,8 @@ func (a *Applier) normalizeClaudeBudget(body []byte, budgetTokens int, modelInfo
// Ensure the request satisfies Claude constraints:
// 1) Determine effective max_tokens (request overrides model default)
// 2) If budget_tokens >= max_tokens, reduce budget_tokens to max_tokens-1
// 3) If the adjusted budget falls below the model minimum, leave the request unchanged
// 3) If the adjusted budget falls below the model minimum, try raising max_tokens
// (clamped to MaxCompletionTokens); disable thinking if constraints are unsatisfiable
// 4) If max_tokens came from model default, write it back into the request
effectiveMax, setDefaultMax := a.effectiveMaxTokens(body, modelInfo)
@@ -193,8 +194,28 @@ func (a *Applier) normalizeClaudeBudget(body []byte, budgetTokens int, modelInfo
minBudget = modelInfo.Thinking.Min
}
if minBudget > 0 && adjustedBudget > 0 && adjustedBudget < minBudget {
// If enforcing the max_tokens constraint would push the budget below the model minimum,
// leave the request unchanged.
// Enforcing budget_tokens < max_tokens pushed the budget below the model minimum.
// Try raising max_tokens to fit the original budget.
needed := budgetTokens + 1
maxAllowed := 0
if modelInfo != nil {
maxAllowed = modelInfo.MaxCompletionTokens
}
if maxAllowed > 0 && needed > maxAllowed {
// Cannot use original budget; cap max_tokens at model limit.
needed = maxAllowed
}
cappedBudget := needed - 1
if cappedBudget < minBudget {
// Impossible to satisfy both budget >= minBudget and budget < max_tokens
// within the model's completion limit. Disable thinking entirely.
body, _ = sjson.DeleteBytes(body, "thinking")
return body
}
body, _ = sjson.SetBytes(body, "max_tokens", needed)
if cappedBudget != budgetTokens {
body, _ = sjson.SetBytes(body, "thinking.budget_tokens", cappedBudget)
}
return body
}

View File

@@ -0,0 +1,99 @@
package claude
import (
"testing"
"github.com/router-for-me/CLIProxyAPI/v6/internal/registry"
"github.com/tidwall/gjson"
)
func TestNormalizeClaudeBudget_RaisesMaxTokens(t *testing.T) {
a := &Applier{}
modelInfo := &registry.ModelInfo{
MaxCompletionTokens: 64000,
Thinking: &registry.ThinkingSupport{Min: 1024, Max: 128000},
}
body := []byte(`{"max_tokens":1000,"thinking":{"type":"enabled","budget_tokens":5000}}`)
out := a.normalizeClaudeBudget(body, 5000, modelInfo)
maxTok := gjson.GetBytes(out, "max_tokens").Int()
if maxTok != 5001 {
t.Fatalf("max_tokens = %d, want 5001, body=%s", maxTok, string(out))
}
}
func TestNormalizeClaudeBudget_ClampsToModelMax(t *testing.T) {
a := &Applier{}
modelInfo := &registry.ModelInfo{
MaxCompletionTokens: 64000,
Thinking: &registry.ThinkingSupport{Min: 1024, Max: 128000},
}
body := []byte(`{"max_tokens":500,"thinking":{"type":"enabled","budget_tokens":200000}}`)
out := a.normalizeClaudeBudget(body, 200000, modelInfo)
maxTok := gjson.GetBytes(out, "max_tokens").Int()
if maxTok != 64000 {
t.Fatalf("max_tokens = %d, want 64000 (capped to model limit), body=%s", maxTok, string(out))
}
budget := gjson.GetBytes(out, "thinking.budget_tokens").Int()
if budget != 63999 {
t.Fatalf("budget_tokens = %d, want 63999 (max_tokens-1), body=%s", budget, string(out))
}
}
func TestNormalizeClaudeBudget_DisablesThinkingWhenUnsatisfiable(t *testing.T) {
a := &Applier{}
modelInfo := &registry.ModelInfo{
MaxCompletionTokens: 1000,
Thinking: &registry.ThinkingSupport{Min: 1024, Max: 128000},
}
body := []byte(`{"max_tokens":500,"thinking":{"type":"enabled","budget_tokens":2000}}`)
out := a.normalizeClaudeBudget(body, 2000, modelInfo)
if gjson.GetBytes(out, "thinking").Exists() {
t.Fatalf("thinking should be removed when constraints are unsatisfiable, body=%s", string(out))
}
}
func TestNormalizeClaudeBudget_NoClamping(t *testing.T) {
a := &Applier{}
modelInfo := &registry.ModelInfo{
MaxCompletionTokens: 64000,
Thinking: &registry.ThinkingSupport{Min: 1024, Max: 128000},
}
body := []byte(`{"max_tokens":32000,"thinking":{"type":"enabled","budget_tokens":16000}}`)
out := a.normalizeClaudeBudget(body, 16000, modelInfo)
maxTok := gjson.GetBytes(out, "max_tokens").Int()
if maxTok != 32000 {
t.Fatalf("max_tokens should remain 32000, got %d, body=%s", maxTok, string(out))
}
budget := gjson.GetBytes(out, "thinking.budget_tokens").Int()
if budget != 16000 {
t.Fatalf("budget_tokens should remain 16000, got %d, body=%s", budget, string(out))
}
}
func TestNormalizeClaudeBudget_AdjustsBudgetToMaxMinus1(t *testing.T) {
a := &Applier{}
modelInfo := &registry.ModelInfo{
MaxCompletionTokens: 8192,
Thinking: &registry.ThinkingSupport{Min: 1024, Max: 128000},
}
body := []byte(`{"max_tokens":8192,"thinking":{"type":"enabled","budget_tokens":10000}}`)
out := a.normalizeClaudeBudget(body, 10000, modelInfo)
maxTok := gjson.GetBytes(out, "max_tokens").Int()
if maxTok != 8192 {
t.Fatalf("max_tokens = %d, want 8192 (unchanged), body=%s", maxTok, string(out))
}
budget := gjson.GetBytes(out, "thinking.budget_tokens").Int()
if budget != 8191 {
t.Fatalf("budget_tokens = %d, want 8191 (max_tokens-1), body=%s", budget, string(out))
}
}

View File

@@ -26,6 +26,9 @@ type ConvertCodexResponseToClaudeParams struct {
HasToolCall bool
BlockIndex int
HasReceivedArgumentsDelta bool
ThinkingBlockOpen bool
ThinkingStopPending bool
ThinkingSignature string
}
// ConvertCodexResponseToClaude performs sophisticated streaming response format conversion.
@@ -44,7 +47,7 @@ type ConvertCodexResponseToClaudeParams struct {
//
// Returns:
// - [][]byte: A slice of Claude Code-compatible JSON responses
func ConvertCodexResponseToClaude(_ context.Context, _ string, originalRequestRawJSON, requestRawJSON, rawJSON []byte, param *any) [][]byte {
func ConvertCodexResponseToClaude(_ context.Context, _ string, originalRequestRawJSON, _ []byte, rawJSON []byte, param *any) [][]byte {
if *param == nil {
*param = &ConvertCodexResponseToClaudeParams{
HasToolCall: false,
@@ -52,7 +55,6 @@ func ConvertCodexResponseToClaude(_ context.Context, _ string, originalRequestRa
}
}
// log.Debugf("rawJSON: %s", string(rawJSON))
if !bytes.HasPrefix(rawJSON, dataTag) {
return [][]byte{}
}
@@ -60,9 +62,18 @@ func ConvertCodexResponseToClaude(_ context.Context, _ string, originalRequestRa
output := make([]byte, 0, 512)
rootResult := gjson.ParseBytes(rawJSON)
params := (*param).(*ConvertCodexResponseToClaudeParams)
if params.ThinkingBlockOpen && params.ThinkingStopPending {
switch rootResult.Get("type").String() {
case "response.content_part.added", "response.completed":
output = append(output, finalizeCodexThinkingBlock(params)...)
}
}
typeResult := rootResult.Get("type")
typeStr := typeResult.String()
var template []byte
if typeStr == "response.created" {
template = []byte(`{"type":"message_start","message":{"id":"","type":"message","role":"assistant","model":"claude-opus-4-1-20250805","stop_sequence":null,"usage":{"input_tokens":0,"output_tokens":0},"content":[],"stop_reason":null}}`)
template, _ = sjson.SetBytes(template, "message.model", rootResult.Get("response.model").String())
@@ -70,43 +81,46 @@ func ConvertCodexResponseToClaude(_ context.Context, _ string, originalRequestRa
output = translatorcommon.AppendSSEEventBytes(output, "message_start", template, 2)
} else if typeStr == "response.reasoning_summary_part.added" {
if params.ThinkingBlockOpen && params.ThinkingStopPending {
output = append(output, finalizeCodexThinkingBlock(params)...)
}
template = []byte(`{"type":"content_block_start","index":0,"content_block":{"type":"thinking","thinking":""}}`)
template, _ = sjson.SetBytes(template, "index", (*param).(*ConvertCodexResponseToClaudeParams).BlockIndex)
template, _ = sjson.SetBytes(template, "index", params.BlockIndex)
params.ThinkingBlockOpen = true
params.ThinkingStopPending = false
output = translatorcommon.AppendSSEEventBytes(output, "content_block_start", template, 2)
} else if typeStr == "response.reasoning_summary_text.delta" {
template = []byte(`{"type":"content_block_delta","index":0,"delta":{"type":"thinking_delta","thinking":""}}`)
template, _ = sjson.SetBytes(template, "index", (*param).(*ConvertCodexResponseToClaudeParams).BlockIndex)
template, _ = sjson.SetBytes(template, "index", params.BlockIndex)
template, _ = sjson.SetBytes(template, "delta.thinking", rootResult.Get("delta").String())
output = translatorcommon.AppendSSEEventBytes(output, "content_block_delta", template, 2)
} else if typeStr == "response.reasoning_summary_part.done" {
template = []byte(`{"type":"content_block_stop","index":0}`)
template, _ = sjson.SetBytes(template, "index", (*param).(*ConvertCodexResponseToClaudeParams).BlockIndex)
(*param).(*ConvertCodexResponseToClaudeParams).BlockIndex++
output = translatorcommon.AppendSSEEventBytes(output, "content_block_stop", template, 2)
params.ThinkingStopPending = true
if params.ThinkingSignature != "" {
output = append(output, finalizeCodexThinkingBlock(params)...)
}
} else if typeStr == "response.content_part.added" {
template = []byte(`{"type":"content_block_start","index":0,"content_block":{"type":"text","text":""}}`)
template, _ = sjson.SetBytes(template, "index", (*param).(*ConvertCodexResponseToClaudeParams).BlockIndex)
template, _ = sjson.SetBytes(template, "index", params.BlockIndex)
output = translatorcommon.AppendSSEEventBytes(output, "content_block_start", template, 2)
} else if typeStr == "response.output_text.delta" {
template = []byte(`{"type":"content_block_delta","index":0,"delta":{"type":"text_delta","text":""}}`)
template, _ = sjson.SetBytes(template, "index", (*param).(*ConvertCodexResponseToClaudeParams).BlockIndex)
template, _ = sjson.SetBytes(template, "index", params.BlockIndex)
template, _ = sjson.SetBytes(template, "delta.text", rootResult.Get("delta").String())
output = translatorcommon.AppendSSEEventBytes(output, "content_block_delta", template, 2)
} else if typeStr == "response.content_part.done" {
template = []byte(`{"type":"content_block_stop","index":0}`)
template, _ = sjson.SetBytes(template, "index", (*param).(*ConvertCodexResponseToClaudeParams).BlockIndex)
(*param).(*ConvertCodexResponseToClaudeParams).BlockIndex++
template, _ = sjson.SetBytes(template, "index", params.BlockIndex)
params.BlockIndex++
output = translatorcommon.AppendSSEEventBytes(output, "content_block_stop", template, 2)
} else if typeStr == "response.completed" {
template = []byte(`{"type":"message_delta","delta":{"stop_reason":"tool_use","stop_sequence":null},"usage":{"input_tokens":0,"output_tokens":0}}`)
p := (*param).(*ConvertCodexResponseToClaudeParams).HasToolCall
p := params.HasToolCall
stopReason := rootResult.Get("response.stop_reason").String()
if p {
template, _ = sjson.SetBytes(template, "delta.stop_reason", "tool_use")
@@ -128,13 +142,13 @@ func ConvertCodexResponseToClaude(_ context.Context, _ string, originalRequestRa
itemResult := rootResult.Get("item")
itemType := itemResult.Get("type").String()
if itemType == "function_call" {
(*param).(*ConvertCodexResponseToClaudeParams).HasToolCall = true
(*param).(*ConvertCodexResponseToClaudeParams).HasReceivedArgumentsDelta = false
output = append(output, finalizeCodexThinkingBlock(params)...)
params.HasToolCall = true
params.HasReceivedArgumentsDelta = false
template = []byte(`{"type":"content_block_start","index":0,"content_block":{"type":"tool_use","id":"","name":"","input":{}}}`)
template, _ = sjson.SetBytes(template, "index", (*param).(*ConvertCodexResponseToClaudeParams).BlockIndex)
template, _ = sjson.SetBytes(template, "index", params.BlockIndex)
template, _ = sjson.SetBytes(template, "content_block.id", util.SanitizeClaudeToolID(itemResult.Get("call_id").String()))
{
// Restore original tool name if shortened
name := itemResult.Get("name").String()
rev := buildReverseMapFromClaudeOriginalShortToOriginal(originalRequestRawJSON)
if orig, ok := rev[name]; ok {
@@ -146,37 +160,43 @@ func ConvertCodexResponseToClaude(_ context.Context, _ string, originalRequestRa
output = translatorcommon.AppendSSEEventBytes(output, "content_block_start", template, 2)
template = []byte(`{"type":"content_block_delta","index":0,"delta":{"type":"input_json_delta","partial_json":""}}`)
template, _ = sjson.SetBytes(template, "index", (*param).(*ConvertCodexResponseToClaudeParams).BlockIndex)
template, _ = sjson.SetBytes(template, "index", params.BlockIndex)
output = translatorcommon.AppendSSEEventBytes(output, "content_block_delta", template, 2)
} else if itemType == "reasoning" {
params.ThinkingSignature = itemResult.Get("encrypted_content").String()
if params.ThinkingStopPending {
output = append(output, finalizeCodexThinkingBlock(params)...)
}
}
} else if typeStr == "response.output_item.done" {
itemResult := rootResult.Get("item")
itemType := itemResult.Get("type").String()
if itemType == "function_call" {
template = []byte(`{"type":"content_block_stop","index":0}`)
template, _ = sjson.SetBytes(template, "index", (*param).(*ConvertCodexResponseToClaudeParams).BlockIndex)
(*param).(*ConvertCodexResponseToClaudeParams).BlockIndex++
template, _ = sjson.SetBytes(template, "index", params.BlockIndex)
params.BlockIndex++
output = translatorcommon.AppendSSEEventBytes(output, "content_block_stop", template, 2)
} else if itemType == "reasoning" {
if signature := itemResult.Get("encrypted_content").String(); signature != "" {
params.ThinkingSignature = signature
}
output = append(output, finalizeCodexThinkingBlock(params)...)
params.ThinkingSignature = ""
}
} else if typeStr == "response.function_call_arguments.delta" {
(*param).(*ConvertCodexResponseToClaudeParams).HasReceivedArgumentsDelta = true
params.HasReceivedArgumentsDelta = true
template = []byte(`{"type":"content_block_delta","index":0,"delta":{"type":"input_json_delta","partial_json":""}}`)
template, _ = sjson.SetBytes(template, "index", (*param).(*ConvertCodexResponseToClaudeParams).BlockIndex)
template, _ = sjson.SetBytes(template, "index", params.BlockIndex)
template, _ = sjson.SetBytes(template, "delta.partial_json", rootResult.Get("delta").String())
output = translatorcommon.AppendSSEEventBytes(output, "content_block_delta", template, 2)
} else if typeStr == "response.function_call_arguments.done" {
// Some models (e.g. gpt-5.3-codex-spark) send function call arguments
// in a single "done" event without preceding "delta" events.
// Emit the full arguments as a single input_json_delta so the
// downstream Claude client receives the complete tool input.
// When delta events were already received, skip to avoid duplicating arguments.
if !(*param).(*ConvertCodexResponseToClaudeParams).HasReceivedArgumentsDelta {
if !params.HasReceivedArgumentsDelta {
if args := rootResult.Get("arguments").String(); args != "" {
template = []byte(`{"type":"content_block_delta","index":0,"delta":{"type":"input_json_delta","partial_json":""}}`)
template, _ = sjson.SetBytes(template, "index", (*param).(*ConvertCodexResponseToClaudeParams).BlockIndex)
template, _ = sjson.SetBytes(template, "index", params.BlockIndex)
template, _ = sjson.SetBytes(template, "delta.partial_json", args)
output = translatorcommon.AppendSSEEventBytes(output, "content_block_delta", template, 2)
@@ -191,15 +211,6 @@ func ConvertCodexResponseToClaude(_ context.Context, _ string, originalRequestRa
// This function processes the complete Codex response and transforms it into a single Claude Code-compatible
// JSON response. It handles message content, tool calls, reasoning content, and usage metadata, combining all
// the information into a single response that matches the Claude Code API format.
//
// Parameters:
// - ctx: The context for the request, used for cancellation and timeout handling
// - modelName: The name of the model being used for the response (unused in current implementation)
// - rawJSON: The raw JSON response from the Codex API
// - param: A pointer to a parameter object for the conversion (unused in current implementation)
//
// Returns:
// - []byte: A Claude Code-compatible JSON response containing all message content and metadata
func ConvertCodexResponseToClaudeNonStream(_ context.Context, _ string, originalRequestRawJSON, _ []byte, rawJSON []byte, _ *any) []byte {
revNames := buildReverseMapFromClaudeOriginalShortToOriginal(originalRequestRawJSON)
@@ -230,6 +241,7 @@ func ConvertCodexResponseToClaudeNonStream(_ context.Context, _ string, original
switch item.Get("type").String() {
case "reasoning":
thinkingBuilder := strings.Builder{}
signature := item.Get("encrypted_content").String()
if summary := item.Get("summary"); summary.Exists() {
if summary.IsArray() {
summary.ForEach(func(_, part gjson.Result) bool {
@@ -260,9 +272,12 @@ func ConvertCodexResponseToClaudeNonStream(_ context.Context, _ string, original
}
}
}
if thinkingBuilder.Len() > 0 {
if thinkingBuilder.Len() > 0 || signature != "" {
block := []byte(`{"type":"thinking","thinking":""}`)
block, _ = sjson.SetBytes(block, "thinking", thinkingBuilder.String())
if signature != "" {
block, _ = sjson.SetBytes(block, "signature", signature)
}
out, _ = sjson.SetRawBytes(out, "content.-1", block)
}
case "message":
@@ -371,6 +386,30 @@ func buildReverseMapFromClaudeOriginalShortToOriginal(original []byte) map[strin
return rev
}
func ClaudeTokenCount(ctx context.Context, count int64) []byte {
func ClaudeTokenCount(_ context.Context, count int64) []byte {
return translatorcommon.ClaudeInputTokensJSON(count)
}
func finalizeCodexThinkingBlock(params *ConvertCodexResponseToClaudeParams) []byte {
if !params.ThinkingBlockOpen {
return nil
}
output := make([]byte, 0, 256)
if params.ThinkingSignature != "" {
signatureDelta := []byte(`{"type":"content_block_delta","index":0,"delta":{"type":"signature_delta","signature":""}}`)
signatureDelta, _ = sjson.SetBytes(signatureDelta, "index", params.BlockIndex)
signatureDelta, _ = sjson.SetBytes(signatureDelta, "delta.signature", params.ThinkingSignature)
output = translatorcommon.AppendSSEEventBytes(output, "content_block_delta", signatureDelta, 2)
}
contentBlockStop := []byte(`{"type":"content_block_stop","index":0}`)
contentBlockStop, _ = sjson.SetBytes(contentBlockStop, "index", params.BlockIndex)
output = translatorcommon.AppendSSEEventBytes(output, "content_block_stop", contentBlockStop, 2)
params.BlockIndex++
params.ThinkingBlockOpen = false
params.ThinkingStopPending = false
return output
}

View File

@@ -0,0 +1,282 @@
package claude
import (
"context"
"strings"
"testing"
"github.com/tidwall/gjson"
)
func TestConvertCodexResponseToClaude_StreamThinkingIncludesSignature(t *testing.T) {
ctx := context.Background()
originalRequest := []byte(`{"messages":[]}`)
var param any
chunks := [][]byte{
[]byte("data: {\"type\":\"response.created\",\"response\":{\"id\":\"resp_123\",\"model\":\"gpt-5\"}}"),
[]byte("data: {\"type\":\"response.reasoning_summary_part.added\"}"),
[]byte("data: {\"type\":\"response.reasoning_summary_text.delta\",\"delta\":\"Let me think\"}"),
[]byte("data: {\"type\":\"response.reasoning_summary_part.done\"}"),
[]byte("data: {\"type\":\"response.output_item.done\",\"item\":{\"type\":\"reasoning\",\"encrypted_content\":\"enc_sig_123\"}}"),
}
var outputs [][]byte
for _, chunk := range chunks {
outputs = append(outputs, ConvertCodexResponseToClaude(ctx, "", originalRequest, nil, chunk, &param)...)
}
startFound := false
signatureDeltaFound := false
stopFound := false
for _, out := range outputs {
for _, line := range strings.Split(string(out), "\n") {
if !strings.HasPrefix(line, "data: ") {
continue
}
data := gjson.Parse(strings.TrimPrefix(line, "data: "))
switch data.Get("type").String() {
case "content_block_start":
if data.Get("content_block.type").String() == "thinking" {
startFound = true
if data.Get("content_block.signature").Exists() {
t.Fatalf("thinking start block should NOT have signature field when signature is unknown: %s", line)
}
}
case "content_block_delta":
if data.Get("delta.type").String() == "signature_delta" {
signatureDeltaFound = true
if got := data.Get("delta.signature").String(); got != "enc_sig_123" {
t.Fatalf("unexpected signature delta: %q", got)
}
}
case "content_block_stop":
stopFound = true
}
}
}
if !startFound {
t.Fatal("expected thinking content_block_start event")
}
if !signatureDeltaFound {
t.Fatal("expected signature_delta event for thinking block")
}
if !stopFound {
t.Fatal("expected content_block_stop event for thinking block")
}
}
func TestConvertCodexResponseToClaude_StreamThinkingWithoutReasoningItemStillIncludesSignatureField(t *testing.T) {
ctx := context.Background()
originalRequest := []byte(`{"messages":[]}`)
var param any
chunks := [][]byte{
[]byte("data: {\"type\":\"response.reasoning_summary_part.added\"}"),
[]byte("data: {\"type\":\"response.reasoning_summary_text.delta\",\"delta\":\"Let me think\"}"),
[]byte("data: {\"type\":\"response.reasoning_summary_part.done\"}"),
[]byte("data: {\"type\":\"response.completed\",\"response\":{\"usage\":{\"input_tokens\":1,\"output_tokens\":1}}}"),
}
var outputs [][]byte
for _, chunk := range chunks {
outputs = append(outputs, ConvertCodexResponseToClaude(ctx, "", originalRequest, nil, chunk, &param)...)
}
thinkingStartFound := false
thinkingStopFound := false
signatureDeltaFound := false
for _, out := range outputs {
for _, line := range strings.Split(string(out), "\n") {
if !strings.HasPrefix(line, "data: ") {
continue
}
data := gjson.Parse(strings.TrimPrefix(line, "data: "))
if data.Get("type").String() == "content_block_start" && data.Get("content_block.type").String() == "thinking" {
thinkingStartFound = true
if data.Get("content_block.signature").Exists() {
t.Fatalf("thinking start block should NOT have signature field without encrypted_content: %s", line)
}
}
if data.Get("type").String() == "content_block_stop" && data.Get("index").Int() == 0 {
thinkingStopFound = true
}
if data.Get("type").String() == "content_block_delta" && data.Get("delta.type").String() == "signature_delta" {
signatureDeltaFound = true
}
}
}
if !thinkingStartFound {
t.Fatal("expected thinking content_block_start event")
}
if !thinkingStopFound {
t.Fatal("expected thinking content_block_stop event")
}
if signatureDeltaFound {
t.Fatal("did not expect signature_delta without encrypted_content")
}
}
func TestConvertCodexResponseToClaude_StreamThinkingFinalizesPendingBlockBeforeNextSummaryPart(t *testing.T) {
ctx := context.Background()
originalRequest := []byte(`{"messages":[]}`)
var param any
chunks := [][]byte{
[]byte("data: {\"type\":\"response.reasoning_summary_part.added\"}"),
[]byte("data: {\"type\":\"response.reasoning_summary_text.delta\",\"delta\":\"First part\"}"),
[]byte("data: {\"type\":\"response.reasoning_summary_part.done\"}"),
[]byte("data: {\"type\":\"response.reasoning_summary_part.added\"}"),
}
var outputs [][]byte
for _, chunk := range chunks {
outputs = append(outputs, ConvertCodexResponseToClaude(ctx, "", originalRequest, nil, chunk, &param)...)
}
startCount := 0
stopCount := 0
for _, out := range outputs {
for _, line := range strings.Split(string(out), "\n") {
if !strings.HasPrefix(line, "data: ") {
continue
}
data := gjson.Parse(strings.TrimPrefix(line, "data: "))
if data.Get("type").String() == "content_block_start" && data.Get("content_block.type").String() == "thinking" {
startCount++
}
if data.Get("type").String() == "content_block_stop" {
stopCount++
}
}
}
if startCount != 2 {
t.Fatalf("expected 2 thinking block starts, got %d", startCount)
}
if stopCount != 1 {
t.Fatalf("expected pending thinking block to be finalized before second start, got %d stops", stopCount)
}
}
func TestConvertCodexResponseToClaude_StreamThinkingRetainsSignatureAcrossMultipartReasoning(t *testing.T) {
ctx := context.Background()
originalRequest := []byte(`{"messages":[]}`)
var param any
chunks := [][]byte{
[]byte("data: {\"type\":\"response.output_item.added\",\"item\":{\"type\":\"reasoning\",\"encrypted_content\":\"enc_sig_multipart\"}}"),
[]byte("data: {\"type\":\"response.reasoning_summary_part.added\"}"),
[]byte("data: {\"type\":\"response.reasoning_summary_text.delta\",\"delta\":\"First part\"}"),
[]byte("data: {\"type\":\"response.reasoning_summary_part.done\"}"),
[]byte("data: {\"type\":\"response.reasoning_summary_part.added\"}"),
[]byte("data: {\"type\":\"response.reasoning_summary_text.delta\",\"delta\":\"Second part\"}"),
[]byte("data: {\"type\":\"response.reasoning_summary_part.done\"}"),
[]byte("data: {\"type\":\"response.output_item.done\",\"item\":{\"type\":\"reasoning\"}}"),
}
var outputs [][]byte
for _, chunk := range chunks {
outputs = append(outputs, ConvertCodexResponseToClaude(ctx, "", originalRequest, nil, chunk, &param)...)
}
signatureDeltaCount := 0
for _, out := range outputs {
for _, line := range strings.Split(string(out), "\n") {
if !strings.HasPrefix(line, "data: ") {
continue
}
data := gjson.Parse(strings.TrimPrefix(line, "data: "))
if data.Get("type").String() == "content_block_delta" && data.Get("delta.type").String() == "signature_delta" {
signatureDeltaCount++
if got := data.Get("delta.signature").String(); got != "enc_sig_multipart" {
t.Fatalf("unexpected signature delta: %q", got)
}
}
}
}
if signatureDeltaCount != 2 {
t.Fatalf("expected signature_delta for both multipart thinking blocks, got %d", signatureDeltaCount)
}
}
func TestConvertCodexResponseToClaude_StreamThinkingUsesEarlyCapturedSignatureWhenDoneOmitsIt(t *testing.T) {
ctx := context.Background()
originalRequest := []byte(`{"messages":[]}`)
var param any
chunks := [][]byte{
[]byte("data: {\"type\":\"response.output_item.added\",\"item\":{\"type\":\"reasoning\",\"encrypted_content\":\"enc_sig_early\"}}"),
[]byte("data: {\"type\":\"response.reasoning_summary_part.added\"}"),
[]byte("data: {\"type\":\"response.reasoning_summary_text.delta\",\"delta\":\"Let me think\"}"),
[]byte("data: {\"type\":\"response.output_item.done\",\"item\":{\"type\":\"reasoning\"}}"),
}
var outputs [][]byte
for _, chunk := range chunks {
outputs = append(outputs, ConvertCodexResponseToClaude(ctx, "", originalRequest, nil, chunk, &param)...)
}
signatureDeltaCount := 0
for _, out := range outputs {
for _, line := range strings.Split(string(out), "\n") {
if !strings.HasPrefix(line, "data: ") {
continue
}
data := gjson.Parse(strings.TrimPrefix(line, "data: "))
if data.Get("type").String() == "content_block_delta" && data.Get("delta.type").String() == "signature_delta" {
signatureDeltaCount++
if got := data.Get("delta.signature").String(); got != "enc_sig_early" {
t.Fatalf("unexpected signature delta: %q", got)
}
}
}
}
if signatureDeltaCount != 1 {
t.Fatalf("expected signature_delta from early-captured signature, got %d", signatureDeltaCount)
}
}
func TestConvertCodexResponseToClaudeNonStream_ThinkingIncludesSignature(t *testing.T) {
ctx := context.Background()
originalRequest := []byte(`{"messages":[]}`)
response := []byte(`{
"type":"response.completed",
"response":{
"id":"resp_123",
"model":"gpt-5",
"usage":{"input_tokens":10,"output_tokens":20},
"output":[
{
"type":"reasoning",
"encrypted_content":"enc_sig_nonstream",
"summary":[{"type":"summary_text","text":"internal reasoning"}]
},
{
"type":"message",
"content":[{"type":"output_text","text":"final answer"}]
}
]
}
}`)
out := ConvertCodexResponseToClaudeNonStream(ctx, "", originalRequest, nil, response, nil)
parsed := gjson.ParseBytes(out)
thinking := parsed.Get("content.0")
if thinking.Get("type").String() != "thinking" {
t.Fatalf("expected first content block to be thinking, got %s", thinking.Raw)
}
if got := thinking.Get("signature").String(); got != "enc_sig_nonstream" {
t.Fatalf("expected signature to be preserved, got %q", got)
}
if got := thinking.Get("thinking").String(); got != "internal reasoning" {
t.Fatalf("unexpected thinking text: %q", got)
}
}

View File

@@ -20,12 +20,14 @@ type oaiToResponsesStateReasoning struct {
OutputIndex int
}
type oaiToResponsesState struct {
Seq int
ResponseID string
Created int64
Started bool
ReasoningID string
ReasoningIndex int
Seq int
ResponseID string
Created int64
Started bool
CompletionPending bool
CompletedEmitted bool
ReasoningID string
ReasoningIndex int
// aggregation buffers for response.output
// Per-output message text buffers by index
MsgTextBuf map[int]*strings.Builder
@@ -60,6 +62,141 @@ func emitRespEvent(event string, payload []byte) []byte {
return translatorcommon.SSEEventData(event, payload)
}
func buildResponsesCompletedEvent(st *oaiToResponsesState, requestRawJSON []byte, nextSeq func() int) []byte {
completed := []byte(`{"type":"response.completed","sequence_number":0,"response":{"id":"","object":"response","created_at":0,"status":"completed","background":false,"error":null}}`)
completed, _ = sjson.SetBytes(completed, "sequence_number", nextSeq())
completed, _ = sjson.SetBytes(completed, "response.id", st.ResponseID)
completed, _ = sjson.SetBytes(completed, "response.created_at", st.Created)
// Inject original request fields into response as per docs/response.completed.json
if requestRawJSON != nil {
req := gjson.ParseBytes(requestRawJSON)
if v := req.Get("instructions"); v.Exists() {
completed, _ = sjson.SetBytes(completed, "response.instructions", v.String())
}
if v := req.Get("max_output_tokens"); v.Exists() {
completed, _ = sjson.SetBytes(completed, "response.max_output_tokens", v.Int())
}
if v := req.Get("max_tool_calls"); v.Exists() {
completed, _ = sjson.SetBytes(completed, "response.max_tool_calls", v.Int())
}
if v := req.Get("model"); v.Exists() {
completed, _ = sjson.SetBytes(completed, "response.model", v.String())
}
if v := req.Get("parallel_tool_calls"); v.Exists() {
completed, _ = sjson.SetBytes(completed, "response.parallel_tool_calls", v.Bool())
}
if v := req.Get("previous_response_id"); v.Exists() {
completed, _ = sjson.SetBytes(completed, "response.previous_response_id", v.String())
}
if v := req.Get("prompt_cache_key"); v.Exists() {
completed, _ = sjson.SetBytes(completed, "response.prompt_cache_key", v.String())
}
if v := req.Get("reasoning"); v.Exists() {
completed, _ = sjson.SetBytes(completed, "response.reasoning", v.Value())
}
if v := req.Get("safety_identifier"); v.Exists() {
completed, _ = sjson.SetBytes(completed, "response.safety_identifier", v.String())
}
if v := req.Get("service_tier"); v.Exists() {
completed, _ = sjson.SetBytes(completed, "response.service_tier", v.String())
}
if v := req.Get("store"); v.Exists() {
completed, _ = sjson.SetBytes(completed, "response.store", v.Bool())
}
if v := req.Get("temperature"); v.Exists() {
completed, _ = sjson.SetBytes(completed, "response.temperature", v.Float())
}
if v := req.Get("text"); v.Exists() {
completed, _ = sjson.SetBytes(completed, "response.text", v.Value())
}
if v := req.Get("tool_choice"); v.Exists() {
completed, _ = sjson.SetBytes(completed, "response.tool_choice", v.Value())
}
if v := req.Get("tools"); v.Exists() {
completed, _ = sjson.SetBytes(completed, "response.tools", v.Value())
}
if v := req.Get("top_logprobs"); v.Exists() {
completed, _ = sjson.SetBytes(completed, "response.top_logprobs", v.Int())
}
if v := req.Get("top_p"); v.Exists() {
completed, _ = sjson.SetBytes(completed, "response.top_p", v.Float())
}
if v := req.Get("truncation"); v.Exists() {
completed, _ = sjson.SetBytes(completed, "response.truncation", v.String())
}
if v := req.Get("user"); v.Exists() {
completed, _ = sjson.SetBytes(completed, "response.user", v.Value())
}
if v := req.Get("metadata"); v.Exists() {
completed, _ = sjson.SetBytes(completed, "response.metadata", v.Value())
}
}
outputsWrapper := []byte(`{"arr":[]}`)
type completedOutputItem struct {
index int
raw []byte
}
outputItems := make([]completedOutputItem, 0, len(st.Reasonings)+len(st.MsgItemAdded)+len(st.FuncArgsBuf))
if len(st.Reasonings) > 0 {
for _, r := range st.Reasonings {
item := []byte(`{"id":"","type":"reasoning","summary":[{"type":"summary_text","text":""}]}`)
item, _ = sjson.SetBytes(item, "id", r.ReasoningID)
item, _ = sjson.SetBytes(item, "summary.0.text", r.ReasoningData)
outputItems = append(outputItems, completedOutputItem{index: r.OutputIndex, raw: item})
}
}
if len(st.MsgItemAdded) > 0 {
for i := range st.MsgItemAdded {
txt := ""
if b := st.MsgTextBuf[i]; b != nil {
txt = b.String()
}
item := []byte(`{"id":"","type":"message","status":"completed","content":[{"type":"output_text","annotations":[],"logprobs":[],"text":""}],"role":"assistant"}`)
item, _ = sjson.SetBytes(item, "id", fmt.Sprintf("msg_%s_%d", st.ResponseID, i))
item, _ = sjson.SetBytes(item, "content.0.text", txt)
outputItems = append(outputItems, completedOutputItem{index: st.MsgOutputIx[i], raw: item})
}
}
if len(st.FuncArgsBuf) > 0 {
for key := range st.FuncArgsBuf {
args := ""
if b := st.FuncArgsBuf[key]; b != nil {
args = b.String()
}
callID := st.FuncCallIDs[key]
name := st.FuncNames[key]
item := []byte(`{"id":"","type":"function_call","status":"completed","arguments":"","call_id":"","name":""}`)
item, _ = sjson.SetBytes(item, "id", fmt.Sprintf("fc_%s", callID))
item, _ = sjson.SetBytes(item, "arguments", args)
item, _ = sjson.SetBytes(item, "call_id", callID)
item, _ = sjson.SetBytes(item, "name", name)
outputItems = append(outputItems, completedOutputItem{index: st.FuncOutputIx[key], raw: item})
}
}
sort.Slice(outputItems, func(i, j int) bool { return outputItems[i].index < outputItems[j].index })
for _, item := range outputItems {
outputsWrapper, _ = sjson.SetRawBytes(outputsWrapper, "arr.-1", item.raw)
}
if gjson.GetBytes(outputsWrapper, "arr.#").Int() > 0 {
completed, _ = sjson.SetRawBytes(completed, "response.output", []byte(gjson.GetBytes(outputsWrapper, "arr").Raw))
}
if st.UsageSeen {
completed, _ = sjson.SetBytes(completed, "response.usage.input_tokens", st.PromptTokens)
completed, _ = sjson.SetBytes(completed, "response.usage.input_tokens_details.cached_tokens", st.CachedTokens)
completed, _ = sjson.SetBytes(completed, "response.usage.output_tokens", st.CompletionTokens)
if st.ReasoningTokens > 0 {
completed, _ = sjson.SetBytes(completed, "response.usage.output_tokens_details.reasoning_tokens", st.ReasoningTokens)
}
total := st.TotalTokens
if total == 0 {
total = st.PromptTokens + st.CompletionTokens
}
completed, _ = sjson.SetBytes(completed, "response.usage.total_tokens", total)
}
return emitRespEvent("response.completed", completed)
}
// ConvertOpenAIChatCompletionsResponseToOpenAIResponses converts OpenAI Chat Completions streaming chunks
// to OpenAI Responses SSE events (response.*).
func ConvertOpenAIChatCompletionsResponseToOpenAIResponses(ctx context.Context, modelName string, originalRequestRawJSON, requestRawJSON, rawJSON []byte, param *any) [][]byte {
@@ -90,6 +227,10 @@ func ConvertOpenAIChatCompletionsResponseToOpenAIResponses(ctx context.Context,
return [][]byte{}
}
if bytes.Equal(rawJSON, []byte("[DONE]")) {
if st.CompletionPending && !st.CompletedEmitted {
st.CompletedEmitted = true
return [][]byte{buildResponsesCompletedEvent(st, requestRawJSON, func() int { st.Seq++; return st.Seq })}
}
return [][]byte{}
}
@@ -165,6 +306,8 @@ func ConvertOpenAIChatCompletionsResponseToOpenAIResponses(ctx context.Context,
st.TotalTokens = 0
st.ReasoningTokens = 0
st.UsageSeen = false
st.CompletionPending = false
st.CompletedEmitted = false
// response.created
created := []byte(`{"type":"response.created","sequence_number":0,"response":{"id":"","object":"response","created_at":0,"status":"in_progress","background":false,"error":null,"output":[]}}`)
created, _ = sjson.SetBytes(created, "sequence_number", nextSeq())
@@ -374,8 +517,9 @@ func ConvertOpenAIChatCompletionsResponseToOpenAIResponses(ctx context.Context,
}
}
// finish_reason triggers finalization, including text done/content done/item done,
// reasoning done/part.done, function args done/item done, and completed
// finish_reason triggers item-level finalization. response.completed is
// deferred until the terminal [DONE] marker so late usage-only chunks can
// still populate response.usage.
if fr := choice.Get("finish_reason"); fr.Exists() && fr.String() != "" {
// Emit message done events for all indices that started a message
if len(st.MsgItemAdded) > 0 {
@@ -464,138 +608,7 @@ func ConvertOpenAIChatCompletionsResponseToOpenAIResponses(ctx context.Context,
st.FuncArgsDone[key] = true
}
}
completed := []byte(`{"type":"response.completed","sequence_number":0,"response":{"id":"","object":"response","created_at":0,"status":"completed","background":false,"error":null}}`)
completed, _ = sjson.SetBytes(completed, "sequence_number", nextSeq())
completed, _ = sjson.SetBytes(completed, "response.id", st.ResponseID)
completed, _ = sjson.SetBytes(completed, "response.created_at", st.Created)
// Inject original request fields into response as per docs/response.completed.json
if requestRawJSON != nil {
req := gjson.ParseBytes(requestRawJSON)
if v := req.Get("instructions"); v.Exists() {
completed, _ = sjson.SetBytes(completed, "response.instructions", v.String())
}
if v := req.Get("max_output_tokens"); v.Exists() {
completed, _ = sjson.SetBytes(completed, "response.max_output_tokens", v.Int())
}
if v := req.Get("max_tool_calls"); v.Exists() {
completed, _ = sjson.SetBytes(completed, "response.max_tool_calls", v.Int())
}
if v := req.Get("model"); v.Exists() {
completed, _ = sjson.SetBytes(completed, "response.model", v.String())
}
if v := req.Get("parallel_tool_calls"); v.Exists() {
completed, _ = sjson.SetBytes(completed, "response.parallel_tool_calls", v.Bool())
}
if v := req.Get("previous_response_id"); v.Exists() {
completed, _ = sjson.SetBytes(completed, "response.previous_response_id", v.String())
}
if v := req.Get("prompt_cache_key"); v.Exists() {
completed, _ = sjson.SetBytes(completed, "response.prompt_cache_key", v.String())
}
if v := req.Get("reasoning"); v.Exists() {
completed, _ = sjson.SetBytes(completed, "response.reasoning", v.Value())
}
if v := req.Get("safety_identifier"); v.Exists() {
completed, _ = sjson.SetBytes(completed, "response.safety_identifier", v.String())
}
if v := req.Get("service_tier"); v.Exists() {
completed, _ = sjson.SetBytes(completed, "response.service_tier", v.String())
}
if v := req.Get("store"); v.Exists() {
completed, _ = sjson.SetBytes(completed, "response.store", v.Bool())
}
if v := req.Get("temperature"); v.Exists() {
completed, _ = sjson.SetBytes(completed, "response.temperature", v.Float())
}
if v := req.Get("text"); v.Exists() {
completed, _ = sjson.SetBytes(completed, "response.text", v.Value())
}
if v := req.Get("tool_choice"); v.Exists() {
completed, _ = sjson.SetBytes(completed, "response.tool_choice", v.Value())
}
if v := req.Get("tools"); v.Exists() {
completed, _ = sjson.SetBytes(completed, "response.tools", v.Value())
}
if v := req.Get("top_logprobs"); v.Exists() {
completed, _ = sjson.SetBytes(completed, "response.top_logprobs", v.Int())
}
if v := req.Get("top_p"); v.Exists() {
completed, _ = sjson.SetBytes(completed, "response.top_p", v.Float())
}
if v := req.Get("truncation"); v.Exists() {
completed, _ = sjson.SetBytes(completed, "response.truncation", v.String())
}
if v := req.Get("user"); v.Exists() {
completed, _ = sjson.SetBytes(completed, "response.user", v.Value())
}
if v := req.Get("metadata"); v.Exists() {
completed, _ = sjson.SetBytes(completed, "response.metadata", v.Value())
}
}
// Build response.output using aggregated buffers
outputsWrapper := []byte(`{"arr":[]}`)
type completedOutputItem struct {
index int
raw []byte
}
outputItems := make([]completedOutputItem, 0, len(st.Reasonings)+len(st.MsgItemAdded)+len(st.FuncArgsBuf))
if len(st.Reasonings) > 0 {
for _, r := range st.Reasonings {
item := []byte(`{"id":"","type":"reasoning","summary":[{"type":"summary_text","text":""}]}`)
item, _ = sjson.SetBytes(item, "id", r.ReasoningID)
item, _ = sjson.SetBytes(item, "summary.0.text", r.ReasoningData)
outputItems = append(outputItems, completedOutputItem{index: r.OutputIndex, raw: item})
}
}
if len(st.MsgItemAdded) > 0 {
for i := range st.MsgItemAdded {
txt := ""
if b := st.MsgTextBuf[i]; b != nil {
txt = b.String()
}
item := []byte(`{"id":"","type":"message","status":"completed","content":[{"type":"output_text","annotations":[],"logprobs":[],"text":""}],"role":"assistant"}`)
item, _ = sjson.SetBytes(item, "id", fmt.Sprintf("msg_%s_%d", st.ResponseID, i))
item, _ = sjson.SetBytes(item, "content.0.text", txt)
outputItems = append(outputItems, completedOutputItem{index: st.MsgOutputIx[i], raw: item})
}
}
if len(st.FuncArgsBuf) > 0 {
for key := range st.FuncArgsBuf {
args := ""
if b := st.FuncArgsBuf[key]; b != nil {
args = b.String()
}
callID := st.FuncCallIDs[key]
name := st.FuncNames[key]
item := []byte(`{"id":"","type":"function_call","status":"completed","arguments":"","call_id":"","name":""}`)
item, _ = sjson.SetBytes(item, "id", fmt.Sprintf("fc_%s", callID))
item, _ = sjson.SetBytes(item, "arguments", args)
item, _ = sjson.SetBytes(item, "call_id", callID)
item, _ = sjson.SetBytes(item, "name", name)
outputItems = append(outputItems, completedOutputItem{index: st.FuncOutputIx[key], raw: item})
}
}
sort.Slice(outputItems, func(i, j int) bool { return outputItems[i].index < outputItems[j].index })
for _, item := range outputItems {
outputsWrapper, _ = sjson.SetRawBytes(outputsWrapper, "arr.-1", item.raw)
}
if gjson.GetBytes(outputsWrapper, "arr.#").Int() > 0 {
completed, _ = sjson.SetRawBytes(completed, "response.output", []byte(gjson.GetBytes(outputsWrapper, "arr").Raw))
}
if st.UsageSeen {
completed, _ = sjson.SetBytes(completed, "response.usage.input_tokens", st.PromptTokens)
completed, _ = sjson.SetBytes(completed, "response.usage.input_tokens_details.cached_tokens", st.CachedTokens)
completed, _ = sjson.SetBytes(completed, "response.usage.output_tokens", st.CompletionTokens)
if st.ReasoningTokens > 0 {
completed, _ = sjson.SetBytes(completed, "response.usage.output_tokens_details.reasoning_tokens", st.ReasoningTokens)
}
total := st.TotalTokens
if total == 0 {
total = st.PromptTokens + st.CompletionTokens
}
completed, _ = sjson.SetBytes(completed, "response.usage.total_tokens", total)
}
out = append(out, emitRespEvent("response.completed", completed))
st.CompletionPending = true
}
return true

View File

@@ -24,6 +24,120 @@ func parseOpenAIResponsesSSEEvent(t *testing.T, chunk []byte) (string, gjson.Res
return event, gjson.Parse(dataLine)
}
func TestConvertOpenAIChatCompletionsResponseToOpenAIResponses_ResponseCompletedWaitsForDone(t *testing.T) {
t.Parallel()
request := []byte(`{"model":"gpt-5.4","tool_choice":"auto","parallel_tool_calls":true}`)
tests := []struct {
name string
in []string
doneInputIndex int // Index in tt.in where the terminal [DONE] chunk arrives and response.completed must be emitted.
hasUsage bool
inputTokens int64
outputTokens int64
totalTokens int64
}{
{
// A provider may send finish_reason first and only attach usage in a later chunk (e.g. Vertex AI),
// so response.completed must wait for [DONE] to include that usage.
name: "late usage after finish reason",
in: []string{
`data: {"id":"resp_late_usage","object":"chat.completion.chunk","created":1773896263,"model":"model","choices":[{"index":0,"delta":{"role":"assistant","content":null,"reasoning_content":null,"tool_calls":[{"index":0,"id":"call_late_usage","type":"function","function":{"name":"read","arguments":""}}]},"finish_reason":null}]}`,
`data: {"id":"resp_late_usage","object":"chat.completion.chunk","created":1773896263,"model":"model","choices":[{"index":0,"delta":{"role":null,"content":null,"reasoning_content":null,"tool_calls":[{"index":0,"function":{"arguments":"{\"filePath\":\"C:\\\\repo\\\\README.md\"}"}}]},"finish_reason":"tool_calls"}]}`,
`data: {"id":"resp_late_usage","object":"chat.completion.chunk","created":1773896263,"model":"model","choices":[],"usage":{"prompt_tokens":11,"completion_tokens":7,"total_tokens":18}}`,
`data: [DONE]`,
},
doneInputIndex: 3,
hasUsage: true,
inputTokens: 11,
outputTokens: 7,
totalTokens: 18,
},
{
// When usage arrives on the same chunk as finish_reason, we still expect a
// single response.completed event and it should remain deferred until [DONE].
name: "usage on finish reason chunk",
in: []string{
`data: {"id":"resp_usage_same_chunk","object":"chat.completion.chunk","created":1773896263,"model":"model","choices":[{"index":0,"delta":{"role":"assistant","content":null,"reasoning_content":null,"tool_calls":[{"index":0,"id":"call_usage_same_chunk","type":"function","function":{"name":"read","arguments":""}}]},"finish_reason":null}]}`,
`data: {"id":"resp_usage_same_chunk","object":"chat.completion.chunk","created":1773896263,"model":"model","choices":[{"index":0,"delta":{"role":null,"content":null,"reasoning_content":null,"tool_calls":[{"index":0,"function":{"arguments":"{\"filePath\":\"C:\\\\repo\\\\README.md\"}"}}]},"finish_reason":"tool_calls"}],"usage":{"prompt_tokens":13,"completion_tokens":5,"total_tokens":18}}`,
`data: [DONE]`,
},
doneInputIndex: 2,
hasUsage: true,
inputTokens: 13,
outputTokens: 5,
totalTokens: 18,
},
{
// An OpenAI-compatible streams from a buggy server might never send usage, so response.completed should
// still wait for [DONE] but omit the usage object entirely.
name: "no usage chunk",
in: []string{
`data: {"id":"resp_no_usage","object":"chat.completion.chunk","created":1773896263,"model":"model","choices":[{"index":0,"delta":{"role":"assistant","content":null,"reasoning_content":null,"tool_calls":[{"index":0,"id":"call_no_usage","type":"function","function":{"name":"read","arguments":""}}]},"finish_reason":null}]}`,
`data: {"id":"resp_no_usage","object":"chat.completion.chunk","created":1773896263,"model":"model","choices":[{"index":0,"delta":{"role":null,"content":null,"reasoning_content":null,"tool_calls":[{"index":0,"function":{"arguments":"{\"filePath\":\"C:\\\\repo\\\\README.md\"}"}}]},"finish_reason":"tool_calls"}]}`,
`data: [DONE]`,
},
doneInputIndex: 2,
hasUsage: false,
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
completedCount := 0
completedInputIndex := -1
var completedData gjson.Result
// Reuse converter state across input lines to simulate one streaming response.
var param any
for i, line := range tt.in {
// One upstream chunk can emit multiple downstream SSE events.
for _, chunk := range ConvertOpenAIChatCompletionsResponseToOpenAIResponses(context.Background(), "model", request, request, []byte(line), &param) {
event, data := parseOpenAIResponsesSSEEvent(t, chunk)
if event != "response.completed" {
continue
}
completedCount++
completedInputIndex = i
completedData = data
if i < tt.doneInputIndex {
t.Fatalf("unexpected early response.completed on input index %d", i)
}
}
}
if completedCount != 1 {
t.Fatalf("expected exactly 1 response.completed event, got %d", completedCount)
}
if completedInputIndex != tt.doneInputIndex {
t.Fatalf("expected response.completed on terminal [DONE] chunk at input index %d, got %d", tt.doneInputIndex, completedInputIndex)
}
// Missing upstream usage should stay omitted in the final completed event.
if !tt.hasUsage {
if completedData.Get("response.usage").Exists() {
t.Fatalf("expected response.completed to omit usage when none was provided, got %s", completedData.Get("response.usage").Raw)
}
return
}
// When usage is present, the final response.completed event must preserve the usage values.
if got := completedData.Get("response.usage.input_tokens").Int(); got != tt.inputTokens {
t.Fatalf("unexpected response.usage.input_tokens: got %d want %d", got, tt.inputTokens)
}
if got := completedData.Get("response.usage.output_tokens").Int(); got != tt.outputTokens {
t.Fatalf("unexpected response.usage.output_tokens: got %d want %d", got, tt.outputTokens)
}
if got := completedData.Get("response.usage.total_tokens").Int(); got != tt.totalTokens {
t.Fatalf("unexpected response.usage.total_tokens: got %d want %d", got, tt.totalTokens)
}
})
}
}
func TestConvertOpenAIChatCompletionsResponseToOpenAIResponses_MultipleToolCallsRemainSeparate(t *testing.T) {
in := []string{
`data: {"id":"resp_test","object":"chat.completion.chunk","created":1773896263,"model":"model","choices":[{"index":0,"delta":{"role":"assistant","content":null,"reasoning_content":null,"tool_calls":[{"index":0,"id":"call_read","type":"function","function":{"name":"read","arguments":""}}]},"finish_reason":null}]}`,
@@ -31,6 +145,7 @@ func TestConvertOpenAIChatCompletionsResponseToOpenAIResponses_MultipleToolCalls
`data: {"id":"resp_test","object":"chat.completion.chunk","created":1773896263,"model":"model","choices":[{"index":0,"delta":{"role":"assistant","content":null,"reasoning_content":null,"tool_calls":[{"index":1,"id":"call_glob","type":"function","function":{"name":"glob","arguments":""}}]},"finish_reason":null}]}`,
`data: {"id":"resp_test","object":"chat.completion.chunk","created":1773896263,"model":"model","choices":[{"index":0,"delta":{"role":null,"content":null,"reasoning_content":null,"tool_calls":[{"index":1,"function":{"arguments":"{\"path\":\"C:\\\\repo\",\"pattern\":\"*.{yml,yaml}\"}"}}]},"finish_reason":null}]}`,
`data: {"id":"resp_test","object":"chat.completion.chunk","created":1773896263,"model":"model","choices":[{"index":0,"delta":{"role":null,"content":null,"reasoning_content":null,"tool_calls":null},"finish_reason":"tool_calls"}],"usage":{"completion_tokens":10,"total_tokens":20,"prompt_tokens":10}}`,
`data: [DONE]`,
}
request := []byte(`{"model":"gpt-5.4","tool_choice":"auto","parallel_tool_calls":true}`)
@@ -131,6 +246,7 @@ func TestConvertOpenAIChatCompletionsResponseToOpenAIResponses_MultiChoiceToolCa
`data: {"id":"resp_multi_choice","object":"chat.completion.chunk","created":1773896263,"model":"model","choices":[{"index":0,"delta":{"role":"assistant","content":null,"reasoning_content":null,"tool_calls":[{"index":0,"id":"call_choice0","type":"function","function":{"name":"glob","arguments":""}}]},"finish_reason":null},{"index":1,"delta":{"role":"assistant","content":null,"reasoning_content":null,"tool_calls":[{"index":0,"id":"call_choice1","type":"function","function":{"name":"read","arguments":""}}]},"finish_reason":null}]}`,
`data: {"id":"resp_multi_choice","object":"chat.completion.chunk","created":1773896263,"model":"model","choices":[{"index":0,"delta":{"role":null,"content":null,"reasoning_content":null,"tool_calls":[{"index":0,"function":{"arguments":"{\"path\":\"C:\\\\repo\",\"pattern\":\"*.go\"}"}}]},"finish_reason":null},{"index":1,"delta":{"role":null,"content":null,"reasoning_content":null,"tool_calls":[{"index":0,"function":{"arguments":"{\"filePath\":\"C:\\\\repo\\\\README.md\",\"limit\":20,\"offset\":1}"}}]},"finish_reason":null}]}`,
`data: {"id":"resp_multi_choice","object":"chat.completion.chunk","created":1773896263,"model":"model","choices":[{"index":0,"delta":{"role":null,"content":null,"reasoning_content":null,"tool_calls":null},"finish_reason":"tool_calls"},{"index":1,"delta":{"role":null,"content":null,"reasoning_content":null,"tool_calls":null},"finish_reason":"tool_calls"}],"usage":{"completion_tokens":10,"total_tokens":20,"prompt_tokens":10}}`,
`data: [DONE]`,
}
request := []byte(`{"model":"gpt-5.4","tool_choice":"auto","parallel_tool_calls":true}`)
@@ -213,6 +329,7 @@ func TestConvertOpenAIChatCompletionsResponseToOpenAIResponses_MixedMessageAndTo
in := []string{
`data: {"id":"resp_mixed","object":"chat.completion.chunk","created":1773896263,"model":"model","choices":[{"index":0,"delta":{"role":"assistant","content":"hello","reasoning_content":null,"tool_calls":null},"finish_reason":null},{"index":1,"delta":{"role":"assistant","content":null,"reasoning_content":null,"tool_calls":[{"index":0,"id":"call_choice1","type":"function","function":{"name":"read","arguments":""}}]},"finish_reason":null}]}`,
`data: {"id":"resp_mixed","object":"chat.completion.chunk","created":1773896263,"model":"model","choices":[{"index":0,"delta":{"role":null,"content":null,"reasoning_content":null,"tool_calls":null},"finish_reason":"stop"},{"index":1,"delta":{"role":null,"content":null,"reasoning_content":null,"tool_calls":[{"index":0,"function":{"arguments":"{\"filePath\":\"C:\\\\repo\\\\README.md\",\"limit\":20,\"offset\":1}"}}]},"finish_reason":"tool_calls"}],"usage":{"completion_tokens":10,"total_tokens":20,"prompt_tokens":10}}`,
`data: [DONE]`,
}
request := []byte(`{"model":"gpt-5.4","tool_choice":"auto","parallel_tool_calls":true}`)
@@ -261,6 +378,7 @@ func TestConvertOpenAIChatCompletionsResponseToOpenAIResponses_FunctionCallDoneA
`data: {"id":"resp_order","object":"chat.completion.chunk","created":1773896263,"model":"model","choices":[{"index":0,"delta":{"role":"assistant","content":null,"reasoning_content":null,"tool_calls":[{"index":1,"id":"call_read","type":"function","function":{"name":"read","arguments":""}}]},"finish_reason":null}]}`,
`data: {"id":"resp_order","object":"chat.completion.chunk","created":1773896263,"model":"model","choices":[{"index":0,"delta":{"role":null,"content":null,"reasoning_content":null,"tool_calls":[{"index":1,"function":{"arguments":"{\"filePath\":\"C:\\\\repo\\\\README.md\",\"limit\":20,\"offset\":1}"}}]},"finish_reason":null}]}`,
`data: {"id":"resp_order","object":"chat.completion.chunk","created":1773896263,"model":"model","choices":[{"index":0,"delta":{"role":null,"content":null,"reasoning_content":null,"tool_calls":null},"finish_reason":"tool_calls"}],"usage":{"completion_tokens":10,"total_tokens":20,"prompt_tokens":10}}`,
`data: [DONE]`,
}
request := []byte(`{"model":"gpt-5.4","tool_choice":"auto","parallel_tool_calls":true}`)

View File

@@ -6,6 +6,7 @@ package handlers
import (
"bytes"
"encoding/json"
"errors"
"fmt"
"net/http"
"strings"
@@ -493,6 +494,7 @@ func (h *BaseAPIHandler) ExecuteWithAuthManager(ctx context.Context, handlerType
opts.Metadata = reqMeta
resp, err := h.AuthManager.Execute(ctx, providers, req, opts)
if err != nil {
err = enrichAuthSelectionError(err, providers, normalizedModel)
status := http.StatusInternalServerError
if se, ok := err.(interface{ StatusCode() int }); ok && se != nil {
if code := se.StatusCode(); code > 0 {
@@ -539,6 +541,7 @@ func (h *BaseAPIHandler) ExecuteCountWithAuthManager(ctx context.Context, handle
opts.Metadata = reqMeta
resp, err := h.AuthManager.ExecuteCount(ctx, providers, req, opts)
if err != nil {
err = enrichAuthSelectionError(err, providers, normalizedModel)
status := http.StatusInternalServerError
if se, ok := err.(interface{ StatusCode() int }); ok && se != nil {
if code := se.StatusCode(); code > 0 {
@@ -589,6 +592,7 @@ func (h *BaseAPIHandler) ExecuteStreamWithAuthManager(ctx context.Context, handl
opts.Metadata = reqMeta
streamResult, err := h.AuthManager.ExecuteStream(ctx, providers, req, opts)
if err != nil {
err = enrichAuthSelectionError(err, providers, normalizedModel)
errChan := make(chan *interfaces.ErrorMessage, 1)
status := http.StatusInternalServerError
if se, ok := err.(interface{ StatusCode() int }); ok && se != nil {
@@ -698,7 +702,7 @@ func (h *BaseAPIHandler) ExecuteStreamWithAuthManager(ctx context.Context, handl
chunks = retryResult.Chunks
continue outer
}
streamErr = retryErr
streamErr = enrichAuthSelectionError(retryErr, providers, normalizedModel)
}
}
@@ -841,6 +845,54 @@ func replaceHeader(dst http.Header, src http.Header) {
}
}
func enrichAuthSelectionError(err error, providers []string, model string) error {
if err == nil {
return nil
}
var authErr *coreauth.Error
if !errors.As(err, &authErr) || authErr == nil {
return err
}
code := strings.TrimSpace(authErr.Code)
if code != "auth_not_found" && code != "auth_unavailable" {
return err
}
providerText := strings.Join(providers, ",")
if providerText == "" {
providerText = "unknown"
}
modelText := strings.TrimSpace(model)
if modelText == "" {
modelText = "unknown"
}
baseMessage := strings.TrimSpace(authErr.Message)
if baseMessage == "" {
baseMessage = "no auth available"
}
detail := fmt.Sprintf("%s (providers=%s, model=%s)", baseMessage, providerText, modelText)
// Clarify the most common alias confusion between Anthropic route names and internal provider keys.
if strings.Contains(","+providerText+",", ",claude,") {
detail += "; check Claude auth/key session and cooldown state via /v0/management/auth-files"
}
status := authErr.HTTPStatus
if status <= 0 {
status = http.StatusServiceUnavailable
}
return &coreauth.Error{
Code: authErr.Code,
Message: detail,
Retryable: authErr.Retryable,
HTTPStatus: status,
}
}
// WriteErrorResponse writes an error message to the response writer using the HTTP status embedded in the message.
func (h *BaseAPIHandler) WriteErrorResponse(c *gin.Context, msg *interfaces.ErrorMessage) {
status := http.StatusInternalServerError

View File

@@ -5,10 +5,12 @@ import (
"net/http"
"net/http/httptest"
"reflect"
"strings"
"testing"
"github.com/gin-gonic/gin"
"github.com/router-for-me/CLIProxyAPI/v6/internal/interfaces"
coreauth "github.com/router-for-me/CLIProxyAPI/v6/sdk/cliproxy/auth"
sdkconfig "github.com/router-for-me/CLIProxyAPI/v6/sdk/config"
)
@@ -66,3 +68,46 @@ func TestWriteErrorResponse_AddonHeadersEnabled(t *testing.T) {
t.Fatalf("X-Request-Id = %#v, want %#v", got, []string{"new-1", "new-2"})
}
}
func TestEnrichAuthSelectionError_DefaultsTo503WithContext(t *testing.T) {
in := &coreauth.Error{Code: "auth_not_found", Message: "no auth available"}
out := enrichAuthSelectionError(in, []string{"claude"}, "claude-sonnet-4-6")
var got *coreauth.Error
if !errors.As(out, &got) || got == nil {
t.Fatalf("expected coreauth.Error, got %T", out)
}
if got.StatusCode() != http.StatusServiceUnavailable {
t.Fatalf("status = %d, want %d", got.StatusCode(), http.StatusServiceUnavailable)
}
if !strings.Contains(got.Message, "providers=claude") {
t.Fatalf("message missing provider context: %q", got.Message)
}
if !strings.Contains(got.Message, "model=claude-sonnet-4-6") {
t.Fatalf("message missing model context: %q", got.Message)
}
if !strings.Contains(got.Message, "/v0/management/auth-files") {
t.Fatalf("message missing management hint: %q", got.Message)
}
}
func TestEnrichAuthSelectionError_PreservesExplicitStatus(t *testing.T) {
in := &coreauth.Error{Code: "auth_unavailable", Message: "no auth available", HTTPStatus: http.StatusTooManyRequests}
out := enrichAuthSelectionError(in, []string{"gemini"}, "gemini-2.5-pro")
var got *coreauth.Error
if !errors.As(out, &got) || got == nil {
t.Fatalf("expected coreauth.Error, got %T", out)
}
if got.StatusCode() != http.StatusTooManyRequests {
t.Fatalf("status = %d, want %d", got.StatusCode(), http.StatusTooManyRequests)
}
}
func TestEnrichAuthSelectionError_IgnoresOtherErrors(t *testing.T) {
in := errors.New("boom")
out := enrichAuthSelectionError(in, []string{"claude"}, "claude-sonnet-4-6")
if out != in {
t.Fatalf("expected original error to be returned unchanged")
}
}

View File

@@ -2,10 +2,13 @@ package handlers
import (
"context"
"errors"
"net/http"
"strings"
"sync"
"testing"
"github.com/router-for-me/CLIProxyAPI/v6/internal/interfaces"
"github.com/router-for-me/CLIProxyAPI/v6/internal/registry"
coreauth "github.com/router-for-me/CLIProxyAPI/v6/sdk/cliproxy/auth"
coreexecutor "github.com/router-for-me/CLIProxyAPI/v6/sdk/cliproxy/executor"
@@ -463,6 +466,76 @@ func TestExecuteStreamWithAuthManager_DoesNotRetryAfterFirstByte(t *testing.T) {
}
}
func TestExecuteStreamWithAuthManager_EnrichesBootstrapRetryAuthUnavailableError(t *testing.T) {
executor := &failOnceStreamExecutor{}
manager := coreauth.NewManager(nil, nil, nil)
manager.RegisterExecutor(executor)
auth1 := &coreauth.Auth{
ID: "auth1",
Provider: "codex",
Status: coreauth.StatusActive,
Metadata: map[string]any{"email": "test1@example.com"},
}
if _, err := manager.Register(context.Background(), auth1); err != nil {
t.Fatalf("manager.Register(auth1): %v", err)
}
registry.GetGlobalRegistry().RegisterClient(auth1.ID, auth1.Provider, []*registry.ModelInfo{{ID: "test-model"}})
t.Cleanup(func() {
registry.GetGlobalRegistry().UnregisterClient(auth1.ID)
})
handler := NewBaseAPIHandlers(&sdkconfig.SDKConfig{
Streaming: sdkconfig.StreamingConfig{
BootstrapRetries: 1,
},
}, manager)
dataChan, _, errChan := handler.ExecuteStreamWithAuthManager(context.Background(), "openai", "test-model", []byte(`{"model":"test-model"}`), "")
if dataChan == nil || errChan == nil {
t.Fatalf("expected non-nil channels")
}
var got []byte
for chunk := range dataChan {
got = append(got, chunk...)
}
if len(got) != 0 {
t.Fatalf("expected empty payload, got %q", string(got))
}
var gotErr *interfaces.ErrorMessage
for msg := range errChan {
if msg != nil {
gotErr = msg
}
}
if gotErr == nil {
t.Fatalf("expected terminal error")
}
if gotErr.StatusCode != http.StatusServiceUnavailable {
t.Fatalf("status = %d, want %d", gotErr.StatusCode, http.StatusServiceUnavailable)
}
var authErr *coreauth.Error
if !errors.As(gotErr.Error, &authErr) || authErr == nil {
t.Fatalf("expected coreauth.Error, got %T", gotErr.Error)
}
if authErr.Code != "auth_unavailable" {
t.Fatalf("code = %q, want %q", authErr.Code, "auth_unavailable")
}
if !strings.Contains(authErr.Message, "providers=codex") {
t.Fatalf("message missing provider context: %q", authErr.Message)
}
if !strings.Contains(authErr.Message, "model=test-model") {
t.Fatalf("message missing model context: %q", authErr.Message)
}
if executor.Calls() != 1 {
t.Fatalf("expected exactly one upstream call before retry path selection failure, got %d", executor.Calls())
}
}
func TestExecuteStreamWithAuthManager_PinnedAuthKeepsSameUpstream(t *testing.T) {
executor := &authAwareStreamExecutor{}
manager := coreauth.NewManager(nil, nil, nil)

View File

@@ -234,6 +234,84 @@ func (m *Manager) RefreshSchedulerEntry(authID string) {
m.scheduler.upsertAuth(snapshot)
}
// ReconcileRegistryModelStates aligns per-model runtime state with the current
// registry snapshot for one auth.
//
// Supported models are reset to a clean state because re-registration already
// cleared the registry-side cooldown/suspension snapshot. ModelStates for
// models that are no longer present in the registry are pruned entirely so
// renamed/removed models cannot keep auth-level status stale.
func (m *Manager) ReconcileRegistryModelStates(ctx context.Context, authID string) {
if m == nil || authID == "" {
return
}
supportedModels := registry.GetGlobalRegistry().GetModelsForClient(authID)
supported := make(map[string]struct{}, len(supportedModels))
for _, model := range supportedModels {
if model == nil {
continue
}
modelKey := canonicalModelKey(model.ID)
if modelKey == "" {
continue
}
supported[modelKey] = struct{}{}
}
var snapshot *Auth
now := time.Now()
m.mu.Lock()
auth, ok := m.auths[authID]
if ok && auth != nil && len(auth.ModelStates) > 0 {
changed := false
for modelKey, state := range auth.ModelStates {
baseModel := canonicalModelKey(modelKey)
if baseModel == "" {
baseModel = strings.TrimSpace(modelKey)
}
if _, supportedModel := supported[baseModel]; !supportedModel {
// Drop state for models that disappeared from the current registry
// snapshot. Keeping them around leaks stale errors into auth-level
// status, management output, and websocket fallback checks.
delete(auth.ModelStates, modelKey)
changed = true
continue
}
if state == nil {
continue
}
if modelStateIsClean(state) {
continue
}
resetModelState(state, now)
changed = true
}
if len(auth.ModelStates) == 0 {
auth.ModelStates = nil
}
if changed {
updateAggregatedAvailability(auth, now)
if !hasModelError(auth, now) {
auth.LastError = nil
auth.StatusMessage = ""
auth.Status = StatusActive
}
auth.UpdatedAt = now
if errPersist := m.persist(ctx, auth); errPersist != nil {
logEntryWithRequestID(ctx).WithField("auth_id", auth.ID).Warnf("failed to persist auth changes during model state reconciliation: %v", errPersist)
}
snapshot = auth.Clone()
}
}
m.mu.Unlock()
if m.scheduler != nil && snapshot != nil {
m.scheduler.upsertAuth(snapshot)
}
}
func (m *Manager) SetSelector(selector Selector) {
if m == nil {
return
@@ -1838,6 +1916,7 @@ func (m *Manager) MarkResult(ctx context.Context, result Result) {
} else {
if result.Model != "" {
if !isRequestScopedNotFoundResultError(result.Error) {
disableCooling := quotaCooldownDisabledForAuth(auth)
state := ensureModelState(auth, result.Model)
state.Unavailable = true
state.Status = StatusError
@@ -1858,31 +1937,45 @@ func (m *Manager) MarkResult(ctx context.Context, result Result) {
} else {
switch statusCode {
case 401:
next := now.Add(30 * time.Minute)
state.NextRetryAfter = next
suspendReason = "unauthorized"
shouldSuspendModel = true
if disableCooling {
state.NextRetryAfter = time.Time{}
} else {
next := now.Add(30 * time.Minute)
state.NextRetryAfter = next
suspendReason = "unauthorized"
shouldSuspendModel = true
}
case 402, 403:
next := now.Add(30 * time.Minute)
state.NextRetryAfter = next
suspendReason = "payment_required"
shouldSuspendModel = true
if disableCooling {
state.NextRetryAfter = time.Time{}
} else {
next := now.Add(30 * time.Minute)
state.NextRetryAfter = next
suspendReason = "payment_required"
shouldSuspendModel = true
}
case 404:
next := now.Add(12 * time.Hour)
state.NextRetryAfter = next
suspendReason = "not_found"
shouldSuspendModel = true
if disableCooling {
state.NextRetryAfter = time.Time{}
} else {
next := now.Add(12 * time.Hour)
state.NextRetryAfter = next
suspendReason = "not_found"
shouldSuspendModel = true
}
case 429:
var next time.Time
backoffLevel := state.Quota.BackoffLevel
if result.RetryAfter != nil {
next = now.Add(*result.RetryAfter)
} else {
cooldown, nextLevel := nextQuotaCooldown(backoffLevel, quotaCooldownDisabledForAuth(auth))
if cooldown > 0 {
next = now.Add(cooldown)
if !disableCooling {
if result.RetryAfter != nil {
next = now.Add(*result.RetryAfter)
} else {
cooldown, nextLevel := nextQuotaCooldown(backoffLevel, disableCooling)
if cooldown > 0 {
next = now.Add(cooldown)
}
backoffLevel = nextLevel
}
backoffLevel = nextLevel
}
state.NextRetryAfter = next
state.Quota = QuotaState{
@@ -1891,11 +1984,13 @@ func (m *Manager) MarkResult(ctx context.Context, result Result) {
NextRecoverAt: next,
BackoffLevel: backoffLevel,
}
suspendReason = "quota"
shouldSuspendModel = true
setModelQuota = true
if !disableCooling {
suspendReason = "quota"
shouldSuspendModel = true
setModelQuota = true
}
case 408, 500, 502, 503, 504:
if quotaCooldownDisabledForAuth(auth) {
if disableCooling {
state.NextRetryAfter = time.Time{}
} else {
next := now.Add(1 * time.Minute)
@@ -1966,8 +2061,28 @@ func resetModelState(state *ModelState, now time.Time) {
state.UpdatedAt = now
}
func modelStateIsClean(state *ModelState) bool {
if state == nil {
return true
}
if state.Status != StatusActive {
return false
}
if state.Unavailable || state.StatusMessage != "" || !state.NextRetryAfter.IsZero() || state.LastError != nil {
return false
}
if state.Quota.Exceeded || state.Quota.Reason != "" || !state.Quota.NextRecoverAt.IsZero() || state.Quota.BackoffLevel != 0 {
return false
}
return true
}
func updateAggregatedAvailability(auth *Auth, now time.Time) {
if auth == nil || len(auth.ModelStates) == 0 {
if auth == nil {
return
}
if len(auth.ModelStates) == 0 {
clearAggregatedAvailability(auth)
return
}
allUnavailable := true
@@ -1975,10 +2090,12 @@ func updateAggregatedAvailability(auth *Auth, now time.Time) {
quotaExceeded := false
quotaRecover := time.Time{}
maxBackoffLevel := 0
hasState := false
for _, state := range auth.ModelStates {
if state == nil {
continue
}
hasState = true
stateUnavailable := false
if state.Status == StatusDisabled {
stateUnavailable = true
@@ -2008,6 +2125,10 @@ func updateAggregatedAvailability(auth *Auth, now time.Time) {
}
}
}
if !hasState {
clearAggregatedAvailability(auth)
return
}
auth.Unavailable = allUnavailable
if allUnavailable {
auth.NextRetryAfter = earliestRetry
@@ -2027,6 +2148,15 @@ func updateAggregatedAvailability(auth *Auth, now time.Time) {
}
}
func clearAggregatedAvailability(auth *Auth) {
if auth == nil {
return
}
auth.Unavailable = false
auth.NextRetryAfter = time.Time{}
auth.Quota = QuotaState{}
}
func hasModelError(auth *Auth, now time.Time) bool {
if auth == nil || len(auth.ModelStates) == 0 {
return false
@@ -2211,6 +2341,7 @@ func applyAuthFailureState(auth *Auth, resultErr *Error, retryAfter *time.Durati
if isRequestScopedNotFoundResultError(resultErr) {
return
}
disableCooling := quotaCooldownDisabledForAuth(auth)
auth.Unavailable = true
auth.Status = StatusError
auth.UpdatedAt = now
@@ -2224,32 +2355,46 @@ func applyAuthFailureState(auth *Auth, resultErr *Error, retryAfter *time.Durati
switch statusCode {
case 401:
auth.StatusMessage = "unauthorized"
auth.NextRetryAfter = now.Add(30 * time.Minute)
if disableCooling {
auth.NextRetryAfter = time.Time{}
} else {
auth.NextRetryAfter = now.Add(30 * time.Minute)
}
case 402, 403:
auth.StatusMessage = "payment_required"
auth.NextRetryAfter = now.Add(30 * time.Minute)
if disableCooling {
auth.NextRetryAfter = time.Time{}
} else {
auth.NextRetryAfter = now.Add(30 * time.Minute)
}
case 404:
auth.StatusMessage = "not_found"
auth.NextRetryAfter = now.Add(12 * time.Hour)
if disableCooling {
auth.NextRetryAfter = time.Time{}
} else {
auth.NextRetryAfter = now.Add(12 * time.Hour)
}
case 429:
auth.StatusMessage = "quota exhausted"
auth.Quota.Exceeded = true
auth.Quota.Reason = "quota"
var next time.Time
if retryAfter != nil {
next = now.Add(*retryAfter)
} else {
cooldown, nextLevel := nextQuotaCooldown(auth.Quota.BackoffLevel, quotaCooldownDisabledForAuth(auth))
if cooldown > 0 {
next = now.Add(cooldown)
if !disableCooling {
if retryAfter != nil {
next = now.Add(*retryAfter)
} else {
cooldown, nextLevel := nextQuotaCooldown(auth.Quota.BackoffLevel, disableCooling)
if cooldown > 0 {
next = now.Add(cooldown)
}
auth.Quota.BackoffLevel = nextLevel
}
auth.Quota.BackoffLevel = nextLevel
}
auth.Quota.NextRecoverAt = next
auth.NextRetryAfter = next
case 408, 500, 502, 503, 504:
auth.StatusMessage = "transient upstream error"
if quotaCooldownDisabledForAuth(auth) {
if disableCooling {
auth.NextRetryAfter = time.Time{}
} else {
auth.NextRetryAfter = now.Add(1 * time.Minute)

View File

@@ -180,6 +180,34 @@ func (e *authFallbackExecutor) StreamCalls() []string {
return out
}
type retryAfterStatusError struct {
status int
message string
retryAfter time.Duration
}
func (e *retryAfterStatusError) Error() string {
if e == nil {
return ""
}
return e.message
}
func (e *retryAfterStatusError) StatusCode() int {
if e == nil {
return 0
}
return e.status
}
func (e *retryAfterStatusError) RetryAfter() *time.Duration {
if e == nil {
return nil
}
d := e.retryAfter
return &d
}
func newCredentialRetryLimitTestManager(t *testing.T, maxRetryCredentials int) (*Manager, *credentialRetryLimitExecutor) {
t.Helper()
@@ -450,6 +478,174 @@ func TestManager_MarkResult_RespectsAuthDisableCoolingOverride(t *testing.T) {
}
}
func TestManager_MarkResult_RespectsAuthDisableCoolingOverride_On403(t *testing.T) {
prev := quotaCooldownDisabled.Load()
quotaCooldownDisabled.Store(false)
t.Cleanup(func() { quotaCooldownDisabled.Store(prev) })
m := NewManager(nil, nil, nil)
auth := &Auth{
ID: "auth-403",
Provider: "claude",
Metadata: map[string]any{
"disable_cooling": true,
},
}
if _, errRegister := m.Register(context.Background(), auth); errRegister != nil {
t.Fatalf("register auth: %v", errRegister)
}
model := "test-model-403"
reg := registry.GetGlobalRegistry()
reg.RegisterClient(auth.ID, "claude", []*registry.ModelInfo{{ID: model}})
t.Cleanup(func() { reg.UnregisterClient(auth.ID) })
m.MarkResult(context.Background(), Result{
AuthID: auth.ID,
Provider: "claude",
Model: model,
Success: false,
Error: &Error{HTTPStatus: http.StatusForbidden, Message: "forbidden"},
})
updated, ok := m.GetByID(auth.ID)
if !ok || updated == nil {
t.Fatalf("expected auth to be present")
}
state := updated.ModelStates[model]
if state == nil {
t.Fatalf("expected model state to be present")
}
if !state.NextRetryAfter.IsZero() {
t.Fatalf("expected NextRetryAfter to be zero when disable_cooling=true, got %v", state.NextRetryAfter)
}
if count := reg.GetModelCount(model); count <= 0 {
t.Fatalf("expected model count > 0 when disable_cooling=true, got %d", count)
}
}
func TestManager_Execute_DisableCooling_DoesNotBlackoutAfter403(t *testing.T) {
prev := quotaCooldownDisabled.Load()
quotaCooldownDisabled.Store(false)
t.Cleanup(func() { quotaCooldownDisabled.Store(prev) })
m := NewManager(nil, nil, nil)
executor := &authFallbackExecutor{
id: "claude",
executeErrors: map[string]error{
"auth-403-exec": &Error{
HTTPStatus: http.StatusForbidden,
Message: "forbidden",
},
},
}
m.RegisterExecutor(executor)
auth := &Auth{
ID: "auth-403-exec",
Provider: "claude",
Metadata: map[string]any{
"disable_cooling": true,
},
}
if _, errRegister := m.Register(context.Background(), auth); errRegister != nil {
t.Fatalf("register auth: %v", errRegister)
}
model := "test-model-403-exec"
reg := registry.GetGlobalRegistry()
reg.RegisterClient(auth.ID, "claude", []*registry.ModelInfo{{ID: model}})
t.Cleanup(func() { reg.UnregisterClient(auth.ID) })
req := cliproxyexecutor.Request{Model: model}
_, errExecute1 := m.Execute(context.Background(), []string{"claude"}, req, cliproxyexecutor.Options{})
if errExecute1 == nil {
t.Fatal("expected first execute error")
}
if statusCodeFromError(errExecute1) != http.StatusForbidden {
t.Fatalf("first execute status = %d, want %d", statusCodeFromError(errExecute1), http.StatusForbidden)
}
_, errExecute2 := m.Execute(context.Background(), []string{"claude"}, req, cliproxyexecutor.Options{})
if errExecute2 == nil {
t.Fatal("expected second execute error")
}
if statusCodeFromError(errExecute2) != http.StatusForbidden {
t.Fatalf("second execute status = %d, want %d", statusCodeFromError(errExecute2), http.StatusForbidden)
}
}
func TestManager_Execute_DisableCooling_DoesNotBlackoutAfter429RetryAfter(t *testing.T) {
prev := quotaCooldownDisabled.Load()
quotaCooldownDisabled.Store(false)
t.Cleanup(func() { quotaCooldownDisabled.Store(prev) })
m := NewManager(nil, nil, nil)
executor := &authFallbackExecutor{
id: "claude",
executeErrors: map[string]error{
"auth-429-exec": &retryAfterStatusError{
status: http.StatusTooManyRequests,
message: "quota exhausted",
retryAfter: 2 * time.Minute,
},
},
}
m.RegisterExecutor(executor)
auth := &Auth{
ID: "auth-429-exec",
Provider: "claude",
Metadata: map[string]any{
"disable_cooling": true,
},
}
if _, errRegister := m.Register(context.Background(), auth); errRegister != nil {
t.Fatalf("register auth: %v", errRegister)
}
model := "test-model-429-exec"
reg := registry.GetGlobalRegistry()
reg.RegisterClient(auth.ID, "claude", []*registry.ModelInfo{{ID: model}})
t.Cleanup(func() { reg.UnregisterClient(auth.ID) })
req := cliproxyexecutor.Request{Model: model}
_, errExecute1 := m.Execute(context.Background(), []string{"claude"}, req, cliproxyexecutor.Options{})
if errExecute1 == nil {
t.Fatal("expected first execute error")
}
if statusCodeFromError(errExecute1) != http.StatusTooManyRequests {
t.Fatalf("first execute status = %d, want %d", statusCodeFromError(errExecute1), http.StatusTooManyRequests)
}
_, errExecute2 := m.Execute(context.Background(), []string{"claude"}, req, cliproxyexecutor.Options{})
if errExecute2 == nil {
t.Fatal("expected second execute error")
}
if statusCodeFromError(errExecute2) != http.StatusTooManyRequests {
t.Fatalf("second execute status = %d, want %d", statusCodeFromError(errExecute2), http.StatusTooManyRequests)
}
calls := executor.ExecuteCalls()
if len(calls) != 2 {
t.Fatalf("execute calls = %d, want 2", len(calls))
}
updated, ok := m.GetByID(auth.ID)
if !ok || updated == nil {
t.Fatalf("expected auth to be present")
}
state := updated.ModelStates[model]
if state == nil {
t.Fatalf("expected model state to be present")
}
if !state.NextRetryAfter.IsZero() {
t.Fatalf("expected NextRetryAfter to be zero when disable_cooling=true, got %v", state.NextRetryAfter)
}
}
func TestManager_MarkResult_RequestScopedNotFoundDoesNotCooldownAuth(t *testing.T) {
m := NewManager(nil, nil, nil)

View File

@@ -324,6 +324,7 @@ func (s *Service) applyCoreAuthAddOrUpdate(ctx context.Context, auth *coreauth.A
// This operation may block on network calls, but the auth configuration
// is already effective at this point.
s.registerModelsForAuth(auth)
s.coreManager.ReconcileRegistryModelStates(ctx, auth.ID)
// Refresh the scheduler entry so that the auth's supportedModelSet is rebuilt
// from the now-populated global model registry. Without this, newly added auths
@@ -1085,6 +1086,7 @@ func (s *Service) refreshModelRegistrationForAuth(current *coreauth.Auth) bool {
s.ensureExecutorsForAuth(current)
}
s.registerModelsForAuth(current)
s.coreManager.ReconcileRegistryModelStates(context.Background(), current.ID)
latest, ok := s.latestAuthForModelRegistration(current.ID)
if !ok || latest.Disabled {
@@ -1098,6 +1100,7 @@ func (s *Service) refreshModelRegistrationForAuth(current *coreauth.Auth) bool {
// no auth fields changed, but keeps the refresh path simple and correct.
s.ensureExecutorsForAuth(latest)
s.registerModelsForAuth(latest)
s.coreManager.ReconcileRegistryModelStates(context.Background(), latest.ID)
s.coreManager.RefreshSchedulerEntry(current.ID)
return true
}

View File

@@ -58,7 +58,7 @@ func Parse(raw string) (Setting, error) {
}
switch parsedURL.Scheme {
case "socks5", "http", "https":
case "socks5", "socks5h", "http", "https":
setting.Mode = ModeProxy
setting.URL = parsedURL
return setting, nil
@@ -95,7 +95,7 @@ func BuildHTTPTransport(raw string) (*http.Transport, Mode, error) {
case ModeDirect:
return NewDirectTransport(), setting.Mode, nil
case ModeProxy:
if setting.URL.Scheme == "socks5" {
if setting.URL.Scheme == "socks5" || setting.URL.Scheme == "socks5h" {
var proxyAuth *proxy.Auth
if setting.URL.User != nil {
username := setting.URL.User.Username()

View File

@@ -30,6 +30,7 @@ func TestParse(t *testing.T) {
{name: "http", input: "http://proxy.example.com:8080", want: ModeProxy},
{name: "https", input: "https://proxy.example.com:8443", want: ModeProxy},
{name: "socks5", input: "socks5://proxy.example.com:1080", want: ModeProxy},
{name: "socks5h", input: "socks5h://proxy.example.com:1080", want: ModeProxy},
{name: "invalid", input: "bad-value", want: ModeInvalid, wantErr: true},
}
@@ -137,3 +138,24 @@ func TestBuildHTTPTransportSOCKS5ProxyInheritsDefaultTransportSettings(t *testin
t.Fatalf("TLSHandshakeTimeout = %v, want %v", transport.TLSHandshakeTimeout, defaultTransport.TLSHandshakeTimeout)
}
}
func TestBuildHTTPTransportSOCKS5HProxy(t *testing.T) {
t.Parallel()
transport, mode, errBuild := BuildHTTPTransport("socks5h://proxy.example.com:1080")
if errBuild != nil {
t.Fatalf("BuildHTTPTransport returned error: %v", errBuild)
}
if mode != ModeProxy {
t.Fatalf("mode = %d, want %d", mode, ModeProxy)
}
if transport == nil {
t.Fatal("expected transport, got nil")
}
if transport.Proxy != nil {
t.Fatal("expected SOCKS5H transport to bypass http proxy function")
}
if transport.DialContext == nil {
t.Fatal("expected SOCKS5H transport to have custom DialContext")
}
}

View File

@@ -0,0 +1,106 @@
package test
import (
"encoding/json"
"os"
"path/filepath"
"testing"
)
type jsonObject = map[string]any
func loadClaudeCodeSentinelFixture(t *testing.T, name string) jsonObject {
t.Helper()
path := filepath.Join("testdata", "claude_code_sentinels", name)
data := mustReadFile(t, path)
var payload jsonObject
if err := json.Unmarshal(data, &payload); err != nil {
t.Fatalf("unmarshal %s: %v", name, err)
}
return payload
}
func mustReadFile(t *testing.T, path string) []byte {
t.Helper()
data, err := os.ReadFile(path)
if err != nil {
t.Fatalf("read %s: %v", path, err)
}
return data
}
func requireStringField(t *testing.T, obj jsonObject, key string) string {
t.Helper()
value, ok := obj[key].(string)
if !ok || value == "" {
t.Fatalf("field %q missing or empty: %#v", key, obj[key])
}
return value
}
func TestClaudeCodeSentinel_ToolProgressShape(t *testing.T) {
payload := loadClaudeCodeSentinelFixture(t, "tool_progress.json")
if got := requireStringField(t, payload, "type"); got != "tool_progress" {
t.Fatalf("type = %q, want tool_progress", got)
}
requireStringField(t, payload, "tool_use_id")
requireStringField(t, payload, "tool_name")
requireStringField(t, payload, "session_id")
if _, ok := payload["elapsed_time_seconds"].(float64); !ok {
t.Fatalf("elapsed_time_seconds missing or non-number: %#v", payload["elapsed_time_seconds"])
}
}
func TestClaudeCodeSentinel_SessionStateShape(t *testing.T) {
payload := loadClaudeCodeSentinelFixture(t, "session_state_changed.json")
if got := requireStringField(t, payload, "type"); got != "system" {
t.Fatalf("type = %q, want system", got)
}
if got := requireStringField(t, payload, "subtype"); got != "session_state_changed" {
t.Fatalf("subtype = %q, want session_state_changed", got)
}
state := requireStringField(t, payload, "state")
switch state {
case "idle", "running", "requires_action":
default:
t.Fatalf("unexpected session state %q", state)
}
requireStringField(t, payload, "session_id")
}
func TestClaudeCodeSentinel_ToolUseSummaryShape(t *testing.T) {
payload := loadClaudeCodeSentinelFixture(t, "tool_use_summary.json")
if got := requireStringField(t, payload, "type"); got != "tool_use_summary" {
t.Fatalf("type = %q, want tool_use_summary", got)
}
requireStringField(t, payload, "summary")
rawIDs, ok := payload["preceding_tool_use_ids"].([]any)
if !ok || len(rawIDs) == 0 {
t.Fatalf("preceding_tool_use_ids missing or empty: %#v", payload["preceding_tool_use_ids"])
}
for i, raw := range rawIDs {
if id, ok := raw.(string); !ok || id == "" {
t.Fatalf("preceding_tool_use_ids[%d] invalid: %#v", i, raw)
}
}
}
func TestClaudeCodeSentinel_ControlRequestCanUseToolShape(t *testing.T) {
payload := loadClaudeCodeSentinelFixture(t, "control_request_can_use_tool.json")
if got := requireStringField(t, payload, "type"); got != "control_request" {
t.Fatalf("type = %q, want control_request", got)
}
requireStringField(t, payload, "request_id")
request, ok := payload["request"].(map[string]any)
if !ok {
t.Fatalf("request missing or invalid: %#v", payload["request"])
}
if got := requireStringField(t, request, "subtype"); got != "can_use_tool" {
t.Fatalf("request.subtype = %q, want can_use_tool", got)
}
requireStringField(t, request, "tool_name")
requireStringField(t, request, "tool_use_id")
if input, ok := request["input"].(map[string]any); !ok || len(input) == 0 {
t.Fatalf("request.input missing or empty: %#v", request["input"])
}
}

View File

@@ -0,0 +1,11 @@
{
"type": "control_request",
"request_id": "req_123",
"request": {
"subtype": "can_use_tool",
"tool_name": "Bash",
"input": {"command": "npm test"},
"tool_use_id": "toolu_123",
"description": "Running npm test"
}
}

View File

@@ -0,0 +1,7 @@
{
"type": "system",
"subtype": "session_state_changed",
"state": "requires_action",
"uuid": "22222222-2222-4222-8222-222222222222",
"session_id": "sess_123"
}

View File

@@ -0,0 +1,10 @@
{
"type": "tool_progress",
"tool_use_id": "toolu_123",
"tool_name": "Bash",
"parent_tool_use_id": null,
"elapsed_time_seconds": 2.5,
"task_id": "task_123",
"uuid": "11111111-1111-4111-8111-111111111111",
"session_id": "sess_123"
}

View File

@@ -0,0 +1,7 @@
{
"type": "tool_use_summary",
"summary": "Searched in auth/",
"preceding_tool_use_ids": ["toolu_1", "toolu_2"],
"uuid": "33333333-3333-4333-8333-333333333333",
"session_id": "sess_123"
}