refactor: improve error handling and code quality

- Handle errors in promptInput instead of ignoring them
- Improve promptSelect to provide feedback on invalid input and re-prompt
- Use sentinel errors (ErrAuthorizationPending, ErrSlowDown) instead of
  string-based error checking with strings.Contains
- Move hardcoded x-amz-user-agent header to idcAmzUserAgent constant

Addresses code review feedback from Gemini Code Assist.
This commit is contained in:
Joao
2025-12-23 10:20:14 +00:00
parent 98db5aabd0
commit 349b2ba3af

View File

@@ -8,6 +8,7 @@ import (
"crypto/sha256"
"encoding/base64"
"encoding/json"
"errors"
"fmt"
"html"
"io"
@@ -35,13 +36,22 @@ const (
// Polling interval
pollInterval = 5 * time.Second
// Authorization code flow callback
authCodeCallbackPath = "/oauth/callback"
authCodeCallbackPort = 19877
// User-Agent to match official Kiro IDE
kiroUserAgent = "KiroIDE"
// IDC token refresh headers (matching Kiro IDE behavior)
idcAmzUserAgent = "aws-sdk-js/3.738.0 ua/2.1 os/other lang/js md/browser#unknown_unknown api/sso-oidc#3.738.0 m/E KiroIDE"
)
// Sentinel errors for OIDC token polling
var (
ErrAuthorizationPending = errors.New("authorization_pending")
ErrSlowDown = errors.New("slow_down")
)
// SSOOIDCClient handles AWS SSO OIDC authentication.
@@ -104,7 +114,11 @@ func promptInput(prompt, defaultValue string) string {
} else {
fmt.Printf("%s: ", prompt)
}
input, _ := reader.ReadString('\n')
input, err := reader.ReadString('\n')
if err != nil {
log.Warnf("Error reading input: %v", err)
return defaultValue
}
input = strings.TrimSpace(input)
if input == "" {
return defaultValue
@@ -112,24 +126,32 @@ func promptInput(prompt, defaultValue string) string {
return input
}
// promptSelect prompts the user to select from options using arrow keys or number input.
// promptSelect prompts the user to select from options using number input.
func promptSelect(prompt string, options []string) int {
fmt.Println(prompt)
for i, opt := range options {
fmt.Printf(" %d) %s\n", i+1, opt)
}
fmt.Print("Enter selection (1-", len(options), "): ")
reader := bufio.NewReader(os.Stdin)
input, _ := reader.ReadString('\n')
input = strings.TrimSpace(input)
// Parse the selection
var selection int
if _, err := fmt.Sscanf(input, "%d", &selection); err != nil || selection < 1 || selection > len(options) {
return 0 // Default to first option
for {
fmt.Println(prompt)
for i, opt := range options {
fmt.Printf(" %d) %s\n", i+1, opt)
}
fmt.Printf("Enter selection (1-%d): ", len(options))
input, err := reader.ReadString('\n')
if err != nil {
log.Warnf("Error reading input: %v", err)
return 0 // Default to first option on error
}
input = strings.TrimSpace(input)
// Parse the selection
var selection int
if _, err := fmt.Sscanf(input, "%d", &selection); err != nil || selection < 1 || selection > len(options) {
fmt.Printf("Invalid selection '%s'. Please enter a number between 1 and %d.\n\n", input, len(options))
continue
}
return selection - 1
}
return selection - 1
}
// RegisterClientWithRegion registers a new OIDC client with AWS using a specific region.
@@ -266,10 +288,10 @@ func (c *SSOOIDCClient) CreateTokenWithRegion(ctx context.Context, clientID, cli
}
if json.Unmarshal(respBody, &errResp) == nil {
if errResp.Error == "authorization_pending" {
return nil, fmt.Errorf("authorization_pending")
return nil, ErrAuthorizationPending
}
if errResp.Error == "slow_down" {
return nil, fmt.Errorf("slow_down")
return nil, ErrSlowDown
}
}
log.Debugf("create token failed: %s", string(respBody))
@@ -315,7 +337,7 @@ func (c *SSOOIDCClient) RefreshTokenWithRegion(ctx context.Context, clientID, cl
req.Header.Set("Content-Type", "application/json")
req.Header.Set("Host", fmt.Sprintf("oidc.%s.amazonaws.com", region))
req.Header.Set("Connection", "keep-alive")
req.Header.Set("x-amz-user-agent", "aws-sdk-js/3.738.0 ua/2.1 os/other lang/js md/browser#unknown_unknown api/sso-oidc#3.738.0 m/E KiroIDE")
req.Header.Set("x-amz-user-agent", idcAmzUserAgent)
req.Header.Set("Accept", "*/*")
req.Header.Set("Accept-Language", "*")
req.Header.Set("sec-fetch-mode", "cors")
@@ -426,12 +448,11 @@ func (c *SSOOIDCClient) LoginWithIDC(ctx context.Context, startURL, region strin
case <-time.After(interval):
tokenResp, err := c.CreateTokenWithRegion(ctx, regResp.ClientID, regResp.ClientSecret, authResp.DeviceCode, region)
if err != nil {
errStr := err.Error()
if strings.Contains(errStr, "authorization_pending") {
if errors.Is(err, ErrAuthorizationPending) {
fmt.Print(".")
continue
}
if strings.Contains(errStr, "slow_down") {
if errors.Is(err, ErrSlowDown) {
interval += 5 * time.Second
continue
}
@@ -639,10 +660,10 @@ func (c *SSOOIDCClient) CreateToken(ctx context.Context, clientID, clientSecret,
}
if json.Unmarshal(respBody, &errResp) == nil {
if errResp.Error == "authorization_pending" {
return nil, fmt.Errorf("authorization_pending")
return nil, ErrAuthorizationPending
}
if errResp.Error == "slow_down" {
return nil, fmt.Errorf("slow_down")
return nil, ErrSlowDown
}
}
log.Debugf("create token failed: %s", string(respBody))
@@ -787,12 +808,11 @@ func (c *SSOOIDCClient) LoginWithBuilderID(ctx context.Context) (*KiroTokenData,
case <-time.After(interval):
tokenResp, err := c.CreateToken(ctx, regResp.ClientID, regResp.ClientSecret, authResp.DeviceCode)
if err != nil {
errStr := err.Error()
if strings.Contains(errStr, "authorization_pending") {
if errors.Is(err, ErrAuthorizationPending) {
fmt.Print(".")
continue
}
if strings.Contains(errStr, "slow_down") {
if errors.Is(err, ErrSlowDown) {
interval += 5 * time.Second
continue
}