diff --git a/.gitignore b/.gitignore index 29cf765b..bab49132 100644 --- a/.gitignore +++ b/.gitignore @@ -50,3 +50,4 @@ _bmad-output/* # macOS .DS_Store ._* +*.bak diff --git a/internal/auth/kiro/aws.go.bak b/internal/auth/kiro/aws.go.bak deleted file mode 100644 index ba73af4d..00000000 --- a/internal/auth/kiro/aws.go.bak +++ /dev/null @@ -1,305 +0,0 @@ -// Package kiro provides authentication functionality for AWS CodeWhisperer (Kiro) API. -// It includes interfaces and implementations for token storage and authentication methods. -package kiro - -import ( - "encoding/base64" - "encoding/json" - "fmt" - "os" - "path/filepath" - "strings" -) - -// PKCECodes holds PKCE verification codes for OAuth2 PKCE flow -type PKCECodes struct { - // CodeVerifier is the cryptographically random string used to correlate - // the authorization request to the token request - CodeVerifier string `json:"code_verifier"` - // CodeChallenge is the SHA256 hash of the code verifier, base64url-encoded - CodeChallenge string `json:"code_challenge"` -} - -// KiroTokenData holds OAuth token information from AWS CodeWhisperer (Kiro) -type KiroTokenData struct { - // AccessToken is the OAuth2 access token for API access - AccessToken string `json:"accessToken"` - // RefreshToken is used to obtain new access tokens - RefreshToken string `json:"refreshToken"` - // ProfileArn is the AWS CodeWhisperer profile ARN - ProfileArn string `json:"profileArn"` - // ExpiresAt is the timestamp when the token expires - ExpiresAt string `json:"expiresAt"` - // AuthMethod indicates the authentication method used (e.g., "builder-id", "social") - AuthMethod string `json:"authMethod"` - // Provider indicates the OAuth provider (e.g., "AWS", "Google") - Provider string `json:"provider"` - // ClientID is the OIDC client ID (needed for token refresh) - ClientID string `json:"clientId,omitempty"` - // ClientSecret is the OIDC client secret (needed for token refresh) - ClientSecret string `json:"clientSecret,omitempty"` - // Email is the user's email address (used for file naming) - Email string `json:"email,omitempty"` - // StartURL is the IDC/Identity Center start URL (only for IDC auth method) - StartURL string `json:"startUrl,omitempty"` - // Region is the AWS region for IDC authentication (only for IDC auth method) - Region string `json:"region,omitempty"` -} - -// KiroAuthBundle aggregates authentication data after OAuth flow completion -type KiroAuthBundle struct { - // TokenData contains the OAuth tokens from the authentication flow - TokenData KiroTokenData `json:"token_data"` - // LastRefresh is the timestamp of the last token refresh - LastRefresh string `json:"last_refresh"` -} - -// KiroUsageInfo represents usage information from CodeWhisperer API -type KiroUsageInfo struct { - // SubscriptionTitle is the subscription plan name (e.g., "KIRO FREE") - SubscriptionTitle string `json:"subscription_title"` - // CurrentUsage is the current credit usage - CurrentUsage float64 `json:"current_usage"` - // UsageLimit is the maximum credit limit - UsageLimit float64 `json:"usage_limit"` - // NextReset is the timestamp of the next usage reset - NextReset string `json:"next_reset"` -} - -// KiroModel represents a model available through the CodeWhisperer API -type KiroModel struct { - // ModelID is the unique identifier for the model - ModelID string `json:"modelId"` - // ModelName is the human-readable name - ModelName string `json:"modelName"` - // Description is the model description - Description string `json:"description"` - // RateMultiplier is the credit multiplier for this model - RateMultiplier float64 `json:"rateMultiplier"` - // RateUnit is the unit for rate calculation (e.g., "credit") - RateUnit string `json:"rateUnit"` - // MaxInputTokens is the maximum input token limit - MaxInputTokens int `json:"maxInputTokens,omitempty"` -} - -// KiroIDETokenFile is the default path to Kiro IDE's token file -const KiroIDETokenFile = ".aws/sso/cache/kiro-auth-token.json" - -// LoadKiroIDEToken loads token data from Kiro IDE's token file. -func LoadKiroIDEToken() (*KiroTokenData, error) { - homeDir, err := os.UserHomeDir() - if err != nil { - return nil, fmt.Errorf("failed to get home directory: %w", err) - } - - tokenPath := filepath.Join(homeDir, KiroIDETokenFile) - data, err := os.ReadFile(tokenPath) - if err != nil { - return nil, fmt.Errorf("failed to read Kiro IDE token file (%s): %w", tokenPath, err) - } - - var token KiroTokenData - if err := json.Unmarshal(data, &token); err != nil { - return nil, fmt.Errorf("failed to parse Kiro IDE token: %w", err) - } - - if token.AccessToken == "" { - return nil, fmt.Errorf("access token is empty in Kiro IDE token file") - } - - return &token, nil -} - -// LoadKiroTokenFromPath loads token data from a custom path. -// This supports multiple accounts by allowing different token files. -func LoadKiroTokenFromPath(tokenPath string) (*KiroTokenData, error) { - // Expand ~ to home directory - if len(tokenPath) > 0 && tokenPath[0] == '~' { - homeDir, err := os.UserHomeDir() - if err != nil { - return nil, fmt.Errorf("failed to get home directory: %w", err) - } - tokenPath = filepath.Join(homeDir, tokenPath[1:]) - } - - data, err := os.ReadFile(tokenPath) - if err != nil { - return nil, fmt.Errorf("failed to read token file (%s): %w", tokenPath, err) - } - - var token KiroTokenData - if err := json.Unmarshal(data, &token); err != nil { - return nil, fmt.Errorf("failed to parse token file: %w", err) - } - - if token.AccessToken == "" { - return nil, fmt.Errorf("access token is empty in token file") - } - - return &token, nil -} - -// ListKiroTokenFiles lists all Kiro token files in the cache directory. -// This supports multiple accounts by finding all token files. -func ListKiroTokenFiles() ([]string, error) { - homeDir, err := os.UserHomeDir() - if err != nil { - return nil, fmt.Errorf("failed to get home directory: %w", err) - } - - cacheDir := filepath.Join(homeDir, ".aws", "sso", "cache") - - // Check if directory exists - if _, err := os.Stat(cacheDir); os.IsNotExist(err) { - return nil, nil // No token files - } - - entries, err := os.ReadDir(cacheDir) - if err != nil { - return nil, fmt.Errorf("failed to read cache directory: %w", err) - } - - var tokenFiles []string - for _, entry := range entries { - if entry.IsDir() { - continue - } - name := entry.Name() - // Look for kiro token files only (avoid matching unrelated AWS SSO cache files) - if strings.HasSuffix(name, ".json") && strings.HasPrefix(name, "kiro") { - tokenFiles = append(tokenFiles, filepath.Join(cacheDir, name)) - } - } - - return tokenFiles, nil -} - -// LoadAllKiroTokens loads all Kiro tokens from the cache directory. -// This supports multiple accounts. -func LoadAllKiroTokens() ([]*KiroTokenData, error) { - files, err := ListKiroTokenFiles() - if err != nil { - return nil, err - } - - var tokens []*KiroTokenData - for _, file := range files { - token, err := LoadKiroTokenFromPath(file) - if err != nil { - // Skip invalid token files - continue - } - tokens = append(tokens, token) - } - - return tokens, nil -} - -// JWTClaims represents the claims we care about from a JWT token. -// JWT tokens from Kiro/AWS contain user information in the payload. -type JWTClaims struct { - Email string `json:"email,omitempty"` - Sub string `json:"sub,omitempty"` - PreferredUser string `json:"preferred_username,omitempty"` - Name string `json:"name,omitempty"` - Iss string `json:"iss,omitempty"` -} - -// ExtractEmailFromJWT extracts the user's email from a JWT access token. -// JWT tokens typically have format: header.payload.signature -// The payload is base64url-encoded JSON containing user claims. -func ExtractEmailFromJWT(accessToken string) string { - if accessToken == "" { - return "" - } - - // JWT format: header.payload.signature - parts := strings.Split(accessToken, ".") - if len(parts) != 3 { - return "" - } - - // Decode the payload (second part) - payload := parts[1] - - // Add padding if needed (base64url requires padding) - switch len(payload) % 4 { - case 2: - payload += "==" - case 3: - payload += "=" - } - - decoded, err := base64.URLEncoding.DecodeString(payload) - if err != nil { - // Try RawURLEncoding (no padding) - decoded, err = base64.RawURLEncoding.DecodeString(parts[1]) - if err != nil { - return "" - } - } - - var claims JWTClaims - if err := json.Unmarshal(decoded, &claims); err != nil { - return "" - } - - // Return email if available - if claims.Email != "" { - return claims.Email - } - - // Fallback to preferred_username (some providers use this) - if claims.PreferredUser != "" && strings.Contains(claims.PreferredUser, "@") { - return claims.PreferredUser - } - - // Fallback to sub if it looks like an email - if claims.Sub != "" && strings.Contains(claims.Sub, "@") { - return claims.Sub - } - - return "" -} - -// SanitizeEmailForFilename sanitizes an email address for use in a filename. -// Replaces special characters with underscores and prevents path traversal attacks. -// Also handles URL-encoded characters to prevent encoded path traversal attempts. -func SanitizeEmailForFilename(email string) string { - if email == "" { - return "" - } - - result := email - - // First, handle URL-encoded path traversal attempts (%2F, %2E, %5C, etc.) - // This prevents encoded characters from bypassing the sanitization. - // Note: We replace % last to catch any remaining encodings including double-encoding (%252F) - result = strings.ReplaceAll(result, "%2F", "_") // / - result = strings.ReplaceAll(result, "%2f", "_") - result = strings.ReplaceAll(result, "%5C", "_") // \ - result = strings.ReplaceAll(result, "%5c", "_") - result = strings.ReplaceAll(result, "%2E", "_") // . - result = strings.ReplaceAll(result, "%2e", "_") - result = strings.ReplaceAll(result, "%00", "_") // null byte - result = strings.ReplaceAll(result, "%", "_") // Catch remaining % to prevent double-encoding attacks - - // Replace characters that are problematic in filenames - // Keep @ and . in middle but replace other special characters - for _, char := range []string{"/", "\\", ":", "*", "?", "\"", "<", ">", "|", " ", "\x00"} { - result = strings.ReplaceAll(result, char, "_") - } - - // Prevent path traversal: replace leading dots in each path component - // This handles cases like "../../../etc/passwd" → "_.._.._.._etc_passwd" - parts := strings.Split(result, "_") - for i, part := range parts { - for strings.HasPrefix(part, ".") { - part = "_" + part[1:] - } - parts[i] = part - } - result = strings.Join(parts, "_") - - return result -} diff --git a/internal/auth/kiro/oauth_web.go.bak b/internal/auth/kiro/oauth_web.go.bak deleted file mode 100644 index 22d7809b..00000000 --- a/internal/auth/kiro/oauth_web.go.bak +++ /dev/null @@ -1,385 +0,0 @@ -// Package kiro provides OAuth Web authentication for Kiro. -package kiro - -import ( - "context" - "crypto/rand" - "encoding/base64" - "fmt" - "html/template" - "net/http" - "sync" - "time" - - "github.com/gin-gonic/gin" - "github.com/router-for-me/CLIProxyAPI/v6/internal/config" - log "github.com/sirupsen/logrus" -) - -const ( - defaultSessionExpiry = 10 * time.Minute - pollIntervalSeconds = 5 -) - -type authSessionStatus string - -const ( - statusPending authSessionStatus = "pending" - statusSuccess authSessionStatus = "success" - statusFailed authSessionStatus = "failed" -) - -type webAuthSession struct { - stateID string - deviceCode string - userCode string - authURL string - verificationURI string - expiresIn int - interval int - status authSessionStatus - startedAt time.Time - completedAt time.Time - expiresAt time.Time - error string - tokenData *KiroTokenData - ssoClient *SSOOIDCClient - clientID string - clientSecret string - region string - cancelFunc context.CancelFunc -} - -type OAuthWebHandler struct { - cfg *config.Config - sessions map[string]*webAuthSession - mu sync.RWMutex - onTokenObtained func(*KiroTokenData) -} - -func NewOAuthWebHandler(cfg *config.Config) *OAuthWebHandler { - return &OAuthWebHandler{ - cfg: cfg, - sessions: make(map[string]*webAuthSession), - } -} - -func (h *OAuthWebHandler) SetTokenCallback(callback func(*KiroTokenData)) { - h.onTokenObtained = callback -} - -func (h *OAuthWebHandler) RegisterRoutes(router gin.IRouter) { - oauth := router.Group("/v0/oauth/kiro") - { - oauth.GET("/start", h.handleStart) - oauth.GET("/callback", h.handleCallback) - oauth.GET("/status", h.handleStatus) - } -} - -func generateStateID() (string, error) { - b := make([]byte, 16) - if _, err := rand.Read(b); err != nil { - return "", err - } - return base64.RawURLEncoding.EncodeToString(b), nil -} - -func (h *OAuthWebHandler) handleStart(c *gin.Context) { - stateID, err := generateStateID() - if err != nil { - h.renderError(c, "Failed to generate state parameter") - return - } - - region := defaultIDCRegion - startURL := builderIDStartURL - - ssoClient := NewSSOOIDCClient(h.cfg) - - regResp, err := ssoClient.RegisterClientWithRegion(c.Request.Context(), region) - if err != nil { - log.Errorf("OAuth Web: failed to register client: %v", err) - h.renderError(c, fmt.Sprintf("Failed to register client: %v", err)) - return - } - - authResp, err := ssoClient.StartDeviceAuthorizationWithIDC( - c.Request.Context(), - regResp.ClientID, - regResp.ClientSecret, - startURL, - region, - ) - if err != nil { - log.Errorf("OAuth Web: failed to start device authorization: %v", err) - h.renderError(c, fmt.Sprintf("Failed to start device authorization: %v", err)) - return - } - - ctx, cancel := context.WithTimeout(context.Background(), time.Duration(authResp.ExpiresIn)*time.Second) - - session := &webAuthSession{ - stateID: stateID, - deviceCode: authResp.DeviceCode, - userCode: authResp.UserCode, - authURL: authResp.VerificationURIComplete, - verificationURI: authResp.VerificationURI, - expiresIn: authResp.ExpiresIn, - interval: authResp.Interval, - status: statusPending, - startedAt: time.Now(), - ssoClient: ssoClient, - clientID: regResp.ClientID, - clientSecret: regResp.ClientSecret, - region: region, - cancelFunc: cancel, - } - - h.mu.Lock() - h.sessions[stateID] = session - h.mu.Unlock() - - go h.pollForToken(ctx, session) - - h.renderStartPage(c, session) -} - -func (h *OAuthWebHandler) pollForToken(ctx context.Context, session *webAuthSession) { - defer session.cancelFunc() - - interval := time.Duration(session.interval) * time.Second - if interval < time.Duration(pollIntervalSeconds)*time.Second { - interval = time.Duration(pollIntervalSeconds) * time.Second - } - - ticker := time.NewTicker(interval) - defer ticker.Stop() - - for { - select { - case <-ctx.Done(): - h.mu.Lock() - if session.status == statusPending { - session.status = statusFailed - session.error = "Authentication timed out" - } - h.mu.Unlock() - return - case <-ticker.C: - tokenResp, err := h.ssoClient(session).CreateTokenWithRegion( - ctx, - session.clientID, - session.clientSecret, - session.deviceCode, - session.region, - ) - - if err != nil { - errStr := err.Error() - if errStr == ErrAuthorizationPending.Error() { - continue - } - if errStr == ErrSlowDown.Error() { - interval += 5 * time.Second - ticker.Reset(interval) - continue - } - - h.mu.Lock() - session.status = statusFailed - session.error = errStr - session.completedAt = time.Now() - h.mu.Unlock() - - log.Errorf("OAuth Web: token polling failed: %v", err) - return - } - - expiresAt := time.Now().Add(time.Duration(tokenResp.ExpiresIn) * time.Second) - profileArn := session.ssoClient.fetchProfileArn(ctx, tokenResp.AccessToken) - email := FetchUserEmailWithFallback(ctx, h.cfg, tokenResp.AccessToken) - - tokenData := &KiroTokenData{ - AccessToken: tokenResp.AccessToken, - RefreshToken: tokenResp.RefreshToken, - ProfileArn: profileArn, - ExpiresAt: expiresAt.Format(time.RFC3339), - AuthMethod: "builder-id", - Provider: "AWS", - ClientID: session.clientID, - ClientSecret: session.clientSecret, - Email: email, - } - - h.mu.Lock() - session.status = statusSuccess - session.completedAt = time.Now() - session.expiresAt = expiresAt - session.tokenData = tokenData - h.mu.Unlock() - - if h.onTokenObtained != nil { - h.onTokenObtained(tokenData) - } - - log.Infof("OAuth Web: authentication successful for %s", email) - return - } - } -} - -func (h *OAuthWebHandler) ssoClient(session *webAuthSession) *SSOOIDCClient { - return session.ssoClient -} - -func (h *OAuthWebHandler) handleCallback(c *gin.Context) { - stateID := c.Query("state") - errParam := c.Query("error") - - if errParam != "" { - h.renderError(c, errParam) - return - } - - if stateID == "" { - h.renderError(c, "Missing state parameter") - return - } - - h.mu.RLock() - session, exists := h.sessions[stateID] - h.mu.RUnlock() - - if !exists { - h.renderError(c, "Invalid or expired session") - return - } - - if session.status == statusSuccess { - h.renderSuccess(c, session) - } else if session.status == statusFailed { - h.renderError(c, session.error) - } else { - c.Redirect(http.StatusFound, "/v0/oauth/kiro/start") - } -} - -func (h *OAuthWebHandler) handleStatus(c *gin.Context) { - stateID := c.Query("state") - if stateID == "" { - c.JSON(http.StatusBadRequest, gin.H{"error": "missing state parameter"}) - return - } - - h.mu.RLock() - session, exists := h.sessions[stateID] - h.mu.RUnlock() - - if !exists { - c.JSON(http.StatusNotFound, gin.H{"error": "session not found"}) - return - } - - response := gin.H{ - "status": string(session.status), - } - - switch session.status { - case statusPending: - elapsed := time.Since(session.startedAt).Seconds() - remaining := float64(session.expiresIn) - elapsed - if remaining < 0 { - remaining = 0 - } - response["remaining_seconds"] = int(remaining) - case statusSuccess: - response["completed_at"] = session.completedAt.Format(time.RFC3339) - response["expires_at"] = session.expiresAt.Format(time.RFC3339) - case statusFailed: - response["error"] = session.error - response["failed_at"] = session.completedAt.Format(time.RFC3339) - } - - c.JSON(http.StatusOK, response) -} - -func (h *OAuthWebHandler) renderStartPage(c *gin.Context, session *webAuthSession) { - tmpl, err := template.New("start").Parse(oauthWebStartPageHTML) - if err != nil { - log.Errorf("OAuth Web: failed to parse template: %v", err) - c.String(http.StatusInternalServerError, "Template error") - return - } - - data := map[string]interface{}{ - "AuthURL": session.authURL, - "UserCode": session.userCode, - "ExpiresIn": session.expiresIn, - "StateID": session.stateID, - } - - c.Header("Content-Type", "text/html; charset=utf-8") - if err := tmpl.Execute(c.Writer, data); err != nil { - log.Errorf("OAuth Web: failed to render template: %v", err) - } -} - -func (h *OAuthWebHandler) renderError(c *gin.Context, errMsg string) { - tmpl, err := template.New("error").Parse(oauthWebErrorPageHTML) - if err != nil { - log.Errorf("OAuth Web: failed to parse error template: %v", err) - c.String(http.StatusInternalServerError, "Template error") - return - } - - data := map[string]interface{}{ - "Error": errMsg, - } - - c.Header("Content-Type", "text/html; charset=utf-8") - c.Status(http.StatusBadRequest) - if err := tmpl.Execute(c.Writer, data); err != nil { - log.Errorf("OAuth Web: failed to render error template: %v", err) - } -} - -func (h *OAuthWebHandler) renderSuccess(c *gin.Context, session *webAuthSession) { - tmpl, err := template.New("success").Parse(oauthWebSuccessPageHTML) - if err != nil { - log.Errorf("OAuth Web: failed to parse success template: %v", err) - c.String(http.StatusInternalServerError, "Template error") - return - } - - data := map[string]interface{}{ - "ExpiresAt": session.expiresAt.Format(time.RFC3339), - } - - c.Header("Content-Type", "text/html; charset=utf-8") - if err := tmpl.Execute(c.Writer, data); err != nil { - log.Errorf("OAuth Web: failed to render success template: %v", err) - } -} - -func (h *OAuthWebHandler) CleanupExpiredSessions() { - h.mu.Lock() - defer h.mu.Unlock() - - now := time.Now() - for id, session := range h.sessions { - if session.status != statusPending && now.Sub(session.completedAt) > 30*time.Minute { - delete(h.sessions, id) - } else if session.status == statusPending && now.Sub(session.startedAt) > defaultSessionExpiry { - session.cancelFunc() - delete(h.sessions, id) - } - } -} - -func (h *OAuthWebHandler) GetSession(stateID string) (*webAuthSession, bool) { - h.mu.RLock() - defer h.mu.RUnlock() - session, exists := h.sessions[stateID] - return session, exists -} diff --git a/internal/auth/kiro/sso_oidc.go.bak b/internal/auth/kiro/sso_oidc.go.bak deleted file mode 100644 index ab44e55f..00000000 --- a/internal/auth/kiro/sso_oidc.go.bak +++ /dev/null @@ -1,1371 +0,0 @@ -// Package kiro provides AWS SSO OIDC authentication for Kiro. -package kiro - -import ( - "bufio" - "context" - "crypto/rand" - "crypto/sha256" - "encoding/base64" - "encoding/json" - "errors" - "fmt" - "html" - "io" - "net" - "net/http" - "os" - "strings" - "time" - - "github.com/router-for-me/CLIProxyAPI/v6/internal/browser" - "github.com/router-for-me/CLIProxyAPI/v6/internal/config" - "github.com/router-for-me/CLIProxyAPI/v6/internal/util" - log "github.com/sirupsen/logrus" -) - -const ( - // AWS SSO OIDC endpoints - ssoOIDCEndpoint = "https://oidc.us-east-1.amazonaws.com" - - // Kiro's start URL for Builder ID - builderIDStartURL = "https://view.awsapps.com/start" - - // Default region for IDC - defaultIDCRegion = "us-east-1" - - // Polling interval - pollInterval = 5 * time.Second - - // Authorization code flow callback - authCodeCallbackPath = "/oauth/callback" - authCodeCallbackPort = 19877 - - // User-Agent to match official Kiro IDE - kiroUserAgent = "KiroIDE" - - // IDC token refresh headers (matching Kiro IDE behavior) - idcAmzUserAgent = "aws-sdk-js/3.738.0 ua/2.1 os/other lang/js md/browser#unknown_unknown api/sso-oidc#3.738.0 m/E KiroIDE" -) - -// Sentinel errors for OIDC token polling -var ( - ErrAuthorizationPending = errors.New("authorization_pending") - ErrSlowDown = errors.New("slow_down") -) - -// SSOOIDCClient handles AWS SSO OIDC authentication. -type SSOOIDCClient struct { - httpClient *http.Client - cfg *config.Config -} - -// NewSSOOIDCClient creates a new SSO OIDC client. -func NewSSOOIDCClient(cfg *config.Config) *SSOOIDCClient { - client := &http.Client{Timeout: 30 * time.Second} - if cfg != nil { - client = util.SetProxy(&cfg.SDKConfig, client) - } - return &SSOOIDCClient{ - httpClient: client, - cfg: cfg, - } -} - -// RegisterClientResponse from AWS SSO OIDC. -type RegisterClientResponse struct { - ClientID string `json:"clientId"` - ClientSecret string `json:"clientSecret"` - ClientIDIssuedAt int64 `json:"clientIdIssuedAt"` - ClientSecretExpiresAt int64 `json:"clientSecretExpiresAt"` -} - -// StartDeviceAuthResponse from AWS SSO OIDC. -type StartDeviceAuthResponse struct { - DeviceCode string `json:"deviceCode"` - UserCode string `json:"userCode"` - VerificationURI string `json:"verificationUri"` - VerificationURIComplete string `json:"verificationUriComplete"` - ExpiresIn int `json:"expiresIn"` - Interval int `json:"interval"` -} - -// CreateTokenResponse from AWS SSO OIDC. -type CreateTokenResponse struct { - AccessToken string `json:"accessToken"` - TokenType string `json:"tokenType"` - ExpiresIn int `json:"expiresIn"` - RefreshToken string `json:"refreshToken"` -} - -// getOIDCEndpoint returns the OIDC endpoint for the given region. -func getOIDCEndpoint(region string) string { - if region == "" { - region = defaultIDCRegion - } - return fmt.Sprintf("https://oidc.%s.amazonaws.com", region) -} - -// promptInput prompts the user for input with an optional default value. -func promptInput(prompt, defaultValue string) string { - reader := bufio.NewReader(os.Stdin) - if defaultValue != "" { - fmt.Printf("%s [%s]: ", prompt, defaultValue) - } else { - fmt.Printf("%s: ", prompt) - } - input, err := reader.ReadString('\n') - if err != nil { - log.Warnf("Error reading input: %v", err) - return defaultValue - } - input = strings.TrimSpace(input) - if input == "" { - return defaultValue - } - return input -} - -// promptSelect prompts the user to select from options using number input. -func promptSelect(prompt string, options []string) int { - reader := bufio.NewReader(os.Stdin) - - for { - fmt.Println(prompt) - for i, opt := range options { - fmt.Printf(" %d) %s\n", i+1, opt) - } - fmt.Printf("Enter selection (1-%d): ", len(options)) - - input, err := reader.ReadString('\n') - if err != nil { - log.Warnf("Error reading input: %v", err) - return 0 // Default to first option on error - } - input = strings.TrimSpace(input) - - // Parse the selection - var selection int - if _, err := fmt.Sscanf(input, "%d", &selection); err != nil || selection < 1 || selection > len(options) { - fmt.Printf("Invalid selection '%s'. Please enter a number between 1 and %d.\n\n", input, len(options)) - continue - } - return selection - 1 - } -} - -// RegisterClientWithRegion registers a new OIDC client with AWS using a specific region. -func (c *SSOOIDCClient) RegisterClientWithRegion(ctx context.Context, region string) (*RegisterClientResponse, error) { - endpoint := getOIDCEndpoint(region) - - payload := map[string]interface{}{ - "clientName": "Kiro IDE", - "clientType": "public", - "scopes": []string{"codewhisperer:completions", "codewhisperer:analysis", "codewhisperer:conversations", "codewhisperer:transformations", "codewhisperer:taskassist"}, - "grantTypes": []string{"urn:ietf:params:oauth:grant-type:device_code", "refresh_token"}, - } - - body, err := json.Marshal(payload) - if err != nil { - return nil, err - } - - req, err := http.NewRequestWithContext(ctx, http.MethodPost, endpoint+"/client/register", strings.NewReader(string(body))) - if err != nil { - return nil, err - } - req.Header.Set("Content-Type", "application/json") - req.Header.Set("User-Agent", kiroUserAgent) - - resp, err := c.httpClient.Do(req) - if err != nil { - return nil, err - } - defer resp.Body.Close() - - respBody, err := io.ReadAll(resp.Body) - if err != nil { - return nil, err - } - - if resp.StatusCode != http.StatusOK { - log.Debugf("register client failed (status %d): %s", resp.StatusCode, string(respBody)) - return nil, fmt.Errorf("register client failed (status %d)", resp.StatusCode) - } - - var result RegisterClientResponse - if err := json.Unmarshal(respBody, &result); err != nil { - return nil, err - } - - return &result, nil -} - -// StartDeviceAuthorizationWithIDC starts the device authorization flow for IDC. -func (c *SSOOIDCClient) StartDeviceAuthorizationWithIDC(ctx context.Context, clientID, clientSecret, startURL, region string) (*StartDeviceAuthResponse, error) { - endpoint := getOIDCEndpoint(region) - - payload := map[string]string{ - "clientId": clientID, - "clientSecret": clientSecret, - "startUrl": startURL, - } - - body, err := json.Marshal(payload) - if err != nil { - return nil, err - } - - req, err := http.NewRequestWithContext(ctx, http.MethodPost, endpoint+"/device_authorization", strings.NewReader(string(body))) - if err != nil { - return nil, err - } - req.Header.Set("Content-Type", "application/json") - req.Header.Set("User-Agent", kiroUserAgent) - - resp, err := c.httpClient.Do(req) - if err != nil { - return nil, err - } - defer resp.Body.Close() - - respBody, err := io.ReadAll(resp.Body) - if err != nil { - return nil, err - } - - if resp.StatusCode != http.StatusOK { - log.Debugf("start device auth failed (status %d): %s", resp.StatusCode, string(respBody)) - return nil, fmt.Errorf("start device auth failed (status %d)", resp.StatusCode) - } - - var result StartDeviceAuthResponse - if err := json.Unmarshal(respBody, &result); err != nil { - return nil, err - } - - return &result, nil -} - -// CreateTokenWithRegion polls for the access token after user authorization using a specific region. -func (c *SSOOIDCClient) CreateTokenWithRegion(ctx context.Context, clientID, clientSecret, deviceCode, region string) (*CreateTokenResponse, error) { - endpoint := getOIDCEndpoint(region) - - payload := map[string]string{ - "clientId": clientID, - "clientSecret": clientSecret, - "deviceCode": deviceCode, - "grantType": "urn:ietf:params:oauth:grant-type:device_code", - } - - body, err := json.Marshal(payload) - if err != nil { - return nil, err - } - - req, err := http.NewRequestWithContext(ctx, http.MethodPost, endpoint+"/token", strings.NewReader(string(body))) - if err != nil { - return nil, err - } - req.Header.Set("Content-Type", "application/json") - req.Header.Set("User-Agent", kiroUserAgent) - - resp, err := c.httpClient.Do(req) - if err != nil { - return nil, err - } - defer resp.Body.Close() - - respBody, err := io.ReadAll(resp.Body) - if err != nil { - return nil, err - } - - // Check for pending authorization - if resp.StatusCode == http.StatusBadRequest { - var errResp struct { - Error string `json:"error"` - } - if json.Unmarshal(respBody, &errResp) == nil { - if errResp.Error == "authorization_pending" { - return nil, ErrAuthorizationPending - } - if errResp.Error == "slow_down" { - return nil, ErrSlowDown - } - } - log.Debugf("create token failed: %s", string(respBody)) - return nil, fmt.Errorf("create token failed") - } - - if resp.StatusCode != http.StatusOK { - log.Debugf("create token failed (status %d): %s", resp.StatusCode, string(respBody)) - return nil, fmt.Errorf("create token failed (status %d)", resp.StatusCode) - } - - var result CreateTokenResponse - if err := json.Unmarshal(respBody, &result); err != nil { - return nil, err - } - - return &result, nil -} - -// RefreshTokenWithRegion refreshes an access token using the refresh token with a specific region. -func (c *SSOOIDCClient) RefreshTokenWithRegion(ctx context.Context, clientID, clientSecret, refreshToken, region, startURL string) (*KiroTokenData, error) { - endpoint := getOIDCEndpoint(region) - - payload := map[string]string{ - "clientId": clientID, - "clientSecret": clientSecret, - "refreshToken": refreshToken, - "grantType": "refresh_token", - } - - body, err := json.Marshal(payload) - if err != nil { - return nil, err - } - - req, err := http.NewRequestWithContext(ctx, http.MethodPost, endpoint+"/token", strings.NewReader(string(body))) - if err != nil { - return nil, err - } - - // Set headers matching kiro2api's IDC token refresh - // These headers are required for successful IDC token refresh - req.Header.Set("Content-Type", "application/json") - req.Header.Set("Host", fmt.Sprintf("oidc.%s.amazonaws.com", region)) - req.Header.Set("Connection", "keep-alive") - req.Header.Set("x-amz-user-agent", idcAmzUserAgent) - req.Header.Set("Accept", "*/*") - req.Header.Set("Accept-Language", "*") - req.Header.Set("sec-fetch-mode", "cors") - req.Header.Set("User-Agent", "node") - req.Header.Set("Accept-Encoding", "br, gzip, deflate") - - resp, err := c.httpClient.Do(req) - if err != nil { - return nil, err - } - defer resp.Body.Close() - - respBody, err := io.ReadAll(resp.Body) - if err != nil { - return nil, err - } - - if resp.StatusCode != http.StatusOK { - log.Warnf("IDC token refresh failed (status %d): %s", resp.StatusCode, string(respBody)) - return nil, fmt.Errorf("token refresh failed (status %d)", resp.StatusCode) - } - - var result CreateTokenResponse - if err := json.Unmarshal(respBody, &result); err != nil { - return nil, err - } - - expiresAt := time.Now().Add(time.Duration(result.ExpiresIn) * time.Second) - - return &KiroTokenData{ - AccessToken: result.AccessToken, - RefreshToken: result.RefreshToken, - ExpiresAt: expiresAt.Format(time.RFC3339), - AuthMethod: "idc", - Provider: "AWS", - ClientID: clientID, - ClientSecret: clientSecret, - StartURL: startURL, - Region: region, - }, nil -} - -// LoginWithIDC performs the full device code flow for AWS Identity Center (IDC). -func (c *SSOOIDCClient) LoginWithIDC(ctx context.Context, startURL, region string) (*KiroTokenData, error) { - fmt.Println("\n╔══════════════════════════════════════════════════════════╗") - fmt.Println("║ Kiro Authentication (AWS Identity Center) ║") - fmt.Println("╚══════════════════════════════════════════════════════════╝") - - // Step 1: Register client with the specified region - fmt.Println("\nRegistering client...") - regResp, err := c.RegisterClientWithRegion(ctx, region) - if err != nil { - return nil, fmt.Errorf("failed to register client: %w", err) - } - log.Debugf("Client registered: %s", regResp.ClientID) - - // Step 2: Start device authorization with IDC start URL - fmt.Println("Starting device authorization...") - authResp, err := c.StartDeviceAuthorizationWithIDC(ctx, regResp.ClientID, regResp.ClientSecret, startURL, region) - if err != nil { - return nil, fmt.Errorf("failed to start device auth: %w", err) - } - - // Step 3: Show user the verification URL - fmt.Printf("\n") - fmt.Println("════════════════════════════════════════════════════════════") - fmt.Printf(" Confirm the following code in the browser:\n") - fmt.Printf(" Code: %s\n", authResp.UserCode) - fmt.Println("════════════════════════════════════════════════════════════") - fmt.Printf("\n Open this URL: %s\n\n", authResp.VerificationURIComplete) - - // Set incognito mode based on config - if c.cfg != nil { - browser.SetIncognitoMode(c.cfg.IncognitoBrowser) - if !c.cfg.IncognitoBrowser { - log.Info("kiro: using normal browser mode (--no-incognito). Note: You may not be able to select a different account.") - } else { - log.Debug("kiro: using incognito mode for multi-account support") - } - } else { - browser.SetIncognitoMode(true) - log.Debug("kiro: using incognito mode for multi-account support (default)") - } - - // Open browser - if err := browser.OpenURL(authResp.VerificationURIComplete); err != nil { - log.Warnf("Could not open browser automatically: %v", err) - fmt.Println(" Please open the URL manually in your browser.") - } else { - fmt.Println(" (Browser opened automatically)") - } - - // Step 4: Poll for token - fmt.Println("Waiting for authorization...") - - interval := pollInterval - if authResp.Interval > 0 { - interval = time.Duration(authResp.Interval) * time.Second - } - - deadline := time.Now().Add(time.Duration(authResp.ExpiresIn) * time.Second) - - for time.Now().Before(deadline) { - select { - case <-ctx.Done(): - browser.CloseBrowser() - return nil, ctx.Err() - case <-time.After(interval): - tokenResp, err := c.CreateTokenWithRegion(ctx, regResp.ClientID, regResp.ClientSecret, authResp.DeviceCode, region) - if err != nil { - if errors.Is(err, ErrAuthorizationPending) { - fmt.Print(".") - continue - } - if errors.Is(err, ErrSlowDown) { - interval += 5 * time.Second - continue - } - browser.CloseBrowser() - return nil, fmt.Errorf("token creation failed: %w", err) - } - - fmt.Println("\n\n✓ Authorization successful!") - - // Close the browser window - if err := browser.CloseBrowser(); err != nil { - log.Debugf("Failed to close browser: %v", err) - } - - // Step 5: Get profile ARN from CodeWhisperer API - fmt.Println("Fetching profile information...") - profileArn := c.fetchProfileArn(ctx, tokenResp.AccessToken) - - // Fetch user email - email := FetchUserEmailWithFallback(ctx, c.cfg, tokenResp.AccessToken) - if email != "" { - fmt.Printf(" Logged in as: %s\n", email) - } - - expiresAt := time.Now().Add(time.Duration(tokenResp.ExpiresIn) * time.Second) - - return &KiroTokenData{ - AccessToken: tokenResp.AccessToken, - RefreshToken: tokenResp.RefreshToken, - ProfileArn: profileArn, - ExpiresAt: expiresAt.Format(time.RFC3339), - AuthMethod: "idc", - Provider: "AWS", - ClientID: regResp.ClientID, - ClientSecret: regResp.ClientSecret, - Email: email, - StartURL: startURL, - Region: region, - }, nil - } - } - - // Close browser on timeout - if err := browser.CloseBrowser(); err != nil { - log.Debugf("Failed to close browser on timeout: %v", err) - } - return nil, fmt.Errorf("authorization timed out") -} - -// LoginWithMethodSelection prompts the user to select between Builder ID and IDC, then performs the login. -func (c *SSOOIDCClient) LoginWithMethodSelection(ctx context.Context) (*KiroTokenData, error) { - fmt.Println("\n╔══════════════════════════════════════════════════════════╗") - fmt.Println("║ Kiro Authentication (AWS) ║") - fmt.Println("╚══════════════════════════════════════════════════════════╝") - - // Prompt for login method - options := []string{ - "Use with Builder ID (personal AWS account)", - "Use with IDC Account (organization SSO)", - } - selection := promptSelect("\n? Select login method:", options) - - if selection == 0 { - // Builder ID flow - use existing implementation - return c.LoginWithBuilderID(ctx) - } - - // IDC flow - prompt for start URL and region - fmt.Println() - startURL := promptInput("? Enter Start URL", "") - if startURL == "" { - return nil, fmt.Errorf("start URL is required for IDC login") - } - - region := promptInput("? Enter Region", defaultIDCRegion) - - return c.LoginWithIDC(ctx, startURL, region) -} - -// RegisterClient registers a new OIDC client with AWS. -func (c *SSOOIDCClient) RegisterClient(ctx context.Context) (*RegisterClientResponse, error) { - payload := map[string]interface{}{ - "clientName": "Kiro IDE", - "clientType": "public", - "scopes": []string{"codewhisperer:completions", "codewhisperer:analysis", "codewhisperer:conversations", "codewhisperer:transformations", "codewhisperer:taskassist"}, - "grantTypes": []string{"urn:ietf:params:oauth:grant-type:device_code", "refresh_token"}, - } - - body, err := json.Marshal(payload) - if err != nil { - return nil, err - } - - req, err := http.NewRequestWithContext(ctx, http.MethodPost, ssoOIDCEndpoint+"/client/register", strings.NewReader(string(body))) - if err != nil { - return nil, err - } - req.Header.Set("Content-Type", "application/json") - req.Header.Set("User-Agent", kiroUserAgent) - - resp, err := c.httpClient.Do(req) - if err != nil { - return nil, err - } - defer resp.Body.Close() - - respBody, err := io.ReadAll(resp.Body) - if err != nil { - return nil, err - } - - if resp.StatusCode != http.StatusOK { - log.Debugf("register client failed (status %d): %s", resp.StatusCode, string(respBody)) - return nil, fmt.Errorf("register client failed (status %d)", resp.StatusCode) - } - - var result RegisterClientResponse - if err := json.Unmarshal(respBody, &result); err != nil { - return nil, err - } - - return &result, nil -} - -// StartDeviceAuthorization starts the device authorization flow. -func (c *SSOOIDCClient) StartDeviceAuthorization(ctx context.Context, clientID, clientSecret string) (*StartDeviceAuthResponse, error) { - payload := map[string]string{ - "clientId": clientID, - "clientSecret": clientSecret, - "startUrl": builderIDStartURL, - } - - body, err := json.Marshal(payload) - if err != nil { - return nil, err - } - - req, err := http.NewRequestWithContext(ctx, http.MethodPost, ssoOIDCEndpoint+"/device_authorization", strings.NewReader(string(body))) - if err != nil { - return nil, err - } - req.Header.Set("Content-Type", "application/json") - req.Header.Set("User-Agent", kiroUserAgent) - - resp, err := c.httpClient.Do(req) - if err != nil { - return nil, err - } - defer resp.Body.Close() - - respBody, err := io.ReadAll(resp.Body) - if err != nil { - return nil, err - } - - if resp.StatusCode != http.StatusOK { - log.Debugf("start device auth failed (status %d): %s", resp.StatusCode, string(respBody)) - return nil, fmt.Errorf("start device auth failed (status %d)", resp.StatusCode) - } - - var result StartDeviceAuthResponse - if err := json.Unmarshal(respBody, &result); err != nil { - return nil, err - } - - return &result, nil -} - -// CreateToken polls for the access token after user authorization. -func (c *SSOOIDCClient) CreateToken(ctx context.Context, clientID, clientSecret, deviceCode string) (*CreateTokenResponse, error) { - payload := map[string]string{ - "clientId": clientID, - "clientSecret": clientSecret, - "deviceCode": deviceCode, - "grantType": "urn:ietf:params:oauth:grant-type:device_code", - } - - body, err := json.Marshal(payload) - if err != nil { - return nil, err - } - - req, err := http.NewRequestWithContext(ctx, http.MethodPost, ssoOIDCEndpoint+"/token", strings.NewReader(string(body))) - if err != nil { - return nil, err - } - req.Header.Set("Content-Type", "application/json") - req.Header.Set("User-Agent", kiroUserAgent) - - resp, err := c.httpClient.Do(req) - if err != nil { - return nil, err - } - defer resp.Body.Close() - - respBody, err := io.ReadAll(resp.Body) - if err != nil { - return nil, err - } - - // Check for pending authorization - if resp.StatusCode == http.StatusBadRequest { - var errResp struct { - Error string `json:"error"` - } - if json.Unmarshal(respBody, &errResp) == nil { - if errResp.Error == "authorization_pending" { - return nil, ErrAuthorizationPending - } - if errResp.Error == "slow_down" { - return nil, ErrSlowDown - } - } - log.Debugf("create token failed: %s", string(respBody)) - return nil, fmt.Errorf("create token failed") - } - - if resp.StatusCode != http.StatusOK { - log.Debugf("create token failed (status %d): %s", resp.StatusCode, string(respBody)) - return nil, fmt.Errorf("create token failed (status %d)", resp.StatusCode) - } - - var result CreateTokenResponse - if err := json.Unmarshal(respBody, &result); err != nil { - return nil, err - } - - return &result, nil -} - -// RefreshToken refreshes an access token using the refresh token. -func (c *SSOOIDCClient) RefreshToken(ctx context.Context, clientID, clientSecret, refreshToken string) (*KiroTokenData, error) { - payload := map[string]string{ - "clientId": clientID, - "clientSecret": clientSecret, - "refreshToken": refreshToken, - "grantType": "refresh_token", - } - - body, err := json.Marshal(payload) - if err != nil { - return nil, err - } - - req, err := http.NewRequestWithContext(ctx, http.MethodPost, ssoOIDCEndpoint+"/token", strings.NewReader(string(body))) - if err != nil { - return nil, err - } - req.Header.Set("Content-Type", "application/json") - req.Header.Set("User-Agent", kiroUserAgent) - - resp, err := c.httpClient.Do(req) - if err != nil { - return nil, err - } - defer resp.Body.Close() - - respBody, err := io.ReadAll(resp.Body) - if err != nil { - return nil, err - } - - if resp.StatusCode != http.StatusOK { - log.Debugf("token refresh failed (status %d): %s", resp.StatusCode, string(respBody)) - return nil, fmt.Errorf("token refresh failed (status %d)", resp.StatusCode) - } - - var result CreateTokenResponse - if err := json.Unmarshal(respBody, &result); err != nil { - return nil, err - } - - expiresAt := time.Now().Add(time.Duration(result.ExpiresIn) * time.Second) - - return &KiroTokenData{ - AccessToken: result.AccessToken, - RefreshToken: result.RefreshToken, - ExpiresAt: expiresAt.Format(time.RFC3339), - AuthMethod: "builder-id", - Provider: "AWS", - ClientID: clientID, - ClientSecret: clientSecret, - }, nil -} - -// LoginWithBuilderID performs the full device code flow for AWS Builder ID. -func (c *SSOOIDCClient) LoginWithBuilderID(ctx context.Context) (*KiroTokenData, error) { - fmt.Println("\n╔══════════════════════════════════════════════════════════╗") - fmt.Println("║ Kiro Authentication (AWS Builder ID) ║") - fmt.Println("╚══════════════════════════════════════════════════════════╝") - - // Step 1: Register client - fmt.Println("\nRegistering client...") - regResp, err := c.RegisterClient(ctx) - if err != nil { - return nil, fmt.Errorf("failed to register client: %w", err) - } - log.Debugf("Client registered: %s", regResp.ClientID) - - // Step 2: Start device authorization - fmt.Println("Starting device authorization...") - authResp, err := c.StartDeviceAuthorization(ctx, regResp.ClientID, regResp.ClientSecret) - if err != nil { - return nil, fmt.Errorf("failed to start device auth: %w", err) - } - - // Step 3: Show user the verification URL - fmt.Printf("\n") - fmt.Println("════════════════════════════════════════════════════════════") - fmt.Printf(" Open this URL in your browser:\n") - fmt.Printf(" %s\n", authResp.VerificationURIComplete) - fmt.Println("════════════════════════════════════════════════════════════") - fmt.Printf("\n Or go to: %s\n", authResp.VerificationURI) - fmt.Printf(" And enter code: %s\n\n", authResp.UserCode) - - // Set incognito mode based on config (defaults to true for Kiro, can be overridden with --no-incognito) - // Incognito mode enables multi-account support by bypassing cached sessions - if c.cfg != nil { - browser.SetIncognitoMode(c.cfg.IncognitoBrowser) - if !c.cfg.IncognitoBrowser { - log.Info("kiro: using normal browser mode (--no-incognito). Note: You may not be able to select a different account.") - } else { - log.Debug("kiro: using incognito mode for multi-account support") - } - } else { - browser.SetIncognitoMode(true) // Default to incognito if no config - log.Debug("kiro: using incognito mode for multi-account support (default)") - } - - // Open browser using cross-platform browser package - if err := browser.OpenURL(authResp.VerificationURIComplete); err != nil { - log.Warnf("Could not open browser automatically: %v", err) - fmt.Println(" Please open the URL manually in your browser.") - } else { - fmt.Println(" (Browser opened automatically)") - } - - // Step 4: Poll for token - fmt.Println("Waiting for authorization...") - - interval := pollInterval - if authResp.Interval > 0 { - interval = time.Duration(authResp.Interval) * time.Second - } - - deadline := time.Now().Add(time.Duration(authResp.ExpiresIn) * time.Second) - - for time.Now().Before(deadline) { - select { - case <-ctx.Done(): - browser.CloseBrowser() // Cleanup on cancel - return nil, ctx.Err() - case <-time.After(interval): - tokenResp, err := c.CreateToken(ctx, regResp.ClientID, regResp.ClientSecret, authResp.DeviceCode) - if err != nil { - if errors.Is(err, ErrAuthorizationPending) { - fmt.Print(".") - continue - } - if errors.Is(err, ErrSlowDown) { - interval += 5 * time.Second - continue - } - // Close browser on error before returning - browser.CloseBrowser() - return nil, fmt.Errorf("token creation failed: %w", err) - } - - fmt.Println("\n\n✓ Authorization successful!") - - // Close the browser window - if err := browser.CloseBrowser(); err != nil { - log.Debugf("Failed to close browser: %v", err) - } - - // Step 5: Get profile ARN from CodeWhisperer API - fmt.Println("Fetching profile information...") - profileArn := c.fetchProfileArn(ctx, tokenResp.AccessToken) - - // Fetch user email (tries CodeWhisperer API first, then userinfo endpoint, then JWT parsing) - email := FetchUserEmailWithFallback(ctx, c.cfg, tokenResp.AccessToken) - if email != "" { - fmt.Printf(" Logged in as: %s\n", email) - } - - expiresAt := time.Now().Add(time.Duration(tokenResp.ExpiresIn) * time.Second) - - return &KiroTokenData{ - AccessToken: tokenResp.AccessToken, - RefreshToken: tokenResp.RefreshToken, - ProfileArn: profileArn, - ExpiresAt: expiresAt.Format(time.RFC3339), - AuthMethod: "builder-id", - Provider: "AWS", - ClientID: regResp.ClientID, - ClientSecret: regResp.ClientSecret, - Email: email, - }, nil - } - } - - // Close browser on timeout for better UX - if err := browser.CloseBrowser(); err != nil { - log.Debugf("Failed to close browser on timeout: %v", err) - } - return nil, fmt.Errorf("authorization timed out") -} - -// FetchUserEmail retrieves the user's email from AWS SSO OIDC userinfo endpoint. -// Falls back to JWT parsing if userinfo fails. -func (c *SSOOIDCClient) FetchUserEmail(ctx context.Context, accessToken string) string { - // Method 1: Try userinfo endpoint (standard OIDC) - email := c.tryUserInfoEndpoint(ctx, accessToken) - if email != "" { - return email - } - - // Method 2: Fallback to JWT parsing - return ExtractEmailFromJWT(accessToken) -} - -// tryUserInfoEndpoint attempts to get user info from AWS SSO OIDC userinfo endpoint. -func (c *SSOOIDCClient) tryUserInfoEndpoint(ctx context.Context, accessToken string) string { - req, err := http.NewRequestWithContext(ctx, http.MethodGet, ssoOIDCEndpoint+"/userinfo", nil) - if err != nil { - return "" - } - req.Header.Set("Authorization", "Bearer "+accessToken) - req.Header.Set("Accept", "application/json") - - resp, err := c.httpClient.Do(req) - if err != nil { - log.Debugf("userinfo request failed: %v", err) - return "" - } - defer resp.Body.Close() - - if resp.StatusCode != http.StatusOK { - respBody, _ := io.ReadAll(resp.Body) - log.Debugf("userinfo endpoint returned status %d: %s", resp.StatusCode, string(respBody)) - return "" - } - - respBody, err := io.ReadAll(resp.Body) - if err != nil { - return "" - } - - log.Debugf("userinfo response: %s", string(respBody)) - - var userInfo struct { - Email string `json:"email"` - Sub string `json:"sub"` - PreferredUsername string `json:"preferred_username"` - Name string `json:"name"` - } - - if err := json.Unmarshal(respBody, &userInfo); err != nil { - return "" - } - - if userInfo.Email != "" { - return userInfo.Email - } - if userInfo.PreferredUsername != "" && strings.Contains(userInfo.PreferredUsername, "@") { - return userInfo.PreferredUsername - } - return "" -} - -// fetchProfileArn retrieves the profile ARN from CodeWhisperer API. -// This is needed for file naming since AWS SSO OIDC doesn't return profile info. -func (c *SSOOIDCClient) fetchProfileArn(ctx context.Context, accessToken string) string { - // Try ListProfiles API first - profileArn := c.tryListProfiles(ctx, accessToken) - if profileArn != "" { - return profileArn - } - - // Fallback: Try ListAvailableCustomizations - return c.tryListCustomizations(ctx, accessToken) -} - -func (c *SSOOIDCClient) tryListProfiles(ctx context.Context, accessToken string) string { - payload := map[string]interface{}{ - "origin": "AI_EDITOR", - } - - body, err := json.Marshal(payload) - if err != nil { - return "" - } - - req, err := http.NewRequestWithContext(ctx, http.MethodPost, "https://codewhisperer.us-east-1.amazonaws.com", strings.NewReader(string(body))) - if err != nil { - return "" - } - - req.Header.Set("Content-Type", "application/x-amz-json-1.0") - req.Header.Set("x-amz-target", "AmazonCodeWhispererService.ListProfiles") - req.Header.Set("Authorization", "Bearer "+accessToken) - req.Header.Set("Accept", "application/json") - - resp, err := c.httpClient.Do(req) - if err != nil { - return "" - } - defer resp.Body.Close() - - respBody, _ := io.ReadAll(resp.Body) - - if resp.StatusCode != http.StatusOK { - log.Debugf("ListProfiles failed (status %d): %s", resp.StatusCode, string(respBody)) - return "" - } - - log.Debugf("ListProfiles response: %s", string(respBody)) - - var result struct { - Profiles []struct { - Arn string `json:"arn"` - } `json:"profiles"` - ProfileArn string `json:"profileArn"` - } - - if err := json.Unmarshal(respBody, &result); err != nil { - return "" - } - - if result.ProfileArn != "" { - return result.ProfileArn - } - - if len(result.Profiles) > 0 { - return result.Profiles[0].Arn - } - - return "" -} - -func (c *SSOOIDCClient) tryListCustomizations(ctx context.Context, accessToken string) string { - payload := map[string]interface{}{ - "origin": "AI_EDITOR", - } - - body, err := json.Marshal(payload) - if err != nil { - return "" - } - - req, err := http.NewRequestWithContext(ctx, http.MethodPost, "https://codewhisperer.us-east-1.amazonaws.com", strings.NewReader(string(body))) - if err != nil { - return "" - } - - req.Header.Set("Content-Type", "application/x-amz-json-1.0") - req.Header.Set("x-amz-target", "AmazonCodeWhispererService.ListAvailableCustomizations") - req.Header.Set("Authorization", "Bearer "+accessToken) - req.Header.Set("Accept", "application/json") - - resp, err := c.httpClient.Do(req) - if err != nil { - return "" - } - defer resp.Body.Close() - - respBody, _ := io.ReadAll(resp.Body) - - if resp.StatusCode != http.StatusOK { - log.Debugf("ListAvailableCustomizations failed (status %d): %s", resp.StatusCode, string(respBody)) - return "" - } - - log.Debugf("ListAvailableCustomizations response: %s", string(respBody)) - - var result struct { - Customizations []struct { - Arn string `json:"arn"` - } `json:"customizations"` - ProfileArn string `json:"profileArn"` - } - - if err := json.Unmarshal(respBody, &result); err != nil { - return "" - } - - if result.ProfileArn != "" { - return result.ProfileArn - } - - if len(result.Customizations) > 0 { - return result.Customizations[0].Arn - } - - return "" -} - -// RegisterClientForAuthCode registers a new OIDC client for authorization code flow. -func (c *SSOOIDCClient) RegisterClientForAuthCode(ctx context.Context, redirectURI string) (*RegisterClientResponse, error) { - payload := map[string]interface{}{ - "clientName": "Kiro IDE", - "clientType": "public", - "scopes": []string{"codewhisperer:completions", "codewhisperer:analysis", "codewhisperer:conversations", "codewhisperer:transformations", "codewhisperer:taskassist"}, - "grantTypes": []string{"authorization_code", "refresh_token"}, - "redirectUris": []string{redirectURI}, - "issuerUrl": builderIDStartURL, - } - - body, err := json.Marshal(payload) - if err != nil { - return nil, err - } - - req, err := http.NewRequestWithContext(ctx, http.MethodPost, ssoOIDCEndpoint+"/client/register", strings.NewReader(string(body))) - if err != nil { - return nil, err - } - req.Header.Set("Content-Type", "application/json") - req.Header.Set("User-Agent", kiroUserAgent) - - resp, err := c.httpClient.Do(req) - if err != nil { - return nil, err - } - defer resp.Body.Close() - - respBody, err := io.ReadAll(resp.Body) - if err != nil { - return nil, err - } - - if resp.StatusCode != http.StatusOK { - log.Debugf("register client for auth code failed (status %d): %s", resp.StatusCode, string(respBody)) - return nil, fmt.Errorf("register client failed (status %d)", resp.StatusCode) - } - - var result RegisterClientResponse - if err := json.Unmarshal(respBody, &result); err != nil { - return nil, err - } - - return &result, nil -} - -// AuthCodeCallbackResult contains the result from authorization code callback. -type AuthCodeCallbackResult struct { - Code string - State string - Error string -} - -// startAuthCodeCallbackServer starts a local HTTP server to receive the authorization code callback. -func (c *SSOOIDCClient) startAuthCodeCallbackServer(ctx context.Context, expectedState string) (string, <-chan AuthCodeCallbackResult, error) { - // Try to find an available port - listener, err := net.Listen("tcp", fmt.Sprintf("127.0.0.1:%d", authCodeCallbackPort)) - if err != nil { - // Try with dynamic port - log.Warnf("sso oidc: default port %d is busy, falling back to dynamic port", authCodeCallbackPort) - listener, err = net.Listen("tcp", "127.0.0.1:0") - if err != nil { - return "", nil, fmt.Errorf("failed to start callback server: %w", err) - } - } - - port := listener.Addr().(*net.TCPAddr).Port - redirectURI := fmt.Sprintf("http://127.0.0.1:%d%s", port, authCodeCallbackPath) - resultChan := make(chan AuthCodeCallbackResult, 1) - - server := &http.Server{ - ReadHeaderTimeout: 10 * time.Second, - } - - mux := http.NewServeMux() - mux.HandleFunc(authCodeCallbackPath, func(w http.ResponseWriter, r *http.Request) { - code := r.URL.Query().Get("code") - state := r.URL.Query().Get("state") - errParam := r.URL.Query().Get("error") - - // Send response to browser - w.Header().Set("Content-Type", "text/html; charset=utf-8") - if errParam != "" { - w.WriteHeader(http.StatusBadRequest) - fmt.Fprintf(w, ` -
Error: %s
You can close this window.
`, html.EscapeString(errParam)) - resultChan <- AuthCodeCallbackResult{Error: errParam} - return - } - - if state != expectedState { - w.WriteHeader(http.StatusBadRequest) - fmt.Fprint(w, ` -Invalid state parameter
You can close this window.
`) - resultChan <- AuthCodeCallbackResult{Error: "state mismatch"} - return - } - - fmt.Fprint(w, ` -You can close this window and return to the terminal.
-`) - resultChan <- AuthCodeCallbackResult{Code: code, State: state} - }) - - server.Handler = mux - - go func() { - if err := server.Serve(listener); err != nil && err != http.ErrServerClosed { - log.Debugf("auth code callback server error: %v", err) - } - }() - - go func() { - select { - case <-ctx.Done(): - case <-time.After(10 * time.Minute): - case <-resultChan: - } - _ = server.Shutdown(context.Background()) - }() - - return redirectURI, resultChan, nil -} - -// generatePKCEForAuthCode generates PKCE code verifier and challenge for authorization code flow. -func generatePKCEForAuthCode() (verifier, challenge string, err error) { - b := make([]byte, 32) - if _, err := rand.Read(b); err != nil { - return "", "", fmt.Errorf("failed to generate random bytes: %w", err) - } - verifier = base64.RawURLEncoding.EncodeToString(b) - h := sha256.Sum256([]byte(verifier)) - challenge = base64.RawURLEncoding.EncodeToString(h[:]) - return verifier, challenge, nil -} - -// generateStateForAuthCode generates a random state parameter. -func generateStateForAuthCode() (string, error) { - b := make([]byte, 16) - if _, err := rand.Read(b); err != nil { - return "", err - } - return base64.RawURLEncoding.EncodeToString(b), nil -} - -// CreateTokenWithAuthCode exchanges authorization code for tokens. -func (c *SSOOIDCClient) CreateTokenWithAuthCode(ctx context.Context, clientID, clientSecret, code, codeVerifier, redirectURI string) (*CreateTokenResponse, error) { - payload := map[string]string{ - "clientId": clientID, - "clientSecret": clientSecret, - "code": code, - "codeVerifier": codeVerifier, - "redirectUri": redirectURI, - "grantType": "authorization_code", - } - - body, err := json.Marshal(payload) - if err != nil { - return nil, err - } - - req, err := http.NewRequestWithContext(ctx, http.MethodPost, ssoOIDCEndpoint+"/token", strings.NewReader(string(body))) - if err != nil { - return nil, err - } - req.Header.Set("Content-Type", "application/json") - req.Header.Set("User-Agent", kiroUserAgent) - - resp, err := c.httpClient.Do(req) - if err != nil { - return nil, err - } - defer resp.Body.Close() - - respBody, err := io.ReadAll(resp.Body) - if err != nil { - return nil, err - } - - if resp.StatusCode != http.StatusOK { - log.Debugf("create token with auth code failed (status %d): %s", resp.StatusCode, string(respBody)) - return nil, fmt.Errorf("create token failed (status %d)", resp.StatusCode) - } - - var result CreateTokenResponse - if err := json.Unmarshal(respBody, &result); err != nil { - return nil, err - } - - return &result, nil -} - -// LoginWithBuilderIDAuthCode performs the authorization code flow for AWS Builder ID. -// This provides a better UX than device code flow as it uses automatic browser callback. -func (c *SSOOIDCClient) LoginWithBuilderIDAuthCode(ctx context.Context) (*KiroTokenData, error) { - fmt.Println("\n╔══════════════════════════════════════════════════════════╗") - fmt.Println("║ Kiro Authentication (AWS Builder ID - Auth Code) ║") - fmt.Println("╚══════════════════════════════════════════════════════════╝") - - // Step 1: Generate PKCE and state - codeVerifier, codeChallenge, err := generatePKCEForAuthCode() - if err != nil { - return nil, fmt.Errorf("failed to generate PKCE: %w", err) - } - - state, err := generateStateForAuthCode() - if err != nil { - return nil, fmt.Errorf("failed to generate state: %w", err) - } - - // Step 2: Start callback server - fmt.Println("\nStarting callback server...") - redirectURI, resultChan, err := c.startAuthCodeCallbackServer(ctx, state) - if err != nil { - return nil, fmt.Errorf("failed to start callback server: %w", err) - } - log.Debugf("Callback server started, redirect URI: %s", redirectURI) - - // Step 3: Register client with auth code grant type - fmt.Println("Registering client...") - regResp, err := c.RegisterClientForAuthCode(ctx, redirectURI) - if err != nil { - return nil, fmt.Errorf("failed to register client: %w", err) - } - log.Debugf("Client registered: %s", regResp.ClientID) - - // Step 4: Build authorization URL - scopes := "codewhisperer:completions,codewhisperer:analysis,codewhisperer:conversations" - authURL := fmt.Sprintf("%s/authorize?response_type=code&client_id=%s&redirect_uri=%s&scopes=%s&state=%s&code_challenge=%s&code_challenge_method=S256", - ssoOIDCEndpoint, - regResp.ClientID, - redirectURI, - scopes, - state, - codeChallenge, - ) - - // Step 5: Open browser - fmt.Println("\n════════════════════════════════════════════════════════════") - fmt.Println(" Opening browser for authentication...") - fmt.Println("════════════════════════════════════════════════════════════") - fmt.Printf("\n URL: %s\n\n", authURL) - - // Set incognito mode - if c.cfg != nil { - browser.SetIncognitoMode(c.cfg.IncognitoBrowser) - } else { - browser.SetIncognitoMode(true) - } - - if err := browser.OpenURL(authURL); err != nil { - log.Warnf("Could not open browser automatically: %v", err) - fmt.Println(" ⚠ Could not open browser automatically.") - fmt.Println(" Please open the URL above in your browser manually.") - } else { - fmt.Println(" (Browser opened automatically)") - } - - fmt.Println("\n Waiting for authorization callback...") - - // Step 6: Wait for callback - select { - case <-ctx.Done(): - browser.CloseBrowser() - return nil, ctx.Err() - case <-time.After(10 * time.Minute): - browser.CloseBrowser() - return nil, fmt.Errorf("authorization timed out") - case result := <-resultChan: - if result.Error != "" { - browser.CloseBrowser() - return nil, fmt.Errorf("authorization failed: %s", result.Error) - } - - fmt.Println("\n✓ Authorization received!") - - // Close browser - if err := browser.CloseBrowser(); err != nil { - log.Debugf("Failed to close browser: %v", err) - } - - // Step 7: Exchange code for tokens - fmt.Println("Exchanging code for tokens...") - tokenResp, err := c.CreateTokenWithAuthCode(ctx, regResp.ClientID, regResp.ClientSecret, result.Code, codeVerifier, redirectURI) - if err != nil { - return nil, fmt.Errorf("failed to exchange code for tokens: %w", err) - } - - fmt.Println("\n✓ Authentication successful!") - - // Step 8: Get profile ARN - fmt.Println("Fetching profile information...") - profileArn := c.fetchProfileArn(ctx, tokenResp.AccessToken) - - // Fetch user email (tries CodeWhisperer API first, then userinfo endpoint, then JWT parsing) - email := FetchUserEmailWithFallback(ctx, c.cfg, tokenResp.AccessToken) - if email != "" { - fmt.Printf(" Logged in as: %s\n", email) - } - - expiresAt := time.Now().Add(time.Duration(tokenResp.ExpiresIn) * time.Second) - - return &KiroTokenData{ - AccessToken: tokenResp.AccessToken, - RefreshToken: tokenResp.RefreshToken, - ProfileArn: profileArn, - ExpiresAt: expiresAt.Format(time.RFC3339), - AuthMethod: "builder-id", - Provider: "AWS", - ClientID: regResp.ClientID, - ClientSecret: regResp.ClientSecret, - Email: email, - }, nil - } -} diff --git a/internal/translator/kiro/common/utf8_stream.go b/internal/translator/kiro/common/utf8_stream.go deleted file mode 100644 index b8d24c82..00000000 --- a/internal/translator/kiro/common/utf8_stream.go +++ /dev/null @@ -1,97 +0,0 @@ -package common - -import ( - "unicode/utf8" -) - -type UTF8StreamParser struct { - buffer []byte -} - -func NewUTF8StreamParser() *UTF8StreamParser { - return &UTF8StreamParser{ - buffer: make([]byte, 0, 64), - } -} - -func (p *UTF8StreamParser) Write(data []byte) { - p.buffer = append(p.buffer, data...) -} - -func (p *UTF8StreamParser) Read() (string, bool) { - if len(p.buffer) == 0 { - return "", false - } - - validLen := p.findValidUTF8End(p.buffer) - if validLen == 0 { - return "", false - } - - result := string(p.buffer[:validLen]) - p.buffer = p.buffer[validLen:] - - return result, true -} - -func (p *UTF8StreamParser) Flush() string { - if len(p.buffer) == 0 { - return "" - } - result := string(p.buffer) - p.buffer = p.buffer[:0] - return result -} - -func (p *UTF8StreamParser) Reset() { - p.buffer = p.buffer[:0] -} - -func (p *UTF8StreamParser) findValidUTF8End(data []byte) int { - if len(data) == 0 { - return 0 - } - - end := len(data) - for i := 1; i <= 3 && i <= len(data); i++ { - b := data[len(data)-i] - if b&0x80 == 0 { - break - } - if b&0xC0 == 0xC0 { - size := p.utf8CharSize(b) - available := i - if size > available { - end = len(data) - i - } - break - } - } - - if end > 0 && !utf8.Valid(data[:end]) { - for i := end - 1; i >= 0; i-- { - if utf8.Valid(data[:i+1]) { - return i + 1 - } - } - return 0 - } - - return end -} - -func (p *UTF8StreamParser) utf8CharSize(b byte) int { - if b&0x80 == 0 { - return 1 - } - if b&0xE0 == 0xC0 { - return 2 - } - if b&0xF0 == 0xE0 { - return 3 - } - if b&0xF8 == 0xF0 { - return 4 - } - return 1 -} diff --git a/internal/translator/kiro/common/utf8_stream_test.go b/internal/translator/kiro/common/utf8_stream_test.go deleted file mode 100644 index 23e80989..00000000 --- a/internal/translator/kiro/common/utf8_stream_test.go +++ /dev/null @@ -1,402 +0,0 @@ -package common - -import ( - "strings" - "sync" - "testing" - "unicode/utf8" -) - -func TestNewUTF8StreamParser(t *testing.T) { - p := NewUTF8StreamParser() - if p == nil { - t.Fatal("expected non-nil UTF8StreamParser") - } - if p.buffer == nil { - t.Error("expected non-nil buffer") - } -} - -func TestWrite(t *testing.T) { - p := NewUTF8StreamParser() - p.Write([]byte("hello")) - - result, ok := p.Read() - if !ok { - t.Error("expected ok to be true") - } - if result != "hello" { - t.Errorf("expected 'hello', got '%s'", result) - } -} - -func TestWrite_MultipleWrites(t *testing.T) { - p := NewUTF8StreamParser() - p.Write([]byte("hel")) - p.Write([]byte("lo")) - - result, ok := p.Read() - if !ok { - t.Error("expected ok to be true") - } - if result != "hello" { - t.Errorf("expected 'hello', got '%s'", result) - } -} - -func TestRead_EmptyBuffer(t *testing.T) { - p := NewUTF8StreamParser() - result, ok := p.Read() - if ok { - t.Error("expected ok to be false for empty buffer") - } - if result != "" { - t.Errorf("expected empty string, got '%s'", result) - } -} - -func TestRead_IncompleteUTF8(t *testing.T) { - p := NewUTF8StreamParser() - - // Write incomplete multi-byte UTF-8 character - // 中 (U+4E2D) = E4 B8 AD - p.Write([]byte{0xE4, 0xB8}) - - result, ok := p.Read() - if ok { - t.Error("expected ok to be false for incomplete UTF-8") - } - if result != "" { - t.Errorf("expected empty string, got '%s'", result) - } - - // Complete the character - p.Write([]byte{0xAD}) - result, ok = p.Read() - if !ok { - t.Error("expected ok to be true after completing UTF-8") - } - if result != "中" { - t.Errorf("expected '中', got '%s'", result) - } -} - -func TestRead_MixedASCIIAndUTF8(t *testing.T) { - p := NewUTF8StreamParser() - p.Write([]byte("Hello 世界")) - - result, ok := p.Read() - if !ok { - t.Error("expected ok to be true") - } - if result != "Hello 世界" { - t.Errorf("expected 'Hello 世界', got '%s'", result) - } -} - -func TestRead_PartialMultibyteAtEnd(t *testing.T) { - p := NewUTF8StreamParser() - // "Hello" + partial "世" (E4 B8 96) - p.Write([]byte("Hello")) - p.Write([]byte{0xE4, 0xB8}) - - result, ok := p.Read() - if !ok { - t.Error("expected ok to be true for valid portion") - } - if result != "Hello" { - t.Errorf("expected 'Hello', got '%s'", result) - } - - // Complete the character - p.Write([]byte{0x96}) - result, ok = p.Read() - if !ok { - t.Error("expected ok to be true after completing") - } - if result != "世" { - t.Errorf("expected '世', got '%s'", result) - } -} - -func TestFlush(t *testing.T) { - p := NewUTF8StreamParser() - p.Write([]byte("hello")) - - result := p.Flush() - if result != "hello" { - t.Errorf("expected 'hello', got '%s'", result) - } - - // Verify buffer is cleared - result2, ok := p.Read() - if ok { - t.Error("expected ok to be false after flush") - } - if result2 != "" { - t.Errorf("expected empty string after flush, got '%s'", result2) - } -} - -func TestFlush_EmptyBuffer(t *testing.T) { - p := NewUTF8StreamParser() - result := p.Flush() - if result != "" { - t.Errorf("expected empty string, got '%s'", result) - } -} - -func TestFlush_IncompleteUTF8(t *testing.T) { - p := NewUTF8StreamParser() - p.Write([]byte{0xE4, 0xB8}) - - result := p.Flush() - // Flush returns everything including incomplete bytes - if len(result) != 2 { - t.Errorf("expected 2 bytes flushed, got %d", len(result)) - } -} - -func TestReset(t *testing.T) { - p := NewUTF8StreamParser() - p.Write([]byte("hello")) - p.Reset() - - result, ok := p.Read() - if ok { - t.Error("expected ok to be false after reset") - } - if result != "" { - t.Errorf("expected empty string after reset, got '%s'", result) - } -} - -func TestUtf8CharSize(t *testing.T) { - p := NewUTF8StreamParser() - - testCases := []struct { - b byte - expected int - }{ - {0x00, 1}, // ASCII - {0x7F, 1}, // ASCII max - {0xC0, 2}, // 2-byte start - {0xDF, 2}, // 2-byte start - {0xE0, 3}, // 3-byte start - {0xEF, 3}, // 3-byte start - {0xF0, 4}, // 4-byte start - {0xF7, 4}, // 4-byte start - {0x80, 1}, // Continuation byte (fallback) - } - - for _, tc := range testCases { - size := p.utf8CharSize(tc.b) - if size != tc.expected { - t.Errorf("utf8CharSize(0x%X) = %d, expected %d", tc.b, size, tc.expected) - } - } -} - -func TestStreamingScenario(t *testing.T) { - p := NewUTF8StreamParser() - - // Simulate streaming: "Hello, 世界! 🌍" - chunks := [][]byte{ - []byte("Hello, "), - {0xE4, 0xB8}, // partial 世 - {0x96, 0xE7}, // complete 世, partial 界 - {0x95, 0x8C}, // complete 界 - []byte("! "), - {0xF0, 0x9F}, // partial 🌍 - {0x8C, 0x8D}, // complete 🌍 - } - - var results []string - for _, chunk := range chunks { - p.Write(chunk) - if result, ok := p.Read(); ok { - results = append(results, result) - } - } - - combined := strings.Join(results, "") - if combined != "Hello, 世界! 🌍" { - t.Errorf("expected 'Hello, 世界! 🌍', got '%s'", combined) - } -} - -func TestValidUTF8Output(t *testing.T) { - p := NewUTF8StreamParser() - - testStrings := []string{ - "Hello World", - "你好世界", - "こんにちは", - "🎉🎊🎁", - "Mixed 混合 Текст ტექსტი", - } - - for _, s := range testStrings { - p.Reset() - p.Write([]byte(s)) - result, ok := p.Read() - if !ok { - t.Errorf("expected ok for '%s'", s) - } - if !utf8.ValidString(result) { - t.Errorf("invalid UTF-8 output for input '%s'", s) - } - if result != s { - t.Errorf("expected '%s', got '%s'", s, result) - } - } -} - -func TestLargeData(t *testing.T) { - p := NewUTF8StreamParser() - - // Generate large UTF-8 string - var builder strings.Builder - for i := 0; i < 1000; i++ { - builder.WriteString("Hello 世界! ") - } - largeString := builder.String() - - p.Write([]byte(largeString)) - result, ok := p.Read() - if !ok { - t.Error("expected ok for large data") - } - if result != largeString { - t.Error("large data mismatch") - } -} - -func TestByteByByteWriting(t *testing.T) { - p := NewUTF8StreamParser() - input := "Hello 世界" - inputBytes := []byte(input) - - var results []string - for _, b := range inputBytes { - p.Write([]byte{b}) - if result, ok := p.Read(); ok { - results = append(results, result) - } - } - - combined := strings.Join(results, "") - if combined != input { - t.Errorf("expected '%s', got '%s'", input, combined) - } -} - -func TestEmoji4ByteUTF8(t *testing.T) { - p := NewUTF8StreamParser() - - // 🎉 = F0 9F 8E 89 - emoji := "🎉" - emojiBytes := []byte(emoji) - - for i := 0; i < len(emojiBytes)-1; i++ { - p.Write(emojiBytes[i : i+1]) - result, ok := p.Read() - if ok && result != "" { - t.Errorf("unexpected output before emoji complete: '%s'", result) - } - } - - p.Write(emojiBytes[len(emojiBytes)-1:]) - result, ok := p.Read() - if !ok { - t.Error("expected ok after completing emoji") - } - if result != emoji { - t.Errorf("expected '%s', got '%s'", emoji, result) - } -} - -func TestContinuationBytesOnly(t *testing.T) { - p := NewUTF8StreamParser() - - // Write only continuation bytes (invalid UTF-8) - p.Write([]byte{0x80, 0x80, 0x80}) - - result, ok := p.Read() - // Should handle gracefully - either return nothing or return the bytes - _ = result - _ = ok -} - -func TestUTF8StreamParser_ConcurrentSafety(t *testing.T) { - // Note: UTF8StreamParser doesn't have built-in locks, - // so this test verifies it works with external synchronization - p := NewUTF8StreamParser() - var mu sync.Mutex - const numGoroutines = 10 - const numOperations = 100 - - var wg sync.WaitGroup - wg.Add(numGoroutines) - - for i := 0; i < numGoroutines; i++ { - go func() { - defer wg.Done() - for j := 0; j < numOperations; j++ { - mu.Lock() - switch j % 4 { - case 0: - p.Write([]byte("test")) - case 1: - p.Read() - case 2: - p.Flush() - case 3: - p.Reset() - } - mu.Unlock() - } - }() - } - - wg.Wait() -} - -func TestConsecutiveReads(t *testing.T) { - p := NewUTF8StreamParser() - p.Write([]byte("hello")) - - result1, ok1 := p.Read() - if !ok1 || result1 != "hello" { - t.Error("first read failed") - } - - result2, ok2 := p.Read() - if ok2 || result2 != "" { - t.Error("second read should return empty") - } -} - -func TestFlushThenWrite(t *testing.T) { - p := NewUTF8StreamParser() - p.Write([]byte("first")) - p.Flush() - p.Write([]byte("second")) - - result, ok := p.Read() - if !ok || result != "second" { - t.Errorf("expected 'second', got '%s'", result) - } -} - -func TestResetThenWrite(t *testing.T) { - p := NewUTF8StreamParser() - p.Write([]byte("first")) - p.Reset() - p.Write([]byte("second")) - - result, ok := p.Read() - if !ok || result != "second" { - t.Errorf("expected 'second', got '%s'", result) - } -} diff --git a/sdk/auth/kiro.go.bak b/sdk/auth/kiro.go.bak deleted file mode 100644 index b75cd28e..00000000 --- a/sdk/auth/kiro.go.bak +++ /dev/null @@ -1,470 +0,0 @@ -package auth - -import ( - "context" - "fmt" - "strings" - "time" - - kiroauth "github.com/router-for-me/CLIProxyAPI/v6/internal/auth/kiro" - "github.com/router-for-me/CLIProxyAPI/v6/internal/config" - coreauth "github.com/router-for-me/CLIProxyAPI/v6/sdk/cliproxy/auth" -) - -// extractKiroIdentifier extracts a meaningful identifier for file naming. -// Returns account name if provided, otherwise profile ARN ID. -// All extracted values are sanitized to prevent path injection attacks. -func extractKiroIdentifier(accountName, profileArn string) string { - // Priority 1: Use account name if provided - if accountName != "" { - return kiroauth.SanitizeEmailForFilename(accountName) - } - - // Priority 2: Use profile ARN ID part (sanitized to prevent path injection) - if profileArn != "" { - parts := strings.Split(profileArn, "/") - if len(parts) >= 2 { - // Sanitize the ARN component to prevent path traversal - return kiroauth.SanitizeEmailForFilename(parts[len(parts)-1]) - } - } - - // Fallback: timestamp - return fmt.Sprintf("%d", time.Now().UnixNano()%100000) -} - -// KiroAuthenticator implements OAuth authentication for Kiro with Google login. -type KiroAuthenticator struct{} - -// NewKiroAuthenticator constructs a Kiro authenticator. -func NewKiroAuthenticator() *KiroAuthenticator { - return &KiroAuthenticator{} -} - -// Provider returns the provider key for the authenticator. -func (a *KiroAuthenticator) Provider() string { - return "kiro" -} - -// RefreshLead indicates how soon before expiry a refresh should be attempted. -// Set to 5 minutes to match Antigravity and avoid frequent refresh checks while still ensuring timely token refresh. -func (a *KiroAuthenticator) RefreshLead() *time.Duration { - d := 5 * time.Minute - return &d -} - -// createAuthRecord creates an auth record from token data. -func (a *KiroAuthenticator) createAuthRecord(tokenData *kiroauth.KiroTokenData, source string) (*coreauth.Auth, error) { - // Parse expires_at - expiresAt, err := time.Parse(time.RFC3339, tokenData.ExpiresAt) - if err != nil { - expiresAt = time.Now().Add(1 * time.Hour) - } - - // Extract identifier for file naming - idPart := extractKiroIdentifier(tokenData.Email, tokenData.ProfileArn) - - // Determine label based on auth method - label := fmt.Sprintf("kiro-%s", source) - if tokenData.AuthMethod == "idc" { - label = "kiro-idc" - } - - now := time.Now() - fileName := fmt.Sprintf("%s-%s.json", label, idPart) - - metadata := map[string]any{ - "type": "kiro", - "access_token": tokenData.AccessToken, - "refresh_token": tokenData.RefreshToken, - "profile_arn": tokenData.ProfileArn, - "expires_at": tokenData.ExpiresAt, - "auth_method": tokenData.AuthMethod, - "provider": tokenData.Provider, - "client_id": tokenData.ClientID, - "client_secret": tokenData.ClientSecret, - "email": tokenData.Email, - } - - // Add IDC-specific fields if present - if tokenData.StartURL != "" { - metadata["start_url"] = tokenData.StartURL - } - if tokenData.Region != "" { - metadata["region"] = tokenData.Region - } - - attributes := map[string]string{ - "profile_arn": tokenData.ProfileArn, - "source": source, - "email": tokenData.Email, - } - - // Add IDC-specific attributes if present - if tokenData.AuthMethod == "idc" { - attributes["source"] = "aws-idc" - if tokenData.StartURL != "" { - attributes["start_url"] = tokenData.StartURL - } - if tokenData.Region != "" { - attributes["region"] = tokenData.Region - } - } - - record := &coreauth.Auth{ - ID: fileName, - Provider: "kiro", - FileName: fileName, - Label: label, - Status: coreauth.StatusActive, - CreatedAt: now, - UpdatedAt: now, - Metadata: metadata, - Attributes: attributes, - // NextRefreshAfter is aligned with RefreshLead (5min) - NextRefreshAfter: expiresAt.Add(-5 * time.Minute), - } - - if tokenData.Email != "" { - fmt.Printf("\n✓ Kiro authentication completed successfully! (Account: %s)\n", tokenData.Email) - } else { - fmt.Println("\n✓ Kiro authentication completed successfully!") - } - - return record, nil -} - -// Login performs OAuth login for Kiro with AWS (Builder ID or IDC). -// This shows a method selection prompt and handles both flows. -func (a *KiroAuthenticator) Login(ctx context.Context, cfg *config.Config, opts *LoginOptions) (*coreauth.Auth, error) { - if cfg == nil { - return nil, fmt.Errorf("kiro auth: configuration is required") - } - - // Use the unified method selection flow (Builder ID or IDC) - ssoClient := kiroauth.NewSSOOIDCClient(cfg) - tokenData, err := ssoClient.LoginWithMethodSelection(ctx) - if err != nil { - return nil, fmt.Errorf("login failed: %w", err) - } - - return a.createAuthRecord(tokenData, "aws") -} - -// LoginWithAuthCode performs OAuth login for Kiro with AWS Builder ID using authorization code flow. -// This provides a better UX than device code flow as it uses automatic browser callback. -func (a *KiroAuthenticator) LoginWithAuthCode(ctx context.Context, cfg *config.Config, opts *LoginOptions) (*coreauth.Auth, error) { - if cfg == nil { - return nil, fmt.Errorf("kiro auth: configuration is required") - } - - oauth := kiroauth.NewKiroOAuth(cfg) - - // Use AWS Builder ID authorization code flow - tokenData, err := oauth.LoginWithBuilderIDAuthCode(ctx) - if err != nil { - return nil, fmt.Errorf("login failed: %w", err) - } - - // Parse expires_at - expiresAt, err := time.Parse(time.RFC3339, tokenData.ExpiresAt) - if err != nil { - expiresAt = time.Now().Add(1 * time.Hour) - } - - // Extract identifier for file naming - idPart := extractKiroIdentifier(tokenData.Email, tokenData.ProfileArn) - - now := time.Now() - fileName := fmt.Sprintf("kiro-aws-%s.json", idPart) - - record := &coreauth.Auth{ - ID: fileName, - Provider: "kiro", - FileName: fileName, - Label: "kiro-aws", - Status: coreauth.StatusActive, - CreatedAt: now, - UpdatedAt: now, - Metadata: map[string]any{ - "type": "kiro", - "access_token": tokenData.AccessToken, - "refresh_token": tokenData.RefreshToken, - "profile_arn": tokenData.ProfileArn, - "expires_at": tokenData.ExpiresAt, - "auth_method": tokenData.AuthMethod, - "provider": tokenData.Provider, - "client_id": tokenData.ClientID, - "client_secret": tokenData.ClientSecret, - "email": tokenData.Email, - }, - Attributes: map[string]string{ - "profile_arn": tokenData.ProfileArn, - "source": "aws-builder-id-authcode", - "email": tokenData.Email, - }, - // NextRefreshAfter is aligned with RefreshLead (5min) - NextRefreshAfter: expiresAt.Add(-5 * time.Minute), - } - - if tokenData.Email != "" { - fmt.Printf("\n✓ Kiro authentication completed successfully! (Account: %s)\n", tokenData.Email) - } else { - fmt.Println("\n✓ Kiro authentication completed successfully!") - } - - return record, nil -} - -// LoginWithGoogle performs OAuth login for Kiro with Google. -// This uses a custom protocol handler (kiro://) to receive the callback. -func (a *KiroAuthenticator) LoginWithGoogle(ctx context.Context, cfg *config.Config, opts *LoginOptions) (*coreauth.Auth, error) { - if cfg == nil { - return nil, fmt.Errorf("kiro auth: configuration is required") - } - - oauth := kiroauth.NewKiroOAuth(cfg) - - // Use Google OAuth flow with protocol handler - tokenData, err := oauth.LoginWithGoogle(ctx) - if err != nil { - return nil, fmt.Errorf("google login failed: %w", err) - } - - // Parse expires_at - expiresAt, err := time.Parse(time.RFC3339, tokenData.ExpiresAt) - if err != nil { - expiresAt = time.Now().Add(1 * time.Hour) - } - - // Extract identifier for file naming - idPart := extractKiroIdentifier(tokenData.Email, tokenData.ProfileArn) - - now := time.Now() - fileName := fmt.Sprintf("kiro-google-%s.json", idPart) - - record := &coreauth.Auth{ - ID: fileName, - Provider: "kiro", - FileName: fileName, - Label: "kiro-google", - Status: coreauth.StatusActive, - CreatedAt: now, - UpdatedAt: now, - Metadata: map[string]any{ - "type": "kiro", - "access_token": tokenData.AccessToken, - "refresh_token": tokenData.RefreshToken, - "profile_arn": tokenData.ProfileArn, - "expires_at": tokenData.ExpiresAt, - "auth_method": tokenData.AuthMethod, - "provider": tokenData.Provider, - "email": tokenData.Email, - }, - Attributes: map[string]string{ - "profile_arn": tokenData.ProfileArn, - "source": "google-oauth", - "email": tokenData.Email, - }, - // NextRefreshAfter is aligned with RefreshLead (5min) - NextRefreshAfter: expiresAt.Add(-5 * time.Minute), - } - - if tokenData.Email != "" { - fmt.Printf("\n✓ Kiro Google authentication completed successfully! (Account: %s)\n", tokenData.Email) - } else { - fmt.Println("\n✓ Kiro Google authentication completed successfully!") - } - - return record, nil -} - -// LoginWithGitHub performs OAuth login for Kiro with GitHub. -// This uses a custom protocol handler (kiro://) to receive the callback. -func (a *KiroAuthenticator) LoginWithGitHub(ctx context.Context, cfg *config.Config, opts *LoginOptions) (*coreauth.Auth, error) { - if cfg == nil { - return nil, fmt.Errorf("kiro auth: configuration is required") - } - - oauth := kiroauth.NewKiroOAuth(cfg) - - // Use GitHub OAuth flow with protocol handler - tokenData, err := oauth.LoginWithGitHub(ctx) - if err != nil { - return nil, fmt.Errorf("github login failed: %w", err) - } - - // Parse expires_at - expiresAt, err := time.Parse(time.RFC3339, tokenData.ExpiresAt) - if err != nil { - expiresAt = time.Now().Add(1 * time.Hour) - } - - // Extract identifier for file naming - idPart := extractKiroIdentifier(tokenData.Email, tokenData.ProfileArn) - - now := time.Now() - fileName := fmt.Sprintf("kiro-github-%s.json", idPart) - - record := &coreauth.Auth{ - ID: fileName, - Provider: "kiro", - FileName: fileName, - Label: "kiro-github", - Status: coreauth.StatusActive, - CreatedAt: now, - UpdatedAt: now, - Metadata: map[string]any{ - "type": "kiro", - "access_token": tokenData.AccessToken, - "refresh_token": tokenData.RefreshToken, - "profile_arn": tokenData.ProfileArn, - "expires_at": tokenData.ExpiresAt, - "auth_method": tokenData.AuthMethod, - "provider": tokenData.Provider, - "email": tokenData.Email, - }, - Attributes: map[string]string{ - "profile_arn": tokenData.ProfileArn, - "source": "github-oauth", - "email": tokenData.Email, - }, - // NextRefreshAfter is aligned with RefreshLead (5min) - NextRefreshAfter: expiresAt.Add(-5 * time.Minute), - } - - if tokenData.Email != "" { - fmt.Printf("\n✓ Kiro GitHub authentication completed successfully! (Account: %s)\n", tokenData.Email) - } else { - fmt.Println("\n✓ Kiro GitHub authentication completed successfully!") - } - - return record, nil -} - -// ImportFromKiroIDE imports token from Kiro IDE's token file. -func (a *KiroAuthenticator) ImportFromKiroIDE(ctx context.Context, cfg *config.Config) (*coreauth.Auth, error) { - tokenData, err := kiroauth.LoadKiroIDEToken() - if err != nil { - return nil, fmt.Errorf("failed to load Kiro IDE token: %w", err) - } - - // Parse expires_at - expiresAt, err := time.Parse(time.RFC3339, tokenData.ExpiresAt) - if err != nil { - expiresAt = time.Now().Add(1 * time.Hour) - } - - // Extract email from JWT if not already set (for imported tokens) - if tokenData.Email == "" { - tokenData.Email = kiroauth.ExtractEmailFromJWT(tokenData.AccessToken) - } - - // Extract identifier for file naming - idPart := extractKiroIdentifier(tokenData.Email, tokenData.ProfileArn) - // Sanitize provider to prevent path traversal (defense-in-depth) - provider := kiroauth.SanitizeEmailForFilename(strings.ToLower(strings.TrimSpace(tokenData.Provider))) - if provider == "" { - provider = "imported" // Fallback for legacy tokens without provider - } - - now := time.Now() - fileName := fmt.Sprintf("kiro-%s-%s.json", provider, idPart) - - record := &coreauth.Auth{ - ID: fileName, - Provider: "kiro", - FileName: fileName, - Label: fmt.Sprintf("kiro-%s", provider), - Status: coreauth.StatusActive, - CreatedAt: now, - UpdatedAt: now, - Metadata: map[string]any{ - "type": "kiro", - "access_token": tokenData.AccessToken, - "refresh_token": tokenData.RefreshToken, - "profile_arn": tokenData.ProfileArn, - "expires_at": tokenData.ExpiresAt, - "auth_method": tokenData.AuthMethod, - "provider": tokenData.Provider, - "email": tokenData.Email, - }, - Attributes: map[string]string{ - "profile_arn": tokenData.ProfileArn, - "source": "kiro-ide-import", - "email": tokenData.Email, - }, - // NextRefreshAfter is aligned with RefreshLead (5min) - NextRefreshAfter: expiresAt.Add(-5 * time.Minute), - } - - // Display the email if extracted - if tokenData.Email != "" { - fmt.Printf("\n✓ Imported Kiro token from IDE (Provider: %s, Account: %s)\n", tokenData.Provider, tokenData.Email) - } else { - fmt.Printf("\n✓ Imported Kiro token from IDE (Provider: %s)\n", tokenData.Provider) - } - - return record, nil -} - -// Refresh refreshes an expired Kiro token using AWS SSO OIDC. -func (a *KiroAuthenticator) Refresh(ctx context.Context, cfg *config.Config, auth *coreauth.Auth) (*coreauth.Auth, error) { - if auth == nil || auth.Metadata == nil { - return nil, fmt.Errorf("invalid auth record") - } - - refreshToken, ok := auth.Metadata["refresh_token"].(string) - if !ok || refreshToken == "" { - return nil, fmt.Errorf("refresh token not found") - } - - clientID, _ := auth.Metadata["client_id"].(string) - clientSecret, _ := auth.Metadata["client_secret"].(string) - authMethod, _ := auth.Metadata["auth_method"].(string) - startURL, _ := auth.Metadata["start_url"].(string) - region, _ := auth.Metadata["region"].(string) - - var tokenData *kiroauth.KiroTokenData - var err error - - ssoClient := kiroauth.NewSSOOIDCClient(cfg) - - // Use SSO OIDC refresh for AWS Builder ID or IDC, otherwise use Kiro's OAuth refresh endpoint - switch { - case clientID != "" && clientSecret != "" && authMethod == "idc" && region != "": - // IDC refresh with region-specific endpoint - tokenData, err = ssoClient.RefreshTokenWithRegion(ctx, clientID, clientSecret, refreshToken, region, startURL) - case clientID != "" && clientSecret != "" && authMethod == "builder-id": - // Builder ID refresh with default endpoint - tokenData, err = ssoClient.RefreshToken(ctx, clientID, clientSecret, refreshToken) - default: - // Fallback to Kiro's refresh endpoint (for social auth: Google/GitHub) - oauth := kiroauth.NewKiroOAuth(cfg) - tokenData, err = oauth.RefreshToken(ctx, refreshToken) - } - - if err != nil { - return nil, fmt.Errorf("token refresh failed: %w", err) - } - - // Parse expires_at - expiresAt, err := time.Parse(time.RFC3339, tokenData.ExpiresAt) - if err != nil { - expiresAt = time.Now().Add(1 * time.Hour) - } - - // Clone auth to avoid mutating the input parameter - updated := auth.Clone() - now := time.Now() - updated.UpdatedAt = now - updated.LastRefreshedAt = now - updated.Metadata["access_token"] = tokenData.AccessToken - updated.Metadata["refresh_token"] = tokenData.RefreshToken - updated.Metadata["expires_at"] = tokenData.ExpiresAt - updated.Metadata["last_refresh"] = now.Format(time.RFC3339) // For double-check optimization - // NextRefreshAfter is aligned with RefreshLead (5min) - updated.NextRefreshAfter = expiresAt.Add(-5 * time.Minute) - - return updated, nil -}