From b18b2ebe9ff756bfc254c1da111e8197de52a599 Mon Sep 17 00:00:00 2001 From: CheesesNguyen Date: Wed, 28 Jan 2026 14:47:04 +0700 Subject: [PATCH] fix: Implement graceful token refresh degradation and enhance IDC SSO support with device registration loading for Kiro. --- internal/auth/kiro/aws.go | 76 +++++++++- internal/auth/kiro/background_refresh.go | 72 ++++++---- internal/auth/kiro/oauth.go | 30 +++- internal/auth/kiro/rate_limiter.go | 12 +- internal/auth/kiro/refresh_utils.go | 159 +++++++++++++++++++++ internal/auth/kiro/social_auth.go | 2 +- internal/auth/kiro/sso_oidc.go | 12 +- internal/runtime/executor/kiro_executor.go | 46 +++++- sdk/auth/kiro.go | 89 ++++++++++-- 9 files changed, 431 insertions(+), 67 deletions(-) create mode 100644 internal/auth/kiro/refresh_utils.go diff --git a/internal/auth/kiro/aws.go b/internal/auth/kiro/aws.go index ef775d05..247a365c 100644 --- a/internal/auth/kiro/aws.go +++ b/internal/auth/kiro/aws.go @@ -32,14 +32,17 @@ type KiroTokenData struct { ProfileArn string `json:"profileArn"` // ExpiresAt is the timestamp when the token expires ExpiresAt string `json:"expiresAt"` - // AuthMethod indicates the authentication method used (e.g., "builder-id", "social") + // AuthMethod indicates the authentication method used (e.g., "builder-id", "social", "idc") AuthMethod string `json:"authMethod"` - // Provider indicates the OAuth provider (e.g., "AWS", "Google") + // Provider indicates the OAuth provider (e.g., "AWS", "Google", "Enterprise") Provider string `json:"provider"` // ClientID is the OIDC client ID (needed for token refresh) ClientID string `json:"clientId,omitempty"` // ClientSecret is the OIDC client secret (needed for token refresh) ClientSecret string `json:"clientSecret,omitempty"` + // ClientIDHash is the hash of client ID used to locate device registration file + // (Enterprise Kiro IDE stores clientId/clientSecret in ~/.aws/sso/cache/{clientIdHash}.json) + ClientIDHash string `json:"clientIdHash,omitempty"` // Email is the user's email address (used for file naming) Email string `json:"email,omitempty"` // StartURL is the IDC/Identity Center start URL (only for IDC auth method) @@ -169,6 +172,8 @@ func LoadKiroIDETokenWithRetry(maxAttempts int, baseDelay time.Duration) (*KiroT } // LoadKiroIDEToken loads token data from Kiro IDE's token file. +// For Enterprise Kiro IDE (IDC auth), it also loads clientId and clientSecret +// from the device registration file referenced by clientIdHash. func LoadKiroIDEToken() (*KiroTokenData, error) { homeDir, err := os.UserHomeDir() if err != nil { @@ -193,18 +198,69 @@ func LoadKiroIDEToken() (*KiroTokenData, error) { // Normalize AuthMethod to lowercase (Kiro IDE uses "IdC" but we expect "idc") token.AuthMethod = strings.ToLower(token.AuthMethod) + // For Enterprise Kiro IDE (IDC auth), load clientId and clientSecret from device registration + // The device registration file is located at ~/.aws/sso/cache/{clientIdHash}.json + if token.ClientIDHash != "" && token.ClientID == "" { + if err := loadDeviceRegistration(homeDir, token.ClientIDHash, &token); err != nil { + // Log warning but don't fail - token might still work for some operations + fmt.Printf("warning: failed to load device registration for clientIdHash %s: %v\n", token.ClientIDHash, err) + } + } + return &token, nil } +// loadDeviceRegistration loads clientId and clientSecret from the device registration file. +// Enterprise Kiro IDE stores these in ~/.aws/sso/cache/{clientIdHash}.json +func loadDeviceRegistration(homeDir, clientIDHash string, token *KiroTokenData) error { + if clientIDHash == "" { + return fmt.Errorf("clientIdHash is empty") + } + + // Sanitize clientIdHash to prevent path traversal + if strings.Contains(clientIDHash, "/") || strings.Contains(clientIDHash, "\\") || strings.Contains(clientIDHash, "..") { + return fmt.Errorf("invalid clientIdHash: contains path separator") + } + + deviceRegPath := filepath.Join(homeDir, ".aws", "sso", "cache", clientIDHash+".json") + data, err := os.ReadFile(deviceRegPath) + if err != nil { + return fmt.Errorf("failed to read device registration file (%s): %w", deviceRegPath, err) + } + + // Device registration file structure + var deviceReg struct { + ClientID string `json:"clientId"` + ClientSecret string `json:"clientSecret"` + ExpiresAt string `json:"expiresAt"` + } + + if err := json.Unmarshal(data, &deviceReg); err != nil { + return fmt.Errorf("failed to parse device registration: %w", err) + } + + if deviceReg.ClientID == "" || deviceReg.ClientSecret == "" { + return fmt.Errorf("device registration missing clientId or clientSecret") + } + + token.ClientID = deviceReg.ClientID + token.ClientSecret = deviceReg.ClientSecret + + return nil +} + // LoadKiroTokenFromPath loads token data from a custom path. // This supports multiple accounts by allowing different token files. +// For Enterprise Kiro IDE (IDC auth), it also loads clientId and clientSecret +// from the device registration file referenced by clientIdHash. func LoadKiroTokenFromPath(tokenPath string) (*KiroTokenData, error) { + homeDir, err := os.UserHomeDir() + if err != nil { + return nil, fmt.Errorf("failed to get home directory: %w", err) + } + // Expand ~ to home directory if len(tokenPath) > 0 && tokenPath[0] == '~' { - homeDir, err := os.UserHomeDir() - if err != nil { - return nil, fmt.Errorf("failed to get home directory: %w", err) - } tokenPath = filepath.Join(homeDir, tokenPath[1:]) } @@ -225,6 +281,14 @@ func LoadKiroTokenFromPath(tokenPath string) (*KiroTokenData, error) { // Normalize AuthMethod to lowercase (Kiro IDE uses "IdC" but we expect "idc") token.AuthMethod = strings.ToLower(token.AuthMethod) + // For Enterprise Kiro IDE (IDC auth), load clientId and clientSecret from device registration + if token.ClientIDHash != "" && token.ClientID == "" { + if err := loadDeviceRegistration(homeDir, token.ClientIDHash, &token); err != nil { + // Log warning but don't fail - token might still work for some operations + fmt.Printf("warning: failed to load device registration for clientIdHash %s: %v\n", token.ClientIDHash, err) + } + } + return &token, nil } diff --git a/internal/auth/kiro/background_refresh.go b/internal/auth/kiro/background_refresh.go index bd1f048f..2b4a161c 100644 --- a/internal/auth/kiro/background_refresh.go +++ b/internal/auth/kiro/background_refresh.go @@ -161,40 +161,56 @@ func (r *BackgroundRefresher) refreshBatch(ctx context.Context) { } func (r *BackgroundRefresher) refreshSingle(ctx context.Context, token *Token) { - var newTokenData *KiroTokenData - var err error - - // Normalize auth method to lowercase for case-insensitive matching - authMethod := strings.ToLower(token.AuthMethod) - - switch authMethod { - case "idc": - newTokenData, err = r.ssoClient.RefreshTokenWithRegion( - ctx, - token.ClientID, - token.ClientSecret, - token.RefreshToken, - token.Region, - token.StartURL, - ) - case "builder-id": - newTokenData, err = r.ssoClient.RefreshToken( - ctx, - token.ClientID, - token.ClientSecret, - token.RefreshToken, - ) - default: - newTokenData, err = r.oauth.RefreshToken(ctx, token.RefreshToken) + // Create refresh function based on auth method + refreshFunc := func(ctx context.Context) (*KiroTokenData, error) { + switch token.AuthMethod { + case "idc": + return r.ssoClient.RefreshTokenWithRegion( + ctx, + token.ClientID, + token.ClientSecret, + token.RefreshToken, + token.Region, + token.StartURL, + ) + case "builder-id": + return r.ssoClient.RefreshToken( + ctx, + token.ClientID, + token.ClientSecret, + token.RefreshToken, + ) + default: + return r.oauth.RefreshTokenWithFingerprint(ctx, token.RefreshToken, token.ID) + } } - if err != nil { - log.Printf("failed to refresh token %s: %v", token.ID, err) + // Use graceful degradation for better reliability + result := RefreshWithGracefulDegradation( + ctx, + refreshFunc, + token.AccessToken, + token.ExpiresAt, + ) + + if result.Error != nil { + log.Printf("failed to refresh token %s: %v", token.ID, result.Error) + return + } + + newTokenData := result.TokenData + if result.UsedFallback { + log.Printf("token %s: using existing token as fallback (refresh failed but token still valid)", token.ID) + // Don't update the token file if we're using fallback + // Just update LastVerified to prevent immediate re-check + token.LastVerified = time.Now() return } token.AccessToken = newTokenData.AccessToken - token.RefreshToken = newTokenData.RefreshToken + if newTokenData.RefreshToken != "" { + token.RefreshToken = newTokenData.RefreshToken + } token.LastVerified = time.Now() if newTokenData.ExpiresAt != "" { diff --git a/internal/auth/kiro/oauth.go b/internal/auth/kiro/oauth.go index 0609610f..a286cf42 100644 --- a/internal/auth/kiro/oauth.go +++ b/internal/auth/kiro/oauth.go @@ -190,7 +190,7 @@ func (o *KiroOAuth) exchangeCodeForToken(ctx context.Context, code, codeVerifier } req.Header.Set("Content-Type", "application/json") - req.Header.Set("User-Agent", "cli-proxy-api/1.0.0") + req.Header.Set("User-Agent", "KiroIDE-0.7.45-cli-proxy-api") resp, err := o.httpClient.Do(req) if err != nil { @@ -232,7 +232,14 @@ func (o *KiroOAuth) exchangeCodeForToken(ctx context.Context, code, codeVerifier } // RefreshToken refreshes an expired access token. +// Uses KiroIDE-style User-Agent to match official Kiro IDE behavior. func (o *KiroOAuth) RefreshToken(ctx context.Context, refreshToken string) (*KiroTokenData, error) { + return o.RefreshTokenWithFingerprint(ctx, refreshToken, "") +} + +// RefreshTokenWithFingerprint refreshes an expired access token with a specific fingerprint. +// tokenKey is used to generate a consistent fingerprint for the token. +func (o *KiroOAuth) RefreshTokenWithFingerprint(ctx context.Context, refreshToken, tokenKey string) (*KiroTokenData, error) { payload := map[string]string{ "refreshToken": refreshToken, } @@ -249,7 +256,11 @@ func (o *KiroOAuth) RefreshToken(ctx context.Context, refreshToken string) (*Kir } req.Header.Set("Content-Type", "application/json") - req.Header.Set("User-Agent", "cli-proxy-api/1.0.0") + + // Use KiroIDE-style User-Agent to match official Kiro IDE behavior + // This helps avoid 403 errors from server-side User-Agent validation + userAgent := buildKiroUserAgent(tokenKey) + req.Header.Set("User-Agent", userAgent) resp, err := o.httpClient.Do(req) if err != nil { @@ -264,7 +275,7 @@ func (o *KiroOAuth) RefreshToken(ctx context.Context, refreshToken string) (*Kir if resp.StatusCode != http.StatusOK { log.Debugf("token refresh failed (status %d): %s", resp.StatusCode, string(respBody)) - return nil, fmt.Errorf("token refresh failed (status %d)", resp.StatusCode) + return nil, fmt.Errorf("token refresh failed (status %d): %s", resp.StatusCode, string(respBody)) } var tokenResp KiroTokenResponse @@ -290,6 +301,19 @@ func (o *KiroOAuth) RefreshToken(ctx context.Context, refreshToken string) (*Kir }, nil } +// buildKiroUserAgent builds a KiroIDE-style User-Agent string. +// If tokenKey is provided, uses fingerprint manager for consistent fingerprint. +// Otherwise generates a simple KiroIDE User-Agent. +func buildKiroUserAgent(tokenKey string) string { + if tokenKey != "" { + fm := NewFingerprintManager() + fp := fm.GetFingerprint(tokenKey) + return fmt.Sprintf("KiroIDE-%s-%s", fp.KiroVersion, fp.KiroHash[:16]) + } + // Default KiroIDE User-Agent matching kiro-openai-gateway format + return "KiroIDE-0.7.45-cli-proxy-api" +} + // LoginWithGoogle performs OAuth login with Google using Kiro's social auth. // This uses a custom protocol handler (kiro://) to receive the callback. func (o *KiroOAuth) LoginWithGoogle(ctx context.Context) (*KiroTokenData, error) { diff --git a/internal/auth/kiro/rate_limiter.go b/internal/auth/kiro/rate_limiter.go index 3c240ebe..52bb24af 100644 --- a/internal/auth/kiro/rate_limiter.go +++ b/internal/auth/kiro/rate_limiter.go @@ -9,14 +9,14 @@ import ( ) const ( - DefaultMinTokenInterval = 10 * time.Second - DefaultMaxTokenInterval = 30 * time.Second + DefaultMinTokenInterval = 1 * time.Second + DefaultMaxTokenInterval = 2 * time.Second DefaultDailyMaxRequests = 500 DefaultJitterPercent = 0.3 - DefaultBackoffBase = 2 * time.Minute - DefaultBackoffMax = 60 * time.Minute - DefaultBackoffMultiplier = 2.0 - DefaultSuspendCooldown = 24 * time.Hour + DefaultBackoffBase = 30 * time.Second + DefaultBackoffMax = 5 * time.Minute + DefaultBackoffMultiplier = 1.5 + DefaultSuspendCooldown = 1 * time.Hour ) // TokenState Token 状态 diff --git a/internal/auth/kiro/refresh_utils.go b/internal/auth/kiro/refresh_utils.go new file mode 100644 index 00000000..5abb714c --- /dev/null +++ b/internal/auth/kiro/refresh_utils.go @@ -0,0 +1,159 @@ +// Package kiro provides refresh utilities for Kiro token management. +package kiro + +import ( + "context" + "fmt" + "time" + + log "github.com/sirupsen/logrus" +) + +// RefreshResult contains the result of a token refresh attempt. +type RefreshResult struct { + TokenData *KiroTokenData + Error error + UsedFallback bool // True if we used the existing token as fallback +} + +// RefreshWithGracefulDegradation attempts to refresh a token with graceful degradation. +// If refresh fails but the existing access token is still valid, it returns the existing token. +// This matches kiro-openai-gateway's behavior for better reliability. +// +// Parameters: +// - ctx: Context for the request +// - refreshFunc: Function to perform the actual refresh +// - existingAccessToken: Current access token (for fallback) +// - expiresAt: Expiration time of the existing token +// +// Returns: +// - RefreshResult containing the new or existing token data +func RefreshWithGracefulDegradation( + ctx context.Context, + refreshFunc func(ctx context.Context) (*KiroTokenData, error), + existingAccessToken string, + expiresAt time.Time, +) RefreshResult { + // Try to refresh the token + newTokenData, err := refreshFunc(ctx) + if err == nil { + return RefreshResult{ + TokenData: newTokenData, + Error: nil, + UsedFallback: false, + } + } + + // Refresh failed - check if we can use the existing token + log.Warnf("kiro: token refresh failed: %v", err) + + // Check if existing token is still valid (not expired) + if existingAccessToken != "" && time.Now().Before(expiresAt) { + remainingTime := time.Until(expiresAt) + log.Warnf("kiro: using existing access token (expires in %v). Will retry refresh later.", remainingTime.Round(time.Second)) + + return RefreshResult{ + TokenData: &KiroTokenData{ + AccessToken: existingAccessToken, + ExpiresAt: expiresAt.Format(time.RFC3339), + }, + Error: nil, + UsedFallback: true, + } + } + + // Token is expired and refresh failed - return the error + return RefreshResult{ + TokenData: nil, + Error: fmt.Errorf("token refresh failed and existing token is expired: %w", err), + UsedFallback: false, + } +} + +// IsTokenExpiringSoon checks if a token is expiring within the given threshold. +// Default threshold is 5 minutes if not specified. +func IsTokenExpiringSoon(expiresAt time.Time, threshold time.Duration) bool { + if threshold == 0 { + threshold = 5 * time.Minute + } + return time.Now().Add(threshold).After(expiresAt) +} + +// IsTokenExpired checks if a token has already expired. +func IsTokenExpired(expiresAt time.Time) bool { + return time.Now().After(expiresAt) +} + +// ParseExpiresAt parses an expiration time string in RFC3339 format. +// Returns zero time if parsing fails. +func ParseExpiresAt(expiresAtStr string) time.Time { + if expiresAtStr == "" { + return time.Time{} + } + t, err := time.Parse(time.RFC3339, expiresAtStr) + if err != nil { + log.Debugf("kiro: failed to parse expiresAt '%s': %v", expiresAtStr, err) + return time.Time{} + } + return t +} + +// RefreshConfig contains configuration for token refresh behavior. +type RefreshConfig struct { + // MaxRetries is the maximum number of refresh attempts (default: 1) + MaxRetries int + // RetryDelay is the delay between retry attempts (default: 1 second) + RetryDelay time.Duration + // RefreshThreshold is how early to refresh before expiration (default: 5 minutes) + RefreshThreshold time.Duration + // EnableGracefulDegradation allows using existing token if refresh fails (default: true) + EnableGracefulDegradation bool +} + +// DefaultRefreshConfig returns the default refresh configuration. +func DefaultRefreshConfig() RefreshConfig { + return RefreshConfig{ + MaxRetries: 1, + RetryDelay: time.Second, + RefreshThreshold: 5 * time.Minute, + EnableGracefulDegradation: true, + } +} + +// RefreshWithRetry attempts to refresh a token with retry logic. +func RefreshWithRetry( + ctx context.Context, + refreshFunc func(ctx context.Context) (*KiroTokenData, error), + config RefreshConfig, +) (*KiroTokenData, error) { + var lastErr error + + maxAttempts := config.MaxRetries + 1 + if maxAttempts < 1 { + maxAttempts = 1 + } + + for attempt := 1; attempt <= maxAttempts; attempt++ { + tokenData, err := refreshFunc(ctx) + if err == nil { + if attempt > 1 { + log.Infof("kiro: token refresh succeeded on attempt %d", attempt) + } + return tokenData, nil + } + + lastErr = err + log.Warnf("kiro: token refresh attempt %d/%d failed: %v", attempt, maxAttempts, err) + + // Don't sleep after the last attempt + if attempt < maxAttempts { + select { + case <-ctx.Done(): + return nil, ctx.Err() + case <-time.After(config.RetryDelay): + } + } + } + + return nil, fmt.Errorf("token refresh failed after %d attempts: %w", maxAttempts, lastErr) +} diff --git a/internal/auth/kiro/social_auth.go b/internal/auth/kiro/social_auth.go index 277b83db..65f31ba4 100644 --- a/internal/auth/kiro/social_auth.go +++ b/internal/auth/kiro/social_auth.go @@ -229,7 +229,7 @@ func (c *SocialAuthClient) CreateToken(ctx context.Context, req *CreateTokenRequ } httpReq.Header.Set("Content-Type", "application/json") - httpReq.Header.Set("User-Agent", "cli-proxy-api/1.0.0") + httpReq.Header.Set("User-Agent", "KiroIDE-0.7.45-cli-proxy-api") resp, err := c.httpClient.Do(httpReq) if err != nil { diff --git a/internal/auth/kiro/sso_oidc.go b/internal/auth/kiro/sso_oidc.go index ba15dac9..60fb8871 100644 --- a/internal/auth/kiro/sso_oidc.go +++ b/internal/auth/kiro/sso_oidc.go @@ -684,6 +684,7 @@ func (c *SSOOIDCClient) CreateToken(ctx context.Context, clientID, clientSecret, } // RefreshToken refreshes an access token using the refresh token. +// Includes retry logic and improved error handling for better reliability. func (c *SSOOIDCClient) RefreshToken(ctx context.Context, clientID, clientSecret, refreshToken string) (*KiroTokenData, error) { payload := map[string]string{ "clientId": clientID, @@ -701,8 +702,13 @@ func (c *SSOOIDCClient) RefreshToken(ctx context.Context, clientID, clientSecret if err != nil { return nil, err } + + // Set headers matching Kiro IDE behavior for better compatibility req.Header.Set("Content-Type", "application/json") - req.Header.Set("User-Agent", kiroUserAgent) + req.Header.Set("Host", "oidc.us-east-1.amazonaws.com") + req.Header.Set("x-amz-user-agent", idcAmzUserAgent) + req.Header.Set("User-Agent", "node") + req.Header.Set("Accept", "*/*") resp, err := c.httpClient.Do(req) if err != nil { @@ -716,8 +722,8 @@ func (c *SSOOIDCClient) RefreshToken(ctx context.Context, clientID, clientSecret } if resp.StatusCode != http.StatusOK { - log.Debugf("token refresh failed (status %d): %s", resp.StatusCode, string(respBody)) - return nil, fmt.Errorf("token refresh failed (status %d)", resp.StatusCode) + log.Warnf("token refresh failed (status %d): %s", resp.StatusCode, string(respBody)) + return nil, fmt.Errorf("token refresh failed (status %d): %s", resp.StatusCode, string(respBody)) } var result CreateTokenResponse diff --git a/internal/runtime/executor/kiro_executor.go b/internal/runtime/executor/kiro_executor.go index 57574268..e0362fe5 100644 --- a/internal/runtime/executor/kiro_executor.go +++ b/internal/runtime/executor/kiro_executor.go @@ -1537,11 +1537,27 @@ func determineAgenticMode(model string) (isAgentic, isChatOnly bool) { } // getEffectiveProfileArn determines if profileArn should be included based on auth method. -// profileArn is only needed for social auth (Google OAuth), not for builder-id (AWS SSO). +// profileArn is only needed for social auth (Google OAuth), not for AWS SSO OIDC (Builder ID/IDC). +// +// Detection logic (matching kiro-openai-gateway): +// 1. Check auth_method field: "builder-id" or "idc" +// 2. Check auth_type field: "aws_sso_oidc" (from kiro-cli tokens) +// 3. Check for client_id + client_secret presence (AWS SSO OIDC signature) func getEffectiveProfileArn(auth *cliproxyauth.Auth, profileArn string) string { if auth != nil && auth.Metadata != nil { - if authMethod, ok := auth.Metadata["auth_method"].(string); ok && authMethod == "builder-id" { - return "" // Don't include profileArn for builder-id auth + // Check 1: auth_method field (from CLIProxyAPI tokens) + if authMethod, ok := auth.Metadata["auth_method"].(string); ok && (authMethod == "builder-id" || authMethod == "idc") { + return "" // AWS SSO OIDC - don't include profileArn + } + // Check 2: auth_type field (from kiro-cli tokens) + if authType, ok := auth.Metadata["auth_type"].(string); ok && authType == "aws_sso_oidc" { + return "" // AWS SSO OIDC - don't include profileArn + } + // Check 3: client_id + client_secret presence (AWS SSO OIDC signature) + _, hasClientID := auth.Metadata["client_id"].(string) + _, hasClientSecret := auth.Metadata["client_secret"].(string) + if hasClientID && hasClientSecret { + return "" // AWS SSO OIDC - don't include profileArn } } return profileArn @@ -1550,14 +1566,32 @@ func getEffectiveProfileArn(auth *cliproxyauth.Auth, profileArn string) string { // getEffectiveProfileArnWithWarning determines if profileArn should be included based on auth method, // and logs a warning if profileArn is missing for non-builder-id auth. // This consolidates the auth_method check that was previously done separately. +// +// AWS SSO OIDC (Builder ID/IDC) users don't need profileArn - sending it causes 403 errors. +// Only Kiro Desktop (social auth like Google/GitHub) users need profileArn. +// +// Detection logic (matching kiro-openai-gateway): +// 1. Check auth_method field: "builder-id" or "idc" +// 2. Check auth_type field: "aws_sso_oidc" (from kiro-cli tokens) +// 3. Check for client_id + client_secret presence (AWS SSO OIDC signature) func getEffectiveProfileArnWithWarning(auth *cliproxyauth.Auth, profileArn string) string { if auth != nil && auth.Metadata != nil { + // Check 1: auth_method field (from CLIProxyAPI tokens) if authMethod, ok := auth.Metadata["auth_method"].(string); ok && (authMethod == "builder-id" || authMethod == "idc") { - // builder-id and idc auth don't need profileArn - return "" + return "" // AWS SSO OIDC - don't include profileArn + } + // Check 2: auth_type field (from kiro-cli tokens) + if authType, ok := auth.Metadata["auth_type"].(string); ok && authType == "aws_sso_oidc" { + return "" // AWS SSO OIDC - don't include profileArn + } + // Check 3: client_id + client_secret presence (AWS SSO OIDC signature, like kiro-openai-gateway) + _, hasClientID := auth.Metadata["client_id"].(string) + _, hasClientSecret := auth.Metadata["client_secret"].(string) + if hasClientID && hasClientSecret { + return "" // AWS SSO OIDC - don't include profileArn } } - // For non-builder-id/idc auth (social auth), profileArn is required + // For social auth (Kiro Desktop), profileArn is required if profileArn == "" { log.Warnf("kiro: profile ARN not found in auth, API calls may fail") } diff --git a/sdk/auth/kiro.go b/sdk/auth/kiro.go index f66be461..b6a13265 100644 --- a/sdk/auth/kiro.go +++ b/sdk/auth/kiro.go @@ -2,7 +2,10 @@ package auth import ( "context" + "encoding/json" "fmt" + "os" + "path/filepath" "strings" "time" @@ -279,18 +282,19 @@ func (a *KiroAuthenticator) ImportFromKiroIDE(ctx context.Context, cfg *config.C CreatedAt: now, UpdatedAt: now, Metadata: map[string]any{ - "type": "kiro", - "access_token": tokenData.AccessToken, - "refresh_token": tokenData.RefreshToken, - "profile_arn": tokenData.ProfileArn, - "expires_at": tokenData.ExpiresAt, - "auth_method": tokenData.AuthMethod, - "provider": tokenData.Provider, - "client_id": tokenData.ClientID, - "client_secret": tokenData.ClientSecret, - "email": tokenData.Email, - "region": tokenData.Region, - "start_url": tokenData.StartURL, + "type": "kiro", + "access_token": tokenData.AccessToken, + "refresh_token": tokenData.RefreshToken, + "profile_arn": tokenData.ProfileArn, + "expires_at": tokenData.ExpiresAt, + "auth_method": tokenData.AuthMethod, + "provider": tokenData.Provider, + "client_id": tokenData.ClientID, + "client_secret": tokenData.ClientSecret, + "client_id_hash": tokenData.ClientIDHash, + "email": tokenData.Email, + "region": tokenData.Region, + "start_url": tokenData.StartURL, }, Attributes: map[string]string{ "profile_arn": tokenData.ProfileArn, @@ -325,10 +329,21 @@ func (a *KiroAuthenticator) Refresh(ctx context.Context, cfg *config.Config, aut clientID, _ := auth.Metadata["client_id"].(string) clientSecret, _ := auth.Metadata["client_secret"].(string) + clientIDHash, _ := auth.Metadata["client_id_hash"].(string) authMethod, _ := auth.Metadata["auth_method"].(string) startURL, _ := auth.Metadata["start_url"].(string) region, _ := auth.Metadata["region"].(string) + // For Enterprise Kiro IDE (IDC auth), try to load clientId/clientSecret from device registration + // if they are missing from metadata. This handles the case where token was imported without + // clientId/clientSecret but has clientIdHash. + if (clientID == "" || clientSecret == "") && clientIDHash != "" { + if loadedClientID, loadedClientSecret, err := loadDeviceRegistrationCredentials(clientIDHash); err == nil { + clientID = loadedClientID + clientSecret = loadedClientSecret + } + } + var tokenData *kiroauth.KiroTokenData var err error @@ -339,8 +354,8 @@ func (a *KiroAuthenticator) Refresh(ctx context.Context, cfg *config.Config, aut case clientID != "" && clientSecret != "" && authMethod == "idc" && region != "": // IDC refresh with region-specific endpoint tokenData, err = ssoClient.RefreshTokenWithRegion(ctx, clientID, clientSecret, refreshToken, region, startURL) - case clientID != "" && clientSecret != "" && authMethod == "builder-id": - // Builder ID refresh with default endpoint + case clientID != "" && clientSecret != "" && (authMethod == "builder-id" || authMethod == "idc"): + // Builder ID or IDC refresh with default endpoint (us-east-1) tokenData, err = ssoClient.RefreshToken(ctx, clientID, clientSecret, refreshToken) default: // Fallback to Kiro's refresh endpoint (for social auth: Google/GitHub) @@ -367,8 +382,54 @@ func (a *KiroAuthenticator) Refresh(ctx context.Context, cfg *config.Config, aut updated.Metadata["refresh_token"] = tokenData.RefreshToken updated.Metadata["expires_at"] = tokenData.ExpiresAt updated.Metadata["last_refresh"] = now.Format(time.RFC3339) // For double-check optimization + // Store clientId/clientSecret if they were loaded from device registration + if clientID != "" && updated.Metadata["client_id"] == nil { + updated.Metadata["client_id"] = clientID + } + if clientSecret != "" && updated.Metadata["client_secret"] == nil { + updated.Metadata["client_secret"] = clientSecret + } // NextRefreshAfter: 20 minutes before expiry updated.NextRefreshAfter = expiresAt.Add(-20 * time.Minute) return updated, nil } + +// loadDeviceRegistrationCredentials loads clientId and clientSecret from device registration file. +// This is used when refreshing tokens that were imported without clientId/clientSecret. +func loadDeviceRegistrationCredentials(clientIDHash string) (clientID, clientSecret string, err error) { + if clientIDHash == "" { + return "", "", fmt.Errorf("clientIdHash is empty") + } + + // Sanitize clientIdHash to prevent path traversal + if strings.Contains(clientIDHash, "/") || strings.Contains(clientIDHash, "\\") || strings.Contains(clientIDHash, "..") { + return "", "", fmt.Errorf("invalid clientIdHash: contains path separator") + } + + homeDir, err := os.UserHomeDir() + if err != nil { + return "", "", fmt.Errorf("failed to get home directory: %w", err) + } + + deviceRegPath := filepath.Join(homeDir, ".aws", "sso", "cache", clientIDHash+".json") + data, err := os.ReadFile(deviceRegPath) + if err != nil { + return "", "", fmt.Errorf("failed to read device registration file: %w", err) + } + + var deviceReg struct { + ClientID string `json:"clientId"` + ClientSecret string `json:"clientSecret"` + } + + if err := json.Unmarshal(data, &deviceReg); err != nil { + return "", "", fmt.Errorf("failed to parse device registration: %w", err) + } + + if deviceReg.ClientID == "" || deviceReg.ClientSecret == "" { + return "", "", fmt.Errorf("device registration missing clientId or clientSecret") + } + + return deviceReg.ClientID, deviceReg.ClientSecret, nil +}