Compare commits

...

31 Commits

Author SHA1 Message Date
Luis Pater
4eeec297de Merge pull request #288 from router-for-me/plus
v6.8.27
2026-02-25 01:04:57 +08:00
Luis Pater
77cc4ce3a0 Merge branch 'main' into plus 2026-02-25 01:04:15 +08:00
Luis Pater
37dfea1d3f Merge pull request #287 from possible055/main
fix(kiro): support OR-group field matching in truncation detector
2026-02-25 01:02:49 +08:00
Luis Pater
e6626c672a Merge pull request #269 from ClubWeGo/fix/filterOrphanedToolResults
fix: filter out orphaned tool results from history and current context
2026-02-25 01:02:11 +08:00
Luis Pater
c66cb0afd2 Merge pull request #1683 from dusty-du/codex/device-login-flow
Add additive Codex device-code login flow
2026-02-25 00:50:48 +08:00
Luis Pater
fb48eee973 Merge pull request #1680 from canxin121/fix/responses-stream-error-chunks
fix(responses): emit schema-valid SSE chunks
2026-02-25 00:49:06 +08:00
Luis Pater
bb44e5ec44 Merge pull request #1701 from router-for-me/openai
Revert "Merge pull request #1627 from thebtf/fix/reasoning-effort-clamping"
2026-02-25 00:46:13 +08:00
apparition
c785c1a3ca fix(kiro): support OR-group field matching in truncation detector
- Change RequiredFieldsByTool value type from []string to [][]string
- Outer slice = AND (all groups required); inner slice = OR (any one satisfies)
- Fix Bash entry to accept "cmd" or "command", resolving soft-truncation loop
- Update findMissingRequiredFields logic and inline docs accordingly
2026-02-24 22:48:05 +08:00
hkfires
0659ffab75 Revert "Merge pull request #1627 from thebtf/fix/reasoning-effort-clamping" 2026-02-24 19:47:53 +08:00
Luis Pater
7cb398d167 Merge pull request #1663 from rensumo/main
feat: implement credential-based round-robin for gemini-cli
2026-02-24 06:02:50 +08:00
Luis Pater
c3e12c5e58 Merge pull request #1654 from alexey-yanchenko/feature/pass-file-inputs
Pass file input from /chat/completions and /responses to codex and claude
2026-02-24 05:53:11 +08:00
Luis Pater
1825fc7503 Merge pull request #1643 from alexey-yanchenko/fix/gemini-prompt-tokens
Fix usage convertation from gemini response to openai format
2026-02-24 05:46:13 +08:00
Luis Pater
48732ba05e Merge pull request #1527 from HEUDavid/feat/auth-hook
feat(auth): add post-auth hook mechanism
2026-02-24 05:33:13 +08:00
canxin121
acf483c9e6 fix(responses): reject invalid SSE data JSON
Guard the openai-response streaming path against truncated/invalid SSE data payloads by validating data: JSON before forwarding; surface a 502 terminal error instead of letting clients crash with JSON parse errors.
2026-02-24 01:42:54 +08:00
test
492b9c46f0 Add additive Codex device-code login flow 2026-02-23 06:30:04 -05:00
Darley
6e634fe3f9 fix: filter out orphaned tool results from history and current context 2026-02-23 14:33:59 +08:00
canxin121
eb7571936c revert: translator changes (path guard)
CI blocks PRs that modify internal/translator. Revert translator edits and keep only the /v1/responses streaming error-chunk fix; file an issue for translator conformance work.
2026-02-23 13:30:43 +08:00
canxin121
5382764d8a fix(responses): include model and usage in translated streams
Ensure response.created and response.completed chunks produced by the OpenAI/Gemini/Claude translators always include required fields (response.model and response.usage) so clients validating Responses SSE do not fail schema validation.
2026-02-23 13:22:06 +08:00
canxin121
49c8ec69d0 fix(openai): emit valid responses stream error chunks
When /v1/responses streaming fails after headers are sent, we now emit a type=error chunk instead of an HTTP-style {error:{...}} payload, preventing AI SDK chunk validation errors.
2026-02-23 12:59:50 +08:00
rensumo
5936f9895c feat: implement credential-based round-robin for gemini-cli virtual auths
Changes the RoundRobinSelector to use two-level round-robin when
gemini-cli virtual auths are detected (via gemini_virtual_parent attr):
- Level 1: cycle across credential groups (parent accounts)
- Level 2: cycle within each group's project auths

Credentials start from a random offset (rand.IntN) for fair distribution.
Non-virtual auths and single-credential scenarios fall back to flat RR.

Adds 3 test cases covering multi-credential grouping, single-parent
fallback, and mixed virtual/non-virtual fallback.
2026-02-21 12:49:48 +08:00
Alexey Yanchenko
0cbfe7f457 Pass file input from /chat/completions and /responses to codex and claude 2026-02-20 10:25:44 +07:00
Alexey Yanchenko
b9ae4ab803 Fix usage convertation from gemini response to openai format 2026-02-19 15:34:59 +07:00
HEUDavid
65debb874f feat/auth-hook: refactor RequstInfo to preserve original HTTP semantics 2026-02-12 07:11:17 +08:00
HEUDavid
3caadac003 feat/auth-hook: add post auth hook [CR] 2026-02-12 07:11:17 +08:00
HEUDavid
6a9e3a6b84 feat/auth-hook: add post auth hook 2026-02-12 07:11:17 +08:00
HEUDavid
269972440a feat/auth-hook: add post auth hook 2026-02-12 07:11:17 +08:00
HEUDavid
cce13e6ad2 feat/auth-hook: add post auth hook 2026-02-12 07:11:17 +08:00
HEUDavid
8a565dcad8 feat/auth-hook: add post auth hook 2026-02-12 07:11:17 +08:00
HEUDavid
d536110404 feat/auth-hook: add post auth hook 2026-02-12 07:11:17 +08:00
HEUDavid
48e957ddff feat/auth-hook: add post auth hook 2026-02-12 07:11:17 +08:00
HEUDavid
94563d622c feat/auth-hook: add post auth hook 2026-02-12 07:11:17 +08:00
37 changed files with 1336 additions and 130 deletions

View File

