diff --git a/internal/api/handlers/management/auth_files.go b/internal/api/handlers/management/auth_files.go index d35570ce..265b4f8c 100644 --- a/internal/api/handlers/management/auth_files.go +++ b/internal/api/handlers/management/auth_files.go @@ -3,6 +3,9 @@ package management import ( "bytes" "context" + "crypto/rand" + "crypto/sha256" + "encoding/base64" "encoding/json" "errors" "fmt" @@ -23,6 +26,7 @@ import ( "github.com/router-for-me/CLIProxyAPI/v6/internal/auth/codex" geminiAuth "github.com/router-for-me/CLIProxyAPI/v6/internal/auth/gemini" iflowauth "github.com/router-for-me/CLIProxyAPI/v6/internal/auth/iflow" + kiroauth "github.com/router-for-me/CLIProxyAPI/v6/internal/auth/kiro" "github.com/router-for-me/CLIProxyAPI/v6/internal/auth/qwen" "github.com/router-for-me/CLIProxyAPI/v6/internal/interfaces" "github.com/router-for-me/CLIProxyAPI/v6/internal/misc" @@ -2154,9 +2158,35 @@ func checkCloudAPIIsEnabled(ctx context.Context, httpClient *http.Client, projec func (h *Handler) GetAuthStatus(c *gin.Context) { state := c.Query("state") - if err, ok := getOAuthStatus(state); ok { - if err != "" { - c.JSON(200, gin.H{"status": "error", "error": err}) + if statusValue, ok := getOAuthStatus(state); ok { + if statusValue != "" { + // Check for device_code prefix (Kiro AWS Builder ID flow) + // Format: "device_code|verification_url|user_code" + // Using "|" as separator because URLs contain ":" + if strings.HasPrefix(statusValue, "device_code|") { + parts := strings.SplitN(statusValue, "|", 3) + if len(parts) == 3 { + c.JSON(200, gin.H{ + "status": "device_code", + "verification_url": parts[1], + "user_code": parts[2], + }) + return + } + } + // Check for auth_url prefix (Kiro social auth flow) + // Format: "auth_url|url" + // Using "|" as separator because URLs contain ":" + if strings.HasPrefix(statusValue, "auth_url|") { + authURL := strings.TrimPrefix(statusValue, "auth_url|") + c.JSON(200, gin.H{ + "status": "auth_url", + "url": authURL, + }) + return + } + // Otherwise treat as error + c.JSON(200, gin.H{"status": "error", "error": statusValue}) } else { c.JSON(200, gin.H{"status": "wait"}) return @@ -2166,3 +2196,295 @@ func (h *Handler) GetAuthStatus(c *gin.Context) { } deleteOAuthStatus(state) } + +const kiroCallbackPort = 9876 + +func (h *Handler) RequestKiroToken(c *gin.Context) { + ctx := context.Background() + + // Get the login method from query parameter (default: aws for device code flow) + method := strings.ToLower(strings.TrimSpace(c.Query("method"))) + if method == "" { + method = "aws" + } + + fmt.Println("Initializing Kiro authentication...") + + state := fmt.Sprintf("kiro-%d", time.Now().UnixNano()) + + switch method { + case "aws", "builder-id": + // AWS Builder ID uses device code flow (no callback needed) + go func() { + ssoClient := kiroauth.NewSSOOIDCClient(h.cfg) + + // Step 1: Register client + fmt.Println("Registering client...") + regResp, err := ssoClient.RegisterClient(ctx) + if err != nil { + log.Errorf("Failed to register client: %v", err) + setOAuthStatus(state, "Failed to register client") + return + } + + // Step 2: Start device authorization + fmt.Println("Starting device authorization...") + authResp, err := ssoClient.StartDeviceAuthorization(ctx, regResp.ClientID, regResp.ClientSecret) + if err != nil { + log.Errorf("Failed to start device auth: %v", err) + setOAuthStatus(state, "Failed to start device authorization") + return + } + + // Store the verification URL for the frontend to display + // Using "|" as separator because URLs contain ":" + setOAuthStatus(state, "device_code|"+authResp.VerificationURIComplete+"|"+authResp.UserCode) + + // Step 3: Poll for token + fmt.Println("Waiting for authorization...") + interval := 5 * time.Second + if authResp.Interval > 0 { + interval = time.Duration(authResp.Interval) * time.Second + } + deadline := time.Now().Add(time.Duration(authResp.ExpiresIn) * time.Second) + + for time.Now().Before(deadline) { + select { + case <-ctx.Done(): + setOAuthStatus(state, "Authorization cancelled") + return + case <-time.After(interval): + tokenResp, err := ssoClient.CreateToken(ctx, regResp.ClientID, regResp.ClientSecret, authResp.DeviceCode) + if err != nil { + errStr := err.Error() + if strings.Contains(errStr, "authorization_pending") { + continue + } + if strings.Contains(errStr, "slow_down") { + interval += 5 * time.Second + continue + } + log.Errorf("Token creation failed: %v", err) + setOAuthStatus(state, "Token creation failed") + return + } + + // Success! Save the token + expiresAt := time.Now().Add(time.Duration(tokenResp.ExpiresIn) * time.Second) + email := kiroauth.ExtractEmailFromJWT(tokenResp.AccessToken) + + idPart := kiroauth.SanitizeEmailForFilename(email) + if idPart == "" { + idPart = fmt.Sprintf("%d", time.Now().UnixNano()%100000) + } + + now := time.Now() + fileName := fmt.Sprintf("kiro-aws-%s.json", idPart) + + record := &coreauth.Auth{ + ID: fileName, + Provider: "kiro", + FileName: fileName, + Metadata: map[string]any{ + "type": "kiro", + "access_token": tokenResp.AccessToken, + "refresh_token": tokenResp.RefreshToken, + "expires_at": expiresAt.Format(time.RFC3339), + "auth_method": "builder-id", + "provider": "AWS", + "client_id": regResp.ClientID, + "client_secret": regResp.ClientSecret, + "email": email, + "last_refresh": now.Format(time.RFC3339), + }, + } + + savedPath, errSave := h.saveTokenRecord(ctx, record) + if errSave != nil { + log.Errorf("Failed to save authentication tokens: %v", errSave) + setOAuthStatus(state, "Failed to save authentication tokens") + return + } + + fmt.Printf("Authentication successful! Token saved to %s\n", savedPath) + if email != "" { + fmt.Printf("Authenticated as: %s\n", email) + } + deleteOAuthStatus(state) + return + } + } + + setOAuthStatus(state, "Authorization timed out") + }() + + // Return immediately with the state for polling + c.JSON(200, gin.H{"status": "ok", "state": state, "method": "device_code"}) + + case "google", "github": + // Social auth uses protocol handler - for WEB UI we use a callback forwarder + provider := "Google" + if method == "github" { + provider = "Github" + } + + isWebUI := isWebUIRequest(c) + if isWebUI { + targetURL, errTarget := h.managementCallbackURL("/kiro/callback") + if errTarget != nil { + log.WithError(errTarget).Error("failed to compute kiro callback target") + c.JSON(http.StatusInternalServerError, gin.H{"error": "callback server unavailable"}) + return + } + if _, errStart := startCallbackForwarder(kiroCallbackPort, "kiro", targetURL); errStart != nil { + log.WithError(errStart).Error("failed to start kiro callback forwarder") + c.JSON(http.StatusInternalServerError, gin.H{"error": "failed to start callback server"}) + return + } + } + + go func() { + if isWebUI { + defer stopCallbackForwarder(kiroCallbackPort) + } + + socialClient := kiroauth.NewSocialAuthClient(h.cfg) + + // Generate PKCE codes + codeVerifier, codeChallenge, err := generateKiroPKCE() + if err != nil { + log.Errorf("Failed to generate PKCE: %v", err) + setOAuthStatus(state, "Failed to generate PKCE") + return + } + + // Build login URL + authURL := fmt.Sprintf("%s/login?idp=%s&redirect_uri=%s&code_challenge=%s&code_challenge_method=S256&state=%s&prompt=select_account", + "https://prod.us-east-1.auth.desktop.kiro.dev", + provider, + url.QueryEscape(kiroauth.KiroRedirectURI), + codeChallenge, + state, + ) + + // Store auth URL for frontend + // Using "|" as separator because URLs contain ":" + setOAuthStatus(state, "auth_url|"+authURL) + + // Wait for callback file + waitFile := filepath.Join(h.cfg.AuthDir, fmt.Sprintf(".oauth-kiro-%s.oauth", state)) + deadline := time.Now().Add(5 * time.Minute) + + for { + if time.Now().After(deadline) { + log.Error("oauth flow timed out") + setOAuthStatus(state, "OAuth flow timed out") + return + } + if data, errR := os.ReadFile(waitFile); errR == nil { + var m map[string]string + _ = json.Unmarshal(data, &m) + _ = os.Remove(waitFile) + if errStr := m["error"]; errStr != "" { + log.Errorf("Authentication failed: %s", errStr) + setOAuthStatus(state, "Authentication failed") + return + } + if m["state"] != state { + log.Errorf("State mismatch") + setOAuthStatus(state, "State mismatch") + return + } + code := m["code"] + if code == "" { + log.Error("No authorization code received") + setOAuthStatus(state, "No authorization code received") + return + } + + // Exchange code for tokens + tokenReq := &kiroauth.CreateTokenRequest{ + Code: code, + CodeVerifier: codeVerifier, + RedirectURI: kiroauth.KiroRedirectURI, + } + + tokenResp, errToken := socialClient.CreateToken(ctx, tokenReq) + if errToken != nil { + log.Errorf("Failed to exchange code for tokens: %v", errToken) + setOAuthStatus(state, "Failed to exchange code for tokens") + return + } + + // Save the token + expiresIn := tokenResp.ExpiresIn + if expiresIn <= 0 { + expiresIn = 3600 + } + expiresAt := time.Now().Add(time.Duration(expiresIn) * time.Second) + email := kiroauth.ExtractEmailFromJWT(tokenResp.AccessToken) + + idPart := kiroauth.SanitizeEmailForFilename(email) + if idPart == "" { + idPart = fmt.Sprintf("%d", time.Now().UnixNano()%100000) + } + + now := time.Now() + fileName := fmt.Sprintf("kiro-%s-%s.json", strings.ToLower(provider), idPart) + + record := &coreauth.Auth{ + ID: fileName, + Provider: "kiro", + FileName: fileName, + Metadata: map[string]any{ + "type": "kiro", + "access_token": tokenResp.AccessToken, + "refresh_token": tokenResp.RefreshToken, + "profile_arn": tokenResp.ProfileArn, + "expires_at": expiresAt.Format(time.RFC3339), + "auth_method": "social", + "provider": provider, + "email": email, + "last_refresh": now.Format(time.RFC3339), + }, + } + + savedPath, errSave := h.saveTokenRecord(ctx, record) + if errSave != nil { + log.Errorf("Failed to save authentication tokens: %v", errSave) + setOAuthStatus(state, "Failed to save authentication tokens") + return + } + + fmt.Printf("Authentication successful! Token saved to %s\n", savedPath) + if email != "" { + fmt.Printf("Authenticated as: %s\n", email) + } + deleteOAuthStatus(state) + return + } + time.Sleep(500 * time.Millisecond) + } + }() + + setOAuthStatus(state, "") + c.JSON(200, gin.H{"status": "ok", "state": state, "method": "social"}) + + default: + c.JSON(http.StatusBadRequest, gin.H{"error": "invalid method, use 'aws', 'google', or 'github'"}) + } +} + +// generateKiroPKCE generates PKCE code verifier and challenge for Kiro OAuth. +func generateKiroPKCE() (verifier, challenge string, err error) { + b := make([]byte, 32) + if _, err := io.ReadFull(rand.Reader, b); err != nil { + return "", "", fmt.Errorf("failed to generate random bytes: %w", err) + } + verifier = base64.RawURLEncoding.EncodeToString(b) + + h := sha256.Sum256([]byte(verifier)) + challenge = base64.RawURLEncoding.EncodeToString(h[:]) + + return verifier, challenge, nil +} diff --git a/internal/api/modules/amp/proxy.go b/internal/api/modules/amp/proxy.go index 5a3f2081..6ea092c4 100644 --- a/internal/api/modules/amp/proxy.go +++ b/internal/api/modules/amp/proxy.go @@ -7,13 +7,11 @@ import ( "errors" "fmt" "io" - "net" "net/http" "net/http/httputil" "net/url" "strconv" "strings" - "time" "github.com/gin-gonic/gin" log "github.com/sirupsen/logrus" @@ -38,22 +36,6 @@ func createReverseProxy(upstreamURL string, secretSource SecretSource) (*httputi } proxy := httputil.NewSingleHostReverseProxy(parsed) - - // Configure custom Transport with optimized connection pooling for high concurrency - proxy.Transport = &http.Transport{ - MaxIdleConns: 100, - MaxIdleConnsPerHost: 20, // Increased from default 2 to support more concurrent users - MaxConnsPerHost: 0, // No limit on max concurrent connections per host - IdleConnTimeout: 90 * time.Second, - DialContext: (&net.Dialer{ - Timeout: 30 * time.Second, - KeepAlive: 30 * time.Second, - }).DialContext, - TLSHandshakeTimeout: 10 * time.Second, - ResponseHeaderTimeout: 60 * time.Second, - ExpectContinueTimeout: 1 * time.Second, - } - originalDirector := proxy.Director // Modify outgoing requests to inject API key and fix routing diff --git a/internal/api/server.go b/internal/api/server.go index ade08fef..d702551e 100644 --- a/internal/api/server.go +++ b/internal/api/server.go @@ -421,6 +421,18 @@ func (s *Server) setupRoutes() { c.String(http.StatusOK, oauthCallbackSuccessHTML) }) + s.engine.GET("/kiro/callback", func(c *gin.Context) { + code := c.Query("code") + state := c.Query("state") + errStr := c.Query("error") + if state != "" { + file := fmt.Sprintf("%s/.oauth-kiro-%s.oauth", s.cfg.AuthDir, state) + _ = os.WriteFile(file, []byte(fmt.Sprintf(`{"code":"%s","state":"%s","error":"%s"}`, code, state, errStr)), 0o600) + } + c.Header("Content-Type", "text/html; charset=utf-8") + c.String(http.StatusOK, oauthCallbackSuccessHTML) + }) + // Management routes are registered lazily by registerManagementRoutes when a secret is configured. } @@ -586,6 +598,7 @@ func (s *Server) registerManagementRoutes() { mgmt.GET("/qwen-auth-url", s.mgmt.RequestQwenToken) mgmt.GET("/iflow-auth-url", s.mgmt.RequestIFlowToken) mgmt.POST("/iflow-auth-url", s.mgmt.RequestIFlowCookieToken) + mgmt.GET("/kiro-auth-url", s.mgmt.RequestKiroToken) mgmt.GET("/get-auth-status", s.mgmt.GetAuthStatus) } } diff --git a/internal/auth/kiro/social_auth.go b/internal/auth/kiro/social_auth.go index 61c67886..2ac29bf8 100644 --- a/internal/auth/kiro/social_auth.go +++ b/internal/auth/kiro/social_auth.go @@ -126,8 +126,8 @@ func (c *SocialAuthClient) buildLoginURL(provider, redirectURI, codeChallenge, s ) } -// createToken exchanges the authorization code for tokens. -func (c *SocialAuthClient) createToken(ctx context.Context, req *CreateTokenRequest) (*SocialTokenResponse, error) { +// CreateToken exchanges the authorization code for tokens. +func (c *SocialAuthClient) CreateToken(ctx context.Context, req *CreateTokenRequest) (*SocialTokenResponse, error) { body, err := json.Marshal(req) if err != nil { return nil, fmt.Errorf("failed to marshal token request: %w", err) @@ -326,7 +326,7 @@ func (c *SocialAuthClient) LoginWithSocial(ctx context.Context, provider SocialP RedirectURI: KiroRedirectURI, } - tokenResp, err := c.createToken(ctx, tokenReq) + tokenResp, err := c.CreateToken(ctx, tokenReq) if err != nil { return nil, fmt.Errorf("failed to exchange code for tokens: %w", err) } diff --git a/internal/runtime/executor/proxy_helpers.go b/internal/runtime/executor/proxy_helpers.go index 8ac91e03..4cda7b16 100644 --- a/internal/runtime/executor/proxy_helpers.go +++ b/internal/runtime/executor/proxy_helpers.go @@ -7,7 +7,6 @@ import ( "net/url" "strings" "sync" - "time" "github.com/router-for-me/CLIProxyAPI/v6/internal/config" cliproxyauth "github.com/router-for-me/CLIProxyAPI/v6/sdk/cliproxy/auth" @@ -137,25 +136,15 @@ func buildProxyTransport(proxyURL string) *http.Transport { log.Errorf("create SOCKS5 dialer failed: %v", errSOCKS5) return nil } - // Set up a custom transport using the SOCKS5 dialer with optimized connection pooling + // Set up a custom transport using the SOCKS5 dialer transport = &http.Transport{ DialContext: func(ctx context.Context, network, addr string) (net.Conn, error) { return dialer.Dial(network, addr) }, - MaxIdleConns: 100, - MaxIdleConnsPerHost: 20, // Increased from default 2 to support more concurrent users - MaxConnsPerHost: 0, // No limit on max concurrent connections per host - IdleConnTimeout: 90 * time.Second, } } else if parsedURL.Scheme == "http" || parsedURL.Scheme == "https" { - // Configure HTTP or HTTPS proxy with optimized connection pooling - transport = &http.Transport{ - Proxy: http.ProxyURL(parsedURL), - MaxIdleConns: 100, - MaxIdleConnsPerHost: 20, // Increased from default 2 to support more concurrent users - MaxConnsPerHost: 0, // No limit on max concurrent connections per host - IdleConnTimeout: 90 * time.Second, - } + // Configure HTTP or HTTPS proxy + transport = &http.Transport{Proxy: http.ProxyURL(parsedURL)} } else { log.Errorf("unsupported proxy scheme: %s", parsedURL.Scheme) return nil diff --git a/internal/util/proxy.go b/internal/util/proxy.go index e5ac7cd6..aea52ba8 100644 --- a/internal/util/proxy.go +++ b/internal/util/proxy.go @@ -8,7 +8,6 @@ import ( "net" "net/http" "net/url" - "time" "github.com/router-for-me/CLIProxyAPI/v6/sdk/config" log "github.com/sirupsen/logrus" @@ -37,25 +36,15 @@ func SetProxy(cfg *config.SDKConfig, httpClient *http.Client) *http.Client { log.Errorf("create SOCKS5 dialer failed: %v", errSOCKS5) return httpClient } - // Set up a custom transport using the SOCKS5 dialer with optimized connection pooling + // Set up a custom transport using the SOCKS5 dialer. transport = &http.Transport{ DialContext: func(ctx context.Context, network, addr string) (net.Conn, error) { return dialer.Dial(network, addr) }, - MaxIdleConns: 100, - MaxIdleConnsPerHost: 20, // Increased from default 2 to support more concurrent users - MaxConnsPerHost: 0, // No limit on max concurrent connections per host - IdleConnTimeout: 90 * time.Second, } } else if proxyURL.Scheme == "http" || proxyURL.Scheme == "https" { - // Configure HTTP or HTTPS proxy with optimized connection pooling - transport = &http.Transport{ - Proxy: http.ProxyURL(proxyURL), - MaxIdleConns: 100, - MaxIdleConnsPerHost: 20, // Increased from default 2 to support more concurrent users - MaxConnsPerHost: 0, // No limit on max concurrent connections per host - IdleConnTimeout: 90 * time.Second, - } + // Configure HTTP or HTTPS proxy. + transport = &http.Transport{Proxy: http.ProxyURL(proxyURL)} } } // If a new transport was created, apply it to the HTTP client.