mirror of
https://github.com/router-for-me/CLIProxyAPIPlus.git
synced 2026-04-26 02:56:11 +00:00
fix: Implement graceful token refresh degradation and enhance IDC SSO support with device registration loading for Kiro.
This commit is contained in:
@@ -32,14 +32,17 @@ type KiroTokenData struct {
|
|||||||
ProfileArn string `json:"profileArn"`
|
ProfileArn string `json:"profileArn"`
|
||||||
// ExpiresAt is the timestamp when the token expires
|
// ExpiresAt is the timestamp when the token expires
|
||||||
ExpiresAt string `json:"expiresAt"`
|
ExpiresAt string `json:"expiresAt"`
|
||||||
// AuthMethod indicates the authentication method used (e.g., "builder-id", "social")
|
// AuthMethod indicates the authentication method used (e.g., "builder-id", "social", "idc")
|
||||||
AuthMethod string `json:"authMethod"`
|
AuthMethod string `json:"authMethod"`
|
||||||
// Provider indicates the OAuth provider (e.g., "AWS", "Google")
|
// Provider indicates the OAuth provider (e.g., "AWS", "Google", "Enterprise")
|
||||||
Provider string `json:"provider"`
|
Provider string `json:"provider"`
|
||||||
// ClientID is the OIDC client ID (needed for token refresh)
|
// ClientID is the OIDC client ID (needed for token refresh)
|
||||||
ClientID string `json:"clientId,omitempty"`
|
ClientID string `json:"clientId,omitempty"`
|
||||||
// ClientSecret is the OIDC client secret (needed for token refresh)
|
// ClientSecret is the OIDC client secret (needed for token refresh)
|
||||||
ClientSecret string `json:"clientSecret,omitempty"`
|
ClientSecret string `json:"clientSecret,omitempty"`
|
||||||
|
// ClientIDHash is the hash of client ID used to locate device registration file
|
||||||
|
// (Enterprise Kiro IDE stores clientId/clientSecret in ~/.aws/sso/cache/{clientIdHash}.json)
|
||||||
|
ClientIDHash string `json:"clientIdHash,omitempty"`
|
||||||
// Email is the user's email address (used for file naming)
|
// Email is the user's email address (used for file naming)
|
||||||
Email string `json:"email,omitempty"`
|
Email string `json:"email,omitempty"`
|
||||||
// StartURL is the IDC/Identity Center start URL (only for IDC auth method)
|
// StartURL is the IDC/Identity Center start URL (only for IDC auth method)
|
||||||
@@ -169,6 +172,8 @@ func LoadKiroIDETokenWithRetry(maxAttempts int, baseDelay time.Duration) (*KiroT
|
|||||||
}
|
}
|
||||||
|
|
||||||
// LoadKiroIDEToken loads token data from Kiro IDE's token file.
|
// LoadKiroIDEToken loads token data from Kiro IDE's token file.
|
||||||
|
// For Enterprise Kiro IDE (IDC auth), it also loads clientId and clientSecret
|
||||||
|
// from the device registration file referenced by clientIdHash.
|
||||||
func LoadKiroIDEToken() (*KiroTokenData, error) {
|
func LoadKiroIDEToken() (*KiroTokenData, error) {
|
||||||
homeDir, err := os.UserHomeDir()
|
homeDir, err := os.UserHomeDir()
|
||||||
if err != nil {
|
if err != nil {
|
||||||
@@ -193,18 +198,69 @@ func LoadKiroIDEToken() (*KiroTokenData, error) {
|
|||||||
// Normalize AuthMethod to lowercase (Kiro IDE uses "IdC" but we expect "idc")
|
// Normalize AuthMethod to lowercase (Kiro IDE uses "IdC" but we expect "idc")
|
||||||
token.AuthMethod = strings.ToLower(token.AuthMethod)
|
token.AuthMethod = strings.ToLower(token.AuthMethod)
|
||||||
|
|
||||||
|
// For Enterprise Kiro IDE (IDC auth), load clientId and clientSecret from device registration
|
||||||
|
// The device registration file is located at ~/.aws/sso/cache/{clientIdHash}.json
|
||||||
|
if token.ClientIDHash != "" && token.ClientID == "" {
|
||||||
|
if err := loadDeviceRegistration(homeDir, token.ClientIDHash, &token); err != nil {
|
||||||
|
// Log warning but don't fail - token might still work for some operations
|
||||||
|
fmt.Printf("warning: failed to load device registration for clientIdHash %s: %v\n", token.ClientIDHash, err)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
return &token, nil
|
return &token, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// loadDeviceRegistration loads clientId and clientSecret from the device registration file.
|
||||||
|
// Enterprise Kiro IDE stores these in ~/.aws/sso/cache/{clientIdHash}.json
|
||||||
|
func loadDeviceRegistration(homeDir, clientIDHash string, token *KiroTokenData) error {
|
||||||
|
if clientIDHash == "" {
|
||||||
|
return fmt.Errorf("clientIdHash is empty")
|
||||||
|
}
|
||||||
|
|
||||||
|
// Sanitize clientIdHash to prevent path traversal
|
||||||
|
if strings.Contains(clientIDHash, "/") || strings.Contains(clientIDHash, "\\") || strings.Contains(clientIDHash, "..") {
|
||||||
|
return fmt.Errorf("invalid clientIdHash: contains path separator")
|
||||||
|
}
|
||||||
|
|
||||||
|
deviceRegPath := filepath.Join(homeDir, ".aws", "sso", "cache", clientIDHash+".json")
|
||||||
|
data, err := os.ReadFile(deviceRegPath)
|
||||||
|
if err != nil {
|
||||||
|
return fmt.Errorf("failed to read device registration file (%s): %w", deviceRegPath, err)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Device registration file structure
|
||||||
|
var deviceReg struct {
|
||||||
|
ClientID string `json:"clientId"`
|
||||||
|
ClientSecret string `json:"clientSecret"`
|
||||||
|
ExpiresAt string `json:"expiresAt"`
|
||||||
|
}
|
||||||
|
|
||||||
|
if err := json.Unmarshal(data, &deviceReg); err != nil {
|
||||||
|
return fmt.Errorf("failed to parse device registration: %w", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
if deviceReg.ClientID == "" || deviceReg.ClientSecret == "" {
|
||||||
|
return fmt.Errorf("device registration missing clientId or clientSecret")
|
||||||
|
}
|
||||||
|
|
||||||
|
token.ClientID = deviceReg.ClientID
|
||||||
|
token.ClientSecret = deviceReg.ClientSecret
|
||||||
|
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
// LoadKiroTokenFromPath loads token data from a custom path.
|
// LoadKiroTokenFromPath loads token data from a custom path.
|
||||||
// This supports multiple accounts by allowing different token files.
|
// This supports multiple accounts by allowing different token files.
|
||||||
|
// For Enterprise Kiro IDE (IDC auth), it also loads clientId and clientSecret
|
||||||
|
// from the device registration file referenced by clientIdHash.
|
||||||
func LoadKiroTokenFromPath(tokenPath string) (*KiroTokenData, error) {
|
func LoadKiroTokenFromPath(tokenPath string) (*KiroTokenData, error) {
|
||||||
|
homeDir, err := os.UserHomeDir()
|
||||||
|
if err != nil {
|
||||||
|
return nil, fmt.Errorf("failed to get home directory: %w", err)
|
||||||
|
}
|
||||||
|
|
||||||
// Expand ~ to home directory
|
// Expand ~ to home directory
|
||||||
if len(tokenPath) > 0 && tokenPath[0] == '~' {
|
if len(tokenPath) > 0 && tokenPath[0] == '~' {
|
||||||
homeDir, err := os.UserHomeDir()
|
|
||||||
if err != nil {
|
|
||||||
return nil, fmt.Errorf("failed to get home directory: %w", err)
|
|
||||||
}
|
|
||||||
tokenPath = filepath.Join(homeDir, tokenPath[1:])
|
tokenPath = filepath.Join(homeDir, tokenPath[1:])
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -225,6 +281,14 @@ func LoadKiroTokenFromPath(tokenPath string) (*KiroTokenData, error) {
|
|||||||
// Normalize AuthMethod to lowercase (Kiro IDE uses "IdC" but we expect "idc")
|
// Normalize AuthMethod to lowercase (Kiro IDE uses "IdC" but we expect "idc")
|
||||||
token.AuthMethod = strings.ToLower(token.AuthMethod)
|
token.AuthMethod = strings.ToLower(token.AuthMethod)
|
||||||
|
|
||||||
|
// For Enterprise Kiro IDE (IDC auth), load clientId and clientSecret from device registration
|
||||||
|
if token.ClientIDHash != "" && token.ClientID == "" {
|
||||||
|
if err := loadDeviceRegistration(homeDir, token.ClientIDHash, &token); err != nil {
|
||||||
|
// Log warning but don't fail - token might still work for some operations
|
||||||
|
fmt.Printf("warning: failed to load device registration for clientIdHash %s: %v\n", token.ClientIDHash, err)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
return &token, nil
|
return &token, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|||||||
@@ -161,40 +161,56 @@ func (r *BackgroundRefresher) refreshBatch(ctx context.Context) {
|
|||||||
}
|
}
|
||||||
|
|
||||||
func (r *BackgroundRefresher) refreshSingle(ctx context.Context, token *Token) {
|
func (r *BackgroundRefresher) refreshSingle(ctx context.Context, token *Token) {
|
||||||
var newTokenData *KiroTokenData
|
// Create refresh function based on auth method
|
||||||
var err error
|
refreshFunc := func(ctx context.Context) (*KiroTokenData, error) {
|
||||||
|
switch token.AuthMethod {
|
||||||
// Normalize auth method to lowercase for case-insensitive matching
|
case "idc":
|
||||||
authMethod := strings.ToLower(token.AuthMethod)
|
return r.ssoClient.RefreshTokenWithRegion(
|
||||||
|
ctx,
|
||||||
switch authMethod {
|
token.ClientID,
|
||||||
case "idc":
|
token.ClientSecret,
|
||||||
newTokenData, err = r.ssoClient.RefreshTokenWithRegion(
|
token.RefreshToken,
|
||||||
ctx,
|
token.Region,
|
||||||
token.ClientID,
|
token.StartURL,
|
||||||
token.ClientSecret,
|
)
|
||||||
token.RefreshToken,
|
case "builder-id":
|
||||||
token.Region,
|
return r.ssoClient.RefreshToken(
|
||||||
token.StartURL,
|
ctx,
|
||||||
)
|
token.ClientID,
|
||||||
case "builder-id":
|
token.ClientSecret,
|
||||||
newTokenData, err = r.ssoClient.RefreshToken(
|
token.RefreshToken,
|
||||||
ctx,
|
)
|
||||||
token.ClientID,
|
default:
|
||||||
token.ClientSecret,
|
return r.oauth.RefreshTokenWithFingerprint(ctx, token.RefreshToken, token.ID)
|
||||||
token.RefreshToken,
|
}
|
||||||
)
|
|
||||||
default:
|
|
||||||
newTokenData, err = r.oauth.RefreshToken(ctx, token.RefreshToken)
|
|
||||||
}
|
}
|
||||||
|
|
||||||
if err != nil {
|
// Use graceful degradation for better reliability
|
||||||
log.Printf("failed to refresh token %s: %v", token.ID, err)
|
result := RefreshWithGracefulDegradation(
|
||||||
|
ctx,
|
||||||
|
refreshFunc,
|
||||||
|
token.AccessToken,
|
||||||
|
token.ExpiresAt,
|
||||||
|
)
|
||||||
|
|
||||||
|
if result.Error != nil {
|
||||||
|
log.Printf("failed to refresh token %s: %v", token.ID, result.Error)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
newTokenData := result.TokenData
|
||||||
|
if result.UsedFallback {
|
||||||
|
log.Printf("token %s: using existing token as fallback (refresh failed but token still valid)", token.ID)
|
||||||
|
// Don't update the token file if we're using fallback
|
||||||
|
// Just update LastVerified to prevent immediate re-check
|
||||||
|
token.LastVerified = time.Now()
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
token.AccessToken = newTokenData.AccessToken
|
token.AccessToken = newTokenData.AccessToken
|
||||||
token.RefreshToken = newTokenData.RefreshToken
|
if newTokenData.RefreshToken != "" {
|
||||||
|
token.RefreshToken = newTokenData.RefreshToken
|
||||||
|
}
|
||||||
token.LastVerified = time.Now()
|
token.LastVerified = time.Now()
|
||||||
|
|
||||||
if newTokenData.ExpiresAt != "" {
|
if newTokenData.ExpiresAt != "" {
|
||||||
|
|||||||
@@ -190,7 +190,7 @@ func (o *KiroOAuth) exchangeCodeForToken(ctx context.Context, code, codeVerifier
|
|||||||
}
|
}
|
||||||
|
|
||||||
req.Header.Set("Content-Type", "application/json")
|
req.Header.Set("Content-Type", "application/json")
|
||||||
req.Header.Set("User-Agent", "cli-proxy-api/1.0.0")
|
req.Header.Set("User-Agent", "KiroIDE-0.7.45-cli-proxy-api")
|
||||||
|
|
||||||
resp, err := o.httpClient.Do(req)
|
resp, err := o.httpClient.Do(req)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
@@ -232,7 +232,14 @@ func (o *KiroOAuth) exchangeCodeForToken(ctx context.Context, code, codeVerifier
|
|||||||
}
|
}
|
||||||
|
|
||||||
// RefreshToken refreshes an expired access token.
|
// RefreshToken refreshes an expired access token.
|
||||||
|
// Uses KiroIDE-style User-Agent to match official Kiro IDE behavior.
|
||||||
func (o *KiroOAuth) RefreshToken(ctx context.Context, refreshToken string) (*KiroTokenData, error) {
|
func (o *KiroOAuth) RefreshToken(ctx context.Context, refreshToken string) (*KiroTokenData, error) {
|
||||||
|
return o.RefreshTokenWithFingerprint(ctx, refreshToken, "")
|
||||||
|
}
|
||||||
|
|
||||||
|
// RefreshTokenWithFingerprint refreshes an expired access token with a specific fingerprint.
|
||||||
|
// tokenKey is used to generate a consistent fingerprint for the token.
|
||||||
|
func (o *KiroOAuth) RefreshTokenWithFingerprint(ctx context.Context, refreshToken, tokenKey string) (*KiroTokenData, error) {
|
||||||
payload := map[string]string{
|
payload := map[string]string{
|
||||||
"refreshToken": refreshToken,
|
"refreshToken": refreshToken,
|
||||||
}
|
}
|
||||||
@@ -249,7 +256,11 @@ func (o *KiroOAuth) RefreshToken(ctx context.Context, refreshToken string) (*Kir
|
|||||||
}
|
}
|
||||||
|
|
||||||
req.Header.Set("Content-Type", "application/json")
|
req.Header.Set("Content-Type", "application/json")
|
||||||
req.Header.Set("User-Agent", "cli-proxy-api/1.0.0")
|
|
||||||
|
// Use KiroIDE-style User-Agent to match official Kiro IDE behavior
|
||||||
|
// This helps avoid 403 errors from server-side User-Agent validation
|
||||||
|
userAgent := buildKiroUserAgent(tokenKey)
|
||||||
|
req.Header.Set("User-Agent", userAgent)
|
||||||
|
|
||||||
resp, err := o.httpClient.Do(req)
|
resp, err := o.httpClient.Do(req)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
@@ -264,7 +275,7 @@ func (o *KiroOAuth) RefreshToken(ctx context.Context, refreshToken string) (*Kir
|
|||||||
|
|
||||||
if resp.StatusCode != http.StatusOK {
|
if resp.StatusCode != http.StatusOK {
|
||||||
log.Debugf("token refresh failed (status %d): %s", resp.StatusCode, string(respBody))
|
log.Debugf("token refresh failed (status %d): %s", resp.StatusCode, string(respBody))
|
||||||
return nil, fmt.Errorf("token refresh failed (status %d)", resp.StatusCode)
|
return nil, fmt.Errorf("token refresh failed (status %d): %s", resp.StatusCode, string(respBody))
|
||||||
}
|
}
|
||||||
|
|
||||||
var tokenResp KiroTokenResponse
|
var tokenResp KiroTokenResponse
|
||||||
@@ -290,6 +301,19 @@ func (o *KiroOAuth) RefreshToken(ctx context.Context, refreshToken string) (*Kir
|
|||||||
}, nil
|
}, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// buildKiroUserAgent builds a KiroIDE-style User-Agent string.
|
||||||
|
// If tokenKey is provided, uses fingerprint manager for consistent fingerprint.
|
||||||
|
// Otherwise generates a simple KiroIDE User-Agent.
|
||||||
|
func buildKiroUserAgent(tokenKey string) string {
|
||||||
|
if tokenKey != "" {
|
||||||
|
fm := NewFingerprintManager()
|
||||||
|
fp := fm.GetFingerprint(tokenKey)
|
||||||
|
return fmt.Sprintf("KiroIDE-%s-%s", fp.KiroVersion, fp.KiroHash[:16])
|
||||||
|
}
|
||||||
|
// Default KiroIDE User-Agent matching kiro-openai-gateway format
|
||||||
|
return "KiroIDE-0.7.45-cli-proxy-api"
|
||||||
|
}
|
||||||
|
|
||||||
// LoginWithGoogle performs OAuth login with Google using Kiro's social auth.
|
// LoginWithGoogle performs OAuth login with Google using Kiro's social auth.
|
||||||
// This uses a custom protocol handler (kiro://) to receive the callback.
|
// This uses a custom protocol handler (kiro://) to receive the callback.
|
||||||
func (o *KiroOAuth) LoginWithGoogle(ctx context.Context) (*KiroTokenData, error) {
|
func (o *KiroOAuth) LoginWithGoogle(ctx context.Context) (*KiroTokenData, error) {
|
||||||
|
|||||||
@@ -9,14 +9,14 @@ import (
|
|||||||
)
|
)
|
||||||
|
|
||||||
const (
|
const (
|
||||||
DefaultMinTokenInterval = 10 * time.Second
|
DefaultMinTokenInterval = 1 * time.Second
|
||||||
DefaultMaxTokenInterval = 30 * time.Second
|
DefaultMaxTokenInterval = 2 * time.Second
|
||||||
DefaultDailyMaxRequests = 500
|
DefaultDailyMaxRequests = 500
|
||||||
DefaultJitterPercent = 0.3
|
DefaultJitterPercent = 0.3
|
||||||
DefaultBackoffBase = 2 * time.Minute
|
DefaultBackoffBase = 30 * time.Second
|
||||||
DefaultBackoffMax = 60 * time.Minute
|
DefaultBackoffMax = 5 * time.Minute
|
||||||
DefaultBackoffMultiplier = 2.0
|
DefaultBackoffMultiplier = 1.5
|
||||||
DefaultSuspendCooldown = 24 * time.Hour
|
DefaultSuspendCooldown = 1 * time.Hour
|
||||||
)
|
)
|
||||||
|
|
||||||
// TokenState Token 状态
|
// TokenState Token 状态
|
||||||
|
|||||||
159
internal/auth/kiro/refresh_utils.go
Normal file
159
internal/auth/kiro/refresh_utils.go
Normal file
@@ -0,0 +1,159 @@
|
|||||||
|
// Package kiro provides refresh utilities for Kiro token management.
|
||||||
|
package kiro
|
||||||
|
|
||||||
|
import (
|
||||||
|
"context"
|
||||||
|
"fmt"
|
||||||
|
"time"
|
||||||
|
|
||||||
|
log "github.com/sirupsen/logrus"
|
||||||
|
)
|
||||||
|
|
||||||
|
// RefreshResult contains the result of a token refresh attempt.
|
||||||
|
type RefreshResult struct {
|
||||||
|
TokenData *KiroTokenData
|
||||||
|
Error error
|
||||||
|
UsedFallback bool // True if we used the existing token as fallback
|
||||||
|
}
|
||||||
|
|
||||||
|
// RefreshWithGracefulDegradation attempts to refresh a token with graceful degradation.
|
||||||
|
// If refresh fails but the existing access token is still valid, it returns the existing token.
|
||||||
|
// This matches kiro-openai-gateway's behavior for better reliability.
|
||||||
|
//
|
||||||
|
// Parameters:
|
||||||
|
// - ctx: Context for the request
|
||||||
|
// - refreshFunc: Function to perform the actual refresh
|
||||||
|
// - existingAccessToken: Current access token (for fallback)
|
||||||
|
// - expiresAt: Expiration time of the existing token
|
||||||
|
//
|
||||||
|
// Returns:
|
||||||
|
// - RefreshResult containing the new or existing token data
|
||||||
|
func RefreshWithGracefulDegradation(
|
||||||
|
ctx context.Context,
|
||||||
|
refreshFunc func(ctx context.Context) (*KiroTokenData, error),
|
||||||
|
existingAccessToken string,
|
||||||
|
expiresAt time.Time,
|
||||||
|
) RefreshResult {
|
||||||
|
// Try to refresh the token
|
||||||
|
newTokenData, err := refreshFunc(ctx)
|
||||||
|
if err == nil {
|
||||||
|
return RefreshResult{
|
||||||
|
TokenData: newTokenData,
|
||||||
|
Error: nil,
|
||||||
|
UsedFallback: false,
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// Refresh failed - check if we can use the existing token
|
||||||
|
log.Warnf("kiro: token refresh failed: %v", err)
|
||||||
|
|
||||||
|
// Check if existing token is still valid (not expired)
|
||||||
|
if existingAccessToken != "" && time.Now().Before(expiresAt) {
|
||||||
|
remainingTime := time.Until(expiresAt)
|
||||||
|
log.Warnf("kiro: using existing access token (expires in %v). Will retry refresh later.", remainingTime.Round(time.Second))
|
||||||
|
|
||||||
|
return RefreshResult{
|
||||||
|
TokenData: &KiroTokenData{
|
||||||
|
AccessToken: existingAccessToken,
|
||||||
|
ExpiresAt: expiresAt.Format(time.RFC3339),
|
||||||
|
},
|
||||||
|
Error: nil,
|
||||||
|
UsedFallback: true,
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// Token is expired and refresh failed - return the error
|
||||||
|
return RefreshResult{
|
||||||
|
TokenData: nil,
|
||||||
|
Error: fmt.Errorf("token refresh failed and existing token is expired: %w", err),
|
||||||
|
UsedFallback: false,
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// IsTokenExpiringSoon checks if a token is expiring within the given threshold.
|
||||||
|
// Default threshold is 5 minutes if not specified.
|
||||||
|
func IsTokenExpiringSoon(expiresAt time.Time, threshold time.Duration) bool {
|
||||||
|
if threshold == 0 {
|
||||||
|
threshold = 5 * time.Minute
|
||||||
|
}
|
||||||
|
return time.Now().Add(threshold).After(expiresAt)
|
||||||
|
}
|
||||||
|
|
||||||
|
// IsTokenExpired checks if a token has already expired.
|
||||||
|
func IsTokenExpired(expiresAt time.Time) bool {
|
||||||
|
return time.Now().After(expiresAt)
|
||||||
|
}
|
||||||
|
|
||||||
|
// ParseExpiresAt parses an expiration time string in RFC3339 format.
|
||||||
|
// Returns zero time if parsing fails.
|
||||||
|
func ParseExpiresAt(expiresAtStr string) time.Time {
|
||||||
|
if expiresAtStr == "" {
|
||||||
|
return time.Time{}
|
||||||
|
}
|
||||||
|
t, err := time.Parse(time.RFC3339, expiresAtStr)
|
||||||
|
if err != nil {
|
||||||
|
log.Debugf("kiro: failed to parse expiresAt '%s': %v", expiresAtStr, err)
|
||||||
|
return time.Time{}
|
||||||
|
}
|
||||||
|
return t
|
||||||
|
}
|
||||||
|
|
||||||
|
// RefreshConfig contains configuration for token refresh behavior.
|
||||||
|
type RefreshConfig struct {
|
||||||
|
// MaxRetries is the maximum number of refresh attempts (default: 1)
|
||||||
|
MaxRetries int
|
||||||
|
// RetryDelay is the delay between retry attempts (default: 1 second)
|
||||||
|
RetryDelay time.Duration
|
||||||
|
// RefreshThreshold is how early to refresh before expiration (default: 5 minutes)
|
||||||
|
RefreshThreshold time.Duration
|
||||||
|
// EnableGracefulDegradation allows using existing token if refresh fails (default: true)
|
||||||
|
EnableGracefulDegradation bool
|
||||||
|
}
|
||||||
|
|
||||||
|
// DefaultRefreshConfig returns the default refresh configuration.
|
||||||
|
func DefaultRefreshConfig() RefreshConfig {
|
||||||
|
return RefreshConfig{
|
||||||
|
MaxRetries: 1,
|
||||||
|
RetryDelay: time.Second,
|
||||||
|
RefreshThreshold: 5 * time.Minute,
|
||||||
|
EnableGracefulDegradation: true,
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// RefreshWithRetry attempts to refresh a token with retry logic.
|
||||||
|
func RefreshWithRetry(
|
||||||
|
ctx context.Context,
|
||||||
|
refreshFunc func(ctx context.Context) (*KiroTokenData, error),
|
||||||
|
config RefreshConfig,
|
||||||
|
) (*KiroTokenData, error) {
|
||||||
|
var lastErr error
|
||||||
|
|
||||||
|
maxAttempts := config.MaxRetries + 1
|
||||||
|
if maxAttempts < 1 {
|
||||||
|
maxAttempts = 1
|
||||||
|
}
|
||||||
|
|
||||||
|
for attempt := 1; attempt <= maxAttempts; attempt++ {
|
||||||
|
tokenData, err := refreshFunc(ctx)
|
||||||
|
if err == nil {
|
||||||
|
if attempt > 1 {
|
||||||
|
log.Infof("kiro: token refresh succeeded on attempt %d", attempt)
|
||||||
|
}
|
||||||
|
return tokenData, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
lastErr = err
|
||||||
|
log.Warnf("kiro: token refresh attempt %d/%d failed: %v", attempt, maxAttempts, err)
|
||||||
|
|
||||||
|
// Don't sleep after the last attempt
|
||||||
|
if attempt < maxAttempts {
|
||||||
|
select {
|
||||||
|
case <-ctx.Done():
|
||||||
|
return nil, ctx.Err()
|
||||||
|
case <-time.After(config.RetryDelay):
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
return nil, fmt.Errorf("token refresh failed after %d attempts: %w", maxAttempts, lastErr)
|
||||||
|
}
|
||||||
@@ -229,7 +229,7 @@ func (c *SocialAuthClient) CreateToken(ctx context.Context, req *CreateTokenRequ
|
|||||||
}
|
}
|
||||||
|
|
||||||
httpReq.Header.Set("Content-Type", "application/json")
|
httpReq.Header.Set("Content-Type", "application/json")
|
||||||
httpReq.Header.Set("User-Agent", "cli-proxy-api/1.0.0")
|
httpReq.Header.Set("User-Agent", "KiroIDE-0.7.45-cli-proxy-api")
|
||||||
|
|
||||||
resp, err := c.httpClient.Do(httpReq)
|
resp, err := c.httpClient.Do(httpReq)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
|
|||||||
@@ -684,6 +684,7 @@ func (c *SSOOIDCClient) CreateToken(ctx context.Context, clientID, clientSecret,
|
|||||||
}
|
}
|
||||||
|
|
||||||
// RefreshToken refreshes an access token using the refresh token.
|
// RefreshToken refreshes an access token using the refresh token.
|
||||||
|
// Includes retry logic and improved error handling for better reliability.
|
||||||
func (c *SSOOIDCClient) RefreshToken(ctx context.Context, clientID, clientSecret, refreshToken string) (*KiroTokenData, error) {
|
func (c *SSOOIDCClient) RefreshToken(ctx context.Context, clientID, clientSecret, refreshToken string) (*KiroTokenData, error) {
|
||||||
payload := map[string]string{
|
payload := map[string]string{
|
||||||
"clientId": clientID,
|
"clientId": clientID,
|
||||||
@@ -701,8 +702,13 @@ func (c *SSOOIDCClient) RefreshToken(ctx context.Context, clientID, clientSecret
|
|||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// Set headers matching Kiro IDE behavior for better compatibility
|
||||||
req.Header.Set("Content-Type", "application/json")
|
req.Header.Set("Content-Type", "application/json")
|
||||||
req.Header.Set("User-Agent", kiroUserAgent)
|
req.Header.Set("Host", "oidc.us-east-1.amazonaws.com")
|
||||||
|
req.Header.Set("x-amz-user-agent", idcAmzUserAgent)
|
||||||
|
req.Header.Set("User-Agent", "node")
|
||||||
|
req.Header.Set("Accept", "*/*")
|
||||||
|
|
||||||
resp, err := c.httpClient.Do(req)
|
resp, err := c.httpClient.Do(req)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
@@ -716,8 +722,8 @@ func (c *SSOOIDCClient) RefreshToken(ctx context.Context, clientID, clientSecret
|
|||||||
}
|
}
|
||||||
|
|
||||||
if resp.StatusCode != http.StatusOK {
|
if resp.StatusCode != http.StatusOK {
|
||||||
log.Debugf("token refresh failed (status %d): %s", resp.StatusCode, string(respBody))
|
log.Warnf("token refresh failed (status %d): %s", resp.StatusCode, string(respBody))
|
||||||
return nil, fmt.Errorf("token refresh failed (status %d)", resp.StatusCode)
|
return nil, fmt.Errorf("token refresh failed (status %d): %s", resp.StatusCode, string(respBody))
|
||||||
}
|
}
|
||||||
|
|
||||||
var result CreateTokenResponse
|
var result CreateTokenResponse
|
||||||
|
|||||||
@@ -1537,11 +1537,27 @@ func determineAgenticMode(model string) (isAgentic, isChatOnly bool) {
|
|||||||
}
|
}
|
||||||
|
|
||||||
// getEffectiveProfileArn determines if profileArn should be included based on auth method.
|
// getEffectiveProfileArn determines if profileArn should be included based on auth method.
|
||||||
// profileArn is only needed for social auth (Google OAuth), not for builder-id (AWS SSO).
|
// profileArn is only needed for social auth (Google OAuth), not for AWS SSO OIDC (Builder ID/IDC).
|
||||||
|
//
|
||||||
|
// Detection logic (matching kiro-openai-gateway):
|
||||||
|
// 1. Check auth_method field: "builder-id" or "idc"
|
||||||
|
// 2. Check auth_type field: "aws_sso_oidc" (from kiro-cli tokens)
|
||||||
|
// 3. Check for client_id + client_secret presence (AWS SSO OIDC signature)
|
||||||
func getEffectiveProfileArn(auth *cliproxyauth.Auth, profileArn string) string {
|
func getEffectiveProfileArn(auth *cliproxyauth.Auth, profileArn string) string {
|
||||||
if auth != nil && auth.Metadata != nil {
|
if auth != nil && auth.Metadata != nil {
|
||||||
if authMethod, ok := auth.Metadata["auth_method"].(string); ok && authMethod == "builder-id" {
|
// Check 1: auth_method field (from CLIProxyAPI tokens)
|
||||||
return "" // Don't include profileArn for builder-id auth
|
if authMethod, ok := auth.Metadata["auth_method"].(string); ok && (authMethod == "builder-id" || authMethod == "idc") {
|
||||||
|
return "" // AWS SSO OIDC - don't include profileArn
|
||||||
|
}
|
||||||
|
// Check 2: auth_type field (from kiro-cli tokens)
|
||||||
|
if authType, ok := auth.Metadata["auth_type"].(string); ok && authType == "aws_sso_oidc" {
|
||||||
|
return "" // AWS SSO OIDC - don't include profileArn
|
||||||
|
}
|
||||||
|
// Check 3: client_id + client_secret presence (AWS SSO OIDC signature)
|
||||||
|
_, hasClientID := auth.Metadata["client_id"].(string)
|
||||||
|
_, hasClientSecret := auth.Metadata["client_secret"].(string)
|
||||||
|
if hasClientID && hasClientSecret {
|
||||||
|
return "" // AWS SSO OIDC - don't include profileArn
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
return profileArn
|
return profileArn
|
||||||
@@ -1550,14 +1566,32 @@ func getEffectiveProfileArn(auth *cliproxyauth.Auth, profileArn string) string {
|
|||||||
// getEffectiveProfileArnWithWarning determines if profileArn should be included based on auth method,
|
// getEffectiveProfileArnWithWarning determines if profileArn should be included based on auth method,
|
||||||
// and logs a warning if profileArn is missing for non-builder-id auth.
|
// and logs a warning if profileArn is missing for non-builder-id auth.
|
||||||
// This consolidates the auth_method check that was previously done separately.
|
// This consolidates the auth_method check that was previously done separately.
|
||||||
|
//
|
||||||
|
// AWS SSO OIDC (Builder ID/IDC) users don't need profileArn - sending it causes 403 errors.
|
||||||
|
// Only Kiro Desktop (social auth like Google/GitHub) users need profileArn.
|
||||||
|
//
|
||||||
|
// Detection logic (matching kiro-openai-gateway):
|
||||||
|
// 1. Check auth_method field: "builder-id" or "idc"
|
||||||
|
// 2. Check auth_type field: "aws_sso_oidc" (from kiro-cli tokens)
|
||||||
|
// 3. Check for client_id + client_secret presence (AWS SSO OIDC signature)
|
||||||
func getEffectiveProfileArnWithWarning(auth *cliproxyauth.Auth, profileArn string) string {
|
func getEffectiveProfileArnWithWarning(auth *cliproxyauth.Auth, profileArn string) string {
|
||||||
if auth != nil && auth.Metadata != nil {
|
if auth != nil && auth.Metadata != nil {
|
||||||
|
// Check 1: auth_method field (from CLIProxyAPI tokens)
|
||||||
if authMethod, ok := auth.Metadata["auth_method"].(string); ok && (authMethod == "builder-id" || authMethod == "idc") {
|
if authMethod, ok := auth.Metadata["auth_method"].(string); ok && (authMethod == "builder-id" || authMethod == "idc") {
|
||||||
// builder-id and idc auth don't need profileArn
|
return "" // AWS SSO OIDC - don't include profileArn
|
||||||
return ""
|
}
|
||||||
|
// Check 2: auth_type field (from kiro-cli tokens)
|
||||||
|
if authType, ok := auth.Metadata["auth_type"].(string); ok && authType == "aws_sso_oidc" {
|
||||||
|
return "" // AWS SSO OIDC - don't include profileArn
|
||||||
|
}
|
||||||
|
// Check 3: client_id + client_secret presence (AWS SSO OIDC signature, like kiro-openai-gateway)
|
||||||
|
_, hasClientID := auth.Metadata["client_id"].(string)
|
||||||
|
_, hasClientSecret := auth.Metadata["client_secret"].(string)
|
||||||
|
if hasClientID && hasClientSecret {
|
||||||
|
return "" // AWS SSO OIDC - don't include profileArn
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
// For non-builder-id/idc auth (social auth), profileArn is required
|
// For social auth (Kiro Desktop), profileArn is required
|
||||||
if profileArn == "" {
|
if profileArn == "" {
|
||||||
log.Warnf("kiro: profile ARN not found in auth, API calls may fail")
|
log.Warnf("kiro: profile ARN not found in auth, API calls may fail")
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -2,7 +2,10 @@ package auth
|
|||||||
|
|
||||||
import (
|
import (
|
||||||
"context"
|
"context"
|
||||||
|
"encoding/json"
|
||||||
"fmt"
|
"fmt"
|
||||||
|
"os"
|
||||||
|
"path/filepath"
|
||||||
"strings"
|
"strings"
|
||||||
"time"
|
"time"
|
||||||
|
|
||||||
@@ -279,18 +282,19 @@ func (a *KiroAuthenticator) ImportFromKiroIDE(ctx context.Context, cfg *config.C
|
|||||||
CreatedAt: now,
|
CreatedAt: now,
|
||||||
UpdatedAt: now,
|
UpdatedAt: now,
|
||||||
Metadata: map[string]any{
|
Metadata: map[string]any{
|
||||||
"type": "kiro",
|
"type": "kiro",
|
||||||
"access_token": tokenData.AccessToken,
|
"access_token": tokenData.AccessToken,
|
||||||
"refresh_token": tokenData.RefreshToken,
|
"refresh_token": tokenData.RefreshToken,
|
||||||
"profile_arn": tokenData.ProfileArn,
|
"profile_arn": tokenData.ProfileArn,
|
||||||
"expires_at": tokenData.ExpiresAt,
|
"expires_at": tokenData.ExpiresAt,
|
||||||
"auth_method": tokenData.AuthMethod,
|
"auth_method": tokenData.AuthMethod,
|
||||||
"provider": tokenData.Provider,
|
"provider": tokenData.Provider,
|
||||||
"client_id": tokenData.ClientID,
|
"client_id": tokenData.ClientID,
|
||||||
"client_secret": tokenData.ClientSecret,
|
"client_secret": tokenData.ClientSecret,
|
||||||
"email": tokenData.Email,
|
"client_id_hash": tokenData.ClientIDHash,
|
||||||
"region": tokenData.Region,
|
"email": tokenData.Email,
|
||||||
"start_url": tokenData.StartURL,
|
"region": tokenData.Region,
|
||||||
|
"start_url": tokenData.StartURL,
|
||||||
},
|
},
|
||||||
Attributes: map[string]string{
|
Attributes: map[string]string{
|
||||||
"profile_arn": tokenData.ProfileArn,
|
"profile_arn": tokenData.ProfileArn,
|
||||||
@@ -325,10 +329,21 @@ func (a *KiroAuthenticator) Refresh(ctx context.Context, cfg *config.Config, aut
|
|||||||
|
|
||||||
clientID, _ := auth.Metadata["client_id"].(string)
|
clientID, _ := auth.Metadata["client_id"].(string)
|
||||||
clientSecret, _ := auth.Metadata["client_secret"].(string)
|
clientSecret, _ := auth.Metadata["client_secret"].(string)
|
||||||
|
clientIDHash, _ := auth.Metadata["client_id_hash"].(string)
|
||||||
authMethod, _ := auth.Metadata["auth_method"].(string)
|
authMethod, _ := auth.Metadata["auth_method"].(string)
|
||||||
startURL, _ := auth.Metadata["start_url"].(string)
|
startURL, _ := auth.Metadata["start_url"].(string)
|
||||||
region, _ := auth.Metadata["region"].(string)
|
region, _ := auth.Metadata["region"].(string)
|
||||||
|
|
||||||
|
// For Enterprise Kiro IDE (IDC auth), try to load clientId/clientSecret from device registration
|
||||||
|
// if they are missing from metadata. This handles the case where token was imported without
|
||||||
|
// clientId/clientSecret but has clientIdHash.
|
||||||
|
if (clientID == "" || clientSecret == "") && clientIDHash != "" {
|
||||||
|
if loadedClientID, loadedClientSecret, err := loadDeviceRegistrationCredentials(clientIDHash); err == nil {
|
||||||
|
clientID = loadedClientID
|
||||||
|
clientSecret = loadedClientSecret
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
var tokenData *kiroauth.KiroTokenData
|
var tokenData *kiroauth.KiroTokenData
|
||||||
var err error
|
var err error
|
||||||
|
|
||||||
@@ -339,8 +354,8 @@ func (a *KiroAuthenticator) Refresh(ctx context.Context, cfg *config.Config, aut
|
|||||||
case clientID != "" && clientSecret != "" && authMethod == "idc" && region != "":
|
case clientID != "" && clientSecret != "" && authMethod == "idc" && region != "":
|
||||||
// IDC refresh with region-specific endpoint
|
// IDC refresh with region-specific endpoint
|
||||||
tokenData, err = ssoClient.RefreshTokenWithRegion(ctx, clientID, clientSecret, refreshToken, region, startURL)
|
tokenData, err = ssoClient.RefreshTokenWithRegion(ctx, clientID, clientSecret, refreshToken, region, startURL)
|
||||||
case clientID != "" && clientSecret != "" && authMethod == "builder-id":
|
case clientID != "" && clientSecret != "" && (authMethod == "builder-id" || authMethod == "idc"):
|
||||||
// Builder ID refresh with default endpoint
|
// Builder ID or IDC refresh with default endpoint (us-east-1)
|
||||||
tokenData, err = ssoClient.RefreshToken(ctx, clientID, clientSecret, refreshToken)
|
tokenData, err = ssoClient.RefreshToken(ctx, clientID, clientSecret, refreshToken)
|
||||||
default:
|
default:
|
||||||
// Fallback to Kiro's refresh endpoint (for social auth: Google/GitHub)
|
// Fallback to Kiro's refresh endpoint (for social auth: Google/GitHub)
|
||||||
@@ -367,8 +382,54 @@ func (a *KiroAuthenticator) Refresh(ctx context.Context, cfg *config.Config, aut
|
|||||||
updated.Metadata["refresh_token"] = tokenData.RefreshToken
|
updated.Metadata["refresh_token"] = tokenData.RefreshToken
|
||||||
updated.Metadata["expires_at"] = tokenData.ExpiresAt
|
updated.Metadata["expires_at"] = tokenData.ExpiresAt
|
||||||
updated.Metadata["last_refresh"] = now.Format(time.RFC3339) // For double-check optimization
|
updated.Metadata["last_refresh"] = now.Format(time.RFC3339) // For double-check optimization
|
||||||
|
// Store clientId/clientSecret if they were loaded from device registration
|
||||||
|
if clientID != "" && updated.Metadata["client_id"] == nil {
|
||||||
|
updated.Metadata["client_id"] = clientID
|
||||||
|
}
|
||||||
|
if clientSecret != "" && updated.Metadata["client_secret"] == nil {
|
||||||
|
updated.Metadata["client_secret"] = clientSecret
|
||||||
|
}
|
||||||
// NextRefreshAfter: 20 minutes before expiry
|
// NextRefreshAfter: 20 minutes before expiry
|
||||||
updated.NextRefreshAfter = expiresAt.Add(-20 * time.Minute)
|
updated.NextRefreshAfter = expiresAt.Add(-20 * time.Minute)
|
||||||
|
|
||||||
return updated, nil
|
return updated, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// loadDeviceRegistrationCredentials loads clientId and clientSecret from device registration file.
|
||||||
|
// This is used when refreshing tokens that were imported without clientId/clientSecret.
|
||||||
|
func loadDeviceRegistrationCredentials(clientIDHash string) (clientID, clientSecret string, err error) {
|
||||||
|
if clientIDHash == "" {
|
||||||
|
return "", "", fmt.Errorf("clientIdHash is empty")
|
||||||
|
}
|
||||||
|
|
||||||
|
// Sanitize clientIdHash to prevent path traversal
|
||||||
|
if strings.Contains(clientIDHash, "/") || strings.Contains(clientIDHash, "\\") || strings.Contains(clientIDHash, "..") {
|
||||||
|
return "", "", fmt.Errorf("invalid clientIdHash: contains path separator")
|
||||||
|
}
|
||||||
|
|
||||||
|
homeDir, err := os.UserHomeDir()
|
||||||
|
if err != nil {
|
||||||
|
return "", "", fmt.Errorf("failed to get home directory: %w", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
deviceRegPath := filepath.Join(homeDir, ".aws", "sso", "cache", clientIDHash+".json")
|
||||||
|
data, err := os.ReadFile(deviceRegPath)
|
||||||
|
if err != nil {
|
||||||
|
return "", "", fmt.Errorf("failed to read device registration file: %w", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
var deviceReg struct {
|
||||||
|
ClientID string `json:"clientId"`
|
||||||
|
ClientSecret string `json:"clientSecret"`
|
||||||
|
}
|
||||||
|
|
||||||
|
if err := json.Unmarshal(data, &deviceReg); err != nil {
|
||||||
|
return "", "", fmt.Errorf("failed to parse device registration: %w", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
if deviceReg.ClientID == "" || deviceReg.ClientSecret == "" {
|
||||||
|
return "", "", fmt.Errorf("device registration missing clientId or clientSecret")
|
||||||
|
}
|
||||||
|
|
||||||
|
return deviceReg.ClientID, deviceReg.ClientSecret, nil
|
||||||
|
}
|
||||||
|
|||||||
Reference in New Issue
Block a user