Compare commits

...

53 Commits

Author SHA1 Message Date
Luis Pater
d3100085b0 Merge pull request #392 from router-for-me/plus
v6.8.30
2026-02-26 23:16:26 +08:00
Luis Pater
f481d25133 Merge branch 'main' into plus 2026-02-26 23:16:17 +08:00
Luis Pater
8c6c90da74 fix(registry): clean up outdated model definitions in static data 2026-02-26 23:12:40 +08:00
Luis Pater
24bcfd9c03 Merge pull request #1699 from 123hi123/fix/antigravity-primary-model-fallback
fix(antigravity): keep primary model list and backfill empty auths
2026-02-26 04:28:29 +08:00
Luis Pater
816fb4c5da Merge pull request #1682 from sususu98/fix/tool-result-image-parts
fix(antigravity): place tool_result images in functionResponse.parts and unify mimeType
2026-02-25 23:14:35 +08:00
Luis Pater
c1bb77c7c9 Merge pull request #291 from howarddong711/feat/copilot-email-name
feat(copilot): fetch and persist user email and display name on login
2026-02-25 22:23:25 +08:00
Luis Pater
6bcac3a55a Merge branch 'router-for-me:main' into main 2026-02-25 22:21:31 +08:00
Howard Dong
fc346f4537 fix(copilot): add username fallback and consistent file name prefix
- Add 'github-user' fallback in WaitForAuthorization when FetchUserInfo
  returns empty Login (fixes malformed 'github-copilot-.json' filenames)
- Standardize Web API file name to 'github-copilot-<user>.json' to match
  CLI path convention (was 'github-<user>.json')

Addresses Gemini Code Assist review comments on PR #291.
2026-02-25 17:17:51 +08:00
Howard Dong
43e531a3b6 feat(copilot): fetch and persist user email and display name on login
- Expand OAuth scope to include read:user for full profile access
- Add GitHubUserInfo struct with Login, Email, Name fields
- Update FetchUserInfo to return complete user profile
- Add Email and Name fields to CopilotTokenStorage and CopilotAuthBundle
- Fix provider string bug: 'github' -> 'github-copilot' in auth_files.go
- Fix semantic bug: email field was storing username
- Update Label to prefer email over username in both CLI and Web API paths
- Add 9 unit tests covering new functionality
2026-02-25 17:09:40 +08:00
Luis Pater
d24ea4ce2a Merge pull request #1664 from ciberponk/pr/responses-compaction-compat
feat: add codex responses compatibility for compaction payloads
2026-02-25 01:21:59 +08:00
Luis Pater
2c30c981ae Merge pull request #1687 from lyd123qw2008/fix/codex-refresh-token-reused-no-retry
fix(codex): stop retrying refresh_token_reused errors
2026-02-25 01:19:30 +08:00
Luis Pater
aa1da8a858 Merge pull request #1685 from lyd123qw2008/fix/auth-auto-refresh-interval
fix(auth): respect configured auto-refresh interval
2026-02-25 01:13:47 +08:00
Luis Pater
f1e9a787d7 Merge pull request #1676 from piexian/feat/qwen-quota-handling-clean
feat(qwen): add rate limiting and quota error handling
2026-02-25 01:07:55 +08:00
Luis Pater
4eeec297de Merge pull request #288 from router-for-me/plus
v6.8.27
2026-02-25 01:04:57 +08:00
Luis Pater
77cc4ce3a0 Merge branch 'main' into plus 2026-02-25 01:04:15 +08:00
Luis Pater
37dfea1d3f Merge pull request #287 from possible055/main
fix(kiro): support OR-group field matching in truncation detector
2026-02-25 01:02:49 +08:00
Luis Pater
e6626c672a Merge pull request #269 from ClubWeGo/fix/filterOrphanedToolResults
fix: filter out orphaned tool results from history and current context
2026-02-25 01:02:11 +08:00
Luis Pater
c66cb0afd2 Merge pull request #1683 from dusty-du/codex/device-login-flow
Add additive Codex device-code login flow
2026-02-25 00:50:48 +08:00
Luis Pater
fb48eee973 Merge pull request #1680 from canxin121/fix/responses-stream-error-chunks
fix(responses): emit schema-valid SSE chunks
2026-02-25 00:49:06 +08:00
Luis Pater
bb44e5ec44 Merge pull request #1701 from router-for-me/openai
Revert "Merge pull request #1627 from thebtf/fix/reasoning-effort-clamping"
2026-02-25 00:46:13 +08:00
apparition
c785c1a3ca fix(kiro): support OR-group field matching in truncation detector
- Change RequiredFieldsByTool value type from []string to [][]string
- Outer slice = AND (all groups required); inner slice = OR (any one satisfies)
- Fix Bash entry to accept "cmd" or "command", resolving soft-truncation loop
- Update findMissingRequiredFields logic and inline docs accordingly
2026-02-24 22:48:05 +08:00
comalot
514ae341c8 fix(antigravity): deep copy cached model metadata 2026-02-24 20:14:01 +08:00
hkfires
0659ffab75 Revert "Merge pull request #1627 from thebtf/fix/reasoning-effort-clamping" 2026-02-24 19:47:53 +08:00
comalot
8ce07f38dd fix(antigravity): keep primary model list and backfill empty auths 2026-02-24 16:16:44 +08:00
Luis Pater
7cb398d167 Merge pull request #1663 from rensumo/main
feat: implement credential-based round-robin for gemini-cli
2026-02-24 06:02:50 +08:00
Luis Pater
c3e12c5e58 Merge pull request #1654 from alexey-yanchenko/feature/pass-file-inputs
Pass file input from /chat/completions and /responses to codex and claude
2026-02-24 05:53:11 +08:00
Luis Pater
1825fc7503 Merge pull request #1643 from alexey-yanchenko/fix/gemini-prompt-tokens
Fix usage convertation from gemini response to openai format
2026-02-24 05:46:13 +08:00
Luis Pater
48732ba05e Merge pull request #1527 from HEUDavid/feat/auth-hook
feat(auth): add post-auth hook mechanism
2026-02-24 05:33:13 +08:00
canxin121
acf483c9e6 fix(responses): reject invalid SSE data JSON
Guard the openai-response streaming path against truncated/invalid SSE data payloads by validating data: JSON before forwarding; surface a 502 terminal error instead of letting clients crash with JSON parse errors.
2026-02-24 01:42:54 +08:00
lyd123qw2008
3b3e0d1141 test(codex): log non-retryable refresh error and cover single-attempt behavior 2026-02-23 22:41:33 +08:00
lyd123qw2008
7acd428507 fix(codex): stop retrying refresh_token_reused errors 2026-02-23 22:31:30 +08:00
lyd123qw2008
450d1227bd fix(auth): respect configured auto-refresh interval 2026-02-23 22:07:50 +08:00
test
492b9c46f0 Add additive Codex device-code login flow 2026-02-23 06:30:04 -05:00
Darley
6e634fe3f9 fix: filter out orphaned tool results from history and current context 2026-02-23 14:33:59 +08:00
sususu98
4e26182d14 fix(antigravity): place tool_result images in functionResponse.parts and unify mimeType
Move base64 image data from Claude tool_result into functionResponse.parts
as inlineData instead of outer sibling parts, preventing context bloat.
Unify all inlineData field naming to camelCase mimeType across Claude,
OpenAI, and Gemini translators. Add comprehensive edge case tests and
Gemini-side regression test for functionResponse.parts preservation.
2026-02-23 13:38:21 +08:00
canxin121
eb7571936c revert: translator changes (path guard)
CI blocks PRs that modify internal/translator. Revert translator edits and keep only the /v1/responses streaming error-chunk fix; file an issue for translator conformance work.
2026-02-23 13:30:43 +08:00
canxin121
5382764d8a fix(responses): include model and usage in translated streams
Ensure response.created and response.completed chunks produced by the OpenAI/Gemini/Claude translators always include required fields (response.model and response.usage) so clients validating Responses SSE do not fail schema validation.
2026-02-23 13:22:06 +08:00
canxin121
49c8ec69d0 fix(openai): emit valid responses stream error chunks
When /v1/responses streaming fails after headers are sent, we now emit a type=error chunk instead of an HTTP-style {error:{...}} payload, preventing AI SDK chunk validation errors.
2026-02-23 12:59:50 +08:00
piexian
3b421c8181 feat(qwen): add rate limiting and quota error handling
- Add 60 requests/minute rate limiting per credential using sliding window
- Detect insufficient_quota errors and set cooldown until next day (Beijing time)
- Map quota errors (HTTP 403/429) to 429 with retryAfter for conductor integration
- Cache Beijing timezone at package level to avoid repeated syscalls
- Add redactAuthID function to protect credentials in logs
- Extract wrapQwenError helper to consolidate error handling
2026-02-23 00:38:46 +08:00
fan
afc8a0f9be refactor: simplify context_management compatibility handling 2026-02-21 22:20:48 +08:00
ciberponk
d693d7993b feat: support responses compaction payload compatibility for codex translator 2026-02-21 12:56:10 +08:00
rensumo
5936f9895c feat: implement credential-based round-robin for gemini-cli virtual auths
Changes the RoundRobinSelector to use two-level round-robin when
gemini-cli virtual auths are detected (via gemini_virtual_parent attr):
- Level 1: cycle across credential groups (parent accounts)
- Level 2: cycle within each group's project auths

Credentials start from a random offset (rand.IntN) for fair distribution.
Non-virtual auths and single-credential scenarios fall back to flat RR.

Adds 3 test cases covering multi-credential grouping, single-parent
fallback, and mixed virtual/non-virtual fallback.
2026-02-21 12:49:48 +08:00
Alexey Yanchenko
0cbfe7f457 Pass file input from /chat/completions and /responses to codex and claude 2026-02-20 10:25:44 +07:00
Alexey Yanchenko
b9ae4ab803 Fix usage convertation from gemini response to openai format 2026-02-19 15:34:59 +07:00
HEUDavid
65debb874f feat/auth-hook: refactor RequstInfo to preserve original HTTP semantics 2026-02-12 07:11:17 +08:00
HEUDavid
3caadac003 feat/auth-hook: add post auth hook [CR] 2026-02-12 07:11:17 +08:00
HEUDavid
6a9e3a6b84 feat/auth-hook: add post auth hook 2026-02-12 07:11:17 +08:00
HEUDavid
269972440a feat/auth-hook: add post auth hook 2026-02-12 07:11:17 +08:00
HEUDavid
cce13e6ad2 feat/auth-hook: add post auth hook 2026-02-12 07:11:17 +08:00
HEUDavid
8a565dcad8 feat/auth-hook: add post auth hook 2026-02-12 07:11:17 +08:00
HEUDavid
d536110404 feat/auth-hook: add post auth hook 2026-02-12 07:11:17 +08:00
HEUDavid
48e957ddff feat/auth-hook: add post auth hook 2026-02-12 07:11:17 +08:00
HEUDavid
94563d622c feat/auth-hook: add post auth hook 2026-02-12 07:11:17 +08:00
56 changed files with 2851 additions and 210 deletions

View File

