diff --git a/internal/api/server.go b/internal/api/server.go index 4df42ec8..40c41a94 100644 --- a/internal/api/server.go +++ b/internal/api/server.go @@ -23,6 +23,7 @@ import ( "github.com/router-for-me/CLIProxyAPI/v6/internal/api/middleware" "github.com/router-for-me/CLIProxyAPI/v6/internal/api/modules" ampmodule "github.com/router-for-me/CLIProxyAPI/v6/internal/api/modules/amp" + "github.com/router-for-me/CLIProxyAPI/v6/internal/auth/kiro" "github.com/router-for-me/CLIProxyAPI/v6/internal/config" "github.com/router-for-me/CLIProxyAPI/v6/internal/logging" "github.com/router-for-me/CLIProxyAPI/v6/internal/managementasset" @@ -295,6 +296,11 @@ func NewServer(cfg *config.Config, authManager *auth.Manager, accessManager *sdk s.registerManagementRoutes() } + // === CLIProxyAPIPlus 扩展: 注册 Kiro OAuth Web 路由 === + kiroOAuthHandler := kiro.NewOAuthWebHandler(cfg) + kiroOAuthHandler.RegisterRoutes(engine) + log.Info("Kiro OAuth Web routes registered at /v0/oauth/kiro/*") + if optionState.keepAliveEnabled { s.enableKeepAlive(optionState.keepAliveTimeout, optionState.keepAliveOnTimeout) } diff --git a/internal/auth/kiro/aws.go b/internal/auth/kiro/aws.go index ba73af4d..d266b9bf 100644 --- a/internal/auth/kiro/aws.go +++ b/internal/auth/kiro/aws.go @@ -5,10 +5,12 @@ package kiro import ( "encoding/base64" "encoding/json" + "errors" "fmt" "os" "path/filepath" "strings" + "time" ) // PKCECodes holds PKCE verification codes for OAuth2 PKCE flow @@ -85,6 +87,87 @@ type KiroModel struct { // KiroIDETokenFile is the default path to Kiro IDE's token file const KiroIDETokenFile = ".aws/sso/cache/kiro-auth-token.json" +// Default retry configuration for file reading +const ( + defaultTokenReadMaxAttempts = 10 // Maximum retry attempts + defaultTokenReadBaseDelay = 50 * time.Millisecond // Base delay between retries +) + +// isTransientFileError checks if the error is a transient file access error +// that may be resolved by retrying (e.g., file locked by another process on Windows). +func isTransientFileError(err error) bool { + if err == nil { + return false + } + + // Check for OS-level file access errors (Windows sharing violation, etc.) + var pathErr *os.PathError + if errors.As(err, &pathErr) { + // Windows sharing violation (ERROR_SHARING_VIOLATION = 32) + // Windows lock violation (ERROR_LOCK_VIOLATION = 33) + errStr := pathErr.Err.Error() + if strings.Contains(errStr, "being used by another process") || + strings.Contains(errStr, "sharing violation") || + strings.Contains(errStr, "lock violation") { + return true + } + } + + // Check error message for common transient patterns + errMsg := strings.ToLower(err.Error()) + transientPatterns := []string{ + "being used by another process", + "sharing violation", + "lock violation", + "access is denied", + "unexpected end of json", + "unexpected eof", + } + for _, pattern := range transientPatterns { + if strings.Contains(errMsg, pattern) { + return true + } + } + + return false +} + +// LoadKiroIDETokenWithRetry loads token data from Kiro IDE's token file with retry logic. +// This handles transient file access errors (e.g., file locked by Kiro IDE during write). +// maxAttempts: maximum number of retry attempts (default 10 if <= 0) +// baseDelay: base delay between retries with exponential backoff (default 50ms if <= 0) +func LoadKiroIDETokenWithRetry(maxAttempts int, baseDelay time.Duration) (*KiroTokenData, error) { + if maxAttempts <= 0 { + maxAttempts = defaultTokenReadMaxAttempts + } + if baseDelay <= 0 { + baseDelay = defaultTokenReadBaseDelay + } + + var lastErr error + for attempt := 0; attempt < maxAttempts; attempt++ { + token, err := LoadKiroIDEToken() + if err == nil { + return token, nil + } + lastErr = err + + // Only retry for transient errors + if !isTransientFileError(err) { + return nil, err + } + + // Exponential backoff: delay * 2^attempt, capped at 500ms + delay := baseDelay * time.Duration(1< 500*time.Millisecond { + delay = 500 * time.Millisecond + } + time.Sleep(delay) + } + + return nil, fmt.Errorf("failed to read token file after %d attempts: %w", maxAttempts, lastErr) +} + // LoadKiroIDEToken loads token data from Kiro IDE's token file. func LoadKiroIDEToken() (*KiroTokenData, error) { homeDir, err := os.UserHomeDir() diff --git a/internal/auth/kiro/aws.go.bak b/internal/auth/kiro/aws.go.bak new file mode 100644 index 00000000..ba73af4d --- /dev/null +++ b/internal/auth/kiro/aws.go.bak @@ -0,0 +1,305 @@ +// Package kiro provides authentication functionality for AWS CodeWhisperer (Kiro) API. +// It includes interfaces and implementations for token storage and authentication methods. +package kiro + +import ( + "encoding/base64" + "encoding/json" + "fmt" + "os" + "path/filepath" + "strings" +) + +// PKCECodes holds PKCE verification codes for OAuth2 PKCE flow +type PKCECodes struct { + // CodeVerifier is the cryptographically random string used to correlate + // the authorization request to the token request + CodeVerifier string `json:"code_verifier"` + // CodeChallenge is the SHA256 hash of the code verifier, base64url-encoded + CodeChallenge string `json:"code_challenge"` +} + +// KiroTokenData holds OAuth token information from AWS CodeWhisperer (Kiro) +type KiroTokenData struct { + // AccessToken is the OAuth2 access token for API access + AccessToken string `json:"accessToken"` + // RefreshToken is used to obtain new access tokens + RefreshToken string `json:"refreshToken"` + // ProfileArn is the AWS CodeWhisperer profile ARN + ProfileArn string `json:"profileArn"` + // ExpiresAt is the timestamp when the token expires + ExpiresAt string `json:"expiresAt"` + // AuthMethod indicates the authentication method used (e.g., "builder-id", "social") + AuthMethod string `json:"authMethod"` + // Provider indicates the OAuth provider (e.g., "AWS", "Google") + Provider string `json:"provider"` + // ClientID is the OIDC client ID (needed for token refresh) + ClientID string `json:"clientId,omitempty"` + // ClientSecret is the OIDC client secret (needed for token refresh) + ClientSecret string `json:"clientSecret,omitempty"` + // Email is the user's email address (used for file naming) + Email string `json:"email,omitempty"` + // StartURL is the IDC/Identity Center start URL (only for IDC auth method) + StartURL string `json:"startUrl,omitempty"` + // Region is the AWS region for IDC authentication (only for IDC auth method) + Region string `json:"region,omitempty"` +} + +// KiroAuthBundle aggregates authentication data after OAuth flow completion +type KiroAuthBundle struct { + // TokenData contains the OAuth tokens from the authentication flow + TokenData KiroTokenData `json:"token_data"` + // LastRefresh is the timestamp of the last token refresh + LastRefresh string `json:"last_refresh"` +} + +// KiroUsageInfo represents usage information from CodeWhisperer API +type KiroUsageInfo struct { + // SubscriptionTitle is the subscription plan name (e.g., "KIRO FREE") + SubscriptionTitle string `json:"subscription_title"` + // CurrentUsage is the current credit usage + CurrentUsage float64 `json:"current_usage"` + // UsageLimit is the maximum credit limit + UsageLimit float64 `json:"usage_limit"` + // NextReset is the timestamp of the next usage reset + NextReset string `json:"next_reset"` +} + +// KiroModel represents a model available through the CodeWhisperer API +type KiroModel struct { + // ModelID is the unique identifier for the model + ModelID string `json:"modelId"` + // ModelName is the human-readable name + ModelName string `json:"modelName"` + // Description is the model description + Description string `json:"description"` + // RateMultiplier is the credit multiplier for this model + RateMultiplier float64 `json:"rateMultiplier"` + // RateUnit is the unit for rate calculation (e.g., "credit") + RateUnit string `json:"rateUnit"` + // MaxInputTokens is the maximum input token limit + MaxInputTokens int `json:"maxInputTokens,omitempty"` +} + +// KiroIDETokenFile is the default path to Kiro IDE's token file +const KiroIDETokenFile = ".aws/sso/cache/kiro-auth-token.json" + +// LoadKiroIDEToken loads token data from Kiro IDE's token file. +func LoadKiroIDEToken() (*KiroTokenData, error) { + homeDir, err := os.UserHomeDir() + if err != nil { + return nil, fmt.Errorf("failed to get home directory: %w", err) + } + + tokenPath := filepath.Join(homeDir, KiroIDETokenFile) + data, err := os.ReadFile(tokenPath) + if err != nil { + return nil, fmt.Errorf("failed to read Kiro IDE token file (%s): %w", tokenPath, err) + } + + var token KiroTokenData + if err := json.Unmarshal(data, &token); err != nil { + return nil, fmt.Errorf("failed to parse Kiro IDE token: %w", err) + } + + if token.AccessToken == "" { + return nil, fmt.Errorf("access token is empty in Kiro IDE token file") + } + + return &token, nil +} + +// LoadKiroTokenFromPath loads token data from a custom path. +// This supports multiple accounts by allowing different token files. +func LoadKiroTokenFromPath(tokenPath string) (*KiroTokenData, error) { + // Expand ~ to home directory + if len(tokenPath) > 0 && tokenPath[0] == '~' { + homeDir, err := os.UserHomeDir() + if err != nil { + return nil, fmt.Errorf("failed to get home directory: %w", err) + } + tokenPath = filepath.Join(homeDir, tokenPath[1:]) + } + + data, err := os.ReadFile(tokenPath) + if err != nil { + return nil, fmt.Errorf("failed to read token file (%s): %w", tokenPath, err) + } + + var token KiroTokenData + if err := json.Unmarshal(data, &token); err != nil { + return nil, fmt.Errorf("failed to parse token file: %w", err) + } + + if token.AccessToken == "" { + return nil, fmt.Errorf("access token is empty in token file") + } + + return &token, nil +} + +// ListKiroTokenFiles lists all Kiro token files in the cache directory. +// This supports multiple accounts by finding all token files. +func ListKiroTokenFiles() ([]string, error) { + homeDir, err := os.UserHomeDir() + if err != nil { + return nil, fmt.Errorf("failed to get home directory: %w", err) + } + + cacheDir := filepath.Join(homeDir, ".aws", "sso", "cache") + + // Check if directory exists + if _, err := os.Stat(cacheDir); os.IsNotExist(err) { + return nil, nil // No token files + } + + entries, err := os.ReadDir(cacheDir) + if err != nil { + return nil, fmt.Errorf("failed to read cache directory: %w", err) + } + + var tokenFiles []string + for _, entry := range entries { + if entry.IsDir() { + continue + } + name := entry.Name() + // Look for kiro token files only (avoid matching unrelated AWS SSO cache files) + if strings.HasSuffix(name, ".json") && strings.HasPrefix(name, "kiro") { + tokenFiles = append(tokenFiles, filepath.Join(cacheDir, name)) + } + } + + return tokenFiles, nil +} + +// LoadAllKiroTokens loads all Kiro tokens from the cache directory. +// This supports multiple accounts. +func LoadAllKiroTokens() ([]*KiroTokenData, error) { + files, err := ListKiroTokenFiles() + if err != nil { + return nil, err + } + + var tokens []*KiroTokenData + for _, file := range files { + token, err := LoadKiroTokenFromPath(file) + if err != nil { + // Skip invalid token files + continue + } + tokens = append(tokens, token) + } + + return tokens, nil +} + +// JWTClaims represents the claims we care about from a JWT token. +// JWT tokens from Kiro/AWS contain user information in the payload. +type JWTClaims struct { + Email string `json:"email,omitempty"` + Sub string `json:"sub,omitempty"` + PreferredUser string `json:"preferred_username,omitempty"` + Name string `json:"name,omitempty"` + Iss string `json:"iss,omitempty"` +} + +// ExtractEmailFromJWT extracts the user's email from a JWT access token. +// JWT tokens typically have format: header.payload.signature +// The payload is base64url-encoded JSON containing user claims. +func ExtractEmailFromJWT(accessToken string) string { + if accessToken == "" { + return "" + } + + // JWT format: header.payload.signature + parts := strings.Split(accessToken, ".") + if len(parts) != 3 { + return "" + } + + // Decode the payload (second part) + payload := parts[1] + + // Add padding if needed (base64url requires padding) + switch len(payload) % 4 { + case 2: + payload += "==" + case 3: + payload += "=" + } + + decoded, err := base64.URLEncoding.DecodeString(payload) + if err != nil { + // Try RawURLEncoding (no padding) + decoded, err = base64.RawURLEncoding.DecodeString(parts[1]) + if err != nil { + return "" + } + } + + var claims JWTClaims + if err := json.Unmarshal(decoded, &claims); err != nil { + return "" + } + + // Return email if available + if claims.Email != "" { + return claims.Email + } + + // Fallback to preferred_username (some providers use this) + if claims.PreferredUser != "" && strings.Contains(claims.PreferredUser, "@") { + return claims.PreferredUser + } + + // Fallback to sub if it looks like an email + if claims.Sub != "" && strings.Contains(claims.Sub, "@") { + return claims.Sub + } + + return "" +} + +// SanitizeEmailForFilename sanitizes an email address for use in a filename. +// Replaces special characters with underscores and prevents path traversal attacks. +// Also handles URL-encoded characters to prevent encoded path traversal attempts. +func SanitizeEmailForFilename(email string) string { + if email == "" { + return "" + } + + result := email + + // First, handle URL-encoded path traversal attempts (%2F, %2E, %5C, etc.) + // This prevents encoded characters from bypassing the sanitization. + // Note: We replace % last to catch any remaining encodings including double-encoding (%252F) + result = strings.ReplaceAll(result, "%2F", "_") // / + result = strings.ReplaceAll(result, "%2f", "_") + result = strings.ReplaceAll(result, "%5C", "_") // \ + result = strings.ReplaceAll(result, "%5c", "_") + result = strings.ReplaceAll(result, "%2E", "_") // . + result = strings.ReplaceAll(result, "%2e", "_") + result = strings.ReplaceAll(result, "%00", "_") // null byte + result = strings.ReplaceAll(result, "%", "_") // Catch remaining % to prevent double-encoding attacks + + // Replace characters that are problematic in filenames + // Keep @ and . in middle but replace other special characters + for _, char := range []string{"/", "\\", ":", "*", "?", "\"", "<", ">", "|", " ", "\x00"} { + result = strings.ReplaceAll(result, char, "_") + } + + // Prevent path traversal: replace leading dots in each path component + // This handles cases like "../../../etc/passwd" → "_.._.._.._etc_passwd" + parts := strings.Split(result, "_") + for i, part := range parts { + for strings.HasPrefix(part, ".") { + part = "_" + part[1:] + } + parts[i] = part + } + result = strings.Join(parts, "_") + + return result +} diff --git a/internal/auth/kiro/background_refresh.go b/internal/auth/kiro/background_refresh.go new file mode 100644 index 00000000..3fecc417 --- /dev/null +++ b/internal/auth/kiro/background_refresh.go @@ -0,0 +1,192 @@ +package kiro + +import ( + "context" + "log" + "sync" + "time" + + "github.com/router-for-me/CLIProxyAPI/v6/internal/config" + "golang.org/x/sync/semaphore" +) + +type Token struct { + ID string + AccessToken string + RefreshToken string + ExpiresAt time.Time + LastVerified time.Time + ClientID string + ClientSecret string + AuthMethod string + Provider string + StartURL string + Region string +} + +type TokenRepository interface { + FindOldestUnverified(limit int) []*Token + UpdateToken(token *Token) error +} + +type RefresherOption func(*BackgroundRefresher) + +func WithInterval(interval time.Duration) RefresherOption { + return func(r *BackgroundRefresher) { + r.interval = interval + } +} + +func WithBatchSize(size int) RefresherOption { + return func(r *BackgroundRefresher) { + r.batchSize = size + } +} + +func WithConcurrency(concurrency int) RefresherOption { + return func(r *BackgroundRefresher) { + r.concurrency = concurrency + } +} + +type BackgroundRefresher struct { + interval time.Duration + batchSize int + concurrency int + tokenRepo TokenRepository + stopCh chan struct{} + wg sync.WaitGroup + oauth *KiroOAuth + ssoClient *SSOOIDCClient +} + +func NewBackgroundRefresher(repo TokenRepository, opts ...RefresherOption) *BackgroundRefresher { + r := &BackgroundRefresher{ + interval: time.Minute, + batchSize: 50, + concurrency: 10, + tokenRepo: repo, + stopCh: make(chan struct{}), + oauth: nil, // Lazy init - will be set when config available + ssoClient: nil, // Lazy init - will be set when config available + } + for _, opt := range opts { + opt(r) + } + return r +} + +// WithConfig sets the configuration for OAuth and SSO clients. +func WithConfig(cfg *config.Config) RefresherOption { + return func(r *BackgroundRefresher) { + r.oauth = NewKiroOAuth(cfg) + r.ssoClient = NewSSOOIDCClient(cfg) + } +} + +func (r *BackgroundRefresher) Start(ctx context.Context) { + r.wg.Add(1) + go func() { + defer r.wg.Done() + ticker := time.NewTicker(r.interval) + defer ticker.Stop() + + r.refreshBatch(ctx) + + for { + select { + case <-ctx.Done(): + return + case <-r.stopCh: + return + case <-ticker.C: + r.refreshBatch(ctx) + } + } + }() +} + +func (r *BackgroundRefresher) Stop() { + close(r.stopCh) + r.wg.Wait() +} + +func (r *BackgroundRefresher) refreshBatch(ctx context.Context) { + tokens := r.tokenRepo.FindOldestUnverified(r.batchSize) + if len(tokens) == 0 { + return + } + + sem := semaphore.NewWeighted(int64(r.concurrency)) + var wg sync.WaitGroup + + for i, token := range tokens { + if i > 0 { + select { + case <-ctx.Done(): + return + case <-r.stopCh: + return + case <-time.After(100 * time.Millisecond): + } + } + + if err := sem.Acquire(ctx, 1); err != nil { + return + } + + wg.Add(1) + go func(t *Token) { + defer wg.Done() + defer sem.Release(1) + r.refreshSingle(ctx, t) + }(token) + } + + wg.Wait() +} + +func (r *BackgroundRefresher) refreshSingle(ctx context.Context, token *Token) { + var newTokenData *KiroTokenData + var err error + + switch token.AuthMethod { + case "idc": + newTokenData, err = r.ssoClient.RefreshTokenWithRegion( + ctx, + token.ClientID, + token.ClientSecret, + token.RefreshToken, + token.Region, + token.StartURL, + ) + case "builder-id": + newTokenData, err = r.ssoClient.RefreshToken( + ctx, + token.ClientID, + token.ClientSecret, + token.RefreshToken, + ) + default: + newTokenData, err = r.oauth.RefreshToken(ctx, token.RefreshToken) + } + + if err != nil { + log.Printf("failed to refresh token %s: %v", token.ID, err) + return + } + + token.AccessToken = newTokenData.AccessToken + token.RefreshToken = newTokenData.RefreshToken + token.LastVerified = time.Now() + + if newTokenData.ExpiresAt != "" { + if expTime, parseErr := time.Parse(time.RFC3339, newTokenData.ExpiresAt); parseErr == nil { + token.ExpiresAt = expTime + } + } + + if err := r.tokenRepo.UpdateToken(token); err != nil { + log.Printf("failed to update token %s: %v", token.ID, err) + } +} diff --git a/internal/auth/kiro/cooldown.go b/internal/auth/kiro/cooldown.go new file mode 100644 index 00000000..c1aabbcb --- /dev/null +++ b/internal/auth/kiro/cooldown.go @@ -0,0 +1,112 @@ +package kiro + +import ( + "sync" + "time" +) + +const ( + CooldownReason429 = "rate_limit_exceeded" + CooldownReasonSuspended = "account_suspended" + CooldownReasonQuotaExhausted = "quota_exhausted" + + DefaultShortCooldown = 1 * time.Minute + MaxShortCooldown = 5 * time.Minute + LongCooldown = 24 * time.Hour +) + +type CooldownManager struct { + mu sync.RWMutex + cooldowns map[string]time.Time + reasons map[string]string +} + +func NewCooldownManager() *CooldownManager { + return &CooldownManager{ + cooldowns: make(map[string]time.Time), + reasons: make(map[string]string), + } +} + +func (cm *CooldownManager) SetCooldown(tokenKey string, duration time.Duration, reason string) { + cm.mu.Lock() + defer cm.mu.Unlock() + cm.cooldowns[tokenKey] = time.Now().Add(duration) + cm.reasons[tokenKey] = reason +} + +func (cm *CooldownManager) IsInCooldown(tokenKey string) bool { + cm.mu.RLock() + defer cm.mu.RUnlock() + endTime, exists := cm.cooldowns[tokenKey] + if !exists { + return false + } + return time.Now().Before(endTime) +} + +func (cm *CooldownManager) GetRemainingCooldown(tokenKey string) time.Duration { + cm.mu.RLock() + defer cm.mu.RUnlock() + endTime, exists := cm.cooldowns[tokenKey] + if !exists { + return 0 + } + remaining := time.Until(endTime) + if remaining < 0 { + return 0 + } + return remaining +} + +func (cm *CooldownManager) GetCooldownReason(tokenKey string) string { + cm.mu.RLock() + defer cm.mu.RUnlock() + return cm.reasons[tokenKey] +} + +func (cm *CooldownManager) ClearCooldown(tokenKey string) { + cm.mu.Lock() + defer cm.mu.Unlock() + delete(cm.cooldowns, tokenKey) + delete(cm.reasons, tokenKey) +} + +func (cm *CooldownManager) CleanupExpired() { + cm.mu.Lock() + defer cm.mu.Unlock() + now := time.Now() + for tokenKey, endTime := range cm.cooldowns { + if now.After(endTime) { + delete(cm.cooldowns, tokenKey) + delete(cm.reasons, tokenKey) + } + } +} + +func (cm *CooldownManager) StartCleanupRoutine(interval time.Duration, stopCh <-chan struct{}) { + ticker := time.NewTicker(interval) + defer ticker.Stop() + for { + select { + case <-ticker.C: + cm.CleanupExpired() + case <-stopCh: + return + } + } +} + +func CalculateCooldownFor429(retryCount int) time.Duration { + duration := DefaultShortCooldown * time.Duration(1< MaxShortCooldown { + return MaxShortCooldown + } + return duration +} + +func CalculateCooldownUntilNextDay() time.Duration { + now := time.Now() + nextDay := time.Date(now.Year(), now.Month(), now.Day()+1, 0, 0, 0, 0, now.Location()) + return time.Until(nextDay) +} diff --git a/internal/auth/kiro/cooldown_test.go b/internal/auth/kiro/cooldown_test.go new file mode 100644 index 00000000..e0b35df4 --- /dev/null +++ b/internal/auth/kiro/cooldown_test.go @@ -0,0 +1,240 @@ +package kiro + +import ( + "sync" + "testing" + "time" +) + +func TestNewCooldownManager(t *testing.T) { + cm := NewCooldownManager() + if cm == nil { + t.Fatal("expected non-nil CooldownManager") + } + if cm.cooldowns == nil { + t.Error("expected non-nil cooldowns map") + } + if cm.reasons == nil { + t.Error("expected non-nil reasons map") + } +} + +func TestSetCooldown(t *testing.T) { + cm := NewCooldownManager() + cm.SetCooldown("token1", 1*time.Minute, CooldownReason429) + + if !cm.IsInCooldown("token1") { + t.Error("expected token to be in cooldown") + } + if cm.GetCooldownReason("token1") != CooldownReason429 { + t.Errorf("expected reason %s, got %s", CooldownReason429, cm.GetCooldownReason("token1")) + } +} + +func TestIsInCooldown_NotSet(t *testing.T) { + cm := NewCooldownManager() + if cm.IsInCooldown("nonexistent") { + t.Error("expected non-existent token to not be in cooldown") + } +} + +func TestIsInCooldown_Expired(t *testing.T) { + cm := NewCooldownManager() + cm.SetCooldown("token1", 1*time.Millisecond, CooldownReason429) + + time.Sleep(10 * time.Millisecond) + + if cm.IsInCooldown("token1") { + t.Error("expected expired cooldown to return false") + } +} + +func TestGetRemainingCooldown(t *testing.T) { + cm := NewCooldownManager() + cm.SetCooldown("token1", 1*time.Second, CooldownReason429) + + remaining := cm.GetRemainingCooldown("token1") + if remaining <= 0 || remaining > 1*time.Second { + t.Errorf("expected remaining cooldown between 0 and 1s, got %v", remaining) + } +} + +func TestGetRemainingCooldown_NotSet(t *testing.T) { + cm := NewCooldownManager() + remaining := cm.GetRemainingCooldown("nonexistent") + if remaining != 0 { + t.Errorf("expected 0 remaining for non-existent, got %v", remaining) + } +} + +func TestGetRemainingCooldown_Expired(t *testing.T) { + cm := NewCooldownManager() + cm.SetCooldown("token1", 1*time.Millisecond, CooldownReason429) + + time.Sleep(10 * time.Millisecond) + + remaining := cm.GetRemainingCooldown("token1") + if remaining != 0 { + t.Errorf("expected 0 remaining for expired, got %v", remaining) + } +} + +func TestGetCooldownReason(t *testing.T) { + cm := NewCooldownManager() + cm.SetCooldown("token1", 1*time.Minute, CooldownReasonSuspended) + + reason := cm.GetCooldownReason("token1") + if reason != CooldownReasonSuspended { + t.Errorf("expected reason %s, got %s", CooldownReasonSuspended, reason) + } +} + +func TestGetCooldownReason_NotSet(t *testing.T) { + cm := NewCooldownManager() + reason := cm.GetCooldownReason("nonexistent") + if reason != "" { + t.Errorf("expected empty reason for non-existent, got %s", reason) + } +} + +func TestClearCooldown(t *testing.T) { + cm := NewCooldownManager() + cm.SetCooldown("token1", 1*time.Minute, CooldownReason429) + cm.ClearCooldown("token1") + + if cm.IsInCooldown("token1") { + t.Error("expected cooldown to be cleared") + } + if cm.GetCooldownReason("token1") != "" { + t.Error("expected reason to be cleared") + } +} + +func TestClearCooldown_NonExistent(t *testing.T) { + cm := NewCooldownManager() + cm.ClearCooldown("nonexistent") +} + +func TestCleanupExpired(t *testing.T) { + cm := NewCooldownManager() + cm.SetCooldown("expired1", 1*time.Millisecond, CooldownReason429) + cm.SetCooldown("expired2", 1*time.Millisecond, CooldownReason429) + cm.SetCooldown("active", 1*time.Hour, CooldownReason429) + + time.Sleep(10 * time.Millisecond) + cm.CleanupExpired() + + if cm.GetCooldownReason("expired1") != "" { + t.Error("expected expired1 to be cleaned up") + } + if cm.GetCooldownReason("expired2") != "" { + t.Error("expected expired2 to be cleaned up") + } + if cm.GetCooldownReason("active") != CooldownReason429 { + t.Error("expected active to remain") + } +} + +func TestCalculateCooldownFor429_FirstRetry(t *testing.T) { + duration := CalculateCooldownFor429(0) + if duration != DefaultShortCooldown { + t.Errorf("expected %v for retry 0, got %v", DefaultShortCooldown, duration) + } +} + +func TestCalculateCooldownFor429_Exponential(t *testing.T) { + d1 := CalculateCooldownFor429(1) + d2 := CalculateCooldownFor429(2) + + if d2 <= d1 { + t.Errorf("expected d2 > d1, got d1=%v, d2=%v", d1, d2) + } +} + +func TestCalculateCooldownFor429_MaxCap(t *testing.T) { + duration := CalculateCooldownFor429(10) + if duration > MaxShortCooldown { + t.Errorf("expected max %v, got %v", MaxShortCooldown, duration) + } +} + +func TestCalculateCooldownUntilNextDay(t *testing.T) { + duration := CalculateCooldownUntilNextDay() + if duration <= 0 || duration > 24*time.Hour { + t.Errorf("expected duration between 0 and 24h, got %v", duration) + } +} + +func TestCooldownManager_ConcurrentAccess(t *testing.T) { + cm := NewCooldownManager() + const numGoroutines = 50 + const numOperations = 100 + + var wg sync.WaitGroup + wg.Add(numGoroutines) + + for i := 0; i < numGoroutines; i++ { + go func(id int) { + defer wg.Done() + tokenKey := "token" + string(rune('a'+id%10)) + for j := 0; j < numOperations; j++ { + switch j % 6 { + case 0: + cm.SetCooldown(tokenKey, time.Duration(j)*time.Millisecond, CooldownReason429) + case 1: + cm.IsInCooldown(tokenKey) + case 2: + cm.GetRemainingCooldown(tokenKey) + case 3: + cm.GetCooldownReason(tokenKey) + case 4: + cm.ClearCooldown(tokenKey) + case 5: + cm.CleanupExpired() + } + } + }(i) + } + + wg.Wait() +} + +func TestCooldownReasonConstants(t *testing.T) { + if CooldownReason429 != "rate_limit_exceeded" { + t.Errorf("unexpected CooldownReason429: %s", CooldownReason429) + } + if CooldownReasonSuspended != "account_suspended" { + t.Errorf("unexpected CooldownReasonSuspended: %s", CooldownReasonSuspended) + } + if CooldownReasonQuotaExhausted != "quota_exhausted" { + t.Errorf("unexpected CooldownReasonQuotaExhausted: %s", CooldownReasonQuotaExhausted) + } +} + +func TestDefaultConstants(t *testing.T) { + if DefaultShortCooldown != 1*time.Minute { + t.Errorf("unexpected DefaultShortCooldown: %v", DefaultShortCooldown) + } + if MaxShortCooldown != 5*time.Minute { + t.Errorf("unexpected MaxShortCooldown: %v", MaxShortCooldown) + } + if LongCooldown != 24*time.Hour { + t.Errorf("unexpected LongCooldown: %v", LongCooldown) + } +} + +func TestSetCooldown_OverwritesPrevious(t *testing.T) { + cm := NewCooldownManager() + cm.SetCooldown("token1", 1*time.Hour, CooldownReason429) + cm.SetCooldown("token1", 1*time.Minute, CooldownReasonSuspended) + + reason := cm.GetCooldownReason("token1") + if reason != CooldownReasonSuspended { + t.Errorf("expected reason to be overwritten to %s, got %s", CooldownReasonSuspended, reason) + } + + remaining := cm.GetRemainingCooldown("token1") + if remaining > 1*time.Minute { + t.Errorf("expected remaining <= 1 minute, got %v", remaining) + } +} diff --git a/internal/auth/kiro/fingerprint.go b/internal/auth/kiro/fingerprint.go new file mode 100644 index 00000000..c35e62b2 --- /dev/null +++ b/internal/auth/kiro/fingerprint.go @@ -0,0 +1,197 @@ +package kiro + +import ( + "crypto/sha256" + "encoding/hex" + "fmt" + "math/rand" + "net/http" + "sync" + "time" +) + +// Fingerprint 多维度指纹信息 +type Fingerprint struct { + SDKVersion string // 1.0.20-1.0.27 + OSType string // darwin/windows/linux + OSVersion string // 10.0.22621 + NodeVersion string // 18.x/20.x/22.x + KiroVersion string // 0.3.x-0.8.x + KiroHash string // SHA256 + AcceptLanguage string + ScreenResolution string // 1920x1080 + ColorDepth int // 24 + HardwareConcurrency int // CPU 核心数 + TimezoneOffset int +} + +// FingerprintManager 指纹管理器 +type FingerprintManager struct { + mu sync.RWMutex + fingerprints map[string]*Fingerprint // tokenKey -> fingerprint + rng *rand.Rand +} + +var ( + sdkVersions = []string{ + "1.0.20", "1.0.21", "1.0.22", "1.0.23", + "1.0.24", "1.0.25", "1.0.26", "1.0.27", + } + osTypes = []string{"darwin", "windows", "linux"} + osVersions = map[string][]string{ + "darwin": {"14.0", "14.1", "14.2", "14.3", "14.4", "14.5", "15.0", "15.1"}, + "windows": {"10.0.19041", "10.0.19042", "10.0.19043", "10.0.19044", "10.0.22621", "10.0.22631"}, + "linux": {"5.15.0", "6.1.0", "6.2.0", "6.5.0", "6.6.0", "6.8.0"}, + } + nodeVersions = []string{ + "18.17.0", "18.18.0", "18.19.0", "18.20.0", + "20.9.0", "20.10.0", "20.11.0", "20.12.0", "20.13.0", + "22.0.0", "22.1.0", "22.2.0", "22.3.0", + } + kiroVersions = []string{ + "0.3.0", "0.3.1", "0.4.0", "0.4.1", "0.5.0", "0.5.1", + "0.6.0", "0.6.1", "0.7.0", "0.7.1", "0.8.0", "0.8.1", + } + acceptLanguages = []string{ + "en-US,en;q=0.9", + "en-GB,en;q=0.9", + "zh-CN,zh;q=0.9,en;q=0.8", + "zh-TW,zh;q=0.9,en;q=0.8", + "ja-JP,ja;q=0.9,en;q=0.8", + "ko-KR,ko;q=0.9,en;q=0.8", + "de-DE,de;q=0.9,en;q=0.8", + "fr-FR,fr;q=0.9,en;q=0.8", + } + screenResolutions = []string{ + "1920x1080", "2560x1440", "3840x2160", + "1366x768", "1440x900", "1680x1050", + "2560x1600", "3440x1440", + } + colorDepths = []int{24, 32} + hardwareConcurrencies = []int{4, 6, 8, 10, 12, 16, 20, 24, 32} + timezoneOffsets = []int{-480, -420, -360, -300, -240, 0, 60, 120, 480, 540} +) + +// NewFingerprintManager 创建指纹管理器 +func NewFingerprintManager() *FingerprintManager { + return &FingerprintManager{ + fingerprints: make(map[string]*Fingerprint), + rng: rand.New(rand.NewSource(time.Now().UnixNano())), + } +} + +// GetFingerprint 获取或生成 Token 关联的指纹 +func (fm *FingerprintManager) GetFingerprint(tokenKey string) *Fingerprint { + fm.mu.RLock() + if fp, exists := fm.fingerprints[tokenKey]; exists { + fm.mu.RUnlock() + return fp + } + fm.mu.RUnlock() + + fm.mu.Lock() + defer fm.mu.Unlock() + + if fp, exists := fm.fingerprints[tokenKey]; exists { + return fp + } + + fp := fm.generateFingerprint(tokenKey) + fm.fingerprints[tokenKey] = fp + return fp +} + +// generateFingerprint 生成新的指纹 +func (fm *FingerprintManager) generateFingerprint(tokenKey string) *Fingerprint { + osType := fm.randomChoice(osTypes) + osVersion := fm.randomChoice(osVersions[osType]) + kiroVersion := fm.randomChoice(kiroVersions) + + fp := &Fingerprint{ + SDKVersion: fm.randomChoice(sdkVersions), + OSType: osType, + OSVersion: osVersion, + NodeVersion: fm.randomChoice(nodeVersions), + KiroVersion: kiroVersion, + AcceptLanguage: fm.randomChoice(acceptLanguages), + ScreenResolution: fm.randomChoice(screenResolutions), + ColorDepth: fm.randomIntChoice(colorDepths), + HardwareConcurrency: fm.randomIntChoice(hardwareConcurrencies), + TimezoneOffset: fm.randomIntChoice(timezoneOffsets), + } + + fp.KiroHash = fm.generateKiroHash(tokenKey, kiroVersion, osType) + return fp +} + +// generateKiroHash 生成 Kiro Hash +func (fm *FingerprintManager) generateKiroHash(tokenKey, kiroVersion, osType string) string { + data := fmt.Sprintf("%s:%s:%s:%d", tokenKey, kiroVersion, osType, time.Now().UnixNano()) + hash := sha256.Sum256([]byte(data)) + return hex.EncodeToString(hash[:]) +} + +// randomChoice 随机选择字符串 +func (fm *FingerprintManager) randomChoice(choices []string) string { + return choices[fm.rng.Intn(len(choices))] +} + +// randomIntChoice 随机选择整数 +func (fm *FingerprintManager) randomIntChoice(choices []int) int { + return choices[fm.rng.Intn(len(choices))] +} + +// ApplyToRequest 将指纹信息应用到 HTTP 请求头 +func (fp *Fingerprint) ApplyToRequest(req *http.Request) { + req.Header.Set("X-Kiro-SDK-Version", fp.SDKVersion) + req.Header.Set("X-Kiro-OS-Type", fp.OSType) + req.Header.Set("X-Kiro-OS-Version", fp.OSVersion) + req.Header.Set("X-Kiro-Node-Version", fp.NodeVersion) + req.Header.Set("X-Kiro-Version", fp.KiroVersion) + req.Header.Set("X-Kiro-Hash", fp.KiroHash) + req.Header.Set("Accept-Language", fp.AcceptLanguage) + req.Header.Set("X-Screen-Resolution", fp.ScreenResolution) + req.Header.Set("X-Color-Depth", fmt.Sprintf("%d", fp.ColorDepth)) + req.Header.Set("X-Hardware-Concurrency", fmt.Sprintf("%d", fp.HardwareConcurrency)) + req.Header.Set("X-Timezone-Offset", fmt.Sprintf("%d", fp.TimezoneOffset)) +} + +// RemoveFingerprint 移除 Token 关联的指纹 +func (fm *FingerprintManager) RemoveFingerprint(tokenKey string) { + fm.mu.Lock() + defer fm.mu.Unlock() + delete(fm.fingerprints, tokenKey) +} + +// Count 返回当前管理的指纹数量 +func (fm *FingerprintManager) Count() int { + fm.mu.RLock() + defer fm.mu.RUnlock() + return len(fm.fingerprints) +} + +// BuildUserAgent 构建 User-Agent 字符串 (Kiro IDE 风格) +// 格式: aws-sdk-js/{SDKVersion} ua/2.1 os/{OSType}#{OSVersion} lang/js md/nodejs#{NodeVersion} api/codewhispererstreaming#{SDKVersion} m/E KiroIDE-{KiroVersion}-{KiroHash} +func (fp *Fingerprint) BuildUserAgent() string { + return fmt.Sprintf( + "aws-sdk-js/%s ua/2.1 os/%s#%s lang/js md/nodejs#%s api/codewhispererstreaming#%s m/E KiroIDE-%s-%s", + fp.SDKVersion, + fp.OSType, + fp.OSVersion, + fp.NodeVersion, + fp.SDKVersion, + fp.KiroVersion, + fp.KiroHash, + ) +} + +// BuildAmzUserAgent 构建 X-Amz-User-Agent 字符串 +// 格式: aws-sdk-js/{SDKVersion} KiroIDE-{KiroVersion}-{KiroHash} +func (fp *Fingerprint) BuildAmzUserAgent() string { + return fmt.Sprintf( + "aws-sdk-js/%s KiroIDE-%s-%s", + fp.SDKVersion, + fp.KiroVersion, + fp.KiroHash, + ) +} diff --git a/internal/auth/kiro/fingerprint_test.go b/internal/auth/kiro/fingerprint_test.go new file mode 100644 index 00000000..e0ae51f2 --- /dev/null +++ b/internal/auth/kiro/fingerprint_test.go @@ -0,0 +1,227 @@ +package kiro + +import ( + "net/http" + "sync" + "testing" +) + +func TestNewFingerprintManager(t *testing.T) { + fm := NewFingerprintManager() + if fm == nil { + t.Fatal("expected non-nil FingerprintManager") + } + if fm.fingerprints == nil { + t.Error("expected non-nil fingerprints map") + } + if fm.rng == nil { + t.Error("expected non-nil rng") + } +} + +func TestGetFingerprint_NewToken(t *testing.T) { + fm := NewFingerprintManager() + fp := fm.GetFingerprint("token1") + + if fp == nil { + t.Fatal("expected non-nil Fingerprint") + } + if fp.SDKVersion == "" { + t.Error("expected non-empty SDKVersion") + } + if fp.OSType == "" { + t.Error("expected non-empty OSType") + } + if fp.OSVersion == "" { + t.Error("expected non-empty OSVersion") + } + if fp.NodeVersion == "" { + t.Error("expected non-empty NodeVersion") + } + if fp.KiroVersion == "" { + t.Error("expected non-empty KiroVersion") + } + if fp.KiroHash == "" { + t.Error("expected non-empty KiroHash") + } + if fp.AcceptLanguage == "" { + t.Error("expected non-empty AcceptLanguage") + } + if fp.ScreenResolution == "" { + t.Error("expected non-empty ScreenResolution") + } + if fp.ColorDepth == 0 { + t.Error("expected non-zero ColorDepth") + } + if fp.HardwareConcurrency == 0 { + t.Error("expected non-zero HardwareConcurrency") + } +} + +func TestGetFingerprint_SameTokenReturnsSameFingerprint(t *testing.T) { + fm := NewFingerprintManager() + fp1 := fm.GetFingerprint("token1") + fp2 := fm.GetFingerprint("token1") + + if fp1 != fp2 { + t.Error("expected same fingerprint for same token") + } +} + +func TestGetFingerprint_DifferentTokens(t *testing.T) { + fm := NewFingerprintManager() + fp1 := fm.GetFingerprint("token1") + fp2 := fm.GetFingerprint("token2") + + if fp1 == fp2 { + t.Error("expected different fingerprints for different tokens") + } +} + +func TestRemoveFingerprint(t *testing.T) { + fm := NewFingerprintManager() + fm.GetFingerprint("token1") + if fm.Count() != 1 { + t.Fatalf("expected count 1, got %d", fm.Count()) + } + + fm.RemoveFingerprint("token1") + if fm.Count() != 0 { + t.Errorf("expected count 0, got %d", fm.Count()) + } +} + +func TestRemoveFingerprint_NonExistent(t *testing.T) { + fm := NewFingerprintManager() + fm.RemoveFingerprint("nonexistent") + if fm.Count() != 0 { + t.Errorf("expected count 0, got %d", fm.Count()) + } +} + +func TestCount(t *testing.T) { + fm := NewFingerprintManager() + if fm.Count() != 0 { + t.Errorf("expected count 0, got %d", fm.Count()) + } + + fm.GetFingerprint("token1") + fm.GetFingerprint("token2") + fm.GetFingerprint("token3") + + if fm.Count() != 3 { + t.Errorf("expected count 3, got %d", fm.Count()) + } +} + +func TestApplyToRequest(t *testing.T) { + fm := NewFingerprintManager() + fp := fm.GetFingerprint("token1") + + req, _ := http.NewRequest("GET", "http://example.com", nil) + fp.ApplyToRequest(req) + + if req.Header.Get("X-Kiro-SDK-Version") != fp.SDKVersion { + t.Error("X-Kiro-SDK-Version header mismatch") + } + if req.Header.Get("X-Kiro-OS-Type") != fp.OSType { + t.Error("X-Kiro-OS-Type header mismatch") + } + if req.Header.Get("X-Kiro-OS-Version") != fp.OSVersion { + t.Error("X-Kiro-OS-Version header mismatch") + } + if req.Header.Get("X-Kiro-Node-Version") != fp.NodeVersion { + t.Error("X-Kiro-Node-Version header mismatch") + } + if req.Header.Get("X-Kiro-Version") != fp.KiroVersion { + t.Error("X-Kiro-Version header mismatch") + } + if req.Header.Get("X-Kiro-Hash") != fp.KiroHash { + t.Error("X-Kiro-Hash header mismatch") + } + if req.Header.Get("Accept-Language") != fp.AcceptLanguage { + t.Error("Accept-Language header mismatch") + } + if req.Header.Get("X-Screen-Resolution") != fp.ScreenResolution { + t.Error("X-Screen-Resolution header mismatch") + } +} + +func TestGetFingerprint_OSVersionMatchesOSType(t *testing.T) { + fm := NewFingerprintManager() + + for i := 0; i < 20; i++ { + fp := fm.GetFingerprint("token" + string(rune('a'+i))) + validVersions := osVersions[fp.OSType] + found := false + for _, v := range validVersions { + if v == fp.OSVersion { + found = true + break + } + } + if !found { + t.Errorf("OS version %s not valid for OS type %s", fp.OSVersion, fp.OSType) + } + } +} + +func TestFingerprintManager_ConcurrentAccess(t *testing.T) { + fm := NewFingerprintManager() + const numGoroutines = 100 + const numOperations = 100 + + var wg sync.WaitGroup + wg.Add(numGoroutines) + + for i := 0; i < numGoroutines; i++ { + go func(id int) { + defer wg.Done() + for j := 0; j < numOperations; j++ { + tokenKey := "token" + string(rune('a'+id%26)) + switch j % 4 { + case 0: + fm.GetFingerprint(tokenKey) + case 1: + fm.Count() + case 2: + fp := fm.GetFingerprint(tokenKey) + req, _ := http.NewRequest("GET", "http://example.com", nil) + fp.ApplyToRequest(req) + case 3: + fm.RemoveFingerprint(tokenKey) + } + } + }(i) + } + + wg.Wait() +} + +func TestKiroHashUniqueness(t *testing.T) { + fm := NewFingerprintManager() + hashes := make(map[string]bool) + + for i := 0; i < 100; i++ { + fp := fm.GetFingerprint("token" + string(rune(i))) + if hashes[fp.KiroHash] { + t.Errorf("duplicate KiroHash detected: %s", fp.KiroHash) + } + hashes[fp.KiroHash] = true + } +} + +func TestKiroHashFormat(t *testing.T) { + fm := NewFingerprintManager() + fp := fm.GetFingerprint("token1") + + if len(fp.KiroHash) != 64 { + t.Errorf("expected KiroHash length 64 (SHA256 hex), got %d", len(fp.KiroHash)) + } + + for _, c := range fp.KiroHash { + if !((c >= '0' && c <= '9') || (c >= 'a' && c <= 'f')) { + t.Errorf("invalid hex character in KiroHash: %c", c) + } + } +} diff --git a/internal/auth/kiro/jitter.go b/internal/auth/kiro/jitter.go new file mode 100644 index 00000000..0569a8fb --- /dev/null +++ b/internal/auth/kiro/jitter.go @@ -0,0 +1,174 @@ +package kiro + +import ( + "math/rand" + "sync" + "time" +) + +// Jitter configuration constants +const ( + // JitterPercent is the default percentage of jitter to apply (±30%) + JitterPercent = 0.30 + + // Human-like delay ranges + ShortDelayMin = 50 * time.Millisecond // Minimum for rapid consecutive operations + ShortDelayMax = 200 * time.Millisecond // Maximum for rapid consecutive operations + NormalDelayMin = 1 * time.Second // Minimum for normal thinking time + NormalDelayMax = 3 * time.Second // Maximum for normal thinking time + LongDelayMin = 5 * time.Second // Minimum for reading/resting + LongDelayMax = 10 * time.Second // Maximum for reading/resting + + // Probability thresholds for human-like behavior + ShortDelayProbability = 0.20 // 20% chance of short delay (consecutive ops) + LongDelayProbability = 0.05 // 5% chance of long delay (reading/resting) + NormalDelayProbability = 0.75 // 75% chance of normal delay (thinking) +) + +var ( + jitterRand *rand.Rand + jitterRandOnce sync.Once + jitterMu sync.Mutex + lastRequestTime time.Time +) + +// initJitterRand initializes the random number generator for jitter calculations. +// Uses a time-based seed for unpredictable but reproducible randomness. +func initJitterRand() { + jitterRandOnce.Do(func() { + jitterRand = rand.New(rand.NewSource(time.Now().UnixNano())) + }) +} + +// RandomDelay generates a random delay between min and max duration. +// Thread-safe implementation using mutex protection. +func RandomDelay(min, max time.Duration) time.Duration { + initJitterRand() + jitterMu.Lock() + defer jitterMu.Unlock() + + if min >= max { + return min + } + + rangeMs := max.Milliseconds() - min.Milliseconds() + randomMs := jitterRand.Int63n(rangeMs) + return min + time.Duration(randomMs)*time.Millisecond +} + +// JitterDelay adds jitter to a base delay. +// Applies ±jitterPercent variation to the base delay. +// For example, JitterDelay(1*time.Second, 0.30) returns a value between 700ms and 1300ms. +func JitterDelay(baseDelay time.Duration, jitterPercent float64) time.Duration { + initJitterRand() + jitterMu.Lock() + defer jitterMu.Unlock() + + if jitterPercent <= 0 || jitterPercent > 1 { + jitterPercent = JitterPercent + } + + // Calculate jitter range: base * jitterPercent + jitterRange := float64(baseDelay) * jitterPercent + + // Generate random value in range [-jitterRange, +jitterRange] + jitter := (jitterRand.Float64()*2 - 1) * jitterRange + + result := time.Duration(float64(baseDelay) + jitter) + if result < 0 { + return 0 + } + return result +} + +// JitterDelayDefault applies the default ±30% jitter to a base delay. +func JitterDelayDefault(baseDelay time.Duration) time.Duration { + return JitterDelay(baseDelay, JitterPercent) +} + +// HumanLikeDelay generates a delay that mimics human behavior patterns. +// The delay is selected based on probability distribution: +// - 20% chance: Short delay (50-200ms) - simulates consecutive rapid operations +// - 75% chance: Normal delay (1-3s) - simulates thinking/reading time +// - 5% chance: Long delay (5-10s) - simulates breaks/reading longer content +// +// Returns the delay duration (caller should call time.Sleep with this value). +func HumanLikeDelay() time.Duration { + initJitterRand() + jitterMu.Lock() + defer jitterMu.Unlock() + + // Track time since last request for adaptive behavior + now := time.Now() + timeSinceLastRequest := now.Sub(lastRequestTime) + lastRequestTime = now + + // If requests are very close together, use short delay + if timeSinceLastRequest < 500*time.Millisecond && timeSinceLastRequest > 0 { + rangeMs := ShortDelayMax.Milliseconds() - ShortDelayMin.Milliseconds() + randomMs := jitterRand.Int63n(rangeMs) + return ShortDelayMin + time.Duration(randomMs)*time.Millisecond + } + + // Otherwise, use probability-based selection + roll := jitterRand.Float64() + + var min, max time.Duration + switch { + case roll < ShortDelayProbability: + // Short delay - consecutive operations + min, max = ShortDelayMin, ShortDelayMax + case roll < ShortDelayProbability+LongDelayProbability: + // Long delay - reading/resting + min, max = LongDelayMin, LongDelayMax + default: + // Normal delay - thinking time + min, max = NormalDelayMin, NormalDelayMax + } + + rangeMs := max.Milliseconds() - min.Milliseconds() + randomMs := jitterRand.Int63n(rangeMs) + return min + time.Duration(randomMs)*time.Millisecond +} + +// ApplyHumanLikeDelay applies human-like delay by sleeping. +// This is a convenience function that combines HumanLikeDelay with time.Sleep. +func ApplyHumanLikeDelay() { + delay := HumanLikeDelay() + if delay > 0 { + time.Sleep(delay) + } +} + +// ExponentialBackoffWithJitter calculates retry delay using exponential backoff with jitter. +// Formula: min(baseDelay * 2^attempt + jitter, maxDelay) +// This helps prevent thundering herd problem when multiple clients retry simultaneously. +func ExponentialBackoffWithJitter(attempt int, baseDelay, maxDelay time.Duration) time.Duration { + if attempt < 0 { + attempt = 0 + } + + // Calculate exponential backoff: baseDelay * 2^attempt + backoff := baseDelay * time.Duration(1< maxDelay { + backoff = maxDelay + } + + // Add ±30% jitter + return JitterDelay(backoff, JitterPercent) +} + +// ShouldSkipDelay determines if delay should be skipped based on context. +// Returns true for streaming responses, WebSocket connections, etc. +// This function can be extended to check additional skip conditions. +func ShouldSkipDelay(isStreaming bool) bool { + return isStreaming +} + +// ResetLastRequestTime resets the last request time tracker. +// Useful for testing or when starting a new session. +func ResetLastRequestTime() { + jitterMu.Lock() + defer jitterMu.Unlock() + lastRequestTime = time.Time{} +} diff --git a/internal/auth/kiro/metrics.go b/internal/auth/kiro/metrics.go new file mode 100644 index 00000000..0fe2d0c6 --- /dev/null +++ b/internal/auth/kiro/metrics.go @@ -0,0 +1,187 @@ +package kiro + +import ( + "math" + "sync" + "time" +) + +// TokenMetrics holds performance metrics for a single token. +type TokenMetrics struct { + SuccessRate float64 // Success rate (0.0 - 1.0) + AvgLatency float64 // Average latency in milliseconds + QuotaRemaining float64 // Remaining quota (0.0 - 1.0) + LastUsed time.Time // Last usage timestamp + FailCount int // Consecutive failure count + TotalRequests int // Total request count + successCount int // Internal: successful request count + totalLatency float64 // Internal: cumulative latency +} + +// TokenScorer manages token metrics and scoring. +type TokenScorer struct { + mu sync.RWMutex + metrics map[string]*TokenMetrics + + // Scoring weights + successRateWeight float64 + quotaWeight float64 + latencyWeight float64 + lastUsedWeight float64 + failPenaltyMultiplier float64 +} + +// NewTokenScorer creates a new TokenScorer with default weights. +func NewTokenScorer() *TokenScorer { + return &TokenScorer{ + metrics: make(map[string]*TokenMetrics), + successRateWeight: 0.4, + quotaWeight: 0.25, + latencyWeight: 0.2, + lastUsedWeight: 0.15, + failPenaltyMultiplier: 0.1, + } +} + +// getOrCreateMetrics returns existing metrics or creates new ones. +func (s *TokenScorer) getOrCreateMetrics(tokenKey string) *TokenMetrics { + if m, ok := s.metrics[tokenKey]; ok { + return m + } + m := &TokenMetrics{ + SuccessRate: 1.0, + QuotaRemaining: 1.0, + } + s.metrics[tokenKey] = m + return m +} + +// RecordRequest records the result of a request for a token. +func (s *TokenScorer) RecordRequest(tokenKey string, success bool, latency time.Duration) { + s.mu.Lock() + defer s.mu.Unlock() + + m := s.getOrCreateMetrics(tokenKey) + m.TotalRequests++ + m.LastUsed = time.Now() + m.totalLatency += float64(latency.Milliseconds()) + + if success { + m.successCount++ + m.FailCount = 0 + } else { + m.FailCount++ + } + + // Update derived metrics + if m.TotalRequests > 0 { + m.SuccessRate = float64(m.successCount) / float64(m.TotalRequests) + m.AvgLatency = m.totalLatency / float64(m.TotalRequests) + } +} + +// SetQuotaRemaining updates the remaining quota for a token. +func (s *TokenScorer) SetQuotaRemaining(tokenKey string, quota float64) { + s.mu.Lock() + defer s.mu.Unlock() + + m := s.getOrCreateMetrics(tokenKey) + m.QuotaRemaining = quota +} + +// GetMetrics returns a copy of the metrics for a token. +func (s *TokenScorer) GetMetrics(tokenKey string) *TokenMetrics { + s.mu.RLock() + defer s.mu.RUnlock() + + if m, ok := s.metrics[tokenKey]; ok { + copy := *m + return © + } + return nil +} + +// CalculateScore computes the score for a token (higher is better). +func (s *TokenScorer) CalculateScore(tokenKey string) float64 { + s.mu.RLock() + defer s.mu.RUnlock() + + m, ok := s.metrics[tokenKey] + if !ok { + return 1.0 // New tokens get a high initial score + } + + // Success rate component (0-1) + successScore := m.SuccessRate + + // Quota component (0-1) + quotaScore := m.QuotaRemaining + + // Latency component (normalized, lower is better) + // Using exponential decay: score = e^(-latency/1000) + // 1000ms latency -> ~0.37 score, 100ms -> ~0.90 score + latencyScore := math.Exp(-m.AvgLatency / 1000.0) + if m.TotalRequests == 0 { + latencyScore = 1.0 + } + + // Last used component (prefer tokens not recently used) + // Score increases as time since last use increases + timeSinceUse := time.Since(m.LastUsed).Seconds() + // Normalize: 60 seconds -> ~0.63 score, 0 seconds -> 0 score + lastUsedScore := 1.0 - math.Exp(-timeSinceUse/60.0) + if m.LastUsed.IsZero() { + lastUsedScore = 1.0 + } + + // Calculate weighted score + score := s.successRateWeight*successScore + + s.quotaWeight*quotaScore + + s.latencyWeight*latencyScore + + s.lastUsedWeight*lastUsedScore + + // Apply consecutive failure penalty + if m.FailCount > 0 { + penalty := s.failPenaltyMultiplier * float64(m.FailCount) + score = score * math.Max(0, 1.0-penalty) + } + + return score +} + +// SelectBestToken selects the token with the highest score. +func (s *TokenScorer) SelectBestToken(tokens []string) string { + if len(tokens) == 0 { + return "" + } + if len(tokens) == 1 { + return tokens[0] + } + + bestToken := tokens[0] + bestScore := s.CalculateScore(tokens[0]) + + for _, token := range tokens[1:] { + score := s.CalculateScore(token) + if score > bestScore { + bestScore = score + bestToken = token + } + } + + return bestToken +} + +// ResetMetrics clears all metrics for a token. +func (s *TokenScorer) ResetMetrics(tokenKey string) { + s.mu.Lock() + defer s.mu.Unlock() + delete(s.metrics, tokenKey) +} + +// ResetAllMetrics clears all stored metrics. +func (s *TokenScorer) ResetAllMetrics() { + s.mu.Lock() + defer s.mu.Unlock() + s.metrics = make(map[string]*TokenMetrics) +} diff --git a/internal/auth/kiro/metrics_test.go b/internal/auth/kiro/metrics_test.go new file mode 100644 index 00000000..ffe2a876 --- /dev/null +++ b/internal/auth/kiro/metrics_test.go @@ -0,0 +1,301 @@ +package kiro + +import ( + "sync" + "testing" + "time" +) + +func TestNewTokenScorer(t *testing.T) { + s := NewTokenScorer() + if s == nil { + t.Fatal("expected non-nil TokenScorer") + } + if s.metrics == nil { + t.Error("expected non-nil metrics map") + } + if s.successRateWeight != 0.4 { + t.Errorf("expected successRateWeight 0.4, got %f", s.successRateWeight) + } + if s.quotaWeight != 0.25 { + t.Errorf("expected quotaWeight 0.25, got %f", s.quotaWeight) + } +} + +func TestRecordRequest_Success(t *testing.T) { + s := NewTokenScorer() + s.RecordRequest("token1", true, 100*time.Millisecond) + + m := s.GetMetrics("token1") + if m == nil { + t.Fatal("expected non-nil metrics") + } + if m.TotalRequests != 1 { + t.Errorf("expected TotalRequests 1, got %d", m.TotalRequests) + } + if m.SuccessRate != 1.0 { + t.Errorf("expected SuccessRate 1.0, got %f", m.SuccessRate) + } + if m.FailCount != 0 { + t.Errorf("expected FailCount 0, got %d", m.FailCount) + } + if m.AvgLatency != 100 { + t.Errorf("expected AvgLatency 100, got %f", m.AvgLatency) + } +} + +func TestRecordRequest_Failure(t *testing.T) { + s := NewTokenScorer() + s.RecordRequest("token1", false, 200*time.Millisecond) + + m := s.GetMetrics("token1") + if m.SuccessRate != 0.0 { + t.Errorf("expected SuccessRate 0.0, got %f", m.SuccessRate) + } + if m.FailCount != 1 { + t.Errorf("expected FailCount 1, got %d", m.FailCount) + } +} + +func TestRecordRequest_MixedResults(t *testing.T) { + s := NewTokenScorer() + s.RecordRequest("token1", true, 100*time.Millisecond) + s.RecordRequest("token1", true, 100*time.Millisecond) + s.RecordRequest("token1", false, 100*time.Millisecond) + s.RecordRequest("token1", true, 100*time.Millisecond) + + m := s.GetMetrics("token1") + if m.TotalRequests != 4 { + t.Errorf("expected TotalRequests 4, got %d", m.TotalRequests) + } + if m.SuccessRate != 0.75 { + t.Errorf("expected SuccessRate 0.75, got %f", m.SuccessRate) + } + if m.FailCount != 0 { + t.Errorf("expected FailCount 0 (reset on success), got %d", m.FailCount) + } +} + +func TestRecordRequest_ConsecutiveFailures(t *testing.T) { + s := NewTokenScorer() + s.RecordRequest("token1", true, 100*time.Millisecond) + s.RecordRequest("token1", false, 100*time.Millisecond) + s.RecordRequest("token1", false, 100*time.Millisecond) + s.RecordRequest("token1", false, 100*time.Millisecond) + + m := s.GetMetrics("token1") + if m.FailCount != 3 { + t.Errorf("expected FailCount 3, got %d", m.FailCount) + } +} + +func TestSetQuotaRemaining(t *testing.T) { + s := NewTokenScorer() + s.SetQuotaRemaining("token1", 0.5) + + m := s.GetMetrics("token1") + if m.QuotaRemaining != 0.5 { + t.Errorf("expected QuotaRemaining 0.5, got %f", m.QuotaRemaining) + } +} + +func TestGetMetrics_NonExistent(t *testing.T) { + s := NewTokenScorer() + m := s.GetMetrics("nonexistent") + if m != nil { + t.Error("expected nil metrics for non-existent token") + } +} + +func TestGetMetrics_ReturnsCopy(t *testing.T) { + s := NewTokenScorer() + s.RecordRequest("token1", true, 100*time.Millisecond) + + m1 := s.GetMetrics("token1") + m1.TotalRequests = 999 + + m2 := s.GetMetrics("token1") + if m2.TotalRequests == 999 { + t.Error("GetMetrics should return a copy") + } +} + +func TestCalculateScore_NewToken(t *testing.T) { + s := NewTokenScorer() + score := s.CalculateScore("newtoken") + if score != 1.0 { + t.Errorf("expected score 1.0 for new token, got %f", score) + } +} + +func TestCalculateScore_PerfectToken(t *testing.T) { + s := NewTokenScorer() + s.RecordRequest("token1", true, 50*time.Millisecond) + s.SetQuotaRemaining("token1", 1.0) + + time.Sleep(100 * time.Millisecond) + score := s.CalculateScore("token1") + if score < 0.5 || score > 1.0 { + t.Errorf("expected high score for perfect token, got %f", score) + } +} + +func TestCalculateScore_FailedToken(t *testing.T) { + s := NewTokenScorer() + for i := 0; i < 5; i++ { + s.RecordRequest("token1", false, 1000*time.Millisecond) + } + s.SetQuotaRemaining("token1", 0.1) + + score := s.CalculateScore("token1") + if score > 0.5 { + t.Errorf("expected low score for failed token, got %f", score) + } +} + +func TestCalculateScore_FailPenalty(t *testing.T) { + s := NewTokenScorer() + s.RecordRequest("token1", true, 100*time.Millisecond) + scoreNoFail := s.CalculateScore("token1") + + s.RecordRequest("token1", false, 100*time.Millisecond) + s.RecordRequest("token1", false, 100*time.Millisecond) + scoreWithFail := s.CalculateScore("token1") + + if scoreWithFail >= scoreNoFail { + t.Errorf("expected lower score with consecutive failures: noFail=%f, withFail=%f", scoreNoFail, scoreWithFail) + } +} + +func TestSelectBestToken_Empty(t *testing.T) { + s := NewTokenScorer() + best := s.SelectBestToken([]string{}) + if best != "" { + t.Errorf("expected empty string for empty tokens, got %s", best) + } +} + +func TestSelectBestToken_SingleToken(t *testing.T) { + s := NewTokenScorer() + best := s.SelectBestToken([]string{"token1"}) + if best != "token1" { + t.Errorf("expected token1, got %s", best) + } +} + +func TestSelectBestToken_MultipleTokens(t *testing.T) { + s := NewTokenScorer() + + s.RecordRequest("bad", false, 1000*time.Millisecond) + s.RecordRequest("bad", false, 1000*time.Millisecond) + s.SetQuotaRemaining("bad", 0.1) + + s.RecordRequest("good", true, 50*time.Millisecond) + s.SetQuotaRemaining("good", 0.9) + + time.Sleep(50 * time.Millisecond) + + best := s.SelectBestToken([]string{"bad", "good"}) + if best != "good" { + t.Errorf("expected good token to be selected, got %s", best) + } +} + +func TestResetMetrics(t *testing.T) { + s := NewTokenScorer() + s.RecordRequest("token1", true, 100*time.Millisecond) + s.ResetMetrics("token1") + + m := s.GetMetrics("token1") + if m != nil { + t.Error("expected nil metrics after reset") + } +} + +func TestResetAllMetrics(t *testing.T) { + s := NewTokenScorer() + s.RecordRequest("token1", true, 100*time.Millisecond) + s.RecordRequest("token2", true, 100*time.Millisecond) + s.RecordRequest("token3", true, 100*time.Millisecond) + + s.ResetAllMetrics() + + if s.GetMetrics("token1") != nil { + t.Error("expected nil metrics for token1 after reset all") + } + if s.GetMetrics("token2") != nil { + t.Error("expected nil metrics for token2 after reset all") + } +} + +func TestTokenScorer_ConcurrentAccess(t *testing.T) { + s := NewTokenScorer() + const numGoroutines = 50 + const numOperations = 100 + + var wg sync.WaitGroup + wg.Add(numGoroutines) + + for i := 0; i < numGoroutines; i++ { + go func(id int) { + defer wg.Done() + tokenKey := "token" + string(rune('a'+id%10)) + for j := 0; j < numOperations; j++ { + switch j % 6 { + case 0: + s.RecordRequest(tokenKey, j%2 == 0, time.Duration(j)*time.Millisecond) + case 1: + s.SetQuotaRemaining(tokenKey, float64(j%100)/100) + case 2: + s.GetMetrics(tokenKey) + case 3: + s.CalculateScore(tokenKey) + case 4: + s.SelectBestToken([]string{tokenKey, "token_x", "token_y"}) + case 5: + if j%20 == 0 { + s.ResetMetrics(tokenKey) + } + } + } + }(i) + } + + wg.Wait() +} + +func TestAvgLatencyCalculation(t *testing.T) { + s := NewTokenScorer() + s.RecordRequest("token1", true, 100*time.Millisecond) + s.RecordRequest("token1", true, 200*time.Millisecond) + s.RecordRequest("token1", true, 300*time.Millisecond) + + m := s.GetMetrics("token1") + if m.AvgLatency != 200 { + t.Errorf("expected AvgLatency 200, got %f", m.AvgLatency) + } +} + +func TestLastUsedUpdated(t *testing.T) { + s := NewTokenScorer() + before := time.Now() + s.RecordRequest("token1", true, 100*time.Millisecond) + + m := s.GetMetrics("token1") + if m.LastUsed.Before(before) { + t.Error("expected LastUsed to be after test start time") + } + if m.LastUsed.After(time.Now()) { + t.Error("expected LastUsed to be before or equal to now") + } +} + +func TestDefaultQuotaForNewToken(t *testing.T) { + s := NewTokenScorer() + s.RecordRequest("token1", true, 100*time.Millisecond) + + m := s.GetMetrics("token1") + if m.QuotaRemaining != 1.0 { + t.Errorf("expected default QuotaRemaining 1.0, got %f", m.QuotaRemaining) + } +} diff --git a/internal/auth/kiro/oauth.go b/internal/auth/kiro/oauth.go index a7d3eb9a..0609610f 100644 --- a/internal/auth/kiro/oauth.go +++ b/internal/auth/kiro/oauth.go @@ -227,6 +227,7 @@ func (o *KiroOAuth) exchangeCodeForToken(ctx context.Context, code, codeVerifier ExpiresAt: expiresAt.Format(time.RFC3339), AuthMethod: "social", Provider: "", // Caller should preserve original provider + Region: "us-east-1", }, nil } @@ -285,6 +286,7 @@ func (o *KiroOAuth) RefreshToken(ctx context.Context, refreshToken string) (*Kir ExpiresAt: expiresAt.Format(time.RFC3339), AuthMethod: "social", Provider: "", // Caller should preserve original provider + Region: "us-east-1", }, nil } diff --git a/internal/auth/kiro/oauth_web.go b/internal/auth/kiro/oauth_web.go new file mode 100644 index 00000000..13198516 --- /dev/null +++ b/internal/auth/kiro/oauth_web.go @@ -0,0 +1,825 @@ +// Package kiro provides OAuth Web authentication for Kiro. +package kiro + +import ( + "context" + "crypto/rand" + "encoding/base64" + "fmt" + "html/template" + "net/http" + "os" + "path/filepath" + "strings" + "sync" + "time" + + "github.com/gin-gonic/gin" + "github.com/router-for-me/CLIProxyAPI/v6/internal/config" + "github.com/router-for-me/CLIProxyAPI/v6/internal/util" + log "github.com/sirupsen/logrus" +) + +const ( + defaultSessionExpiry = 10 * time.Minute + pollIntervalSeconds = 5 +) + +type authSessionStatus string + +const ( + statusPending authSessionStatus = "pending" + statusSuccess authSessionStatus = "success" + statusFailed authSessionStatus = "failed" +) + +type webAuthSession struct { + stateID string + deviceCode string + userCode string + authURL string + verificationURI string + expiresIn int + interval int + status authSessionStatus + startedAt time.Time + completedAt time.Time + expiresAt time.Time + error string + tokenData *KiroTokenData + ssoClient *SSOOIDCClient + clientID string + clientSecret string + region string + cancelFunc context.CancelFunc + authMethod string // "google", "github", "builder-id", "idc" + startURL string // Used for IDC + codeVerifier string // Used for social auth PKCE + codeChallenge string // Used for social auth PKCE +} + +type OAuthWebHandler struct { + cfg *config.Config + sessions map[string]*webAuthSession + mu sync.RWMutex + onTokenObtained func(*KiroTokenData) +} + +func NewOAuthWebHandler(cfg *config.Config) *OAuthWebHandler { + return &OAuthWebHandler{ + cfg: cfg, + sessions: make(map[string]*webAuthSession), + } +} + +func (h *OAuthWebHandler) SetTokenCallback(callback func(*KiroTokenData)) { + h.onTokenObtained = callback +} + +func (h *OAuthWebHandler) RegisterRoutes(router gin.IRouter) { + oauth := router.Group("/v0/oauth/kiro") + { + oauth.GET("", h.handleSelect) + oauth.GET("/start", h.handleStart) + oauth.GET("/callback", h.handleCallback) + oauth.GET("/social/callback", h.handleSocialCallback) + oauth.GET("/status", h.handleStatus) + oauth.POST("/import", h.handleImportToken) + } +} + +func generateStateID() (string, error) { + b := make([]byte, 16) + if _, err := rand.Read(b); err != nil { + return "", err + } + return base64.RawURLEncoding.EncodeToString(b), nil +} + +func (h *OAuthWebHandler) handleSelect(c *gin.Context) { + h.renderSelectPage(c) +} + +func (h *OAuthWebHandler) handleStart(c *gin.Context) { + method := c.Query("method") + + if method == "" { + c.Redirect(http.StatusFound, "/v0/oauth/kiro") + return + } + + switch method { + case "google", "github": + // Google/GitHub social login is not supported for third-party apps + // due to AWS Cognito redirect_uri restrictions + h.renderError(c, "Google/GitHub login is not available for third-party applications. Please use AWS Builder ID or import your token from Kiro IDE.") + case "builder-id": + h.startBuilderIDAuth(c) + case "idc": + h.startIDCAuth(c) + default: + h.renderError(c, fmt.Sprintf("Unknown authentication method: %s", method)) + } +} + +func (h *OAuthWebHandler) startSocialAuth(c *gin.Context, method string) { + stateID, err := generateStateID() + if err != nil { + h.renderError(c, "Failed to generate state parameter") + return + } + + codeVerifier, codeChallenge, err := generatePKCE() + if err != nil { + h.renderError(c, "Failed to generate PKCE parameters") + return + } + + socialClient := NewSocialAuthClient(h.cfg) + + var provider string + if method == "google" { + provider = string(ProviderGoogle) + } else { + provider = string(ProviderGitHub) + } + + redirectURI := h.getSocialCallbackURL(c) + authURL := socialClient.buildLoginURL(provider, redirectURI, codeChallenge, stateID) + + ctx, cancel := context.WithTimeout(context.Background(), 10*time.Minute) + + session := &webAuthSession{ + stateID: stateID, + authMethod: method, + authURL: authURL, + status: statusPending, + startedAt: time.Now(), + expiresIn: 600, + codeVerifier: codeVerifier, + codeChallenge: codeChallenge, + region: "us-east-1", + cancelFunc: cancel, + } + + h.mu.Lock() + h.sessions[stateID] = session + h.mu.Unlock() + + go func() { + <-ctx.Done() + h.mu.Lock() + if session.status == statusPending { + session.status = statusFailed + session.error = "Authentication timed out" + } + h.mu.Unlock() + }() + + c.Redirect(http.StatusFound, authURL) +} + +func (h *OAuthWebHandler) getSocialCallbackURL(c *gin.Context) string { + scheme := "http" + if c.Request.TLS != nil || c.GetHeader("X-Forwarded-Proto") == "https" { + scheme = "https" + } + return fmt.Sprintf("%s://%s/v0/oauth/kiro/social/callback", scheme, c.Request.Host) +} + +func (h *OAuthWebHandler) startBuilderIDAuth(c *gin.Context) { + stateID, err := generateStateID() + if err != nil { + h.renderError(c, "Failed to generate state parameter") + return + } + + region := defaultIDCRegion + startURL := builderIDStartURL + + ssoClient := NewSSOOIDCClient(h.cfg) + + regResp, err := ssoClient.RegisterClientWithRegion(c.Request.Context(), region) + if err != nil { + log.Errorf("OAuth Web: failed to register client: %v", err) + h.renderError(c, fmt.Sprintf("Failed to register client: %v", err)) + return + } + + authResp, err := ssoClient.StartDeviceAuthorizationWithIDC( + c.Request.Context(), + regResp.ClientID, + regResp.ClientSecret, + startURL, + region, + ) + if err != nil { + log.Errorf("OAuth Web: failed to start device authorization: %v", err) + h.renderError(c, fmt.Sprintf("Failed to start device authorization: %v", err)) + return + } + + ctx, cancel := context.WithTimeout(context.Background(), time.Duration(authResp.ExpiresIn)*time.Second) + + session := &webAuthSession{ + stateID: stateID, + deviceCode: authResp.DeviceCode, + userCode: authResp.UserCode, + authURL: authResp.VerificationURIComplete, + verificationURI: authResp.VerificationURI, + expiresIn: authResp.ExpiresIn, + interval: authResp.Interval, + status: statusPending, + startedAt: time.Now(), + ssoClient: ssoClient, + clientID: regResp.ClientID, + clientSecret: regResp.ClientSecret, + region: region, + authMethod: "builder-id", + startURL: startURL, + cancelFunc: cancel, + } + + h.mu.Lock() + h.sessions[stateID] = session + h.mu.Unlock() + + go h.pollForToken(ctx, session) + + h.renderStartPage(c, session) +} + +func (h *OAuthWebHandler) startIDCAuth(c *gin.Context) { + startURL := c.Query("startUrl") + region := c.Query("region") + + if startURL == "" { + h.renderError(c, "Missing startUrl parameter for IDC authentication") + return + } + if region == "" { + region = defaultIDCRegion + } + + stateID, err := generateStateID() + if err != nil { + h.renderError(c, "Failed to generate state parameter") + return + } + + ssoClient := NewSSOOIDCClient(h.cfg) + + regResp, err := ssoClient.RegisterClientWithRegion(c.Request.Context(), region) + if err != nil { + log.Errorf("OAuth Web: failed to register client: %v", err) + h.renderError(c, fmt.Sprintf("Failed to register client: %v", err)) + return + } + + authResp, err := ssoClient.StartDeviceAuthorizationWithIDC( + c.Request.Context(), + regResp.ClientID, + regResp.ClientSecret, + startURL, + region, + ) + if err != nil { + log.Errorf("OAuth Web: failed to start device authorization: %v", err) + h.renderError(c, fmt.Sprintf("Failed to start device authorization: %v", err)) + return + } + + ctx, cancel := context.WithTimeout(context.Background(), time.Duration(authResp.ExpiresIn)*time.Second) + + session := &webAuthSession{ + stateID: stateID, + deviceCode: authResp.DeviceCode, + userCode: authResp.UserCode, + authURL: authResp.VerificationURIComplete, + verificationURI: authResp.VerificationURI, + expiresIn: authResp.ExpiresIn, + interval: authResp.Interval, + status: statusPending, + startedAt: time.Now(), + ssoClient: ssoClient, + clientID: regResp.ClientID, + clientSecret: regResp.ClientSecret, + region: region, + authMethod: "idc", + startURL: startURL, + cancelFunc: cancel, + } + + h.mu.Lock() + h.sessions[stateID] = session + h.mu.Unlock() + + go h.pollForToken(ctx, session) + + h.renderStartPage(c, session) +} + +func (h *OAuthWebHandler) pollForToken(ctx context.Context, session *webAuthSession) { + defer session.cancelFunc() + + interval := time.Duration(session.interval) * time.Second + if interval < time.Duration(pollIntervalSeconds)*time.Second { + interval = time.Duration(pollIntervalSeconds) * time.Second + } + + ticker := time.NewTicker(interval) + defer ticker.Stop() + + for { + select { + case <-ctx.Done(): + h.mu.Lock() + if session.status == statusPending { + session.status = statusFailed + session.error = "Authentication timed out" + } + h.mu.Unlock() + return + case <-ticker.C: + tokenResp, err := h.ssoClient(session).CreateTokenWithRegion( + ctx, + session.clientID, + session.clientSecret, + session.deviceCode, + session.region, + ) + + if err != nil { + errStr := err.Error() + if errStr == ErrAuthorizationPending.Error() { + continue + } + if errStr == ErrSlowDown.Error() { + interval += 5 * time.Second + ticker.Reset(interval) + continue + } + + h.mu.Lock() + session.status = statusFailed + session.error = errStr + session.completedAt = time.Now() + h.mu.Unlock() + + log.Errorf("OAuth Web: token polling failed: %v", err) + return + } + + expiresAt := time.Now().Add(time.Duration(tokenResp.ExpiresIn) * time.Second) + profileArn := session.ssoClient.fetchProfileArn(ctx, tokenResp.AccessToken) + email := FetchUserEmailWithFallback(ctx, h.cfg, tokenResp.AccessToken) + + tokenData := &KiroTokenData{ + AccessToken: tokenResp.AccessToken, + RefreshToken: tokenResp.RefreshToken, + ProfileArn: profileArn, + ExpiresAt: expiresAt.Format(time.RFC3339), + AuthMethod: session.authMethod, + Provider: "AWS", + ClientID: session.clientID, + ClientSecret: session.clientSecret, + Email: email, + Region: session.region, + } + + h.mu.Lock() + session.status = statusSuccess + session.completedAt = time.Now() + session.expiresAt = expiresAt + session.tokenData = tokenData + h.mu.Unlock() + + if h.onTokenObtained != nil { + h.onTokenObtained(tokenData) + } + + // Save token to file + h.saveTokenToFile(tokenData) + + log.Infof("OAuth Web: authentication successful for %s", email) + return + } + } +} + +// saveTokenToFile saves the token data to the auth directory +func (h *OAuthWebHandler) saveTokenToFile(tokenData *KiroTokenData) { + // Get auth directory from config or use default + authDir := "" + if h.cfg != nil && h.cfg.AuthDir != "" { + var err error + authDir, err = util.ResolveAuthDir(h.cfg.AuthDir) + if err != nil { + log.Errorf("OAuth Web: failed to resolve auth directory: %v", err) + } + } + + // Fall back to default location + if authDir == "" { + home, err := os.UserHomeDir() + if err != nil { + log.Errorf("OAuth Web: failed to get home directory: %v", err) + return + } + authDir = filepath.Join(home, ".cli-proxy-api") + } + + // Create directory if not exists + if err := os.MkdirAll(authDir, 0700); err != nil { + log.Errorf("OAuth Web: failed to create auth directory: %v", err) + return + } + + // Generate filename based on auth method + // Format: kiro-{authMethod}.json or kiro-{authMethod}-{email}.json + fileName := fmt.Sprintf("kiro-%s.json", tokenData.AuthMethod) + if tokenData.Email != "" { + // Sanitize email for filename (replace @ and . with -) + sanitizedEmail := tokenData.Email + sanitizedEmail = strings.ReplaceAll(sanitizedEmail, "@", "-") + sanitizedEmail = strings.ReplaceAll(sanitizedEmail, ".", "-") + fileName = fmt.Sprintf("kiro-%s-%s.json", tokenData.AuthMethod, sanitizedEmail) + } + + authFilePath := filepath.Join(authDir, fileName) + + // Convert to storage format and save + storage := &KiroTokenStorage{ + Type: "kiro", + AccessToken: tokenData.AccessToken, + RefreshToken: tokenData.RefreshToken, + ProfileArn: tokenData.ProfileArn, + ExpiresAt: tokenData.ExpiresAt, + AuthMethod: tokenData.AuthMethod, + Provider: tokenData.Provider, + LastRefresh: time.Now().Format(time.RFC3339), + ClientID: tokenData.ClientID, + ClientSecret: tokenData.ClientSecret, + Region: tokenData.Region, + StartURL: tokenData.StartURL, + Email: tokenData.Email, + } + + if err := storage.SaveTokenToFile(authFilePath); err != nil { + log.Errorf("OAuth Web: failed to save token to file: %v", err) + return + } + + log.Infof("OAuth Web: token saved to %s", authFilePath) +} + +func (h *OAuthWebHandler) ssoClient(session *webAuthSession) *SSOOIDCClient { + return session.ssoClient +} + +func (h *OAuthWebHandler) handleCallback(c *gin.Context) { + stateID := c.Query("state") + errParam := c.Query("error") + + if errParam != "" { + h.renderError(c, errParam) + return + } + + if stateID == "" { + h.renderError(c, "Missing state parameter") + return + } + + h.mu.RLock() + session, exists := h.sessions[stateID] + h.mu.RUnlock() + + if !exists { + h.renderError(c, "Invalid or expired session") + return + } + + if session.status == statusSuccess { + h.renderSuccess(c, session) + } else if session.status == statusFailed { + h.renderError(c, session.error) + } else { + c.Redirect(http.StatusFound, "/v0/oauth/kiro/start") + } +} + +func (h *OAuthWebHandler) handleSocialCallback(c *gin.Context) { + stateID := c.Query("state") + code := c.Query("code") + errParam := c.Query("error") + + if errParam != "" { + h.renderError(c, errParam) + return + } + + if stateID == "" { + h.renderError(c, "Missing state parameter") + return + } + + if code == "" { + h.renderError(c, "Missing authorization code") + return + } + + h.mu.RLock() + session, exists := h.sessions[stateID] + h.mu.RUnlock() + + if !exists { + h.renderError(c, "Invalid or expired session") + return + } + + if session.authMethod != "google" && session.authMethod != "github" { + h.renderError(c, "Invalid session type for social callback") + return + } + + socialClient := NewSocialAuthClient(h.cfg) + redirectURI := h.getSocialCallbackURL(c) + + tokenReq := &CreateTokenRequest{ + Code: code, + CodeVerifier: session.codeVerifier, + RedirectURI: redirectURI, + } + + tokenResp, err := socialClient.CreateToken(c.Request.Context(), tokenReq) + if err != nil { + log.Errorf("OAuth Web: social token exchange failed: %v", err) + h.mu.Lock() + session.status = statusFailed + session.error = fmt.Sprintf("Token exchange failed: %v", err) + session.completedAt = time.Now() + h.mu.Unlock() + h.renderError(c, session.error) + return + } + + expiresIn := tokenResp.ExpiresIn + if expiresIn <= 0 { + expiresIn = 3600 + } + expiresAt := time.Now().Add(time.Duration(expiresIn) * time.Second) + + email := ExtractEmailFromJWT(tokenResp.AccessToken) + + var provider string + if session.authMethod == "google" { + provider = string(ProviderGoogle) + } else { + provider = string(ProviderGitHub) + } + + tokenData := &KiroTokenData{ + AccessToken: tokenResp.AccessToken, + RefreshToken: tokenResp.RefreshToken, + ProfileArn: tokenResp.ProfileArn, + ExpiresAt: expiresAt.Format(time.RFC3339), + AuthMethod: session.authMethod, + Provider: provider, + Email: email, + Region: "us-east-1", + } + + h.mu.Lock() + session.status = statusSuccess + session.completedAt = time.Now() + session.expiresAt = expiresAt + session.tokenData = tokenData + h.mu.Unlock() + + if session.cancelFunc != nil { + session.cancelFunc() + } + + if h.onTokenObtained != nil { + h.onTokenObtained(tokenData) + } + + // Save token to file + h.saveTokenToFile(tokenData) + + log.Infof("OAuth Web: social authentication successful for %s via %s", email, provider) + h.renderSuccess(c, session) +} + +func (h *OAuthWebHandler) handleStatus(c *gin.Context) { + stateID := c.Query("state") + if stateID == "" { + c.JSON(http.StatusBadRequest, gin.H{"error": "missing state parameter"}) + return + } + + h.mu.RLock() + session, exists := h.sessions[stateID] + h.mu.RUnlock() + + if !exists { + c.JSON(http.StatusNotFound, gin.H{"error": "session not found"}) + return + } + + response := gin.H{ + "status": string(session.status), + } + + switch session.status { + case statusPending: + elapsed := time.Since(session.startedAt).Seconds() + remaining := float64(session.expiresIn) - elapsed + if remaining < 0 { + remaining = 0 + } + response["remaining_seconds"] = int(remaining) + case statusSuccess: + response["completed_at"] = session.completedAt.Format(time.RFC3339) + response["expires_at"] = session.expiresAt.Format(time.RFC3339) + case statusFailed: + response["error"] = session.error + response["failed_at"] = session.completedAt.Format(time.RFC3339) + } + + c.JSON(http.StatusOK, response) +} + +func (h *OAuthWebHandler) renderStartPage(c *gin.Context, session *webAuthSession) { + tmpl, err := template.New("start").Parse(oauthWebStartPageHTML) + if err != nil { + log.Errorf("OAuth Web: failed to parse template: %v", err) + c.String(http.StatusInternalServerError, "Template error") + return + } + + data := map[string]interface{}{ + "AuthURL": session.authURL, + "UserCode": session.userCode, + "ExpiresIn": session.expiresIn, + "StateID": session.stateID, + } + + c.Header("Content-Type", "text/html; charset=utf-8") + if err := tmpl.Execute(c.Writer, data); err != nil { + log.Errorf("OAuth Web: failed to render template: %v", err) + } +} + +func (h *OAuthWebHandler) renderSelectPage(c *gin.Context) { + tmpl, err := template.New("select").Parse(oauthWebSelectPageHTML) + if err != nil { + log.Errorf("OAuth Web: failed to parse select template: %v", err) + c.String(http.StatusInternalServerError, "Template error") + return + } + + c.Header("Content-Type", "text/html; charset=utf-8") + if err := tmpl.Execute(c.Writer, nil); err != nil { + log.Errorf("OAuth Web: failed to render select template: %v", err) + } +} + +func (h *OAuthWebHandler) renderError(c *gin.Context, errMsg string) { + tmpl, err := template.New("error").Parse(oauthWebErrorPageHTML) + if err != nil { + log.Errorf("OAuth Web: failed to parse error template: %v", err) + c.String(http.StatusInternalServerError, "Template error") + return + } + + data := map[string]interface{}{ + "Error": errMsg, + } + + c.Header("Content-Type", "text/html; charset=utf-8") + c.Status(http.StatusBadRequest) + if err := tmpl.Execute(c.Writer, data); err != nil { + log.Errorf("OAuth Web: failed to render error template: %v", err) + } +} + +func (h *OAuthWebHandler) renderSuccess(c *gin.Context, session *webAuthSession) { + tmpl, err := template.New("success").Parse(oauthWebSuccessPageHTML) + if err != nil { + log.Errorf("OAuth Web: failed to parse success template: %v", err) + c.String(http.StatusInternalServerError, "Template error") + return + } + + data := map[string]interface{}{ + "ExpiresAt": session.expiresAt.Format(time.RFC3339), + } + + c.Header("Content-Type", "text/html; charset=utf-8") + if err := tmpl.Execute(c.Writer, data); err != nil { + log.Errorf("OAuth Web: failed to render success template: %v", err) + } +} + +func (h *OAuthWebHandler) CleanupExpiredSessions() { + h.mu.Lock() + defer h.mu.Unlock() + + now := time.Now() + for id, session := range h.sessions { + if session.status != statusPending && now.Sub(session.completedAt) > 30*time.Minute { + delete(h.sessions, id) + } else if session.status == statusPending && now.Sub(session.startedAt) > defaultSessionExpiry { + session.cancelFunc() + delete(h.sessions, id) + } + } +} + +func (h *OAuthWebHandler) GetSession(stateID string) (*webAuthSession, bool) { + h.mu.RLock() + defer h.mu.RUnlock() + session, exists := h.sessions[stateID] + return session, exists +} + +// ImportTokenRequest represents the request body for token import +type ImportTokenRequest struct { + RefreshToken string `json:"refreshToken"` +} + +// handleImportToken handles manual refresh token import from Kiro IDE +func (h *OAuthWebHandler) handleImportToken(c *gin.Context) { + var req ImportTokenRequest + if err := c.ShouldBindJSON(&req); err != nil { + c.JSON(http.StatusBadRequest, gin.H{ + "success": false, + "error": "Invalid request body", + }) + return + } + + refreshToken := strings.TrimSpace(req.RefreshToken) + if refreshToken == "" { + c.JSON(http.StatusBadRequest, gin.H{ + "success": false, + "error": "Refresh token is required", + }) + return + } + + // Validate token format + if !strings.HasPrefix(refreshToken, "aorAAAAAG") { + c.JSON(http.StatusBadRequest, gin.H{ + "success": false, + "error": "Invalid token format. Token should start with aorAAAAAG...", + }) + return + } + + // Create social auth client to refresh and validate the token + socialClient := NewSocialAuthClient(h.cfg) + + // Refresh the token to validate it and get access token + tokenData, err := socialClient.RefreshSocialToken(c.Request.Context(), refreshToken) + if err != nil { + log.Errorf("OAuth Web: token refresh failed during import: %v", err) + c.JSON(http.StatusBadRequest, gin.H{ + "success": false, + "error": fmt.Sprintf("Token validation failed: %v", err), + }) + return + } + + // Set the original refresh token (the refreshed one might be empty) + if tokenData.RefreshToken == "" { + tokenData.RefreshToken = refreshToken + } + tokenData.AuthMethod = "social" + tokenData.Provider = "imported" + + // Notify callback if set + if h.onTokenObtained != nil { + h.onTokenObtained(tokenData) + } + + // Save token to file + h.saveTokenToFile(tokenData) + + // Generate filename for response + fileName := fmt.Sprintf("kiro-%s.json", tokenData.AuthMethod) + if tokenData.Email != "" { + sanitizedEmail := strings.ReplaceAll(tokenData.Email, "@", "-") + sanitizedEmail = strings.ReplaceAll(sanitizedEmail, ".", "-") + fileName = fmt.Sprintf("kiro-%s-%s.json", tokenData.AuthMethod, sanitizedEmail) + } + + log.Infof("OAuth Web: token imported successfully") + c.JSON(http.StatusOK, gin.H{ + "success": true, + "message": "Token imported successfully", + "fileName": fileName, + }) +} diff --git a/internal/auth/kiro/oauth_web.go.bak b/internal/auth/kiro/oauth_web.go.bak new file mode 100644 index 00000000..22d7809b --- /dev/null +++ b/internal/auth/kiro/oauth_web.go.bak @@ -0,0 +1,385 @@ +// Package kiro provides OAuth Web authentication for Kiro. +package kiro + +import ( + "context" + "crypto/rand" + "encoding/base64" + "fmt" + "html/template" + "net/http" + "sync" + "time" + + "github.com/gin-gonic/gin" + "github.com/router-for-me/CLIProxyAPI/v6/internal/config" + log "github.com/sirupsen/logrus" +) + +const ( + defaultSessionExpiry = 10 * time.Minute + pollIntervalSeconds = 5 +) + +type authSessionStatus string + +const ( + statusPending authSessionStatus = "pending" + statusSuccess authSessionStatus = "success" + statusFailed authSessionStatus = "failed" +) + +type webAuthSession struct { + stateID string + deviceCode string + userCode string + authURL string + verificationURI string + expiresIn int + interval int + status authSessionStatus + startedAt time.Time + completedAt time.Time + expiresAt time.Time + error string + tokenData *KiroTokenData + ssoClient *SSOOIDCClient + clientID string + clientSecret string + region string + cancelFunc context.CancelFunc +} + +type OAuthWebHandler struct { + cfg *config.Config + sessions map[string]*webAuthSession + mu sync.RWMutex + onTokenObtained func(*KiroTokenData) +} + +func NewOAuthWebHandler(cfg *config.Config) *OAuthWebHandler { + return &OAuthWebHandler{ + cfg: cfg, + sessions: make(map[string]*webAuthSession), + } +} + +func (h *OAuthWebHandler) SetTokenCallback(callback func(*KiroTokenData)) { + h.onTokenObtained = callback +} + +func (h *OAuthWebHandler) RegisterRoutes(router gin.IRouter) { + oauth := router.Group("/v0/oauth/kiro") + { + oauth.GET("/start", h.handleStart) + oauth.GET("/callback", h.handleCallback) + oauth.GET("/status", h.handleStatus) + } +} + +func generateStateID() (string, error) { + b := make([]byte, 16) + if _, err := rand.Read(b); err != nil { + return "", err + } + return base64.RawURLEncoding.EncodeToString(b), nil +} + +func (h *OAuthWebHandler) handleStart(c *gin.Context) { + stateID, err := generateStateID() + if err != nil { + h.renderError(c, "Failed to generate state parameter") + return + } + + region := defaultIDCRegion + startURL := builderIDStartURL + + ssoClient := NewSSOOIDCClient(h.cfg) + + regResp, err := ssoClient.RegisterClientWithRegion(c.Request.Context(), region) + if err != nil { + log.Errorf("OAuth Web: failed to register client: %v", err) + h.renderError(c, fmt.Sprintf("Failed to register client: %v", err)) + return + } + + authResp, err := ssoClient.StartDeviceAuthorizationWithIDC( + c.Request.Context(), + regResp.ClientID, + regResp.ClientSecret, + startURL, + region, + ) + if err != nil { + log.Errorf("OAuth Web: failed to start device authorization: %v", err) + h.renderError(c, fmt.Sprintf("Failed to start device authorization: %v", err)) + return + } + + ctx, cancel := context.WithTimeout(context.Background(), time.Duration(authResp.ExpiresIn)*time.Second) + + session := &webAuthSession{ + stateID: stateID, + deviceCode: authResp.DeviceCode, + userCode: authResp.UserCode, + authURL: authResp.VerificationURIComplete, + verificationURI: authResp.VerificationURI, + expiresIn: authResp.ExpiresIn, + interval: authResp.Interval, + status: statusPending, + startedAt: time.Now(), + ssoClient: ssoClient, + clientID: regResp.ClientID, + clientSecret: regResp.ClientSecret, + region: region, + cancelFunc: cancel, + } + + h.mu.Lock() + h.sessions[stateID] = session + h.mu.Unlock() + + go h.pollForToken(ctx, session) + + h.renderStartPage(c, session) +} + +func (h *OAuthWebHandler) pollForToken(ctx context.Context, session *webAuthSession) { + defer session.cancelFunc() + + interval := time.Duration(session.interval) * time.Second + if interval < time.Duration(pollIntervalSeconds)*time.Second { + interval = time.Duration(pollIntervalSeconds) * time.Second + } + + ticker := time.NewTicker(interval) + defer ticker.Stop() + + for { + select { + case <-ctx.Done(): + h.mu.Lock() + if session.status == statusPending { + session.status = statusFailed + session.error = "Authentication timed out" + } + h.mu.Unlock() + return + case <-ticker.C: + tokenResp, err := h.ssoClient(session).CreateTokenWithRegion( + ctx, + session.clientID, + session.clientSecret, + session.deviceCode, + session.region, + ) + + if err != nil { + errStr := err.Error() + if errStr == ErrAuthorizationPending.Error() { + continue + } + if errStr == ErrSlowDown.Error() { + interval += 5 * time.Second + ticker.Reset(interval) + continue + } + + h.mu.Lock() + session.status = statusFailed + session.error = errStr + session.completedAt = time.Now() + h.mu.Unlock() + + log.Errorf("OAuth Web: token polling failed: %v", err) + return + } + + expiresAt := time.Now().Add(time.Duration(tokenResp.ExpiresIn) * time.Second) + profileArn := session.ssoClient.fetchProfileArn(ctx, tokenResp.AccessToken) + email := FetchUserEmailWithFallback(ctx, h.cfg, tokenResp.AccessToken) + + tokenData := &KiroTokenData{ + AccessToken: tokenResp.AccessToken, + RefreshToken: tokenResp.RefreshToken, + ProfileArn: profileArn, + ExpiresAt: expiresAt.Format(time.RFC3339), + AuthMethod: "builder-id", + Provider: "AWS", + ClientID: session.clientID, + ClientSecret: session.clientSecret, + Email: email, + } + + h.mu.Lock() + session.status = statusSuccess + session.completedAt = time.Now() + session.expiresAt = expiresAt + session.tokenData = tokenData + h.mu.Unlock() + + if h.onTokenObtained != nil { + h.onTokenObtained(tokenData) + } + + log.Infof("OAuth Web: authentication successful for %s", email) + return + } + } +} + +func (h *OAuthWebHandler) ssoClient(session *webAuthSession) *SSOOIDCClient { + return session.ssoClient +} + +func (h *OAuthWebHandler) handleCallback(c *gin.Context) { + stateID := c.Query("state") + errParam := c.Query("error") + + if errParam != "" { + h.renderError(c, errParam) + return + } + + if stateID == "" { + h.renderError(c, "Missing state parameter") + return + } + + h.mu.RLock() + session, exists := h.sessions[stateID] + h.mu.RUnlock() + + if !exists { + h.renderError(c, "Invalid or expired session") + return + } + + if session.status == statusSuccess { + h.renderSuccess(c, session) + } else if session.status == statusFailed { + h.renderError(c, session.error) + } else { + c.Redirect(http.StatusFound, "/v0/oauth/kiro/start") + } +} + +func (h *OAuthWebHandler) handleStatus(c *gin.Context) { + stateID := c.Query("state") + if stateID == "" { + c.JSON(http.StatusBadRequest, gin.H{"error": "missing state parameter"}) + return + } + + h.mu.RLock() + session, exists := h.sessions[stateID] + h.mu.RUnlock() + + if !exists { + c.JSON(http.StatusNotFound, gin.H{"error": "session not found"}) + return + } + + response := gin.H{ + "status": string(session.status), + } + + switch session.status { + case statusPending: + elapsed := time.Since(session.startedAt).Seconds() + remaining := float64(session.expiresIn) - elapsed + if remaining < 0 { + remaining = 0 + } + response["remaining_seconds"] = int(remaining) + case statusSuccess: + response["completed_at"] = session.completedAt.Format(time.RFC3339) + response["expires_at"] = session.expiresAt.Format(time.RFC3339) + case statusFailed: + response["error"] = session.error + response["failed_at"] = session.completedAt.Format(time.RFC3339) + } + + c.JSON(http.StatusOK, response) +} + +func (h *OAuthWebHandler) renderStartPage(c *gin.Context, session *webAuthSession) { + tmpl, err := template.New("start").Parse(oauthWebStartPageHTML) + if err != nil { + log.Errorf("OAuth Web: failed to parse template: %v", err) + c.String(http.StatusInternalServerError, "Template error") + return + } + + data := map[string]interface{}{ + "AuthURL": session.authURL, + "UserCode": session.userCode, + "ExpiresIn": session.expiresIn, + "StateID": session.stateID, + } + + c.Header("Content-Type", "text/html; charset=utf-8") + if err := tmpl.Execute(c.Writer, data); err != nil { + log.Errorf("OAuth Web: failed to render template: %v", err) + } +} + +func (h *OAuthWebHandler) renderError(c *gin.Context, errMsg string) { + tmpl, err := template.New("error").Parse(oauthWebErrorPageHTML) + if err != nil { + log.Errorf("OAuth Web: failed to parse error template: %v", err) + c.String(http.StatusInternalServerError, "Template error") + return + } + + data := map[string]interface{}{ + "Error": errMsg, + } + + c.Header("Content-Type", "text/html; charset=utf-8") + c.Status(http.StatusBadRequest) + if err := tmpl.Execute(c.Writer, data); err != nil { + log.Errorf("OAuth Web: failed to render error template: %v", err) + } +} + +func (h *OAuthWebHandler) renderSuccess(c *gin.Context, session *webAuthSession) { + tmpl, err := template.New("success").Parse(oauthWebSuccessPageHTML) + if err != nil { + log.Errorf("OAuth Web: failed to parse success template: %v", err) + c.String(http.StatusInternalServerError, "Template error") + return + } + + data := map[string]interface{}{ + "ExpiresAt": session.expiresAt.Format(time.RFC3339), + } + + c.Header("Content-Type", "text/html; charset=utf-8") + if err := tmpl.Execute(c.Writer, data); err != nil { + log.Errorf("OAuth Web: failed to render success template: %v", err) + } +} + +func (h *OAuthWebHandler) CleanupExpiredSessions() { + h.mu.Lock() + defer h.mu.Unlock() + + now := time.Now() + for id, session := range h.sessions { + if session.status != statusPending && now.Sub(session.completedAt) > 30*time.Minute { + delete(h.sessions, id) + } else if session.status == statusPending && now.Sub(session.startedAt) > defaultSessionExpiry { + session.cancelFunc() + delete(h.sessions, id) + } + } +} + +func (h *OAuthWebHandler) GetSession(stateID string) (*webAuthSession, bool) { + h.mu.RLock() + defer h.mu.RUnlock() + session, exists := h.sessions[stateID] + return session, exists +} diff --git a/internal/auth/kiro/oauth_web_templates.go b/internal/auth/kiro/oauth_web_templates.go new file mode 100644 index 00000000..064a1ff9 --- /dev/null +++ b/internal/auth/kiro/oauth_web_templates.go @@ -0,0 +1,732 @@ +// Package kiro provides OAuth Web authentication templates. +package kiro + +const ( + oauthWebStartPageHTML = ` + + + + + AWS SSO Authentication + + + +
+

🔐 AWS SSO Authentication

+

Follow the steps below to complete authentication

+ +
+
+ 1 + Click the button below to open the authorization page +
+ + 🚀 Open Authorization Page + +
+ +
+
+ 2 + Enter the verification code below +
+
+
Verification Code
+
{{.UserCode}}
+
+
+ +
+
+ 3 + Complete AWS SSO login +
+

+ Use your AWS SSO account to login and authorize +

+
+ +
+
+
{{.ExpiresIn}}s
+
+ Waiting for authorization... +
+
+ +
+ 💡 Tip: The authorization page will open in a new tab. This page will automatically update once authorization is complete. +
+
+ + + +` + + oauthWebErrorPageHTML = ` + + + + + Authentication Failed + + + +
+

❌ Authentication Failed

+
+

Error:

+

{{.Error}}

+
+ 🔄 Retry +
+ +` + + oauthWebSuccessPageHTML = ` + + + + + Authentication Successful + + + +
+
+

Authentication Successful!

+
+

You can close this window.

+
+
Token expires: {{.ExpiresAt}}
+
+ +` + + oauthWebSelectPageHTML = ` + + + + + Select Authentication Method + + + +
+

🔐 Select Authentication Method

+

Choose how you want to authenticate with Kiro

+ +
+ + 🔶 + AWS Builder ID (Recommended) + + + + +
or
+ + +
+ +
+
+ + +
+ + +
Your AWS Identity Center Start URL
+
+ +
+ + +
AWS Region for your Identity Center
+
+ + +
+
+ +
+
+
+ + +
Copy from Kiro IDE: ~/.kiro/kiro-auth-token.json → refreshToken field
+
+ + + +
+
+
+ +
+ ⚠️ Note: Google and GitHub login are not available for third-party applications due to AWS Cognito restrictions. Please use AWS Builder ID or import your token from Kiro IDE. +
+ +
+ 💡 How to get RefreshToken:
+ 1. Open Kiro IDE and login with Google/GitHub
+ 2. Find the token file: ~/.kiro/kiro-auth-token.json
+ 3. Copy the refreshToken value and paste it above +
+
+ + + +` +) diff --git a/internal/auth/kiro/rate_limiter.go b/internal/auth/kiro/rate_limiter.go new file mode 100644 index 00000000..3c240ebe --- /dev/null +++ b/internal/auth/kiro/rate_limiter.go @@ -0,0 +1,316 @@ +package kiro + +import ( + "math" + "math/rand" + "strings" + "sync" + "time" +) + +const ( + DefaultMinTokenInterval = 10 * time.Second + DefaultMaxTokenInterval = 30 * time.Second + DefaultDailyMaxRequests = 500 + DefaultJitterPercent = 0.3 + DefaultBackoffBase = 2 * time.Minute + DefaultBackoffMax = 60 * time.Minute + DefaultBackoffMultiplier = 2.0 + DefaultSuspendCooldown = 24 * time.Hour +) + +// TokenState Token 状态 +type TokenState struct { + LastRequest time.Time + RequestCount int + CooldownEnd time.Time + FailCount int + DailyRequests int + DailyResetTime time.Time + IsSuspended bool + SuspendedAt time.Time + SuspendReason string +} + +// RateLimiter 频率限制器 +type RateLimiter struct { + mu sync.RWMutex + states map[string]*TokenState + minTokenInterval time.Duration + maxTokenInterval time.Duration + dailyMaxRequests int + jitterPercent float64 + backoffBase time.Duration + backoffMax time.Duration + backoffMultiplier float64 + suspendCooldown time.Duration + rng *rand.Rand +} + +// NewRateLimiter 创建默认配置的频率限制器 +func NewRateLimiter() *RateLimiter { + return &RateLimiter{ + states: make(map[string]*TokenState), + minTokenInterval: DefaultMinTokenInterval, + maxTokenInterval: DefaultMaxTokenInterval, + dailyMaxRequests: DefaultDailyMaxRequests, + jitterPercent: DefaultJitterPercent, + backoffBase: DefaultBackoffBase, + backoffMax: DefaultBackoffMax, + backoffMultiplier: DefaultBackoffMultiplier, + suspendCooldown: DefaultSuspendCooldown, + rng: rand.New(rand.NewSource(time.Now().UnixNano())), + } +} + +// RateLimiterConfig 频率限制器配置 +type RateLimiterConfig struct { + MinTokenInterval time.Duration + MaxTokenInterval time.Duration + DailyMaxRequests int + JitterPercent float64 + BackoffBase time.Duration + BackoffMax time.Duration + BackoffMultiplier float64 + SuspendCooldown time.Duration +} + +// NewRateLimiterWithConfig 使用自定义配置创建频率限制器 +func NewRateLimiterWithConfig(cfg RateLimiterConfig) *RateLimiter { + rl := NewRateLimiter() + if cfg.MinTokenInterval > 0 { + rl.minTokenInterval = cfg.MinTokenInterval + } + if cfg.MaxTokenInterval > 0 { + rl.maxTokenInterval = cfg.MaxTokenInterval + } + if cfg.DailyMaxRequests > 0 { + rl.dailyMaxRequests = cfg.DailyMaxRequests + } + if cfg.JitterPercent > 0 { + rl.jitterPercent = cfg.JitterPercent + } + if cfg.BackoffBase > 0 { + rl.backoffBase = cfg.BackoffBase + } + if cfg.BackoffMax > 0 { + rl.backoffMax = cfg.BackoffMax + } + if cfg.BackoffMultiplier > 0 { + rl.backoffMultiplier = cfg.BackoffMultiplier + } + if cfg.SuspendCooldown > 0 { + rl.suspendCooldown = cfg.SuspendCooldown + } + return rl +} + +// getOrCreateState 获取或创建 Token 状态 +func (rl *RateLimiter) getOrCreateState(tokenKey string) *TokenState { + state, exists := rl.states[tokenKey] + if !exists { + state = &TokenState{ + DailyResetTime: time.Now().Truncate(24 * time.Hour).Add(24 * time.Hour), + } + rl.states[tokenKey] = state + } + return state +} + +// resetDailyIfNeeded 如果需要则重置每日计数 +func (rl *RateLimiter) resetDailyIfNeeded(state *TokenState) { + now := time.Now() + if now.After(state.DailyResetTime) { + state.DailyRequests = 0 + state.DailyResetTime = now.Truncate(24 * time.Hour).Add(24 * time.Hour) + } +} + +// calculateInterval 计算带抖动的随机间隔 +func (rl *RateLimiter) calculateInterval() time.Duration { + baseInterval := rl.minTokenInterval + time.Duration(rl.rng.Int63n(int64(rl.maxTokenInterval-rl.minTokenInterval))) + jitter := time.Duration(float64(baseInterval) * rl.jitterPercent * (rl.rng.Float64()*2 - 1)) + return baseInterval + jitter +} + +// WaitForToken 等待 Token 可用(带抖动的随机间隔) +func (rl *RateLimiter) WaitForToken(tokenKey string) { + rl.mu.Lock() + state := rl.getOrCreateState(tokenKey) + rl.resetDailyIfNeeded(state) + + now := time.Now() + + // 检查是否在冷却期 + if now.Before(state.CooldownEnd) { + waitTime := state.CooldownEnd.Sub(now) + rl.mu.Unlock() + time.Sleep(waitTime) + rl.mu.Lock() + state = rl.getOrCreateState(tokenKey) + now = time.Now() + } + + // 计算距离上次请求的间隔 + interval := rl.calculateInterval() + nextAllowedTime := state.LastRequest.Add(interval) + + if now.Before(nextAllowedTime) { + waitTime := nextAllowedTime.Sub(now) + rl.mu.Unlock() + time.Sleep(waitTime) + rl.mu.Lock() + state = rl.getOrCreateState(tokenKey) + } + + state.LastRequest = time.Now() + state.RequestCount++ + state.DailyRequests++ + rl.mu.Unlock() +} + +// MarkTokenFailed 标记 Token 失败 +func (rl *RateLimiter) MarkTokenFailed(tokenKey string) { + rl.mu.Lock() + defer rl.mu.Unlock() + + state := rl.getOrCreateState(tokenKey) + state.FailCount++ + state.CooldownEnd = time.Now().Add(rl.calculateBackoff(state.FailCount)) +} + +// MarkTokenSuccess 标记 Token 成功 +func (rl *RateLimiter) MarkTokenSuccess(tokenKey string) { + rl.mu.Lock() + defer rl.mu.Unlock() + + state := rl.getOrCreateState(tokenKey) + state.FailCount = 0 + state.CooldownEnd = time.Time{} +} + +// CheckAndMarkSuspended 检测暂停错误并标记 +func (rl *RateLimiter) CheckAndMarkSuspended(tokenKey string, errorMsg string) bool { + suspendKeywords := []string{ + "suspended", + "banned", + "disabled", + "account has been", + "access denied", + "rate limit exceeded", + "too many requests", + "quota exceeded", + } + + lowerMsg := strings.ToLower(errorMsg) + for _, keyword := range suspendKeywords { + if strings.Contains(lowerMsg, keyword) { + rl.mu.Lock() + defer rl.mu.Unlock() + + state := rl.getOrCreateState(tokenKey) + state.IsSuspended = true + state.SuspendedAt = time.Now() + state.SuspendReason = errorMsg + state.CooldownEnd = time.Now().Add(rl.suspendCooldown) + return true + } + } + return false +} + +// IsTokenAvailable 检查 Token 是否可用 +func (rl *RateLimiter) IsTokenAvailable(tokenKey string) bool { + rl.mu.RLock() + defer rl.mu.RUnlock() + + state, exists := rl.states[tokenKey] + if !exists { + return true + } + + now := time.Now() + + // 检查是否被暂停 + if state.IsSuspended { + if now.After(state.SuspendedAt.Add(rl.suspendCooldown)) { + return true + } + return false + } + + // 检查是否在冷却期 + if now.Before(state.CooldownEnd) { + return false + } + + // 检查每日请求限制 + rl.mu.RUnlock() + rl.mu.Lock() + rl.resetDailyIfNeeded(state) + dailyRequests := state.DailyRequests + dailyMax := rl.dailyMaxRequests + rl.mu.Unlock() + rl.mu.RLock() + + if dailyRequests >= dailyMax { + return false + } + + return true +} + +// calculateBackoff 计算指数退避时间 +func (rl *RateLimiter) calculateBackoff(failCount int) time.Duration { + if failCount <= 0 { + return 0 + } + + backoff := float64(rl.backoffBase) * math.Pow(rl.backoffMultiplier, float64(failCount-1)) + + // 添加抖动 + jitter := backoff * rl.jitterPercent * (rl.rng.Float64()*2 - 1) + backoff += jitter + + if time.Duration(backoff) > rl.backoffMax { + return rl.backoffMax + } + return time.Duration(backoff) +} + +// GetTokenState 获取 Token 状态(只读) +func (rl *RateLimiter) GetTokenState(tokenKey string) *TokenState { + rl.mu.RLock() + defer rl.mu.RUnlock() + + state, exists := rl.states[tokenKey] + if !exists { + return nil + } + + // 返回副本以防止外部修改 + stateCopy := *state + return &stateCopy +} + +// ClearTokenState 清除 Token 状态 +func (rl *RateLimiter) ClearTokenState(tokenKey string) { + rl.mu.Lock() + defer rl.mu.Unlock() + delete(rl.states, tokenKey) +} + +// ResetSuspension 重置暂停状态 +func (rl *RateLimiter) ResetSuspension(tokenKey string) { + rl.mu.Lock() + defer rl.mu.Unlock() + + state, exists := rl.states[tokenKey] + if exists { + state.IsSuspended = false + state.SuspendedAt = time.Time{} + state.SuspendReason = "" + state.CooldownEnd = time.Time{} + state.FailCount = 0 + } +} diff --git a/internal/auth/kiro/rate_limiter_singleton.go b/internal/auth/kiro/rate_limiter_singleton.go new file mode 100644 index 00000000..4c02af89 --- /dev/null +++ b/internal/auth/kiro/rate_limiter_singleton.go @@ -0,0 +1,46 @@ +package kiro + +import ( + "sync" + "time" + + log "github.com/sirupsen/logrus" +) + +var ( + globalRateLimiter *RateLimiter + globalRateLimiterOnce sync.Once + + globalCooldownManager *CooldownManager + globalCooldownManagerOnce sync.Once + cooldownStopCh chan struct{} +) + +// GetGlobalRateLimiter returns the singleton RateLimiter instance. +func GetGlobalRateLimiter() *RateLimiter { + globalRateLimiterOnce.Do(func() { + globalRateLimiter = NewRateLimiter() + log.Info("kiro: global RateLimiter initialized") + }) + return globalRateLimiter +} + +// GetGlobalCooldownManager returns the singleton CooldownManager instance. +func GetGlobalCooldownManager() *CooldownManager { + globalCooldownManagerOnce.Do(func() { + globalCooldownManager = NewCooldownManager() + cooldownStopCh = make(chan struct{}) + go globalCooldownManager.StartCleanupRoutine(5*time.Minute, cooldownStopCh) + log.Info("kiro: global CooldownManager initialized with cleanup routine") + }) + return globalCooldownManager +} + +// ShutdownRateLimiters stops the cooldown cleanup routine. +// Should be called during application shutdown. +func ShutdownRateLimiters() { + if cooldownStopCh != nil { + close(cooldownStopCh) + log.Info("kiro: rate limiter cleanup routine stopped") + } +} diff --git a/internal/auth/kiro/rate_limiter_test.go b/internal/auth/kiro/rate_limiter_test.go new file mode 100644 index 00000000..636413dd --- /dev/null +++ b/internal/auth/kiro/rate_limiter_test.go @@ -0,0 +1,304 @@ +package kiro + +import ( + "sync" + "testing" + "time" +) + +func TestNewRateLimiter(t *testing.T) { + rl := NewRateLimiter() + if rl == nil { + t.Fatal("expected non-nil RateLimiter") + } + if rl.states == nil { + t.Error("expected non-nil states map") + } + if rl.minTokenInterval != DefaultMinTokenInterval { + t.Errorf("expected minTokenInterval %v, got %v", DefaultMinTokenInterval, rl.minTokenInterval) + } + if rl.maxTokenInterval != DefaultMaxTokenInterval { + t.Errorf("expected maxTokenInterval %v, got %v", DefaultMaxTokenInterval, rl.maxTokenInterval) + } + if rl.dailyMaxRequests != DefaultDailyMaxRequests { + t.Errorf("expected dailyMaxRequests %d, got %d", DefaultDailyMaxRequests, rl.dailyMaxRequests) + } +} + +func TestNewRateLimiterWithConfig(t *testing.T) { + cfg := RateLimiterConfig{ + MinTokenInterval: 5 * time.Second, + MaxTokenInterval: 15 * time.Second, + DailyMaxRequests: 100, + JitterPercent: 0.2, + BackoffBase: 1 * time.Minute, + BackoffMax: 30 * time.Minute, + BackoffMultiplier: 1.5, + SuspendCooldown: 12 * time.Hour, + } + + rl := NewRateLimiterWithConfig(cfg) + if rl.minTokenInterval != 5*time.Second { + t.Errorf("expected minTokenInterval 5s, got %v", rl.minTokenInterval) + } + if rl.maxTokenInterval != 15*time.Second { + t.Errorf("expected maxTokenInterval 15s, got %v", rl.maxTokenInterval) + } + if rl.dailyMaxRequests != 100 { + t.Errorf("expected dailyMaxRequests 100, got %d", rl.dailyMaxRequests) + } +} + +func TestNewRateLimiterWithConfig_PartialConfig(t *testing.T) { + cfg := RateLimiterConfig{ + MinTokenInterval: 5 * time.Second, + } + + rl := NewRateLimiterWithConfig(cfg) + if rl.minTokenInterval != 5*time.Second { + t.Errorf("expected minTokenInterval 5s, got %v", rl.minTokenInterval) + } + if rl.maxTokenInterval != DefaultMaxTokenInterval { + t.Errorf("expected default maxTokenInterval, got %v", rl.maxTokenInterval) + } +} + +func TestGetTokenState_NonExistent(t *testing.T) { + rl := NewRateLimiter() + state := rl.GetTokenState("nonexistent") + if state != nil { + t.Error("expected nil state for non-existent token") + } +} + +func TestIsTokenAvailable_NewToken(t *testing.T) { + rl := NewRateLimiter() + if !rl.IsTokenAvailable("newtoken") { + t.Error("expected new token to be available") + } +} + +func TestMarkTokenFailed(t *testing.T) { + rl := NewRateLimiter() + rl.MarkTokenFailed("token1") + + state := rl.GetTokenState("token1") + if state == nil { + t.Fatal("expected non-nil state") + } + if state.FailCount != 1 { + t.Errorf("expected FailCount 1, got %d", state.FailCount) + } + if state.CooldownEnd.IsZero() { + t.Error("expected non-zero CooldownEnd") + } +} + +func TestMarkTokenSuccess(t *testing.T) { + rl := NewRateLimiter() + rl.MarkTokenFailed("token1") + rl.MarkTokenFailed("token1") + rl.MarkTokenSuccess("token1") + + state := rl.GetTokenState("token1") + if state == nil { + t.Fatal("expected non-nil state") + } + if state.FailCount != 0 { + t.Errorf("expected FailCount 0, got %d", state.FailCount) + } + if !state.CooldownEnd.IsZero() { + t.Error("expected zero CooldownEnd after success") + } +} + +func TestCheckAndMarkSuspended_Suspended(t *testing.T) { + rl := NewRateLimiter() + + testCases := []string{ + "Account has been suspended", + "You are banned from this service", + "Account disabled", + "Access denied permanently", + "Rate limit exceeded", + "Too many requests", + "Quota exceeded for today", + } + + for i, msg := range testCases { + tokenKey := "token" + string(rune('a'+i)) + if !rl.CheckAndMarkSuspended(tokenKey, msg) { + t.Errorf("expected suspension detected for: %s", msg) + } + state := rl.GetTokenState(tokenKey) + if !state.IsSuspended { + t.Errorf("expected IsSuspended true for: %s", msg) + } + } +} + +func TestCheckAndMarkSuspended_NotSuspended(t *testing.T) { + rl := NewRateLimiter() + + normalErrors := []string{ + "connection timeout", + "internal server error", + "bad request", + "invalid token format", + } + + for i, msg := range normalErrors { + tokenKey := "token" + string(rune('a'+i)) + if rl.CheckAndMarkSuspended(tokenKey, msg) { + t.Errorf("unexpected suspension for: %s", msg) + } + } +} + +func TestIsTokenAvailable_Suspended(t *testing.T) { + rl := NewRateLimiter() + rl.CheckAndMarkSuspended("token1", "Account suspended") + + if rl.IsTokenAvailable("token1") { + t.Error("expected suspended token to be unavailable") + } +} + +func TestClearTokenState(t *testing.T) { + rl := NewRateLimiter() + rl.MarkTokenFailed("token1") + rl.ClearTokenState("token1") + + state := rl.GetTokenState("token1") + if state != nil { + t.Error("expected nil state after clear") + } +} + +func TestResetSuspension(t *testing.T) { + rl := NewRateLimiter() + rl.CheckAndMarkSuspended("token1", "Account suspended") + rl.ResetSuspension("token1") + + state := rl.GetTokenState("token1") + if state.IsSuspended { + t.Error("expected IsSuspended false after reset") + } + if state.FailCount != 0 { + t.Errorf("expected FailCount 0, got %d", state.FailCount) + } +} + +func TestResetSuspension_NonExistent(t *testing.T) { + rl := NewRateLimiter() + rl.ResetSuspension("nonexistent") +} + +func TestCalculateBackoff_ZeroFailCount(t *testing.T) { + rl := NewRateLimiter() + backoff := rl.calculateBackoff(0) + if backoff != 0 { + t.Errorf("expected 0 backoff for 0 fails, got %v", backoff) + } +} + +func TestCalculateBackoff_Exponential(t *testing.T) { + cfg := RateLimiterConfig{ + BackoffBase: 1 * time.Minute, + BackoffMax: 60 * time.Minute, + BackoffMultiplier: 2.0, + JitterPercent: 0.3, + } + rl := NewRateLimiterWithConfig(cfg) + + backoff1 := rl.calculateBackoff(1) + if backoff1 < 40*time.Second || backoff1 > 80*time.Second { + t.Errorf("expected ~1min (with jitter) for fail 1, got %v", backoff1) + } + + backoff2 := rl.calculateBackoff(2) + if backoff2 < 80*time.Second || backoff2 > 160*time.Second { + t.Errorf("expected ~2min (with jitter) for fail 2, got %v", backoff2) + } +} + +func TestCalculateBackoff_MaxCap(t *testing.T) { + cfg := RateLimiterConfig{ + BackoffBase: 1 * time.Minute, + BackoffMax: 10 * time.Minute, + BackoffMultiplier: 2.0, + JitterPercent: 0, + } + rl := NewRateLimiterWithConfig(cfg) + + backoff := rl.calculateBackoff(10) + if backoff > 10*time.Minute { + t.Errorf("expected backoff capped at 10min, got %v", backoff) + } +} + +func TestGetTokenState_ReturnsCopy(t *testing.T) { + rl := NewRateLimiter() + rl.MarkTokenFailed("token1") + + state1 := rl.GetTokenState("token1") + state1.FailCount = 999 + + state2 := rl.GetTokenState("token1") + if state2.FailCount == 999 { + t.Error("GetTokenState should return a copy") + } +} + +func TestRateLimiter_ConcurrentAccess(t *testing.T) { + rl := NewRateLimiter() + const numGoroutines = 50 + const numOperations = 50 + + var wg sync.WaitGroup + wg.Add(numGoroutines) + + for i := 0; i < numGoroutines; i++ { + go func(id int) { + defer wg.Done() + tokenKey := "token" + string(rune('a'+id%10)) + for j := 0; j < numOperations; j++ { + switch j % 6 { + case 0: + rl.IsTokenAvailable(tokenKey) + case 1: + rl.MarkTokenFailed(tokenKey) + case 2: + rl.MarkTokenSuccess(tokenKey) + case 3: + rl.GetTokenState(tokenKey) + case 4: + rl.CheckAndMarkSuspended(tokenKey, "test error") + case 5: + rl.ResetSuspension(tokenKey) + } + } + }(i) + } + + wg.Wait() +} + +func TestCalculateInterval_WithinRange(t *testing.T) { + cfg := RateLimiterConfig{ + MinTokenInterval: 10 * time.Second, + MaxTokenInterval: 30 * time.Second, + JitterPercent: 0.3, + } + rl := NewRateLimiterWithConfig(cfg) + + minAllowed := 7 * time.Second + maxAllowed := 40 * time.Second + + for i := 0; i < 100; i++ { + interval := rl.calculateInterval() + if interval < minAllowed || interval > maxAllowed { + t.Errorf("interval %v outside expected range [%v, %v]", interval, minAllowed, maxAllowed) + } + } +} diff --git a/internal/auth/kiro/social_auth.go b/internal/auth/kiro/social_auth.go index 2ac29bf8..277b83db 100644 --- a/internal/auth/kiro/social_auth.go +++ b/internal/auth/kiro/social_auth.go @@ -9,7 +9,9 @@ import ( "encoding/base64" "encoding/json" "fmt" + "html" "io" + "net" "net/http" "net/url" "os" @@ -31,6 +33,9 @@ const ( // OAuth timeout socialAuthTimeout = 10 * time.Minute + + // Default callback port for social auth HTTP server + socialAuthCallbackPort = 9876 ) // SocialProvider represents the social login provider. @@ -67,6 +72,13 @@ type RefreshTokenRequest struct { RefreshToken string `json:"refreshToken"` } +// WebCallbackResult contains the OAuth callback result from HTTP server. +type WebCallbackResult struct { + Code string + State string + Error string +} + // SocialAuthClient handles social authentication with Kiro. type SocialAuthClient struct { httpClient *http.Client @@ -87,6 +99,83 @@ func NewSocialAuthClient(cfg *config.Config) *SocialAuthClient { } } +// startWebCallbackServer starts a local HTTP server to receive the OAuth callback. +// This is used instead of the kiro:// protocol handler to avoid redirect_mismatch errors. +func (c *SocialAuthClient) startWebCallbackServer(ctx context.Context, expectedState string) (string, <-chan WebCallbackResult, error) { + // Try to find an available port - use localhost like Kiro does + listener, err := net.Listen("tcp", fmt.Sprintf("localhost:%d", socialAuthCallbackPort)) + if err != nil { + // Try with dynamic port (RFC 8252 allows dynamic ports for native apps) + log.Warnf("kiro social auth: default port %d is busy, falling back to dynamic port", socialAuthCallbackPort) + listener, err = net.Listen("tcp", "localhost:0") + if err != nil { + return "", nil, fmt.Errorf("failed to start callback server: %w", err) + } + } + + port := listener.Addr().(*net.TCPAddr).Port + // Use http scheme for local callback server + redirectURI := fmt.Sprintf("http://localhost:%d/oauth/callback", port) + resultChan := make(chan WebCallbackResult, 1) + + server := &http.Server{ + ReadHeaderTimeout: 10 * time.Second, + } + + mux := http.NewServeMux() + mux.HandleFunc("/oauth/callback", func(w http.ResponseWriter, r *http.Request) { + code := r.URL.Query().Get("code") + state := r.URL.Query().Get("state") + errParam := r.URL.Query().Get("error") + + if errParam != "" { + w.Header().Set("Content-Type", "text/html; charset=utf-8") + w.WriteHeader(http.StatusBadRequest) + fmt.Fprintf(w, ` +Login Failed +

Login Failed

%s

You can close this window.

`, html.EscapeString(errParam)) + resultChan <- WebCallbackResult{Error: errParam} + return + } + + if state != expectedState { + w.Header().Set("Content-Type", "text/html; charset=utf-8") + w.WriteHeader(http.StatusBadRequest) + fmt.Fprint(w, ` +Login Failed +

Login Failed

Invalid state parameter

You can close this window.

`) + resultChan <- WebCallbackResult{Error: "state mismatch"} + return + } + + w.Header().Set("Content-Type", "text/html; charset=utf-8") + fmt.Fprint(w, ` +Login Successful +

Login Successful!

You can close this window and return to the terminal.

+`) + resultChan <- WebCallbackResult{Code: code, State: state} + }) + + server.Handler = mux + + go func() { + if err := server.Serve(listener); err != nil && err != http.ErrServerClosed { + log.Debugf("kiro social auth callback server error: %v", err) + } + }() + + go func() { + select { + case <-ctx.Done(): + case <-time.After(socialAuthTimeout): + case <-resultChan: + } + _ = server.Shutdown(context.Background()) + }() + + return redirectURI, resultChan, nil +} + // generatePKCE generates PKCE code verifier and challenge. func generatePKCE() (verifier, challenge string, err error) { // Generate 32 bytes of random data for verifier @@ -217,10 +306,12 @@ func (c *SocialAuthClient) RefreshSocialToken(ctx context.Context, refreshToken ExpiresAt: expiresAt.Format(time.RFC3339), AuthMethod: "social", Provider: "", // Caller should preserve original provider + Region: "us-east-1", }, nil } -// LoginWithSocial performs OAuth login with Google. +// LoginWithSocial performs OAuth login with Google or GitHub. +// Uses local HTTP callback server instead of custom protocol handler to avoid redirect_mismatch errors. func (c *SocialAuthClient) LoginWithSocial(ctx context.Context, provider SocialProvider) (*KiroTokenData, error) { providerName := string(provider) @@ -228,28 +319,10 @@ func (c *SocialAuthClient) LoginWithSocial(ctx context.Context, provider SocialP fmt.Printf("║ Kiro Authentication (%s) ║\n", providerName) fmt.Println("╚══════════════════════════════════════════════════════════╝") - // Step 1: Setup protocol handler + // Step 1: Start local HTTP callback server (instead of kiro:// protocol handler) + // This avoids redirect_mismatch errors with AWS Cognito fmt.Println("\nSetting up authentication...") - // Start the local callback server - handlerPort, err := c.protocolHandler.Start(ctx) - if err != nil { - return nil, fmt.Errorf("failed to start callback server: %w", err) - } - defer c.protocolHandler.Stop() - - // Ensure protocol handler is installed and set as default - if err := SetupProtocolHandlerIfNeeded(handlerPort); err != nil { - fmt.Println("\n⚠ Protocol handler setup failed. Trying alternative method...") - fmt.Println(" If you see a browser 'Open with' dialog, select your default browser.") - fmt.Println(" For manual setup instructions, run: cliproxy kiro --help-protocol") - log.Debugf("kiro: protocol handler setup error: %v", err) - // Continue anyway - user might have set it up manually or select browser manually - } else { - // Force set our handler as default (prevents "Open with" dialog) - forceDefaultProtocolHandler() - } - // Step 2: Generate PKCE codes codeVerifier, codeChallenge, err := generatePKCE() if err != nil { @@ -262,8 +335,15 @@ func (c *SocialAuthClient) LoginWithSocial(ctx context.Context, provider SocialP return nil, fmt.Errorf("failed to generate state: %w", err) } - // Step 4: Build the login URL (Kiro uses GET request with query params) - authURL := c.buildLoginURL(providerName, KiroRedirectURI, codeChallenge, state) + // Step 4: Start local HTTP callback server + redirectURI, resultChan, err := c.startWebCallbackServer(ctx, state) + if err != nil { + return nil, fmt.Errorf("failed to start callback server: %w", err) + } + log.Debugf("kiro social auth: callback server started at %s", redirectURI) + + // Step 5: Build the login URL using HTTP redirect URI + authURL := c.buildLoginURL(providerName, redirectURI, codeChallenge, state) // Set incognito mode based on config (defaults to true for Kiro, can be overridden with --no-incognito) // Incognito mode enables multi-account support by bypassing cached sessions @@ -279,7 +359,7 @@ func (c *SocialAuthClient) LoginWithSocial(ctx context.Context, provider SocialP log.Debug("kiro: using incognito mode for multi-account support (default)") } - // Step 5: Open browser for user authentication + // Step 6: Open browser for user authentication fmt.Println("\n════════════════════════════════════════════════════════════") fmt.Printf(" Opening browser for %s authentication...\n", providerName) fmt.Println("════════════════════════════════════════════════════════════") @@ -295,80 +375,78 @@ func (c *SocialAuthClient) LoginWithSocial(ctx context.Context, provider SocialP fmt.Println("\n Waiting for authentication callback...") - // Step 6: Wait for callback - callback, err := c.protocolHandler.WaitForCallback(ctx) - if err != nil { - return nil, fmt.Errorf("failed to receive callback: %w", err) - } - - if callback.Error != "" { - return nil, fmt.Errorf("authentication error: %s", callback.Error) - } - - if callback.State != state { - // Log state values for debugging, but don't expose in user-facing error - log.Debugf("kiro: OAuth state mismatch - expected %s, got %s", state, callback.State) - return nil, fmt.Errorf("OAuth state validation failed - please try again") - } - - if callback.Code == "" { - return nil, fmt.Errorf("no authorization code received") - } - - fmt.Println("\n✓ Authorization received!") - - // Step 7: Exchange code for tokens - fmt.Println("Exchanging code for tokens...") - - tokenReq := &CreateTokenRequest{ - Code: callback.Code, - CodeVerifier: codeVerifier, - RedirectURI: KiroRedirectURI, - } - - tokenResp, err := c.CreateToken(ctx, tokenReq) - if err != nil { - return nil, fmt.Errorf("failed to exchange code for tokens: %w", err) - } - - fmt.Println("\n✓ Authentication successful!") - - // Close the browser window - if err := browser.CloseBrowser(); err != nil { - log.Debugf("Failed to close browser: %v", err) - } - - // Validate ExpiresIn - use default 1 hour if invalid - expiresIn := tokenResp.ExpiresIn - if expiresIn <= 0 { - expiresIn = 3600 - } - expiresAt := time.Now().Add(time.Duration(expiresIn) * time.Second) - - // Try to extract email from JWT access token first - email := ExtractEmailFromJWT(tokenResp.AccessToken) - - // If no email in JWT, ask user for account label (only in interactive mode) - if email == "" && isInteractiveTerminal() { - fmt.Print("\n Enter account label for file naming (optional, press Enter to skip): ") - reader := bufio.NewReader(os.Stdin) - var err error - email, err = reader.ReadString('\n') - if err != nil { - log.Debugf("Failed to read account label: %v", err) + // Step 7: Wait for callback from HTTP server + select { + case <-ctx.Done(): + return nil, ctx.Err() + case <-time.After(socialAuthTimeout): + return nil, fmt.Errorf("authentication timed out") + case callback := <-resultChan: + if callback.Error != "" { + return nil, fmt.Errorf("authentication error: %s", callback.Error) } - email = strings.TrimSpace(email) - } - return &KiroTokenData{ - AccessToken: tokenResp.AccessToken, - RefreshToken: tokenResp.RefreshToken, - ProfileArn: tokenResp.ProfileArn, - ExpiresAt: expiresAt.Format(time.RFC3339), - AuthMethod: "social", - Provider: providerName, - Email: email, // JWT email or user-provided label - }, nil + // State is already validated by the callback server + if callback.Code == "" { + return nil, fmt.Errorf("no authorization code received") + } + + fmt.Println("\n✓ Authorization received!") + + // Step 8: Exchange code for tokens + fmt.Println("Exchanging code for tokens...") + + tokenReq := &CreateTokenRequest{ + Code: callback.Code, + CodeVerifier: codeVerifier, + RedirectURI: redirectURI, // Use HTTP redirect URI, not kiro:// protocol + } + + tokenResp, err := c.CreateToken(ctx, tokenReq) + if err != nil { + return nil, fmt.Errorf("failed to exchange code for tokens: %w", err) + } + + fmt.Println("\n✓ Authentication successful!") + + // Close the browser window + if err := browser.CloseBrowser(); err != nil { + log.Debugf("Failed to close browser: %v", err) + } + + // Validate ExpiresIn - use default 1 hour if invalid + expiresIn := tokenResp.ExpiresIn + if expiresIn <= 0 { + expiresIn = 3600 + } + expiresAt := time.Now().Add(time.Duration(expiresIn) * time.Second) + + // Try to extract email from JWT access token first + email := ExtractEmailFromJWT(tokenResp.AccessToken) + + // If no email in JWT, ask user for account label (only in interactive mode) + if email == "" && isInteractiveTerminal() { + fmt.Print("\n Enter account label for file naming (optional, press Enter to skip): ") + reader := bufio.NewReader(os.Stdin) + var err error + email, err = reader.ReadString('\n') + if err != nil { + log.Debugf("Failed to read account label: %v", err) + } + email = strings.TrimSpace(email) + } + + return &KiroTokenData{ + AccessToken: tokenResp.AccessToken, + RefreshToken: tokenResp.RefreshToken, + ProfileArn: tokenResp.ProfileArn, + ExpiresAt: expiresAt.Format(time.RFC3339), + AuthMethod: "social", + Provider: providerName, + Email: email, // JWT email or user-provided label + Region: "us-east-1", + }, nil + } } // LoginWithGoogle performs OAuth login with Google. diff --git a/internal/auth/kiro/sso_oidc.go b/internal/auth/kiro/sso_oidc.go index ab44e55f..ba15dac9 100644 --- a/internal/auth/kiro/sso_oidc.go +++ b/internal/auth/kiro/sso_oidc.go @@ -735,6 +735,7 @@ func (c *SSOOIDCClient) RefreshToken(ctx context.Context, clientID, clientSecret Provider: "AWS", ClientID: clientID, ClientSecret: clientSecret, + Region: defaultIDCRegion, }, nil } @@ -850,16 +851,17 @@ func (c *SSOOIDCClient) LoginWithBuilderID(ctx context.Context) (*KiroTokenData, ClientID: regResp.ClientID, ClientSecret: regResp.ClientSecret, Email: email, + Region: defaultIDCRegion, }, nil - } - } + } + } - // Close browser on timeout for better UX - if err := browser.CloseBrowser(); err != nil { - log.Debugf("Failed to close browser on timeout: %v", err) - } - return nil, fmt.Errorf("authorization timed out") -} + // Close browser on timeout for better UX + if err := browser.CloseBrowser(); err != nil { + log.Debugf("Failed to close browser on timeout: %v", err) + } + return nil, fmt.Errorf("authorization timed out") + } // FetchUserEmail retrieves the user's email from AWS SSO OIDC userinfo endpoint. // Falls back to JWT parsing if userinfo fails. @@ -1366,6 +1368,7 @@ func (c *SSOOIDCClient) LoginWithBuilderIDAuthCode(ctx context.Context) (*KiroTo ClientID: regResp.ClientID, ClientSecret: regResp.ClientSecret, Email: email, + Region: defaultIDCRegion, }, nil } } diff --git a/internal/auth/kiro/sso_oidc.go.bak b/internal/auth/kiro/sso_oidc.go.bak new file mode 100644 index 00000000..ab44e55f --- /dev/null +++ b/internal/auth/kiro/sso_oidc.go.bak @@ -0,0 +1,1371 @@ +// Package kiro provides AWS SSO OIDC authentication for Kiro. +package kiro + +import ( + "bufio" + "context" + "crypto/rand" + "crypto/sha256" + "encoding/base64" + "encoding/json" + "errors" + "fmt" + "html" + "io" + "net" + "net/http" + "os" + "strings" + "time" + + "github.com/router-for-me/CLIProxyAPI/v6/internal/browser" + "github.com/router-for-me/CLIProxyAPI/v6/internal/config" + "github.com/router-for-me/CLIProxyAPI/v6/internal/util" + log "github.com/sirupsen/logrus" +) + +const ( + // AWS SSO OIDC endpoints + ssoOIDCEndpoint = "https://oidc.us-east-1.amazonaws.com" + + // Kiro's start URL for Builder ID + builderIDStartURL = "https://view.awsapps.com/start" + + // Default region for IDC + defaultIDCRegion = "us-east-1" + + // Polling interval + pollInterval = 5 * time.Second + + // Authorization code flow callback + authCodeCallbackPath = "/oauth/callback" + authCodeCallbackPort = 19877 + + // User-Agent to match official Kiro IDE + kiroUserAgent = "KiroIDE" + + // IDC token refresh headers (matching Kiro IDE behavior) + idcAmzUserAgent = "aws-sdk-js/3.738.0 ua/2.1 os/other lang/js md/browser#unknown_unknown api/sso-oidc#3.738.0 m/E KiroIDE" +) + +// Sentinel errors for OIDC token polling +var ( + ErrAuthorizationPending = errors.New("authorization_pending") + ErrSlowDown = errors.New("slow_down") +) + +// SSOOIDCClient handles AWS SSO OIDC authentication. +type SSOOIDCClient struct { + httpClient *http.Client + cfg *config.Config +} + +// NewSSOOIDCClient creates a new SSO OIDC client. +func NewSSOOIDCClient(cfg *config.Config) *SSOOIDCClient { + client := &http.Client{Timeout: 30 * time.Second} + if cfg != nil { + client = util.SetProxy(&cfg.SDKConfig, client) + } + return &SSOOIDCClient{ + httpClient: client, + cfg: cfg, + } +} + +// RegisterClientResponse from AWS SSO OIDC. +type RegisterClientResponse struct { + ClientID string `json:"clientId"` + ClientSecret string `json:"clientSecret"` + ClientIDIssuedAt int64 `json:"clientIdIssuedAt"` + ClientSecretExpiresAt int64 `json:"clientSecretExpiresAt"` +} + +// StartDeviceAuthResponse from AWS SSO OIDC. +type StartDeviceAuthResponse struct { + DeviceCode string `json:"deviceCode"` + UserCode string `json:"userCode"` + VerificationURI string `json:"verificationUri"` + VerificationURIComplete string `json:"verificationUriComplete"` + ExpiresIn int `json:"expiresIn"` + Interval int `json:"interval"` +} + +// CreateTokenResponse from AWS SSO OIDC. +type CreateTokenResponse struct { + AccessToken string `json:"accessToken"` + TokenType string `json:"tokenType"` + ExpiresIn int `json:"expiresIn"` + RefreshToken string `json:"refreshToken"` +} + +// getOIDCEndpoint returns the OIDC endpoint for the given region. +func getOIDCEndpoint(region string) string { + if region == "" { + region = defaultIDCRegion + } + return fmt.Sprintf("https://oidc.%s.amazonaws.com", region) +} + +// promptInput prompts the user for input with an optional default value. +func promptInput(prompt, defaultValue string) string { + reader := bufio.NewReader(os.Stdin) + if defaultValue != "" { + fmt.Printf("%s [%s]: ", prompt, defaultValue) + } else { + fmt.Printf("%s: ", prompt) + } + input, err := reader.ReadString('\n') + if err != nil { + log.Warnf("Error reading input: %v", err) + return defaultValue + } + input = strings.TrimSpace(input) + if input == "" { + return defaultValue + } + return input +} + +// promptSelect prompts the user to select from options using number input. +func promptSelect(prompt string, options []string) int { + reader := bufio.NewReader(os.Stdin) + + for { + fmt.Println(prompt) + for i, opt := range options { + fmt.Printf(" %d) %s\n", i+1, opt) + } + fmt.Printf("Enter selection (1-%d): ", len(options)) + + input, err := reader.ReadString('\n') + if err != nil { + log.Warnf("Error reading input: %v", err) + return 0 // Default to first option on error + } + input = strings.TrimSpace(input) + + // Parse the selection + var selection int + if _, err := fmt.Sscanf(input, "%d", &selection); err != nil || selection < 1 || selection > len(options) { + fmt.Printf("Invalid selection '%s'. Please enter a number between 1 and %d.\n\n", input, len(options)) + continue + } + return selection - 1 + } +} + +// RegisterClientWithRegion registers a new OIDC client with AWS using a specific region. +func (c *SSOOIDCClient) RegisterClientWithRegion(ctx context.Context, region string) (*RegisterClientResponse, error) { + endpoint := getOIDCEndpoint(region) + + payload := map[string]interface{}{ + "clientName": "Kiro IDE", + "clientType": "public", + "scopes": []string{"codewhisperer:completions", "codewhisperer:analysis", "codewhisperer:conversations", "codewhisperer:transformations", "codewhisperer:taskassist"}, + "grantTypes": []string{"urn:ietf:params:oauth:grant-type:device_code", "refresh_token"}, + } + + body, err := json.Marshal(payload) + if err != nil { + return nil, err + } + + req, err := http.NewRequestWithContext(ctx, http.MethodPost, endpoint+"/client/register", strings.NewReader(string(body))) + if err != nil { + return nil, err + } + req.Header.Set("Content-Type", "application/json") + req.Header.Set("User-Agent", kiroUserAgent) + + resp, err := c.httpClient.Do(req) + if err != nil { + return nil, err + } + defer resp.Body.Close() + + respBody, err := io.ReadAll(resp.Body) + if err != nil { + return nil, err + } + + if resp.StatusCode != http.StatusOK { + log.Debugf("register client failed (status %d): %s", resp.StatusCode, string(respBody)) + return nil, fmt.Errorf("register client failed (status %d)", resp.StatusCode) + } + + var result RegisterClientResponse + if err := json.Unmarshal(respBody, &result); err != nil { + return nil, err + } + + return &result, nil +} + +// StartDeviceAuthorizationWithIDC starts the device authorization flow for IDC. +func (c *SSOOIDCClient) StartDeviceAuthorizationWithIDC(ctx context.Context, clientID, clientSecret, startURL, region string) (*StartDeviceAuthResponse, error) { + endpoint := getOIDCEndpoint(region) + + payload := map[string]string{ + "clientId": clientID, + "clientSecret": clientSecret, + "startUrl": startURL, + } + + body, err := json.Marshal(payload) + if err != nil { + return nil, err + } + + req, err := http.NewRequestWithContext(ctx, http.MethodPost, endpoint+"/device_authorization", strings.NewReader(string(body))) + if err != nil { + return nil, err + } + req.Header.Set("Content-Type", "application/json") + req.Header.Set("User-Agent", kiroUserAgent) + + resp, err := c.httpClient.Do(req) + if err != nil { + return nil, err + } + defer resp.Body.Close() + + respBody, err := io.ReadAll(resp.Body) + if err != nil { + return nil, err + } + + if resp.StatusCode != http.StatusOK { + log.Debugf("start device auth failed (status %d): %s", resp.StatusCode, string(respBody)) + return nil, fmt.Errorf("start device auth failed (status %d)", resp.StatusCode) + } + + var result StartDeviceAuthResponse + if err := json.Unmarshal(respBody, &result); err != nil { + return nil, err + } + + return &result, nil +} + +// CreateTokenWithRegion polls for the access token after user authorization using a specific region. +func (c *SSOOIDCClient) CreateTokenWithRegion(ctx context.Context, clientID, clientSecret, deviceCode, region string) (*CreateTokenResponse, error) { + endpoint := getOIDCEndpoint(region) + + payload := map[string]string{ + "clientId": clientID, + "clientSecret": clientSecret, + "deviceCode": deviceCode, + "grantType": "urn:ietf:params:oauth:grant-type:device_code", + } + + body, err := json.Marshal(payload) + if err != nil { + return nil, err + } + + req, err := http.NewRequestWithContext(ctx, http.MethodPost, endpoint+"/token", strings.NewReader(string(body))) + if err != nil { + return nil, err + } + req.Header.Set("Content-Type", "application/json") + req.Header.Set("User-Agent", kiroUserAgent) + + resp, err := c.httpClient.Do(req) + if err != nil { + return nil, err + } + defer resp.Body.Close() + + respBody, err := io.ReadAll(resp.Body) + if err != nil { + return nil, err + } + + // Check for pending authorization + if resp.StatusCode == http.StatusBadRequest { + var errResp struct { + Error string `json:"error"` + } + if json.Unmarshal(respBody, &errResp) == nil { + if errResp.Error == "authorization_pending" { + return nil, ErrAuthorizationPending + } + if errResp.Error == "slow_down" { + return nil, ErrSlowDown + } + } + log.Debugf("create token failed: %s", string(respBody)) + return nil, fmt.Errorf("create token failed") + } + + if resp.StatusCode != http.StatusOK { + log.Debugf("create token failed (status %d): %s", resp.StatusCode, string(respBody)) + return nil, fmt.Errorf("create token failed (status %d)", resp.StatusCode) + } + + var result CreateTokenResponse + if err := json.Unmarshal(respBody, &result); err != nil { + return nil, err + } + + return &result, nil +} + +// RefreshTokenWithRegion refreshes an access token using the refresh token with a specific region. +func (c *SSOOIDCClient) RefreshTokenWithRegion(ctx context.Context, clientID, clientSecret, refreshToken, region, startURL string) (*KiroTokenData, error) { + endpoint := getOIDCEndpoint(region) + + payload := map[string]string{ + "clientId": clientID, + "clientSecret": clientSecret, + "refreshToken": refreshToken, + "grantType": "refresh_token", + } + + body, err := json.Marshal(payload) + if err != nil { + return nil, err + } + + req, err := http.NewRequestWithContext(ctx, http.MethodPost, endpoint+"/token", strings.NewReader(string(body))) + if err != nil { + return nil, err + } + + // Set headers matching kiro2api's IDC token refresh + // These headers are required for successful IDC token refresh + req.Header.Set("Content-Type", "application/json") + req.Header.Set("Host", fmt.Sprintf("oidc.%s.amazonaws.com", region)) + req.Header.Set("Connection", "keep-alive") + req.Header.Set("x-amz-user-agent", idcAmzUserAgent) + req.Header.Set("Accept", "*/*") + req.Header.Set("Accept-Language", "*") + req.Header.Set("sec-fetch-mode", "cors") + req.Header.Set("User-Agent", "node") + req.Header.Set("Accept-Encoding", "br, gzip, deflate") + + resp, err := c.httpClient.Do(req) + if err != nil { + return nil, err + } + defer resp.Body.Close() + + respBody, err := io.ReadAll(resp.Body) + if err != nil { + return nil, err + } + + if resp.StatusCode != http.StatusOK { + log.Warnf("IDC token refresh failed (status %d): %s", resp.StatusCode, string(respBody)) + return nil, fmt.Errorf("token refresh failed (status %d)", resp.StatusCode) + } + + var result CreateTokenResponse + if err := json.Unmarshal(respBody, &result); err != nil { + return nil, err + } + + expiresAt := time.Now().Add(time.Duration(result.ExpiresIn) * time.Second) + + return &KiroTokenData{ + AccessToken: result.AccessToken, + RefreshToken: result.RefreshToken, + ExpiresAt: expiresAt.Format(time.RFC3339), + AuthMethod: "idc", + Provider: "AWS", + ClientID: clientID, + ClientSecret: clientSecret, + StartURL: startURL, + Region: region, + }, nil +} + +// LoginWithIDC performs the full device code flow for AWS Identity Center (IDC). +func (c *SSOOIDCClient) LoginWithIDC(ctx context.Context, startURL, region string) (*KiroTokenData, error) { + fmt.Println("\n╔══════════════════════════════════════════════════════════╗") + fmt.Println("║ Kiro Authentication (AWS Identity Center) ║") + fmt.Println("╚══════════════════════════════════════════════════════════╝") + + // Step 1: Register client with the specified region + fmt.Println("\nRegistering client...") + regResp, err := c.RegisterClientWithRegion(ctx, region) + if err != nil { + return nil, fmt.Errorf("failed to register client: %w", err) + } + log.Debugf("Client registered: %s", regResp.ClientID) + + // Step 2: Start device authorization with IDC start URL + fmt.Println("Starting device authorization...") + authResp, err := c.StartDeviceAuthorizationWithIDC(ctx, regResp.ClientID, regResp.ClientSecret, startURL, region) + if err != nil { + return nil, fmt.Errorf("failed to start device auth: %w", err) + } + + // Step 3: Show user the verification URL + fmt.Printf("\n") + fmt.Println("════════════════════════════════════════════════════════════") + fmt.Printf(" Confirm the following code in the browser:\n") + fmt.Printf(" Code: %s\n", authResp.UserCode) + fmt.Println("════════════════════════════════════════════════════════════") + fmt.Printf("\n Open this URL: %s\n\n", authResp.VerificationURIComplete) + + // Set incognito mode based on config + if c.cfg != nil { + browser.SetIncognitoMode(c.cfg.IncognitoBrowser) + if !c.cfg.IncognitoBrowser { + log.Info("kiro: using normal browser mode (--no-incognito). Note: You may not be able to select a different account.") + } else { + log.Debug("kiro: using incognito mode for multi-account support") + } + } else { + browser.SetIncognitoMode(true) + log.Debug("kiro: using incognito mode for multi-account support (default)") + } + + // Open browser + if err := browser.OpenURL(authResp.VerificationURIComplete); err != nil { + log.Warnf("Could not open browser automatically: %v", err) + fmt.Println(" Please open the URL manually in your browser.") + } else { + fmt.Println(" (Browser opened automatically)") + } + + // Step 4: Poll for token + fmt.Println("Waiting for authorization...") + + interval := pollInterval + if authResp.Interval > 0 { + interval = time.Duration(authResp.Interval) * time.Second + } + + deadline := time.Now().Add(time.Duration(authResp.ExpiresIn) * time.Second) + + for time.Now().Before(deadline) { + select { + case <-ctx.Done(): + browser.CloseBrowser() + return nil, ctx.Err() + case <-time.After(interval): + tokenResp, err := c.CreateTokenWithRegion(ctx, regResp.ClientID, regResp.ClientSecret, authResp.DeviceCode, region) + if err != nil { + if errors.Is(err, ErrAuthorizationPending) { + fmt.Print(".") + continue + } + if errors.Is(err, ErrSlowDown) { + interval += 5 * time.Second + continue + } + browser.CloseBrowser() + return nil, fmt.Errorf("token creation failed: %w", err) + } + + fmt.Println("\n\n✓ Authorization successful!") + + // Close the browser window + if err := browser.CloseBrowser(); err != nil { + log.Debugf("Failed to close browser: %v", err) + } + + // Step 5: Get profile ARN from CodeWhisperer API + fmt.Println("Fetching profile information...") + profileArn := c.fetchProfileArn(ctx, tokenResp.AccessToken) + + // Fetch user email + email := FetchUserEmailWithFallback(ctx, c.cfg, tokenResp.AccessToken) + if email != "" { + fmt.Printf(" Logged in as: %s\n", email) + } + + expiresAt := time.Now().Add(time.Duration(tokenResp.ExpiresIn) * time.Second) + + return &KiroTokenData{ + AccessToken: tokenResp.AccessToken, + RefreshToken: tokenResp.RefreshToken, + ProfileArn: profileArn, + ExpiresAt: expiresAt.Format(time.RFC3339), + AuthMethod: "idc", + Provider: "AWS", + ClientID: regResp.ClientID, + ClientSecret: regResp.ClientSecret, + Email: email, + StartURL: startURL, + Region: region, + }, nil + } + } + + // Close browser on timeout + if err := browser.CloseBrowser(); err != nil { + log.Debugf("Failed to close browser on timeout: %v", err) + } + return nil, fmt.Errorf("authorization timed out") +} + +// LoginWithMethodSelection prompts the user to select between Builder ID and IDC, then performs the login. +func (c *SSOOIDCClient) LoginWithMethodSelection(ctx context.Context) (*KiroTokenData, error) { + fmt.Println("\n╔══════════════════════════════════════════════════════════╗") + fmt.Println("║ Kiro Authentication (AWS) ║") + fmt.Println("╚══════════════════════════════════════════════════════════╝") + + // Prompt for login method + options := []string{ + "Use with Builder ID (personal AWS account)", + "Use with IDC Account (organization SSO)", + } + selection := promptSelect("\n? Select login method:", options) + + if selection == 0 { + // Builder ID flow - use existing implementation + return c.LoginWithBuilderID(ctx) + } + + // IDC flow - prompt for start URL and region + fmt.Println() + startURL := promptInput("? Enter Start URL", "") + if startURL == "" { + return nil, fmt.Errorf("start URL is required for IDC login") + } + + region := promptInput("? Enter Region", defaultIDCRegion) + + return c.LoginWithIDC(ctx, startURL, region) +} + +// RegisterClient registers a new OIDC client with AWS. +func (c *SSOOIDCClient) RegisterClient(ctx context.Context) (*RegisterClientResponse, error) { + payload := map[string]interface{}{ + "clientName": "Kiro IDE", + "clientType": "public", + "scopes": []string{"codewhisperer:completions", "codewhisperer:analysis", "codewhisperer:conversations", "codewhisperer:transformations", "codewhisperer:taskassist"}, + "grantTypes": []string{"urn:ietf:params:oauth:grant-type:device_code", "refresh_token"}, + } + + body, err := json.Marshal(payload) + if err != nil { + return nil, err + } + + req, err := http.NewRequestWithContext(ctx, http.MethodPost, ssoOIDCEndpoint+"/client/register", strings.NewReader(string(body))) + if err != nil { + return nil, err + } + req.Header.Set("Content-Type", "application/json") + req.Header.Set("User-Agent", kiroUserAgent) + + resp, err := c.httpClient.Do(req) + if err != nil { + return nil, err + } + defer resp.Body.Close() + + respBody, err := io.ReadAll(resp.Body) + if err != nil { + return nil, err + } + + if resp.StatusCode != http.StatusOK { + log.Debugf("register client failed (status %d): %s", resp.StatusCode, string(respBody)) + return nil, fmt.Errorf("register client failed (status %d)", resp.StatusCode) + } + + var result RegisterClientResponse + if err := json.Unmarshal(respBody, &result); err != nil { + return nil, err + } + + return &result, nil +} + +// StartDeviceAuthorization starts the device authorization flow. +func (c *SSOOIDCClient) StartDeviceAuthorization(ctx context.Context, clientID, clientSecret string) (*StartDeviceAuthResponse, error) { + payload := map[string]string{ + "clientId": clientID, + "clientSecret": clientSecret, + "startUrl": builderIDStartURL, + } + + body, err := json.Marshal(payload) + if err != nil { + return nil, err + } + + req, err := http.NewRequestWithContext(ctx, http.MethodPost, ssoOIDCEndpoint+"/device_authorization", strings.NewReader(string(body))) + if err != nil { + return nil, err + } + req.Header.Set("Content-Type", "application/json") + req.Header.Set("User-Agent", kiroUserAgent) + + resp, err := c.httpClient.Do(req) + if err != nil { + return nil, err + } + defer resp.Body.Close() + + respBody, err := io.ReadAll(resp.Body) + if err != nil { + return nil, err + } + + if resp.StatusCode != http.StatusOK { + log.Debugf("start device auth failed (status %d): %s", resp.StatusCode, string(respBody)) + return nil, fmt.Errorf("start device auth failed (status %d)", resp.StatusCode) + } + + var result StartDeviceAuthResponse + if err := json.Unmarshal(respBody, &result); err != nil { + return nil, err + } + + return &result, nil +} + +// CreateToken polls for the access token after user authorization. +func (c *SSOOIDCClient) CreateToken(ctx context.Context, clientID, clientSecret, deviceCode string) (*CreateTokenResponse, error) { + payload := map[string]string{ + "clientId": clientID, + "clientSecret": clientSecret, + "deviceCode": deviceCode, + "grantType": "urn:ietf:params:oauth:grant-type:device_code", + } + + body, err := json.Marshal(payload) + if err != nil { + return nil, err + } + + req, err := http.NewRequestWithContext(ctx, http.MethodPost, ssoOIDCEndpoint+"/token", strings.NewReader(string(body))) + if err != nil { + return nil, err + } + req.Header.Set("Content-Type", "application/json") + req.Header.Set("User-Agent", kiroUserAgent) + + resp, err := c.httpClient.Do(req) + if err != nil { + return nil, err + } + defer resp.Body.Close() + + respBody, err := io.ReadAll(resp.Body) + if err != nil { + return nil, err + } + + // Check for pending authorization + if resp.StatusCode == http.StatusBadRequest { + var errResp struct { + Error string `json:"error"` + } + if json.Unmarshal(respBody, &errResp) == nil { + if errResp.Error == "authorization_pending" { + return nil, ErrAuthorizationPending + } + if errResp.Error == "slow_down" { + return nil, ErrSlowDown + } + } + log.Debugf("create token failed: %s", string(respBody)) + return nil, fmt.Errorf("create token failed") + } + + if resp.StatusCode != http.StatusOK { + log.Debugf("create token failed (status %d): %s", resp.StatusCode, string(respBody)) + return nil, fmt.Errorf("create token failed (status %d)", resp.StatusCode) + } + + var result CreateTokenResponse + if err := json.Unmarshal(respBody, &result); err != nil { + return nil, err + } + + return &result, nil +} + +// RefreshToken refreshes an access token using the refresh token. +func (c *SSOOIDCClient) RefreshToken(ctx context.Context, clientID, clientSecret, refreshToken string) (*KiroTokenData, error) { + payload := map[string]string{ + "clientId": clientID, + "clientSecret": clientSecret, + "refreshToken": refreshToken, + "grantType": "refresh_token", + } + + body, err := json.Marshal(payload) + if err != nil { + return nil, err + } + + req, err := http.NewRequestWithContext(ctx, http.MethodPost, ssoOIDCEndpoint+"/token", strings.NewReader(string(body))) + if err != nil { + return nil, err + } + req.Header.Set("Content-Type", "application/json") + req.Header.Set("User-Agent", kiroUserAgent) + + resp, err := c.httpClient.Do(req) + if err != nil { + return nil, err + } + defer resp.Body.Close() + + respBody, err := io.ReadAll(resp.Body) + if err != nil { + return nil, err + } + + if resp.StatusCode != http.StatusOK { + log.Debugf("token refresh failed (status %d): %s", resp.StatusCode, string(respBody)) + return nil, fmt.Errorf("token refresh failed (status %d)", resp.StatusCode) + } + + var result CreateTokenResponse + if err := json.Unmarshal(respBody, &result); err != nil { + return nil, err + } + + expiresAt := time.Now().Add(time.Duration(result.ExpiresIn) * time.Second) + + return &KiroTokenData{ + AccessToken: result.AccessToken, + RefreshToken: result.RefreshToken, + ExpiresAt: expiresAt.Format(time.RFC3339), + AuthMethod: "builder-id", + Provider: "AWS", + ClientID: clientID, + ClientSecret: clientSecret, + }, nil +} + +// LoginWithBuilderID performs the full device code flow for AWS Builder ID. +func (c *SSOOIDCClient) LoginWithBuilderID(ctx context.Context) (*KiroTokenData, error) { + fmt.Println("\n╔══════════════════════════════════════════════════════════╗") + fmt.Println("║ Kiro Authentication (AWS Builder ID) ║") + fmt.Println("╚══════════════════════════════════════════════════════════╝") + + // Step 1: Register client + fmt.Println("\nRegistering client...") + regResp, err := c.RegisterClient(ctx) + if err != nil { + return nil, fmt.Errorf("failed to register client: %w", err) + } + log.Debugf("Client registered: %s", regResp.ClientID) + + // Step 2: Start device authorization + fmt.Println("Starting device authorization...") + authResp, err := c.StartDeviceAuthorization(ctx, regResp.ClientID, regResp.ClientSecret) + if err != nil { + return nil, fmt.Errorf("failed to start device auth: %w", err) + } + + // Step 3: Show user the verification URL + fmt.Printf("\n") + fmt.Println("════════════════════════════════════════════════════════════") + fmt.Printf(" Open this URL in your browser:\n") + fmt.Printf(" %s\n", authResp.VerificationURIComplete) + fmt.Println("════════════════════════════════════════════════════════════") + fmt.Printf("\n Or go to: %s\n", authResp.VerificationURI) + fmt.Printf(" And enter code: %s\n\n", authResp.UserCode) + + // Set incognito mode based on config (defaults to true for Kiro, can be overridden with --no-incognito) + // Incognito mode enables multi-account support by bypassing cached sessions + if c.cfg != nil { + browser.SetIncognitoMode(c.cfg.IncognitoBrowser) + if !c.cfg.IncognitoBrowser { + log.Info("kiro: using normal browser mode (--no-incognito). Note: You may not be able to select a different account.") + } else { + log.Debug("kiro: using incognito mode for multi-account support") + } + } else { + browser.SetIncognitoMode(true) // Default to incognito if no config + log.Debug("kiro: using incognito mode for multi-account support (default)") + } + + // Open browser using cross-platform browser package + if err := browser.OpenURL(authResp.VerificationURIComplete); err != nil { + log.Warnf("Could not open browser automatically: %v", err) + fmt.Println(" Please open the URL manually in your browser.") + } else { + fmt.Println(" (Browser opened automatically)") + } + + // Step 4: Poll for token + fmt.Println("Waiting for authorization...") + + interval := pollInterval + if authResp.Interval > 0 { + interval = time.Duration(authResp.Interval) * time.Second + } + + deadline := time.Now().Add(time.Duration(authResp.ExpiresIn) * time.Second) + + for time.Now().Before(deadline) { + select { + case <-ctx.Done(): + browser.CloseBrowser() // Cleanup on cancel + return nil, ctx.Err() + case <-time.After(interval): + tokenResp, err := c.CreateToken(ctx, regResp.ClientID, regResp.ClientSecret, authResp.DeviceCode) + if err != nil { + if errors.Is(err, ErrAuthorizationPending) { + fmt.Print(".") + continue + } + if errors.Is(err, ErrSlowDown) { + interval += 5 * time.Second + continue + } + // Close browser on error before returning + browser.CloseBrowser() + return nil, fmt.Errorf("token creation failed: %w", err) + } + + fmt.Println("\n\n✓ Authorization successful!") + + // Close the browser window + if err := browser.CloseBrowser(); err != nil { + log.Debugf("Failed to close browser: %v", err) + } + + // Step 5: Get profile ARN from CodeWhisperer API + fmt.Println("Fetching profile information...") + profileArn := c.fetchProfileArn(ctx, tokenResp.AccessToken) + + // Fetch user email (tries CodeWhisperer API first, then userinfo endpoint, then JWT parsing) + email := FetchUserEmailWithFallback(ctx, c.cfg, tokenResp.AccessToken) + if email != "" { + fmt.Printf(" Logged in as: %s\n", email) + } + + expiresAt := time.Now().Add(time.Duration(tokenResp.ExpiresIn) * time.Second) + + return &KiroTokenData{ + AccessToken: tokenResp.AccessToken, + RefreshToken: tokenResp.RefreshToken, + ProfileArn: profileArn, + ExpiresAt: expiresAt.Format(time.RFC3339), + AuthMethod: "builder-id", + Provider: "AWS", + ClientID: regResp.ClientID, + ClientSecret: regResp.ClientSecret, + Email: email, + }, nil + } + } + + // Close browser on timeout for better UX + if err := browser.CloseBrowser(); err != nil { + log.Debugf("Failed to close browser on timeout: %v", err) + } + return nil, fmt.Errorf("authorization timed out") +} + +// FetchUserEmail retrieves the user's email from AWS SSO OIDC userinfo endpoint. +// Falls back to JWT parsing if userinfo fails. +func (c *SSOOIDCClient) FetchUserEmail(ctx context.Context, accessToken string) string { + // Method 1: Try userinfo endpoint (standard OIDC) + email := c.tryUserInfoEndpoint(ctx, accessToken) + if email != "" { + return email + } + + // Method 2: Fallback to JWT parsing + return ExtractEmailFromJWT(accessToken) +} + +// tryUserInfoEndpoint attempts to get user info from AWS SSO OIDC userinfo endpoint. +func (c *SSOOIDCClient) tryUserInfoEndpoint(ctx context.Context, accessToken string) string { + req, err := http.NewRequestWithContext(ctx, http.MethodGet, ssoOIDCEndpoint+"/userinfo", nil) + if err != nil { + return "" + } + req.Header.Set("Authorization", "Bearer "+accessToken) + req.Header.Set("Accept", "application/json") + + resp, err := c.httpClient.Do(req) + if err != nil { + log.Debugf("userinfo request failed: %v", err) + return "" + } + defer resp.Body.Close() + + if resp.StatusCode != http.StatusOK { + respBody, _ := io.ReadAll(resp.Body) + log.Debugf("userinfo endpoint returned status %d: %s", resp.StatusCode, string(respBody)) + return "" + } + + respBody, err := io.ReadAll(resp.Body) + if err != nil { + return "" + } + + log.Debugf("userinfo response: %s", string(respBody)) + + var userInfo struct { + Email string `json:"email"` + Sub string `json:"sub"` + PreferredUsername string `json:"preferred_username"` + Name string `json:"name"` + } + + if err := json.Unmarshal(respBody, &userInfo); err != nil { + return "" + } + + if userInfo.Email != "" { + return userInfo.Email + } + if userInfo.PreferredUsername != "" && strings.Contains(userInfo.PreferredUsername, "@") { + return userInfo.PreferredUsername + } + return "" +} + +// fetchProfileArn retrieves the profile ARN from CodeWhisperer API. +// This is needed for file naming since AWS SSO OIDC doesn't return profile info. +func (c *SSOOIDCClient) fetchProfileArn(ctx context.Context, accessToken string) string { + // Try ListProfiles API first + profileArn := c.tryListProfiles(ctx, accessToken) + if profileArn != "" { + return profileArn + } + + // Fallback: Try ListAvailableCustomizations + return c.tryListCustomizations(ctx, accessToken) +} + +func (c *SSOOIDCClient) tryListProfiles(ctx context.Context, accessToken string) string { + payload := map[string]interface{}{ + "origin": "AI_EDITOR", + } + + body, err := json.Marshal(payload) + if err != nil { + return "" + } + + req, err := http.NewRequestWithContext(ctx, http.MethodPost, "https://codewhisperer.us-east-1.amazonaws.com", strings.NewReader(string(body))) + if err != nil { + return "" + } + + req.Header.Set("Content-Type", "application/x-amz-json-1.0") + req.Header.Set("x-amz-target", "AmazonCodeWhispererService.ListProfiles") + req.Header.Set("Authorization", "Bearer "+accessToken) + req.Header.Set("Accept", "application/json") + + resp, err := c.httpClient.Do(req) + if err != nil { + return "" + } + defer resp.Body.Close() + + respBody, _ := io.ReadAll(resp.Body) + + if resp.StatusCode != http.StatusOK { + log.Debugf("ListProfiles failed (status %d): %s", resp.StatusCode, string(respBody)) + return "" + } + + log.Debugf("ListProfiles response: %s", string(respBody)) + + var result struct { + Profiles []struct { + Arn string `json:"arn"` + } `json:"profiles"` + ProfileArn string `json:"profileArn"` + } + + if err := json.Unmarshal(respBody, &result); err != nil { + return "" + } + + if result.ProfileArn != "" { + return result.ProfileArn + } + + if len(result.Profiles) > 0 { + return result.Profiles[0].Arn + } + + return "" +} + +func (c *SSOOIDCClient) tryListCustomizations(ctx context.Context, accessToken string) string { + payload := map[string]interface{}{ + "origin": "AI_EDITOR", + } + + body, err := json.Marshal(payload) + if err != nil { + return "" + } + + req, err := http.NewRequestWithContext(ctx, http.MethodPost, "https://codewhisperer.us-east-1.amazonaws.com", strings.NewReader(string(body))) + if err != nil { + return "" + } + + req.Header.Set("Content-Type", "application/x-amz-json-1.0") + req.Header.Set("x-amz-target", "AmazonCodeWhispererService.ListAvailableCustomizations") + req.Header.Set("Authorization", "Bearer "+accessToken) + req.Header.Set("Accept", "application/json") + + resp, err := c.httpClient.Do(req) + if err != nil { + return "" + } + defer resp.Body.Close() + + respBody, _ := io.ReadAll(resp.Body) + + if resp.StatusCode != http.StatusOK { + log.Debugf("ListAvailableCustomizations failed (status %d): %s", resp.StatusCode, string(respBody)) + return "" + } + + log.Debugf("ListAvailableCustomizations response: %s", string(respBody)) + + var result struct { + Customizations []struct { + Arn string `json:"arn"` + } `json:"customizations"` + ProfileArn string `json:"profileArn"` + } + + if err := json.Unmarshal(respBody, &result); err != nil { + return "" + } + + if result.ProfileArn != "" { + return result.ProfileArn + } + + if len(result.Customizations) > 0 { + return result.Customizations[0].Arn + } + + return "" +} + +// RegisterClientForAuthCode registers a new OIDC client for authorization code flow. +func (c *SSOOIDCClient) RegisterClientForAuthCode(ctx context.Context, redirectURI string) (*RegisterClientResponse, error) { + payload := map[string]interface{}{ + "clientName": "Kiro IDE", + "clientType": "public", + "scopes": []string{"codewhisperer:completions", "codewhisperer:analysis", "codewhisperer:conversations", "codewhisperer:transformations", "codewhisperer:taskassist"}, + "grantTypes": []string{"authorization_code", "refresh_token"}, + "redirectUris": []string{redirectURI}, + "issuerUrl": builderIDStartURL, + } + + body, err := json.Marshal(payload) + if err != nil { + return nil, err + } + + req, err := http.NewRequestWithContext(ctx, http.MethodPost, ssoOIDCEndpoint+"/client/register", strings.NewReader(string(body))) + if err != nil { + return nil, err + } + req.Header.Set("Content-Type", "application/json") + req.Header.Set("User-Agent", kiroUserAgent) + + resp, err := c.httpClient.Do(req) + if err != nil { + return nil, err + } + defer resp.Body.Close() + + respBody, err := io.ReadAll(resp.Body) + if err != nil { + return nil, err + } + + if resp.StatusCode != http.StatusOK { + log.Debugf("register client for auth code failed (status %d): %s", resp.StatusCode, string(respBody)) + return nil, fmt.Errorf("register client failed (status %d)", resp.StatusCode) + } + + var result RegisterClientResponse + if err := json.Unmarshal(respBody, &result); err != nil { + return nil, err + } + + return &result, nil +} + +// AuthCodeCallbackResult contains the result from authorization code callback. +type AuthCodeCallbackResult struct { + Code string + State string + Error string +} + +// startAuthCodeCallbackServer starts a local HTTP server to receive the authorization code callback. +func (c *SSOOIDCClient) startAuthCodeCallbackServer(ctx context.Context, expectedState string) (string, <-chan AuthCodeCallbackResult, error) { + // Try to find an available port + listener, err := net.Listen("tcp", fmt.Sprintf("127.0.0.1:%d", authCodeCallbackPort)) + if err != nil { + // Try with dynamic port + log.Warnf("sso oidc: default port %d is busy, falling back to dynamic port", authCodeCallbackPort) + listener, err = net.Listen("tcp", "127.0.0.1:0") + if err != nil { + return "", nil, fmt.Errorf("failed to start callback server: %w", err) + } + } + + port := listener.Addr().(*net.TCPAddr).Port + redirectURI := fmt.Sprintf("http://127.0.0.1:%d%s", port, authCodeCallbackPath) + resultChan := make(chan AuthCodeCallbackResult, 1) + + server := &http.Server{ + ReadHeaderTimeout: 10 * time.Second, + } + + mux := http.NewServeMux() + mux.HandleFunc(authCodeCallbackPath, func(w http.ResponseWriter, r *http.Request) { + code := r.URL.Query().Get("code") + state := r.URL.Query().Get("state") + errParam := r.URL.Query().Get("error") + + // Send response to browser + w.Header().Set("Content-Type", "text/html; charset=utf-8") + if errParam != "" { + w.WriteHeader(http.StatusBadRequest) + fmt.Fprintf(w, ` +Login Failed +

Login Failed

Error: %s

You can close this window.

`, html.EscapeString(errParam)) + resultChan <- AuthCodeCallbackResult{Error: errParam} + return + } + + if state != expectedState { + w.WriteHeader(http.StatusBadRequest) + fmt.Fprint(w, ` +Login Failed +

Login Failed

Invalid state parameter

You can close this window.

`) + resultChan <- AuthCodeCallbackResult{Error: "state mismatch"} + return + } + + fmt.Fprint(w, ` +Login Successful +

Login Successful!

You can close this window and return to the terminal.

+`) + resultChan <- AuthCodeCallbackResult{Code: code, State: state} + }) + + server.Handler = mux + + go func() { + if err := server.Serve(listener); err != nil && err != http.ErrServerClosed { + log.Debugf("auth code callback server error: %v", err) + } + }() + + go func() { + select { + case <-ctx.Done(): + case <-time.After(10 * time.Minute): + case <-resultChan: + } + _ = server.Shutdown(context.Background()) + }() + + return redirectURI, resultChan, nil +} + +// generatePKCEForAuthCode generates PKCE code verifier and challenge for authorization code flow. +func generatePKCEForAuthCode() (verifier, challenge string, err error) { + b := make([]byte, 32) + if _, err := rand.Read(b); err != nil { + return "", "", fmt.Errorf("failed to generate random bytes: %w", err) + } + verifier = base64.RawURLEncoding.EncodeToString(b) + h := sha256.Sum256([]byte(verifier)) + challenge = base64.RawURLEncoding.EncodeToString(h[:]) + return verifier, challenge, nil +} + +// generateStateForAuthCode generates a random state parameter. +func generateStateForAuthCode() (string, error) { + b := make([]byte, 16) + if _, err := rand.Read(b); err != nil { + return "", err + } + return base64.RawURLEncoding.EncodeToString(b), nil +} + +// CreateTokenWithAuthCode exchanges authorization code for tokens. +func (c *SSOOIDCClient) CreateTokenWithAuthCode(ctx context.Context, clientID, clientSecret, code, codeVerifier, redirectURI string) (*CreateTokenResponse, error) { + payload := map[string]string{ + "clientId": clientID, + "clientSecret": clientSecret, + "code": code, + "codeVerifier": codeVerifier, + "redirectUri": redirectURI, + "grantType": "authorization_code", + } + + body, err := json.Marshal(payload) + if err != nil { + return nil, err + } + + req, err := http.NewRequestWithContext(ctx, http.MethodPost, ssoOIDCEndpoint+"/token", strings.NewReader(string(body))) + if err != nil { + return nil, err + } + req.Header.Set("Content-Type", "application/json") + req.Header.Set("User-Agent", kiroUserAgent) + + resp, err := c.httpClient.Do(req) + if err != nil { + return nil, err + } + defer resp.Body.Close() + + respBody, err := io.ReadAll(resp.Body) + if err != nil { + return nil, err + } + + if resp.StatusCode != http.StatusOK { + log.Debugf("create token with auth code failed (status %d): %s", resp.StatusCode, string(respBody)) + return nil, fmt.Errorf("create token failed (status %d)", resp.StatusCode) + } + + var result CreateTokenResponse + if err := json.Unmarshal(respBody, &result); err != nil { + return nil, err + } + + return &result, nil +} + +// LoginWithBuilderIDAuthCode performs the authorization code flow for AWS Builder ID. +// This provides a better UX than device code flow as it uses automatic browser callback. +func (c *SSOOIDCClient) LoginWithBuilderIDAuthCode(ctx context.Context) (*KiroTokenData, error) { + fmt.Println("\n╔══════════════════════════════════════════════════════════╗") + fmt.Println("║ Kiro Authentication (AWS Builder ID - Auth Code) ║") + fmt.Println("╚══════════════════════════════════════════════════════════╝") + + // Step 1: Generate PKCE and state + codeVerifier, codeChallenge, err := generatePKCEForAuthCode() + if err != nil { + return nil, fmt.Errorf("failed to generate PKCE: %w", err) + } + + state, err := generateStateForAuthCode() + if err != nil { + return nil, fmt.Errorf("failed to generate state: %w", err) + } + + // Step 2: Start callback server + fmt.Println("\nStarting callback server...") + redirectURI, resultChan, err := c.startAuthCodeCallbackServer(ctx, state) + if err != nil { + return nil, fmt.Errorf("failed to start callback server: %w", err) + } + log.Debugf("Callback server started, redirect URI: %s", redirectURI) + + // Step 3: Register client with auth code grant type + fmt.Println("Registering client...") + regResp, err := c.RegisterClientForAuthCode(ctx, redirectURI) + if err != nil { + return nil, fmt.Errorf("failed to register client: %w", err) + } + log.Debugf("Client registered: %s", regResp.ClientID) + + // Step 4: Build authorization URL + scopes := "codewhisperer:completions,codewhisperer:analysis,codewhisperer:conversations" + authURL := fmt.Sprintf("%s/authorize?response_type=code&client_id=%s&redirect_uri=%s&scopes=%s&state=%s&code_challenge=%s&code_challenge_method=S256", + ssoOIDCEndpoint, + regResp.ClientID, + redirectURI, + scopes, + state, + codeChallenge, + ) + + // Step 5: Open browser + fmt.Println("\n════════════════════════════════════════════════════════════") + fmt.Println(" Opening browser for authentication...") + fmt.Println("════════════════════════════════════════════════════════════") + fmt.Printf("\n URL: %s\n\n", authURL) + + // Set incognito mode + if c.cfg != nil { + browser.SetIncognitoMode(c.cfg.IncognitoBrowser) + } else { + browser.SetIncognitoMode(true) + } + + if err := browser.OpenURL(authURL); err != nil { + log.Warnf("Could not open browser automatically: %v", err) + fmt.Println(" ⚠ Could not open browser automatically.") + fmt.Println(" Please open the URL above in your browser manually.") + } else { + fmt.Println(" (Browser opened automatically)") + } + + fmt.Println("\n Waiting for authorization callback...") + + // Step 6: Wait for callback + select { + case <-ctx.Done(): + browser.CloseBrowser() + return nil, ctx.Err() + case <-time.After(10 * time.Minute): + browser.CloseBrowser() + return nil, fmt.Errorf("authorization timed out") + case result := <-resultChan: + if result.Error != "" { + browser.CloseBrowser() + return nil, fmt.Errorf("authorization failed: %s", result.Error) + } + + fmt.Println("\n✓ Authorization received!") + + // Close browser + if err := browser.CloseBrowser(); err != nil { + log.Debugf("Failed to close browser: %v", err) + } + + // Step 7: Exchange code for tokens + fmt.Println("Exchanging code for tokens...") + tokenResp, err := c.CreateTokenWithAuthCode(ctx, regResp.ClientID, regResp.ClientSecret, result.Code, codeVerifier, redirectURI) + if err != nil { + return nil, fmt.Errorf("failed to exchange code for tokens: %w", err) + } + + fmt.Println("\n✓ Authentication successful!") + + // Step 8: Get profile ARN + fmt.Println("Fetching profile information...") + profileArn := c.fetchProfileArn(ctx, tokenResp.AccessToken) + + // Fetch user email (tries CodeWhisperer API first, then userinfo endpoint, then JWT parsing) + email := FetchUserEmailWithFallback(ctx, c.cfg, tokenResp.AccessToken) + if email != "" { + fmt.Printf(" Logged in as: %s\n", email) + } + + expiresAt := time.Now().Add(time.Duration(tokenResp.ExpiresIn) * time.Second) + + return &KiroTokenData{ + AccessToken: tokenResp.AccessToken, + RefreshToken: tokenResp.RefreshToken, + ProfileArn: profileArn, + ExpiresAt: expiresAt.Format(time.RFC3339), + AuthMethod: "builder-id", + Provider: "AWS", + ClientID: regResp.ClientID, + ClientSecret: regResp.ClientSecret, + Email: email, + }, nil + } +} diff --git a/internal/auth/kiro/token.go b/internal/auth/kiro/token.go index e83b1728..bfbdc795 100644 --- a/internal/auth/kiro/token.go +++ b/internal/auth/kiro/token.go @@ -9,6 +9,8 @@ import ( // KiroTokenStorage holds the persistent token data for Kiro authentication. type KiroTokenStorage struct { + // Type is the provider type for management UI recognition (must be "kiro") + Type string `json:"type"` // AccessToken is the OAuth2 access token for API access AccessToken string `json:"access_token"` // RefreshToken is used to obtain new access tokens @@ -23,6 +25,16 @@ type KiroTokenStorage struct { Provider string `json:"provider"` // LastRefresh is the timestamp of the last token refresh LastRefresh string `json:"last_refresh"` + // ClientID is the OAuth client ID (required for token refresh) + ClientID string `json:"clientId,omitempty"` + // ClientSecret is the OAuth client secret (required for token refresh) + ClientSecret string `json:"clientSecret,omitempty"` + // Region is the AWS region + Region string `json:"region,omitempty"` + // StartURL is the AWS Identity Center start URL (for IDC auth) + StartURL string `json:"startUrl,omitempty"` + // Email is the user's email address + Email string `json:"email,omitempty"` } // SaveTokenToFile persists the token storage to the specified file path. @@ -68,5 +80,10 @@ func (s *KiroTokenStorage) ToTokenData() *KiroTokenData { ExpiresAt: s.ExpiresAt, AuthMethod: s.AuthMethod, Provider: s.Provider, + ClientID: s.ClientID, + ClientSecret: s.ClientSecret, + Region: s.Region, + StartURL: s.StartURL, + Email: s.Email, } } diff --git a/internal/auth/kiro/usage_checker.go b/internal/auth/kiro/usage_checker.go new file mode 100644 index 00000000..94870214 --- /dev/null +++ b/internal/auth/kiro/usage_checker.go @@ -0,0 +1,243 @@ +// Package kiro provides authentication functionality for AWS CodeWhisperer (Kiro) API. +// This file implements usage quota checking and monitoring. +package kiro + +import ( + "context" + "encoding/json" + "fmt" + "io" + "net/http" + "strings" + "time" + + "github.com/router-for-me/CLIProxyAPI/v6/internal/config" + "github.com/router-for-me/CLIProxyAPI/v6/internal/util" +) + +// UsageQuotaResponse represents the API response structure for usage quota checking. +type UsageQuotaResponse struct { + UsageBreakdownList []UsageBreakdownExtended `json:"usageBreakdownList"` + SubscriptionInfo *SubscriptionInfo `json:"subscriptionInfo,omitempty"` + NextDateReset float64 `json:"nextDateReset,omitempty"` +} + +// UsageBreakdownExtended represents detailed usage information for quota checking. +// Note: UsageBreakdown is already defined in codewhisperer_client.go +type UsageBreakdownExtended struct { + ResourceType string `json:"resourceType"` + UsageLimitWithPrecision float64 `json:"usageLimitWithPrecision"` + CurrentUsageWithPrecision float64 `json:"currentUsageWithPrecision"` + FreeTrialInfo *FreeTrialInfoExtended `json:"freeTrialInfo,omitempty"` +} + +// FreeTrialInfoExtended represents free trial usage information. +type FreeTrialInfoExtended struct { + FreeTrialStatus string `json:"freeTrialStatus"` + UsageLimitWithPrecision float64 `json:"usageLimitWithPrecision"` + CurrentUsageWithPrecision float64 `json:"currentUsageWithPrecision"` +} + +// QuotaStatus represents the quota status for a token. +type QuotaStatus struct { + TotalLimit float64 + CurrentUsage float64 + RemainingQuota float64 + IsExhausted bool + ResourceType string + NextReset time.Time +} + +// UsageChecker provides methods for checking token quota usage. +type UsageChecker struct { + httpClient *http.Client + endpoint string +} + +// NewUsageChecker creates a new UsageChecker instance. +func NewUsageChecker(cfg *config.Config) *UsageChecker { + return &UsageChecker{ + httpClient: util.SetProxy(&cfg.SDKConfig, &http.Client{Timeout: 30 * time.Second}), + endpoint: awsKiroEndpoint, + } +} + +// NewUsageCheckerWithClient creates a UsageChecker with a custom HTTP client. +func NewUsageCheckerWithClient(client *http.Client) *UsageChecker { + return &UsageChecker{ + httpClient: client, + endpoint: awsKiroEndpoint, + } +} + +// CheckUsage retrieves usage limits for the given token. +func (c *UsageChecker) CheckUsage(ctx context.Context, tokenData *KiroTokenData) (*UsageQuotaResponse, error) { + if tokenData == nil { + return nil, fmt.Errorf("token data is nil") + } + + if tokenData.AccessToken == "" { + return nil, fmt.Errorf("access token is empty") + } + + payload := map[string]interface{}{ + "origin": "AI_EDITOR", + "profileArn": tokenData.ProfileArn, + "resourceType": "AGENTIC_REQUEST", + } + + jsonBody, err := json.Marshal(payload) + if err != nil { + return nil, fmt.Errorf("failed to marshal request: %w", err) + } + + req, err := http.NewRequestWithContext(ctx, http.MethodPost, c.endpoint, strings.NewReader(string(jsonBody))) + if err != nil { + return nil, fmt.Errorf("failed to create request: %w", err) + } + + req.Header.Set("Content-Type", "application/x-amz-json-1.0") + req.Header.Set("x-amz-target", targetGetUsage) + req.Header.Set("Authorization", "Bearer "+tokenData.AccessToken) + req.Header.Set("Accept", "application/json") + + resp, err := c.httpClient.Do(req) + if err != nil { + return nil, fmt.Errorf("request failed: %w", err) + } + defer resp.Body.Close() + + body, err := io.ReadAll(resp.Body) + if err != nil { + return nil, fmt.Errorf("failed to read response: %w", err) + } + + if resp.StatusCode != http.StatusOK { + return nil, fmt.Errorf("API error (status %d): %s", resp.StatusCode, string(body)) + } + + var result UsageQuotaResponse + if err := json.Unmarshal(body, &result); err != nil { + return nil, fmt.Errorf("failed to parse usage response: %w", err) + } + + return &result, nil +} + +// CheckUsageByAccessToken retrieves usage limits using an access token and profile ARN directly. +func (c *UsageChecker) CheckUsageByAccessToken(ctx context.Context, accessToken, profileArn string) (*UsageQuotaResponse, error) { + tokenData := &KiroTokenData{ + AccessToken: accessToken, + ProfileArn: profileArn, + } + return c.CheckUsage(ctx, tokenData) +} + +// GetRemainingQuota calculates the remaining quota from usage limits. +func GetRemainingQuota(usage *UsageQuotaResponse) float64 { + if usage == nil || len(usage.UsageBreakdownList) == 0 { + return 0 + } + + var totalRemaining float64 + for _, breakdown := range usage.UsageBreakdownList { + remaining := breakdown.UsageLimitWithPrecision - breakdown.CurrentUsageWithPrecision + if remaining > 0 { + totalRemaining += remaining + } + + if breakdown.FreeTrialInfo != nil { + freeRemaining := breakdown.FreeTrialInfo.UsageLimitWithPrecision - breakdown.FreeTrialInfo.CurrentUsageWithPrecision + if freeRemaining > 0 { + totalRemaining += freeRemaining + } + } + } + + return totalRemaining +} + +// IsQuotaExhausted checks if the quota is exhausted based on usage limits. +func IsQuotaExhausted(usage *UsageQuotaResponse) bool { + if usage == nil || len(usage.UsageBreakdownList) == 0 { + return true + } + + for _, breakdown := range usage.UsageBreakdownList { + if breakdown.CurrentUsageWithPrecision < breakdown.UsageLimitWithPrecision { + return false + } + + if breakdown.FreeTrialInfo != nil { + if breakdown.FreeTrialInfo.CurrentUsageWithPrecision < breakdown.FreeTrialInfo.UsageLimitWithPrecision { + return false + } + } + } + + return true +} + +// GetQuotaStatus retrieves a comprehensive quota status for a token. +func (c *UsageChecker) GetQuotaStatus(ctx context.Context, tokenData *KiroTokenData) (*QuotaStatus, error) { + usage, err := c.CheckUsage(ctx, tokenData) + if err != nil { + return nil, err + } + + status := &QuotaStatus{ + IsExhausted: IsQuotaExhausted(usage), + } + + if len(usage.UsageBreakdownList) > 0 { + breakdown := usage.UsageBreakdownList[0] + status.TotalLimit = breakdown.UsageLimitWithPrecision + status.CurrentUsage = breakdown.CurrentUsageWithPrecision + status.RemainingQuota = breakdown.UsageLimitWithPrecision - breakdown.CurrentUsageWithPrecision + status.ResourceType = breakdown.ResourceType + + if breakdown.FreeTrialInfo != nil { + status.TotalLimit += breakdown.FreeTrialInfo.UsageLimitWithPrecision + status.CurrentUsage += breakdown.FreeTrialInfo.CurrentUsageWithPrecision + freeRemaining := breakdown.FreeTrialInfo.UsageLimitWithPrecision - breakdown.FreeTrialInfo.CurrentUsageWithPrecision + if freeRemaining > 0 { + status.RemainingQuota += freeRemaining + } + } + } + + if usage.NextDateReset > 0 { + status.NextReset = time.Unix(int64(usage.NextDateReset/1000), 0) + } + + return status, nil +} + +// CalculateAvailableCount calculates the available request count based on usage limits. +func CalculateAvailableCount(usage *UsageQuotaResponse) float64 { + return GetRemainingQuota(usage) +} + +// GetUsagePercentage calculates the usage percentage. +func GetUsagePercentage(usage *UsageQuotaResponse) float64 { + if usage == nil || len(usage.UsageBreakdownList) == 0 { + return 100.0 + } + + var totalLimit, totalUsage float64 + for _, breakdown := range usage.UsageBreakdownList { + totalLimit += breakdown.UsageLimitWithPrecision + totalUsage += breakdown.CurrentUsageWithPrecision + + if breakdown.FreeTrialInfo != nil { + totalLimit += breakdown.FreeTrialInfo.UsageLimitWithPrecision + totalUsage += breakdown.FreeTrialInfo.CurrentUsageWithPrecision + } + } + + if totalLimit == 0 { + return 100.0 + } + + return (totalUsage / totalLimit) * 100 +} diff --git a/internal/registry/kiro_model_converter.go b/internal/registry/kiro_model_converter.go new file mode 100644 index 00000000..fe50a8f3 --- /dev/null +++ b/internal/registry/kiro_model_converter.go @@ -0,0 +1,303 @@ +// Package registry provides Kiro model conversion utilities. +// This file handles converting dynamic Kiro API model lists to the internal ModelInfo format, +// and merging with static metadata for thinking support and other capabilities. +package registry + +import ( + "strings" + "time" +) + +// KiroAPIModel represents a model from Kiro API response. +// This is a local copy to avoid import cycles with the kiro package. +// The structure mirrors kiro.KiroModel for easy data conversion. +type KiroAPIModel struct { + // ModelID is the unique identifier for the model (e.g., "claude-sonnet-4.5") + ModelID string + // ModelName is the human-readable name + ModelName string + // Description is the model description + Description string + // RateMultiplier is the credit multiplier for this model + RateMultiplier float64 + // RateUnit is the unit for rate calculation (e.g., "credit") + RateUnit string + // MaxInputTokens is the maximum input token limit + MaxInputTokens int +} + +// DefaultKiroThinkingSupport defines the default thinking configuration for Kiro models. +// All Kiro models support thinking with the following budget range. +var DefaultKiroThinkingSupport = &ThinkingSupport{ + Min: 1024, // Minimum thinking budget tokens + Max: 32000, // Maximum thinking budget tokens + ZeroAllowed: true, // Allow disabling thinking with 0 + DynamicAllowed: true, // Allow dynamic thinking budget (-1) +} + +// DefaultKiroContextLength is the default context window size for Kiro models. +const DefaultKiroContextLength = 200000 + +// DefaultKiroMaxCompletionTokens is the default max completion tokens for Kiro models. +const DefaultKiroMaxCompletionTokens = 64000 + +// ConvertKiroAPIModels converts Kiro API models to internal ModelInfo format. +// It performs the following transformations: +// - Normalizes model ID (e.g., claude-sonnet-4.5 → kiro-claude-sonnet-4-5) +// - Adds default thinking support metadata +// - Sets default context length and max completion tokens if not provided +// +// Parameters: +// - kiroModels: List of models from Kiro API response +// +// Returns: +// - []*ModelInfo: Converted model information list +func ConvertKiroAPIModels(kiroModels []*KiroAPIModel) []*ModelInfo { + if len(kiroModels) == 0 { + return nil + } + + now := time.Now().Unix() + result := make([]*ModelInfo, 0, len(kiroModels)) + + for _, km := range kiroModels { + // Skip nil models + if km == nil { + continue + } + + // Skip models without valid ID + if km.ModelID == "" { + continue + } + + // Normalize the model ID to kiro-* format + normalizedID := normalizeKiroModelID(km.ModelID) + + // Create ModelInfo with converted data + info := &ModelInfo{ + ID: normalizedID, + Object: "model", + Created: now, + OwnedBy: "aws", + Type: "kiro", + DisplayName: generateKiroDisplayName(km.ModelName, normalizedID), + Description: km.Description, + // Use MaxInputTokens from API if available, otherwise use default + ContextLength: getContextLength(km.MaxInputTokens), + MaxCompletionTokens: DefaultKiroMaxCompletionTokens, + // All Kiro models support thinking + Thinking: cloneThinkingSupport(DefaultKiroThinkingSupport), + } + + result = append(result, info) + } + + return result +} + +// GenerateAgenticVariants creates -agentic variants for each model. +// Agentic variants are optimized for coding agents with chunked writes. +// +// Parameters: +// - models: Base models to generate variants for +// +// Returns: +// - []*ModelInfo: Combined list of base models and their agentic variants +func GenerateAgenticVariants(models []*ModelInfo) []*ModelInfo { + if len(models) == 0 { + return nil + } + + // Pre-allocate result with capacity for both base models and variants + result := make([]*ModelInfo, 0, len(models)*2) + + for _, model := range models { + if model == nil { + continue + } + + // Add the base model first + result = append(result, model) + + // Skip if model already has -agentic suffix + if strings.HasSuffix(model.ID, "-agentic") { + continue + } + + // Skip special models that shouldn't have agentic variants + if model.ID == "kiro-auto" { + continue + } + + // Create agentic variant + agenticModel := &ModelInfo{ + ID: model.ID + "-agentic", + Object: model.Object, + Created: model.Created, + OwnedBy: model.OwnedBy, + Type: model.Type, + DisplayName: model.DisplayName + " (Agentic)", + Description: generateAgenticDescription(model.Description), + ContextLength: model.ContextLength, + MaxCompletionTokens: model.MaxCompletionTokens, + Thinking: cloneThinkingSupport(model.Thinking), + } + + result = append(result, agenticModel) + } + + return result +} + +// MergeWithStaticMetadata merges dynamic models with static metadata. +// Static metadata takes priority for any overlapping fields. +// This allows manual overrides for specific models while keeping dynamic discovery. +// +// Parameters: +// - dynamicModels: Models from Kiro API (converted to ModelInfo) +// - staticModels: Predefined model metadata (from GetKiroModels()) +// +// Returns: +// - []*ModelInfo: Merged model list with static metadata taking priority +func MergeWithStaticMetadata(dynamicModels, staticModels []*ModelInfo) []*ModelInfo { + if len(dynamicModels) == 0 && len(staticModels) == 0 { + return nil + } + + // Build a map of static models for quick lookup + staticMap := make(map[string]*ModelInfo, len(staticModels)) + for _, sm := range staticModels { + if sm != nil && sm.ID != "" { + staticMap[sm.ID] = sm + } + } + + // Build result, preferring static metadata where available + seenIDs := make(map[string]struct{}) + result := make([]*ModelInfo, 0, len(dynamicModels)+len(staticModels)) + + // First, process dynamic models and merge with static if available + for _, dm := range dynamicModels { + if dm == nil || dm.ID == "" { + continue + } + + // Skip duplicates + if _, seen := seenIDs[dm.ID]; seen { + continue + } + seenIDs[dm.ID] = struct{}{} + + // Check if static metadata exists for this model + if sm, exists := staticMap[dm.ID]; exists { + // Static metadata takes priority - use static model + result = append(result, sm) + } else { + // No static metadata - use dynamic model + result = append(result, dm) + } + } + + // Add any static models not in dynamic list + for _, sm := range staticModels { + if sm == nil || sm.ID == "" { + continue + } + if _, seen := seenIDs[sm.ID]; seen { + continue + } + seenIDs[sm.ID] = struct{}{} + result = append(result, sm) + } + + return result +} + +// normalizeKiroModelID converts Kiro API model IDs to internal format. +// Transformation rules: +// - Adds "kiro-" prefix if not present +// - Replaces dots with hyphens (e.g., 4.5 → 4-5) +// - Handles special cases like "auto" → "kiro-auto" +// +// Examples: +// - "claude-sonnet-4.5" → "kiro-claude-sonnet-4-5" +// - "claude-opus-4.5" → "kiro-claude-opus-4-5" +// - "auto" → "kiro-auto" +// - "kiro-claude-sonnet-4-5" → "kiro-claude-sonnet-4-5" (unchanged) +func normalizeKiroModelID(modelID string) string { + if modelID == "" { + return "" + } + + // Trim whitespace + modelID = strings.TrimSpace(modelID) + + // Replace dots with hyphens (e.g., 4.5 → 4-5) + normalized := strings.ReplaceAll(modelID, ".", "-") + + // Add kiro- prefix if not present + if !strings.HasPrefix(normalized, "kiro-") { + normalized = "kiro-" + normalized + } + + return normalized +} + +// generateKiroDisplayName creates a human-readable display name. +// Uses the API-provided model name if available, otherwise generates from ID. +func generateKiroDisplayName(modelName, normalizedID string) string { + if modelName != "" { + return "Kiro " + modelName + } + + // Generate from normalized ID by removing kiro- prefix and formatting + displayID := strings.TrimPrefix(normalizedID, "kiro-") + // Capitalize first letter of each word + words := strings.Split(displayID, "-") + for i, word := range words { + if len(word) > 0 { + words[i] = strings.ToUpper(word[:1]) + word[1:] + } + } + return "Kiro " + strings.Join(words, " ") +} + +// generateAgenticDescription creates description for agentic variants. +func generateAgenticDescription(baseDescription string) string { + if baseDescription == "" { + return "Optimized for coding agents with chunked writes" + } + return baseDescription + " (Agentic mode: chunked writes)" +} + +// getContextLength returns the context length, using default if not provided. +func getContextLength(maxInputTokens int) int { + if maxInputTokens > 0 { + return maxInputTokens + } + return DefaultKiroContextLength +} + +// cloneThinkingSupport creates a deep copy of ThinkingSupport. +// Returns nil if input is nil. +func cloneThinkingSupport(ts *ThinkingSupport) *ThinkingSupport { + if ts == nil { + return nil + } + + clone := &ThinkingSupport{ + Min: ts.Min, + Max: ts.Max, + ZeroAllowed: ts.ZeroAllowed, + DynamicAllowed: ts.DynamicAllowed, + } + + // Deep copy Levels slice if present + if len(ts.Levels) > 0 { + clone.Levels = make([]string, len(ts.Levels)) + copy(clone.Levels, ts.Levels) + } + + return clone +} diff --git a/internal/runtime/executor/kiro_executor.go b/internal/runtime/executor/kiro_executor.go index 3d152955..b0c14c61 100644 --- a/internal/runtime/executor/kiro_executor.go +++ b/internal/runtime/executor/kiro_executor.go @@ -7,13 +7,16 @@ import ( "encoding/base64" "encoding/binary" "encoding/json" + "errors" "fmt" "io" + "net" "net/http" "os" "path/filepath" "strings" "sync" + "syscall" "time" "github.com/google/uuid" @@ -53,9 +56,28 @@ const ( kiroIDEUserAgent = "aws-sdk-js/1.0.18 ua/2.1 os/darwin#25.0.0 lang/js md/nodejs#20.16.0 api/codewhispererstreaming#1.0.18 m/E KiroIDE-0.2.13-66c23a8c5d15afabec89ef9954ef52a119f10d369df04d548fc6c1eac694b0d1" kiroIDEAmzUserAgent = "aws-sdk-js/1.0.18 KiroIDE-0.2.13-66c23a8c5d15afabec89ef9954ef52a119f10d369df04d548fc6c1eac694b0d1" kiroIDEAgentModeSpec = "spec" - kiroAgentModeVibe = "vibe" + + // Socket retry configuration constants (based on kiro2Api reference implementation) + // Maximum number of retry attempts for socket/network errors + kiroSocketMaxRetries = 3 + // Base delay between retry attempts (uses exponential backoff: delay * 2^attempt) + kiroSocketBaseRetryDelay = 1 * time.Second + // Maximum delay between retry attempts (cap for exponential backoff) + kiroSocketMaxRetryDelay = 30 * time.Second + // First token timeout for streaming responses (how long to wait for first response) + kiroFirstTokenTimeout = 15 * time.Second + // Streaming read timeout (how long to wait between chunks) + kiroStreamingReadTimeout = 300 * time.Second ) +// retryableHTTPStatusCodes defines HTTP status codes that are considered retryable. +// Based on kiro2Api reference: 502 (Bad Gateway), 503 (Service Unavailable), 504 (Gateway Timeout) +var retryableHTTPStatusCodes = map[int]bool{ + 502: true, // Bad Gateway - upstream server error + 503: true, // Service Unavailable - server temporarily overloaded + 504: true, // Gateway Timeout - upstream server timeout +} + // Real-time usage estimation configuration // These control how often usage updates are sent during streaming var ( @@ -63,6 +85,241 @@ var ( usageUpdateTimeInterval = 15 * time.Second // Or every 15 seconds, whichever comes first ) +// Global FingerprintManager for dynamic User-Agent generation per token +// Each token gets a unique fingerprint on first use, which is cached for subsequent requests +var ( + globalFingerprintManager *kiroauth.FingerprintManager + globalFingerprintManagerOnce sync.Once +) + +// getGlobalFingerprintManager returns the global FingerprintManager instance +func getGlobalFingerprintManager() *kiroauth.FingerprintManager { + globalFingerprintManagerOnce.Do(func() { + globalFingerprintManager = kiroauth.NewFingerprintManager() + log.Infof("kiro: initialized global FingerprintManager for dynamic UA generation") + }) + return globalFingerprintManager +} + +// retryConfig holds configuration for socket retry logic. +// Based on kiro2Api Python implementation patterns. +type retryConfig struct { + MaxRetries int // Maximum number of retry attempts + BaseDelay time.Duration // Base delay between retries (exponential backoff) + MaxDelay time.Duration // Maximum delay cap + RetryableErrors []string // List of retryable error patterns + RetryableStatus map[int]bool // HTTP status codes to retry + FirstTokenTmout time.Duration // Timeout for first token in streaming + StreamReadTmout time.Duration // Timeout between stream chunks +} + +// defaultRetryConfig returns the default retry configuration for Kiro socket operations. +func defaultRetryConfig() retryConfig { + return retryConfig{ + MaxRetries: kiroSocketMaxRetries, + BaseDelay: kiroSocketBaseRetryDelay, + MaxDelay: kiroSocketMaxRetryDelay, + RetryableStatus: retryableHTTPStatusCodes, + RetryableErrors: []string{ + "connection reset", + "connection refused", + "broken pipe", + "EOF", + "timeout", + "temporary failure", + "no such host", + "network is unreachable", + "i/o timeout", + }, + FirstTokenTmout: kiroFirstTokenTimeout, + StreamReadTmout: kiroStreamingReadTimeout, + } +} + +// isRetryableError checks if an error is retryable based on error type and message. +// Returns true for network timeouts, connection resets, and temporary failures. +// Based on kiro2Api's retry logic patterns. +func isRetryableError(err error) bool { + if err == nil { + return false + } + + // Check for context cancellation - not retryable + if errors.Is(err, context.Canceled) || errors.Is(err, context.DeadlineExceeded) { + return false + } + + // Check for net.Error (timeout, temporary) + var netErr net.Error + if errors.As(err, &netErr) { + if netErr.Timeout() { + log.Debugf("kiro: isRetryableError: network timeout detected") + return true + } + // Note: Temporary() is deprecated but still useful for some error types + } + + // Check for specific syscall errors (connection reset, broken pipe, etc.) + var syscallErr syscall.Errno + if errors.As(err, &syscallErr) { + switch syscallErr { + case syscall.ECONNRESET: // Connection reset by peer + log.Debugf("kiro: isRetryableError: ECONNRESET detected") + return true + case syscall.ECONNREFUSED: // Connection refused + log.Debugf("kiro: isRetryableError: ECONNREFUSED detected") + return true + case syscall.EPIPE: // Broken pipe + log.Debugf("kiro: isRetryableError: EPIPE (broken pipe) detected") + return true + case syscall.ETIMEDOUT: // Connection timed out + log.Debugf("kiro: isRetryableError: ETIMEDOUT detected") + return true + case syscall.ENETUNREACH: // Network is unreachable + log.Debugf("kiro: isRetryableError: ENETUNREACH detected") + return true + case syscall.EHOSTUNREACH: // No route to host + log.Debugf("kiro: isRetryableError: EHOSTUNREACH detected") + return true + } + } + + // Check for net.OpError wrapping other errors + var opErr *net.OpError + if errors.As(err, &opErr) { + log.Debugf("kiro: isRetryableError: net.OpError detected, op=%s", opErr.Op) + // Recursively check the wrapped error + if opErr.Err != nil { + return isRetryableError(opErr.Err) + } + return true + } + + // Check error message for retryable patterns + errMsg := strings.ToLower(err.Error()) + cfg := defaultRetryConfig() + for _, pattern := range cfg.RetryableErrors { + if strings.Contains(errMsg, pattern) { + log.Debugf("kiro: isRetryableError: pattern '%s' matched in error: %s", pattern, errMsg) + return true + } + } + + // Check for EOF which may indicate connection was closed + if errors.Is(err, io.EOF) || errors.Is(err, io.ErrUnexpectedEOF) { + log.Debugf("kiro: isRetryableError: EOF/UnexpectedEOF detected") + return true + } + + return false +} + +// isRetryableHTTPStatus checks if an HTTP status code is retryable. +// Based on kiro2Api: 502, 503, 504 are retryable server errors. +func isRetryableHTTPStatus(statusCode int) bool { + return retryableHTTPStatusCodes[statusCode] +} + +// calculateRetryDelay calculates the delay for the next retry attempt using exponential backoff. +// delay = min(baseDelay * 2^attempt, maxDelay) +// Adds ±30% jitter to prevent thundering herd. +func calculateRetryDelay(attempt int, cfg retryConfig) time.Duration { + return kiroauth.ExponentialBackoffWithJitter(attempt, cfg.BaseDelay, cfg.MaxDelay) +} + +// logRetryAttempt logs a retry attempt with relevant context. +func logRetryAttempt(attempt, maxRetries int, reason string, delay time.Duration, endpoint string) { + log.Warnf("kiro: retry attempt %d/%d for %s, waiting %v before next attempt (endpoint: %s)", + attempt+1, maxRetries, reason, delay, endpoint) +} + +// kiroHTTPClientPool provides a shared HTTP client with connection pooling for Kiro API. +// This reduces connection overhead and improves performance for concurrent requests. +// Based on kiro2Api's connection pooling pattern. +var ( + kiroHTTPClientPool *http.Client + kiroHTTPClientPoolOnce sync.Once +) + +// getKiroPooledHTTPClient returns a shared HTTP client with optimized connection pooling. +// The client is lazily initialized on first use and reused across requests. +// This is especially beneficial for: +// - Reducing TCP handshake overhead +// - Enabling HTTP/2 multiplexing +// - Better handling of keep-alive connections +func getKiroPooledHTTPClient() *http.Client { + kiroHTTPClientPoolOnce.Do(func() { + transport := &http.Transport{ + // Connection pool settings + MaxIdleConns: 100, // Max idle connections across all hosts + MaxIdleConnsPerHost: 20, // Max idle connections per host + MaxConnsPerHost: 50, // Max total connections per host + IdleConnTimeout: 90 * time.Second, // How long idle connections stay in pool + + // Timeouts for connection establishment + DialContext: (&net.Dialer{ + Timeout: 30 * time.Second, // TCP connection timeout + KeepAlive: 30 * time.Second, // TCP keep-alive interval + }).DialContext, + + // TLS handshake timeout + TLSHandshakeTimeout: 10 * time.Second, + + // Response header timeout + ResponseHeaderTimeout: 30 * time.Second, + + // Expect 100-continue timeout + ExpectContinueTimeout: 1 * time.Second, + + // Enable HTTP/2 when available + ForceAttemptHTTP2: true, + } + + kiroHTTPClientPool = &http.Client{ + Transport: transport, + // No global timeout - let individual requests set their own timeouts via context + } + + log.Debugf("kiro: initialized pooled HTTP client (MaxIdleConns=%d, MaxIdleConnsPerHost=%d, MaxConnsPerHost=%d)", + transport.MaxIdleConns, transport.MaxIdleConnsPerHost, transport.MaxConnsPerHost) + }) + + return kiroHTTPClientPool +} + +// newKiroHTTPClientWithPooling creates an HTTP client that uses connection pooling when appropriate. +// It respects proxy configuration from auth or config, falling back to the pooled client. +// This provides the best of both worlds: custom proxy support + connection reuse. +func newKiroHTTPClientWithPooling(ctx context.Context, cfg *config.Config, auth *cliproxyauth.Auth, timeout time.Duration) *http.Client { + // Check if a proxy is configured - if so, we need a custom client + var proxyURL string + if auth != nil { + proxyURL = strings.TrimSpace(auth.ProxyURL) + } + if proxyURL == "" && cfg != nil { + proxyURL = strings.TrimSpace(cfg.ProxyURL) + } + + // If proxy is configured, use the existing proxy-aware client (doesn't pool) + if proxyURL != "" { + log.Debugf("kiro: using proxy-aware HTTP client (proxy=%s)", proxyURL) + return newProxyAwareHTTPClient(ctx, cfg, auth, timeout) + } + + // No proxy - use pooled client for better performance + pooledClient := getKiroPooledHTTPClient() + + // If timeout is specified, we need to wrap the pooled transport with timeout + if timeout > 0 { + return &http.Client{ + Transport: pooledClient.Transport, + Timeout: timeout, + } + } + + return pooledClient +} + // kiroEndpointConfig bundles endpoint URL with its compatible Origin and AmzTarget values. // This solves the "triple mismatch" problem where different endpoints require matching // Origin and X-Amz-Target header values. @@ -99,7 +356,7 @@ var kiroEndpointConfigs = []kiroEndpointConfig{ Name: "CodeWhisperer", }, { - URL: "https://q.us-east-1.amazonaws.com/generateAssistantResponse", + URL: "https://q.us-east-1.amazonaws.com/", Origin: "CLI", AmzTarget: "AmazonQDeveloperStreamingService.SendMessage", Name: "AmazonQ", @@ -217,6 +474,29 @@ func NewKiroExecutor(cfg *config.Config) *KiroExecutor { // Identifier returns the unique identifier for this executor. func (e *KiroExecutor) Identifier() string { return "kiro" } +// applyDynamicFingerprint applies token-specific fingerprint headers to the request +// For IDC auth, uses dynamic fingerprint-based User-Agent +// For other auth types, uses static Amazon Q CLI style headers +func applyDynamicFingerprint(req *http.Request, auth *cliproxyauth.Auth) { + if isIDCAuth(auth) { + // Get token-specific fingerprint for dynamic UA generation + tokenKey := getTokenKey(auth) + fp := getGlobalFingerprintManager().GetFingerprint(tokenKey) + + // Use fingerprint-generated dynamic User-Agent + req.Header.Set("User-Agent", fp.BuildUserAgent()) + req.Header.Set("X-Amz-User-Agent", fp.BuildAmzUserAgent()) + req.Header.Set("x-amzn-kiro-agent-mode", kiroIDEAgentModeSpec) + + log.Debugf("kiro: using dynamic fingerprint for token %s (SDK:%s, OS:%s/%s, Kiro:%s)", + tokenKey[:8]+"...", fp.SDKVersion, fp.OSType, fp.OSVersion, fp.KiroVersion) + } else { + // Use static Amazon Q CLI style headers for non-IDC auth + req.Header.Set("User-Agent", kiroUserAgent) + req.Header.Set("X-Amz-User-Agent", kiroFullUserAgent) + } +} + // PrepareRequest prepares the HTTP request before execution. func (e *KiroExecutor) PrepareRequest(req *http.Request, auth *cliproxyauth.Auth) error { if req == nil { @@ -226,16 +506,10 @@ func (e *KiroExecutor) PrepareRequest(req *http.Request, auth *cliproxyauth.Auth if strings.TrimSpace(accessToken) == "" { return statusErr{code: http.StatusUnauthorized, msg: "missing access token"} } - if isIDCAuth(auth) { - req.Header.Set("User-Agent", kiroIDEUserAgent) - req.Header.Set("X-Amz-User-Agent", kiroIDEAmzUserAgent) - req.Header.Set("x-amzn-kiro-agent-mode", kiroIDEAgentModeSpec) - } else { - req.Header.Set("User-Agent", kiroUserAgent) - req.Header.Set("X-Amz-User-Agent", kiroFullUserAgent) - req.Header.Set("x-amzn-kiro-agent-mode", kiroAgentModeVibe) - } - req.Header.Set("x-amzn-codewhisperer-optout", "true") + + // Apply dynamic fingerprint-based headers + applyDynamicFingerprint(req, auth) + req.Header.Set("Amz-Sdk-Request", "attempt=1; max=3") req.Header.Set("Amz-Sdk-Invocation-Id", uuid.New().String()) req.Header.Set("Authorization", "Bearer "+accessToken) @@ -259,10 +533,23 @@ func (e *KiroExecutor) HttpRequest(ctx context.Context, auth *cliproxyauth.Auth, if errPrepare := e.PrepareRequest(httpReq, auth); errPrepare != nil { return nil, errPrepare } - httpClient := newProxyAwareHTTPClient(ctx, e.cfg, auth, 0) + httpClient := newKiroHTTPClientWithPooling(ctx, e.cfg, auth, 0) return httpClient.Do(httpReq) } +// getTokenKey returns a unique key for rate limiting based on auth credentials. +// Uses auth ID if available, otherwise falls back to a hash of the access token. +func getTokenKey(auth *cliproxyauth.Auth) string { + if auth != nil && auth.ID != "" { + return auth.ID + } + accessToken, _ := kiroCredentials(auth) + if len(accessToken) > 16 { + return accessToken[:16] + } + return accessToken +} + // Execute sends the request to Kiro API and returns the response. // Supports automatic token refresh on 401/403 errors. func (e *KiroExecutor) Execute(ctx context.Context, auth *cliproxyauth.Auth, req cliproxyexecutor.Request, opts cliproxyexecutor.Options) (resp cliproxyexecutor.Response, err error) { @@ -271,6 +558,24 @@ func (e *KiroExecutor) Execute(ctx context.Context, auth *cliproxyauth.Auth, req return resp, fmt.Errorf("kiro: access token not found in auth") } + // Rate limiting: get token key for tracking + tokenKey := getTokenKey(auth) + rateLimiter := kiroauth.GetGlobalRateLimiter() + cooldownMgr := kiroauth.GetGlobalCooldownManager() + + // Check if token is in cooldown period + if cooldownMgr.IsInCooldown(tokenKey) { + remaining := cooldownMgr.GetRemainingCooldown(tokenKey) + reason := cooldownMgr.GetCooldownReason(tokenKey) + log.Warnf("kiro: token %s is in cooldown (reason: %s), remaining: %v", tokenKey, reason, remaining) + return resp, fmt.Errorf("kiro: token is in cooldown for %v (reason: %s)", remaining, reason) + } + + // Wait for rate limiter before proceeding + log.Debugf("kiro: waiting for rate limiter for token %s", tokenKey) + rateLimiter.WaitForToken(tokenKey) + log.Debugf("kiro: rate limiter cleared for token %s", tokenKey) + reporter := newUsageReporter(ctx, e.Identifier(), req.Model, auth) defer reporter.trackFailure(ctx, &err) @@ -303,7 +608,7 @@ func (e *KiroExecutor) Execute(ctx context.Context, auth *cliproxyauth.Auth, req // Execute with retry on 401/403 and 429 (quota exhausted) // Note: currentOrigin and kiroPayload are built inside executeWithRetry for each endpoint - resp, err = e.executeWithRetry(ctx, auth, req, opts, accessToken, effectiveProfileArn, nil, body, from, to, reporter, "", kiroModelID, isAgentic, isChatOnly) + resp, err = e.executeWithRetry(ctx, auth, req, opts, accessToken, effectiveProfileArn, nil, body, from, to, reporter, "", kiroModelID, isAgentic, isChatOnly, tokenKey) return resp, err } @@ -312,9 +617,12 @@ func (e *KiroExecutor) Execute(ctx context.Context, auth *cliproxyauth.Auth, req // - Amazon Q endpoint (CLI origin) uses Amazon Q Developer quota // - CodeWhisperer endpoint (AI_EDITOR origin) uses Kiro IDE quota // Also supports multi-endpoint fallback similar to Antigravity implementation. -func (e *KiroExecutor) executeWithRetry(ctx context.Context, auth *cliproxyauth.Auth, req cliproxyexecutor.Request, opts cliproxyexecutor.Options, accessToken, profileArn string, kiroPayload, body []byte, from, to sdktranslator.Format, reporter *usageReporter, currentOrigin, kiroModelID string, isAgentic, isChatOnly bool) (cliproxyexecutor.Response, error) { +// tokenKey is used for rate limiting and cooldown tracking. +func (e *KiroExecutor) executeWithRetry(ctx context.Context, auth *cliproxyauth.Auth, req cliproxyexecutor.Request, opts cliproxyexecutor.Options, accessToken, profileArn string, kiroPayload, body []byte, from, to sdktranslator.Format, reporter *usageReporter, currentOrigin, kiroModelID string, isAgentic, isChatOnly bool, tokenKey string) (cliproxyexecutor.Response, error) { var resp cliproxyexecutor.Response maxRetries := 2 // Allow retries for token refresh + endpoint fallback + rateLimiter := kiroauth.GetGlobalRateLimiter() + cooldownMgr := kiroauth.GetGlobalCooldownManager() endpointConfigs := getKiroEndpointConfigs(auth) var last429Err error @@ -332,6 +640,12 @@ func (e *KiroExecutor) executeWithRetry(ctx context.Context, auth *cliproxyauth. endpointIdx+1, len(endpointConfigs), url, endpointConfig.Name, currentOrigin) for attempt := 0; attempt <= maxRetries; attempt++ { + // Apply human-like delay before first request (not on retries) + // This mimics natural user behavior patterns + if attempt == 0 && endpointIdx == 0 { + kiroauth.ApplyHumanLikeDelay() + } + httpReq, err := http.NewRequestWithContext(ctx, http.MethodPost, url, bytes.NewReader(kiroPayload)) if err != nil { return resp, err @@ -342,20 +656,9 @@ func (e *KiroExecutor) executeWithRetry(ctx context.Context, auth *cliproxyauth. // Use endpoint-specific X-Amz-Target (critical for avoiding 403 errors) httpReq.Header.Set("X-Amz-Target", endpointConfig.AmzTarget) - // Use different headers based on auth type - // IDC auth uses Kiro IDE style headers (from kiro2api) - // Other auth types use Amazon Q CLI style headers - if isIDCAuth(auth) { - httpReq.Header.Set("User-Agent", kiroIDEUserAgent) - httpReq.Header.Set("X-Amz-User-Agent", kiroIDEAmzUserAgent) - httpReq.Header.Set("x-amzn-kiro-agent-mode", kiroIDEAgentModeSpec) - log.Debugf("kiro: using Kiro IDE headers for IDC auth") - } else { - httpReq.Header.Set("User-Agent", kiroUserAgent) - httpReq.Header.Set("X-Amz-User-Agent", kiroFullUserAgent) - httpReq.Header.Set("x-amzn-kiro-agent-mode", kiroAgentModeVibe) - } - httpReq.Header.Set("x-amzn-codewhisperer-optout", "true") + // Apply dynamic fingerprint-based headers + applyDynamicFingerprint(httpReq, auth) + httpReq.Header.Set("Amz-Sdk-Request", "attempt=1; max=3") httpReq.Header.Set("Amz-Sdk-Invocation-Id", uuid.New().String()) @@ -386,10 +689,34 @@ func (e *KiroExecutor) executeWithRetry(ctx context.Context, auth *cliproxyauth. AuthValue: authValue, }) - httpClient := newProxyAwareHTTPClient(ctx, e.cfg, auth, 120*time.Second) + httpClient := newKiroHTTPClientWithPooling(ctx, e.cfg, auth, 120*time.Second) httpResp, err := httpClient.Do(httpReq) if err != nil { + // Check for context cancellation first - client disconnected, not a server error + // Use 499 (Client Closed Request - nginx convention) instead of 500 + if errors.Is(err, context.Canceled) { + log.Debugf("kiro: request canceled by client (context.Canceled)") + return resp, statusErr{code: 499, msg: "client canceled request"} + } + + // Check for context deadline exceeded - request timed out + // Return 504 Gateway Timeout instead of 500 + if errors.Is(err, context.DeadlineExceeded) { + log.Debugf("kiro: request timed out (context.DeadlineExceeded)") + return resp, statusErr{code: http.StatusGatewayTimeout, msg: "upstream request timed out"} + } + recordAPIResponseError(ctx, e.cfg, err) + + // Enhanced socket retry: Check if error is retryable (network timeout, connection reset, etc.) + retryCfg := defaultRetryConfig() + if isRetryableError(err) && attempt < retryCfg.MaxRetries { + delay := calculateRetryDelay(attempt, retryCfg) + logRetryAttempt(attempt, retryCfg.MaxRetries, fmt.Sprintf("socket error: %v", err), delay, endpointConfig.Name) + time.Sleep(delay) + continue + } + return resp, err } recordAPIResponseMetadata(ctx, e.cfg, httpResp.StatusCode, httpResp.Header.Clone()) @@ -401,6 +728,12 @@ func (e *KiroExecutor) executeWithRetry(ctx context.Context, auth *cliproxyauth. _ = httpResp.Body.Close() appendAPIResponseChunk(ctx, e.cfg, respBody) + // Record failure and set cooldown for 429 + rateLimiter.MarkTokenFailed(tokenKey) + cooldownDuration := kiroauth.CalculateCooldownFor429(attempt) + cooldownMgr.SetCooldown(tokenKey, cooldownDuration, kiroauth.CooldownReason429) + log.Warnf("kiro: rate limit hit (429), token %s set to cooldown for %v", tokenKey, cooldownDuration) + // Preserve last 429 so callers can correctly backoff when all endpoints are exhausted last429Err = statusErr{code: httpResp.StatusCode, msg: string(respBody)} @@ -412,13 +745,21 @@ func (e *KiroExecutor) executeWithRetry(ctx context.Context, auth *cliproxyauth. } // Handle 5xx server errors with exponential backoff retry + // Enhanced: Use retryConfig for consistent retry behavior if httpResp.StatusCode >= 500 && httpResp.StatusCode < 600 { respBody, _ := io.ReadAll(httpResp.Body) _ = httpResp.Body.Close() appendAPIResponseChunk(ctx, e.cfg, respBody) - if attempt < maxRetries { - // Exponential backoff: 1s, 2s, 4s... (max 30s) + retryCfg := defaultRetryConfig() + // Check if this specific 5xx code is retryable (502, 503, 504) + if isRetryableHTTPStatus(httpResp.StatusCode) && attempt < retryCfg.MaxRetries { + delay := calculateRetryDelay(attempt, retryCfg) + logRetryAttempt(attempt, retryCfg.MaxRetries, fmt.Sprintf("HTTP %d", httpResp.StatusCode), delay, endpointConfig.Name) + time.Sleep(delay) + continue + } else if attempt < maxRetries { + // Fallback for other 5xx errors (500, 501, etc.) backoff := time.Duration(1< 30*time.Second { backoff = 30 * time.Second @@ -492,7 +833,10 @@ func (e *KiroExecutor) executeWithRetry(ctx context.Context, auth *cliproxyauth. // Check for SUSPENDED status - return immediately without retry if strings.Contains(respBodyStr, "SUSPENDED") || strings.Contains(respBodyStr, "TEMPORARILY_SUSPENDED") { - log.Errorf("kiro: account is suspended, cannot proceed") + // Set long cooldown for suspended accounts + rateLimiter.CheckAndMarkSuspended(tokenKey, respBodyStr) + cooldownMgr.SetCooldown(tokenKey, kiroauth.LongCooldown, kiroauth.CooldownReasonSuspended) + log.Errorf("kiro: account is suspended, token %s set to cooldown for %v", tokenKey, kiroauth.LongCooldown) return resp, statusErr{code: httpResp.StatusCode, msg: "account suspended: " + string(respBody)} } @@ -581,6 +925,10 @@ func (e *KiroExecutor) executeWithRetry(ctx context.Context, auth *cliproxyauth. appendAPIResponseChunk(ctx, e.cfg, []byte(content)) reporter.publish(ctx, usageInfo) + // Record success for rate limiting + rateLimiter.MarkTokenSuccess(tokenKey) + log.Debugf("kiro: request successful, token %s marked as success", tokenKey) + // Build response in Claude format for Kiro translator // stopReason is extracted from upstream response by parseEventStream kiroResponse := kiroclaude.BuildClaudeResponse(content, toolUses, req.Model, usageInfo, stopReason) @@ -608,6 +956,24 @@ func (e *KiroExecutor) ExecuteStream(ctx context.Context, auth *cliproxyauth.Aut return nil, fmt.Errorf("kiro: access token not found in auth") } + // Rate limiting: get token key for tracking + tokenKey := getTokenKey(auth) + rateLimiter := kiroauth.GetGlobalRateLimiter() + cooldownMgr := kiroauth.GetGlobalCooldownManager() + + // Check if token is in cooldown period + if cooldownMgr.IsInCooldown(tokenKey) { + remaining := cooldownMgr.GetRemainingCooldown(tokenKey) + reason := cooldownMgr.GetCooldownReason(tokenKey) + log.Warnf("kiro: token %s is in cooldown (reason: %s), remaining: %v", tokenKey, reason, remaining) + return nil, fmt.Errorf("kiro: token is in cooldown for %v (reason: %s)", remaining, reason) + } + + // Wait for rate limiter before proceeding + log.Debugf("kiro: stream waiting for rate limiter for token %s", tokenKey) + rateLimiter.WaitForToken(tokenKey) + log.Debugf("kiro: stream rate limiter cleared for token %s", tokenKey) + reporter := newUsageReporter(ctx, e.Identifier(), req.Model, auth) defer reporter.trackFailure(ctx, &err) @@ -640,7 +1006,7 @@ func (e *KiroExecutor) ExecuteStream(ctx context.Context, auth *cliproxyauth.Aut // Execute stream with retry on 401/403 and 429 (quota exhausted) // Note: currentOrigin and kiroPayload are built inside executeStreamWithRetry for each endpoint - return e.executeStreamWithRetry(ctx, auth, req, opts, accessToken, effectiveProfileArn, nil, body, from, reporter, "", kiroModelID, isAgentic, isChatOnly) + return e.executeStreamWithRetry(ctx, auth, req, opts, accessToken, effectiveProfileArn, nil, body, from, reporter, "", kiroModelID, isAgentic, isChatOnly, tokenKey) } // executeStreamWithRetry performs the streaming HTTP request with automatic retry on auth errors. @@ -648,8 +1014,11 @@ func (e *KiroExecutor) ExecuteStream(ctx context.Context, auth *cliproxyauth.Aut // - Amazon Q endpoint (CLI origin) uses Amazon Q Developer quota // - CodeWhisperer endpoint (AI_EDITOR origin) uses Kiro IDE quota // Also supports multi-endpoint fallback similar to Antigravity implementation. -func (e *KiroExecutor) executeStreamWithRetry(ctx context.Context, auth *cliproxyauth.Auth, req cliproxyexecutor.Request, opts cliproxyexecutor.Options, accessToken, profileArn string, kiroPayload, body []byte, from sdktranslator.Format, reporter *usageReporter, currentOrigin, kiroModelID string, isAgentic, isChatOnly bool) (<-chan cliproxyexecutor.StreamChunk, error) { +// tokenKey is used for rate limiting and cooldown tracking. +func (e *KiroExecutor) executeStreamWithRetry(ctx context.Context, auth *cliproxyauth.Auth, req cliproxyexecutor.Request, opts cliproxyexecutor.Options, accessToken, profileArn string, kiroPayload, body []byte, from sdktranslator.Format, reporter *usageReporter, currentOrigin, kiroModelID string, isAgentic, isChatOnly bool, tokenKey string) (<-chan cliproxyexecutor.StreamChunk, error) { maxRetries := 2 // Allow retries for token refresh + endpoint fallback + rateLimiter := kiroauth.GetGlobalRateLimiter() + cooldownMgr := kiroauth.GetGlobalCooldownManager() endpointConfigs := getKiroEndpointConfigs(auth) var last429Err error @@ -667,6 +1036,13 @@ func (e *KiroExecutor) executeStreamWithRetry(ctx context.Context, auth *cliprox endpointIdx+1, len(endpointConfigs), url, endpointConfig.Name, currentOrigin) for attempt := 0; attempt <= maxRetries; attempt++ { + // Apply human-like delay before first streaming request (not on retries) + // This mimics natural user behavior patterns + // Note: Delay is NOT applied during streaming response - only before initial request + if attempt == 0 && endpointIdx == 0 { + kiroauth.ApplyHumanLikeDelay() + } + httpReq, err := http.NewRequestWithContext(ctx, http.MethodPost, url, bytes.NewReader(kiroPayload)) if err != nil { return nil, err @@ -677,20 +1053,9 @@ func (e *KiroExecutor) executeStreamWithRetry(ctx context.Context, auth *cliprox // Use endpoint-specific X-Amz-Target (critical for avoiding 403 errors) httpReq.Header.Set("X-Amz-Target", endpointConfig.AmzTarget) - // Use different headers based on auth type - // IDC auth uses Kiro IDE style headers (from kiro2api) - // Other auth types use Amazon Q CLI style headers - if isIDCAuth(auth) { - httpReq.Header.Set("User-Agent", kiroIDEUserAgent) - httpReq.Header.Set("X-Amz-User-Agent", kiroIDEAmzUserAgent) - httpReq.Header.Set("x-amzn-kiro-agent-mode", kiroIDEAgentModeSpec) - log.Debugf("kiro: using Kiro IDE headers for IDC auth") - } else { - httpReq.Header.Set("User-Agent", kiroUserAgent) - httpReq.Header.Set("X-Amz-User-Agent", kiroFullUserAgent) - httpReq.Header.Set("x-amzn-kiro-agent-mode", kiroAgentModeVibe) - } - httpReq.Header.Set("x-amzn-codewhisperer-optout", "true") + // Apply dynamic fingerprint-based headers + applyDynamicFingerprint(httpReq, auth) + httpReq.Header.Set("Amz-Sdk-Request", "attempt=1; max=3") httpReq.Header.Set("Amz-Sdk-Invocation-Id", uuid.New().String()) @@ -721,10 +1086,20 @@ func (e *KiroExecutor) executeStreamWithRetry(ctx context.Context, auth *cliprox AuthValue: authValue, }) - httpClient := newProxyAwareHTTPClient(ctx, e.cfg, auth, 0) + httpClient := newKiroHTTPClientWithPooling(ctx, e.cfg, auth, 0) httpResp, err := httpClient.Do(httpReq) if err != nil { recordAPIResponseError(ctx, e.cfg, err) + + // Enhanced socket retry for streaming: Check if error is retryable (network timeout, connection reset, etc.) + retryCfg := defaultRetryConfig() + if isRetryableError(err) && attempt < retryCfg.MaxRetries { + delay := calculateRetryDelay(attempt, retryCfg) + logRetryAttempt(attempt, retryCfg.MaxRetries, fmt.Sprintf("stream socket error: %v", err), delay, endpointConfig.Name) + time.Sleep(delay) + continue + } + return nil, err } recordAPIResponseMetadata(ctx, e.cfg, httpResp.StatusCode, httpResp.Header.Clone()) @@ -736,6 +1111,12 @@ func (e *KiroExecutor) executeStreamWithRetry(ctx context.Context, auth *cliprox _ = httpResp.Body.Close() appendAPIResponseChunk(ctx, e.cfg, respBody) + // Record failure and set cooldown for 429 + rateLimiter.MarkTokenFailed(tokenKey) + cooldownDuration := kiroauth.CalculateCooldownFor429(attempt) + cooldownMgr.SetCooldown(tokenKey, cooldownDuration, kiroauth.CooldownReason429) + log.Warnf("kiro: stream rate limit hit (429), token %s set to cooldown for %v", tokenKey, cooldownDuration) + // Preserve last 429 so callers can correctly backoff when all endpoints are exhausted last429Err = statusErr{code: httpResp.StatusCode, msg: string(respBody)} @@ -747,13 +1128,21 @@ func (e *KiroExecutor) executeStreamWithRetry(ctx context.Context, auth *cliprox } // Handle 5xx server errors with exponential backoff retry + // Enhanced: Use retryConfig for consistent retry behavior if httpResp.StatusCode >= 500 && httpResp.StatusCode < 600 { respBody, _ := io.ReadAll(httpResp.Body) _ = httpResp.Body.Close() appendAPIResponseChunk(ctx, e.cfg, respBody) - if attempt < maxRetries { - // Exponential backoff: 1s, 2s, 4s... (max 30s) + retryCfg := defaultRetryConfig() + // Check if this specific 5xx code is retryable (502, 503, 504) + if isRetryableHTTPStatus(httpResp.StatusCode) && attempt < retryCfg.MaxRetries { + delay := calculateRetryDelay(attempt, retryCfg) + logRetryAttempt(attempt, retryCfg.MaxRetries, fmt.Sprintf("stream HTTP %d", httpResp.StatusCode), delay, endpointConfig.Name) + time.Sleep(delay) + continue + } else if attempt < maxRetries { + // Fallback for other 5xx errors (500, 501, etc.) backoff := time.Duration(1< 30*time.Second { backoff = 30 * time.Second @@ -840,7 +1229,10 @@ func (e *KiroExecutor) executeStreamWithRetry(ctx context.Context, auth *cliprox // Check for SUSPENDED status - return immediately without retry if strings.Contains(respBodyStr, "SUSPENDED") || strings.Contains(respBodyStr, "TEMPORARILY_SUSPENDED") { - log.Errorf("kiro: account is suspended, cannot proceed") + // Set long cooldown for suspended accounts + rateLimiter.CheckAndMarkSuspended(tokenKey, respBodyStr) + cooldownMgr.SetCooldown(tokenKey, kiroauth.LongCooldown, kiroauth.CooldownReasonSuspended) + log.Errorf("kiro: stream account is suspended, token %s set to cooldown for %v", tokenKey, kiroauth.LongCooldown) return nil, statusErr{code: httpResp.StatusCode, msg: "account suspended: " + string(respBody)} } @@ -890,6 +1282,11 @@ func (e *KiroExecutor) executeStreamWithRetry(ctx context.Context, auth *cliprox out := make(chan cliproxyexecutor.StreamChunk) + // Record success immediately since connection was established successfully + // Streaming errors will be handled separately + rateLimiter.MarkTokenSuccess(tokenKey) + log.Debugf("kiro: stream request successful, token %s marked as success", tokenKey) + go func(resp *http.Response, thinkingEnabled bool) { defer close(out) defer func() { diff --git a/internal/translator/kiro/common/utf8_stream.go b/internal/translator/kiro/common/utf8_stream.go new file mode 100644 index 00000000..b8d24c82 --- /dev/null +++ b/internal/translator/kiro/common/utf8_stream.go @@ -0,0 +1,97 @@ +package common + +import ( + "unicode/utf8" +) + +type UTF8StreamParser struct { + buffer []byte +} + +func NewUTF8StreamParser() *UTF8StreamParser { + return &UTF8StreamParser{ + buffer: make([]byte, 0, 64), + } +} + +func (p *UTF8StreamParser) Write(data []byte) { + p.buffer = append(p.buffer, data...) +} + +func (p *UTF8StreamParser) Read() (string, bool) { + if len(p.buffer) == 0 { + return "", false + } + + validLen := p.findValidUTF8End(p.buffer) + if validLen == 0 { + return "", false + } + + result := string(p.buffer[:validLen]) + p.buffer = p.buffer[validLen:] + + return result, true +} + +func (p *UTF8StreamParser) Flush() string { + if len(p.buffer) == 0 { + return "" + } + result := string(p.buffer) + p.buffer = p.buffer[:0] + return result +} + +func (p *UTF8StreamParser) Reset() { + p.buffer = p.buffer[:0] +} + +func (p *UTF8StreamParser) findValidUTF8End(data []byte) int { + if len(data) == 0 { + return 0 + } + + end := len(data) + for i := 1; i <= 3 && i <= len(data); i++ { + b := data[len(data)-i] + if b&0x80 == 0 { + break + } + if b&0xC0 == 0xC0 { + size := p.utf8CharSize(b) + available := i + if size > available { + end = len(data) - i + } + break + } + } + + if end > 0 && !utf8.Valid(data[:end]) { + for i := end - 1; i >= 0; i-- { + if utf8.Valid(data[:i+1]) { + return i + 1 + } + } + return 0 + } + + return end +} + +func (p *UTF8StreamParser) utf8CharSize(b byte) int { + if b&0x80 == 0 { + return 1 + } + if b&0xE0 == 0xC0 { + return 2 + } + if b&0xF0 == 0xE0 { + return 3 + } + if b&0xF8 == 0xF0 { + return 4 + } + return 1 +} diff --git a/internal/translator/kiro/common/utf8_stream_test.go b/internal/translator/kiro/common/utf8_stream_test.go new file mode 100644 index 00000000..23e80989 --- /dev/null +++ b/internal/translator/kiro/common/utf8_stream_test.go @@ -0,0 +1,402 @@ +package common + +import ( + "strings" + "sync" + "testing" + "unicode/utf8" +) + +func TestNewUTF8StreamParser(t *testing.T) { + p := NewUTF8StreamParser() + if p == nil { + t.Fatal("expected non-nil UTF8StreamParser") + } + if p.buffer == nil { + t.Error("expected non-nil buffer") + } +} + +func TestWrite(t *testing.T) { + p := NewUTF8StreamParser() + p.Write([]byte("hello")) + + result, ok := p.Read() + if !ok { + t.Error("expected ok to be true") + } + if result != "hello" { + t.Errorf("expected 'hello', got '%s'", result) + } +} + +func TestWrite_MultipleWrites(t *testing.T) { + p := NewUTF8StreamParser() + p.Write([]byte("hel")) + p.Write([]byte("lo")) + + result, ok := p.Read() + if !ok { + t.Error("expected ok to be true") + } + if result != "hello" { + t.Errorf("expected 'hello', got '%s'", result) + } +} + +func TestRead_EmptyBuffer(t *testing.T) { + p := NewUTF8StreamParser() + result, ok := p.Read() + if ok { + t.Error("expected ok to be false for empty buffer") + } + if result != "" { + t.Errorf("expected empty string, got '%s'", result) + } +} + +func TestRead_IncompleteUTF8(t *testing.T) { + p := NewUTF8StreamParser() + + // Write incomplete multi-byte UTF-8 character + // 中 (U+4E2D) = E4 B8 AD + p.Write([]byte{0xE4, 0xB8}) + + result, ok := p.Read() + if ok { + t.Error("expected ok to be false for incomplete UTF-8") + } + if result != "" { + t.Errorf("expected empty string, got '%s'", result) + } + + // Complete the character + p.Write([]byte{0xAD}) + result, ok = p.Read() + if !ok { + t.Error("expected ok to be true after completing UTF-8") + } + if result != "中" { + t.Errorf("expected '中', got '%s'", result) + } +} + +func TestRead_MixedASCIIAndUTF8(t *testing.T) { + p := NewUTF8StreamParser() + p.Write([]byte("Hello 世界")) + + result, ok := p.Read() + if !ok { + t.Error("expected ok to be true") + } + if result != "Hello 世界" { + t.Errorf("expected 'Hello 世界', got '%s'", result) + } +} + +func TestRead_PartialMultibyteAtEnd(t *testing.T) { + p := NewUTF8StreamParser() + // "Hello" + partial "世" (E4 B8 96) + p.Write([]byte("Hello")) + p.Write([]byte{0xE4, 0xB8}) + + result, ok := p.Read() + if !ok { + t.Error("expected ok to be true for valid portion") + } + if result != "Hello" { + t.Errorf("expected 'Hello', got '%s'", result) + } + + // Complete the character + p.Write([]byte{0x96}) + result, ok = p.Read() + if !ok { + t.Error("expected ok to be true after completing") + } + if result != "世" { + t.Errorf("expected '世', got '%s'", result) + } +} + +func TestFlush(t *testing.T) { + p := NewUTF8StreamParser() + p.Write([]byte("hello")) + + result := p.Flush() + if result != "hello" { + t.Errorf("expected 'hello', got '%s'", result) + } + + // Verify buffer is cleared + result2, ok := p.Read() + if ok { + t.Error("expected ok to be false after flush") + } + if result2 != "" { + t.Errorf("expected empty string after flush, got '%s'", result2) + } +} + +func TestFlush_EmptyBuffer(t *testing.T) { + p := NewUTF8StreamParser() + result := p.Flush() + if result != "" { + t.Errorf("expected empty string, got '%s'", result) + } +} + +func TestFlush_IncompleteUTF8(t *testing.T) { + p := NewUTF8StreamParser() + p.Write([]byte{0xE4, 0xB8}) + + result := p.Flush() + // Flush returns everything including incomplete bytes + if len(result) != 2 { + t.Errorf("expected 2 bytes flushed, got %d", len(result)) + } +} + +func TestReset(t *testing.T) { + p := NewUTF8StreamParser() + p.Write([]byte("hello")) + p.Reset() + + result, ok := p.Read() + if ok { + t.Error("expected ok to be false after reset") + } + if result != "" { + t.Errorf("expected empty string after reset, got '%s'", result) + } +} + +func TestUtf8CharSize(t *testing.T) { + p := NewUTF8StreamParser() + + testCases := []struct { + b byte + expected int + }{ + {0x00, 1}, // ASCII + {0x7F, 1}, // ASCII max + {0xC0, 2}, // 2-byte start + {0xDF, 2}, // 2-byte start + {0xE0, 3}, // 3-byte start + {0xEF, 3}, // 3-byte start + {0xF0, 4}, // 4-byte start + {0xF7, 4}, // 4-byte start + {0x80, 1}, // Continuation byte (fallback) + } + + for _, tc := range testCases { + size := p.utf8CharSize(tc.b) + if size != tc.expected { + t.Errorf("utf8CharSize(0x%X) = %d, expected %d", tc.b, size, tc.expected) + } + } +} + +func TestStreamingScenario(t *testing.T) { + p := NewUTF8StreamParser() + + // Simulate streaming: "Hello, 世界! 🌍" + chunks := [][]byte{ + []byte("Hello, "), + {0xE4, 0xB8}, // partial 世 + {0x96, 0xE7}, // complete 世, partial 界 + {0x95, 0x8C}, // complete 界 + []byte("! "), + {0xF0, 0x9F}, // partial 🌍 + {0x8C, 0x8D}, // complete 🌍 + } + + var results []string + for _, chunk := range chunks { + p.Write(chunk) + if result, ok := p.Read(); ok { + results = append(results, result) + } + } + + combined := strings.Join(results, "") + if combined != "Hello, 世界! 🌍" { + t.Errorf("expected 'Hello, 世界! 🌍', got '%s'", combined) + } +} + +func TestValidUTF8Output(t *testing.T) { + p := NewUTF8StreamParser() + + testStrings := []string{ + "Hello World", + "你好世界", + "こんにちは", + "🎉🎊🎁", + "Mixed 混合 Текст ტექსტი", + } + + for _, s := range testStrings { + p.Reset() + p.Write([]byte(s)) + result, ok := p.Read() + if !ok { + t.Errorf("expected ok for '%s'", s) + } + if !utf8.ValidString(result) { + t.Errorf("invalid UTF-8 output for input '%s'", s) + } + if result != s { + t.Errorf("expected '%s', got '%s'", s, result) + } + } +} + +func TestLargeData(t *testing.T) { + p := NewUTF8StreamParser() + + // Generate large UTF-8 string + var builder strings.Builder + for i := 0; i < 1000; i++ { + builder.WriteString("Hello 世界! ") + } + largeString := builder.String() + + p.Write([]byte(largeString)) + result, ok := p.Read() + if !ok { + t.Error("expected ok for large data") + } + if result != largeString { + t.Error("large data mismatch") + } +} + +func TestByteByByteWriting(t *testing.T) { + p := NewUTF8StreamParser() + input := "Hello 世界" + inputBytes := []byte(input) + + var results []string + for _, b := range inputBytes { + p.Write([]byte{b}) + if result, ok := p.Read(); ok { + results = append(results, result) + } + } + + combined := strings.Join(results, "") + if combined != input { + t.Errorf("expected '%s', got '%s'", input, combined) + } +} + +func TestEmoji4ByteUTF8(t *testing.T) { + p := NewUTF8StreamParser() + + // 🎉 = F0 9F 8E 89 + emoji := "🎉" + emojiBytes := []byte(emoji) + + for i := 0; i < len(emojiBytes)-1; i++ { + p.Write(emojiBytes[i : i+1]) + result, ok := p.Read() + if ok && result != "" { + t.Errorf("unexpected output before emoji complete: '%s'", result) + } + } + + p.Write(emojiBytes[len(emojiBytes)-1:]) + result, ok := p.Read() + if !ok { + t.Error("expected ok after completing emoji") + } + if result != emoji { + t.Errorf("expected '%s', got '%s'", emoji, result) + } +} + +func TestContinuationBytesOnly(t *testing.T) { + p := NewUTF8StreamParser() + + // Write only continuation bytes (invalid UTF-8) + p.Write([]byte{0x80, 0x80, 0x80}) + + result, ok := p.Read() + // Should handle gracefully - either return nothing or return the bytes + _ = result + _ = ok +} + +func TestUTF8StreamParser_ConcurrentSafety(t *testing.T) { + // Note: UTF8StreamParser doesn't have built-in locks, + // so this test verifies it works with external synchronization + p := NewUTF8StreamParser() + var mu sync.Mutex + const numGoroutines = 10 + const numOperations = 100 + + var wg sync.WaitGroup + wg.Add(numGoroutines) + + for i := 0; i < numGoroutines; i++ { + go func() { + defer wg.Done() + for j := 0; j < numOperations; j++ { + mu.Lock() + switch j % 4 { + case 0: + p.Write([]byte("test")) + case 1: + p.Read() + case 2: + p.Flush() + case 3: + p.Reset() + } + mu.Unlock() + } + }() + } + + wg.Wait() +} + +func TestConsecutiveReads(t *testing.T) { + p := NewUTF8StreamParser() + p.Write([]byte("hello")) + + result1, ok1 := p.Read() + if !ok1 || result1 != "hello" { + t.Error("first read failed") + } + + result2, ok2 := p.Read() + if ok2 || result2 != "" { + t.Error("second read should return empty") + } +} + +func TestFlushThenWrite(t *testing.T) { + p := NewUTF8StreamParser() + p.Write([]byte("first")) + p.Flush() + p.Write([]byte("second")) + + result, ok := p.Read() + if !ok || result != "second" { + t.Errorf("expected 'second', got '%s'", result) + } +} + +func TestResetThenWrite(t *testing.T) { + p := NewUTF8StreamParser() + p.Write([]byte("first")) + p.Reset() + p.Write([]byte("second")) + + result, ok := p.Read() + if !ok || result != "second" { + t.Errorf("expected 'second', got '%s'", result) + } +} diff --git a/internal/watcher/events.go b/internal/watcher/events.go index eb428353..fb96ad2a 100644 --- a/internal/watcher/events.go +++ b/internal/watcher/events.go @@ -170,7 +170,9 @@ func (w *Watcher) handleKiroIDETokenChange(event fsnotify.Event) { } } - tokenData, err := kiroauth.LoadKiroIDEToken() + // Use retry logic to handle file lock contention (e.g., Kiro IDE writing the file) + // This prevents "being used by another process" errors on Windows + tokenData, err := kiroauth.LoadKiroIDETokenWithRetry(10, 50*time.Millisecond) if err != nil { log.Debugf("failed to load Kiro IDE token after change: %v", err) return diff --git a/sdk/auth/kiro.go b/sdk/auth/kiro.go index b75cd28e..7747c777 100644 --- a/sdk/auth/kiro.go +++ b/sdk/auth/kiro.go @@ -12,9 +12,9 @@ import ( ) // extractKiroIdentifier extracts a meaningful identifier for file naming. -// Returns account name if provided, otherwise profile ARN ID. +// Returns account name if provided, otherwise profile ARN ID, then client ID. // All extracted values are sanitized to prevent path injection attacks. -func extractKiroIdentifier(accountName, profileArn string) string { +func extractKiroIdentifier(accountName, profileArn, clientID string) string { // Priority 1: Use account name if provided if accountName != "" { return kiroauth.SanitizeEmailForFilename(accountName) @@ -29,6 +29,11 @@ func extractKiroIdentifier(accountName, profileArn string) string { } } + // Priority 3: Use client ID (for IDC auth without email/profileArn) + if clientID != "" { + return kiroauth.SanitizeEmailForFilename(clientID) + } + // Fallback: timestamp return fmt.Sprintf("%d", time.Now().UnixNano()%100000) } @@ -62,7 +67,7 @@ func (a *KiroAuthenticator) createAuthRecord(tokenData *kiroauth.KiroTokenData, } // Extract identifier for file naming - idPart := extractKiroIdentifier(tokenData.Email, tokenData.ProfileArn) + idPart := extractKiroIdentifier(tokenData.Email, tokenData.ProfileArn, tokenData.ClientID) // Determine label based on auth method label := fmt.Sprintf("kiro-%s", source) @@ -173,7 +178,7 @@ func (a *KiroAuthenticator) LoginWithAuthCode(ctx context.Context, cfg *config.C } // Extract identifier for file naming - idPart := extractKiroIdentifier(tokenData.Email, tokenData.ProfileArn) + idPart := extractKiroIdentifier(tokenData.Email, tokenData.ProfileArn, tokenData.ClientID) now := time.Now() fileName := fmt.Sprintf("kiro-aws-%s.json", idPart) @@ -217,129 +222,17 @@ func (a *KiroAuthenticator) LoginWithAuthCode(ctx context.Context, cfg *config.C } // LoginWithGoogle performs OAuth login for Kiro with Google. -// This uses a custom protocol handler (kiro://) to receive the callback. +// NOTE: Google login is not available for third-party applications due to AWS Cognito restrictions. +// Please use AWS Builder ID or import your token from Kiro IDE. func (a *KiroAuthenticator) LoginWithGoogle(ctx context.Context, cfg *config.Config, opts *LoginOptions) (*coreauth.Auth, error) { - if cfg == nil { - return nil, fmt.Errorf("kiro auth: configuration is required") - } - - oauth := kiroauth.NewKiroOAuth(cfg) - - // Use Google OAuth flow with protocol handler - tokenData, err := oauth.LoginWithGoogle(ctx) - if err != nil { - return nil, fmt.Errorf("google login failed: %w", err) - } - - // Parse expires_at - expiresAt, err := time.Parse(time.RFC3339, tokenData.ExpiresAt) - if err != nil { - expiresAt = time.Now().Add(1 * time.Hour) - } - - // Extract identifier for file naming - idPart := extractKiroIdentifier(tokenData.Email, tokenData.ProfileArn) - - now := time.Now() - fileName := fmt.Sprintf("kiro-google-%s.json", idPart) - - record := &coreauth.Auth{ - ID: fileName, - Provider: "kiro", - FileName: fileName, - Label: "kiro-google", - Status: coreauth.StatusActive, - CreatedAt: now, - UpdatedAt: now, - Metadata: map[string]any{ - "type": "kiro", - "access_token": tokenData.AccessToken, - "refresh_token": tokenData.RefreshToken, - "profile_arn": tokenData.ProfileArn, - "expires_at": tokenData.ExpiresAt, - "auth_method": tokenData.AuthMethod, - "provider": tokenData.Provider, - "email": tokenData.Email, - }, - Attributes: map[string]string{ - "profile_arn": tokenData.ProfileArn, - "source": "google-oauth", - "email": tokenData.Email, - }, - // NextRefreshAfter is aligned with RefreshLead (5min) - NextRefreshAfter: expiresAt.Add(-5 * time.Minute), - } - - if tokenData.Email != "" { - fmt.Printf("\n✓ Kiro Google authentication completed successfully! (Account: %s)\n", tokenData.Email) - } else { - fmt.Println("\n✓ Kiro Google authentication completed successfully!") - } - - return record, nil + return nil, fmt.Errorf("Google login is not available for third-party applications due to AWS Cognito restrictions.\n\nAlternatives:\n 1. Use AWS Builder ID: cliproxy kiro --builder-id\n 2. Import token from Kiro IDE: cliproxy kiro --import\n\nTo get a token from Kiro IDE:\n 1. Open Kiro IDE and login with Google\n 2. Find: ~/.kiro/kiro-auth-token.json\n 3. Run: cliproxy kiro --import") } // LoginWithGitHub performs OAuth login for Kiro with GitHub. -// This uses a custom protocol handler (kiro://) to receive the callback. +// NOTE: GitHub login is not available for third-party applications due to AWS Cognito restrictions. +// Please use AWS Builder ID or import your token from Kiro IDE. func (a *KiroAuthenticator) LoginWithGitHub(ctx context.Context, cfg *config.Config, opts *LoginOptions) (*coreauth.Auth, error) { - if cfg == nil { - return nil, fmt.Errorf("kiro auth: configuration is required") - } - - oauth := kiroauth.NewKiroOAuth(cfg) - - // Use GitHub OAuth flow with protocol handler - tokenData, err := oauth.LoginWithGitHub(ctx) - if err != nil { - return nil, fmt.Errorf("github login failed: %w", err) - } - - // Parse expires_at - expiresAt, err := time.Parse(time.RFC3339, tokenData.ExpiresAt) - if err != nil { - expiresAt = time.Now().Add(1 * time.Hour) - } - - // Extract identifier for file naming - idPart := extractKiroIdentifier(tokenData.Email, tokenData.ProfileArn) - - now := time.Now() - fileName := fmt.Sprintf("kiro-github-%s.json", idPart) - - record := &coreauth.Auth{ - ID: fileName, - Provider: "kiro", - FileName: fileName, - Label: "kiro-github", - Status: coreauth.StatusActive, - CreatedAt: now, - UpdatedAt: now, - Metadata: map[string]any{ - "type": "kiro", - "access_token": tokenData.AccessToken, - "refresh_token": tokenData.RefreshToken, - "profile_arn": tokenData.ProfileArn, - "expires_at": tokenData.ExpiresAt, - "auth_method": tokenData.AuthMethod, - "provider": tokenData.Provider, - "email": tokenData.Email, - }, - Attributes: map[string]string{ - "profile_arn": tokenData.ProfileArn, - "source": "github-oauth", - "email": tokenData.Email, - }, - // NextRefreshAfter is aligned with RefreshLead (5min) - NextRefreshAfter: expiresAt.Add(-5 * time.Minute), - } - - if tokenData.Email != "" { - fmt.Printf("\n✓ Kiro GitHub authentication completed successfully! (Account: %s)\n", tokenData.Email) - } else { - fmt.Println("\n✓ Kiro GitHub authentication completed successfully!") - } - - return record, nil + return nil, fmt.Errorf("GitHub login is not available for third-party applications due to AWS Cognito restrictions.\n\nAlternatives:\n 1. Use AWS Builder ID: cliproxy kiro --builder-id\n 2. Import token from Kiro IDE: cliproxy kiro --import\n\nTo get a token from Kiro IDE:\n 1. Open Kiro IDE and login with GitHub\n 2. Find: ~/.kiro/kiro-auth-token.json\n 3. Run: cliproxy kiro --import") } // ImportFromKiroIDE imports token from Kiro IDE's token file. @@ -361,7 +254,7 @@ func (a *KiroAuthenticator) ImportFromKiroIDE(ctx context.Context, cfg *config.C } // Extract identifier for file naming - idPart := extractKiroIdentifier(tokenData.Email, tokenData.ProfileArn) + idPart := extractKiroIdentifier(tokenData.Email, tokenData.ProfileArn, tokenData.ClientID) // Sanitize provider to prevent path traversal (defense-in-depth) provider := kiroauth.SanitizeEmailForFilename(strings.ToLower(strings.TrimSpace(tokenData.Provider))) if provider == "" { @@ -387,12 +280,17 @@ func (a *KiroAuthenticator) ImportFromKiroIDE(ctx context.Context, cfg *config.C "expires_at": tokenData.ExpiresAt, "auth_method": tokenData.AuthMethod, "provider": tokenData.Provider, + "client_id": tokenData.ClientID, + "client_secret": tokenData.ClientSecret, "email": tokenData.Email, + "region": tokenData.Region, + "start_url": tokenData.StartURL, }, Attributes: map[string]string{ "profile_arn": tokenData.ProfileArn, "source": "kiro-ide-import", "email": tokenData.Email, + "region": tokenData.Region, }, // NextRefreshAfter is aligned with RefreshLead (5min) NextRefreshAfter: expiresAt.Add(-5 * time.Minute), diff --git a/sdk/auth/kiro.go.bak b/sdk/auth/kiro.go.bak new file mode 100644 index 00000000..b75cd28e --- /dev/null +++ b/sdk/auth/kiro.go.bak @@ -0,0 +1,470 @@ +package auth + +import ( + "context" + "fmt" + "strings" + "time" + + kiroauth "github.com/router-for-me/CLIProxyAPI/v6/internal/auth/kiro" + "github.com/router-for-me/CLIProxyAPI/v6/internal/config" + coreauth "github.com/router-for-me/CLIProxyAPI/v6/sdk/cliproxy/auth" +) + +// extractKiroIdentifier extracts a meaningful identifier for file naming. +// Returns account name if provided, otherwise profile ARN ID. +// All extracted values are sanitized to prevent path injection attacks. +func extractKiroIdentifier(accountName, profileArn string) string { + // Priority 1: Use account name if provided + if accountName != "" { + return kiroauth.SanitizeEmailForFilename(accountName) + } + + // Priority 2: Use profile ARN ID part (sanitized to prevent path injection) + if profileArn != "" { + parts := strings.Split(profileArn, "/") + if len(parts) >= 2 { + // Sanitize the ARN component to prevent path traversal + return kiroauth.SanitizeEmailForFilename(parts[len(parts)-1]) + } + } + + // Fallback: timestamp + return fmt.Sprintf("%d", time.Now().UnixNano()%100000) +} + +// KiroAuthenticator implements OAuth authentication for Kiro with Google login. +type KiroAuthenticator struct{} + +// NewKiroAuthenticator constructs a Kiro authenticator. +func NewKiroAuthenticator() *KiroAuthenticator { + return &KiroAuthenticator{} +} + +// Provider returns the provider key for the authenticator. +func (a *KiroAuthenticator) Provider() string { + return "kiro" +} + +// RefreshLead indicates how soon before expiry a refresh should be attempted. +// Set to 5 minutes to match Antigravity and avoid frequent refresh checks while still ensuring timely token refresh. +func (a *KiroAuthenticator) RefreshLead() *time.Duration { + d := 5 * time.Minute + return &d +} + +// createAuthRecord creates an auth record from token data. +func (a *KiroAuthenticator) createAuthRecord(tokenData *kiroauth.KiroTokenData, source string) (*coreauth.Auth, error) { + // Parse expires_at + expiresAt, err := time.Parse(time.RFC3339, tokenData.ExpiresAt) + if err != nil { + expiresAt = time.Now().Add(1 * time.Hour) + } + + // Extract identifier for file naming + idPart := extractKiroIdentifier(tokenData.Email, tokenData.ProfileArn) + + // Determine label based on auth method + label := fmt.Sprintf("kiro-%s", source) + if tokenData.AuthMethod == "idc" { + label = "kiro-idc" + } + + now := time.Now() + fileName := fmt.Sprintf("%s-%s.json", label, idPart) + + metadata := map[string]any{ + "type": "kiro", + "access_token": tokenData.AccessToken, + "refresh_token": tokenData.RefreshToken, + "profile_arn": tokenData.ProfileArn, + "expires_at": tokenData.ExpiresAt, + "auth_method": tokenData.AuthMethod, + "provider": tokenData.Provider, + "client_id": tokenData.ClientID, + "client_secret": tokenData.ClientSecret, + "email": tokenData.Email, + } + + // Add IDC-specific fields if present + if tokenData.StartURL != "" { + metadata["start_url"] = tokenData.StartURL + } + if tokenData.Region != "" { + metadata["region"] = tokenData.Region + } + + attributes := map[string]string{ + "profile_arn": tokenData.ProfileArn, + "source": source, + "email": tokenData.Email, + } + + // Add IDC-specific attributes if present + if tokenData.AuthMethod == "idc" { + attributes["source"] = "aws-idc" + if tokenData.StartURL != "" { + attributes["start_url"] = tokenData.StartURL + } + if tokenData.Region != "" { + attributes["region"] = tokenData.Region + } + } + + record := &coreauth.Auth{ + ID: fileName, + Provider: "kiro", + FileName: fileName, + Label: label, + Status: coreauth.StatusActive, + CreatedAt: now, + UpdatedAt: now, + Metadata: metadata, + Attributes: attributes, + // NextRefreshAfter is aligned with RefreshLead (5min) + NextRefreshAfter: expiresAt.Add(-5 * time.Minute), + } + + if tokenData.Email != "" { + fmt.Printf("\n✓ Kiro authentication completed successfully! (Account: %s)\n", tokenData.Email) + } else { + fmt.Println("\n✓ Kiro authentication completed successfully!") + } + + return record, nil +} + +// Login performs OAuth login for Kiro with AWS (Builder ID or IDC). +// This shows a method selection prompt and handles both flows. +func (a *KiroAuthenticator) Login(ctx context.Context, cfg *config.Config, opts *LoginOptions) (*coreauth.Auth, error) { + if cfg == nil { + return nil, fmt.Errorf("kiro auth: configuration is required") + } + + // Use the unified method selection flow (Builder ID or IDC) + ssoClient := kiroauth.NewSSOOIDCClient(cfg) + tokenData, err := ssoClient.LoginWithMethodSelection(ctx) + if err != nil { + return nil, fmt.Errorf("login failed: %w", err) + } + + return a.createAuthRecord(tokenData, "aws") +} + +// LoginWithAuthCode performs OAuth login for Kiro with AWS Builder ID using authorization code flow. +// This provides a better UX than device code flow as it uses automatic browser callback. +func (a *KiroAuthenticator) LoginWithAuthCode(ctx context.Context, cfg *config.Config, opts *LoginOptions) (*coreauth.Auth, error) { + if cfg == nil { + return nil, fmt.Errorf("kiro auth: configuration is required") + } + + oauth := kiroauth.NewKiroOAuth(cfg) + + // Use AWS Builder ID authorization code flow + tokenData, err := oauth.LoginWithBuilderIDAuthCode(ctx) + if err != nil { + return nil, fmt.Errorf("login failed: %w", err) + } + + // Parse expires_at + expiresAt, err := time.Parse(time.RFC3339, tokenData.ExpiresAt) + if err != nil { + expiresAt = time.Now().Add(1 * time.Hour) + } + + // Extract identifier for file naming + idPart := extractKiroIdentifier(tokenData.Email, tokenData.ProfileArn) + + now := time.Now() + fileName := fmt.Sprintf("kiro-aws-%s.json", idPart) + + record := &coreauth.Auth{ + ID: fileName, + Provider: "kiro", + FileName: fileName, + Label: "kiro-aws", + Status: coreauth.StatusActive, + CreatedAt: now, + UpdatedAt: now, + Metadata: map[string]any{ + "type": "kiro", + "access_token": tokenData.AccessToken, + "refresh_token": tokenData.RefreshToken, + "profile_arn": tokenData.ProfileArn, + "expires_at": tokenData.ExpiresAt, + "auth_method": tokenData.AuthMethod, + "provider": tokenData.Provider, + "client_id": tokenData.ClientID, + "client_secret": tokenData.ClientSecret, + "email": tokenData.Email, + }, + Attributes: map[string]string{ + "profile_arn": tokenData.ProfileArn, + "source": "aws-builder-id-authcode", + "email": tokenData.Email, + }, + // NextRefreshAfter is aligned with RefreshLead (5min) + NextRefreshAfter: expiresAt.Add(-5 * time.Minute), + } + + if tokenData.Email != "" { + fmt.Printf("\n✓ Kiro authentication completed successfully! (Account: %s)\n", tokenData.Email) + } else { + fmt.Println("\n✓ Kiro authentication completed successfully!") + } + + return record, nil +} + +// LoginWithGoogle performs OAuth login for Kiro with Google. +// This uses a custom protocol handler (kiro://) to receive the callback. +func (a *KiroAuthenticator) LoginWithGoogle(ctx context.Context, cfg *config.Config, opts *LoginOptions) (*coreauth.Auth, error) { + if cfg == nil { + return nil, fmt.Errorf("kiro auth: configuration is required") + } + + oauth := kiroauth.NewKiroOAuth(cfg) + + // Use Google OAuth flow with protocol handler + tokenData, err := oauth.LoginWithGoogle(ctx) + if err != nil { + return nil, fmt.Errorf("google login failed: %w", err) + } + + // Parse expires_at + expiresAt, err := time.Parse(time.RFC3339, tokenData.ExpiresAt) + if err != nil { + expiresAt = time.Now().Add(1 * time.Hour) + } + + // Extract identifier for file naming + idPart := extractKiroIdentifier(tokenData.Email, tokenData.ProfileArn) + + now := time.Now() + fileName := fmt.Sprintf("kiro-google-%s.json", idPart) + + record := &coreauth.Auth{ + ID: fileName, + Provider: "kiro", + FileName: fileName, + Label: "kiro-google", + Status: coreauth.StatusActive, + CreatedAt: now, + UpdatedAt: now, + Metadata: map[string]any{ + "type": "kiro", + "access_token": tokenData.AccessToken, + "refresh_token": tokenData.RefreshToken, + "profile_arn": tokenData.ProfileArn, + "expires_at": tokenData.ExpiresAt, + "auth_method": tokenData.AuthMethod, + "provider": tokenData.Provider, + "email": tokenData.Email, + }, + Attributes: map[string]string{ + "profile_arn": tokenData.ProfileArn, + "source": "google-oauth", + "email": tokenData.Email, + }, + // NextRefreshAfter is aligned with RefreshLead (5min) + NextRefreshAfter: expiresAt.Add(-5 * time.Minute), + } + + if tokenData.Email != "" { + fmt.Printf("\n✓ Kiro Google authentication completed successfully! (Account: %s)\n", tokenData.Email) + } else { + fmt.Println("\n✓ Kiro Google authentication completed successfully!") + } + + return record, nil +} + +// LoginWithGitHub performs OAuth login for Kiro with GitHub. +// This uses a custom protocol handler (kiro://) to receive the callback. +func (a *KiroAuthenticator) LoginWithGitHub(ctx context.Context, cfg *config.Config, opts *LoginOptions) (*coreauth.Auth, error) { + if cfg == nil { + return nil, fmt.Errorf("kiro auth: configuration is required") + } + + oauth := kiroauth.NewKiroOAuth(cfg) + + // Use GitHub OAuth flow with protocol handler + tokenData, err := oauth.LoginWithGitHub(ctx) + if err != nil { + return nil, fmt.Errorf("github login failed: %w", err) + } + + // Parse expires_at + expiresAt, err := time.Parse(time.RFC3339, tokenData.ExpiresAt) + if err != nil { + expiresAt = time.Now().Add(1 * time.Hour) + } + + // Extract identifier for file naming + idPart := extractKiroIdentifier(tokenData.Email, tokenData.ProfileArn) + + now := time.Now() + fileName := fmt.Sprintf("kiro-github-%s.json", idPart) + + record := &coreauth.Auth{ + ID: fileName, + Provider: "kiro", + FileName: fileName, + Label: "kiro-github", + Status: coreauth.StatusActive, + CreatedAt: now, + UpdatedAt: now, + Metadata: map[string]any{ + "type": "kiro", + "access_token": tokenData.AccessToken, + "refresh_token": tokenData.RefreshToken, + "profile_arn": tokenData.ProfileArn, + "expires_at": tokenData.ExpiresAt, + "auth_method": tokenData.AuthMethod, + "provider": tokenData.Provider, + "email": tokenData.Email, + }, + Attributes: map[string]string{ + "profile_arn": tokenData.ProfileArn, + "source": "github-oauth", + "email": tokenData.Email, + }, + // NextRefreshAfter is aligned with RefreshLead (5min) + NextRefreshAfter: expiresAt.Add(-5 * time.Minute), + } + + if tokenData.Email != "" { + fmt.Printf("\n✓ Kiro GitHub authentication completed successfully! (Account: %s)\n", tokenData.Email) + } else { + fmt.Println("\n✓ Kiro GitHub authentication completed successfully!") + } + + return record, nil +} + +// ImportFromKiroIDE imports token from Kiro IDE's token file. +func (a *KiroAuthenticator) ImportFromKiroIDE(ctx context.Context, cfg *config.Config) (*coreauth.Auth, error) { + tokenData, err := kiroauth.LoadKiroIDEToken() + if err != nil { + return nil, fmt.Errorf("failed to load Kiro IDE token: %w", err) + } + + // Parse expires_at + expiresAt, err := time.Parse(time.RFC3339, tokenData.ExpiresAt) + if err != nil { + expiresAt = time.Now().Add(1 * time.Hour) + } + + // Extract email from JWT if not already set (for imported tokens) + if tokenData.Email == "" { + tokenData.Email = kiroauth.ExtractEmailFromJWT(tokenData.AccessToken) + } + + // Extract identifier for file naming + idPart := extractKiroIdentifier(tokenData.Email, tokenData.ProfileArn) + // Sanitize provider to prevent path traversal (defense-in-depth) + provider := kiroauth.SanitizeEmailForFilename(strings.ToLower(strings.TrimSpace(tokenData.Provider))) + if provider == "" { + provider = "imported" // Fallback for legacy tokens without provider + } + + now := time.Now() + fileName := fmt.Sprintf("kiro-%s-%s.json", provider, idPart) + + record := &coreauth.Auth{ + ID: fileName, + Provider: "kiro", + FileName: fileName, + Label: fmt.Sprintf("kiro-%s", provider), + Status: coreauth.StatusActive, + CreatedAt: now, + UpdatedAt: now, + Metadata: map[string]any{ + "type": "kiro", + "access_token": tokenData.AccessToken, + "refresh_token": tokenData.RefreshToken, + "profile_arn": tokenData.ProfileArn, + "expires_at": tokenData.ExpiresAt, + "auth_method": tokenData.AuthMethod, + "provider": tokenData.Provider, + "email": tokenData.Email, + }, + Attributes: map[string]string{ + "profile_arn": tokenData.ProfileArn, + "source": "kiro-ide-import", + "email": tokenData.Email, + }, + // NextRefreshAfter is aligned with RefreshLead (5min) + NextRefreshAfter: expiresAt.Add(-5 * time.Minute), + } + + // Display the email if extracted + if tokenData.Email != "" { + fmt.Printf("\n✓ Imported Kiro token from IDE (Provider: %s, Account: %s)\n", tokenData.Provider, tokenData.Email) + } else { + fmt.Printf("\n✓ Imported Kiro token from IDE (Provider: %s)\n", tokenData.Provider) + } + + return record, nil +} + +// Refresh refreshes an expired Kiro token using AWS SSO OIDC. +func (a *KiroAuthenticator) Refresh(ctx context.Context, cfg *config.Config, auth *coreauth.Auth) (*coreauth.Auth, error) { + if auth == nil || auth.Metadata == nil { + return nil, fmt.Errorf("invalid auth record") + } + + refreshToken, ok := auth.Metadata["refresh_token"].(string) + if !ok || refreshToken == "" { + return nil, fmt.Errorf("refresh token not found") + } + + clientID, _ := auth.Metadata["client_id"].(string) + clientSecret, _ := auth.Metadata["client_secret"].(string) + authMethod, _ := auth.Metadata["auth_method"].(string) + startURL, _ := auth.Metadata["start_url"].(string) + region, _ := auth.Metadata["region"].(string) + + var tokenData *kiroauth.KiroTokenData + var err error + + ssoClient := kiroauth.NewSSOOIDCClient(cfg) + + // Use SSO OIDC refresh for AWS Builder ID or IDC, otherwise use Kiro's OAuth refresh endpoint + switch { + case clientID != "" && clientSecret != "" && authMethod == "idc" && region != "": + // IDC refresh with region-specific endpoint + tokenData, err = ssoClient.RefreshTokenWithRegion(ctx, clientID, clientSecret, refreshToken, region, startURL) + case clientID != "" && clientSecret != "" && authMethod == "builder-id": + // Builder ID refresh with default endpoint + tokenData, err = ssoClient.RefreshToken(ctx, clientID, clientSecret, refreshToken) + default: + // Fallback to Kiro's refresh endpoint (for social auth: Google/GitHub) + oauth := kiroauth.NewKiroOAuth(cfg) + tokenData, err = oauth.RefreshToken(ctx, refreshToken) + } + + if err != nil { + return nil, fmt.Errorf("token refresh failed: %w", err) + } + + // Parse expires_at + expiresAt, err := time.Parse(time.RFC3339, tokenData.ExpiresAt) + if err != nil { + expiresAt = time.Now().Add(1 * time.Hour) + } + + // Clone auth to avoid mutating the input parameter + updated := auth.Clone() + now := time.Now() + updated.UpdatedAt = now + updated.LastRefreshedAt = now + updated.Metadata["access_token"] = tokenData.AccessToken + updated.Metadata["refresh_token"] = tokenData.RefreshToken + updated.Metadata["expires_at"] = tokenData.ExpiresAt + updated.Metadata["last_refresh"] = now.Format(time.RFC3339) // For double-check optimization + // NextRefreshAfter is aligned with RefreshLead (5min) + updated.NextRefreshAfter = expiresAt.Add(-5 * time.Minute) + + return updated, nil +} diff --git a/sdk/cliproxy/service.go b/sdk/cliproxy/service.go index 66d1b8dd..885304ad 100644 --- a/sdk/cliproxy/service.go +++ b/sdk/cliproxy/service.go @@ -13,6 +13,7 @@ import ( "time" "github.com/router-for-me/CLIProxyAPI/v6/internal/api" + kiroauth "github.com/router-for-me/CLIProxyAPI/v6/internal/auth/kiro" "github.com/router-for-me/CLIProxyAPI/v6/internal/registry" "github.com/router-for-me/CLIProxyAPI/v6/internal/runtime/executor" _ "github.com/router-for-me/CLIProxyAPI/v6/internal/usage" @@ -775,7 +776,7 @@ func (s *Service) registerModelsForAuth(a *coreauth.Auth) { models = registry.GetGitHubCopilotModels() models = applyExcludedModels(models, excluded) case "kiro": - models = registry.GetKiroModels() + models = s.fetchKiroModels(a) models = applyExcludedModels(models, excluded) default: // Handle OpenAI-compatibility providers by name using config @@ -1338,3 +1339,201 @@ func applyOAuthModelAlias(cfg *config.Config, provider, authKind string, models } return out } + +// fetchKiroModels attempts to dynamically fetch Kiro models from the API. +// If dynamic fetch fails, it falls back to static registry.GetKiroModels(). +func (s *Service) fetchKiroModels(a *coreauth.Auth) []*ModelInfo { + if a == nil { + log.Debug("kiro: auth is nil, using static models") + return registry.GetKiroModels() + } + + // Extract token data from auth attributes + tokenData := s.extractKiroTokenData(a) + if tokenData == nil || tokenData.AccessToken == "" { + log.Debug("kiro: no valid token data in auth, using static models") + return registry.GetKiroModels() + } + + // Create KiroAuth instance + kAuth := kiroauth.NewKiroAuth(s.cfg) + if kAuth == nil { + log.Warn("kiro: failed to create KiroAuth instance, using static models") + return registry.GetKiroModels() + } + + // Use timeout context for API call + ctx, cancel := context.WithTimeout(context.Background(), 15*time.Second) + defer cancel() + + // Attempt to fetch dynamic models + apiModels, err := kAuth.ListAvailableModels(ctx, tokenData) + if err != nil { + log.Warnf("kiro: failed to fetch dynamic models: %v, using static models", err) + return registry.GetKiroModels() + } + + if len(apiModels) == 0 { + log.Debug("kiro: API returned no models, using static models") + return registry.GetKiroModels() + } + + // Convert API models to ModelInfo + models := convertKiroAPIModels(apiModels) + + // Generate agentic variants + models = generateKiroAgenticVariants(models) + + log.Infof("kiro: successfully fetched %d models from API (including agentic variants)", len(models)) + return models +} + +// extractKiroTokenData extracts KiroTokenData from auth attributes and metadata. +func (s *Service) extractKiroTokenData(a *coreauth.Auth) *kiroauth.KiroTokenData { + if a == nil || a.Attributes == nil { + return nil + } + + accessToken := strings.TrimSpace(a.Attributes["access_token"]) + if accessToken == "" { + return nil + } + + tokenData := &kiroauth.KiroTokenData{ + AccessToken: accessToken, + ProfileArn: strings.TrimSpace(a.Attributes["profile_arn"]), + } + + // Also try to get refresh token from metadata + if a.Metadata != nil { + if rt, ok := a.Metadata["refresh_token"].(string); ok { + tokenData.RefreshToken = rt + } + } + + return tokenData +} + +// convertKiroAPIModels converts Kiro API models to ModelInfo slice. +func convertKiroAPIModels(apiModels []*kiroauth.KiroModel) []*ModelInfo { + if len(apiModels) == 0 { + return nil + } + + now := time.Now().Unix() + models := make([]*ModelInfo, 0, len(apiModels)) + + for _, m := range apiModels { + if m == nil || m.ModelID == "" { + continue + } + + // Create model ID with kiro- prefix + modelID := "kiro-" + normalizeKiroModelID(m.ModelID) + + info := &ModelInfo{ + ID: modelID, + Object: "model", + Created: now, + OwnedBy: "aws", + Type: "kiro", + DisplayName: formatKiroDisplayName(m.ModelName, m.RateMultiplier), + Description: m.Description, + ContextLength: 200000, + MaxCompletionTokens: 64000, + Thinking: ®istry.ThinkingSupport{Min: 1024, Max: 32000, ZeroAllowed: true, DynamicAllowed: true}, + } + + if m.MaxInputTokens > 0 { + info.ContextLength = m.MaxInputTokens + } + + models = append(models, info) + } + + return models +} + +// normalizeKiroModelID normalizes a Kiro model ID by converting dots to dashes +// and removing common prefixes. +func normalizeKiroModelID(modelID string) string { + // Remove common prefixes + modelID = strings.TrimPrefix(modelID, "anthropic.") + modelID = strings.TrimPrefix(modelID, "amazon.") + + // Replace dots with dashes for consistency + modelID = strings.ReplaceAll(modelID, ".", "-") + + // Replace underscores with dashes + modelID = strings.ReplaceAll(modelID, "_", "-") + + return strings.ToLower(modelID) +} + +// formatKiroDisplayName formats the display name with rate multiplier info. +func formatKiroDisplayName(modelName string, rateMultiplier float64) string { + if modelName == "" { + return "" + } + + displayName := "Kiro " + modelName + if rateMultiplier > 0 && rateMultiplier != 1.0 { + displayName += fmt.Sprintf(" (%.1fx credit)", rateMultiplier) + } + + return displayName +} + +// generateKiroAgenticVariants generates agentic variants for Kiro models. +// Agentic variants have optimized system prompts for coding agents. +func generateKiroAgenticVariants(models []*ModelInfo) []*ModelInfo { + if len(models) == 0 { + return models + } + + result := make([]*ModelInfo, 0, len(models)*2) + result = append(result, models...) + + for _, m := range models { + if m == nil { + continue + } + + // Skip if already an agentic variant + if strings.HasSuffix(m.ID, "-agentic") { + continue + } + + // Skip auto models from agentic variant generation + if strings.Contains(m.ID, "-auto") { + continue + } + + // Create agentic variant + agentic := &ModelInfo{ + ID: m.ID + "-agentic", + Object: m.Object, + Created: m.Created, + OwnedBy: m.OwnedBy, + Type: m.Type, + DisplayName: m.DisplayName + " (Agentic)", + Description: m.Description + " - Optimized for coding agents (chunked writes)", + ContextLength: m.ContextLength, + MaxCompletionTokens: m.MaxCompletionTokens, + } + + // Copy thinking support if present + if m.Thinking != nil { + agentic.Thinking = ®istry.ThinkingSupport{ + Min: m.Thinking.Min, + Max: m.Thinking.Max, + ZeroAllowed: m.Thinking.ZeroAllowed, + DynamicAllowed: m.Thinking.DynamicAllowed, + } + } + + result = append(result, agentic) + } + + return result +} diff --git a/test_api.py b/test_api.py new file mode 100644 index 00000000..1849e2ba --- /dev/null +++ b/test_api.py @@ -0,0 +1,452 @@ +#!/usr/bin/env python3 +""" +CLIProxyAPI 全面测试脚本 +测试模型列表、流式输出、thinking模式及复杂任务 +""" + +import requests +import json +import time +import sys +import io +from typing import Optional, List, Dict, Any + +# 修复 Windows 控制台编码问题 +sys.stdout = io.TextIOWrapper(sys.stdout.buffer, encoding='utf-8', errors='replace') +sys.stderr = io.TextIOWrapper(sys.stderr.buffer, encoding='utf-8', errors='replace') + +# 配置 +BASE_URL = "http://localhost:8317" +API_KEY = "your-api-key-1" +HEADERS = { + "Authorization": f"Bearer {API_KEY}", + "Content-Type": "application/json" +} + +# 复杂任务提示词 - 用于测试 thinking 模式 +COMPLEX_TASK_PROMPT = """请帮我分析以下复杂的编程问题,并给出详细的解决方案: + +问题:设计一个高并发的分布式任务调度系统,需要满足以下要求: +1. 支持百万级任务队列 +2. 任务可以设置优先级、延迟执行、定时执行 +3. 支持任务依赖关系(DAG调度) +4. 失败重试机制,支持指数退避 +5. 任务结果持久化和查询 +6. 水平扩展能力 +7. 监控和告警 + +请从以下几个方面详细分析: +1. 整体架构设计 +2. 核心数据结构 +3. 调度算法选择 +4. 容错机制设计 +5. 性能优化策略 +6. 技术选型建议 + +请逐步思考每个方面,给出你的推理过程。""" + +# 简单测试提示词 +SIMPLE_PROMPT = "Hello! Please respond with 'OK' if you receive this message." + +def print_separator(title: str): + print(f"\n{'='*60}") + print(f" {title}") + print(f"{'='*60}\n") + +def print_result(name: str, success: bool, detail: str = ""): + status = "✅ PASS" if success else "❌ FAIL" + print(f"{status} | {name}") + if detail: + print(f" └─ {detail[:200]}{'...' if len(detail) > 200 else ''}") + +def get_models() -> List[str]: + """获取可用模型列表""" + print_separator("获取模型列表") + try: + resp = requests.get(f"{BASE_URL}/v1/models", headers=HEADERS, timeout=30) + if resp.status_code == 200: + data = resp.json() + models = [m.get("id", m.get("name", "unknown")) for m in data.get("data", [])] + print(f"找到 {len(models)} 个模型:") + for m in models: + print(f" - {m}") + return models + else: + print(f"❌ 获取模型列表失败: HTTP {resp.status_code}") + print(f" 响应: {resp.text[:500]}") + return [] + except Exception as e: + print(f"❌ 获取模型列表异常: {e}") + return [] + +def test_model_basic(model: str) -> tuple: + """基础可用性测试,返回 (success, error_detail)""" + try: + payload = { + "model": model, + "messages": [{"role": "user", "content": SIMPLE_PROMPT}], + "max_tokens": 50, + "stream": False + } + resp = requests.post( + f"{BASE_URL}/v1/chat/completions", + headers=HEADERS, + json=payload, + timeout=60 + ) + if resp.status_code == 200: + data = resp.json() + content = data.get("choices", [{}])[0].get("message", {}).get("content", "") + return (bool(content), f"content_len={len(content)}") + else: + return (False, f"HTTP {resp.status_code}: {resp.text[:300]}") + except Exception as e: + return (False, str(e)) + +def test_streaming(model: str) -> Dict[str, Any]: + """测试流式输出""" + result = {"success": False, "chunks": 0, "content": "", "error": None} + try: + payload = { + "model": model, + "messages": [{"role": "user", "content": "Count from 1 to 5, one number per line."}], + "max_tokens": 100, + "stream": True + } + resp = requests.post( + f"{BASE_URL}/v1/chat/completions", + headers=HEADERS, + json=payload, + timeout=60, + stream=True + ) + + if resp.status_code != 200: + result["error"] = f"HTTP {resp.status_code}: {resp.text[:200]}" + return result + + content_parts = [] + for line in resp.iter_lines(): + if line: + line_str = line.decode('utf-8') + if line_str.startswith("data: "): + data_str = line_str[6:] + if data_str.strip() == "[DONE]": + break + try: + data = json.loads(data_str) + result["chunks"] += 1 + choices = data.get("choices", []) + if choices: + delta = choices[0].get("delta", {}) + if "content" in delta and delta["content"]: + content_parts.append(delta["content"]) + except json.JSONDecodeError: + pass + except Exception as e: + result["error"] = f"Parse error: {e}, data: {data_str[:200]}" + + result["content"] = "".join(content_parts) + result["success"] = result["chunks"] > 0 and len(result["content"]) > 0 + + except Exception as e: + result["error"] = str(e) + + return result + +def test_thinking_mode(model: str, complex_task: bool = False) -> Dict[str, Any]: + """测试 thinking 模式""" + result = { + "success": False, + "has_reasoning": False, + "reasoning_content": "", + "content": "", + "error": None, + "chunks": 0 + } + + prompt = COMPLEX_TASK_PROMPT if complex_task else "What is 15 * 23? Please think step by step." + + try: + # 尝试不同的 thinking 模式参数格式 + payload = { + "model": model, + "messages": [{"role": "user", "content": prompt}], + "max_tokens": 8000 if complex_task else 2000, + "stream": True + } + + # 根据模型类型添加 thinking 参数 + if "claude" in model.lower(): + payload["thinking"] = {"type": "enabled", "budget_tokens": 5000 if complex_task else 2000} + elif "gemini" in model.lower(): + payload["thinking"] = {"thinking_budget": 5000 if complex_task else 2000} + elif "gpt" in model.lower() or "codex" in model.lower() or "o1" in model.lower() or "o3" in model.lower(): + payload["reasoning_effort"] = "high" if complex_task else "medium" + else: + # 通用格式 + payload["thinking"] = {"type": "enabled", "budget_tokens": 5000 if complex_task else 2000} + + resp = requests.post( + f"{BASE_URL}/v1/chat/completions", + headers=HEADERS, + json=payload, + timeout=300 if complex_task else 120, + stream=True + ) + + if resp.status_code != 200: + result["error"] = f"HTTP {resp.status_code}: {resp.text[:500]}" + return result + + content_parts = [] + reasoning_parts = [] + + for line in resp.iter_lines(): + if line: + line_str = line.decode('utf-8') + if line_str.startswith("data: "): + data_str = line_str[6:] + if data_str.strip() == "[DONE]": + break + try: + data = json.loads(data_str) + result["chunks"] += 1 + + choices = data.get("choices", []) + if not choices: + continue + choice = choices[0] + delta = choice.get("delta", {}) + + # 检查 reasoning_content (Claude/OpenAI格式) + if "reasoning_content" in delta and delta["reasoning_content"]: + reasoning_parts.append(delta["reasoning_content"]) + result["has_reasoning"] = True + + # 检查 thinking (Gemini格式) + if "thinking" in delta and delta["thinking"]: + reasoning_parts.append(delta["thinking"]) + result["has_reasoning"] = True + + # 常规内容 + if "content" in delta and delta["content"]: + content_parts.append(delta["content"]) + + except json.JSONDecodeError as e: + pass + except Exception as e: + result["error"] = f"Parse error: {e}" + + result["reasoning_content"] = "".join(reasoning_parts) + result["content"] = "".join(content_parts) + result["success"] = result["chunks"] > 0 and (len(result["content"]) > 0 or len(result["reasoning_content"]) > 0) + + except requests.exceptions.Timeout: + result["error"] = "Request timeout" + except Exception as e: + result["error"] = str(e) + + return result + +def run_full_test(): + """运行完整测试""" + print("\n" + "="*60) + print(" CLIProxyAPI 全面测试") + print("="*60) + print(f"目标地址: {BASE_URL}") + print(f"API Key: {API_KEY[:10]}...") + + # 1. 获取模型列表 + models = get_models() + if not models: + print("\n❌ 无法获取模型列表,测试终止") + return + + # 2. 基础可用性测试 + print_separator("基础可用性测试") + available_models = [] + for model in models: + success, detail = test_model_basic(model) + print_result(f"模型: {model}", success, detail) + if success: + available_models.append(model) + + print(f"\n可用模型: {len(available_models)}/{len(models)}") + + if not available_models: + print("\n❌ 没有可用的模型,测试终止") + return + + # 3. 流式输出测试 + print_separator("流式输出测试") + streaming_results = {} + for model in available_models: + result = test_streaming(model) + streaming_results[model] = result + detail = f"chunks={result['chunks']}, content_len={len(result['content'])}" + if result["error"]: + detail = f"error: {result['error']}" + print_result(f"模型: {model}", result["success"], detail) + + # 4. Thinking 模式测试 (简单任务) + print_separator("Thinking 模式测试 (简单任务)") + thinking_results = {} + for model in available_models: + result = test_thinking_mode(model, complex_task=False) + thinking_results[model] = result + detail = f"reasoning={result['has_reasoning']}, chunks={result['chunks']}" + if result["error"]: + detail = f"error: {result['error']}" + print_result(f"模型: {model}", result["success"], detail) + + # 5. Thinking 模式测试 (复杂任务) - 只测试支持 thinking 的模型 + print_separator("Thinking 模式测试 (复杂任务)") + complex_thinking_results = {} + + # 选择前3个可用模型进行复杂任务测试 + test_models = available_models[:3] + print(f"测试模型 (取前3个): {test_models}\n") + + for model in test_models: + print(f"⏳ 正在测试 {model} (复杂任务,可能需要较长时间)...") + result = test_thinking_mode(model, complex_task=True) + complex_thinking_results[model] = result + + if result["success"]: + detail = f"reasoning={result['has_reasoning']}, reasoning_len={len(result['reasoning_content'])}, content_len={len(result['content'])}" + else: + detail = f"error: {result['error']}" if result["error"] else "Unknown error" + + print_result(f"模型: {model}", result["success"], detail) + + # 如果有 reasoning 内容,打印前500字符 + if result["has_reasoning"] and result["reasoning_content"]: + print(f"\n 📝 Reasoning 内容预览 (前500字符):") + print(f" {result['reasoning_content'][:500]}...") + + # 6. 总结报告 + print_separator("测试总结报告") + + print(f"📊 模型总数: {len(models)}") + print(f"✅ 可用模型: {len(available_models)}") + print(f"❌ 不可用模型: {len(models) - len(available_models)}") + + print(f"\n📊 流式输出测试:") + streaming_pass = sum(1 for r in streaming_results.values() if r["success"]) + print(f" 通过: {streaming_pass}/{len(streaming_results)}") + + print(f"\n📊 Thinking 模式测试 (简单):") + thinking_pass = sum(1 for r in thinking_results.values() if r["success"]) + thinking_with_reasoning = sum(1 for r in thinking_results.values() if r["has_reasoning"]) + print(f" 通过: {thinking_pass}/{len(thinking_results)}") + print(f" 包含推理内容: {thinking_with_reasoning}/{len(thinking_results)}") + + print(f"\n📊 Thinking 模式测试 (复杂):") + complex_pass = sum(1 for r in complex_thinking_results.values() if r["success"]) + complex_with_reasoning = sum(1 for r in complex_thinking_results.values() if r["has_reasoning"]) + print(f" 通过: {complex_pass}/{len(complex_thinking_results)}") + print(f" 包含推理内容: {complex_with_reasoning}/{len(complex_thinking_results)}") + + # 列出所有错误 + print(f"\n📋 错误详情:") + has_errors = False + + for model, result in streaming_results.items(): + if result["error"]: + has_errors = True + print(f" [流式] {model}: {result['error'][:100]}") + + for model, result in thinking_results.items(): + if result["error"]: + has_errors = True + print(f" [Thinking简单] {model}: {result['error'][:100]}") + + for model, result in complex_thinking_results.items(): + if result["error"]: + has_errors = True + print(f" [Thinking复杂] {model}: {result['error'][:100]}") + + if not has_errors: + print(" 无错误") + + print("\n" + "="*60) + print(" 测试完成") + print("="*60 + "\n") + +def test_single_model_basic(model: str): + """单独测试一个模型的基础功能""" + print_separator(f"基础测试: {model}") + success, detail = test_model_basic(model) + print_result(f"模型: {model}", success, detail) + return success + +def test_single_model_streaming(model: str): + """单独测试一个模型的流式输出""" + print_separator(f"流式测试: {model}") + result = test_streaming(model) + detail = f"chunks={result['chunks']}, content_len={len(result['content'])}" + if result["error"]: + detail = f"error: {result['error']}" + print_result(f"模型: {model}", result["success"], detail) + if result["content"]: + print(f"\n内容: {result['content'][:300]}") + return result + +def test_single_model_thinking(model: str, complex_task: bool = False): + """单独测试一个模型的thinking模式""" + task_type = "复杂" if complex_task else "简单" + print_separator(f"Thinking测试({task_type}): {model}") + result = test_thinking_mode(model, complex_task=complex_task) + detail = f"reasoning={result['has_reasoning']}, chunks={result['chunks']}" + if result["error"]: + detail = f"error: {result['error']}" + print_result(f"模型: {model}", result["success"], detail) + if result["reasoning_content"]: + print(f"\nReasoning预览: {result['reasoning_content'][:500]}") + if result["content"]: + print(f"\n内容预览: {result['content'][:500]}") + return result + +def print_usage(): + print(""" +用法: python test_api.py [options] + +命令: + models - 获取模型列表 + basic - 测试单个模型基础功能 + stream - 测试单个模型流式输出 + thinking - 测试单个模型thinking模式(简单任务) + thinking-complex - 测试单个模型thinking模式(复杂任务) + all - 运行完整测试(原有功能) + +示例: + python test_api.py models + python test_api.py basic claude-sonnet + python test_api.py stream claude-sonnet + python test_api.py thinking claude-sonnet +""") + +if __name__ == "__main__": + import sys + + if len(sys.argv) < 2: + print_usage() + sys.exit(0) + + cmd = sys.argv[1].lower() + + if cmd == "models": + get_models() + elif cmd == "basic" and len(sys.argv) >= 3: + test_single_model_basic(sys.argv[2]) + elif cmd == "stream" and len(sys.argv) >= 3: + test_single_model_streaming(sys.argv[2]) + elif cmd == "thinking" and len(sys.argv) >= 3: + test_single_model_thinking(sys.argv[2], complex_task=False) + elif cmd == "thinking-complex" and len(sys.argv) >= 3: + test_single_model_thinking(sys.argv[2], complex_task=True) + elif cmd == "all": + run_full_test() + else: + print_usage() diff --git a/test_auth_diff.go b/test_auth_diff.go new file mode 100644 index 00000000..b294622e --- /dev/null +++ b/test_auth_diff.go @@ -0,0 +1,273 @@ +// 测试脚本 3:对比 CLIProxyAPIPlus 与官方格式的差异 +// 这个脚本分析 CLIProxyAPIPlus 保存的 token 与官方格式的差异 +// 运行方式: go run test_auth_diff.go +package main + +import ( + "bytes" + "encoding/json" + "fmt" + "io" + "net/http" + "os" + "path/filepath" + "strings" + "time" +) + +func main() { + fmt.Println("=" + strings.Repeat("=", 59)) + fmt.Println(" 测试脚本 3: Token 格式差异分析") + fmt.Println("=" + strings.Repeat("=", 59)) + + homeDir := os.Getenv("USERPROFILE") + + // 加载官方 IDE Token (Kiro IDE 生成) + fmt.Println("\n[1] 官方 Kiro IDE Token 格式") + fmt.Println("-" + strings.Repeat("-", 59)) + + ideTokenPath := filepath.Join(homeDir, ".aws", "sso", "cache", "kiro-auth-token.json") + ideToken := loadAndAnalyze(ideTokenPath, "Kiro IDE") + + // 加载 CLIProxyAPIPlus 保存的 Token + fmt.Println("\n[2] CLIProxyAPIPlus 保存的 Token 格式") + fmt.Println("-" + strings.Repeat("-", 59)) + + cliProxyDir := filepath.Join(homeDir, ".cli-proxy-api") + files, _ := os.ReadDir(cliProxyDir) + + var cliProxyTokens []map[string]interface{} + for _, f := range files { + if strings.HasPrefix(f.Name(), "kiro") && strings.HasSuffix(f.Name(), ".json") { + p := filepath.Join(cliProxyDir, f.Name()) + token := loadAndAnalyze(p, f.Name()) + if token != nil { + cliProxyTokens = append(cliProxyTokens, token) + } + } + } + + // 对比分析 + fmt.Println("\n[3] 关键差异分析") + fmt.Println("-" + strings.Repeat("-", 59)) + + if ideToken == nil { + fmt.Println("❌ 无法加载 IDE Token,跳过对比") + } else if len(cliProxyTokens) == 0 { + fmt.Println("❌ 无法加载 CLIProxyAPIPlus Token,跳过对比") + } else { + // 对比最新的 CLIProxyAPIPlus token + cliToken := cliProxyTokens[0] + + fmt.Println("\n字段对比:") + fmt.Printf("%-20s | %-15s | %-15s\n", "字段", "IDE Token", "CLIProxy Token") + fmt.Println(strings.Repeat("-", 55)) + + fields := []string{ + "accessToken", "refreshToken", "clientId", "clientSecret", + "authMethod", "auth_method", "provider", "region", "expiresAt", "expires_at", + } + + for _, field := range fields { + ideVal := getFieldStatus(ideToken, field) + cliVal := getFieldStatus(cliToken, field) + + status := " " + if ideVal != cliVal { + if ideVal == "✅ 有" && cliVal == "❌ 无" { + status = "⚠️" + } else if ideVal == "❌ 无" && cliVal == "✅ 有" { + status = "📝" + } + } + + fmt.Printf("%-20s | %-15s | %-15s %s\n", field, ideVal, cliVal, status) + } + + // 关键问题检测 + fmt.Println("\n🔍 问题检测:") + + // 检查 clientId/clientSecret + if hasField(ideToken, "clientId") && !hasField(cliToken, "clientId") { + fmt.Println(" ⚠️ 问题: CLIProxyAPIPlus 缺少 clientId 字段!") + fmt.Println(" 原因: IdC 认证刷新 token 时需要 clientId") + } + + if hasField(ideToken, "clientSecret") && !hasField(cliToken, "clientSecret") { + fmt.Println(" ⚠️ 问题: CLIProxyAPIPlus 缺少 clientSecret 字段!") + fmt.Println(" 原因: IdC 认证刷新 token 时需要 clientSecret") + } + + // 检查字段名差异 + if hasField(cliToken, "auth_method") && !hasField(cliToken, "authMethod") { + fmt.Println(" 📝 注意: CLIProxy 使用 auth_method (snake_case)") + fmt.Println(" 而官方使用 authMethod (camelCase)") + } + + if hasField(cliToken, "expires_at") && !hasField(cliToken, "expiresAt") { + fmt.Println(" 📝 注意: CLIProxy 使用 expires_at (snake_case)") + fmt.Println(" 而官方使用 expiresAt (camelCase)") + } + } + + // Step 4: 测试使用完整格式的 token + fmt.Println("\n[4] 测试完整格式 Token (带 clientId/clientSecret)") + fmt.Println("-" + strings.Repeat("-", 59)) + + if ideToken != nil { + testWithFullToken(ideToken) + } + + fmt.Println("\n" + strings.Repeat("=", 60)) + fmt.Println(" 分析完成") + fmt.Println(strings.Repeat("=", 60)) + + // 给出建议 + fmt.Println("\n💡 修复建议:") + fmt.Println(" 1. CLIProxyAPIPlus 导入 token 时需要保留 clientId 和 clientSecret") + fmt.Println(" 2. IdC 认证刷新 token 必须使用这两个字段") + fmt.Println(" 3. 检查 CLIProxyAPIPlus 的 token 导入逻辑:") + fmt.Println(" - internal/auth/kiro/aws.go LoadKiroIDEToken()") + fmt.Println(" - sdk/auth/kiro.go ImportFromKiroIDE()") +} + +func loadAndAnalyze(path, name string) map[string]interface{} { + data, err := os.ReadFile(path) + if err != nil { + fmt.Printf("❌ 无法加载 %s: %v\n", name, err) + return nil + } + + var token map[string]interface{} + if err := json.Unmarshal(data, &token); err != nil { + fmt.Printf("❌ 无法解析 %s: %v\n", name, err) + return nil + } + + fmt.Printf("📄 %s\n", path) + fmt.Printf(" 字段数: %d\n", len(token)) + + // 列出所有字段 + fmt.Printf(" 字段列表: ") + keys := make([]string, 0, len(token)) + for k := range token { + keys = append(keys, k) + } + fmt.Printf("%v\n", keys) + + return token +} + +func getFieldStatus(token map[string]interface{}, field string) string { + if token == nil { + return "N/A" + } + if v, ok := token[field]; ok && v != nil && v != "" { + return "✅ 有" + } + return "❌ 无" +} + +func hasField(token map[string]interface{}, field string) bool { + if token == nil { + return false + } + v, ok := token[field] + return ok && v != nil && v != "" +} + +func testWithFullToken(token map[string]interface{}) { + accessToken, _ := token["accessToken"].(string) + refreshToken, _ := token["refreshToken"].(string) + clientId, _ := token["clientId"].(string) + clientSecret, _ := token["clientSecret"].(string) + region, _ := token["region"].(string) + + if region == "" { + region = "us-east-1" + } + + // 测试当前 accessToken + fmt.Println("\n测试当前 accessToken...") + if testAPICall(accessToken, region) { + fmt.Println("✅ 当前 accessToken 有效") + return + } + + fmt.Println("⚠️ 当前 accessToken 无效,尝试刷新...") + + // 检查是否有完整的刷新所需字段 + if clientId == "" || clientSecret == "" { + fmt.Println("❌ 缺少 clientId 或 clientSecret,无法刷新") + fmt.Println(" 这就是问题所在!") + return + } + + // 尝试刷新 + fmt.Println("\n使用完整字段刷新 token...") + url := fmt.Sprintf("https://oidc.%s.amazonaws.com/token", region) + + requestBody := map[string]interface{}{ + "refreshToken": refreshToken, + "clientId": clientId, + "clientSecret": clientSecret, + "grantType": "refresh_token", + } + + body, _ := json.Marshal(requestBody) + req, _ := http.NewRequest("POST", url, bytes.NewBuffer(body)) + req.Header.Set("Content-Type", "application/json") + + client := &http.Client{Timeout: 30 * time.Second} + resp, err := client.Do(req) + if err != nil { + fmt.Printf("❌ 请求失败: %v\n", err) + return + } + defer resp.Body.Close() + + respBody, _ := io.ReadAll(resp.Body) + + if resp.StatusCode == 200 { + var refreshResp map[string]interface{} + json.Unmarshal(respBody, &refreshResp) + + newAccessToken, _ := refreshResp["accessToken"].(string) + fmt.Println("✅ Token 刷新成功!") + + // 验证新 token + if testAPICall(newAccessToken, region) { + fmt.Println("✅ 新 Token 验证成功!") + fmt.Println("\n✅ 结论: 使用完整格式 (含 clientId/clientSecret) 可以正常工作") + } + } else { + fmt.Printf("❌ 刷新失败: HTTP %d\n", resp.StatusCode) + fmt.Printf(" 响应: %s\n", string(respBody)) + } +} + +func testAPICall(accessToken, region string) bool { + url := fmt.Sprintf("https://codewhisperer.%s.amazonaws.com", region) + + payload := map[string]interface{}{ + "origin": "AI_EDITOR", + "isEmailRequired": true, + "resourceType": "AGENTIC_REQUEST", + } + body, _ := json.Marshal(payload) + + req, _ := http.NewRequest("POST", url, bytes.NewBuffer(body)) + req.Header.Set("Content-Type", "application/x-amz-json-1.0") + req.Header.Set("x-amz-target", "AmazonCodeWhispererService.GetUsageLimits") + req.Header.Set("Authorization", "Bearer "+accessToken) + req.Header.Set("Accept", "application/json") + + client := &http.Client{Timeout: 30 * time.Second} + resp, err := client.Do(req) + if err != nil { + return false + } + defer resp.Body.Close() + + return resp.StatusCode == 200 +} diff --git a/test_auth_idc_go1.go b/test_auth_idc_go1.go new file mode 100644 index 00000000..55fd5829 --- /dev/null +++ b/test_auth_idc_go1.go @@ -0,0 +1,323 @@ +// 测试脚本 1:模拟 kiro2api_go1 的 IdC 认证方式 +// 这个脚本完整模拟 kiro-gateway/temp/kiro2api_go1 的认证逻辑 +// 运行方式: go run test_auth_idc_go1.go +package main + +import ( + "bytes" + "encoding/json" + "fmt" + "io" + "math/rand" + "net/http" + "os" + "path/filepath" + "strings" + "time" +) + +// 配置常量 - 来自 kiro2api_go1/config/config.go +const ( + IdcRefreshTokenURL = "https://oidc.us-east-1.amazonaws.com/token" + CodeWhispererAPIURL = "https://codewhisperer.us-east-1.amazonaws.com" +) + +// AuthConfig - 来自 kiro2api_go1/auth/config.go +type AuthConfig struct { + AuthType string `json:"auth"` + RefreshToken string `json:"refreshToken"` + ClientID string `json:"clientId,omitempty"` + ClientSecret string `json:"clientSecret,omitempty"` +} + +// IdcRefreshRequest - 来自 kiro2api_go1/types/token.go +type IdcRefreshRequest struct { + ClientId string `json:"clientId"` + ClientSecret string `json:"clientSecret"` + GrantType string `json:"grantType"` + RefreshToken string `json:"refreshToken"` +} + +// RefreshResponse - 来自 kiro2api_go1/types/token.go +type RefreshResponse struct { + AccessToken string `json:"accessToken"` + RefreshToken string `json:"refreshToken,omitempty"` + ExpiresIn int `json:"expiresIn"` + TokenType string `json:"tokenType,omitempty"` +} + +// Fingerprint - 简化的指纹结构 +type Fingerprint struct { + OSType string + ConnectionBehavior string + AcceptLanguage string + SecFetchMode string + AcceptEncoding string +} + +func generateFingerprint() *Fingerprint { + osTypes := []string{"darwin", "windows", "linux"} + connections := []string{"keep-alive", "close"} + languages := []string{"en-US,en;q=0.9", "zh-CN,zh;q=0.9", "en-GB,en;q=0.9"} + fetchModes := []string{"cors", "navigate", "no-cors"} + + return &Fingerprint{ + OSType: osTypes[rand.Intn(len(osTypes))], + ConnectionBehavior: connections[rand.Intn(len(connections))], + AcceptLanguage: languages[rand.Intn(len(languages))], + SecFetchMode: fetchModes[rand.Intn(len(fetchModes))], + AcceptEncoding: "gzip, deflate, br", + } +} + +func main() { + rand.Seed(time.Now().UnixNano()) + + fmt.Println("=" + strings.Repeat("=", 59)) + fmt.Println(" 测试脚本 1: kiro2api_go1 风格 IdC 认证") + fmt.Println("=" + strings.Repeat("=", 59)) + + // Step 1: 加载官方格式的 token 文件 + fmt.Println("\n[Step 1] 加载官方格式 Token 文件") + fmt.Println("-" + strings.Repeat("-", 59)) + + // 尝试从多个位置加载 + tokenPaths := []string{ + // 优先使用包含完整 clientId/clientSecret 的文件 + "E:/ai_project_2api/kiro-gateway/configs/kiro/kiro-auth-token-1768317098.json", + filepath.Join(os.Getenv("USERPROFILE"), ".aws", "sso", "cache", "kiro-auth-token.json"), + } + + var tokenData map[string]interface{} + var loadedPath string + + for _, p := range tokenPaths { + data, err := os.ReadFile(p) + if err == nil { + if err := json.Unmarshal(data, &tokenData); err == nil { + loadedPath = p + break + } + } + } + + if tokenData == nil { + fmt.Println("❌ 无法加载任何 token 文件") + return + } + + fmt.Printf("✅ 加载文件: %s\n", loadedPath) + + // 提取关键字段 + accessToken, _ := tokenData["accessToken"].(string) + refreshToken, _ := tokenData["refreshToken"].(string) + clientId, _ := tokenData["clientId"].(string) + clientSecret, _ := tokenData["clientSecret"].(string) + authMethod, _ := tokenData["authMethod"].(string) + region, _ := tokenData["region"].(string) + + if region == "" { + region = "us-east-1" + } + + fmt.Printf("\n当前 Token 信息:\n") + fmt.Printf(" AuthMethod: %s\n", authMethod) + fmt.Printf(" Region: %s\n", region) + fmt.Printf(" AccessToken: %s...\n", truncate(accessToken, 50)) + fmt.Printf(" RefreshToken: %s...\n", truncate(refreshToken, 50)) + fmt.Printf(" ClientID: %s\n", truncate(clientId, 30)) + fmt.Printf(" ClientSecret: %s...\n", truncate(clientSecret, 50)) + + // Step 2: 验证 IdC 认证所需字段 + fmt.Println("\n[Step 2] 验证 IdC 认证必需字段") + fmt.Println("-" + strings.Repeat("-", 59)) + + missingFields := []string{} + if refreshToken == "" { + missingFields = append(missingFields, "refreshToken") + } + if clientId == "" { + missingFields = append(missingFields, "clientId") + } + if clientSecret == "" { + missingFields = append(missingFields, "clientSecret") + } + + if len(missingFields) > 0 { + fmt.Printf("❌ 缺少必需字段: %v\n", missingFields) + fmt.Println(" IdC 认证需要: refreshToken, clientId, clientSecret") + return + } + fmt.Println("✅ 所有必需字段都存在") + + // Step 3: 测试直接使用 accessToken 调用 API + fmt.Println("\n[Step 3] 测试当前 AccessToken 有效性") + fmt.Println("-" + strings.Repeat("-", 59)) + + if testAPICall(accessToken, region) { + fmt.Println("✅ 当前 AccessToken 有效,无需刷新") + } else { + fmt.Println("⚠️ 当前 AccessToken 无效,需要刷新") + + // Step 4: 使用 kiro2api_go1 风格刷新 token + fmt.Println("\n[Step 4] 使用 kiro2api_go1 风格刷新 Token") + fmt.Println("-" + strings.Repeat("-", 59)) + + newToken, err := refreshIdCToken(AuthConfig{ + AuthType: "IdC", + RefreshToken: refreshToken, + ClientID: clientId, + ClientSecret: clientSecret, + }, region) + + if err != nil { + fmt.Printf("❌ 刷新失败: %v\n", err) + return + } + + fmt.Println("✅ Token 刷新成功!") + fmt.Printf(" 新 AccessToken: %s...\n", truncate(newToken.AccessToken, 50)) + fmt.Printf(" ExpiresIn: %d 秒\n", newToken.ExpiresIn) + + // Step 5: 验证新 token + fmt.Println("\n[Step 5] 验证新 Token") + fmt.Println("-" + strings.Repeat("-", 59)) + + if testAPICall(newToken.AccessToken, region) { + fmt.Println("✅ 新 Token 验证成功!") + + // 保存新 token + saveNewToken(loadedPath, newToken, tokenData) + } else { + fmt.Println("❌ 新 Token 验证失败") + } + } + + fmt.Println("\n" + strings.Repeat("=", 60)) + fmt.Println(" 测试完成") + fmt.Println(strings.Repeat("=", 60)) +} + +// refreshIdCToken - 完全模拟 kiro2api_go1/auth/refresh.go 的 refreshIdCToken 函数 +func refreshIdCToken(authConfig AuthConfig, region string) (*RefreshResponse, error) { + refreshReq := IdcRefreshRequest{ + ClientId: authConfig.ClientID, + ClientSecret: authConfig.ClientSecret, + GrantType: "refresh_token", + RefreshToken: authConfig.RefreshToken, + } + + reqBody, err := json.Marshal(refreshReq) + if err != nil { + return nil, fmt.Errorf("序列化IdC请求失败: %v", err) + } + + url := fmt.Sprintf("https://oidc.%s.amazonaws.com/token", region) + req, err := http.NewRequest("POST", url, bytes.NewBuffer(reqBody)) + if err != nil { + return nil, fmt.Errorf("创建IdC请求失败: %v", err) + } + + // 设置 IdC 特殊 headers(使用指纹随机化)- 完全模拟 kiro2api_go1 + fp := generateFingerprint() + + req.Header.Set("Content-Type", "application/json") + req.Header.Set("Host", fmt.Sprintf("oidc.%s.amazonaws.com", region)) + req.Header.Set("Connection", fp.ConnectionBehavior) + req.Header.Set("x-amz-user-agent", fmt.Sprintf("aws-sdk-js/3.738.0 ua/2.1 os/%s lang/js md/browser#unknown_unknown api/sso-oidc#3.738.0 m/E KiroIDE", fp.OSType)) + req.Header.Set("Accept", "*/*") + req.Header.Set("Accept-Language", fp.AcceptLanguage) + req.Header.Set("sec-fetch-mode", fp.SecFetchMode) + req.Header.Set("User-Agent", "node") + req.Header.Set("Accept-Encoding", fp.AcceptEncoding) + + fmt.Println("发送刷新请求:") + fmt.Printf(" URL: %s\n", url) + fmt.Println(" Headers:") + for k, v := range req.Header { + if k == "Content-Type" || k == "Host" || k == "X-Amz-User-Agent" || k == "User-Agent" { + fmt.Printf(" %s: %s\n", k, v[0]) + } + } + + client := &http.Client{Timeout: 30 * time.Second} + resp, err := client.Do(req) + if err != nil { + return nil, fmt.Errorf("IdC请求失败: %v", err) + } + defer resp.Body.Close() + + body, _ := io.ReadAll(resp.Body) + + if resp.StatusCode != http.StatusOK { + return nil, fmt.Errorf("IdC刷新失败: 状态码 %d, 响应: %s", resp.StatusCode, string(body)) + } + + var refreshResp RefreshResponse + if err := json.Unmarshal(body, &refreshResp); err != nil { + return nil, fmt.Errorf("解析IdC响应失败: %v", err) + } + + return &refreshResp, nil +} + +func testAPICall(accessToken, region string) bool { + url := fmt.Sprintf("https://codewhisperer.%s.amazonaws.com", region) + + payload := map[string]interface{}{ + "origin": "AI_EDITOR", + "isEmailRequired": true, + "resourceType": "AGENTIC_REQUEST", + } + body, _ := json.Marshal(payload) + + req, _ := http.NewRequest("POST", url, bytes.NewBuffer(body)) + req.Header.Set("Content-Type", "application/x-amz-json-1.0") + req.Header.Set("x-amz-target", "AmazonCodeWhispererService.GetUsageLimits") + req.Header.Set("Authorization", "Bearer "+accessToken) + req.Header.Set("Accept", "application/json") + + client := &http.Client{Timeout: 30 * time.Second} + resp, err := client.Do(req) + if err != nil { + fmt.Printf(" 请求错误: %v\n", err) + return false + } + defer resp.Body.Close() + + respBody, _ := io.ReadAll(resp.Body) + fmt.Printf(" API 响应: HTTP %d\n", resp.StatusCode) + + if resp.StatusCode == 200 { + return true + } + + fmt.Printf(" 错误详情: %s\n", truncate(string(respBody), 200)) + return false +} + +func saveNewToken(originalPath string, newToken *RefreshResponse, originalData map[string]interface{}) { + // 更新 token 数据 + originalData["accessToken"] = newToken.AccessToken + if newToken.RefreshToken != "" { + originalData["refreshToken"] = newToken.RefreshToken + } + originalData["expiresAt"] = time.Now().Add(time.Duration(newToken.ExpiresIn) * time.Second).Format(time.RFC3339) + + data, _ := json.MarshalIndent(originalData, "", " ") + + // 保存到新文件 + newPath := strings.TrimSuffix(originalPath, ".json") + "_refreshed.json" + if err := os.WriteFile(newPath, data, 0644); err != nil { + fmt.Printf("⚠️ 保存失败: %v\n", err) + } else { + fmt.Printf("✅ 新 Token 已保存到: %s\n", newPath) + } +} + +func truncate(s string, n int) string { + if len(s) <= n { + return s + } + return s[:n] +} diff --git a/test_auth_js_style.go b/test_auth_js_style.go new file mode 100644 index 00000000..6ded3305 --- /dev/null +++ b/test_auth_js_style.go @@ -0,0 +1,237 @@ +// 测试脚本 2:模拟 kiro2Api_js 的认证方式 +// 这个脚本完整模拟 kiro-gateway/temp/kiro2Api_js 的认证逻辑 +// 运行方式: go run test_auth_js_style.go +package main + +import ( + "bytes" + "encoding/json" + "fmt" + "io" + "net/http" + "os" + "path/filepath" + "strings" + "time" +) + +// 常量 - 来自 kiro2Api_js/src/kiro/auth.js +const ( + REFRESH_URL_TEMPLATE = "https://prod.{{region}}.auth.desktop.kiro.dev/refreshToken" + REFRESH_IDC_URL_TEMPLATE = "https://oidc.{{region}}.amazonaws.com/token" + AUTH_METHOD_SOCIAL = "social" + AUTH_METHOD_IDC = "IdC" +) + +func main() { + fmt.Println("=" + strings.Repeat("=", 59)) + fmt.Println(" 测试脚本 2: kiro2Api_js 风格认证") + fmt.Println("=" + strings.Repeat("=", 59)) + + // Step 1: 加载 token 文件 + fmt.Println("\n[Step 1] 加载 Token 文件") + fmt.Println("-" + strings.Repeat("-", 59)) + + tokenPaths := []string{ + filepath.Join(os.Getenv("USERPROFILE"), ".aws", "sso", "cache", "kiro-auth-token.json"), + "E:/ai_project_2api/kiro-gateway/configs/kiro/kiro-auth-token-1768317098.json", + } + + var tokenData map[string]interface{} + var loadedPath string + + for _, p := range tokenPaths { + data, err := os.ReadFile(p) + if err == nil { + if err := json.Unmarshal(data, &tokenData); err == nil { + loadedPath = p + break + } + } + } + + if tokenData == nil { + fmt.Println("❌ 无法加载任何 token 文件") + return + } + + fmt.Printf("✅ 加载文件: %s\n", loadedPath) + + // 提取字段 - 模拟 kiro2Api_js/src/kiro/auth.js initializeAuth + accessToken, _ := tokenData["accessToken"].(string) + refreshToken, _ := tokenData["refreshToken"].(string) + clientId, _ := tokenData["clientId"].(string) + clientSecret, _ := tokenData["clientSecret"].(string) + authMethod, _ := tokenData["authMethod"].(string) + region, _ := tokenData["region"].(string) + + if region == "" { + region = "us-east-1" + fmt.Println("⚠️ Region 未设置,使用默认值 us-east-1") + } + + fmt.Printf("\nToken 信息:\n") + fmt.Printf(" AuthMethod: %s\n", authMethod) + fmt.Printf(" Region: %s\n", region) + fmt.Printf(" 有 ClientID: %v\n", clientId != "") + fmt.Printf(" 有 ClientSecret: %v\n", clientSecret != "") + + // Step 2: 测试当前 token + fmt.Println("\n[Step 2] 测试当前 AccessToken") + fmt.Println("-" + strings.Repeat("-", 59)) + + if testAPI(accessToken, region) { + fmt.Println("✅ 当前 AccessToken 有效") + return + } + + fmt.Println("⚠️ 当前 AccessToken 无效,开始刷新...") + + // Step 3: 根据 authMethod 选择刷新方式 - 模拟 doRefreshToken + fmt.Println("\n[Step 3] 刷新 Token (JS 风格)") + fmt.Println("-" + strings.Repeat("-", 59)) + + var refreshURL string + var requestBody map[string]interface{} + + // 判断认证方式 - 模拟 kiro2Api_js auth.js doRefreshToken + if authMethod == AUTH_METHOD_SOCIAL { + // Social 认证 + refreshURL = strings.Replace(REFRESH_URL_TEMPLATE, "{{region}}", region, 1) + requestBody = map[string]interface{}{ + "refreshToken": refreshToken, + } + fmt.Println("使用 Social 认证方式") + } else { + // IdC 认证 (默认) + refreshURL = strings.Replace(REFRESH_IDC_URL_TEMPLATE, "{{region}}", region, 1) + requestBody = map[string]interface{}{ + "refreshToken": refreshToken, + "clientId": clientId, + "clientSecret": clientSecret, + "grantType": "refresh_token", + } + fmt.Println("使用 IdC 认证方式") + } + + fmt.Printf("刷新 URL: %s\n", refreshURL) + fmt.Printf("请求字段: %v\n", getKeys(requestBody)) + + // 发送刷新请求 + body, _ := json.Marshal(requestBody) + req, _ := http.NewRequest("POST", refreshURL, bytes.NewBuffer(body)) + req.Header.Set("Content-Type", "application/json") + + client := &http.Client{Timeout: 30 * time.Second} + resp, err := client.Do(req) + if err != nil { + fmt.Printf("❌ 请求失败: %v\n", err) + return + } + defer resp.Body.Close() + + respBody, _ := io.ReadAll(resp.Body) + + fmt.Printf("\n响应状态: HTTP %d\n", resp.StatusCode) + + if resp.StatusCode != 200 { + fmt.Printf("❌ 刷新失败: %s\n", string(respBody)) + + // 分析错误 + var errResp map[string]interface{} + if err := json.Unmarshal(respBody, &errResp); err == nil { + if errType, ok := errResp["error"].(string); ok { + fmt.Printf("错误类型: %s\n", errType) + if errType == "invalid_grant" { + fmt.Println("\n💡 提示: refresh_token 可能已过期,需要重新授权") + } + } + if errDesc, ok := errResp["error_description"].(string); ok { + fmt.Printf("错误描述: %s\n", errDesc) + } + } + return + } + + // 解析响应 + var refreshResp map[string]interface{} + json.Unmarshal(respBody, &refreshResp) + + newAccessToken, _ := refreshResp["accessToken"].(string) + newRefreshToken, _ := refreshResp["refreshToken"].(string) + expiresIn, _ := refreshResp["expiresIn"].(float64) + + fmt.Println("✅ Token 刷新成功!") + fmt.Printf(" 新 AccessToken: %s...\n", truncate(newAccessToken, 50)) + fmt.Printf(" ExpiresIn: %.0f 秒\n", expiresIn) + if newRefreshToken != "" { + fmt.Printf(" 新 RefreshToken: %s...\n", truncate(newRefreshToken, 50)) + } + + // Step 4: 验证新 token + fmt.Println("\n[Step 4] 验证新 Token") + fmt.Println("-" + strings.Repeat("-", 59)) + + if testAPI(newAccessToken, region) { + fmt.Println("✅ 新 Token 验证成功!") + + // 保存新 token - 模拟 saveCredentialsToFile + tokenData["accessToken"] = newAccessToken + if newRefreshToken != "" { + tokenData["refreshToken"] = newRefreshToken + } + tokenData["expiresAt"] = time.Now().Add(time.Duration(expiresIn) * time.Second).Format(time.RFC3339) + + saveData, _ := json.MarshalIndent(tokenData, "", " ") + newPath := strings.TrimSuffix(loadedPath, ".json") + "_js_refreshed.json" + os.WriteFile(newPath, saveData, 0644) + fmt.Printf("✅ 已保存到: %s\n", newPath) + } else { + fmt.Println("❌ 新 Token 验证失败") + } + + fmt.Println("\n" + strings.Repeat("=", 60)) + fmt.Println(" 测试完成") + fmt.Println(strings.Repeat("=", 60)) +} + +func testAPI(accessToken, region string) bool { + url := fmt.Sprintf("https://codewhisperer.%s.amazonaws.com", region) + + payload := map[string]interface{}{ + "origin": "AI_EDITOR", + "isEmailRequired": true, + "resourceType": "AGENTIC_REQUEST", + } + body, _ := json.Marshal(payload) + + req, _ := http.NewRequest("POST", url, bytes.NewBuffer(body)) + req.Header.Set("Content-Type", "application/x-amz-json-1.0") + req.Header.Set("x-amz-target", "AmazonCodeWhispererService.GetUsageLimits") + req.Header.Set("Authorization", "Bearer "+accessToken) + req.Header.Set("Accept", "application/json") + + client := &http.Client{Timeout: 30 * time.Second} + resp, err := client.Do(req) + if err != nil { + return false + } + defer resp.Body.Close() + + return resp.StatusCode == 200 +} + +func getKeys(m map[string]interface{}) []string { + keys := make([]string, 0, len(m)) + for k := range m { + keys = append(keys, k) + } + return keys +} + +func truncate(s string, n int) string { + if len(s) <= n { + return s + } + return s[:n] +} diff --git a/test_kiro_debug.go b/test_kiro_debug.go new file mode 100644 index 00000000..0fbbed6c --- /dev/null +++ b/test_kiro_debug.go @@ -0,0 +1,348 @@ +// 独立测试脚本:排查 Kiro Token 403 错误 +// 运行方式: go run test_kiro_debug.go +package main + +import ( + "bytes" + "encoding/base64" + "encoding/json" + "fmt" + "io" + "net/http" + "os" + "path/filepath" + "strings" + "time" +) + +// Token 结构 - 匹配 Kiro IDE 格式 +type KiroIDEToken struct { + AccessToken string `json:"accessToken"` + RefreshToken string `json:"refreshToken"` + ExpiresAt string `json:"expiresAt"` + ClientIDHash string `json:"clientIdHash,omitempty"` + AuthMethod string `json:"authMethod"` + Provider string `json:"provider"` + Region string `json:"region,omitempty"` +} + +// Token 结构 - 匹配 CLIProxyAPIPlus 格式 +type CLIProxyToken struct { + AccessToken string `json:"access_token"` + RefreshToken string `json:"refresh_token"` + ProfileArn string `json:"profile_arn"` + ExpiresAt string `json:"expires_at"` + AuthMethod string `json:"auth_method"` + Provider string `json:"provider"` + ClientID string `json:"client_id,omitempty"` + ClientSecret string `json:"client_secret,omitempty"` + Email string `json:"email,omitempty"` + Type string `json:"type"` +} + +func main() { + fmt.Println("=" + strings.Repeat("=", 59)) + fmt.Println(" Kiro Token 403 错误排查工具") + fmt.Println("=" + strings.Repeat("=", 59)) + + homeDir, _ := os.UserHomeDir() + + // Step 1: 检查 Kiro IDE Token 文件 + fmt.Println("\n[Step 1] 检查 Kiro IDE Token 文件") + fmt.Println("-" + strings.Repeat("-", 59)) + + ideTokenPath := filepath.Join(homeDir, ".aws", "sso", "cache", "kiro-auth-token.json") + ideToken, err := loadKiroIDEToken(ideTokenPath) + if err != nil { + fmt.Printf("❌ 无法加载 Kiro IDE Token: %v\n", err) + return + } + fmt.Printf("✅ Token 文件: %s\n", ideTokenPath) + fmt.Printf(" AuthMethod: %s\n", ideToken.AuthMethod) + fmt.Printf(" Provider: %s\n", ideToken.Provider) + fmt.Printf(" Region: %s\n", ideToken.Region) + fmt.Printf(" ExpiresAt: %s\n", ideToken.ExpiresAt) + fmt.Printf(" AccessToken (前50字符): %s...\n", truncate(ideToken.AccessToken, 50)) + + // Step 2: 检查 Token 过期状态 + fmt.Println("\n[Step 2] 检查 Token 过期状态") + fmt.Println("-" + strings.Repeat("-", 59)) + + expiresAt, err := parseExpiresAt(ideToken.ExpiresAt) + if err != nil { + fmt.Printf("❌ 无法解析过期时间: %v\n", err) + } else { + now := time.Now() + if now.After(expiresAt) { + fmt.Printf("❌ Token 已过期!过期时间: %s,当前时间: %s\n", expiresAt.Format(time.RFC3339), now.Format(time.RFC3339)) + } else { + remaining := expiresAt.Sub(now) + fmt.Printf("✅ Token 未过期,剩余: %s\n", remaining.Round(time.Second)) + } + } + + // Step 3: 检查 CLIProxyAPIPlus 保存的 Token + fmt.Println("\n[Step 3] 检查 CLIProxyAPIPlus 保存的 Token") + fmt.Println("-" + strings.Repeat("-", 59)) + + cliProxyDir := filepath.Join(homeDir, ".cli-proxy-api") + files, _ := os.ReadDir(cliProxyDir) + for _, f := range files { + if strings.HasPrefix(f.Name(), "kiro") && strings.HasSuffix(f.Name(), ".json") { + filePath := filepath.Join(cliProxyDir, f.Name()) + cliToken, err := loadCLIProxyToken(filePath) + if err != nil { + fmt.Printf("❌ %s: 加载失败 - %v\n", f.Name(), err) + continue + } + fmt.Printf("📄 %s:\n", f.Name()) + fmt.Printf(" AuthMethod: %s\n", cliToken.AuthMethod) + fmt.Printf(" Provider: %s\n", cliToken.Provider) + fmt.Printf(" ExpiresAt: %s\n", cliToken.ExpiresAt) + fmt.Printf(" AccessToken (前50字符): %s...\n", truncate(cliToken.AccessToken, 50)) + + // 比较 Token + if cliToken.AccessToken == ideToken.AccessToken { + fmt.Printf(" ✅ AccessToken 与 IDE Token 一致\n") + } else { + fmt.Printf(" ⚠️ AccessToken 与 IDE Token 不一致!\n") + } + } + } + + // Step 4: 直接测试 Token 有效性 (调用 Kiro API) + fmt.Println("\n[Step 4] 直接测试 Token 有效性") + fmt.Println("-" + strings.Repeat("-", 59)) + + testTokenValidity(ideToken.AccessToken, ideToken.Region) + + // Step 5: 测试不同的请求头格式 + fmt.Println("\n[Step 5] 测试不同的请求头格式") + fmt.Println("-" + strings.Repeat("-", 59)) + + testDifferentHeaders(ideToken.AccessToken, ideToken.Region) + + // Step 6: 解析 JWT 内容 + fmt.Println("\n[Step 6] 解析 JWT Token 内容") + fmt.Println("-" + strings.Repeat("-", 59)) + + parseJWT(ideToken.AccessToken) + + fmt.Println("\n" + strings.Repeat("=", 60)) + fmt.Println(" 排查完成") + fmt.Println(strings.Repeat("=", 60)) +} + +func loadKiroIDEToken(path string) (*KiroIDEToken, error) { + data, err := os.ReadFile(path) + if err != nil { + return nil, err + } + var token KiroIDEToken + if err := json.Unmarshal(data, &token); err != nil { + return nil, err + } + return &token, nil +} + +func loadCLIProxyToken(path string) (*CLIProxyToken, error) { + data, err := os.ReadFile(path) + if err != nil { + return nil, err + } + var token CLIProxyToken + if err := json.Unmarshal(data, &token); err != nil { + return nil, err + } + return &token, nil +} + +func parseExpiresAt(s string) (time.Time, error) { + formats := []string{ + time.RFC3339, + "2006-01-02T15:04:05.000Z", + "2006-01-02T15:04:05Z", + } + for _, f := range formats { + if t, err := time.Parse(f, s); err == nil { + return t, nil + } + } + return time.Time{}, fmt.Errorf("无法解析时间格式: %s", s) +} + +func truncate(s string, n int) string { + if len(s) <= n { + return s + } + return s[:n] +} + +func testTokenValidity(accessToken, region string) { + if region == "" { + region = "us-east-1" + } + + // 测试 GetUsageLimits API + url := fmt.Sprintf("https://codewhisperer.%s.amazonaws.com", region) + + payload := map[string]interface{}{ + "origin": "AI_EDITOR", + "isEmailRequired": true, + "resourceType": "AGENTIC_REQUEST", + } + body, _ := json.Marshal(payload) + + req, _ := http.NewRequest("POST", url, bytes.NewBuffer(body)) + req.Header.Set("Content-Type", "application/x-amz-json-1.0") + req.Header.Set("x-amz-target", "AmazonCodeWhispererService.GetUsageLimits") + req.Header.Set("Authorization", "Bearer "+accessToken) + req.Header.Set("Accept", "application/json") + + fmt.Printf("请求 URL: %s\n", url) + fmt.Printf("请求头:\n") + for k, v := range req.Header { + if k == "Authorization" { + fmt.Printf(" %s: Bearer %s...\n", k, truncate(v[0][7:], 30)) + } else { + fmt.Printf(" %s: %s\n", k, v[0]) + } + } + + client := &http.Client{Timeout: 30 * time.Second} + resp, err := client.Do(req) + if err != nil { + fmt.Printf("❌ 请求失败: %v\n", err) + return + } + defer resp.Body.Close() + + respBody, _ := io.ReadAll(resp.Body) + fmt.Printf("响应状态: %d\n", resp.StatusCode) + fmt.Printf("响应内容: %s\n", string(respBody)) + + if resp.StatusCode == 200 { + fmt.Println("✅ Token 有效!") + } else if resp.StatusCode == 403 { + fmt.Println("❌ Token 无效或已过期 (403)") + } +} + +func testDifferentHeaders(accessToken, region string) { + if region == "" { + region = "us-east-1" + } + + tests := []struct { + name string + headers map[string]string + }{ + { + name: "最小请求头", + headers: map[string]string{ + "Content-Type": "application/json", + "Authorization": "Bearer " + accessToken, + }, + }, + { + name: "模拟 kiro2api_go1 风格", + headers: map[string]string{ + "Content-Type": "application/json", + "Accept": "text/event-stream", + "Authorization": "Bearer " + accessToken, + "x-amzn-kiro-agent-mode": "vibe", + "x-amzn-codewhisperer-optout": "true", + "amz-sdk-invocation-id": "test-invocation-id", + "amz-sdk-request": "attempt=1; max=3", + "x-amz-user-agent": "aws-sdk-js/1.0.27 KiroIDE-0.8.0-abc123", + "User-Agent": "aws-sdk-js/1.0.27 ua/2.1 os/windows#10.0 lang/js md/nodejs#20.16.0 api/codewhispererstreaming#1.0.27 m/E KiroIDE-0.8.0-abc123", + }, + }, + { + name: "模拟 CLIProxyAPIPlus 风格", + headers: map[string]string{ + "Content-Type": "application/x-amz-json-1.0", + "x-amz-target": "AmazonCodeWhispererService.GetUsageLimits", + "Authorization": "Bearer " + accessToken, + "Accept": "application/json", + "amz-sdk-invocation-id": "test-invocation-id", + "amz-sdk-request": "attempt=1; max=1", + "Connection": "close", + }, + }, + } + + url := fmt.Sprintf("https://codewhisperer.%s.amazonaws.com", region) + payload := map[string]interface{}{ + "origin": "AI_EDITOR", + "isEmailRequired": true, + "resourceType": "AGENTIC_REQUEST", + } + body, _ := json.Marshal(payload) + + for _, test := range tests { + fmt.Printf("\n测试: %s\n", test.name) + + req, _ := http.NewRequest("POST", url, bytes.NewBuffer(body)) + for k, v := range test.headers { + req.Header.Set(k, v) + } + + client := &http.Client{Timeout: 30 * time.Second} + resp, err := client.Do(req) + if err != nil { + fmt.Printf(" ❌ 请求失败: %v\n", err) + continue + } + + respBody, _ := io.ReadAll(resp.Body) + resp.Body.Close() + + if resp.StatusCode == 200 { + fmt.Printf(" ✅ 成功 (HTTP %d)\n", resp.StatusCode) + } else { + fmt.Printf(" ❌ 失败 (HTTP %d): %s\n", resp.StatusCode, truncate(string(respBody), 100)) + } + } +} + +func parseJWT(token string) { + parts := strings.Split(token, ".") + if len(parts) < 2 { + fmt.Println("Token 不是 JWT 格式") + return + } + + // 解码 header + headerData, err := base64.RawURLEncoding.DecodeString(parts[0]) + if err != nil { + fmt.Printf("无法解码 JWT header: %v\n", err) + } else { + var header map[string]interface{} + json.Unmarshal(headerData, &header) + fmt.Printf("JWT Header: %v\n", header) + } + + // 解码 payload + payloadData, err := base64.RawURLEncoding.DecodeString(parts[1]) + if err != nil { + fmt.Printf("无法解码 JWT payload: %v\n", err) + } else { + var payload map[string]interface{} + json.Unmarshal(payloadData, &payload) + fmt.Printf("JWT Payload:\n") + for k, v := range payload { + fmt.Printf(" %s: %v\n", k, v) + } + + // 检查过期时间 + if exp, ok := payload["exp"].(float64); ok { + expTime := time.Unix(int64(exp), 0) + if time.Now().After(expTime) { + fmt.Printf(" ⚠️ JWT 已过期! exp=%s\n", expTime.Format(time.RFC3339)) + } else { + fmt.Printf(" ✅ JWT 未过期, 剩余: %s\n", expTime.Sub(time.Now()).Round(time.Second)) + } + } + } +} diff --git a/test_proxy_debug.go b/test_proxy_debug.go new file mode 100644 index 00000000..82369e74 --- /dev/null +++ b/test_proxy_debug.go @@ -0,0 +1,367 @@ +// 测试脚本 2:通过 CLIProxyAPIPlus 代理层排查问题 +// 运行方式: go run test_proxy_debug.go +package main + +import ( + "bytes" + "encoding/json" + "fmt" + "io" + "net/http" + "os" + "path/filepath" + "strings" + "time" +) + +const ( + ProxyURL = "http://localhost:8317" + APIKey = "your-api-key-1" +) + +func main() { + fmt.Println("=" + strings.Repeat("=", 59)) + fmt.Println(" CLIProxyAPIPlus 代理层问题排查") + fmt.Println("=" + strings.Repeat("=", 59)) + + // Step 1: 检查代理服务状态 + fmt.Println("\n[Step 1] 检查代理服务状态") + fmt.Println("-" + strings.Repeat("-", 59)) + + resp, err := http.Get(ProxyURL + "/health") + if err != nil { + fmt.Printf("❌ 代理服务不可达: %v\n", err) + fmt.Println("请确保服务正在运行: go run ./cmd/server/main.go") + return + } + resp.Body.Close() + fmt.Printf("✅ 代理服务正常 (HTTP %d)\n", resp.StatusCode) + + // Step 2: 获取模型列表 + fmt.Println("\n[Step 2] 获取模型列表") + fmt.Println("-" + strings.Repeat("-", 59)) + + models := getModels() + if len(models) == 0 { + fmt.Println("❌ 没有可用的模型,检查凭据加载") + checkCredentials() + return + } + fmt.Printf("✅ 找到 %d 个模型:\n", len(models)) + for _, m := range models { + fmt.Printf(" - %s\n", m) + } + + // Step 3: 测试模型请求 - 捕获详细错误 + fmt.Println("\n[Step 3] 测试模型请求(详细日志)") + fmt.Println("-" + strings.Repeat("-", 59)) + + if len(models) > 0 { + testModel := models[0] + testModelRequest(testModel) + } + + // Step 4: 检查代理内部 Token 状态 + fmt.Println("\n[Step 4] 检查代理服务加载的凭据") + fmt.Println("-" + strings.Repeat("-", 59)) + + checkProxyCredentials() + + // Step 5: 对比直接请求和代理请求 + fmt.Println("\n[Step 5] 对比直接请求 vs 代理请求") + fmt.Println("-" + strings.Repeat("-", 59)) + + compareDirectVsProxy() + + fmt.Println("\n" + strings.Repeat("=", 60)) + fmt.Println(" 排查完成") + fmt.Println(strings.Repeat("=", 60)) +} + +func getModels() []string { + req, _ := http.NewRequest("GET", ProxyURL+"/v1/models", nil) + req.Header.Set("Authorization", "Bearer "+APIKey) + + client := &http.Client{Timeout: 30 * time.Second} + resp, err := client.Do(req) + if err != nil { + fmt.Printf("❌ 请求失败: %v\n", err) + return nil + } + defer resp.Body.Close() + + body, _ := io.ReadAll(resp.Body) + + if resp.StatusCode != 200 { + fmt.Printf("❌ HTTP %d: %s\n", resp.StatusCode, string(body)) + return nil + } + + var result struct { + Data []struct { + ID string `json:"id"` + } `json:"data"` + } + json.Unmarshal(body, &result) + + models := make([]string, len(result.Data)) + for i, m := range result.Data { + models[i] = m.ID + } + return models +} + +func checkCredentials() { + homeDir, _ := os.UserHomeDir() + cliProxyDir := filepath.Join(homeDir, ".cli-proxy-api") + + fmt.Printf("\n检查凭据目录: %s\n", cliProxyDir) + files, err := os.ReadDir(cliProxyDir) + if err != nil { + fmt.Printf("❌ 无法读取目录: %v\n", err) + return + } + + for _, f := range files { + if strings.HasSuffix(f.Name(), ".json") { + fmt.Printf(" 📄 %s\n", f.Name()) + } + } +} + +func testModelRequest(model string) { + fmt.Printf("测试模型: %s\n", model) + + payload := map[string]interface{}{ + "model": model, + "messages": []map[string]string{ + {"role": "user", "content": "Say 'OK' if you receive this."}, + }, + "max_tokens": 50, + "stream": false, + } + body, _ := json.Marshal(payload) + + req, _ := http.NewRequest("POST", ProxyURL+"/v1/chat/completions", bytes.NewBuffer(body)) + req.Header.Set("Authorization", "Bearer "+APIKey) + req.Header.Set("Content-Type", "application/json") + + fmt.Println("\n发送请求:") + fmt.Printf(" URL: %s/v1/chat/completions\n", ProxyURL) + fmt.Printf(" Model: %s\n", model) + + client := &http.Client{Timeout: 60 * time.Second} + resp, err := client.Do(req) + if err != nil { + fmt.Printf("❌ 请求失败: %v\n", err) + return + } + defer resp.Body.Close() + + respBody, _ := io.ReadAll(resp.Body) + + fmt.Printf("\n响应:\n") + fmt.Printf(" Status: %d\n", resp.StatusCode) + fmt.Printf(" Headers:\n") + for k, v := range resp.Header { + fmt.Printf(" %s: %s\n", k, strings.Join(v, ", ")) + } + + // 格式化 JSON 输出 + var prettyJSON bytes.Buffer + if err := json.Indent(&prettyJSON, respBody, " ", " "); err == nil { + fmt.Printf(" Body:\n %s\n", prettyJSON.String()) + } else { + fmt.Printf(" Body: %s\n", string(respBody)) + } + + if resp.StatusCode == 200 { + fmt.Println("\n✅ 请求成功!") + } else { + fmt.Println("\n❌ 请求失败!分析错误原因...") + analyzeError(respBody) + } +} + +func analyzeError(body []byte) { + var errResp struct { + Message string `json:"message"` + Reason string `json:"reason"` + Error struct { + Message string `json:"message"` + Type string `json:"type"` + } `json:"error"` + } + json.Unmarshal(body, &errResp) + + if errResp.Message != "" { + fmt.Printf("错误消息: %s\n", errResp.Message) + } + if errResp.Reason != "" { + fmt.Printf("错误原因: %s\n", errResp.Reason) + } + if errResp.Error.Message != "" { + fmt.Printf("错误详情: %s (类型: %s)\n", errResp.Error.Message, errResp.Error.Type) + } + + // 分析常见错误 + bodyStr := string(body) + if strings.Contains(bodyStr, "bearer token") || strings.Contains(bodyStr, "invalid") { + fmt.Println("\n可能的原因:") + fmt.Println(" 1. Token 已过期 - 需要刷新") + fmt.Println(" 2. Token 格式不正确 - 检查凭据文件") + fmt.Println(" 3. 代理服务加载了旧的 Token") + } +} + +func checkProxyCredentials() { + // 尝试通过管理 API 获取凭据状态 + req, _ := http.NewRequest("GET", ProxyURL+"/v0/management/auth/list", nil) + // 使用配置中的管理密钥 admin123 + req.Header.Set("Authorization", "Bearer admin123") + + client := &http.Client{Timeout: 10 * time.Second} + resp, err := client.Do(req) + if err != nil { + fmt.Printf("❌ 无法访问管理 API: %v\n", err) + return + } + defer resp.Body.Close() + + body, _ := io.ReadAll(resp.Body) + + if resp.StatusCode == 200 { + fmt.Println("管理 API 返回的凭据列表:") + var prettyJSON bytes.Buffer + if err := json.Indent(&prettyJSON, body, " ", " "); err == nil { + fmt.Printf("%s\n", prettyJSON.String()) + } else { + fmt.Printf("%s\n", string(body)) + } + } else { + fmt.Printf("管理 API 返回: HTTP %d\n", resp.StatusCode) + fmt.Printf("响应: %s\n", truncate(string(body), 200)) + } +} + +func compareDirectVsProxy() { + homeDir, _ := os.UserHomeDir() + tokenPath := filepath.Join(homeDir, ".aws", "sso", "cache", "kiro-auth-token.json") + + data, err := os.ReadFile(tokenPath) + if err != nil { + fmt.Printf("❌ 无法读取 Token 文件: %v\n", err) + return + } + + var token struct { + AccessToken string `json:"accessToken"` + Region string `json:"region"` + } + json.Unmarshal(data, &token) + + if token.Region == "" { + token.Region = "us-east-1" + } + + // 直接请求 + fmt.Println("\n1. 直接请求 Kiro API:") + directSuccess := testDirectKiroAPI(token.AccessToken, token.Region) + + // 通过代理请求 + fmt.Println("\n2. 通过代理请求:") + proxySuccess := testProxyAPI() + + // 结论 + fmt.Println("\n结论:") + if directSuccess && !proxySuccess { + fmt.Println(" ⚠️ 直接请求成功,代理请求失败") + fmt.Println(" 问题在于 CLIProxyAPIPlus 代理层") + fmt.Println(" 可能原因:") + fmt.Println(" 1. 代理服务使用了过期的 Token") + fmt.Println(" 2. Token 刷新逻辑有问题") + fmt.Println(" 3. 请求头构造不正确") + } else if directSuccess && proxySuccess { + fmt.Println(" ✅ 两者都成功") + } else if !directSuccess && !proxySuccess { + fmt.Println(" ❌ 两者都失败 - Token 本身可能有问题") + } +} + +func testDirectKiroAPI(accessToken, region string) bool { + url := fmt.Sprintf("https://codewhisperer.%s.amazonaws.com", region) + + payload := map[string]interface{}{ + "origin": "AI_EDITOR", + "isEmailRequired": true, + "resourceType": "AGENTIC_REQUEST", + } + body, _ := json.Marshal(payload) + + req, _ := http.NewRequest("POST", url, bytes.NewBuffer(body)) + req.Header.Set("Content-Type", "application/x-amz-json-1.0") + req.Header.Set("x-amz-target", "AmazonCodeWhispererService.GetUsageLimits") + req.Header.Set("Authorization", "Bearer "+accessToken) + req.Header.Set("Accept", "application/json") + + client := &http.Client{Timeout: 30 * time.Second} + resp, err := client.Do(req) + if err != nil { + fmt.Printf(" ❌ 请求失败: %v\n", err) + return false + } + defer resp.Body.Close() + + if resp.StatusCode == 200 { + fmt.Printf(" ✅ 成功 (HTTP %d)\n", resp.StatusCode) + return true + } + respBody, _ := io.ReadAll(resp.Body) + fmt.Printf(" ❌ 失败 (HTTP %d): %s\n", resp.StatusCode, truncate(string(respBody), 100)) + return false +} + +func testProxyAPI() bool { + models := getModels() + if len(models) == 0 { + fmt.Println(" ❌ 没有可用模型") + return false + } + + payload := map[string]interface{}{ + "model": models[0], + "messages": []map[string]string{ + {"role": "user", "content": "Say OK"}, + }, + "max_tokens": 10, + "stream": false, + } + body, _ := json.Marshal(payload) + + req, _ := http.NewRequest("POST", ProxyURL+"/v1/chat/completions", bytes.NewBuffer(body)) + req.Header.Set("Authorization", "Bearer "+APIKey) + req.Header.Set("Content-Type", "application/json") + + client := &http.Client{Timeout: 60 * time.Second} + resp, err := client.Do(req) + if err != nil { + fmt.Printf(" ❌ 请求失败: %v\n", err) + return false + } + defer resp.Body.Close() + + if resp.StatusCode == 200 { + fmt.Printf(" ✅ 成功 (HTTP %d)\n", resp.StatusCode) + return true + } + respBody, _ := io.ReadAll(resp.Body) + fmt.Printf(" ❌ 失败 (HTTP %d): %s\n", resp.StatusCode, truncate(string(respBody), 100)) + return false +} + +func truncate(s string, n int) string { + if len(s) <= n { + return s + } + return s[:n] + "..." +}