diff --git a/internal/auth/kiro/sso_oidc.go b/internal/auth/kiro/sso_oidc.go index 292f5bcf..ab44e55f 100644 --- a/internal/auth/kiro/sso_oidc.go +++ b/internal/auth/kiro/sso_oidc.go @@ -8,6 +8,7 @@ import ( "crypto/sha256" "encoding/base64" "encoding/json" + "errors" "fmt" "html" "io" @@ -35,13 +36,22 @@ const ( // Polling interval pollInterval = 5 * time.Second - + // Authorization code flow callback authCodeCallbackPath = "/oauth/callback" authCodeCallbackPort = 19877 - + // User-Agent to match official Kiro IDE kiroUserAgent = "KiroIDE" + + // IDC token refresh headers (matching Kiro IDE behavior) + idcAmzUserAgent = "aws-sdk-js/3.738.0 ua/2.1 os/other lang/js md/browser#unknown_unknown api/sso-oidc#3.738.0 m/E KiroIDE" +) + +// Sentinel errors for OIDC token polling +var ( + ErrAuthorizationPending = errors.New("authorization_pending") + ErrSlowDown = errors.New("slow_down") ) // SSOOIDCClient handles AWS SSO OIDC authentication. @@ -104,7 +114,11 @@ func promptInput(prompt, defaultValue string) string { } else { fmt.Printf("%s: ", prompt) } - input, _ := reader.ReadString('\n') + input, err := reader.ReadString('\n') + if err != nil { + log.Warnf("Error reading input: %v", err) + return defaultValue + } input = strings.TrimSpace(input) if input == "" { return defaultValue @@ -112,24 +126,32 @@ func promptInput(prompt, defaultValue string) string { return input } -// promptSelect prompts the user to select from options using arrow keys or number input. +// promptSelect prompts the user to select from options using number input. func promptSelect(prompt string, options []string) int { - fmt.Println(prompt) - for i, opt := range options { - fmt.Printf(" %d) %s\n", i+1, opt) - } - fmt.Print("Enter selection (1-", len(options), "): ") - reader := bufio.NewReader(os.Stdin) - input, _ := reader.ReadString('\n') - input = strings.TrimSpace(input) - // Parse the selection - var selection int - if _, err := fmt.Sscanf(input, "%d", &selection); err != nil || selection < 1 || selection > len(options) { - return 0 // Default to first option + for { + fmt.Println(prompt) + for i, opt := range options { + fmt.Printf(" %d) %s\n", i+1, opt) + } + fmt.Printf("Enter selection (1-%d): ", len(options)) + + input, err := reader.ReadString('\n') + if err != nil { + log.Warnf("Error reading input: %v", err) + return 0 // Default to first option on error + } + input = strings.TrimSpace(input) + + // Parse the selection + var selection int + if _, err := fmt.Sscanf(input, "%d", &selection); err != nil || selection < 1 || selection > len(options) { + fmt.Printf("Invalid selection '%s'. Please enter a number between 1 and %d.\n\n", input, len(options)) + continue + } + return selection - 1 } - return selection - 1 } // RegisterClientWithRegion registers a new OIDC client with AWS using a specific region. @@ -266,10 +288,10 @@ func (c *SSOOIDCClient) CreateTokenWithRegion(ctx context.Context, clientID, cli } if json.Unmarshal(respBody, &errResp) == nil { if errResp.Error == "authorization_pending" { - return nil, fmt.Errorf("authorization_pending") + return nil, ErrAuthorizationPending } if errResp.Error == "slow_down" { - return nil, fmt.Errorf("slow_down") + return nil, ErrSlowDown } } log.Debugf("create token failed: %s", string(respBody)) @@ -315,7 +337,7 @@ func (c *SSOOIDCClient) RefreshTokenWithRegion(ctx context.Context, clientID, cl req.Header.Set("Content-Type", "application/json") req.Header.Set("Host", fmt.Sprintf("oidc.%s.amazonaws.com", region)) req.Header.Set("Connection", "keep-alive") - req.Header.Set("x-amz-user-agent", "aws-sdk-js/3.738.0 ua/2.1 os/other lang/js md/browser#unknown_unknown api/sso-oidc#3.738.0 m/E KiroIDE") + req.Header.Set("x-amz-user-agent", idcAmzUserAgent) req.Header.Set("Accept", "*/*") req.Header.Set("Accept-Language", "*") req.Header.Set("sec-fetch-mode", "cors") @@ -426,12 +448,11 @@ func (c *SSOOIDCClient) LoginWithIDC(ctx context.Context, startURL, region strin case <-time.After(interval): tokenResp, err := c.CreateTokenWithRegion(ctx, regResp.ClientID, regResp.ClientSecret, authResp.DeviceCode, region) if err != nil { - errStr := err.Error() - if strings.Contains(errStr, "authorization_pending") { + if errors.Is(err, ErrAuthorizationPending) { fmt.Print(".") continue } - if strings.Contains(errStr, "slow_down") { + if errors.Is(err, ErrSlowDown) { interval += 5 * time.Second continue } @@ -639,10 +660,10 @@ func (c *SSOOIDCClient) CreateToken(ctx context.Context, clientID, clientSecret, } if json.Unmarshal(respBody, &errResp) == nil { if errResp.Error == "authorization_pending" { - return nil, fmt.Errorf("authorization_pending") + return nil, ErrAuthorizationPending } if errResp.Error == "slow_down" { - return nil, fmt.Errorf("slow_down") + return nil, ErrSlowDown } } log.Debugf("create token failed: %s", string(respBody)) @@ -787,12 +808,11 @@ func (c *SSOOIDCClient) LoginWithBuilderID(ctx context.Context) (*KiroTokenData, case <-time.After(interval): tokenResp, err := c.CreateToken(ctx, regResp.ClientID, regResp.ClientSecret, authResp.DeviceCode) if err != nil { - errStr := err.Error() - if strings.Contains(errStr, "authorization_pending") { + if errors.Is(err, ErrAuthorizationPending) { fmt.Print(".") continue } - if strings.Contains(errStr, "slow_down") { + if errors.Is(err, ErrSlowDown) { interval += 5 * time.Second continue }