@@ -72,6 +72,7 @@ func main() {
// Command-line flags to control the application's behavior.
var login bool
var codexLogin bool
var codexDeviceLogin bool
var claudeLogin bool
var qwenLogin bool
var kiloLogin bool
@@ -99,6 +100,7 @@ func main() {
// Define command-line flags for different operation modes.
flag.BoolVar(&login, "login", false, "Login Google Account")
flag.BoolVar(&codexLogin, "codex-login", false, "Login to Codex using OAuth")
flag.BoolVar(&codexDeviceLogin, "codex-device-login", false, "Login to Codex using device code flow")
flag.BoolVar(&claudeLogin, "claude-login", false, "Login to Claude using OAuth")
flag.BoolVar(&qwenLogin, "qwen-login", false, "Login to Qwen using OAuth")
flag.BoolVar(&kiloLogin, "kilo-login", false, "Login to Kilo AI using device flow")
@@ -502,6 +504,9 @@ func main() {
} else if codexLogin {
// Handle Codex login
cmd.DoCodexLogin(cfg, options)
} else if codexDeviceLogin {
// Handle Codex device-code login
cmd.DoCodexDeviceLogin(cfg, options)
} else if claudeLogin {
// Handle Claude login
cmd.DoClaudeLogin(cfg, options)

View File

@@ -951,11 +951,17 @@ func (h *Handler) saveTokenRecord(ctx context.Context, record *coreauth.Auth) (s
if store == nil {
return "", fmt.Errorf("token store unavailable")
}
if h.postAuthHook != nil {
if err := h.postAuthHook(ctx, record); err != nil {
return "", fmt.Errorf("post-auth hook failed: %w", err)
}
}
return store.Save(ctx, record)
}
func (h *Handler) RequestAnthropicToken(c *gin.Context) {
ctx := context.Background()
ctx = PopulateAuthContext(ctx, c)
fmt.Println("Initializing Claude authentication...")
@@ -1100,6 +1106,7 @@ func (h *Handler) RequestAnthropicToken(c *gin.Context) {
func (h *Handler) RequestGeminiCLIToken(c *gin.Context) {
ctx := context.Background()
ctx = PopulateAuthContext(ctx, c)
proxyHTTPClient := util.SetProxy(&h.cfg.SDKConfig, &http.Client{})
ctx = context.WithValue(ctx, oauth2.HTTPClient, proxyHTTPClient)
@@ -1358,6 +1365,7 @@ func (h *Handler) RequestGeminiCLIToken(c *gin.Context) {
func (h *Handler) RequestCodexToken(c *gin.Context) {
ctx := context.Background()
ctx = PopulateAuthContext(ctx, c)
fmt.Println("Initializing Codex authentication...")
@@ -1503,6 +1511,7 @@ func (h *Handler) RequestCodexToken(c *gin.Context) {
func (h *Handler) RequestAntigravityToken(c *gin.Context) {
ctx := context.Background()
ctx = PopulateAuthContext(ctx, c)
fmt.Println("Initializing Antigravity authentication...")
@@ -1667,6 +1676,7 @@ func (h *Handler) RequestAntigravityToken(c *gin.Context) {
func (h *Handler) RequestQwenToken(c *gin.Context) {
ctx := context.Background()
ctx = PopulateAuthContext(ctx, c)
fmt.Println("Initializing Qwen authentication...")
@@ -1722,6 +1732,7 @@ func (h *Handler) RequestQwenToken(c *gin.Context) {
func (h *Handler) RequestKimiToken(c *gin.Context) {
ctx := context.Background()
ctx = PopulateAuthContext(ctx, c)
fmt.Println("Initializing Kimi authentication...")
@@ -1798,6 +1809,7 @@ func (h *Handler) RequestKimiToken(c *gin.Context) {
func (h *Handler) RequestIFlowToken(c *gin.Context) {
ctx := context.Background()
ctx = PopulateAuthContext(ctx, c)
fmt.Println("Initializing iFlow authentication...")
@@ -1917,8 +1929,6 @@ func (h *Handler) RequestGitHubToken(c *gin.Context) {
state := fmt.Sprintf("gh-%d", time.Now().UnixNano())
// Initialize Copilot auth service
// We need to import "github.com/router-for-me/CLIProxyAPI/v6/internal/auth/copilot" first if not present
// Assuming copilot package is imported as "copilot"
deviceClient := copilot.NewDeviceFlowClient(h.cfg)
// Initiate device flow
@@ -1932,7 +1942,7 @@ func (h *Handler) RequestGitHubToken(c *gin.Context) {
authURL := deviceCode.VerificationURI
userCode := deviceCode.UserCode
RegisterOAuthSession(state, "github")
RegisterOAuthSession(state, "github-copilot")
go func() {
fmt.Printf("Please visit %s and enter code: %s\n", authURL, userCode)
@@ -1944,9 +1954,13 @@ func (h *Handler) RequestGitHubToken(c *gin.Context) {
return
}
username, errUser := deviceClient.FetchUserInfo(ctx, tokenData.AccessToken)
userInfo, errUser := deviceClient.FetchUserInfo(ctx, tokenData.AccessToken)
if errUser != nil {
log.Warnf("Failed to fetch user info: %v", errUser)
}
username := userInfo.Login
if username == "" {
username = "github-user"
}
@@ -1955,18 +1969,26 @@ func (h *Handler) RequestGitHubToken(c *gin.Context) {
TokenType: tokenData.TokenType,
Scope: tokenData.Scope,
Username: username,
Email: userInfo.Email,
Name: userInfo.Name,
Type: "github-copilot",
}
fileName := fmt.Sprintf("github-%s.json", username)
fileName := fmt.Sprintf("github-copilot-%s.json", username)
label := userInfo.Email
if label == "" {
label = username
}
record := &coreauth.Auth{
ID: fileName,
Provider: "github",
Provider: "github-copilot",
Label: label,
FileName: fileName,
Storage: tokenStorage,
Metadata: map[string]any{
"email": username,
"email": userInfo.Email,
"username": username,
"name": userInfo.Name,
},
}
@@ -1980,7 +2002,7 @@ func (h *Handler) RequestGitHubToken(c *gin.Context) {
fmt.Printf("Authentication successful! Token saved to %s\n", savedPath)
fmt.Println("You can now use GitHub Copilot services through this CLI")
CompleteOAuthSession(state)
CompleteOAuthSessionsByProvider("github")
CompleteOAuthSessionsByProvider("github-copilot")
}()
c.JSON(200, gin.H{
@@ -2521,6 +2543,14 @@ func (h *Handler) GetAuthStatus(c *gin.Context) {
c.JSON(http.StatusOK, gin.H{"status": "wait"})
}
// PopulateAuthContext extracts request info and adds it to the context
func PopulateAuthContext(ctx context.Context, c *gin.Context) context.Context {
info := &coreauth.RequestInfo{
Query: c.Request.URL.Query(),
Headers: c.Request.Header,
}
return coreauth.WithRequestInfo(ctx, info)
}
const kiroCallbackPort = 9876
func (h *Handler) RequestKiroToken(c *gin.Context) {

View File

@@ -47,6 +47,7 @@ type Handler struct {
allowRemoteOverride bool
envSecret string
logDir string
postAuthHook coreauth.PostAuthHook
}
// NewHandler creates a new management handler instance.
@@ -128,6 +129,11 @@ func (h *Handler) SetLogDirectory(dir string) {
h.logDir = dir
}
// SetPostAuthHook registers a hook to be called after auth record creation but before persistence.
func (h *Handler) SetPostAuthHook(hook coreauth.PostAuthHook) {
h.postAuthHook = hook
}
// Middleware enforces access control for management endpoints.
// All requests (local and remote) require a valid management key.
// Additionally, remote access requires allow-remote-management=true.

View File

@@ -52,6 +52,7 @@ type serverOptionConfig struct {
keepAliveEnabled bool
keepAliveTimeout time.Duration
keepAliveOnTimeout func()
postAuthHook auth.PostAuthHook
}
// ServerOption customises HTTP server construction.
@@ -112,6 +113,13 @@ func WithRequestLoggerFactory(factory func(*config.Config, string) logging.Reque
}
}
// WithPostAuthHook registers a hook to be called after auth record creation.
func WithPostAuthHook(hook auth.PostAuthHook) ServerOption {
return func(cfg *serverOptionConfig) {
cfg.postAuthHook = hook
}
}
// Server represents the main API server.
// It encapsulates the Gin engine, HTTP server, handlers, and configuration.
type Server struct {
@@ -263,6 +271,9 @@ func NewServer(cfg *config.Config, authManager *auth.Manager, accessManager *sdk
}
logDir := logging.ResolveLogDirectory(cfg)
s.mgmt.SetLogDirectory(logDir)
if optionState.postAuthHook != nil {
s.mgmt.SetPostAuthHook(optionState.postAuthHook)
}
s.localPassword = optionState.localPassword
// Setup routes

View File

@@ -36,11 +36,21 @@ type ClaudeTokenStorage struct {
// Expire is the timestamp when the current access token expires.
Expire string `json:"expired"`
// Metadata holds arbitrary key-value pairs injected via hooks.
// It is not exported to JSON directly to allow flattening during serialization.
Metadata map[string]any `json:"-"`
}
// SetMetadata allows external callers to inject metadata into the storage before saving.
func (ts *ClaudeTokenStorage) SetMetadata(meta map[string]any) {
ts.Metadata = meta
}
// SaveTokenToFile serializes the Claude token storage to a JSON file.
// This method creates the necessary directory structure and writes the token
// data in JSON format to the specified file path for persistent storage.
// It merges any injected metadata into the top-level JSON object.
//
// Parameters:
// - authFilePath: The full path where the token file should be saved
@@ -65,8 +75,14 @@ func (ts *ClaudeTokenStorage) SaveTokenToFile(authFilePath string) error {
_ = f.Close()
}()
// Merge metadata using helper
data, errMerge := misc.MergeMetadata(ts, ts.Metadata)
if errMerge != nil {
return fmt.Errorf("failed to merge metadata: %w", errMerge)
}
// Encode and write the token data as JSON
if err = json.NewEncoder(f).Encode(ts); err != nil {
if err = json.NewEncoder(f).Encode(data); err != nil {
return fmt.Errorf("failed to write token to file: %w", err)
}
return nil

View File

@@ -71,16 +71,26 @@ func (o *CodexAuth) GenerateAuthURL(state string, pkceCodes *PKCECodes) (string,
// It performs an HTTP POST request to the OpenAI token endpoint with the provided
// authorization code and PKCE verifier.
func (o *CodexAuth) ExchangeCodeForTokens(ctx context.Context, code string, pkceCodes *PKCECodes) (*CodexAuthBundle, error) {
return o.ExchangeCodeForTokensWithRedirect(ctx, code, RedirectURI, pkceCodes)
}
// ExchangeCodeForTokensWithRedirect exchanges an authorization code for tokens using
// a caller-provided redirect URI. This supports alternate auth flows such as device
// login while preserving the existing token parsing and storage behavior.
func (o *CodexAuth) ExchangeCodeForTokensWithRedirect(ctx context.Context, code, redirectURI string, pkceCodes *PKCECodes) (*CodexAuthBundle, error) {
if pkceCodes == nil {
return nil, fmt.Errorf("PKCE codes are required for token exchange")
}
if strings.TrimSpace(redirectURI) == "" {
return nil, fmt.Errorf("redirect URI is required for token exchange")
}
// Prepare token exchange request
data := url.Values{
"grant_type": {"authorization_code"},
"client_id": {ClientID},
"code": {code},
"redirect_uri": {RedirectURI},
"redirect_uri": {strings.TrimSpace(redirectURI)},
"code_verifier": {pkceCodes.CodeVerifier},
}
@@ -266,6 +276,10 @@ func (o *CodexAuth) RefreshTokensWithRetry(ctx context.Context, refreshToken str
if err == nil {
return tokenData, nil
}
if isNonRetryableRefreshErr(err) {
log.Warnf("Token refresh attempt %d failed with non-retryable error: %v", attempt+1, err)
return nil, err
}
lastErr = err
log.Warnf("Token refresh attempt %d failed: %v", attempt+1, err)
@@ -274,6 +288,14 @@ func (o *CodexAuth) RefreshTokensWithRetry(ctx context.Context, refreshToken str
return nil, fmt.Errorf("token refresh failed after %d attempts: %w", maxRetries, lastErr)
}
func isNonRetryableRefreshErr(err error) bool {
if err == nil {
return false
}
raw := strings.ToLower(err.Error())
return strings.Contains(raw, "refresh_token_reused")
}
// UpdateTokenStorage updates an existing CodexTokenStorage with new token data.
// This is typically called after a successful token refresh to persist the new credentials.
func (o *CodexAuth) UpdateTokenStorage(storage *CodexTokenStorage, tokenData *CodexTokenData) {

View File

@@ -0,0 +1,44 @@
package codex
import (
"context"
"io"
"net/http"
"strings"
"sync/atomic"
"testing"
)
type roundTripFunc func(*http.Request) (*http.Response, error)
func (f roundTripFunc) RoundTrip(req *http.Request) (*http.Response, error) {
return f(req)
}
func TestRefreshTokensWithRetry_NonRetryableOnlyAttemptsOnce(t *testing.T) {
var calls int32
auth := &CodexAuth{
httpClient: &http.Client{
Transport: roundTripFunc(func(req *http.Request) (*http.Response, error) {
atomic.AddInt32(&calls, 1)
return &http.Response{
StatusCode: http.StatusBadRequest,
Body: io.NopCloser(strings.NewReader(`{"error":"invalid_grant","code":"refresh_token_reused"}`)),
Header: make(http.Header),
Request: req,
}, nil
}),
},
}
_, err := auth.RefreshTokensWithRetry(context.Background(), "dummy_refresh_token", 3)
if err == nil {
t.Fatalf("expected error for non-retryable refresh failure")
}
if !strings.Contains(strings.ToLower(err.Error()), "refresh_token_reused") {
t.Fatalf("expected refresh_token_reused in error, got: %v", err)
}
if got := atomic.LoadInt32(&calls); got != 1 {
t.Fatalf("expected 1 refresh attempt, got %d", got)
}
}

View File

@@ -32,11 +32,21 @@ type CodexTokenStorage struct {
Type string `json:"type"`
// Expire is the timestamp when the current access token expires.
Expire string `json:"expired"`
// Metadata holds arbitrary key-value pairs injected via hooks.
// It is not exported to JSON directly to allow flattening during serialization.
Metadata map[string]any `json:"-"`
}
// SetMetadata allows external callers to inject metadata into the storage before saving.
func (ts *CodexTokenStorage) SetMetadata(meta map[string]any) {
ts.Metadata = meta
}
// SaveTokenToFile serializes the Codex token storage to a JSON file.
// This method creates the necessary directory structure and writes the token
// data in JSON format to the specified file path for persistent storage.
// It merges any injected metadata into the top-level JSON object.
//
// Parameters:
// - authFilePath: The full path where the token file should be saved
@@ -58,7 +68,13 @@ func (ts *CodexTokenStorage) SaveTokenToFile(authFilePath string) error {
_ = f.Close()
}()
if err = json.NewEncoder(f).Encode(ts); err != nil {
// Merge metadata using helper
data, errMerge := misc.MergeMetadata(ts, ts.Metadata)
if errMerge != nil {
return fmt.Errorf("failed to merge metadata: %w", errMerge)
}
if err = json.NewEncoder(f).Encode(data); err != nil {
return fmt.Errorf("failed to write token to file: %w", err)
}
return nil

View File

@@ -82,15 +82,21 @@ func (c *CopilotAuth) WaitForAuthorization(ctx context.Context, deviceCode *Devi
}
// Fetch the GitHub username
username, err := c.deviceClient.FetchUserInfo(ctx, tokenData.AccessToken)
userInfo, err := c.deviceClient.FetchUserInfo(ctx, tokenData.AccessToken)
if err != nil {
log.Warnf("copilot: failed to fetch user info: %v", err)
username = "unknown"
}
username := userInfo.Login
if username == "" {
username = "github-user"
}
return &CopilotAuthBundle{
TokenData: tokenData,
Username: username,
Email: userInfo.Email,
Name: userInfo.Name,
}, nil
}
@@ -150,12 +156,12 @@ func (c *CopilotAuth) ValidateToken(ctx context.Context, accessToken string) (bo
return false, "", nil
}
username, err := c.deviceClient.FetchUserInfo(ctx, accessToken)
userInfo, err := c.deviceClient.FetchUserInfo(ctx, accessToken)
if err != nil {
return false, "", err
}
return true, username, nil
return true, userInfo.Login, nil
}
// CreateTokenStorage creates a new CopilotTokenStorage from auth bundle.
@@ -165,6 +171,8 @@ func (c *CopilotAuth) CreateTokenStorage(bundle *CopilotAuthBundle) *CopilotToke
TokenType: bundle.TokenData.TokenType,
Scope: bundle.TokenData.Scope,
Username: bundle.Username,
Email: bundle.Email,
Name: bundle.Name,
Type: "github-copilot",
}
}

View File

@@ -53,7 +53,7 @@ func NewDeviceFlowClient(cfg *config.Config) *DeviceFlowClient {
func (c *DeviceFlowClient) RequestDeviceCode(ctx context.Context) (*DeviceCodeResponse, error) {
data := url.Values{}
data.Set("client_id", copilotClientID)
data.Set("scope", "user:email")
data.Set("scope", "read:user user:email")
req, err := http.NewRequestWithContext(ctx, http.MethodPost, copilotDeviceCodeURL, strings.NewReader(data.Encode()))
if err != nil {
@@ -211,15 +211,25 @@ func (c *DeviceFlowClient) exchangeDeviceCode(ctx context.Context, deviceCode st
}, nil
}
// FetchUserInfo retrieves the GitHub username for the authenticated user.
func (c *DeviceFlowClient) FetchUserInfo(ctx context.Context, accessToken string) (string, error) {
// GitHubUserInfo holds GitHub user profile information.
type GitHubUserInfo struct {
// Login is the GitHub username.
Login string
// Email is the primary email address (may be empty if not public).
Email string
// Name is the display name.
Name string
}
// FetchUserInfo retrieves the GitHub user profile for the authenticated user.
func (c *DeviceFlowClient) FetchUserInfo(ctx context.Context, accessToken string) (GitHubUserInfo, error) {
if accessToken == "" {
return "", NewAuthenticationError(ErrUserInfoFailed, fmt.Errorf("access token is empty"))
return GitHubUserInfo{}, NewAuthenticationError(ErrUserInfoFailed, fmt.Errorf("access token is empty"))
}
req, err := http.NewRequestWithContext(ctx, http.MethodGet, copilotUserInfoURL, nil)
if err != nil {
return "", NewAuthenticationError(ErrUserInfoFailed, err)
return GitHubUserInfo{}, NewAuthenticationError(ErrUserInfoFailed, err)
}
req.Header.Set("Authorization", "Bearer "+accessToken)
req.Header.Set("Accept", "application/json")
@@ -227,7 +237,7 @@ func (c *DeviceFlowClient) FetchUserInfo(ctx context.Context, accessToken string
resp, err := c.httpClient.Do(req)
if err != nil {
return "", NewAuthenticationError(ErrUserInfoFailed, err)
return GitHubUserInfo{}, NewAuthenticationError(ErrUserInfoFailed, err)
}
defer func() {
if errClose := resp.Body.Close(); errClose != nil {
@@ -237,19 +247,25 @@ func (c *DeviceFlowClient) FetchUserInfo(ctx context.Context, accessToken string
if !isHTTPSuccess(resp.StatusCode) {
bodyBytes, _ := io.ReadAll(resp.Body)
return "", NewAuthenticationError(ErrUserInfoFailed, fmt.Errorf("status %d: %s", resp.StatusCode, string(bodyBytes)))
return GitHubUserInfo{}, NewAuthenticationError(ErrUserInfoFailed, fmt.Errorf("status %d: %s", resp.StatusCode, string(bodyBytes)))
}
var userInfo struct {
var raw struct {
Login string `json:"login"`
Email string `json:"email"`
Name string `json:"name"`
}
if err = json.NewDecoder(resp.Body).Decode(&userInfo); err != nil {
return "", NewAuthenticationError(ErrUserInfoFailed, err)
if err = json.NewDecoder(resp.Body).Decode(&raw); err != nil {
return GitHubUserInfo{}, NewAuthenticationError(ErrUserInfoFailed, err)
}
if userInfo.Login == "" {
return "", NewAuthenticationError(ErrUserInfoFailed, fmt.Errorf("empty username"))
if raw.Login == "" {
return GitHubUserInfo{}, NewAuthenticationError(ErrUserInfoFailed, fmt.Errorf("empty username"))
}
return userInfo.Login, nil
return GitHubUserInfo{
Login: raw.Login,
Email: raw.Email,
Name: raw.Name,
}, nil
}

View File

@@ -0,0 +1,213 @@
package copilot
import (
"context"
"encoding/json"
"net/http"
"net/http/httptest"
"strings"
"testing"
)
// roundTripFunc lets us inject a custom transport for testing.
type roundTripFunc func(*http.Request) (*http.Response, error)
func (f roundTripFunc) RoundTrip(r *http.Request) (*http.Response, error) { return f(r) }
// newTestClient returns an *http.Client whose requests are redirected to the given test server,
// regardless of the original URL host.
func newTestClient(srv *httptest.Server) *http.Client {
return &http.Client{
Transport: roundTripFunc(func(req *http.Request) (*http.Response, error) {
req2 := req.Clone(req.Context())
req2.URL.Scheme = "http"
req2.URL.Host = strings.TrimPrefix(srv.URL, "http://")
return srv.Client().Transport.RoundTrip(req2)
}),
}
}
// TestFetchUserInfo_FullProfile verifies that FetchUserInfo returns login, email, and name.
func TestFetchUserInfo_FullProfile(t *testing.T) {
srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
if !strings.HasPrefix(r.Header.Get("Authorization"), "Bearer ") {
w.WriteHeader(http.StatusUnauthorized)
return
}
w.Header().Set("Content-Type", "application/json")
_ = json.NewEncoder(w).Encode(map[string]string{
"login": "octocat",
"email": "octocat@github.com",
"name": "The Octocat",
})
}))
defer srv.Close()
client := &DeviceFlowClient{httpClient: newTestClient(srv)}
info, err := client.FetchUserInfo(context.Background(), "test-token")
if err != nil {
t.Fatalf("unexpected error: %v", err)
}
if info.Login != "octocat" {
t.Errorf("Login: got %q, want %q", info.Login, "octocat")
}
if info.Email != "octocat@github.com" {
t.Errorf("Email: got %q, want %q", info.Email, "octocat@github.com")
}
if info.Name != "The Octocat" {
t.Errorf("Name: got %q, want %q", info.Name, "The Octocat")
}
}
// TestFetchUserInfo_EmptyEmail verifies graceful handling when email is absent (private account).
func TestFetchUserInfo_EmptyEmail(t *testing.T) {
srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
w.Header().Set("Content-Type", "application/json")
// GitHub returns null for private emails.
_, _ = w.Write([]byte(`{"login":"privateuser","email":null,"name":"Private User"}`))
}))
defer srv.Close()
client := &DeviceFlowClient{httpClient: newTestClient(srv)}
info, err := client.FetchUserInfo(context.Background(), "test-token")
if err != nil {
t.Fatalf("unexpected error: %v", err)
}
if info.Login != "privateuser" {
t.Errorf("Login: got %q, want %q", info.Login, "privateuser")
}
if info.Email != "" {
t.Errorf("Email: got %q, want empty string", info.Email)
}
if info.Name != "Private User" {
t.Errorf("Name: got %q, want %q", info.Name, "Private User")
}
}
// TestFetchUserInfo_EmptyToken verifies error is returned for empty access token.
func TestFetchUserInfo_EmptyToken(t *testing.T) {
client := &DeviceFlowClient{httpClient: http.DefaultClient}
_, err := client.FetchUserInfo(context.Background(), "")
if err == nil {
t.Fatal("expected error for empty token, got nil")
}
}
// TestFetchUserInfo_EmptyLogin verifies error is returned when API returns no login.
func TestFetchUserInfo_EmptyLogin(t *testing.T) {
srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
w.Header().Set("Content-Type", "application/json")
_, _ = w.Write([]byte(`{"email":"someone@example.com","name":"No Login"}`))
}))
defer srv.Close()
client := &DeviceFlowClient{httpClient: newTestClient(srv)}
_, err := client.FetchUserInfo(context.Background(), "test-token")
if err == nil {
t.Fatal("expected error for empty login, got nil")
}
}
// TestFetchUserInfo_HTTPError verifies error is returned on non-2xx response.
func TestFetchUserInfo_HTTPError(t *testing.T) {
srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
w.WriteHeader(http.StatusUnauthorized)
_, _ = w.Write([]byte(`{"message":"Bad credentials"}`))
}))
defer srv.Close()
client := &DeviceFlowClient{httpClient: newTestClient(srv)}
_, err := client.FetchUserInfo(context.Background(), "bad-token")
if err == nil {
t.Fatal("expected error for 401 response, got nil")
}
}
// TestCopilotTokenStorage_EmailNameFields verifies Email and Name serialise correctly.
func TestCopilotTokenStorage_EmailNameFields(t *testing.T) {
ts := &CopilotTokenStorage{
AccessToken: "ghu_abc",
TokenType: "bearer",
Scope: "read:user user:email",
Username: "octocat",
Email: "octocat@github.com",
Name: "The Octocat",
Type: "github-copilot",
}
data, err := json.Marshal(ts)
if err != nil {
t.Fatalf("marshal error: %v", err)
}
var out map[string]any
if err = json.Unmarshal(data, &out); err != nil {
t.Fatalf("unmarshal error: %v", err)
}
for _, key := range []string{"access_token", "username", "email", "name", "type"} {
if _, ok := out[key]; !ok {
t.Errorf("expected key %q in JSON output, not found", key)
}
}
if out["email"] != "octocat@github.com" {
t.Errorf("email: got %v, want %q", out["email"], "octocat@github.com")
}
if out["name"] != "The Octocat" {
t.Errorf("name: got %v, want %q", out["name"], "The Octocat")
}
}
// TestCopilotTokenStorage_OmitEmptyEmailName verifies email/name are omitted when empty (omitempty).
func TestCopilotTokenStorage_OmitEmptyEmailName(t *testing.T) {
ts := &CopilotTokenStorage{
AccessToken: "ghu_abc",
Username: "octocat",
Type: "github-copilot",
}
data, err := json.Marshal(ts)
if err != nil {
t.Fatalf("marshal error: %v", err)
}
var out map[string]any
if err = json.Unmarshal(data, &out); err != nil {
t.Fatalf("unmarshal error: %v", err)
}
if _, ok := out["email"]; ok {
t.Error("email key should be omitted when empty (omitempty), but was present")
}
if _, ok := out["name"]; ok {
t.Error("name key should be omitted when empty (omitempty), but was present")
}
}
// TestCopilotAuthBundle_EmailNameFields verifies bundle carries email and name through the pipeline.
func TestCopilotAuthBundle_EmailNameFields(t *testing.T) {
bundle := &CopilotAuthBundle{
TokenData: &CopilotTokenData{AccessToken: "ghu_abc"},
Username: "octocat",
Email: "octocat@github.com",
Name: "The Octocat",
}
if bundle.Email != "octocat@github.com" {
t.Errorf("bundle.Email: got %q, want %q", bundle.Email, "octocat@github.com")
}
if bundle.Name != "The Octocat" {
t.Errorf("bundle.Name: got %q, want %q", bundle.Name, "The Octocat")
}
}
// TestGitHubUserInfo_Struct verifies the exported GitHubUserInfo struct fields are accessible.
func TestGitHubUserInfo_Struct(t *testing.T) {
info := GitHubUserInfo{
Login: "octocat",
Email: "octocat@github.com",
Name: "The Octocat",
}
if info.Login == "" || info.Email == "" || info.Name == "" {
t.Error("GitHubUserInfo fields should not be empty")
}
}

View File

@@ -26,6 +26,10 @@ type CopilotTokenStorage struct {
ExpiresAt string `json:"expires_at,omitempty"`
// Username is the GitHub username associated with this token.
Username string `json:"username"`
// Email is the GitHub email address associated with this token.
Email string `json:"email,omitempty"`
// Name is the GitHub display name associated with this token.
Name string `json:"name,omitempty"`
// Type indicates the authentication provider type, always "github-copilot" for this storage.
Type string `json:"type"`
}
@@ -46,6 +50,10 @@ type CopilotAuthBundle struct {
TokenData *CopilotTokenData
// Username is the GitHub username.
Username string
// Email is the GitHub email address.
Email string
// Name is the GitHub display name.
Name string
}
// DeviceCodeResponse represents GitHub's device code response.

View File

@@ -35,11 +35,21 @@ type GeminiTokenStorage struct {
// Type indicates the authentication provider type, always "gemini" for this storage.
Type string `json:"type"`
// Metadata holds arbitrary key-value pairs injected via hooks.
// It is not exported to JSON directly to allow flattening during serialization.
Metadata map[string]any `json:"-"`
}
// SetMetadata allows external callers to inject metadata into the storage before saving.
func (ts *GeminiTokenStorage) SetMetadata(meta map[string]any) {
ts.Metadata = meta
}
// SaveTokenToFile serializes the Gemini token storage to a JSON file.
// This method creates the necessary directory structure and writes the token
// data in JSON format to the specified file path for persistent storage.
// It merges any injected metadata into the top-level JSON object.
//
// Parameters:
// - authFilePath: The full path where the token file should be saved
@@ -49,6 +59,11 @@ type GeminiTokenStorage struct {
func (ts *GeminiTokenStorage) SaveTokenToFile(authFilePath string) error {
misc.LogSavingCredentials(authFilePath)
ts.Type = "gemini"
// Merge metadata using helper
data, errMerge := misc.MergeMetadata(ts, ts.Metadata)
if errMerge != nil {
return fmt.Errorf("failed to merge metadata: %w", errMerge)
}
if err := os.MkdirAll(filepath.Dir(authFilePath), 0700); err != nil {
return fmt.Errorf("failed to create directory: %v", err)
}
@@ -63,7 +78,9 @@ func (ts *GeminiTokenStorage) SaveTokenToFile(authFilePath string) error {
}
}()
if err = json.NewEncoder(f).Encode(ts); err != nil {
enc := json.NewEncoder(f)
enc.SetIndent("", " ")
if err := enc.Encode(data); err != nil {
return fmt.Errorf("failed to write token to file: %w", err)
}
return nil

View File

@@ -21,6 +21,15 @@ type IFlowTokenStorage struct {
Scope string `json:"scope"`
Cookie string `json:"cookie"`
Type string `json:"type"`
// Metadata holds arbitrary key-value pairs injected via hooks.
// It is not exported to JSON directly to allow flattening during serialization.
Metadata map[string]any `json:"-"`
}
// SetMetadata allows external callers to inject metadata into the storage before saving.
func (ts *IFlowTokenStorage) SetMetadata(meta map[string]any) {
ts.Metadata = meta
}
// SaveTokenToFile serialises the token storage to disk.
@@ -37,7 +46,13 @@ func (ts *IFlowTokenStorage) SaveTokenToFile(authFilePath string) error {
}
defer func() { _ = f.Close() }()
if err = json.NewEncoder(f).Encode(ts); err != nil {
// Merge metadata using helper
data, errMerge := misc.MergeMetadata(ts, ts.Metadata)
if errMerge != nil {
return fmt.Errorf("failed to merge metadata: %w", errMerge)
}
if err = json.NewEncoder(f).Encode(data); err != nil {
return fmt.Errorf("iflow token: encode token failed: %w", err)
}
return nil

View File

@@ -29,6 +29,15 @@ type KimiTokenStorage struct {
Expired string `json:"expired,omitempty"`
// Type indicates the authentication provider type, always "kimi" for this storage.
Type string `json:"type"`
// Metadata holds arbitrary key-value pairs injected via hooks.
// It is not exported to JSON directly to allow flattening during serialization.
Metadata map[string]any `json:"-"`
}
// SetMetadata allows external callers to inject metadata into the storage before saving.
func (ts *KimiTokenStorage) SetMetadata(meta map[string]any) {
ts.Metadata = meta
}
// KimiTokenData holds the raw OAuth token response from Kimi.
@@ -86,9 +95,15 @@ func (ts *KimiTokenStorage) SaveTokenToFile(authFilePath string) error {
_ = f.Close()
}()
// Merge metadata using helper
data, errMerge := misc.MergeMetadata(ts, ts.Metadata)
if errMerge != nil {
return fmt.Errorf("failed to merge metadata: %w", errMerge)
}
encoder := json.NewEncoder(f)
encoder.SetIndent("", " ")
if err = encoder.Encode(ts); err != nil {
if err = encoder.Encode(data); err != nil {
return fmt.Errorf("failed to write token to file: %w", err)
}
return nil

View File

@@ -30,11 +30,21 @@ type QwenTokenStorage struct {
Type string `json:"type"`
// Expire is the timestamp when the current access token expires.
Expire string `json:"expired"`
// Metadata holds arbitrary key-value pairs injected via hooks.
// It is not exported to JSON directly to allow flattening during serialization.
Metadata map[string]any `json:"-"`
}
// SetMetadata allows external callers to inject metadata into the storage before saving.
func (ts *QwenTokenStorage) SetMetadata(meta map[string]any) {
ts.Metadata = meta
}
// SaveTokenToFile serializes the Qwen token storage to a JSON file.
// This method creates the necessary directory structure and writes the token
// data in JSON format to the specified file path for persistent storage.
// It merges any injected metadata into the top-level JSON object.
//
// Parameters:
// - authFilePath: The full path where the token file should be saved
@@ -56,7 +66,13 @@ func (ts *QwenTokenStorage) SaveTokenToFile(authFilePath string) error {
_ = f.Close()
}()
if err = json.NewEncoder(f).Encode(ts); err != nil {
// Merge metadata using helper
data, errMerge := misc.MergeMetadata(ts, ts.Metadata)
if errMerge != nil {
return fmt.Errorf("failed to merge metadata: %w", errMerge)
}
if err = json.NewEncoder(f).Encode(data); err != nil {
return fmt.Errorf("failed to write token to file: %w", err)
}
return nil

View File

@@ -0,0 +1,60 @@
package cmd
import (
"context"
"errors"
"fmt"
"os"
"github.com/router-for-me/CLIProxyAPI/v6/internal/auth/codex"
"github.com/router-for-me/CLIProxyAPI/v6/internal/config"
sdkAuth "github.com/router-for-me/CLIProxyAPI/v6/sdk/auth"
log "github.com/sirupsen/logrus"
)
const (
codexLoginModeMetadataKey = "codex_login_mode"
codexLoginModeDevice = "device"
)
// DoCodexDeviceLogin triggers the Codex device-code flow while keeping the
// existing codex-login OAuth callback flow intact.
func DoCodexDeviceLogin(cfg *config.Config, options *LoginOptions) {
if options == nil {
options = &LoginOptions{}
}
promptFn := options.Prompt
if promptFn == nil {
promptFn = defaultProjectPrompt()
}
manager := newAuthManager()
authOpts := &sdkAuth.LoginOptions{
NoBrowser: options.NoBrowser,
CallbackPort: options.CallbackPort,
Metadata: map[string]string{
codexLoginModeMetadataKey: codexLoginModeDevice,
},
Prompt: promptFn,
}
_, savedPath, err := manager.Login(context.Background(), "codex", cfg, authOpts)
if err != nil {
if authErr, ok := errors.AsType[*codex.AuthenticationError](err); ok {
log.Error(codex.GetUserFriendlyMessage(authErr))
if authErr.Type == codex.ErrPortInUse.Type {
os.Exit(codex.ErrPortInUse.Code)
}
return
}
fmt.Printf("Codex device authentication failed: %v\n", err)
return
}
if savedPath != "" {
fmt.Printf("Authentication saved to %s\n", savedPath)
}
fmt.Println("Codex device authentication successful!")
}

View File

@@ -1,6 +1,7 @@
package misc
import (
"encoding/json"
"fmt"
"path/filepath"
"strings"
@@ -24,3 +25,37 @@ func LogSavingCredentials(path string) {
func LogCredentialSeparator() {
log.Debug(credentialSeparator)
}
// MergeMetadata serializes the source struct into a map and merges the provided metadata into it.
func MergeMetadata(source any, metadata map[string]any) (map[string]any, error) {
var data map[string]any
// Fast path: if source is already a map, just copy it to avoid mutation of original
if srcMap, ok := source.(map[string]any); ok {
data = make(map[string]any, len(srcMap)+len(metadata))
for k, v := range srcMap {
data[k] = v
}
} else {
// Slow path: marshal to JSON and back to map to respect JSON tags
temp, err := json.Marshal(source)
if err != nil {
return nil, fmt.Errorf("failed to marshal source: %w", err)
}
if err := json.Unmarshal(temp, &data); err != nil {
return nil, fmt.Errorf("failed to unmarshal to map: %w", err)
}
}
// Merge extra metadata
if metadata != nil {
if data == nil {
data = make(map[string]any)
}
for k, v := range metadata {
data[k] = v
}
}
return data, nil
}

View File

@@ -916,19 +916,12 @@ func GetIFlowModels() []*ModelInfo {
Created int64
Thinking *ThinkingSupport
}{
{ID: "tstars2.0", DisplayName: "TStars-2.0", Description: "iFlow TStars-2.0 multimodal assistant", Created: 1746489600},
{ID: "qwen3-coder-plus", DisplayName: "Qwen3-Coder-Plus", Description: "Qwen3 Coder Plus code generation", Created: 1753228800},
{ID: "qwen3-max", DisplayName: "Qwen3-Max", Description: "Qwen3 flagship model", Created: 1758672000},
{ID: "qwen3-vl-plus", DisplayName: "Qwen3-VL-Plus", Description: "Qwen3 multimodal vision-language", Created: 1758672000},
{ID: "qwen3-max-preview", DisplayName: "Qwen3-Max-Preview", Description: "Qwen3 Max preview build", Created: 1757030400, Thinking: iFlowThinkingSupport},
{ID: "kimi-k2-0905", DisplayName: "Kimi-K2-Instruct-0905", Description: "Moonshot Kimi K2 instruct 0905", Created: 1757030400},
{ID: "glm-4.6", DisplayName: "GLM-4.6", Description: "Zhipu GLM 4.6 general model", Created: 1759190400, Thinking: iFlowThinkingSupport},
{ID: "glm-4.7", DisplayName: "GLM-4.7", Description: "Zhipu GLM 4.7 general model", Created: 1766448000, Thinking: iFlowThinkingSupport},
{ID: "glm-5", DisplayName: "GLM-5", Description: "Zhipu GLM 5 general model", Created: 1770768000, Thinking: iFlowThinkingSupport},
{ID: "kimi-k2", DisplayName: "Kimi-K2", Description: "Moonshot Kimi K2 general model", Created: 1752192000},
{ID: "kimi-k2-thinking", DisplayName: "Kimi-K2-Thinking", Description: "Moonshot Kimi K2 thinking model", Created: 1762387200},
{ID: "deepseek-v3.2-chat", DisplayName: "DeepSeek-V3.2", Description: "DeepSeek V3.2 Chat", Created: 1764576000},
{ID: "deepseek-v3.2-reasoner", DisplayName: "DeepSeek-V3.2", Description: "DeepSeek V3.2 Reasoner", Created: 1764576000},
{ID: "deepseek-v3.2", DisplayName: "DeepSeek-V3.2-Exp", Description: "DeepSeek V3.2 experimental", Created: 1759104000, Thinking: iFlowThinkingSupport},
{ID: "deepseek-v3.1", DisplayName: "DeepSeek-V3.1-Terminus", Description: "DeepSeek V3.1 Terminus", Created: 1756339200, Thinking: iFlowThinkingSupport},
{ID: "deepseek-r1", DisplayName: "DeepSeek-R1", Description: "DeepSeek reasoning model R1", Created: 1737331200},
@@ -937,11 +930,7 @@ func GetIFlowModels() []*ModelInfo {
{ID: "qwen3-235b-a22b-thinking-2507", DisplayName: "Qwen3-235B-A22B-Thinking", Description: "Qwen3 235B A22B Thinking (2507)", Created: 1753401600},
{ID: "qwen3-235b-a22b-instruct", DisplayName: "Qwen3-235B-A22B-Instruct", Description: "Qwen3 235B A22B Instruct", Created: 1753401600},
{ID: "qwen3-235b", DisplayName: "Qwen3-235B-A22B", Description: "Qwen3 235B A22B", Created: 1753401600},
{ID: "minimax-m2", DisplayName: "MiniMax-M2", Description: "MiniMax M2", Created: 1758672000, Thinking: iFlowThinkingSupport},
{ID: "minimax-m2.1", DisplayName: "MiniMax-M2.1", Description: "MiniMax M2.1", Created: 1766448000, Thinking: iFlowThinkingSupport},
{ID: "minimax-m2.5", DisplayName: "MiniMax-M2.5", Description: "MiniMax M2.5", Created: 1770825600, Thinking: iFlowThinkingSupport},
{ID: "iflow-rome-30ba3b", DisplayName: "iFlow-ROME", Description: "iFlow Rome 30BA3B model", Created: 1736899200},
{ID: "kimi-k2.5", DisplayName: "Kimi-K2.5", Description: "Moonshot Kimi K2.5", Created: 1769443200, Thinking: iFlowThinkingSupport},
}
models := make([]*ModelInfo, 0, len(entries))
for _, entry := range entries {

View File

@@ -54,8 +54,78 @@ const (
var (
randSource = rand.New(rand.NewSource(time.Now().UnixNano()))
randSourceMutex sync.Mutex
// antigravityPrimaryModelsCache keeps the latest non-empty model list fetched
// from any antigravity auth. Empty fetches never overwrite this cache.
antigravityPrimaryModelsCache struct {
mu sync.RWMutex
models []*registry.ModelInfo
}
)
func cloneAntigravityModels(models []*registry.ModelInfo) []*registry.ModelInfo {
if len(models) == 0 {
return nil
}
out := make([]*registry.ModelInfo, 0, len(models))
for _, model := range models {
if model == nil || strings.TrimSpace(model.ID) == "" {
continue
}
out = append(out, cloneAntigravityModelInfo(model))
}
if len(out) == 0 {
return nil
}
return out
}
func cloneAntigravityModelInfo(model *registry.ModelInfo) *registry.ModelInfo {
if model == nil {
return nil
}
clone := *model
if len(model.SupportedGenerationMethods) > 0 {
clone.SupportedGenerationMethods = append([]string(nil), model.SupportedGenerationMethods...)
}
if len(model.SupportedParameters) > 0 {
clone.SupportedParameters = append([]string(nil), model.SupportedParameters...)
}
if model.Thinking != nil {
thinkingClone := *model.Thinking
if len(model.Thinking.Levels) > 0 {
thinkingClone.Levels = append([]string(nil), model.Thinking.Levels...)
}
clone.Thinking = &thinkingClone
}
return &clone
}
func storeAntigravityPrimaryModels(models []*registry.ModelInfo) bool {
cloned := cloneAntigravityModels(models)
if len(cloned) == 0 {
return false
}
antigravityPrimaryModelsCache.mu.Lock()
antigravityPrimaryModelsCache.models = cloned
antigravityPrimaryModelsCache.mu.Unlock()
return true
}
func loadAntigravityPrimaryModels() []*registry.ModelInfo {
antigravityPrimaryModelsCache.mu.RLock()
cloned := cloneAntigravityModels(antigravityPrimaryModelsCache.models)
antigravityPrimaryModelsCache.mu.RUnlock()
return cloned
}
func fallbackAntigravityPrimaryModels() []*registry.ModelInfo {
models := loadAntigravityPrimaryModels()
if len(models) > 0 {
log.Debugf("antigravity executor: using cached primary model list (%d models)", len(models))
}
return models
}
// AntigravityExecutor proxies requests to the antigravity upstream.
type AntigravityExecutor struct {
cfg *config.Config
@@ -1006,14 +1076,9 @@ func (e *AntigravityExecutor) CountTokens(ctx context.Context, auth *cliproxyaut
func FetchAntigravityModels(ctx context.Context, auth *cliproxyauth.Auth, cfg *config.Config) []*registry.ModelInfo {
exec := &AntigravityExecutor{cfg: cfg}
token, updatedAuth, errToken := exec.ensureAccessToken(ctx, auth)
if errToken != nil {
log.Warnf("antigravity executor: fetch models failed for %s: token error: %v", auth.ID, errToken)
return nil
}
if token == "" {
log.Warnf("antigravity executor: fetch models failed for %s: got empty token", auth.ID)
return nil
}
if errToken != nil || token == "" {
return fallbackAntigravityPrimaryModels()
}
if updatedAuth != nil {
auth = updatedAuth
}
@@ -1025,8 +1090,7 @@ func FetchAntigravityModels(ctx context.Context, auth *cliproxyauth.Auth, cfg *c
modelsURL := baseURL + antigravityModelsPath
httpReq, errReq := http.NewRequestWithContext(ctx, http.MethodPost, modelsURL, bytes.NewReader([]byte(`{}`)))
if errReq != nil {
log.Warnf("antigravity executor: fetch models failed for %s: create request error: %v", auth.ID, errReq)
return nil
return fallbackAntigravityPrimaryModels()
}
httpReq.Header.Set("Content-Type", "application/json")
httpReq.Header.Set("Authorization", "Bearer "+token)
@@ -1038,15 +1102,13 @@ func FetchAntigravityModels(ctx context.Context, auth *cliproxyauth.Auth, cfg *c
httpResp, errDo := httpClient.Do(httpReq)
if errDo != nil {
if errors.Is(errDo, context.Canceled) || errors.Is(errDo, context.DeadlineExceeded) {
log.Warnf("antigravity executor: fetch models failed for %s: context canceled: %v", auth.ID, errDo)
return nil
return fallbackAntigravityPrimaryModels()
}
if idx+1 < len(baseURLs) {
log.Debugf("antigravity executor: models request error on base url %s, retrying with fallback base url: %s", baseURL, baseURLs[idx+1])
continue
}
log.Warnf("antigravity executor: fetch models failed for %s: request error: %v", auth.ID, errDo)
return nil
return fallbackAntigravityPrimaryModels()
}
bodyBytes, errRead := io.ReadAll(httpResp.Body)
@@ -1058,22 +1120,27 @@ func FetchAntigravityModels(ctx context.Context, auth *cliproxyauth.Auth, cfg *c
log.Debugf("antigravity executor: models read error on base url %s, retrying with fallback base url: %s", baseURL, baseURLs[idx+1])
continue
}
log.Warnf("antigravity executor: fetch models failed for %s: read body error: %v", auth.ID, errRead)
return nil
return fallbackAntigravityPrimaryModels()
}
if httpResp.StatusCode < http.StatusOK || httpResp.StatusCode >= http.StatusMultipleChoices {
if httpResp.StatusCode == http.StatusTooManyRequests && idx+1 < len(baseURLs) {
log.Debugf("antigravity executor: models request rate limited on base url %s, retrying with fallback base url: %s", baseURL, baseURLs[idx+1])
continue
}
log.Warnf("antigravity executor: fetch models failed for %s: unexpected status %d, body: %s", auth.ID, httpResp.StatusCode, string(bodyBytes))
return nil
if idx+1 < len(baseURLs) {
log.Debugf("antigravity executor: models request failed with status %d on base url %s, retrying with fallback base url: %s", httpResp.StatusCode, baseURL, baseURLs[idx+1])
continue
}
return fallbackAntigravityPrimaryModels()
}
result := gjson.GetBytes(bodyBytes, "models")
if !result.Exists() {
log.Warnf("antigravity executor: fetch models failed for %s: no models field in response, body: %s", auth.ID, string(bodyBytes))
return nil
if idx+1 < len(baseURLs) {
log.Debugf("antigravity executor: models field missing on base url %s, retrying with fallback base url: %s", baseURL, baseURLs[idx+1])
continue
}
return fallbackAntigravityPrimaryModels()
}
now := time.Now().Unix()
@@ -1118,9 +1185,18 @@ func FetchAntigravityModels(ctx context.Context, auth *cliproxyauth.Auth, cfg *c
}
models = append(models, modelInfo)
}
if len(models) == 0 {
if idx+1 < len(baseURLs) {
log.Debugf("antigravity executor: empty models list on base url %s, retrying with fallback base url: %s", baseURL, baseURLs[idx+1])
continue
}
log.Debug("antigravity executor: fetched empty model list; retaining cached primary model list")
return fallbackAntigravityPrimaryModels()
}
storeAntigravityPrimaryModels(models)
return models
}
return nil
return fallbackAntigravityPrimaryModels()
}
func (e *AntigravityExecutor) ensureAccessToken(ctx context.Context, auth *cliproxyauth.Auth) (string, *cliproxyauth.Auth, error) {

View File

@@ -0,0 +1,90 @@
package executor
import (
"testing"
"github.com/router-for-me/CLIProxyAPI/v6/internal/registry"
)
func resetAntigravityPrimaryModelsCacheForTest() {
antigravityPrimaryModelsCache.mu.Lock()
antigravityPrimaryModelsCache.models = nil
antigravityPrimaryModelsCache.mu.Unlock()
}
func TestStoreAntigravityPrimaryModels_EmptyDoesNotOverwrite(t *testing.T) {
resetAntigravityPrimaryModelsCacheForTest()
t.Cleanup(resetAntigravityPrimaryModelsCacheForTest)
seed := []*registry.ModelInfo{
{ID: "claude-sonnet-4-5"},
{ID: "gemini-2.5-pro"},
}
if updated := storeAntigravityPrimaryModels(seed); !updated {
t.Fatal("expected non-empty model list to update primary cache")
}
if updated := storeAntigravityPrimaryModels(nil); updated {
t.Fatal("expected nil model list not to overwrite primary cache")
}
if updated := storeAntigravityPrimaryModels([]*registry.ModelInfo{}); updated {
t.Fatal("expected empty model list not to overwrite primary cache")
}
got := loadAntigravityPrimaryModels()
if len(got) != 2 {
t.Fatalf("expected cached model count 2, got %d", len(got))
}
if got[0].ID != "claude-sonnet-4-5" || got[1].ID != "gemini-2.5-pro" {
t.Fatalf("unexpected cached model ids: %q, %q", got[0].ID, got[1].ID)
}
}
func TestLoadAntigravityPrimaryModels_ReturnsClone(t *testing.T) {
resetAntigravityPrimaryModelsCacheForTest()
t.Cleanup(resetAntigravityPrimaryModelsCacheForTest)
if updated := storeAntigravityPrimaryModels([]*registry.ModelInfo{{
ID: "gpt-5",
DisplayName: "GPT-5",
SupportedGenerationMethods: []string{"generateContent"},
SupportedParameters: []string{"temperature"},
Thinking: &registry.ThinkingSupport{
Levels: []string{"high"},
},
}}); !updated {
t.Fatal("expected model cache update")
}
got := loadAntigravityPrimaryModels()
if len(got) != 1 {
t.Fatalf("expected one cached model, got %d", len(got))
}
got[0].ID = "mutated-id"
if len(got[0].SupportedGenerationMethods) > 0 {
got[0].SupportedGenerationMethods[0] = "mutated-method"
}
if len(got[0].SupportedParameters) > 0 {
got[0].SupportedParameters[0] = "mutated-parameter"
}
if got[0].Thinking != nil && len(got[0].Thinking.Levels) > 0 {
got[0].Thinking.Levels[0] = "mutated-level"
}
again := loadAntigravityPrimaryModels()
if len(again) != 1 {
t.Fatalf("expected one cached model after mutation, got %d", len(again))
}
if again[0].ID != "gpt-5" {
t.Fatalf("expected cached model id to remain %q, got %q", "gpt-5", again[0].ID)
}
if len(again[0].SupportedGenerationMethods) == 0 || again[0].SupportedGenerationMethods[0] != "generateContent" {
t.Fatalf("expected cached generation methods to be unmutated, got %v", again[0].SupportedGenerationMethods)
}
if len(again[0].SupportedParameters) == 0 || again[0].SupportedParameters[0] != "temperature" {
t.Fatalf("expected cached supported parameters to be unmutated, got %v", again[0].SupportedParameters)
}
if again[0].Thinking == nil || len(again[0].Thinking.Levels) == 0 || again[0].Thinking.Levels[0] != "high" {
t.Fatalf("expected cached model thinking levels to be unmutated, got %v", again[0].Thinking)
}
}

View File

@@ -8,6 +8,7 @@ import (
"io"
"net/http"
"strings"
"sync"
"time"
qwenauth "github.com/router-for-me/CLIProxyAPI/v6/internal/auth/qwen"
@@ -22,9 +23,151 @@ import (
)
const (
qwenUserAgent = "QwenCode/0.10.3 (darwin; arm64)"
qwenUserAgent = "QwenCode/0.10.3 (darwin; arm64)"
qwenRateLimitPerMin = 60 // 60 requests per minute per credential
qwenRateLimitWindow = time.Minute // sliding window duration
)
// qwenBeijingLoc caches the Beijing timezone to avoid repeated LoadLocation syscalls.
var qwenBeijingLoc = func() *time.Location {
loc, err := time.LoadLocation("Asia/Shanghai")
if err != nil || loc == nil {
log.Warnf("qwen: failed to load Asia/Shanghai timezone: %v, using fixed UTC+8", err)
return time.FixedZone("CST", 8*3600)
}
return loc
}()
// qwenQuotaCodes is a package-level set of error codes that indicate quota exhaustion.
var qwenQuotaCodes = map[string]struct{}{
"insufficient_quota": {},
"quota_exceeded": {},
}
// qwenRateLimiter tracks request timestamps per credential for rate limiting.
// Qwen has a limit of 60 requests per minute per account.
var qwenRateLimiter = struct {
sync.Mutex
requests map[string][]time.Time // authID -> request timestamps
}{
requests: make(map[string][]time.Time),
}
// redactAuthID returns a redacted version of the auth ID for safe logging.
// Keeps a small prefix/suffix to allow correlation across events.
func redactAuthID(id string) string {
if id == "" {
return ""
}
if len(id) <= 8 {
return id
}
return id[:4] + "..." + id[len(id)-4:]
}
// checkQwenRateLimit checks if the credential has exceeded the rate limit.
// Returns nil if allowed, or a statusErr with retryAfter if rate limited.
func checkQwenRateLimit(authID string) error {
if authID == "" {
// Empty authID should not bypass rate limiting in production
// Use debug level to avoid log spam for certain auth flows
log.Debug("qwen rate limit check: empty authID, skipping rate limit")
return nil
}
now := time.Now()
windowStart := now.Add(-qwenRateLimitWindow)
qwenRateLimiter.Lock()
defer qwenRateLimiter.Unlock()
// Get and filter timestamps within the window
timestamps := qwenRateLimiter.requests[authID]
var validTimestamps []time.Time
for _, ts := range timestamps {
if ts.After(windowStart) {
validTimestamps = append(validTimestamps, ts)
}
}
// Always prune expired entries to prevent memory leak
// Delete empty entries, otherwise update with pruned slice
if len(validTimestamps) == 0 {
delete(qwenRateLimiter.requests, authID)
}
// Check if rate limit exceeded
if len(validTimestamps) >= qwenRateLimitPerMin {
// Calculate when the oldest request will expire
oldestInWindow := validTimestamps[0]
retryAfter := oldestInWindow.Add(qwenRateLimitWindow).Sub(now)
if retryAfter < time.Second {
retryAfter = time.Second
}
retryAfterSec := int(retryAfter.Seconds())
return statusErr{
code: http.StatusTooManyRequests,
msg: fmt.Sprintf(`{"error":{"code":"rate_limit_exceeded","message":"Qwen rate limit: %d requests/minute exceeded, retry after %ds","type":"rate_limit_exceeded"}}`, qwenRateLimitPerMin, retryAfterSec),
retryAfter: &retryAfter,
}
}
// Record this request and update the map with pruned timestamps
validTimestamps = append(validTimestamps, now)
qwenRateLimiter.requests[authID] = validTimestamps
return nil
}
// isQwenQuotaError checks if the error response indicates a quota exceeded error.
// Qwen returns HTTP 403 with error.code="insufficient_quota" when daily quota is exhausted.
func isQwenQuotaError(body []byte) bool {
code := strings.ToLower(gjson.GetBytes(body, "error.code").String())
errType := strings.ToLower(gjson.GetBytes(body, "error.type").String())
// Primary check: exact match on error.code or error.type (most reliable)
if _, ok := qwenQuotaCodes[code]; ok {
return true
}
if _, ok := qwenQuotaCodes[errType]; ok {
return true
}
// Fallback: check message only if code/type don't match (less reliable)
msg := strings.ToLower(gjson.GetBytes(body, "error.message").String())
if strings.Contains(msg, "insufficient_quota") || strings.Contains(msg, "quota exceeded") ||
strings.Contains(msg, "free allocated quota exceeded") {
return true
}
return false
}
// wrapQwenError wraps an HTTP error response, detecting quota errors and mapping them to 429.
// Returns the appropriate status code and retryAfter duration for statusErr.
// Only checks for quota errors when httpCode is 403 or 429 to avoid false positives.
func wrapQwenError(ctx context.Context, httpCode int, body []byte) (errCode int, retryAfter *time.Duration) {
errCode = httpCode
// Only check quota errors for expected status codes to avoid false positives
// Qwen returns 403 for quota errors, 429 for rate limits
if (httpCode == http.StatusForbidden || httpCode == http.StatusTooManyRequests) && isQwenQuotaError(body) {
errCode = http.StatusTooManyRequests // Map to 429 to trigger quota logic
cooldown := timeUntilNextDay()
retryAfter = &cooldown
logWithRequestID(ctx).Warnf("qwen quota exceeded (http %d -> %d), cooling down until tomorrow (%v)", httpCode, errCode, cooldown)
}
return errCode, retryAfter
}
// timeUntilNextDay returns duration until midnight Beijing time (UTC+8).
// Qwen's daily quota resets at 00:00 Beijing time.
func timeUntilNextDay() time.Duration {
now := time.Now()
nowLocal := now.In(qwenBeijingLoc)
tomorrow := time.Date(nowLocal.Year(), nowLocal.Month(), nowLocal.Day()+1, 0, 0, 0, 0, qwenBeijingLoc)
return tomorrow.Sub(now)
}
// QwenExecutor is a stateless executor for Qwen Code using OpenAI-compatible chat completions.
// If access token is unavailable, it falls back to legacy via ClientAdapter.
type QwenExecutor struct {
@@ -67,6 +210,17 @@ func (e *QwenExecutor) Execute(ctx context.Context, auth *cliproxyauth.Auth, req
if opts.Alt == "responses/compact" {
return resp, statusErr{code: http.StatusNotImplemented, msg: "/responses/compact not supported"}
}
// Check rate limit before proceeding
var authID string
if auth != nil {
authID = auth.ID
}
if err := checkQwenRateLimit(authID); err != nil {
logWithRequestID(ctx).Warnf("qwen rate limit exceeded for credential %s", redactAuthID(authID))
return resp, err
}
baseModel := thinking.ParseSuffix(req.Model).ModelName
token, baseURL := qwenCreds(auth)
@@ -102,9 +256,8 @@ func (e *QwenExecutor) Execute(ctx context.Context, auth *cliproxyauth.Auth, req
return resp, err
}
applyQwenHeaders(httpReq, token, false)
var authID, authLabel, authType, authValue string
var authLabel, authType, authValue string
if auth != nil {
authID = auth.ID
authLabel = auth.Label
authType, authValue = auth.AccountInfo()
}
@@ -135,8 +288,10 @@ func (e *QwenExecutor) Execute(ctx context.Context, auth *cliproxyauth.Auth, req
if httpResp.StatusCode < 200 || httpResp.StatusCode >= 300 {
b, _ := io.ReadAll(httpResp.Body)
appendAPIResponseChunk(ctx, e.cfg, b)
logWithRequestID(ctx).Debugf("request error, error status: %d, error message: %s", httpResp.StatusCode, summarizeErrorBody(httpResp.Header.Get("Content-Type"), b))
err = statusErr{code: httpResp.StatusCode, msg: string(b)}
errCode, retryAfter := wrapQwenError(ctx, httpResp.StatusCode, b)
logWithRequestID(ctx).Debugf("request error, error status: %d (mapped: %d), error message: %s", httpResp.StatusCode, errCode, summarizeErrorBody(httpResp.Header.Get("Content-Type"), b))
err = statusErr{code: errCode, msg: string(b), retryAfter: retryAfter}
return resp, err
}
data, err := io.ReadAll(httpResp.Body)
@@ -158,6 +313,17 @@ func (e *QwenExecutor) ExecuteStream(ctx context.Context, auth *cliproxyauth.Aut
if opts.Alt == "responses/compact" {
return nil, statusErr{code: http.StatusNotImplemented, msg: "/responses/compact not supported"}
}
// Check rate limit before proceeding
var authID string
if auth != nil {
authID = auth.ID
}
if err := checkQwenRateLimit(authID); err != nil {
logWithRequestID(ctx).Warnf("qwen rate limit exceeded for credential %s", redactAuthID(authID))
return nil, err
}
baseModel := thinking.ParseSuffix(req.Model).ModelName
token, baseURL := qwenCreds(auth)
@@ -200,9 +366,8 @@ func (e *QwenExecutor) ExecuteStream(ctx context.Context, auth *cliproxyauth.Aut
return nil, err
}
applyQwenHeaders(httpReq, token, true)
var authID, authLabel, authType, authValue string
var authLabel, authType, authValue string
if auth != nil {
authID = auth.ID
authLabel = auth.Label
authType, authValue = auth.AccountInfo()
}
@@ -228,11 +393,13 @@ func (e *QwenExecutor) ExecuteStream(ctx context.Context, auth *cliproxyauth.Aut
if httpResp.StatusCode < 200 || httpResp.StatusCode >= 300 {
b, _ := io.ReadAll(httpResp.Body)
appendAPIResponseChunk(ctx, e.cfg, b)
logWithRequestID(ctx).Debugf("request error, error status: %d, error message: %s", httpResp.StatusCode, summarizeErrorBody(httpResp.Header.Get("Content-Type"), b))
errCode, retryAfter := wrapQwenError(ctx, httpResp.StatusCode, b)
logWithRequestID(ctx).Debugf("request error, error status: %d (mapped: %d), error message: %s", httpResp.StatusCode, errCode, summarizeErrorBody(httpResp.Header.Get("Content-Type"), b))
if errClose := httpResp.Body.Close(); errClose != nil {
log.Errorf("qwen executor: close response body error: %v", errClose)
}
err = statusErr{code: httpResp.StatusCode, msg: string(b)}
err = statusErr{code: errCode, msg: string(b), retryAfter: retryAfter}
return nil, err
}
out := make(chan cliproxyexecutor.StreamChunk)

View File

@@ -10,53 +10,10 @@ import (
"github.com/router-for-me/CLIProxyAPI/v6/internal/registry"
"github.com/router-for-me/CLIProxyAPI/v6/internal/thinking"
log "github.com/sirupsen/logrus"
"github.com/tidwall/gjson"
"github.com/tidwall/sjson"
)
// validReasoningEffortLevels contains the standard values accepted by the
// OpenAI reasoning_effort field. Provider-specific extensions (xhigh, minimal,
// auto) are NOT in this set and must be clamped before use.
var validReasoningEffortLevels = map[string]struct{}{
"none": {},
"low": {},
"medium": {},
"high": {},
}
// clampReasoningEffort maps any thinking level string to a value that is safe
// to send as OpenAI reasoning_effort. Non-standard CPA-internal values are
// mapped to the nearest standard equivalent.
//
// Mapping rules:
// - none / low / medium / high → returned as-is (already valid)
// - xhigh → "high" (nearest lower standard level)
// - minimal → "low" (nearest higher standard level)
// - auto → "medium" (reasonable default)
// - anything else → "medium" (safe default)
func clampReasoningEffort(level string) string {
if _, ok := validReasoningEffortLevels[level]; ok {
return level
}
var clamped string
switch level {
case string(thinking.LevelXHigh):
clamped = string(thinking.LevelHigh)
case string(thinking.LevelMinimal):
clamped = string(thinking.LevelLow)
case string(thinking.LevelAuto):
clamped = string(thinking.LevelMedium)
default:
clamped = string(thinking.LevelMedium)
}
log.WithFields(log.Fields{
"original": level,
"clamped": clamped,
}).Debug("openai: reasoning_effort clamped to nearest valid standard value")
return clamped
}
// Applier implements thinking.ProviderApplier for OpenAI models.
//
// OpenAI-specific behavior:
@@ -101,7 +58,7 @@ func (a *Applier) Apply(body []byte, config thinking.ThinkingConfig, modelInfo *
}
if config.Mode == thinking.ModeLevel {
result, _ := sjson.SetBytes(body, "reasoning_effort", clampReasoningEffort(string(config.Level)))
result, _ := sjson.SetBytes(body, "reasoning_effort", string(config.Level))
return result, nil
}
@@ -122,7 +79,7 @@ func (a *Applier) Apply(body []byte, config thinking.ThinkingConfig, modelInfo *
return body, nil
}
result, _ := sjson.SetBytes(body, "reasoning_effort", clampReasoningEffort(effort))
result, _ := sjson.SetBytes(body, "reasoning_effort", effort)
return result, nil
}
@@ -157,7 +114,7 @@ func applyCompatibleOpenAI(body []byte, config thinking.ThinkingConfig) ([]byte,
return body, nil
}
result, _ := sjson.SetBytes(body, "reasoning_effort", clampReasoningEffort(effort))
result, _ := sjson.SetBytes(body, "reasoning_effort", effort)
return result, nil
}

View File

@@ -223,14 +223,65 @@ func ConvertClaudeRequestToAntigravity(modelName string, inputRawJSON []byte, _
functionResponseJSON, _ = sjson.Set(functionResponseJSON, "response.result", responseData)
} else if functionResponseResult.IsArray() {
frResults := functionResponseResult.Array()
if len(frResults) == 1 {
functionResponseJSON, _ = sjson.SetRaw(functionResponseJSON, "response.result", frResults[0].Raw)
nonImageCount := 0
lastNonImageRaw := ""
filteredJSON := "[]"
imagePartsJSON := "[]"
for _, fr := range frResults {
if fr.Get("type").String() == "image" && fr.Get("source.type").String() == "base64" {
inlineDataJSON := `{}`
if mimeType := fr.Get("source.media_type").String(); mimeType != "" {
inlineDataJSON, _ = sjson.Set(inlineDataJSON, "mimeType", mimeType)
}
if data := fr.Get("source.data").String(); data != "" {
inlineDataJSON, _ = sjson.Set(inlineDataJSON, "data", data)
}
imagePartJSON := `{}`
imagePartJSON, _ = sjson.SetRaw(imagePartJSON, "inlineData", inlineDataJSON)
imagePartsJSON, _ = sjson.SetRaw(imagePartsJSON, "-1", imagePartJSON)
continue
}
nonImageCount++
lastNonImageRaw = fr.Raw
filteredJSON, _ = sjson.SetRaw(filteredJSON, "-1", fr.Raw)
}
if nonImageCount == 1 {
functionResponseJSON, _ = sjson.SetRaw(functionResponseJSON, "response.result", lastNonImageRaw)
} else if nonImageCount > 1 {
functionResponseJSON, _ = sjson.SetRaw(functionResponseJSON, "response.result", filteredJSON)
} else {
functionResponseJSON, _ = sjson.SetRaw(functionResponseJSON, "response.result", functionResponseResult.Raw)
functionResponseJSON, _ = sjson.Set(functionResponseJSON, "response.result", "")
}
// Place image data inside functionResponse.parts as inlineData
// instead of as sibling parts in the outer content, to avoid
// base64 data bloating the text context.
if gjson.Get(imagePartsJSON, "#").Int() > 0 {
functionResponseJSON, _ = sjson.SetRaw(functionResponseJSON, "parts", imagePartsJSON)
}
} else if functionResponseResult.IsObject() {
functionResponseJSON, _ = sjson.SetRaw(functionResponseJSON, "response.result", functionResponseResult.Raw)
if functionResponseResult.Get("type").String() == "image" && functionResponseResult.Get("source.type").String() == "base64" {
inlineDataJSON := `{}`
if mimeType := functionResponseResult.Get("source.media_type").String(); mimeType != "" {
inlineDataJSON, _ = sjson.Set(inlineDataJSON, "mimeType", mimeType)
}
if data := functionResponseResult.Get("source.data").String(); data != "" {
inlineDataJSON, _ = sjson.Set(inlineDataJSON, "data", data)
}
imagePartJSON := `{}`
imagePartJSON, _ = sjson.SetRaw(imagePartJSON, "inlineData", inlineDataJSON)
imagePartsJSON := "[]"
imagePartsJSON, _ = sjson.SetRaw(imagePartsJSON, "-1", imagePartJSON)
functionResponseJSON, _ = sjson.SetRaw(functionResponseJSON, "parts", imagePartsJSON)
functionResponseJSON, _ = sjson.Set(functionResponseJSON, "response.result", "")
} else {
functionResponseJSON, _ = sjson.SetRaw(functionResponseJSON, "response.result", functionResponseResult.Raw)
}
} else if functionResponseResult.Raw != "" {
functionResponseJSON, _ = sjson.SetRaw(functionResponseJSON, "response.result", functionResponseResult.Raw)
} else {
@@ -248,7 +299,7 @@ func ConvertClaudeRequestToAntigravity(modelName string, inputRawJSON []byte, _
if sourceResult.Get("type").String() == "base64" {
inlineDataJSON := `{}`
if mimeType := sourceResult.Get("media_type").String(); mimeType != "" {
inlineDataJSON, _ = sjson.Set(inlineDataJSON, "mime_type", mimeType)
inlineDataJSON, _ = sjson.Set(inlineDataJSON, "mimeType", mimeType)
}
if data := sourceResult.Get("data").String(); data != "" {
inlineDataJSON, _ = sjson.Set(inlineDataJSON, "data", data)

View File

@@ -413,8 +413,8 @@ func TestConvertClaudeRequestToAntigravity_ImageContent(t *testing.T) {
if !inlineData.Exists() {
t.Error("inlineData should exist")
}
if inlineData.Get("mime_type").String() != "image/png" {
t.Error("mime_type mismatch")
if inlineData.Get("mimeType").String() != "image/png" {
t.Error("mimeType mismatch")
}
if !strings.Contains(inlineData.Get("data").String(), "iVBORw0KGgo") {
t.Error("data mismatch")
@@ -740,6 +740,429 @@ func TestConvertClaudeRequestToAntigravity_ToolResultNullContent(t *testing.T) {
}
}
func TestConvertClaudeRequestToAntigravity_ToolResultWithImage(t *testing.T) {
// tool_result with array content containing text + image should place
// image data inside functionResponse.parts as inlineData, not as a
// sibling part in the outer content (to avoid base64 context bloat).
inputJSON := []byte(`{
"model": "claude-3-5-sonnet-20240620",
"messages": [
{
"role": "user",
"content": [
{
"type": "tool_result",
"tool_use_id": "Read-123-456",
"content": [
{
"type": "text",
"text": "File content here"
},
{
"type": "image",
"source": {
"type": "base64",
"media_type": "image/png",
"data": "iVBORw0KGgoAAAANSUhEUg=="
}
}
]
}
]
}
]
}`)
output := ConvertClaudeRequestToAntigravity("claude-sonnet-4-5", inputJSON, false)
outputStr := string(output)
if !gjson.Valid(outputStr) {
t.Fatalf("Result is not valid JSON:\n%s", outputStr)
}
// Image should be inside functionResponse.parts, not as outer sibling part
funcResp := gjson.Get(outputStr, "request.contents.0.parts.0.functionResponse")
if !funcResp.Exists() {
t.Fatal("functionResponse should exist")
}
// Text content should be in response.result
resultText := funcResp.Get("response.result.text").String()
if resultText != "File content here" {
t.Errorf("Expected response.result.text = 'File content here', got '%s'", resultText)
}
// Image should be in functionResponse.parts[0].inlineData
inlineData := funcResp.Get("parts.0.inlineData")
if !inlineData.Exists() {
t.Fatal("functionResponse.parts[0].inlineData should exist")
}
if inlineData.Get("mimeType").String() != "image/png" {
t.Errorf("Expected mimeType 'image/png', got '%s'", inlineData.Get("mimeType").String())
}
if !strings.Contains(inlineData.Get("data").String(), "iVBORw0KGgo") {
t.Error("data mismatch")
}
// Image should NOT be in outer parts (only functionResponse part should exist)
outerParts := gjson.Get(outputStr, "request.contents.0.parts")
if outerParts.IsArray() && len(outerParts.Array()) > 1 {
t.Errorf("Expected only 1 outer part (functionResponse), got %d", len(outerParts.Array()))
}
}
func TestConvertClaudeRequestToAntigravity_ToolResultWithSingleImage(t *testing.T) {
// tool_result with single image object as content should place
// image data inside functionResponse.parts, not as outer sibling part.
inputJSON := []byte(`{
"model": "claude-3-5-sonnet-20240620",
"messages": [
{
"role": "user",
"content": [
{
"type": "tool_result",
"tool_use_id": "Read-789-012",
"content": {
"type": "image",
"source": {
"type": "base64",
"media_type": "image/jpeg",
"data": "/9j/4AAQSkZJRgABAQ=="
}
}
}
]
}
]
}`)
output := ConvertClaudeRequestToAntigravity("claude-sonnet-4-5", inputJSON, false)
outputStr := string(output)
if !gjson.Valid(outputStr) {
t.Fatalf("Result is not valid JSON:\n%s", outputStr)
}
funcResp := gjson.Get(outputStr, "request.contents.0.parts.0.functionResponse")
if !funcResp.Exists() {
t.Fatal("functionResponse should exist")
}
// response.result should be empty (image only)
if funcResp.Get("response.result").String() != "" {
t.Errorf("Expected empty response.result for image-only content, got '%s'", funcResp.Get("response.result").String())
}
// Image should be in functionResponse.parts[0].inlineData
inlineData := funcResp.Get("parts.0.inlineData")
if !inlineData.Exists() {
t.Fatal("functionResponse.parts[0].inlineData should exist")
}
if inlineData.Get("mimeType").String() != "image/jpeg" {
t.Errorf("Expected mimeType 'image/jpeg', got '%s'", inlineData.Get("mimeType").String())
}
// Image should NOT be in outer parts
outerParts := gjson.Get(outputStr, "request.contents.0.parts")
if outerParts.IsArray() && len(outerParts.Array()) > 1 {
t.Errorf("Expected only 1 outer part, got %d", len(outerParts.Array()))
}
}
func TestConvertClaudeRequestToAntigravity_ToolResultWithMultipleImagesAndTexts(t *testing.T) {
// tool_result with array content: 2 text items + 2 images
// All images go into functionResponse.parts, texts into response.result array
inputJSON := []byte(`{
"model": "claude-3-5-sonnet-20240620",
"messages": [
{
"role": "user",
"content": [
{
"type": "tool_result",
"tool_use_id": "Multi-001",
"content": [
{"type": "text", "text": "First text"},
{
"type": "image",
"source": {"type": "base64", "media_type": "image/png", "data": "AAAA"}
},
{"type": "text", "text": "Second text"},
{
"type": "image",
"source": {"type": "base64", "media_type": "image/jpeg", "data": "BBBB"}
}
]
}
]
}
]
}`)
output := ConvertClaudeRequestToAntigravity("claude-sonnet-4-5", inputJSON, false)
outputStr := string(output)
if !gjson.Valid(outputStr) {
t.Fatalf("Result is not valid JSON:\n%s", outputStr)
}
funcResp := gjson.Get(outputStr, "request.contents.0.parts.0.functionResponse")
if !funcResp.Exists() {
t.Fatal("functionResponse should exist")
}
// Multiple text items => response.result is an array
resultArr := funcResp.Get("response.result")
if !resultArr.IsArray() {
t.Fatalf("Expected response.result to be an array, got: %s", resultArr.Raw)
}
results := resultArr.Array()
if len(results) != 2 {
t.Fatalf("Expected 2 result items, got %d", len(results))
}
// Both images should be in functionResponse.parts
imgParts := funcResp.Get("parts").Array()
if len(imgParts) != 2 {
t.Fatalf("Expected 2 image parts in functionResponse.parts, got %d", len(imgParts))
}
if imgParts[0].Get("inlineData.mimeType").String() != "image/png" {
t.Errorf("Expected first image mimeType 'image/png', got '%s'", imgParts[0].Get("inlineData.mimeType").String())
}
if imgParts[0].Get("inlineData.data").String() != "AAAA" {
t.Errorf("Expected first image data 'AAAA', got '%s'", imgParts[0].Get("inlineData.data").String())
}
if imgParts[1].Get("inlineData.mimeType").String() != "image/jpeg" {
t.Errorf("Expected second image mimeType 'image/jpeg', got '%s'", imgParts[1].Get("inlineData.mimeType").String())
}
if imgParts[1].Get("inlineData.data").String() != "BBBB" {
t.Errorf("Expected second image data 'BBBB', got '%s'", imgParts[1].Get("inlineData.data").String())
}
// Only 1 outer part (the functionResponse itself)
outerParts := gjson.Get(outputStr, "request.contents.0.parts").Array()
if len(outerParts) != 1 {
t.Errorf("Expected 1 outer part, got %d", len(outerParts))
}
}
func TestConvertClaudeRequestToAntigravity_ToolResultWithOnlyMultipleImages(t *testing.T) {
// tool_result with only images (no text) — response.result should be empty string
inputJSON := []byte(`{
"model": "claude-3-5-sonnet-20240620",
"messages": [
{
"role": "user",
"content": [
{
"type": "tool_result",
"tool_use_id": "ImgOnly-001",
"content": [
{
"type": "image",
"source": {"type": "base64", "media_type": "image/png", "data": "PNG1"}
},
{
"type": "image",
"source": {"type": "base64", "media_type": "image/gif", "data": "GIF1"}
}
]
}
]
}
]
}`)
output := ConvertClaudeRequestToAntigravity("claude-sonnet-4-5", inputJSON, false)
outputStr := string(output)
if !gjson.Valid(outputStr) {
t.Fatalf("Result is not valid JSON:\n%s", outputStr)
}
funcResp := gjson.Get(outputStr, "request.contents.0.parts.0.functionResponse")
if !funcResp.Exists() {
t.Fatal("functionResponse should exist")
}
// No text => response.result should be empty string
if funcResp.Get("response.result").String() != "" {
t.Errorf("Expected empty response.result, got '%s'", funcResp.Get("response.result").String())
}
// Both images in functionResponse.parts
imgParts := funcResp.Get("parts").Array()
if len(imgParts) != 2 {
t.Fatalf("Expected 2 image parts, got %d", len(imgParts))
}
if imgParts[0].Get("inlineData.mimeType").String() != "image/png" {
t.Error("first image mimeType mismatch")
}
if imgParts[1].Get("inlineData.mimeType").String() != "image/gif" {
t.Error("second image mimeType mismatch")
}
// Only 1 outer part
outerParts := gjson.Get(outputStr, "request.contents.0.parts").Array()
if len(outerParts) != 1 {
t.Errorf("Expected 1 outer part, got %d", len(outerParts))
}
}
func TestConvertClaudeRequestToAntigravity_ToolResultImageNotBase64(t *testing.T) {
// image with source.type != "base64" should be treated as non-image (falls through)
inputJSON := []byte(`{
"model": "claude-3-5-sonnet-20240620",
"messages": [
{
"role": "user",
"content": [
{
"type": "tool_result",
"tool_use_id": "NotB64-001",
"content": [
{"type": "text", "text": "some output"},
{
"type": "image",
"source": {"type": "url", "url": "https://example.com/img.png"}
}
]
}
]
}
]
}`)
output := ConvertClaudeRequestToAntigravity("claude-sonnet-4-5", inputJSON, false)
outputStr := string(output)
if !gjson.Valid(outputStr) {
t.Fatalf("Result is not valid JSON:\n%s", outputStr)
}
funcResp := gjson.Get(outputStr, "request.contents.0.parts.0.functionResponse")
if !funcResp.Exists() {
t.Fatal("functionResponse should exist")
}
// Non-base64 image is treated as non-image, so it goes into the filtered results
// along with the text item. Since there are 2 non-image items, result is array.
resultArr := funcResp.Get("response.result")
if !resultArr.IsArray() {
t.Fatalf("Expected response.result to be an array (2 non-image items), got: %s", resultArr.Raw)
}
results := resultArr.Array()
if len(results) != 2 {
t.Fatalf("Expected 2 result items, got %d", len(results))
}
// No functionResponse.parts (no base64 images collected)
if funcResp.Get("parts").Exists() {
t.Error("functionResponse.parts should NOT exist when no base64 images")
}
}
func TestConvertClaudeRequestToAntigravity_ToolResultImageMissingData(t *testing.T) {
// image with source.type=base64 but missing data field
inputJSON := []byte(`{
"model": "claude-3-5-sonnet-20240620",
"messages": [
{
"role": "user",
"content": [
{
"type": "tool_result",
"tool_use_id": "NoData-001",
"content": [
{"type": "text", "text": "output"},
{
"type": "image",
"source": {"type": "base64", "media_type": "image/png"}
}
]
}
]
}
]
}`)
output := ConvertClaudeRequestToAntigravity("claude-sonnet-4-5", inputJSON, false)
outputStr := string(output)
if !gjson.Valid(outputStr) {
t.Fatalf("Result is not valid JSON:\n%s", outputStr)
}
funcResp := gjson.Get(outputStr, "request.contents.0.parts.0.functionResponse")
if !funcResp.Exists() {
t.Fatal("functionResponse should exist")
}
// The image is still classified as base64 image (type check passes),
// but data field is missing => inlineData has mimeType but no data
imgParts := funcResp.Get("parts").Array()
if len(imgParts) != 1 {
t.Fatalf("Expected 1 image part, got %d", len(imgParts))
}
if imgParts[0].Get("inlineData.mimeType").String() != "image/png" {
t.Error("mimeType should still be set")
}
if imgParts[0].Get("inlineData.data").Exists() {
t.Error("data should not exist when source.data is missing")
}
}
func TestConvertClaudeRequestToAntigravity_ToolResultImageMissingMediaType(t *testing.T) {
// image with source.type=base64 but missing media_type field
inputJSON := []byte(`{
"model": "claude-3-5-sonnet-20240620",
"messages": [
{
"role": "user",
"content": [
{
"type": "tool_result",
"tool_use_id": "NoMime-001",
"content": [
{"type": "text", "text": "output"},
{
"type": "image",
"source": {"type": "base64", "data": "AAAA"}
}
]
}
]
}
]
}`)
output := ConvertClaudeRequestToAntigravity("claude-sonnet-4-5", inputJSON, false)
outputStr := string(output)
if !gjson.Valid(outputStr) {
t.Fatalf("Result is not valid JSON:\n%s", outputStr)
}
funcResp := gjson.Get(outputStr, "request.contents.0.parts.0.functionResponse")
if !funcResp.Exists() {
t.Fatal("functionResponse should exist")
}
// The image is still classified as base64 image,
// but media_type is missing => inlineData has data but no mimeType
imgParts := funcResp.Get("parts").Array()
if len(imgParts) != 1 {
t.Fatalf("Expected 1 image part, got %d", len(imgParts))
}
if imgParts[0].Get("inlineData.mimeType").Exists() {
t.Error("mimeType should not exist when media_type is missing")
}
if imgParts[0].Get("inlineData.data").String() != "AAAA" {
t.Error("data should still be set")
}
}
func TestConvertClaudeRequestToAntigravity_ToolAndThinking_NoExistingSystem(t *testing.T) {
// When tools + thinking but no system instruction, should create one with hint
inputJSON := []byte(`{

View File

@@ -93,3 +93,81 @@ func TestConvertGeminiRequestToAntigravity_ParallelFunctionCalls(t *testing.T) {
}
}
}
func TestFixCLIToolResponse_PreservesFunctionResponseParts(t *testing.T) {
// When functionResponse contains a "parts" field with inlineData (from Claude
// translator's image embedding), fixCLIToolResponse should preserve it as-is.
// parseFunctionResponseRaw returns response.Raw for valid JSON objects,
// so extra fields like "parts" survive the pipeline.
input := `{
"model": "claude-opus-4-6-thinking",
"request": {
"contents": [
{
"role": "model",
"parts": [
{
"functionCall": {"name": "screenshot", "args": {}}
}
]
},
{
"role": "function",
"parts": [
{
"functionResponse": {
"id": "tool-001",
"name": "screenshot",
"response": {"result": "Screenshot taken"},
"parts": [
{"inlineData": {"mimeType": "image/png", "data": "iVBOR"}}
]
}
}
]
}
]
}
}`
result, err := fixCLIToolResponse(input)
if err != nil {
t.Fatalf("fixCLIToolResponse failed: %v", err)
}
// Find the function response content (role=function)
contents := gjson.Get(result, "request.contents").Array()
var funcContent gjson.Result
for _, c := range contents {
if c.Get("role").String() == "function" {
funcContent = c
break
}
}
if !funcContent.Exists() {
t.Fatal("function role content should exist in output")
}
// The functionResponse should be preserved with its parts field
funcResp := funcContent.Get("parts.0.functionResponse")
if !funcResp.Exists() {
t.Fatal("functionResponse should exist in output")
}
// Verify the parts field with inlineData is preserved
inlineParts := funcResp.Get("parts").Array()
if len(inlineParts) != 1 {
t.Fatalf("Expected 1 inlineData part in functionResponse.parts, got %d", len(inlineParts))
}
if inlineParts[0].Get("inlineData.mimeType").String() != "image/png" {
t.Errorf("Expected mimeType 'image/png', got '%s'", inlineParts[0].Get("inlineData.mimeType").String())
}
if inlineParts[0].Get("inlineData.data").String() != "iVBOR" {
t.Errorf("Expected data 'iVBOR', got '%s'", inlineParts[0].Get("inlineData.data").String())
}
// Verify response.result is also preserved
if funcResp.Get("response.result").String() != "Screenshot taken" {
t.Errorf("Expected response.result 'Screenshot taken', got '%s'", funcResp.Get("response.result").String())
}
}

View File

@@ -187,7 +187,7 @@ func ConvertOpenAIRequestToAntigravity(modelName string, inputRawJSON []byte, _
if len(pieces) == 2 && len(pieces[1]) > 7 {
mime := pieces[0]
data := pieces[1][7:]
node, _ = sjson.SetBytes(node, "parts."+itoa(p)+".inlineData.mime_type", mime)
node, _ = sjson.SetBytes(node, "parts."+itoa(p)+".inlineData.mimeType", mime)
node, _ = sjson.SetBytes(node, "parts."+itoa(p)+".inlineData.data", data)
node, _ = sjson.SetBytes(node, "parts."+itoa(p)+".thoughtSignature", geminiCLIFunctionThoughtSignature)
p++
@@ -201,7 +201,7 @@ func ConvertOpenAIRequestToAntigravity(modelName string, inputRawJSON []byte, _
ext = sp[len(sp)-1]
}
if mimeType, ok := misc.MimeTypes[ext]; ok {
node, _ = sjson.SetBytes(node, "parts."+itoa(p)+".inlineData.mime_type", mimeType)
node, _ = sjson.SetBytes(node, "parts."+itoa(p)+".inlineData.mimeType", mimeType)
node, _ = sjson.SetBytes(node, "parts."+itoa(p)+".inlineData.data", fileData)
p++
} else {
@@ -235,7 +235,7 @@ func ConvertOpenAIRequestToAntigravity(modelName string, inputRawJSON []byte, _
if len(pieces) == 2 && len(pieces[1]) > 7 {
mime := pieces[0]
data := pieces[1][7:]
node, _ = sjson.SetBytes(node, "parts."+itoa(p)+".inlineData.mime_type", mime)
node, _ = sjson.SetBytes(node, "parts."+itoa(p)+".inlineData.mimeType", mime)
node, _ = sjson.SetBytes(node, "parts."+itoa(p)+".inlineData.data", data)
node, _ = sjson.SetBytes(node, "parts."+itoa(p)+".thoughtSignature", geminiCLIFunctionThoughtSignature)
p++

View File

@@ -95,9 +95,9 @@ func ConvertAntigravityResponseToOpenAI(_ context.Context, _ string, originalReq
if totalTokenCountResult := usageResult.Get("totalTokenCount"); totalTokenCountResult.Exists() {
template, _ = sjson.Set(template, "usage.total_tokens", totalTokenCountResult.Int())
}
promptTokenCount := usageResult.Get("promptTokenCount").Int() - cachedTokenCount
promptTokenCount := usageResult.Get("promptTokenCount").Int()
thoughtsTokenCount := usageResult.Get("thoughtsTokenCount").Int()
template, _ = sjson.Set(template, "usage.prompt_tokens", promptTokenCount+thoughtsTokenCount)
template, _ = sjson.Set(template, "usage.prompt_tokens", promptTokenCount)
if thoughtsTokenCount > 0 {
template, _ = sjson.Set(template, "usage.completion_tokens_details.reasoning_tokens", thoughtsTokenCount)
}

View File

@@ -199,6 +199,21 @@ func ConvertOpenAIRequestToClaude(modelName string, inputRawJSON []byte, stream
msg, _ = sjson.SetRaw(msg, "content.-1", imagePart)
}
}
case "file":
fileData := part.Get("file.file_data").String()
if strings.HasPrefix(fileData, "data:") {
semicolonIdx := strings.Index(fileData, ";")
commaIdx := strings.Index(fileData, ",")
if semicolonIdx != -1 && commaIdx != -1 && commaIdx > semicolonIdx {
mediaType := strings.TrimPrefix(fileData[:semicolonIdx], "data:")
data := fileData[commaIdx+1:]
docPart := `{"type":"document","source":{"type":"base64","media_type":"","data":""}}`
docPart, _ = sjson.Set(docPart, "source.media_type", mediaType)
docPart, _ = sjson.Set(docPart, "source.data", data)
msg, _ = sjson.SetRaw(msg, "content.-1", docPart)
}
}
}
return true
})

View File

@@ -155,6 +155,7 @@ func ConvertOpenAIResponsesRequestToClaude(modelName string, inputRawJSON []byte
var textAggregate strings.Builder
var partsJSON []string
hasImage := false
hasFile := false
if parts := item.Get("content"); parts.Exists() && parts.IsArray() {
parts.ForEach(func(_, part gjson.Result) bool {
ptype := part.Get("type").String()
@@ -207,6 +208,30 @@ func ConvertOpenAIResponsesRequestToClaude(modelName string, inputRawJSON []byte
hasImage = true
}
}
case "input_file":
fileData := part.Get("file_data").String()
if fileData != "" {
mediaType := "application/octet-stream"
data := fileData
if strings.HasPrefix(fileData, "data:") {
trimmed := strings.TrimPrefix(fileData, "data:")
mediaAndData := strings.SplitN(trimmed, ";base64,", 2)
if len(mediaAndData) == 2 {
if mediaAndData[0] != "" {
mediaType = mediaAndData[0]
}
data = mediaAndData[1]
}
}
contentPart := `{"type":"document","source":{"type":"base64","media_type":"","data":""}}`
contentPart, _ = sjson.Set(contentPart, "source.media_type", mediaType)
contentPart, _ = sjson.Set(contentPart, "source.data", data)
partsJSON = append(partsJSON, contentPart)
if role == "" {
role = "user"
}
hasFile = true
}
}
return true
})
@@ -228,7 +253,7 @@ func ConvertOpenAIResponsesRequestToClaude(modelName string, inputRawJSON []byte
if len(partsJSON) > 0 {
msg := `{"role":"","content":[]}`
msg, _ = sjson.Set(msg, "role", role)
if len(partsJSON) == 1 && !hasImage {
if len(partsJSON) == 1 && !hasImage && !hasFile {
// Preserve legacy behavior for single text content
msg, _ = sjson.Delete(msg, "content")
textPart := gjson.Parse(partsJSON[0])

View File

@@ -180,7 +180,19 @@ func ConvertOpenAIRequestToCodex(modelName string, inputRawJSON []byte, stream b
msg, _ = sjson.SetRaw(msg, "content.-1", part)
}
case "file":
// Files are not specified in examples; skip for now
if role == "user" {
fileData := it.Get("file.file_data").String()
filename := it.Get("file.filename").String()
if fileData != "" {
part := `{}`
part, _ = sjson.Set(part, "type", "input_file")
part, _ = sjson.Set(part, "file_data", fileData)
if filename != "" {
part, _ = sjson.Set(part, "filename", filename)
}
msg, _ = sjson.SetRaw(msg, "content.-1", part)
}
}
}
}
}

View File

@@ -26,6 +26,8 @@ func ConvertOpenAIResponsesRequestToCodex(modelName string, inputRawJSON []byte,
rawJSON, _ = sjson.DeleteBytes(rawJSON, "temperature")
rawJSON, _ = sjson.DeleteBytes(rawJSON, "top_p")
rawJSON, _ = sjson.DeleteBytes(rawJSON, "service_tier")
rawJSON, _ = sjson.DeleteBytes(rawJSON, "truncation")
rawJSON = applyResponsesCompactionCompatibility(rawJSON)
// Delete the user field as it is not supported by the Codex upstream.
rawJSON, _ = sjson.DeleteBytes(rawJSON, "user")
@@ -36,6 +38,23 @@ func ConvertOpenAIResponsesRequestToCodex(modelName string, inputRawJSON []byte,
return rawJSON
}
// applyResponsesCompactionCompatibility handles OpenAI Responses context_management.compaction
// for Codex upstream compatibility.
//
// Codex /responses currently rejects context_management with:
// {"detail":"Unsupported parameter: context_management"}.
//
// Compatibility strategy:
// 1) Remove context_management before forwarding to Codex upstream.
func applyResponsesCompactionCompatibility(rawJSON []byte) []byte {
if !gjson.GetBytes(rawJSON, "context_management").Exists() {
return rawJSON
}
rawJSON, _ = sjson.DeleteBytes(rawJSON, "context_management")
return rawJSON
}
// convertSystemRoleToDeveloper traverses the input array and converts any message items
// with role "system" to role "developer". This is necessary because Codex API does not
// accept "system" role in the input array.

View File

@@ -280,3 +280,41 @@ func TestUserFieldDeletion(t *testing.T) {
t.Errorf("user field should be deleted, but it was found with value: %s", userField.Raw)
}
}
func TestContextManagementCompactionCompatibility(t *testing.T) {
inputJSON := []byte(`{
"model": "gpt-5.2",
"context_management": [
{
"type": "compaction",
"compact_threshold": 12000
}
],
"input": [{"role":"user","content":"hello"}]
}`)
output := ConvertOpenAIResponsesRequestToCodex("gpt-5.2", inputJSON, false)
outputStr := string(output)
if gjson.Get(outputStr, "context_management").Exists() {
t.Fatalf("context_management should be removed for Codex compatibility")
}
if gjson.Get(outputStr, "truncation").Exists() {
t.Fatalf("truncation should be removed for Codex compatibility")
}
}
func TestTruncationRemovedForCodexCompatibility(t *testing.T) {
inputJSON := []byte(`{
"model": "gpt-5.2",
"truncation": "disabled",
"input": [{"role":"user","content":"hello"}]
}`)
output := ConvertOpenAIResponsesRequestToCodex("gpt-5.2", inputJSON, false)
outputStr := string(output)
if gjson.Get(outputStr, "truncation").Exists() {
t.Fatalf("truncation should be removed for Codex compatibility")
}
}

View File

@@ -100,7 +100,7 @@ func ConvertCliResponseToOpenAI(_ context.Context, _ string, originalRequestRawJ
}
promptTokenCount := usageResult.Get("promptTokenCount").Int()
thoughtsTokenCount := usageResult.Get("thoughtsTokenCount").Int()
template, _ = sjson.Set(template, "usage.prompt_tokens", promptTokenCount+thoughtsTokenCount)
template, _ = sjson.Set(template, "usage.prompt_tokens", promptTokenCount)
if thoughtsTokenCount > 0 {
template, _ = sjson.Set(template, "usage.completion_tokens_details.reasoning_tokens", thoughtsTokenCount)
}

View File

@@ -100,9 +100,9 @@ func ConvertGeminiResponseToOpenAI(_ context.Context, _ string, originalRequestR
if totalTokenCountResult := usageResult.Get("totalTokenCount"); totalTokenCountResult.Exists() {
baseTemplate, _ = sjson.Set(baseTemplate, "usage.total_tokens", totalTokenCountResult.Int())
}
promptTokenCount := usageResult.Get("promptTokenCount").Int() - cachedTokenCount
promptTokenCount := usageResult.Get("promptTokenCount").Int()
thoughtsTokenCount := usageResult.Get("thoughtsTokenCount").Int()
baseTemplate, _ = sjson.Set(baseTemplate, "usage.prompt_tokens", promptTokenCount+thoughtsTokenCount)
baseTemplate, _ = sjson.Set(baseTemplate, "usage.prompt_tokens", promptTokenCount)
if thoughtsTokenCount > 0 {
baseTemplate, _ = sjson.Set(baseTemplate, "usage.completion_tokens_details.reasoning_tokens", thoughtsTokenCount)
}
@@ -297,7 +297,7 @@ func ConvertGeminiResponseToOpenAINonStream(_ context.Context, _ string, origina
promptTokenCount := usageResult.Get("promptTokenCount").Int()
thoughtsTokenCount := usageResult.Get("thoughtsTokenCount").Int()
cachedTokenCount := usageResult.Get("cachedContentTokenCount").Int()
template, _ = sjson.Set(template, "usage.prompt_tokens", promptTokenCount+thoughtsTokenCount)
template, _ = sjson.Set(template, "usage.prompt_tokens", promptTokenCount)
if thoughtsTokenCount > 0 {
template, _ = sjson.Set(template, "usage.completion_tokens_details.reasoning_tokens", thoughtsTokenCount)
}

View File

@@ -531,8 +531,8 @@ func ConvertGeminiResponseToOpenAIResponses(_ context.Context, modelName string,
// usage mapping
if um := root.Get("usageMetadata"); um.Exists() {
// input tokens = prompt + thoughts
input := um.Get("promptTokenCount").Int() + um.Get("thoughtsTokenCount").Int()
// input tokens = prompt only (thoughts go to output)
input := um.Get("promptTokenCount").Int()
completed, _ = sjson.Set(completed, "response.usage.input_tokens", input)
// cached token details: align with OpenAI "cached_tokens" semantics.
completed, _ = sjson.Set(completed, "response.usage.input_tokens_details.cached_tokens", um.Get("cachedContentTokenCount").Int())
@@ -737,8 +737,8 @@ func ConvertGeminiResponseToOpenAIResponsesNonStream(_ context.Context, _ string
// usage mapping
if um := root.Get("usageMetadata"); um.Exists() {
// input tokens = prompt + thoughts
input := um.Get("promptTokenCount").Int() + um.Get("thoughtsTokenCount").Int()
// input tokens = prompt only (thoughts go to output)
input := um.Get("promptTokenCount").Int()
resp, _ = sjson.Set(resp, "usage.input_tokens", input)
// cached token details: align with OpenAI "cached_tokens" semantics.
resp, _ = sjson.Set(resp, "usage.input_tokens_details.cached_tokens", um.Get("cachedContentTokenCount").Int())

View File

@@ -53,19 +53,25 @@ var KnownCommandTools = map[string]bool{
"execute_python": true,
}
// RequiredFieldsByTool maps tool names to their required fields.
// If any of these fields are missing, the tool input is considered truncated.
var RequiredFieldsByTool = map[string][]string{
"Write": {"file_path", "content"},
"write_to_file": {"path", "content"},
"fsWrite": {"path", "content"},
"create_file": {"path", "content"},
"edit_file": {"path"},
"apply_diff": {"path", "diff"},
"str_replace_editor": {"path", "old_str", "new_str"},
"Bash": {"command"},
"execute": {"command"},
"run_command": {"command"},
// RequiredFieldsByTool maps tool names to their required field groups.
// Each outer element is a required group; each inner slice lists alternative field names (OR logic).
// A group is satisfied when ANY one of its alternatives exists in the parsed input.
// All groups must be satisfied for the tool input to be considered valid.
//
// Example:
// {{"cmd", "command"}} means the tool needs EITHER "cmd" OR "command".
// {{"file_path"}, {"content"}} means the tool needs BOTH "file_path" AND "content".
var RequiredFieldsByTool = map[string][][]string{
"Write": {{"file_path"}, {"content"}},
"write_to_file": {{"path"}, {"content"}},
"fsWrite": {{"path"}, {"content"}},
"create_file": {{"path"}, {"content"}},
"edit_file": {{"path"}},
"apply_diff": {{"path"}, {"diff"}},
"str_replace_editor": {{"path"}, {"old_str"}, {"new_str"}},
"Bash": {{"cmd", "command"}},
"execute": {{"command"}},
"run_command": {{"command"}},
}
// DetectTruncation checks if the tool use input appears to be truncated.
@@ -104,9 +110,9 @@ func DetectTruncation(toolName, toolUseID, rawInput string, parsedInput map[stri
// Scenario 3: JSON parsed but critical fields are missing
if parsedInput != nil {
requiredFields, hasRequirements := RequiredFieldsByTool[toolName]
requiredGroups, hasRequirements := RequiredFieldsByTool[toolName]
if hasRequirements {
missingFields := findMissingRequiredFields(parsedInput, requiredFields)
missingFields := findMissingRequiredFields(parsedInput, requiredGroups)
if len(missingFields) > 0 {
info.IsTruncated = true
info.TruncationType = TruncationTypeMissingFields
@@ -253,12 +259,21 @@ func extractParsedFieldNames(parsed map[string]interface{}) map[string]string {
return fields
}
// findMissingRequiredFields checks which required fields are missing from the parsed input.
func findMissingRequiredFields(parsed map[string]interface{}, required []string) []string {
// findMissingRequiredFields checks which required field groups are unsatisfied.
// Each group is a slice of alternative field names; the group is satisfied when ANY alternative exists.
// Returns the list of unsatisfied groups (represented by their alternatives joined with "/").
func findMissingRequiredFields(parsed map[string]interface{}, requiredGroups [][]string) []string {
var missing []string
for _, field := range required {
if _, exists := parsed[field]; !exists {
missing = append(missing, field)
for _, group := range requiredGroups {
satisfied := false
for _, field := range group {
if _, exists := parsed[field]; exists {
satisfied = true
break
}
}
if !satisfied {
missing = append(missing, strings.Join(group, "/"))
}
}
return missing

View File

@@ -578,6 +578,7 @@ func processOpenAIMessages(messages gjson.Result, modelID, origin string) ([]Kir
// Truncate history if too long to prevent Kiro API errors
history = truncateHistoryIfNeeded(history)
history, currentToolResults = filterOrphanedToolResults(history, currentToolResults)
return history, currentUserMsg, currentToolResults
}
@@ -593,6 +594,61 @@ func truncateHistoryIfNeeded(history []KiroHistoryMessage) []KiroHistoryMessage
return history[len(history)-kiroMaxHistoryMessages:]
}
func filterOrphanedToolResults(history []KiroHistoryMessage, currentToolResults []KiroToolResult) ([]KiroHistoryMessage, []KiroToolResult) {
// Remove tool results with no matching tool_use in retained history.
// This happens after truncation when the assistant turn that produced tool_use
// is dropped but a later user/tool_result survives.
validToolUseIDs := make(map[string]bool)
for _, h := range history {
if h.AssistantResponseMessage == nil {
continue
}
for _, tu := range h.AssistantResponseMessage.ToolUses {
validToolUseIDs[tu.ToolUseID] = true
}
}
for i, h := range history {
if h.UserInputMessage == nil || h.UserInputMessage.UserInputMessageContext == nil {
continue
}
ctx := h.UserInputMessage.UserInputMessageContext
if len(ctx.ToolResults) == 0 {
continue
}
filtered := make([]KiroToolResult, 0, len(ctx.ToolResults))
for _, tr := range ctx.ToolResults {
if validToolUseIDs[tr.ToolUseID] {
filtered = append(filtered, tr)
continue
}
log.Debugf("kiro-openai: dropping orphaned tool_result in history[%d]: toolUseId=%s (no matching tool_use)", i, tr.ToolUseID)
}
ctx.ToolResults = filtered
if len(ctx.ToolResults) == 0 && len(ctx.Tools) == 0 {
h.UserInputMessage.UserInputMessageContext = nil
}
}
if len(currentToolResults) > 0 {
filtered := make([]KiroToolResult, 0, len(currentToolResults))
for _, tr := range currentToolResults {
if validToolUseIDs[tr.ToolUseID] {
filtered = append(filtered, tr)
continue
}
log.Debugf("kiro-openai: dropping orphaned tool_result in currentMessage: toolUseId=%s (no matching tool_use)", tr.ToolUseID)
}
if len(filtered) != len(currentToolResults) {
log.Infof("kiro-openai: dropped %d orphaned tool_result(s) from currentMessage", len(currentToolResults)-len(filtered))
}
currentToolResults = filtered
}
return history, currentToolResults
}
// buildUserMessageFromOpenAI builds a user message from OpenAI format and extracts tool results
func buildUserMessageFromOpenAI(msg gjson.Result, modelID, origin string) (KiroUserInputMessage, []KiroToolResult) {
content := msg.Get("content")

View File

@@ -384,3 +384,57 @@ func TestAssistantEndsConversation(t *testing.T) {
t.Error("Expected a 'Continue' message to be created when assistant is last")
}
}
func TestFilterOrphanedToolResults_RemovesHistoryAndCurrentOrphans(t *testing.T) {
history := []KiroHistoryMessage{
{
AssistantResponseMessage: &KiroAssistantResponseMessage{
Content: "assistant",
ToolUses: []KiroToolUse{
{ToolUseID: "keep-1", Name: "Read", Input: map[string]interface{}{}},
},
},
},
{
UserInputMessage: &KiroUserInputMessage{
Content: "user-with-mixed-results",
UserInputMessageContext: &KiroUserInputMessageContext{
ToolResults: []KiroToolResult{
{ToolUseID: "keep-1", Status: "success", Content: []KiroTextContent{{Text: "ok"}}},
{ToolUseID: "orphan-1", Status: "success", Content: []KiroTextContent{{Text: "bad"}}},
},
},
},
},
{
UserInputMessage: &KiroUserInputMessage{
Content: "user-only-orphans",
UserInputMessageContext: &KiroUserInputMessageContext{
ToolResults: []KiroToolResult{
{ToolUseID: "orphan-2", Status: "success", Content: []KiroTextContent{{Text: "bad"}}},
},
},
},
},
}
currentToolResults := []KiroToolResult{
{ToolUseID: "keep-1", Status: "success", Content: []KiroTextContent{{Text: "ok"}}},
{ToolUseID: "orphan-3", Status: "success", Content: []KiroTextContent{{Text: "bad"}}},
}
filteredHistory, filteredCurrent := filterOrphanedToolResults(history, currentToolResults)
ctx1 := filteredHistory[1].UserInputMessage.UserInputMessageContext
if ctx1 == nil || len(ctx1.ToolResults) != 1 || ctx1.ToolResults[0].ToolUseID != "keep-1" {
t.Fatalf("expected mixed history message to keep only keep-1, got: %+v", ctx1)
}
if filteredHistory[2].UserInputMessage.UserInputMessageContext != nil {
t.Fatalf("expected orphan-only history context to be removed")
}
if len(filteredCurrent) != 1 || filteredCurrent[0].ToolUseID != "keep-1" {
t.Fatalf("expected current tool results to keep only keep-1, got: %+v", filteredCurrent)
}
}

View File

@@ -717,6 +717,12 @@ func (h *BaseAPIHandler) ExecuteStreamWithAuthManager(ctx context.Context, handl
return
}
if len(chunk.Payload) > 0 {
if handlerType == "openai-response" {
if err := validateSSEDataJSON(chunk.Payload); err != nil {
_ = sendErr(&interfaces.ErrorMessage{StatusCode: http.StatusBadGateway, Error: err})
return
}
}
sentPayload = true
if okSendData := sendData(cloneBytes(chunk.Payload)); !okSendData {
return
@@ -728,6 +734,35 @@ func (h *BaseAPIHandler) ExecuteStreamWithAuthManager(ctx context.Context, handl
return dataChan, upstreamHeaders, errChan
}
func validateSSEDataJSON(chunk []byte) error {
for _, line := range bytes.Split(chunk, []byte("\n")) {
line = bytes.TrimSpace(line)
if len(line) == 0 {
continue
}
if !bytes.HasPrefix(line, []byte("data:")) {
continue
}
data := bytes.TrimSpace(line[5:])
if len(data) == 0 {
continue
}
if bytes.Equal(data, []byte("[DONE]")) {
continue
}
if json.Valid(data) {
continue
}
const max = 512
preview := data
if len(preview) > max {
preview = preview[:max]
}
return fmt.Errorf("invalid SSE data JSON (len=%d): %q", len(data), preview)
}
return nil
}
func statusFromError(err error) int {
if err == nil {
return 0

View File

@@ -134,6 +134,37 @@ type authAwareStreamExecutor struct {
authIDs []string
}
type invalidJSONStreamExecutor struct{}
func (e *invalidJSONStreamExecutor) Identifier() string { return "codex" }
func (e *invalidJSONStreamExecutor) Execute(context.Context, *coreauth.Auth, coreexecutor.Request, coreexecutor.Options) (coreexecutor.Response, error) {
return coreexecutor.Response{}, &coreauth.Error{Code: "not_implemented", Message: "Execute not implemented"}
}
func (e *invalidJSONStreamExecutor) ExecuteStream(context.Context, *coreauth.Auth, coreexecutor.Request, coreexecutor.Options) (*coreexecutor.StreamResult, error) {
ch := make(chan coreexecutor.StreamChunk, 1)
ch <- coreexecutor.StreamChunk{Payload: []byte("event: response.completed\ndata: {\"type\"")}
close(ch)
return &coreexecutor.StreamResult{Chunks: ch}, nil
}
func (e *invalidJSONStreamExecutor) Refresh(ctx context.Context, auth *coreauth.Auth) (*coreauth.Auth, error) {
return auth, nil
}
func (e *invalidJSONStreamExecutor) CountTokens(context.Context, *coreauth.Auth, coreexecutor.Request, coreexecutor.Options) (coreexecutor.Response, error) {
return coreexecutor.Response{}, &coreauth.Error{Code: "not_implemented", Message: "CountTokens not implemented"}
}
func (e *invalidJSONStreamExecutor) HttpRequest(ctx context.Context, auth *coreauth.Auth, req *http.Request) (*http.Response, error) {
return nil, &coreauth.Error{
Code: "not_implemented",
Message: "HttpRequest not implemented",
HTTPStatus: http.StatusNotImplemented,
}
}
func (e *authAwareStreamExecutor) Identifier() string { return "codex" }
func (e *authAwareStreamExecutor) Execute(context.Context, *coreauth.Auth, coreexecutor.Request, coreexecutor.Options) (coreexecutor.Response, error) {
@@ -524,3 +555,55 @@ func TestExecuteStreamWithAuthManager_SelectedAuthCallbackReceivesAuthID(t *test
t.Fatalf("selectedAuthID = %q, want %q", selectedAuthID, "auth2")
}
}
func TestExecuteStreamWithAuthManager_ValidatesOpenAIResponsesStreamDataJSON(t *testing.T) {
executor := &invalidJSONStreamExecutor{}
manager := coreauth.NewManager(nil, nil, nil)
manager.RegisterExecutor(executor)
auth1 := &coreauth.Auth{
ID: "auth1",
Provider: "codex",
Status: coreauth.StatusActive,
Metadata: map[string]any{"email": "test1@example.com"},
}
if _, err := manager.Register(context.Background(), auth1); err != nil {
t.Fatalf("manager.Register(auth1): %v", err)
}
registry.GetGlobalRegistry().RegisterClient(auth1.ID, auth1.Provider, []*registry.ModelInfo{{ID: "test-model"}})
t.Cleanup(func() {
registry.GetGlobalRegistry().UnregisterClient(auth1.ID)
})
handler := NewBaseAPIHandlers(&sdkconfig.SDKConfig{}, manager)
dataChan, _, errChan := handler.ExecuteStreamWithAuthManager(context.Background(), "openai-response", "test-model", []byte(`{"model":"test-model"}`), "")
if dataChan == nil || errChan == nil {
t.Fatalf("expected non-nil channels")
}
var got []byte
for chunk := range dataChan {
got = append(got, chunk...)
}
if len(got) != 0 {
t.Fatalf("expected empty payload, got %q", string(got))
}
gotErr := false
for msg := range errChan {
if msg == nil {
continue
}
if msg.StatusCode != http.StatusBadGateway {
t.Fatalf("expected status %d, got %d", http.StatusBadGateway, msg.StatusCode)
}
if msg.Error == nil {
t.Fatalf("expected error")
}
gotErr = true
}
if !gotErr {
t.Fatalf("expected terminal error")
}
}

View File

@@ -418,8 +418,8 @@ func (h *OpenAIResponsesAPIHandler) forwardResponsesStream(c *gin.Context, flush
if errMsg.Error != nil && errMsg.Error.Error() != "" {
errText = errMsg.Error.Error()
}
body := handlers.BuildErrorResponseBody(status, errText)
_, _ = fmt.Fprintf(c.Writer, "\nevent: error\ndata: %s\n\n", string(body))
chunk := handlers.BuildOpenAIResponsesStreamErrorChunk(status, errText, 0)
_, _ = fmt.Fprintf(c.Writer, "\nevent: error\ndata: %s\n\n", string(chunk))
},
WriteDone: func() {
_, _ = c.Writer.Write([]byte("\n"))

View File

@@ -0,0 +1,43 @@
package openai
import (
"errors"
"net/http"
"net/http/httptest"
"strings"
"testing"
"github.com/gin-gonic/gin"
"github.com/router-for-me/CLIProxyAPI/v6/internal/interfaces"
"github.com/router-for-me/CLIProxyAPI/v6/sdk/api/handlers"
sdkconfig "github.com/router-for-me/CLIProxyAPI/v6/sdk/config"
)
func TestForwardResponsesStreamTerminalErrorUsesResponsesErrorChunk(t *testing.T) {
gin.SetMode(gin.TestMode)
base := handlers.NewBaseAPIHandlers(&sdkconfig.SDKConfig{}, nil)
h := NewOpenAIResponsesAPIHandler(base)
recorder := httptest.NewRecorder()
c, _ := gin.CreateTestContext(recorder)
c.Request = httptest.NewRequest(http.MethodPost, "/v1/responses", nil)
flusher, ok := c.Writer.(http.Flusher)
if !ok {
t.Fatalf("expected gin writer to implement http.Flusher")
}
data := make(chan []byte)
errs := make(chan *interfaces.ErrorMessage, 1)
errs <- &interfaces.ErrorMessage{StatusCode: http.StatusInternalServerError, Error: errors.New("unexpected EOF")}
close(errs)
h.forwardResponsesStream(c, flusher, func(error) {}, data, errs)
body := recorder.Body.String()
if !strings.Contains(body, `"type":"error"`) {
t.Fatalf("expected responses error chunk, got: %q", body)
}
if strings.Contains(body, `"error":{`) {
t.Fatalf("expected streaming error chunk (top-level type), got HTTP error body: %q", body)
}
}

View File

@@ -0,0 +1,119 @@
package handlers
import (
"encoding/json"
"fmt"
"net/http"
"strings"
)
type openAIResponsesStreamErrorChunk struct {
Type string `json:"type"`
Code string `json:"code"`
Message string `json:"message"`
SequenceNumber int `json:"sequence_number"`
}
func openAIResponsesStreamErrorCode(status int) string {
switch status {
case http.StatusUnauthorized:
return "invalid_api_key"
case http.StatusForbidden:
return "insufficient_quota"
case http.StatusTooManyRequests:
return "rate_limit_exceeded"
case http.StatusNotFound:
return "model_not_found"
case http.StatusRequestTimeout:
return "request_timeout"
default:
if status >= http.StatusInternalServerError {
return "internal_server_error"
}
if status >= http.StatusBadRequest {
return "invalid_request_error"
}
return "unknown_error"
}
}
// BuildOpenAIResponsesStreamErrorChunk builds an OpenAI Responses streaming error chunk.
//
// Important: OpenAI's HTTP error bodies are shaped like {"error":{...}}; those are valid for
// non-streaming responses, but streaming clients validate SSE `data:` payloads against a union
// of chunks that requires a top-level `type` field.
func BuildOpenAIResponsesStreamErrorChunk(status int, errText string, sequenceNumber int) []byte {
if status <= 0 {
status = http.StatusInternalServerError
}
if sequenceNumber < 0 {
sequenceNumber = 0
}
message := strings.TrimSpace(errText)
if message == "" {
message = http.StatusText(status)
}
code := openAIResponsesStreamErrorCode(status)
trimmed := strings.TrimSpace(errText)
if trimmed != "" && json.Valid([]byte(trimmed)) {
var payload map[string]any
if err := json.Unmarshal([]byte(trimmed), &payload); err == nil {
if t, ok := payload["type"].(string); ok && strings.TrimSpace(t) == "error" {
if m, ok := payload["message"].(string); ok && strings.TrimSpace(m) != "" {
message = strings.TrimSpace(m)
}
if v, ok := payload["code"]; ok && v != nil {
if c, ok := v.(string); ok && strings.TrimSpace(c) != "" {
code = strings.TrimSpace(c)
} else {
code = strings.TrimSpace(fmt.Sprint(v))
}
}
if v, ok := payload["sequence_number"].(float64); ok && sequenceNumber == 0 {
sequenceNumber = int(v)
}
}
if e, ok := payload["error"].(map[string]any); ok {
if m, ok := e["message"].(string); ok && strings.TrimSpace(m) != "" {
message = strings.TrimSpace(m)
}
if v, ok := e["code"]; ok && v != nil {
if c, ok := v.(string); ok && strings.TrimSpace(c) != "" {
code = strings.TrimSpace(c)
} else {
code = strings.TrimSpace(fmt.Sprint(v))
}
}
}
}
}
if strings.TrimSpace(code) == "" {
code = "unknown_error"
}
data, err := json.Marshal(openAIResponsesStreamErrorChunk{
Type: "error",
Code: code,
Message: message,
SequenceNumber: sequenceNumber,
})
if err == nil {
return data
}
// Extremely defensive fallback.
data, _ = json.Marshal(openAIResponsesStreamErrorChunk{
Type: "error",
Code: "internal_server_error",
Message: message,
SequenceNumber: sequenceNumber,
})
if len(data) > 0 {
return data
}
return []byte(`{"type":"error","code":"internal_server_error","message":"internal error","sequence_number":0}`)
}

View File

@@ -0,0 +1,48 @@
package handlers
import (
"encoding/json"
"net/http"
"testing"
)
func TestBuildOpenAIResponsesStreamErrorChunk(t *testing.T) {
chunk := BuildOpenAIResponsesStreamErrorChunk(http.StatusInternalServerError, "unexpected EOF", 0)
var payload map[string]any
if err := json.Unmarshal(chunk, &payload); err != nil {
t.Fatalf("unmarshal: %v", err)
}
if payload["type"] != "error" {
t.Fatalf("type = %v, want %q", payload["type"], "error")
}
if payload["code"] != "internal_server_error" {
t.Fatalf("code = %v, want %q", payload["code"], "internal_server_error")
}
if payload["message"] != "unexpected EOF" {
t.Fatalf("message = %v, want %q", payload["message"], "unexpected EOF")
}
if payload["sequence_number"] != float64(0) {
t.Fatalf("sequence_number = %v, want %v", payload["sequence_number"], 0)
}
}
func TestBuildOpenAIResponsesStreamErrorChunkExtractsHTTPErrorBody(t *testing.T) {
chunk := BuildOpenAIResponsesStreamErrorChunk(
http.StatusInternalServerError,
`{"error":{"message":"oops","type":"server_error","code":"internal_server_error"}}`,
0,
)
var payload map[string]any
if err := json.Unmarshal(chunk, &payload); err != nil {
t.Fatalf("unmarshal: %v", err)
}
if payload["type"] != "error" {
t.Fatalf("type = %v, want %q", payload["type"], "error")
}
if payload["code"] != "internal_server_error" {
t.Fatalf("code = %v, want %q", payload["code"], "internal_server_error")
}
if payload["message"] != "oops" {
t.Fatalf("message = %v, want %q", payload["message"], "oops")
}
}

View File

@@ -2,8 +2,6 @@ package auth
import (
"context"
"crypto/sha256"
"encoding/hex"
"fmt"
"net/http"
"strings"
@@ -48,6 +46,10 @@ func (a *CodexAuthenticator) Login(ctx context.Context, cfg *config.Config, opts
opts = &LoginOptions{}
}
if shouldUseCodexDeviceFlow(opts) {
return a.loginWithDeviceFlow(ctx, cfg, opts)
}
callbackPort := a.CallbackPort
if opts.CallbackPort > 0 {
callbackPort = opts.CallbackPort
@@ -186,39 +188,5 @@ waitForCallback:
return nil, codex.NewAuthenticationError(codex.ErrCodeExchangeFailed, err)
}
tokenStorage := authSvc.CreateTokenStorage(authBundle)
if tokenStorage == nil || tokenStorage.Email == "" {
return nil, fmt.Errorf("codex token storage missing account information")
}
planType := ""
hashAccountID := ""
if tokenStorage.IDToken != "" {
if claims, errParse := codex.ParseJWTToken(tokenStorage.IDToken); errParse == nil && claims != nil {
planType = strings.TrimSpace(claims.CodexAuthInfo.ChatgptPlanType)
accountID := strings.TrimSpace(claims.CodexAuthInfo.ChatgptAccountID)
if accountID != "" {
digest := sha256.Sum256([]byte(accountID))
hashAccountID = hex.EncodeToString(digest[:])[:8]
}
}
}
fileName := codex.CredentialFileName(tokenStorage.Email, planType, hashAccountID, true)
metadata := map[string]any{
"email": tokenStorage.Email,
}
fmt.Println("Codex authentication successful")
if authBundle.APIKey != "" {
fmt.Println("Codex API key obtained and stored")
}
return &coreauth.Auth{
ID: fileName,
Provider: a.Provider(),
FileName: fileName,
Storage: tokenStorage,
Metadata: metadata,
}, nil
return a.buildAuthRecord(authSvc, authBundle)
}

291
sdk/auth/codex_device.go Normal file
View File

@@ -0,0 +1,291 @@
package auth
import (
"bytes"
"context"
"crypto/sha256"
"encoding/hex"
"encoding/json"
"fmt"
"io"
"net/http"
"strconv"
"strings"
"time"
"github.com/router-for-me/CLIProxyAPI/v6/internal/auth/codex"
"github.com/router-for-me/CLIProxyAPI/v6/internal/browser"
"github.com/router-for-me/CLIProxyAPI/v6/internal/config"
"github.com/router-for-me/CLIProxyAPI/v6/internal/util"
coreauth "github.com/router-for-me/CLIProxyAPI/v6/sdk/cliproxy/auth"
log "github.com/sirupsen/logrus"
)
const (
codexLoginModeMetadataKey = "codex_login_mode"
codexLoginModeDevice = "device"
codexDeviceUserCodeURL = "https://auth.openai.com/api/accounts/deviceauth/usercode"
codexDeviceTokenURL = "https://auth.openai.com/api/accounts/deviceauth/token"
codexDeviceVerificationURL = "https://auth.openai.com/codex/device"
codexDeviceTokenExchangeRedirectURI = "https://auth.openai.com/deviceauth/callback"
codexDeviceTimeout = 15 * time.Minute
codexDeviceDefaultPollIntervalSeconds = 5
)
type codexDeviceUserCodeRequest struct {
ClientID string `json:"client_id"`
}
type codexDeviceUserCodeResponse struct {
DeviceAuthID string `json:"device_auth_id"`
UserCode string `json:"user_code"`
UserCodeAlt string `json:"usercode"`
Interval json.RawMessage `json:"interval"`
}
type codexDeviceTokenRequest struct {
DeviceAuthID string `json:"device_auth_id"`
UserCode string `json:"user_code"`
}
type codexDeviceTokenResponse struct {
AuthorizationCode string `json:"authorization_code"`
CodeVerifier string `json:"code_verifier"`
CodeChallenge string `json:"code_challenge"`
}
func shouldUseCodexDeviceFlow(opts *LoginOptions) bool {
if opts == nil || opts.Metadata == nil {
return false
}
return strings.EqualFold(strings.TrimSpace(opts.Metadata[codexLoginModeMetadataKey]), codexLoginModeDevice)
}
func (a *CodexAuthenticator) loginWithDeviceFlow(ctx context.Context, cfg *config.Config, opts *LoginOptions) (*coreauth.Auth, error) {
if ctx == nil {
ctx = context.Background()
}
httpClient := util.SetProxy(&cfg.SDKConfig, &http.Client{})
userCodeResp, err := requestCodexDeviceUserCode(ctx, httpClient)
if err != nil {
return nil, err
}
deviceCode := strings.TrimSpace(userCodeResp.UserCode)
if deviceCode == "" {
deviceCode = strings.TrimSpace(userCodeResp.UserCodeAlt)
}
deviceAuthID := strings.TrimSpace(userCodeResp.DeviceAuthID)
if deviceCode == "" || deviceAuthID == "" {
return nil, fmt.Errorf("codex device flow did not return required fields")
}
pollInterval := parseCodexDevicePollInterval(userCodeResp.Interval)
fmt.Println("Starting Codex device authentication...")
fmt.Printf("Codex device URL: %s\n", codexDeviceVerificationURL)
fmt.Printf("Codex device code: %s\n", deviceCode)
if !opts.NoBrowser {
if !browser.IsAvailable() {
log.Warn("No browser available; please open the device URL manually")
} else if errOpen := browser.OpenURL(codexDeviceVerificationURL); errOpen != nil {
log.Warnf("Failed to open browser automatically: %v", errOpen)
}
}
tokenResp, err := pollCodexDeviceToken(ctx, httpClient, deviceAuthID, deviceCode, pollInterval)
if err != nil {
return nil, err
}
authCode := strings.TrimSpace(tokenResp.AuthorizationCode)
codeVerifier := strings.TrimSpace(tokenResp.CodeVerifier)
codeChallenge := strings.TrimSpace(tokenResp.CodeChallenge)
if authCode == "" || codeVerifier == "" || codeChallenge == "" {
return nil, fmt.Errorf("codex device flow token response missing required fields")
}
authSvc := codex.NewCodexAuth(cfg)
authBundle, err := authSvc.ExchangeCodeForTokensWithRedirect(
ctx,
authCode,
codexDeviceTokenExchangeRedirectURI,
&codex.PKCECodes{
CodeVerifier: codeVerifier,
CodeChallenge: codeChallenge,
},
)
if err != nil {
return nil, codex.NewAuthenticationError(codex.ErrCodeExchangeFailed, err)
}
return a.buildAuthRecord(authSvc, authBundle)
}
func requestCodexDeviceUserCode(ctx context.Context, client *http.Client) (*codexDeviceUserCodeResponse, error) {
body, err := json.Marshal(codexDeviceUserCodeRequest{ClientID: codex.ClientID})
if err != nil {
return nil, fmt.Errorf("failed to encode codex device request: %w", err)
}
req, err := http.NewRequestWithContext(ctx, http.MethodPost, codexDeviceUserCodeURL, bytes.NewReader(body))
if err != nil {
return nil, fmt.Errorf("failed to create codex device request: %w", err)
}
req.Header.Set("Content-Type", "application/json")
req.Header.Set("Accept", "application/json")
resp, err := client.Do(req)
if err != nil {
return nil, fmt.Errorf("failed to request codex device code: %w", err)
}
defer func() { _ = resp.Body.Close() }()
respBody, err := io.ReadAll(resp.Body)
if err != nil {
return nil, fmt.Errorf("failed to read codex device code response: %w", err)
}
if !codexDeviceIsSuccessStatus(resp.StatusCode) {
trimmed := strings.TrimSpace(string(respBody))
if resp.StatusCode == http.StatusNotFound {
return nil, fmt.Errorf("codex device endpoint is unavailable (status %d)", resp.StatusCode)
}
if trimmed == "" {
trimmed = "empty response body"
}
return nil, fmt.Errorf("codex device code request failed with status %d: %s", resp.StatusCode, trimmed)
}
var parsed codexDeviceUserCodeResponse
if err := json.Unmarshal(respBody, &parsed); err != nil {
return nil, fmt.Errorf("failed to decode codex device code response: %w", err)
}
return &parsed, nil
}
func pollCodexDeviceToken(ctx context.Context, client *http.Client, deviceAuthID, userCode string, interval time.Duration) (*codexDeviceTokenResponse, error) {
deadline := time.Now().Add(codexDeviceTimeout)
for {
if time.Now().After(deadline) {
return nil, fmt.Errorf("codex device authentication timed out after 15 minutes")
}
body, err := json.Marshal(codexDeviceTokenRequest{
DeviceAuthID: deviceAuthID,
UserCode: userCode,
})
if err != nil {
return nil, fmt.Errorf("failed to encode codex device poll request: %w", err)
}
req, err := http.NewRequestWithContext(ctx, http.MethodPost, codexDeviceTokenURL, bytes.NewReader(body))
if err != nil {
return nil, fmt.Errorf("failed to create codex device poll request: %w", err)
}
req.Header.Set("Content-Type", "application/json")
req.Header.Set("Accept", "application/json")
resp, err := client.Do(req)
if err != nil {
return nil, fmt.Errorf("failed to poll codex device token: %w", err)
}
respBody, readErr := io.ReadAll(resp.Body)
_ = resp.Body.Close()
if readErr != nil {
return nil, fmt.Errorf("failed to read codex device poll response: %w", readErr)
}
switch {
case codexDeviceIsSuccessStatus(resp.StatusCode):
var parsed codexDeviceTokenResponse
if err := json.Unmarshal(respBody, &parsed); err != nil {
return nil, fmt.Errorf("failed to decode codex device token response: %w", err)
}
return &parsed, nil
case resp.StatusCode == http.StatusForbidden || resp.StatusCode == http.StatusNotFound:
select {
case <-ctx.Done():
return nil, ctx.Err()
case <-time.After(interval):
continue
}
default:
trimmed := strings.TrimSpace(string(respBody))
if trimmed == "" {
trimmed = "empty response body"
}
return nil, fmt.Errorf("codex device token polling failed with status %d: %s", resp.StatusCode, trimmed)
}
}
}
func parseCodexDevicePollInterval(raw json.RawMessage) time.Duration {
defaultInterval := time.Duration(codexDeviceDefaultPollIntervalSeconds) * time.Second
if len(raw) == 0 {
return defaultInterval
}
var asString string
if err := json.Unmarshal(raw, &asString); err == nil {
if seconds, convErr := strconv.Atoi(strings.TrimSpace(asString)); convErr == nil && seconds > 0 {
return time.Duration(seconds) * time.Second
}
}
var asInt int
if err := json.Unmarshal(raw, &asInt); err == nil && asInt > 0 {
return time.Duration(asInt) * time.Second
}
return defaultInterval
}
func codexDeviceIsSuccessStatus(code int) bool {
return code >= 200 && code < 300
}
func (a *CodexAuthenticator) buildAuthRecord(authSvc *codex.CodexAuth, authBundle *codex.CodexAuthBundle) (*coreauth.Auth, error) {
tokenStorage := authSvc.CreateTokenStorage(authBundle)
if tokenStorage == nil || tokenStorage.Email == "" {
return nil, fmt.Errorf("codex token storage missing account information")
}
planType := ""
hashAccountID := ""
if tokenStorage.IDToken != "" {
if claims, errParse := codex.ParseJWTToken(tokenStorage.IDToken); errParse == nil && claims != nil {
planType = strings.TrimSpace(claims.CodexAuthInfo.ChatgptPlanType)
accountID := strings.TrimSpace(claims.CodexAuthInfo.ChatgptAccountID)
if accountID != "" {
digest := sha256.Sum256([]byte(accountID))
hashAccountID = hex.EncodeToString(digest[:])[:8]
}
}
}
fileName := codex.CredentialFileName(tokenStorage.Email, planType, hashAccountID, true)
metadata := map[string]any{
"email": tokenStorage.Email,
}
fmt.Println("Codex authentication successful")
if authBundle.APIKey != "" {
fmt.Println("Codex API key obtained and stored")
}
return &coreauth.Auth{
ID: fileName,
Provider: a.Provider(),
FileName: fileName,
Storage: tokenStorage,
Metadata: metadata,
}, nil
}

View File

@@ -64,8 +64,16 @@ func (s *FileTokenStore) Save(ctx context.Context, auth *cliproxyauth.Auth) (str
return "", fmt.Errorf("auth filestore: create dir failed: %w", err)
}
// metadataSetter is a private interface for TokenStorage implementations that support metadata injection.
type metadataSetter interface {
SetMetadata(map[string]any)
}
switch {
case auth.Storage != nil:
if setter, ok := auth.Storage.(metadataSetter); ok {
setter.SetMetadata(auth.Metadata)
}
if err = auth.Storage.SaveTokenToFile(path); err != nil {
return "", err
}

View File

@@ -86,6 +86,8 @@ func (a GitHubCopilotAuthenticator) Login(ctx context.Context, cfg *config.Confi
metadata := map[string]any{
"type": "github-copilot",
"username": authBundle.Username,
"email": authBundle.Email,
"name": authBundle.Name,
"access_token": authBundle.TokenData.AccessToken,
"token_type": authBundle.TokenData.TokenType,
"scope": authBundle.TokenData.Scope,
@@ -98,13 +100,18 @@ func (a GitHubCopilotAuthenticator) Login(ctx context.Context, cfg *config.Confi
fileName := fmt.Sprintf("github-copilot-%s.json", authBundle.Username)
label := authBundle.Email
if label == "" {
label = authBundle.Username
}
fmt.Printf("\nGitHub Copilot authentication successful for user: %s\n", authBundle.Username)
return &coreauth.Auth{
ID: fileName,
Provider: a.Provider(),
FileName: fileName,
Label: authBundle.Username,
Label: label,
Storage: tokenStorage,
Metadata: metadata,
}, nil

View File

@@ -1828,9 +1828,7 @@ func (m *Manager) persist(ctx context.Context, auth *Auth) error {
// every few seconds and triggers refresh operations when required.
// Only one loop is kept alive; starting a new one cancels the previous run.
func (m *Manager) StartAutoRefresh(parent context.Context, interval time.Duration) {
if interval <= 0 || interval > refreshCheckInterval {
interval = refreshCheckInterval
} else {
if interval <= 0 {
interval = refreshCheckInterval
}
if m.refreshCancel != nil {

View File

@@ -5,6 +5,7 @@ import (
"encoding/json"
"fmt"
"math"
"math/rand/v2"
"net/http"
"sort"
"strconv"
@@ -248,6 +249,9 @@ func getAvailableAuths(auths []*Auth, provider, model string, now time.Time) ([]
}
// Pick selects the next available auth for the provider in a round-robin manner.
// For gemini-cli virtual auths (identified by the gemini_virtual_parent attribute),
// a two-level round-robin is used: first cycling across credential groups (parent
// accounts), then cycling within each group's project auths.
func (s *RoundRobinSelector) Pick(ctx context.Context, provider, model string, opts cliproxyexecutor.Options, auths []*Auth) (*Auth, error) {
_ = opts
now := time.Now()
@@ -265,21 +269,87 @@ func (s *RoundRobinSelector) Pick(ctx context.Context, provider, model string, o
if limit <= 0 {
limit = 4096
}
if _, ok := s.cursors[key]; !ok && len(s.cursors) >= limit {
s.cursors = make(map[string]int)
}
index := s.cursors[key]
// Check if any available auth has gemini_virtual_parent attribute,
// indicating gemini-cli virtual auths that should use credential-level polling.
groups, parentOrder := groupByVirtualParent(available)
if len(parentOrder) > 1 {
// Two-level round-robin: first select a credential group, then pick within it.
groupKey := key + "::group"
s.ensureCursorKey(groupKey, limit)
if _, exists := s.cursors[groupKey]; !exists {
// Seed with a random initial offset so the starting credential is randomized.
s.cursors[groupKey] = rand.IntN(len(parentOrder))
}
groupIndex := s.cursors[groupKey]
if groupIndex >= 2_147_483_640 {
groupIndex = 0
}
s.cursors[groupKey] = groupIndex + 1
selectedParent := parentOrder[groupIndex%len(parentOrder)]
group := groups[selectedParent]
// Second level: round-robin within the selected credential group.
innerKey := key + "::cred:" + selectedParent
s.ensureCursorKey(innerKey, limit)
innerIndex := s.cursors[innerKey]
if innerIndex >= 2_147_483_640 {
innerIndex = 0
}
s.cursors[innerKey] = innerIndex + 1
s.mu.Unlock()
return group[innerIndex%len(group)], nil
}
// Flat round-robin for non-grouped auths (original behavior).
s.ensureCursorKey(key, limit)
index := s.cursors[key]
if index >= 2_147_483_640 {
index = 0
}
s.cursors[key] = index + 1
s.mu.Unlock()
// log.Debugf("available: %d, index: %d, key: %d", len(available), index, index%len(available))
return available[index%len(available)], nil
}
// ensureCursorKey ensures the cursor map has capacity for the given key.
// Must be called with s.mu held.
func (s *RoundRobinSelector) ensureCursorKey(key string, limit int) {
if _, ok := s.cursors[key]; !ok && len(s.cursors) >= limit {
s.cursors = make(map[string]int)
}
}
// groupByVirtualParent groups auths by their gemini_virtual_parent attribute.
// Returns a map of parentID -> auths and a sorted slice of parent IDs for stable iteration.
// Only auths with a non-empty gemini_virtual_parent are grouped; if any auth lacks
// this attribute, nil/nil is returned so the caller falls back to flat round-robin.
func groupByVirtualParent(auths []*Auth) (map[string][]*Auth, []string) {
if len(auths) == 0 {
return nil, nil
}
groups := make(map[string][]*Auth)
for _, a := range auths {
parent := ""
if a.Attributes != nil {
parent = strings.TrimSpace(a.Attributes["gemini_virtual_parent"])
}
if parent == "" {
// Non-virtual auth present; fall back to flat round-robin.
return nil, nil
}
groups[parent] = append(groups[parent], a)
}
// Collect parent IDs in sorted order for stable cursor indexing.
parentOrder := make([]string, 0, len(groups))
for p := range groups {
parentOrder = append(parentOrder, p)
}
sort.Strings(parentOrder)
return groups, parentOrder
}
// Pick selects the first available auth for the provider in a deterministic manner.
func (s *FillFirstSelector) Pick(ctx context.Context, provider, model string, opts cliproxyexecutor.Options, auths []*Auth) (*Auth, error) {
_ = opts

View File

@@ -402,3 +402,128 @@ func TestRoundRobinSelectorPick_CursorKeyCap(t *testing.T) {
t.Fatalf("selector.cursors missing key %q", "gemini:m3")
}
}
func TestRoundRobinSelectorPick_GeminiCLICredentialGrouping(t *testing.T) {
t.Parallel()
selector := &RoundRobinSelector{}
// Simulate two gemini-cli credentials, each with multiple projects:
// Credential A (parent = "cred-a.json") has 3 projects
// Credential B (parent = "cred-b.json") has 2 projects
auths := []*Auth{
{ID: "cred-a.json::proj-a1", Attributes: map[string]string{"gemini_virtual_parent": "cred-a.json"}},
{ID: "cred-a.json::proj-a2", Attributes: map[string]string{"gemini_virtual_parent": "cred-a.json"}},
{ID: "cred-a.json::proj-a3", Attributes: map[string]string{"gemini_virtual_parent": "cred-a.json"}},
{ID: "cred-b.json::proj-b1", Attributes: map[string]string{"gemini_virtual_parent": "cred-b.json"}},
{ID: "cred-b.json::proj-b2", Attributes: map[string]string{"gemini_virtual_parent": "cred-b.json"}},
}
// Two-level round-robin: consecutive picks must alternate between credentials.
// Credential group order is randomized, but within each call the group cursor
// advances by 1, so consecutive picks should cycle through different parents.
picks := make([]string, 6)
parents := make([]string, 6)
for i := 0; i < 6; i++ {
got, err := selector.Pick(context.Background(), "gemini-cli", "gemini-2.5-pro", cliproxyexecutor.Options{}, auths)
if err != nil {
t.Fatalf("Pick() #%d error = %v", i, err)
}
if got == nil {
t.Fatalf("Pick() #%d auth = nil", i)
}
picks[i] = got.ID
parents[i] = got.Attributes["gemini_virtual_parent"]
}
// Verify property: consecutive picks must alternate between credential groups.
for i := 1; i < len(parents); i++ {
if parents[i] == parents[i-1] {
t.Fatalf("Pick() #%d and #%d both from same parent %q (IDs: %q, %q); expected alternating credentials",
i-1, i, parents[i], picks[i-1], picks[i])
}
}
// Verify property: each credential's projects are picked in sequence (round-robin within group).
credPicks := map[string][]string{}
for i, id := range picks {
credPicks[parents[i]] = append(credPicks[parents[i]], id)
}
for parent, ids := range credPicks {
for i := 1; i < len(ids); i++ {
if ids[i] == ids[i-1] {
t.Fatalf("Credential %q picked same project %q twice in a row", parent, ids[i])
}
}
}
}
func TestRoundRobinSelectorPick_SingleParentFallsBackToFlat(t *testing.T) {
t.Parallel()
selector := &RoundRobinSelector{}
// All auths from the same parent - should fall back to flat round-robin
// because there's only one credential group (no benefit from two-level).
auths := []*Auth{
{ID: "cred-a.json::proj-a1", Attributes: map[string]string{"gemini_virtual_parent": "cred-a.json"}},
{ID: "cred-a.json::proj-a2", Attributes: map[string]string{"gemini_virtual_parent": "cred-a.json"}},
{ID: "cred-a.json::proj-a3", Attributes: map[string]string{"gemini_virtual_parent": "cred-a.json"}},
}
// With single parent group, parentOrder has length 1, so it uses flat round-robin.
// Sorted by ID: proj-a1, proj-a2, proj-a3
want := []string{
"cred-a.json::proj-a1",
"cred-a.json::proj-a2",
"cred-a.json::proj-a3",
"cred-a.json::proj-a1",
}
for i, expectedID := range want {
got, err := selector.Pick(context.Background(), "gemini-cli", "gemini-2.5-pro", cliproxyexecutor.Options{}, auths)
if err != nil {
t.Fatalf("Pick() #%d error = %v", i, err)
}
if got == nil {
t.Fatalf("Pick() #%d auth = nil", i)
}
if got.ID != expectedID {
t.Fatalf("Pick() #%d auth.ID = %q, want %q", i, got.ID, expectedID)
}
}
}
func TestRoundRobinSelectorPick_MixedVirtualAndNonVirtualFallsBackToFlat(t *testing.T) {
t.Parallel()
selector := &RoundRobinSelector{}
// Mix of virtual and non-virtual auths (e.g., a regular gemini-cli auth without projects
// alongside virtual ones). Should fall back to flat round-robin.
auths := []*Auth{
{ID: "cred-a.json::proj-a1", Attributes: map[string]string{"gemini_virtual_parent": "cred-a.json"}},
{ID: "cred-regular.json"}, // no gemini_virtual_parent
}
// groupByVirtualParent returns nil when any auth lacks the attribute,
// so flat round-robin is used. Sorted by ID: cred-a.json::proj-a1, cred-regular.json
want := []string{
"cred-a.json::proj-a1",
"cred-regular.json",
"cred-a.json::proj-a1",
}
for i, expectedID := range want {
got, err := selector.Pick(context.Background(), "gemini-cli", "", cliproxyexecutor.Options{}, auths)
if err != nil {
t.Fatalf("Pick() #%d error = %v", i, err)
}
if got == nil {
t.Fatalf("Pick() #%d auth = nil", i)
}
if got.ID != expectedID {
t.Fatalf("Pick() #%d auth.ID = %q, want %q", i, got.ID, expectedID)
}
}
}

View File

@@ -1,9 +1,12 @@
package auth
import (
"context"
"crypto/sha256"
"encoding/hex"
"encoding/json"
"net/http"
"net/url"
"strconv"
"strings"
"sync"
@@ -12,6 +15,33 @@ import (
baseauth "github.com/router-for-me/CLIProxyAPI/v6/internal/auth"
)
// PostAuthHook defines a function that is called after an Auth record is created
// but before it is persisted to storage. This allows for modification of the
// Auth record (e.g., injecting metadata) based on external context.
type PostAuthHook func(context.Context, *Auth) error
// RequestInfo holds information extracted from the HTTP request.
// It is injected into the context passed to PostAuthHook.
type RequestInfo struct {
Query url.Values
Headers http.Header
}
type requestInfoKey struct{}
// WithRequestInfo returns a new context with the given RequestInfo attached.
func WithRequestInfo(ctx context.Context, info *RequestInfo) context.Context {
return context.WithValue(ctx, requestInfoKey{}, info)
}
// GetRequestInfo retrieves the RequestInfo from the context, if present.
func GetRequestInfo(ctx context.Context) *RequestInfo {
if val, ok := ctx.Value(requestInfoKey{}).(*RequestInfo); ok {
return val
}
return nil
}
// Auth encapsulates the runtime state and metadata associated with a single credential.
type Auth struct {
// ID uniquely identifies the auth record across restarts.

View File

@@ -153,6 +153,16 @@ func (b *Builder) WithLocalManagementPassword(password string) *Builder {
return b
}
// WithPostAuthHook registers a hook to be called after an Auth record is created
// but before it is persisted to storage.
func (b *Builder) WithPostAuthHook(hook coreauth.PostAuthHook) *Builder {
if hook == nil {
return b
}
b.serverOptions = append(b.serverOptions, api.WithPostAuthHook(hook))
return b
}
// Build validates inputs, applies defaults, and returns a ready-to-run service.
func (b *Builder) Build() (*Service, error) {
if b.cfg == nil {

View File

@@ -963,6 +963,9 @@ func (s *Service) registerModelsForAuth(a *coreauth.Auth) {
key = strings.ToLower(strings.TrimSpace(a.Provider))
}
GlobalModelRegistry().RegisterClient(a.ID, key, applyModelPrefixes(models, a.Prefix, s.cfg != nil && s.cfg.ForceModelPrefix))
if provider == "antigravity" {
s.backfillAntigravityModels(a, models)
}
return
}
@@ -1107,6 +1110,56 @@ func (s *Service) oauthExcludedModels(provider, authKind string) []string {
return cfg.OAuthExcludedModels[providerKey]
}
func (s *Service) backfillAntigravityModels(source *coreauth.Auth, primaryModels []*ModelInfo) {
if s == nil || s.coreManager == nil || len(primaryModels) == 0 {
return
}
sourceID := ""
if source != nil {
sourceID = strings.TrimSpace(source.ID)
}
reg := registry.GetGlobalRegistry()
for _, candidate := range s.coreManager.List() {
if candidate == nil || candidate.Disabled {
continue
}
candidateID := strings.TrimSpace(candidate.ID)
if candidateID == "" || candidateID == sourceID {
continue
}
if !strings.EqualFold(strings.TrimSpace(candidate.Provider), "antigravity") {
continue
}
if len(reg.GetModelsForClient(candidateID)) > 0 {
continue
}
authKind := strings.ToLower(strings.TrimSpace(candidate.Attributes["auth_kind"]))
if authKind == "" {
if kind, _ := candidate.AccountInfo(); strings.EqualFold(kind, "api_key") {
authKind = "apikey"
}
}
excluded := s.oauthExcludedModels("antigravity", authKind)
if candidate.Attributes != nil {
if val, ok := candidate.Attributes["excluded_models"]; ok && strings.TrimSpace(val) != "" {
excluded = strings.Split(val, ",")
}
}
models := applyExcludedModels(primaryModels, excluded)
models = applyOAuthModelAlias(s.cfg, "antigravity", authKind, models)
if len(models) == 0 {
continue
}
reg.RegisterClient(candidateID, "antigravity", applyModelPrefixes(models, candidate.Prefix, s.cfg != nil && s.cfg.ForceModelPrefix))
log.Debugf("antigravity models backfilled for auth %s using primary model list", candidateID)
}
}
func applyExcludedModels(models []*ModelInfo, excluded []string) []*ModelInfo {
if len(models) == 0 || len(excluded) == 0 {
return models

View File

@@ -0,0 +1,135 @@
package cliproxy
import (
"context"
"strings"
"testing"
"github.com/router-for-me/CLIProxyAPI/v6/internal/registry"
coreauth "github.com/router-for-me/CLIProxyAPI/v6/sdk/cliproxy/auth"
"github.com/router-for-me/CLIProxyAPI/v6/sdk/config"
)
func TestBackfillAntigravityModels_RegistersMissingAuth(t *testing.T) {
source := &coreauth.Auth{
ID: "ag-backfill-source",
Provider: "antigravity",
Status: coreauth.StatusActive,
Attributes: map[string]string{
"auth_kind": "oauth",
},
}
target := &coreauth.Auth{
ID: "ag-backfill-target",
Provider: "antigravity",
Status: coreauth.StatusActive,
Attributes: map[string]string{
"auth_kind": "oauth",
},
}
manager := coreauth.NewManager(nil, nil, nil)
if _, err := manager.Register(context.Background(), source); err != nil {
t.Fatalf("register source auth: %v", err)
}
if _, err := manager.Register(context.Background(), target); err != nil {
t.Fatalf("register target auth: %v", err)
}
service := &Service{
cfg: &config.Config{},
coreManager: manager,
}
reg := registry.GetGlobalRegistry()
reg.UnregisterClient(source.ID)
reg.UnregisterClient(target.ID)
t.Cleanup(func() {
reg.UnregisterClient(source.ID)
reg.UnregisterClient(target.ID)
})
primary := []*ModelInfo{
{ID: "claude-sonnet-4-5"},
{ID: "gemini-2.5-pro"},
}
reg.RegisterClient(source.ID, "antigravity", primary)
service.backfillAntigravityModels(source, primary)
got := reg.GetModelsForClient(target.ID)
if len(got) != 2 {
t.Fatalf("expected target auth to be backfilled with 2 models, got %d", len(got))
}
ids := make(map[string]struct{}, len(got))
for _, model := range got {
if model == nil {
continue
}
ids[strings.ToLower(strings.TrimSpace(model.ID))] = struct{}{}
}
if _, ok := ids["claude-sonnet-4-5"]; !ok {
t.Fatal("expected backfilled model claude-sonnet-4-5")
}
if _, ok := ids["gemini-2.5-pro"]; !ok {
t.Fatal("expected backfilled model gemini-2.5-pro")
}
}
func TestBackfillAntigravityModels_RespectsExcludedModels(t *testing.T) {
source := &coreauth.Auth{
ID: "ag-backfill-source-excluded",
Provider: "antigravity",
Status: coreauth.StatusActive,
Attributes: map[string]string{
"auth_kind": "oauth",
},
}
target := &coreauth.Auth{
ID: "ag-backfill-target-excluded",
Provider: "antigravity",
Status: coreauth.StatusActive,
Attributes: map[string]string{
"auth_kind": "oauth",
"excluded_models": "gemini-2.5-pro",
},
}
manager := coreauth.NewManager(nil, nil, nil)
if _, err := manager.Register(context.Background(), source); err != nil {
t.Fatalf("register source auth: %v", err)
}
if _, err := manager.Register(context.Background(), target); err != nil {
t.Fatalf("register target auth: %v", err)
}
service := &Service{
cfg: &config.Config{},
coreManager: manager,
}
reg := registry.GetGlobalRegistry()
reg.UnregisterClient(source.ID)
reg.UnregisterClient(target.ID)
t.Cleanup(func() {
reg.UnregisterClient(source.ID)
reg.UnregisterClient(target.ID)
})
primary := []*ModelInfo{
{ID: "claude-sonnet-4-5"},
{ID: "gemini-2.5-pro"},
}
reg.RegisterClient(source.ID, "antigravity", primary)
service.backfillAntigravityModels(source, primary)
got := reg.GetModelsForClient(target.ID)
if len(got) != 1 {
t.Fatalf("expected 1 model after exclusion, got %d", len(got))
}
if got[0] == nil || !strings.EqualFold(strings.TrimSpace(got[0].ID), "claude-sonnet-4-5") {
t.Fatalf("expected remaining model %q, got %+v", "claude-sonnet-4-5", got[0])
}
}