@@ -72,6 +72,7 @@ func main() {
// Command-line flags to control the application's behavior.
var login bool
var codexLogin bool
var codexDeviceLogin bool
var claudeLogin bool
var qwenLogin bool
var kiloLogin bool
@@ -99,6 +100,7 @@ func main() {
// Define command-line flags for different operation modes.
flag.BoolVar(&login, "login", false, "Login Google Account")
flag.BoolVar(&codexLogin, "codex-login", false, "Login to Codex using OAuth")
flag.BoolVar(&codexDeviceLogin, "codex-device-login", false, "Login to Codex using device code flow")
flag.BoolVar(&claudeLogin, "claude-login", false, "Login to Claude using OAuth")
flag.BoolVar(&qwenLogin, "qwen-login", false, "Login to Qwen using OAuth")
flag.BoolVar(&kiloLogin, "kilo-login", false, "Login to Kilo AI using device flow")
@@ -502,6 +504,9 @@ func main() {
} else if codexLogin {
// Handle Codex login
cmd.DoCodexLogin(cfg, options)
} else if codexDeviceLogin {
// Handle Codex device-code login
cmd.DoCodexDeviceLogin(cfg, options)
} else if claudeLogin {
// Handle Claude login
cmd.DoClaudeLogin(cfg, options)

View File

@@ -951,11 +951,17 @@ func (h *Handler) saveTokenRecord(ctx context.Context, record *coreauth.Auth) (s
if store == nil {
return "", fmt.Errorf("token store unavailable")
}
if h.postAuthHook != nil {
if err := h.postAuthHook(ctx, record); err != nil {
return "", fmt.Errorf("post-auth hook failed: %w", err)
}
}
return store.Save(ctx, record)
}
func (h *Handler) RequestAnthropicToken(c *gin.Context) {
ctx := context.Background()
ctx = PopulateAuthContext(ctx, c)
fmt.Println("Initializing Claude authentication...")
@@ -1100,6 +1106,7 @@ func (h *Handler) RequestAnthropicToken(c *gin.Context) {
func (h *Handler) RequestGeminiCLIToken(c *gin.Context) {
ctx := context.Background()
ctx = PopulateAuthContext(ctx, c)
proxyHTTPClient := util.SetProxy(&h.cfg.SDKConfig, &http.Client{})
ctx = context.WithValue(ctx, oauth2.HTTPClient, proxyHTTPClient)
@@ -1358,6 +1365,7 @@ func (h *Handler) RequestGeminiCLIToken(c *gin.Context) {
func (h *Handler) RequestCodexToken(c *gin.Context) {
ctx := context.Background()
ctx = PopulateAuthContext(ctx, c)
fmt.Println("Initializing Codex authentication...")
@@ -1503,6 +1511,7 @@ func (h *Handler) RequestCodexToken(c *gin.Context) {
func (h *Handler) RequestAntigravityToken(c *gin.Context) {
ctx := context.Background()
ctx = PopulateAuthContext(ctx, c)
fmt.Println("Initializing Antigravity authentication...")
@@ -1667,6 +1676,7 @@ func (h *Handler) RequestAntigravityToken(c *gin.Context) {
func (h *Handler) RequestQwenToken(c *gin.Context) {
ctx := context.Background()
ctx = PopulateAuthContext(ctx, c)
fmt.Println("Initializing Qwen authentication...")
@@ -1722,6 +1732,7 @@ func (h *Handler) RequestQwenToken(c *gin.Context) {
func (h *Handler) RequestKimiToken(c *gin.Context) {
ctx := context.Background()
ctx = PopulateAuthContext(ctx, c)
fmt.Println("Initializing Kimi authentication...")
@@ -1798,6 +1809,7 @@ func (h *Handler) RequestKimiToken(c *gin.Context) {
func (h *Handler) RequestIFlowToken(c *gin.Context) {
ctx := context.Background()
ctx = PopulateAuthContext(ctx, c)
fmt.Println("Initializing iFlow authentication...")
@@ -2521,6 +2533,14 @@ func (h *Handler) GetAuthStatus(c *gin.Context) {
c.JSON(http.StatusOK, gin.H{"status": "wait"})
}
// PopulateAuthContext extracts request info and adds it to the context
func PopulateAuthContext(ctx context.Context, c *gin.Context) context.Context {
info := &coreauth.RequestInfo{
Query: c.Request.URL.Query(),
Headers: c.Request.Header,
}
return coreauth.WithRequestInfo(ctx, info)
}
const kiroCallbackPort = 9876
func (h *Handler) RequestKiroToken(c *gin.Context) {

View File

@@ -47,6 +47,7 @@ type Handler struct {
allowRemoteOverride bool
envSecret string
logDir string
postAuthHook coreauth.PostAuthHook
}
// NewHandler creates a new management handler instance.
@@ -128,6 +129,11 @@ func (h *Handler) SetLogDirectory(dir string) {
h.logDir = dir
}
// SetPostAuthHook registers a hook to be called after auth record creation but before persistence.
func (h *Handler) SetPostAuthHook(hook coreauth.PostAuthHook) {
h.postAuthHook = hook
}
// Middleware enforces access control for management endpoints.
// All requests (local and remote) require a valid management key.
// Additionally, remote access requires allow-remote-management=true.

View File

@@ -52,6 +52,7 @@ type serverOptionConfig struct {
keepAliveEnabled bool
keepAliveTimeout time.Duration
keepAliveOnTimeout func()
postAuthHook auth.PostAuthHook
}
// ServerOption customises HTTP server construction.
@@ -112,6 +113,13 @@ func WithRequestLoggerFactory(factory func(*config.Config, string) logging.Reque
}
}
// WithPostAuthHook registers a hook to be called after auth record creation.
func WithPostAuthHook(hook auth.PostAuthHook) ServerOption {
return func(cfg *serverOptionConfig) {
cfg.postAuthHook = hook
}
}
// Server represents the main API server.
// It encapsulates the Gin engine, HTTP server, handlers, and configuration.
type Server struct {
@@ -263,6 +271,9 @@ func NewServer(cfg *config.Config, authManager *auth.Manager, accessManager *sdk
}
logDir := logging.ResolveLogDirectory(cfg)
s.mgmt.SetLogDirectory(logDir)
if optionState.postAuthHook != nil {
s.mgmt.SetPostAuthHook(optionState.postAuthHook)
}
s.localPassword = optionState.localPassword
// Setup routes

View File

@@ -36,11 +36,21 @@ type ClaudeTokenStorage struct {
// Expire is the timestamp when the current access token expires.
Expire string `json:"expired"`
// Metadata holds arbitrary key-value pairs injected via hooks.
// It is not exported to JSON directly to allow flattening during serialization.
Metadata map[string]any `json:"-"`
}
// SetMetadata allows external callers to inject metadata into the storage before saving.
func (ts *ClaudeTokenStorage) SetMetadata(meta map[string]any) {
ts.Metadata = meta
}
// SaveTokenToFile serializes the Claude token storage to a JSON file.
// This method creates the necessary directory structure and writes the token
// data in JSON format to the specified file path for persistent storage.
// It merges any injected metadata into the top-level JSON object.
//
// Parameters:
// - authFilePath: The full path where the token file should be saved
@@ -65,8 +75,14 @@ func (ts *ClaudeTokenStorage) SaveTokenToFile(authFilePath string) error {
_ = f.Close()
}()
// Merge metadata using helper
data, errMerge := misc.MergeMetadata(ts, ts.Metadata)
if errMerge != nil {
return fmt.Errorf("failed to merge metadata: %w", errMerge)
}
// Encode and write the token data as JSON
if err = json.NewEncoder(f).Encode(ts); err != nil {
if err = json.NewEncoder(f).Encode(data); err != nil {
return fmt.Errorf("failed to write token to file: %w", err)
}
return nil

View File

@@ -71,16 +71,26 @@ func (o *CodexAuth) GenerateAuthURL(state string, pkceCodes *PKCECodes) (string,
// It performs an HTTP POST request to the OpenAI token endpoint with the provided
// authorization code and PKCE verifier.
func (o *CodexAuth) ExchangeCodeForTokens(ctx context.Context, code string, pkceCodes *PKCECodes) (*CodexAuthBundle, error) {
return o.ExchangeCodeForTokensWithRedirect(ctx, code, RedirectURI, pkceCodes)
}
// ExchangeCodeForTokensWithRedirect exchanges an authorization code for tokens using
// a caller-provided redirect URI. This supports alternate auth flows such as device
// login while preserving the existing token parsing and storage behavior.
func (o *CodexAuth) ExchangeCodeForTokensWithRedirect(ctx context.Context, code, redirectURI string, pkceCodes *PKCECodes) (*CodexAuthBundle, error) {
if pkceCodes == nil {
return nil, fmt.Errorf("PKCE codes are required for token exchange")
}
if strings.TrimSpace(redirectURI) == "" {
return nil, fmt.Errorf("redirect URI is required for token exchange")
}
// Prepare token exchange request
data := url.Values{
"grant_type": {"authorization_code"},
"client_id": {ClientID},
"code": {code},
"redirect_uri": {RedirectURI},
"redirect_uri": {strings.TrimSpace(redirectURI)},
"code_verifier": {pkceCodes.CodeVerifier},
}

View File

@@ -32,11 +32,21 @@ type CodexTokenStorage struct {
Type string `json:"type"`
// Expire is the timestamp when the current access token expires.
Expire string `json:"expired"`
// Metadata holds arbitrary key-value pairs injected via hooks.
// It is not exported to JSON directly to allow flattening during serialization.
Metadata map[string]any `json:"-"`
}
// SetMetadata allows external callers to inject metadata into the storage before saving.
func (ts *CodexTokenStorage) SetMetadata(meta map[string]any) {
ts.Metadata = meta
}
// SaveTokenToFile serializes the Codex token storage to a JSON file.
// This method creates the necessary directory structure and writes the token
// data in JSON format to the specified file path for persistent storage.
// It merges any injected metadata into the top-level JSON object.
//
// Parameters:
// - authFilePath: The full path where the token file should be saved
@@ -58,7 +68,13 @@ func (ts *CodexTokenStorage) SaveTokenToFile(authFilePath string) error {
_ = f.Close()
}()
if err = json.NewEncoder(f).Encode(ts); err != nil {
// Merge metadata using helper
data, errMerge := misc.MergeMetadata(ts, ts.Metadata)
if errMerge != nil {
return fmt.Errorf("failed to merge metadata: %w", errMerge)
}
if err = json.NewEncoder(f).Encode(data); err != nil {
return fmt.Errorf("failed to write token to file: %w", err)
}
return nil

View File

@@ -35,11 +35,21 @@ type GeminiTokenStorage struct {
// Type indicates the authentication provider type, always "gemini" for this storage.
Type string `json:"type"`
// Metadata holds arbitrary key-value pairs injected via hooks.
// It is not exported to JSON directly to allow flattening during serialization.
Metadata map[string]any `json:"-"`
}
// SetMetadata allows external callers to inject metadata into the storage before saving.
func (ts *GeminiTokenStorage) SetMetadata(meta map[string]any) {
ts.Metadata = meta
}
// SaveTokenToFile serializes the Gemini token storage to a JSON file.
// This method creates the necessary directory structure and writes the token
// data in JSON format to the specified file path for persistent storage.
// It merges any injected metadata into the top-level JSON object.
//
// Parameters:
// - authFilePath: The full path where the token file should be saved
@@ -49,6 +59,11 @@ type GeminiTokenStorage struct {
func (ts *GeminiTokenStorage) SaveTokenToFile(authFilePath string) error {
misc.LogSavingCredentials(authFilePath)
ts.Type = "gemini"
// Merge metadata using helper
data, errMerge := misc.MergeMetadata(ts, ts.Metadata)
if errMerge != nil {
return fmt.Errorf("failed to merge metadata: %w", errMerge)
}
if err := os.MkdirAll(filepath.Dir(authFilePath), 0700); err != nil {
return fmt.Errorf("failed to create directory: %v", err)
}
@@ -63,7 +78,9 @@ func (ts *GeminiTokenStorage) SaveTokenToFile(authFilePath string) error {
}
}()
if err = json.NewEncoder(f).Encode(ts); err != nil {
enc := json.NewEncoder(f)
enc.SetIndent("", " ")
if err := enc.Encode(data); err != nil {
return fmt.Errorf("failed to write token to file: %w", err)
}
return nil

View File

@@ -21,6 +21,15 @@ type IFlowTokenStorage struct {
Scope string `json:"scope"`
Cookie string `json:"cookie"`
Type string `json:"type"`
// Metadata holds arbitrary key-value pairs injected via hooks.
// It is not exported to JSON directly to allow flattening during serialization.
Metadata map[string]any `json:"-"`
}
// SetMetadata allows external callers to inject metadata into the storage before saving.
func (ts *IFlowTokenStorage) SetMetadata(meta map[string]any) {
ts.Metadata = meta
}
// SaveTokenToFile serialises the token storage to disk.
@@ -37,7 +46,13 @@ func (ts *IFlowTokenStorage) SaveTokenToFile(authFilePath string) error {
}
defer func() { _ = f.Close() }()
if err = json.NewEncoder(f).Encode(ts); err != nil {
// Merge metadata using helper
data, errMerge := misc.MergeMetadata(ts, ts.Metadata)
if errMerge != nil {
return fmt.Errorf("failed to merge metadata: %w", errMerge)
}
if err = json.NewEncoder(f).Encode(data); err != nil {
return fmt.Errorf("iflow token: encode token failed: %w", err)
}
return nil

View File

@@ -29,6 +29,15 @@ type KimiTokenStorage struct {
Expired string `json:"expired,omitempty"`
// Type indicates the authentication provider type, always "kimi" for this storage.
Type string `json:"type"`
// Metadata holds arbitrary key-value pairs injected via hooks.
// It is not exported to JSON directly to allow flattening during serialization.
Metadata map[string]any `json:"-"`
}
// SetMetadata allows external callers to inject metadata into the storage before saving.
func (ts *KimiTokenStorage) SetMetadata(meta map[string]any) {
ts.Metadata = meta
}
// KimiTokenData holds the raw OAuth token response from Kimi.
@@ -86,9 +95,15 @@ func (ts *KimiTokenStorage) SaveTokenToFile(authFilePath string) error {
_ = f.Close()
}()
// Merge metadata using helper
data, errMerge := misc.MergeMetadata(ts, ts.Metadata)
if errMerge != nil {
return fmt.Errorf("failed to merge metadata: %w", errMerge)
}
encoder := json.NewEncoder(f)
encoder.SetIndent("", " ")
if err = encoder.Encode(ts); err != nil {
if err = encoder.Encode(data); err != nil {
return fmt.Errorf("failed to write token to file: %w", err)
}
return nil

View File

@@ -30,11 +30,21 @@ type QwenTokenStorage struct {
Type string `json:"type"`
// Expire is the timestamp when the current access token expires.
Expire string `json:"expired"`
// Metadata holds arbitrary key-value pairs injected via hooks.
// It is not exported to JSON directly to allow flattening during serialization.
Metadata map[string]any `json:"-"`
}
// SetMetadata allows external callers to inject metadata into the storage before saving.
func (ts *QwenTokenStorage) SetMetadata(meta map[string]any) {
ts.Metadata = meta
}
// SaveTokenToFile serializes the Qwen token storage to a JSON file.
// This method creates the necessary directory structure and writes the token
// data in JSON format to the specified file path for persistent storage.
// It merges any injected metadata into the top-level JSON object.
//
// Parameters:
// - authFilePath: The full path where the token file should be saved
@@ -56,7 +66,13 @@ func (ts *QwenTokenStorage) SaveTokenToFile(authFilePath string) error {
_ = f.Close()
}()
if err = json.NewEncoder(f).Encode(ts); err != nil {
// Merge metadata using helper
data, errMerge := misc.MergeMetadata(ts, ts.Metadata)
if errMerge != nil {
return fmt.Errorf("failed to merge metadata: %w", errMerge)
}
if err = json.NewEncoder(f).Encode(data); err != nil {
return fmt.Errorf("failed to write token to file: %w", err)
}
return nil

View File

@@ -0,0 +1,60 @@
package cmd
import (
"context"
"errors"
"fmt"
"os"
"github.com/router-for-me/CLIProxyAPI/v6/internal/auth/codex"
"github.com/router-for-me/CLIProxyAPI/v6/internal/config"
sdkAuth "github.com/router-for-me/CLIProxyAPI/v6/sdk/auth"
log "github.com/sirupsen/logrus"
)
const (
codexLoginModeMetadataKey = "codex_login_mode"
codexLoginModeDevice = "device"
)
// DoCodexDeviceLogin triggers the Codex device-code flow while keeping the
// existing codex-login OAuth callback flow intact.
func DoCodexDeviceLogin(cfg *config.Config, options *LoginOptions) {
if options == nil {
options = &LoginOptions{}
}
promptFn := options.Prompt
if promptFn == nil {
promptFn = defaultProjectPrompt()
}
manager := newAuthManager()
authOpts := &sdkAuth.LoginOptions{
NoBrowser: options.NoBrowser,
CallbackPort: options.CallbackPort,
Metadata: map[string]string{
codexLoginModeMetadataKey: codexLoginModeDevice,
},
Prompt: promptFn,
}
_, savedPath, err := manager.Login(context.Background(), "codex", cfg, authOpts)
if err != nil {
if authErr, ok := errors.AsType[*codex.AuthenticationError](err); ok {
log.Error(codex.GetUserFriendlyMessage(authErr))
if authErr.Type == codex.ErrPortInUse.Type {
os.Exit(codex.ErrPortInUse.Code)
}
return
}
fmt.Printf("Codex device authentication failed: %v\n", err)
return
}
if savedPath != "" {
fmt.Printf("Authentication saved to %s\n", savedPath)
}
fmt.Println("Codex device authentication successful!")
}

View File

@@ -1,6 +1,7 @@
package misc
import (
"encoding/json"
"fmt"
"path/filepath"
"strings"
@@ -24,3 +25,37 @@ func LogSavingCredentials(path string) {
func LogCredentialSeparator() {
log.Debug(credentialSeparator)
}
// MergeMetadata serializes the source struct into a map and merges the provided metadata into it.
func MergeMetadata(source any, metadata map[string]any) (map[string]any, error) {
var data map[string]any
// Fast path: if source is already a map, just copy it to avoid mutation of original
if srcMap, ok := source.(map[string]any); ok {
data = make(map[string]any, len(srcMap)+len(metadata))
for k, v := range srcMap {
data[k] = v
}
} else {
// Slow path: marshal to JSON and back to map to respect JSON tags
temp, err := json.Marshal(source)
if err != nil {
return nil, fmt.Errorf("failed to marshal source: %w", err)
}
if err := json.Unmarshal(temp, &data); err != nil {
return nil, fmt.Errorf("failed to unmarshal to map: %w", err)
}
}
// Merge extra metadata
if metadata != nil {
if data == nil {
data = make(map[string]any)
}
for k, v := range metadata {
data[k] = v
}
}
return data, nil
}

View File

@@ -10,53 +10,10 @@ import (
"github.com/router-for-me/CLIProxyAPI/v6/internal/registry"
"github.com/router-for-me/CLIProxyAPI/v6/internal/thinking"
log "github.com/sirupsen/logrus"
"github.com/tidwall/gjson"
"github.com/tidwall/sjson"
)
// validReasoningEffortLevels contains the standard values accepted by the
// OpenAI reasoning_effort field. Provider-specific extensions (xhigh, minimal,
// auto) are NOT in this set and must be clamped before use.
var validReasoningEffortLevels = map[string]struct{}{
"none": {},
"low": {},
"medium": {},
"high": {},
}
// clampReasoningEffort maps any thinking level string to a value that is safe
// to send as OpenAI reasoning_effort. Non-standard CPA-internal values are
// mapped to the nearest standard equivalent.
//
// Mapping rules:
// - none / low / medium / high → returned as-is (already valid)
// - xhigh → "high" (nearest lower standard level)
// - minimal → "low" (nearest higher standard level)
// - auto → "medium" (reasonable default)
// - anything else → "medium" (safe default)
func clampReasoningEffort(level string) string {
if _, ok := validReasoningEffortLevels[level]; ok {
return level
}
var clamped string
switch level {
case string(thinking.LevelXHigh):
clamped = string(thinking.LevelHigh)
case string(thinking.LevelMinimal):
clamped = string(thinking.LevelLow)
case string(thinking.LevelAuto):
clamped = string(thinking.LevelMedium)
default:
clamped = string(thinking.LevelMedium)
}
log.WithFields(log.Fields{
"original": level,
"clamped": clamped,
}).Debug("openai: reasoning_effort clamped to nearest valid standard value")
return clamped
}
// Applier implements thinking.ProviderApplier for OpenAI models.
//
// OpenAI-specific behavior:
@@ -101,7 +58,7 @@ func (a *Applier) Apply(body []byte, config thinking.ThinkingConfig, modelInfo *
}
if config.Mode == thinking.ModeLevel {
result, _ := sjson.SetBytes(body, "reasoning_effort", clampReasoningEffort(string(config.Level)))
result, _ := sjson.SetBytes(body, "reasoning_effort", string(config.Level))
return result, nil
}
@@ -122,7 +79,7 @@ func (a *Applier) Apply(body []byte, config thinking.ThinkingConfig, modelInfo *
return body, nil
}
result, _ := sjson.SetBytes(body, "reasoning_effort", clampReasoningEffort(effort))
result, _ := sjson.SetBytes(body, "reasoning_effort", effort)
return result, nil
}
@@ -157,7 +114,7 @@ func applyCompatibleOpenAI(body []byte, config thinking.ThinkingConfig) ([]byte,
return body, nil
}
result, _ := sjson.SetBytes(body, "reasoning_effort", clampReasoningEffort(effort))
result, _ := sjson.SetBytes(body, "reasoning_effort", effort)
return result, nil
}

View File

@@ -95,9 +95,9 @@ func ConvertAntigravityResponseToOpenAI(_ context.Context, _ string, originalReq
if totalTokenCountResult := usageResult.Get("totalTokenCount"); totalTokenCountResult.Exists() {
template, _ = sjson.Set(template, "usage.total_tokens", totalTokenCountResult.Int())
}
promptTokenCount := usageResult.Get("promptTokenCount").Int() - cachedTokenCount
promptTokenCount := usageResult.Get("promptTokenCount").Int()
thoughtsTokenCount := usageResult.Get("thoughtsTokenCount").Int()
template, _ = sjson.Set(template, "usage.prompt_tokens", promptTokenCount+thoughtsTokenCount)
template, _ = sjson.Set(template, "usage.prompt_tokens", promptTokenCount)
if thoughtsTokenCount > 0 {
template, _ = sjson.Set(template, "usage.completion_tokens_details.reasoning_tokens", thoughtsTokenCount)
}

View File

@@ -199,6 +199,21 @@ func ConvertOpenAIRequestToClaude(modelName string, inputRawJSON []byte, stream
msg, _ = sjson.SetRaw(msg, "content.-1", imagePart)
}
}
case "file":
fileData := part.Get("file.file_data").String()
if strings.HasPrefix(fileData, "data:") {
semicolonIdx := strings.Index(fileData, ";")
commaIdx := strings.Index(fileData, ",")
if semicolonIdx != -1 && commaIdx != -1 && commaIdx > semicolonIdx {
mediaType := strings.TrimPrefix(fileData[:semicolonIdx], "data:")
data := fileData[commaIdx+1:]
docPart := `{"type":"document","source":{"type":"base64","media_type":"","data":""}}`
docPart, _ = sjson.Set(docPart, "source.media_type", mediaType)
docPart, _ = sjson.Set(docPart, "source.data", data)
msg, _ = sjson.SetRaw(msg, "content.-1", docPart)
}
}
}
return true
})

View File

@@ -155,6 +155,7 @@ func ConvertOpenAIResponsesRequestToClaude(modelName string, inputRawJSON []byte
var textAggregate strings.Builder
var partsJSON []string
hasImage := false
hasFile := false
if parts := item.Get("content"); parts.Exists() && parts.IsArray() {
parts.ForEach(func(_, part gjson.Result) bool {
ptype := part.Get("type").String()
@@ -207,6 +208,30 @@ func ConvertOpenAIResponsesRequestToClaude(modelName string, inputRawJSON []byte
hasImage = true
}
}
case "input_file":
fileData := part.Get("file_data").String()
if fileData != "" {
mediaType := "application/octet-stream"
data := fileData
if strings.HasPrefix(fileData, "data:") {
trimmed := strings.TrimPrefix(fileData, "data:")
mediaAndData := strings.SplitN(trimmed, ";base64,", 2)
if len(mediaAndData) == 2 {
if mediaAndData[0] != "" {
mediaType = mediaAndData[0]
}
data = mediaAndData[1]
}
}
contentPart := `{"type":"document","source":{"type":"base64","media_type":"","data":""}}`
contentPart, _ = sjson.Set(contentPart, "source.media_type", mediaType)
contentPart, _ = sjson.Set(contentPart, "source.data", data)
partsJSON = append(partsJSON, contentPart)
if role == "" {
role = "user"
}
hasFile = true
}
}
return true
})
@@ -228,7 +253,7 @@ func ConvertOpenAIResponsesRequestToClaude(modelName string, inputRawJSON []byte
if len(partsJSON) > 0 {
msg := `{"role":"","content":[]}`
msg, _ = sjson.Set(msg, "role", role)
if len(partsJSON) == 1 && !hasImage {
if len(partsJSON) == 1 && !hasImage && !hasFile {
// Preserve legacy behavior for single text content
msg, _ = sjson.Delete(msg, "content")
textPart := gjson.Parse(partsJSON[0])

View File

@@ -180,7 +180,19 @@ func ConvertOpenAIRequestToCodex(modelName string, inputRawJSON []byte, stream b
msg, _ = sjson.SetRaw(msg, "content.-1", part)
}
case "file":
// Files are not specified in examples; skip for now
if role == "user" {
fileData := it.Get("file.file_data").String()
filename := it.Get("file.filename").String()
if fileData != "" {
part := `{}`
part, _ = sjson.Set(part, "type", "input_file")
part, _ = sjson.Set(part, "file_data", fileData)
if filename != "" {
part, _ = sjson.Set(part, "filename", filename)
}
msg, _ = sjson.SetRaw(msg, "content.-1", part)
}
}
}
}
}

View File

@@ -100,7 +100,7 @@ func ConvertCliResponseToOpenAI(_ context.Context, _ string, originalRequestRawJ
}
promptTokenCount := usageResult.Get("promptTokenCount").Int()
thoughtsTokenCount := usageResult.Get("thoughtsTokenCount").Int()
template, _ = sjson.Set(template, "usage.prompt_tokens", promptTokenCount+thoughtsTokenCount)
template, _ = sjson.Set(template, "usage.prompt_tokens", promptTokenCount)
if thoughtsTokenCount > 0 {
template, _ = sjson.Set(template, "usage.completion_tokens_details.reasoning_tokens", thoughtsTokenCount)
}

View File

@@ -100,9 +100,9 @@ func ConvertGeminiResponseToOpenAI(_ context.Context, _ string, originalRequestR
if totalTokenCountResult := usageResult.Get("totalTokenCount"); totalTokenCountResult.Exists() {
baseTemplate, _ = sjson.Set(baseTemplate, "usage.total_tokens", totalTokenCountResult.Int())
}
promptTokenCount := usageResult.Get("promptTokenCount").Int() - cachedTokenCount
promptTokenCount := usageResult.Get("promptTokenCount").Int()
thoughtsTokenCount := usageResult.Get("thoughtsTokenCount").Int()
baseTemplate, _ = sjson.Set(baseTemplate, "usage.prompt_tokens", promptTokenCount+thoughtsTokenCount)
baseTemplate, _ = sjson.Set(baseTemplate, "usage.prompt_tokens", promptTokenCount)
if thoughtsTokenCount > 0 {
baseTemplate, _ = sjson.Set(baseTemplate, "usage.completion_tokens_details.reasoning_tokens", thoughtsTokenCount)
}
@@ -297,7 +297,7 @@ func ConvertGeminiResponseToOpenAINonStream(_ context.Context, _ string, origina
promptTokenCount := usageResult.Get("promptTokenCount").Int()
thoughtsTokenCount := usageResult.Get("thoughtsTokenCount").Int()
cachedTokenCount := usageResult.Get("cachedContentTokenCount").Int()
template, _ = sjson.Set(template, "usage.prompt_tokens", promptTokenCount+thoughtsTokenCount)
template, _ = sjson.Set(template, "usage.prompt_tokens", promptTokenCount)
if thoughtsTokenCount > 0 {
template, _ = sjson.Set(template, "usage.completion_tokens_details.reasoning_tokens", thoughtsTokenCount)
}

View File

@@ -531,8 +531,8 @@ func ConvertGeminiResponseToOpenAIResponses(_ context.Context, modelName string,
// usage mapping
if um := root.Get("usageMetadata"); um.Exists() {
// input tokens = prompt + thoughts
input := um.Get("promptTokenCount").Int() + um.Get("thoughtsTokenCount").Int()
// input tokens = prompt only (thoughts go to output)
input := um.Get("promptTokenCount").Int()
completed, _ = sjson.Set(completed, "response.usage.input_tokens", input)
// cached token details: align with OpenAI "cached_tokens" semantics.
completed, _ = sjson.Set(completed, "response.usage.input_tokens_details.cached_tokens", um.Get("cachedContentTokenCount").Int())
@@ -737,8 +737,8 @@ func ConvertGeminiResponseToOpenAIResponsesNonStream(_ context.Context, _ string
// usage mapping
if um := root.Get("usageMetadata"); um.Exists() {
// input tokens = prompt + thoughts
input := um.Get("promptTokenCount").Int() + um.Get("thoughtsTokenCount").Int()
// input tokens = prompt only (thoughts go to output)
input := um.Get("promptTokenCount").Int()
resp, _ = sjson.Set(resp, "usage.input_tokens", input)
// cached token details: align with OpenAI "cached_tokens" semantics.
resp, _ = sjson.Set(resp, "usage.input_tokens_details.cached_tokens", um.Get("cachedContentTokenCount").Int())

View File

@@ -53,19 +53,25 @@ var KnownCommandTools = map[string]bool{
"execute_python": true,
}
// RequiredFieldsByTool maps tool names to their required fields.
// If any of these fields are missing, the tool input is considered truncated.
var RequiredFieldsByTool = map[string][]string{
"Write": {"file_path", "content"},
"write_to_file": {"path", "content"},
"fsWrite": {"path", "content"},
"create_file": {"path", "content"},
"edit_file": {"path"},
"apply_diff": {"path", "diff"},
"str_replace_editor": {"path", "old_str", "new_str"},
"Bash": {"command"},
"execute": {"command"},
"run_command": {"command"},
// RequiredFieldsByTool maps tool names to their required field groups.
// Each outer element is a required group; each inner slice lists alternative field names (OR logic).
// A group is satisfied when ANY one of its alternatives exists in the parsed input.
// All groups must be satisfied for the tool input to be considered valid.
//
// Example:
// {{"cmd", "command"}} means the tool needs EITHER "cmd" OR "command".
// {{"file_path"}, {"content"}} means the tool needs BOTH "file_path" AND "content".
var RequiredFieldsByTool = map[string][][]string{
"Write": {{"file_path"}, {"content"}},
"write_to_file": {{"path"}, {"content"}},
"fsWrite": {{"path"}, {"content"}},
"create_file": {{"path"}, {"content"}},
"edit_file": {{"path"}},
"apply_diff": {{"path"}, {"diff"}},
"str_replace_editor": {{"path"}, {"old_str"}, {"new_str"}},
"Bash": {{"cmd", "command"}},
"execute": {{"command"}},
"run_command": {{"command"}},
}
// DetectTruncation checks if the tool use input appears to be truncated.
@@ -104,9 +110,9 @@ func DetectTruncation(toolName, toolUseID, rawInput string, parsedInput map[stri
// Scenario 3: JSON parsed but critical fields are missing
if parsedInput != nil {
requiredFields, hasRequirements := RequiredFieldsByTool[toolName]
requiredGroups, hasRequirements := RequiredFieldsByTool[toolName]
if hasRequirements {
missingFields := findMissingRequiredFields(parsedInput, requiredFields)
missingFields := findMissingRequiredFields(parsedInput, requiredGroups)
if len(missingFields) > 0 {
info.IsTruncated = true
info.TruncationType = TruncationTypeMissingFields
@@ -253,12 +259,21 @@ func extractParsedFieldNames(parsed map[string]interface{}) map[string]string {
return fields
}
// findMissingRequiredFields checks which required fields are missing from the parsed input.
func findMissingRequiredFields(parsed map[string]interface{}, required []string) []string {
// findMissingRequiredFields checks which required field groups are unsatisfied.
// Each group is a slice of alternative field names; the group is satisfied when ANY alternative exists.
// Returns the list of unsatisfied groups (represented by their alternatives joined with "/").
func findMissingRequiredFields(parsed map[string]interface{}, requiredGroups [][]string) []string {
var missing []string
for _, field := range required {
if _, exists := parsed[field]; !exists {
missing = append(missing, field)
for _, group := range requiredGroups {
satisfied := false
for _, field := range group {
if _, exists := parsed[field]; exists {
satisfied = true
break
}
}
if !satisfied {
missing = append(missing, strings.Join(group, "/"))
}
}
return missing

View File

@@ -578,6 +578,7 @@ func processOpenAIMessages(messages gjson.Result, modelID, origin string) ([]Kir
// Truncate history if too long to prevent Kiro API errors
history = truncateHistoryIfNeeded(history)
history, currentToolResults = filterOrphanedToolResults(history, currentToolResults)
return history, currentUserMsg, currentToolResults
}
@@ -593,6 +594,61 @@ func truncateHistoryIfNeeded(history []KiroHistoryMessage) []KiroHistoryMessage
return history[len(history)-kiroMaxHistoryMessages:]
}
func filterOrphanedToolResults(history []KiroHistoryMessage, currentToolResults []KiroToolResult) ([]KiroHistoryMessage, []KiroToolResult) {
// Remove tool results with no matching tool_use in retained history.
// This happens after truncation when the assistant turn that produced tool_use
// is dropped but a later user/tool_result survives.
validToolUseIDs := make(map[string]bool)
for _, h := range history {
if h.AssistantResponseMessage == nil {
continue
}
for _, tu := range h.AssistantResponseMessage.ToolUses {
validToolUseIDs[tu.ToolUseID] = true
}
}
for i, h := range history {
if h.UserInputMessage == nil || h.UserInputMessage.UserInputMessageContext == nil {
continue
}
ctx := h.UserInputMessage.UserInputMessageContext
if len(ctx.ToolResults) == 0 {
continue
}
filtered := make([]KiroToolResult, 0, len(ctx.ToolResults))
for _, tr := range ctx.ToolResults {
if validToolUseIDs[tr.ToolUseID] {
filtered = append(filtered, tr)
continue
}
log.Debugf("kiro-openai: dropping orphaned tool_result in history[%d]: toolUseId=%s (no matching tool_use)", i, tr.ToolUseID)
}
ctx.ToolResults = filtered
if len(ctx.ToolResults) == 0 && len(ctx.Tools) == 0 {
h.UserInputMessage.UserInputMessageContext = nil
}
}
if len(currentToolResults) > 0 {
filtered := make([]KiroToolResult, 0, len(currentToolResults))
for _, tr := range currentToolResults {
if validToolUseIDs[tr.ToolUseID] {
filtered = append(filtered, tr)
continue
}
log.Debugf("kiro-openai: dropping orphaned tool_result in currentMessage: toolUseId=%s (no matching tool_use)", tr.ToolUseID)
}
if len(filtered) != len(currentToolResults) {
log.Infof("kiro-openai: dropped %d orphaned tool_result(s) from currentMessage", len(currentToolResults)-len(filtered))
}
currentToolResults = filtered
}
return history, currentToolResults
}
// buildUserMessageFromOpenAI builds a user message from OpenAI format and extracts tool results
func buildUserMessageFromOpenAI(msg gjson.Result, modelID, origin string) (KiroUserInputMessage, []KiroToolResult) {
content := msg.Get("content")

View File

@@ -384,3 +384,57 @@ func TestAssistantEndsConversation(t *testing.T) {
t.Error("Expected a 'Continue' message to be created when assistant is last")
}
}
func TestFilterOrphanedToolResults_RemovesHistoryAndCurrentOrphans(t *testing.T) {
history := []KiroHistoryMessage{
{
AssistantResponseMessage: &KiroAssistantResponseMessage{
Content: "assistant",
ToolUses: []KiroToolUse{
{ToolUseID: "keep-1", Name: "Read", Input: map[string]interface{}{}},
},
},
},
{
UserInputMessage: &KiroUserInputMessage{
Content: "user-with-mixed-results",
UserInputMessageContext: &KiroUserInputMessageContext{
ToolResults: []KiroToolResult{
{ToolUseID: "keep-1", Status: "success", Content: []KiroTextContent{{Text: "ok"}}},
{ToolUseID: "orphan-1", Status: "success", Content: []KiroTextContent{{Text: "bad"}}},
},
},
},
},
{
UserInputMessage: &KiroUserInputMessage{
Content: "user-only-orphans",
UserInputMessageContext: &KiroUserInputMessageContext{
ToolResults: []KiroToolResult{
{ToolUseID: "orphan-2", Status: "success", Content: []KiroTextContent{{Text: "bad"}}},
},
},
},
},
}
currentToolResults := []KiroToolResult{
{ToolUseID: "keep-1", Status: "success", Content: []KiroTextContent{{Text: "ok"}}},
{ToolUseID: "orphan-3", Status: "success", Content: []KiroTextContent{{Text: "bad"}}},
}
filteredHistory, filteredCurrent := filterOrphanedToolResults(history, currentToolResults)
ctx1 := filteredHistory[1].UserInputMessage.UserInputMessageContext
if ctx1 == nil || len(ctx1.ToolResults) != 1 || ctx1.ToolResults[0].ToolUseID != "keep-1" {
t.Fatalf("expected mixed history message to keep only keep-1, got: %+v", ctx1)
}
if filteredHistory[2].UserInputMessage.UserInputMessageContext != nil {
t.Fatalf("expected orphan-only history context to be removed")
}
if len(filteredCurrent) != 1 || filteredCurrent[0].ToolUseID != "keep-1" {
t.Fatalf("expected current tool results to keep only keep-1, got: %+v", filteredCurrent)
}
}

View File

@@ -717,6 +717,12 @@ func (h *BaseAPIHandler) ExecuteStreamWithAuthManager(ctx context.Context, handl
return
}
if len(chunk.Payload) > 0 {
if handlerType == "openai-response" {
if err := validateSSEDataJSON(chunk.Payload); err != nil {
_ = sendErr(&interfaces.ErrorMessage{StatusCode: http.StatusBadGateway, Error: err})
return
}
}
sentPayload = true
if okSendData := sendData(cloneBytes(chunk.Payload)); !okSendData {
return
@@ -728,6 +734,35 @@ func (h *BaseAPIHandler) ExecuteStreamWithAuthManager(ctx context.Context, handl
return dataChan, upstreamHeaders, errChan
}
func validateSSEDataJSON(chunk []byte) error {
for _, line := range bytes.Split(chunk, []byte("\n")) {
line = bytes.TrimSpace(line)
if len(line) == 0 {
continue
}
if !bytes.HasPrefix(line, []byte("data:")) {
continue
}
data := bytes.TrimSpace(line[5:])
if len(data) == 0 {
continue
}
if bytes.Equal(data, []byte("[DONE]")) {
continue
}
if json.Valid(data) {
continue
}
const max = 512
preview := data
if len(preview) > max {
preview = preview[:max]
}
return fmt.Errorf("invalid SSE data JSON (len=%d): %q", len(data), preview)
}
return nil
}
func statusFromError(err error) int {
if err == nil {
return 0

View File

@@ -134,6 +134,37 @@ type authAwareStreamExecutor struct {
authIDs []string
}
type invalidJSONStreamExecutor struct{}
func (e *invalidJSONStreamExecutor) Identifier() string { return "codex" }
func (e *invalidJSONStreamExecutor) Execute(context.Context, *coreauth.Auth, coreexecutor.Request, coreexecutor.Options) (coreexecutor.Response, error) {
return coreexecutor.Response{}, &coreauth.Error{Code: "not_implemented", Message: "Execute not implemented"}
}
func (e *invalidJSONStreamExecutor) ExecuteStream(context.Context, *coreauth.Auth, coreexecutor.Request, coreexecutor.Options) (*coreexecutor.StreamResult, error) {
ch := make(chan coreexecutor.StreamChunk, 1)
ch <- coreexecutor.StreamChunk{Payload: []byte("event: response.completed\ndata: {\"type\"")}
close(ch)
return &coreexecutor.StreamResult{Chunks: ch}, nil
}
func (e *invalidJSONStreamExecutor) Refresh(ctx context.Context, auth *coreauth.Auth) (*coreauth.Auth, error) {
return auth, nil
}
func (e *invalidJSONStreamExecutor) CountTokens(context.Context, *coreauth.Auth, coreexecutor.Request, coreexecutor.Options) (coreexecutor.Response, error) {
return coreexecutor.Response{}, &coreauth.Error{Code: "not_implemented", Message: "CountTokens not implemented"}
}
func (e *invalidJSONStreamExecutor) HttpRequest(ctx context.Context, auth *coreauth.Auth, req *http.Request) (*http.Response, error) {
return nil, &coreauth.Error{
Code: "not_implemented",
Message: "HttpRequest not implemented",
HTTPStatus: http.StatusNotImplemented,
}
}
func (e *authAwareStreamExecutor) Identifier() string { return "codex" }
func (e *authAwareStreamExecutor) Execute(context.Context, *coreauth.Auth, coreexecutor.Request, coreexecutor.Options) (coreexecutor.Response, error) {
@@ -524,3 +555,55 @@ func TestExecuteStreamWithAuthManager_SelectedAuthCallbackReceivesAuthID(t *test
t.Fatalf("selectedAuthID = %q, want %q", selectedAuthID, "auth2")
}
}
func TestExecuteStreamWithAuthManager_ValidatesOpenAIResponsesStreamDataJSON(t *testing.T) {
executor := &invalidJSONStreamExecutor{}
manager := coreauth.NewManager(nil, nil, nil)
manager.RegisterExecutor(executor)
auth1 := &coreauth.Auth{
ID: "auth1",
Provider: "codex",
Status: coreauth.StatusActive,
Metadata: map[string]any{"email": "test1@example.com"},
}
if _, err := manager.Register(context.Background(), auth1); err != nil {
t.Fatalf("manager.Register(auth1): %v", err)
}
registry.GetGlobalRegistry().RegisterClient(auth1.ID, auth1.Provider, []*registry.ModelInfo{{ID: "test-model"}})
t.Cleanup(func() {
registry.GetGlobalRegistry().UnregisterClient(auth1.ID)
})
handler := NewBaseAPIHandlers(&sdkconfig.SDKConfig{}, manager)
dataChan, _, errChan := handler.ExecuteStreamWithAuthManager(context.Background(), "openai-response", "test-model", []byte(`{"model":"test-model"}`), "")
if dataChan == nil || errChan == nil {
t.Fatalf("expected non-nil channels")
}
var got []byte
for chunk := range dataChan {
got = append(got, chunk...)
}
if len(got) != 0 {
t.Fatalf("expected empty payload, got %q", string(got))
}
gotErr := false
for msg := range errChan {
if msg == nil {
continue
}
if msg.StatusCode != http.StatusBadGateway {
t.Fatalf("expected status %d, got %d", http.StatusBadGateway, msg.StatusCode)
}
if msg.Error == nil {
t.Fatalf("expected error")
}
gotErr = true
}
if !gotErr {
t.Fatalf("expected terminal error")
}
}

View File

@@ -418,8 +418,8 @@ func (h *OpenAIResponsesAPIHandler) forwardResponsesStream(c *gin.Context, flush
if errMsg.Error != nil && errMsg.Error.Error() != "" {
errText = errMsg.Error.Error()
}
body := handlers.BuildErrorResponseBody(status, errText)
_, _ = fmt.Fprintf(c.Writer, "\nevent: error\ndata: %s\n\n", string(body))
chunk := handlers.BuildOpenAIResponsesStreamErrorChunk(status, errText, 0)
_, _ = fmt.Fprintf(c.Writer, "\nevent: error\ndata: %s\n\n", string(chunk))
},
WriteDone: func() {
_, _ = c.Writer.Write([]byte("\n"))

View File

@@ -0,0 +1,43 @@
package openai
import (
"errors"
"net/http"
"net/http/httptest"
"strings"
"testing"
"github.com/gin-gonic/gin"
"github.com/router-for-me/CLIProxyAPI/v6/internal/interfaces"
"github.com/router-for-me/CLIProxyAPI/v6/sdk/api/handlers"
sdkconfig "github.com/router-for-me/CLIProxyAPI/v6/sdk/config"
)
func TestForwardResponsesStreamTerminalErrorUsesResponsesErrorChunk(t *testing.T) {
gin.SetMode(gin.TestMode)
base := handlers.NewBaseAPIHandlers(&sdkconfig.SDKConfig{}, nil)
h := NewOpenAIResponsesAPIHandler(base)
recorder := httptest.NewRecorder()
c, _ := gin.CreateTestContext(recorder)
c.Request = httptest.NewRequest(http.MethodPost, "/v1/responses", nil)
flusher, ok := c.Writer.(http.Flusher)
if !ok {
t.Fatalf("expected gin writer to implement http.Flusher")
}
data := make(chan []byte)
errs := make(chan *interfaces.ErrorMessage, 1)
errs <- &interfaces.ErrorMessage{StatusCode: http.StatusInternalServerError, Error: errors.New("unexpected EOF")}
close(errs)
h.forwardResponsesStream(c, flusher, func(error) {}, data, errs)
body := recorder.Body.String()
if !strings.Contains(body, `"type":"error"`) {
t.Fatalf("expected responses error chunk, got: %q", body)
}
if strings.Contains(body, `"error":{`) {
t.Fatalf("expected streaming error chunk (top-level type), got HTTP error body: %q", body)
}
}

View File

@@ -0,0 +1,119 @@
package handlers
import (
"encoding/json"
"fmt"
"net/http"
"strings"
)
type openAIResponsesStreamErrorChunk struct {
Type string `json:"type"`
Code string `json:"code"`
Message string `json:"message"`
SequenceNumber int `json:"sequence_number"`
}
func openAIResponsesStreamErrorCode(status int) string {
switch status {
case http.StatusUnauthorized:
return "invalid_api_key"
case http.StatusForbidden:
return "insufficient_quota"
case http.StatusTooManyRequests:
return "rate_limit_exceeded"
case http.StatusNotFound:
return "model_not_found"
case http.StatusRequestTimeout:
return "request_timeout"
default:
if status >= http.StatusInternalServerError {
return "internal_server_error"
}
if status >= http.StatusBadRequest {
return "invalid_request_error"
}
return "unknown_error"
}
}
// BuildOpenAIResponsesStreamErrorChunk builds an OpenAI Responses streaming error chunk.
//
// Important: OpenAI's HTTP error bodies are shaped like {"error":{...}}; those are valid for
// non-streaming responses, but streaming clients validate SSE `data:` payloads against a union
// of chunks that requires a top-level `type` field.
func BuildOpenAIResponsesStreamErrorChunk(status int, errText string, sequenceNumber int) []byte {
if status <= 0 {
status = http.StatusInternalServerError
}
if sequenceNumber < 0 {
sequenceNumber = 0
}
message := strings.TrimSpace(errText)
if message == "" {
message = http.StatusText(status)
}
code := openAIResponsesStreamErrorCode(status)
trimmed := strings.TrimSpace(errText)
if trimmed != "" && json.Valid([]byte(trimmed)) {
var payload map[string]any
if err := json.Unmarshal([]byte(trimmed), &payload); err == nil {
if t, ok := payload["type"].(string); ok && strings.TrimSpace(t) == "error" {
if m, ok := payload["message"].(string); ok && strings.TrimSpace(m) != "" {
message = strings.TrimSpace(m)
}
if v, ok := payload["code"]; ok && v != nil {
if c, ok := v.(string); ok && strings.TrimSpace(c) != "" {
code = strings.TrimSpace(c)
} else {
code = strings.TrimSpace(fmt.Sprint(v))
}
}
if v, ok := payload["sequence_number"].(float64); ok && sequenceNumber == 0 {
sequenceNumber = int(v)
}
}
if e, ok := payload["error"].(map[string]any); ok {
if m, ok := e["message"].(string); ok && strings.TrimSpace(m) != "" {
message = strings.TrimSpace(m)
}
if v, ok := e["code"]; ok && v != nil {
if c, ok := v.(string); ok && strings.TrimSpace(c) != "" {
code = strings.TrimSpace(c)
} else {
code = strings.TrimSpace(fmt.Sprint(v))
}
}
}
}
}
if strings.TrimSpace(code) == "" {
code = "unknown_error"
}
data, err := json.Marshal(openAIResponsesStreamErrorChunk{
Type: "error",
Code: code,
Message: message,
SequenceNumber: sequenceNumber,
})
if err == nil {
return data
}
// Extremely defensive fallback.
data, _ = json.Marshal(openAIResponsesStreamErrorChunk{
Type: "error",
Code: "internal_server_error",
Message: message,
SequenceNumber: sequenceNumber,
})
if len(data) > 0 {
return data
}
return []byte(`{"type":"error","code":"internal_server_error","message":"internal error","sequence_number":0}`)
}

View File

@@ -0,0 +1,48 @@
package handlers
import (
"encoding/json"
"net/http"
"testing"
)
func TestBuildOpenAIResponsesStreamErrorChunk(t *testing.T) {
chunk := BuildOpenAIResponsesStreamErrorChunk(http.StatusInternalServerError, "unexpected EOF", 0)
var payload map[string]any
if err := json.Unmarshal(chunk, &payload); err != nil {
t.Fatalf("unmarshal: %v", err)
}
if payload["type"] != "error" {
t.Fatalf("type = %v, want %q", payload["type"], "error")
}
if payload["code"] != "internal_server_error" {
t.Fatalf("code = %v, want %q", payload["code"], "internal_server_error")
}
if payload["message"] != "unexpected EOF" {
t.Fatalf("message = %v, want %q", payload["message"], "unexpected EOF")
}
if payload["sequence_number"] != float64(0) {
t.Fatalf("sequence_number = %v, want %v", payload["sequence_number"], 0)
}
}
func TestBuildOpenAIResponsesStreamErrorChunkExtractsHTTPErrorBody(t *testing.T) {
chunk := BuildOpenAIResponsesStreamErrorChunk(
http.StatusInternalServerError,
`{"error":{"message":"oops","type":"server_error","code":"internal_server_error"}}`,
0,
)
var payload map[string]any
if err := json.Unmarshal(chunk, &payload); err != nil {
t.Fatalf("unmarshal: %v", err)
}
if payload["type"] != "error" {
t.Fatalf("type = %v, want %q", payload["type"], "error")
}
if payload["code"] != "internal_server_error" {
t.Fatalf("code = %v, want %q", payload["code"], "internal_server_error")
}
if payload["message"] != "oops" {
t.Fatalf("message = %v, want %q", payload["message"], "oops")
}
}

View File

@@ -2,8 +2,6 @@ package auth
import (
"context"
"crypto/sha256"
"encoding/hex"
"fmt"
"net/http"
"strings"
@@ -48,6 +46,10 @@ func (a *CodexAuthenticator) Login(ctx context.Context, cfg *config.Config, opts
opts = &LoginOptions{}
}
if shouldUseCodexDeviceFlow(opts) {
return a.loginWithDeviceFlow(ctx, cfg, opts)
}
callbackPort := a.CallbackPort
if opts.CallbackPort > 0 {
callbackPort = opts.CallbackPort
@@ -186,39 +188,5 @@ waitForCallback:
return nil, codex.NewAuthenticationError(codex.ErrCodeExchangeFailed, err)
}
tokenStorage := authSvc.CreateTokenStorage(authBundle)
if tokenStorage == nil || tokenStorage.Email == "" {
return nil, fmt.Errorf("codex token storage missing account information")
}
planType := ""
hashAccountID := ""
if tokenStorage.IDToken != "" {
if claims, errParse := codex.ParseJWTToken(tokenStorage.IDToken); errParse == nil && claims != nil {
planType = strings.TrimSpace(claims.CodexAuthInfo.ChatgptPlanType)
accountID := strings.TrimSpace(claims.CodexAuthInfo.ChatgptAccountID)
if accountID != "" {
digest := sha256.Sum256([]byte(accountID))
hashAccountID = hex.EncodeToString(digest[:])[:8]
}
}
}
fileName := codex.CredentialFileName(tokenStorage.Email, planType, hashAccountID, true)
metadata := map[string]any{
"email": tokenStorage.Email,
}
fmt.Println("Codex authentication successful")
if authBundle.APIKey != "" {
fmt.Println("Codex API key obtained and stored")
}
return &coreauth.Auth{
ID: fileName,
Provider: a.Provider(),
FileName: fileName,
Storage: tokenStorage,
Metadata: metadata,
}, nil
return a.buildAuthRecord(authSvc, authBundle)
}

291
sdk/auth/codex_device.go Normal file
View File

@@ -0,0 +1,291 @@
package auth
import (
"bytes"
"context"
"crypto/sha256"
"encoding/hex"
"encoding/json"
"fmt"
"io"
"net/http"
"strconv"
"strings"
"time"
"github.com/router-for-me/CLIProxyAPI/v6/internal/auth/codex"
"github.com/router-for-me/CLIProxyAPI/v6/internal/browser"
"github.com/router-for-me/CLIProxyAPI/v6/internal/config"
"github.com/router-for-me/CLIProxyAPI/v6/internal/util"
coreauth "github.com/router-for-me/CLIProxyAPI/v6/sdk/cliproxy/auth"
log "github.com/sirupsen/logrus"
)
const (
codexLoginModeMetadataKey = "codex_login_mode"
codexLoginModeDevice = "device"
codexDeviceUserCodeURL = "https://auth.openai.com/api/accounts/deviceauth/usercode"
codexDeviceTokenURL = "https://auth.openai.com/api/accounts/deviceauth/token"
codexDeviceVerificationURL = "https://auth.openai.com/codex/device"
codexDeviceTokenExchangeRedirectURI = "https://auth.openai.com/deviceauth/callback"
codexDeviceTimeout = 15 * time.Minute
codexDeviceDefaultPollIntervalSeconds = 5
)
type codexDeviceUserCodeRequest struct {
ClientID string `json:"client_id"`
}
type codexDeviceUserCodeResponse struct {
DeviceAuthID string `json:"device_auth_id"`
UserCode string `json:"user_code"`
UserCodeAlt string `json:"usercode"`
Interval json.RawMessage `json:"interval"`
}
type codexDeviceTokenRequest struct {
DeviceAuthID string `json:"device_auth_id"`
UserCode string `json:"user_code"`
}
type codexDeviceTokenResponse struct {
AuthorizationCode string `json:"authorization_code"`
CodeVerifier string `json:"code_verifier"`
CodeChallenge string `json:"code_challenge"`
}
func shouldUseCodexDeviceFlow(opts *LoginOptions) bool {
if opts == nil || opts.Metadata == nil {
return false
}
return strings.EqualFold(strings.TrimSpace(opts.Metadata[codexLoginModeMetadataKey]), codexLoginModeDevice)
}
func (a *CodexAuthenticator) loginWithDeviceFlow(ctx context.Context, cfg *config.Config, opts *LoginOptions) (*coreauth.Auth, error) {
if ctx == nil {
ctx = context.Background()
}
httpClient := util.SetProxy(&cfg.SDKConfig, &http.Client{})
userCodeResp, err := requestCodexDeviceUserCode(ctx, httpClient)
if err != nil {
return nil, err
}
deviceCode := strings.TrimSpace(userCodeResp.UserCode)
if deviceCode == "" {
deviceCode = strings.TrimSpace(userCodeResp.UserCodeAlt)
}
deviceAuthID := strings.TrimSpace(userCodeResp.DeviceAuthID)
if deviceCode == "" || deviceAuthID == "" {
return nil, fmt.Errorf("codex device flow did not return required fields")
}
pollInterval := parseCodexDevicePollInterval(userCodeResp.Interval)
fmt.Println("Starting Codex device authentication...")
fmt.Printf("Codex device URL: %s\n", codexDeviceVerificationURL)
fmt.Printf("Codex device code: %s\n", deviceCode)
if !opts.NoBrowser {
if !browser.IsAvailable() {
log.Warn("No browser available; please open the device URL manually")
} else if errOpen := browser.OpenURL(codexDeviceVerificationURL); errOpen != nil {
log.Warnf("Failed to open browser automatically: %v", errOpen)
}
}
tokenResp, err := pollCodexDeviceToken(ctx, httpClient, deviceAuthID, deviceCode, pollInterval)
if err != nil {
return nil, err
}
authCode := strings.TrimSpace(tokenResp.AuthorizationCode)
codeVerifier := strings.TrimSpace(tokenResp.CodeVerifier)
codeChallenge := strings.TrimSpace(tokenResp.CodeChallenge)
if authCode == "" || codeVerifier == "" || codeChallenge == "" {
return nil, fmt.Errorf("codex device flow token response missing required fields")
}
authSvc := codex.NewCodexAuth(cfg)
authBundle, err := authSvc.ExchangeCodeForTokensWithRedirect(
ctx,
authCode,
codexDeviceTokenExchangeRedirectURI,
&codex.PKCECodes{
CodeVerifier: codeVerifier,
CodeChallenge: codeChallenge,
},
)
if err != nil {
return nil, codex.NewAuthenticationError(codex.ErrCodeExchangeFailed, err)
}
return a.buildAuthRecord(authSvc, authBundle)
}
func requestCodexDeviceUserCode(ctx context.Context, client *http.Client) (*codexDeviceUserCodeResponse, error) {
body, err := json.Marshal(codexDeviceUserCodeRequest{ClientID: codex.ClientID})
if err != nil {
return nil, fmt.Errorf("failed to encode codex device request: %w", err)
}
req, err := http.NewRequestWithContext(ctx, http.MethodPost, codexDeviceUserCodeURL, bytes.NewReader(body))
if err != nil {
return nil, fmt.Errorf("failed to create codex device request: %w", err)
}
req.Header.Set("Content-Type", "application/json")
req.Header.Set("Accept", "application/json")
resp, err := client.Do(req)
if err != nil {
return nil, fmt.Errorf("failed to request codex device code: %w", err)
}
defer func() { _ = resp.Body.Close() }()
respBody, err := io.ReadAll(resp.Body)
if err != nil {
return nil, fmt.Errorf("failed to read codex device code response: %w", err)
}
if !codexDeviceIsSuccessStatus(resp.StatusCode) {
trimmed := strings.TrimSpace(string(respBody))
if resp.StatusCode == http.StatusNotFound {
return nil, fmt.Errorf("codex device endpoint is unavailable (status %d)", resp.StatusCode)
}
if trimmed == "" {
trimmed = "empty response body"
}
return nil, fmt.Errorf("codex device code request failed with status %d: %s", resp.StatusCode, trimmed)
}
var parsed codexDeviceUserCodeResponse
if err := json.Unmarshal(respBody, &parsed); err != nil {
return nil, fmt.Errorf("failed to decode codex device code response: %w", err)
}
return &parsed, nil
}
func pollCodexDeviceToken(ctx context.Context, client *http.Client, deviceAuthID, userCode string, interval time.Duration) (*codexDeviceTokenResponse, error) {
deadline := time.Now().Add(codexDeviceTimeout)
for {
if time.Now().After(deadline) {
return nil, fmt.Errorf("codex device authentication timed out after 15 minutes")
}
body, err := json.Marshal(codexDeviceTokenRequest{
DeviceAuthID: deviceAuthID,
UserCode: userCode,
})
if err != nil {
return nil, fmt.Errorf("failed to encode codex device poll request: %w", err)
}
req, err := http.NewRequestWithContext(ctx, http.MethodPost, codexDeviceTokenURL, bytes.NewReader(body))
if err != nil {
return nil, fmt.Errorf("failed to create codex device poll request: %w", err)
}
req.Header.Set("Content-Type", "application/json")
req.Header.Set("Accept", "application/json")
resp, err := client.Do(req)
if err != nil {
return nil, fmt.Errorf("failed to poll codex device token: %w", err)
}
respBody, readErr := io.ReadAll(resp.Body)
_ = resp.Body.Close()
if readErr != nil {
return nil, fmt.Errorf("failed to read codex device poll response: %w", readErr)
}
switch {
case codexDeviceIsSuccessStatus(resp.StatusCode):
var parsed codexDeviceTokenResponse
if err := json.Unmarshal(respBody, &parsed); err != nil {
return nil, fmt.Errorf("failed to decode codex device token response: %w", err)
}
return &parsed, nil
case resp.StatusCode == http.StatusForbidden || resp.StatusCode == http.StatusNotFound:
select {
case <-ctx.Done():
return nil, ctx.Err()
case <-time.After(interval):
continue
}
default:
trimmed := strings.TrimSpace(string(respBody))
if trimmed == "" {
trimmed = "empty response body"
}
return nil, fmt.Errorf("codex device token polling failed with status %d: %s", resp.StatusCode, trimmed)
}
}
}
func parseCodexDevicePollInterval(raw json.RawMessage) time.Duration {
defaultInterval := time.Duration(codexDeviceDefaultPollIntervalSeconds) * time.Second
if len(raw) == 0 {
return defaultInterval
}
var asString string
if err := json.Unmarshal(raw, &asString); err == nil {
if seconds, convErr := strconv.Atoi(strings.TrimSpace(asString)); convErr == nil && seconds > 0 {
return time.Duration(seconds) * time.Second
}
}
var asInt int
if err := json.Unmarshal(raw, &asInt); err == nil && asInt > 0 {
return time.Duration(asInt) * time.Second
}
return defaultInterval
}
func codexDeviceIsSuccessStatus(code int) bool {
return code >= 200 && code < 300
}
func (a *CodexAuthenticator) buildAuthRecord(authSvc *codex.CodexAuth, authBundle *codex.CodexAuthBundle) (*coreauth.Auth, error) {
tokenStorage := authSvc.CreateTokenStorage(authBundle)
if tokenStorage == nil || tokenStorage.Email == "" {
return nil, fmt.Errorf("codex token storage missing account information")
}
planType := ""
hashAccountID := ""
if tokenStorage.IDToken != "" {
if claims, errParse := codex.ParseJWTToken(tokenStorage.IDToken); errParse == nil && claims != nil {
planType = strings.TrimSpace(claims.CodexAuthInfo.ChatgptPlanType)
accountID := strings.TrimSpace(claims.CodexAuthInfo.ChatgptAccountID)
if accountID != "" {
digest := sha256.Sum256([]byte(accountID))
hashAccountID = hex.EncodeToString(digest[:])[:8]
}
}
}
fileName := codex.CredentialFileName(tokenStorage.Email, planType, hashAccountID, true)
metadata := map[string]any{
"email": tokenStorage.Email,
}
fmt.Println("Codex authentication successful")
if authBundle.APIKey != "" {
fmt.Println("Codex API key obtained and stored")
}
return &coreauth.Auth{
ID: fileName,
Provider: a.Provider(),
FileName: fileName,
Storage: tokenStorage,
Metadata: metadata,
}, nil
}

View File

@@ -64,8 +64,16 @@ func (s *FileTokenStore) Save(ctx context.Context, auth *cliproxyauth.Auth) (str
return "", fmt.Errorf("auth filestore: create dir failed: %w", err)
}
// metadataSetter is a private interface for TokenStorage implementations that support metadata injection.
type metadataSetter interface {
SetMetadata(map[string]any)
}
switch {
case auth.Storage != nil:
if setter, ok := auth.Storage.(metadataSetter); ok {
setter.SetMetadata(auth.Metadata)
}
if err = auth.Storage.SaveTokenToFile(path); err != nil {
return "", err
}

View File

@@ -5,6 +5,7 @@ import (
"encoding/json"
"fmt"
"math"
"math/rand/v2"
"net/http"
"sort"
"strconv"
@@ -248,6 +249,9 @@ func getAvailableAuths(auths []*Auth, provider, model string, now time.Time) ([]
}
// Pick selects the next available auth for the provider in a round-robin manner.
// For gemini-cli virtual auths (identified by the gemini_virtual_parent attribute),
// a two-level round-robin is used: first cycling across credential groups (parent
// accounts), then cycling within each group's project auths.
func (s *RoundRobinSelector) Pick(ctx context.Context, provider, model string, opts cliproxyexecutor.Options, auths []*Auth) (*Auth, error) {
_ = opts
now := time.Now()
@@ -265,21 +269,87 @@ func (s *RoundRobinSelector) Pick(ctx context.Context, provider, model string, o
if limit <= 0 {
limit = 4096
}
if _, ok := s.cursors[key]; !ok && len(s.cursors) >= limit {
s.cursors = make(map[string]int)
}
index := s.cursors[key]
// Check if any available auth has gemini_virtual_parent attribute,
// indicating gemini-cli virtual auths that should use credential-level polling.
groups, parentOrder := groupByVirtualParent(available)
if len(parentOrder) > 1 {
// Two-level round-robin: first select a credential group, then pick within it.
groupKey := key + "::group"
s.ensureCursorKey(groupKey, limit)
if _, exists := s.cursors[groupKey]; !exists {
// Seed with a random initial offset so the starting credential is randomized.
s.cursors[groupKey] = rand.IntN(len(parentOrder))
}
groupIndex := s.cursors[groupKey]
if groupIndex >= 2_147_483_640 {
groupIndex = 0
}
s.cursors[groupKey] = groupIndex + 1
selectedParent := parentOrder[groupIndex%len(parentOrder)]
group := groups[selectedParent]
// Second level: round-robin within the selected credential group.
innerKey := key + "::cred:" + selectedParent
s.ensureCursorKey(innerKey, limit)
innerIndex := s.cursors[innerKey]
if innerIndex >= 2_147_483_640 {
innerIndex = 0
}
s.cursors[innerKey] = innerIndex + 1
s.mu.Unlock()
return group[innerIndex%len(group)], nil
}
// Flat round-robin for non-grouped auths (original behavior).
s.ensureCursorKey(key, limit)
index := s.cursors[key]
if index >= 2_147_483_640 {
index = 0
}
s.cursors[key] = index + 1
s.mu.Unlock()
// log.Debugf("available: %d, index: %d, key: %d", len(available), index, index%len(available))
return available[index%len(available)], nil
}
// ensureCursorKey ensures the cursor map has capacity for the given key.
// Must be called with s.mu held.
func (s *RoundRobinSelector) ensureCursorKey(key string, limit int) {
if _, ok := s.cursors[key]; !ok && len(s.cursors) >= limit {
s.cursors = make(map[string]int)
}
}
// groupByVirtualParent groups auths by their gemini_virtual_parent attribute.
// Returns a map of parentID -> auths and a sorted slice of parent IDs for stable iteration.
// Only auths with a non-empty gemini_virtual_parent are grouped; if any auth lacks
// this attribute, nil/nil is returned so the caller falls back to flat round-robin.
func groupByVirtualParent(auths []*Auth) (map[string][]*Auth, []string) {
if len(auths) == 0 {
return nil, nil
}
groups := make(map[string][]*Auth)
for _, a := range auths {
parent := ""
if a.Attributes != nil {
parent = strings.TrimSpace(a.Attributes["gemini_virtual_parent"])
}
if parent == "" {
// Non-virtual auth present; fall back to flat round-robin.
return nil, nil
}
groups[parent] = append(groups[parent], a)
}
// Collect parent IDs in sorted order for stable cursor indexing.
parentOrder := make([]string, 0, len(groups))
for p := range groups {
parentOrder = append(parentOrder, p)
}
sort.Strings(parentOrder)
return groups, parentOrder
}
// Pick selects the first available auth for the provider in a deterministic manner.
func (s *FillFirstSelector) Pick(ctx context.Context, provider, model string, opts cliproxyexecutor.Options, auths []*Auth) (*Auth, error) {
_ = opts

View File

@@ -402,3 +402,128 @@ func TestRoundRobinSelectorPick_CursorKeyCap(t *testing.T) {
t.Fatalf("selector.cursors missing key %q", "gemini:m3")
}
}
func TestRoundRobinSelectorPick_GeminiCLICredentialGrouping(t *testing.T) {
t.Parallel()
selector := &RoundRobinSelector{}
// Simulate two gemini-cli credentials, each with multiple projects:
// Credential A (parent = "cred-a.json") has 3 projects
// Credential B (parent = "cred-b.json") has 2 projects
auths := []*Auth{
{ID: "cred-a.json::proj-a1", Attributes: map[string]string{"gemini_virtual_parent": "cred-a.json"}},
{ID: "cred-a.json::proj-a2", Attributes: map[string]string{"gemini_virtual_parent": "cred-a.json"}},
{ID: "cred-a.json::proj-a3", Attributes: map[string]string{"gemini_virtual_parent": "cred-a.json"}},
{ID: "cred-b.json::proj-b1", Attributes: map[string]string{"gemini_virtual_parent": "cred-b.json"}},
{ID: "cred-b.json::proj-b2", Attributes: map[string]string{"gemini_virtual_parent": "cred-b.json"}},
}
// Two-level round-robin: consecutive picks must alternate between credentials.
// Credential group order is randomized, but within each call the group cursor
// advances by 1, so consecutive picks should cycle through different parents.
picks := make([]string, 6)
parents := make([]string, 6)
for i := 0; i < 6; i++ {
got, err := selector.Pick(context.Background(), "gemini-cli", "gemini-2.5-pro", cliproxyexecutor.Options{}, auths)
if err != nil {
t.Fatalf("Pick() #%d error = %v", i, err)
}
if got == nil {
t.Fatalf("Pick() #%d auth = nil", i)
}
picks[i] = got.ID
parents[i] = got.Attributes["gemini_virtual_parent"]
}
// Verify property: consecutive picks must alternate between credential groups.
for i := 1; i < len(parents); i++ {
if parents[i] == parents[i-1] {
t.Fatalf("Pick() #%d and #%d both from same parent %q (IDs: %q, %q); expected alternating credentials",
i-1, i, parents[i], picks[i-1], picks[i])
}
}
// Verify property: each credential's projects are picked in sequence (round-robin within group).
credPicks := map[string][]string{}
for i, id := range picks {
credPicks[parents[i]] = append(credPicks[parents[i]], id)
}
for parent, ids := range credPicks {
for i := 1; i < len(ids); i++ {
if ids[i] == ids[i-1] {
t.Fatalf("Credential %q picked same project %q twice in a row", parent, ids[i])
}
}
}
}
func TestRoundRobinSelectorPick_SingleParentFallsBackToFlat(t *testing.T) {
t.Parallel()
selector := &RoundRobinSelector{}
// All auths from the same parent - should fall back to flat round-robin
// because there's only one credential group (no benefit from two-level).
auths := []*Auth{
{ID: "cred-a.json::proj-a1", Attributes: map[string]string{"gemini_virtual_parent": "cred-a.json"}},
{ID: "cred-a.json::proj-a2", Attributes: map[string]string{"gemini_virtual_parent": "cred-a.json"}},
{ID: "cred-a.json::proj-a3", Attributes: map[string]string{"gemini_virtual_parent": "cred-a.json"}},
}
// With single parent group, parentOrder has length 1, so it uses flat round-robin.
// Sorted by ID: proj-a1, proj-a2, proj-a3
want := []string{
"cred-a.json::proj-a1",
"cred-a.json::proj-a2",
"cred-a.json::proj-a3",
"cred-a.json::proj-a1",
}
for i, expectedID := range want {
got, err := selector.Pick(context.Background(), "gemini-cli", "gemini-2.5-pro", cliproxyexecutor.Options{}, auths)
if err != nil {
t.Fatalf("Pick() #%d error = %v", i, err)
}
if got == nil {
t.Fatalf("Pick() #%d auth = nil", i)
}
if got.ID != expectedID {
t.Fatalf("Pick() #%d auth.ID = %q, want %q", i, got.ID, expectedID)
}
}
}
func TestRoundRobinSelectorPick_MixedVirtualAndNonVirtualFallsBackToFlat(t *testing.T) {
t.Parallel()
selector := &RoundRobinSelector{}
// Mix of virtual and non-virtual auths (e.g., a regular gemini-cli auth without projects
// alongside virtual ones). Should fall back to flat round-robin.
auths := []*Auth{
{ID: "cred-a.json::proj-a1", Attributes: map[string]string{"gemini_virtual_parent": "cred-a.json"}},
{ID: "cred-regular.json"}, // no gemini_virtual_parent
}
// groupByVirtualParent returns nil when any auth lacks the attribute,
// so flat round-robin is used. Sorted by ID: cred-a.json::proj-a1, cred-regular.json
want := []string{
"cred-a.json::proj-a1",
"cred-regular.json",
"cred-a.json::proj-a1",
}
for i, expectedID := range want {
got, err := selector.Pick(context.Background(), "gemini-cli", "", cliproxyexecutor.Options{}, auths)
if err != nil {
t.Fatalf("Pick() #%d error = %v", i, err)
}
if got == nil {
t.Fatalf("Pick() #%d auth = nil", i)
}
if got.ID != expectedID {
t.Fatalf("Pick() #%d auth.ID = %q, want %q", i, got.ID, expectedID)
}
}
}

View File

@@ -1,9 +1,12 @@
package auth
import (
"context"
"crypto/sha256"
"encoding/hex"
"encoding/json"
"net/http"
"net/url"
"strconv"
"strings"
"sync"
@@ -12,6 +15,33 @@ import (
baseauth "github.com/router-for-me/CLIProxyAPI/v6/internal/auth"
)
// PostAuthHook defines a function that is called after an Auth record is created
// but before it is persisted to storage. This allows for modification of the
// Auth record (e.g., injecting metadata) based on external context.
type PostAuthHook func(context.Context, *Auth) error
// RequestInfo holds information extracted from the HTTP request.
// It is injected into the context passed to PostAuthHook.
type RequestInfo struct {
Query url.Values
Headers http.Header
}
type requestInfoKey struct{}
// WithRequestInfo returns a new context with the given RequestInfo attached.
func WithRequestInfo(ctx context.Context, info *RequestInfo) context.Context {
return context.WithValue(ctx, requestInfoKey{}, info)
}
// GetRequestInfo retrieves the RequestInfo from the context, if present.
func GetRequestInfo(ctx context.Context) *RequestInfo {
if val, ok := ctx.Value(requestInfoKey{}).(*RequestInfo); ok {
return val
}
return nil
}
// Auth encapsulates the runtime state and metadata associated with a single credential.
type Auth struct {
// ID uniquely identifies the auth record across restarts.

View File

@@ -153,6 +153,16 @@ func (b *Builder) WithLocalManagementPassword(password string) *Builder {
return b
}
// WithPostAuthHook registers a hook to be called after an Auth record is created
// but before it is persisted to storage.
func (b *Builder) WithPostAuthHook(hook coreauth.PostAuthHook) *Builder {
if hook == nil {
return b
}
b.serverOptions = append(b.serverOptions, api.WithPostAuthHook(hook))
return b
}
// Build validates inputs, applies defaults, and returns a ready-to-run service.
func (b *Builder) Build() (*Service, error) {
if b.cfg == nil {