diff --git a/cmd/server/main.go b/cmd/server/main.go index 2ef8c339..db95f6b3 100644 --- a/cmd/server/main.go +++ b/cmd/server/main.go @@ -72,6 +72,7 @@ func main() { // Command-line flags to control the application's behavior. var login bool var codexLogin bool + var codexDeviceLogin bool var claudeLogin bool var qwenLogin bool var kiloLogin bool @@ -99,6 +100,7 @@ func main() { // Define command-line flags for different operation modes. flag.BoolVar(&login, "login", false, "Login Google Account") flag.BoolVar(&codexLogin, "codex-login", false, "Login to Codex using OAuth") + flag.BoolVar(&codexDeviceLogin, "codex-device-login", false, "Login to Codex using device code flow") flag.BoolVar(&claudeLogin, "claude-login", false, "Login to Claude using OAuth") flag.BoolVar(&qwenLogin, "qwen-login", false, "Login to Qwen using OAuth") flag.BoolVar(&kiloLogin, "kilo-login", false, "Login to Kilo AI using device flow") @@ -502,6 +504,9 @@ func main() { } else if codexLogin { // Handle Codex login cmd.DoCodexLogin(cfg, options) + } else if codexDeviceLogin { + // Handle Codex device-code login + cmd.DoCodexDeviceLogin(cfg, options) } else if claudeLogin { // Handle Claude login cmd.DoClaudeLogin(cfg, options) diff --git a/internal/api/handlers/management/auth_files.go b/internal/api/handlers/management/auth_files.go index bd1338a2..d0900492 100644 --- a/internal/api/handlers/management/auth_files.go +++ b/internal/api/handlers/management/auth_files.go @@ -951,11 +951,17 @@ func (h *Handler) saveTokenRecord(ctx context.Context, record *coreauth.Auth) (s if store == nil { return "", fmt.Errorf("token store unavailable") } + if h.postAuthHook != nil { + if err := h.postAuthHook(ctx, record); err != nil { + return "", fmt.Errorf("post-auth hook failed: %w", err) + } + } return store.Save(ctx, record) } func (h *Handler) RequestAnthropicToken(c *gin.Context) { ctx := context.Background() + ctx = PopulateAuthContext(ctx, c) fmt.Println("Initializing Claude authentication...") @@ -1100,6 +1106,7 @@ func (h *Handler) RequestAnthropicToken(c *gin.Context) { func (h *Handler) RequestGeminiCLIToken(c *gin.Context) { ctx := context.Background() + ctx = PopulateAuthContext(ctx, c) proxyHTTPClient := util.SetProxy(&h.cfg.SDKConfig, &http.Client{}) ctx = context.WithValue(ctx, oauth2.HTTPClient, proxyHTTPClient) @@ -1358,6 +1365,7 @@ func (h *Handler) RequestGeminiCLIToken(c *gin.Context) { func (h *Handler) RequestCodexToken(c *gin.Context) { ctx := context.Background() + ctx = PopulateAuthContext(ctx, c) fmt.Println("Initializing Codex authentication...") @@ -1503,6 +1511,7 @@ func (h *Handler) RequestCodexToken(c *gin.Context) { func (h *Handler) RequestAntigravityToken(c *gin.Context) { ctx := context.Background() + ctx = PopulateAuthContext(ctx, c) fmt.Println("Initializing Antigravity authentication...") @@ -1667,6 +1676,7 @@ func (h *Handler) RequestAntigravityToken(c *gin.Context) { func (h *Handler) RequestQwenToken(c *gin.Context) { ctx := context.Background() + ctx = PopulateAuthContext(ctx, c) fmt.Println("Initializing Qwen authentication...") @@ -1722,6 +1732,7 @@ func (h *Handler) RequestQwenToken(c *gin.Context) { func (h *Handler) RequestKimiToken(c *gin.Context) { ctx := context.Background() + ctx = PopulateAuthContext(ctx, c) fmt.Println("Initializing Kimi authentication...") @@ -1798,6 +1809,7 @@ func (h *Handler) RequestKimiToken(c *gin.Context) { func (h *Handler) RequestIFlowToken(c *gin.Context) { ctx := context.Background() + ctx = PopulateAuthContext(ctx, c) fmt.Println("Initializing iFlow authentication...") @@ -2521,6 +2533,14 @@ func (h *Handler) GetAuthStatus(c *gin.Context) { c.JSON(http.StatusOK, gin.H{"status": "wait"}) } +// PopulateAuthContext extracts request info and adds it to the context +func PopulateAuthContext(ctx context.Context, c *gin.Context) context.Context { + info := &coreauth.RequestInfo{ + Query: c.Request.URL.Query(), + Headers: c.Request.Header, + } + return coreauth.WithRequestInfo(ctx, info) +} const kiroCallbackPort = 9876 func (h *Handler) RequestKiroToken(c *gin.Context) { diff --git a/internal/api/handlers/management/handler.go b/internal/api/handlers/management/handler.go index 613c9841..45786b9d 100644 --- a/internal/api/handlers/management/handler.go +++ b/internal/api/handlers/management/handler.go @@ -47,6 +47,7 @@ type Handler struct { allowRemoteOverride bool envSecret string logDir string + postAuthHook coreauth.PostAuthHook } // NewHandler creates a new management handler instance. @@ -128,6 +129,11 @@ func (h *Handler) SetLogDirectory(dir string) { h.logDir = dir } +// SetPostAuthHook registers a hook to be called after auth record creation but before persistence. +func (h *Handler) SetPostAuthHook(hook coreauth.PostAuthHook) { + h.postAuthHook = hook +} + // Middleware enforces access control for management endpoints. // All requests (local and remote) require a valid management key. // Additionally, remote access requires allow-remote-management=true. diff --git a/internal/api/server.go b/internal/api/server.go index 98041b8b..1a0147f6 100644 --- a/internal/api/server.go +++ b/internal/api/server.go @@ -52,6 +52,7 @@ type serverOptionConfig struct { keepAliveEnabled bool keepAliveTimeout time.Duration keepAliveOnTimeout func() + postAuthHook auth.PostAuthHook } // ServerOption customises HTTP server construction. @@ -112,6 +113,13 @@ func WithRequestLoggerFactory(factory func(*config.Config, string) logging.Reque } } +// WithPostAuthHook registers a hook to be called after auth record creation. +func WithPostAuthHook(hook auth.PostAuthHook) ServerOption { + return func(cfg *serverOptionConfig) { + cfg.postAuthHook = hook + } +} + // Server represents the main API server. // It encapsulates the Gin engine, HTTP server, handlers, and configuration. type Server struct { @@ -263,6 +271,9 @@ func NewServer(cfg *config.Config, authManager *auth.Manager, accessManager *sdk } logDir := logging.ResolveLogDirectory(cfg) s.mgmt.SetLogDirectory(logDir) + if optionState.postAuthHook != nil { + s.mgmt.SetPostAuthHook(optionState.postAuthHook) + } s.localPassword = optionState.localPassword // Setup routes diff --git a/internal/auth/claude/token.go b/internal/auth/claude/token.go index cda10d58..6ebb0f2f 100644 --- a/internal/auth/claude/token.go +++ b/internal/auth/claude/token.go @@ -36,11 +36,21 @@ type ClaudeTokenStorage struct { // Expire is the timestamp when the current access token expires. Expire string `json:"expired"` + + // Metadata holds arbitrary key-value pairs injected via hooks. + // It is not exported to JSON directly to allow flattening during serialization. + Metadata map[string]any `json:"-"` +} + +// SetMetadata allows external callers to inject metadata into the storage before saving. +func (ts *ClaudeTokenStorage) SetMetadata(meta map[string]any) { + ts.Metadata = meta } // SaveTokenToFile serializes the Claude token storage to a JSON file. // This method creates the necessary directory structure and writes the token // data in JSON format to the specified file path for persistent storage. +// It merges any injected metadata into the top-level JSON object. // // Parameters: // - authFilePath: The full path where the token file should be saved @@ -65,8 +75,14 @@ func (ts *ClaudeTokenStorage) SaveTokenToFile(authFilePath string) error { _ = f.Close() }() + // Merge metadata using helper + data, errMerge := misc.MergeMetadata(ts, ts.Metadata) + if errMerge != nil { + return fmt.Errorf("failed to merge metadata: %w", errMerge) + } + // Encode and write the token data as JSON - if err = json.NewEncoder(f).Encode(ts); err != nil { + if err = json.NewEncoder(f).Encode(data); err != nil { return fmt.Errorf("failed to write token to file: %w", err) } return nil diff --git a/internal/auth/codex/openai_auth.go b/internal/auth/codex/openai_auth.go index 89deeadb..c273acae 100644 --- a/internal/auth/codex/openai_auth.go +++ b/internal/auth/codex/openai_auth.go @@ -71,16 +71,26 @@ func (o *CodexAuth) GenerateAuthURL(state string, pkceCodes *PKCECodes) (string, // It performs an HTTP POST request to the OpenAI token endpoint with the provided // authorization code and PKCE verifier. func (o *CodexAuth) ExchangeCodeForTokens(ctx context.Context, code string, pkceCodes *PKCECodes) (*CodexAuthBundle, error) { + return o.ExchangeCodeForTokensWithRedirect(ctx, code, RedirectURI, pkceCodes) +} + +// ExchangeCodeForTokensWithRedirect exchanges an authorization code for tokens using +// a caller-provided redirect URI. This supports alternate auth flows such as device +// login while preserving the existing token parsing and storage behavior. +func (o *CodexAuth) ExchangeCodeForTokensWithRedirect(ctx context.Context, code, redirectURI string, pkceCodes *PKCECodes) (*CodexAuthBundle, error) { if pkceCodes == nil { return nil, fmt.Errorf("PKCE codes are required for token exchange") } + if strings.TrimSpace(redirectURI) == "" { + return nil, fmt.Errorf("redirect URI is required for token exchange") + } // Prepare token exchange request data := url.Values{ "grant_type": {"authorization_code"}, "client_id": {ClientID}, "code": {code}, - "redirect_uri": {RedirectURI}, + "redirect_uri": {strings.TrimSpace(redirectURI)}, "code_verifier": {pkceCodes.CodeVerifier}, } diff --git a/internal/auth/codex/token.go b/internal/auth/codex/token.go index e93fc417..7f032071 100644 --- a/internal/auth/codex/token.go +++ b/internal/auth/codex/token.go @@ -32,11 +32,21 @@ type CodexTokenStorage struct { Type string `json:"type"` // Expire is the timestamp when the current access token expires. Expire string `json:"expired"` + + // Metadata holds arbitrary key-value pairs injected via hooks. + // It is not exported to JSON directly to allow flattening during serialization. + Metadata map[string]any `json:"-"` +} + +// SetMetadata allows external callers to inject metadata into the storage before saving. +func (ts *CodexTokenStorage) SetMetadata(meta map[string]any) { + ts.Metadata = meta } // SaveTokenToFile serializes the Codex token storage to a JSON file. // This method creates the necessary directory structure and writes the token // data in JSON format to the specified file path for persistent storage. +// It merges any injected metadata into the top-level JSON object. // // Parameters: // - authFilePath: The full path where the token file should be saved @@ -58,7 +68,13 @@ func (ts *CodexTokenStorage) SaveTokenToFile(authFilePath string) error { _ = f.Close() }() - if err = json.NewEncoder(f).Encode(ts); err != nil { + // Merge metadata using helper + data, errMerge := misc.MergeMetadata(ts, ts.Metadata) + if errMerge != nil { + return fmt.Errorf("failed to merge metadata: %w", errMerge) + } + + if err = json.NewEncoder(f).Encode(data); err != nil { return fmt.Errorf("failed to write token to file: %w", err) } return nil diff --git a/internal/auth/gemini/gemini_token.go b/internal/auth/gemini/gemini_token.go index 0ec7da17..6848b708 100644 --- a/internal/auth/gemini/gemini_token.go +++ b/internal/auth/gemini/gemini_token.go @@ -35,11 +35,21 @@ type GeminiTokenStorage struct { // Type indicates the authentication provider type, always "gemini" for this storage. Type string `json:"type"` + + // Metadata holds arbitrary key-value pairs injected via hooks. + // It is not exported to JSON directly to allow flattening during serialization. + Metadata map[string]any `json:"-"` +} + +// SetMetadata allows external callers to inject metadata into the storage before saving. +func (ts *GeminiTokenStorage) SetMetadata(meta map[string]any) { + ts.Metadata = meta } // SaveTokenToFile serializes the Gemini token storage to a JSON file. // This method creates the necessary directory structure and writes the token // data in JSON format to the specified file path for persistent storage. +// It merges any injected metadata into the top-level JSON object. // // Parameters: // - authFilePath: The full path where the token file should be saved @@ -49,6 +59,11 @@ type GeminiTokenStorage struct { func (ts *GeminiTokenStorage) SaveTokenToFile(authFilePath string) error { misc.LogSavingCredentials(authFilePath) ts.Type = "gemini" + // Merge metadata using helper + data, errMerge := misc.MergeMetadata(ts, ts.Metadata) + if errMerge != nil { + return fmt.Errorf("failed to merge metadata: %w", errMerge) + } if err := os.MkdirAll(filepath.Dir(authFilePath), 0700); err != nil { return fmt.Errorf("failed to create directory: %v", err) } @@ -63,7 +78,9 @@ func (ts *GeminiTokenStorage) SaveTokenToFile(authFilePath string) error { } }() - if err = json.NewEncoder(f).Encode(ts); err != nil { + enc := json.NewEncoder(f) + enc.SetIndent("", " ") + if err := enc.Encode(data); err != nil { return fmt.Errorf("failed to write token to file: %w", err) } return nil diff --git a/internal/auth/iflow/iflow_token.go b/internal/auth/iflow/iflow_token.go index 6d2beb39..a515c926 100644 --- a/internal/auth/iflow/iflow_token.go +++ b/internal/auth/iflow/iflow_token.go @@ -21,6 +21,15 @@ type IFlowTokenStorage struct { Scope string `json:"scope"` Cookie string `json:"cookie"` Type string `json:"type"` + + // Metadata holds arbitrary key-value pairs injected via hooks. + // It is not exported to JSON directly to allow flattening during serialization. + Metadata map[string]any `json:"-"` +} + +// SetMetadata allows external callers to inject metadata into the storage before saving. +func (ts *IFlowTokenStorage) SetMetadata(meta map[string]any) { + ts.Metadata = meta } // SaveTokenToFile serialises the token storage to disk. @@ -37,7 +46,13 @@ func (ts *IFlowTokenStorage) SaveTokenToFile(authFilePath string) error { } defer func() { _ = f.Close() }() - if err = json.NewEncoder(f).Encode(ts); err != nil { + // Merge metadata using helper + data, errMerge := misc.MergeMetadata(ts, ts.Metadata) + if errMerge != nil { + return fmt.Errorf("failed to merge metadata: %w", errMerge) + } + + if err = json.NewEncoder(f).Encode(data); err != nil { return fmt.Errorf("iflow token: encode token failed: %w", err) } return nil diff --git a/internal/auth/kimi/token.go b/internal/auth/kimi/token.go index d4d06b64..7320d760 100644 --- a/internal/auth/kimi/token.go +++ b/internal/auth/kimi/token.go @@ -29,6 +29,15 @@ type KimiTokenStorage struct { Expired string `json:"expired,omitempty"` // Type indicates the authentication provider type, always "kimi" for this storage. Type string `json:"type"` + + // Metadata holds arbitrary key-value pairs injected via hooks. + // It is not exported to JSON directly to allow flattening during serialization. + Metadata map[string]any `json:"-"` +} + +// SetMetadata allows external callers to inject metadata into the storage before saving. +func (ts *KimiTokenStorage) SetMetadata(meta map[string]any) { + ts.Metadata = meta } // KimiTokenData holds the raw OAuth token response from Kimi. @@ -86,9 +95,15 @@ func (ts *KimiTokenStorage) SaveTokenToFile(authFilePath string) error { _ = f.Close() }() + // Merge metadata using helper + data, errMerge := misc.MergeMetadata(ts, ts.Metadata) + if errMerge != nil { + return fmt.Errorf("failed to merge metadata: %w", errMerge) + } + encoder := json.NewEncoder(f) encoder.SetIndent("", " ") - if err = encoder.Encode(ts); err != nil { + if err = encoder.Encode(data); err != nil { return fmt.Errorf("failed to write token to file: %w", err) } return nil diff --git a/internal/auth/qwen/qwen_token.go b/internal/auth/qwen/qwen_token.go index 4a2b3a2d..276c8b40 100644 --- a/internal/auth/qwen/qwen_token.go +++ b/internal/auth/qwen/qwen_token.go @@ -30,11 +30,21 @@ type QwenTokenStorage struct { Type string `json:"type"` // Expire is the timestamp when the current access token expires. Expire string `json:"expired"` + + // Metadata holds arbitrary key-value pairs injected via hooks. + // It is not exported to JSON directly to allow flattening during serialization. + Metadata map[string]any `json:"-"` +} + +// SetMetadata allows external callers to inject metadata into the storage before saving. +func (ts *QwenTokenStorage) SetMetadata(meta map[string]any) { + ts.Metadata = meta } // SaveTokenToFile serializes the Qwen token storage to a JSON file. // This method creates the necessary directory structure and writes the token // data in JSON format to the specified file path for persistent storage. +// It merges any injected metadata into the top-level JSON object. // // Parameters: // - authFilePath: The full path where the token file should be saved @@ -56,7 +66,13 @@ func (ts *QwenTokenStorage) SaveTokenToFile(authFilePath string) error { _ = f.Close() }() - if err = json.NewEncoder(f).Encode(ts); err != nil { + // Merge metadata using helper + data, errMerge := misc.MergeMetadata(ts, ts.Metadata) + if errMerge != nil { + return fmt.Errorf("failed to merge metadata: %w", errMerge) + } + + if err = json.NewEncoder(f).Encode(data); err != nil { return fmt.Errorf("failed to write token to file: %w", err) } return nil diff --git a/internal/cmd/openai_device_login.go b/internal/cmd/openai_device_login.go new file mode 100644 index 00000000..1b7351e6 --- /dev/null +++ b/internal/cmd/openai_device_login.go @@ -0,0 +1,60 @@ +package cmd + +import ( + "context" + "errors" + "fmt" + "os" + + "github.com/router-for-me/CLIProxyAPI/v6/internal/auth/codex" + "github.com/router-for-me/CLIProxyAPI/v6/internal/config" + sdkAuth "github.com/router-for-me/CLIProxyAPI/v6/sdk/auth" + log "github.com/sirupsen/logrus" +) + +const ( + codexLoginModeMetadataKey = "codex_login_mode" + codexLoginModeDevice = "device" +) + +// DoCodexDeviceLogin triggers the Codex device-code flow while keeping the +// existing codex-login OAuth callback flow intact. +func DoCodexDeviceLogin(cfg *config.Config, options *LoginOptions) { + if options == nil { + options = &LoginOptions{} + } + + promptFn := options.Prompt + if promptFn == nil { + promptFn = defaultProjectPrompt() + } + + manager := newAuthManager() + + authOpts := &sdkAuth.LoginOptions{ + NoBrowser: options.NoBrowser, + CallbackPort: options.CallbackPort, + Metadata: map[string]string{ + codexLoginModeMetadataKey: codexLoginModeDevice, + }, + Prompt: promptFn, + } + + _, savedPath, err := manager.Login(context.Background(), "codex", cfg, authOpts) + if err != nil { + if authErr, ok := errors.AsType[*codex.AuthenticationError](err); ok { + log.Error(codex.GetUserFriendlyMessage(authErr)) + if authErr.Type == codex.ErrPortInUse.Type { + os.Exit(codex.ErrPortInUse.Code) + } + return + } + fmt.Printf("Codex device authentication failed: %v\n", err) + return + } + + if savedPath != "" { + fmt.Printf("Authentication saved to %s\n", savedPath) + } + fmt.Println("Codex device authentication successful!") +} diff --git a/internal/misc/credentials.go b/internal/misc/credentials.go index b03cd788..6b4f9ced 100644 --- a/internal/misc/credentials.go +++ b/internal/misc/credentials.go @@ -1,6 +1,7 @@ package misc import ( + "encoding/json" "fmt" "path/filepath" "strings" @@ -24,3 +25,37 @@ func LogSavingCredentials(path string) { func LogCredentialSeparator() { log.Debug(credentialSeparator) } + +// MergeMetadata serializes the source struct into a map and merges the provided metadata into it. +func MergeMetadata(source any, metadata map[string]any) (map[string]any, error) { + var data map[string]any + + // Fast path: if source is already a map, just copy it to avoid mutation of original + if srcMap, ok := source.(map[string]any); ok { + data = make(map[string]any, len(srcMap)+len(metadata)) + for k, v := range srcMap { + data[k] = v + } + } else { + // Slow path: marshal to JSON and back to map to respect JSON tags + temp, err := json.Marshal(source) + if err != nil { + return nil, fmt.Errorf("failed to marshal source: %w", err) + } + if err := json.Unmarshal(temp, &data); err != nil { + return nil, fmt.Errorf("failed to unmarshal to map: %w", err) + } + } + + // Merge extra metadata + if metadata != nil { + if data == nil { + data = make(map[string]any) + } + for k, v := range metadata { + data[k] = v + } + } + + return data, nil +} diff --git a/internal/thinking/provider/openai/apply.go b/internal/thinking/provider/openai/apply.go index e8a2562f..eaad30ee 100644 --- a/internal/thinking/provider/openai/apply.go +++ b/internal/thinking/provider/openai/apply.go @@ -10,53 +10,10 @@ import ( "github.com/router-for-me/CLIProxyAPI/v6/internal/registry" "github.com/router-for-me/CLIProxyAPI/v6/internal/thinking" - log "github.com/sirupsen/logrus" "github.com/tidwall/gjson" "github.com/tidwall/sjson" ) -// validReasoningEffortLevels contains the standard values accepted by the -// OpenAI reasoning_effort field. Provider-specific extensions (xhigh, minimal, -// auto) are NOT in this set and must be clamped before use. -var validReasoningEffortLevels = map[string]struct{}{ - "none": {}, - "low": {}, - "medium": {}, - "high": {}, -} - -// clampReasoningEffort maps any thinking level string to a value that is safe -// to send as OpenAI reasoning_effort. Non-standard CPA-internal values are -// mapped to the nearest standard equivalent. -// -// Mapping rules: -// - none / low / medium / high → returned as-is (already valid) -// - xhigh → "high" (nearest lower standard level) -// - minimal → "low" (nearest higher standard level) -// - auto → "medium" (reasonable default) -// - anything else → "medium" (safe default) -func clampReasoningEffort(level string) string { - if _, ok := validReasoningEffortLevels[level]; ok { - return level - } - var clamped string - switch level { - case string(thinking.LevelXHigh): - clamped = string(thinking.LevelHigh) - case string(thinking.LevelMinimal): - clamped = string(thinking.LevelLow) - case string(thinking.LevelAuto): - clamped = string(thinking.LevelMedium) - default: - clamped = string(thinking.LevelMedium) - } - log.WithFields(log.Fields{ - "original": level, - "clamped": clamped, - }).Debug("openai: reasoning_effort clamped to nearest valid standard value") - return clamped -} - // Applier implements thinking.ProviderApplier for OpenAI models. // // OpenAI-specific behavior: @@ -101,7 +58,7 @@ func (a *Applier) Apply(body []byte, config thinking.ThinkingConfig, modelInfo * } if config.Mode == thinking.ModeLevel { - result, _ := sjson.SetBytes(body, "reasoning_effort", clampReasoningEffort(string(config.Level))) + result, _ := sjson.SetBytes(body, "reasoning_effort", string(config.Level)) return result, nil } @@ -122,7 +79,7 @@ func (a *Applier) Apply(body []byte, config thinking.ThinkingConfig, modelInfo * return body, nil } - result, _ := sjson.SetBytes(body, "reasoning_effort", clampReasoningEffort(effort)) + result, _ := sjson.SetBytes(body, "reasoning_effort", effort) return result, nil } @@ -157,7 +114,7 @@ func applyCompatibleOpenAI(body []byte, config thinking.ThinkingConfig) ([]byte, return body, nil } - result, _ := sjson.SetBytes(body, "reasoning_effort", clampReasoningEffort(effort)) + result, _ := sjson.SetBytes(body, "reasoning_effort", effort) return result, nil } diff --git a/internal/translator/antigravity/openai/chat-completions/antigravity_openai_response.go b/internal/translator/antigravity/openai/chat-completions/antigravity_openai_response.go index af9ffef1..91bc0423 100644 --- a/internal/translator/antigravity/openai/chat-completions/antigravity_openai_response.go +++ b/internal/translator/antigravity/openai/chat-completions/antigravity_openai_response.go @@ -95,9 +95,9 @@ func ConvertAntigravityResponseToOpenAI(_ context.Context, _ string, originalReq if totalTokenCountResult := usageResult.Get("totalTokenCount"); totalTokenCountResult.Exists() { template, _ = sjson.Set(template, "usage.total_tokens", totalTokenCountResult.Int()) } - promptTokenCount := usageResult.Get("promptTokenCount").Int() - cachedTokenCount + promptTokenCount := usageResult.Get("promptTokenCount").Int() thoughtsTokenCount := usageResult.Get("thoughtsTokenCount").Int() - template, _ = sjson.Set(template, "usage.prompt_tokens", promptTokenCount+thoughtsTokenCount) + template, _ = sjson.Set(template, "usage.prompt_tokens", promptTokenCount) if thoughtsTokenCount > 0 { template, _ = sjson.Set(template, "usage.completion_tokens_details.reasoning_tokens", thoughtsTokenCount) } diff --git a/internal/translator/claude/openai/chat-completions/claude_openai_request.go b/internal/translator/claude/openai/chat-completions/claude_openai_request.go index 3cad1882..f94825b2 100644 --- a/internal/translator/claude/openai/chat-completions/claude_openai_request.go +++ b/internal/translator/claude/openai/chat-completions/claude_openai_request.go @@ -199,6 +199,21 @@ func ConvertOpenAIRequestToClaude(modelName string, inputRawJSON []byte, stream msg, _ = sjson.SetRaw(msg, "content.-1", imagePart) } } + + case "file": + fileData := part.Get("file.file_data").String() + if strings.HasPrefix(fileData, "data:") { + semicolonIdx := strings.Index(fileData, ";") + commaIdx := strings.Index(fileData, ",") + if semicolonIdx != -1 && commaIdx != -1 && commaIdx > semicolonIdx { + mediaType := strings.TrimPrefix(fileData[:semicolonIdx], "data:") + data := fileData[commaIdx+1:] + docPart := `{"type":"document","source":{"type":"base64","media_type":"","data":""}}` + docPart, _ = sjson.Set(docPart, "source.media_type", mediaType) + docPart, _ = sjson.Set(docPart, "source.data", data) + msg, _ = sjson.SetRaw(msg, "content.-1", docPart) + } + } } return true }) diff --git a/internal/translator/claude/openai/responses/claude_openai-responses_request.go b/internal/translator/claude/openai/responses/claude_openai-responses_request.go index 337f9be9..33a81124 100644 --- a/internal/translator/claude/openai/responses/claude_openai-responses_request.go +++ b/internal/translator/claude/openai/responses/claude_openai-responses_request.go @@ -155,6 +155,7 @@ func ConvertOpenAIResponsesRequestToClaude(modelName string, inputRawJSON []byte var textAggregate strings.Builder var partsJSON []string hasImage := false + hasFile := false if parts := item.Get("content"); parts.Exists() && parts.IsArray() { parts.ForEach(func(_, part gjson.Result) bool { ptype := part.Get("type").String() @@ -207,6 +208,30 @@ func ConvertOpenAIResponsesRequestToClaude(modelName string, inputRawJSON []byte hasImage = true } } + case "input_file": + fileData := part.Get("file_data").String() + if fileData != "" { + mediaType := "application/octet-stream" + data := fileData + if strings.HasPrefix(fileData, "data:") { + trimmed := strings.TrimPrefix(fileData, "data:") + mediaAndData := strings.SplitN(trimmed, ";base64,", 2) + if len(mediaAndData) == 2 { + if mediaAndData[0] != "" { + mediaType = mediaAndData[0] + } + data = mediaAndData[1] + } + } + contentPart := `{"type":"document","source":{"type":"base64","media_type":"","data":""}}` + contentPart, _ = sjson.Set(contentPart, "source.media_type", mediaType) + contentPart, _ = sjson.Set(contentPart, "source.data", data) + partsJSON = append(partsJSON, contentPart) + if role == "" { + role = "user" + } + hasFile = true + } } return true }) @@ -228,7 +253,7 @@ func ConvertOpenAIResponsesRequestToClaude(modelName string, inputRawJSON []byte if len(partsJSON) > 0 { msg := `{"role":"","content":[]}` msg, _ = sjson.Set(msg, "role", role) - if len(partsJSON) == 1 && !hasImage { + if len(partsJSON) == 1 && !hasImage && !hasFile { // Preserve legacy behavior for single text content msg, _ = sjson.Delete(msg, "content") textPart := gjson.Parse(partsJSON[0]) diff --git a/internal/translator/codex/openai/chat-completions/codex_openai_request.go b/internal/translator/codex/openai/chat-completions/codex_openai_request.go index e79f97cd..1ea9ca4b 100644 --- a/internal/translator/codex/openai/chat-completions/codex_openai_request.go +++ b/internal/translator/codex/openai/chat-completions/codex_openai_request.go @@ -180,7 +180,19 @@ func ConvertOpenAIRequestToCodex(modelName string, inputRawJSON []byte, stream b msg, _ = sjson.SetRaw(msg, "content.-1", part) } case "file": - // Files are not specified in examples; skip for now + if role == "user" { + fileData := it.Get("file.file_data").String() + filename := it.Get("file.filename").String() + if fileData != "" { + part := `{}` + part, _ = sjson.Set(part, "type", "input_file") + part, _ = sjson.Set(part, "file_data", fileData) + if filename != "" { + part, _ = sjson.Set(part, "filename", filename) + } + msg, _ = sjson.SetRaw(msg, "content.-1", part) + } + } } } } diff --git a/internal/translator/gemini-cli/openai/chat-completions/gemini-cli_openai_response.go b/internal/translator/gemini-cli/openai/chat-completions/gemini-cli_openai_response.go index 0415e014..b26d431f 100644 --- a/internal/translator/gemini-cli/openai/chat-completions/gemini-cli_openai_response.go +++ b/internal/translator/gemini-cli/openai/chat-completions/gemini-cli_openai_response.go @@ -100,7 +100,7 @@ func ConvertCliResponseToOpenAI(_ context.Context, _ string, originalRequestRawJ } promptTokenCount := usageResult.Get("promptTokenCount").Int() thoughtsTokenCount := usageResult.Get("thoughtsTokenCount").Int() - template, _ = sjson.Set(template, "usage.prompt_tokens", promptTokenCount+thoughtsTokenCount) + template, _ = sjson.Set(template, "usage.prompt_tokens", promptTokenCount) if thoughtsTokenCount > 0 { template, _ = sjson.Set(template, "usage.completion_tokens_details.reasoning_tokens", thoughtsTokenCount) } diff --git a/internal/translator/gemini/openai/chat-completions/gemini_openai_response.go b/internal/translator/gemini/openai/chat-completions/gemini_openai_response.go index ee581c46..aeec5e9e 100644 --- a/internal/translator/gemini/openai/chat-completions/gemini_openai_response.go +++ b/internal/translator/gemini/openai/chat-completions/gemini_openai_response.go @@ -100,9 +100,9 @@ func ConvertGeminiResponseToOpenAI(_ context.Context, _ string, originalRequestR if totalTokenCountResult := usageResult.Get("totalTokenCount"); totalTokenCountResult.Exists() { baseTemplate, _ = sjson.Set(baseTemplate, "usage.total_tokens", totalTokenCountResult.Int()) } - promptTokenCount := usageResult.Get("promptTokenCount").Int() - cachedTokenCount + promptTokenCount := usageResult.Get("promptTokenCount").Int() thoughtsTokenCount := usageResult.Get("thoughtsTokenCount").Int() - baseTemplate, _ = sjson.Set(baseTemplate, "usage.prompt_tokens", promptTokenCount+thoughtsTokenCount) + baseTemplate, _ = sjson.Set(baseTemplate, "usage.prompt_tokens", promptTokenCount) if thoughtsTokenCount > 0 { baseTemplate, _ = sjson.Set(baseTemplate, "usage.completion_tokens_details.reasoning_tokens", thoughtsTokenCount) } @@ -297,7 +297,7 @@ func ConvertGeminiResponseToOpenAINonStream(_ context.Context, _ string, origina promptTokenCount := usageResult.Get("promptTokenCount").Int() thoughtsTokenCount := usageResult.Get("thoughtsTokenCount").Int() cachedTokenCount := usageResult.Get("cachedContentTokenCount").Int() - template, _ = sjson.Set(template, "usage.prompt_tokens", promptTokenCount+thoughtsTokenCount) + template, _ = sjson.Set(template, "usage.prompt_tokens", promptTokenCount) if thoughtsTokenCount > 0 { template, _ = sjson.Set(template, "usage.completion_tokens_details.reasoning_tokens", thoughtsTokenCount) } diff --git a/internal/translator/gemini/openai/responses/gemini_openai-responses_response.go b/internal/translator/gemini/openai/responses/gemini_openai-responses_response.go index 985897fa..73609be7 100644 --- a/internal/translator/gemini/openai/responses/gemini_openai-responses_response.go +++ b/internal/translator/gemini/openai/responses/gemini_openai-responses_response.go @@ -531,8 +531,8 @@ func ConvertGeminiResponseToOpenAIResponses(_ context.Context, modelName string, // usage mapping if um := root.Get("usageMetadata"); um.Exists() { - // input tokens = prompt + thoughts - input := um.Get("promptTokenCount").Int() + um.Get("thoughtsTokenCount").Int() + // input tokens = prompt only (thoughts go to output) + input := um.Get("promptTokenCount").Int() completed, _ = sjson.Set(completed, "response.usage.input_tokens", input) // cached token details: align with OpenAI "cached_tokens" semantics. completed, _ = sjson.Set(completed, "response.usage.input_tokens_details.cached_tokens", um.Get("cachedContentTokenCount").Int()) @@ -737,8 +737,8 @@ func ConvertGeminiResponseToOpenAIResponsesNonStream(_ context.Context, _ string // usage mapping if um := root.Get("usageMetadata"); um.Exists() { - // input tokens = prompt + thoughts - input := um.Get("promptTokenCount").Int() + um.Get("thoughtsTokenCount").Int() + // input tokens = prompt only (thoughts go to output) + input := um.Get("promptTokenCount").Int() resp, _ = sjson.Set(resp, "usage.input_tokens", input) // cached token details: align with OpenAI "cached_tokens" semantics. resp, _ = sjson.Set(resp, "usage.input_tokens_details.cached_tokens", um.Get("cachedContentTokenCount").Int()) diff --git a/sdk/api/handlers/handlers.go b/sdk/api/handlers/handlers.go index c86651c5..f099116d 100644 --- a/sdk/api/handlers/handlers.go +++ b/sdk/api/handlers/handlers.go @@ -717,6 +717,12 @@ func (h *BaseAPIHandler) ExecuteStreamWithAuthManager(ctx context.Context, handl return } if len(chunk.Payload) > 0 { + if handlerType == "openai-response" { + if err := validateSSEDataJSON(chunk.Payload); err != nil { + _ = sendErr(&interfaces.ErrorMessage{StatusCode: http.StatusBadGateway, Error: err}) + return + } + } sentPayload = true if okSendData := sendData(cloneBytes(chunk.Payload)); !okSendData { return @@ -728,6 +734,35 @@ func (h *BaseAPIHandler) ExecuteStreamWithAuthManager(ctx context.Context, handl return dataChan, upstreamHeaders, errChan } +func validateSSEDataJSON(chunk []byte) error { + for _, line := range bytes.Split(chunk, []byte("\n")) { + line = bytes.TrimSpace(line) + if len(line) == 0 { + continue + } + if !bytes.HasPrefix(line, []byte("data:")) { + continue + } + data := bytes.TrimSpace(line[5:]) + if len(data) == 0 { + continue + } + if bytes.Equal(data, []byte("[DONE]")) { + continue + } + if json.Valid(data) { + continue + } + const max = 512 + preview := data + if len(preview) > max { + preview = preview[:max] + } + return fmt.Errorf("invalid SSE data JSON (len=%d): %q", len(data), preview) + } + return nil +} + func statusFromError(err error) int { if err == nil { return 0 diff --git a/sdk/api/handlers/handlers_stream_bootstrap_test.go b/sdk/api/handlers/handlers_stream_bootstrap_test.go index ba9dcac5..b08e3a99 100644 --- a/sdk/api/handlers/handlers_stream_bootstrap_test.go +++ b/sdk/api/handlers/handlers_stream_bootstrap_test.go @@ -134,6 +134,37 @@ type authAwareStreamExecutor struct { authIDs []string } +type invalidJSONStreamExecutor struct{} + +func (e *invalidJSONStreamExecutor) Identifier() string { return "codex" } + +func (e *invalidJSONStreamExecutor) Execute(context.Context, *coreauth.Auth, coreexecutor.Request, coreexecutor.Options) (coreexecutor.Response, error) { + return coreexecutor.Response{}, &coreauth.Error{Code: "not_implemented", Message: "Execute not implemented"} +} + +func (e *invalidJSONStreamExecutor) ExecuteStream(context.Context, *coreauth.Auth, coreexecutor.Request, coreexecutor.Options) (*coreexecutor.StreamResult, error) { + ch := make(chan coreexecutor.StreamChunk, 1) + ch <- coreexecutor.StreamChunk{Payload: []byte("event: response.completed\ndata: {\"type\"")} + close(ch) + return &coreexecutor.StreamResult{Chunks: ch}, nil +} + +func (e *invalidJSONStreamExecutor) Refresh(ctx context.Context, auth *coreauth.Auth) (*coreauth.Auth, error) { + return auth, nil +} + +func (e *invalidJSONStreamExecutor) CountTokens(context.Context, *coreauth.Auth, coreexecutor.Request, coreexecutor.Options) (coreexecutor.Response, error) { + return coreexecutor.Response{}, &coreauth.Error{Code: "not_implemented", Message: "CountTokens not implemented"} +} + +func (e *invalidJSONStreamExecutor) HttpRequest(ctx context.Context, auth *coreauth.Auth, req *http.Request) (*http.Response, error) { + return nil, &coreauth.Error{ + Code: "not_implemented", + Message: "HttpRequest not implemented", + HTTPStatus: http.StatusNotImplemented, + } +} + func (e *authAwareStreamExecutor) Identifier() string { return "codex" } func (e *authAwareStreamExecutor) Execute(context.Context, *coreauth.Auth, coreexecutor.Request, coreexecutor.Options) (coreexecutor.Response, error) { @@ -524,3 +555,55 @@ func TestExecuteStreamWithAuthManager_SelectedAuthCallbackReceivesAuthID(t *test t.Fatalf("selectedAuthID = %q, want %q", selectedAuthID, "auth2") } } + +func TestExecuteStreamWithAuthManager_ValidatesOpenAIResponsesStreamDataJSON(t *testing.T) { + executor := &invalidJSONStreamExecutor{} + manager := coreauth.NewManager(nil, nil, nil) + manager.RegisterExecutor(executor) + + auth1 := &coreauth.Auth{ + ID: "auth1", + Provider: "codex", + Status: coreauth.StatusActive, + Metadata: map[string]any{"email": "test1@example.com"}, + } + if _, err := manager.Register(context.Background(), auth1); err != nil { + t.Fatalf("manager.Register(auth1): %v", err) + } + + registry.GetGlobalRegistry().RegisterClient(auth1.ID, auth1.Provider, []*registry.ModelInfo{{ID: "test-model"}}) + t.Cleanup(func() { + registry.GetGlobalRegistry().UnregisterClient(auth1.ID) + }) + + handler := NewBaseAPIHandlers(&sdkconfig.SDKConfig{}, manager) + dataChan, _, errChan := handler.ExecuteStreamWithAuthManager(context.Background(), "openai-response", "test-model", []byte(`{"model":"test-model"}`), "") + if dataChan == nil || errChan == nil { + t.Fatalf("expected non-nil channels") + } + + var got []byte + for chunk := range dataChan { + got = append(got, chunk...) + } + if len(got) != 0 { + t.Fatalf("expected empty payload, got %q", string(got)) + } + + gotErr := false + for msg := range errChan { + if msg == nil { + continue + } + if msg.StatusCode != http.StatusBadGateway { + t.Fatalf("expected status %d, got %d", http.StatusBadGateway, msg.StatusCode) + } + if msg.Error == nil { + t.Fatalf("expected error") + } + gotErr = true + } + if !gotErr { + t.Fatalf("expected terminal error") + } +} diff --git a/sdk/api/handlers/openai/openai_responses_handlers.go b/sdk/api/handlers/openai/openai_responses_handlers.go index f10e8d51..da04a1ae 100644 --- a/sdk/api/handlers/openai/openai_responses_handlers.go +++ b/sdk/api/handlers/openai/openai_responses_handlers.go @@ -418,8 +418,8 @@ func (h *OpenAIResponsesAPIHandler) forwardResponsesStream(c *gin.Context, flush if errMsg.Error != nil && errMsg.Error.Error() != "" { errText = errMsg.Error.Error() } - body := handlers.BuildErrorResponseBody(status, errText) - _, _ = fmt.Fprintf(c.Writer, "\nevent: error\ndata: %s\n\n", string(body)) + chunk := handlers.BuildOpenAIResponsesStreamErrorChunk(status, errText, 0) + _, _ = fmt.Fprintf(c.Writer, "\nevent: error\ndata: %s\n\n", string(chunk)) }, WriteDone: func() { _, _ = c.Writer.Write([]byte("\n")) diff --git a/sdk/api/handlers/openai/openai_responses_handlers_stream_error_test.go b/sdk/api/handlers/openai/openai_responses_handlers_stream_error_test.go new file mode 100644 index 00000000..dce73807 --- /dev/null +++ b/sdk/api/handlers/openai/openai_responses_handlers_stream_error_test.go @@ -0,0 +1,43 @@ +package openai + +import ( + "errors" + "net/http" + "net/http/httptest" + "strings" + "testing" + + "github.com/gin-gonic/gin" + "github.com/router-for-me/CLIProxyAPI/v6/internal/interfaces" + "github.com/router-for-me/CLIProxyAPI/v6/sdk/api/handlers" + sdkconfig "github.com/router-for-me/CLIProxyAPI/v6/sdk/config" +) + +func TestForwardResponsesStreamTerminalErrorUsesResponsesErrorChunk(t *testing.T) { + gin.SetMode(gin.TestMode) + base := handlers.NewBaseAPIHandlers(&sdkconfig.SDKConfig{}, nil) + h := NewOpenAIResponsesAPIHandler(base) + + recorder := httptest.NewRecorder() + c, _ := gin.CreateTestContext(recorder) + c.Request = httptest.NewRequest(http.MethodPost, "/v1/responses", nil) + + flusher, ok := c.Writer.(http.Flusher) + if !ok { + t.Fatalf("expected gin writer to implement http.Flusher") + } + + data := make(chan []byte) + errs := make(chan *interfaces.ErrorMessage, 1) + errs <- &interfaces.ErrorMessage{StatusCode: http.StatusInternalServerError, Error: errors.New("unexpected EOF")} + close(errs) + + h.forwardResponsesStream(c, flusher, func(error) {}, data, errs) + body := recorder.Body.String() + if !strings.Contains(body, `"type":"error"`) { + t.Fatalf("expected responses error chunk, got: %q", body) + } + if strings.Contains(body, `"error":{`) { + t.Fatalf("expected streaming error chunk (top-level type), got HTTP error body: %q", body) + } +} diff --git a/sdk/api/handlers/openai_responses_stream_error.go b/sdk/api/handlers/openai_responses_stream_error.go new file mode 100644 index 00000000..e7760bd0 --- /dev/null +++ b/sdk/api/handlers/openai_responses_stream_error.go @@ -0,0 +1,119 @@ +package handlers + +import ( + "encoding/json" + "fmt" + "net/http" + "strings" +) + +type openAIResponsesStreamErrorChunk struct { + Type string `json:"type"` + Code string `json:"code"` + Message string `json:"message"` + SequenceNumber int `json:"sequence_number"` +} + +func openAIResponsesStreamErrorCode(status int) string { + switch status { + case http.StatusUnauthorized: + return "invalid_api_key" + case http.StatusForbidden: + return "insufficient_quota" + case http.StatusTooManyRequests: + return "rate_limit_exceeded" + case http.StatusNotFound: + return "model_not_found" + case http.StatusRequestTimeout: + return "request_timeout" + default: + if status >= http.StatusInternalServerError { + return "internal_server_error" + } + if status >= http.StatusBadRequest { + return "invalid_request_error" + } + return "unknown_error" + } +} + +// BuildOpenAIResponsesStreamErrorChunk builds an OpenAI Responses streaming error chunk. +// +// Important: OpenAI's HTTP error bodies are shaped like {"error":{...}}; those are valid for +// non-streaming responses, but streaming clients validate SSE `data:` payloads against a union +// of chunks that requires a top-level `type` field. +func BuildOpenAIResponsesStreamErrorChunk(status int, errText string, sequenceNumber int) []byte { + if status <= 0 { + status = http.StatusInternalServerError + } + if sequenceNumber < 0 { + sequenceNumber = 0 + } + + message := strings.TrimSpace(errText) + if message == "" { + message = http.StatusText(status) + } + + code := openAIResponsesStreamErrorCode(status) + + trimmed := strings.TrimSpace(errText) + if trimmed != "" && json.Valid([]byte(trimmed)) { + var payload map[string]any + if err := json.Unmarshal([]byte(trimmed), &payload); err == nil { + if t, ok := payload["type"].(string); ok && strings.TrimSpace(t) == "error" { + if m, ok := payload["message"].(string); ok && strings.TrimSpace(m) != "" { + message = strings.TrimSpace(m) + } + if v, ok := payload["code"]; ok && v != nil { + if c, ok := v.(string); ok && strings.TrimSpace(c) != "" { + code = strings.TrimSpace(c) + } else { + code = strings.TrimSpace(fmt.Sprint(v)) + } + } + if v, ok := payload["sequence_number"].(float64); ok && sequenceNumber == 0 { + sequenceNumber = int(v) + } + } + if e, ok := payload["error"].(map[string]any); ok { + if m, ok := e["message"].(string); ok && strings.TrimSpace(m) != "" { + message = strings.TrimSpace(m) + } + if v, ok := e["code"]; ok && v != nil { + if c, ok := v.(string); ok && strings.TrimSpace(c) != "" { + code = strings.TrimSpace(c) + } else { + code = strings.TrimSpace(fmt.Sprint(v)) + } + } + } + } + } + + if strings.TrimSpace(code) == "" { + code = "unknown_error" + } + + data, err := json.Marshal(openAIResponsesStreamErrorChunk{ + Type: "error", + Code: code, + Message: message, + SequenceNumber: sequenceNumber, + }) + if err == nil { + return data + } + + // Extremely defensive fallback. + data, _ = json.Marshal(openAIResponsesStreamErrorChunk{ + Type: "error", + Code: "internal_server_error", + Message: message, + SequenceNumber: sequenceNumber, + }) + if len(data) > 0 { + return data + } + return []byte(`{"type":"error","code":"internal_server_error","message":"internal error","sequence_number":0}`) +} diff --git a/sdk/api/handlers/openai_responses_stream_error_test.go b/sdk/api/handlers/openai_responses_stream_error_test.go new file mode 100644 index 00000000..90b2c667 --- /dev/null +++ b/sdk/api/handlers/openai_responses_stream_error_test.go @@ -0,0 +1,48 @@ +package handlers + +import ( + "encoding/json" + "net/http" + "testing" +) + +func TestBuildOpenAIResponsesStreamErrorChunk(t *testing.T) { + chunk := BuildOpenAIResponsesStreamErrorChunk(http.StatusInternalServerError, "unexpected EOF", 0) + var payload map[string]any + if err := json.Unmarshal(chunk, &payload); err != nil { + t.Fatalf("unmarshal: %v", err) + } + if payload["type"] != "error" { + t.Fatalf("type = %v, want %q", payload["type"], "error") + } + if payload["code"] != "internal_server_error" { + t.Fatalf("code = %v, want %q", payload["code"], "internal_server_error") + } + if payload["message"] != "unexpected EOF" { + t.Fatalf("message = %v, want %q", payload["message"], "unexpected EOF") + } + if payload["sequence_number"] != float64(0) { + t.Fatalf("sequence_number = %v, want %v", payload["sequence_number"], 0) + } +} + +func TestBuildOpenAIResponsesStreamErrorChunkExtractsHTTPErrorBody(t *testing.T) { + chunk := BuildOpenAIResponsesStreamErrorChunk( + http.StatusInternalServerError, + `{"error":{"message":"oops","type":"server_error","code":"internal_server_error"}}`, + 0, + ) + var payload map[string]any + if err := json.Unmarshal(chunk, &payload); err != nil { + t.Fatalf("unmarshal: %v", err) + } + if payload["type"] != "error" { + t.Fatalf("type = %v, want %q", payload["type"], "error") + } + if payload["code"] != "internal_server_error" { + t.Fatalf("code = %v, want %q", payload["code"], "internal_server_error") + } + if payload["message"] != "oops" { + t.Fatalf("message = %v, want %q", payload["message"], "oops") + } +} diff --git a/sdk/auth/codex.go b/sdk/auth/codex.go index c81842eb..1af36936 100644 --- a/sdk/auth/codex.go +++ b/sdk/auth/codex.go @@ -2,8 +2,6 @@ package auth import ( "context" - "crypto/sha256" - "encoding/hex" "fmt" "net/http" "strings" @@ -48,6 +46,10 @@ func (a *CodexAuthenticator) Login(ctx context.Context, cfg *config.Config, opts opts = &LoginOptions{} } + if shouldUseCodexDeviceFlow(opts) { + return a.loginWithDeviceFlow(ctx, cfg, opts) + } + callbackPort := a.CallbackPort if opts.CallbackPort > 0 { callbackPort = opts.CallbackPort @@ -186,39 +188,5 @@ waitForCallback: return nil, codex.NewAuthenticationError(codex.ErrCodeExchangeFailed, err) } - tokenStorage := authSvc.CreateTokenStorage(authBundle) - - if tokenStorage == nil || tokenStorage.Email == "" { - return nil, fmt.Errorf("codex token storage missing account information") - } - - planType := "" - hashAccountID := "" - if tokenStorage.IDToken != "" { - if claims, errParse := codex.ParseJWTToken(tokenStorage.IDToken); errParse == nil && claims != nil { - planType = strings.TrimSpace(claims.CodexAuthInfo.ChatgptPlanType) - accountID := strings.TrimSpace(claims.CodexAuthInfo.ChatgptAccountID) - if accountID != "" { - digest := sha256.Sum256([]byte(accountID)) - hashAccountID = hex.EncodeToString(digest[:])[:8] - } - } - } - fileName := codex.CredentialFileName(tokenStorage.Email, planType, hashAccountID, true) - metadata := map[string]any{ - "email": tokenStorage.Email, - } - - fmt.Println("Codex authentication successful") - if authBundle.APIKey != "" { - fmt.Println("Codex API key obtained and stored") - } - - return &coreauth.Auth{ - ID: fileName, - Provider: a.Provider(), - FileName: fileName, - Storage: tokenStorage, - Metadata: metadata, - }, nil + return a.buildAuthRecord(authSvc, authBundle) } diff --git a/sdk/auth/codex_device.go b/sdk/auth/codex_device.go new file mode 100644 index 00000000..78a95af8 --- /dev/null +++ b/sdk/auth/codex_device.go @@ -0,0 +1,291 @@ +package auth + +import ( + "bytes" + "context" + "crypto/sha256" + "encoding/hex" + "encoding/json" + "fmt" + "io" + "net/http" + "strconv" + "strings" + "time" + + "github.com/router-for-me/CLIProxyAPI/v6/internal/auth/codex" + "github.com/router-for-me/CLIProxyAPI/v6/internal/browser" + "github.com/router-for-me/CLIProxyAPI/v6/internal/config" + "github.com/router-for-me/CLIProxyAPI/v6/internal/util" + coreauth "github.com/router-for-me/CLIProxyAPI/v6/sdk/cliproxy/auth" + log "github.com/sirupsen/logrus" +) + +const ( + codexLoginModeMetadataKey = "codex_login_mode" + codexLoginModeDevice = "device" + codexDeviceUserCodeURL = "https://auth.openai.com/api/accounts/deviceauth/usercode" + codexDeviceTokenURL = "https://auth.openai.com/api/accounts/deviceauth/token" + codexDeviceVerificationURL = "https://auth.openai.com/codex/device" + codexDeviceTokenExchangeRedirectURI = "https://auth.openai.com/deviceauth/callback" + codexDeviceTimeout = 15 * time.Minute + codexDeviceDefaultPollIntervalSeconds = 5 +) + +type codexDeviceUserCodeRequest struct { + ClientID string `json:"client_id"` +} + +type codexDeviceUserCodeResponse struct { + DeviceAuthID string `json:"device_auth_id"` + UserCode string `json:"user_code"` + UserCodeAlt string `json:"usercode"` + Interval json.RawMessage `json:"interval"` +} + +type codexDeviceTokenRequest struct { + DeviceAuthID string `json:"device_auth_id"` + UserCode string `json:"user_code"` +} + +type codexDeviceTokenResponse struct { + AuthorizationCode string `json:"authorization_code"` + CodeVerifier string `json:"code_verifier"` + CodeChallenge string `json:"code_challenge"` +} + +func shouldUseCodexDeviceFlow(opts *LoginOptions) bool { + if opts == nil || opts.Metadata == nil { + return false + } + return strings.EqualFold(strings.TrimSpace(opts.Metadata[codexLoginModeMetadataKey]), codexLoginModeDevice) +} + +func (a *CodexAuthenticator) loginWithDeviceFlow(ctx context.Context, cfg *config.Config, opts *LoginOptions) (*coreauth.Auth, error) { + if ctx == nil { + ctx = context.Background() + } + + httpClient := util.SetProxy(&cfg.SDKConfig, &http.Client{}) + + userCodeResp, err := requestCodexDeviceUserCode(ctx, httpClient) + if err != nil { + return nil, err + } + + deviceCode := strings.TrimSpace(userCodeResp.UserCode) + if deviceCode == "" { + deviceCode = strings.TrimSpace(userCodeResp.UserCodeAlt) + } + deviceAuthID := strings.TrimSpace(userCodeResp.DeviceAuthID) + if deviceCode == "" || deviceAuthID == "" { + return nil, fmt.Errorf("codex device flow did not return required fields") + } + + pollInterval := parseCodexDevicePollInterval(userCodeResp.Interval) + + fmt.Println("Starting Codex device authentication...") + fmt.Printf("Codex device URL: %s\n", codexDeviceVerificationURL) + fmt.Printf("Codex device code: %s\n", deviceCode) + + if !opts.NoBrowser { + if !browser.IsAvailable() { + log.Warn("No browser available; please open the device URL manually") + } else if errOpen := browser.OpenURL(codexDeviceVerificationURL); errOpen != nil { + log.Warnf("Failed to open browser automatically: %v", errOpen) + } + } + + tokenResp, err := pollCodexDeviceToken(ctx, httpClient, deviceAuthID, deviceCode, pollInterval) + if err != nil { + return nil, err + } + + authCode := strings.TrimSpace(tokenResp.AuthorizationCode) + codeVerifier := strings.TrimSpace(tokenResp.CodeVerifier) + codeChallenge := strings.TrimSpace(tokenResp.CodeChallenge) + if authCode == "" || codeVerifier == "" || codeChallenge == "" { + return nil, fmt.Errorf("codex device flow token response missing required fields") + } + + authSvc := codex.NewCodexAuth(cfg) + authBundle, err := authSvc.ExchangeCodeForTokensWithRedirect( + ctx, + authCode, + codexDeviceTokenExchangeRedirectURI, + &codex.PKCECodes{ + CodeVerifier: codeVerifier, + CodeChallenge: codeChallenge, + }, + ) + if err != nil { + return nil, codex.NewAuthenticationError(codex.ErrCodeExchangeFailed, err) + } + + return a.buildAuthRecord(authSvc, authBundle) +} + +func requestCodexDeviceUserCode(ctx context.Context, client *http.Client) (*codexDeviceUserCodeResponse, error) { + body, err := json.Marshal(codexDeviceUserCodeRequest{ClientID: codex.ClientID}) + if err != nil { + return nil, fmt.Errorf("failed to encode codex device request: %w", err) + } + + req, err := http.NewRequestWithContext(ctx, http.MethodPost, codexDeviceUserCodeURL, bytes.NewReader(body)) + if err != nil { + return nil, fmt.Errorf("failed to create codex device request: %w", err) + } + req.Header.Set("Content-Type", "application/json") + req.Header.Set("Accept", "application/json") + + resp, err := client.Do(req) + if err != nil { + return nil, fmt.Errorf("failed to request codex device code: %w", err) + } + defer func() { _ = resp.Body.Close() }() + + respBody, err := io.ReadAll(resp.Body) + if err != nil { + return nil, fmt.Errorf("failed to read codex device code response: %w", err) + } + + if !codexDeviceIsSuccessStatus(resp.StatusCode) { + trimmed := strings.TrimSpace(string(respBody)) + if resp.StatusCode == http.StatusNotFound { + return nil, fmt.Errorf("codex device endpoint is unavailable (status %d)", resp.StatusCode) + } + if trimmed == "" { + trimmed = "empty response body" + } + return nil, fmt.Errorf("codex device code request failed with status %d: %s", resp.StatusCode, trimmed) + } + + var parsed codexDeviceUserCodeResponse + if err := json.Unmarshal(respBody, &parsed); err != nil { + return nil, fmt.Errorf("failed to decode codex device code response: %w", err) + } + + return &parsed, nil +} + +func pollCodexDeviceToken(ctx context.Context, client *http.Client, deviceAuthID, userCode string, interval time.Duration) (*codexDeviceTokenResponse, error) { + deadline := time.Now().Add(codexDeviceTimeout) + + for { + if time.Now().After(deadline) { + return nil, fmt.Errorf("codex device authentication timed out after 15 minutes") + } + + body, err := json.Marshal(codexDeviceTokenRequest{ + DeviceAuthID: deviceAuthID, + UserCode: userCode, + }) + if err != nil { + return nil, fmt.Errorf("failed to encode codex device poll request: %w", err) + } + + req, err := http.NewRequestWithContext(ctx, http.MethodPost, codexDeviceTokenURL, bytes.NewReader(body)) + if err != nil { + return nil, fmt.Errorf("failed to create codex device poll request: %w", err) + } + req.Header.Set("Content-Type", "application/json") + req.Header.Set("Accept", "application/json") + + resp, err := client.Do(req) + if err != nil { + return nil, fmt.Errorf("failed to poll codex device token: %w", err) + } + + respBody, readErr := io.ReadAll(resp.Body) + _ = resp.Body.Close() + if readErr != nil { + return nil, fmt.Errorf("failed to read codex device poll response: %w", readErr) + } + + switch { + case codexDeviceIsSuccessStatus(resp.StatusCode): + var parsed codexDeviceTokenResponse + if err := json.Unmarshal(respBody, &parsed); err != nil { + return nil, fmt.Errorf("failed to decode codex device token response: %w", err) + } + return &parsed, nil + case resp.StatusCode == http.StatusForbidden || resp.StatusCode == http.StatusNotFound: + select { + case <-ctx.Done(): + return nil, ctx.Err() + case <-time.After(interval): + continue + } + default: + trimmed := strings.TrimSpace(string(respBody)) + if trimmed == "" { + trimmed = "empty response body" + } + return nil, fmt.Errorf("codex device token polling failed with status %d: %s", resp.StatusCode, trimmed) + } + } +} + +func parseCodexDevicePollInterval(raw json.RawMessage) time.Duration { + defaultInterval := time.Duration(codexDeviceDefaultPollIntervalSeconds) * time.Second + if len(raw) == 0 { + return defaultInterval + } + + var asString string + if err := json.Unmarshal(raw, &asString); err == nil { + if seconds, convErr := strconv.Atoi(strings.TrimSpace(asString)); convErr == nil && seconds > 0 { + return time.Duration(seconds) * time.Second + } + } + + var asInt int + if err := json.Unmarshal(raw, &asInt); err == nil && asInt > 0 { + return time.Duration(asInt) * time.Second + } + + return defaultInterval +} + +func codexDeviceIsSuccessStatus(code int) bool { + return code >= 200 && code < 300 +} + +func (a *CodexAuthenticator) buildAuthRecord(authSvc *codex.CodexAuth, authBundle *codex.CodexAuthBundle) (*coreauth.Auth, error) { + tokenStorage := authSvc.CreateTokenStorage(authBundle) + + if tokenStorage == nil || tokenStorage.Email == "" { + return nil, fmt.Errorf("codex token storage missing account information") + } + + planType := "" + hashAccountID := "" + if tokenStorage.IDToken != "" { + if claims, errParse := codex.ParseJWTToken(tokenStorage.IDToken); errParse == nil && claims != nil { + planType = strings.TrimSpace(claims.CodexAuthInfo.ChatgptPlanType) + accountID := strings.TrimSpace(claims.CodexAuthInfo.ChatgptAccountID) + if accountID != "" { + digest := sha256.Sum256([]byte(accountID)) + hashAccountID = hex.EncodeToString(digest[:])[:8] + } + } + } + + fileName := codex.CredentialFileName(tokenStorage.Email, planType, hashAccountID, true) + metadata := map[string]any{ + "email": tokenStorage.Email, + } + + fmt.Println("Codex authentication successful") + if authBundle.APIKey != "" { + fmt.Println("Codex API key obtained and stored") + } + + return &coreauth.Auth{ + ID: fileName, + Provider: a.Provider(), + FileName: fileName, + Storage: tokenStorage, + Metadata: metadata, + }, nil +} diff --git a/sdk/auth/filestore.go b/sdk/auth/filestore.go index 4715d7f7..9e5a2ddb 100644 --- a/sdk/auth/filestore.go +++ b/sdk/auth/filestore.go @@ -64,8 +64,16 @@ func (s *FileTokenStore) Save(ctx context.Context, auth *cliproxyauth.Auth) (str return "", fmt.Errorf("auth filestore: create dir failed: %w", err) } + // metadataSetter is a private interface for TokenStorage implementations that support metadata injection. + type metadataSetter interface { + SetMetadata(map[string]any) + } + switch { case auth.Storage != nil: + if setter, ok := auth.Storage.(metadataSetter); ok { + setter.SetMetadata(auth.Metadata) + } if err = auth.Storage.SaveTokenToFile(path); err != nil { return "", err } diff --git a/sdk/cliproxy/auth/selector.go b/sdk/cliproxy/auth/selector.go index a173ed01..cf79e173 100644 --- a/sdk/cliproxy/auth/selector.go +++ b/sdk/cliproxy/auth/selector.go @@ -5,6 +5,7 @@ import ( "encoding/json" "fmt" "math" + "math/rand/v2" "net/http" "sort" "strconv" @@ -248,6 +249,9 @@ func getAvailableAuths(auths []*Auth, provider, model string, now time.Time) ([] } // Pick selects the next available auth for the provider in a round-robin manner. +// For gemini-cli virtual auths (identified by the gemini_virtual_parent attribute), +// a two-level round-robin is used: first cycling across credential groups (parent +// accounts), then cycling within each group's project auths. func (s *RoundRobinSelector) Pick(ctx context.Context, provider, model string, opts cliproxyexecutor.Options, auths []*Auth) (*Auth, error) { _ = opts now := time.Now() @@ -265,21 +269,87 @@ func (s *RoundRobinSelector) Pick(ctx context.Context, provider, model string, o if limit <= 0 { limit = 4096 } - if _, ok := s.cursors[key]; !ok && len(s.cursors) >= limit { - s.cursors = make(map[string]int) - } - index := s.cursors[key] + // Check if any available auth has gemini_virtual_parent attribute, + // indicating gemini-cli virtual auths that should use credential-level polling. + groups, parentOrder := groupByVirtualParent(available) + if len(parentOrder) > 1 { + // Two-level round-robin: first select a credential group, then pick within it. + groupKey := key + "::group" + s.ensureCursorKey(groupKey, limit) + if _, exists := s.cursors[groupKey]; !exists { + // Seed with a random initial offset so the starting credential is randomized. + s.cursors[groupKey] = rand.IntN(len(parentOrder)) + } + groupIndex := s.cursors[groupKey] + if groupIndex >= 2_147_483_640 { + groupIndex = 0 + } + s.cursors[groupKey] = groupIndex + 1 + + selectedParent := parentOrder[groupIndex%len(parentOrder)] + group := groups[selectedParent] + + // Second level: round-robin within the selected credential group. + innerKey := key + "::cred:" + selectedParent + s.ensureCursorKey(innerKey, limit) + innerIndex := s.cursors[innerKey] + if innerIndex >= 2_147_483_640 { + innerIndex = 0 + } + s.cursors[innerKey] = innerIndex + 1 + s.mu.Unlock() + return group[innerIndex%len(group)], nil + } + + // Flat round-robin for non-grouped auths (original behavior). + s.ensureCursorKey(key, limit) + index := s.cursors[key] if index >= 2_147_483_640 { index = 0 } - s.cursors[key] = index + 1 s.mu.Unlock() - // log.Debugf("available: %d, index: %d, key: %d", len(available), index, index%len(available)) return available[index%len(available)], nil } +// ensureCursorKey ensures the cursor map has capacity for the given key. +// Must be called with s.mu held. +func (s *RoundRobinSelector) ensureCursorKey(key string, limit int) { + if _, ok := s.cursors[key]; !ok && len(s.cursors) >= limit { + s.cursors = make(map[string]int) + } +} + +// groupByVirtualParent groups auths by their gemini_virtual_parent attribute. +// Returns a map of parentID -> auths and a sorted slice of parent IDs for stable iteration. +// Only auths with a non-empty gemini_virtual_parent are grouped; if any auth lacks +// this attribute, nil/nil is returned so the caller falls back to flat round-robin. +func groupByVirtualParent(auths []*Auth) (map[string][]*Auth, []string) { + if len(auths) == 0 { + return nil, nil + } + groups := make(map[string][]*Auth) + for _, a := range auths { + parent := "" + if a.Attributes != nil { + parent = strings.TrimSpace(a.Attributes["gemini_virtual_parent"]) + } + if parent == "" { + // Non-virtual auth present; fall back to flat round-robin. + return nil, nil + } + groups[parent] = append(groups[parent], a) + } + // Collect parent IDs in sorted order for stable cursor indexing. + parentOrder := make([]string, 0, len(groups)) + for p := range groups { + parentOrder = append(parentOrder, p) + } + sort.Strings(parentOrder) + return groups, parentOrder +} + // Pick selects the first available auth for the provider in a deterministic manner. func (s *FillFirstSelector) Pick(ctx context.Context, provider, model string, opts cliproxyexecutor.Options, auths []*Auth) (*Auth, error) { _ = opts diff --git a/sdk/cliproxy/auth/selector_test.go b/sdk/cliproxy/auth/selector_test.go index fe1cf15e..79431a9a 100644 --- a/sdk/cliproxy/auth/selector_test.go +++ b/sdk/cliproxy/auth/selector_test.go @@ -402,3 +402,128 @@ func TestRoundRobinSelectorPick_CursorKeyCap(t *testing.T) { t.Fatalf("selector.cursors missing key %q", "gemini:m3") } } + +func TestRoundRobinSelectorPick_GeminiCLICredentialGrouping(t *testing.T) { + t.Parallel() + + selector := &RoundRobinSelector{} + + // Simulate two gemini-cli credentials, each with multiple projects: + // Credential A (parent = "cred-a.json") has 3 projects + // Credential B (parent = "cred-b.json") has 2 projects + auths := []*Auth{ + {ID: "cred-a.json::proj-a1", Attributes: map[string]string{"gemini_virtual_parent": "cred-a.json"}}, + {ID: "cred-a.json::proj-a2", Attributes: map[string]string{"gemini_virtual_parent": "cred-a.json"}}, + {ID: "cred-a.json::proj-a3", Attributes: map[string]string{"gemini_virtual_parent": "cred-a.json"}}, + {ID: "cred-b.json::proj-b1", Attributes: map[string]string{"gemini_virtual_parent": "cred-b.json"}}, + {ID: "cred-b.json::proj-b2", Attributes: map[string]string{"gemini_virtual_parent": "cred-b.json"}}, + } + + // Two-level round-robin: consecutive picks must alternate between credentials. + // Credential group order is randomized, but within each call the group cursor + // advances by 1, so consecutive picks should cycle through different parents. + picks := make([]string, 6) + parents := make([]string, 6) + for i := 0; i < 6; i++ { + got, err := selector.Pick(context.Background(), "gemini-cli", "gemini-2.5-pro", cliproxyexecutor.Options{}, auths) + if err != nil { + t.Fatalf("Pick() #%d error = %v", i, err) + } + if got == nil { + t.Fatalf("Pick() #%d auth = nil", i) + } + picks[i] = got.ID + parents[i] = got.Attributes["gemini_virtual_parent"] + } + + // Verify property: consecutive picks must alternate between credential groups. + for i := 1; i < len(parents); i++ { + if parents[i] == parents[i-1] { + t.Fatalf("Pick() #%d and #%d both from same parent %q (IDs: %q, %q); expected alternating credentials", + i-1, i, parents[i], picks[i-1], picks[i]) + } + } + + // Verify property: each credential's projects are picked in sequence (round-robin within group). + credPicks := map[string][]string{} + for i, id := range picks { + credPicks[parents[i]] = append(credPicks[parents[i]], id) + } + for parent, ids := range credPicks { + for i := 1; i < len(ids); i++ { + if ids[i] == ids[i-1] { + t.Fatalf("Credential %q picked same project %q twice in a row", parent, ids[i]) + } + } + } +} + +func TestRoundRobinSelectorPick_SingleParentFallsBackToFlat(t *testing.T) { + t.Parallel() + + selector := &RoundRobinSelector{} + + // All auths from the same parent - should fall back to flat round-robin + // because there's only one credential group (no benefit from two-level). + auths := []*Auth{ + {ID: "cred-a.json::proj-a1", Attributes: map[string]string{"gemini_virtual_parent": "cred-a.json"}}, + {ID: "cred-a.json::proj-a2", Attributes: map[string]string{"gemini_virtual_parent": "cred-a.json"}}, + {ID: "cred-a.json::proj-a3", Attributes: map[string]string{"gemini_virtual_parent": "cred-a.json"}}, + } + + // With single parent group, parentOrder has length 1, so it uses flat round-robin. + // Sorted by ID: proj-a1, proj-a2, proj-a3 + want := []string{ + "cred-a.json::proj-a1", + "cred-a.json::proj-a2", + "cred-a.json::proj-a3", + "cred-a.json::proj-a1", + } + + for i, expectedID := range want { + got, err := selector.Pick(context.Background(), "gemini-cli", "gemini-2.5-pro", cliproxyexecutor.Options{}, auths) + if err != nil { + t.Fatalf("Pick() #%d error = %v", i, err) + } + if got == nil { + t.Fatalf("Pick() #%d auth = nil", i) + } + if got.ID != expectedID { + t.Fatalf("Pick() #%d auth.ID = %q, want %q", i, got.ID, expectedID) + } + } +} + +func TestRoundRobinSelectorPick_MixedVirtualAndNonVirtualFallsBackToFlat(t *testing.T) { + t.Parallel() + + selector := &RoundRobinSelector{} + + // Mix of virtual and non-virtual auths (e.g., a regular gemini-cli auth without projects + // alongside virtual ones). Should fall back to flat round-robin. + auths := []*Auth{ + {ID: "cred-a.json::proj-a1", Attributes: map[string]string{"gemini_virtual_parent": "cred-a.json"}}, + {ID: "cred-regular.json"}, // no gemini_virtual_parent + } + + // groupByVirtualParent returns nil when any auth lacks the attribute, + // so flat round-robin is used. Sorted by ID: cred-a.json::proj-a1, cred-regular.json + want := []string{ + "cred-a.json::proj-a1", + "cred-regular.json", + "cred-a.json::proj-a1", + } + + for i, expectedID := range want { + got, err := selector.Pick(context.Background(), "gemini-cli", "", cliproxyexecutor.Options{}, auths) + if err != nil { + t.Fatalf("Pick() #%d error = %v", i, err) + } + if got == nil { + t.Fatalf("Pick() #%d auth = nil", i) + } + if got.ID != expectedID { + t.Fatalf("Pick() #%d auth.ID = %q, want %q", i, got.ID, expectedID) + } + } +} diff --git a/sdk/cliproxy/auth/types.go b/sdk/cliproxy/auth/types.go index 88d0ea52..e24919f5 100644 --- a/sdk/cliproxy/auth/types.go +++ b/sdk/cliproxy/auth/types.go @@ -1,9 +1,12 @@ package auth import ( + "context" "crypto/sha256" "encoding/hex" "encoding/json" + "net/http" + "net/url" "strconv" "strings" "sync" @@ -12,6 +15,33 @@ import ( baseauth "github.com/router-for-me/CLIProxyAPI/v6/internal/auth" ) +// PostAuthHook defines a function that is called after an Auth record is created +// but before it is persisted to storage. This allows for modification of the +// Auth record (e.g., injecting metadata) based on external context. +type PostAuthHook func(context.Context, *Auth) error + +// RequestInfo holds information extracted from the HTTP request. +// It is injected into the context passed to PostAuthHook. +type RequestInfo struct { + Query url.Values + Headers http.Header +} + +type requestInfoKey struct{} + +// WithRequestInfo returns a new context with the given RequestInfo attached. +func WithRequestInfo(ctx context.Context, info *RequestInfo) context.Context { + return context.WithValue(ctx, requestInfoKey{}, info) +} + +// GetRequestInfo retrieves the RequestInfo from the context, if present. +func GetRequestInfo(ctx context.Context) *RequestInfo { + if val, ok := ctx.Value(requestInfoKey{}).(*RequestInfo); ok { + return val + } + return nil +} + // Auth encapsulates the runtime state and metadata associated with a single credential. type Auth struct { // ID uniquely identifies the auth record across restarts. diff --git a/sdk/cliproxy/builder.go b/sdk/cliproxy/builder.go index 60ca07f5..0e6d1421 100644 --- a/sdk/cliproxy/builder.go +++ b/sdk/cliproxy/builder.go @@ -153,6 +153,16 @@ func (b *Builder) WithLocalManagementPassword(password string) *Builder { return b } +// WithPostAuthHook registers a hook to be called after an Auth record is created +// but before it is persisted to storage. +func (b *Builder) WithPostAuthHook(hook coreauth.PostAuthHook) *Builder { + if hook == nil { + return b + } + b.serverOptions = append(b.serverOptions, api.WithPostAuthHook(hook)) + return b +} + // Build validates inputs, applies defaults, and returns a ready-to-run service. func (b *Builder) Build() (*Service, error) { if b.cfg == nil {