mirror of
https://github.com/router-for-me/CLIProxyAPIPlus.git
synced 2026-03-21 16:40:22 +00:00
feat: add Kiro OAuth web, rate limiter, metrics, fingerprint, background refresh and model converter
This commit is contained in:
@@ -5,10 +5,12 @@ package kiro
|
||||
import (
|
||||
"encoding/base64"
|
||||
"encoding/json"
|
||||
"errors"
|
||||
"fmt"
|
||||
"os"
|
||||
"path/filepath"
|
||||
"strings"
|
||||
"time"
|
||||
)
|
||||
|
||||
// PKCECodes holds PKCE verification codes for OAuth2 PKCE flow
|
||||
@@ -85,6 +87,87 @@ type KiroModel struct {
|
||||
// KiroIDETokenFile is the default path to Kiro IDE's token file
|
||||
const KiroIDETokenFile = ".aws/sso/cache/kiro-auth-token.json"
|
||||
|
||||
// Default retry configuration for file reading
|
||||
const (
|
||||
defaultTokenReadMaxAttempts = 10 // Maximum retry attempts
|
||||
defaultTokenReadBaseDelay = 50 * time.Millisecond // Base delay between retries
|
||||
)
|
||||
|
||||
// isTransientFileError checks if the error is a transient file access error
|
||||
// that may be resolved by retrying (e.g., file locked by another process on Windows).
|
||||
func isTransientFileError(err error) bool {
|
||||
if err == nil {
|
||||
return false
|
||||
}
|
||||
|
||||
// Check for OS-level file access errors (Windows sharing violation, etc.)
|
||||
var pathErr *os.PathError
|
||||
if errors.As(err, &pathErr) {
|
||||
// Windows sharing violation (ERROR_SHARING_VIOLATION = 32)
|
||||
// Windows lock violation (ERROR_LOCK_VIOLATION = 33)
|
||||
errStr := pathErr.Err.Error()
|
||||
if strings.Contains(errStr, "being used by another process") ||
|
||||
strings.Contains(errStr, "sharing violation") ||
|
||||
strings.Contains(errStr, "lock violation") {
|
||||
return true
|
||||
}
|
||||
}
|
||||
|
||||
// Check error message for common transient patterns
|
||||
errMsg := strings.ToLower(err.Error())
|
||||
transientPatterns := []string{
|
||||
"being used by another process",
|
||||
"sharing violation",
|
||||
"lock violation",
|
||||
"access is denied",
|
||||
"unexpected end of json",
|
||||
"unexpected eof",
|
||||
}
|
||||
for _, pattern := range transientPatterns {
|
||||
if strings.Contains(errMsg, pattern) {
|
||||
return true
|
||||
}
|
||||
}
|
||||
|
||||
return false
|
||||
}
|
||||
|
||||
// LoadKiroIDETokenWithRetry loads token data from Kiro IDE's token file with retry logic.
|
||||
// This handles transient file access errors (e.g., file locked by Kiro IDE during write).
|
||||
// maxAttempts: maximum number of retry attempts (default 10 if <= 0)
|
||||
// baseDelay: base delay between retries with exponential backoff (default 50ms if <= 0)
|
||||
func LoadKiroIDETokenWithRetry(maxAttempts int, baseDelay time.Duration) (*KiroTokenData, error) {
|
||||
if maxAttempts <= 0 {
|
||||
maxAttempts = defaultTokenReadMaxAttempts
|
||||
}
|
||||
if baseDelay <= 0 {
|
||||
baseDelay = defaultTokenReadBaseDelay
|
||||
}
|
||||
|
||||
var lastErr error
|
||||
for attempt := 0; attempt < maxAttempts; attempt++ {
|
||||
token, err := LoadKiroIDEToken()
|
||||
if err == nil {
|
||||
return token, nil
|
||||
}
|
||||
lastErr = err
|
||||
|
||||
// Only retry for transient errors
|
||||
if !isTransientFileError(err) {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
// Exponential backoff: delay * 2^attempt, capped at 500ms
|
||||
delay := baseDelay * time.Duration(1<<uint(attempt))
|
||||
if delay > 500*time.Millisecond {
|
||||
delay = 500 * time.Millisecond
|
||||
}
|
||||
time.Sleep(delay)
|
||||
}
|
||||
|
||||
return nil, fmt.Errorf("failed to read token file after %d attempts: %w", maxAttempts, lastErr)
|
||||
}
|
||||
|
||||
// LoadKiroIDEToken loads token data from Kiro IDE's token file.
|
||||
func LoadKiroIDEToken() (*KiroTokenData, error) {
|
||||
homeDir, err := os.UserHomeDir()
|
||||
|
||||
305
internal/auth/kiro/aws.go.bak
Normal file
305
internal/auth/kiro/aws.go.bak
Normal file
@@ -0,0 +1,305 @@
|
||||
// Package kiro provides authentication functionality for AWS CodeWhisperer (Kiro) API.
|
||||
// It includes interfaces and implementations for token storage and authentication methods.
|
||||
package kiro
|
||||
|
||||
import (
|
||||
"encoding/base64"
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"os"
|
||||
"path/filepath"
|
||||
"strings"
|
||||
)
|
||||
|
||||
// PKCECodes holds PKCE verification codes for OAuth2 PKCE flow
|
||||
type PKCECodes struct {
|
||||
// CodeVerifier is the cryptographically random string used to correlate
|
||||
// the authorization request to the token request
|
||||
CodeVerifier string `json:"code_verifier"`
|
||||
// CodeChallenge is the SHA256 hash of the code verifier, base64url-encoded
|
||||
CodeChallenge string `json:"code_challenge"`
|
||||
}
|
||||
|
||||
// KiroTokenData holds OAuth token information from AWS CodeWhisperer (Kiro)
|
||||
type KiroTokenData struct {
|
||||
// AccessToken is the OAuth2 access token for API access
|
||||
AccessToken string `json:"accessToken"`
|
||||
// RefreshToken is used to obtain new access tokens
|
||||
RefreshToken string `json:"refreshToken"`
|
||||
// ProfileArn is the AWS CodeWhisperer profile ARN
|
||||
ProfileArn string `json:"profileArn"`
|
||||
// ExpiresAt is the timestamp when the token expires
|
||||
ExpiresAt string `json:"expiresAt"`
|
||||
// AuthMethod indicates the authentication method used (e.g., "builder-id", "social")
|
||||
AuthMethod string `json:"authMethod"`
|
||||
// Provider indicates the OAuth provider (e.g., "AWS", "Google")
|
||||
Provider string `json:"provider"`
|
||||
// ClientID is the OIDC client ID (needed for token refresh)
|
||||
ClientID string `json:"clientId,omitempty"`
|
||||
// ClientSecret is the OIDC client secret (needed for token refresh)
|
||||
ClientSecret string `json:"clientSecret,omitempty"`
|
||||
// Email is the user's email address (used for file naming)
|
||||
Email string `json:"email,omitempty"`
|
||||
// StartURL is the IDC/Identity Center start URL (only for IDC auth method)
|
||||
StartURL string `json:"startUrl,omitempty"`
|
||||
// Region is the AWS region for IDC authentication (only for IDC auth method)
|
||||
Region string `json:"region,omitempty"`
|
||||
}
|
||||
|
||||
// KiroAuthBundle aggregates authentication data after OAuth flow completion
|
||||
type KiroAuthBundle struct {
|
||||
// TokenData contains the OAuth tokens from the authentication flow
|
||||
TokenData KiroTokenData `json:"token_data"`
|
||||
// LastRefresh is the timestamp of the last token refresh
|
||||
LastRefresh string `json:"last_refresh"`
|
||||
}
|
||||
|
||||
// KiroUsageInfo represents usage information from CodeWhisperer API
|
||||
type KiroUsageInfo struct {
|
||||
// SubscriptionTitle is the subscription plan name (e.g., "KIRO FREE")
|
||||
SubscriptionTitle string `json:"subscription_title"`
|
||||
// CurrentUsage is the current credit usage
|
||||
CurrentUsage float64 `json:"current_usage"`
|
||||
// UsageLimit is the maximum credit limit
|
||||
UsageLimit float64 `json:"usage_limit"`
|
||||
// NextReset is the timestamp of the next usage reset
|
||||
NextReset string `json:"next_reset"`
|
||||
}
|
||||
|
||||
// KiroModel represents a model available through the CodeWhisperer API
|
||||
type KiroModel struct {
|
||||
// ModelID is the unique identifier for the model
|
||||
ModelID string `json:"modelId"`
|
||||
// ModelName is the human-readable name
|
||||
ModelName string `json:"modelName"`
|
||||
// Description is the model description
|
||||
Description string `json:"description"`
|
||||
// RateMultiplier is the credit multiplier for this model
|
||||
RateMultiplier float64 `json:"rateMultiplier"`
|
||||
// RateUnit is the unit for rate calculation (e.g., "credit")
|
||||
RateUnit string `json:"rateUnit"`
|
||||
// MaxInputTokens is the maximum input token limit
|
||||
MaxInputTokens int `json:"maxInputTokens,omitempty"`
|
||||
}
|
||||
|
||||
// KiroIDETokenFile is the default path to Kiro IDE's token file
|
||||
const KiroIDETokenFile = ".aws/sso/cache/kiro-auth-token.json"
|
||||
|
||||
// LoadKiroIDEToken loads token data from Kiro IDE's token file.
|
||||
func LoadKiroIDEToken() (*KiroTokenData, error) {
|
||||
homeDir, err := os.UserHomeDir()
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to get home directory: %w", err)
|
||||
}
|
||||
|
||||
tokenPath := filepath.Join(homeDir, KiroIDETokenFile)
|
||||
data, err := os.ReadFile(tokenPath)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to read Kiro IDE token file (%s): %w", tokenPath, err)
|
||||
}
|
||||
|
||||
var token KiroTokenData
|
||||
if err := json.Unmarshal(data, &token); err != nil {
|
||||
return nil, fmt.Errorf("failed to parse Kiro IDE token: %w", err)
|
||||
}
|
||||
|
||||
if token.AccessToken == "" {
|
||||
return nil, fmt.Errorf("access token is empty in Kiro IDE token file")
|
||||
}
|
||||
|
||||
return &token, nil
|
||||
}
|
||||
|
||||
// LoadKiroTokenFromPath loads token data from a custom path.
|
||||
// This supports multiple accounts by allowing different token files.
|
||||
func LoadKiroTokenFromPath(tokenPath string) (*KiroTokenData, error) {
|
||||
// Expand ~ to home directory
|
||||
if len(tokenPath) > 0 && tokenPath[0] == '~' {
|
||||
homeDir, err := os.UserHomeDir()
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to get home directory: %w", err)
|
||||
}
|
||||
tokenPath = filepath.Join(homeDir, tokenPath[1:])
|
||||
}
|
||||
|
||||
data, err := os.ReadFile(tokenPath)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to read token file (%s): %w", tokenPath, err)
|
||||
}
|
||||
|
||||
var token KiroTokenData
|
||||
if err := json.Unmarshal(data, &token); err != nil {
|
||||
return nil, fmt.Errorf("failed to parse token file: %w", err)
|
||||
}
|
||||
|
||||
if token.AccessToken == "" {
|
||||
return nil, fmt.Errorf("access token is empty in token file")
|
||||
}
|
||||
|
||||
return &token, nil
|
||||
}
|
||||
|
||||
// ListKiroTokenFiles lists all Kiro token files in the cache directory.
|
||||
// This supports multiple accounts by finding all token files.
|
||||
func ListKiroTokenFiles() ([]string, error) {
|
||||
homeDir, err := os.UserHomeDir()
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to get home directory: %w", err)
|
||||
}
|
||||
|
||||
cacheDir := filepath.Join(homeDir, ".aws", "sso", "cache")
|
||||
|
||||
// Check if directory exists
|
||||
if _, err := os.Stat(cacheDir); os.IsNotExist(err) {
|
||||
return nil, nil // No token files
|
||||
}
|
||||
|
||||
entries, err := os.ReadDir(cacheDir)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to read cache directory: %w", err)
|
||||
}
|
||||
|
||||
var tokenFiles []string
|
||||
for _, entry := range entries {
|
||||
if entry.IsDir() {
|
||||
continue
|
||||
}
|
||||
name := entry.Name()
|
||||
// Look for kiro token files only (avoid matching unrelated AWS SSO cache files)
|
||||
if strings.HasSuffix(name, ".json") && strings.HasPrefix(name, "kiro") {
|
||||
tokenFiles = append(tokenFiles, filepath.Join(cacheDir, name))
|
||||
}
|
||||
}
|
||||
|
||||
return tokenFiles, nil
|
||||
}
|
||||
|
||||
// LoadAllKiroTokens loads all Kiro tokens from the cache directory.
|
||||
// This supports multiple accounts.
|
||||
func LoadAllKiroTokens() ([]*KiroTokenData, error) {
|
||||
files, err := ListKiroTokenFiles()
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
var tokens []*KiroTokenData
|
||||
for _, file := range files {
|
||||
token, err := LoadKiroTokenFromPath(file)
|
||||
if err != nil {
|
||||
// Skip invalid token files
|
||||
continue
|
||||
}
|
||||
tokens = append(tokens, token)
|
||||
}
|
||||
|
||||
return tokens, nil
|
||||
}
|
||||
|
||||
// JWTClaims represents the claims we care about from a JWT token.
|
||||
// JWT tokens from Kiro/AWS contain user information in the payload.
|
||||
type JWTClaims struct {
|
||||
Email string `json:"email,omitempty"`
|
||||
Sub string `json:"sub,omitempty"`
|
||||
PreferredUser string `json:"preferred_username,omitempty"`
|
||||
Name string `json:"name,omitempty"`
|
||||
Iss string `json:"iss,omitempty"`
|
||||
}
|
||||
|
||||
// ExtractEmailFromJWT extracts the user's email from a JWT access token.
|
||||
// JWT tokens typically have format: header.payload.signature
|
||||
// The payload is base64url-encoded JSON containing user claims.
|
||||
func ExtractEmailFromJWT(accessToken string) string {
|
||||
if accessToken == "" {
|
||||
return ""
|
||||
}
|
||||
|
||||
// JWT format: header.payload.signature
|
||||
parts := strings.Split(accessToken, ".")
|
||||
if len(parts) != 3 {
|
||||
return ""
|
||||
}
|
||||
|
||||
// Decode the payload (second part)
|
||||
payload := parts[1]
|
||||
|
||||
// Add padding if needed (base64url requires padding)
|
||||
switch len(payload) % 4 {
|
||||
case 2:
|
||||
payload += "=="
|
||||
case 3:
|
||||
payload += "="
|
||||
}
|
||||
|
||||
decoded, err := base64.URLEncoding.DecodeString(payload)
|
||||
if err != nil {
|
||||
// Try RawURLEncoding (no padding)
|
||||
decoded, err = base64.RawURLEncoding.DecodeString(parts[1])
|
||||
if err != nil {
|
||||
return ""
|
||||
}
|
||||
}
|
||||
|
||||
var claims JWTClaims
|
||||
if err := json.Unmarshal(decoded, &claims); err != nil {
|
||||
return ""
|
||||
}
|
||||
|
||||
// Return email if available
|
||||
if claims.Email != "" {
|
||||
return claims.Email
|
||||
}
|
||||
|
||||
// Fallback to preferred_username (some providers use this)
|
||||
if claims.PreferredUser != "" && strings.Contains(claims.PreferredUser, "@") {
|
||||
return claims.PreferredUser
|
||||
}
|
||||
|
||||
// Fallback to sub if it looks like an email
|
||||
if claims.Sub != "" && strings.Contains(claims.Sub, "@") {
|
||||
return claims.Sub
|
||||
}
|
||||
|
||||
return ""
|
||||
}
|
||||
|
||||
// SanitizeEmailForFilename sanitizes an email address for use in a filename.
|
||||
// Replaces special characters with underscores and prevents path traversal attacks.
|
||||
// Also handles URL-encoded characters to prevent encoded path traversal attempts.
|
||||
func SanitizeEmailForFilename(email string) string {
|
||||
if email == "" {
|
||||
return ""
|
||||
}
|
||||
|
||||
result := email
|
||||
|
||||
// First, handle URL-encoded path traversal attempts (%2F, %2E, %5C, etc.)
|
||||
// This prevents encoded characters from bypassing the sanitization.
|
||||
// Note: We replace % last to catch any remaining encodings including double-encoding (%252F)
|
||||
result = strings.ReplaceAll(result, "%2F", "_") // /
|
||||
result = strings.ReplaceAll(result, "%2f", "_")
|
||||
result = strings.ReplaceAll(result, "%5C", "_") // \
|
||||
result = strings.ReplaceAll(result, "%5c", "_")
|
||||
result = strings.ReplaceAll(result, "%2E", "_") // .
|
||||
result = strings.ReplaceAll(result, "%2e", "_")
|
||||
result = strings.ReplaceAll(result, "%00", "_") // null byte
|
||||
result = strings.ReplaceAll(result, "%", "_") // Catch remaining % to prevent double-encoding attacks
|
||||
|
||||
// Replace characters that are problematic in filenames
|
||||
// Keep @ and . in middle but replace other special characters
|
||||
for _, char := range []string{"/", "\\", ":", "*", "?", "\"", "<", ">", "|", " ", "\x00"} {
|
||||
result = strings.ReplaceAll(result, char, "_")
|
||||
}
|
||||
|
||||
// Prevent path traversal: replace leading dots in each path component
|
||||
// This handles cases like "../../../etc/passwd" → "_.._.._.._etc_passwd"
|
||||
parts := strings.Split(result, "_")
|
||||
for i, part := range parts {
|
||||
for strings.HasPrefix(part, ".") {
|
||||
part = "_" + part[1:]
|
||||
}
|
||||
parts[i] = part
|
||||
}
|
||||
result = strings.Join(parts, "_")
|
||||
|
||||
return result
|
||||
}
|
||||
192
internal/auth/kiro/background_refresh.go
Normal file
192
internal/auth/kiro/background_refresh.go
Normal file
@@ -0,0 +1,192 @@
|
||||
package kiro
|
||||
|
||||
import (
|
||||
"context"
|
||||
"log"
|
||||
"sync"
|
||||
"time"
|
||||
|
||||
"github.com/router-for-me/CLIProxyAPI/v6/internal/config"
|
||||
"golang.org/x/sync/semaphore"
|
||||
)
|
||||
|
||||
type Token struct {
|
||||
ID string
|
||||
AccessToken string
|
||||
RefreshToken string
|
||||
ExpiresAt time.Time
|
||||
LastVerified time.Time
|
||||
ClientID string
|
||||
ClientSecret string
|
||||
AuthMethod string
|
||||
Provider string
|
||||
StartURL string
|
||||
Region string
|
||||
}
|
||||
|
||||
type TokenRepository interface {
|
||||
FindOldestUnverified(limit int) []*Token
|
||||
UpdateToken(token *Token) error
|
||||
}
|
||||
|
||||
type RefresherOption func(*BackgroundRefresher)
|
||||
|
||||
func WithInterval(interval time.Duration) RefresherOption {
|
||||
return func(r *BackgroundRefresher) {
|
||||
r.interval = interval
|
||||
}
|
||||
}
|
||||
|
||||
func WithBatchSize(size int) RefresherOption {
|
||||
return func(r *BackgroundRefresher) {
|
||||
r.batchSize = size
|
||||
}
|
||||
}
|
||||
|
||||
func WithConcurrency(concurrency int) RefresherOption {
|
||||
return func(r *BackgroundRefresher) {
|
||||
r.concurrency = concurrency
|
||||
}
|
||||
}
|
||||
|
||||
type BackgroundRefresher struct {
|
||||
interval time.Duration
|
||||
batchSize int
|
||||
concurrency int
|
||||
tokenRepo TokenRepository
|
||||
stopCh chan struct{}
|
||||
wg sync.WaitGroup
|
||||
oauth *KiroOAuth
|
||||
ssoClient *SSOOIDCClient
|
||||
}
|
||||
|
||||
func NewBackgroundRefresher(repo TokenRepository, opts ...RefresherOption) *BackgroundRefresher {
|
||||
r := &BackgroundRefresher{
|
||||
interval: time.Minute,
|
||||
batchSize: 50,
|
||||
concurrency: 10,
|
||||
tokenRepo: repo,
|
||||
stopCh: make(chan struct{}),
|
||||
oauth: nil, // Lazy init - will be set when config available
|
||||
ssoClient: nil, // Lazy init - will be set when config available
|
||||
}
|
||||
for _, opt := range opts {
|
||||
opt(r)
|
||||
}
|
||||
return r
|
||||
}
|
||||
|
||||
// WithConfig sets the configuration for OAuth and SSO clients.
|
||||
func WithConfig(cfg *config.Config) RefresherOption {
|
||||
return func(r *BackgroundRefresher) {
|
||||
r.oauth = NewKiroOAuth(cfg)
|
||||
r.ssoClient = NewSSOOIDCClient(cfg)
|
||||
}
|
||||
}
|
||||
|
||||
func (r *BackgroundRefresher) Start(ctx context.Context) {
|
||||
r.wg.Add(1)
|
||||
go func() {
|
||||
defer r.wg.Done()
|
||||
ticker := time.NewTicker(r.interval)
|
||||
defer ticker.Stop()
|
||||
|
||||
r.refreshBatch(ctx)
|
||||
|
||||
for {
|
||||
select {
|
||||
case <-ctx.Done():
|
||||
return
|
||||
case <-r.stopCh:
|
||||
return
|
||||
case <-ticker.C:
|
||||
r.refreshBatch(ctx)
|
||||
}
|
||||
}
|
||||
}()
|
||||
}
|
||||
|
||||
func (r *BackgroundRefresher) Stop() {
|
||||
close(r.stopCh)
|
||||
r.wg.Wait()
|
||||
}
|
||||
|
||||
func (r *BackgroundRefresher) refreshBatch(ctx context.Context) {
|
||||
tokens := r.tokenRepo.FindOldestUnverified(r.batchSize)
|
||||
if len(tokens) == 0 {
|
||||
return
|
||||
}
|
||||
|
||||
sem := semaphore.NewWeighted(int64(r.concurrency))
|
||||
var wg sync.WaitGroup
|
||||
|
||||
for i, token := range tokens {
|
||||
if i > 0 {
|
||||
select {
|
||||
case <-ctx.Done():
|
||||
return
|
||||
case <-r.stopCh:
|
||||
return
|
||||
case <-time.After(100 * time.Millisecond):
|
||||
}
|
||||
}
|
||||
|
||||
if err := sem.Acquire(ctx, 1); err != nil {
|
||||
return
|
||||
}
|
||||
|
||||
wg.Add(1)
|
||||
go func(t *Token) {
|
||||
defer wg.Done()
|
||||
defer sem.Release(1)
|
||||
r.refreshSingle(ctx, t)
|
||||
}(token)
|
||||
}
|
||||
|
||||
wg.Wait()
|
||||
}
|
||||
|
||||
func (r *BackgroundRefresher) refreshSingle(ctx context.Context, token *Token) {
|
||||
var newTokenData *KiroTokenData
|
||||
var err error
|
||||
|
||||
switch token.AuthMethod {
|
||||
case "idc":
|
||||
newTokenData, err = r.ssoClient.RefreshTokenWithRegion(
|
||||
ctx,
|
||||
token.ClientID,
|
||||
token.ClientSecret,
|
||||
token.RefreshToken,
|
||||
token.Region,
|
||||
token.StartURL,
|
||||
)
|
||||
case "builder-id":
|
||||
newTokenData, err = r.ssoClient.RefreshToken(
|
||||
ctx,
|
||||
token.ClientID,
|
||||
token.ClientSecret,
|
||||
token.RefreshToken,
|
||||
)
|
||||
default:
|
||||
newTokenData, err = r.oauth.RefreshToken(ctx, token.RefreshToken)
|
||||
}
|
||||
|
||||
if err != nil {
|
||||
log.Printf("failed to refresh token %s: %v", token.ID, err)
|
||||
return
|
||||
}
|
||||
|
||||
token.AccessToken = newTokenData.AccessToken
|
||||
token.RefreshToken = newTokenData.RefreshToken
|
||||
token.LastVerified = time.Now()
|
||||
|
||||
if newTokenData.ExpiresAt != "" {
|
||||
if expTime, parseErr := time.Parse(time.RFC3339, newTokenData.ExpiresAt); parseErr == nil {
|
||||
token.ExpiresAt = expTime
|
||||
}
|
||||
}
|
||||
|
||||
if err := r.tokenRepo.UpdateToken(token); err != nil {
|
||||
log.Printf("failed to update token %s: %v", token.ID, err)
|
||||
}
|
||||
}
|
||||
112
internal/auth/kiro/cooldown.go
Normal file
112
internal/auth/kiro/cooldown.go
Normal file
@@ -0,0 +1,112 @@
|
||||
package kiro
|
||||
|
||||
import (
|
||||
"sync"
|
||||
"time"
|
||||
)
|
||||
|
||||
const (
|
||||
CooldownReason429 = "rate_limit_exceeded"
|
||||
CooldownReasonSuspended = "account_suspended"
|
||||
CooldownReasonQuotaExhausted = "quota_exhausted"
|
||||
|
||||
DefaultShortCooldown = 1 * time.Minute
|
||||
MaxShortCooldown = 5 * time.Minute
|
||||
LongCooldown = 24 * time.Hour
|
||||
)
|
||||
|
||||
type CooldownManager struct {
|
||||
mu sync.RWMutex
|
||||
cooldowns map[string]time.Time
|
||||
reasons map[string]string
|
||||
}
|
||||
|
||||
func NewCooldownManager() *CooldownManager {
|
||||
return &CooldownManager{
|
||||
cooldowns: make(map[string]time.Time),
|
||||
reasons: make(map[string]string),
|
||||
}
|
||||
}
|
||||
|
||||
func (cm *CooldownManager) SetCooldown(tokenKey string, duration time.Duration, reason string) {
|
||||
cm.mu.Lock()
|
||||
defer cm.mu.Unlock()
|
||||
cm.cooldowns[tokenKey] = time.Now().Add(duration)
|
||||
cm.reasons[tokenKey] = reason
|
||||
}
|
||||
|
||||
func (cm *CooldownManager) IsInCooldown(tokenKey string) bool {
|
||||
cm.mu.RLock()
|
||||
defer cm.mu.RUnlock()
|
||||
endTime, exists := cm.cooldowns[tokenKey]
|
||||
if !exists {
|
||||
return false
|
||||
}
|
||||
return time.Now().Before(endTime)
|
||||
}
|
||||
|
||||
func (cm *CooldownManager) GetRemainingCooldown(tokenKey string) time.Duration {
|
||||
cm.mu.RLock()
|
||||
defer cm.mu.RUnlock()
|
||||
endTime, exists := cm.cooldowns[tokenKey]
|
||||
if !exists {
|
||||
return 0
|
||||
}
|
||||
remaining := time.Until(endTime)
|
||||
if remaining < 0 {
|
||||
return 0
|
||||
}
|
||||
return remaining
|
||||
}
|
||||
|
||||
func (cm *CooldownManager) GetCooldownReason(tokenKey string) string {
|
||||
cm.mu.RLock()
|
||||
defer cm.mu.RUnlock()
|
||||
return cm.reasons[tokenKey]
|
||||
}
|
||||
|
||||
func (cm *CooldownManager) ClearCooldown(tokenKey string) {
|
||||
cm.mu.Lock()
|
||||
defer cm.mu.Unlock()
|
||||
delete(cm.cooldowns, tokenKey)
|
||||
delete(cm.reasons, tokenKey)
|
||||
}
|
||||
|
||||
func (cm *CooldownManager) CleanupExpired() {
|
||||
cm.mu.Lock()
|
||||
defer cm.mu.Unlock()
|
||||
now := time.Now()
|
||||
for tokenKey, endTime := range cm.cooldowns {
|
||||
if now.After(endTime) {
|
||||
delete(cm.cooldowns, tokenKey)
|
||||
delete(cm.reasons, tokenKey)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func (cm *CooldownManager) StartCleanupRoutine(interval time.Duration, stopCh <-chan struct{}) {
|
||||
ticker := time.NewTicker(interval)
|
||||
defer ticker.Stop()
|
||||
for {
|
||||
select {
|
||||
case <-ticker.C:
|
||||
cm.CleanupExpired()
|
||||
case <-stopCh:
|
||||
return
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func CalculateCooldownFor429(retryCount int) time.Duration {
|
||||
duration := DefaultShortCooldown * time.Duration(1<<retryCount)
|
||||
if duration > MaxShortCooldown {
|
||||
return MaxShortCooldown
|
||||
}
|
||||
return duration
|
||||
}
|
||||
|
||||
func CalculateCooldownUntilNextDay() time.Duration {
|
||||
now := time.Now()
|
||||
nextDay := time.Date(now.Year(), now.Month(), now.Day()+1, 0, 0, 0, 0, now.Location())
|
||||
return time.Until(nextDay)
|
||||
}
|
||||
240
internal/auth/kiro/cooldown_test.go
Normal file
240
internal/auth/kiro/cooldown_test.go
Normal file
@@ -0,0 +1,240 @@
|
||||
package kiro
|
||||
|
||||
import (
|
||||
"sync"
|
||||
"testing"
|
||||
"time"
|
||||
)
|
||||
|
||||
func TestNewCooldownManager(t *testing.T) {
|
||||
cm := NewCooldownManager()
|
||||
if cm == nil {
|
||||
t.Fatal("expected non-nil CooldownManager")
|
||||
}
|
||||
if cm.cooldowns == nil {
|
||||
t.Error("expected non-nil cooldowns map")
|
||||
}
|
||||
if cm.reasons == nil {
|
||||
t.Error("expected non-nil reasons map")
|
||||
}
|
||||
}
|
||||
|
||||
func TestSetCooldown(t *testing.T) {
|
||||
cm := NewCooldownManager()
|
||||
cm.SetCooldown("token1", 1*time.Minute, CooldownReason429)
|
||||
|
||||
if !cm.IsInCooldown("token1") {
|
||||
t.Error("expected token to be in cooldown")
|
||||
}
|
||||
if cm.GetCooldownReason("token1") != CooldownReason429 {
|
||||
t.Errorf("expected reason %s, got %s", CooldownReason429, cm.GetCooldownReason("token1"))
|
||||
}
|
||||
}
|
||||
|
||||
func TestIsInCooldown_NotSet(t *testing.T) {
|
||||
cm := NewCooldownManager()
|
||||
if cm.IsInCooldown("nonexistent") {
|
||||
t.Error("expected non-existent token to not be in cooldown")
|
||||
}
|
||||
}
|
||||
|
||||
func TestIsInCooldown_Expired(t *testing.T) {
|
||||
cm := NewCooldownManager()
|
||||
cm.SetCooldown("token1", 1*time.Millisecond, CooldownReason429)
|
||||
|
||||
time.Sleep(10 * time.Millisecond)
|
||||
|
||||
if cm.IsInCooldown("token1") {
|
||||
t.Error("expected expired cooldown to return false")
|
||||
}
|
||||
}
|
||||
|
||||
func TestGetRemainingCooldown(t *testing.T) {
|
||||
cm := NewCooldownManager()
|
||||
cm.SetCooldown("token1", 1*time.Second, CooldownReason429)
|
||||
|
||||
remaining := cm.GetRemainingCooldown("token1")
|
||||
if remaining <= 0 || remaining > 1*time.Second {
|
||||
t.Errorf("expected remaining cooldown between 0 and 1s, got %v", remaining)
|
||||
}
|
||||
}
|
||||
|
||||
func TestGetRemainingCooldown_NotSet(t *testing.T) {
|
||||
cm := NewCooldownManager()
|
||||
remaining := cm.GetRemainingCooldown("nonexistent")
|
||||
if remaining != 0 {
|
||||
t.Errorf("expected 0 remaining for non-existent, got %v", remaining)
|
||||
}
|
||||
}
|
||||
|
||||
func TestGetRemainingCooldown_Expired(t *testing.T) {
|
||||
cm := NewCooldownManager()
|
||||
cm.SetCooldown("token1", 1*time.Millisecond, CooldownReason429)
|
||||
|
||||
time.Sleep(10 * time.Millisecond)
|
||||
|
||||
remaining := cm.GetRemainingCooldown("token1")
|
||||
if remaining != 0 {
|
||||
t.Errorf("expected 0 remaining for expired, got %v", remaining)
|
||||
}
|
||||
}
|
||||
|
||||
func TestGetCooldownReason(t *testing.T) {
|
||||
cm := NewCooldownManager()
|
||||
cm.SetCooldown("token1", 1*time.Minute, CooldownReasonSuspended)
|
||||
|
||||
reason := cm.GetCooldownReason("token1")
|
||||
if reason != CooldownReasonSuspended {
|
||||
t.Errorf("expected reason %s, got %s", CooldownReasonSuspended, reason)
|
||||
}
|
||||
}
|
||||
|
||||
func TestGetCooldownReason_NotSet(t *testing.T) {
|
||||
cm := NewCooldownManager()
|
||||
reason := cm.GetCooldownReason("nonexistent")
|
||||
if reason != "" {
|
||||
t.Errorf("expected empty reason for non-existent, got %s", reason)
|
||||
}
|
||||
}
|
||||
|
||||
func TestClearCooldown(t *testing.T) {
|
||||
cm := NewCooldownManager()
|
||||
cm.SetCooldown("token1", 1*time.Minute, CooldownReason429)
|
||||
cm.ClearCooldown("token1")
|
||||
|
||||
if cm.IsInCooldown("token1") {
|
||||
t.Error("expected cooldown to be cleared")
|
||||
}
|
||||
if cm.GetCooldownReason("token1") != "" {
|
||||
t.Error("expected reason to be cleared")
|
||||
}
|
||||
}
|
||||
|
||||
func TestClearCooldown_NonExistent(t *testing.T) {
|
||||
cm := NewCooldownManager()
|
||||
cm.ClearCooldown("nonexistent")
|
||||
}
|
||||
|
||||
func TestCleanupExpired(t *testing.T) {
|
||||
cm := NewCooldownManager()
|
||||
cm.SetCooldown("expired1", 1*time.Millisecond, CooldownReason429)
|
||||
cm.SetCooldown("expired2", 1*time.Millisecond, CooldownReason429)
|
||||
cm.SetCooldown("active", 1*time.Hour, CooldownReason429)
|
||||
|
||||
time.Sleep(10 * time.Millisecond)
|
||||
cm.CleanupExpired()
|
||||
|
||||
if cm.GetCooldownReason("expired1") != "" {
|
||||
t.Error("expected expired1 to be cleaned up")
|
||||
}
|
||||
if cm.GetCooldownReason("expired2") != "" {
|
||||
t.Error("expected expired2 to be cleaned up")
|
||||
}
|
||||
if cm.GetCooldownReason("active") != CooldownReason429 {
|
||||
t.Error("expected active to remain")
|
||||
}
|
||||
}
|
||||
|
||||
func TestCalculateCooldownFor429_FirstRetry(t *testing.T) {
|
||||
duration := CalculateCooldownFor429(0)
|
||||
if duration != DefaultShortCooldown {
|
||||
t.Errorf("expected %v for retry 0, got %v", DefaultShortCooldown, duration)
|
||||
}
|
||||
}
|
||||
|
||||
func TestCalculateCooldownFor429_Exponential(t *testing.T) {
|
||||
d1 := CalculateCooldownFor429(1)
|
||||
d2 := CalculateCooldownFor429(2)
|
||||
|
||||
if d2 <= d1 {
|
||||
t.Errorf("expected d2 > d1, got d1=%v, d2=%v", d1, d2)
|
||||
}
|
||||
}
|
||||
|
||||
func TestCalculateCooldownFor429_MaxCap(t *testing.T) {
|
||||
duration := CalculateCooldownFor429(10)
|
||||
if duration > MaxShortCooldown {
|
||||
t.Errorf("expected max %v, got %v", MaxShortCooldown, duration)
|
||||
}
|
||||
}
|
||||
|
||||
func TestCalculateCooldownUntilNextDay(t *testing.T) {
|
||||
duration := CalculateCooldownUntilNextDay()
|
||||
if duration <= 0 || duration > 24*time.Hour {
|
||||
t.Errorf("expected duration between 0 and 24h, got %v", duration)
|
||||
}
|
||||
}
|
||||
|
||||
func TestCooldownManager_ConcurrentAccess(t *testing.T) {
|
||||
cm := NewCooldownManager()
|
||||
const numGoroutines = 50
|
||||
const numOperations = 100
|
||||
|
||||
var wg sync.WaitGroup
|
||||
wg.Add(numGoroutines)
|
||||
|
||||
for i := 0; i < numGoroutines; i++ {
|
||||
go func(id int) {
|
||||
defer wg.Done()
|
||||
tokenKey := "token" + string(rune('a'+id%10))
|
||||
for j := 0; j < numOperations; j++ {
|
||||
switch j % 6 {
|
||||
case 0:
|
||||
cm.SetCooldown(tokenKey, time.Duration(j)*time.Millisecond, CooldownReason429)
|
||||
case 1:
|
||||
cm.IsInCooldown(tokenKey)
|
||||
case 2:
|
||||
cm.GetRemainingCooldown(tokenKey)
|
||||
case 3:
|
||||
cm.GetCooldownReason(tokenKey)
|
||||
case 4:
|
||||
cm.ClearCooldown(tokenKey)
|
||||
case 5:
|
||||
cm.CleanupExpired()
|
||||
}
|
||||
}
|
||||
}(i)
|
||||
}
|
||||
|
||||
wg.Wait()
|
||||
}
|
||||
|
||||
func TestCooldownReasonConstants(t *testing.T) {
|
||||
if CooldownReason429 != "rate_limit_exceeded" {
|
||||
t.Errorf("unexpected CooldownReason429: %s", CooldownReason429)
|
||||
}
|
||||
if CooldownReasonSuspended != "account_suspended" {
|
||||
t.Errorf("unexpected CooldownReasonSuspended: %s", CooldownReasonSuspended)
|
||||
}
|
||||
if CooldownReasonQuotaExhausted != "quota_exhausted" {
|
||||
t.Errorf("unexpected CooldownReasonQuotaExhausted: %s", CooldownReasonQuotaExhausted)
|
||||
}
|
||||
}
|
||||
|
||||
func TestDefaultConstants(t *testing.T) {
|
||||
if DefaultShortCooldown != 1*time.Minute {
|
||||
t.Errorf("unexpected DefaultShortCooldown: %v", DefaultShortCooldown)
|
||||
}
|
||||
if MaxShortCooldown != 5*time.Minute {
|
||||
t.Errorf("unexpected MaxShortCooldown: %v", MaxShortCooldown)
|
||||
}
|
||||
if LongCooldown != 24*time.Hour {
|
||||
t.Errorf("unexpected LongCooldown: %v", LongCooldown)
|
||||
}
|
||||
}
|
||||
|
||||
func TestSetCooldown_OverwritesPrevious(t *testing.T) {
|
||||
cm := NewCooldownManager()
|
||||
cm.SetCooldown("token1", 1*time.Hour, CooldownReason429)
|
||||
cm.SetCooldown("token1", 1*time.Minute, CooldownReasonSuspended)
|
||||
|
||||
reason := cm.GetCooldownReason("token1")
|
||||
if reason != CooldownReasonSuspended {
|
||||
t.Errorf("expected reason to be overwritten to %s, got %s", CooldownReasonSuspended, reason)
|
||||
}
|
||||
|
||||
remaining := cm.GetRemainingCooldown("token1")
|
||||
if remaining > 1*time.Minute {
|
||||
t.Errorf("expected remaining <= 1 minute, got %v", remaining)
|
||||
}
|
||||
}
|
||||
197
internal/auth/kiro/fingerprint.go
Normal file
197
internal/auth/kiro/fingerprint.go
Normal file
@@ -0,0 +1,197 @@
|
||||
package kiro
|
||||
|
||||
import (
|
||||
"crypto/sha256"
|
||||
"encoding/hex"
|
||||
"fmt"
|
||||
"math/rand"
|
||||
"net/http"
|
||||
"sync"
|
||||
"time"
|
||||
)
|
||||
|
||||
// Fingerprint 多维度指纹信息
|
||||
type Fingerprint struct {
|
||||
SDKVersion string // 1.0.20-1.0.27
|
||||
OSType string // darwin/windows/linux
|
||||
OSVersion string // 10.0.22621
|
||||
NodeVersion string // 18.x/20.x/22.x
|
||||
KiroVersion string // 0.3.x-0.8.x
|
||||
KiroHash string // SHA256
|
||||
AcceptLanguage string
|
||||
ScreenResolution string // 1920x1080
|
||||
ColorDepth int // 24
|
||||
HardwareConcurrency int // CPU 核心数
|
||||
TimezoneOffset int
|
||||
}
|
||||
|
||||
// FingerprintManager 指纹管理器
|
||||
type FingerprintManager struct {
|
||||
mu sync.RWMutex
|
||||
fingerprints map[string]*Fingerprint // tokenKey -> fingerprint
|
||||
rng *rand.Rand
|
||||
}
|
||||
|
||||
var (
|
||||
sdkVersions = []string{
|
||||
"1.0.20", "1.0.21", "1.0.22", "1.0.23",
|
||||
"1.0.24", "1.0.25", "1.0.26", "1.0.27",
|
||||
}
|
||||
osTypes = []string{"darwin", "windows", "linux"}
|
||||
osVersions = map[string][]string{
|
||||
"darwin": {"14.0", "14.1", "14.2", "14.3", "14.4", "14.5", "15.0", "15.1"},
|
||||
"windows": {"10.0.19041", "10.0.19042", "10.0.19043", "10.0.19044", "10.0.22621", "10.0.22631"},
|
||||
"linux": {"5.15.0", "6.1.0", "6.2.0", "6.5.0", "6.6.0", "6.8.0"},
|
||||
}
|
||||
nodeVersions = []string{
|
||||
"18.17.0", "18.18.0", "18.19.0", "18.20.0",
|
||||
"20.9.0", "20.10.0", "20.11.0", "20.12.0", "20.13.0",
|
||||
"22.0.0", "22.1.0", "22.2.0", "22.3.0",
|
||||
}
|
||||
kiroVersions = []string{
|
||||
"0.3.0", "0.3.1", "0.4.0", "0.4.1", "0.5.0", "0.5.1",
|
||||
"0.6.0", "0.6.1", "0.7.0", "0.7.1", "0.8.0", "0.8.1",
|
||||
}
|
||||
acceptLanguages = []string{
|
||||
"en-US,en;q=0.9",
|
||||
"en-GB,en;q=0.9",
|
||||
"zh-CN,zh;q=0.9,en;q=0.8",
|
||||
"zh-TW,zh;q=0.9,en;q=0.8",
|
||||
"ja-JP,ja;q=0.9,en;q=0.8",
|
||||
"ko-KR,ko;q=0.9,en;q=0.8",
|
||||
"de-DE,de;q=0.9,en;q=0.8",
|
||||
"fr-FR,fr;q=0.9,en;q=0.8",
|
||||
}
|
||||
screenResolutions = []string{
|
||||
"1920x1080", "2560x1440", "3840x2160",
|
||||
"1366x768", "1440x900", "1680x1050",
|
||||
"2560x1600", "3440x1440",
|
||||
}
|
||||
colorDepths = []int{24, 32}
|
||||
hardwareConcurrencies = []int{4, 6, 8, 10, 12, 16, 20, 24, 32}
|
||||
timezoneOffsets = []int{-480, -420, -360, -300, -240, 0, 60, 120, 480, 540}
|
||||
)
|
||||
|
||||
// NewFingerprintManager 创建指纹管理器
|
||||
func NewFingerprintManager() *FingerprintManager {
|
||||
return &FingerprintManager{
|
||||
fingerprints: make(map[string]*Fingerprint),
|
||||
rng: rand.New(rand.NewSource(time.Now().UnixNano())),
|
||||
}
|
||||
}
|
||||
|
||||
// GetFingerprint 获取或生成 Token 关联的指纹
|
||||
func (fm *FingerprintManager) GetFingerprint(tokenKey string) *Fingerprint {
|
||||
fm.mu.RLock()
|
||||
if fp, exists := fm.fingerprints[tokenKey]; exists {
|
||||
fm.mu.RUnlock()
|
||||
return fp
|
||||
}
|
||||
fm.mu.RUnlock()
|
||||
|
||||
fm.mu.Lock()
|
||||
defer fm.mu.Unlock()
|
||||
|
||||
if fp, exists := fm.fingerprints[tokenKey]; exists {
|
||||
return fp
|
||||
}
|
||||
|
||||
fp := fm.generateFingerprint(tokenKey)
|
||||
fm.fingerprints[tokenKey] = fp
|
||||
return fp
|
||||
}
|
||||
|
||||
// generateFingerprint 生成新的指纹
|
||||
func (fm *FingerprintManager) generateFingerprint(tokenKey string) *Fingerprint {
|
||||
osType := fm.randomChoice(osTypes)
|
||||
osVersion := fm.randomChoice(osVersions[osType])
|
||||
kiroVersion := fm.randomChoice(kiroVersions)
|
||||
|
||||
fp := &Fingerprint{
|
||||
SDKVersion: fm.randomChoice(sdkVersions),
|
||||
OSType: osType,
|
||||
OSVersion: osVersion,
|
||||
NodeVersion: fm.randomChoice(nodeVersions),
|
||||
KiroVersion: kiroVersion,
|
||||
AcceptLanguage: fm.randomChoice(acceptLanguages),
|
||||
ScreenResolution: fm.randomChoice(screenResolutions),
|
||||
ColorDepth: fm.randomIntChoice(colorDepths),
|
||||
HardwareConcurrency: fm.randomIntChoice(hardwareConcurrencies),
|
||||
TimezoneOffset: fm.randomIntChoice(timezoneOffsets),
|
||||
}
|
||||
|
||||
fp.KiroHash = fm.generateKiroHash(tokenKey, kiroVersion, osType)
|
||||
return fp
|
||||
}
|
||||
|
||||
// generateKiroHash 生成 Kiro Hash
|
||||
func (fm *FingerprintManager) generateKiroHash(tokenKey, kiroVersion, osType string) string {
|
||||
data := fmt.Sprintf("%s:%s:%s:%d", tokenKey, kiroVersion, osType, time.Now().UnixNano())
|
||||
hash := sha256.Sum256([]byte(data))
|
||||
return hex.EncodeToString(hash[:])
|
||||
}
|
||||
|
||||
// randomChoice 随机选择字符串
|
||||
func (fm *FingerprintManager) randomChoice(choices []string) string {
|
||||
return choices[fm.rng.Intn(len(choices))]
|
||||
}
|
||||
|
||||
// randomIntChoice 随机选择整数
|
||||
func (fm *FingerprintManager) randomIntChoice(choices []int) int {
|
||||
return choices[fm.rng.Intn(len(choices))]
|
||||
}
|
||||
|
||||
// ApplyToRequest 将指纹信息应用到 HTTP 请求头
|
||||
func (fp *Fingerprint) ApplyToRequest(req *http.Request) {
|
||||
req.Header.Set("X-Kiro-SDK-Version", fp.SDKVersion)
|
||||
req.Header.Set("X-Kiro-OS-Type", fp.OSType)
|
||||
req.Header.Set("X-Kiro-OS-Version", fp.OSVersion)
|
||||
req.Header.Set("X-Kiro-Node-Version", fp.NodeVersion)
|
||||
req.Header.Set("X-Kiro-Version", fp.KiroVersion)
|
||||
req.Header.Set("X-Kiro-Hash", fp.KiroHash)
|
||||
req.Header.Set("Accept-Language", fp.AcceptLanguage)
|
||||
req.Header.Set("X-Screen-Resolution", fp.ScreenResolution)
|
||||
req.Header.Set("X-Color-Depth", fmt.Sprintf("%d", fp.ColorDepth))
|
||||
req.Header.Set("X-Hardware-Concurrency", fmt.Sprintf("%d", fp.HardwareConcurrency))
|
||||
req.Header.Set("X-Timezone-Offset", fmt.Sprintf("%d", fp.TimezoneOffset))
|
||||
}
|
||||
|
||||
// RemoveFingerprint 移除 Token 关联的指纹
|
||||
func (fm *FingerprintManager) RemoveFingerprint(tokenKey string) {
|
||||
fm.mu.Lock()
|
||||
defer fm.mu.Unlock()
|
||||
delete(fm.fingerprints, tokenKey)
|
||||
}
|
||||
|
||||
// Count 返回当前管理的指纹数量
|
||||
func (fm *FingerprintManager) Count() int {
|
||||
fm.mu.RLock()
|
||||
defer fm.mu.RUnlock()
|
||||
return len(fm.fingerprints)
|
||||
}
|
||||
|
||||
// BuildUserAgent 构建 User-Agent 字符串 (Kiro IDE 风格)
|
||||
// 格式: aws-sdk-js/{SDKVersion} ua/2.1 os/{OSType}#{OSVersion} lang/js md/nodejs#{NodeVersion} api/codewhispererstreaming#{SDKVersion} m/E KiroIDE-{KiroVersion}-{KiroHash}
|
||||
func (fp *Fingerprint) BuildUserAgent() string {
|
||||
return fmt.Sprintf(
|
||||
"aws-sdk-js/%s ua/2.1 os/%s#%s lang/js md/nodejs#%s api/codewhispererstreaming#%s m/E KiroIDE-%s-%s",
|
||||
fp.SDKVersion,
|
||||
fp.OSType,
|
||||
fp.OSVersion,
|
||||
fp.NodeVersion,
|
||||
fp.SDKVersion,
|
||||
fp.KiroVersion,
|
||||
fp.KiroHash,
|
||||
)
|
||||
}
|
||||
|
||||
// BuildAmzUserAgent 构建 X-Amz-User-Agent 字符串
|
||||
// 格式: aws-sdk-js/{SDKVersion} KiroIDE-{KiroVersion}-{KiroHash}
|
||||
func (fp *Fingerprint) BuildAmzUserAgent() string {
|
||||
return fmt.Sprintf(
|
||||
"aws-sdk-js/%s KiroIDE-%s-%s",
|
||||
fp.SDKVersion,
|
||||
fp.KiroVersion,
|
||||
fp.KiroHash,
|
||||
)
|
||||
}
|
||||
227
internal/auth/kiro/fingerprint_test.go
Normal file
227
internal/auth/kiro/fingerprint_test.go
Normal file
@@ -0,0 +1,227 @@
|
||||
package kiro
|
||||
|
||||
import (
|
||||
"net/http"
|
||||
"sync"
|
||||
"testing"
|
||||
)
|
||||
|
||||
func TestNewFingerprintManager(t *testing.T) {
|
||||
fm := NewFingerprintManager()
|
||||
if fm == nil {
|
||||
t.Fatal("expected non-nil FingerprintManager")
|
||||
}
|
||||
if fm.fingerprints == nil {
|
||||
t.Error("expected non-nil fingerprints map")
|
||||
}
|
||||
if fm.rng == nil {
|
||||
t.Error("expected non-nil rng")
|
||||
}
|
||||
}
|
||||
|
||||
func TestGetFingerprint_NewToken(t *testing.T) {
|
||||
fm := NewFingerprintManager()
|
||||
fp := fm.GetFingerprint("token1")
|
||||
|
||||
if fp == nil {
|
||||
t.Fatal("expected non-nil Fingerprint")
|
||||
}
|
||||
if fp.SDKVersion == "" {
|
||||
t.Error("expected non-empty SDKVersion")
|
||||
}
|
||||
if fp.OSType == "" {
|
||||
t.Error("expected non-empty OSType")
|
||||
}
|
||||
if fp.OSVersion == "" {
|
||||
t.Error("expected non-empty OSVersion")
|
||||
}
|
||||
if fp.NodeVersion == "" {
|
||||
t.Error("expected non-empty NodeVersion")
|
||||
}
|
||||
if fp.KiroVersion == "" {
|
||||
t.Error("expected non-empty KiroVersion")
|
||||
}
|
||||
if fp.KiroHash == "" {
|
||||
t.Error("expected non-empty KiroHash")
|
||||
}
|
||||
if fp.AcceptLanguage == "" {
|
||||
t.Error("expected non-empty AcceptLanguage")
|
||||
}
|
||||
if fp.ScreenResolution == "" {
|
||||
t.Error("expected non-empty ScreenResolution")
|
||||
}
|
||||
if fp.ColorDepth == 0 {
|
||||
t.Error("expected non-zero ColorDepth")
|
||||
}
|
||||
if fp.HardwareConcurrency == 0 {
|
||||
t.Error("expected non-zero HardwareConcurrency")
|
||||
}
|
||||
}
|
||||
|
||||
func TestGetFingerprint_SameTokenReturnsSameFingerprint(t *testing.T) {
|
||||
fm := NewFingerprintManager()
|
||||
fp1 := fm.GetFingerprint("token1")
|
||||
fp2 := fm.GetFingerprint("token1")
|
||||
|
||||
if fp1 != fp2 {
|
||||
t.Error("expected same fingerprint for same token")
|
||||
}
|
||||
}
|
||||
|
||||
func TestGetFingerprint_DifferentTokens(t *testing.T) {
|
||||
fm := NewFingerprintManager()
|
||||
fp1 := fm.GetFingerprint("token1")
|
||||
fp2 := fm.GetFingerprint("token2")
|
||||
|
||||
if fp1 == fp2 {
|
||||
t.Error("expected different fingerprints for different tokens")
|
||||
}
|
||||
}
|
||||
|
||||
func TestRemoveFingerprint(t *testing.T) {
|
||||
fm := NewFingerprintManager()
|
||||
fm.GetFingerprint("token1")
|
||||
if fm.Count() != 1 {
|
||||
t.Fatalf("expected count 1, got %d", fm.Count())
|
||||
}
|
||||
|
||||
fm.RemoveFingerprint("token1")
|
||||
if fm.Count() != 0 {
|
||||
t.Errorf("expected count 0, got %d", fm.Count())
|
||||
}
|
||||
}
|
||||
|
||||
func TestRemoveFingerprint_NonExistent(t *testing.T) {
|
||||
fm := NewFingerprintManager()
|
||||
fm.RemoveFingerprint("nonexistent")
|
||||
if fm.Count() != 0 {
|
||||
t.Errorf("expected count 0, got %d", fm.Count())
|
||||
}
|
||||
}
|
||||
|
||||
func TestCount(t *testing.T) {
|
||||
fm := NewFingerprintManager()
|
||||
if fm.Count() != 0 {
|
||||
t.Errorf("expected count 0, got %d", fm.Count())
|
||||
}
|
||||
|
||||
fm.GetFingerprint("token1")
|
||||
fm.GetFingerprint("token2")
|
||||
fm.GetFingerprint("token3")
|
||||
|
||||
if fm.Count() != 3 {
|
||||
t.Errorf("expected count 3, got %d", fm.Count())
|
||||
}
|
||||
}
|
||||
|
||||
func TestApplyToRequest(t *testing.T) {
|
||||
fm := NewFingerprintManager()
|
||||
fp := fm.GetFingerprint("token1")
|
||||
|
||||
req, _ := http.NewRequest("GET", "http://example.com", nil)
|
||||
fp.ApplyToRequest(req)
|
||||
|
||||
if req.Header.Get("X-Kiro-SDK-Version") != fp.SDKVersion {
|
||||
t.Error("X-Kiro-SDK-Version header mismatch")
|
||||
}
|
||||
if req.Header.Get("X-Kiro-OS-Type") != fp.OSType {
|
||||
t.Error("X-Kiro-OS-Type header mismatch")
|
||||
}
|
||||
if req.Header.Get("X-Kiro-OS-Version") != fp.OSVersion {
|
||||
t.Error("X-Kiro-OS-Version header mismatch")
|
||||
}
|
||||
if req.Header.Get("X-Kiro-Node-Version") != fp.NodeVersion {
|
||||
t.Error("X-Kiro-Node-Version header mismatch")
|
||||
}
|
||||
if req.Header.Get("X-Kiro-Version") != fp.KiroVersion {
|
||||
t.Error("X-Kiro-Version header mismatch")
|
||||
}
|
||||
if req.Header.Get("X-Kiro-Hash") != fp.KiroHash {
|
||||
t.Error("X-Kiro-Hash header mismatch")
|
||||
}
|
||||
if req.Header.Get("Accept-Language") != fp.AcceptLanguage {
|
||||
t.Error("Accept-Language header mismatch")
|
||||
}
|
||||
if req.Header.Get("X-Screen-Resolution") != fp.ScreenResolution {
|
||||
t.Error("X-Screen-Resolution header mismatch")
|
||||
}
|
||||
}
|
||||
|
||||
func TestGetFingerprint_OSVersionMatchesOSType(t *testing.T) {
|
||||
fm := NewFingerprintManager()
|
||||
|
||||
for i := 0; i < 20; i++ {
|
||||
fp := fm.GetFingerprint("token" + string(rune('a'+i)))
|
||||
validVersions := osVersions[fp.OSType]
|
||||
found := false
|
||||
for _, v := range validVersions {
|
||||
if v == fp.OSVersion {
|
||||
found = true
|
||||
break
|
||||
}
|
||||
}
|
||||
if !found {
|
||||
t.Errorf("OS version %s not valid for OS type %s", fp.OSVersion, fp.OSType)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func TestFingerprintManager_ConcurrentAccess(t *testing.T) {
|
||||
fm := NewFingerprintManager()
|
||||
const numGoroutines = 100
|
||||
const numOperations = 100
|
||||
|
||||
var wg sync.WaitGroup
|
||||
wg.Add(numGoroutines)
|
||||
|
||||
for i := 0; i < numGoroutines; i++ {
|
||||
go func(id int) {
|
||||
defer wg.Done()
|
||||
for j := 0; j < numOperations; j++ {
|
||||
tokenKey := "token" + string(rune('a'+id%26))
|
||||
switch j % 4 {
|
||||
case 0:
|
||||
fm.GetFingerprint(tokenKey)
|
||||
case 1:
|
||||
fm.Count()
|
||||
case 2:
|
||||
fp := fm.GetFingerprint(tokenKey)
|
||||
req, _ := http.NewRequest("GET", "http://example.com", nil)
|
||||
fp.ApplyToRequest(req)
|
||||
case 3:
|
||||
fm.RemoveFingerprint(tokenKey)
|
||||
}
|
||||
}
|
||||
}(i)
|
||||
}
|
||||
|
||||
wg.Wait()
|
||||
}
|
||||
|
||||
func TestKiroHashUniqueness(t *testing.T) {
|
||||
fm := NewFingerprintManager()
|
||||
hashes := make(map[string]bool)
|
||||
|
||||
for i := 0; i < 100; i++ {
|
||||
fp := fm.GetFingerprint("token" + string(rune(i)))
|
||||
if hashes[fp.KiroHash] {
|
||||
t.Errorf("duplicate KiroHash detected: %s", fp.KiroHash)
|
||||
}
|
||||
hashes[fp.KiroHash] = true
|
||||
}
|
||||
}
|
||||
|
||||
func TestKiroHashFormat(t *testing.T) {
|
||||
fm := NewFingerprintManager()
|
||||
fp := fm.GetFingerprint("token1")
|
||||
|
||||
if len(fp.KiroHash) != 64 {
|
||||
t.Errorf("expected KiroHash length 64 (SHA256 hex), got %d", len(fp.KiroHash))
|
||||
}
|
||||
|
||||
for _, c := range fp.KiroHash {
|
||||
if !((c >= '0' && c <= '9') || (c >= 'a' && c <= 'f')) {
|
||||
t.Errorf("invalid hex character in KiroHash: %c", c)
|
||||
}
|
||||
}
|
||||
}
|
||||
174
internal/auth/kiro/jitter.go
Normal file
174
internal/auth/kiro/jitter.go
Normal file
@@ -0,0 +1,174 @@
|
||||
package kiro
|
||||
|
||||
import (
|
||||
"math/rand"
|
||||
"sync"
|
||||
"time"
|
||||
)
|
||||
|
||||
// Jitter configuration constants
|
||||
const (
|
||||
// JitterPercent is the default percentage of jitter to apply (±30%)
|
||||
JitterPercent = 0.30
|
||||
|
||||
// Human-like delay ranges
|
||||
ShortDelayMin = 50 * time.Millisecond // Minimum for rapid consecutive operations
|
||||
ShortDelayMax = 200 * time.Millisecond // Maximum for rapid consecutive operations
|
||||
NormalDelayMin = 1 * time.Second // Minimum for normal thinking time
|
||||
NormalDelayMax = 3 * time.Second // Maximum for normal thinking time
|
||||
LongDelayMin = 5 * time.Second // Minimum for reading/resting
|
||||
LongDelayMax = 10 * time.Second // Maximum for reading/resting
|
||||
|
||||
// Probability thresholds for human-like behavior
|
||||
ShortDelayProbability = 0.20 // 20% chance of short delay (consecutive ops)
|
||||
LongDelayProbability = 0.05 // 5% chance of long delay (reading/resting)
|
||||
NormalDelayProbability = 0.75 // 75% chance of normal delay (thinking)
|
||||
)
|
||||
|
||||
var (
|
||||
jitterRand *rand.Rand
|
||||
jitterRandOnce sync.Once
|
||||
jitterMu sync.Mutex
|
||||
lastRequestTime time.Time
|
||||
)
|
||||
|
||||
// initJitterRand initializes the random number generator for jitter calculations.
|
||||
// Uses a time-based seed for unpredictable but reproducible randomness.
|
||||
func initJitterRand() {
|
||||
jitterRandOnce.Do(func() {
|
||||
jitterRand = rand.New(rand.NewSource(time.Now().UnixNano()))
|
||||
})
|
||||
}
|
||||
|
||||
// RandomDelay generates a random delay between min and max duration.
|
||||
// Thread-safe implementation using mutex protection.
|
||||
func RandomDelay(min, max time.Duration) time.Duration {
|
||||
initJitterRand()
|
||||
jitterMu.Lock()
|
||||
defer jitterMu.Unlock()
|
||||
|
||||
if min >= max {
|
||||
return min
|
||||
}
|
||||
|
||||
rangeMs := max.Milliseconds() - min.Milliseconds()
|
||||
randomMs := jitterRand.Int63n(rangeMs)
|
||||
return min + time.Duration(randomMs)*time.Millisecond
|
||||
}
|
||||
|
||||
// JitterDelay adds jitter to a base delay.
|
||||
// Applies ±jitterPercent variation to the base delay.
|
||||
// For example, JitterDelay(1*time.Second, 0.30) returns a value between 700ms and 1300ms.
|
||||
func JitterDelay(baseDelay time.Duration, jitterPercent float64) time.Duration {
|
||||
initJitterRand()
|
||||
jitterMu.Lock()
|
||||
defer jitterMu.Unlock()
|
||||
|
||||
if jitterPercent <= 0 || jitterPercent > 1 {
|
||||
jitterPercent = JitterPercent
|
||||
}
|
||||
|
||||
// Calculate jitter range: base * jitterPercent
|
||||
jitterRange := float64(baseDelay) * jitterPercent
|
||||
|
||||
// Generate random value in range [-jitterRange, +jitterRange]
|
||||
jitter := (jitterRand.Float64()*2 - 1) * jitterRange
|
||||
|
||||
result := time.Duration(float64(baseDelay) + jitter)
|
||||
if result < 0 {
|
||||
return 0
|
||||
}
|
||||
return result
|
||||
}
|
||||
|
||||
// JitterDelayDefault applies the default ±30% jitter to a base delay.
|
||||
func JitterDelayDefault(baseDelay time.Duration) time.Duration {
|
||||
return JitterDelay(baseDelay, JitterPercent)
|
||||
}
|
||||
|
||||
// HumanLikeDelay generates a delay that mimics human behavior patterns.
|
||||
// The delay is selected based on probability distribution:
|
||||
// - 20% chance: Short delay (50-200ms) - simulates consecutive rapid operations
|
||||
// - 75% chance: Normal delay (1-3s) - simulates thinking/reading time
|
||||
// - 5% chance: Long delay (5-10s) - simulates breaks/reading longer content
|
||||
//
|
||||
// Returns the delay duration (caller should call time.Sleep with this value).
|
||||
func HumanLikeDelay() time.Duration {
|
||||
initJitterRand()
|
||||
jitterMu.Lock()
|
||||
defer jitterMu.Unlock()
|
||||
|
||||
// Track time since last request for adaptive behavior
|
||||
now := time.Now()
|
||||
timeSinceLastRequest := now.Sub(lastRequestTime)
|
||||
lastRequestTime = now
|
||||
|
||||
// If requests are very close together, use short delay
|
||||
if timeSinceLastRequest < 500*time.Millisecond && timeSinceLastRequest > 0 {
|
||||
rangeMs := ShortDelayMax.Milliseconds() - ShortDelayMin.Milliseconds()
|
||||
randomMs := jitterRand.Int63n(rangeMs)
|
||||
return ShortDelayMin + time.Duration(randomMs)*time.Millisecond
|
||||
}
|
||||
|
||||
// Otherwise, use probability-based selection
|
||||
roll := jitterRand.Float64()
|
||||
|
||||
var min, max time.Duration
|
||||
switch {
|
||||
case roll < ShortDelayProbability:
|
||||
// Short delay - consecutive operations
|
||||
min, max = ShortDelayMin, ShortDelayMax
|
||||
case roll < ShortDelayProbability+LongDelayProbability:
|
||||
// Long delay - reading/resting
|
||||
min, max = LongDelayMin, LongDelayMax
|
||||
default:
|
||||
// Normal delay - thinking time
|
||||
min, max = NormalDelayMin, NormalDelayMax
|
||||
}
|
||||
|
||||
rangeMs := max.Milliseconds() - min.Milliseconds()
|
||||
randomMs := jitterRand.Int63n(rangeMs)
|
||||
return min + time.Duration(randomMs)*time.Millisecond
|
||||
}
|
||||
|
||||
// ApplyHumanLikeDelay applies human-like delay by sleeping.
|
||||
// This is a convenience function that combines HumanLikeDelay with time.Sleep.
|
||||
func ApplyHumanLikeDelay() {
|
||||
delay := HumanLikeDelay()
|
||||
if delay > 0 {
|
||||
time.Sleep(delay)
|
||||
}
|
||||
}
|
||||
|
||||
// ExponentialBackoffWithJitter calculates retry delay using exponential backoff with jitter.
|
||||
// Formula: min(baseDelay * 2^attempt + jitter, maxDelay)
|
||||
// This helps prevent thundering herd problem when multiple clients retry simultaneously.
|
||||
func ExponentialBackoffWithJitter(attempt int, baseDelay, maxDelay time.Duration) time.Duration {
|
||||
if attempt < 0 {
|
||||
attempt = 0
|
||||
}
|
||||
|
||||
// Calculate exponential backoff: baseDelay * 2^attempt
|
||||
backoff := baseDelay * time.Duration(1<<uint(attempt))
|
||||
if backoff > maxDelay {
|
||||
backoff = maxDelay
|
||||
}
|
||||
|
||||
// Add ±30% jitter
|
||||
return JitterDelay(backoff, JitterPercent)
|
||||
}
|
||||
|
||||
// ShouldSkipDelay determines if delay should be skipped based on context.
|
||||
// Returns true for streaming responses, WebSocket connections, etc.
|
||||
// This function can be extended to check additional skip conditions.
|
||||
func ShouldSkipDelay(isStreaming bool) bool {
|
||||
return isStreaming
|
||||
}
|
||||
|
||||
// ResetLastRequestTime resets the last request time tracker.
|
||||
// Useful for testing or when starting a new session.
|
||||
func ResetLastRequestTime() {
|
||||
jitterMu.Lock()
|
||||
defer jitterMu.Unlock()
|
||||
lastRequestTime = time.Time{}
|
||||
}
|
||||
187
internal/auth/kiro/metrics.go
Normal file
187
internal/auth/kiro/metrics.go
Normal file
@@ -0,0 +1,187 @@
|
||||
package kiro
|
||||
|
||||
import (
|
||||
"math"
|
||||
"sync"
|
||||
"time"
|
||||
)
|
||||
|
||||
// TokenMetrics holds performance metrics for a single token.
|
||||
type TokenMetrics struct {
|
||||
SuccessRate float64 // Success rate (0.0 - 1.0)
|
||||
AvgLatency float64 // Average latency in milliseconds
|
||||
QuotaRemaining float64 // Remaining quota (0.0 - 1.0)
|
||||
LastUsed time.Time // Last usage timestamp
|
||||
FailCount int // Consecutive failure count
|
||||
TotalRequests int // Total request count
|
||||
successCount int // Internal: successful request count
|
||||
totalLatency float64 // Internal: cumulative latency
|
||||
}
|
||||
|
||||
// TokenScorer manages token metrics and scoring.
|
||||
type TokenScorer struct {
|
||||
mu sync.RWMutex
|
||||
metrics map[string]*TokenMetrics
|
||||
|
||||
// Scoring weights
|
||||
successRateWeight float64
|
||||
quotaWeight float64
|
||||
latencyWeight float64
|
||||
lastUsedWeight float64
|
||||
failPenaltyMultiplier float64
|
||||
}
|
||||
|
||||
// NewTokenScorer creates a new TokenScorer with default weights.
|
||||
func NewTokenScorer() *TokenScorer {
|
||||
return &TokenScorer{
|
||||
metrics: make(map[string]*TokenMetrics),
|
||||
successRateWeight: 0.4,
|
||||
quotaWeight: 0.25,
|
||||
latencyWeight: 0.2,
|
||||
lastUsedWeight: 0.15,
|
||||
failPenaltyMultiplier: 0.1,
|
||||
}
|
||||
}
|
||||
|
||||
// getOrCreateMetrics returns existing metrics or creates new ones.
|
||||
func (s *TokenScorer) getOrCreateMetrics(tokenKey string) *TokenMetrics {
|
||||
if m, ok := s.metrics[tokenKey]; ok {
|
||||
return m
|
||||
}
|
||||
m := &TokenMetrics{
|
||||
SuccessRate: 1.0,
|
||||
QuotaRemaining: 1.0,
|
||||
}
|
||||
s.metrics[tokenKey] = m
|
||||
return m
|
||||
}
|
||||
|
||||
// RecordRequest records the result of a request for a token.
|
||||
func (s *TokenScorer) RecordRequest(tokenKey string, success bool, latency time.Duration) {
|
||||
s.mu.Lock()
|
||||
defer s.mu.Unlock()
|
||||
|
||||
m := s.getOrCreateMetrics(tokenKey)
|
||||
m.TotalRequests++
|
||||
m.LastUsed = time.Now()
|
||||
m.totalLatency += float64(latency.Milliseconds())
|
||||
|
||||
if success {
|
||||
m.successCount++
|
||||
m.FailCount = 0
|
||||
} else {
|
||||
m.FailCount++
|
||||
}
|
||||
|
||||
// Update derived metrics
|
||||
if m.TotalRequests > 0 {
|
||||
m.SuccessRate = float64(m.successCount) / float64(m.TotalRequests)
|
||||
m.AvgLatency = m.totalLatency / float64(m.TotalRequests)
|
||||
}
|
||||
}
|
||||
|
||||
// SetQuotaRemaining updates the remaining quota for a token.
|
||||
func (s *TokenScorer) SetQuotaRemaining(tokenKey string, quota float64) {
|
||||
s.mu.Lock()
|
||||
defer s.mu.Unlock()
|
||||
|
||||
m := s.getOrCreateMetrics(tokenKey)
|
||||
m.QuotaRemaining = quota
|
||||
}
|
||||
|
||||
// GetMetrics returns a copy of the metrics for a token.
|
||||
func (s *TokenScorer) GetMetrics(tokenKey string) *TokenMetrics {
|
||||
s.mu.RLock()
|
||||
defer s.mu.RUnlock()
|
||||
|
||||
if m, ok := s.metrics[tokenKey]; ok {
|
||||
copy := *m
|
||||
return ©
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
// CalculateScore computes the score for a token (higher is better).
|
||||
func (s *TokenScorer) CalculateScore(tokenKey string) float64 {
|
||||
s.mu.RLock()
|
||||
defer s.mu.RUnlock()
|
||||
|
||||
m, ok := s.metrics[tokenKey]
|
||||
if !ok {
|
||||
return 1.0 // New tokens get a high initial score
|
||||
}
|
||||
|
||||
// Success rate component (0-1)
|
||||
successScore := m.SuccessRate
|
||||
|
||||
// Quota component (0-1)
|
||||
quotaScore := m.QuotaRemaining
|
||||
|
||||
// Latency component (normalized, lower is better)
|
||||
// Using exponential decay: score = e^(-latency/1000)
|
||||
// 1000ms latency -> ~0.37 score, 100ms -> ~0.90 score
|
||||
latencyScore := math.Exp(-m.AvgLatency / 1000.0)
|
||||
if m.TotalRequests == 0 {
|
||||
latencyScore = 1.0
|
||||
}
|
||||
|
||||
// Last used component (prefer tokens not recently used)
|
||||
// Score increases as time since last use increases
|
||||
timeSinceUse := time.Since(m.LastUsed).Seconds()
|
||||
// Normalize: 60 seconds -> ~0.63 score, 0 seconds -> 0 score
|
||||
lastUsedScore := 1.0 - math.Exp(-timeSinceUse/60.0)
|
||||
if m.LastUsed.IsZero() {
|
||||
lastUsedScore = 1.0
|
||||
}
|
||||
|
||||
// Calculate weighted score
|
||||
score := s.successRateWeight*successScore +
|
||||
s.quotaWeight*quotaScore +
|
||||
s.latencyWeight*latencyScore +
|
||||
s.lastUsedWeight*lastUsedScore
|
||||
|
||||
// Apply consecutive failure penalty
|
||||
if m.FailCount > 0 {
|
||||
penalty := s.failPenaltyMultiplier * float64(m.FailCount)
|
||||
score = score * math.Max(0, 1.0-penalty)
|
||||
}
|
||||
|
||||
return score
|
||||
}
|
||||
|
||||
// SelectBestToken selects the token with the highest score.
|
||||
func (s *TokenScorer) SelectBestToken(tokens []string) string {
|
||||
if len(tokens) == 0 {
|
||||
return ""
|
||||
}
|
||||
if len(tokens) == 1 {
|
||||
return tokens[0]
|
||||
}
|
||||
|
||||
bestToken := tokens[0]
|
||||
bestScore := s.CalculateScore(tokens[0])
|
||||
|
||||
for _, token := range tokens[1:] {
|
||||
score := s.CalculateScore(token)
|
||||
if score > bestScore {
|
||||
bestScore = score
|
||||
bestToken = token
|
||||
}
|
||||
}
|
||||
|
||||
return bestToken
|
||||
}
|
||||
|
||||
// ResetMetrics clears all metrics for a token.
|
||||
func (s *TokenScorer) ResetMetrics(tokenKey string) {
|
||||
s.mu.Lock()
|
||||
defer s.mu.Unlock()
|
||||
delete(s.metrics, tokenKey)
|
||||
}
|
||||
|
||||
// ResetAllMetrics clears all stored metrics.
|
||||
func (s *TokenScorer) ResetAllMetrics() {
|
||||
s.mu.Lock()
|
||||
defer s.mu.Unlock()
|
||||
s.metrics = make(map[string]*TokenMetrics)
|
||||
}
|
||||
301
internal/auth/kiro/metrics_test.go
Normal file
301
internal/auth/kiro/metrics_test.go
Normal file
@@ -0,0 +1,301 @@
|
||||
package kiro
|
||||
|
||||
import (
|
||||
"sync"
|
||||
"testing"
|
||||
"time"
|
||||
)
|
||||
|
||||
func TestNewTokenScorer(t *testing.T) {
|
||||
s := NewTokenScorer()
|
||||
if s == nil {
|
||||
t.Fatal("expected non-nil TokenScorer")
|
||||
}
|
||||
if s.metrics == nil {
|
||||
t.Error("expected non-nil metrics map")
|
||||
}
|
||||
if s.successRateWeight != 0.4 {
|
||||
t.Errorf("expected successRateWeight 0.4, got %f", s.successRateWeight)
|
||||
}
|
||||
if s.quotaWeight != 0.25 {
|
||||
t.Errorf("expected quotaWeight 0.25, got %f", s.quotaWeight)
|
||||
}
|
||||
}
|
||||
|
||||
func TestRecordRequest_Success(t *testing.T) {
|
||||
s := NewTokenScorer()
|
||||
s.RecordRequest("token1", true, 100*time.Millisecond)
|
||||
|
||||
m := s.GetMetrics("token1")
|
||||
if m == nil {
|
||||
t.Fatal("expected non-nil metrics")
|
||||
}
|
||||
if m.TotalRequests != 1 {
|
||||
t.Errorf("expected TotalRequests 1, got %d", m.TotalRequests)
|
||||
}
|
||||
if m.SuccessRate != 1.0 {
|
||||
t.Errorf("expected SuccessRate 1.0, got %f", m.SuccessRate)
|
||||
}
|
||||
if m.FailCount != 0 {
|
||||
t.Errorf("expected FailCount 0, got %d", m.FailCount)
|
||||
}
|
||||
if m.AvgLatency != 100 {
|
||||
t.Errorf("expected AvgLatency 100, got %f", m.AvgLatency)
|
||||
}
|
||||
}
|
||||
|
||||
func TestRecordRequest_Failure(t *testing.T) {
|
||||
s := NewTokenScorer()
|
||||
s.RecordRequest("token1", false, 200*time.Millisecond)
|
||||
|
||||
m := s.GetMetrics("token1")
|
||||
if m.SuccessRate != 0.0 {
|
||||
t.Errorf("expected SuccessRate 0.0, got %f", m.SuccessRate)
|
||||
}
|
||||
if m.FailCount != 1 {
|
||||
t.Errorf("expected FailCount 1, got %d", m.FailCount)
|
||||
}
|
||||
}
|
||||
|
||||
func TestRecordRequest_MixedResults(t *testing.T) {
|
||||
s := NewTokenScorer()
|
||||
s.RecordRequest("token1", true, 100*time.Millisecond)
|
||||
s.RecordRequest("token1", true, 100*time.Millisecond)
|
||||
s.RecordRequest("token1", false, 100*time.Millisecond)
|
||||
s.RecordRequest("token1", true, 100*time.Millisecond)
|
||||
|
||||
m := s.GetMetrics("token1")
|
||||
if m.TotalRequests != 4 {
|
||||
t.Errorf("expected TotalRequests 4, got %d", m.TotalRequests)
|
||||
}
|
||||
if m.SuccessRate != 0.75 {
|
||||
t.Errorf("expected SuccessRate 0.75, got %f", m.SuccessRate)
|
||||
}
|
||||
if m.FailCount != 0 {
|
||||
t.Errorf("expected FailCount 0 (reset on success), got %d", m.FailCount)
|
||||
}
|
||||
}
|
||||
|
||||
func TestRecordRequest_ConsecutiveFailures(t *testing.T) {
|
||||
s := NewTokenScorer()
|
||||
s.RecordRequest("token1", true, 100*time.Millisecond)
|
||||
s.RecordRequest("token1", false, 100*time.Millisecond)
|
||||
s.RecordRequest("token1", false, 100*time.Millisecond)
|
||||
s.RecordRequest("token1", false, 100*time.Millisecond)
|
||||
|
||||
m := s.GetMetrics("token1")
|
||||
if m.FailCount != 3 {
|
||||
t.Errorf("expected FailCount 3, got %d", m.FailCount)
|
||||
}
|
||||
}
|
||||
|
||||
func TestSetQuotaRemaining(t *testing.T) {
|
||||
s := NewTokenScorer()
|
||||
s.SetQuotaRemaining("token1", 0.5)
|
||||
|
||||
m := s.GetMetrics("token1")
|
||||
if m.QuotaRemaining != 0.5 {
|
||||
t.Errorf("expected QuotaRemaining 0.5, got %f", m.QuotaRemaining)
|
||||
}
|
||||
}
|
||||
|
||||
func TestGetMetrics_NonExistent(t *testing.T) {
|
||||
s := NewTokenScorer()
|
||||
m := s.GetMetrics("nonexistent")
|
||||
if m != nil {
|
||||
t.Error("expected nil metrics for non-existent token")
|
||||
}
|
||||
}
|
||||
|
||||
func TestGetMetrics_ReturnsCopy(t *testing.T) {
|
||||
s := NewTokenScorer()
|
||||
s.RecordRequest("token1", true, 100*time.Millisecond)
|
||||
|
||||
m1 := s.GetMetrics("token1")
|
||||
m1.TotalRequests = 999
|
||||
|
||||
m2 := s.GetMetrics("token1")
|
||||
if m2.TotalRequests == 999 {
|
||||
t.Error("GetMetrics should return a copy")
|
||||
}
|
||||
}
|
||||
|
||||
func TestCalculateScore_NewToken(t *testing.T) {
|
||||
s := NewTokenScorer()
|
||||
score := s.CalculateScore("newtoken")
|
||||
if score != 1.0 {
|
||||
t.Errorf("expected score 1.0 for new token, got %f", score)
|
||||
}
|
||||
}
|
||||
|
||||
func TestCalculateScore_PerfectToken(t *testing.T) {
|
||||
s := NewTokenScorer()
|
||||
s.RecordRequest("token1", true, 50*time.Millisecond)
|
||||
s.SetQuotaRemaining("token1", 1.0)
|
||||
|
||||
time.Sleep(100 * time.Millisecond)
|
||||
score := s.CalculateScore("token1")
|
||||
if score < 0.5 || score > 1.0 {
|
||||
t.Errorf("expected high score for perfect token, got %f", score)
|
||||
}
|
||||
}
|
||||
|
||||
func TestCalculateScore_FailedToken(t *testing.T) {
|
||||
s := NewTokenScorer()
|
||||
for i := 0; i < 5; i++ {
|
||||
s.RecordRequest("token1", false, 1000*time.Millisecond)
|
||||
}
|
||||
s.SetQuotaRemaining("token1", 0.1)
|
||||
|
||||
score := s.CalculateScore("token1")
|
||||
if score > 0.5 {
|
||||
t.Errorf("expected low score for failed token, got %f", score)
|
||||
}
|
||||
}
|
||||
|
||||
func TestCalculateScore_FailPenalty(t *testing.T) {
|
||||
s := NewTokenScorer()
|
||||
s.RecordRequest("token1", true, 100*time.Millisecond)
|
||||
scoreNoFail := s.CalculateScore("token1")
|
||||
|
||||
s.RecordRequest("token1", false, 100*time.Millisecond)
|
||||
s.RecordRequest("token1", false, 100*time.Millisecond)
|
||||
scoreWithFail := s.CalculateScore("token1")
|
||||
|
||||
if scoreWithFail >= scoreNoFail {
|
||||
t.Errorf("expected lower score with consecutive failures: noFail=%f, withFail=%f", scoreNoFail, scoreWithFail)
|
||||
}
|
||||
}
|
||||
|
||||
func TestSelectBestToken_Empty(t *testing.T) {
|
||||
s := NewTokenScorer()
|
||||
best := s.SelectBestToken([]string{})
|
||||
if best != "" {
|
||||
t.Errorf("expected empty string for empty tokens, got %s", best)
|
||||
}
|
||||
}
|
||||
|
||||
func TestSelectBestToken_SingleToken(t *testing.T) {
|
||||
s := NewTokenScorer()
|
||||
best := s.SelectBestToken([]string{"token1"})
|
||||
if best != "token1" {
|
||||
t.Errorf("expected token1, got %s", best)
|
||||
}
|
||||
}
|
||||
|
||||
func TestSelectBestToken_MultipleTokens(t *testing.T) {
|
||||
s := NewTokenScorer()
|
||||
|
||||
s.RecordRequest("bad", false, 1000*time.Millisecond)
|
||||
s.RecordRequest("bad", false, 1000*time.Millisecond)
|
||||
s.SetQuotaRemaining("bad", 0.1)
|
||||
|
||||
s.RecordRequest("good", true, 50*time.Millisecond)
|
||||
s.SetQuotaRemaining("good", 0.9)
|
||||
|
||||
time.Sleep(50 * time.Millisecond)
|
||||
|
||||
best := s.SelectBestToken([]string{"bad", "good"})
|
||||
if best != "good" {
|
||||
t.Errorf("expected good token to be selected, got %s", best)
|
||||
}
|
||||
}
|
||||
|
||||
func TestResetMetrics(t *testing.T) {
|
||||
s := NewTokenScorer()
|
||||
s.RecordRequest("token1", true, 100*time.Millisecond)
|
||||
s.ResetMetrics("token1")
|
||||
|
||||
m := s.GetMetrics("token1")
|
||||
if m != nil {
|
||||
t.Error("expected nil metrics after reset")
|
||||
}
|
||||
}
|
||||
|
||||
func TestResetAllMetrics(t *testing.T) {
|
||||
s := NewTokenScorer()
|
||||
s.RecordRequest("token1", true, 100*time.Millisecond)
|
||||
s.RecordRequest("token2", true, 100*time.Millisecond)
|
||||
s.RecordRequest("token3", true, 100*time.Millisecond)
|
||||
|
||||
s.ResetAllMetrics()
|
||||
|
||||
if s.GetMetrics("token1") != nil {
|
||||
t.Error("expected nil metrics for token1 after reset all")
|
||||
}
|
||||
if s.GetMetrics("token2") != nil {
|
||||
t.Error("expected nil metrics for token2 after reset all")
|
||||
}
|
||||
}
|
||||
|
||||
func TestTokenScorer_ConcurrentAccess(t *testing.T) {
|
||||
s := NewTokenScorer()
|
||||
const numGoroutines = 50
|
||||
const numOperations = 100
|
||||
|
||||
var wg sync.WaitGroup
|
||||
wg.Add(numGoroutines)
|
||||
|
||||
for i := 0; i < numGoroutines; i++ {
|
||||
go func(id int) {
|
||||
defer wg.Done()
|
||||
tokenKey := "token" + string(rune('a'+id%10))
|
||||
for j := 0; j < numOperations; j++ {
|
||||
switch j % 6 {
|
||||
case 0:
|
||||
s.RecordRequest(tokenKey, j%2 == 0, time.Duration(j)*time.Millisecond)
|
||||
case 1:
|
||||
s.SetQuotaRemaining(tokenKey, float64(j%100)/100)
|
||||
case 2:
|
||||
s.GetMetrics(tokenKey)
|
||||
case 3:
|
||||
s.CalculateScore(tokenKey)
|
||||
case 4:
|
||||
s.SelectBestToken([]string{tokenKey, "token_x", "token_y"})
|
||||
case 5:
|
||||
if j%20 == 0 {
|
||||
s.ResetMetrics(tokenKey)
|
||||
}
|
||||
}
|
||||
}
|
||||
}(i)
|
||||
}
|
||||
|
||||
wg.Wait()
|
||||
}
|
||||
|
||||
func TestAvgLatencyCalculation(t *testing.T) {
|
||||
s := NewTokenScorer()
|
||||
s.RecordRequest("token1", true, 100*time.Millisecond)
|
||||
s.RecordRequest("token1", true, 200*time.Millisecond)
|
||||
s.RecordRequest("token1", true, 300*time.Millisecond)
|
||||
|
||||
m := s.GetMetrics("token1")
|
||||
if m.AvgLatency != 200 {
|
||||
t.Errorf("expected AvgLatency 200, got %f", m.AvgLatency)
|
||||
}
|
||||
}
|
||||
|
||||
func TestLastUsedUpdated(t *testing.T) {
|
||||
s := NewTokenScorer()
|
||||
before := time.Now()
|
||||
s.RecordRequest("token1", true, 100*time.Millisecond)
|
||||
|
||||
m := s.GetMetrics("token1")
|
||||
if m.LastUsed.Before(before) {
|
||||
t.Error("expected LastUsed to be after test start time")
|
||||
}
|
||||
if m.LastUsed.After(time.Now()) {
|
||||
t.Error("expected LastUsed to be before or equal to now")
|
||||
}
|
||||
}
|
||||
|
||||
func TestDefaultQuotaForNewToken(t *testing.T) {
|
||||
s := NewTokenScorer()
|
||||
s.RecordRequest("token1", true, 100*time.Millisecond)
|
||||
|
||||
m := s.GetMetrics("token1")
|
||||
if m.QuotaRemaining != 1.0 {
|
||||
t.Errorf("expected default QuotaRemaining 1.0, got %f", m.QuotaRemaining)
|
||||
}
|
||||
}
|
||||
@@ -227,6 +227,7 @@ func (o *KiroOAuth) exchangeCodeForToken(ctx context.Context, code, codeVerifier
|
||||
ExpiresAt: expiresAt.Format(time.RFC3339),
|
||||
AuthMethod: "social",
|
||||
Provider: "", // Caller should preserve original provider
|
||||
Region: "us-east-1",
|
||||
}, nil
|
||||
}
|
||||
|
||||
@@ -285,6 +286,7 @@ func (o *KiroOAuth) RefreshToken(ctx context.Context, refreshToken string) (*Kir
|
||||
ExpiresAt: expiresAt.Format(time.RFC3339),
|
||||
AuthMethod: "social",
|
||||
Provider: "", // Caller should preserve original provider
|
||||
Region: "us-east-1",
|
||||
}, nil
|
||||
}
|
||||
|
||||
|
||||
825
internal/auth/kiro/oauth_web.go
Normal file
825
internal/auth/kiro/oauth_web.go
Normal file
@@ -0,0 +1,825 @@
|
||||
// Package kiro provides OAuth Web authentication for Kiro.
|
||||
package kiro
|
||||
|
||||
import (
|
||||
"context"
|
||||
"crypto/rand"
|
||||
"encoding/base64"
|
||||
"fmt"
|
||||
"html/template"
|
||||
"net/http"
|
||||
"os"
|
||||
"path/filepath"
|
||||
"strings"
|
||||
"sync"
|
||||
"time"
|
||||
|
||||
"github.com/gin-gonic/gin"
|
||||
"github.com/router-for-me/CLIProxyAPI/v6/internal/config"
|
||||
"github.com/router-for-me/CLIProxyAPI/v6/internal/util"
|
||||
log "github.com/sirupsen/logrus"
|
||||
)
|
||||
|
||||
const (
|
||||
defaultSessionExpiry = 10 * time.Minute
|
||||
pollIntervalSeconds = 5
|
||||
)
|
||||
|
||||
type authSessionStatus string
|
||||
|
||||
const (
|
||||
statusPending authSessionStatus = "pending"
|
||||
statusSuccess authSessionStatus = "success"
|
||||
statusFailed authSessionStatus = "failed"
|
||||
)
|
||||
|
||||
type webAuthSession struct {
|
||||
stateID string
|
||||
deviceCode string
|
||||
userCode string
|
||||
authURL string
|
||||
verificationURI string
|
||||
expiresIn int
|
||||
interval int
|
||||
status authSessionStatus
|
||||
startedAt time.Time
|
||||
completedAt time.Time
|
||||
expiresAt time.Time
|
||||
error string
|
||||
tokenData *KiroTokenData
|
||||
ssoClient *SSOOIDCClient
|
||||
clientID string
|
||||
clientSecret string
|
||||
region string
|
||||
cancelFunc context.CancelFunc
|
||||
authMethod string // "google", "github", "builder-id", "idc"
|
||||
startURL string // Used for IDC
|
||||
codeVerifier string // Used for social auth PKCE
|
||||
codeChallenge string // Used for social auth PKCE
|
||||
}
|
||||
|
||||
type OAuthWebHandler struct {
|
||||
cfg *config.Config
|
||||
sessions map[string]*webAuthSession
|
||||
mu sync.RWMutex
|
||||
onTokenObtained func(*KiroTokenData)
|
||||
}
|
||||
|
||||
func NewOAuthWebHandler(cfg *config.Config) *OAuthWebHandler {
|
||||
return &OAuthWebHandler{
|
||||
cfg: cfg,
|
||||
sessions: make(map[string]*webAuthSession),
|
||||
}
|
||||
}
|
||||
|
||||
func (h *OAuthWebHandler) SetTokenCallback(callback func(*KiroTokenData)) {
|
||||
h.onTokenObtained = callback
|
||||
}
|
||||
|
||||
func (h *OAuthWebHandler) RegisterRoutes(router gin.IRouter) {
|
||||
oauth := router.Group("/v0/oauth/kiro")
|
||||
{
|
||||
oauth.GET("", h.handleSelect)
|
||||
oauth.GET("/start", h.handleStart)
|
||||
oauth.GET("/callback", h.handleCallback)
|
||||
oauth.GET("/social/callback", h.handleSocialCallback)
|
||||
oauth.GET("/status", h.handleStatus)
|
||||
oauth.POST("/import", h.handleImportToken)
|
||||
}
|
||||
}
|
||||
|
||||
func generateStateID() (string, error) {
|
||||
b := make([]byte, 16)
|
||||
if _, err := rand.Read(b); err != nil {
|
||||
return "", err
|
||||
}
|
||||
return base64.RawURLEncoding.EncodeToString(b), nil
|
||||
}
|
||||
|
||||
func (h *OAuthWebHandler) handleSelect(c *gin.Context) {
|
||||
h.renderSelectPage(c)
|
||||
}
|
||||
|
||||
func (h *OAuthWebHandler) handleStart(c *gin.Context) {
|
||||
method := c.Query("method")
|
||||
|
||||
if method == "" {
|
||||
c.Redirect(http.StatusFound, "/v0/oauth/kiro")
|
||||
return
|
||||
}
|
||||
|
||||
switch method {
|
||||
case "google", "github":
|
||||
// Google/GitHub social login is not supported for third-party apps
|
||||
// due to AWS Cognito redirect_uri restrictions
|
||||
h.renderError(c, "Google/GitHub login is not available for third-party applications. Please use AWS Builder ID or import your token from Kiro IDE.")
|
||||
case "builder-id":
|
||||
h.startBuilderIDAuth(c)
|
||||
case "idc":
|
||||
h.startIDCAuth(c)
|
||||
default:
|
||||
h.renderError(c, fmt.Sprintf("Unknown authentication method: %s", method))
|
||||
}
|
||||
}
|
||||
|
||||
func (h *OAuthWebHandler) startSocialAuth(c *gin.Context, method string) {
|
||||
stateID, err := generateStateID()
|
||||
if err != nil {
|
||||
h.renderError(c, "Failed to generate state parameter")
|
||||
return
|
||||
}
|
||||
|
||||
codeVerifier, codeChallenge, err := generatePKCE()
|
||||
if err != nil {
|
||||
h.renderError(c, "Failed to generate PKCE parameters")
|
||||
return
|
||||
}
|
||||
|
||||
socialClient := NewSocialAuthClient(h.cfg)
|
||||
|
||||
var provider string
|
||||
if method == "google" {
|
||||
provider = string(ProviderGoogle)
|
||||
} else {
|
||||
provider = string(ProviderGitHub)
|
||||
}
|
||||
|
||||
redirectURI := h.getSocialCallbackURL(c)
|
||||
authURL := socialClient.buildLoginURL(provider, redirectURI, codeChallenge, stateID)
|
||||
|
||||
ctx, cancel := context.WithTimeout(context.Background(), 10*time.Minute)
|
||||
|
||||
session := &webAuthSession{
|
||||
stateID: stateID,
|
||||
authMethod: method,
|
||||
authURL: authURL,
|
||||
status: statusPending,
|
||||
startedAt: time.Now(),
|
||||
expiresIn: 600,
|
||||
codeVerifier: codeVerifier,
|
||||
codeChallenge: codeChallenge,
|
||||
region: "us-east-1",
|
||||
cancelFunc: cancel,
|
||||
}
|
||||
|
||||
h.mu.Lock()
|
||||
h.sessions[stateID] = session
|
||||
h.mu.Unlock()
|
||||
|
||||
go func() {
|
||||
<-ctx.Done()
|
||||
h.mu.Lock()
|
||||
if session.status == statusPending {
|
||||
session.status = statusFailed
|
||||
session.error = "Authentication timed out"
|
||||
}
|
||||
h.mu.Unlock()
|
||||
}()
|
||||
|
||||
c.Redirect(http.StatusFound, authURL)
|
||||
}
|
||||
|
||||
func (h *OAuthWebHandler) getSocialCallbackURL(c *gin.Context) string {
|
||||
scheme := "http"
|
||||
if c.Request.TLS != nil || c.GetHeader("X-Forwarded-Proto") == "https" {
|
||||
scheme = "https"
|
||||
}
|
||||
return fmt.Sprintf("%s://%s/v0/oauth/kiro/social/callback", scheme, c.Request.Host)
|
||||
}
|
||||
|
||||
func (h *OAuthWebHandler) startBuilderIDAuth(c *gin.Context) {
|
||||
stateID, err := generateStateID()
|
||||
if err != nil {
|
||||
h.renderError(c, "Failed to generate state parameter")
|
||||
return
|
||||
}
|
||||
|
||||
region := defaultIDCRegion
|
||||
startURL := builderIDStartURL
|
||||
|
||||
ssoClient := NewSSOOIDCClient(h.cfg)
|
||||
|
||||
regResp, err := ssoClient.RegisterClientWithRegion(c.Request.Context(), region)
|
||||
if err != nil {
|
||||
log.Errorf("OAuth Web: failed to register client: %v", err)
|
||||
h.renderError(c, fmt.Sprintf("Failed to register client: %v", err))
|
||||
return
|
||||
}
|
||||
|
||||
authResp, err := ssoClient.StartDeviceAuthorizationWithIDC(
|
||||
c.Request.Context(),
|
||||
regResp.ClientID,
|
||||
regResp.ClientSecret,
|
||||
startURL,
|
||||
region,
|
||||
)
|
||||
if err != nil {
|
||||
log.Errorf("OAuth Web: failed to start device authorization: %v", err)
|
||||
h.renderError(c, fmt.Sprintf("Failed to start device authorization: %v", err))
|
||||
return
|
||||
}
|
||||
|
||||
ctx, cancel := context.WithTimeout(context.Background(), time.Duration(authResp.ExpiresIn)*time.Second)
|
||||
|
||||
session := &webAuthSession{
|
||||
stateID: stateID,
|
||||
deviceCode: authResp.DeviceCode,
|
||||
userCode: authResp.UserCode,
|
||||
authURL: authResp.VerificationURIComplete,
|
||||
verificationURI: authResp.VerificationURI,
|
||||
expiresIn: authResp.ExpiresIn,
|
||||
interval: authResp.Interval,
|
||||
status: statusPending,
|
||||
startedAt: time.Now(),
|
||||
ssoClient: ssoClient,
|
||||
clientID: regResp.ClientID,
|
||||
clientSecret: regResp.ClientSecret,
|
||||
region: region,
|
||||
authMethod: "builder-id",
|
||||
startURL: startURL,
|
||||
cancelFunc: cancel,
|
||||
}
|
||||
|
||||
h.mu.Lock()
|
||||
h.sessions[stateID] = session
|
||||
h.mu.Unlock()
|
||||
|
||||
go h.pollForToken(ctx, session)
|
||||
|
||||
h.renderStartPage(c, session)
|
||||
}
|
||||
|
||||
func (h *OAuthWebHandler) startIDCAuth(c *gin.Context) {
|
||||
startURL := c.Query("startUrl")
|
||||
region := c.Query("region")
|
||||
|
||||
if startURL == "" {
|
||||
h.renderError(c, "Missing startUrl parameter for IDC authentication")
|
||||
return
|
||||
}
|
||||
if region == "" {
|
||||
region = defaultIDCRegion
|
||||
}
|
||||
|
||||
stateID, err := generateStateID()
|
||||
if err != nil {
|
||||
h.renderError(c, "Failed to generate state parameter")
|
||||
return
|
||||
}
|
||||
|
||||
ssoClient := NewSSOOIDCClient(h.cfg)
|
||||
|
||||
regResp, err := ssoClient.RegisterClientWithRegion(c.Request.Context(), region)
|
||||
if err != nil {
|
||||
log.Errorf("OAuth Web: failed to register client: %v", err)
|
||||
h.renderError(c, fmt.Sprintf("Failed to register client: %v", err))
|
||||
return
|
||||
}
|
||||
|
||||
authResp, err := ssoClient.StartDeviceAuthorizationWithIDC(
|
||||
c.Request.Context(),
|
||||
regResp.ClientID,
|
||||
regResp.ClientSecret,
|
||||
startURL,
|
||||
region,
|
||||
)
|
||||
if err != nil {
|
||||
log.Errorf("OAuth Web: failed to start device authorization: %v", err)
|
||||
h.renderError(c, fmt.Sprintf("Failed to start device authorization: %v", err))
|
||||
return
|
||||
}
|
||||
|
||||
ctx, cancel := context.WithTimeout(context.Background(), time.Duration(authResp.ExpiresIn)*time.Second)
|
||||
|
||||
session := &webAuthSession{
|
||||
stateID: stateID,
|
||||
deviceCode: authResp.DeviceCode,
|
||||
userCode: authResp.UserCode,
|
||||
authURL: authResp.VerificationURIComplete,
|
||||
verificationURI: authResp.VerificationURI,
|
||||
expiresIn: authResp.ExpiresIn,
|
||||
interval: authResp.Interval,
|
||||
status: statusPending,
|
||||
startedAt: time.Now(),
|
||||
ssoClient: ssoClient,
|
||||
clientID: regResp.ClientID,
|
||||
clientSecret: regResp.ClientSecret,
|
||||
region: region,
|
||||
authMethod: "idc",
|
||||
startURL: startURL,
|
||||
cancelFunc: cancel,
|
||||
}
|
||||
|
||||
h.mu.Lock()
|
||||
h.sessions[stateID] = session
|
||||
h.mu.Unlock()
|
||||
|
||||
go h.pollForToken(ctx, session)
|
||||
|
||||
h.renderStartPage(c, session)
|
||||
}
|
||||
|
||||
func (h *OAuthWebHandler) pollForToken(ctx context.Context, session *webAuthSession) {
|
||||
defer session.cancelFunc()
|
||||
|
||||
interval := time.Duration(session.interval) * time.Second
|
||||
if interval < time.Duration(pollIntervalSeconds)*time.Second {
|
||||
interval = time.Duration(pollIntervalSeconds) * time.Second
|
||||
}
|
||||
|
||||
ticker := time.NewTicker(interval)
|
||||
defer ticker.Stop()
|
||||
|
||||
for {
|
||||
select {
|
||||
case <-ctx.Done():
|
||||
h.mu.Lock()
|
||||
if session.status == statusPending {
|
||||
session.status = statusFailed
|
||||
session.error = "Authentication timed out"
|
||||
}
|
||||
h.mu.Unlock()
|
||||
return
|
||||
case <-ticker.C:
|
||||
tokenResp, err := h.ssoClient(session).CreateTokenWithRegion(
|
||||
ctx,
|
||||
session.clientID,
|
||||
session.clientSecret,
|
||||
session.deviceCode,
|
||||
session.region,
|
||||
)
|
||||
|
||||
if err != nil {
|
||||
errStr := err.Error()
|
||||
if errStr == ErrAuthorizationPending.Error() {
|
||||
continue
|
||||
}
|
||||
if errStr == ErrSlowDown.Error() {
|
||||
interval += 5 * time.Second
|
||||
ticker.Reset(interval)
|
||||
continue
|
||||
}
|
||||
|
||||
h.mu.Lock()
|
||||
session.status = statusFailed
|
||||
session.error = errStr
|
||||
session.completedAt = time.Now()
|
||||
h.mu.Unlock()
|
||||
|
||||
log.Errorf("OAuth Web: token polling failed: %v", err)
|
||||
return
|
||||
}
|
||||
|
||||
expiresAt := time.Now().Add(time.Duration(tokenResp.ExpiresIn) * time.Second)
|
||||
profileArn := session.ssoClient.fetchProfileArn(ctx, tokenResp.AccessToken)
|
||||
email := FetchUserEmailWithFallback(ctx, h.cfg, tokenResp.AccessToken)
|
||||
|
||||
tokenData := &KiroTokenData{
|
||||
AccessToken: tokenResp.AccessToken,
|
||||
RefreshToken: tokenResp.RefreshToken,
|
||||
ProfileArn: profileArn,
|
||||
ExpiresAt: expiresAt.Format(time.RFC3339),
|
||||
AuthMethod: session.authMethod,
|
||||
Provider: "AWS",
|
||||
ClientID: session.clientID,
|
||||
ClientSecret: session.clientSecret,
|
||||
Email: email,
|
||||
Region: session.region,
|
||||
}
|
||||
|
||||
h.mu.Lock()
|
||||
session.status = statusSuccess
|
||||
session.completedAt = time.Now()
|
||||
session.expiresAt = expiresAt
|
||||
session.tokenData = tokenData
|
||||
h.mu.Unlock()
|
||||
|
||||
if h.onTokenObtained != nil {
|
||||
h.onTokenObtained(tokenData)
|
||||
}
|
||||
|
||||
// Save token to file
|
||||
h.saveTokenToFile(tokenData)
|
||||
|
||||
log.Infof("OAuth Web: authentication successful for %s", email)
|
||||
return
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// saveTokenToFile saves the token data to the auth directory
|
||||
func (h *OAuthWebHandler) saveTokenToFile(tokenData *KiroTokenData) {
|
||||
// Get auth directory from config or use default
|
||||
authDir := ""
|
||||
if h.cfg != nil && h.cfg.AuthDir != "" {
|
||||
var err error
|
||||
authDir, err = util.ResolveAuthDir(h.cfg.AuthDir)
|
||||
if err != nil {
|
||||
log.Errorf("OAuth Web: failed to resolve auth directory: %v", err)
|
||||
}
|
||||
}
|
||||
|
||||
// Fall back to default location
|
||||
if authDir == "" {
|
||||
home, err := os.UserHomeDir()
|
||||
if err != nil {
|
||||
log.Errorf("OAuth Web: failed to get home directory: %v", err)
|
||||
return
|
||||
}
|
||||
authDir = filepath.Join(home, ".cli-proxy-api")
|
||||
}
|
||||
|
||||
// Create directory if not exists
|
||||
if err := os.MkdirAll(authDir, 0700); err != nil {
|
||||
log.Errorf("OAuth Web: failed to create auth directory: %v", err)
|
||||
return
|
||||
}
|
||||
|
||||
// Generate filename based on auth method
|
||||
// Format: kiro-{authMethod}.json or kiro-{authMethod}-{email}.json
|
||||
fileName := fmt.Sprintf("kiro-%s.json", tokenData.AuthMethod)
|
||||
if tokenData.Email != "" {
|
||||
// Sanitize email for filename (replace @ and . with -)
|
||||
sanitizedEmail := tokenData.Email
|
||||
sanitizedEmail = strings.ReplaceAll(sanitizedEmail, "@", "-")
|
||||
sanitizedEmail = strings.ReplaceAll(sanitizedEmail, ".", "-")
|
||||
fileName = fmt.Sprintf("kiro-%s-%s.json", tokenData.AuthMethod, sanitizedEmail)
|
||||
}
|
||||
|
||||
authFilePath := filepath.Join(authDir, fileName)
|
||||
|
||||
// Convert to storage format and save
|
||||
storage := &KiroTokenStorage{
|
||||
Type: "kiro",
|
||||
AccessToken: tokenData.AccessToken,
|
||||
RefreshToken: tokenData.RefreshToken,
|
||||
ProfileArn: tokenData.ProfileArn,
|
||||
ExpiresAt: tokenData.ExpiresAt,
|
||||
AuthMethod: tokenData.AuthMethod,
|
||||
Provider: tokenData.Provider,
|
||||
LastRefresh: time.Now().Format(time.RFC3339),
|
||||
ClientID: tokenData.ClientID,
|
||||
ClientSecret: tokenData.ClientSecret,
|
||||
Region: tokenData.Region,
|
||||
StartURL: tokenData.StartURL,
|
||||
Email: tokenData.Email,
|
||||
}
|
||||
|
||||
if err := storage.SaveTokenToFile(authFilePath); err != nil {
|
||||
log.Errorf("OAuth Web: failed to save token to file: %v", err)
|
||||
return
|
||||
}
|
||||
|
||||
log.Infof("OAuth Web: token saved to %s", authFilePath)
|
||||
}
|
||||
|
||||
func (h *OAuthWebHandler) ssoClient(session *webAuthSession) *SSOOIDCClient {
|
||||
return session.ssoClient
|
||||
}
|
||||
|
||||
func (h *OAuthWebHandler) handleCallback(c *gin.Context) {
|
||||
stateID := c.Query("state")
|
||||
errParam := c.Query("error")
|
||||
|
||||
if errParam != "" {
|
||||
h.renderError(c, errParam)
|
||||
return
|
||||
}
|
||||
|
||||
if stateID == "" {
|
||||
h.renderError(c, "Missing state parameter")
|
||||
return
|
||||
}
|
||||
|
||||
h.mu.RLock()
|
||||
session, exists := h.sessions[stateID]
|
||||
h.mu.RUnlock()
|
||||
|
||||
if !exists {
|
||||
h.renderError(c, "Invalid or expired session")
|
||||
return
|
||||
}
|
||||
|
||||
if session.status == statusSuccess {
|
||||
h.renderSuccess(c, session)
|
||||
} else if session.status == statusFailed {
|
||||
h.renderError(c, session.error)
|
||||
} else {
|
||||
c.Redirect(http.StatusFound, "/v0/oauth/kiro/start")
|
||||
}
|
||||
}
|
||||
|
||||
func (h *OAuthWebHandler) handleSocialCallback(c *gin.Context) {
|
||||
stateID := c.Query("state")
|
||||
code := c.Query("code")
|
||||
errParam := c.Query("error")
|
||||
|
||||
if errParam != "" {
|
||||
h.renderError(c, errParam)
|
||||
return
|
||||
}
|
||||
|
||||
if stateID == "" {
|
||||
h.renderError(c, "Missing state parameter")
|
||||
return
|
||||
}
|
||||
|
||||
if code == "" {
|
||||
h.renderError(c, "Missing authorization code")
|
||||
return
|
||||
}
|
||||
|
||||
h.mu.RLock()
|
||||
session, exists := h.sessions[stateID]
|
||||
h.mu.RUnlock()
|
||||
|
||||
if !exists {
|
||||
h.renderError(c, "Invalid or expired session")
|
||||
return
|
||||
}
|
||||
|
||||
if session.authMethod != "google" && session.authMethod != "github" {
|
||||
h.renderError(c, "Invalid session type for social callback")
|
||||
return
|
||||
}
|
||||
|
||||
socialClient := NewSocialAuthClient(h.cfg)
|
||||
redirectURI := h.getSocialCallbackURL(c)
|
||||
|
||||
tokenReq := &CreateTokenRequest{
|
||||
Code: code,
|
||||
CodeVerifier: session.codeVerifier,
|
||||
RedirectURI: redirectURI,
|
||||
}
|
||||
|
||||
tokenResp, err := socialClient.CreateToken(c.Request.Context(), tokenReq)
|
||||
if err != nil {
|
||||
log.Errorf("OAuth Web: social token exchange failed: %v", err)
|
||||
h.mu.Lock()
|
||||
session.status = statusFailed
|
||||
session.error = fmt.Sprintf("Token exchange failed: %v", err)
|
||||
session.completedAt = time.Now()
|
||||
h.mu.Unlock()
|
||||
h.renderError(c, session.error)
|
||||
return
|
||||
}
|
||||
|
||||
expiresIn := tokenResp.ExpiresIn
|
||||
if expiresIn <= 0 {
|
||||
expiresIn = 3600
|
||||
}
|
||||
expiresAt := time.Now().Add(time.Duration(expiresIn) * time.Second)
|
||||
|
||||
email := ExtractEmailFromJWT(tokenResp.AccessToken)
|
||||
|
||||
var provider string
|
||||
if session.authMethod == "google" {
|
||||
provider = string(ProviderGoogle)
|
||||
} else {
|
||||
provider = string(ProviderGitHub)
|
||||
}
|
||||
|
||||
tokenData := &KiroTokenData{
|
||||
AccessToken: tokenResp.AccessToken,
|
||||
RefreshToken: tokenResp.RefreshToken,
|
||||
ProfileArn: tokenResp.ProfileArn,
|
||||
ExpiresAt: expiresAt.Format(time.RFC3339),
|
||||
AuthMethod: session.authMethod,
|
||||
Provider: provider,
|
||||
Email: email,
|
||||
Region: "us-east-1",
|
||||
}
|
||||
|
||||
h.mu.Lock()
|
||||
session.status = statusSuccess
|
||||
session.completedAt = time.Now()
|
||||
session.expiresAt = expiresAt
|
||||
session.tokenData = tokenData
|
||||
h.mu.Unlock()
|
||||
|
||||
if session.cancelFunc != nil {
|
||||
session.cancelFunc()
|
||||
}
|
||||
|
||||
if h.onTokenObtained != nil {
|
||||
h.onTokenObtained(tokenData)
|
||||
}
|
||||
|
||||
// Save token to file
|
||||
h.saveTokenToFile(tokenData)
|
||||
|
||||
log.Infof("OAuth Web: social authentication successful for %s via %s", email, provider)
|
||||
h.renderSuccess(c, session)
|
||||
}
|
||||
|
||||
func (h *OAuthWebHandler) handleStatus(c *gin.Context) {
|
||||
stateID := c.Query("state")
|
||||
if stateID == "" {
|
||||
c.JSON(http.StatusBadRequest, gin.H{"error": "missing state parameter"})
|
||||
return
|
||||
}
|
||||
|
||||
h.mu.RLock()
|
||||
session, exists := h.sessions[stateID]
|
||||
h.mu.RUnlock()
|
||||
|
||||
if !exists {
|
||||
c.JSON(http.StatusNotFound, gin.H{"error": "session not found"})
|
||||
return
|
||||
}
|
||||
|
||||
response := gin.H{
|
||||
"status": string(session.status),
|
||||
}
|
||||
|
||||
switch session.status {
|
||||
case statusPending:
|
||||
elapsed := time.Since(session.startedAt).Seconds()
|
||||
remaining := float64(session.expiresIn) - elapsed
|
||||
if remaining < 0 {
|
||||
remaining = 0
|
||||
}
|
||||
response["remaining_seconds"] = int(remaining)
|
||||
case statusSuccess:
|
||||
response["completed_at"] = session.completedAt.Format(time.RFC3339)
|
||||
response["expires_at"] = session.expiresAt.Format(time.RFC3339)
|
||||
case statusFailed:
|
||||
response["error"] = session.error
|
||||
response["failed_at"] = session.completedAt.Format(time.RFC3339)
|
||||
}
|
||||
|
||||
c.JSON(http.StatusOK, response)
|
||||
}
|
||||
|
||||
func (h *OAuthWebHandler) renderStartPage(c *gin.Context, session *webAuthSession) {
|
||||
tmpl, err := template.New("start").Parse(oauthWebStartPageHTML)
|
||||
if err != nil {
|
||||
log.Errorf("OAuth Web: failed to parse template: %v", err)
|
||||
c.String(http.StatusInternalServerError, "Template error")
|
||||
return
|
||||
}
|
||||
|
||||
data := map[string]interface{}{
|
||||
"AuthURL": session.authURL,
|
||||
"UserCode": session.userCode,
|
||||
"ExpiresIn": session.expiresIn,
|
||||
"StateID": session.stateID,
|
||||
}
|
||||
|
||||
c.Header("Content-Type", "text/html; charset=utf-8")
|
||||
if err := tmpl.Execute(c.Writer, data); err != nil {
|
||||
log.Errorf("OAuth Web: failed to render template: %v", err)
|
||||
}
|
||||
}
|
||||
|
||||
func (h *OAuthWebHandler) renderSelectPage(c *gin.Context) {
|
||||
tmpl, err := template.New("select").Parse(oauthWebSelectPageHTML)
|
||||
if err != nil {
|
||||
log.Errorf("OAuth Web: failed to parse select template: %v", err)
|
||||
c.String(http.StatusInternalServerError, "Template error")
|
||||
return
|
||||
}
|
||||
|
||||
c.Header("Content-Type", "text/html; charset=utf-8")
|
||||
if err := tmpl.Execute(c.Writer, nil); err != nil {
|
||||
log.Errorf("OAuth Web: failed to render select template: %v", err)
|
||||
}
|
||||
}
|
||||
|
||||
func (h *OAuthWebHandler) renderError(c *gin.Context, errMsg string) {
|
||||
tmpl, err := template.New("error").Parse(oauthWebErrorPageHTML)
|
||||
if err != nil {
|
||||
log.Errorf("OAuth Web: failed to parse error template: %v", err)
|
||||
c.String(http.StatusInternalServerError, "Template error")
|
||||
return
|
||||
}
|
||||
|
||||
data := map[string]interface{}{
|
||||
"Error": errMsg,
|
||||
}
|
||||
|
||||
c.Header("Content-Type", "text/html; charset=utf-8")
|
||||
c.Status(http.StatusBadRequest)
|
||||
if err := tmpl.Execute(c.Writer, data); err != nil {
|
||||
log.Errorf("OAuth Web: failed to render error template: %v", err)
|
||||
}
|
||||
}
|
||||
|
||||
func (h *OAuthWebHandler) renderSuccess(c *gin.Context, session *webAuthSession) {
|
||||
tmpl, err := template.New("success").Parse(oauthWebSuccessPageHTML)
|
||||
if err != nil {
|
||||
log.Errorf("OAuth Web: failed to parse success template: %v", err)
|
||||
c.String(http.StatusInternalServerError, "Template error")
|
||||
return
|
||||
}
|
||||
|
||||
data := map[string]interface{}{
|
||||
"ExpiresAt": session.expiresAt.Format(time.RFC3339),
|
||||
}
|
||||
|
||||
c.Header("Content-Type", "text/html; charset=utf-8")
|
||||
if err := tmpl.Execute(c.Writer, data); err != nil {
|
||||
log.Errorf("OAuth Web: failed to render success template: %v", err)
|
||||
}
|
||||
}
|
||||
|
||||
func (h *OAuthWebHandler) CleanupExpiredSessions() {
|
||||
h.mu.Lock()
|
||||
defer h.mu.Unlock()
|
||||
|
||||
now := time.Now()
|
||||
for id, session := range h.sessions {
|
||||
if session.status != statusPending && now.Sub(session.completedAt) > 30*time.Minute {
|
||||
delete(h.sessions, id)
|
||||
} else if session.status == statusPending && now.Sub(session.startedAt) > defaultSessionExpiry {
|
||||
session.cancelFunc()
|
||||
delete(h.sessions, id)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func (h *OAuthWebHandler) GetSession(stateID string) (*webAuthSession, bool) {
|
||||
h.mu.RLock()
|
||||
defer h.mu.RUnlock()
|
||||
session, exists := h.sessions[stateID]
|
||||
return session, exists
|
||||
}
|
||||
|
||||
// ImportTokenRequest represents the request body for token import
|
||||
type ImportTokenRequest struct {
|
||||
RefreshToken string `json:"refreshToken"`
|
||||
}
|
||||
|
||||
// handleImportToken handles manual refresh token import from Kiro IDE
|
||||
func (h *OAuthWebHandler) handleImportToken(c *gin.Context) {
|
||||
var req ImportTokenRequest
|
||||
if err := c.ShouldBindJSON(&req); err != nil {
|
||||
c.JSON(http.StatusBadRequest, gin.H{
|
||||
"success": false,
|
||||
"error": "Invalid request body",
|
||||
})
|
||||
return
|
||||
}
|
||||
|
||||
refreshToken := strings.TrimSpace(req.RefreshToken)
|
||||
if refreshToken == "" {
|
||||
c.JSON(http.StatusBadRequest, gin.H{
|
||||
"success": false,
|
||||
"error": "Refresh token is required",
|
||||
})
|
||||
return
|
||||
}
|
||||
|
||||
// Validate token format
|
||||
if !strings.HasPrefix(refreshToken, "aorAAAAAG") {
|
||||
c.JSON(http.StatusBadRequest, gin.H{
|
||||
"success": false,
|
||||
"error": "Invalid token format. Token should start with aorAAAAAG...",
|
||||
})
|
||||
return
|
||||
}
|
||||
|
||||
// Create social auth client to refresh and validate the token
|
||||
socialClient := NewSocialAuthClient(h.cfg)
|
||||
|
||||
// Refresh the token to validate it and get access token
|
||||
tokenData, err := socialClient.RefreshSocialToken(c.Request.Context(), refreshToken)
|
||||
if err != nil {
|
||||
log.Errorf("OAuth Web: token refresh failed during import: %v", err)
|
||||
c.JSON(http.StatusBadRequest, gin.H{
|
||||
"success": false,
|
||||
"error": fmt.Sprintf("Token validation failed: %v", err),
|
||||
})
|
||||
return
|
||||
}
|
||||
|
||||
// Set the original refresh token (the refreshed one might be empty)
|
||||
if tokenData.RefreshToken == "" {
|
||||
tokenData.RefreshToken = refreshToken
|
||||
}
|
||||
tokenData.AuthMethod = "social"
|
||||
tokenData.Provider = "imported"
|
||||
|
||||
// Notify callback if set
|
||||
if h.onTokenObtained != nil {
|
||||
h.onTokenObtained(tokenData)
|
||||
}
|
||||
|
||||
// Save token to file
|
||||
h.saveTokenToFile(tokenData)
|
||||
|
||||
// Generate filename for response
|
||||
fileName := fmt.Sprintf("kiro-%s.json", tokenData.AuthMethod)
|
||||
if tokenData.Email != "" {
|
||||
sanitizedEmail := strings.ReplaceAll(tokenData.Email, "@", "-")
|
||||
sanitizedEmail = strings.ReplaceAll(sanitizedEmail, ".", "-")
|
||||
fileName = fmt.Sprintf("kiro-%s-%s.json", tokenData.AuthMethod, sanitizedEmail)
|
||||
}
|
||||
|
||||
log.Infof("OAuth Web: token imported successfully")
|
||||
c.JSON(http.StatusOK, gin.H{
|
||||
"success": true,
|
||||
"message": "Token imported successfully",
|
||||
"fileName": fileName,
|
||||
})
|
||||
}
|
||||
385
internal/auth/kiro/oauth_web.go.bak
Normal file
385
internal/auth/kiro/oauth_web.go.bak
Normal file
@@ -0,0 +1,385 @@
|
||||
// Package kiro provides OAuth Web authentication for Kiro.
|
||||
package kiro
|
||||
|
||||
import (
|
||||
"context"
|
||||
"crypto/rand"
|
||||
"encoding/base64"
|
||||
"fmt"
|
||||
"html/template"
|
||||
"net/http"
|
||||
"sync"
|
||||
"time"
|
||||
|
||||
"github.com/gin-gonic/gin"
|
||||
"github.com/router-for-me/CLIProxyAPI/v6/internal/config"
|
||||
log "github.com/sirupsen/logrus"
|
||||
)
|
||||
|
||||
const (
|
||||
defaultSessionExpiry = 10 * time.Minute
|
||||
pollIntervalSeconds = 5
|
||||
)
|
||||
|
||||
type authSessionStatus string
|
||||
|
||||
const (
|
||||
statusPending authSessionStatus = "pending"
|
||||
statusSuccess authSessionStatus = "success"
|
||||
statusFailed authSessionStatus = "failed"
|
||||
)
|
||||
|
||||
type webAuthSession struct {
|
||||
stateID string
|
||||
deviceCode string
|
||||
userCode string
|
||||
authURL string
|
||||
verificationURI string
|
||||
expiresIn int
|
||||
interval int
|
||||
status authSessionStatus
|
||||
startedAt time.Time
|
||||
completedAt time.Time
|
||||
expiresAt time.Time
|
||||
error string
|
||||
tokenData *KiroTokenData
|
||||
ssoClient *SSOOIDCClient
|
||||
clientID string
|
||||
clientSecret string
|
||||
region string
|
||||
cancelFunc context.CancelFunc
|
||||
}
|
||||
|
||||
type OAuthWebHandler struct {
|
||||
cfg *config.Config
|
||||
sessions map[string]*webAuthSession
|
||||
mu sync.RWMutex
|
||||
onTokenObtained func(*KiroTokenData)
|
||||
}
|
||||
|
||||
func NewOAuthWebHandler(cfg *config.Config) *OAuthWebHandler {
|
||||
return &OAuthWebHandler{
|
||||
cfg: cfg,
|
||||
sessions: make(map[string]*webAuthSession),
|
||||
}
|
||||
}
|
||||
|
||||
func (h *OAuthWebHandler) SetTokenCallback(callback func(*KiroTokenData)) {
|
||||
h.onTokenObtained = callback
|
||||
}
|
||||
|
||||
func (h *OAuthWebHandler) RegisterRoutes(router gin.IRouter) {
|
||||
oauth := router.Group("/v0/oauth/kiro")
|
||||
{
|
||||
oauth.GET("/start", h.handleStart)
|
||||
oauth.GET("/callback", h.handleCallback)
|
||||
oauth.GET("/status", h.handleStatus)
|
||||
}
|
||||
}
|
||||
|
||||
func generateStateID() (string, error) {
|
||||
b := make([]byte, 16)
|
||||
if _, err := rand.Read(b); err != nil {
|
||||
return "", err
|
||||
}
|
||||
return base64.RawURLEncoding.EncodeToString(b), nil
|
||||
}
|
||||
|
||||
func (h *OAuthWebHandler) handleStart(c *gin.Context) {
|
||||
stateID, err := generateStateID()
|
||||
if err != nil {
|
||||
h.renderError(c, "Failed to generate state parameter")
|
||||
return
|
||||
}
|
||||
|
||||
region := defaultIDCRegion
|
||||
startURL := builderIDStartURL
|
||||
|
||||
ssoClient := NewSSOOIDCClient(h.cfg)
|
||||
|
||||
regResp, err := ssoClient.RegisterClientWithRegion(c.Request.Context(), region)
|
||||
if err != nil {
|
||||
log.Errorf("OAuth Web: failed to register client: %v", err)
|
||||
h.renderError(c, fmt.Sprintf("Failed to register client: %v", err))
|
||||
return
|
||||
}
|
||||
|
||||
authResp, err := ssoClient.StartDeviceAuthorizationWithIDC(
|
||||
c.Request.Context(),
|
||||
regResp.ClientID,
|
||||
regResp.ClientSecret,
|
||||
startURL,
|
||||
region,
|
||||
)
|
||||
if err != nil {
|
||||
log.Errorf("OAuth Web: failed to start device authorization: %v", err)
|
||||
h.renderError(c, fmt.Sprintf("Failed to start device authorization: %v", err))
|
||||
return
|
||||
}
|
||||
|
||||
ctx, cancel := context.WithTimeout(context.Background(), time.Duration(authResp.ExpiresIn)*time.Second)
|
||||
|
||||
session := &webAuthSession{
|
||||
stateID: stateID,
|
||||
deviceCode: authResp.DeviceCode,
|
||||
userCode: authResp.UserCode,
|
||||
authURL: authResp.VerificationURIComplete,
|
||||
verificationURI: authResp.VerificationURI,
|
||||
expiresIn: authResp.ExpiresIn,
|
||||
interval: authResp.Interval,
|
||||
status: statusPending,
|
||||
startedAt: time.Now(),
|
||||
ssoClient: ssoClient,
|
||||
clientID: regResp.ClientID,
|
||||
clientSecret: regResp.ClientSecret,
|
||||
region: region,
|
||||
cancelFunc: cancel,
|
||||
}
|
||||
|
||||
h.mu.Lock()
|
||||
h.sessions[stateID] = session
|
||||
h.mu.Unlock()
|
||||
|
||||
go h.pollForToken(ctx, session)
|
||||
|
||||
h.renderStartPage(c, session)
|
||||
}
|
||||
|
||||
func (h *OAuthWebHandler) pollForToken(ctx context.Context, session *webAuthSession) {
|
||||
defer session.cancelFunc()
|
||||
|
||||
interval := time.Duration(session.interval) * time.Second
|
||||
if interval < time.Duration(pollIntervalSeconds)*time.Second {
|
||||
interval = time.Duration(pollIntervalSeconds) * time.Second
|
||||
}
|
||||
|
||||
ticker := time.NewTicker(interval)
|
||||
defer ticker.Stop()
|
||||
|
||||
for {
|
||||
select {
|
||||
case <-ctx.Done():
|
||||
h.mu.Lock()
|
||||
if session.status == statusPending {
|
||||
session.status = statusFailed
|
||||
session.error = "Authentication timed out"
|
||||
}
|
||||
h.mu.Unlock()
|
||||
return
|
||||
case <-ticker.C:
|
||||
tokenResp, err := h.ssoClient(session).CreateTokenWithRegion(
|
||||
ctx,
|
||||
session.clientID,
|
||||
session.clientSecret,
|
||||
session.deviceCode,
|
||||
session.region,
|
||||
)
|
||||
|
||||
if err != nil {
|
||||
errStr := err.Error()
|
||||
if errStr == ErrAuthorizationPending.Error() {
|
||||
continue
|
||||
}
|
||||
if errStr == ErrSlowDown.Error() {
|
||||
interval += 5 * time.Second
|
||||
ticker.Reset(interval)
|
||||
continue
|
||||
}
|
||||
|
||||
h.mu.Lock()
|
||||
session.status = statusFailed
|
||||
session.error = errStr
|
||||
session.completedAt = time.Now()
|
||||
h.mu.Unlock()
|
||||
|
||||
log.Errorf("OAuth Web: token polling failed: %v", err)
|
||||
return
|
||||
}
|
||||
|
||||
expiresAt := time.Now().Add(time.Duration(tokenResp.ExpiresIn) * time.Second)
|
||||
profileArn := session.ssoClient.fetchProfileArn(ctx, tokenResp.AccessToken)
|
||||
email := FetchUserEmailWithFallback(ctx, h.cfg, tokenResp.AccessToken)
|
||||
|
||||
tokenData := &KiroTokenData{
|
||||
AccessToken: tokenResp.AccessToken,
|
||||
RefreshToken: tokenResp.RefreshToken,
|
||||
ProfileArn: profileArn,
|
||||
ExpiresAt: expiresAt.Format(time.RFC3339),
|
||||
AuthMethod: "builder-id",
|
||||
Provider: "AWS",
|
||||
ClientID: session.clientID,
|
||||
ClientSecret: session.clientSecret,
|
||||
Email: email,
|
||||
}
|
||||
|
||||
h.mu.Lock()
|
||||
session.status = statusSuccess
|
||||
session.completedAt = time.Now()
|
||||
session.expiresAt = expiresAt
|
||||
session.tokenData = tokenData
|
||||
h.mu.Unlock()
|
||||
|
||||
if h.onTokenObtained != nil {
|
||||
h.onTokenObtained(tokenData)
|
||||
}
|
||||
|
||||
log.Infof("OAuth Web: authentication successful for %s", email)
|
||||
return
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func (h *OAuthWebHandler) ssoClient(session *webAuthSession) *SSOOIDCClient {
|
||||
return session.ssoClient
|
||||
}
|
||||
|
||||
func (h *OAuthWebHandler) handleCallback(c *gin.Context) {
|
||||
stateID := c.Query("state")
|
||||
errParam := c.Query("error")
|
||||
|
||||
if errParam != "" {
|
||||
h.renderError(c, errParam)
|
||||
return
|
||||
}
|
||||
|
||||
if stateID == "" {
|
||||
h.renderError(c, "Missing state parameter")
|
||||
return
|
||||
}
|
||||
|
||||
h.mu.RLock()
|
||||
session, exists := h.sessions[stateID]
|
||||
h.mu.RUnlock()
|
||||
|
||||
if !exists {
|
||||
h.renderError(c, "Invalid or expired session")
|
||||
return
|
||||
}
|
||||
|
||||
if session.status == statusSuccess {
|
||||
h.renderSuccess(c, session)
|
||||
} else if session.status == statusFailed {
|
||||
h.renderError(c, session.error)
|
||||
} else {
|
||||
c.Redirect(http.StatusFound, "/v0/oauth/kiro/start")
|
||||
}
|
||||
}
|
||||
|
||||
func (h *OAuthWebHandler) handleStatus(c *gin.Context) {
|
||||
stateID := c.Query("state")
|
||||
if stateID == "" {
|
||||
c.JSON(http.StatusBadRequest, gin.H{"error": "missing state parameter"})
|
||||
return
|
||||
}
|
||||
|
||||
h.mu.RLock()
|
||||
session, exists := h.sessions[stateID]
|
||||
h.mu.RUnlock()
|
||||
|
||||
if !exists {
|
||||
c.JSON(http.StatusNotFound, gin.H{"error": "session not found"})
|
||||
return
|
||||
}
|
||||
|
||||
response := gin.H{
|
||||
"status": string(session.status),
|
||||
}
|
||||
|
||||
switch session.status {
|
||||
case statusPending:
|
||||
elapsed := time.Since(session.startedAt).Seconds()
|
||||
remaining := float64(session.expiresIn) - elapsed
|
||||
if remaining < 0 {
|
||||
remaining = 0
|
||||
}
|
||||
response["remaining_seconds"] = int(remaining)
|
||||
case statusSuccess:
|
||||
response["completed_at"] = session.completedAt.Format(time.RFC3339)
|
||||
response["expires_at"] = session.expiresAt.Format(time.RFC3339)
|
||||
case statusFailed:
|
||||
response["error"] = session.error
|
||||
response["failed_at"] = session.completedAt.Format(time.RFC3339)
|
||||
}
|
||||
|
||||
c.JSON(http.StatusOK, response)
|
||||
}
|
||||
|
||||
func (h *OAuthWebHandler) renderStartPage(c *gin.Context, session *webAuthSession) {
|
||||
tmpl, err := template.New("start").Parse(oauthWebStartPageHTML)
|
||||
if err != nil {
|
||||
log.Errorf("OAuth Web: failed to parse template: %v", err)
|
||||
c.String(http.StatusInternalServerError, "Template error")
|
||||
return
|
||||
}
|
||||
|
||||
data := map[string]interface{}{
|
||||
"AuthURL": session.authURL,
|
||||
"UserCode": session.userCode,
|
||||
"ExpiresIn": session.expiresIn,
|
||||
"StateID": session.stateID,
|
||||
}
|
||||
|
||||
c.Header("Content-Type", "text/html; charset=utf-8")
|
||||
if err := tmpl.Execute(c.Writer, data); err != nil {
|
||||
log.Errorf("OAuth Web: failed to render template: %v", err)
|
||||
}
|
||||
}
|
||||
|
||||
func (h *OAuthWebHandler) renderError(c *gin.Context, errMsg string) {
|
||||
tmpl, err := template.New("error").Parse(oauthWebErrorPageHTML)
|
||||
if err != nil {
|
||||
log.Errorf("OAuth Web: failed to parse error template: %v", err)
|
||||
c.String(http.StatusInternalServerError, "Template error")
|
||||
return
|
||||
}
|
||||
|
||||
data := map[string]interface{}{
|
||||
"Error": errMsg,
|
||||
}
|
||||
|
||||
c.Header("Content-Type", "text/html; charset=utf-8")
|
||||
c.Status(http.StatusBadRequest)
|
||||
if err := tmpl.Execute(c.Writer, data); err != nil {
|
||||
log.Errorf("OAuth Web: failed to render error template: %v", err)
|
||||
}
|
||||
}
|
||||
|
||||
func (h *OAuthWebHandler) renderSuccess(c *gin.Context, session *webAuthSession) {
|
||||
tmpl, err := template.New("success").Parse(oauthWebSuccessPageHTML)
|
||||
if err != nil {
|
||||
log.Errorf("OAuth Web: failed to parse success template: %v", err)
|
||||
c.String(http.StatusInternalServerError, "Template error")
|
||||
return
|
||||
}
|
||||
|
||||
data := map[string]interface{}{
|
||||
"ExpiresAt": session.expiresAt.Format(time.RFC3339),
|
||||
}
|
||||
|
||||
c.Header("Content-Type", "text/html; charset=utf-8")
|
||||
if err := tmpl.Execute(c.Writer, data); err != nil {
|
||||
log.Errorf("OAuth Web: failed to render success template: %v", err)
|
||||
}
|
||||
}
|
||||
|
||||
func (h *OAuthWebHandler) CleanupExpiredSessions() {
|
||||
h.mu.Lock()
|
||||
defer h.mu.Unlock()
|
||||
|
||||
now := time.Now()
|
||||
for id, session := range h.sessions {
|
||||
if session.status != statusPending && now.Sub(session.completedAt) > 30*time.Minute {
|
||||
delete(h.sessions, id)
|
||||
} else if session.status == statusPending && now.Sub(session.startedAt) > defaultSessionExpiry {
|
||||
session.cancelFunc()
|
||||
delete(h.sessions, id)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func (h *OAuthWebHandler) GetSession(stateID string) (*webAuthSession, bool) {
|
||||
h.mu.RLock()
|
||||
defer h.mu.RUnlock()
|
||||
session, exists := h.sessions[stateID]
|
||||
return session, exists
|
||||
}
|
||||
732
internal/auth/kiro/oauth_web_templates.go
Normal file
732
internal/auth/kiro/oauth_web_templates.go
Normal file
@@ -0,0 +1,732 @@
|
||||
// Package kiro provides OAuth Web authentication templates.
|
||||
package kiro
|
||||
|
||||
const (
|
||||
oauthWebStartPageHTML = `<!DOCTYPE html>
|
||||
<html>
|
||||
<head>
|
||||
<meta charset="UTF-8">
|
||||
<meta name="viewport" content="width=device-width, initial-scale=1.0">
|
||||
<title>AWS SSO Authentication</title>
|
||||
<style>
|
||||
* { box-sizing: border-box; }
|
||||
body {
|
||||
font-family: -apple-system, BlinkMacSystemFont, "Segoe UI", Roboto, sans-serif;
|
||||
margin: 0;
|
||||
padding: 20px;
|
||||
background: linear-gradient(135deg, #667eea 0%, #764ba2 100%);
|
||||
min-height: 100vh;
|
||||
display: flex;
|
||||
justify-content: center;
|
||||
align-items: center;
|
||||
}
|
||||
.container {
|
||||
max-width: 500px;
|
||||
width: 100%;
|
||||
background: #fff;
|
||||
padding: 40px;
|
||||
border-radius: 12px;
|
||||
box-shadow: 0 10px 40px rgba(0,0,0,0.2);
|
||||
}
|
||||
h1 {
|
||||
margin: 0 0 10px;
|
||||
color: #333;
|
||||
font-size: 24px;
|
||||
text-align: center;
|
||||
}
|
||||
.subtitle {
|
||||
text-align: center;
|
||||
color: #666;
|
||||
margin-bottom: 30px;
|
||||
}
|
||||
.step {
|
||||
background: #f8f9fa;
|
||||
padding: 20px;
|
||||
border-radius: 8px;
|
||||
margin-bottom: 15px;
|
||||
}
|
||||
.step-title {
|
||||
display: flex;
|
||||
align-items: center;
|
||||
font-weight: 600;
|
||||
color: #333;
|
||||
margin-bottom: 10px;
|
||||
}
|
||||
.step-number {
|
||||
width: 28px;
|
||||
height: 28px;
|
||||
background: #667eea;
|
||||
color: white;
|
||||
border-radius: 50%;
|
||||
display: flex;
|
||||
align-items: center;
|
||||
justify-content: center;
|
||||
font-size: 14px;
|
||||
margin-right: 12px;
|
||||
}
|
||||
.user-code {
|
||||
background: #e7f3ff;
|
||||
border: 2px dashed #2196F3;
|
||||
border-radius: 8px;
|
||||
padding: 20px;
|
||||
text-align: center;
|
||||
margin-top: 10px;
|
||||
}
|
||||
.user-code-label {
|
||||
font-size: 12px;
|
||||
color: #666;
|
||||
text-transform: uppercase;
|
||||
letter-spacing: 1px;
|
||||
margin-bottom: 8px;
|
||||
}
|
||||
.user-code-value {
|
||||
font-size: 32px;
|
||||
font-weight: bold;
|
||||
font-family: monospace;
|
||||
color: #2196F3;
|
||||
letter-spacing: 4px;
|
||||
}
|
||||
.auth-btn {
|
||||
display: block;
|
||||
width: 100%;
|
||||
padding: 15px;
|
||||
background: #667eea;
|
||||
color: white;
|
||||
text-align: center;
|
||||
text-decoration: none;
|
||||
border-radius: 8px;
|
||||
font-weight: 600;
|
||||
font-size: 16px;
|
||||
transition: all 0.3s;
|
||||
border: none;
|
||||
cursor: pointer;
|
||||
margin-top: 20px;
|
||||
}
|
||||
.auth-btn:hover {
|
||||
background: #5568d3;
|
||||
transform: translateY(-2px);
|
||||
box-shadow: 0 4px 12px rgba(102, 126, 234, 0.4);
|
||||
}
|
||||
.status {
|
||||
margin-top: 30px;
|
||||
padding: 20px;
|
||||
background: #f8f9fa;
|
||||
border-radius: 8px;
|
||||
text-align: center;
|
||||
}
|
||||
.status-pending { border-left: 4px solid #ffc107; }
|
||||
.status-success { border-left: 4px solid #28a745; }
|
||||
.status-failed { border-left: 4px solid #dc3545; }
|
||||
.spinner {
|
||||
border: 3px solid #f3f3f3;
|
||||
border-top: 3px solid #667eea;
|
||||
border-radius: 50%;
|
||||
width: 40px;
|
||||
height: 40px;
|
||||
animation: spin 1s linear infinite;
|
||||
margin: 0 auto 15px;
|
||||
}
|
||||
@keyframes spin {
|
||||
0% { transform: rotate(0deg); }
|
||||
100% { transform: rotate(360deg); }
|
||||
}
|
||||
.timer {
|
||||
font-size: 24px;
|
||||
font-weight: bold;
|
||||
color: #667eea;
|
||||
margin: 10px 0;
|
||||
}
|
||||
.timer.warning { color: #ffc107; }
|
||||
.timer.danger { color: #dc3545; }
|
||||
.status-message { color: #666; line-height: 1.6; }
|
||||
.success-icon, .error-icon { font-size: 48px; margin-bottom: 15px; }
|
||||
.info-box {
|
||||
background: #e7f3ff;
|
||||
border-left: 4px solid #2196F3;
|
||||
padding: 15px;
|
||||
margin-top: 20px;
|
||||
border-radius: 4px;
|
||||
font-size: 14px;
|
||||
color: #666;
|
||||
}
|
||||
</style>
|
||||
</head>
|
||||
<body>
|
||||
<div class="container">
|
||||
<h1>🔐 AWS SSO Authentication</h1>
|
||||
<p class="subtitle">Follow the steps below to complete authentication</p>
|
||||
|
||||
<div class="step">
|
||||
<div class="step-title">
|
||||
<span class="step-number">1</span>
|
||||
Click the button below to open the authorization page
|
||||
</div>
|
||||
<a href="{{.AuthURL}}" target="_blank" class="auth-btn" id="authBtn">
|
||||
🚀 Open Authorization Page
|
||||
</a>
|
||||
</div>
|
||||
|
||||
<div class="step">
|
||||
<div class="step-title">
|
||||
<span class="step-number">2</span>
|
||||
Enter the verification code below
|
||||
</div>
|
||||
<div class="user-code">
|
||||
<div class="user-code-label">Verification Code</div>
|
||||
<div class="user-code-value">{{.UserCode}}</div>
|
||||
</div>
|
||||
</div>
|
||||
|
||||
<div class="step">
|
||||
<div class="step-title">
|
||||
<span class="step-number">3</span>
|
||||
Complete AWS SSO login
|
||||
</div>
|
||||
<p style="color: #666; font-size: 14px; margin-top: 10px;">
|
||||
Use your AWS SSO account to login and authorize
|
||||
</p>
|
||||
</div>
|
||||
|
||||
<div class="status status-pending" id="statusBox">
|
||||
<div class="spinner" id="spinner"></div>
|
||||
<div class="timer" id="timer">{{.ExpiresIn}}s</div>
|
||||
<div class="status-message" id="statusMessage">
|
||||
Waiting for authorization...
|
||||
</div>
|
||||
</div>
|
||||
|
||||
<div class="info-box">
|
||||
💡 <strong>Tip:</strong> The authorization page will open in a new tab. This page will automatically update once authorization is complete.
|
||||
</div>
|
||||
</div>
|
||||
|
||||
<script>
|
||||
let pollInterval;
|
||||
let timerInterval;
|
||||
let remainingSeconds = {{.ExpiresIn}};
|
||||
const stateID = "{{.StateID}}";
|
||||
|
||||
setTimeout(() => {
|
||||
document.getElementById('authBtn').click();
|
||||
}, 500);
|
||||
|
||||
function pollStatus() {
|
||||
fetch('/v0/oauth/kiro/status?state=' + stateID)
|
||||
.then(response => response.json())
|
||||
.then(data => {
|
||||
console.log('Status:', data);
|
||||
if (data.status === 'success') {
|
||||
clearInterval(pollInterval);
|
||||
clearInterval(timerInterval);
|
||||
showSuccess(data);
|
||||
} else if (data.status === 'failed') {
|
||||
clearInterval(pollInterval);
|
||||
clearInterval(timerInterval);
|
||||
showError(data);
|
||||
} else {
|
||||
remainingSeconds = data.remaining_seconds || 0;
|
||||
}
|
||||
})
|
||||
.catch(error => {
|
||||
console.error('Poll error:', error);
|
||||
});
|
||||
}
|
||||
|
||||
function updateTimer() {
|
||||
const timerEl = document.getElementById('timer');
|
||||
const minutes = Math.floor(remainingSeconds / 60);
|
||||
const seconds = remainingSeconds % 60;
|
||||
timerEl.textContent = minutes + ':' + seconds.toString().padStart(2, '0');
|
||||
|
||||
if (remainingSeconds < 60) {
|
||||
timerEl.className = 'timer danger';
|
||||
} else if (remainingSeconds < 180) {
|
||||
timerEl.className = 'timer warning';
|
||||
} else {
|
||||
timerEl.className = 'timer';
|
||||
}
|
||||
|
||||
remainingSeconds--;
|
||||
|
||||
if (remainingSeconds < 0) {
|
||||
clearInterval(timerInterval);
|
||||
clearInterval(pollInterval);
|
||||
showError({ error: 'Authentication timed out. Please refresh and try again.' });
|
||||
}
|
||||
}
|
||||
|
||||
function showSuccess(data) {
|
||||
const statusBox = document.getElementById('statusBox');
|
||||
statusBox.className = 'status status-success';
|
||||
statusBox.innerHTML = '<div class="success-icon">✅</div>' +
|
||||
'<div class="status-message">' +
|
||||
'<strong>Authentication Successful!</strong><br>' +
|
||||
'Token expires: ' + new Date(data.expires_at).toLocaleString() +
|
||||
'</div>';
|
||||
}
|
||||
|
||||
function showError(data) {
|
||||
const statusBox = document.getElementById('statusBox');
|
||||
statusBox.className = 'status status-failed';
|
||||
statusBox.innerHTML = '<div class="error-icon">❌</div>' +
|
||||
'<div class="status-message">' +
|
||||
'<strong>Authentication Failed</strong><br>' +
|
||||
(data.error || 'Unknown error') +
|
||||
'</div>' +
|
||||
'<button class="auth-btn" onclick="location.reload()" style="margin-top: 15px;">' +
|
||||
'🔄 Retry' +
|
||||
'</button>';
|
||||
}
|
||||
|
||||
pollInterval = setInterval(pollStatus, 3000);
|
||||
timerInterval = setInterval(updateTimer, 1000);
|
||||
pollStatus();
|
||||
</script>
|
||||
</body>
|
||||
</html>`
|
||||
|
||||
oauthWebErrorPageHTML = `<!DOCTYPE html>
|
||||
<html>
|
||||
<head>
|
||||
<meta charset="UTF-8">
|
||||
<meta name="viewport" content="width=device-width, initial-scale=1.0">
|
||||
<title>Authentication Failed</title>
|
||||
<style>
|
||||
body {
|
||||
font-family: -apple-system, BlinkMacSystemFont, "Segoe UI", Roboto, sans-serif;
|
||||
max-width: 600px;
|
||||
margin: 50px auto;
|
||||
padding: 20px;
|
||||
background: #f5f5f5;
|
||||
}
|
||||
.error {
|
||||
background: #fff;
|
||||
padding: 30px;
|
||||
border-radius: 8px;
|
||||
box-shadow: 0 2px 4px rgba(0,0,0,0.1);
|
||||
border-left: 4px solid #dc3545;
|
||||
}
|
||||
h1 { color: #dc3545; margin-top: 0; }
|
||||
.error-message { color: #666; line-height: 1.6; }
|
||||
.retry-btn {
|
||||
display: inline-block;
|
||||
margin-top: 20px;
|
||||
padding: 10px 20px;
|
||||
background: #007bff;
|
||||
color: white;
|
||||
text-decoration: none;
|
||||
border-radius: 4px;
|
||||
}
|
||||
.retry-btn:hover { background: #0056b3; }
|
||||
</style>
|
||||
</head>
|
||||
<body>
|
||||
<div class="error">
|
||||
<h1>❌ Authentication Failed</h1>
|
||||
<div class="error-message">
|
||||
<p><strong>Error:</strong></p>
|
||||
<p>{{.Error}}</p>
|
||||
</div>
|
||||
<a href="/v0/oauth/kiro/start" class="retry-btn">🔄 Retry</a>
|
||||
</div>
|
||||
</body>
|
||||
</html>`
|
||||
|
||||
oauthWebSuccessPageHTML = `<!DOCTYPE html>
|
||||
<html>
|
||||
<head>
|
||||
<meta charset="UTF-8">
|
||||
<meta name="viewport" content="width=device-width, initial-scale=1.0">
|
||||
<title>Authentication Successful</title>
|
||||
<style>
|
||||
body {
|
||||
font-family: -apple-system, BlinkMacSystemFont, "Segoe UI", Roboto, sans-serif;
|
||||
max-width: 600px;
|
||||
margin: 50px auto;
|
||||
padding: 20px;
|
||||
background: #f5f5f5;
|
||||
}
|
||||
.success {
|
||||
background: #fff;
|
||||
padding: 30px;
|
||||
border-radius: 8px;
|
||||
box-shadow: 0 2px 4px rgba(0,0,0,0.1);
|
||||
border-left: 4px solid #28a745;
|
||||
text-align: center;
|
||||
}
|
||||
h1 { color: #28a745; margin-top: 0; }
|
||||
.success-message { color: #666; line-height: 1.6; }
|
||||
.icon { font-size: 48px; margin-bottom: 15px; }
|
||||
.expires { font-size: 14px; color: #999; margin-top: 15px; }
|
||||
</style>
|
||||
</head>
|
||||
<body>
|
||||
<div class="success">
|
||||
<div class="icon">✅</div>
|
||||
<h1>Authentication Successful!</h1>
|
||||
<div class="success-message">
|
||||
<p>You can close this window.</p>
|
||||
</div>
|
||||
<div class="expires">Token expires: {{.ExpiresAt}}</div>
|
||||
</div>
|
||||
</body>
|
||||
</html>`
|
||||
|
||||
oauthWebSelectPageHTML = `<!DOCTYPE html>
|
||||
<html>
|
||||
<head>
|
||||
<meta charset="UTF-8">
|
||||
<meta name="viewport" content="width=device-width, initial-scale=1.0">
|
||||
<title>Select Authentication Method</title>
|
||||
<style>
|
||||
* { box-sizing: border-box; }
|
||||
body {
|
||||
font-family: -apple-system, BlinkMacSystemFont, "Segoe UI", Roboto, sans-serif;
|
||||
margin: 0;
|
||||
padding: 20px;
|
||||
background: linear-gradient(135deg, #667eea 0%, #764ba2 100%);
|
||||
min-height: 100vh;
|
||||
display: flex;
|
||||
justify-content: center;
|
||||
align-items: center;
|
||||
}
|
||||
.container {
|
||||
max-width: 500px;
|
||||
width: 100%;
|
||||
background: #fff;
|
||||
padding: 40px;
|
||||
border-radius: 12px;
|
||||
box-shadow: 0 10px 40px rgba(0,0,0,0.2);
|
||||
}
|
||||
h1 {
|
||||
margin: 0 0 10px;
|
||||
color: #333;
|
||||
font-size: 24px;
|
||||
text-align: center;
|
||||
}
|
||||
.subtitle {
|
||||
text-align: center;
|
||||
color: #666;
|
||||
margin-bottom: 30px;
|
||||
}
|
||||
.auth-methods {
|
||||
display: flex;
|
||||
flex-direction: column;
|
||||
gap: 15px;
|
||||
}
|
||||
.auth-btn {
|
||||
display: flex;
|
||||
align-items: center;
|
||||
width: 100%;
|
||||
padding: 15px 20px;
|
||||
background: #667eea;
|
||||
color: white;
|
||||
text-decoration: none;
|
||||
border-radius: 8px;
|
||||
font-weight: 600;
|
||||
font-size: 16px;
|
||||
transition: all 0.3s;
|
||||
border: none;
|
||||
cursor: pointer;
|
||||
}
|
||||
.auth-btn:hover {
|
||||
background: #5568d3;
|
||||
transform: translateY(-2px);
|
||||
box-shadow: 0 4px 12px rgba(102, 126, 234, 0.4);
|
||||
}
|
||||
.auth-btn .icon {
|
||||
font-size: 24px;
|
||||
margin-right: 15px;
|
||||
width: 32px;
|
||||
text-align: center;
|
||||
}
|
||||
.auth-btn.google { background: #4285F4; }
|
||||
.auth-btn.google:hover { background: #3367D6; }
|
||||
.auth-btn.github { background: #24292e; }
|
||||
.auth-btn.github:hover { background: #1a1e22; }
|
||||
.auth-btn.aws { background: #FF9900; }
|
||||
.auth-btn.aws:hover { background: #E68A00; }
|
||||
.auth-btn.idc { background: #232F3E; }
|
||||
.auth-btn.idc:hover { background: #1a242f; }
|
||||
.idc-form {
|
||||
background: #f8f9fa;
|
||||
padding: 20px;
|
||||
border-radius: 8px;
|
||||
margin-top: 15px;
|
||||
display: none;
|
||||
}
|
||||
.idc-form.show {
|
||||
display: block;
|
||||
}
|
||||
.form-group {
|
||||
margin-bottom: 15px;
|
||||
}
|
||||
.form-group label {
|
||||
display: block;
|
||||
font-weight: 600;
|
||||
color: #333;
|
||||
margin-bottom: 8px;
|
||||
font-size: 14px;
|
||||
}
|
||||
.form-group input {
|
||||
width: 100%;
|
||||
padding: 12px;
|
||||
border: 2px solid #e0e0e0;
|
||||
border-radius: 6px;
|
||||
font-size: 14px;
|
||||
transition: border-color 0.3s;
|
||||
}
|
||||
.form-group input:focus {
|
||||
outline: none;
|
||||
border-color: #667eea;
|
||||
}
|
||||
.form-group .hint {
|
||||
font-size: 12px;
|
||||
color: #999;
|
||||
margin-top: 5px;
|
||||
}
|
||||
.submit-btn {
|
||||
display: block;
|
||||
width: 100%;
|
||||
padding: 15px;
|
||||
background: #232F3E;
|
||||
color: white;
|
||||
text-align: center;
|
||||
text-decoration: none;
|
||||
border-radius: 8px;
|
||||
font-weight: 600;
|
||||
font-size: 16px;
|
||||
transition: all 0.3s;
|
||||
border: none;
|
||||
cursor: pointer;
|
||||
}
|
||||
.submit-btn:hover {
|
||||
background: #1a242f;
|
||||
transform: translateY(-2px);
|
||||
box-shadow: 0 4px 12px rgba(35, 47, 62, 0.4);
|
||||
}
|
||||
.divider {
|
||||
display: flex;
|
||||
align-items: center;
|
||||
margin: 20px 0;
|
||||
}
|
||||
.divider::before,
|
||||
.divider::after {
|
||||
content: "";
|
||||
flex: 1;
|
||||
border-bottom: 1px solid #e0e0e0;
|
||||
}
|
||||
.divider span {
|
||||
padding: 0 15px;
|
||||
color: #999;
|
||||
font-size: 14px;
|
||||
}
|
||||
.info-box {
|
||||
background: #e7f3ff;
|
||||
border-left: 4px solid #2196F3;
|
||||
padding: 15px;
|
||||
margin-top: 20px;
|
||||
border-radius: 4px;
|
||||
font-size: 14px;
|
||||
color: #666;
|
||||
}
|
||||
.warning-box {
|
||||
background: #fff3cd;
|
||||
border-left: 4px solid #ffc107;
|
||||
padding: 15px;
|
||||
margin-top: 20px;
|
||||
border-radius: 4px;
|
||||
font-size: 14px;
|
||||
color: #856404;
|
||||
}
|
||||
.auth-btn.manual { background: #6c757d; }
|
||||
.auth-btn.manual:hover { background: #5a6268; }
|
||||
.manual-form {
|
||||
background: #f8f9fa;
|
||||
padding: 20px;
|
||||
border-radius: 8px;
|
||||
margin-top: 15px;
|
||||
display: none;
|
||||
}
|
||||
.manual-form.show {
|
||||
display: block;
|
||||
}
|
||||
.form-group textarea {
|
||||
width: 100%;
|
||||
padding: 12px;
|
||||
border: 2px solid #e0e0e0;
|
||||
border-radius: 6px;
|
||||
font-size: 14px;
|
||||
font-family: monospace;
|
||||
transition: border-color 0.3s;
|
||||
resize: vertical;
|
||||
min-height: 80px;
|
||||
}
|
||||
.form-group textarea:focus {
|
||||
outline: none;
|
||||
border-color: #667eea;
|
||||
}
|
||||
.status-message {
|
||||
padding: 15px;
|
||||
border-radius: 6px;
|
||||
margin-top: 15px;
|
||||
display: none;
|
||||
}
|
||||
.status-message.success {
|
||||
background: #d4edda;
|
||||
color: #155724;
|
||||
display: block;
|
||||
}
|
||||
.status-message.error {
|
||||
background: #f8d7da;
|
||||
color: #721c24;
|
||||
display: block;
|
||||
}
|
||||
</style>
|
||||
</head>
|
||||
<body>
|
||||
<div class="container">
|
||||
<h1>🔐 Select Authentication Method</h1>
|
||||
<p class="subtitle">Choose how you want to authenticate with Kiro</p>
|
||||
|
||||
<div class="auth-methods">
|
||||
<a href="/v0/oauth/kiro/start?method=builder-id" class="auth-btn aws">
|
||||
<span class="icon">🔶</span>
|
||||
AWS Builder ID (Recommended)
|
||||
</a>
|
||||
|
||||
<button type="button" class="auth-btn idc" onclick="toggleIdcForm()">
|
||||
<span class="icon">🏢</span>
|
||||
AWS Identity Center (IDC)
|
||||
</button>
|
||||
|
||||
<div class="divider"><span>or</span></div>
|
||||
|
||||
<button type="button" class="auth-btn manual" onclick="toggleManualForm()">
|
||||
<span class="icon">📋</span>
|
||||
Import RefreshToken from Kiro IDE
|
||||
</button>
|
||||
</div>
|
||||
|
||||
<div class="idc-form" id="idcForm">
|
||||
<form action="/v0/oauth/kiro/start" method="get">
|
||||
<input type="hidden" name="method" value="idc">
|
||||
|
||||
<div class="form-group">
|
||||
<label for="startUrl">Start URL</label>
|
||||
<input type="url" id="startUrl" name="startUrl" placeholder="https://your-org.awsapps.com/start" required>
|
||||
<div class="hint">Your AWS Identity Center Start URL</div>
|
||||
</div>
|
||||
|
||||
<div class="form-group">
|
||||
<label for="region">Region</label>
|
||||
<input type="text" id="region" name="region" value="us-east-1" placeholder="us-east-1">
|
||||
<div class="hint">AWS Region for your Identity Center</div>
|
||||
</div>
|
||||
|
||||
<button type="submit" class="submit-btn">
|
||||
🚀 Continue with IDC
|
||||
</button>
|
||||
</form>
|
||||
</div>
|
||||
|
||||
<div class="manual-form" id="manualForm">
|
||||
<form id="importForm" onsubmit="submitImport(event)">
|
||||
<div class="form-group">
|
||||
<label for="refreshToken">Refresh Token</label>
|
||||
<textarea id="refreshToken" name="refreshToken" placeholder="Paste your refreshToken here (starts with aorAAAAAG...)" required></textarea>
|
||||
<div class="hint">Copy from Kiro IDE: ~/.kiro/kiro-auth-token.json → refreshToken field</div>
|
||||
</div>
|
||||
|
||||
<button type="submit" class="submit-btn" id="importBtn">
|
||||
📥 Import Token
|
||||
</button>
|
||||
|
||||
<div class="status-message" id="importStatus"></div>
|
||||
</form>
|
||||
</div>
|
||||
|
||||
<div class="warning-box">
|
||||
⚠️ <strong>Note:</strong> Google and GitHub login are not available for third-party applications due to AWS Cognito restrictions. Please use AWS Builder ID or import your token from Kiro IDE.
|
||||
</div>
|
||||
|
||||
<div class="info-box">
|
||||
💡 <strong>How to get RefreshToken:</strong><br>
|
||||
1. Open Kiro IDE and login with Google/GitHub<br>
|
||||
2. Find the token file: <code>~/.kiro/kiro-auth-token.json</code><br>
|
||||
3. Copy the <code>refreshToken</code> value and paste it above
|
||||
</div>
|
||||
</div>
|
||||
|
||||
<script>
|
||||
function toggleIdcForm() {
|
||||
const idcForm = document.getElementById('idcForm');
|
||||
const manualForm = document.getElementById('manualForm');
|
||||
manualForm.classList.remove('show');
|
||||
idcForm.classList.toggle('show');
|
||||
if (idcForm.classList.contains('show')) {
|
||||
document.getElementById('startUrl').focus();
|
||||
}
|
||||
}
|
||||
|
||||
function toggleManualForm() {
|
||||
const idcForm = document.getElementById('idcForm');
|
||||
const manualForm = document.getElementById('manualForm');
|
||||
idcForm.classList.remove('show');
|
||||
manualForm.classList.toggle('show');
|
||||
if (manualForm.classList.contains('show')) {
|
||||
document.getElementById('refreshToken').focus();
|
||||
}
|
||||
}
|
||||
|
||||
async function submitImport(event) {
|
||||
event.preventDefault();
|
||||
const refreshToken = document.getElementById('refreshToken').value.trim();
|
||||
const statusEl = document.getElementById('importStatus');
|
||||
const btn = document.getElementById('importBtn');
|
||||
|
||||
if (!refreshToken) {
|
||||
statusEl.className = 'status-message error';
|
||||
statusEl.textContent = 'Please enter a refresh token';
|
||||
return;
|
||||
}
|
||||
|
||||
if (!refreshToken.startsWith('aorAAAAAG')) {
|
||||
statusEl.className = 'status-message error';
|
||||
statusEl.textContent = 'Invalid token format. Token should start with aorAAAAAG...';
|
||||
return;
|
||||
}
|
||||
|
||||
btn.disabled = true;
|
||||
btn.textContent = '⏳ Importing...';
|
||||
statusEl.className = 'status-message';
|
||||
statusEl.style.display = 'none';
|
||||
|
||||
try {
|
||||
const response = await fetch('/v0/oauth/kiro/import', {
|
||||
method: 'POST',
|
||||
headers: { 'Content-Type': 'application/json' },
|
||||
body: JSON.stringify({ refreshToken: refreshToken })
|
||||
});
|
||||
|
||||
const data = await response.json();
|
||||
|
||||
if (response.ok && data.success) {
|
||||
statusEl.className = 'status-message success';
|
||||
statusEl.textContent = '✅ Token imported successfully! File: ' + (data.fileName || 'kiro-token.json');
|
||||
} else {
|
||||
statusEl.className = 'status-message error';
|
||||
statusEl.textContent = '❌ ' + (data.error || data.message || 'Import failed');
|
||||
}
|
||||
} catch (error) {
|
||||
statusEl.className = 'status-message error';
|
||||
statusEl.textContent = '❌ Network error: ' + error.message;
|
||||
} finally {
|
||||
btn.disabled = false;
|
||||
btn.textContent = '📥 Import Token';
|
||||
}
|
||||
}
|
||||
</script>
|
||||
</body>
|
||||
</html>`
|
||||
)
|
||||
316
internal/auth/kiro/rate_limiter.go
Normal file
316
internal/auth/kiro/rate_limiter.go
Normal file
@@ -0,0 +1,316 @@
|
||||
package kiro
|
||||
|
||||
import (
|
||||
"math"
|
||||
"math/rand"
|
||||
"strings"
|
||||
"sync"
|
||||
"time"
|
||||
)
|
||||
|
||||
const (
|
||||
DefaultMinTokenInterval = 10 * time.Second
|
||||
DefaultMaxTokenInterval = 30 * time.Second
|
||||
DefaultDailyMaxRequests = 500
|
||||
DefaultJitterPercent = 0.3
|
||||
DefaultBackoffBase = 2 * time.Minute
|
||||
DefaultBackoffMax = 60 * time.Minute
|
||||
DefaultBackoffMultiplier = 2.0
|
||||
DefaultSuspendCooldown = 24 * time.Hour
|
||||
)
|
||||
|
||||
// TokenState Token 状态
|
||||
type TokenState struct {
|
||||
LastRequest time.Time
|
||||
RequestCount int
|
||||
CooldownEnd time.Time
|
||||
FailCount int
|
||||
DailyRequests int
|
||||
DailyResetTime time.Time
|
||||
IsSuspended bool
|
||||
SuspendedAt time.Time
|
||||
SuspendReason string
|
||||
}
|
||||
|
||||
// RateLimiter 频率限制器
|
||||
type RateLimiter struct {
|
||||
mu sync.RWMutex
|
||||
states map[string]*TokenState
|
||||
minTokenInterval time.Duration
|
||||
maxTokenInterval time.Duration
|
||||
dailyMaxRequests int
|
||||
jitterPercent float64
|
||||
backoffBase time.Duration
|
||||
backoffMax time.Duration
|
||||
backoffMultiplier float64
|
||||
suspendCooldown time.Duration
|
||||
rng *rand.Rand
|
||||
}
|
||||
|
||||
// NewRateLimiter 创建默认配置的频率限制器
|
||||
func NewRateLimiter() *RateLimiter {
|
||||
return &RateLimiter{
|
||||
states: make(map[string]*TokenState),
|
||||
minTokenInterval: DefaultMinTokenInterval,
|
||||
maxTokenInterval: DefaultMaxTokenInterval,
|
||||
dailyMaxRequests: DefaultDailyMaxRequests,
|
||||
jitterPercent: DefaultJitterPercent,
|
||||
backoffBase: DefaultBackoffBase,
|
||||
backoffMax: DefaultBackoffMax,
|
||||
backoffMultiplier: DefaultBackoffMultiplier,
|
||||
suspendCooldown: DefaultSuspendCooldown,
|
||||
rng: rand.New(rand.NewSource(time.Now().UnixNano())),
|
||||
}
|
||||
}
|
||||
|
||||
// RateLimiterConfig 频率限制器配置
|
||||
type RateLimiterConfig struct {
|
||||
MinTokenInterval time.Duration
|
||||
MaxTokenInterval time.Duration
|
||||
DailyMaxRequests int
|
||||
JitterPercent float64
|
||||
BackoffBase time.Duration
|
||||
BackoffMax time.Duration
|
||||
BackoffMultiplier float64
|
||||
SuspendCooldown time.Duration
|
||||
}
|
||||
|
||||
// NewRateLimiterWithConfig 使用自定义配置创建频率限制器
|
||||
func NewRateLimiterWithConfig(cfg RateLimiterConfig) *RateLimiter {
|
||||
rl := NewRateLimiter()
|
||||
if cfg.MinTokenInterval > 0 {
|
||||
rl.minTokenInterval = cfg.MinTokenInterval
|
||||
}
|
||||
if cfg.MaxTokenInterval > 0 {
|
||||
rl.maxTokenInterval = cfg.MaxTokenInterval
|
||||
}
|
||||
if cfg.DailyMaxRequests > 0 {
|
||||
rl.dailyMaxRequests = cfg.DailyMaxRequests
|
||||
}
|
||||
if cfg.JitterPercent > 0 {
|
||||
rl.jitterPercent = cfg.JitterPercent
|
||||
}
|
||||
if cfg.BackoffBase > 0 {
|
||||
rl.backoffBase = cfg.BackoffBase
|
||||
}
|
||||
if cfg.BackoffMax > 0 {
|
||||
rl.backoffMax = cfg.BackoffMax
|
||||
}
|
||||
if cfg.BackoffMultiplier > 0 {
|
||||
rl.backoffMultiplier = cfg.BackoffMultiplier
|
||||
}
|
||||
if cfg.SuspendCooldown > 0 {
|
||||
rl.suspendCooldown = cfg.SuspendCooldown
|
||||
}
|
||||
return rl
|
||||
}
|
||||
|
||||
// getOrCreateState 获取或创建 Token 状态
|
||||
func (rl *RateLimiter) getOrCreateState(tokenKey string) *TokenState {
|
||||
state, exists := rl.states[tokenKey]
|
||||
if !exists {
|
||||
state = &TokenState{
|
||||
DailyResetTime: time.Now().Truncate(24 * time.Hour).Add(24 * time.Hour),
|
||||
}
|
||||
rl.states[tokenKey] = state
|
||||
}
|
||||
return state
|
||||
}
|
||||
|
||||
// resetDailyIfNeeded 如果需要则重置每日计数
|
||||
func (rl *RateLimiter) resetDailyIfNeeded(state *TokenState) {
|
||||
now := time.Now()
|
||||
if now.After(state.DailyResetTime) {
|
||||
state.DailyRequests = 0
|
||||
state.DailyResetTime = now.Truncate(24 * time.Hour).Add(24 * time.Hour)
|
||||
}
|
||||
}
|
||||
|
||||
// calculateInterval 计算带抖动的随机间隔
|
||||
func (rl *RateLimiter) calculateInterval() time.Duration {
|
||||
baseInterval := rl.minTokenInterval + time.Duration(rl.rng.Int63n(int64(rl.maxTokenInterval-rl.minTokenInterval)))
|
||||
jitter := time.Duration(float64(baseInterval) * rl.jitterPercent * (rl.rng.Float64()*2 - 1))
|
||||
return baseInterval + jitter
|
||||
}
|
||||
|
||||
// WaitForToken 等待 Token 可用(带抖动的随机间隔)
|
||||
func (rl *RateLimiter) WaitForToken(tokenKey string) {
|
||||
rl.mu.Lock()
|
||||
state := rl.getOrCreateState(tokenKey)
|
||||
rl.resetDailyIfNeeded(state)
|
||||
|
||||
now := time.Now()
|
||||
|
||||
// 检查是否在冷却期
|
||||
if now.Before(state.CooldownEnd) {
|
||||
waitTime := state.CooldownEnd.Sub(now)
|
||||
rl.mu.Unlock()
|
||||
time.Sleep(waitTime)
|
||||
rl.mu.Lock()
|
||||
state = rl.getOrCreateState(tokenKey)
|
||||
now = time.Now()
|
||||
}
|
||||
|
||||
// 计算距离上次请求的间隔
|
||||
interval := rl.calculateInterval()
|
||||
nextAllowedTime := state.LastRequest.Add(interval)
|
||||
|
||||
if now.Before(nextAllowedTime) {
|
||||
waitTime := nextAllowedTime.Sub(now)
|
||||
rl.mu.Unlock()
|
||||
time.Sleep(waitTime)
|
||||
rl.mu.Lock()
|
||||
state = rl.getOrCreateState(tokenKey)
|
||||
}
|
||||
|
||||
state.LastRequest = time.Now()
|
||||
state.RequestCount++
|
||||
state.DailyRequests++
|
||||
rl.mu.Unlock()
|
||||
}
|
||||
|
||||
// MarkTokenFailed 标记 Token 失败
|
||||
func (rl *RateLimiter) MarkTokenFailed(tokenKey string) {
|
||||
rl.mu.Lock()
|
||||
defer rl.mu.Unlock()
|
||||
|
||||
state := rl.getOrCreateState(tokenKey)
|
||||
state.FailCount++
|
||||
state.CooldownEnd = time.Now().Add(rl.calculateBackoff(state.FailCount))
|
||||
}
|
||||
|
||||
// MarkTokenSuccess 标记 Token 成功
|
||||
func (rl *RateLimiter) MarkTokenSuccess(tokenKey string) {
|
||||
rl.mu.Lock()
|
||||
defer rl.mu.Unlock()
|
||||
|
||||
state := rl.getOrCreateState(tokenKey)
|
||||
state.FailCount = 0
|
||||
state.CooldownEnd = time.Time{}
|
||||
}
|
||||
|
||||
// CheckAndMarkSuspended 检测暂停错误并标记
|
||||
func (rl *RateLimiter) CheckAndMarkSuspended(tokenKey string, errorMsg string) bool {
|
||||
suspendKeywords := []string{
|
||||
"suspended",
|
||||
"banned",
|
||||
"disabled",
|
||||
"account has been",
|
||||
"access denied",
|
||||
"rate limit exceeded",
|
||||
"too many requests",
|
||||
"quota exceeded",
|
||||
}
|
||||
|
||||
lowerMsg := strings.ToLower(errorMsg)
|
||||
for _, keyword := range suspendKeywords {
|
||||
if strings.Contains(lowerMsg, keyword) {
|
||||
rl.mu.Lock()
|
||||
defer rl.mu.Unlock()
|
||||
|
||||
state := rl.getOrCreateState(tokenKey)
|
||||
state.IsSuspended = true
|
||||
state.SuspendedAt = time.Now()
|
||||
state.SuspendReason = errorMsg
|
||||
state.CooldownEnd = time.Now().Add(rl.suspendCooldown)
|
||||
return true
|
||||
}
|
||||
}
|
||||
return false
|
||||
}
|
||||
|
||||
// IsTokenAvailable 检查 Token 是否可用
|
||||
func (rl *RateLimiter) IsTokenAvailable(tokenKey string) bool {
|
||||
rl.mu.RLock()
|
||||
defer rl.mu.RUnlock()
|
||||
|
||||
state, exists := rl.states[tokenKey]
|
||||
if !exists {
|
||||
return true
|
||||
}
|
||||
|
||||
now := time.Now()
|
||||
|
||||
// 检查是否被暂停
|
||||
if state.IsSuspended {
|
||||
if now.After(state.SuspendedAt.Add(rl.suspendCooldown)) {
|
||||
return true
|
||||
}
|
||||
return false
|
||||
}
|
||||
|
||||
// 检查是否在冷却期
|
||||
if now.Before(state.CooldownEnd) {
|
||||
return false
|
||||
}
|
||||
|
||||
// 检查每日请求限制
|
||||
rl.mu.RUnlock()
|
||||
rl.mu.Lock()
|
||||
rl.resetDailyIfNeeded(state)
|
||||
dailyRequests := state.DailyRequests
|
||||
dailyMax := rl.dailyMaxRequests
|
||||
rl.mu.Unlock()
|
||||
rl.mu.RLock()
|
||||
|
||||
if dailyRequests >= dailyMax {
|
||||
return false
|
||||
}
|
||||
|
||||
return true
|
||||
}
|
||||
|
||||
// calculateBackoff 计算指数退避时间
|
||||
func (rl *RateLimiter) calculateBackoff(failCount int) time.Duration {
|
||||
if failCount <= 0 {
|
||||
return 0
|
||||
}
|
||||
|
||||
backoff := float64(rl.backoffBase) * math.Pow(rl.backoffMultiplier, float64(failCount-1))
|
||||
|
||||
// 添加抖动
|
||||
jitter := backoff * rl.jitterPercent * (rl.rng.Float64()*2 - 1)
|
||||
backoff += jitter
|
||||
|
||||
if time.Duration(backoff) > rl.backoffMax {
|
||||
return rl.backoffMax
|
||||
}
|
||||
return time.Duration(backoff)
|
||||
}
|
||||
|
||||
// GetTokenState 获取 Token 状态(只读)
|
||||
func (rl *RateLimiter) GetTokenState(tokenKey string) *TokenState {
|
||||
rl.mu.RLock()
|
||||
defer rl.mu.RUnlock()
|
||||
|
||||
state, exists := rl.states[tokenKey]
|
||||
if !exists {
|
||||
return nil
|
||||
}
|
||||
|
||||
// 返回副本以防止外部修改
|
||||
stateCopy := *state
|
||||
return &stateCopy
|
||||
}
|
||||
|
||||
// ClearTokenState 清除 Token 状态
|
||||
func (rl *RateLimiter) ClearTokenState(tokenKey string) {
|
||||
rl.mu.Lock()
|
||||
defer rl.mu.Unlock()
|
||||
delete(rl.states, tokenKey)
|
||||
}
|
||||
|
||||
// ResetSuspension 重置暂停状态
|
||||
func (rl *RateLimiter) ResetSuspension(tokenKey string) {
|
||||
rl.mu.Lock()
|
||||
defer rl.mu.Unlock()
|
||||
|
||||
state, exists := rl.states[tokenKey]
|
||||
if exists {
|
||||
state.IsSuspended = false
|
||||
state.SuspendedAt = time.Time{}
|
||||
state.SuspendReason = ""
|
||||
state.CooldownEnd = time.Time{}
|
||||
state.FailCount = 0
|
||||
}
|
||||
}
|
||||
46
internal/auth/kiro/rate_limiter_singleton.go
Normal file
46
internal/auth/kiro/rate_limiter_singleton.go
Normal file
@@ -0,0 +1,46 @@
|
||||
package kiro
|
||||
|
||||
import (
|
||||
"sync"
|
||||
"time"
|
||||
|
||||
log "github.com/sirupsen/logrus"
|
||||
)
|
||||
|
||||
var (
|
||||
globalRateLimiter *RateLimiter
|
||||
globalRateLimiterOnce sync.Once
|
||||
|
||||
globalCooldownManager *CooldownManager
|
||||
globalCooldownManagerOnce sync.Once
|
||||
cooldownStopCh chan struct{}
|
||||
)
|
||||
|
||||
// GetGlobalRateLimiter returns the singleton RateLimiter instance.
|
||||
func GetGlobalRateLimiter() *RateLimiter {
|
||||
globalRateLimiterOnce.Do(func() {
|
||||
globalRateLimiter = NewRateLimiter()
|
||||
log.Info("kiro: global RateLimiter initialized")
|
||||
})
|
||||
return globalRateLimiter
|
||||
}
|
||||
|
||||
// GetGlobalCooldownManager returns the singleton CooldownManager instance.
|
||||
func GetGlobalCooldownManager() *CooldownManager {
|
||||
globalCooldownManagerOnce.Do(func() {
|
||||
globalCooldownManager = NewCooldownManager()
|
||||
cooldownStopCh = make(chan struct{})
|
||||
go globalCooldownManager.StartCleanupRoutine(5*time.Minute, cooldownStopCh)
|
||||
log.Info("kiro: global CooldownManager initialized with cleanup routine")
|
||||
})
|
||||
return globalCooldownManager
|
||||
}
|
||||
|
||||
// ShutdownRateLimiters stops the cooldown cleanup routine.
|
||||
// Should be called during application shutdown.
|
||||
func ShutdownRateLimiters() {
|
||||
if cooldownStopCh != nil {
|
||||
close(cooldownStopCh)
|
||||
log.Info("kiro: rate limiter cleanup routine stopped")
|
||||
}
|
||||
}
|
||||
304
internal/auth/kiro/rate_limiter_test.go
Normal file
304
internal/auth/kiro/rate_limiter_test.go
Normal file
@@ -0,0 +1,304 @@
|
||||
package kiro
|
||||
|
||||
import (
|
||||
"sync"
|
||||
"testing"
|
||||
"time"
|
||||
)
|
||||
|
||||
func TestNewRateLimiter(t *testing.T) {
|
||||
rl := NewRateLimiter()
|
||||
if rl == nil {
|
||||
t.Fatal("expected non-nil RateLimiter")
|
||||
}
|
||||
if rl.states == nil {
|
||||
t.Error("expected non-nil states map")
|
||||
}
|
||||
if rl.minTokenInterval != DefaultMinTokenInterval {
|
||||
t.Errorf("expected minTokenInterval %v, got %v", DefaultMinTokenInterval, rl.minTokenInterval)
|
||||
}
|
||||
if rl.maxTokenInterval != DefaultMaxTokenInterval {
|
||||
t.Errorf("expected maxTokenInterval %v, got %v", DefaultMaxTokenInterval, rl.maxTokenInterval)
|
||||
}
|
||||
if rl.dailyMaxRequests != DefaultDailyMaxRequests {
|
||||
t.Errorf("expected dailyMaxRequests %d, got %d", DefaultDailyMaxRequests, rl.dailyMaxRequests)
|
||||
}
|
||||
}
|
||||
|
||||
func TestNewRateLimiterWithConfig(t *testing.T) {
|
||||
cfg := RateLimiterConfig{
|
||||
MinTokenInterval: 5 * time.Second,
|
||||
MaxTokenInterval: 15 * time.Second,
|
||||
DailyMaxRequests: 100,
|
||||
JitterPercent: 0.2,
|
||||
BackoffBase: 1 * time.Minute,
|
||||
BackoffMax: 30 * time.Minute,
|
||||
BackoffMultiplier: 1.5,
|
||||
SuspendCooldown: 12 * time.Hour,
|
||||
}
|
||||
|
||||
rl := NewRateLimiterWithConfig(cfg)
|
||||
if rl.minTokenInterval != 5*time.Second {
|
||||
t.Errorf("expected minTokenInterval 5s, got %v", rl.minTokenInterval)
|
||||
}
|
||||
if rl.maxTokenInterval != 15*time.Second {
|
||||
t.Errorf("expected maxTokenInterval 15s, got %v", rl.maxTokenInterval)
|
||||
}
|
||||
if rl.dailyMaxRequests != 100 {
|
||||
t.Errorf("expected dailyMaxRequests 100, got %d", rl.dailyMaxRequests)
|
||||
}
|
||||
}
|
||||
|
||||
func TestNewRateLimiterWithConfig_PartialConfig(t *testing.T) {
|
||||
cfg := RateLimiterConfig{
|
||||
MinTokenInterval: 5 * time.Second,
|
||||
}
|
||||
|
||||
rl := NewRateLimiterWithConfig(cfg)
|
||||
if rl.minTokenInterval != 5*time.Second {
|
||||
t.Errorf("expected minTokenInterval 5s, got %v", rl.minTokenInterval)
|
||||
}
|
||||
if rl.maxTokenInterval != DefaultMaxTokenInterval {
|
||||
t.Errorf("expected default maxTokenInterval, got %v", rl.maxTokenInterval)
|
||||
}
|
||||
}
|
||||
|
||||
func TestGetTokenState_NonExistent(t *testing.T) {
|
||||
rl := NewRateLimiter()
|
||||
state := rl.GetTokenState("nonexistent")
|
||||
if state != nil {
|
||||
t.Error("expected nil state for non-existent token")
|
||||
}
|
||||
}
|
||||
|
||||
func TestIsTokenAvailable_NewToken(t *testing.T) {
|
||||
rl := NewRateLimiter()
|
||||
if !rl.IsTokenAvailable("newtoken") {
|
||||
t.Error("expected new token to be available")
|
||||
}
|
||||
}
|
||||
|
||||
func TestMarkTokenFailed(t *testing.T) {
|
||||
rl := NewRateLimiter()
|
||||
rl.MarkTokenFailed("token1")
|
||||
|
||||
state := rl.GetTokenState("token1")
|
||||
if state == nil {
|
||||
t.Fatal("expected non-nil state")
|
||||
}
|
||||
if state.FailCount != 1 {
|
||||
t.Errorf("expected FailCount 1, got %d", state.FailCount)
|
||||
}
|
||||
if state.CooldownEnd.IsZero() {
|
||||
t.Error("expected non-zero CooldownEnd")
|
||||
}
|
||||
}
|
||||
|
||||
func TestMarkTokenSuccess(t *testing.T) {
|
||||
rl := NewRateLimiter()
|
||||
rl.MarkTokenFailed("token1")
|
||||
rl.MarkTokenFailed("token1")
|
||||
rl.MarkTokenSuccess("token1")
|
||||
|
||||
state := rl.GetTokenState("token1")
|
||||
if state == nil {
|
||||
t.Fatal("expected non-nil state")
|
||||
}
|
||||
if state.FailCount != 0 {
|
||||
t.Errorf("expected FailCount 0, got %d", state.FailCount)
|
||||
}
|
||||
if !state.CooldownEnd.IsZero() {
|
||||
t.Error("expected zero CooldownEnd after success")
|
||||
}
|
||||
}
|
||||
|
||||
func TestCheckAndMarkSuspended_Suspended(t *testing.T) {
|
||||
rl := NewRateLimiter()
|
||||
|
||||
testCases := []string{
|
||||
"Account has been suspended",
|
||||
"You are banned from this service",
|
||||
"Account disabled",
|
||||
"Access denied permanently",
|
||||
"Rate limit exceeded",
|
||||
"Too many requests",
|
||||
"Quota exceeded for today",
|
||||
}
|
||||
|
||||
for i, msg := range testCases {
|
||||
tokenKey := "token" + string(rune('a'+i))
|
||||
if !rl.CheckAndMarkSuspended(tokenKey, msg) {
|
||||
t.Errorf("expected suspension detected for: %s", msg)
|
||||
}
|
||||
state := rl.GetTokenState(tokenKey)
|
||||
if !state.IsSuspended {
|
||||
t.Errorf("expected IsSuspended true for: %s", msg)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func TestCheckAndMarkSuspended_NotSuspended(t *testing.T) {
|
||||
rl := NewRateLimiter()
|
||||
|
||||
normalErrors := []string{
|
||||
"connection timeout",
|
||||
"internal server error",
|
||||
"bad request",
|
||||
"invalid token format",
|
||||
}
|
||||
|
||||
for i, msg := range normalErrors {
|
||||
tokenKey := "token" + string(rune('a'+i))
|
||||
if rl.CheckAndMarkSuspended(tokenKey, msg) {
|
||||
t.Errorf("unexpected suspension for: %s", msg)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func TestIsTokenAvailable_Suspended(t *testing.T) {
|
||||
rl := NewRateLimiter()
|
||||
rl.CheckAndMarkSuspended("token1", "Account suspended")
|
||||
|
||||
if rl.IsTokenAvailable("token1") {
|
||||
t.Error("expected suspended token to be unavailable")
|
||||
}
|
||||
}
|
||||
|
||||
func TestClearTokenState(t *testing.T) {
|
||||
rl := NewRateLimiter()
|
||||
rl.MarkTokenFailed("token1")
|
||||
rl.ClearTokenState("token1")
|
||||
|
||||
state := rl.GetTokenState("token1")
|
||||
if state != nil {
|
||||
t.Error("expected nil state after clear")
|
||||
}
|
||||
}
|
||||
|
||||
func TestResetSuspension(t *testing.T) {
|
||||
rl := NewRateLimiter()
|
||||
rl.CheckAndMarkSuspended("token1", "Account suspended")
|
||||
rl.ResetSuspension("token1")
|
||||
|
||||
state := rl.GetTokenState("token1")
|
||||
if state.IsSuspended {
|
||||
t.Error("expected IsSuspended false after reset")
|
||||
}
|
||||
if state.FailCount != 0 {
|
||||
t.Errorf("expected FailCount 0, got %d", state.FailCount)
|
||||
}
|
||||
}
|
||||
|
||||
func TestResetSuspension_NonExistent(t *testing.T) {
|
||||
rl := NewRateLimiter()
|
||||
rl.ResetSuspension("nonexistent")
|
||||
}
|
||||
|
||||
func TestCalculateBackoff_ZeroFailCount(t *testing.T) {
|
||||
rl := NewRateLimiter()
|
||||
backoff := rl.calculateBackoff(0)
|
||||
if backoff != 0 {
|
||||
t.Errorf("expected 0 backoff for 0 fails, got %v", backoff)
|
||||
}
|
||||
}
|
||||
|
||||
func TestCalculateBackoff_Exponential(t *testing.T) {
|
||||
cfg := RateLimiterConfig{
|
||||
BackoffBase: 1 * time.Minute,
|
||||
BackoffMax: 60 * time.Minute,
|
||||
BackoffMultiplier: 2.0,
|
||||
JitterPercent: 0.3,
|
||||
}
|
||||
rl := NewRateLimiterWithConfig(cfg)
|
||||
|
||||
backoff1 := rl.calculateBackoff(1)
|
||||
if backoff1 < 40*time.Second || backoff1 > 80*time.Second {
|
||||
t.Errorf("expected ~1min (with jitter) for fail 1, got %v", backoff1)
|
||||
}
|
||||
|
||||
backoff2 := rl.calculateBackoff(2)
|
||||
if backoff2 < 80*time.Second || backoff2 > 160*time.Second {
|
||||
t.Errorf("expected ~2min (with jitter) for fail 2, got %v", backoff2)
|
||||
}
|
||||
}
|
||||
|
||||
func TestCalculateBackoff_MaxCap(t *testing.T) {
|
||||
cfg := RateLimiterConfig{
|
||||
BackoffBase: 1 * time.Minute,
|
||||
BackoffMax: 10 * time.Minute,
|
||||
BackoffMultiplier: 2.0,
|
||||
JitterPercent: 0,
|
||||
}
|
||||
rl := NewRateLimiterWithConfig(cfg)
|
||||
|
||||
backoff := rl.calculateBackoff(10)
|
||||
if backoff > 10*time.Minute {
|
||||
t.Errorf("expected backoff capped at 10min, got %v", backoff)
|
||||
}
|
||||
}
|
||||
|
||||
func TestGetTokenState_ReturnsCopy(t *testing.T) {
|
||||
rl := NewRateLimiter()
|
||||
rl.MarkTokenFailed("token1")
|
||||
|
||||
state1 := rl.GetTokenState("token1")
|
||||
state1.FailCount = 999
|
||||
|
||||
state2 := rl.GetTokenState("token1")
|
||||
if state2.FailCount == 999 {
|
||||
t.Error("GetTokenState should return a copy")
|
||||
}
|
||||
}
|
||||
|
||||
func TestRateLimiter_ConcurrentAccess(t *testing.T) {
|
||||
rl := NewRateLimiter()
|
||||
const numGoroutines = 50
|
||||
const numOperations = 50
|
||||
|
||||
var wg sync.WaitGroup
|
||||
wg.Add(numGoroutines)
|
||||
|
||||
for i := 0; i < numGoroutines; i++ {
|
||||
go func(id int) {
|
||||
defer wg.Done()
|
||||
tokenKey := "token" + string(rune('a'+id%10))
|
||||
for j := 0; j < numOperations; j++ {
|
||||
switch j % 6 {
|
||||
case 0:
|
||||
rl.IsTokenAvailable(tokenKey)
|
||||
case 1:
|
||||
rl.MarkTokenFailed(tokenKey)
|
||||
case 2:
|
||||
rl.MarkTokenSuccess(tokenKey)
|
||||
case 3:
|
||||
rl.GetTokenState(tokenKey)
|
||||
case 4:
|
||||
rl.CheckAndMarkSuspended(tokenKey, "test error")
|
||||
case 5:
|
||||
rl.ResetSuspension(tokenKey)
|
||||
}
|
||||
}
|
||||
}(i)
|
||||
}
|
||||
|
||||
wg.Wait()
|
||||
}
|
||||
|
||||
func TestCalculateInterval_WithinRange(t *testing.T) {
|
||||
cfg := RateLimiterConfig{
|
||||
MinTokenInterval: 10 * time.Second,
|
||||
MaxTokenInterval: 30 * time.Second,
|
||||
JitterPercent: 0.3,
|
||||
}
|
||||
rl := NewRateLimiterWithConfig(cfg)
|
||||
|
||||
minAllowed := 7 * time.Second
|
||||
maxAllowed := 40 * time.Second
|
||||
|
||||
for i := 0; i < 100; i++ {
|
||||
interval := rl.calculateInterval()
|
||||
if interval < minAllowed || interval > maxAllowed {
|
||||
t.Errorf("interval %v outside expected range [%v, %v]", interval, minAllowed, maxAllowed)
|
||||
}
|
||||
}
|
||||
}
|
||||
@@ -9,7 +9,9 @@ import (
|
||||
"encoding/base64"
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"html"
|
||||
"io"
|
||||
"net"
|
||||
"net/http"
|
||||
"net/url"
|
||||
"os"
|
||||
@@ -31,6 +33,9 @@ const (
|
||||
|
||||
// OAuth timeout
|
||||
socialAuthTimeout = 10 * time.Minute
|
||||
|
||||
// Default callback port for social auth HTTP server
|
||||
socialAuthCallbackPort = 9876
|
||||
)
|
||||
|
||||
// SocialProvider represents the social login provider.
|
||||
@@ -67,6 +72,13 @@ type RefreshTokenRequest struct {
|
||||
RefreshToken string `json:"refreshToken"`
|
||||
}
|
||||
|
||||
// WebCallbackResult contains the OAuth callback result from HTTP server.
|
||||
type WebCallbackResult struct {
|
||||
Code string
|
||||
State string
|
||||
Error string
|
||||
}
|
||||
|
||||
// SocialAuthClient handles social authentication with Kiro.
|
||||
type SocialAuthClient struct {
|
||||
httpClient *http.Client
|
||||
@@ -87,6 +99,83 @@ func NewSocialAuthClient(cfg *config.Config) *SocialAuthClient {
|
||||
}
|
||||
}
|
||||
|
||||
// startWebCallbackServer starts a local HTTP server to receive the OAuth callback.
|
||||
// This is used instead of the kiro:// protocol handler to avoid redirect_mismatch errors.
|
||||
func (c *SocialAuthClient) startWebCallbackServer(ctx context.Context, expectedState string) (string, <-chan WebCallbackResult, error) {
|
||||
// Try to find an available port - use localhost like Kiro does
|
||||
listener, err := net.Listen("tcp", fmt.Sprintf("localhost:%d", socialAuthCallbackPort))
|
||||
if err != nil {
|
||||
// Try with dynamic port (RFC 8252 allows dynamic ports for native apps)
|
||||
log.Warnf("kiro social auth: default port %d is busy, falling back to dynamic port", socialAuthCallbackPort)
|
||||
listener, err = net.Listen("tcp", "localhost:0")
|
||||
if err != nil {
|
||||
return "", nil, fmt.Errorf("failed to start callback server: %w", err)
|
||||
}
|
||||
}
|
||||
|
||||
port := listener.Addr().(*net.TCPAddr).Port
|
||||
// Use http scheme for local callback server
|
||||
redirectURI := fmt.Sprintf("http://localhost:%d/oauth/callback", port)
|
||||
resultChan := make(chan WebCallbackResult, 1)
|
||||
|
||||
server := &http.Server{
|
||||
ReadHeaderTimeout: 10 * time.Second,
|
||||
}
|
||||
|
||||
mux := http.NewServeMux()
|
||||
mux.HandleFunc("/oauth/callback", func(w http.ResponseWriter, r *http.Request) {
|
||||
code := r.URL.Query().Get("code")
|
||||
state := r.URL.Query().Get("state")
|
||||
errParam := r.URL.Query().Get("error")
|
||||
|
||||
if errParam != "" {
|
||||
w.Header().Set("Content-Type", "text/html; charset=utf-8")
|
||||
w.WriteHeader(http.StatusBadRequest)
|
||||
fmt.Fprintf(w, `<!DOCTYPE html>
|
||||
<html><head><title>Login Failed</title></head>
|
||||
<body><h1>Login Failed</h1><p>%s</p><p>You can close this window.</p></body></html>`, html.EscapeString(errParam))
|
||||
resultChan <- WebCallbackResult{Error: errParam}
|
||||
return
|
||||
}
|
||||
|
||||
if state != expectedState {
|
||||
w.Header().Set("Content-Type", "text/html; charset=utf-8")
|
||||
w.WriteHeader(http.StatusBadRequest)
|
||||
fmt.Fprint(w, `<!DOCTYPE html>
|
||||
<html><head><title>Login Failed</title></head>
|
||||
<body><h1>Login Failed</h1><p>Invalid state parameter</p><p>You can close this window.</p></body></html>`)
|
||||
resultChan <- WebCallbackResult{Error: "state mismatch"}
|
||||
return
|
||||
}
|
||||
|
||||
w.Header().Set("Content-Type", "text/html; charset=utf-8")
|
||||
fmt.Fprint(w, `<!DOCTYPE html>
|
||||
<html><head><title>Login Successful</title></head>
|
||||
<body><h1>Login Successful!</h1><p>You can close this window and return to the terminal.</p>
|
||||
<script>window.close();</script></body></html>`)
|
||||
resultChan <- WebCallbackResult{Code: code, State: state}
|
||||
})
|
||||
|
||||
server.Handler = mux
|
||||
|
||||
go func() {
|
||||
if err := server.Serve(listener); err != nil && err != http.ErrServerClosed {
|
||||
log.Debugf("kiro social auth callback server error: %v", err)
|
||||
}
|
||||
}()
|
||||
|
||||
go func() {
|
||||
select {
|
||||
case <-ctx.Done():
|
||||
case <-time.After(socialAuthTimeout):
|
||||
case <-resultChan:
|
||||
}
|
||||
_ = server.Shutdown(context.Background())
|
||||
}()
|
||||
|
||||
return redirectURI, resultChan, nil
|
||||
}
|
||||
|
||||
// generatePKCE generates PKCE code verifier and challenge.
|
||||
func generatePKCE() (verifier, challenge string, err error) {
|
||||
// Generate 32 bytes of random data for verifier
|
||||
@@ -217,10 +306,12 @@ func (c *SocialAuthClient) RefreshSocialToken(ctx context.Context, refreshToken
|
||||
ExpiresAt: expiresAt.Format(time.RFC3339),
|
||||
AuthMethod: "social",
|
||||
Provider: "", // Caller should preserve original provider
|
||||
Region: "us-east-1",
|
||||
}, nil
|
||||
}
|
||||
|
||||
// LoginWithSocial performs OAuth login with Google.
|
||||
// LoginWithSocial performs OAuth login with Google or GitHub.
|
||||
// Uses local HTTP callback server instead of custom protocol handler to avoid redirect_mismatch errors.
|
||||
func (c *SocialAuthClient) LoginWithSocial(ctx context.Context, provider SocialProvider) (*KiroTokenData, error) {
|
||||
providerName := string(provider)
|
||||
|
||||
@@ -228,28 +319,10 @@ func (c *SocialAuthClient) LoginWithSocial(ctx context.Context, provider SocialP
|
||||
fmt.Printf("║ Kiro Authentication (%s) ║\n", providerName)
|
||||
fmt.Println("╚══════════════════════════════════════════════════════════╝")
|
||||
|
||||
// Step 1: Setup protocol handler
|
||||
// Step 1: Start local HTTP callback server (instead of kiro:// protocol handler)
|
||||
// This avoids redirect_mismatch errors with AWS Cognito
|
||||
fmt.Println("\nSetting up authentication...")
|
||||
|
||||
// Start the local callback server
|
||||
handlerPort, err := c.protocolHandler.Start(ctx)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to start callback server: %w", err)
|
||||
}
|
||||
defer c.protocolHandler.Stop()
|
||||
|
||||
// Ensure protocol handler is installed and set as default
|
||||
if err := SetupProtocolHandlerIfNeeded(handlerPort); err != nil {
|
||||
fmt.Println("\n⚠ Protocol handler setup failed. Trying alternative method...")
|
||||
fmt.Println(" If you see a browser 'Open with' dialog, select your default browser.")
|
||||
fmt.Println(" For manual setup instructions, run: cliproxy kiro --help-protocol")
|
||||
log.Debugf("kiro: protocol handler setup error: %v", err)
|
||||
// Continue anyway - user might have set it up manually or select browser manually
|
||||
} else {
|
||||
// Force set our handler as default (prevents "Open with" dialog)
|
||||
forceDefaultProtocolHandler()
|
||||
}
|
||||
|
||||
// Step 2: Generate PKCE codes
|
||||
codeVerifier, codeChallenge, err := generatePKCE()
|
||||
if err != nil {
|
||||
@@ -262,8 +335,15 @@ func (c *SocialAuthClient) LoginWithSocial(ctx context.Context, provider SocialP
|
||||
return nil, fmt.Errorf("failed to generate state: %w", err)
|
||||
}
|
||||
|
||||
// Step 4: Build the login URL (Kiro uses GET request with query params)
|
||||
authURL := c.buildLoginURL(providerName, KiroRedirectURI, codeChallenge, state)
|
||||
// Step 4: Start local HTTP callback server
|
||||
redirectURI, resultChan, err := c.startWebCallbackServer(ctx, state)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to start callback server: %w", err)
|
||||
}
|
||||
log.Debugf("kiro social auth: callback server started at %s", redirectURI)
|
||||
|
||||
// Step 5: Build the login URL using HTTP redirect URI
|
||||
authURL := c.buildLoginURL(providerName, redirectURI, codeChallenge, state)
|
||||
|
||||
// Set incognito mode based on config (defaults to true for Kiro, can be overridden with --no-incognito)
|
||||
// Incognito mode enables multi-account support by bypassing cached sessions
|
||||
@@ -279,7 +359,7 @@ func (c *SocialAuthClient) LoginWithSocial(ctx context.Context, provider SocialP
|
||||
log.Debug("kiro: using incognito mode for multi-account support (default)")
|
||||
}
|
||||
|
||||
// Step 5: Open browser for user authentication
|
||||
// Step 6: Open browser for user authentication
|
||||
fmt.Println("\n════════════════════════════════════════════════════════════")
|
||||
fmt.Printf(" Opening browser for %s authentication...\n", providerName)
|
||||
fmt.Println("════════════════════════════════════════════════════════════")
|
||||
@@ -295,80 +375,78 @@ func (c *SocialAuthClient) LoginWithSocial(ctx context.Context, provider SocialP
|
||||
|
||||
fmt.Println("\n Waiting for authentication callback...")
|
||||
|
||||
// Step 6: Wait for callback
|
||||
callback, err := c.protocolHandler.WaitForCallback(ctx)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to receive callback: %w", err)
|
||||
}
|
||||
|
||||
if callback.Error != "" {
|
||||
return nil, fmt.Errorf("authentication error: %s", callback.Error)
|
||||
}
|
||||
|
||||
if callback.State != state {
|
||||
// Log state values for debugging, but don't expose in user-facing error
|
||||
log.Debugf("kiro: OAuth state mismatch - expected %s, got %s", state, callback.State)
|
||||
return nil, fmt.Errorf("OAuth state validation failed - please try again")
|
||||
}
|
||||
|
||||
if callback.Code == "" {
|
||||
return nil, fmt.Errorf("no authorization code received")
|
||||
}
|
||||
|
||||
fmt.Println("\n✓ Authorization received!")
|
||||
|
||||
// Step 7: Exchange code for tokens
|
||||
fmt.Println("Exchanging code for tokens...")
|
||||
|
||||
tokenReq := &CreateTokenRequest{
|
||||
Code: callback.Code,
|
||||
CodeVerifier: codeVerifier,
|
||||
RedirectURI: KiroRedirectURI,
|
||||
}
|
||||
|
||||
tokenResp, err := c.CreateToken(ctx, tokenReq)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to exchange code for tokens: %w", err)
|
||||
}
|
||||
|
||||
fmt.Println("\n✓ Authentication successful!")
|
||||
|
||||
// Close the browser window
|
||||
if err := browser.CloseBrowser(); err != nil {
|
||||
log.Debugf("Failed to close browser: %v", err)
|
||||
}
|
||||
|
||||
// Validate ExpiresIn - use default 1 hour if invalid
|
||||
expiresIn := tokenResp.ExpiresIn
|
||||
if expiresIn <= 0 {
|
||||
expiresIn = 3600
|
||||
}
|
||||
expiresAt := time.Now().Add(time.Duration(expiresIn) * time.Second)
|
||||
|
||||
// Try to extract email from JWT access token first
|
||||
email := ExtractEmailFromJWT(tokenResp.AccessToken)
|
||||
|
||||
// If no email in JWT, ask user for account label (only in interactive mode)
|
||||
if email == "" && isInteractiveTerminal() {
|
||||
fmt.Print("\n Enter account label for file naming (optional, press Enter to skip): ")
|
||||
reader := bufio.NewReader(os.Stdin)
|
||||
var err error
|
||||
email, err = reader.ReadString('\n')
|
||||
if err != nil {
|
||||
log.Debugf("Failed to read account label: %v", err)
|
||||
// Step 7: Wait for callback from HTTP server
|
||||
select {
|
||||
case <-ctx.Done():
|
||||
return nil, ctx.Err()
|
||||
case <-time.After(socialAuthTimeout):
|
||||
return nil, fmt.Errorf("authentication timed out")
|
||||
case callback := <-resultChan:
|
||||
if callback.Error != "" {
|
||||
return nil, fmt.Errorf("authentication error: %s", callback.Error)
|
||||
}
|
||||
email = strings.TrimSpace(email)
|
||||
}
|
||||
|
||||
return &KiroTokenData{
|
||||
AccessToken: tokenResp.AccessToken,
|
||||
RefreshToken: tokenResp.RefreshToken,
|
||||
ProfileArn: tokenResp.ProfileArn,
|
||||
ExpiresAt: expiresAt.Format(time.RFC3339),
|
||||
AuthMethod: "social",
|
||||
Provider: providerName,
|
||||
Email: email, // JWT email or user-provided label
|
||||
}, nil
|
||||
// State is already validated by the callback server
|
||||
if callback.Code == "" {
|
||||
return nil, fmt.Errorf("no authorization code received")
|
||||
}
|
||||
|
||||
fmt.Println("\n✓ Authorization received!")
|
||||
|
||||
// Step 8: Exchange code for tokens
|
||||
fmt.Println("Exchanging code for tokens...")
|
||||
|
||||
tokenReq := &CreateTokenRequest{
|
||||
Code: callback.Code,
|
||||
CodeVerifier: codeVerifier,
|
||||
RedirectURI: redirectURI, // Use HTTP redirect URI, not kiro:// protocol
|
||||
}
|
||||
|
||||
tokenResp, err := c.CreateToken(ctx, tokenReq)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to exchange code for tokens: %w", err)
|
||||
}
|
||||
|
||||
fmt.Println("\n✓ Authentication successful!")
|
||||
|
||||
// Close the browser window
|
||||
if err := browser.CloseBrowser(); err != nil {
|
||||
log.Debugf("Failed to close browser: %v", err)
|
||||
}
|
||||
|
||||
// Validate ExpiresIn - use default 1 hour if invalid
|
||||
expiresIn := tokenResp.ExpiresIn
|
||||
if expiresIn <= 0 {
|
||||
expiresIn = 3600
|
||||
}
|
||||
expiresAt := time.Now().Add(time.Duration(expiresIn) * time.Second)
|
||||
|
||||
// Try to extract email from JWT access token first
|
||||
email := ExtractEmailFromJWT(tokenResp.AccessToken)
|
||||
|
||||
// If no email in JWT, ask user for account label (only in interactive mode)
|
||||
if email == "" && isInteractiveTerminal() {
|
||||
fmt.Print("\n Enter account label for file naming (optional, press Enter to skip): ")
|
||||
reader := bufio.NewReader(os.Stdin)
|
||||
var err error
|
||||
email, err = reader.ReadString('\n')
|
||||
if err != nil {
|
||||
log.Debugf("Failed to read account label: %v", err)
|
||||
}
|
||||
email = strings.TrimSpace(email)
|
||||
}
|
||||
|
||||
return &KiroTokenData{
|
||||
AccessToken: tokenResp.AccessToken,
|
||||
RefreshToken: tokenResp.RefreshToken,
|
||||
ProfileArn: tokenResp.ProfileArn,
|
||||
ExpiresAt: expiresAt.Format(time.RFC3339),
|
||||
AuthMethod: "social",
|
||||
Provider: providerName,
|
||||
Email: email, // JWT email or user-provided label
|
||||
Region: "us-east-1",
|
||||
}, nil
|
||||
}
|
||||
}
|
||||
|
||||
// LoginWithGoogle performs OAuth login with Google.
|
||||
|
||||
@@ -735,6 +735,7 @@ func (c *SSOOIDCClient) RefreshToken(ctx context.Context, clientID, clientSecret
|
||||
Provider: "AWS",
|
||||
ClientID: clientID,
|
||||
ClientSecret: clientSecret,
|
||||
Region: defaultIDCRegion,
|
||||
}, nil
|
||||
}
|
||||
|
||||
@@ -850,16 +851,17 @@ func (c *SSOOIDCClient) LoginWithBuilderID(ctx context.Context) (*KiroTokenData,
|
||||
ClientID: regResp.ClientID,
|
||||
ClientSecret: regResp.ClientSecret,
|
||||
Email: email,
|
||||
Region: defaultIDCRegion,
|
||||
}, nil
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// Close browser on timeout for better UX
|
||||
if err := browser.CloseBrowser(); err != nil {
|
||||
log.Debugf("Failed to close browser on timeout: %v", err)
|
||||
}
|
||||
return nil, fmt.Errorf("authorization timed out")
|
||||
}
|
||||
// Close browser on timeout for better UX
|
||||
if err := browser.CloseBrowser(); err != nil {
|
||||
log.Debugf("Failed to close browser on timeout: %v", err)
|
||||
}
|
||||
return nil, fmt.Errorf("authorization timed out")
|
||||
}
|
||||
|
||||
// FetchUserEmail retrieves the user's email from AWS SSO OIDC userinfo endpoint.
|
||||
// Falls back to JWT parsing if userinfo fails.
|
||||
@@ -1366,6 +1368,7 @@ func (c *SSOOIDCClient) LoginWithBuilderIDAuthCode(ctx context.Context) (*KiroTo
|
||||
ClientID: regResp.ClientID,
|
||||
ClientSecret: regResp.ClientSecret,
|
||||
Email: email,
|
||||
Region: defaultIDCRegion,
|
||||
}, nil
|
||||
}
|
||||
}
|
||||
|
||||
1371
internal/auth/kiro/sso_oidc.go.bak
Normal file
1371
internal/auth/kiro/sso_oidc.go.bak
Normal file
File diff suppressed because it is too large
Load Diff
@@ -9,6 +9,8 @@ import (
|
||||
|
||||
// KiroTokenStorage holds the persistent token data for Kiro authentication.
|
||||
type KiroTokenStorage struct {
|
||||
// Type is the provider type for management UI recognition (must be "kiro")
|
||||
Type string `json:"type"`
|
||||
// AccessToken is the OAuth2 access token for API access
|
||||
AccessToken string `json:"access_token"`
|
||||
// RefreshToken is used to obtain new access tokens
|
||||
@@ -23,6 +25,16 @@ type KiroTokenStorage struct {
|
||||
Provider string `json:"provider"`
|
||||
// LastRefresh is the timestamp of the last token refresh
|
||||
LastRefresh string `json:"last_refresh"`
|
||||
// ClientID is the OAuth client ID (required for token refresh)
|
||||
ClientID string `json:"clientId,omitempty"`
|
||||
// ClientSecret is the OAuth client secret (required for token refresh)
|
||||
ClientSecret string `json:"clientSecret,omitempty"`
|
||||
// Region is the AWS region
|
||||
Region string `json:"region,omitempty"`
|
||||
// StartURL is the AWS Identity Center start URL (for IDC auth)
|
||||
StartURL string `json:"startUrl,omitempty"`
|
||||
// Email is the user's email address
|
||||
Email string `json:"email,omitempty"`
|
||||
}
|
||||
|
||||
// SaveTokenToFile persists the token storage to the specified file path.
|
||||
@@ -68,5 +80,10 @@ func (s *KiroTokenStorage) ToTokenData() *KiroTokenData {
|
||||
ExpiresAt: s.ExpiresAt,
|
||||
AuthMethod: s.AuthMethod,
|
||||
Provider: s.Provider,
|
||||
ClientID: s.ClientID,
|
||||
ClientSecret: s.ClientSecret,
|
||||
Region: s.Region,
|
||||
StartURL: s.StartURL,
|
||||
Email: s.Email,
|
||||
}
|
||||
}
|
||||
|
||||
243
internal/auth/kiro/usage_checker.go
Normal file
243
internal/auth/kiro/usage_checker.go
Normal file
@@ -0,0 +1,243 @@
|
||||
// Package kiro provides authentication functionality for AWS CodeWhisperer (Kiro) API.
|
||||
// This file implements usage quota checking and monitoring.
|
||||
package kiro
|
||||
|
||||
import (
|
||||
"context"
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"io"
|
||||
"net/http"
|
||||
"strings"
|
||||
"time"
|
||||
|
||||
"github.com/router-for-me/CLIProxyAPI/v6/internal/config"
|
||||
"github.com/router-for-me/CLIProxyAPI/v6/internal/util"
|
||||
)
|
||||
|
||||
// UsageQuotaResponse represents the API response structure for usage quota checking.
|
||||
type UsageQuotaResponse struct {
|
||||
UsageBreakdownList []UsageBreakdownExtended `json:"usageBreakdownList"`
|
||||
SubscriptionInfo *SubscriptionInfo `json:"subscriptionInfo,omitempty"`
|
||||
NextDateReset float64 `json:"nextDateReset,omitempty"`
|
||||
}
|
||||
|
||||
// UsageBreakdownExtended represents detailed usage information for quota checking.
|
||||
// Note: UsageBreakdown is already defined in codewhisperer_client.go
|
||||
type UsageBreakdownExtended struct {
|
||||
ResourceType string `json:"resourceType"`
|
||||
UsageLimitWithPrecision float64 `json:"usageLimitWithPrecision"`
|
||||
CurrentUsageWithPrecision float64 `json:"currentUsageWithPrecision"`
|
||||
FreeTrialInfo *FreeTrialInfoExtended `json:"freeTrialInfo,omitempty"`
|
||||
}
|
||||
|
||||
// FreeTrialInfoExtended represents free trial usage information.
|
||||
type FreeTrialInfoExtended struct {
|
||||
FreeTrialStatus string `json:"freeTrialStatus"`
|
||||
UsageLimitWithPrecision float64 `json:"usageLimitWithPrecision"`
|
||||
CurrentUsageWithPrecision float64 `json:"currentUsageWithPrecision"`
|
||||
}
|
||||
|
||||
// QuotaStatus represents the quota status for a token.
|
||||
type QuotaStatus struct {
|
||||
TotalLimit float64
|
||||
CurrentUsage float64
|
||||
RemainingQuota float64
|
||||
IsExhausted bool
|
||||
ResourceType string
|
||||
NextReset time.Time
|
||||
}
|
||||
|
||||
// UsageChecker provides methods for checking token quota usage.
|
||||
type UsageChecker struct {
|
||||
httpClient *http.Client
|
||||
endpoint string
|
||||
}
|
||||
|
||||
// NewUsageChecker creates a new UsageChecker instance.
|
||||
func NewUsageChecker(cfg *config.Config) *UsageChecker {
|
||||
return &UsageChecker{
|
||||
httpClient: util.SetProxy(&cfg.SDKConfig, &http.Client{Timeout: 30 * time.Second}),
|
||||
endpoint: awsKiroEndpoint,
|
||||
}
|
||||
}
|
||||
|
||||
// NewUsageCheckerWithClient creates a UsageChecker with a custom HTTP client.
|
||||
func NewUsageCheckerWithClient(client *http.Client) *UsageChecker {
|
||||
return &UsageChecker{
|
||||
httpClient: client,
|
||||
endpoint: awsKiroEndpoint,
|
||||
}
|
||||
}
|
||||
|
||||
// CheckUsage retrieves usage limits for the given token.
|
||||
func (c *UsageChecker) CheckUsage(ctx context.Context, tokenData *KiroTokenData) (*UsageQuotaResponse, error) {
|
||||
if tokenData == nil {
|
||||
return nil, fmt.Errorf("token data is nil")
|
||||
}
|
||||
|
||||
if tokenData.AccessToken == "" {
|
||||
return nil, fmt.Errorf("access token is empty")
|
||||
}
|
||||
|
||||
payload := map[string]interface{}{
|
||||
"origin": "AI_EDITOR",
|
||||
"profileArn": tokenData.ProfileArn,
|
||||
"resourceType": "AGENTIC_REQUEST",
|
||||
}
|
||||
|
||||
jsonBody, err := json.Marshal(payload)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to marshal request: %w", err)
|
||||
}
|
||||
|
||||
req, err := http.NewRequestWithContext(ctx, http.MethodPost, c.endpoint, strings.NewReader(string(jsonBody)))
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to create request: %w", err)
|
||||
}
|
||||
|
||||
req.Header.Set("Content-Type", "application/x-amz-json-1.0")
|
||||
req.Header.Set("x-amz-target", targetGetUsage)
|
||||
req.Header.Set("Authorization", "Bearer "+tokenData.AccessToken)
|
||||
req.Header.Set("Accept", "application/json")
|
||||
|
||||
resp, err := c.httpClient.Do(req)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("request failed: %w", err)
|
||||
}
|
||||
defer resp.Body.Close()
|
||||
|
||||
body, err := io.ReadAll(resp.Body)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to read response: %w", err)
|
||||
}
|
||||
|
||||
if resp.StatusCode != http.StatusOK {
|
||||
return nil, fmt.Errorf("API error (status %d): %s", resp.StatusCode, string(body))
|
||||
}
|
||||
|
||||
var result UsageQuotaResponse
|
||||
if err := json.Unmarshal(body, &result); err != nil {
|
||||
return nil, fmt.Errorf("failed to parse usage response: %w", err)
|
||||
}
|
||||
|
||||
return &result, nil
|
||||
}
|
||||
|
||||
// CheckUsageByAccessToken retrieves usage limits using an access token and profile ARN directly.
|
||||
func (c *UsageChecker) CheckUsageByAccessToken(ctx context.Context, accessToken, profileArn string) (*UsageQuotaResponse, error) {
|
||||
tokenData := &KiroTokenData{
|
||||
AccessToken: accessToken,
|
||||
ProfileArn: profileArn,
|
||||
}
|
||||
return c.CheckUsage(ctx, tokenData)
|
||||
}
|
||||
|
||||
// GetRemainingQuota calculates the remaining quota from usage limits.
|
||||
func GetRemainingQuota(usage *UsageQuotaResponse) float64 {
|
||||
if usage == nil || len(usage.UsageBreakdownList) == 0 {
|
||||
return 0
|
||||
}
|
||||
|
||||
var totalRemaining float64
|
||||
for _, breakdown := range usage.UsageBreakdownList {
|
||||
remaining := breakdown.UsageLimitWithPrecision - breakdown.CurrentUsageWithPrecision
|
||||
if remaining > 0 {
|
||||
totalRemaining += remaining
|
||||
}
|
||||
|
||||
if breakdown.FreeTrialInfo != nil {
|
||||
freeRemaining := breakdown.FreeTrialInfo.UsageLimitWithPrecision - breakdown.FreeTrialInfo.CurrentUsageWithPrecision
|
||||
if freeRemaining > 0 {
|
||||
totalRemaining += freeRemaining
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
return totalRemaining
|
||||
}
|
||||
|
||||
// IsQuotaExhausted checks if the quota is exhausted based on usage limits.
|
||||
func IsQuotaExhausted(usage *UsageQuotaResponse) bool {
|
||||
if usage == nil || len(usage.UsageBreakdownList) == 0 {
|
||||
return true
|
||||
}
|
||||
|
||||
for _, breakdown := range usage.UsageBreakdownList {
|
||||
if breakdown.CurrentUsageWithPrecision < breakdown.UsageLimitWithPrecision {
|
||||
return false
|
||||
}
|
||||
|
||||
if breakdown.FreeTrialInfo != nil {
|
||||
if breakdown.FreeTrialInfo.CurrentUsageWithPrecision < breakdown.FreeTrialInfo.UsageLimitWithPrecision {
|
||||
return false
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
return true
|
||||
}
|
||||
|
||||
// GetQuotaStatus retrieves a comprehensive quota status for a token.
|
||||
func (c *UsageChecker) GetQuotaStatus(ctx context.Context, tokenData *KiroTokenData) (*QuotaStatus, error) {
|
||||
usage, err := c.CheckUsage(ctx, tokenData)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
status := &QuotaStatus{
|
||||
IsExhausted: IsQuotaExhausted(usage),
|
||||
}
|
||||
|
||||
if len(usage.UsageBreakdownList) > 0 {
|
||||
breakdown := usage.UsageBreakdownList[0]
|
||||
status.TotalLimit = breakdown.UsageLimitWithPrecision
|
||||
status.CurrentUsage = breakdown.CurrentUsageWithPrecision
|
||||
status.RemainingQuota = breakdown.UsageLimitWithPrecision - breakdown.CurrentUsageWithPrecision
|
||||
status.ResourceType = breakdown.ResourceType
|
||||
|
||||
if breakdown.FreeTrialInfo != nil {
|
||||
status.TotalLimit += breakdown.FreeTrialInfo.UsageLimitWithPrecision
|
||||
status.CurrentUsage += breakdown.FreeTrialInfo.CurrentUsageWithPrecision
|
||||
freeRemaining := breakdown.FreeTrialInfo.UsageLimitWithPrecision - breakdown.FreeTrialInfo.CurrentUsageWithPrecision
|
||||
if freeRemaining > 0 {
|
||||
status.RemainingQuota += freeRemaining
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
if usage.NextDateReset > 0 {
|
||||
status.NextReset = time.Unix(int64(usage.NextDateReset/1000), 0)
|
||||
}
|
||||
|
||||
return status, nil
|
||||
}
|
||||
|
||||
// CalculateAvailableCount calculates the available request count based on usage limits.
|
||||
func CalculateAvailableCount(usage *UsageQuotaResponse) float64 {
|
||||
return GetRemainingQuota(usage)
|
||||
}
|
||||
|
||||
// GetUsagePercentage calculates the usage percentage.
|
||||
func GetUsagePercentage(usage *UsageQuotaResponse) float64 {
|
||||
if usage == nil || len(usage.UsageBreakdownList) == 0 {
|
||||
return 100.0
|
||||
}
|
||||
|
||||
var totalLimit, totalUsage float64
|
||||
for _, breakdown := range usage.UsageBreakdownList {
|
||||
totalLimit += breakdown.UsageLimitWithPrecision
|
||||
totalUsage += breakdown.CurrentUsageWithPrecision
|
||||
|
||||
if breakdown.FreeTrialInfo != nil {
|
||||
totalLimit += breakdown.FreeTrialInfo.UsageLimitWithPrecision
|
||||
totalUsage += breakdown.FreeTrialInfo.CurrentUsageWithPrecision
|
||||
}
|
||||
}
|
||||
|
||||
if totalLimit == 0 {
|
||||
return 100.0
|
||||
}
|
||||
|
||||
return (totalUsage / totalLimit) * 100
|
||||
}
|
||||
Reference in New Issue
Block a user