mirror of
https://github.com/router-for-me/CLIProxyAPIPlus.git
synced 2026-03-08 06:43:41 +00:00
feat: add Kiro OAuth web, rate limiter, metrics, fingerprint, background refresh and model converter
This commit is contained in:
@@ -23,6 +23,7 @@ import (
|
||||
"github.com/router-for-me/CLIProxyAPI/v6/internal/api/middleware"
|
||||
"github.com/router-for-me/CLIProxyAPI/v6/internal/api/modules"
|
||||
ampmodule "github.com/router-for-me/CLIProxyAPI/v6/internal/api/modules/amp"
|
||||
"github.com/router-for-me/CLIProxyAPI/v6/internal/auth/kiro"
|
||||
"github.com/router-for-me/CLIProxyAPI/v6/internal/config"
|
||||
"github.com/router-for-me/CLIProxyAPI/v6/internal/logging"
|
||||
"github.com/router-for-me/CLIProxyAPI/v6/internal/managementasset"
|
||||
@@ -295,6 +296,11 @@ func NewServer(cfg *config.Config, authManager *auth.Manager, accessManager *sdk
|
||||
s.registerManagementRoutes()
|
||||
}
|
||||
|
||||
// === CLIProxyAPIPlus 扩展: 注册 Kiro OAuth Web 路由 ===
|
||||
kiroOAuthHandler := kiro.NewOAuthWebHandler(cfg)
|
||||
kiroOAuthHandler.RegisterRoutes(engine)
|
||||
log.Info("Kiro OAuth Web routes registered at /v0/oauth/kiro/*")
|
||||
|
||||
if optionState.keepAliveEnabled {
|
||||
s.enableKeepAlive(optionState.keepAliveTimeout, optionState.keepAliveOnTimeout)
|
||||
}
|
||||
|
||||
@@ -5,10 +5,12 @@ package kiro
|
||||
import (
|
||||
"encoding/base64"
|
||||
"encoding/json"
|
||||
"errors"
|
||||
"fmt"
|
||||
"os"
|
||||
"path/filepath"
|
||||
"strings"
|
||||
"time"
|
||||
)
|
||||
|
||||
// PKCECodes holds PKCE verification codes for OAuth2 PKCE flow
|
||||
@@ -85,6 +87,87 @@ type KiroModel struct {
|
||||
// KiroIDETokenFile is the default path to Kiro IDE's token file
|
||||
const KiroIDETokenFile = ".aws/sso/cache/kiro-auth-token.json"
|
||||
|
||||
// Default retry configuration for file reading
|
||||
const (
|
||||
defaultTokenReadMaxAttempts = 10 // Maximum retry attempts
|
||||
defaultTokenReadBaseDelay = 50 * time.Millisecond // Base delay between retries
|
||||
)
|
||||
|
||||
// isTransientFileError checks if the error is a transient file access error
|
||||
// that may be resolved by retrying (e.g., file locked by another process on Windows).
|
||||
func isTransientFileError(err error) bool {
|
||||
if err == nil {
|
||||
return false
|
||||
}
|
||||
|
||||
// Check for OS-level file access errors (Windows sharing violation, etc.)
|
||||
var pathErr *os.PathError
|
||||
if errors.As(err, &pathErr) {
|
||||
// Windows sharing violation (ERROR_SHARING_VIOLATION = 32)
|
||||
// Windows lock violation (ERROR_LOCK_VIOLATION = 33)
|
||||
errStr := pathErr.Err.Error()
|
||||
if strings.Contains(errStr, "being used by another process") ||
|
||||
strings.Contains(errStr, "sharing violation") ||
|
||||
strings.Contains(errStr, "lock violation") {
|
||||
return true
|
||||
}
|
||||
}
|
||||
|
||||
// Check error message for common transient patterns
|
||||
errMsg := strings.ToLower(err.Error())
|
||||
transientPatterns := []string{
|
||||
"being used by another process",
|
||||
"sharing violation",
|
||||
"lock violation",
|
||||
"access is denied",
|
||||
"unexpected end of json",
|
||||
"unexpected eof",
|
||||
}
|
||||
for _, pattern := range transientPatterns {
|
||||
if strings.Contains(errMsg, pattern) {
|
||||
return true
|
||||
}
|
||||
}
|
||||
|
||||
return false
|
||||
}
|
||||
|
||||
// LoadKiroIDETokenWithRetry loads token data from Kiro IDE's token file with retry logic.
|
||||
// This handles transient file access errors (e.g., file locked by Kiro IDE during write).
|
||||
// maxAttempts: maximum number of retry attempts (default 10 if <= 0)
|
||||
// baseDelay: base delay between retries with exponential backoff (default 50ms if <= 0)
|
||||
func LoadKiroIDETokenWithRetry(maxAttempts int, baseDelay time.Duration) (*KiroTokenData, error) {
|
||||
if maxAttempts <= 0 {
|
||||
maxAttempts = defaultTokenReadMaxAttempts
|
||||
}
|
||||
if baseDelay <= 0 {
|
||||
baseDelay = defaultTokenReadBaseDelay
|
||||
}
|
||||
|
||||
var lastErr error
|
||||
for attempt := 0; attempt < maxAttempts; attempt++ {
|
||||
token, err := LoadKiroIDEToken()
|
||||
if err == nil {
|
||||
return token, nil
|
||||
}
|
||||
lastErr = err
|
||||
|
||||
// Only retry for transient errors
|
||||
if !isTransientFileError(err) {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
// Exponential backoff: delay * 2^attempt, capped at 500ms
|
||||
delay := baseDelay * time.Duration(1<<uint(attempt))
|
||||
if delay > 500*time.Millisecond {
|
||||
delay = 500 * time.Millisecond
|
||||
}
|
||||
time.Sleep(delay)
|
||||
}
|
||||
|
||||
return nil, fmt.Errorf("failed to read token file after %d attempts: %w", maxAttempts, lastErr)
|
||||
}
|
||||
|
||||
// LoadKiroIDEToken loads token data from Kiro IDE's token file.
|
||||
func LoadKiroIDEToken() (*KiroTokenData, error) {
|
||||
homeDir, err := os.UserHomeDir()
|
||||
|
||||
305
internal/auth/kiro/aws.go.bak
Normal file
305
internal/auth/kiro/aws.go.bak
Normal file
@@ -0,0 +1,305 @@
|
||||
// Package kiro provides authentication functionality for AWS CodeWhisperer (Kiro) API.
|
||||
// It includes interfaces and implementations for token storage and authentication methods.
|
||||
package kiro
|
||||
|
||||
import (
|
||||
"encoding/base64"
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"os"
|
||||
"path/filepath"
|
||||
"strings"
|
||||
)
|
||||
|
||||
// PKCECodes holds PKCE verification codes for OAuth2 PKCE flow
|
||||
type PKCECodes struct {
|
||||
// CodeVerifier is the cryptographically random string used to correlate
|
||||
// the authorization request to the token request
|
||||
CodeVerifier string `json:"code_verifier"`
|
||||
// CodeChallenge is the SHA256 hash of the code verifier, base64url-encoded
|
||||
CodeChallenge string `json:"code_challenge"`
|
||||
}
|
||||
|
||||
// KiroTokenData holds OAuth token information from AWS CodeWhisperer (Kiro)
|
||||
type KiroTokenData struct {
|
||||
// AccessToken is the OAuth2 access token for API access
|
||||
AccessToken string `json:"accessToken"`
|
||||
// RefreshToken is used to obtain new access tokens
|
||||
RefreshToken string `json:"refreshToken"`
|
||||
// ProfileArn is the AWS CodeWhisperer profile ARN
|
||||
ProfileArn string `json:"profileArn"`
|
||||
// ExpiresAt is the timestamp when the token expires
|
||||
ExpiresAt string `json:"expiresAt"`
|
||||
// AuthMethod indicates the authentication method used (e.g., "builder-id", "social")
|
||||
AuthMethod string `json:"authMethod"`
|
||||
// Provider indicates the OAuth provider (e.g., "AWS", "Google")
|
||||
Provider string `json:"provider"`
|
||||
// ClientID is the OIDC client ID (needed for token refresh)
|
||||
ClientID string `json:"clientId,omitempty"`
|
||||
// ClientSecret is the OIDC client secret (needed for token refresh)
|
||||
ClientSecret string `json:"clientSecret,omitempty"`
|
||||
// Email is the user's email address (used for file naming)
|
||||
Email string `json:"email,omitempty"`
|
||||
// StartURL is the IDC/Identity Center start URL (only for IDC auth method)
|
||||
StartURL string `json:"startUrl,omitempty"`
|
||||
// Region is the AWS region for IDC authentication (only for IDC auth method)
|
||||
Region string `json:"region,omitempty"`
|
||||
}
|
||||
|
||||
// KiroAuthBundle aggregates authentication data after OAuth flow completion
|
||||
type KiroAuthBundle struct {
|
||||
// TokenData contains the OAuth tokens from the authentication flow
|
||||
TokenData KiroTokenData `json:"token_data"`
|
||||
// LastRefresh is the timestamp of the last token refresh
|
||||
LastRefresh string `json:"last_refresh"`
|
||||
}
|
||||
|
||||
// KiroUsageInfo represents usage information from CodeWhisperer API
|
||||
type KiroUsageInfo struct {
|
||||
// SubscriptionTitle is the subscription plan name (e.g., "KIRO FREE")
|
||||
SubscriptionTitle string `json:"subscription_title"`
|
||||
// CurrentUsage is the current credit usage
|
||||
CurrentUsage float64 `json:"current_usage"`
|
||||
// UsageLimit is the maximum credit limit
|
||||
UsageLimit float64 `json:"usage_limit"`
|
||||
// NextReset is the timestamp of the next usage reset
|
||||
NextReset string `json:"next_reset"`
|
||||
}
|
||||
|
||||
// KiroModel represents a model available through the CodeWhisperer API
|
||||
type KiroModel struct {
|
||||
// ModelID is the unique identifier for the model
|
||||
ModelID string `json:"modelId"`
|
||||
// ModelName is the human-readable name
|
||||
ModelName string `json:"modelName"`
|
||||
// Description is the model description
|
||||
Description string `json:"description"`
|
||||
// RateMultiplier is the credit multiplier for this model
|
||||
RateMultiplier float64 `json:"rateMultiplier"`
|
||||
// RateUnit is the unit for rate calculation (e.g., "credit")
|
||||
RateUnit string `json:"rateUnit"`
|
||||
// MaxInputTokens is the maximum input token limit
|
||||
MaxInputTokens int `json:"maxInputTokens,omitempty"`
|
||||
}
|
||||
|
||||
// KiroIDETokenFile is the default path to Kiro IDE's token file
|
||||
const KiroIDETokenFile = ".aws/sso/cache/kiro-auth-token.json"
|
||||
|
||||
// LoadKiroIDEToken loads token data from Kiro IDE's token file.
|
||||
func LoadKiroIDEToken() (*KiroTokenData, error) {
|
||||
homeDir, err := os.UserHomeDir()
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to get home directory: %w", err)
|
||||
}
|
||||
|
||||
tokenPath := filepath.Join(homeDir, KiroIDETokenFile)
|
||||
data, err := os.ReadFile(tokenPath)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to read Kiro IDE token file (%s): %w", tokenPath, err)
|
||||
}
|
||||
|
||||
var token KiroTokenData
|
||||
if err := json.Unmarshal(data, &token); err != nil {
|
||||
return nil, fmt.Errorf("failed to parse Kiro IDE token: %w", err)
|
||||
}
|
||||
|
||||
if token.AccessToken == "" {
|
||||
return nil, fmt.Errorf("access token is empty in Kiro IDE token file")
|
||||
}
|
||||
|
||||
return &token, nil
|
||||
}
|
||||
|
||||
// LoadKiroTokenFromPath loads token data from a custom path.
|
||||
// This supports multiple accounts by allowing different token files.
|
||||
func LoadKiroTokenFromPath(tokenPath string) (*KiroTokenData, error) {
|
||||
// Expand ~ to home directory
|
||||
if len(tokenPath) > 0 && tokenPath[0] == '~' {
|
||||
homeDir, err := os.UserHomeDir()
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to get home directory: %w", err)
|
||||
}
|
||||
tokenPath = filepath.Join(homeDir, tokenPath[1:])
|
||||
}
|
||||
|
||||
data, err := os.ReadFile(tokenPath)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to read token file (%s): %w", tokenPath, err)
|
||||
}
|
||||
|
||||
var token KiroTokenData
|
||||
if err := json.Unmarshal(data, &token); err != nil {
|
||||
return nil, fmt.Errorf("failed to parse token file: %w", err)
|
||||
}
|
||||
|
||||
if token.AccessToken == "" {
|
||||
return nil, fmt.Errorf("access token is empty in token file")
|
||||
}
|
||||
|
||||
return &token, nil
|
||||
}
|
||||
|
||||
// ListKiroTokenFiles lists all Kiro token files in the cache directory.
|
||||
// This supports multiple accounts by finding all token files.
|
||||
func ListKiroTokenFiles() ([]string, error) {
|
||||
homeDir, err := os.UserHomeDir()
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to get home directory: %w", err)
|
||||
}
|
||||
|
||||
cacheDir := filepath.Join(homeDir, ".aws", "sso", "cache")
|
||||
|
||||
// Check if directory exists
|
||||
if _, err := os.Stat(cacheDir); os.IsNotExist(err) {
|
||||
return nil, nil // No token files
|
||||
}
|
||||
|
||||
entries, err := os.ReadDir(cacheDir)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to read cache directory: %w", err)
|
||||
}
|
||||
|
||||
var tokenFiles []string
|
||||
for _, entry := range entries {
|
||||
if entry.IsDir() {
|
||||
continue
|
||||
}
|
||||
name := entry.Name()
|
||||
// Look for kiro token files only (avoid matching unrelated AWS SSO cache files)
|
||||
if strings.HasSuffix(name, ".json") && strings.HasPrefix(name, "kiro") {
|
||||
tokenFiles = append(tokenFiles, filepath.Join(cacheDir, name))
|
||||
}
|
||||
}
|
||||
|
||||
return tokenFiles, nil
|
||||
}
|
||||
|
||||
// LoadAllKiroTokens loads all Kiro tokens from the cache directory.
|
||||
// This supports multiple accounts.
|
||||
func LoadAllKiroTokens() ([]*KiroTokenData, error) {
|
||||
files, err := ListKiroTokenFiles()
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
var tokens []*KiroTokenData
|
||||
for _, file := range files {
|
||||
token, err := LoadKiroTokenFromPath(file)
|
||||
if err != nil {
|
||||
// Skip invalid token files
|
||||
continue
|
||||
}
|
||||
tokens = append(tokens, token)
|
||||
}
|
||||
|
||||
return tokens, nil
|
||||
}
|
||||
|
||||
// JWTClaims represents the claims we care about from a JWT token.
|
||||
// JWT tokens from Kiro/AWS contain user information in the payload.
|
||||
type JWTClaims struct {
|
||||
Email string `json:"email,omitempty"`
|
||||
Sub string `json:"sub,omitempty"`
|
||||
PreferredUser string `json:"preferred_username,omitempty"`
|
||||
Name string `json:"name,omitempty"`
|
||||
Iss string `json:"iss,omitempty"`
|
||||
}
|
||||
|
||||
// ExtractEmailFromJWT extracts the user's email from a JWT access token.
|
||||
// JWT tokens typically have format: header.payload.signature
|
||||
// The payload is base64url-encoded JSON containing user claims.
|
||||
func ExtractEmailFromJWT(accessToken string) string {
|
||||
if accessToken == "" {
|
||||
return ""
|
||||
}
|
||||
|
||||
// JWT format: header.payload.signature
|
||||
parts := strings.Split(accessToken, ".")
|
||||
if len(parts) != 3 {
|
||||
return ""
|
||||
}
|
||||
|
||||
// Decode the payload (second part)
|
||||
payload := parts[1]
|
||||
|
||||
// Add padding if needed (base64url requires padding)
|
||||
switch len(payload) % 4 {
|
||||
case 2:
|
||||
payload += "=="
|
||||
case 3:
|
||||
payload += "="
|
||||
}
|
||||
|
||||
decoded, err := base64.URLEncoding.DecodeString(payload)
|
||||
if err != nil {
|
||||
// Try RawURLEncoding (no padding)
|
||||
decoded, err = base64.RawURLEncoding.DecodeString(parts[1])
|
||||
if err != nil {
|
||||
return ""
|
||||
}
|
||||
}
|
||||
|
||||
var claims JWTClaims
|
||||
if err := json.Unmarshal(decoded, &claims); err != nil {
|
||||
return ""
|
||||
}
|
||||
|
||||
// Return email if available
|
||||
if claims.Email != "" {
|
||||
return claims.Email
|
||||
}
|
||||
|
||||
// Fallback to preferred_username (some providers use this)
|
||||
if claims.PreferredUser != "" && strings.Contains(claims.PreferredUser, "@") {
|
||||
return claims.PreferredUser
|
||||
}
|
||||
|
||||
// Fallback to sub if it looks like an email
|
||||
if claims.Sub != "" && strings.Contains(claims.Sub, "@") {
|
||||
return claims.Sub
|
||||
}
|
||||
|
||||
return ""
|
||||
}
|
||||
|
||||
// SanitizeEmailForFilename sanitizes an email address for use in a filename.
|
||||
// Replaces special characters with underscores and prevents path traversal attacks.
|
||||
// Also handles URL-encoded characters to prevent encoded path traversal attempts.
|
||||
func SanitizeEmailForFilename(email string) string {
|
||||
if email == "" {
|
||||
return ""
|
||||
}
|
||||
|
||||
result := email
|
||||
|
||||
// First, handle URL-encoded path traversal attempts (%2F, %2E, %5C, etc.)
|
||||
// This prevents encoded characters from bypassing the sanitization.
|
||||
// Note: We replace % last to catch any remaining encodings including double-encoding (%252F)
|
||||
result = strings.ReplaceAll(result, "%2F", "_") // /
|
||||
result = strings.ReplaceAll(result, "%2f", "_")
|
||||
result = strings.ReplaceAll(result, "%5C", "_") // \
|
||||
result = strings.ReplaceAll(result, "%5c", "_")
|
||||
result = strings.ReplaceAll(result, "%2E", "_") // .
|
||||
result = strings.ReplaceAll(result, "%2e", "_")
|
||||
result = strings.ReplaceAll(result, "%00", "_") // null byte
|
||||
result = strings.ReplaceAll(result, "%", "_") // Catch remaining % to prevent double-encoding attacks
|
||||
|
||||
// Replace characters that are problematic in filenames
|
||||
// Keep @ and . in middle but replace other special characters
|
||||
for _, char := range []string{"/", "\\", ":", "*", "?", "\"", "<", ">", "|", " ", "\x00"} {
|
||||
result = strings.ReplaceAll(result, char, "_")
|
||||
}
|
||||
|
||||
// Prevent path traversal: replace leading dots in each path component
|
||||
// This handles cases like "../../../etc/passwd" → "_.._.._.._etc_passwd"
|
||||
parts := strings.Split(result, "_")
|
||||
for i, part := range parts {
|
||||
for strings.HasPrefix(part, ".") {
|
||||
part = "_" + part[1:]
|
||||
}
|
||||
parts[i] = part
|
||||
}
|
||||
result = strings.Join(parts, "_")
|
||||
|
||||
return result
|
||||
}
|
||||
192
internal/auth/kiro/background_refresh.go
Normal file
192
internal/auth/kiro/background_refresh.go
Normal file
@@ -0,0 +1,192 @@
|
||||
package kiro
|
||||
|
||||
import (
|
||||
"context"
|
||||
"log"
|
||||
"sync"
|
||||
"time"
|
||||
|
||||
"github.com/router-for-me/CLIProxyAPI/v6/internal/config"
|
||||
"golang.org/x/sync/semaphore"
|
||||
)
|
||||
|
||||
type Token struct {
|
||||
ID string
|
||||
AccessToken string
|
||||
RefreshToken string
|
||||
ExpiresAt time.Time
|
||||
LastVerified time.Time
|
||||
ClientID string
|
||||
ClientSecret string
|
||||
AuthMethod string
|
||||
Provider string
|
||||
StartURL string
|
||||
Region string
|
||||
}
|
||||
|
||||
type TokenRepository interface {
|
||||
FindOldestUnverified(limit int) []*Token
|
||||
UpdateToken(token *Token) error
|
||||
}
|
||||
|
||||
type RefresherOption func(*BackgroundRefresher)
|
||||
|
||||
func WithInterval(interval time.Duration) RefresherOption {
|
||||
return func(r *BackgroundRefresher) {
|
||||
r.interval = interval
|
||||
}
|
||||
}
|
||||
|
||||
func WithBatchSize(size int) RefresherOption {
|
||||
return func(r *BackgroundRefresher) {
|
||||
r.batchSize = size
|
||||
}
|
||||
}
|
||||
|
||||
func WithConcurrency(concurrency int) RefresherOption {
|
||||
return func(r *BackgroundRefresher) {
|
||||
r.concurrency = concurrency
|
||||
}
|
||||
}
|
||||
|
||||
type BackgroundRefresher struct {
|
||||
interval time.Duration
|
||||
batchSize int
|
||||
concurrency int
|
||||
tokenRepo TokenRepository
|
||||
stopCh chan struct{}
|
||||
wg sync.WaitGroup
|
||||
oauth *KiroOAuth
|
||||
ssoClient *SSOOIDCClient
|
||||
}
|
||||
|
||||
func NewBackgroundRefresher(repo TokenRepository, opts ...RefresherOption) *BackgroundRefresher {
|
||||
r := &BackgroundRefresher{
|
||||
interval: time.Minute,
|
||||
batchSize: 50,
|
||||
concurrency: 10,
|
||||
tokenRepo: repo,
|
||||
stopCh: make(chan struct{}),
|
||||
oauth: nil, // Lazy init - will be set when config available
|
||||
ssoClient: nil, // Lazy init - will be set when config available
|
||||
}
|
||||
for _, opt := range opts {
|
||||
opt(r)
|
||||
}
|
||||
return r
|
||||
}
|
||||
|
||||
// WithConfig sets the configuration for OAuth and SSO clients.
|
||||
func WithConfig(cfg *config.Config) RefresherOption {
|
||||
return func(r *BackgroundRefresher) {
|
||||
r.oauth = NewKiroOAuth(cfg)
|
||||
r.ssoClient = NewSSOOIDCClient(cfg)
|
||||
}
|
||||
}
|
||||
|
||||
func (r *BackgroundRefresher) Start(ctx context.Context) {
|
||||
r.wg.Add(1)
|
||||
go func() {
|
||||
defer r.wg.Done()
|
||||
ticker := time.NewTicker(r.interval)
|
||||
defer ticker.Stop()
|
||||
|
||||
r.refreshBatch(ctx)
|
||||
|
||||
for {
|
||||
select {
|
||||
case <-ctx.Done():
|
||||
return
|
||||
case <-r.stopCh:
|
||||
return
|
||||
case <-ticker.C:
|
||||
r.refreshBatch(ctx)
|
||||
}
|
||||
}
|
||||
}()
|
||||
}
|
||||
|
||||
func (r *BackgroundRefresher) Stop() {
|
||||
close(r.stopCh)
|
||||
r.wg.Wait()
|
||||
}
|
||||
|
||||
func (r *BackgroundRefresher) refreshBatch(ctx context.Context) {
|
||||
tokens := r.tokenRepo.FindOldestUnverified(r.batchSize)
|
||||
if len(tokens) == 0 {
|
||||
return
|
||||
}
|
||||
|
||||
sem := semaphore.NewWeighted(int64(r.concurrency))
|
||||
var wg sync.WaitGroup
|
||||
|
||||
for i, token := range tokens {
|
||||
if i > 0 {
|
||||
select {
|
||||
case <-ctx.Done():
|
||||
return
|
||||
case <-r.stopCh:
|
||||
return
|
||||
case <-time.After(100 * time.Millisecond):
|
||||
}
|
||||
}
|
||||
|
||||
if err := sem.Acquire(ctx, 1); err != nil {
|
||||
return
|
||||
}
|
||||
|
||||
wg.Add(1)
|
||||
go func(t *Token) {
|
||||
defer wg.Done()
|
||||
defer sem.Release(1)
|
||||
r.refreshSingle(ctx, t)
|
||||
}(token)
|
||||
}
|
||||
|
||||
wg.Wait()
|
||||
}
|
||||
|
||||
func (r *BackgroundRefresher) refreshSingle(ctx context.Context, token *Token) {
|
||||
var newTokenData *KiroTokenData
|
||||
var err error
|
||||
|
||||
switch token.AuthMethod {
|
||||
case "idc":
|
||||
newTokenData, err = r.ssoClient.RefreshTokenWithRegion(
|
||||
ctx,
|
||||
token.ClientID,
|
||||
token.ClientSecret,
|
||||
token.RefreshToken,
|
||||
token.Region,
|
||||
token.StartURL,
|
||||
)
|
||||
case "builder-id":
|
||||
newTokenData, err = r.ssoClient.RefreshToken(
|
||||
ctx,
|
||||
token.ClientID,
|
||||
token.ClientSecret,
|
||||
token.RefreshToken,
|
||||
)
|
||||
default:
|
||||
newTokenData, err = r.oauth.RefreshToken(ctx, token.RefreshToken)
|
||||
}
|
||||
|
||||
if err != nil {
|
||||
log.Printf("failed to refresh token %s: %v", token.ID, err)
|
||||
return
|
||||
}
|
||||
|
||||
token.AccessToken = newTokenData.AccessToken
|
||||
token.RefreshToken = newTokenData.RefreshToken
|
||||
token.LastVerified = time.Now()
|
||||
|
||||
if newTokenData.ExpiresAt != "" {
|
||||
if expTime, parseErr := time.Parse(time.RFC3339, newTokenData.ExpiresAt); parseErr == nil {
|
||||
token.ExpiresAt = expTime
|
||||
}
|
||||
}
|
||||
|
||||
if err := r.tokenRepo.UpdateToken(token); err != nil {
|
||||
log.Printf("failed to update token %s: %v", token.ID, err)
|
||||
}
|
||||
}
|
||||
112
internal/auth/kiro/cooldown.go
Normal file
112
internal/auth/kiro/cooldown.go
Normal file
@@ -0,0 +1,112 @@
|
||||
package kiro
|
||||
|
||||
import (
|
||||
"sync"
|
||||
"time"
|
||||
)
|
||||
|
||||
const (
|
||||
CooldownReason429 = "rate_limit_exceeded"
|
||||
CooldownReasonSuspended = "account_suspended"
|
||||
CooldownReasonQuotaExhausted = "quota_exhausted"
|
||||
|
||||
DefaultShortCooldown = 1 * time.Minute
|
||||
MaxShortCooldown = 5 * time.Minute
|
||||
LongCooldown = 24 * time.Hour
|
||||
)
|
||||
|
||||
type CooldownManager struct {
|
||||
mu sync.RWMutex
|
||||
cooldowns map[string]time.Time
|
||||
reasons map[string]string
|
||||
}
|
||||
|
||||
func NewCooldownManager() *CooldownManager {
|
||||
return &CooldownManager{
|
||||
cooldowns: make(map[string]time.Time),
|
||||
reasons: make(map[string]string),
|
||||
}
|
||||
}
|
||||
|
||||
func (cm *CooldownManager) SetCooldown(tokenKey string, duration time.Duration, reason string) {
|
||||
cm.mu.Lock()
|
||||
defer cm.mu.Unlock()
|
||||
cm.cooldowns[tokenKey] = time.Now().Add(duration)
|
||||
cm.reasons[tokenKey] = reason
|
||||
}
|
||||
|
||||
func (cm *CooldownManager) IsInCooldown(tokenKey string) bool {
|
||||
cm.mu.RLock()
|
||||
defer cm.mu.RUnlock()
|
||||
endTime, exists := cm.cooldowns[tokenKey]
|
||||
if !exists {
|
||||
return false
|
||||
}
|
||||
return time.Now().Before(endTime)
|
||||
}
|
||||
|
||||
func (cm *CooldownManager) GetRemainingCooldown(tokenKey string) time.Duration {
|
||||
cm.mu.RLock()
|
||||
defer cm.mu.RUnlock()
|
||||
endTime, exists := cm.cooldowns[tokenKey]
|
||||
if !exists {
|
||||
return 0
|
||||
}
|
||||
remaining := time.Until(endTime)
|
||||
if remaining < 0 {
|
||||
return 0
|
||||
}
|
||||
return remaining
|
||||
}
|
||||
|
||||
func (cm *CooldownManager) GetCooldownReason(tokenKey string) string {
|
||||
cm.mu.RLock()
|
||||
defer cm.mu.RUnlock()
|
||||
return cm.reasons[tokenKey]
|
||||
}
|
||||
|
||||
func (cm *CooldownManager) ClearCooldown(tokenKey string) {
|
||||
cm.mu.Lock()
|
||||
defer cm.mu.Unlock()
|
||||
delete(cm.cooldowns, tokenKey)
|
||||
delete(cm.reasons, tokenKey)
|
||||
}
|
||||
|
||||
func (cm *CooldownManager) CleanupExpired() {
|
||||
cm.mu.Lock()
|
||||
defer cm.mu.Unlock()
|
||||
now := time.Now()
|
||||
for tokenKey, endTime := range cm.cooldowns {
|
||||
if now.After(endTime) {
|
||||
delete(cm.cooldowns, tokenKey)
|
||||
delete(cm.reasons, tokenKey)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func (cm *CooldownManager) StartCleanupRoutine(interval time.Duration, stopCh <-chan struct{}) {
|
||||
ticker := time.NewTicker(interval)
|
||||
defer ticker.Stop()
|
||||
for {
|
||||
select {
|
||||
case <-ticker.C:
|
||||
cm.CleanupExpired()
|
||||
case <-stopCh:
|
||||
return
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func CalculateCooldownFor429(retryCount int) time.Duration {
|
||||
duration := DefaultShortCooldown * time.Duration(1<<retryCount)
|
||||
if duration > MaxShortCooldown {
|
||||
return MaxShortCooldown
|
||||
}
|
||||
return duration
|
||||
}
|
||||
|
||||
func CalculateCooldownUntilNextDay() time.Duration {
|
||||
now := time.Now()
|
||||
nextDay := time.Date(now.Year(), now.Month(), now.Day()+1, 0, 0, 0, 0, now.Location())
|
||||
return time.Until(nextDay)
|
||||
}
|
||||
240
internal/auth/kiro/cooldown_test.go
Normal file
240
internal/auth/kiro/cooldown_test.go
Normal file
@@ -0,0 +1,240 @@
|
||||
package kiro
|
||||
|
||||
import (
|
||||
"sync"
|
||||
"testing"
|
||||
"time"
|
||||
)
|
||||
|
||||
func TestNewCooldownManager(t *testing.T) {
|
||||
cm := NewCooldownManager()
|
||||
if cm == nil {
|
||||
t.Fatal("expected non-nil CooldownManager")
|
||||
}
|
||||
if cm.cooldowns == nil {
|
||||
t.Error("expected non-nil cooldowns map")
|
||||
}
|
||||
if cm.reasons == nil {
|
||||
t.Error("expected non-nil reasons map")
|
||||
}
|
||||
}
|
||||
|
||||
func TestSetCooldown(t *testing.T) {
|
||||
cm := NewCooldownManager()
|
||||
cm.SetCooldown("token1", 1*time.Minute, CooldownReason429)
|
||||
|
||||
if !cm.IsInCooldown("token1") {
|
||||
t.Error("expected token to be in cooldown")
|
||||
}
|
||||
if cm.GetCooldownReason("token1") != CooldownReason429 {
|
||||
t.Errorf("expected reason %s, got %s", CooldownReason429, cm.GetCooldownReason("token1"))
|
||||
}
|
||||
}
|
||||
|
||||
func TestIsInCooldown_NotSet(t *testing.T) {
|
||||
cm := NewCooldownManager()
|
||||
if cm.IsInCooldown("nonexistent") {
|
||||
t.Error("expected non-existent token to not be in cooldown")
|
||||
}
|
||||
}
|
||||
|
||||
func TestIsInCooldown_Expired(t *testing.T) {
|
||||
cm := NewCooldownManager()
|
||||
cm.SetCooldown("token1", 1*time.Millisecond, CooldownReason429)
|
||||
|
||||
time.Sleep(10 * time.Millisecond)
|
||||
|
||||
if cm.IsInCooldown("token1") {
|
||||
t.Error("expected expired cooldown to return false")
|
||||
}
|
||||
}
|
||||
|
||||
func TestGetRemainingCooldown(t *testing.T) {
|
||||
cm := NewCooldownManager()
|
||||
cm.SetCooldown("token1", 1*time.Second, CooldownReason429)
|
||||
|
||||
remaining := cm.GetRemainingCooldown("token1")
|
||||
if remaining <= 0 || remaining > 1*time.Second {
|
||||
t.Errorf("expected remaining cooldown between 0 and 1s, got %v", remaining)
|
||||
}
|
||||
}
|
||||
|
||||
func TestGetRemainingCooldown_NotSet(t *testing.T) {
|
||||
cm := NewCooldownManager()
|
||||
remaining := cm.GetRemainingCooldown("nonexistent")
|
||||
if remaining != 0 {
|
||||
t.Errorf("expected 0 remaining for non-existent, got %v", remaining)
|
||||
}
|
||||
}
|
||||
|
||||
func TestGetRemainingCooldown_Expired(t *testing.T) {
|
||||
cm := NewCooldownManager()
|
||||
cm.SetCooldown("token1", 1*time.Millisecond, CooldownReason429)
|
||||
|
||||
time.Sleep(10 * time.Millisecond)
|
||||
|
||||
remaining := cm.GetRemainingCooldown("token1")
|
||||
if remaining != 0 {
|
||||
t.Errorf("expected 0 remaining for expired, got %v", remaining)
|
||||
}
|
||||
}
|
||||
|
||||
func TestGetCooldownReason(t *testing.T) {
|
||||
cm := NewCooldownManager()
|
||||
cm.SetCooldown("token1", 1*time.Minute, CooldownReasonSuspended)
|
||||
|
||||
reason := cm.GetCooldownReason("token1")
|
||||
if reason != CooldownReasonSuspended {
|
||||
t.Errorf("expected reason %s, got %s", CooldownReasonSuspended, reason)
|
||||
}
|
||||
}
|
||||
|
||||
func TestGetCooldownReason_NotSet(t *testing.T) {
|
||||
cm := NewCooldownManager()
|
||||
reason := cm.GetCooldownReason("nonexistent")
|
||||
if reason != "" {
|
||||
t.Errorf("expected empty reason for non-existent, got %s", reason)
|
||||
}
|
||||
}
|
||||
|
||||
func TestClearCooldown(t *testing.T) {
|
||||
cm := NewCooldownManager()
|
||||
cm.SetCooldown("token1", 1*time.Minute, CooldownReason429)
|
||||
cm.ClearCooldown("token1")
|
||||
|
||||
if cm.IsInCooldown("token1") {
|
||||
t.Error("expected cooldown to be cleared")
|
||||
}
|
||||
if cm.GetCooldownReason("token1") != "" {
|
||||
t.Error("expected reason to be cleared")
|
||||
}
|
||||
}
|
||||
|
||||
func TestClearCooldown_NonExistent(t *testing.T) {
|
||||
cm := NewCooldownManager()
|
||||
cm.ClearCooldown("nonexistent")
|
||||
}
|
||||
|
||||
func TestCleanupExpired(t *testing.T) {
|
||||
cm := NewCooldownManager()
|
||||
cm.SetCooldown("expired1", 1*time.Millisecond, CooldownReason429)
|
||||
cm.SetCooldown("expired2", 1*time.Millisecond, CooldownReason429)
|
||||
cm.SetCooldown("active", 1*time.Hour, CooldownReason429)
|
||||
|
||||
time.Sleep(10 * time.Millisecond)
|
||||
cm.CleanupExpired()
|
||||
|
||||
if cm.GetCooldownReason("expired1") != "" {
|
||||
t.Error("expected expired1 to be cleaned up")
|
||||
}
|
||||
if cm.GetCooldownReason("expired2") != "" {
|
||||
t.Error("expected expired2 to be cleaned up")
|
||||
}
|
||||
if cm.GetCooldownReason("active") != CooldownReason429 {
|
||||
t.Error("expected active to remain")
|
||||
}
|
||||
}
|
||||
|
||||
func TestCalculateCooldownFor429_FirstRetry(t *testing.T) {
|
||||
duration := CalculateCooldownFor429(0)
|
||||
if duration != DefaultShortCooldown {
|
||||
t.Errorf("expected %v for retry 0, got %v", DefaultShortCooldown, duration)
|
||||
}
|
||||
}
|
||||
|
||||
func TestCalculateCooldownFor429_Exponential(t *testing.T) {
|
||||
d1 := CalculateCooldownFor429(1)
|
||||
d2 := CalculateCooldownFor429(2)
|
||||
|
||||
if d2 <= d1 {
|
||||
t.Errorf("expected d2 > d1, got d1=%v, d2=%v", d1, d2)
|
||||
}
|
||||
}
|
||||
|
||||
func TestCalculateCooldownFor429_MaxCap(t *testing.T) {
|
||||
duration := CalculateCooldownFor429(10)
|
||||
if duration > MaxShortCooldown {
|
||||
t.Errorf("expected max %v, got %v", MaxShortCooldown, duration)
|
||||
}
|
||||
}
|
||||
|
||||
func TestCalculateCooldownUntilNextDay(t *testing.T) {
|
||||
duration := CalculateCooldownUntilNextDay()
|
||||
if duration <= 0 || duration > 24*time.Hour {
|
||||
t.Errorf("expected duration between 0 and 24h, got %v", duration)
|
||||
}
|
||||
}
|
||||
|
||||
func TestCooldownManager_ConcurrentAccess(t *testing.T) {
|
||||
cm := NewCooldownManager()
|
||||
const numGoroutines = 50
|
||||
const numOperations = 100
|
||||
|
||||
var wg sync.WaitGroup
|
||||
wg.Add(numGoroutines)
|
||||
|
||||
for i := 0; i < numGoroutines; i++ {
|
||||
go func(id int) {
|
||||
defer wg.Done()
|
||||
tokenKey := "token" + string(rune('a'+id%10))
|
||||
for j := 0; j < numOperations; j++ {
|
||||
switch j % 6 {
|
||||
case 0:
|
||||
cm.SetCooldown(tokenKey, time.Duration(j)*time.Millisecond, CooldownReason429)
|
||||
case 1:
|
||||
cm.IsInCooldown(tokenKey)
|
||||
case 2:
|
||||
cm.GetRemainingCooldown(tokenKey)
|
||||
case 3:
|
||||
cm.GetCooldownReason(tokenKey)
|
||||
case 4:
|
||||
cm.ClearCooldown(tokenKey)
|
||||
case 5:
|
||||
cm.CleanupExpired()
|
||||
}
|
||||
}
|
||||
}(i)
|
||||
}
|
||||
|
||||
wg.Wait()
|
||||
}
|
||||
|
||||
func TestCooldownReasonConstants(t *testing.T) {
|
||||
if CooldownReason429 != "rate_limit_exceeded" {
|
||||
t.Errorf("unexpected CooldownReason429: %s", CooldownReason429)
|
||||
}
|
||||
if CooldownReasonSuspended != "account_suspended" {
|
||||
t.Errorf("unexpected CooldownReasonSuspended: %s", CooldownReasonSuspended)
|
||||
}
|
||||
if CooldownReasonQuotaExhausted != "quota_exhausted" {
|
||||
t.Errorf("unexpected CooldownReasonQuotaExhausted: %s", CooldownReasonQuotaExhausted)
|
||||
}
|
||||
}
|
||||
|
||||
func TestDefaultConstants(t *testing.T) {
|
||||
if DefaultShortCooldown != 1*time.Minute {
|
||||
t.Errorf("unexpected DefaultShortCooldown: %v", DefaultShortCooldown)
|
||||
}
|
||||
if MaxShortCooldown != 5*time.Minute {
|
||||
t.Errorf("unexpected MaxShortCooldown: %v", MaxShortCooldown)
|
||||
}
|
||||
if LongCooldown != 24*time.Hour {
|
||||
t.Errorf("unexpected LongCooldown: %v", LongCooldown)
|
||||
}
|
||||
}
|
||||
|
||||
func TestSetCooldown_OverwritesPrevious(t *testing.T) {
|
||||
cm := NewCooldownManager()
|
||||
cm.SetCooldown("token1", 1*time.Hour, CooldownReason429)
|
||||
cm.SetCooldown("token1", 1*time.Minute, CooldownReasonSuspended)
|
||||
|
||||
reason := cm.GetCooldownReason("token1")
|
||||
if reason != CooldownReasonSuspended {
|
||||
t.Errorf("expected reason to be overwritten to %s, got %s", CooldownReasonSuspended, reason)
|
||||
}
|
||||
|
||||
remaining := cm.GetRemainingCooldown("token1")
|
||||
if remaining > 1*time.Minute {
|
||||
t.Errorf("expected remaining <= 1 minute, got %v", remaining)
|
||||
}
|
||||
}
|
||||
197
internal/auth/kiro/fingerprint.go
Normal file
197
internal/auth/kiro/fingerprint.go
Normal file
@@ -0,0 +1,197 @@
|
||||
package kiro
|
||||
|
||||
import (
|
||||
"crypto/sha256"
|
||||
"encoding/hex"
|
||||
"fmt"
|
||||
"math/rand"
|
||||
"net/http"
|
||||
"sync"
|
||||
"time"
|
||||
)
|
||||
|
||||
// Fingerprint 多维度指纹信息
|
||||
type Fingerprint struct {
|
||||
SDKVersion string // 1.0.20-1.0.27
|
||||
OSType string // darwin/windows/linux
|
||||
OSVersion string // 10.0.22621
|
||||
NodeVersion string // 18.x/20.x/22.x
|
||||
KiroVersion string // 0.3.x-0.8.x
|
||||
KiroHash string // SHA256
|
||||
AcceptLanguage string
|
||||
ScreenResolution string // 1920x1080
|
||||
ColorDepth int // 24
|
||||
HardwareConcurrency int // CPU 核心数
|
||||
TimezoneOffset int
|
||||
}
|
||||
|
||||
// FingerprintManager 指纹管理器
|
||||
type FingerprintManager struct {
|
||||
mu sync.RWMutex
|
||||
fingerprints map[string]*Fingerprint // tokenKey -> fingerprint
|
||||
rng *rand.Rand
|
||||
}
|
||||
|
||||
var (
|
||||
sdkVersions = []string{
|
||||
"1.0.20", "1.0.21", "1.0.22", "1.0.23",
|
||||
"1.0.24", "1.0.25", "1.0.26", "1.0.27",
|
||||
}
|
||||
osTypes = []string{"darwin", "windows", "linux"}
|
||||
osVersions = map[string][]string{
|
||||
"darwin": {"14.0", "14.1", "14.2", "14.3", "14.4", "14.5", "15.0", "15.1"},
|
||||
"windows": {"10.0.19041", "10.0.19042", "10.0.19043", "10.0.19044", "10.0.22621", "10.0.22631"},
|
||||
"linux": {"5.15.0", "6.1.0", "6.2.0", "6.5.0", "6.6.0", "6.8.0"},
|
||||
}
|
||||
nodeVersions = []string{
|
||||
"18.17.0", "18.18.0", "18.19.0", "18.20.0",
|
||||
"20.9.0", "20.10.0", "20.11.0", "20.12.0", "20.13.0",
|
||||
"22.0.0", "22.1.0", "22.2.0", "22.3.0",
|
||||
}
|
||||
kiroVersions = []string{
|
||||
"0.3.0", "0.3.1", "0.4.0", "0.4.1", "0.5.0", "0.5.1",
|
||||
"0.6.0", "0.6.1", "0.7.0", "0.7.1", "0.8.0", "0.8.1",
|
||||
}
|
||||
acceptLanguages = []string{
|
||||
"en-US,en;q=0.9",
|
||||
"en-GB,en;q=0.9",
|
||||
"zh-CN,zh;q=0.9,en;q=0.8",
|
||||
"zh-TW,zh;q=0.9,en;q=0.8",
|
||||
"ja-JP,ja;q=0.9,en;q=0.8",
|
||||
"ko-KR,ko;q=0.9,en;q=0.8",
|
||||
"de-DE,de;q=0.9,en;q=0.8",
|
||||
"fr-FR,fr;q=0.9,en;q=0.8",
|
||||
}
|
||||
screenResolutions = []string{
|
||||
"1920x1080", "2560x1440", "3840x2160",
|
||||
"1366x768", "1440x900", "1680x1050",
|
||||
"2560x1600", "3440x1440",
|
||||
}
|
||||
colorDepths = []int{24, 32}
|
||||
hardwareConcurrencies = []int{4, 6, 8, 10, 12, 16, 20, 24, 32}
|
||||
timezoneOffsets = []int{-480, -420, -360, -300, -240, 0, 60, 120, 480, 540}
|
||||
)
|
||||
|
||||
// NewFingerprintManager 创建指纹管理器
|
||||
func NewFingerprintManager() *FingerprintManager {
|
||||
return &FingerprintManager{
|
||||
fingerprints: make(map[string]*Fingerprint),
|
||||
rng: rand.New(rand.NewSource(time.Now().UnixNano())),
|
||||
}
|
||||
}
|
||||
|
||||
// GetFingerprint 获取或生成 Token 关联的指纹
|
||||
func (fm *FingerprintManager) GetFingerprint(tokenKey string) *Fingerprint {
|
||||
fm.mu.RLock()
|
||||
if fp, exists := fm.fingerprints[tokenKey]; exists {
|
||||
fm.mu.RUnlock()
|
||||
return fp
|
||||
}
|
||||
fm.mu.RUnlock()
|
||||
|
||||
fm.mu.Lock()
|
||||
defer fm.mu.Unlock()
|
||||
|
||||
if fp, exists := fm.fingerprints[tokenKey]; exists {
|
||||
return fp
|
||||
}
|
||||
|
||||
fp := fm.generateFingerprint(tokenKey)
|
||||
fm.fingerprints[tokenKey] = fp
|
||||
return fp
|
||||
}
|
||||
|
||||
// generateFingerprint 生成新的指纹
|
||||
func (fm *FingerprintManager) generateFingerprint(tokenKey string) *Fingerprint {
|
||||
osType := fm.randomChoice(osTypes)
|
||||
osVersion := fm.randomChoice(osVersions[osType])
|
||||
kiroVersion := fm.randomChoice(kiroVersions)
|
||||
|
||||
fp := &Fingerprint{
|
||||
SDKVersion: fm.randomChoice(sdkVersions),
|
||||
OSType: osType,
|
||||
OSVersion: osVersion,
|
||||
NodeVersion: fm.randomChoice(nodeVersions),
|
||||
KiroVersion: kiroVersion,
|
||||
AcceptLanguage: fm.randomChoice(acceptLanguages),
|
||||
ScreenResolution: fm.randomChoice(screenResolutions),
|
||||
ColorDepth: fm.randomIntChoice(colorDepths),
|
||||
HardwareConcurrency: fm.randomIntChoice(hardwareConcurrencies),
|
||||
TimezoneOffset: fm.randomIntChoice(timezoneOffsets),
|
||||
}
|
||||
|
||||
fp.KiroHash = fm.generateKiroHash(tokenKey, kiroVersion, osType)
|
||||
return fp
|
||||
}
|
||||
|
||||
// generateKiroHash 生成 Kiro Hash
|
||||
func (fm *FingerprintManager) generateKiroHash(tokenKey, kiroVersion, osType string) string {
|
||||
data := fmt.Sprintf("%s:%s:%s:%d", tokenKey, kiroVersion, osType, time.Now().UnixNano())
|
||||
hash := sha256.Sum256([]byte(data))
|
||||
return hex.EncodeToString(hash[:])
|
||||
}
|
||||
|
||||
// randomChoice 随机选择字符串
|
||||
func (fm *FingerprintManager) randomChoice(choices []string) string {
|
||||
return choices[fm.rng.Intn(len(choices))]
|
||||
}
|
||||
|
||||
// randomIntChoice 随机选择整数
|
||||
func (fm *FingerprintManager) randomIntChoice(choices []int) int {
|
||||
return choices[fm.rng.Intn(len(choices))]
|
||||
}
|
||||
|
||||
// ApplyToRequest 将指纹信息应用到 HTTP 请求头
|
||||
func (fp *Fingerprint) ApplyToRequest(req *http.Request) {
|
||||
req.Header.Set("X-Kiro-SDK-Version", fp.SDKVersion)
|
||||
req.Header.Set("X-Kiro-OS-Type", fp.OSType)
|
||||
req.Header.Set("X-Kiro-OS-Version", fp.OSVersion)
|
||||
req.Header.Set("X-Kiro-Node-Version", fp.NodeVersion)
|
||||
req.Header.Set("X-Kiro-Version", fp.KiroVersion)
|
||||
req.Header.Set("X-Kiro-Hash", fp.KiroHash)
|
||||
req.Header.Set("Accept-Language", fp.AcceptLanguage)
|
||||
req.Header.Set("X-Screen-Resolution", fp.ScreenResolution)
|
||||
req.Header.Set("X-Color-Depth", fmt.Sprintf("%d", fp.ColorDepth))
|
||||
req.Header.Set("X-Hardware-Concurrency", fmt.Sprintf("%d", fp.HardwareConcurrency))
|
||||
req.Header.Set("X-Timezone-Offset", fmt.Sprintf("%d", fp.TimezoneOffset))
|
||||
}
|
||||
|
||||
// RemoveFingerprint 移除 Token 关联的指纹
|
||||
func (fm *FingerprintManager) RemoveFingerprint(tokenKey string) {
|
||||
fm.mu.Lock()
|
||||
defer fm.mu.Unlock()
|
||||
delete(fm.fingerprints, tokenKey)
|
||||
}
|
||||
|
||||
// Count 返回当前管理的指纹数量
|
||||
func (fm *FingerprintManager) Count() int {
|
||||
fm.mu.RLock()
|
||||
defer fm.mu.RUnlock()
|
||||
return len(fm.fingerprints)
|
||||
}
|
||||
|
||||
// BuildUserAgent 构建 User-Agent 字符串 (Kiro IDE 风格)
|
||||
// 格式: aws-sdk-js/{SDKVersion} ua/2.1 os/{OSType}#{OSVersion} lang/js md/nodejs#{NodeVersion} api/codewhispererstreaming#{SDKVersion} m/E KiroIDE-{KiroVersion}-{KiroHash}
|
||||
func (fp *Fingerprint) BuildUserAgent() string {
|
||||
return fmt.Sprintf(
|
||||
"aws-sdk-js/%s ua/2.1 os/%s#%s lang/js md/nodejs#%s api/codewhispererstreaming#%s m/E KiroIDE-%s-%s",
|
||||
fp.SDKVersion,
|
||||
fp.OSType,
|
||||
fp.OSVersion,
|
||||
fp.NodeVersion,
|
||||
fp.SDKVersion,
|
||||
fp.KiroVersion,
|
||||
fp.KiroHash,
|
||||
)
|
||||
}
|
||||
|
||||
// BuildAmzUserAgent 构建 X-Amz-User-Agent 字符串
|
||||
// 格式: aws-sdk-js/{SDKVersion} KiroIDE-{KiroVersion}-{KiroHash}
|
||||
func (fp *Fingerprint) BuildAmzUserAgent() string {
|
||||
return fmt.Sprintf(
|
||||
"aws-sdk-js/%s KiroIDE-%s-%s",
|
||||
fp.SDKVersion,
|
||||
fp.KiroVersion,
|
||||
fp.KiroHash,
|
||||
)
|
||||
}
|
||||
227
internal/auth/kiro/fingerprint_test.go
Normal file
227
internal/auth/kiro/fingerprint_test.go
Normal file
@@ -0,0 +1,227 @@
|
||||
package kiro
|
||||
|
||||
import (
|
||||
"net/http"
|
||||
"sync"
|
||||
"testing"
|
||||
)
|
||||
|
||||
func TestNewFingerprintManager(t *testing.T) {
|
||||
fm := NewFingerprintManager()
|
||||
if fm == nil {
|
||||
t.Fatal("expected non-nil FingerprintManager")
|
||||
}
|
||||
if fm.fingerprints == nil {
|
||||
t.Error("expected non-nil fingerprints map")
|
||||
}
|
||||
if fm.rng == nil {
|
||||
t.Error("expected non-nil rng")
|
||||
}
|
||||
}
|
||||
|
||||
func TestGetFingerprint_NewToken(t *testing.T) {
|
||||
fm := NewFingerprintManager()
|
||||
fp := fm.GetFingerprint("token1")
|
||||
|
||||
if fp == nil {
|
||||
t.Fatal("expected non-nil Fingerprint")
|
||||
}
|
||||
if fp.SDKVersion == "" {
|
||||
t.Error("expected non-empty SDKVersion")
|
||||
}
|
||||
if fp.OSType == "" {
|
||||
t.Error("expected non-empty OSType")
|
||||
}
|
||||
if fp.OSVersion == "" {
|
||||
t.Error("expected non-empty OSVersion")
|
||||
}
|
||||
if fp.NodeVersion == "" {
|
||||
t.Error("expected non-empty NodeVersion")
|
||||
}
|
||||
if fp.KiroVersion == "" {
|
||||
t.Error("expected non-empty KiroVersion")
|
||||
}
|
||||
if fp.KiroHash == "" {
|
||||
t.Error("expected non-empty KiroHash")
|
||||
}
|
||||
if fp.AcceptLanguage == "" {
|
||||
t.Error("expected non-empty AcceptLanguage")
|
||||
}
|
||||
if fp.ScreenResolution == "" {
|
||||
t.Error("expected non-empty ScreenResolution")
|
||||
}
|
||||
if fp.ColorDepth == 0 {
|
||||
t.Error("expected non-zero ColorDepth")
|
||||
}
|
||||
if fp.HardwareConcurrency == 0 {
|
||||
t.Error("expected non-zero HardwareConcurrency")
|
||||
}
|
||||
}
|
||||
|
||||
func TestGetFingerprint_SameTokenReturnsSameFingerprint(t *testing.T) {
|
||||
fm := NewFingerprintManager()
|
||||
fp1 := fm.GetFingerprint("token1")
|
||||
fp2 := fm.GetFingerprint("token1")
|
||||
|
||||
if fp1 != fp2 {
|
||||
t.Error("expected same fingerprint for same token")
|
||||
}
|
||||
}
|
||||
|
||||
func TestGetFingerprint_DifferentTokens(t *testing.T) {
|
||||
fm := NewFingerprintManager()
|
||||
fp1 := fm.GetFingerprint("token1")
|
||||
fp2 := fm.GetFingerprint("token2")
|
||||
|
||||
if fp1 == fp2 {
|
||||
t.Error("expected different fingerprints for different tokens")
|
||||
}
|
||||
}
|
||||
|
||||
func TestRemoveFingerprint(t *testing.T) {
|
||||
fm := NewFingerprintManager()
|
||||
fm.GetFingerprint("token1")
|
||||
if fm.Count() != 1 {
|
||||
t.Fatalf("expected count 1, got %d", fm.Count())
|
||||
}
|
||||
|
||||
fm.RemoveFingerprint("token1")
|
||||
if fm.Count() != 0 {
|
||||
t.Errorf("expected count 0, got %d", fm.Count())
|
||||
}
|
||||
}
|
||||
|
||||
func TestRemoveFingerprint_NonExistent(t *testing.T) {
|
||||
fm := NewFingerprintManager()
|
||||
fm.RemoveFingerprint("nonexistent")
|
||||
if fm.Count() != 0 {
|
||||
t.Errorf("expected count 0, got %d", fm.Count())
|
||||
}
|
||||
}
|
||||
|
||||
func TestCount(t *testing.T) {
|
||||
fm := NewFingerprintManager()
|
||||
if fm.Count() != 0 {
|
||||
t.Errorf("expected count 0, got %d", fm.Count())
|
||||
}
|
||||
|
||||
fm.GetFingerprint("token1")
|
||||
fm.GetFingerprint("token2")
|
||||
fm.GetFingerprint("token3")
|
||||
|
||||
if fm.Count() != 3 {
|
||||
t.Errorf("expected count 3, got %d", fm.Count())
|
||||
}
|
||||
}
|
||||
|
||||
func TestApplyToRequest(t *testing.T) {
|
||||
fm := NewFingerprintManager()
|
||||
fp := fm.GetFingerprint("token1")
|
||||
|
||||
req, _ := http.NewRequest("GET", "http://example.com", nil)
|
||||
fp.ApplyToRequest(req)
|
||||
|
||||
if req.Header.Get("X-Kiro-SDK-Version") != fp.SDKVersion {
|
||||
t.Error("X-Kiro-SDK-Version header mismatch")
|
||||
}
|
||||
if req.Header.Get("X-Kiro-OS-Type") != fp.OSType {
|
||||
t.Error("X-Kiro-OS-Type header mismatch")
|
||||
}
|
||||
if req.Header.Get("X-Kiro-OS-Version") != fp.OSVersion {
|
||||
t.Error("X-Kiro-OS-Version header mismatch")
|
||||
}
|
||||
if req.Header.Get("X-Kiro-Node-Version") != fp.NodeVersion {
|
||||
t.Error("X-Kiro-Node-Version header mismatch")
|
||||
}
|
||||
if req.Header.Get("X-Kiro-Version") != fp.KiroVersion {
|
||||
t.Error("X-Kiro-Version header mismatch")
|
||||
}
|
||||
if req.Header.Get("X-Kiro-Hash") != fp.KiroHash {
|
||||
t.Error("X-Kiro-Hash header mismatch")
|
||||
}
|
||||
if req.Header.Get("Accept-Language") != fp.AcceptLanguage {
|
||||
t.Error("Accept-Language header mismatch")
|
||||
}
|
||||
if req.Header.Get("X-Screen-Resolution") != fp.ScreenResolution {
|
||||
t.Error("X-Screen-Resolution header mismatch")
|
||||
}
|
||||
}
|
||||
|
||||
func TestGetFingerprint_OSVersionMatchesOSType(t *testing.T) {
|
||||
fm := NewFingerprintManager()
|
||||
|
||||
for i := 0; i < 20; i++ {
|
||||
fp := fm.GetFingerprint("token" + string(rune('a'+i)))
|
||||
validVersions := osVersions[fp.OSType]
|
||||
found := false
|
||||
for _, v := range validVersions {
|
||||
if v == fp.OSVersion {
|
||||
found = true
|
||||
break
|
||||
}
|
||||
}
|
||||
if !found {
|
||||
t.Errorf("OS version %s not valid for OS type %s", fp.OSVersion, fp.OSType)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func TestFingerprintManager_ConcurrentAccess(t *testing.T) {
|
||||
fm := NewFingerprintManager()
|
||||
const numGoroutines = 100
|
||||
const numOperations = 100
|
||||
|
||||
var wg sync.WaitGroup
|
||||
wg.Add(numGoroutines)
|
||||
|
||||
for i := 0; i < numGoroutines; i++ {
|
||||
go func(id int) {
|
||||
defer wg.Done()
|
||||
for j := 0; j < numOperations; j++ {
|
||||
tokenKey := "token" + string(rune('a'+id%26))
|
||||
switch j % 4 {
|
||||
case 0:
|
||||
fm.GetFingerprint(tokenKey)
|
||||
case 1:
|
||||
fm.Count()
|
||||
case 2:
|
||||
fp := fm.GetFingerprint(tokenKey)
|
||||
req, _ := http.NewRequest("GET", "http://example.com", nil)
|
||||
fp.ApplyToRequest(req)
|
||||
case 3:
|
||||
fm.RemoveFingerprint(tokenKey)
|
||||
}
|
||||
}
|
||||
}(i)
|
||||
}
|
||||
|
||||
wg.Wait()
|
||||
}
|
||||
|
||||
func TestKiroHashUniqueness(t *testing.T) {
|
||||
fm := NewFingerprintManager()
|
||||
hashes := make(map[string]bool)
|
||||
|
||||
for i := 0; i < 100; i++ {
|
||||
fp := fm.GetFingerprint("token" + string(rune(i)))
|
||||
if hashes[fp.KiroHash] {
|
||||
t.Errorf("duplicate KiroHash detected: %s", fp.KiroHash)
|
||||
}
|
||||
hashes[fp.KiroHash] = true
|
||||
}
|
||||
}
|
||||
|
||||
func TestKiroHashFormat(t *testing.T) {
|
||||
fm := NewFingerprintManager()
|
||||
fp := fm.GetFingerprint("token1")
|
||||
|
||||
if len(fp.KiroHash) != 64 {
|
||||
t.Errorf("expected KiroHash length 64 (SHA256 hex), got %d", len(fp.KiroHash))
|
||||
}
|
||||
|
||||
for _, c := range fp.KiroHash {
|
||||
if !((c >= '0' && c <= '9') || (c >= 'a' && c <= 'f')) {
|
||||
t.Errorf("invalid hex character in KiroHash: %c", c)
|
||||
}
|
||||
}
|
||||
}
|
||||
174
internal/auth/kiro/jitter.go
Normal file
174
internal/auth/kiro/jitter.go
Normal file
@@ -0,0 +1,174 @@
|
||||
package kiro
|
||||
|
||||
import (
|
||||
"math/rand"
|
||||
"sync"
|
||||
"time"
|
||||
)
|
||||
|
||||
// Jitter configuration constants
|
||||
const (
|
||||
// JitterPercent is the default percentage of jitter to apply (±30%)
|
||||
JitterPercent = 0.30
|
||||
|
||||
// Human-like delay ranges
|
||||
ShortDelayMin = 50 * time.Millisecond // Minimum for rapid consecutive operations
|
||||
ShortDelayMax = 200 * time.Millisecond // Maximum for rapid consecutive operations
|
||||
NormalDelayMin = 1 * time.Second // Minimum for normal thinking time
|
||||
NormalDelayMax = 3 * time.Second // Maximum for normal thinking time
|
||||
LongDelayMin = 5 * time.Second // Minimum for reading/resting
|
||||
LongDelayMax = 10 * time.Second // Maximum for reading/resting
|
||||
|
||||
// Probability thresholds for human-like behavior
|
||||
ShortDelayProbability = 0.20 // 20% chance of short delay (consecutive ops)
|
||||
LongDelayProbability = 0.05 // 5% chance of long delay (reading/resting)
|
||||
NormalDelayProbability = 0.75 // 75% chance of normal delay (thinking)
|
||||
)
|
||||
|
||||
var (
|
||||
jitterRand *rand.Rand
|
||||
jitterRandOnce sync.Once
|
||||
jitterMu sync.Mutex
|
||||
lastRequestTime time.Time
|
||||
)
|
||||
|
||||
// initJitterRand initializes the random number generator for jitter calculations.
|
||||
// Uses a time-based seed for unpredictable but reproducible randomness.
|
||||
func initJitterRand() {
|
||||
jitterRandOnce.Do(func() {
|
||||
jitterRand = rand.New(rand.NewSource(time.Now().UnixNano()))
|
||||
})
|
||||
}
|
||||
|
||||
// RandomDelay generates a random delay between min and max duration.
|
||||
// Thread-safe implementation using mutex protection.
|
||||
func RandomDelay(min, max time.Duration) time.Duration {
|
||||
initJitterRand()
|
||||
jitterMu.Lock()
|
||||
defer jitterMu.Unlock()
|
||||
|
||||
if min >= max {
|
||||
return min
|
||||
}
|
||||
|
||||
rangeMs := max.Milliseconds() - min.Milliseconds()
|
||||
randomMs := jitterRand.Int63n(rangeMs)
|
||||
return min + time.Duration(randomMs)*time.Millisecond
|
||||
}
|
||||
|
||||
// JitterDelay adds jitter to a base delay.
|
||||
// Applies ±jitterPercent variation to the base delay.
|
||||
// For example, JitterDelay(1*time.Second, 0.30) returns a value between 700ms and 1300ms.
|
||||
func JitterDelay(baseDelay time.Duration, jitterPercent float64) time.Duration {
|
||||
initJitterRand()
|
||||
jitterMu.Lock()
|
||||
defer jitterMu.Unlock()
|
||||
|
||||
if jitterPercent <= 0 || jitterPercent > 1 {
|
||||
jitterPercent = JitterPercent
|
||||
}
|
||||
|
||||
// Calculate jitter range: base * jitterPercent
|
||||
jitterRange := float64(baseDelay) * jitterPercent
|
||||
|
||||
// Generate random value in range [-jitterRange, +jitterRange]
|
||||
jitter := (jitterRand.Float64()*2 - 1) * jitterRange
|
||||
|
||||
result := time.Duration(float64(baseDelay) + jitter)
|
||||
if result < 0 {
|
||||
return 0
|
||||
}
|
||||
return result
|
||||
}
|
||||
|
||||
// JitterDelayDefault applies the default ±30% jitter to a base delay.
|
||||
func JitterDelayDefault(baseDelay time.Duration) time.Duration {
|
||||
return JitterDelay(baseDelay, JitterPercent)
|
||||
}
|
||||
|
||||
// HumanLikeDelay generates a delay that mimics human behavior patterns.
|
||||
// The delay is selected based on probability distribution:
|
||||
// - 20% chance: Short delay (50-200ms) - simulates consecutive rapid operations
|
||||
// - 75% chance: Normal delay (1-3s) - simulates thinking/reading time
|
||||
// - 5% chance: Long delay (5-10s) - simulates breaks/reading longer content
|
||||
//
|
||||
// Returns the delay duration (caller should call time.Sleep with this value).
|
||||
func HumanLikeDelay() time.Duration {
|
||||
initJitterRand()
|
||||
jitterMu.Lock()
|
||||
defer jitterMu.Unlock()
|
||||
|
||||
// Track time since last request for adaptive behavior
|
||||
now := time.Now()
|
||||
timeSinceLastRequest := now.Sub(lastRequestTime)
|
||||
lastRequestTime = now
|
||||
|
||||
// If requests are very close together, use short delay
|
||||
if timeSinceLastRequest < 500*time.Millisecond && timeSinceLastRequest > 0 {
|
||||
rangeMs := ShortDelayMax.Milliseconds() - ShortDelayMin.Milliseconds()
|
||||
randomMs := jitterRand.Int63n(rangeMs)
|
||||
return ShortDelayMin + time.Duration(randomMs)*time.Millisecond
|
||||
}
|
||||
|
||||
// Otherwise, use probability-based selection
|
||||
roll := jitterRand.Float64()
|
||||
|
||||
var min, max time.Duration
|
||||
switch {
|
||||
case roll < ShortDelayProbability:
|
||||
// Short delay - consecutive operations
|
||||
min, max = ShortDelayMin, ShortDelayMax
|
||||
case roll < ShortDelayProbability+LongDelayProbability:
|
||||
// Long delay - reading/resting
|
||||
min, max = LongDelayMin, LongDelayMax
|
||||
default:
|
||||
// Normal delay - thinking time
|
||||
min, max = NormalDelayMin, NormalDelayMax
|
||||
}
|
||||
|
||||
rangeMs := max.Milliseconds() - min.Milliseconds()
|
||||
randomMs := jitterRand.Int63n(rangeMs)
|
||||
return min + time.Duration(randomMs)*time.Millisecond
|
||||
}
|
||||
|
||||
// ApplyHumanLikeDelay applies human-like delay by sleeping.
|
||||
// This is a convenience function that combines HumanLikeDelay with time.Sleep.
|
||||
func ApplyHumanLikeDelay() {
|
||||
delay := HumanLikeDelay()
|
||||
if delay > 0 {
|
||||
time.Sleep(delay)
|
||||
}
|
||||
}
|
||||
|
||||
// ExponentialBackoffWithJitter calculates retry delay using exponential backoff with jitter.
|
||||
// Formula: min(baseDelay * 2^attempt + jitter, maxDelay)
|
||||
// This helps prevent thundering herd problem when multiple clients retry simultaneously.
|
||||
func ExponentialBackoffWithJitter(attempt int, baseDelay, maxDelay time.Duration) time.Duration {
|
||||
if attempt < 0 {
|
||||
attempt = 0
|
||||
}
|
||||
|
||||
// Calculate exponential backoff: baseDelay * 2^attempt
|
||||
backoff := baseDelay * time.Duration(1<<uint(attempt))
|
||||
if backoff > maxDelay {
|
||||
backoff = maxDelay
|
||||
}
|
||||
|
||||
// Add ±30% jitter
|
||||
return JitterDelay(backoff, JitterPercent)
|
||||
}
|
||||
|
||||
// ShouldSkipDelay determines if delay should be skipped based on context.
|
||||
// Returns true for streaming responses, WebSocket connections, etc.
|
||||
// This function can be extended to check additional skip conditions.
|
||||
func ShouldSkipDelay(isStreaming bool) bool {
|
||||
return isStreaming
|
||||
}
|
||||
|
||||
// ResetLastRequestTime resets the last request time tracker.
|
||||
// Useful for testing or when starting a new session.
|
||||
func ResetLastRequestTime() {
|
||||
jitterMu.Lock()
|
||||
defer jitterMu.Unlock()
|
||||
lastRequestTime = time.Time{}
|
||||
}
|
||||
187
internal/auth/kiro/metrics.go
Normal file
187
internal/auth/kiro/metrics.go
Normal file
@@ -0,0 +1,187 @@
|
||||
package kiro
|
||||
|
||||
import (
|
||||
"math"
|
||||
"sync"
|
||||
"time"
|
||||
)
|
||||
|
||||
// TokenMetrics holds performance metrics for a single token.
|
||||
type TokenMetrics struct {
|
||||
SuccessRate float64 // Success rate (0.0 - 1.0)
|
||||
AvgLatency float64 // Average latency in milliseconds
|
||||
QuotaRemaining float64 // Remaining quota (0.0 - 1.0)
|
||||
LastUsed time.Time // Last usage timestamp
|
||||
FailCount int // Consecutive failure count
|
||||
TotalRequests int // Total request count
|
||||
successCount int // Internal: successful request count
|
||||
totalLatency float64 // Internal: cumulative latency
|
||||
}
|
||||
|
||||
// TokenScorer manages token metrics and scoring.
|
||||
type TokenScorer struct {
|
||||
mu sync.RWMutex
|
||||
metrics map[string]*TokenMetrics
|
||||
|
||||
// Scoring weights
|
||||
successRateWeight float64
|
||||
quotaWeight float64
|
||||
latencyWeight float64
|
||||
lastUsedWeight float64
|
||||
failPenaltyMultiplier float64
|
||||
}
|
||||
|
||||
// NewTokenScorer creates a new TokenScorer with default weights.
|
||||
func NewTokenScorer() *TokenScorer {
|
||||
return &TokenScorer{
|
||||
metrics: make(map[string]*TokenMetrics),
|
||||
successRateWeight: 0.4,
|
||||
quotaWeight: 0.25,
|
||||
latencyWeight: 0.2,
|
||||
lastUsedWeight: 0.15,
|
||||
failPenaltyMultiplier: 0.1,
|
||||
}
|
||||
}
|
||||
|
||||
// getOrCreateMetrics returns existing metrics or creates new ones.
|
||||
func (s *TokenScorer) getOrCreateMetrics(tokenKey string) *TokenMetrics {
|
||||
if m, ok := s.metrics[tokenKey]; ok {
|
||||
return m
|
||||
}
|
||||
m := &TokenMetrics{
|
||||
SuccessRate: 1.0,
|
||||
QuotaRemaining: 1.0,
|
||||
}
|
||||
s.metrics[tokenKey] = m
|
||||
return m
|
||||
}
|
||||
|
||||
// RecordRequest records the result of a request for a token.
|
||||
func (s *TokenScorer) RecordRequest(tokenKey string, success bool, latency time.Duration) {
|
||||
s.mu.Lock()
|
||||
defer s.mu.Unlock()
|
||||
|
||||
m := s.getOrCreateMetrics(tokenKey)
|
||||
m.TotalRequests++
|
||||
m.LastUsed = time.Now()
|
||||
m.totalLatency += float64(latency.Milliseconds())
|
||||
|
||||
if success {
|
||||
m.successCount++
|
||||
m.FailCount = 0
|
||||
} else {
|
||||
m.FailCount++
|
||||
}
|
||||
|
||||
// Update derived metrics
|
||||
if m.TotalRequests > 0 {
|
||||
m.SuccessRate = float64(m.successCount) / float64(m.TotalRequests)
|
||||
m.AvgLatency = m.totalLatency / float64(m.TotalRequests)
|
||||
}
|
||||
}
|
||||
|
||||
// SetQuotaRemaining updates the remaining quota for a token.
|
||||
func (s *TokenScorer) SetQuotaRemaining(tokenKey string, quota float64) {
|
||||
s.mu.Lock()
|
||||
defer s.mu.Unlock()
|
||||
|
||||
m := s.getOrCreateMetrics(tokenKey)
|
||||
m.QuotaRemaining = quota
|
||||
}
|
||||
|
||||
// GetMetrics returns a copy of the metrics for a token.
|
||||
func (s *TokenScorer) GetMetrics(tokenKey string) *TokenMetrics {
|
||||
s.mu.RLock()
|
||||
defer s.mu.RUnlock()
|
||||
|
||||
if m, ok := s.metrics[tokenKey]; ok {
|
||||
copy := *m
|
||||
return ©
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
// CalculateScore computes the score for a token (higher is better).
|
||||
func (s *TokenScorer) CalculateScore(tokenKey string) float64 {
|
||||
s.mu.RLock()
|
||||
defer s.mu.RUnlock()
|
||||
|
||||
m, ok := s.metrics[tokenKey]
|
||||
if !ok {
|
||||
return 1.0 // New tokens get a high initial score
|
||||
}
|
||||
|
||||
// Success rate component (0-1)
|
||||
successScore := m.SuccessRate
|
||||
|
||||
// Quota component (0-1)
|
||||
quotaScore := m.QuotaRemaining
|
||||
|
||||
// Latency component (normalized, lower is better)
|
||||
// Using exponential decay: score = e^(-latency/1000)
|
||||
// 1000ms latency -> ~0.37 score, 100ms -> ~0.90 score
|
||||
latencyScore := math.Exp(-m.AvgLatency / 1000.0)
|
||||
if m.TotalRequests == 0 {
|
||||
latencyScore = 1.0
|
||||
}
|
||||
|
||||
// Last used component (prefer tokens not recently used)
|
||||
// Score increases as time since last use increases
|
||||
timeSinceUse := time.Since(m.LastUsed).Seconds()
|
||||
// Normalize: 60 seconds -> ~0.63 score, 0 seconds -> 0 score
|
||||
lastUsedScore := 1.0 - math.Exp(-timeSinceUse/60.0)
|
||||
if m.LastUsed.IsZero() {
|
||||
lastUsedScore = 1.0
|
||||
}
|
||||
|
||||
// Calculate weighted score
|
||||
score := s.successRateWeight*successScore +
|
||||
s.quotaWeight*quotaScore +
|
||||
s.latencyWeight*latencyScore +
|
||||
s.lastUsedWeight*lastUsedScore
|
||||
|
||||
// Apply consecutive failure penalty
|
||||
if m.FailCount > 0 {
|
||||
penalty := s.failPenaltyMultiplier * float64(m.FailCount)
|
||||
score = score * math.Max(0, 1.0-penalty)
|
||||
}
|
||||
|
||||
return score
|
||||
}
|
||||
|
||||
// SelectBestToken selects the token with the highest score.
|
||||
func (s *TokenScorer) SelectBestToken(tokens []string) string {
|
||||
if len(tokens) == 0 {
|
||||
return ""
|
||||
}
|
||||
if len(tokens) == 1 {
|
||||
return tokens[0]
|
||||
}
|
||||
|
||||
bestToken := tokens[0]
|
||||
bestScore := s.CalculateScore(tokens[0])
|
||||
|
||||
for _, token := range tokens[1:] {
|
||||
score := s.CalculateScore(token)
|
||||
if score > bestScore {
|
||||
bestScore = score
|
||||
bestToken = token
|
||||
}
|
||||
}
|
||||
|
||||
return bestToken
|
||||
}
|
||||
|
||||
// ResetMetrics clears all metrics for a token.
|
||||
func (s *TokenScorer) ResetMetrics(tokenKey string) {
|
||||
s.mu.Lock()
|
||||
defer s.mu.Unlock()
|
||||
delete(s.metrics, tokenKey)
|
||||
}
|
||||
|
||||
// ResetAllMetrics clears all stored metrics.
|
||||
func (s *TokenScorer) ResetAllMetrics() {
|
||||
s.mu.Lock()
|
||||
defer s.mu.Unlock()
|
||||
s.metrics = make(map[string]*TokenMetrics)
|
||||
}
|
||||
301
internal/auth/kiro/metrics_test.go
Normal file
301
internal/auth/kiro/metrics_test.go
Normal file
@@ -0,0 +1,301 @@
|
||||
package kiro
|
||||
|
||||
import (
|
||||
"sync"
|
||||
"testing"
|
||||
"time"
|
||||
)
|
||||
|
||||
func TestNewTokenScorer(t *testing.T) {
|
||||
s := NewTokenScorer()
|
||||
if s == nil {
|
||||
t.Fatal("expected non-nil TokenScorer")
|
||||
}
|
||||
if s.metrics == nil {
|
||||
t.Error("expected non-nil metrics map")
|
||||
}
|
||||
if s.successRateWeight != 0.4 {
|
||||
t.Errorf("expected successRateWeight 0.4, got %f", s.successRateWeight)
|
||||
}
|
||||
if s.quotaWeight != 0.25 {
|
||||
t.Errorf("expected quotaWeight 0.25, got %f", s.quotaWeight)
|
||||
}
|
||||
}
|
||||
|
||||
func TestRecordRequest_Success(t *testing.T) {
|
||||
s := NewTokenScorer()
|
||||
s.RecordRequest("token1", true, 100*time.Millisecond)
|
||||
|
||||
m := s.GetMetrics("token1")
|
||||
if m == nil {
|
||||
t.Fatal("expected non-nil metrics")
|
||||
}
|
||||
if m.TotalRequests != 1 {
|
||||
t.Errorf("expected TotalRequests 1, got %d", m.TotalRequests)
|
||||
}
|
||||
if m.SuccessRate != 1.0 {
|
||||
t.Errorf("expected SuccessRate 1.0, got %f", m.SuccessRate)
|
||||
}
|
||||
if m.FailCount != 0 {
|
||||
t.Errorf("expected FailCount 0, got %d", m.FailCount)
|
||||
}
|
||||
if m.AvgLatency != 100 {
|
||||
t.Errorf("expected AvgLatency 100, got %f", m.AvgLatency)
|
||||
}
|
||||
}
|
||||
|
||||
func TestRecordRequest_Failure(t *testing.T) {
|
||||
s := NewTokenScorer()
|
||||
s.RecordRequest("token1", false, 200*time.Millisecond)
|
||||
|
||||
m := s.GetMetrics("token1")
|
||||
if m.SuccessRate != 0.0 {
|
||||
t.Errorf("expected SuccessRate 0.0, got %f", m.SuccessRate)
|
||||
}
|
||||
if m.FailCount != 1 {
|
||||
t.Errorf("expected FailCount 1, got %d", m.FailCount)
|
||||
}
|
||||
}
|
||||
|
||||
func TestRecordRequest_MixedResults(t *testing.T) {
|
||||
s := NewTokenScorer()
|
||||
s.RecordRequest("token1", true, 100*time.Millisecond)
|
||||
s.RecordRequest("token1", true, 100*time.Millisecond)
|
||||
s.RecordRequest("token1", false, 100*time.Millisecond)
|
||||
s.RecordRequest("token1", true, 100*time.Millisecond)
|
||||
|
||||
m := s.GetMetrics("token1")
|
||||
if m.TotalRequests != 4 {
|
||||
t.Errorf("expected TotalRequests 4, got %d", m.TotalRequests)
|
||||
}
|
||||
if m.SuccessRate != 0.75 {
|
||||
t.Errorf("expected SuccessRate 0.75, got %f", m.SuccessRate)
|
||||
}
|
||||
if m.FailCount != 0 {
|
||||
t.Errorf("expected FailCount 0 (reset on success), got %d", m.FailCount)
|
||||
}
|
||||
}
|
||||
|
||||
func TestRecordRequest_ConsecutiveFailures(t *testing.T) {
|
||||
s := NewTokenScorer()
|
||||
s.RecordRequest("token1", true, 100*time.Millisecond)
|
||||
s.RecordRequest("token1", false, 100*time.Millisecond)
|
||||
s.RecordRequest("token1", false, 100*time.Millisecond)
|
||||
s.RecordRequest("token1", false, 100*time.Millisecond)
|
||||
|
||||
m := s.GetMetrics("token1")
|
||||
if m.FailCount != 3 {
|
||||
t.Errorf("expected FailCount 3, got %d", m.FailCount)
|
||||
}
|
||||
}
|
||||
|
||||
func TestSetQuotaRemaining(t *testing.T) {
|
||||
s := NewTokenScorer()
|
||||
s.SetQuotaRemaining("token1", 0.5)
|
||||
|
||||
m := s.GetMetrics("token1")
|
||||
if m.QuotaRemaining != 0.5 {
|
||||
t.Errorf("expected QuotaRemaining 0.5, got %f", m.QuotaRemaining)
|
||||
}
|
||||
}
|
||||
|
||||
func TestGetMetrics_NonExistent(t *testing.T) {
|
||||
s := NewTokenScorer()
|
||||
m := s.GetMetrics("nonexistent")
|
||||
if m != nil {
|
||||
t.Error("expected nil metrics for non-existent token")
|
||||
}
|
||||
}
|
||||
|
||||
func TestGetMetrics_ReturnsCopy(t *testing.T) {
|
||||
s := NewTokenScorer()
|
||||
s.RecordRequest("token1", true, 100*time.Millisecond)
|
||||
|
||||
m1 := s.GetMetrics("token1")
|
||||
m1.TotalRequests = 999
|
||||
|
||||
m2 := s.GetMetrics("token1")
|
||||
if m2.TotalRequests == 999 {
|
||||
t.Error("GetMetrics should return a copy")
|
||||
}
|
||||
}
|
||||
|
||||
func TestCalculateScore_NewToken(t *testing.T) {
|
||||
s := NewTokenScorer()
|
||||
score := s.CalculateScore("newtoken")
|
||||
if score != 1.0 {
|
||||
t.Errorf("expected score 1.0 for new token, got %f", score)
|
||||
}
|
||||
}
|
||||
|
||||
func TestCalculateScore_PerfectToken(t *testing.T) {
|
||||
s := NewTokenScorer()
|
||||
s.RecordRequest("token1", true, 50*time.Millisecond)
|
||||
s.SetQuotaRemaining("token1", 1.0)
|
||||
|
||||
time.Sleep(100 * time.Millisecond)
|
||||
score := s.CalculateScore("token1")
|
||||
if score < 0.5 || score > 1.0 {
|
||||
t.Errorf("expected high score for perfect token, got %f", score)
|
||||
}
|
||||
}
|
||||
|
||||
func TestCalculateScore_FailedToken(t *testing.T) {
|
||||
s := NewTokenScorer()
|
||||
for i := 0; i < 5; i++ {
|
||||
s.RecordRequest("token1", false, 1000*time.Millisecond)
|
||||
}
|
||||
s.SetQuotaRemaining("token1", 0.1)
|
||||
|
||||
score := s.CalculateScore("token1")
|
||||
if score > 0.5 {
|
||||
t.Errorf("expected low score for failed token, got %f", score)
|
||||
}
|
||||
}
|
||||
|
||||
func TestCalculateScore_FailPenalty(t *testing.T) {
|
||||
s := NewTokenScorer()
|
||||
s.RecordRequest("token1", true, 100*time.Millisecond)
|
||||
scoreNoFail := s.CalculateScore("token1")
|
||||
|
||||
s.RecordRequest("token1", false, 100*time.Millisecond)
|
||||
s.RecordRequest("token1", false, 100*time.Millisecond)
|
||||
scoreWithFail := s.CalculateScore("token1")
|
||||
|
||||
if scoreWithFail >= scoreNoFail {
|
||||
t.Errorf("expected lower score with consecutive failures: noFail=%f, withFail=%f", scoreNoFail, scoreWithFail)
|
||||
}
|
||||
}
|
||||
|
||||
func TestSelectBestToken_Empty(t *testing.T) {
|
||||
s := NewTokenScorer()
|
||||
best := s.SelectBestToken([]string{})
|
||||
if best != "" {
|
||||
t.Errorf("expected empty string for empty tokens, got %s", best)
|
||||
}
|
||||
}
|
||||
|
||||
func TestSelectBestToken_SingleToken(t *testing.T) {
|
||||
s := NewTokenScorer()
|
||||
best := s.SelectBestToken([]string{"token1"})
|
||||
if best != "token1" {
|
||||
t.Errorf("expected token1, got %s", best)
|
||||
}
|
||||
}
|
||||
|
||||
func TestSelectBestToken_MultipleTokens(t *testing.T) {
|
||||
s := NewTokenScorer()
|
||||
|
||||
s.RecordRequest("bad", false, 1000*time.Millisecond)
|
||||
s.RecordRequest("bad", false, 1000*time.Millisecond)
|
||||
s.SetQuotaRemaining("bad", 0.1)
|
||||
|
||||
s.RecordRequest("good", true, 50*time.Millisecond)
|
||||
s.SetQuotaRemaining("good", 0.9)
|
||||
|
||||
time.Sleep(50 * time.Millisecond)
|
||||
|
||||
best := s.SelectBestToken([]string{"bad", "good"})
|
||||
if best != "good" {
|
||||
t.Errorf("expected good token to be selected, got %s", best)
|
||||
}
|
||||
}
|
||||
|
||||
func TestResetMetrics(t *testing.T) {
|
||||
s := NewTokenScorer()
|
||||
s.RecordRequest("token1", true, 100*time.Millisecond)
|
||||
s.ResetMetrics("token1")
|
||||
|
||||
m := s.GetMetrics("token1")
|
||||
if m != nil {
|
||||
t.Error("expected nil metrics after reset")
|
||||
}
|
||||
}
|
||||
|
||||
func TestResetAllMetrics(t *testing.T) {
|
||||
s := NewTokenScorer()
|
||||
s.RecordRequest("token1", true, 100*time.Millisecond)
|
||||
s.RecordRequest("token2", true, 100*time.Millisecond)
|
||||
s.RecordRequest("token3", true, 100*time.Millisecond)
|
||||
|
||||
s.ResetAllMetrics()
|
||||
|
||||
if s.GetMetrics("token1") != nil {
|
||||
t.Error("expected nil metrics for token1 after reset all")
|
||||
}
|
||||
if s.GetMetrics("token2") != nil {
|
||||
t.Error("expected nil metrics for token2 after reset all")
|
||||
}
|
||||
}
|
||||
|
||||
func TestTokenScorer_ConcurrentAccess(t *testing.T) {
|
||||
s := NewTokenScorer()
|
||||
const numGoroutines = 50
|
||||
const numOperations = 100
|
||||
|
||||
var wg sync.WaitGroup
|
||||
wg.Add(numGoroutines)
|
||||
|
||||
for i := 0; i < numGoroutines; i++ {
|
||||
go func(id int) {
|
||||
defer wg.Done()
|
||||
tokenKey := "token" + string(rune('a'+id%10))
|
||||
for j := 0; j < numOperations; j++ {
|
||||
switch j % 6 {
|
||||
case 0:
|
||||
s.RecordRequest(tokenKey, j%2 == 0, time.Duration(j)*time.Millisecond)
|
||||
case 1:
|
||||
s.SetQuotaRemaining(tokenKey, float64(j%100)/100)
|
||||
case 2:
|
||||
s.GetMetrics(tokenKey)
|
||||
case 3:
|
||||
s.CalculateScore(tokenKey)
|
||||
case 4:
|
||||
s.SelectBestToken([]string{tokenKey, "token_x", "token_y"})
|
||||
case 5:
|
||||
if j%20 == 0 {
|
||||
s.ResetMetrics(tokenKey)
|
||||
}
|
||||
}
|
||||
}
|
||||
}(i)
|
||||
}
|
||||
|
||||
wg.Wait()
|
||||
}
|
||||
|
||||
func TestAvgLatencyCalculation(t *testing.T) {
|
||||
s := NewTokenScorer()
|
||||
s.RecordRequest("token1", true, 100*time.Millisecond)
|
||||
s.RecordRequest("token1", true, 200*time.Millisecond)
|
||||
s.RecordRequest("token1", true, 300*time.Millisecond)
|
||||
|
||||
m := s.GetMetrics("token1")
|
||||
if m.AvgLatency != 200 {
|
||||
t.Errorf("expected AvgLatency 200, got %f", m.AvgLatency)
|
||||
}
|
||||
}
|
||||
|
||||
func TestLastUsedUpdated(t *testing.T) {
|
||||
s := NewTokenScorer()
|
||||
before := time.Now()
|
||||
s.RecordRequest("token1", true, 100*time.Millisecond)
|
||||
|
||||
m := s.GetMetrics("token1")
|
||||
if m.LastUsed.Before(before) {
|
||||
t.Error("expected LastUsed to be after test start time")
|
||||
}
|
||||
if m.LastUsed.After(time.Now()) {
|
||||
t.Error("expected LastUsed to be before or equal to now")
|
||||
}
|
||||
}
|
||||
|
||||
func TestDefaultQuotaForNewToken(t *testing.T) {
|
||||
s := NewTokenScorer()
|
||||
s.RecordRequest("token1", true, 100*time.Millisecond)
|
||||
|
||||
m := s.GetMetrics("token1")
|
||||
if m.QuotaRemaining != 1.0 {
|
||||
t.Errorf("expected default QuotaRemaining 1.0, got %f", m.QuotaRemaining)
|
||||
}
|
||||
}
|
||||
@@ -227,6 +227,7 @@ func (o *KiroOAuth) exchangeCodeForToken(ctx context.Context, code, codeVerifier
|
||||
ExpiresAt: expiresAt.Format(time.RFC3339),
|
||||
AuthMethod: "social",
|
||||
Provider: "", // Caller should preserve original provider
|
||||
Region: "us-east-1",
|
||||
}, nil
|
||||
}
|
||||
|
||||
@@ -285,6 +286,7 @@ func (o *KiroOAuth) RefreshToken(ctx context.Context, refreshToken string) (*Kir
|
||||
ExpiresAt: expiresAt.Format(time.RFC3339),
|
||||
AuthMethod: "social",
|
||||
Provider: "", // Caller should preserve original provider
|
||||
Region: "us-east-1",
|
||||
}, nil
|
||||
}
|
||||
|
||||
|
||||
825
internal/auth/kiro/oauth_web.go
Normal file
825
internal/auth/kiro/oauth_web.go
Normal file
@@ -0,0 +1,825 @@
|
||||
// Package kiro provides OAuth Web authentication for Kiro.
|
||||
package kiro
|
||||
|
||||
import (
|
||||
"context"
|
||||
"crypto/rand"
|
||||
"encoding/base64"
|
||||
"fmt"
|
||||
"html/template"
|
||||
"net/http"
|
||||
"os"
|
||||
"path/filepath"
|
||||
"strings"
|
||||
"sync"
|
||||
"time"
|
||||
|
||||
"github.com/gin-gonic/gin"
|
||||
"github.com/router-for-me/CLIProxyAPI/v6/internal/config"
|
||||
"github.com/router-for-me/CLIProxyAPI/v6/internal/util"
|
||||
log "github.com/sirupsen/logrus"
|
||||
)
|
||||
|
||||
const (
|
||||
defaultSessionExpiry = 10 * time.Minute
|
||||
pollIntervalSeconds = 5
|
||||
)
|
||||
|
||||
type authSessionStatus string
|
||||
|
||||
const (
|
||||
statusPending authSessionStatus = "pending"
|
||||
statusSuccess authSessionStatus = "success"
|
||||
statusFailed authSessionStatus = "failed"
|
||||
)
|
||||
|
||||
type webAuthSession struct {
|
||||
stateID string
|
||||
deviceCode string
|
||||
userCode string
|
||||
authURL string
|
||||
verificationURI string
|
||||
expiresIn int
|
||||
interval int
|
||||
status authSessionStatus
|
||||
startedAt time.Time
|
||||
completedAt time.Time
|
||||
expiresAt time.Time
|
||||
error string
|
||||
tokenData *KiroTokenData
|
||||
ssoClient *SSOOIDCClient
|
||||
clientID string
|
||||
clientSecret string
|
||||
region string
|
||||
cancelFunc context.CancelFunc
|
||||
authMethod string // "google", "github", "builder-id", "idc"
|
||||
startURL string // Used for IDC
|
||||
codeVerifier string // Used for social auth PKCE
|
||||
codeChallenge string // Used for social auth PKCE
|
||||
}
|
||||
|
||||
type OAuthWebHandler struct {
|
||||
cfg *config.Config
|
||||
sessions map[string]*webAuthSession
|
||||
mu sync.RWMutex
|
||||
onTokenObtained func(*KiroTokenData)
|
||||
}
|
||||
|
||||
func NewOAuthWebHandler(cfg *config.Config) *OAuthWebHandler {
|
||||
return &OAuthWebHandler{
|
||||
cfg: cfg,
|
||||
sessions: make(map[string]*webAuthSession),
|
||||
}
|
||||
}
|
||||
|
||||
func (h *OAuthWebHandler) SetTokenCallback(callback func(*KiroTokenData)) {
|
||||
h.onTokenObtained = callback
|
||||
}
|
||||
|
||||
func (h *OAuthWebHandler) RegisterRoutes(router gin.IRouter) {
|
||||
oauth := router.Group("/v0/oauth/kiro")
|
||||
{
|
||||
oauth.GET("", h.handleSelect)
|
||||
oauth.GET("/start", h.handleStart)
|
||||
oauth.GET("/callback", h.handleCallback)
|
||||
oauth.GET("/social/callback", h.handleSocialCallback)
|
||||
oauth.GET("/status", h.handleStatus)
|
||||
oauth.POST("/import", h.handleImportToken)
|
||||
}
|
||||
}
|
||||
|
||||
func generateStateID() (string, error) {
|
||||
b := make([]byte, 16)
|
||||
if _, err := rand.Read(b); err != nil {
|
||||
return "", err
|
||||
}
|
||||
return base64.RawURLEncoding.EncodeToString(b), nil
|
||||
}
|
||||
|
||||
func (h *OAuthWebHandler) handleSelect(c *gin.Context) {
|
||||
h.renderSelectPage(c)
|
||||
}
|
||||
|
||||
func (h *OAuthWebHandler) handleStart(c *gin.Context) {
|
||||
method := c.Query("method")
|
||||
|
||||
if method == "" {
|
||||
c.Redirect(http.StatusFound, "/v0/oauth/kiro")
|
||||
return
|
||||
}
|
||||
|
||||
switch method {
|
||||
case "google", "github":
|
||||
// Google/GitHub social login is not supported for third-party apps
|
||||
// due to AWS Cognito redirect_uri restrictions
|
||||
h.renderError(c, "Google/GitHub login is not available for third-party applications. Please use AWS Builder ID or import your token from Kiro IDE.")
|
||||
case "builder-id":
|
||||
h.startBuilderIDAuth(c)
|
||||
case "idc":
|
||||
h.startIDCAuth(c)
|
||||
default:
|
||||
h.renderError(c, fmt.Sprintf("Unknown authentication method: %s", method))
|
||||
}
|
||||
}
|
||||
|
||||
func (h *OAuthWebHandler) startSocialAuth(c *gin.Context, method string) {
|
||||
stateID, err := generateStateID()
|
||||
if err != nil {
|
||||
h.renderError(c, "Failed to generate state parameter")
|
||||
return
|
||||
}
|
||||
|
||||
codeVerifier, codeChallenge, err := generatePKCE()
|
||||
if err != nil {
|
||||
h.renderError(c, "Failed to generate PKCE parameters")
|
||||
return
|
||||
}
|
||||
|
||||
socialClient := NewSocialAuthClient(h.cfg)
|
||||
|
||||
var provider string
|
||||
if method == "google" {
|
||||
provider = string(ProviderGoogle)
|
||||
} else {
|
||||
provider = string(ProviderGitHub)
|
||||
}
|
||||
|
||||
redirectURI := h.getSocialCallbackURL(c)
|
||||
authURL := socialClient.buildLoginURL(provider, redirectURI, codeChallenge, stateID)
|
||||
|
||||
ctx, cancel := context.WithTimeout(context.Background(), 10*time.Minute)
|
||||
|
||||
session := &webAuthSession{
|
||||
stateID: stateID,
|
||||
authMethod: method,
|
||||
authURL: authURL,
|
||||
status: statusPending,
|
||||
startedAt: time.Now(),
|
||||
expiresIn: 600,
|
||||
codeVerifier: codeVerifier,
|
||||
codeChallenge: codeChallenge,
|
||||
region: "us-east-1",
|
||||
cancelFunc: cancel,
|
||||
}
|
||||
|
||||
h.mu.Lock()
|
||||
h.sessions[stateID] = session
|
||||
h.mu.Unlock()
|
||||
|
||||
go func() {
|
||||
<-ctx.Done()
|
||||
h.mu.Lock()
|
||||
if session.status == statusPending {
|
||||
session.status = statusFailed
|
||||
session.error = "Authentication timed out"
|
||||
}
|
||||
h.mu.Unlock()
|
||||
}()
|
||||
|
||||
c.Redirect(http.StatusFound, authURL)
|
||||
}
|
||||
|
||||
func (h *OAuthWebHandler) getSocialCallbackURL(c *gin.Context) string {
|
||||
scheme := "http"
|
||||
if c.Request.TLS != nil || c.GetHeader("X-Forwarded-Proto") == "https" {
|
||||
scheme = "https"
|
||||
}
|
||||
return fmt.Sprintf("%s://%s/v0/oauth/kiro/social/callback", scheme, c.Request.Host)
|
||||
}
|
||||
|
||||
func (h *OAuthWebHandler) startBuilderIDAuth(c *gin.Context) {
|
||||
stateID, err := generateStateID()
|
||||
if err != nil {
|
||||
h.renderError(c, "Failed to generate state parameter")
|
||||
return
|
||||
}
|
||||
|
||||
region := defaultIDCRegion
|
||||
startURL := builderIDStartURL
|
||||
|
||||
ssoClient := NewSSOOIDCClient(h.cfg)
|
||||
|
||||
regResp, err := ssoClient.RegisterClientWithRegion(c.Request.Context(), region)
|
||||
if err != nil {
|
||||
log.Errorf("OAuth Web: failed to register client: %v", err)
|
||||
h.renderError(c, fmt.Sprintf("Failed to register client: %v", err))
|
||||
return
|
||||
}
|
||||
|
||||
authResp, err := ssoClient.StartDeviceAuthorizationWithIDC(
|
||||
c.Request.Context(),
|
||||
regResp.ClientID,
|
||||
regResp.ClientSecret,
|
||||
startURL,
|
||||
region,
|
||||
)
|
||||
if err != nil {
|
||||
log.Errorf("OAuth Web: failed to start device authorization: %v", err)
|
||||
h.renderError(c, fmt.Sprintf("Failed to start device authorization: %v", err))
|
||||
return
|
||||
}
|
||||
|
||||
ctx, cancel := context.WithTimeout(context.Background(), time.Duration(authResp.ExpiresIn)*time.Second)
|
||||
|
||||
session := &webAuthSession{
|
||||
stateID: stateID,
|
||||
deviceCode: authResp.DeviceCode,
|
||||
userCode: authResp.UserCode,
|
||||
authURL: authResp.VerificationURIComplete,
|
||||
verificationURI: authResp.VerificationURI,
|
||||
expiresIn: authResp.ExpiresIn,
|
||||
interval: authResp.Interval,
|
||||
status: statusPending,
|
||||
startedAt: time.Now(),
|
||||
ssoClient: ssoClient,
|
||||
clientID: regResp.ClientID,
|
||||
clientSecret: regResp.ClientSecret,
|
||||
region: region,
|
||||
authMethod: "builder-id",
|
||||
startURL: startURL,
|
||||
cancelFunc: cancel,
|
||||
}
|
||||
|
||||
h.mu.Lock()
|
||||
h.sessions[stateID] = session
|
||||
h.mu.Unlock()
|
||||
|
||||
go h.pollForToken(ctx, session)
|
||||
|
||||
h.renderStartPage(c, session)
|
||||
}
|
||||
|
||||
func (h *OAuthWebHandler) startIDCAuth(c *gin.Context) {
|
||||
startURL := c.Query("startUrl")
|
||||
region := c.Query("region")
|
||||
|
||||
if startURL == "" {
|
||||
h.renderError(c, "Missing startUrl parameter for IDC authentication")
|
||||
return
|
||||
}
|
||||
if region == "" {
|
||||
region = defaultIDCRegion
|
||||
}
|
||||
|
||||
stateID, err := generateStateID()
|
||||
if err != nil {
|
||||
h.renderError(c, "Failed to generate state parameter")
|
||||
return
|
||||
}
|
||||
|
||||
ssoClient := NewSSOOIDCClient(h.cfg)
|
||||
|
||||
regResp, err := ssoClient.RegisterClientWithRegion(c.Request.Context(), region)
|
||||
if err != nil {
|
||||
log.Errorf("OAuth Web: failed to register client: %v", err)
|
||||
h.renderError(c, fmt.Sprintf("Failed to register client: %v", err))
|
||||
return
|
||||
}
|
||||
|
||||
authResp, err := ssoClient.StartDeviceAuthorizationWithIDC(
|
||||
c.Request.Context(),
|
||||
regResp.ClientID,
|
||||
regResp.ClientSecret,
|
||||
startURL,
|
||||
region,
|
||||
)
|
||||
if err != nil {
|
||||
log.Errorf("OAuth Web: failed to start device authorization: %v", err)
|
||||
h.renderError(c, fmt.Sprintf("Failed to start device authorization: %v", err))
|
||||
return
|
||||
}
|
||||
|
||||
ctx, cancel := context.WithTimeout(context.Background(), time.Duration(authResp.ExpiresIn)*time.Second)
|
||||
|
||||
session := &webAuthSession{
|
||||
stateID: stateID,
|
||||
deviceCode: authResp.DeviceCode,
|
||||
userCode: authResp.UserCode,
|
||||
authURL: authResp.VerificationURIComplete,
|
||||
verificationURI: authResp.VerificationURI,
|
||||
expiresIn: authResp.ExpiresIn,
|
||||
interval: authResp.Interval,
|
||||
status: statusPending,
|
||||
startedAt: time.Now(),
|
||||
ssoClient: ssoClient,
|
||||
clientID: regResp.ClientID,
|
||||
clientSecret: regResp.ClientSecret,
|
||||
region: region,
|
||||
authMethod: "idc",
|
||||
startURL: startURL,
|
||||
cancelFunc: cancel,
|
||||
}
|
||||
|
||||
h.mu.Lock()
|
||||
h.sessions[stateID] = session
|
||||
h.mu.Unlock()
|
||||
|
||||
go h.pollForToken(ctx, session)
|
||||
|
||||
h.renderStartPage(c, session)
|
||||
}
|
||||
|
||||
func (h *OAuthWebHandler) pollForToken(ctx context.Context, session *webAuthSession) {
|
||||
defer session.cancelFunc()
|
||||
|
||||
interval := time.Duration(session.interval) * time.Second
|
||||
if interval < time.Duration(pollIntervalSeconds)*time.Second {
|
||||
interval = time.Duration(pollIntervalSeconds) * time.Second
|
||||
}
|
||||
|
||||
ticker := time.NewTicker(interval)
|
||||
defer ticker.Stop()
|
||||
|
||||
for {
|
||||
select {
|
||||
case <-ctx.Done():
|
||||
h.mu.Lock()
|
||||
if session.status == statusPending {
|
||||
session.status = statusFailed
|
||||
session.error = "Authentication timed out"
|
||||
}
|
||||
h.mu.Unlock()
|
||||
return
|
||||
case <-ticker.C:
|
||||
tokenResp, err := h.ssoClient(session).CreateTokenWithRegion(
|
||||
ctx,
|
||||
session.clientID,
|
||||
session.clientSecret,
|
||||
session.deviceCode,
|
||||
session.region,
|
||||
)
|
||||
|
||||
if err != nil {
|
||||
errStr := err.Error()
|
||||
if errStr == ErrAuthorizationPending.Error() {
|
||||
continue
|
||||
}
|
||||
if errStr == ErrSlowDown.Error() {
|
||||
interval += 5 * time.Second
|
||||
ticker.Reset(interval)
|
||||
continue
|
||||
}
|
||||
|
||||
h.mu.Lock()
|
||||
session.status = statusFailed
|
||||
session.error = errStr
|
||||
session.completedAt = time.Now()
|
||||
h.mu.Unlock()
|
||||
|
||||
log.Errorf("OAuth Web: token polling failed: %v", err)
|
||||
return
|
||||
}
|
||||
|
||||
expiresAt := time.Now().Add(time.Duration(tokenResp.ExpiresIn) * time.Second)
|
||||
profileArn := session.ssoClient.fetchProfileArn(ctx, tokenResp.AccessToken)
|
||||
email := FetchUserEmailWithFallback(ctx, h.cfg, tokenResp.AccessToken)
|
||||
|
||||
tokenData := &KiroTokenData{
|
||||
AccessToken: tokenResp.AccessToken,
|
||||
RefreshToken: tokenResp.RefreshToken,
|
||||
ProfileArn: profileArn,
|
||||
ExpiresAt: expiresAt.Format(time.RFC3339),
|
||||
AuthMethod: session.authMethod,
|
||||
Provider: "AWS",
|
||||
ClientID: session.clientID,
|
||||
ClientSecret: session.clientSecret,
|
||||
Email: email,
|
||||
Region: session.region,
|
||||
}
|
||||
|
||||
h.mu.Lock()
|
||||
session.status = statusSuccess
|
||||
session.completedAt = time.Now()
|
||||
session.expiresAt = expiresAt
|
||||
session.tokenData = tokenData
|
||||
h.mu.Unlock()
|
||||
|
||||
if h.onTokenObtained != nil {
|
||||
h.onTokenObtained(tokenData)
|
||||
}
|
||||
|
||||
// Save token to file
|
||||
h.saveTokenToFile(tokenData)
|
||||
|
||||
log.Infof("OAuth Web: authentication successful for %s", email)
|
||||
return
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// saveTokenToFile saves the token data to the auth directory
|
||||
func (h *OAuthWebHandler) saveTokenToFile(tokenData *KiroTokenData) {
|
||||
// Get auth directory from config or use default
|
||||
authDir := ""
|
||||
if h.cfg != nil && h.cfg.AuthDir != "" {
|
||||
var err error
|
||||
authDir, err = util.ResolveAuthDir(h.cfg.AuthDir)
|
||||
if err != nil {
|
||||
log.Errorf("OAuth Web: failed to resolve auth directory: %v", err)
|
||||
}
|
||||
}
|
||||
|
||||
// Fall back to default location
|
||||
if authDir == "" {
|
||||
home, err := os.UserHomeDir()
|
||||
if err != nil {
|
||||
log.Errorf("OAuth Web: failed to get home directory: %v", err)
|
||||
return
|
||||
}
|
||||
authDir = filepath.Join(home, ".cli-proxy-api")
|
||||
}
|
||||
|
||||
// Create directory if not exists
|
||||
if err := os.MkdirAll(authDir, 0700); err != nil {
|
||||
log.Errorf("OAuth Web: failed to create auth directory: %v", err)
|
||||
return
|
||||
}
|
||||
|
||||
// Generate filename based on auth method
|
||||
// Format: kiro-{authMethod}.json or kiro-{authMethod}-{email}.json
|
||||
fileName := fmt.Sprintf("kiro-%s.json", tokenData.AuthMethod)
|
||||
if tokenData.Email != "" {
|
||||
// Sanitize email for filename (replace @ and . with -)
|
||||
sanitizedEmail := tokenData.Email
|
||||
sanitizedEmail = strings.ReplaceAll(sanitizedEmail, "@", "-")
|
||||
sanitizedEmail = strings.ReplaceAll(sanitizedEmail, ".", "-")
|
||||
fileName = fmt.Sprintf("kiro-%s-%s.json", tokenData.AuthMethod, sanitizedEmail)
|
||||
}
|
||||
|
||||
authFilePath := filepath.Join(authDir, fileName)
|
||||
|
||||
// Convert to storage format and save
|
||||
storage := &KiroTokenStorage{
|
||||
Type: "kiro",
|
||||
AccessToken: tokenData.AccessToken,
|
||||
RefreshToken: tokenData.RefreshToken,
|
||||
ProfileArn: tokenData.ProfileArn,
|
||||
ExpiresAt: tokenData.ExpiresAt,
|
||||
AuthMethod: tokenData.AuthMethod,
|
||||
Provider: tokenData.Provider,
|
||||
LastRefresh: time.Now().Format(time.RFC3339),
|
||||
ClientID: tokenData.ClientID,
|
||||
ClientSecret: tokenData.ClientSecret,
|
||||
Region: tokenData.Region,
|
||||
StartURL: tokenData.StartURL,
|
||||
Email: tokenData.Email,
|
||||
}
|
||||
|
||||
if err := storage.SaveTokenToFile(authFilePath); err != nil {
|
||||
log.Errorf("OAuth Web: failed to save token to file: %v", err)
|
||||
return
|
||||
}
|
||||
|
||||
log.Infof("OAuth Web: token saved to %s", authFilePath)
|
||||
}
|
||||
|
||||
func (h *OAuthWebHandler) ssoClient(session *webAuthSession) *SSOOIDCClient {
|
||||
return session.ssoClient
|
||||
}
|
||||
|
||||
func (h *OAuthWebHandler) handleCallback(c *gin.Context) {
|
||||
stateID := c.Query("state")
|
||||
errParam := c.Query("error")
|
||||
|
||||
if errParam != "" {
|
||||
h.renderError(c, errParam)
|
||||
return
|
||||
}
|
||||
|
||||
if stateID == "" {
|
||||
h.renderError(c, "Missing state parameter")
|
||||
return
|
||||
}
|
||||
|
||||
h.mu.RLock()
|
||||
session, exists := h.sessions[stateID]
|
||||
h.mu.RUnlock()
|
||||
|
||||
if !exists {
|
||||
h.renderError(c, "Invalid or expired session")
|
||||
return
|
||||
}
|
||||
|
||||
if session.status == statusSuccess {
|
||||
h.renderSuccess(c, session)
|
||||
} else if session.status == statusFailed {
|
||||
h.renderError(c, session.error)
|
||||
} else {
|
||||
c.Redirect(http.StatusFound, "/v0/oauth/kiro/start")
|
||||
}
|
||||
}
|
||||
|
||||
func (h *OAuthWebHandler) handleSocialCallback(c *gin.Context) {
|
||||
stateID := c.Query("state")
|
||||
code := c.Query("code")
|
||||
errParam := c.Query("error")
|
||||
|
||||
if errParam != "" {
|
||||
h.renderError(c, errParam)
|
||||
return
|
||||
}
|
||||
|
||||
if stateID == "" {
|
||||
h.renderError(c, "Missing state parameter")
|
||||
return
|
||||
}
|
||||
|
||||
if code == "" {
|
||||
h.renderError(c, "Missing authorization code")
|
||||
return
|
||||
}
|
||||
|
||||
h.mu.RLock()
|
||||
session, exists := h.sessions[stateID]
|
||||
h.mu.RUnlock()
|
||||
|
||||
if !exists {
|
||||
h.renderError(c, "Invalid or expired session")
|
||||
return
|
||||
}
|
||||
|
||||
if session.authMethod != "google" && session.authMethod != "github" {
|
||||
h.renderError(c, "Invalid session type for social callback")
|
||||
return
|
||||
}
|
||||
|
||||
socialClient := NewSocialAuthClient(h.cfg)
|
||||
redirectURI := h.getSocialCallbackURL(c)
|
||||
|
||||
tokenReq := &CreateTokenRequest{
|
||||
Code: code,
|
||||
CodeVerifier: session.codeVerifier,
|
||||
RedirectURI: redirectURI,
|
||||
}
|
||||
|
||||
tokenResp, err := socialClient.CreateToken(c.Request.Context(), tokenReq)
|
||||
if err != nil {
|
||||
log.Errorf("OAuth Web: social token exchange failed: %v", err)
|
||||
h.mu.Lock()
|
||||
session.status = statusFailed
|
||||
session.error = fmt.Sprintf("Token exchange failed: %v", err)
|
||||
session.completedAt = time.Now()
|
||||
h.mu.Unlock()
|
||||
h.renderError(c, session.error)
|
||||
return
|
||||
}
|
||||
|
||||
expiresIn := tokenResp.ExpiresIn
|
||||
if expiresIn <= 0 {
|
||||
expiresIn = 3600
|
||||
}
|
||||
expiresAt := time.Now().Add(time.Duration(expiresIn) * time.Second)
|
||||
|
||||
email := ExtractEmailFromJWT(tokenResp.AccessToken)
|
||||
|
||||
var provider string
|
||||
if session.authMethod == "google" {
|
||||
provider = string(ProviderGoogle)
|
||||
} else {
|
||||
provider = string(ProviderGitHub)
|
||||
}
|
||||
|
||||
tokenData := &KiroTokenData{
|
||||
AccessToken: tokenResp.AccessToken,
|
||||
RefreshToken: tokenResp.RefreshToken,
|
||||
ProfileArn: tokenResp.ProfileArn,
|
||||
ExpiresAt: expiresAt.Format(time.RFC3339),
|
||||
AuthMethod: session.authMethod,
|
||||
Provider: provider,
|
||||
Email: email,
|
||||
Region: "us-east-1",
|
||||
}
|
||||
|
||||
h.mu.Lock()
|
||||
session.status = statusSuccess
|
||||
session.completedAt = time.Now()
|
||||
session.expiresAt = expiresAt
|
||||
session.tokenData = tokenData
|
||||
h.mu.Unlock()
|
||||
|
||||
if session.cancelFunc != nil {
|
||||
session.cancelFunc()
|
||||
}
|
||||
|
||||
if h.onTokenObtained != nil {
|
||||
h.onTokenObtained(tokenData)
|
||||
}
|
||||
|
||||
// Save token to file
|
||||
h.saveTokenToFile(tokenData)
|
||||
|
||||
log.Infof("OAuth Web: social authentication successful for %s via %s", email, provider)
|
||||
h.renderSuccess(c, session)
|
||||
}
|
||||
|
||||
func (h *OAuthWebHandler) handleStatus(c *gin.Context) {
|
||||
stateID := c.Query("state")
|
||||
if stateID == "" {
|
||||
c.JSON(http.StatusBadRequest, gin.H{"error": "missing state parameter"})
|
||||
return
|
||||
}
|
||||
|
||||
h.mu.RLock()
|
||||
session, exists := h.sessions[stateID]
|
||||
h.mu.RUnlock()
|
||||
|
||||
if !exists {
|
||||
c.JSON(http.StatusNotFound, gin.H{"error": "session not found"})
|
||||
return
|
||||
}
|
||||
|
||||
response := gin.H{
|
||||
"status": string(session.status),
|
||||
}
|
||||
|
||||
switch session.status {
|
||||
case statusPending:
|
||||
elapsed := time.Since(session.startedAt).Seconds()
|
||||
remaining := float64(session.expiresIn) - elapsed
|
||||
if remaining < 0 {
|
||||
remaining = 0
|
||||
}
|
||||
response["remaining_seconds"] = int(remaining)
|
||||
case statusSuccess:
|
||||
response["completed_at"] = session.completedAt.Format(time.RFC3339)
|
||||
response["expires_at"] = session.expiresAt.Format(time.RFC3339)
|
||||
case statusFailed:
|
||||
response["error"] = session.error
|
||||
response["failed_at"] = session.completedAt.Format(time.RFC3339)
|
||||
}
|
||||
|
||||
c.JSON(http.StatusOK, response)
|
||||
}
|
||||
|
||||
func (h *OAuthWebHandler) renderStartPage(c *gin.Context, session *webAuthSession) {
|
||||
tmpl, err := template.New("start").Parse(oauthWebStartPageHTML)
|
||||
if err != nil {
|
||||
log.Errorf("OAuth Web: failed to parse template: %v", err)
|
||||
c.String(http.StatusInternalServerError, "Template error")
|
||||
return
|
||||
}
|
||||
|
||||
data := map[string]interface{}{
|
||||
"AuthURL": session.authURL,
|
||||
"UserCode": session.userCode,
|
||||
"ExpiresIn": session.expiresIn,
|
||||
"StateID": session.stateID,
|
||||
}
|
||||
|
||||
c.Header("Content-Type", "text/html; charset=utf-8")
|
||||
if err := tmpl.Execute(c.Writer, data); err != nil {
|
||||
log.Errorf("OAuth Web: failed to render template: %v", err)
|
||||
}
|
||||
}
|
||||
|
||||
func (h *OAuthWebHandler) renderSelectPage(c *gin.Context) {
|
||||
tmpl, err := template.New("select").Parse(oauthWebSelectPageHTML)
|
||||
if err != nil {
|
||||
log.Errorf("OAuth Web: failed to parse select template: %v", err)
|
||||
c.String(http.StatusInternalServerError, "Template error")
|
||||
return
|
||||
}
|
||||
|
||||
c.Header("Content-Type", "text/html; charset=utf-8")
|
||||
if err := tmpl.Execute(c.Writer, nil); err != nil {
|
||||
log.Errorf("OAuth Web: failed to render select template: %v", err)
|
||||
}
|
||||
}
|
||||
|
||||
func (h *OAuthWebHandler) renderError(c *gin.Context, errMsg string) {
|
||||
tmpl, err := template.New("error").Parse(oauthWebErrorPageHTML)
|
||||
if err != nil {
|
||||
log.Errorf("OAuth Web: failed to parse error template: %v", err)
|
||||
c.String(http.StatusInternalServerError, "Template error")
|
||||
return
|
||||
}
|
||||
|
||||
data := map[string]interface{}{
|
||||
"Error": errMsg,
|
||||
}
|
||||
|
||||
c.Header("Content-Type", "text/html; charset=utf-8")
|
||||
c.Status(http.StatusBadRequest)
|
||||
if err := tmpl.Execute(c.Writer, data); err != nil {
|
||||
log.Errorf("OAuth Web: failed to render error template: %v", err)
|
||||
}
|
||||
}
|
||||
|
||||
func (h *OAuthWebHandler) renderSuccess(c *gin.Context, session *webAuthSession) {
|
||||
tmpl, err := template.New("success").Parse(oauthWebSuccessPageHTML)
|
||||
if err != nil {
|
||||
log.Errorf("OAuth Web: failed to parse success template: %v", err)
|
||||
c.String(http.StatusInternalServerError, "Template error")
|
||||
return
|
||||
}
|
||||
|
||||
data := map[string]interface{}{
|
||||
"ExpiresAt": session.expiresAt.Format(time.RFC3339),
|
||||
}
|
||||
|
||||
c.Header("Content-Type", "text/html; charset=utf-8")
|
||||
if err := tmpl.Execute(c.Writer, data); err != nil {
|
||||
log.Errorf("OAuth Web: failed to render success template: %v", err)
|
||||
}
|
||||
}
|
||||
|
||||
func (h *OAuthWebHandler) CleanupExpiredSessions() {
|
||||
h.mu.Lock()
|
||||
defer h.mu.Unlock()
|
||||
|
||||
now := time.Now()
|
||||
for id, session := range h.sessions {
|
||||
if session.status != statusPending && now.Sub(session.completedAt) > 30*time.Minute {
|
||||
delete(h.sessions, id)
|
||||
} else if session.status == statusPending && now.Sub(session.startedAt) > defaultSessionExpiry {
|
||||
session.cancelFunc()
|
||||
delete(h.sessions, id)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func (h *OAuthWebHandler) GetSession(stateID string) (*webAuthSession, bool) {
|
||||
h.mu.RLock()
|
||||
defer h.mu.RUnlock()
|
||||
session, exists := h.sessions[stateID]
|
||||
return session, exists
|
||||
}
|
||||
|
||||
// ImportTokenRequest represents the request body for token import
|
||||
type ImportTokenRequest struct {
|
||||
RefreshToken string `json:"refreshToken"`
|
||||
}
|
||||
|
||||
// handleImportToken handles manual refresh token import from Kiro IDE
|
||||
func (h *OAuthWebHandler) handleImportToken(c *gin.Context) {
|
||||
var req ImportTokenRequest
|
||||
if err := c.ShouldBindJSON(&req); err != nil {
|
||||
c.JSON(http.StatusBadRequest, gin.H{
|
||||
"success": false,
|
||||
"error": "Invalid request body",
|
||||
})
|
||||
return
|
||||
}
|
||||
|
||||
refreshToken := strings.TrimSpace(req.RefreshToken)
|
||||
if refreshToken == "" {
|
||||
c.JSON(http.StatusBadRequest, gin.H{
|
||||
"success": false,
|
||||
"error": "Refresh token is required",
|
||||
})
|
||||
return
|
||||
}
|
||||
|
||||
// Validate token format
|
||||
if !strings.HasPrefix(refreshToken, "aorAAAAAG") {
|
||||
c.JSON(http.StatusBadRequest, gin.H{
|
||||
"success": false,
|
||||
"error": "Invalid token format. Token should start with aorAAAAAG...",
|
||||
})
|
||||
return
|
||||
}
|
||||
|
||||
// Create social auth client to refresh and validate the token
|
||||
socialClient := NewSocialAuthClient(h.cfg)
|
||||
|
||||
// Refresh the token to validate it and get access token
|
||||
tokenData, err := socialClient.RefreshSocialToken(c.Request.Context(), refreshToken)
|
||||
if err != nil {
|
||||
log.Errorf("OAuth Web: token refresh failed during import: %v", err)
|
||||
c.JSON(http.StatusBadRequest, gin.H{
|
||||
"success": false,
|
||||
"error": fmt.Sprintf("Token validation failed: %v", err),
|
||||
})
|
||||
return
|
||||
}
|
||||
|
||||
// Set the original refresh token (the refreshed one might be empty)
|
||||
if tokenData.RefreshToken == "" {
|
||||
tokenData.RefreshToken = refreshToken
|
||||
}
|
||||
tokenData.AuthMethod = "social"
|
||||
tokenData.Provider = "imported"
|
||||
|
||||
// Notify callback if set
|
||||
if h.onTokenObtained != nil {
|
||||
h.onTokenObtained(tokenData)
|
||||
}
|
||||
|
||||
// Save token to file
|
||||
h.saveTokenToFile(tokenData)
|
||||
|
||||
// Generate filename for response
|
||||
fileName := fmt.Sprintf("kiro-%s.json", tokenData.AuthMethod)
|
||||
if tokenData.Email != "" {
|
||||
sanitizedEmail := strings.ReplaceAll(tokenData.Email, "@", "-")
|
||||
sanitizedEmail = strings.ReplaceAll(sanitizedEmail, ".", "-")
|
||||
fileName = fmt.Sprintf("kiro-%s-%s.json", tokenData.AuthMethod, sanitizedEmail)
|
||||
}
|
||||
|
||||
log.Infof("OAuth Web: token imported successfully")
|
||||
c.JSON(http.StatusOK, gin.H{
|
||||
"success": true,
|
||||
"message": "Token imported successfully",
|
||||
"fileName": fileName,
|
||||
})
|
||||
}
|
||||
385
internal/auth/kiro/oauth_web.go.bak
Normal file
385
internal/auth/kiro/oauth_web.go.bak
Normal file
@@ -0,0 +1,385 @@
|
||||
// Package kiro provides OAuth Web authentication for Kiro.
|
||||
package kiro
|
||||
|
||||
import (
|
||||
"context"
|
||||
"crypto/rand"
|
||||
"encoding/base64"
|
||||
"fmt"
|
||||
"html/template"
|
||||
"net/http"
|
||||
"sync"
|
||||
"time"
|
||||
|
||||
"github.com/gin-gonic/gin"
|
||||
"github.com/router-for-me/CLIProxyAPI/v6/internal/config"
|
||||
log "github.com/sirupsen/logrus"
|
||||
)
|
||||
|
||||
const (
|
||||
defaultSessionExpiry = 10 * time.Minute
|
||||
pollIntervalSeconds = 5
|
||||
)
|
||||
|
||||
type authSessionStatus string
|
||||
|
||||
const (
|
||||
statusPending authSessionStatus = "pending"
|
||||
statusSuccess authSessionStatus = "success"
|
||||
statusFailed authSessionStatus = "failed"
|
||||
)
|
||||
|
||||
type webAuthSession struct {
|
||||
stateID string
|
||||
deviceCode string
|
||||
userCode string
|
||||
authURL string
|
||||
verificationURI string
|
||||
expiresIn int
|
||||
interval int
|
||||
status authSessionStatus
|
||||
startedAt time.Time
|
||||
completedAt time.Time
|
||||
expiresAt time.Time
|
||||
error string
|
||||
tokenData *KiroTokenData
|
||||
ssoClient *SSOOIDCClient
|
||||
clientID string
|
||||
clientSecret string
|
||||
region string
|
||||
cancelFunc context.CancelFunc
|
||||
}
|
||||
|
||||
type OAuthWebHandler struct {
|
||||
cfg *config.Config
|
||||
sessions map[string]*webAuthSession
|
||||
mu sync.RWMutex
|
||||
onTokenObtained func(*KiroTokenData)
|
||||
}
|
||||
|
||||
func NewOAuthWebHandler(cfg *config.Config) *OAuthWebHandler {
|
||||
return &OAuthWebHandler{
|
||||
cfg: cfg,
|
||||
sessions: make(map[string]*webAuthSession),
|
||||
}
|
||||
}
|
||||
|
||||
func (h *OAuthWebHandler) SetTokenCallback(callback func(*KiroTokenData)) {
|
||||
h.onTokenObtained = callback
|
||||
}
|
||||
|
||||
func (h *OAuthWebHandler) RegisterRoutes(router gin.IRouter) {
|
||||
oauth := router.Group("/v0/oauth/kiro")
|
||||
{
|
||||
oauth.GET("/start", h.handleStart)
|
||||
oauth.GET("/callback", h.handleCallback)
|
||||
oauth.GET("/status", h.handleStatus)
|
||||
}
|
||||
}
|
||||
|
||||
func generateStateID() (string, error) {
|
||||
b := make([]byte, 16)
|
||||
if _, err := rand.Read(b); err != nil {
|
||||
return "", err
|
||||
}
|
||||
return base64.RawURLEncoding.EncodeToString(b), nil
|
||||
}
|
||||
|
||||
func (h *OAuthWebHandler) handleStart(c *gin.Context) {
|
||||
stateID, err := generateStateID()
|
||||
if err != nil {
|
||||
h.renderError(c, "Failed to generate state parameter")
|
||||
return
|
||||
}
|
||||
|
||||
region := defaultIDCRegion
|
||||
startURL := builderIDStartURL
|
||||
|
||||
ssoClient := NewSSOOIDCClient(h.cfg)
|
||||
|
||||
regResp, err := ssoClient.RegisterClientWithRegion(c.Request.Context(), region)
|
||||
if err != nil {
|
||||
log.Errorf("OAuth Web: failed to register client: %v", err)
|
||||
h.renderError(c, fmt.Sprintf("Failed to register client: %v", err))
|
||||
return
|
||||
}
|
||||
|
||||
authResp, err := ssoClient.StartDeviceAuthorizationWithIDC(
|
||||
c.Request.Context(),
|
||||
regResp.ClientID,
|
||||
regResp.ClientSecret,
|
||||
startURL,
|
||||
region,
|
||||
)
|
||||
if err != nil {
|
||||
log.Errorf("OAuth Web: failed to start device authorization: %v", err)
|
||||
h.renderError(c, fmt.Sprintf("Failed to start device authorization: %v", err))
|
||||
return
|
||||
}
|
||||
|
||||
ctx, cancel := context.WithTimeout(context.Background(), time.Duration(authResp.ExpiresIn)*time.Second)
|
||||
|
||||
session := &webAuthSession{
|
||||
stateID: stateID,
|
||||
deviceCode: authResp.DeviceCode,
|
||||
userCode: authResp.UserCode,
|
||||
authURL: authResp.VerificationURIComplete,
|
||||
verificationURI: authResp.VerificationURI,
|
||||
expiresIn: authResp.ExpiresIn,
|
||||
interval: authResp.Interval,
|
||||
status: statusPending,
|
||||
startedAt: time.Now(),
|
||||
ssoClient: ssoClient,
|
||||
clientID: regResp.ClientID,
|
||||
clientSecret: regResp.ClientSecret,
|
||||
region: region,
|
||||
cancelFunc: cancel,
|
||||
}
|
||||
|
||||
h.mu.Lock()
|
||||
h.sessions[stateID] = session
|
||||
h.mu.Unlock()
|
||||
|
||||
go h.pollForToken(ctx, session)
|
||||
|
||||
h.renderStartPage(c, session)
|
||||
}
|
||||
|
||||
func (h *OAuthWebHandler) pollForToken(ctx context.Context, session *webAuthSession) {
|
||||
defer session.cancelFunc()
|
||||
|
||||
interval := time.Duration(session.interval) * time.Second
|
||||
if interval < time.Duration(pollIntervalSeconds)*time.Second {
|
||||
interval = time.Duration(pollIntervalSeconds) * time.Second
|
||||
}
|
||||
|
||||
ticker := time.NewTicker(interval)
|
||||
defer ticker.Stop()
|
||||
|
||||
for {
|
||||
select {
|
||||
case <-ctx.Done():
|
||||
h.mu.Lock()
|
||||
if session.status == statusPending {
|
||||
session.status = statusFailed
|
||||
session.error = "Authentication timed out"
|
||||
}
|
||||
h.mu.Unlock()
|
||||
return
|
||||
case <-ticker.C:
|
||||
tokenResp, err := h.ssoClient(session).CreateTokenWithRegion(
|
||||
ctx,
|
||||
session.clientID,
|
||||
session.clientSecret,
|
||||
session.deviceCode,
|
||||
session.region,
|
||||
)
|
||||
|
||||
if err != nil {
|
||||
errStr := err.Error()
|
||||
if errStr == ErrAuthorizationPending.Error() {
|
||||
continue
|
||||
}
|
||||
if errStr == ErrSlowDown.Error() {
|
||||
interval += 5 * time.Second
|
||||
ticker.Reset(interval)
|
||||
continue
|
||||
}
|
||||
|
||||
h.mu.Lock()
|
||||
session.status = statusFailed
|
||||
session.error = errStr
|
||||
session.completedAt = time.Now()
|
||||
h.mu.Unlock()
|
||||
|
||||
log.Errorf("OAuth Web: token polling failed: %v", err)
|
||||
return
|
||||
}
|
||||
|
||||
expiresAt := time.Now().Add(time.Duration(tokenResp.ExpiresIn) * time.Second)
|
||||
profileArn := session.ssoClient.fetchProfileArn(ctx, tokenResp.AccessToken)
|
||||
email := FetchUserEmailWithFallback(ctx, h.cfg, tokenResp.AccessToken)
|
||||
|
||||
tokenData := &KiroTokenData{
|
||||
AccessToken: tokenResp.AccessToken,
|
||||
RefreshToken: tokenResp.RefreshToken,
|
||||
ProfileArn: profileArn,
|
||||
ExpiresAt: expiresAt.Format(time.RFC3339),
|
||||
AuthMethod: "builder-id",
|
||||
Provider: "AWS",
|
||||
ClientID: session.clientID,
|
||||
ClientSecret: session.clientSecret,
|
||||
Email: email,
|
||||
}
|
||||
|
||||
h.mu.Lock()
|
||||
session.status = statusSuccess
|
||||
session.completedAt = time.Now()
|
||||
session.expiresAt = expiresAt
|
||||
session.tokenData = tokenData
|
||||
h.mu.Unlock()
|
||||
|
||||
if h.onTokenObtained != nil {
|
||||
h.onTokenObtained(tokenData)
|
||||
}
|
||||
|
||||
log.Infof("OAuth Web: authentication successful for %s", email)
|
||||
return
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func (h *OAuthWebHandler) ssoClient(session *webAuthSession) *SSOOIDCClient {
|
||||
return session.ssoClient
|
||||
}
|
||||
|
||||
func (h *OAuthWebHandler) handleCallback(c *gin.Context) {
|
||||
stateID := c.Query("state")
|
||||
errParam := c.Query("error")
|
||||
|
||||
if errParam != "" {
|
||||
h.renderError(c, errParam)
|
||||
return
|
||||
}
|
||||
|
||||
if stateID == "" {
|
||||
h.renderError(c, "Missing state parameter")
|
||||
return
|
||||
}
|
||||
|
||||
h.mu.RLock()
|
||||
session, exists := h.sessions[stateID]
|
||||
h.mu.RUnlock()
|
||||
|
||||
if !exists {
|
||||
h.renderError(c, "Invalid or expired session")
|
||||
return
|
||||
}
|
||||
|
||||
if session.status == statusSuccess {
|
||||
h.renderSuccess(c, session)
|
||||
} else if session.status == statusFailed {
|
||||
h.renderError(c, session.error)
|
||||
} else {
|
||||
c.Redirect(http.StatusFound, "/v0/oauth/kiro/start")
|
||||
}
|
||||
}
|
||||
|
||||
func (h *OAuthWebHandler) handleStatus(c *gin.Context) {
|
||||
stateID := c.Query("state")
|
||||
if stateID == "" {
|
||||
c.JSON(http.StatusBadRequest, gin.H{"error": "missing state parameter"})
|
||||
return
|
||||
}
|
||||
|
||||
h.mu.RLock()
|
||||
session, exists := h.sessions[stateID]
|
||||
h.mu.RUnlock()
|
||||
|
||||
if !exists {
|
||||
c.JSON(http.StatusNotFound, gin.H{"error": "session not found"})
|
||||
return
|
||||
}
|
||||
|
||||
response := gin.H{
|
||||
"status": string(session.status),
|
||||
}
|
||||
|
||||
switch session.status {
|
||||
case statusPending:
|
||||
elapsed := time.Since(session.startedAt).Seconds()
|
||||
remaining := float64(session.expiresIn) - elapsed
|
||||
if remaining < 0 {
|
||||
remaining = 0
|
||||
}
|
||||
response["remaining_seconds"] = int(remaining)
|
||||
case statusSuccess:
|
||||
response["completed_at"] = session.completedAt.Format(time.RFC3339)
|
||||
response["expires_at"] = session.expiresAt.Format(time.RFC3339)
|
||||
case statusFailed:
|
||||
response["error"] = session.error
|
||||
response["failed_at"] = session.completedAt.Format(time.RFC3339)
|
||||
}
|
||||
|
||||
c.JSON(http.StatusOK, response)
|
||||
}
|
||||
|
||||
func (h *OAuthWebHandler) renderStartPage(c *gin.Context, session *webAuthSession) {
|
||||
tmpl, err := template.New("start").Parse(oauthWebStartPageHTML)
|
||||
if err != nil {
|
||||
log.Errorf("OAuth Web: failed to parse template: %v", err)
|
||||
c.String(http.StatusInternalServerError, "Template error")
|
||||
return
|
||||
}
|
||||
|
||||
data := map[string]interface{}{
|
||||
"AuthURL": session.authURL,
|
||||
"UserCode": session.userCode,
|
||||
"ExpiresIn": session.expiresIn,
|
||||
"StateID": session.stateID,
|
||||
}
|
||||
|
||||
c.Header("Content-Type", "text/html; charset=utf-8")
|
||||
if err := tmpl.Execute(c.Writer, data); err != nil {
|
||||
log.Errorf("OAuth Web: failed to render template: %v", err)
|
||||
}
|
||||
}
|
||||
|
||||
func (h *OAuthWebHandler) renderError(c *gin.Context, errMsg string) {
|
||||
tmpl, err := template.New("error").Parse(oauthWebErrorPageHTML)
|
||||
if err != nil {
|
||||
log.Errorf("OAuth Web: failed to parse error template: %v", err)
|
||||
c.String(http.StatusInternalServerError, "Template error")
|
||||
return
|
||||
}
|
||||
|
||||
data := map[string]interface{}{
|
||||
"Error": errMsg,
|
||||
}
|
||||
|
||||
c.Header("Content-Type", "text/html; charset=utf-8")
|
||||
c.Status(http.StatusBadRequest)
|
||||
if err := tmpl.Execute(c.Writer, data); err != nil {
|
||||
log.Errorf("OAuth Web: failed to render error template: %v", err)
|
||||
}
|
||||
}
|
||||
|
||||
func (h *OAuthWebHandler) renderSuccess(c *gin.Context, session *webAuthSession) {
|
||||
tmpl, err := template.New("success").Parse(oauthWebSuccessPageHTML)
|
||||
if err != nil {
|
||||
log.Errorf("OAuth Web: failed to parse success template: %v", err)
|
||||
c.String(http.StatusInternalServerError, "Template error")
|
||||
return
|
||||
}
|
||||
|
||||
data := map[string]interface{}{
|
||||
"ExpiresAt": session.expiresAt.Format(time.RFC3339),
|
||||
}
|
||||
|
||||
c.Header("Content-Type", "text/html; charset=utf-8")
|
||||
if err := tmpl.Execute(c.Writer, data); err != nil {
|
||||
log.Errorf("OAuth Web: failed to render success template: %v", err)
|
||||
}
|
||||
}
|
||||
|
||||
func (h *OAuthWebHandler) CleanupExpiredSessions() {
|
||||
h.mu.Lock()
|
||||
defer h.mu.Unlock()
|
||||
|
||||
now := time.Now()
|
||||
for id, session := range h.sessions {
|
||||
if session.status != statusPending && now.Sub(session.completedAt) > 30*time.Minute {
|
||||
delete(h.sessions, id)
|
||||
} else if session.status == statusPending && now.Sub(session.startedAt) > defaultSessionExpiry {
|
||||
session.cancelFunc()
|
||||
delete(h.sessions, id)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func (h *OAuthWebHandler) GetSession(stateID string) (*webAuthSession, bool) {
|
||||
h.mu.RLock()
|
||||
defer h.mu.RUnlock()
|
||||
session, exists := h.sessions[stateID]
|
||||
return session, exists
|
||||
}
|
||||
732
internal/auth/kiro/oauth_web_templates.go
Normal file
732
internal/auth/kiro/oauth_web_templates.go
Normal file
@@ -0,0 +1,732 @@
|
||||
// Package kiro provides OAuth Web authentication templates.
|
||||
package kiro
|
||||
|
||||
const (
|
||||
oauthWebStartPageHTML = `<!DOCTYPE html>
|
||||
<html>
|
||||
<head>
|
||||
<meta charset="UTF-8">
|
||||
<meta name="viewport" content="width=device-width, initial-scale=1.0">
|
||||
<title>AWS SSO Authentication</title>
|
||||
<style>
|
||||
* { box-sizing: border-box; }
|
||||
body {
|
||||
font-family: -apple-system, BlinkMacSystemFont, "Segoe UI", Roboto, sans-serif;
|
||||
margin: 0;
|
||||
padding: 20px;
|
||||
background: linear-gradient(135deg, #667eea 0%, #764ba2 100%);
|
||||
min-height: 100vh;
|
||||
display: flex;
|
||||
justify-content: center;
|
||||
align-items: center;
|
||||
}
|
||||
.container {
|
||||
max-width: 500px;
|
||||
width: 100%;
|
||||
background: #fff;
|
||||
padding: 40px;
|
||||
border-radius: 12px;
|
||||
box-shadow: 0 10px 40px rgba(0,0,0,0.2);
|
||||
}
|
||||
h1 {
|
||||
margin: 0 0 10px;
|
||||
color: #333;
|
||||
font-size: 24px;
|
||||
text-align: center;
|
||||
}
|
||||
.subtitle {
|
||||
text-align: center;
|
||||
color: #666;
|
||||
margin-bottom: 30px;
|
||||
}
|
||||
.step {
|
||||
background: #f8f9fa;
|
||||
padding: 20px;
|
||||
border-radius: 8px;
|
||||
margin-bottom: 15px;
|
||||
}
|
||||
.step-title {
|
||||
display: flex;
|
||||
align-items: center;
|
||||
font-weight: 600;
|
||||
color: #333;
|
||||
margin-bottom: 10px;
|
||||
}
|
||||
.step-number {
|
||||
width: 28px;
|
||||
height: 28px;
|
||||
background: #667eea;
|
||||
color: white;
|
||||
border-radius: 50%;
|
||||
display: flex;
|
||||
align-items: center;
|
||||
justify-content: center;
|
||||
font-size: 14px;
|
||||
margin-right: 12px;
|
||||
}
|
||||
.user-code {
|
||||
background: #e7f3ff;
|
||||
border: 2px dashed #2196F3;
|
||||
border-radius: 8px;
|
||||
padding: 20px;
|
||||
text-align: center;
|
||||
margin-top: 10px;
|
||||
}
|
||||
.user-code-label {
|
||||
font-size: 12px;
|
||||
color: #666;
|
||||
text-transform: uppercase;
|
||||
letter-spacing: 1px;
|
||||
margin-bottom: 8px;
|
||||
}
|
||||
.user-code-value {
|
||||
font-size: 32px;
|
||||
font-weight: bold;
|
||||
font-family: monospace;
|
||||
color: #2196F3;
|
||||
letter-spacing: 4px;
|
||||
}
|
||||
.auth-btn {
|
||||
display: block;
|
||||
width: 100%;
|
||||
padding: 15px;
|
||||
background: #667eea;
|
||||
color: white;
|
||||
text-align: center;
|
||||
text-decoration: none;
|
||||
border-radius: 8px;
|
||||
font-weight: 600;
|
||||
font-size: 16px;
|
||||
transition: all 0.3s;
|
||||
border: none;
|
||||
cursor: pointer;
|
||||
margin-top: 20px;
|
||||
}
|
||||
.auth-btn:hover {
|
||||
background: #5568d3;
|
||||
transform: translateY(-2px);
|
||||
box-shadow: 0 4px 12px rgba(102, 126, 234, 0.4);
|
||||
}
|
||||
.status {
|
||||
margin-top: 30px;
|
||||
padding: 20px;
|
||||
background: #f8f9fa;
|
||||
border-radius: 8px;
|
||||
text-align: center;
|
||||
}
|
||||
.status-pending { border-left: 4px solid #ffc107; }
|
||||
.status-success { border-left: 4px solid #28a745; }
|
||||
.status-failed { border-left: 4px solid #dc3545; }
|
||||
.spinner {
|
||||
border: 3px solid #f3f3f3;
|
||||
border-top: 3px solid #667eea;
|
||||
border-radius: 50%;
|
||||
width: 40px;
|
||||
height: 40px;
|
||||
animation: spin 1s linear infinite;
|
||||
margin: 0 auto 15px;
|
||||
}
|
||||
@keyframes spin {
|
||||
0% { transform: rotate(0deg); }
|
||||
100% { transform: rotate(360deg); }
|
||||
}
|
||||
.timer {
|
||||
font-size: 24px;
|
||||
font-weight: bold;
|
||||
color: #667eea;
|
||||
margin: 10px 0;
|
||||
}
|
||||
.timer.warning { color: #ffc107; }
|
||||
.timer.danger { color: #dc3545; }
|
||||
.status-message { color: #666; line-height: 1.6; }
|
||||
.success-icon, .error-icon { font-size: 48px; margin-bottom: 15px; }
|
||||
.info-box {
|
||||
background: #e7f3ff;
|
||||
border-left: 4px solid #2196F3;
|
||||
padding: 15px;
|
||||
margin-top: 20px;
|
||||
border-radius: 4px;
|
||||
font-size: 14px;
|
||||
color: #666;
|
||||
}
|
||||
</style>
|
||||
</head>
|
||||
<body>
|
||||
<div class="container">
|
||||
<h1>🔐 AWS SSO Authentication</h1>
|
||||
<p class="subtitle">Follow the steps below to complete authentication</p>
|
||||
|
||||
<div class="step">
|
||||
<div class="step-title">
|
||||
<span class="step-number">1</span>
|
||||
Click the button below to open the authorization page
|
||||
</div>
|
||||
<a href="{{.AuthURL}}" target="_blank" class="auth-btn" id="authBtn">
|
||||
🚀 Open Authorization Page
|
||||
</a>
|
||||
</div>
|
||||
|
||||
<div class="step">
|
||||
<div class="step-title">
|
||||
<span class="step-number">2</span>
|
||||
Enter the verification code below
|
||||
</div>
|
||||
<div class="user-code">
|
||||
<div class="user-code-label">Verification Code</div>
|
||||
<div class="user-code-value">{{.UserCode}}</div>
|
||||
</div>
|
||||
</div>
|
||||
|
||||
<div class="step">
|
||||
<div class="step-title">
|
||||
<span class="step-number">3</span>
|
||||
Complete AWS SSO login
|
||||
</div>
|
||||
<p style="color: #666; font-size: 14px; margin-top: 10px;">
|
||||
Use your AWS SSO account to login and authorize
|
||||
</p>
|
||||
</div>
|
||||
|
||||
<div class="status status-pending" id="statusBox">
|
||||
<div class="spinner" id="spinner"></div>
|
||||
<div class="timer" id="timer">{{.ExpiresIn}}s</div>
|
||||
<div class="status-message" id="statusMessage">
|
||||
Waiting for authorization...
|
||||
</div>
|
||||
</div>
|
||||
|
||||
<div class="info-box">
|
||||
💡 <strong>Tip:</strong> The authorization page will open in a new tab. This page will automatically update once authorization is complete.
|
||||
</div>
|
||||
</div>
|
||||
|
||||
<script>
|
||||
let pollInterval;
|
||||
let timerInterval;
|
||||
let remainingSeconds = {{.ExpiresIn}};
|
||||
const stateID = "{{.StateID}}";
|
||||
|
||||
setTimeout(() => {
|
||||
document.getElementById('authBtn').click();
|
||||
}, 500);
|
||||
|
||||
function pollStatus() {
|
||||
fetch('/v0/oauth/kiro/status?state=' + stateID)
|
||||
.then(response => response.json())
|
||||
.then(data => {
|
||||
console.log('Status:', data);
|
||||
if (data.status === 'success') {
|
||||
clearInterval(pollInterval);
|
||||
clearInterval(timerInterval);
|
||||
showSuccess(data);
|
||||
} else if (data.status === 'failed') {
|
||||
clearInterval(pollInterval);
|
||||
clearInterval(timerInterval);
|
||||
showError(data);
|
||||
} else {
|
||||
remainingSeconds = data.remaining_seconds || 0;
|
||||
}
|
||||
})
|
||||
.catch(error => {
|
||||
console.error('Poll error:', error);
|
||||
});
|
||||
}
|
||||
|
||||
function updateTimer() {
|
||||
const timerEl = document.getElementById('timer');
|
||||
const minutes = Math.floor(remainingSeconds / 60);
|
||||
const seconds = remainingSeconds % 60;
|
||||
timerEl.textContent = minutes + ':' + seconds.toString().padStart(2, '0');
|
||||
|
||||
if (remainingSeconds < 60) {
|
||||
timerEl.className = 'timer danger';
|
||||
} else if (remainingSeconds < 180) {
|
||||
timerEl.className = 'timer warning';
|
||||
} else {
|
||||
timerEl.className = 'timer';
|
||||
}
|
||||
|
||||
remainingSeconds--;
|
||||
|
||||
if (remainingSeconds < 0) {
|
||||
clearInterval(timerInterval);
|
||||
clearInterval(pollInterval);
|
||||
showError({ error: 'Authentication timed out. Please refresh and try again.' });
|
||||
}
|
||||
}
|
||||
|
||||
function showSuccess(data) {
|
||||
const statusBox = document.getElementById('statusBox');
|
||||
statusBox.className = 'status status-success';
|
||||
statusBox.innerHTML = '<div class="success-icon">✅</div>' +
|
||||
'<div class="status-message">' +
|
||||
'<strong>Authentication Successful!</strong><br>' +
|
||||
'Token expires: ' + new Date(data.expires_at).toLocaleString() +
|
||||
'</div>';
|
||||
}
|
||||
|
||||
function showError(data) {
|
||||
const statusBox = document.getElementById('statusBox');
|
||||
statusBox.className = 'status status-failed';
|
||||
statusBox.innerHTML = '<div class="error-icon">❌</div>' +
|
||||
'<div class="status-message">' +
|
||||
'<strong>Authentication Failed</strong><br>' +
|
||||
(data.error || 'Unknown error') +
|
||||
'</div>' +
|
||||
'<button class="auth-btn" onclick="location.reload()" style="margin-top: 15px;">' +
|
||||
'🔄 Retry' +
|
||||
'</button>';
|
||||
}
|
||||
|
||||
pollInterval = setInterval(pollStatus, 3000);
|
||||
timerInterval = setInterval(updateTimer, 1000);
|
||||
pollStatus();
|
||||
</script>
|
||||
</body>
|
||||
</html>`
|
||||
|
||||
oauthWebErrorPageHTML = `<!DOCTYPE html>
|
||||
<html>
|
||||
<head>
|
||||
<meta charset="UTF-8">
|
||||
<meta name="viewport" content="width=device-width, initial-scale=1.0">
|
||||
<title>Authentication Failed</title>
|
||||
<style>
|
||||
body {
|
||||
font-family: -apple-system, BlinkMacSystemFont, "Segoe UI", Roboto, sans-serif;
|
||||
max-width: 600px;
|
||||
margin: 50px auto;
|
||||
padding: 20px;
|
||||
background: #f5f5f5;
|
||||
}
|
||||
.error {
|
||||
background: #fff;
|
||||
padding: 30px;
|
||||
border-radius: 8px;
|
||||
box-shadow: 0 2px 4px rgba(0,0,0,0.1);
|
||||
border-left: 4px solid #dc3545;
|
||||
}
|
||||
h1 { color: #dc3545; margin-top: 0; }
|
||||
.error-message { color: #666; line-height: 1.6; }
|
||||
.retry-btn {
|
||||
display: inline-block;
|
||||
margin-top: 20px;
|
||||
padding: 10px 20px;
|
||||
background: #007bff;
|
||||
color: white;
|
||||
text-decoration: none;
|
||||
border-radius: 4px;
|
||||
}
|
||||
.retry-btn:hover { background: #0056b3; }
|
||||
</style>
|
||||
</head>
|
||||
<body>
|
||||
<div class="error">
|
||||
<h1>❌ Authentication Failed</h1>
|
||||
<div class="error-message">
|
||||
<p><strong>Error:</strong></p>
|
||||
<p>{{.Error}}</p>
|
||||
</div>
|
||||
<a href="/v0/oauth/kiro/start" class="retry-btn">🔄 Retry</a>
|
||||
</div>
|
||||
</body>
|
||||
</html>`
|
||||
|
||||
oauthWebSuccessPageHTML = `<!DOCTYPE html>
|
||||
<html>
|
||||
<head>
|
||||
<meta charset="UTF-8">
|
||||
<meta name="viewport" content="width=device-width, initial-scale=1.0">
|
||||
<title>Authentication Successful</title>
|
||||
<style>
|
||||
body {
|
||||
font-family: -apple-system, BlinkMacSystemFont, "Segoe UI", Roboto, sans-serif;
|
||||
max-width: 600px;
|
||||
margin: 50px auto;
|
||||
padding: 20px;
|
||||
background: #f5f5f5;
|
||||
}
|
||||
.success {
|
||||
background: #fff;
|
||||
padding: 30px;
|
||||
border-radius: 8px;
|
||||
box-shadow: 0 2px 4px rgba(0,0,0,0.1);
|
||||
border-left: 4px solid #28a745;
|
||||
text-align: center;
|
||||
}
|
||||
h1 { color: #28a745; margin-top: 0; }
|
||||
.success-message { color: #666; line-height: 1.6; }
|
||||
.icon { font-size: 48px; margin-bottom: 15px; }
|
||||
.expires { font-size: 14px; color: #999; margin-top: 15px; }
|
||||
</style>
|
||||
</head>
|
||||
<body>
|
||||
<div class="success">
|
||||
<div class="icon">✅</div>
|
||||
<h1>Authentication Successful!</h1>
|
||||
<div class="success-message">
|
||||
<p>You can close this window.</p>
|
||||
</div>
|
||||
<div class="expires">Token expires: {{.ExpiresAt}}</div>
|
||||
</div>
|
||||
</body>
|
||||
</html>`
|
||||
|
||||
oauthWebSelectPageHTML = `<!DOCTYPE html>
|
||||
<html>
|
||||
<head>
|
||||
<meta charset="UTF-8">
|
||||
<meta name="viewport" content="width=device-width, initial-scale=1.0">
|
||||
<title>Select Authentication Method</title>
|
||||
<style>
|
||||
* { box-sizing: border-box; }
|
||||
body {
|
||||
font-family: -apple-system, BlinkMacSystemFont, "Segoe UI", Roboto, sans-serif;
|
||||
margin: 0;
|
||||
padding: 20px;
|
||||
background: linear-gradient(135deg, #667eea 0%, #764ba2 100%);
|
||||
min-height: 100vh;
|
||||
display: flex;
|
||||
justify-content: center;
|
||||
align-items: center;
|
||||
}
|
||||
.container {
|
||||
max-width: 500px;
|
||||
width: 100%;
|
||||
background: #fff;
|
||||
padding: 40px;
|
||||
border-radius: 12px;
|
||||
box-shadow: 0 10px 40px rgba(0,0,0,0.2);
|
||||
}
|
||||
h1 {
|
||||
margin: 0 0 10px;
|
||||
color: #333;
|
||||
font-size: 24px;
|
||||
text-align: center;
|
||||
}
|
||||
.subtitle {
|
||||
text-align: center;
|
||||
color: #666;
|
||||
margin-bottom: 30px;
|
||||
}
|
||||
.auth-methods {
|
||||
display: flex;
|
||||
flex-direction: column;
|
||||
gap: 15px;
|
||||
}
|
||||
.auth-btn {
|
||||
display: flex;
|
||||
align-items: center;
|
||||
width: 100%;
|
||||
padding: 15px 20px;
|
||||
background: #667eea;
|
||||
color: white;
|
||||
text-decoration: none;
|
||||
border-radius: 8px;
|
||||
font-weight: 600;
|
||||
font-size: 16px;
|
||||
transition: all 0.3s;
|
||||
border: none;
|
||||
cursor: pointer;
|
||||
}
|
||||
.auth-btn:hover {
|
||||
background: #5568d3;
|
||||
transform: translateY(-2px);
|
||||
box-shadow: 0 4px 12px rgba(102, 126, 234, 0.4);
|
||||
}
|
||||
.auth-btn .icon {
|
||||
font-size: 24px;
|
||||
margin-right: 15px;
|
||||
width: 32px;
|
||||
text-align: center;
|
||||
}
|
||||
.auth-btn.google { background: #4285F4; }
|
||||
.auth-btn.google:hover { background: #3367D6; }
|
||||
.auth-btn.github { background: #24292e; }
|
||||
.auth-btn.github:hover { background: #1a1e22; }
|
||||
.auth-btn.aws { background: #FF9900; }
|
||||
.auth-btn.aws:hover { background: #E68A00; }
|
||||
.auth-btn.idc { background: #232F3E; }
|
||||
.auth-btn.idc:hover { background: #1a242f; }
|
||||
.idc-form {
|
||||
background: #f8f9fa;
|
||||
padding: 20px;
|
||||
border-radius: 8px;
|
||||
margin-top: 15px;
|
||||
display: none;
|
||||
}
|
||||
.idc-form.show {
|
||||
display: block;
|
||||
}
|
||||
.form-group {
|
||||
margin-bottom: 15px;
|
||||
}
|
||||
.form-group label {
|
||||
display: block;
|
||||
font-weight: 600;
|
||||
color: #333;
|
||||
margin-bottom: 8px;
|
||||
font-size: 14px;
|
||||
}
|
||||
.form-group input {
|
||||
width: 100%;
|
||||
padding: 12px;
|
||||
border: 2px solid #e0e0e0;
|
||||
border-radius: 6px;
|
||||
font-size: 14px;
|
||||
transition: border-color 0.3s;
|
||||
}
|
||||
.form-group input:focus {
|
||||
outline: none;
|
||||
border-color: #667eea;
|
||||
}
|
||||
.form-group .hint {
|
||||
font-size: 12px;
|
||||
color: #999;
|
||||
margin-top: 5px;
|
||||
}
|
||||
.submit-btn {
|
||||
display: block;
|
||||
width: 100%;
|
||||
padding: 15px;
|
||||
background: #232F3E;
|
||||
color: white;
|
||||
text-align: center;
|
||||
text-decoration: none;
|
||||
border-radius: 8px;
|
||||
font-weight: 600;
|
||||
font-size: 16px;
|
||||
transition: all 0.3s;
|
||||
border: none;
|
||||
cursor: pointer;
|
||||
}
|
||||
.submit-btn:hover {
|
||||
background: #1a242f;
|
||||
transform: translateY(-2px);
|
||||
box-shadow: 0 4px 12px rgba(35, 47, 62, 0.4);
|
||||
}
|
||||
.divider {
|
||||
display: flex;
|
||||
align-items: center;
|
||||
margin: 20px 0;
|
||||
}
|
||||
.divider::before,
|
||||
.divider::after {
|
||||
content: "";
|
||||
flex: 1;
|
||||
border-bottom: 1px solid #e0e0e0;
|
||||
}
|
||||
.divider span {
|
||||
padding: 0 15px;
|
||||
color: #999;
|
||||
font-size: 14px;
|
||||
}
|
||||
.info-box {
|
||||
background: #e7f3ff;
|
||||
border-left: 4px solid #2196F3;
|
||||
padding: 15px;
|
||||
margin-top: 20px;
|
||||
border-radius: 4px;
|
||||
font-size: 14px;
|
||||
color: #666;
|
||||
}
|
||||
.warning-box {
|
||||
background: #fff3cd;
|
||||
border-left: 4px solid #ffc107;
|
||||
padding: 15px;
|
||||
margin-top: 20px;
|
||||
border-radius: 4px;
|
||||
font-size: 14px;
|
||||
color: #856404;
|
||||
}
|
||||
.auth-btn.manual { background: #6c757d; }
|
||||
.auth-btn.manual:hover { background: #5a6268; }
|
||||
.manual-form {
|
||||
background: #f8f9fa;
|
||||
padding: 20px;
|
||||
border-radius: 8px;
|
||||
margin-top: 15px;
|
||||
display: none;
|
||||
}
|
||||
.manual-form.show {
|
||||
display: block;
|
||||
}
|
||||
.form-group textarea {
|
||||
width: 100%;
|
||||
padding: 12px;
|
||||
border: 2px solid #e0e0e0;
|
||||
border-radius: 6px;
|
||||
font-size: 14px;
|
||||
font-family: monospace;
|
||||
transition: border-color 0.3s;
|
||||
resize: vertical;
|
||||
min-height: 80px;
|
||||
}
|
||||
.form-group textarea:focus {
|
||||
outline: none;
|
||||
border-color: #667eea;
|
||||
}
|
||||
.status-message {
|
||||
padding: 15px;
|
||||
border-radius: 6px;
|
||||
margin-top: 15px;
|
||||
display: none;
|
||||
}
|
||||
.status-message.success {
|
||||
background: #d4edda;
|
||||
color: #155724;
|
||||
display: block;
|
||||
}
|
||||
.status-message.error {
|
||||
background: #f8d7da;
|
||||
color: #721c24;
|
||||
display: block;
|
||||
}
|
||||
</style>
|
||||
</head>
|
||||
<body>
|
||||
<div class="container">
|
||||
<h1>🔐 Select Authentication Method</h1>
|
||||
<p class="subtitle">Choose how you want to authenticate with Kiro</p>
|
||||
|
||||
<div class="auth-methods">
|
||||
<a href="/v0/oauth/kiro/start?method=builder-id" class="auth-btn aws">
|
||||
<span class="icon">🔶</span>
|
||||
AWS Builder ID (Recommended)
|
||||
</a>
|
||||
|
||||
<button type="button" class="auth-btn idc" onclick="toggleIdcForm()">
|
||||
<span class="icon">🏢</span>
|
||||
AWS Identity Center (IDC)
|
||||
</button>
|
||||
|
||||
<div class="divider"><span>or</span></div>
|
||||
|
||||
<button type="button" class="auth-btn manual" onclick="toggleManualForm()">
|
||||
<span class="icon">📋</span>
|
||||
Import RefreshToken from Kiro IDE
|
||||
</button>
|
||||
</div>
|
||||
|
||||
<div class="idc-form" id="idcForm">
|
||||
<form action="/v0/oauth/kiro/start" method="get">
|
||||
<input type="hidden" name="method" value="idc">
|
||||
|
||||
<div class="form-group">
|
||||
<label for="startUrl">Start URL</label>
|
||||
<input type="url" id="startUrl" name="startUrl" placeholder="https://your-org.awsapps.com/start" required>
|
||||
<div class="hint">Your AWS Identity Center Start URL</div>
|
||||
</div>
|
||||
|
||||
<div class="form-group">
|
||||
<label for="region">Region</label>
|
||||
<input type="text" id="region" name="region" value="us-east-1" placeholder="us-east-1">
|
||||
<div class="hint">AWS Region for your Identity Center</div>
|
||||
</div>
|
||||
|
||||
<button type="submit" class="submit-btn">
|
||||
🚀 Continue with IDC
|
||||
</button>
|
||||
</form>
|
||||
</div>
|
||||
|
||||
<div class="manual-form" id="manualForm">
|
||||
<form id="importForm" onsubmit="submitImport(event)">
|
||||
<div class="form-group">
|
||||
<label for="refreshToken">Refresh Token</label>
|
||||
<textarea id="refreshToken" name="refreshToken" placeholder="Paste your refreshToken here (starts with aorAAAAAG...)" required></textarea>
|
||||
<div class="hint">Copy from Kiro IDE: ~/.kiro/kiro-auth-token.json → refreshToken field</div>
|
||||
</div>
|
||||
|
||||
<button type="submit" class="submit-btn" id="importBtn">
|
||||
📥 Import Token
|
||||
</button>
|
||||
|
||||
<div class="status-message" id="importStatus"></div>
|
||||
</form>
|
||||
</div>
|
||||
|
||||
<div class="warning-box">
|
||||
⚠️ <strong>Note:</strong> Google and GitHub login are not available for third-party applications due to AWS Cognito restrictions. Please use AWS Builder ID or import your token from Kiro IDE.
|
||||
</div>
|
||||
|
||||
<div class="info-box">
|
||||
💡 <strong>How to get RefreshToken:</strong><br>
|
||||
1. Open Kiro IDE and login with Google/GitHub<br>
|
||||
2. Find the token file: <code>~/.kiro/kiro-auth-token.json</code><br>
|
||||
3. Copy the <code>refreshToken</code> value and paste it above
|
||||
</div>
|
||||
</div>
|
||||
|
||||
<script>
|
||||
function toggleIdcForm() {
|
||||
const idcForm = document.getElementById('idcForm');
|
||||
const manualForm = document.getElementById('manualForm');
|
||||
manualForm.classList.remove('show');
|
||||
idcForm.classList.toggle('show');
|
||||
if (idcForm.classList.contains('show')) {
|
||||
document.getElementById('startUrl').focus();
|
||||
}
|
||||
}
|
||||
|
||||
function toggleManualForm() {
|
||||
const idcForm = document.getElementById('idcForm');
|
||||
const manualForm = document.getElementById('manualForm');
|
||||
idcForm.classList.remove('show');
|
||||
manualForm.classList.toggle('show');
|
||||
if (manualForm.classList.contains('show')) {
|
||||
document.getElementById('refreshToken').focus();
|
||||
}
|
||||
}
|
||||
|
||||
async function submitImport(event) {
|
||||
event.preventDefault();
|
||||
const refreshToken = document.getElementById('refreshToken').value.trim();
|
||||
const statusEl = document.getElementById('importStatus');
|
||||
const btn = document.getElementById('importBtn');
|
||||
|
||||
if (!refreshToken) {
|
||||
statusEl.className = 'status-message error';
|
||||
statusEl.textContent = 'Please enter a refresh token';
|
||||
return;
|
||||
}
|
||||
|
||||
if (!refreshToken.startsWith('aorAAAAAG')) {
|
||||
statusEl.className = 'status-message error';
|
||||
statusEl.textContent = 'Invalid token format. Token should start with aorAAAAAG...';
|
||||
return;
|
||||
}
|
||||
|
||||
btn.disabled = true;
|
||||
btn.textContent = '⏳ Importing...';
|
||||
statusEl.className = 'status-message';
|
||||
statusEl.style.display = 'none';
|
||||
|
||||
try {
|
||||
const response = await fetch('/v0/oauth/kiro/import', {
|
||||
method: 'POST',
|
||||
headers: { 'Content-Type': 'application/json' },
|
||||
body: JSON.stringify({ refreshToken: refreshToken })
|
||||
});
|
||||
|
||||
const data = await response.json();
|
||||
|
||||
if (response.ok && data.success) {
|
||||
statusEl.className = 'status-message success';
|
||||
statusEl.textContent = '✅ Token imported successfully! File: ' + (data.fileName || 'kiro-token.json');
|
||||
} else {
|
||||
statusEl.className = 'status-message error';
|
||||
statusEl.textContent = '❌ ' + (data.error || data.message || 'Import failed');
|
||||
}
|
||||
} catch (error) {
|
||||
statusEl.className = 'status-message error';
|
||||
statusEl.textContent = '❌ Network error: ' + error.message;
|
||||
} finally {
|
||||
btn.disabled = false;
|
||||
btn.textContent = '📥 Import Token';
|
||||
}
|
||||
}
|
||||
</script>
|
||||
</body>
|
||||
</html>`
|
||||
)
|
||||
316
internal/auth/kiro/rate_limiter.go
Normal file
316
internal/auth/kiro/rate_limiter.go
Normal file
@@ -0,0 +1,316 @@
|
||||
package kiro
|
||||
|
||||
import (
|
||||
"math"
|
||||
"math/rand"
|
||||
"strings"
|
||||
"sync"
|
||||
"time"
|
||||
)
|
||||
|
||||
const (
|
||||
DefaultMinTokenInterval = 10 * time.Second
|
||||
DefaultMaxTokenInterval = 30 * time.Second
|
||||
DefaultDailyMaxRequests = 500
|
||||
DefaultJitterPercent = 0.3
|
||||
DefaultBackoffBase = 2 * time.Minute
|
||||
DefaultBackoffMax = 60 * time.Minute
|
||||
DefaultBackoffMultiplier = 2.0
|
||||
DefaultSuspendCooldown = 24 * time.Hour
|
||||
)
|
||||
|
||||
// TokenState Token 状态
|
||||
type TokenState struct {
|
||||
LastRequest time.Time
|
||||
RequestCount int
|
||||
CooldownEnd time.Time
|
||||
FailCount int
|
||||
DailyRequests int
|
||||
DailyResetTime time.Time
|
||||
IsSuspended bool
|
||||
SuspendedAt time.Time
|
||||
SuspendReason string
|
||||
}
|
||||
|
||||
// RateLimiter 频率限制器
|
||||
type RateLimiter struct {
|
||||
mu sync.RWMutex
|
||||
states map[string]*TokenState
|
||||
minTokenInterval time.Duration
|
||||
maxTokenInterval time.Duration
|
||||
dailyMaxRequests int
|
||||
jitterPercent float64
|
||||
backoffBase time.Duration
|
||||
backoffMax time.Duration
|
||||
backoffMultiplier float64
|
||||
suspendCooldown time.Duration
|
||||
rng *rand.Rand
|
||||
}
|
||||
|
||||
// NewRateLimiter 创建默认配置的频率限制器
|
||||
func NewRateLimiter() *RateLimiter {
|
||||
return &RateLimiter{
|
||||
states: make(map[string]*TokenState),
|
||||
minTokenInterval: DefaultMinTokenInterval,
|
||||
maxTokenInterval: DefaultMaxTokenInterval,
|
||||
dailyMaxRequests: DefaultDailyMaxRequests,
|
||||
jitterPercent: DefaultJitterPercent,
|
||||
backoffBase: DefaultBackoffBase,
|
||||
backoffMax: DefaultBackoffMax,
|
||||
backoffMultiplier: DefaultBackoffMultiplier,
|
||||
suspendCooldown: DefaultSuspendCooldown,
|
||||
rng: rand.New(rand.NewSource(time.Now().UnixNano())),
|
||||
}
|
||||
}
|
||||
|
||||
// RateLimiterConfig 频率限制器配置
|
||||
type RateLimiterConfig struct {
|
||||
MinTokenInterval time.Duration
|
||||
MaxTokenInterval time.Duration
|
||||
DailyMaxRequests int
|
||||
JitterPercent float64
|
||||
BackoffBase time.Duration
|
||||
BackoffMax time.Duration
|
||||
BackoffMultiplier float64
|
||||
SuspendCooldown time.Duration
|
||||
}
|
||||
|
||||
// NewRateLimiterWithConfig 使用自定义配置创建频率限制器
|
||||
func NewRateLimiterWithConfig(cfg RateLimiterConfig) *RateLimiter {
|
||||
rl := NewRateLimiter()
|
||||
if cfg.MinTokenInterval > 0 {
|
||||
rl.minTokenInterval = cfg.MinTokenInterval
|
||||
}
|
||||
if cfg.MaxTokenInterval > 0 {
|
||||
rl.maxTokenInterval = cfg.MaxTokenInterval
|
||||
}
|
||||
if cfg.DailyMaxRequests > 0 {
|
||||
rl.dailyMaxRequests = cfg.DailyMaxRequests
|
||||
}
|
||||
if cfg.JitterPercent > 0 {
|
||||
rl.jitterPercent = cfg.JitterPercent
|
||||
}
|
||||
if cfg.BackoffBase > 0 {
|
||||
rl.backoffBase = cfg.BackoffBase
|
||||
}
|
||||
if cfg.BackoffMax > 0 {
|
||||
rl.backoffMax = cfg.BackoffMax
|
||||
}
|
||||
if cfg.BackoffMultiplier > 0 {
|
||||
rl.backoffMultiplier = cfg.BackoffMultiplier
|
||||
}
|
||||
if cfg.SuspendCooldown > 0 {
|
||||
rl.suspendCooldown = cfg.SuspendCooldown
|
||||
}
|
||||
return rl
|
||||
}
|
||||
|
||||
// getOrCreateState 获取或创建 Token 状态
|
||||
func (rl *RateLimiter) getOrCreateState(tokenKey string) *TokenState {
|
||||
state, exists := rl.states[tokenKey]
|
||||
if !exists {
|
||||
state = &TokenState{
|
||||
DailyResetTime: time.Now().Truncate(24 * time.Hour).Add(24 * time.Hour),
|
||||
}
|
||||
rl.states[tokenKey] = state
|
||||
}
|
||||
return state
|
||||
}
|
||||
|
||||
// resetDailyIfNeeded 如果需要则重置每日计数
|
||||
func (rl *RateLimiter) resetDailyIfNeeded(state *TokenState) {
|
||||
now := time.Now()
|
||||
if now.After(state.DailyResetTime) {
|
||||
state.DailyRequests = 0
|
||||
state.DailyResetTime = now.Truncate(24 * time.Hour).Add(24 * time.Hour)
|
||||
}
|
||||
}
|
||||
|
||||
// calculateInterval 计算带抖动的随机间隔
|
||||
func (rl *RateLimiter) calculateInterval() time.Duration {
|
||||
baseInterval := rl.minTokenInterval + time.Duration(rl.rng.Int63n(int64(rl.maxTokenInterval-rl.minTokenInterval)))
|
||||
jitter := time.Duration(float64(baseInterval) * rl.jitterPercent * (rl.rng.Float64()*2 - 1))
|
||||
return baseInterval + jitter
|
||||
}
|
||||
|
||||
// WaitForToken 等待 Token 可用(带抖动的随机间隔)
|
||||
func (rl *RateLimiter) WaitForToken(tokenKey string) {
|
||||
rl.mu.Lock()
|
||||
state := rl.getOrCreateState(tokenKey)
|
||||
rl.resetDailyIfNeeded(state)
|
||||
|
||||
now := time.Now()
|
||||
|
||||
// 检查是否在冷却期
|
||||
if now.Before(state.CooldownEnd) {
|
||||
waitTime := state.CooldownEnd.Sub(now)
|
||||
rl.mu.Unlock()
|
||||
time.Sleep(waitTime)
|
||||
rl.mu.Lock()
|
||||
state = rl.getOrCreateState(tokenKey)
|
||||
now = time.Now()
|
||||
}
|
||||
|
||||
// 计算距离上次请求的间隔
|
||||
interval := rl.calculateInterval()
|
||||
nextAllowedTime := state.LastRequest.Add(interval)
|
||||
|
||||
if now.Before(nextAllowedTime) {
|
||||
waitTime := nextAllowedTime.Sub(now)
|
||||
rl.mu.Unlock()
|
||||
time.Sleep(waitTime)
|
||||
rl.mu.Lock()
|
||||
state = rl.getOrCreateState(tokenKey)
|
||||
}
|
||||
|
||||
state.LastRequest = time.Now()
|
||||
state.RequestCount++
|
||||
state.DailyRequests++
|
||||
rl.mu.Unlock()
|
||||
}
|
||||
|
||||
// MarkTokenFailed 标记 Token 失败
|
||||
func (rl *RateLimiter) MarkTokenFailed(tokenKey string) {
|
||||
rl.mu.Lock()
|
||||
defer rl.mu.Unlock()
|
||||
|
||||
state := rl.getOrCreateState(tokenKey)
|
||||
state.FailCount++
|
||||
state.CooldownEnd = time.Now().Add(rl.calculateBackoff(state.FailCount))
|
||||
}
|
||||
|
||||
// MarkTokenSuccess 标记 Token 成功
|
||||
func (rl *RateLimiter) MarkTokenSuccess(tokenKey string) {
|
||||
rl.mu.Lock()
|
||||
defer rl.mu.Unlock()
|
||||
|
||||
state := rl.getOrCreateState(tokenKey)
|
||||
state.FailCount = 0
|
||||
state.CooldownEnd = time.Time{}
|
||||
}
|
||||
|
||||
// CheckAndMarkSuspended 检测暂停错误并标记
|
||||
func (rl *RateLimiter) CheckAndMarkSuspended(tokenKey string, errorMsg string) bool {
|
||||
suspendKeywords := []string{
|
||||
"suspended",
|
||||
"banned",
|
||||
"disabled",
|
||||
"account has been",
|
||||
"access denied",
|
||||
"rate limit exceeded",
|
||||
"too many requests",
|
||||
"quota exceeded",
|
||||
}
|
||||
|
||||
lowerMsg := strings.ToLower(errorMsg)
|
||||
for _, keyword := range suspendKeywords {
|
||||
if strings.Contains(lowerMsg, keyword) {
|
||||
rl.mu.Lock()
|
||||
defer rl.mu.Unlock()
|
||||
|
||||
state := rl.getOrCreateState(tokenKey)
|
||||
state.IsSuspended = true
|
||||
state.SuspendedAt = time.Now()
|
||||
state.SuspendReason = errorMsg
|
||||
state.CooldownEnd = time.Now().Add(rl.suspendCooldown)
|
||||
return true
|
||||
}
|
||||
}
|
||||
return false
|
||||
}
|
||||
|
||||
// IsTokenAvailable 检查 Token 是否可用
|
||||
func (rl *RateLimiter) IsTokenAvailable(tokenKey string) bool {
|
||||
rl.mu.RLock()
|
||||
defer rl.mu.RUnlock()
|
||||
|
||||
state, exists := rl.states[tokenKey]
|
||||
if !exists {
|
||||
return true
|
||||
}
|
||||
|
||||
now := time.Now()
|
||||
|
||||
// 检查是否被暂停
|
||||
if state.IsSuspended {
|
||||
if now.After(state.SuspendedAt.Add(rl.suspendCooldown)) {
|
||||
return true
|
||||
}
|
||||
return false
|
||||
}
|
||||
|
||||
// 检查是否在冷却期
|
||||
if now.Before(state.CooldownEnd) {
|
||||
return false
|
||||
}
|
||||
|
||||
// 检查每日请求限制
|
||||
rl.mu.RUnlock()
|
||||
rl.mu.Lock()
|
||||
rl.resetDailyIfNeeded(state)
|
||||
dailyRequests := state.DailyRequests
|
||||
dailyMax := rl.dailyMaxRequests
|
||||
rl.mu.Unlock()
|
||||
rl.mu.RLock()
|
||||
|
||||
if dailyRequests >= dailyMax {
|
||||
return false
|
||||
}
|
||||
|
||||
return true
|
||||
}
|
||||
|
||||
// calculateBackoff 计算指数退避时间
|
||||
func (rl *RateLimiter) calculateBackoff(failCount int) time.Duration {
|
||||
if failCount <= 0 {
|
||||
return 0
|
||||
}
|
||||
|
||||
backoff := float64(rl.backoffBase) * math.Pow(rl.backoffMultiplier, float64(failCount-1))
|
||||
|
||||
// 添加抖动
|
||||
jitter := backoff * rl.jitterPercent * (rl.rng.Float64()*2 - 1)
|
||||
backoff += jitter
|
||||
|
||||
if time.Duration(backoff) > rl.backoffMax {
|
||||
return rl.backoffMax
|
||||
}
|
||||
return time.Duration(backoff)
|
||||
}
|
||||
|
||||
// GetTokenState 获取 Token 状态(只读)
|
||||
func (rl *RateLimiter) GetTokenState(tokenKey string) *TokenState {
|
||||
rl.mu.RLock()
|
||||
defer rl.mu.RUnlock()
|
||||
|
||||
state, exists := rl.states[tokenKey]
|
||||
if !exists {
|
||||
return nil
|
||||
}
|
||||
|
||||
// 返回副本以防止外部修改
|
||||
stateCopy := *state
|
||||
return &stateCopy
|
||||
}
|
||||
|
||||
// ClearTokenState 清除 Token 状态
|
||||
func (rl *RateLimiter) ClearTokenState(tokenKey string) {
|
||||
rl.mu.Lock()
|
||||
defer rl.mu.Unlock()
|
||||
delete(rl.states, tokenKey)
|
||||
}
|
||||
|
||||
// ResetSuspension 重置暂停状态
|
||||
func (rl *RateLimiter) ResetSuspension(tokenKey string) {
|
||||
rl.mu.Lock()
|
||||
defer rl.mu.Unlock()
|
||||
|
||||
state, exists := rl.states[tokenKey]
|
||||
if exists {
|
||||
state.IsSuspended = false
|
||||
state.SuspendedAt = time.Time{}
|
||||
state.SuspendReason = ""
|
||||
state.CooldownEnd = time.Time{}
|
||||
state.FailCount = 0
|
||||
}
|
||||
}
|
||||
46
internal/auth/kiro/rate_limiter_singleton.go
Normal file
46
internal/auth/kiro/rate_limiter_singleton.go
Normal file
@@ -0,0 +1,46 @@
|
||||
package kiro
|
||||
|
||||
import (
|
||||
"sync"
|
||||
"time"
|
||||
|
||||
log "github.com/sirupsen/logrus"
|
||||
)
|
||||
|
||||
var (
|
||||
globalRateLimiter *RateLimiter
|
||||
globalRateLimiterOnce sync.Once
|
||||
|
||||
globalCooldownManager *CooldownManager
|
||||
globalCooldownManagerOnce sync.Once
|
||||
cooldownStopCh chan struct{}
|
||||
)
|
||||
|
||||
// GetGlobalRateLimiter returns the singleton RateLimiter instance.
|
||||
func GetGlobalRateLimiter() *RateLimiter {
|
||||
globalRateLimiterOnce.Do(func() {
|
||||
globalRateLimiter = NewRateLimiter()
|
||||
log.Info("kiro: global RateLimiter initialized")
|
||||
})
|
||||
return globalRateLimiter
|
||||
}
|
||||
|
||||
// GetGlobalCooldownManager returns the singleton CooldownManager instance.
|
||||
func GetGlobalCooldownManager() *CooldownManager {
|
||||
globalCooldownManagerOnce.Do(func() {
|
||||
globalCooldownManager = NewCooldownManager()
|
||||
cooldownStopCh = make(chan struct{})
|
||||
go globalCooldownManager.StartCleanupRoutine(5*time.Minute, cooldownStopCh)
|
||||
log.Info("kiro: global CooldownManager initialized with cleanup routine")
|
||||
})
|
||||
return globalCooldownManager
|
||||
}
|
||||
|
||||
// ShutdownRateLimiters stops the cooldown cleanup routine.
|
||||
// Should be called during application shutdown.
|
||||
func ShutdownRateLimiters() {
|
||||
if cooldownStopCh != nil {
|
||||
close(cooldownStopCh)
|
||||
log.Info("kiro: rate limiter cleanup routine stopped")
|
||||
}
|
||||
}
|
||||
304
internal/auth/kiro/rate_limiter_test.go
Normal file
304
internal/auth/kiro/rate_limiter_test.go
Normal file
@@ -0,0 +1,304 @@
|
||||
package kiro
|
||||
|
||||
import (
|
||||
"sync"
|
||||
"testing"
|
||||
"time"
|
||||
)
|
||||
|
||||
func TestNewRateLimiter(t *testing.T) {
|
||||
rl := NewRateLimiter()
|
||||
if rl == nil {
|
||||
t.Fatal("expected non-nil RateLimiter")
|
||||
}
|
||||
if rl.states == nil {
|
||||
t.Error("expected non-nil states map")
|
||||
}
|
||||
if rl.minTokenInterval != DefaultMinTokenInterval {
|
||||
t.Errorf("expected minTokenInterval %v, got %v", DefaultMinTokenInterval, rl.minTokenInterval)
|
||||
}
|
||||
if rl.maxTokenInterval != DefaultMaxTokenInterval {
|
||||
t.Errorf("expected maxTokenInterval %v, got %v", DefaultMaxTokenInterval, rl.maxTokenInterval)
|
||||
}
|
||||
if rl.dailyMaxRequests != DefaultDailyMaxRequests {
|
||||
t.Errorf("expected dailyMaxRequests %d, got %d", DefaultDailyMaxRequests, rl.dailyMaxRequests)
|
||||
}
|
||||
}
|
||||
|
||||
func TestNewRateLimiterWithConfig(t *testing.T) {
|
||||
cfg := RateLimiterConfig{
|
||||
MinTokenInterval: 5 * time.Second,
|
||||
MaxTokenInterval: 15 * time.Second,
|
||||
DailyMaxRequests: 100,
|
||||
JitterPercent: 0.2,
|
||||
BackoffBase: 1 * time.Minute,
|
||||
BackoffMax: 30 * time.Minute,
|
||||
BackoffMultiplier: 1.5,
|
||||
SuspendCooldown: 12 * time.Hour,
|
||||
}
|
||||
|
||||
rl := NewRateLimiterWithConfig(cfg)
|
||||
if rl.minTokenInterval != 5*time.Second {
|
||||
t.Errorf("expected minTokenInterval 5s, got %v", rl.minTokenInterval)
|
||||
}
|
||||
if rl.maxTokenInterval != 15*time.Second {
|
||||
t.Errorf("expected maxTokenInterval 15s, got %v", rl.maxTokenInterval)
|
||||
}
|
||||
if rl.dailyMaxRequests != 100 {
|
||||
t.Errorf("expected dailyMaxRequests 100, got %d", rl.dailyMaxRequests)
|
||||
}
|
||||
}
|
||||
|
||||
func TestNewRateLimiterWithConfig_PartialConfig(t *testing.T) {
|
||||
cfg := RateLimiterConfig{
|
||||
MinTokenInterval: 5 * time.Second,
|
||||
}
|
||||
|
||||
rl := NewRateLimiterWithConfig(cfg)
|
||||
if rl.minTokenInterval != 5*time.Second {
|
||||
t.Errorf("expected minTokenInterval 5s, got %v", rl.minTokenInterval)
|
||||
}
|
||||
if rl.maxTokenInterval != DefaultMaxTokenInterval {
|
||||
t.Errorf("expected default maxTokenInterval, got %v", rl.maxTokenInterval)
|
||||
}
|
||||
}
|
||||
|
||||
func TestGetTokenState_NonExistent(t *testing.T) {
|
||||
rl := NewRateLimiter()
|
||||
state := rl.GetTokenState("nonexistent")
|
||||
if state != nil {
|
||||
t.Error("expected nil state for non-existent token")
|
||||
}
|
||||
}
|
||||
|
||||
func TestIsTokenAvailable_NewToken(t *testing.T) {
|
||||
rl := NewRateLimiter()
|
||||
if !rl.IsTokenAvailable("newtoken") {
|
||||
t.Error("expected new token to be available")
|
||||
}
|
||||
}
|
||||
|
||||
func TestMarkTokenFailed(t *testing.T) {
|
||||
rl := NewRateLimiter()
|
||||
rl.MarkTokenFailed("token1")
|
||||
|
||||
state := rl.GetTokenState("token1")
|
||||
if state == nil {
|
||||
t.Fatal("expected non-nil state")
|
||||
}
|
||||
if state.FailCount != 1 {
|
||||
t.Errorf("expected FailCount 1, got %d", state.FailCount)
|
||||
}
|
||||
if state.CooldownEnd.IsZero() {
|
||||
t.Error("expected non-zero CooldownEnd")
|
||||
}
|
||||
}
|
||||
|
||||
func TestMarkTokenSuccess(t *testing.T) {
|
||||
rl := NewRateLimiter()
|
||||
rl.MarkTokenFailed("token1")
|
||||
rl.MarkTokenFailed("token1")
|
||||
rl.MarkTokenSuccess("token1")
|
||||
|
||||
state := rl.GetTokenState("token1")
|
||||
if state == nil {
|
||||
t.Fatal("expected non-nil state")
|
||||
}
|
||||
if state.FailCount != 0 {
|
||||
t.Errorf("expected FailCount 0, got %d", state.FailCount)
|
||||
}
|
||||
if !state.CooldownEnd.IsZero() {
|
||||
t.Error("expected zero CooldownEnd after success")
|
||||
}
|
||||
}
|
||||
|
||||
func TestCheckAndMarkSuspended_Suspended(t *testing.T) {
|
||||
rl := NewRateLimiter()
|
||||
|
||||
testCases := []string{
|
||||
"Account has been suspended",
|
||||
"You are banned from this service",
|
||||
"Account disabled",
|
||||
"Access denied permanently",
|
||||
"Rate limit exceeded",
|
||||
"Too many requests",
|
||||
"Quota exceeded for today",
|
||||
}
|
||||
|
||||
for i, msg := range testCases {
|
||||
tokenKey := "token" + string(rune('a'+i))
|
||||
if !rl.CheckAndMarkSuspended(tokenKey, msg) {
|
||||
t.Errorf("expected suspension detected for: %s", msg)
|
||||
}
|
||||
state := rl.GetTokenState(tokenKey)
|
||||
if !state.IsSuspended {
|
||||
t.Errorf("expected IsSuspended true for: %s", msg)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func TestCheckAndMarkSuspended_NotSuspended(t *testing.T) {
|
||||
rl := NewRateLimiter()
|
||||
|
||||
normalErrors := []string{
|
||||
"connection timeout",
|
||||
"internal server error",
|
||||
"bad request",
|
||||
"invalid token format",
|
||||
}
|
||||
|
||||
for i, msg := range normalErrors {
|
||||
tokenKey := "token" + string(rune('a'+i))
|
||||
if rl.CheckAndMarkSuspended(tokenKey, msg) {
|
||||
t.Errorf("unexpected suspension for: %s", msg)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func TestIsTokenAvailable_Suspended(t *testing.T) {
|
||||
rl := NewRateLimiter()
|
||||
rl.CheckAndMarkSuspended("token1", "Account suspended")
|
||||
|
||||
if rl.IsTokenAvailable("token1") {
|
||||
t.Error("expected suspended token to be unavailable")
|
||||
}
|
||||
}
|
||||
|
||||
func TestClearTokenState(t *testing.T) {
|
||||
rl := NewRateLimiter()
|
||||
rl.MarkTokenFailed("token1")
|
||||
rl.ClearTokenState("token1")
|
||||
|
||||
state := rl.GetTokenState("token1")
|
||||
if state != nil {
|
||||
t.Error("expected nil state after clear")
|
||||
}
|
||||
}
|
||||
|
||||
func TestResetSuspension(t *testing.T) {
|
||||
rl := NewRateLimiter()
|
||||
rl.CheckAndMarkSuspended("token1", "Account suspended")
|
||||
rl.ResetSuspension("token1")
|
||||
|
||||
state := rl.GetTokenState("token1")
|
||||
if state.IsSuspended {
|
||||
t.Error("expected IsSuspended false after reset")
|
||||
}
|
||||
if state.FailCount != 0 {
|
||||
t.Errorf("expected FailCount 0, got %d", state.FailCount)
|
||||
}
|
||||
}
|
||||
|
||||
func TestResetSuspension_NonExistent(t *testing.T) {
|
||||
rl := NewRateLimiter()
|
||||
rl.ResetSuspension("nonexistent")
|
||||
}
|
||||
|
||||
func TestCalculateBackoff_ZeroFailCount(t *testing.T) {
|
||||
rl := NewRateLimiter()
|
||||
backoff := rl.calculateBackoff(0)
|
||||
if backoff != 0 {
|
||||
t.Errorf("expected 0 backoff for 0 fails, got %v", backoff)
|
||||
}
|
||||
}
|
||||
|
||||
func TestCalculateBackoff_Exponential(t *testing.T) {
|
||||
cfg := RateLimiterConfig{
|
||||
BackoffBase: 1 * time.Minute,
|
||||
BackoffMax: 60 * time.Minute,
|
||||
BackoffMultiplier: 2.0,
|
||||
JitterPercent: 0.3,
|
||||
}
|
||||
rl := NewRateLimiterWithConfig(cfg)
|
||||
|
||||
backoff1 := rl.calculateBackoff(1)
|
||||
if backoff1 < 40*time.Second || backoff1 > 80*time.Second {
|
||||
t.Errorf("expected ~1min (with jitter) for fail 1, got %v", backoff1)
|
||||
}
|
||||
|
||||
backoff2 := rl.calculateBackoff(2)
|
||||
if backoff2 < 80*time.Second || backoff2 > 160*time.Second {
|
||||
t.Errorf("expected ~2min (with jitter) for fail 2, got %v", backoff2)
|
||||
}
|
||||
}
|
||||
|
||||
func TestCalculateBackoff_MaxCap(t *testing.T) {
|
||||
cfg := RateLimiterConfig{
|
||||
BackoffBase: 1 * time.Minute,
|
||||
BackoffMax: 10 * time.Minute,
|
||||
BackoffMultiplier: 2.0,
|
||||
JitterPercent: 0,
|
||||
}
|
||||
rl := NewRateLimiterWithConfig(cfg)
|
||||
|
||||
backoff := rl.calculateBackoff(10)
|
||||
if backoff > 10*time.Minute {
|
||||
t.Errorf("expected backoff capped at 10min, got %v", backoff)
|
||||
}
|
||||
}
|
||||
|
||||
func TestGetTokenState_ReturnsCopy(t *testing.T) {
|
||||
rl := NewRateLimiter()
|
||||
rl.MarkTokenFailed("token1")
|
||||
|
||||
state1 := rl.GetTokenState("token1")
|
||||
state1.FailCount = 999
|
||||
|
||||
state2 := rl.GetTokenState("token1")
|
||||
if state2.FailCount == 999 {
|
||||
t.Error("GetTokenState should return a copy")
|
||||
}
|
||||
}
|
||||
|
||||
func TestRateLimiter_ConcurrentAccess(t *testing.T) {
|
||||
rl := NewRateLimiter()
|
||||
const numGoroutines = 50
|
||||
const numOperations = 50
|
||||
|
||||
var wg sync.WaitGroup
|
||||
wg.Add(numGoroutines)
|
||||
|
||||
for i := 0; i < numGoroutines; i++ {
|
||||
go func(id int) {
|
||||
defer wg.Done()
|
||||
tokenKey := "token" + string(rune('a'+id%10))
|
||||
for j := 0; j < numOperations; j++ {
|
||||
switch j % 6 {
|
||||
case 0:
|
||||
rl.IsTokenAvailable(tokenKey)
|
||||
case 1:
|
||||
rl.MarkTokenFailed(tokenKey)
|
||||
case 2:
|
||||
rl.MarkTokenSuccess(tokenKey)
|
||||
case 3:
|
||||
rl.GetTokenState(tokenKey)
|
||||
case 4:
|
||||
rl.CheckAndMarkSuspended(tokenKey, "test error")
|
||||
case 5:
|
||||
rl.ResetSuspension(tokenKey)
|
||||
}
|
||||
}
|
||||
}(i)
|
||||
}
|
||||
|
||||
wg.Wait()
|
||||
}
|
||||
|
||||
func TestCalculateInterval_WithinRange(t *testing.T) {
|
||||
cfg := RateLimiterConfig{
|
||||
MinTokenInterval: 10 * time.Second,
|
||||
MaxTokenInterval: 30 * time.Second,
|
||||
JitterPercent: 0.3,
|
||||
}
|
||||
rl := NewRateLimiterWithConfig(cfg)
|
||||
|
||||
minAllowed := 7 * time.Second
|
||||
maxAllowed := 40 * time.Second
|
||||
|
||||
for i := 0; i < 100; i++ {
|
||||
interval := rl.calculateInterval()
|
||||
if interval < minAllowed || interval > maxAllowed {
|
||||
t.Errorf("interval %v outside expected range [%v, %v]", interval, minAllowed, maxAllowed)
|
||||
}
|
||||
}
|
||||
}
|
||||
@@ -9,7 +9,9 @@ import (
|
||||
"encoding/base64"
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"html"
|
||||
"io"
|
||||
"net"
|
||||
"net/http"
|
||||
"net/url"
|
||||
"os"
|
||||
@@ -31,6 +33,9 @@ const (
|
||||
|
||||
// OAuth timeout
|
||||
socialAuthTimeout = 10 * time.Minute
|
||||
|
||||
// Default callback port for social auth HTTP server
|
||||
socialAuthCallbackPort = 9876
|
||||
)
|
||||
|
||||
// SocialProvider represents the social login provider.
|
||||
@@ -67,6 +72,13 @@ type RefreshTokenRequest struct {
|
||||
RefreshToken string `json:"refreshToken"`
|
||||
}
|
||||
|
||||
// WebCallbackResult contains the OAuth callback result from HTTP server.
|
||||
type WebCallbackResult struct {
|
||||
Code string
|
||||
State string
|
||||
Error string
|
||||
}
|
||||
|
||||
// SocialAuthClient handles social authentication with Kiro.
|
||||
type SocialAuthClient struct {
|
||||
httpClient *http.Client
|
||||
@@ -87,6 +99,83 @@ func NewSocialAuthClient(cfg *config.Config) *SocialAuthClient {
|
||||
}
|
||||
}
|
||||
|
||||
// startWebCallbackServer starts a local HTTP server to receive the OAuth callback.
|
||||
// This is used instead of the kiro:// protocol handler to avoid redirect_mismatch errors.
|
||||
func (c *SocialAuthClient) startWebCallbackServer(ctx context.Context, expectedState string) (string, <-chan WebCallbackResult, error) {
|
||||
// Try to find an available port - use localhost like Kiro does
|
||||
listener, err := net.Listen("tcp", fmt.Sprintf("localhost:%d", socialAuthCallbackPort))
|
||||
if err != nil {
|
||||
// Try with dynamic port (RFC 8252 allows dynamic ports for native apps)
|
||||
log.Warnf("kiro social auth: default port %d is busy, falling back to dynamic port", socialAuthCallbackPort)
|
||||
listener, err = net.Listen("tcp", "localhost:0")
|
||||
if err != nil {
|
||||
return "", nil, fmt.Errorf("failed to start callback server: %w", err)
|
||||
}
|
||||
}
|
||||
|
||||
port := listener.Addr().(*net.TCPAddr).Port
|
||||
// Use http scheme for local callback server
|
||||
redirectURI := fmt.Sprintf("http://localhost:%d/oauth/callback", port)
|
||||
resultChan := make(chan WebCallbackResult, 1)
|
||||
|
||||
server := &http.Server{
|
||||
ReadHeaderTimeout: 10 * time.Second,
|
||||
}
|
||||
|
||||
mux := http.NewServeMux()
|
||||
mux.HandleFunc("/oauth/callback", func(w http.ResponseWriter, r *http.Request) {
|
||||
code := r.URL.Query().Get("code")
|
||||
state := r.URL.Query().Get("state")
|
||||
errParam := r.URL.Query().Get("error")
|
||||
|
||||
if errParam != "" {
|
||||
w.Header().Set("Content-Type", "text/html; charset=utf-8")
|
||||
w.WriteHeader(http.StatusBadRequest)
|
||||
fmt.Fprintf(w, `<!DOCTYPE html>
|
||||
<html><head><title>Login Failed</title></head>
|
||||
<body><h1>Login Failed</h1><p>%s</p><p>You can close this window.</p></body></html>`, html.EscapeString(errParam))
|
||||
resultChan <- WebCallbackResult{Error: errParam}
|
||||
return
|
||||
}
|
||||
|
||||
if state != expectedState {
|
||||
w.Header().Set("Content-Type", "text/html; charset=utf-8")
|
||||
w.WriteHeader(http.StatusBadRequest)
|
||||
fmt.Fprint(w, `<!DOCTYPE html>
|
||||
<html><head><title>Login Failed</title></head>
|
||||
<body><h1>Login Failed</h1><p>Invalid state parameter</p><p>You can close this window.</p></body></html>`)
|
||||
resultChan <- WebCallbackResult{Error: "state mismatch"}
|
||||
return
|
||||
}
|
||||
|
||||
w.Header().Set("Content-Type", "text/html; charset=utf-8")
|
||||
fmt.Fprint(w, `<!DOCTYPE html>
|
||||
<html><head><title>Login Successful</title></head>
|
||||
<body><h1>Login Successful!</h1><p>You can close this window and return to the terminal.</p>
|
||||
<script>window.close();</script></body></html>`)
|
||||
resultChan <- WebCallbackResult{Code: code, State: state}
|
||||
})
|
||||
|
||||
server.Handler = mux
|
||||
|
||||
go func() {
|
||||
if err := server.Serve(listener); err != nil && err != http.ErrServerClosed {
|
||||
log.Debugf("kiro social auth callback server error: %v", err)
|
||||
}
|
||||
}()
|
||||
|
||||
go func() {
|
||||
select {
|
||||
case <-ctx.Done():
|
||||
case <-time.After(socialAuthTimeout):
|
||||
case <-resultChan:
|
||||
}
|
||||
_ = server.Shutdown(context.Background())
|
||||
}()
|
||||
|
||||
return redirectURI, resultChan, nil
|
||||
}
|
||||
|
||||
// generatePKCE generates PKCE code verifier and challenge.
|
||||
func generatePKCE() (verifier, challenge string, err error) {
|
||||
// Generate 32 bytes of random data for verifier
|
||||
@@ -217,10 +306,12 @@ func (c *SocialAuthClient) RefreshSocialToken(ctx context.Context, refreshToken
|
||||
ExpiresAt: expiresAt.Format(time.RFC3339),
|
||||
AuthMethod: "social",
|
||||
Provider: "", // Caller should preserve original provider
|
||||
Region: "us-east-1",
|
||||
}, nil
|
||||
}
|
||||
|
||||
// LoginWithSocial performs OAuth login with Google.
|
||||
// LoginWithSocial performs OAuth login with Google or GitHub.
|
||||
// Uses local HTTP callback server instead of custom protocol handler to avoid redirect_mismatch errors.
|
||||
func (c *SocialAuthClient) LoginWithSocial(ctx context.Context, provider SocialProvider) (*KiroTokenData, error) {
|
||||
providerName := string(provider)
|
||||
|
||||
@@ -228,28 +319,10 @@ func (c *SocialAuthClient) LoginWithSocial(ctx context.Context, provider SocialP
|
||||
fmt.Printf("║ Kiro Authentication (%s) ║\n", providerName)
|
||||
fmt.Println("╚══════════════════════════════════════════════════════════╝")
|
||||
|
||||
// Step 1: Setup protocol handler
|
||||
// Step 1: Start local HTTP callback server (instead of kiro:// protocol handler)
|
||||
// This avoids redirect_mismatch errors with AWS Cognito
|
||||
fmt.Println("\nSetting up authentication...")
|
||||
|
||||
// Start the local callback server
|
||||
handlerPort, err := c.protocolHandler.Start(ctx)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to start callback server: %w", err)
|
||||
}
|
||||
defer c.protocolHandler.Stop()
|
||||
|
||||
// Ensure protocol handler is installed and set as default
|
||||
if err := SetupProtocolHandlerIfNeeded(handlerPort); err != nil {
|
||||
fmt.Println("\n⚠ Protocol handler setup failed. Trying alternative method...")
|
||||
fmt.Println(" If you see a browser 'Open with' dialog, select your default browser.")
|
||||
fmt.Println(" For manual setup instructions, run: cliproxy kiro --help-protocol")
|
||||
log.Debugf("kiro: protocol handler setup error: %v", err)
|
||||
// Continue anyway - user might have set it up manually or select browser manually
|
||||
} else {
|
||||
// Force set our handler as default (prevents "Open with" dialog)
|
||||
forceDefaultProtocolHandler()
|
||||
}
|
||||
|
||||
// Step 2: Generate PKCE codes
|
||||
codeVerifier, codeChallenge, err := generatePKCE()
|
||||
if err != nil {
|
||||
@@ -262,8 +335,15 @@ func (c *SocialAuthClient) LoginWithSocial(ctx context.Context, provider SocialP
|
||||
return nil, fmt.Errorf("failed to generate state: %w", err)
|
||||
}
|
||||
|
||||
// Step 4: Build the login URL (Kiro uses GET request with query params)
|
||||
authURL := c.buildLoginURL(providerName, KiroRedirectURI, codeChallenge, state)
|
||||
// Step 4: Start local HTTP callback server
|
||||
redirectURI, resultChan, err := c.startWebCallbackServer(ctx, state)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to start callback server: %w", err)
|
||||
}
|
||||
log.Debugf("kiro social auth: callback server started at %s", redirectURI)
|
||||
|
||||
// Step 5: Build the login URL using HTTP redirect URI
|
||||
authURL := c.buildLoginURL(providerName, redirectURI, codeChallenge, state)
|
||||
|
||||
// Set incognito mode based on config (defaults to true for Kiro, can be overridden with --no-incognito)
|
||||
// Incognito mode enables multi-account support by bypassing cached sessions
|
||||
@@ -279,7 +359,7 @@ func (c *SocialAuthClient) LoginWithSocial(ctx context.Context, provider SocialP
|
||||
log.Debug("kiro: using incognito mode for multi-account support (default)")
|
||||
}
|
||||
|
||||
// Step 5: Open browser for user authentication
|
||||
// Step 6: Open browser for user authentication
|
||||
fmt.Println("\n════════════════════════════════════════════════════════════")
|
||||
fmt.Printf(" Opening browser for %s authentication...\n", providerName)
|
||||
fmt.Println("════════════════════════════════════════════════════════════")
|
||||
@@ -295,80 +375,78 @@ func (c *SocialAuthClient) LoginWithSocial(ctx context.Context, provider SocialP
|
||||
|
||||
fmt.Println("\n Waiting for authentication callback...")
|
||||
|
||||
// Step 6: Wait for callback
|
||||
callback, err := c.protocolHandler.WaitForCallback(ctx)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to receive callback: %w", err)
|
||||
}
|
||||
|
||||
if callback.Error != "" {
|
||||
return nil, fmt.Errorf("authentication error: %s", callback.Error)
|
||||
}
|
||||
|
||||
if callback.State != state {
|
||||
// Log state values for debugging, but don't expose in user-facing error
|
||||
log.Debugf("kiro: OAuth state mismatch - expected %s, got %s", state, callback.State)
|
||||
return nil, fmt.Errorf("OAuth state validation failed - please try again")
|
||||
}
|
||||
|
||||
if callback.Code == "" {
|
||||
return nil, fmt.Errorf("no authorization code received")
|
||||
}
|
||||
|
||||
fmt.Println("\n✓ Authorization received!")
|
||||
|
||||
// Step 7: Exchange code for tokens
|
||||
fmt.Println("Exchanging code for tokens...")
|
||||
|
||||
tokenReq := &CreateTokenRequest{
|
||||
Code: callback.Code,
|
||||
CodeVerifier: codeVerifier,
|
||||
RedirectURI: KiroRedirectURI,
|
||||
}
|
||||
|
||||
tokenResp, err := c.CreateToken(ctx, tokenReq)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to exchange code for tokens: %w", err)
|
||||
}
|
||||
|
||||
fmt.Println("\n✓ Authentication successful!")
|
||||
|
||||
// Close the browser window
|
||||
if err := browser.CloseBrowser(); err != nil {
|
||||
log.Debugf("Failed to close browser: %v", err)
|
||||
}
|
||||
|
||||
// Validate ExpiresIn - use default 1 hour if invalid
|
||||
expiresIn := tokenResp.ExpiresIn
|
||||
if expiresIn <= 0 {
|
||||
expiresIn = 3600
|
||||
}
|
||||
expiresAt := time.Now().Add(time.Duration(expiresIn) * time.Second)
|
||||
|
||||
// Try to extract email from JWT access token first
|
||||
email := ExtractEmailFromJWT(tokenResp.AccessToken)
|
||||
|
||||
// If no email in JWT, ask user for account label (only in interactive mode)
|
||||
if email == "" && isInteractiveTerminal() {
|
||||
fmt.Print("\n Enter account label for file naming (optional, press Enter to skip): ")
|
||||
reader := bufio.NewReader(os.Stdin)
|
||||
var err error
|
||||
email, err = reader.ReadString('\n')
|
||||
if err != nil {
|
||||
log.Debugf("Failed to read account label: %v", err)
|
||||
// Step 7: Wait for callback from HTTP server
|
||||
select {
|
||||
case <-ctx.Done():
|
||||
return nil, ctx.Err()
|
||||
case <-time.After(socialAuthTimeout):
|
||||
return nil, fmt.Errorf("authentication timed out")
|
||||
case callback := <-resultChan:
|
||||
if callback.Error != "" {
|
||||
return nil, fmt.Errorf("authentication error: %s", callback.Error)
|
||||
}
|
||||
email = strings.TrimSpace(email)
|
||||
}
|
||||
|
||||
return &KiroTokenData{
|
||||
AccessToken: tokenResp.AccessToken,
|
||||
RefreshToken: tokenResp.RefreshToken,
|
||||
ProfileArn: tokenResp.ProfileArn,
|
||||
ExpiresAt: expiresAt.Format(time.RFC3339),
|
||||
AuthMethod: "social",
|
||||
Provider: providerName,
|
||||
Email: email, // JWT email or user-provided label
|
||||
}, nil
|
||||
// State is already validated by the callback server
|
||||
if callback.Code == "" {
|
||||
return nil, fmt.Errorf("no authorization code received")
|
||||
}
|
||||
|
||||
fmt.Println("\n✓ Authorization received!")
|
||||
|
||||
// Step 8: Exchange code for tokens
|
||||
fmt.Println("Exchanging code for tokens...")
|
||||
|
||||
tokenReq := &CreateTokenRequest{
|
||||
Code: callback.Code,
|
||||
CodeVerifier: codeVerifier,
|
||||
RedirectURI: redirectURI, // Use HTTP redirect URI, not kiro:// protocol
|
||||
}
|
||||
|
||||
tokenResp, err := c.CreateToken(ctx, tokenReq)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to exchange code for tokens: %w", err)
|
||||
}
|
||||
|
||||
fmt.Println("\n✓ Authentication successful!")
|
||||
|
||||
// Close the browser window
|
||||
if err := browser.CloseBrowser(); err != nil {
|
||||
log.Debugf("Failed to close browser: %v", err)
|
||||
}
|
||||
|
||||
// Validate ExpiresIn - use default 1 hour if invalid
|
||||
expiresIn := tokenResp.ExpiresIn
|
||||
if expiresIn <= 0 {
|
||||
expiresIn = 3600
|
||||
}
|
||||
expiresAt := time.Now().Add(time.Duration(expiresIn) * time.Second)
|
||||
|
||||
// Try to extract email from JWT access token first
|
||||
email := ExtractEmailFromJWT(tokenResp.AccessToken)
|
||||
|
||||
// If no email in JWT, ask user for account label (only in interactive mode)
|
||||
if email == "" && isInteractiveTerminal() {
|
||||
fmt.Print("\n Enter account label for file naming (optional, press Enter to skip): ")
|
||||
reader := bufio.NewReader(os.Stdin)
|
||||
var err error
|
||||
email, err = reader.ReadString('\n')
|
||||
if err != nil {
|
||||
log.Debugf("Failed to read account label: %v", err)
|
||||
}
|
||||
email = strings.TrimSpace(email)
|
||||
}
|
||||
|
||||
return &KiroTokenData{
|
||||
AccessToken: tokenResp.AccessToken,
|
||||
RefreshToken: tokenResp.RefreshToken,
|
||||
ProfileArn: tokenResp.ProfileArn,
|
||||
ExpiresAt: expiresAt.Format(time.RFC3339),
|
||||
AuthMethod: "social",
|
||||
Provider: providerName,
|
||||
Email: email, // JWT email or user-provided label
|
||||
Region: "us-east-1",
|
||||
}, nil
|
||||
}
|
||||
}
|
||||
|
||||
// LoginWithGoogle performs OAuth login with Google.
|
||||
|
||||
@@ -735,6 +735,7 @@ func (c *SSOOIDCClient) RefreshToken(ctx context.Context, clientID, clientSecret
|
||||
Provider: "AWS",
|
||||
ClientID: clientID,
|
||||
ClientSecret: clientSecret,
|
||||
Region: defaultIDCRegion,
|
||||
}, nil
|
||||
}
|
||||
|
||||
@@ -850,16 +851,17 @@ func (c *SSOOIDCClient) LoginWithBuilderID(ctx context.Context) (*KiroTokenData,
|
||||
ClientID: regResp.ClientID,
|
||||
ClientSecret: regResp.ClientSecret,
|
||||
Email: email,
|
||||
Region: defaultIDCRegion,
|
||||
}, nil
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// Close browser on timeout for better UX
|
||||
if err := browser.CloseBrowser(); err != nil {
|
||||
log.Debugf("Failed to close browser on timeout: %v", err)
|
||||
}
|
||||
return nil, fmt.Errorf("authorization timed out")
|
||||
}
|
||||
// Close browser on timeout for better UX
|
||||
if err := browser.CloseBrowser(); err != nil {
|
||||
log.Debugf("Failed to close browser on timeout: %v", err)
|
||||
}
|
||||
return nil, fmt.Errorf("authorization timed out")
|
||||
}
|
||||
|
||||
// FetchUserEmail retrieves the user's email from AWS SSO OIDC userinfo endpoint.
|
||||
// Falls back to JWT parsing if userinfo fails.
|
||||
@@ -1366,6 +1368,7 @@ func (c *SSOOIDCClient) LoginWithBuilderIDAuthCode(ctx context.Context) (*KiroTo
|
||||
ClientID: regResp.ClientID,
|
||||
ClientSecret: regResp.ClientSecret,
|
||||
Email: email,
|
||||
Region: defaultIDCRegion,
|
||||
}, nil
|
||||
}
|
||||
}
|
||||
|
||||
1371
internal/auth/kiro/sso_oidc.go.bak
Normal file
1371
internal/auth/kiro/sso_oidc.go.bak
Normal file
File diff suppressed because it is too large
Load Diff
@@ -9,6 +9,8 @@ import (
|
||||
|
||||
// KiroTokenStorage holds the persistent token data for Kiro authentication.
|
||||
type KiroTokenStorage struct {
|
||||
// Type is the provider type for management UI recognition (must be "kiro")
|
||||
Type string `json:"type"`
|
||||
// AccessToken is the OAuth2 access token for API access
|
||||
AccessToken string `json:"access_token"`
|
||||
// RefreshToken is used to obtain new access tokens
|
||||
@@ -23,6 +25,16 @@ type KiroTokenStorage struct {
|
||||
Provider string `json:"provider"`
|
||||
// LastRefresh is the timestamp of the last token refresh
|
||||
LastRefresh string `json:"last_refresh"`
|
||||
// ClientID is the OAuth client ID (required for token refresh)
|
||||
ClientID string `json:"clientId,omitempty"`
|
||||
// ClientSecret is the OAuth client secret (required for token refresh)
|
||||
ClientSecret string `json:"clientSecret,omitempty"`
|
||||
// Region is the AWS region
|
||||
Region string `json:"region,omitempty"`
|
||||
// StartURL is the AWS Identity Center start URL (for IDC auth)
|
||||
StartURL string `json:"startUrl,omitempty"`
|
||||
// Email is the user's email address
|
||||
Email string `json:"email,omitempty"`
|
||||
}
|
||||
|
||||
// SaveTokenToFile persists the token storage to the specified file path.
|
||||
@@ -68,5 +80,10 @@ func (s *KiroTokenStorage) ToTokenData() *KiroTokenData {
|
||||
ExpiresAt: s.ExpiresAt,
|
||||
AuthMethod: s.AuthMethod,
|
||||
Provider: s.Provider,
|
||||
ClientID: s.ClientID,
|
||||
ClientSecret: s.ClientSecret,
|
||||
Region: s.Region,
|
||||
StartURL: s.StartURL,
|
||||
Email: s.Email,
|
||||
}
|
||||
}
|
||||
|
||||
243
internal/auth/kiro/usage_checker.go
Normal file
243
internal/auth/kiro/usage_checker.go
Normal file
@@ -0,0 +1,243 @@
|
||||
// Package kiro provides authentication functionality for AWS CodeWhisperer (Kiro) API.
|
||||
// This file implements usage quota checking and monitoring.
|
||||
package kiro
|
||||
|
||||
import (
|
||||
"context"
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"io"
|
||||
"net/http"
|
||||
"strings"
|
||||
"time"
|
||||
|
||||
"github.com/router-for-me/CLIProxyAPI/v6/internal/config"
|
||||
"github.com/router-for-me/CLIProxyAPI/v6/internal/util"
|
||||
)
|
||||
|
||||
// UsageQuotaResponse represents the API response structure for usage quota checking.
|
||||
type UsageQuotaResponse struct {
|
||||
UsageBreakdownList []UsageBreakdownExtended `json:"usageBreakdownList"`
|
||||
SubscriptionInfo *SubscriptionInfo `json:"subscriptionInfo,omitempty"`
|
||||
NextDateReset float64 `json:"nextDateReset,omitempty"`
|
||||
}
|
||||
|
||||
// UsageBreakdownExtended represents detailed usage information for quota checking.
|
||||
// Note: UsageBreakdown is already defined in codewhisperer_client.go
|
||||
type UsageBreakdownExtended struct {
|
||||
ResourceType string `json:"resourceType"`
|
||||
UsageLimitWithPrecision float64 `json:"usageLimitWithPrecision"`
|
||||
CurrentUsageWithPrecision float64 `json:"currentUsageWithPrecision"`
|
||||
FreeTrialInfo *FreeTrialInfoExtended `json:"freeTrialInfo,omitempty"`
|
||||
}
|
||||
|
||||
// FreeTrialInfoExtended represents free trial usage information.
|
||||
type FreeTrialInfoExtended struct {
|
||||
FreeTrialStatus string `json:"freeTrialStatus"`
|
||||
UsageLimitWithPrecision float64 `json:"usageLimitWithPrecision"`
|
||||
CurrentUsageWithPrecision float64 `json:"currentUsageWithPrecision"`
|
||||
}
|
||||
|
||||
// QuotaStatus represents the quota status for a token.
|
||||
type QuotaStatus struct {
|
||||
TotalLimit float64
|
||||
CurrentUsage float64
|
||||
RemainingQuota float64
|
||||
IsExhausted bool
|
||||
ResourceType string
|
||||
NextReset time.Time
|
||||
}
|
||||
|
||||
// UsageChecker provides methods for checking token quota usage.
|
||||
type UsageChecker struct {
|
||||
httpClient *http.Client
|
||||
endpoint string
|
||||
}
|
||||
|
||||
// NewUsageChecker creates a new UsageChecker instance.
|
||||
func NewUsageChecker(cfg *config.Config) *UsageChecker {
|
||||
return &UsageChecker{
|
||||
httpClient: util.SetProxy(&cfg.SDKConfig, &http.Client{Timeout: 30 * time.Second}),
|
||||
endpoint: awsKiroEndpoint,
|
||||
}
|
||||
}
|
||||
|
||||
// NewUsageCheckerWithClient creates a UsageChecker with a custom HTTP client.
|
||||
func NewUsageCheckerWithClient(client *http.Client) *UsageChecker {
|
||||
return &UsageChecker{
|
||||
httpClient: client,
|
||||
endpoint: awsKiroEndpoint,
|
||||
}
|
||||
}
|
||||
|
||||
// CheckUsage retrieves usage limits for the given token.
|
||||
func (c *UsageChecker) CheckUsage(ctx context.Context, tokenData *KiroTokenData) (*UsageQuotaResponse, error) {
|
||||
if tokenData == nil {
|
||||
return nil, fmt.Errorf("token data is nil")
|
||||
}
|
||||
|
||||
if tokenData.AccessToken == "" {
|
||||
return nil, fmt.Errorf("access token is empty")
|
||||
}
|
||||
|
||||
payload := map[string]interface{}{
|
||||
"origin": "AI_EDITOR",
|
||||
"profileArn": tokenData.ProfileArn,
|
||||
"resourceType": "AGENTIC_REQUEST",
|
||||
}
|
||||
|
||||
jsonBody, err := json.Marshal(payload)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to marshal request: %w", err)
|
||||
}
|
||||
|
||||
req, err := http.NewRequestWithContext(ctx, http.MethodPost, c.endpoint, strings.NewReader(string(jsonBody)))
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to create request: %w", err)
|
||||
}
|
||||
|
||||
req.Header.Set("Content-Type", "application/x-amz-json-1.0")
|
||||
req.Header.Set("x-amz-target", targetGetUsage)
|
||||
req.Header.Set("Authorization", "Bearer "+tokenData.AccessToken)
|
||||
req.Header.Set("Accept", "application/json")
|
||||
|
||||
resp, err := c.httpClient.Do(req)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("request failed: %w", err)
|
||||
}
|
||||
defer resp.Body.Close()
|
||||
|
||||
body, err := io.ReadAll(resp.Body)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to read response: %w", err)
|
||||
}
|
||||
|
||||
if resp.StatusCode != http.StatusOK {
|
||||
return nil, fmt.Errorf("API error (status %d): %s", resp.StatusCode, string(body))
|
||||
}
|
||||
|
||||
var result UsageQuotaResponse
|
||||
if err := json.Unmarshal(body, &result); err != nil {
|
||||
return nil, fmt.Errorf("failed to parse usage response: %w", err)
|
||||
}
|
||||
|
||||
return &result, nil
|
||||
}
|
||||
|
||||
// CheckUsageByAccessToken retrieves usage limits using an access token and profile ARN directly.
|
||||
func (c *UsageChecker) CheckUsageByAccessToken(ctx context.Context, accessToken, profileArn string) (*UsageQuotaResponse, error) {
|
||||
tokenData := &KiroTokenData{
|
||||
AccessToken: accessToken,
|
||||
ProfileArn: profileArn,
|
||||
}
|
||||
return c.CheckUsage(ctx, tokenData)
|
||||
}
|
||||
|
||||
// GetRemainingQuota calculates the remaining quota from usage limits.
|
||||
func GetRemainingQuota(usage *UsageQuotaResponse) float64 {
|
||||
if usage == nil || len(usage.UsageBreakdownList) == 0 {
|
||||
return 0
|
||||
}
|
||||
|
||||
var totalRemaining float64
|
||||
for _, breakdown := range usage.UsageBreakdownList {
|
||||
remaining := breakdown.UsageLimitWithPrecision - breakdown.CurrentUsageWithPrecision
|
||||
if remaining > 0 {
|
||||
totalRemaining += remaining
|
||||
}
|
||||
|
||||
if breakdown.FreeTrialInfo != nil {
|
||||
freeRemaining := breakdown.FreeTrialInfo.UsageLimitWithPrecision - breakdown.FreeTrialInfo.CurrentUsageWithPrecision
|
||||
if freeRemaining > 0 {
|
||||
totalRemaining += freeRemaining
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
return totalRemaining
|
||||
}
|
||||
|
||||
// IsQuotaExhausted checks if the quota is exhausted based on usage limits.
|
||||
func IsQuotaExhausted(usage *UsageQuotaResponse) bool {
|
||||
if usage == nil || len(usage.UsageBreakdownList) == 0 {
|
||||
return true
|
||||
}
|
||||
|
||||
for _, breakdown := range usage.UsageBreakdownList {
|
||||
if breakdown.CurrentUsageWithPrecision < breakdown.UsageLimitWithPrecision {
|
||||
return false
|
||||
}
|
||||
|
||||
if breakdown.FreeTrialInfo != nil {
|
||||
if breakdown.FreeTrialInfo.CurrentUsageWithPrecision < breakdown.FreeTrialInfo.UsageLimitWithPrecision {
|
||||
return false
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
return true
|
||||
}
|
||||
|
||||
// GetQuotaStatus retrieves a comprehensive quota status for a token.
|
||||
func (c *UsageChecker) GetQuotaStatus(ctx context.Context, tokenData *KiroTokenData) (*QuotaStatus, error) {
|
||||
usage, err := c.CheckUsage(ctx, tokenData)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
status := &QuotaStatus{
|
||||
IsExhausted: IsQuotaExhausted(usage),
|
||||
}
|
||||
|
||||
if len(usage.UsageBreakdownList) > 0 {
|
||||
breakdown := usage.UsageBreakdownList[0]
|
||||
status.TotalLimit = breakdown.UsageLimitWithPrecision
|
||||
status.CurrentUsage = breakdown.CurrentUsageWithPrecision
|
||||
status.RemainingQuota = breakdown.UsageLimitWithPrecision - breakdown.CurrentUsageWithPrecision
|
||||
status.ResourceType = breakdown.ResourceType
|
||||
|
||||
if breakdown.FreeTrialInfo != nil {
|
||||
status.TotalLimit += breakdown.FreeTrialInfo.UsageLimitWithPrecision
|
||||
status.CurrentUsage += breakdown.FreeTrialInfo.CurrentUsageWithPrecision
|
||||
freeRemaining := breakdown.FreeTrialInfo.UsageLimitWithPrecision - breakdown.FreeTrialInfo.CurrentUsageWithPrecision
|
||||
if freeRemaining > 0 {
|
||||
status.RemainingQuota += freeRemaining
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
if usage.NextDateReset > 0 {
|
||||
status.NextReset = time.Unix(int64(usage.NextDateReset/1000), 0)
|
||||
}
|
||||
|
||||
return status, nil
|
||||
}
|
||||
|
||||
// CalculateAvailableCount calculates the available request count based on usage limits.
|
||||
func CalculateAvailableCount(usage *UsageQuotaResponse) float64 {
|
||||
return GetRemainingQuota(usage)
|
||||
}
|
||||
|
||||
// GetUsagePercentage calculates the usage percentage.
|
||||
func GetUsagePercentage(usage *UsageQuotaResponse) float64 {
|
||||
if usage == nil || len(usage.UsageBreakdownList) == 0 {
|
||||
return 100.0
|
||||
}
|
||||
|
||||
var totalLimit, totalUsage float64
|
||||
for _, breakdown := range usage.UsageBreakdownList {
|
||||
totalLimit += breakdown.UsageLimitWithPrecision
|
||||
totalUsage += breakdown.CurrentUsageWithPrecision
|
||||
|
||||
if breakdown.FreeTrialInfo != nil {
|
||||
totalLimit += breakdown.FreeTrialInfo.UsageLimitWithPrecision
|
||||
totalUsage += breakdown.FreeTrialInfo.CurrentUsageWithPrecision
|
||||
}
|
||||
}
|
||||
|
||||
if totalLimit == 0 {
|
||||
return 100.0
|
||||
}
|
||||
|
||||
return (totalUsage / totalLimit) * 100
|
||||
}
|
||||
303
internal/registry/kiro_model_converter.go
Normal file
303
internal/registry/kiro_model_converter.go
Normal file
@@ -0,0 +1,303 @@
|
||||
// Package registry provides Kiro model conversion utilities.
|
||||
// This file handles converting dynamic Kiro API model lists to the internal ModelInfo format,
|
||||
// and merging with static metadata for thinking support and other capabilities.
|
||||
package registry
|
||||
|
||||
import (
|
||||
"strings"
|
||||
"time"
|
||||
)
|
||||
|
||||
// KiroAPIModel represents a model from Kiro API response.
|
||||
// This is a local copy to avoid import cycles with the kiro package.
|
||||
// The structure mirrors kiro.KiroModel for easy data conversion.
|
||||
type KiroAPIModel struct {
|
||||
// ModelID is the unique identifier for the model (e.g., "claude-sonnet-4.5")
|
||||
ModelID string
|
||||
// ModelName is the human-readable name
|
||||
ModelName string
|
||||
// Description is the model description
|
||||
Description string
|
||||
// RateMultiplier is the credit multiplier for this model
|
||||
RateMultiplier float64
|
||||
// RateUnit is the unit for rate calculation (e.g., "credit")
|
||||
RateUnit string
|
||||
// MaxInputTokens is the maximum input token limit
|
||||
MaxInputTokens int
|
||||
}
|
||||
|
||||
// DefaultKiroThinkingSupport defines the default thinking configuration for Kiro models.
|
||||
// All Kiro models support thinking with the following budget range.
|
||||
var DefaultKiroThinkingSupport = &ThinkingSupport{
|
||||
Min: 1024, // Minimum thinking budget tokens
|
||||
Max: 32000, // Maximum thinking budget tokens
|
||||
ZeroAllowed: true, // Allow disabling thinking with 0
|
||||
DynamicAllowed: true, // Allow dynamic thinking budget (-1)
|
||||
}
|
||||
|
||||
// DefaultKiroContextLength is the default context window size for Kiro models.
|
||||
const DefaultKiroContextLength = 200000
|
||||
|
||||
// DefaultKiroMaxCompletionTokens is the default max completion tokens for Kiro models.
|
||||
const DefaultKiroMaxCompletionTokens = 64000
|
||||
|
||||
// ConvertKiroAPIModels converts Kiro API models to internal ModelInfo format.
|
||||
// It performs the following transformations:
|
||||
// - Normalizes model ID (e.g., claude-sonnet-4.5 → kiro-claude-sonnet-4-5)
|
||||
// - Adds default thinking support metadata
|
||||
// - Sets default context length and max completion tokens if not provided
|
||||
//
|
||||
// Parameters:
|
||||
// - kiroModels: List of models from Kiro API response
|
||||
//
|
||||
// Returns:
|
||||
// - []*ModelInfo: Converted model information list
|
||||
func ConvertKiroAPIModels(kiroModels []*KiroAPIModel) []*ModelInfo {
|
||||
if len(kiroModels) == 0 {
|
||||
return nil
|
||||
}
|
||||
|
||||
now := time.Now().Unix()
|
||||
result := make([]*ModelInfo, 0, len(kiroModels))
|
||||
|
||||
for _, km := range kiroModels {
|
||||
// Skip nil models
|
||||
if km == nil {
|
||||
continue
|
||||
}
|
||||
|
||||
// Skip models without valid ID
|
||||
if km.ModelID == "" {
|
||||
continue
|
||||
}
|
||||
|
||||
// Normalize the model ID to kiro-* format
|
||||
normalizedID := normalizeKiroModelID(km.ModelID)
|
||||
|
||||
// Create ModelInfo with converted data
|
||||
info := &ModelInfo{
|
||||
ID: normalizedID,
|
||||
Object: "model",
|
||||
Created: now,
|
||||
OwnedBy: "aws",
|
||||
Type: "kiro",
|
||||
DisplayName: generateKiroDisplayName(km.ModelName, normalizedID),
|
||||
Description: km.Description,
|
||||
// Use MaxInputTokens from API if available, otherwise use default
|
||||
ContextLength: getContextLength(km.MaxInputTokens),
|
||||
MaxCompletionTokens: DefaultKiroMaxCompletionTokens,
|
||||
// All Kiro models support thinking
|
||||
Thinking: cloneThinkingSupport(DefaultKiroThinkingSupport),
|
||||
}
|
||||
|
||||
result = append(result, info)
|
||||
}
|
||||
|
||||
return result
|
||||
}
|
||||
|
||||
// GenerateAgenticVariants creates -agentic variants for each model.
|
||||
// Agentic variants are optimized for coding agents with chunked writes.
|
||||
//
|
||||
// Parameters:
|
||||
// - models: Base models to generate variants for
|
||||
//
|
||||
// Returns:
|
||||
// - []*ModelInfo: Combined list of base models and their agentic variants
|
||||
func GenerateAgenticVariants(models []*ModelInfo) []*ModelInfo {
|
||||
if len(models) == 0 {
|
||||
return nil
|
||||
}
|
||||
|
||||
// Pre-allocate result with capacity for both base models and variants
|
||||
result := make([]*ModelInfo, 0, len(models)*2)
|
||||
|
||||
for _, model := range models {
|
||||
if model == nil {
|
||||
continue
|
||||
}
|
||||
|
||||
// Add the base model first
|
||||
result = append(result, model)
|
||||
|
||||
// Skip if model already has -agentic suffix
|
||||
if strings.HasSuffix(model.ID, "-agentic") {
|
||||
continue
|
||||
}
|
||||
|
||||
// Skip special models that shouldn't have agentic variants
|
||||
if model.ID == "kiro-auto" {
|
||||
continue
|
||||
}
|
||||
|
||||
// Create agentic variant
|
||||
agenticModel := &ModelInfo{
|
||||
ID: model.ID + "-agentic",
|
||||
Object: model.Object,
|
||||
Created: model.Created,
|
||||
OwnedBy: model.OwnedBy,
|
||||
Type: model.Type,
|
||||
DisplayName: model.DisplayName + " (Agentic)",
|
||||
Description: generateAgenticDescription(model.Description),
|
||||
ContextLength: model.ContextLength,
|
||||
MaxCompletionTokens: model.MaxCompletionTokens,
|
||||
Thinking: cloneThinkingSupport(model.Thinking),
|
||||
}
|
||||
|
||||
result = append(result, agenticModel)
|
||||
}
|
||||
|
||||
return result
|
||||
}
|
||||
|
||||
// MergeWithStaticMetadata merges dynamic models with static metadata.
|
||||
// Static metadata takes priority for any overlapping fields.
|
||||
// This allows manual overrides for specific models while keeping dynamic discovery.
|
||||
//
|
||||
// Parameters:
|
||||
// - dynamicModels: Models from Kiro API (converted to ModelInfo)
|
||||
// - staticModels: Predefined model metadata (from GetKiroModels())
|
||||
//
|
||||
// Returns:
|
||||
// - []*ModelInfo: Merged model list with static metadata taking priority
|
||||
func MergeWithStaticMetadata(dynamicModels, staticModels []*ModelInfo) []*ModelInfo {
|
||||
if len(dynamicModels) == 0 && len(staticModels) == 0 {
|
||||
return nil
|
||||
}
|
||||
|
||||
// Build a map of static models for quick lookup
|
||||
staticMap := make(map[string]*ModelInfo, len(staticModels))
|
||||
for _, sm := range staticModels {
|
||||
if sm != nil && sm.ID != "" {
|
||||
staticMap[sm.ID] = sm
|
||||
}
|
||||
}
|
||||
|
||||
// Build result, preferring static metadata where available
|
||||
seenIDs := make(map[string]struct{})
|
||||
result := make([]*ModelInfo, 0, len(dynamicModels)+len(staticModels))
|
||||
|
||||
// First, process dynamic models and merge with static if available
|
||||
for _, dm := range dynamicModels {
|
||||
if dm == nil || dm.ID == "" {
|
||||
continue
|
||||
}
|
||||
|
||||
// Skip duplicates
|
||||
if _, seen := seenIDs[dm.ID]; seen {
|
||||
continue
|
||||
}
|
||||
seenIDs[dm.ID] = struct{}{}
|
||||
|
||||
// Check if static metadata exists for this model
|
||||
if sm, exists := staticMap[dm.ID]; exists {
|
||||
// Static metadata takes priority - use static model
|
||||
result = append(result, sm)
|
||||
} else {
|
||||
// No static metadata - use dynamic model
|
||||
result = append(result, dm)
|
||||
}
|
||||
}
|
||||
|
||||
// Add any static models not in dynamic list
|
||||
for _, sm := range staticModels {
|
||||
if sm == nil || sm.ID == "" {
|
||||
continue
|
||||
}
|
||||
if _, seen := seenIDs[sm.ID]; seen {
|
||||
continue
|
||||
}
|
||||
seenIDs[sm.ID] = struct{}{}
|
||||
result = append(result, sm)
|
||||
}
|
||||
|
||||
return result
|
||||
}
|
||||
|
||||
// normalizeKiroModelID converts Kiro API model IDs to internal format.
|
||||
// Transformation rules:
|
||||
// - Adds "kiro-" prefix if not present
|
||||
// - Replaces dots with hyphens (e.g., 4.5 → 4-5)
|
||||
// - Handles special cases like "auto" → "kiro-auto"
|
||||
//
|
||||
// Examples:
|
||||
// - "claude-sonnet-4.5" → "kiro-claude-sonnet-4-5"
|
||||
// - "claude-opus-4.5" → "kiro-claude-opus-4-5"
|
||||
// - "auto" → "kiro-auto"
|
||||
// - "kiro-claude-sonnet-4-5" → "kiro-claude-sonnet-4-5" (unchanged)
|
||||
func normalizeKiroModelID(modelID string) string {
|
||||
if modelID == "" {
|
||||
return ""
|
||||
}
|
||||
|
||||
// Trim whitespace
|
||||
modelID = strings.TrimSpace(modelID)
|
||||
|
||||
// Replace dots with hyphens (e.g., 4.5 → 4-5)
|
||||
normalized := strings.ReplaceAll(modelID, ".", "-")
|
||||
|
||||
// Add kiro- prefix if not present
|
||||
if !strings.HasPrefix(normalized, "kiro-") {
|
||||
normalized = "kiro-" + normalized
|
||||
}
|
||||
|
||||
return normalized
|
||||
}
|
||||
|
||||
// generateKiroDisplayName creates a human-readable display name.
|
||||
// Uses the API-provided model name if available, otherwise generates from ID.
|
||||
func generateKiroDisplayName(modelName, normalizedID string) string {
|
||||
if modelName != "" {
|
||||
return "Kiro " + modelName
|
||||
}
|
||||
|
||||
// Generate from normalized ID by removing kiro- prefix and formatting
|
||||
displayID := strings.TrimPrefix(normalizedID, "kiro-")
|
||||
// Capitalize first letter of each word
|
||||
words := strings.Split(displayID, "-")
|
||||
for i, word := range words {
|
||||
if len(word) > 0 {
|
||||
words[i] = strings.ToUpper(word[:1]) + word[1:]
|
||||
}
|
||||
}
|
||||
return "Kiro " + strings.Join(words, " ")
|
||||
}
|
||||
|
||||
// generateAgenticDescription creates description for agentic variants.
|
||||
func generateAgenticDescription(baseDescription string) string {
|
||||
if baseDescription == "" {
|
||||
return "Optimized for coding agents with chunked writes"
|
||||
}
|
||||
return baseDescription + " (Agentic mode: chunked writes)"
|
||||
}
|
||||
|
||||
// getContextLength returns the context length, using default if not provided.
|
||||
func getContextLength(maxInputTokens int) int {
|
||||
if maxInputTokens > 0 {
|
||||
return maxInputTokens
|
||||
}
|
||||
return DefaultKiroContextLength
|
||||
}
|
||||
|
||||
// cloneThinkingSupport creates a deep copy of ThinkingSupport.
|
||||
// Returns nil if input is nil.
|
||||
func cloneThinkingSupport(ts *ThinkingSupport) *ThinkingSupport {
|
||||
if ts == nil {
|
||||
return nil
|
||||
}
|
||||
|
||||
clone := &ThinkingSupport{
|
||||
Min: ts.Min,
|
||||
Max: ts.Max,
|
||||
ZeroAllowed: ts.ZeroAllowed,
|
||||
DynamicAllowed: ts.DynamicAllowed,
|
||||
}
|
||||
|
||||
// Deep copy Levels slice if present
|
||||
if len(ts.Levels) > 0 {
|
||||
clone.Levels = make([]string, len(ts.Levels))
|
||||
copy(clone.Levels, ts.Levels)
|
||||
}
|
||||
|
||||
return clone
|
||||
}
|
||||
@@ -7,13 +7,16 @@ import (
|
||||
"encoding/base64"
|
||||
"encoding/binary"
|
||||
"encoding/json"
|
||||
"errors"
|
||||
"fmt"
|
||||
"io"
|
||||
"net"
|
||||
"net/http"
|
||||
"os"
|
||||
"path/filepath"
|
||||
"strings"
|
||||
"sync"
|
||||
"syscall"
|
||||
"time"
|
||||
|
||||
"github.com/google/uuid"
|
||||
@@ -53,9 +56,28 @@ const (
|
||||
kiroIDEUserAgent = "aws-sdk-js/1.0.18 ua/2.1 os/darwin#25.0.0 lang/js md/nodejs#20.16.0 api/codewhispererstreaming#1.0.18 m/E KiroIDE-0.2.13-66c23a8c5d15afabec89ef9954ef52a119f10d369df04d548fc6c1eac694b0d1"
|
||||
kiroIDEAmzUserAgent = "aws-sdk-js/1.0.18 KiroIDE-0.2.13-66c23a8c5d15afabec89ef9954ef52a119f10d369df04d548fc6c1eac694b0d1"
|
||||
kiroIDEAgentModeSpec = "spec"
|
||||
kiroAgentModeVibe = "vibe"
|
||||
|
||||
// Socket retry configuration constants (based on kiro2Api reference implementation)
|
||||
// Maximum number of retry attempts for socket/network errors
|
||||
kiroSocketMaxRetries = 3
|
||||
// Base delay between retry attempts (uses exponential backoff: delay * 2^attempt)
|
||||
kiroSocketBaseRetryDelay = 1 * time.Second
|
||||
// Maximum delay between retry attempts (cap for exponential backoff)
|
||||
kiroSocketMaxRetryDelay = 30 * time.Second
|
||||
// First token timeout for streaming responses (how long to wait for first response)
|
||||
kiroFirstTokenTimeout = 15 * time.Second
|
||||
// Streaming read timeout (how long to wait between chunks)
|
||||
kiroStreamingReadTimeout = 300 * time.Second
|
||||
)
|
||||
|
||||
// retryableHTTPStatusCodes defines HTTP status codes that are considered retryable.
|
||||
// Based on kiro2Api reference: 502 (Bad Gateway), 503 (Service Unavailable), 504 (Gateway Timeout)
|
||||
var retryableHTTPStatusCodes = map[int]bool{
|
||||
502: true, // Bad Gateway - upstream server error
|
||||
503: true, // Service Unavailable - server temporarily overloaded
|
||||
504: true, // Gateway Timeout - upstream server timeout
|
||||
}
|
||||
|
||||
// Real-time usage estimation configuration
|
||||
// These control how often usage updates are sent during streaming
|
||||
var (
|
||||
@@ -63,6 +85,241 @@ var (
|
||||
usageUpdateTimeInterval = 15 * time.Second // Or every 15 seconds, whichever comes first
|
||||
)
|
||||
|
||||
// Global FingerprintManager for dynamic User-Agent generation per token
|
||||
// Each token gets a unique fingerprint on first use, which is cached for subsequent requests
|
||||
var (
|
||||
globalFingerprintManager *kiroauth.FingerprintManager
|
||||
globalFingerprintManagerOnce sync.Once
|
||||
)
|
||||
|
||||
// getGlobalFingerprintManager returns the global FingerprintManager instance
|
||||
func getGlobalFingerprintManager() *kiroauth.FingerprintManager {
|
||||
globalFingerprintManagerOnce.Do(func() {
|
||||
globalFingerprintManager = kiroauth.NewFingerprintManager()
|
||||
log.Infof("kiro: initialized global FingerprintManager for dynamic UA generation")
|
||||
})
|
||||
return globalFingerprintManager
|
||||
}
|
||||
|
||||
// retryConfig holds configuration for socket retry logic.
|
||||
// Based on kiro2Api Python implementation patterns.
|
||||
type retryConfig struct {
|
||||
MaxRetries int // Maximum number of retry attempts
|
||||
BaseDelay time.Duration // Base delay between retries (exponential backoff)
|
||||
MaxDelay time.Duration // Maximum delay cap
|
||||
RetryableErrors []string // List of retryable error patterns
|
||||
RetryableStatus map[int]bool // HTTP status codes to retry
|
||||
FirstTokenTmout time.Duration // Timeout for first token in streaming
|
||||
StreamReadTmout time.Duration // Timeout between stream chunks
|
||||
}
|
||||
|
||||
// defaultRetryConfig returns the default retry configuration for Kiro socket operations.
|
||||
func defaultRetryConfig() retryConfig {
|
||||
return retryConfig{
|
||||
MaxRetries: kiroSocketMaxRetries,
|
||||
BaseDelay: kiroSocketBaseRetryDelay,
|
||||
MaxDelay: kiroSocketMaxRetryDelay,
|
||||
RetryableStatus: retryableHTTPStatusCodes,
|
||||
RetryableErrors: []string{
|
||||
"connection reset",
|
||||
"connection refused",
|
||||
"broken pipe",
|
||||
"EOF",
|
||||
"timeout",
|
||||
"temporary failure",
|
||||
"no such host",
|
||||
"network is unreachable",
|
||||
"i/o timeout",
|
||||
},
|
||||
FirstTokenTmout: kiroFirstTokenTimeout,
|
||||
StreamReadTmout: kiroStreamingReadTimeout,
|
||||
}
|
||||
}
|
||||
|
||||
// isRetryableError checks if an error is retryable based on error type and message.
|
||||
// Returns true for network timeouts, connection resets, and temporary failures.
|
||||
// Based on kiro2Api's retry logic patterns.
|
||||
func isRetryableError(err error) bool {
|
||||
if err == nil {
|
||||
return false
|
||||
}
|
||||
|
||||
// Check for context cancellation - not retryable
|
||||
if errors.Is(err, context.Canceled) || errors.Is(err, context.DeadlineExceeded) {
|
||||
return false
|
||||
}
|
||||
|
||||
// Check for net.Error (timeout, temporary)
|
||||
var netErr net.Error
|
||||
if errors.As(err, &netErr) {
|
||||
if netErr.Timeout() {
|
||||
log.Debugf("kiro: isRetryableError: network timeout detected")
|
||||
return true
|
||||
}
|
||||
// Note: Temporary() is deprecated but still useful for some error types
|
||||
}
|
||||
|
||||
// Check for specific syscall errors (connection reset, broken pipe, etc.)
|
||||
var syscallErr syscall.Errno
|
||||
if errors.As(err, &syscallErr) {
|
||||
switch syscallErr {
|
||||
case syscall.ECONNRESET: // Connection reset by peer
|
||||
log.Debugf("kiro: isRetryableError: ECONNRESET detected")
|
||||
return true
|
||||
case syscall.ECONNREFUSED: // Connection refused
|
||||
log.Debugf("kiro: isRetryableError: ECONNREFUSED detected")
|
||||
return true
|
||||
case syscall.EPIPE: // Broken pipe
|
||||
log.Debugf("kiro: isRetryableError: EPIPE (broken pipe) detected")
|
||||
return true
|
||||
case syscall.ETIMEDOUT: // Connection timed out
|
||||
log.Debugf("kiro: isRetryableError: ETIMEDOUT detected")
|
||||
return true
|
||||
case syscall.ENETUNREACH: // Network is unreachable
|
||||
log.Debugf("kiro: isRetryableError: ENETUNREACH detected")
|
||||
return true
|
||||
case syscall.EHOSTUNREACH: // No route to host
|
||||
log.Debugf("kiro: isRetryableError: EHOSTUNREACH detected")
|
||||
return true
|
||||
}
|
||||
}
|
||||
|
||||
// Check for net.OpError wrapping other errors
|
||||
var opErr *net.OpError
|
||||
if errors.As(err, &opErr) {
|
||||
log.Debugf("kiro: isRetryableError: net.OpError detected, op=%s", opErr.Op)
|
||||
// Recursively check the wrapped error
|
||||
if opErr.Err != nil {
|
||||
return isRetryableError(opErr.Err)
|
||||
}
|
||||
return true
|
||||
}
|
||||
|
||||
// Check error message for retryable patterns
|
||||
errMsg := strings.ToLower(err.Error())
|
||||
cfg := defaultRetryConfig()
|
||||
for _, pattern := range cfg.RetryableErrors {
|
||||
if strings.Contains(errMsg, pattern) {
|
||||
log.Debugf("kiro: isRetryableError: pattern '%s' matched in error: %s", pattern, errMsg)
|
||||
return true
|
||||
}
|
||||
}
|
||||
|
||||
// Check for EOF which may indicate connection was closed
|
||||
if errors.Is(err, io.EOF) || errors.Is(err, io.ErrUnexpectedEOF) {
|
||||
log.Debugf("kiro: isRetryableError: EOF/UnexpectedEOF detected")
|
||||
return true
|
||||
}
|
||||
|
||||
return false
|
||||
}
|
||||
|
||||
// isRetryableHTTPStatus checks if an HTTP status code is retryable.
|
||||
// Based on kiro2Api: 502, 503, 504 are retryable server errors.
|
||||
func isRetryableHTTPStatus(statusCode int) bool {
|
||||
return retryableHTTPStatusCodes[statusCode]
|
||||
}
|
||||
|
||||
// calculateRetryDelay calculates the delay for the next retry attempt using exponential backoff.
|
||||
// delay = min(baseDelay * 2^attempt, maxDelay)
|
||||
// Adds ±30% jitter to prevent thundering herd.
|
||||
func calculateRetryDelay(attempt int, cfg retryConfig) time.Duration {
|
||||
return kiroauth.ExponentialBackoffWithJitter(attempt, cfg.BaseDelay, cfg.MaxDelay)
|
||||
}
|
||||
|
||||
// logRetryAttempt logs a retry attempt with relevant context.
|
||||
func logRetryAttempt(attempt, maxRetries int, reason string, delay time.Duration, endpoint string) {
|
||||
log.Warnf("kiro: retry attempt %d/%d for %s, waiting %v before next attempt (endpoint: %s)",
|
||||
attempt+1, maxRetries, reason, delay, endpoint)
|
||||
}
|
||||
|
||||
// kiroHTTPClientPool provides a shared HTTP client with connection pooling for Kiro API.
|
||||
// This reduces connection overhead and improves performance for concurrent requests.
|
||||
// Based on kiro2Api's connection pooling pattern.
|
||||
var (
|
||||
kiroHTTPClientPool *http.Client
|
||||
kiroHTTPClientPoolOnce sync.Once
|
||||
)
|
||||
|
||||
// getKiroPooledHTTPClient returns a shared HTTP client with optimized connection pooling.
|
||||
// The client is lazily initialized on first use and reused across requests.
|
||||
// This is especially beneficial for:
|
||||
// - Reducing TCP handshake overhead
|
||||
// - Enabling HTTP/2 multiplexing
|
||||
// - Better handling of keep-alive connections
|
||||
func getKiroPooledHTTPClient() *http.Client {
|
||||
kiroHTTPClientPoolOnce.Do(func() {
|
||||
transport := &http.Transport{
|
||||
// Connection pool settings
|
||||
MaxIdleConns: 100, // Max idle connections across all hosts
|
||||
MaxIdleConnsPerHost: 20, // Max idle connections per host
|
||||
MaxConnsPerHost: 50, // Max total connections per host
|
||||
IdleConnTimeout: 90 * time.Second, // How long idle connections stay in pool
|
||||
|
||||
// Timeouts for connection establishment
|
||||
DialContext: (&net.Dialer{
|
||||
Timeout: 30 * time.Second, // TCP connection timeout
|
||||
KeepAlive: 30 * time.Second, // TCP keep-alive interval
|
||||
}).DialContext,
|
||||
|
||||
// TLS handshake timeout
|
||||
TLSHandshakeTimeout: 10 * time.Second,
|
||||
|
||||
// Response header timeout
|
||||
ResponseHeaderTimeout: 30 * time.Second,
|
||||
|
||||
// Expect 100-continue timeout
|
||||
ExpectContinueTimeout: 1 * time.Second,
|
||||
|
||||
// Enable HTTP/2 when available
|
||||
ForceAttemptHTTP2: true,
|
||||
}
|
||||
|
||||
kiroHTTPClientPool = &http.Client{
|
||||
Transport: transport,
|
||||
// No global timeout - let individual requests set their own timeouts via context
|
||||
}
|
||||
|
||||
log.Debugf("kiro: initialized pooled HTTP client (MaxIdleConns=%d, MaxIdleConnsPerHost=%d, MaxConnsPerHost=%d)",
|
||||
transport.MaxIdleConns, transport.MaxIdleConnsPerHost, transport.MaxConnsPerHost)
|
||||
})
|
||||
|
||||
return kiroHTTPClientPool
|
||||
}
|
||||
|
||||
// newKiroHTTPClientWithPooling creates an HTTP client that uses connection pooling when appropriate.
|
||||
// It respects proxy configuration from auth or config, falling back to the pooled client.
|
||||
// This provides the best of both worlds: custom proxy support + connection reuse.
|
||||
func newKiroHTTPClientWithPooling(ctx context.Context, cfg *config.Config, auth *cliproxyauth.Auth, timeout time.Duration) *http.Client {
|
||||
// Check if a proxy is configured - if so, we need a custom client
|
||||
var proxyURL string
|
||||
if auth != nil {
|
||||
proxyURL = strings.TrimSpace(auth.ProxyURL)
|
||||
}
|
||||
if proxyURL == "" && cfg != nil {
|
||||
proxyURL = strings.TrimSpace(cfg.ProxyURL)
|
||||
}
|
||||
|
||||
// If proxy is configured, use the existing proxy-aware client (doesn't pool)
|
||||
if proxyURL != "" {
|
||||
log.Debugf("kiro: using proxy-aware HTTP client (proxy=%s)", proxyURL)
|
||||
return newProxyAwareHTTPClient(ctx, cfg, auth, timeout)
|
||||
}
|
||||
|
||||
// No proxy - use pooled client for better performance
|
||||
pooledClient := getKiroPooledHTTPClient()
|
||||
|
||||
// If timeout is specified, we need to wrap the pooled transport with timeout
|
||||
if timeout > 0 {
|
||||
return &http.Client{
|
||||
Transport: pooledClient.Transport,
|
||||
Timeout: timeout,
|
||||
}
|
||||
}
|
||||
|
||||
return pooledClient
|
||||
}
|
||||
|
||||
// kiroEndpointConfig bundles endpoint URL with its compatible Origin and AmzTarget values.
|
||||
// This solves the "triple mismatch" problem where different endpoints require matching
|
||||
// Origin and X-Amz-Target header values.
|
||||
@@ -99,7 +356,7 @@ var kiroEndpointConfigs = []kiroEndpointConfig{
|
||||
Name: "CodeWhisperer",
|
||||
},
|
||||
{
|
||||
URL: "https://q.us-east-1.amazonaws.com/generateAssistantResponse",
|
||||
URL: "https://q.us-east-1.amazonaws.com/",
|
||||
Origin: "CLI",
|
||||
AmzTarget: "AmazonQDeveloperStreamingService.SendMessage",
|
||||
Name: "AmazonQ",
|
||||
@@ -217,6 +474,29 @@ func NewKiroExecutor(cfg *config.Config) *KiroExecutor {
|
||||
// Identifier returns the unique identifier for this executor.
|
||||
func (e *KiroExecutor) Identifier() string { return "kiro" }
|
||||
|
||||
// applyDynamicFingerprint applies token-specific fingerprint headers to the request
|
||||
// For IDC auth, uses dynamic fingerprint-based User-Agent
|
||||
// For other auth types, uses static Amazon Q CLI style headers
|
||||
func applyDynamicFingerprint(req *http.Request, auth *cliproxyauth.Auth) {
|
||||
if isIDCAuth(auth) {
|
||||
// Get token-specific fingerprint for dynamic UA generation
|
||||
tokenKey := getTokenKey(auth)
|
||||
fp := getGlobalFingerprintManager().GetFingerprint(tokenKey)
|
||||
|
||||
// Use fingerprint-generated dynamic User-Agent
|
||||
req.Header.Set("User-Agent", fp.BuildUserAgent())
|
||||
req.Header.Set("X-Amz-User-Agent", fp.BuildAmzUserAgent())
|
||||
req.Header.Set("x-amzn-kiro-agent-mode", kiroIDEAgentModeSpec)
|
||||
|
||||
log.Debugf("kiro: using dynamic fingerprint for token %s (SDK:%s, OS:%s/%s, Kiro:%s)",
|
||||
tokenKey[:8]+"...", fp.SDKVersion, fp.OSType, fp.OSVersion, fp.KiroVersion)
|
||||
} else {
|
||||
// Use static Amazon Q CLI style headers for non-IDC auth
|
||||
req.Header.Set("User-Agent", kiroUserAgent)
|
||||
req.Header.Set("X-Amz-User-Agent", kiroFullUserAgent)
|
||||
}
|
||||
}
|
||||
|
||||
// PrepareRequest prepares the HTTP request before execution.
|
||||
func (e *KiroExecutor) PrepareRequest(req *http.Request, auth *cliproxyauth.Auth) error {
|
||||
if req == nil {
|
||||
@@ -226,16 +506,10 @@ func (e *KiroExecutor) PrepareRequest(req *http.Request, auth *cliproxyauth.Auth
|
||||
if strings.TrimSpace(accessToken) == "" {
|
||||
return statusErr{code: http.StatusUnauthorized, msg: "missing access token"}
|
||||
}
|
||||
if isIDCAuth(auth) {
|
||||
req.Header.Set("User-Agent", kiroIDEUserAgent)
|
||||
req.Header.Set("X-Amz-User-Agent", kiroIDEAmzUserAgent)
|
||||
req.Header.Set("x-amzn-kiro-agent-mode", kiroIDEAgentModeSpec)
|
||||
} else {
|
||||
req.Header.Set("User-Agent", kiroUserAgent)
|
||||
req.Header.Set("X-Amz-User-Agent", kiroFullUserAgent)
|
||||
req.Header.Set("x-amzn-kiro-agent-mode", kiroAgentModeVibe)
|
||||
}
|
||||
req.Header.Set("x-amzn-codewhisperer-optout", "true")
|
||||
|
||||
// Apply dynamic fingerprint-based headers
|
||||
applyDynamicFingerprint(req, auth)
|
||||
|
||||
req.Header.Set("Amz-Sdk-Request", "attempt=1; max=3")
|
||||
req.Header.Set("Amz-Sdk-Invocation-Id", uuid.New().String())
|
||||
req.Header.Set("Authorization", "Bearer "+accessToken)
|
||||
@@ -259,10 +533,23 @@ func (e *KiroExecutor) HttpRequest(ctx context.Context, auth *cliproxyauth.Auth,
|
||||
if errPrepare := e.PrepareRequest(httpReq, auth); errPrepare != nil {
|
||||
return nil, errPrepare
|
||||
}
|
||||
httpClient := newProxyAwareHTTPClient(ctx, e.cfg, auth, 0)
|
||||
httpClient := newKiroHTTPClientWithPooling(ctx, e.cfg, auth, 0)
|
||||
return httpClient.Do(httpReq)
|
||||
}
|
||||
|
||||
// getTokenKey returns a unique key for rate limiting based on auth credentials.
|
||||
// Uses auth ID if available, otherwise falls back to a hash of the access token.
|
||||
func getTokenKey(auth *cliproxyauth.Auth) string {
|
||||
if auth != nil && auth.ID != "" {
|
||||
return auth.ID
|
||||
}
|
||||
accessToken, _ := kiroCredentials(auth)
|
||||
if len(accessToken) > 16 {
|
||||
return accessToken[:16]
|
||||
}
|
||||
return accessToken
|
||||
}
|
||||
|
||||
// Execute sends the request to Kiro API and returns the response.
|
||||
// Supports automatic token refresh on 401/403 errors.
|
||||
func (e *KiroExecutor) Execute(ctx context.Context, auth *cliproxyauth.Auth, req cliproxyexecutor.Request, opts cliproxyexecutor.Options) (resp cliproxyexecutor.Response, err error) {
|
||||
@@ -271,6 +558,24 @@ func (e *KiroExecutor) Execute(ctx context.Context, auth *cliproxyauth.Auth, req
|
||||
return resp, fmt.Errorf("kiro: access token not found in auth")
|
||||
}
|
||||
|
||||
// Rate limiting: get token key for tracking
|
||||
tokenKey := getTokenKey(auth)
|
||||
rateLimiter := kiroauth.GetGlobalRateLimiter()
|
||||
cooldownMgr := kiroauth.GetGlobalCooldownManager()
|
||||
|
||||
// Check if token is in cooldown period
|
||||
if cooldownMgr.IsInCooldown(tokenKey) {
|
||||
remaining := cooldownMgr.GetRemainingCooldown(tokenKey)
|
||||
reason := cooldownMgr.GetCooldownReason(tokenKey)
|
||||
log.Warnf("kiro: token %s is in cooldown (reason: %s), remaining: %v", tokenKey, reason, remaining)
|
||||
return resp, fmt.Errorf("kiro: token is in cooldown for %v (reason: %s)", remaining, reason)
|
||||
}
|
||||
|
||||
// Wait for rate limiter before proceeding
|
||||
log.Debugf("kiro: waiting for rate limiter for token %s", tokenKey)
|
||||
rateLimiter.WaitForToken(tokenKey)
|
||||
log.Debugf("kiro: rate limiter cleared for token %s", tokenKey)
|
||||
|
||||
reporter := newUsageReporter(ctx, e.Identifier(), req.Model, auth)
|
||||
defer reporter.trackFailure(ctx, &err)
|
||||
|
||||
@@ -303,7 +608,7 @@ func (e *KiroExecutor) Execute(ctx context.Context, auth *cliproxyauth.Auth, req
|
||||
|
||||
// Execute with retry on 401/403 and 429 (quota exhausted)
|
||||
// Note: currentOrigin and kiroPayload are built inside executeWithRetry for each endpoint
|
||||
resp, err = e.executeWithRetry(ctx, auth, req, opts, accessToken, effectiveProfileArn, nil, body, from, to, reporter, "", kiroModelID, isAgentic, isChatOnly)
|
||||
resp, err = e.executeWithRetry(ctx, auth, req, opts, accessToken, effectiveProfileArn, nil, body, from, to, reporter, "", kiroModelID, isAgentic, isChatOnly, tokenKey)
|
||||
return resp, err
|
||||
}
|
||||
|
||||
@@ -312,9 +617,12 @@ func (e *KiroExecutor) Execute(ctx context.Context, auth *cliproxyauth.Auth, req
|
||||
// - Amazon Q endpoint (CLI origin) uses Amazon Q Developer quota
|
||||
// - CodeWhisperer endpoint (AI_EDITOR origin) uses Kiro IDE quota
|
||||
// Also supports multi-endpoint fallback similar to Antigravity implementation.
|
||||
func (e *KiroExecutor) executeWithRetry(ctx context.Context, auth *cliproxyauth.Auth, req cliproxyexecutor.Request, opts cliproxyexecutor.Options, accessToken, profileArn string, kiroPayload, body []byte, from, to sdktranslator.Format, reporter *usageReporter, currentOrigin, kiroModelID string, isAgentic, isChatOnly bool) (cliproxyexecutor.Response, error) {
|
||||
// tokenKey is used for rate limiting and cooldown tracking.
|
||||
func (e *KiroExecutor) executeWithRetry(ctx context.Context, auth *cliproxyauth.Auth, req cliproxyexecutor.Request, opts cliproxyexecutor.Options, accessToken, profileArn string, kiroPayload, body []byte, from, to sdktranslator.Format, reporter *usageReporter, currentOrigin, kiroModelID string, isAgentic, isChatOnly bool, tokenKey string) (cliproxyexecutor.Response, error) {
|
||||
var resp cliproxyexecutor.Response
|
||||
maxRetries := 2 // Allow retries for token refresh + endpoint fallback
|
||||
rateLimiter := kiroauth.GetGlobalRateLimiter()
|
||||
cooldownMgr := kiroauth.GetGlobalCooldownManager()
|
||||
endpointConfigs := getKiroEndpointConfigs(auth)
|
||||
var last429Err error
|
||||
|
||||
@@ -332,6 +640,12 @@ func (e *KiroExecutor) executeWithRetry(ctx context.Context, auth *cliproxyauth.
|
||||
endpointIdx+1, len(endpointConfigs), url, endpointConfig.Name, currentOrigin)
|
||||
|
||||
for attempt := 0; attempt <= maxRetries; attempt++ {
|
||||
// Apply human-like delay before first request (not on retries)
|
||||
// This mimics natural user behavior patterns
|
||||
if attempt == 0 && endpointIdx == 0 {
|
||||
kiroauth.ApplyHumanLikeDelay()
|
||||
}
|
||||
|
||||
httpReq, err := http.NewRequestWithContext(ctx, http.MethodPost, url, bytes.NewReader(kiroPayload))
|
||||
if err != nil {
|
||||
return resp, err
|
||||
@@ -342,20 +656,9 @@ func (e *KiroExecutor) executeWithRetry(ctx context.Context, auth *cliproxyauth.
|
||||
// Use endpoint-specific X-Amz-Target (critical for avoiding 403 errors)
|
||||
httpReq.Header.Set("X-Amz-Target", endpointConfig.AmzTarget)
|
||||
|
||||
// Use different headers based on auth type
|
||||
// IDC auth uses Kiro IDE style headers (from kiro2api)
|
||||
// Other auth types use Amazon Q CLI style headers
|
||||
if isIDCAuth(auth) {
|
||||
httpReq.Header.Set("User-Agent", kiroIDEUserAgent)
|
||||
httpReq.Header.Set("X-Amz-User-Agent", kiroIDEAmzUserAgent)
|
||||
httpReq.Header.Set("x-amzn-kiro-agent-mode", kiroIDEAgentModeSpec)
|
||||
log.Debugf("kiro: using Kiro IDE headers for IDC auth")
|
||||
} else {
|
||||
httpReq.Header.Set("User-Agent", kiroUserAgent)
|
||||
httpReq.Header.Set("X-Amz-User-Agent", kiroFullUserAgent)
|
||||
httpReq.Header.Set("x-amzn-kiro-agent-mode", kiroAgentModeVibe)
|
||||
}
|
||||
httpReq.Header.Set("x-amzn-codewhisperer-optout", "true")
|
||||
// Apply dynamic fingerprint-based headers
|
||||
applyDynamicFingerprint(httpReq, auth)
|
||||
|
||||
httpReq.Header.Set("Amz-Sdk-Request", "attempt=1; max=3")
|
||||
httpReq.Header.Set("Amz-Sdk-Invocation-Id", uuid.New().String())
|
||||
|
||||
@@ -386,10 +689,34 @@ func (e *KiroExecutor) executeWithRetry(ctx context.Context, auth *cliproxyauth.
|
||||
AuthValue: authValue,
|
||||
})
|
||||
|
||||
httpClient := newProxyAwareHTTPClient(ctx, e.cfg, auth, 120*time.Second)
|
||||
httpClient := newKiroHTTPClientWithPooling(ctx, e.cfg, auth, 120*time.Second)
|
||||
httpResp, err := httpClient.Do(httpReq)
|
||||
if err != nil {
|
||||
// Check for context cancellation first - client disconnected, not a server error
|
||||
// Use 499 (Client Closed Request - nginx convention) instead of 500
|
||||
if errors.Is(err, context.Canceled) {
|
||||
log.Debugf("kiro: request canceled by client (context.Canceled)")
|
||||
return resp, statusErr{code: 499, msg: "client canceled request"}
|
||||
}
|
||||
|
||||
// Check for context deadline exceeded - request timed out
|
||||
// Return 504 Gateway Timeout instead of 500
|
||||
if errors.Is(err, context.DeadlineExceeded) {
|
||||
log.Debugf("kiro: request timed out (context.DeadlineExceeded)")
|
||||
return resp, statusErr{code: http.StatusGatewayTimeout, msg: "upstream request timed out"}
|
||||
}
|
||||
|
||||
recordAPIResponseError(ctx, e.cfg, err)
|
||||
|
||||
// Enhanced socket retry: Check if error is retryable (network timeout, connection reset, etc.)
|
||||
retryCfg := defaultRetryConfig()
|
||||
if isRetryableError(err) && attempt < retryCfg.MaxRetries {
|
||||
delay := calculateRetryDelay(attempt, retryCfg)
|
||||
logRetryAttempt(attempt, retryCfg.MaxRetries, fmt.Sprintf("socket error: %v", err), delay, endpointConfig.Name)
|
||||
time.Sleep(delay)
|
||||
continue
|
||||
}
|
||||
|
||||
return resp, err
|
||||
}
|
||||
recordAPIResponseMetadata(ctx, e.cfg, httpResp.StatusCode, httpResp.Header.Clone())
|
||||
@@ -401,6 +728,12 @@ func (e *KiroExecutor) executeWithRetry(ctx context.Context, auth *cliproxyauth.
|
||||
_ = httpResp.Body.Close()
|
||||
appendAPIResponseChunk(ctx, e.cfg, respBody)
|
||||
|
||||
// Record failure and set cooldown for 429
|
||||
rateLimiter.MarkTokenFailed(tokenKey)
|
||||
cooldownDuration := kiroauth.CalculateCooldownFor429(attempt)
|
||||
cooldownMgr.SetCooldown(tokenKey, cooldownDuration, kiroauth.CooldownReason429)
|
||||
log.Warnf("kiro: rate limit hit (429), token %s set to cooldown for %v", tokenKey, cooldownDuration)
|
||||
|
||||
// Preserve last 429 so callers can correctly backoff when all endpoints are exhausted
|
||||
last429Err = statusErr{code: httpResp.StatusCode, msg: string(respBody)}
|
||||
|
||||
@@ -412,13 +745,21 @@ func (e *KiroExecutor) executeWithRetry(ctx context.Context, auth *cliproxyauth.
|
||||
}
|
||||
|
||||
// Handle 5xx server errors with exponential backoff retry
|
||||
// Enhanced: Use retryConfig for consistent retry behavior
|
||||
if httpResp.StatusCode >= 500 && httpResp.StatusCode < 600 {
|
||||
respBody, _ := io.ReadAll(httpResp.Body)
|
||||
_ = httpResp.Body.Close()
|
||||
appendAPIResponseChunk(ctx, e.cfg, respBody)
|
||||
|
||||
if attempt < maxRetries {
|
||||
// Exponential backoff: 1s, 2s, 4s... (max 30s)
|
||||
retryCfg := defaultRetryConfig()
|
||||
// Check if this specific 5xx code is retryable (502, 503, 504)
|
||||
if isRetryableHTTPStatus(httpResp.StatusCode) && attempt < retryCfg.MaxRetries {
|
||||
delay := calculateRetryDelay(attempt, retryCfg)
|
||||
logRetryAttempt(attempt, retryCfg.MaxRetries, fmt.Sprintf("HTTP %d", httpResp.StatusCode), delay, endpointConfig.Name)
|
||||
time.Sleep(delay)
|
||||
continue
|
||||
} else if attempt < maxRetries {
|
||||
// Fallback for other 5xx errors (500, 501, etc.)
|
||||
backoff := time.Duration(1<<attempt) * time.Second
|
||||
if backoff > 30*time.Second {
|
||||
backoff = 30 * time.Second
|
||||
@@ -492,7 +833,10 @@ func (e *KiroExecutor) executeWithRetry(ctx context.Context, auth *cliproxyauth.
|
||||
|
||||
// Check for SUSPENDED status - return immediately without retry
|
||||
if strings.Contains(respBodyStr, "SUSPENDED") || strings.Contains(respBodyStr, "TEMPORARILY_SUSPENDED") {
|
||||
log.Errorf("kiro: account is suspended, cannot proceed")
|
||||
// Set long cooldown for suspended accounts
|
||||
rateLimiter.CheckAndMarkSuspended(tokenKey, respBodyStr)
|
||||
cooldownMgr.SetCooldown(tokenKey, kiroauth.LongCooldown, kiroauth.CooldownReasonSuspended)
|
||||
log.Errorf("kiro: account is suspended, token %s set to cooldown for %v", tokenKey, kiroauth.LongCooldown)
|
||||
return resp, statusErr{code: httpResp.StatusCode, msg: "account suspended: " + string(respBody)}
|
||||
}
|
||||
|
||||
@@ -581,6 +925,10 @@ func (e *KiroExecutor) executeWithRetry(ctx context.Context, auth *cliproxyauth.
|
||||
appendAPIResponseChunk(ctx, e.cfg, []byte(content))
|
||||
reporter.publish(ctx, usageInfo)
|
||||
|
||||
// Record success for rate limiting
|
||||
rateLimiter.MarkTokenSuccess(tokenKey)
|
||||
log.Debugf("kiro: request successful, token %s marked as success", tokenKey)
|
||||
|
||||
// Build response in Claude format for Kiro translator
|
||||
// stopReason is extracted from upstream response by parseEventStream
|
||||
kiroResponse := kiroclaude.BuildClaudeResponse(content, toolUses, req.Model, usageInfo, stopReason)
|
||||
@@ -608,6 +956,24 @@ func (e *KiroExecutor) ExecuteStream(ctx context.Context, auth *cliproxyauth.Aut
|
||||
return nil, fmt.Errorf("kiro: access token not found in auth")
|
||||
}
|
||||
|
||||
// Rate limiting: get token key for tracking
|
||||
tokenKey := getTokenKey(auth)
|
||||
rateLimiter := kiroauth.GetGlobalRateLimiter()
|
||||
cooldownMgr := kiroauth.GetGlobalCooldownManager()
|
||||
|
||||
// Check if token is in cooldown period
|
||||
if cooldownMgr.IsInCooldown(tokenKey) {
|
||||
remaining := cooldownMgr.GetRemainingCooldown(tokenKey)
|
||||
reason := cooldownMgr.GetCooldownReason(tokenKey)
|
||||
log.Warnf("kiro: token %s is in cooldown (reason: %s), remaining: %v", tokenKey, reason, remaining)
|
||||
return nil, fmt.Errorf("kiro: token is in cooldown for %v (reason: %s)", remaining, reason)
|
||||
}
|
||||
|
||||
// Wait for rate limiter before proceeding
|
||||
log.Debugf("kiro: stream waiting for rate limiter for token %s", tokenKey)
|
||||
rateLimiter.WaitForToken(tokenKey)
|
||||
log.Debugf("kiro: stream rate limiter cleared for token %s", tokenKey)
|
||||
|
||||
reporter := newUsageReporter(ctx, e.Identifier(), req.Model, auth)
|
||||
defer reporter.trackFailure(ctx, &err)
|
||||
|
||||
@@ -640,7 +1006,7 @@ func (e *KiroExecutor) ExecuteStream(ctx context.Context, auth *cliproxyauth.Aut
|
||||
|
||||
// Execute stream with retry on 401/403 and 429 (quota exhausted)
|
||||
// Note: currentOrigin and kiroPayload are built inside executeStreamWithRetry for each endpoint
|
||||
return e.executeStreamWithRetry(ctx, auth, req, opts, accessToken, effectiveProfileArn, nil, body, from, reporter, "", kiroModelID, isAgentic, isChatOnly)
|
||||
return e.executeStreamWithRetry(ctx, auth, req, opts, accessToken, effectiveProfileArn, nil, body, from, reporter, "", kiroModelID, isAgentic, isChatOnly, tokenKey)
|
||||
}
|
||||
|
||||
// executeStreamWithRetry performs the streaming HTTP request with automatic retry on auth errors.
|
||||
@@ -648,8 +1014,11 @@ func (e *KiroExecutor) ExecuteStream(ctx context.Context, auth *cliproxyauth.Aut
|
||||
// - Amazon Q endpoint (CLI origin) uses Amazon Q Developer quota
|
||||
// - CodeWhisperer endpoint (AI_EDITOR origin) uses Kiro IDE quota
|
||||
// Also supports multi-endpoint fallback similar to Antigravity implementation.
|
||||
func (e *KiroExecutor) executeStreamWithRetry(ctx context.Context, auth *cliproxyauth.Auth, req cliproxyexecutor.Request, opts cliproxyexecutor.Options, accessToken, profileArn string, kiroPayload, body []byte, from sdktranslator.Format, reporter *usageReporter, currentOrigin, kiroModelID string, isAgentic, isChatOnly bool) (<-chan cliproxyexecutor.StreamChunk, error) {
|
||||
// tokenKey is used for rate limiting and cooldown tracking.
|
||||
func (e *KiroExecutor) executeStreamWithRetry(ctx context.Context, auth *cliproxyauth.Auth, req cliproxyexecutor.Request, opts cliproxyexecutor.Options, accessToken, profileArn string, kiroPayload, body []byte, from sdktranslator.Format, reporter *usageReporter, currentOrigin, kiroModelID string, isAgentic, isChatOnly bool, tokenKey string) (<-chan cliproxyexecutor.StreamChunk, error) {
|
||||
maxRetries := 2 // Allow retries for token refresh + endpoint fallback
|
||||
rateLimiter := kiroauth.GetGlobalRateLimiter()
|
||||
cooldownMgr := kiroauth.GetGlobalCooldownManager()
|
||||
endpointConfigs := getKiroEndpointConfigs(auth)
|
||||
var last429Err error
|
||||
|
||||
@@ -667,6 +1036,13 @@ func (e *KiroExecutor) executeStreamWithRetry(ctx context.Context, auth *cliprox
|
||||
endpointIdx+1, len(endpointConfigs), url, endpointConfig.Name, currentOrigin)
|
||||
|
||||
for attempt := 0; attempt <= maxRetries; attempt++ {
|
||||
// Apply human-like delay before first streaming request (not on retries)
|
||||
// This mimics natural user behavior patterns
|
||||
// Note: Delay is NOT applied during streaming response - only before initial request
|
||||
if attempt == 0 && endpointIdx == 0 {
|
||||
kiroauth.ApplyHumanLikeDelay()
|
||||
}
|
||||
|
||||
httpReq, err := http.NewRequestWithContext(ctx, http.MethodPost, url, bytes.NewReader(kiroPayload))
|
||||
if err != nil {
|
||||
return nil, err
|
||||
@@ -677,20 +1053,9 @@ func (e *KiroExecutor) executeStreamWithRetry(ctx context.Context, auth *cliprox
|
||||
// Use endpoint-specific X-Amz-Target (critical for avoiding 403 errors)
|
||||
httpReq.Header.Set("X-Amz-Target", endpointConfig.AmzTarget)
|
||||
|
||||
// Use different headers based on auth type
|
||||
// IDC auth uses Kiro IDE style headers (from kiro2api)
|
||||
// Other auth types use Amazon Q CLI style headers
|
||||
if isIDCAuth(auth) {
|
||||
httpReq.Header.Set("User-Agent", kiroIDEUserAgent)
|
||||
httpReq.Header.Set("X-Amz-User-Agent", kiroIDEAmzUserAgent)
|
||||
httpReq.Header.Set("x-amzn-kiro-agent-mode", kiroIDEAgentModeSpec)
|
||||
log.Debugf("kiro: using Kiro IDE headers for IDC auth")
|
||||
} else {
|
||||
httpReq.Header.Set("User-Agent", kiroUserAgent)
|
||||
httpReq.Header.Set("X-Amz-User-Agent", kiroFullUserAgent)
|
||||
httpReq.Header.Set("x-amzn-kiro-agent-mode", kiroAgentModeVibe)
|
||||
}
|
||||
httpReq.Header.Set("x-amzn-codewhisperer-optout", "true")
|
||||
// Apply dynamic fingerprint-based headers
|
||||
applyDynamicFingerprint(httpReq, auth)
|
||||
|
||||
httpReq.Header.Set("Amz-Sdk-Request", "attempt=1; max=3")
|
||||
httpReq.Header.Set("Amz-Sdk-Invocation-Id", uuid.New().String())
|
||||
|
||||
@@ -721,10 +1086,20 @@ func (e *KiroExecutor) executeStreamWithRetry(ctx context.Context, auth *cliprox
|
||||
AuthValue: authValue,
|
||||
})
|
||||
|
||||
httpClient := newProxyAwareHTTPClient(ctx, e.cfg, auth, 0)
|
||||
httpClient := newKiroHTTPClientWithPooling(ctx, e.cfg, auth, 0)
|
||||
httpResp, err := httpClient.Do(httpReq)
|
||||
if err != nil {
|
||||
recordAPIResponseError(ctx, e.cfg, err)
|
||||
|
||||
// Enhanced socket retry for streaming: Check if error is retryable (network timeout, connection reset, etc.)
|
||||
retryCfg := defaultRetryConfig()
|
||||
if isRetryableError(err) && attempt < retryCfg.MaxRetries {
|
||||
delay := calculateRetryDelay(attempt, retryCfg)
|
||||
logRetryAttempt(attempt, retryCfg.MaxRetries, fmt.Sprintf("stream socket error: %v", err), delay, endpointConfig.Name)
|
||||
time.Sleep(delay)
|
||||
continue
|
||||
}
|
||||
|
||||
return nil, err
|
||||
}
|
||||
recordAPIResponseMetadata(ctx, e.cfg, httpResp.StatusCode, httpResp.Header.Clone())
|
||||
@@ -736,6 +1111,12 @@ func (e *KiroExecutor) executeStreamWithRetry(ctx context.Context, auth *cliprox
|
||||
_ = httpResp.Body.Close()
|
||||
appendAPIResponseChunk(ctx, e.cfg, respBody)
|
||||
|
||||
// Record failure and set cooldown for 429
|
||||
rateLimiter.MarkTokenFailed(tokenKey)
|
||||
cooldownDuration := kiroauth.CalculateCooldownFor429(attempt)
|
||||
cooldownMgr.SetCooldown(tokenKey, cooldownDuration, kiroauth.CooldownReason429)
|
||||
log.Warnf("kiro: stream rate limit hit (429), token %s set to cooldown for %v", tokenKey, cooldownDuration)
|
||||
|
||||
// Preserve last 429 so callers can correctly backoff when all endpoints are exhausted
|
||||
last429Err = statusErr{code: httpResp.StatusCode, msg: string(respBody)}
|
||||
|
||||
@@ -747,13 +1128,21 @@ func (e *KiroExecutor) executeStreamWithRetry(ctx context.Context, auth *cliprox
|
||||
}
|
||||
|
||||
// Handle 5xx server errors with exponential backoff retry
|
||||
// Enhanced: Use retryConfig for consistent retry behavior
|
||||
if httpResp.StatusCode >= 500 && httpResp.StatusCode < 600 {
|
||||
respBody, _ := io.ReadAll(httpResp.Body)
|
||||
_ = httpResp.Body.Close()
|
||||
appendAPIResponseChunk(ctx, e.cfg, respBody)
|
||||
|
||||
if attempt < maxRetries {
|
||||
// Exponential backoff: 1s, 2s, 4s... (max 30s)
|
||||
retryCfg := defaultRetryConfig()
|
||||
// Check if this specific 5xx code is retryable (502, 503, 504)
|
||||
if isRetryableHTTPStatus(httpResp.StatusCode) && attempt < retryCfg.MaxRetries {
|
||||
delay := calculateRetryDelay(attempt, retryCfg)
|
||||
logRetryAttempt(attempt, retryCfg.MaxRetries, fmt.Sprintf("stream HTTP %d", httpResp.StatusCode), delay, endpointConfig.Name)
|
||||
time.Sleep(delay)
|
||||
continue
|
||||
} else if attempt < maxRetries {
|
||||
// Fallback for other 5xx errors (500, 501, etc.)
|
||||
backoff := time.Duration(1<<attempt) * time.Second
|
||||
if backoff > 30*time.Second {
|
||||
backoff = 30 * time.Second
|
||||
@@ -840,7 +1229,10 @@ func (e *KiroExecutor) executeStreamWithRetry(ctx context.Context, auth *cliprox
|
||||
|
||||
// Check for SUSPENDED status - return immediately without retry
|
||||
if strings.Contains(respBodyStr, "SUSPENDED") || strings.Contains(respBodyStr, "TEMPORARILY_SUSPENDED") {
|
||||
log.Errorf("kiro: account is suspended, cannot proceed")
|
||||
// Set long cooldown for suspended accounts
|
||||
rateLimiter.CheckAndMarkSuspended(tokenKey, respBodyStr)
|
||||
cooldownMgr.SetCooldown(tokenKey, kiroauth.LongCooldown, kiroauth.CooldownReasonSuspended)
|
||||
log.Errorf("kiro: stream account is suspended, token %s set to cooldown for %v", tokenKey, kiroauth.LongCooldown)
|
||||
return nil, statusErr{code: httpResp.StatusCode, msg: "account suspended: " + string(respBody)}
|
||||
}
|
||||
|
||||
@@ -890,6 +1282,11 @@ func (e *KiroExecutor) executeStreamWithRetry(ctx context.Context, auth *cliprox
|
||||
|
||||
out := make(chan cliproxyexecutor.StreamChunk)
|
||||
|
||||
// Record success immediately since connection was established successfully
|
||||
// Streaming errors will be handled separately
|
||||
rateLimiter.MarkTokenSuccess(tokenKey)
|
||||
log.Debugf("kiro: stream request successful, token %s marked as success", tokenKey)
|
||||
|
||||
go func(resp *http.Response, thinkingEnabled bool) {
|
||||
defer close(out)
|
||||
defer func() {
|
||||
|
||||
97
internal/translator/kiro/common/utf8_stream.go
Normal file
97
internal/translator/kiro/common/utf8_stream.go
Normal file
@@ -0,0 +1,97 @@
|
||||
package common
|
||||
|
||||
import (
|
||||
"unicode/utf8"
|
||||
)
|
||||
|
||||
type UTF8StreamParser struct {
|
||||
buffer []byte
|
||||
}
|
||||
|
||||
func NewUTF8StreamParser() *UTF8StreamParser {
|
||||
return &UTF8StreamParser{
|
||||
buffer: make([]byte, 0, 64),
|
||||
}
|
||||
}
|
||||
|
||||
func (p *UTF8StreamParser) Write(data []byte) {
|
||||
p.buffer = append(p.buffer, data...)
|
||||
}
|
||||
|
||||
func (p *UTF8StreamParser) Read() (string, bool) {
|
||||
if len(p.buffer) == 0 {
|
||||
return "", false
|
||||
}
|
||||
|
||||
validLen := p.findValidUTF8End(p.buffer)
|
||||
if validLen == 0 {
|
||||
return "", false
|
||||
}
|
||||
|
||||
result := string(p.buffer[:validLen])
|
||||
p.buffer = p.buffer[validLen:]
|
||||
|
||||
return result, true
|
||||
}
|
||||
|
||||
func (p *UTF8StreamParser) Flush() string {
|
||||
if len(p.buffer) == 0 {
|
||||
return ""
|
||||
}
|
||||
result := string(p.buffer)
|
||||
p.buffer = p.buffer[:0]
|
||||
return result
|
||||
}
|
||||
|
||||
func (p *UTF8StreamParser) Reset() {
|
||||
p.buffer = p.buffer[:0]
|
||||
}
|
||||
|
||||
func (p *UTF8StreamParser) findValidUTF8End(data []byte) int {
|
||||
if len(data) == 0 {
|
||||
return 0
|
||||
}
|
||||
|
||||
end := len(data)
|
||||
for i := 1; i <= 3 && i <= len(data); i++ {
|
||||
b := data[len(data)-i]
|
||||
if b&0x80 == 0 {
|
||||
break
|
||||
}
|
||||
if b&0xC0 == 0xC0 {
|
||||
size := p.utf8CharSize(b)
|
||||
available := i
|
||||
if size > available {
|
||||
end = len(data) - i
|
||||
}
|
||||
break
|
||||
}
|
||||
}
|
||||
|
||||
if end > 0 && !utf8.Valid(data[:end]) {
|
||||
for i := end - 1; i >= 0; i-- {
|
||||
if utf8.Valid(data[:i+1]) {
|
||||
return i + 1
|
||||
}
|
||||
}
|
||||
return 0
|
||||
}
|
||||
|
||||
return end
|
||||
}
|
||||
|
||||
func (p *UTF8StreamParser) utf8CharSize(b byte) int {
|
||||
if b&0x80 == 0 {
|
||||
return 1
|
||||
}
|
||||
if b&0xE0 == 0xC0 {
|
||||
return 2
|
||||
}
|
||||
if b&0xF0 == 0xE0 {
|
||||
return 3
|
||||
}
|
||||
if b&0xF8 == 0xF0 {
|
||||
return 4
|
||||
}
|
||||
return 1
|
||||
}
|
||||
402
internal/translator/kiro/common/utf8_stream_test.go
Normal file
402
internal/translator/kiro/common/utf8_stream_test.go
Normal file
@@ -0,0 +1,402 @@
|
||||
package common
|
||||
|
||||
import (
|
||||
"strings"
|
||||
"sync"
|
||||
"testing"
|
||||
"unicode/utf8"
|
||||
)
|
||||
|
||||
func TestNewUTF8StreamParser(t *testing.T) {
|
||||
p := NewUTF8StreamParser()
|
||||
if p == nil {
|
||||
t.Fatal("expected non-nil UTF8StreamParser")
|
||||
}
|
||||
if p.buffer == nil {
|
||||
t.Error("expected non-nil buffer")
|
||||
}
|
||||
}
|
||||
|
||||
func TestWrite(t *testing.T) {
|
||||
p := NewUTF8StreamParser()
|
||||
p.Write([]byte("hello"))
|
||||
|
||||
result, ok := p.Read()
|
||||
if !ok {
|
||||
t.Error("expected ok to be true")
|
||||
}
|
||||
if result != "hello" {
|
||||
t.Errorf("expected 'hello', got '%s'", result)
|
||||
}
|
||||
}
|
||||
|
||||
func TestWrite_MultipleWrites(t *testing.T) {
|
||||
p := NewUTF8StreamParser()
|
||||
p.Write([]byte("hel"))
|
||||
p.Write([]byte("lo"))
|
||||
|
||||
result, ok := p.Read()
|
||||
if !ok {
|
||||
t.Error("expected ok to be true")
|
||||
}
|
||||
if result != "hello" {
|
||||
t.Errorf("expected 'hello', got '%s'", result)
|
||||
}
|
||||
}
|
||||
|
||||
func TestRead_EmptyBuffer(t *testing.T) {
|
||||
p := NewUTF8StreamParser()
|
||||
result, ok := p.Read()
|
||||
if ok {
|
||||
t.Error("expected ok to be false for empty buffer")
|
||||
}
|
||||
if result != "" {
|
||||
t.Errorf("expected empty string, got '%s'", result)
|
||||
}
|
||||
}
|
||||
|
||||
func TestRead_IncompleteUTF8(t *testing.T) {
|
||||
p := NewUTF8StreamParser()
|
||||
|
||||
// Write incomplete multi-byte UTF-8 character
|
||||
// 中 (U+4E2D) = E4 B8 AD
|
||||
p.Write([]byte{0xE4, 0xB8})
|
||||
|
||||
result, ok := p.Read()
|
||||
if ok {
|
||||
t.Error("expected ok to be false for incomplete UTF-8")
|
||||
}
|
||||
if result != "" {
|
||||
t.Errorf("expected empty string, got '%s'", result)
|
||||
}
|
||||
|
||||
// Complete the character
|
||||
p.Write([]byte{0xAD})
|
||||
result, ok = p.Read()
|
||||
if !ok {
|
||||
t.Error("expected ok to be true after completing UTF-8")
|
||||
}
|
||||
if result != "中" {
|
||||
t.Errorf("expected '中', got '%s'", result)
|
||||
}
|
||||
}
|
||||
|
||||
func TestRead_MixedASCIIAndUTF8(t *testing.T) {
|
||||
p := NewUTF8StreamParser()
|
||||
p.Write([]byte("Hello 世界"))
|
||||
|
||||
result, ok := p.Read()
|
||||
if !ok {
|
||||
t.Error("expected ok to be true")
|
||||
}
|
||||
if result != "Hello 世界" {
|
||||
t.Errorf("expected 'Hello 世界', got '%s'", result)
|
||||
}
|
||||
}
|
||||
|
||||
func TestRead_PartialMultibyteAtEnd(t *testing.T) {
|
||||
p := NewUTF8StreamParser()
|
||||
// "Hello" + partial "世" (E4 B8 96)
|
||||
p.Write([]byte("Hello"))
|
||||
p.Write([]byte{0xE4, 0xB8})
|
||||
|
||||
result, ok := p.Read()
|
||||
if !ok {
|
||||
t.Error("expected ok to be true for valid portion")
|
||||
}
|
||||
if result != "Hello" {
|
||||
t.Errorf("expected 'Hello', got '%s'", result)
|
||||
}
|
||||
|
||||
// Complete the character
|
||||
p.Write([]byte{0x96})
|
||||
result, ok = p.Read()
|
||||
if !ok {
|
||||
t.Error("expected ok to be true after completing")
|
||||
}
|
||||
if result != "世" {
|
||||
t.Errorf("expected '世', got '%s'", result)
|
||||
}
|
||||
}
|
||||
|
||||
func TestFlush(t *testing.T) {
|
||||
p := NewUTF8StreamParser()
|
||||
p.Write([]byte("hello"))
|
||||
|
||||
result := p.Flush()
|
||||
if result != "hello" {
|
||||
t.Errorf("expected 'hello', got '%s'", result)
|
||||
}
|
||||
|
||||
// Verify buffer is cleared
|
||||
result2, ok := p.Read()
|
||||
if ok {
|
||||
t.Error("expected ok to be false after flush")
|
||||
}
|
||||
if result2 != "" {
|
||||
t.Errorf("expected empty string after flush, got '%s'", result2)
|
||||
}
|
||||
}
|
||||
|
||||
func TestFlush_EmptyBuffer(t *testing.T) {
|
||||
p := NewUTF8StreamParser()
|
||||
result := p.Flush()
|
||||
if result != "" {
|
||||
t.Errorf("expected empty string, got '%s'", result)
|
||||
}
|
||||
}
|
||||
|
||||
func TestFlush_IncompleteUTF8(t *testing.T) {
|
||||
p := NewUTF8StreamParser()
|
||||
p.Write([]byte{0xE4, 0xB8})
|
||||
|
||||
result := p.Flush()
|
||||
// Flush returns everything including incomplete bytes
|
||||
if len(result) != 2 {
|
||||
t.Errorf("expected 2 bytes flushed, got %d", len(result))
|
||||
}
|
||||
}
|
||||
|
||||
func TestReset(t *testing.T) {
|
||||
p := NewUTF8StreamParser()
|
||||
p.Write([]byte("hello"))
|
||||
p.Reset()
|
||||
|
||||
result, ok := p.Read()
|
||||
if ok {
|
||||
t.Error("expected ok to be false after reset")
|
||||
}
|
||||
if result != "" {
|
||||
t.Errorf("expected empty string after reset, got '%s'", result)
|
||||
}
|
||||
}
|
||||
|
||||
func TestUtf8CharSize(t *testing.T) {
|
||||
p := NewUTF8StreamParser()
|
||||
|
||||
testCases := []struct {
|
||||
b byte
|
||||
expected int
|
||||
}{
|
||||
{0x00, 1}, // ASCII
|
||||
{0x7F, 1}, // ASCII max
|
||||
{0xC0, 2}, // 2-byte start
|
||||
{0xDF, 2}, // 2-byte start
|
||||
{0xE0, 3}, // 3-byte start
|
||||
{0xEF, 3}, // 3-byte start
|
||||
{0xF0, 4}, // 4-byte start
|
||||
{0xF7, 4}, // 4-byte start
|
||||
{0x80, 1}, // Continuation byte (fallback)
|
||||
}
|
||||
|
||||
for _, tc := range testCases {
|
||||
size := p.utf8CharSize(tc.b)
|
||||
if size != tc.expected {
|
||||
t.Errorf("utf8CharSize(0x%X) = %d, expected %d", tc.b, size, tc.expected)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func TestStreamingScenario(t *testing.T) {
|
||||
p := NewUTF8StreamParser()
|
||||
|
||||
// Simulate streaming: "Hello, 世界! 🌍"
|
||||
chunks := [][]byte{
|
||||
[]byte("Hello, "),
|
||||
{0xE4, 0xB8}, // partial 世
|
||||
{0x96, 0xE7}, // complete 世, partial 界
|
||||
{0x95, 0x8C}, // complete 界
|
||||
[]byte("! "),
|
||||
{0xF0, 0x9F}, // partial 🌍
|
||||
{0x8C, 0x8D}, // complete 🌍
|
||||
}
|
||||
|
||||
var results []string
|
||||
for _, chunk := range chunks {
|
||||
p.Write(chunk)
|
||||
if result, ok := p.Read(); ok {
|
||||
results = append(results, result)
|
||||
}
|
||||
}
|
||||
|
||||
combined := strings.Join(results, "")
|
||||
if combined != "Hello, 世界! 🌍" {
|
||||
t.Errorf("expected 'Hello, 世界! 🌍', got '%s'", combined)
|
||||
}
|
||||
}
|
||||
|
||||
func TestValidUTF8Output(t *testing.T) {
|
||||
p := NewUTF8StreamParser()
|
||||
|
||||
testStrings := []string{
|
||||
"Hello World",
|
||||
"你好世界",
|
||||
"こんにちは",
|
||||
"🎉🎊🎁",
|
||||
"Mixed 混合 Текст ტექსტი",
|
||||
}
|
||||
|
||||
for _, s := range testStrings {
|
||||
p.Reset()
|
||||
p.Write([]byte(s))
|
||||
result, ok := p.Read()
|
||||
if !ok {
|
||||
t.Errorf("expected ok for '%s'", s)
|
||||
}
|
||||
if !utf8.ValidString(result) {
|
||||
t.Errorf("invalid UTF-8 output for input '%s'", s)
|
||||
}
|
||||
if result != s {
|
||||
t.Errorf("expected '%s', got '%s'", s, result)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func TestLargeData(t *testing.T) {
|
||||
p := NewUTF8StreamParser()
|
||||
|
||||
// Generate large UTF-8 string
|
||||
var builder strings.Builder
|
||||
for i := 0; i < 1000; i++ {
|
||||
builder.WriteString("Hello 世界! ")
|
||||
}
|
||||
largeString := builder.String()
|
||||
|
||||
p.Write([]byte(largeString))
|
||||
result, ok := p.Read()
|
||||
if !ok {
|
||||
t.Error("expected ok for large data")
|
||||
}
|
||||
if result != largeString {
|
||||
t.Error("large data mismatch")
|
||||
}
|
||||
}
|
||||
|
||||
func TestByteByByteWriting(t *testing.T) {
|
||||
p := NewUTF8StreamParser()
|
||||
input := "Hello 世界"
|
||||
inputBytes := []byte(input)
|
||||
|
||||
var results []string
|
||||
for _, b := range inputBytes {
|
||||
p.Write([]byte{b})
|
||||
if result, ok := p.Read(); ok {
|
||||
results = append(results, result)
|
||||
}
|
||||
}
|
||||
|
||||
combined := strings.Join(results, "")
|
||||
if combined != input {
|
||||
t.Errorf("expected '%s', got '%s'", input, combined)
|
||||
}
|
||||
}
|
||||
|
||||
func TestEmoji4ByteUTF8(t *testing.T) {
|
||||
p := NewUTF8StreamParser()
|
||||
|
||||
// 🎉 = F0 9F 8E 89
|
||||
emoji := "🎉"
|
||||
emojiBytes := []byte(emoji)
|
||||
|
||||
for i := 0; i < len(emojiBytes)-1; i++ {
|
||||
p.Write(emojiBytes[i : i+1])
|
||||
result, ok := p.Read()
|
||||
if ok && result != "" {
|
||||
t.Errorf("unexpected output before emoji complete: '%s'", result)
|
||||
}
|
||||
}
|
||||
|
||||
p.Write(emojiBytes[len(emojiBytes)-1:])
|
||||
result, ok := p.Read()
|
||||
if !ok {
|
||||
t.Error("expected ok after completing emoji")
|
||||
}
|
||||
if result != emoji {
|
||||
t.Errorf("expected '%s', got '%s'", emoji, result)
|
||||
}
|
||||
}
|
||||
|
||||
func TestContinuationBytesOnly(t *testing.T) {
|
||||
p := NewUTF8StreamParser()
|
||||
|
||||
// Write only continuation bytes (invalid UTF-8)
|
||||
p.Write([]byte{0x80, 0x80, 0x80})
|
||||
|
||||
result, ok := p.Read()
|
||||
// Should handle gracefully - either return nothing or return the bytes
|
||||
_ = result
|
||||
_ = ok
|
||||
}
|
||||
|
||||
func TestUTF8StreamParser_ConcurrentSafety(t *testing.T) {
|
||||
// Note: UTF8StreamParser doesn't have built-in locks,
|
||||
// so this test verifies it works with external synchronization
|
||||
p := NewUTF8StreamParser()
|
||||
var mu sync.Mutex
|
||||
const numGoroutines = 10
|
||||
const numOperations = 100
|
||||
|
||||
var wg sync.WaitGroup
|
||||
wg.Add(numGoroutines)
|
||||
|
||||
for i := 0; i < numGoroutines; i++ {
|
||||
go func() {
|
||||
defer wg.Done()
|
||||
for j := 0; j < numOperations; j++ {
|
||||
mu.Lock()
|
||||
switch j % 4 {
|
||||
case 0:
|
||||
p.Write([]byte("test"))
|
||||
case 1:
|
||||
p.Read()
|
||||
case 2:
|
||||
p.Flush()
|
||||
case 3:
|
||||
p.Reset()
|
||||
}
|
||||
mu.Unlock()
|
||||
}
|
||||
}()
|
||||
}
|
||||
|
||||
wg.Wait()
|
||||
}
|
||||
|
||||
func TestConsecutiveReads(t *testing.T) {
|
||||
p := NewUTF8StreamParser()
|
||||
p.Write([]byte("hello"))
|
||||
|
||||
result1, ok1 := p.Read()
|
||||
if !ok1 || result1 != "hello" {
|
||||
t.Error("first read failed")
|
||||
}
|
||||
|
||||
result2, ok2 := p.Read()
|
||||
if ok2 || result2 != "" {
|
||||
t.Error("second read should return empty")
|
||||
}
|
||||
}
|
||||
|
||||
func TestFlushThenWrite(t *testing.T) {
|
||||
p := NewUTF8StreamParser()
|
||||
p.Write([]byte("first"))
|
||||
p.Flush()
|
||||
p.Write([]byte("second"))
|
||||
|
||||
result, ok := p.Read()
|
||||
if !ok || result != "second" {
|
||||
t.Errorf("expected 'second', got '%s'", result)
|
||||
}
|
||||
}
|
||||
|
||||
func TestResetThenWrite(t *testing.T) {
|
||||
p := NewUTF8StreamParser()
|
||||
p.Write([]byte("first"))
|
||||
p.Reset()
|
||||
p.Write([]byte("second"))
|
||||
|
||||
result, ok := p.Read()
|
||||
if !ok || result != "second" {
|
||||
t.Errorf("expected 'second', got '%s'", result)
|
||||
}
|
||||
}
|
||||
@@ -170,7 +170,9 @@ func (w *Watcher) handleKiroIDETokenChange(event fsnotify.Event) {
|
||||
}
|
||||
}
|
||||
|
||||
tokenData, err := kiroauth.LoadKiroIDEToken()
|
||||
// Use retry logic to handle file lock contention (e.g., Kiro IDE writing the file)
|
||||
// This prevents "being used by another process" errors on Windows
|
||||
tokenData, err := kiroauth.LoadKiroIDETokenWithRetry(10, 50*time.Millisecond)
|
||||
if err != nil {
|
||||
log.Debugf("failed to load Kiro IDE token after change: %v", err)
|
||||
return
|
||||
|
||||
144
sdk/auth/kiro.go
144
sdk/auth/kiro.go
@@ -12,9 +12,9 @@ import (
|
||||
)
|
||||
|
||||
// extractKiroIdentifier extracts a meaningful identifier for file naming.
|
||||
// Returns account name if provided, otherwise profile ARN ID.
|
||||
// Returns account name if provided, otherwise profile ARN ID, then client ID.
|
||||
// All extracted values are sanitized to prevent path injection attacks.
|
||||
func extractKiroIdentifier(accountName, profileArn string) string {
|
||||
func extractKiroIdentifier(accountName, profileArn, clientID string) string {
|
||||
// Priority 1: Use account name if provided
|
||||
if accountName != "" {
|
||||
return kiroauth.SanitizeEmailForFilename(accountName)
|
||||
@@ -29,6 +29,11 @@ func extractKiroIdentifier(accountName, profileArn string) string {
|
||||
}
|
||||
}
|
||||
|
||||
// Priority 3: Use client ID (for IDC auth without email/profileArn)
|
||||
if clientID != "" {
|
||||
return kiroauth.SanitizeEmailForFilename(clientID)
|
||||
}
|
||||
|
||||
// Fallback: timestamp
|
||||
return fmt.Sprintf("%d", time.Now().UnixNano()%100000)
|
||||
}
|
||||
@@ -62,7 +67,7 @@ func (a *KiroAuthenticator) createAuthRecord(tokenData *kiroauth.KiroTokenData,
|
||||
}
|
||||
|
||||
// Extract identifier for file naming
|
||||
idPart := extractKiroIdentifier(tokenData.Email, tokenData.ProfileArn)
|
||||
idPart := extractKiroIdentifier(tokenData.Email, tokenData.ProfileArn, tokenData.ClientID)
|
||||
|
||||
// Determine label based on auth method
|
||||
label := fmt.Sprintf("kiro-%s", source)
|
||||
@@ -173,7 +178,7 @@ func (a *KiroAuthenticator) LoginWithAuthCode(ctx context.Context, cfg *config.C
|
||||
}
|
||||
|
||||
// Extract identifier for file naming
|
||||
idPart := extractKiroIdentifier(tokenData.Email, tokenData.ProfileArn)
|
||||
idPart := extractKiroIdentifier(tokenData.Email, tokenData.ProfileArn, tokenData.ClientID)
|
||||
|
||||
now := time.Now()
|
||||
fileName := fmt.Sprintf("kiro-aws-%s.json", idPart)
|
||||
@@ -217,129 +222,17 @@ func (a *KiroAuthenticator) LoginWithAuthCode(ctx context.Context, cfg *config.C
|
||||
}
|
||||
|
||||
// LoginWithGoogle performs OAuth login for Kiro with Google.
|
||||
// This uses a custom protocol handler (kiro://) to receive the callback.
|
||||
// NOTE: Google login is not available for third-party applications due to AWS Cognito restrictions.
|
||||
// Please use AWS Builder ID or import your token from Kiro IDE.
|
||||
func (a *KiroAuthenticator) LoginWithGoogle(ctx context.Context, cfg *config.Config, opts *LoginOptions) (*coreauth.Auth, error) {
|
||||
if cfg == nil {
|
||||
return nil, fmt.Errorf("kiro auth: configuration is required")
|
||||
}
|
||||
|
||||
oauth := kiroauth.NewKiroOAuth(cfg)
|
||||
|
||||
// Use Google OAuth flow with protocol handler
|
||||
tokenData, err := oauth.LoginWithGoogle(ctx)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("google login failed: %w", err)
|
||||
}
|
||||
|
||||
// Parse expires_at
|
||||
expiresAt, err := time.Parse(time.RFC3339, tokenData.ExpiresAt)
|
||||
if err != nil {
|
||||
expiresAt = time.Now().Add(1 * time.Hour)
|
||||
}
|
||||
|
||||
// Extract identifier for file naming
|
||||
idPart := extractKiroIdentifier(tokenData.Email, tokenData.ProfileArn)
|
||||
|
||||
now := time.Now()
|
||||
fileName := fmt.Sprintf("kiro-google-%s.json", idPart)
|
||||
|
||||
record := &coreauth.Auth{
|
||||
ID: fileName,
|
||||
Provider: "kiro",
|
||||
FileName: fileName,
|
||||
Label: "kiro-google",
|
||||
Status: coreauth.StatusActive,
|
||||
CreatedAt: now,
|
||||
UpdatedAt: now,
|
||||
Metadata: map[string]any{
|
||||
"type": "kiro",
|
||||
"access_token": tokenData.AccessToken,
|
||||
"refresh_token": tokenData.RefreshToken,
|
||||
"profile_arn": tokenData.ProfileArn,
|
||||
"expires_at": tokenData.ExpiresAt,
|
||||
"auth_method": tokenData.AuthMethod,
|
||||
"provider": tokenData.Provider,
|
||||
"email": tokenData.Email,
|
||||
},
|
||||
Attributes: map[string]string{
|
||||
"profile_arn": tokenData.ProfileArn,
|
||||
"source": "google-oauth",
|
||||
"email": tokenData.Email,
|
||||
},
|
||||
// NextRefreshAfter is aligned with RefreshLead (5min)
|
||||
NextRefreshAfter: expiresAt.Add(-5 * time.Minute),
|
||||
}
|
||||
|
||||
if tokenData.Email != "" {
|
||||
fmt.Printf("\n✓ Kiro Google authentication completed successfully! (Account: %s)\n", tokenData.Email)
|
||||
} else {
|
||||
fmt.Println("\n✓ Kiro Google authentication completed successfully!")
|
||||
}
|
||||
|
||||
return record, nil
|
||||
return nil, fmt.Errorf("Google login is not available for third-party applications due to AWS Cognito restrictions.\n\nAlternatives:\n 1. Use AWS Builder ID: cliproxy kiro --builder-id\n 2. Import token from Kiro IDE: cliproxy kiro --import\n\nTo get a token from Kiro IDE:\n 1. Open Kiro IDE and login with Google\n 2. Find: ~/.kiro/kiro-auth-token.json\n 3. Run: cliproxy kiro --import")
|
||||
}
|
||||
|
||||
// LoginWithGitHub performs OAuth login for Kiro with GitHub.
|
||||
// This uses a custom protocol handler (kiro://) to receive the callback.
|
||||
// NOTE: GitHub login is not available for third-party applications due to AWS Cognito restrictions.
|
||||
// Please use AWS Builder ID or import your token from Kiro IDE.
|
||||
func (a *KiroAuthenticator) LoginWithGitHub(ctx context.Context, cfg *config.Config, opts *LoginOptions) (*coreauth.Auth, error) {
|
||||
if cfg == nil {
|
||||
return nil, fmt.Errorf("kiro auth: configuration is required")
|
||||
}
|
||||
|
||||
oauth := kiroauth.NewKiroOAuth(cfg)
|
||||
|
||||
// Use GitHub OAuth flow with protocol handler
|
||||
tokenData, err := oauth.LoginWithGitHub(ctx)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("github login failed: %w", err)
|
||||
}
|
||||
|
||||
// Parse expires_at
|
||||
expiresAt, err := time.Parse(time.RFC3339, tokenData.ExpiresAt)
|
||||
if err != nil {
|
||||
expiresAt = time.Now().Add(1 * time.Hour)
|
||||
}
|
||||
|
||||
// Extract identifier for file naming
|
||||
idPart := extractKiroIdentifier(tokenData.Email, tokenData.ProfileArn)
|
||||
|
||||
now := time.Now()
|
||||
fileName := fmt.Sprintf("kiro-github-%s.json", idPart)
|
||||
|
||||
record := &coreauth.Auth{
|
||||
ID: fileName,
|
||||
Provider: "kiro",
|
||||
FileName: fileName,
|
||||
Label: "kiro-github",
|
||||
Status: coreauth.StatusActive,
|
||||
CreatedAt: now,
|
||||
UpdatedAt: now,
|
||||
Metadata: map[string]any{
|
||||
"type": "kiro",
|
||||
"access_token": tokenData.AccessToken,
|
||||
"refresh_token": tokenData.RefreshToken,
|
||||
"profile_arn": tokenData.ProfileArn,
|
||||
"expires_at": tokenData.ExpiresAt,
|
||||
"auth_method": tokenData.AuthMethod,
|
||||
"provider": tokenData.Provider,
|
||||
"email": tokenData.Email,
|
||||
},
|
||||
Attributes: map[string]string{
|
||||
"profile_arn": tokenData.ProfileArn,
|
||||
"source": "github-oauth",
|
||||
"email": tokenData.Email,
|
||||
},
|
||||
// NextRefreshAfter is aligned with RefreshLead (5min)
|
||||
NextRefreshAfter: expiresAt.Add(-5 * time.Minute),
|
||||
}
|
||||
|
||||
if tokenData.Email != "" {
|
||||
fmt.Printf("\n✓ Kiro GitHub authentication completed successfully! (Account: %s)\n", tokenData.Email)
|
||||
} else {
|
||||
fmt.Println("\n✓ Kiro GitHub authentication completed successfully!")
|
||||
}
|
||||
|
||||
return record, nil
|
||||
return nil, fmt.Errorf("GitHub login is not available for third-party applications due to AWS Cognito restrictions.\n\nAlternatives:\n 1. Use AWS Builder ID: cliproxy kiro --builder-id\n 2. Import token from Kiro IDE: cliproxy kiro --import\n\nTo get a token from Kiro IDE:\n 1. Open Kiro IDE and login with GitHub\n 2. Find: ~/.kiro/kiro-auth-token.json\n 3. Run: cliproxy kiro --import")
|
||||
}
|
||||
|
||||
// ImportFromKiroIDE imports token from Kiro IDE's token file.
|
||||
@@ -361,7 +254,7 @@ func (a *KiroAuthenticator) ImportFromKiroIDE(ctx context.Context, cfg *config.C
|
||||
}
|
||||
|
||||
// Extract identifier for file naming
|
||||
idPart := extractKiroIdentifier(tokenData.Email, tokenData.ProfileArn)
|
||||
idPart := extractKiroIdentifier(tokenData.Email, tokenData.ProfileArn, tokenData.ClientID)
|
||||
// Sanitize provider to prevent path traversal (defense-in-depth)
|
||||
provider := kiroauth.SanitizeEmailForFilename(strings.ToLower(strings.TrimSpace(tokenData.Provider)))
|
||||
if provider == "" {
|
||||
@@ -387,12 +280,17 @@ func (a *KiroAuthenticator) ImportFromKiroIDE(ctx context.Context, cfg *config.C
|
||||
"expires_at": tokenData.ExpiresAt,
|
||||
"auth_method": tokenData.AuthMethod,
|
||||
"provider": tokenData.Provider,
|
||||
"client_id": tokenData.ClientID,
|
||||
"client_secret": tokenData.ClientSecret,
|
||||
"email": tokenData.Email,
|
||||
"region": tokenData.Region,
|
||||
"start_url": tokenData.StartURL,
|
||||
},
|
||||
Attributes: map[string]string{
|
||||
"profile_arn": tokenData.ProfileArn,
|
||||
"source": "kiro-ide-import",
|
||||
"email": tokenData.Email,
|
||||
"region": tokenData.Region,
|
||||
},
|
||||
// NextRefreshAfter is aligned with RefreshLead (5min)
|
||||
NextRefreshAfter: expiresAt.Add(-5 * time.Minute),
|
||||
|
||||
470
sdk/auth/kiro.go.bak
Normal file
470
sdk/auth/kiro.go.bak
Normal file
@@ -0,0 +1,470 @@
|
||||
package auth
|
||||
|
||||
import (
|
||||
"context"
|
||||
"fmt"
|
||||
"strings"
|
||||
"time"
|
||||
|
||||
kiroauth "github.com/router-for-me/CLIProxyAPI/v6/internal/auth/kiro"
|
||||
"github.com/router-for-me/CLIProxyAPI/v6/internal/config"
|
||||
coreauth "github.com/router-for-me/CLIProxyAPI/v6/sdk/cliproxy/auth"
|
||||
)
|
||||
|
||||
// extractKiroIdentifier extracts a meaningful identifier for file naming.
|
||||
// Returns account name if provided, otherwise profile ARN ID.
|
||||
// All extracted values are sanitized to prevent path injection attacks.
|
||||
func extractKiroIdentifier(accountName, profileArn string) string {
|
||||
// Priority 1: Use account name if provided
|
||||
if accountName != "" {
|
||||
return kiroauth.SanitizeEmailForFilename(accountName)
|
||||
}
|
||||
|
||||
// Priority 2: Use profile ARN ID part (sanitized to prevent path injection)
|
||||
if profileArn != "" {
|
||||
parts := strings.Split(profileArn, "/")
|
||||
if len(parts) >= 2 {
|
||||
// Sanitize the ARN component to prevent path traversal
|
||||
return kiroauth.SanitizeEmailForFilename(parts[len(parts)-1])
|
||||
}
|
||||
}
|
||||
|
||||
// Fallback: timestamp
|
||||
return fmt.Sprintf("%d", time.Now().UnixNano()%100000)
|
||||
}
|
||||
|
||||
// KiroAuthenticator implements OAuth authentication for Kiro with Google login.
|
||||
type KiroAuthenticator struct{}
|
||||
|
||||
// NewKiroAuthenticator constructs a Kiro authenticator.
|
||||
func NewKiroAuthenticator() *KiroAuthenticator {
|
||||
return &KiroAuthenticator{}
|
||||
}
|
||||
|
||||
// Provider returns the provider key for the authenticator.
|
||||
func (a *KiroAuthenticator) Provider() string {
|
||||
return "kiro"
|
||||
}
|
||||
|
||||
// RefreshLead indicates how soon before expiry a refresh should be attempted.
|
||||
// Set to 5 minutes to match Antigravity and avoid frequent refresh checks while still ensuring timely token refresh.
|
||||
func (a *KiroAuthenticator) RefreshLead() *time.Duration {
|
||||
d := 5 * time.Minute
|
||||
return &d
|
||||
}
|
||||
|
||||
// createAuthRecord creates an auth record from token data.
|
||||
func (a *KiroAuthenticator) createAuthRecord(tokenData *kiroauth.KiroTokenData, source string) (*coreauth.Auth, error) {
|
||||
// Parse expires_at
|
||||
expiresAt, err := time.Parse(time.RFC3339, tokenData.ExpiresAt)
|
||||
if err != nil {
|
||||
expiresAt = time.Now().Add(1 * time.Hour)
|
||||
}
|
||||
|
||||
// Extract identifier for file naming
|
||||
idPart := extractKiroIdentifier(tokenData.Email, tokenData.ProfileArn)
|
||||
|
||||
// Determine label based on auth method
|
||||
label := fmt.Sprintf("kiro-%s", source)
|
||||
if tokenData.AuthMethod == "idc" {
|
||||
label = "kiro-idc"
|
||||
}
|
||||
|
||||
now := time.Now()
|
||||
fileName := fmt.Sprintf("%s-%s.json", label, idPart)
|
||||
|
||||
metadata := map[string]any{
|
||||
"type": "kiro",
|
||||
"access_token": tokenData.AccessToken,
|
||||
"refresh_token": tokenData.RefreshToken,
|
||||
"profile_arn": tokenData.ProfileArn,
|
||||
"expires_at": tokenData.ExpiresAt,
|
||||
"auth_method": tokenData.AuthMethod,
|
||||
"provider": tokenData.Provider,
|
||||
"client_id": tokenData.ClientID,
|
||||
"client_secret": tokenData.ClientSecret,
|
||||
"email": tokenData.Email,
|
||||
}
|
||||
|
||||
// Add IDC-specific fields if present
|
||||
if tokenData.StartURL != "" {
|
||||
metadata["start_url"] = tokenData.StartURL
|
||||
}
|
||||
if tokenData.Region != "" {
|
||||
metadata["region"] = tokenData.Region
|
||||
}
|
||||
|
||||
attributes := map[string]string{
|
||||
"profile_arn": tokenData.ProfileArn,
|
||||
"source": source,
|
||||
"email": tokenData.Email,
|
||||
}
|
||||
|
||||
// Add IDC-specific attributes if present
|
||||
if tokenData.AuthMethod == "idc" {
|
||||
attributes["source"] = "aws-idc"
|
||||
if tokenData.StartURL != "" {
|
||||
attributes["start_url"] = tokenData.StartURL
|
||||
}
|
||||
if tokenData.Region != "" {
|
||||
attributes["region"] = tokenData.Region
|
||||
}
|
||||
}
|
||||
|
||||
record := &coreauth.Auth{
|
||||
ID: fileName,
|
||||
Provider: "kiro",
|
||||
FileName: fileName,
|
||||
Label: label,
|
||||
Status: coreauth.StatusActive,
|
||||
CreatedAt: now,
|
||||
UpdatedAt: now,
|
||||
Metadata: metadata,
|
||||
Attributes: attributes,
|
||||
// NextRefreshAfter is aligned with RefreshLead (5min)
|
||||
NextRefreshAfter: expiresAt.Add(-5 * time.Minute),
|
||||
}
|
||||
|
||||
if tokenData.Email != "" {
|
||||
fmt.Printf("\n✓ Kiro authentication completed successfully! (Account: %s)\n", tokenData.Email)
|
||||
} else {
|
||||
fmt.Println("\n✓ Kiro authentication completed successfully!")
|
||||
}
|
||||
|
||||
return record, nil
|
||||
}
|
||||
|
||||
// Login performs OAuth login for Kiro with AWS (Builder ID or IDC).
|
||||
// This shows a method selection prompt and handles both flows.
|
||||
func (a *KiroAuthenticator) Login(ctx context.Context, cfg *config.Config, opts *LoginOptions) (*coreauth.Auth, error) {
|
||||
if cfg == nil {
|
||||
return nil, fmt.Errorf("kiro auth: configuration is required")
|
||||
}
|
||||
|
||||
// Use the unified method selection flow (Builder ID or IDC)
|
||||
ssoClient := kiroauth.NewSSOOIDCClient(cfg)
|
||||
tokenData, err := ssoClient.LoginWithMethodSelection(ctx)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("login failed: %w", err)
|
||||
}
|
||||
|
||||
return a.createAuthRecord(tokenData, "aws")
|
||||
}
|
||||
|
||||
// LoginWithAuthCode performs OAuth login for Kiro with AWS Builder ID using authorization code flow.
|
||||
// This provides a better UX than device code flow as it uses automatic browser callback.
|
||||
func (a *KiroAuthenticator) LoginWithAuthCode(ctx context.Context, cfg *config.Config, opts *LoginOptions) (*coreauth.Auth, error) {
|
||||
if cfg == nil {
|
||||
return nil, fmt.Errorf("kiro auth: configuration is required")
|
||||
}
|
||||
|
||||
oauth := kiroauth.NewKiroOAuth(cfg)
|
||||
|
||||
// Use AWS Builder ID authorization code flow
|
||||
tokenData, err := oauth.LoginWithBuilderIDAuthCode(ctx)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("login failed: %w", err)
|
||||
}
|
||||
|
||||
// Parse expires_at
|
||||
expiresAt, err := time.Parse(time.RFC3339, tokenData.ExpiresAt)
|
||||
if err != nil {
|
||||
expiresAt = time.Now().Add(1 * time.Hour)
|
||||
}
|
||||
|
||||
// Extract identifier for file naming
|
||||
idPart := extractKiroIdentifier(tokenData.Email, tokenData.ProfileArn)
|
||||
|
||||
now := time.Now()
|
||||
fileName := fmt.Sprintf("kiro-aws-%s.json", idPart)
|
||||
|
||||
record := &coreauth.Auth{
|
||||
ID: fileName,
|
||||
Provider: "kiro",
|
||||
FileName: fileName,
|
||||
Label: "kiro-aws",
|
||||
Status: coreauth.StatusActive,
|
||||
CreatedAt: now,
|
||||
UpdatedAt: now,
|
||||
Metadata: map[string]any{
|
||||
"type": "kiro",
|
||||
"access_token": tokenData.AccessToken,
|
||||
"refresh_token": tokenData.RefreshToken,
|
||||
"profile_arn": tokenData.ProfileArn,
|
||||
"expires_at": tokenData.ExpiresAt,
|
||||
"auth_method": tokenData.AuthMethod,
|
||||
"provider": tokenData.Provider,
|
||||
"client_id": tokenData.ClientID,
|
||||
"client_secret": tokenData.ClientSecret,
|
||||
"email": tokenData.Email,
|
||||
},
|
||||
Attributes: map[string]string{
|
||||
"profile_arn": tokenData.ProfileArn,
|
||||
"source": "aws-builder-id-authcode",
|
||||
"email": tokenData.Email,
|
||||
},
|
||||
// NextRefreshAfter is aligned with RefreshLead (5min)
|
||||
NextRefreshAfter: expiresAt.Add(-5 * time.Minute),
|
||||
}
|
||||
|
||||
if tokenData.Email != "" {
|
||||
fmt.Printf("\n✓ Kiro authentication completed successfully! (Account: %s)\n", tokenData.Email)
|
||||
} else {
|
||||
fmt.Println("\n✓ Kiro authentication completed successfully!")
|
||||
}
|
||||
|
||||
return record, nil
|
||||
}
|
||||
|
||||
// LoginWithGoogle performs OAuth login for Kiro with Google.
|
||||
// This uses a custom protocol handler (kiro://) to receive the callback.
|
||||
func (a *KiroAuthenticator) LoginWithGoogle(ctx context.Context, cfg *config.Config, opts *LoginOptions) (*coreauth.Auth, error) {
|
||||
if cfg == nil {
|
||||
return nil, fmt.Errorf("kiro auth: configuration is required")
|
||||
}
|
||||
|
||||
oauth := kiroauth.NewKiroOAuth(cfg)
|
||||
|
||||
// Use Google OAuth flow with protocol handler
|
||||
tokenData, err := oauth.LoginWithGoogle(ctx)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("google login failed: %w", err)
|
||||
}
|
||||
|
||||
// Parse expires_at
|
||||
expiresAt, err := time.Parse(time.RFC3339, tokenData.ExpiresAt)
|
||||
if err != nil {
|
||||
expiresAt = time.Now().Add(1 * time.Hour)
|
||||
}
|
||||
|
||||
// Extract identifier for file naming
|
||||
idPart := extractKiroIdentifier(tokenData.Email, tokenData.ProfileArn)
|
||||
|
||||
now := time.Now()
|
||||
fileName := fmt.Sprintf("kiro-google-%s.json", idPart)
|
||||
|
||||
record := &coreauth.Auth{
|
||||
ID: fileName,
|
||||
Provider: "kiro",
|
||||
FileName: fileName,
|
||||
Label: "kiro-google",
|
||||
Status: coreauth.StatusActive,
|
||||
CreatedAt: now,
|
||||
UpdatedAt: now,
|
||||
Metadata: map[string]any{
|
||||
"type": "kiro",
|
||||
"access_token": tokenData.AccessToken,
|
||||
"refresh_token": tokenData.RefreshToken,
|
||||
"profile_arn": tokenData.ProfileArn,
|
||||
"expires_at": tokenData.ExpiresAt,
|
||||
"auth_method": tokenData.AuthMethod,
|
||||
"provider": tokenData.Provider,
|
||||
"email": tokenData.Email,
|
||||
},
|
||||
Attributes: map[string]string{
|
||||
"profile_arn": tokenData.ProfileArn,
|
||||
"source": "google-oauth",
|
||||
"email": tokenData.Email,
|
||||
},
|
||||
// NextRefreshAfter is aligned with RefreshLead (5min)
|
||||
NextRefreshAfter: expiresAt.Add(-5 * time.Minute),
|
||||
}
|
||||
|
||||
if tokenData.Email != "" {
|
||||
fmt.Printf("\n✓ Kiro Google authentication completed successfully! (Account: %s)\n", tokenData.Email)
|
||||
} else {
|
||||
fmt.Println("\n✓ Kiro Google authentication completed successfully!")
|
||||
}
|
||||
|
||||
return record, nil
|
||||
}
|
||||
|
||||
// LoginWithGitHub performs OAuth login for Kiro with GitHub.
|
||||
// This uses a custom protocol handler (kiro://) to receive the callback.
|
||||
func (a *KiroAuthenticator) LoginWithGitHub(ctx context.Context, cfg *config.Config, opts *LoginOptions) (*coreauth.Auth, error) {
|
||||
if cfg == nil {
|
||||
return nil, fmt.Errorf("kiro auth: configuration is required")
|
||||
}
|
||||
|
||||
oauth := kiroauth.NewKiroOAuth(cfg)
|
||||
|
||||
// Use GitHub OAuth flow with protocol handler
|
||||
tokenData, err := oauth.LoginWithGitHub(ctx)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("github login failed: %w", err)
|
||||
}
|
||||
|
||||
// Parse expires_at
|
||||
expiresAt, err := time.Parse(time.RFC3339, tokenData.ExpiresAt)
|
||||
if err != nil {
|
||||
expiresAt = time.Now().Add(1 * time.Hour)
|
||||
}
|
||||
|
||||
// Extract identifier for file naming
|
||||
idPart := extractKiroIdentifier(tokenData.Email, tokenData.ProfileArn)
|
||||
|
||||
now := time.Now()
|
||||
fileName := fmt.Sprintf("kiro-github-%s.json", idPart)
|
||||
|
||||
record := &coreauth.Auth{
|
||||
ID: fileName,
|
||||
Provider: "kiro",
|
||||
FileName: fileName,
|
||||
Label: "kiro-github",
|
||||
Status: coreauth.StatusActive,
|
||||
CreatedAt: now,
|
||||
UpdatedAt: now,
|
||||
Metadata: map[string]any{
|
||||
"type": "kiro",
|
||||
"access_token": tokenData.AccessToken,
|
||||
"refresh_token": tokenData.RefreshToken,
|
||||
"profile_arn": tokenData.ProfileArn,
|
||||
"expires_at": tokenData.ExpiresAt,
|
||||
"auth_method": tokenData.AuthMethod,
|
||||
"provider": tokenData.Provider,
|
||||
"email": tokenData.Email,
|
||||
},
|
||||
Attributes: map[string]string{
|
||||
"profile_arn": tokenData.ProfileArn,
|
||||
"source": "github-oauth",
|
||||
"email": tokenData.Email,
|
||||
},
|
||||
// NextRefreshAfter is aligned with RefreshLead (5min)
|
||||
NextRefreshAfter: expiresAt.Add(-5 * time.Minute),
|
||||
}
|
||||
|
||||
if tokenData.Email != "" {
|
||||
fmt.Printf("\n✓ Kiro GitHub authentication completed successfully! (Account: %s)\n", tokenData.Email)
|
||||
} else {
|
||||
fmt.Println("\n✓ Kiro GitHub authentication completed successfully!")
|
||||
}
|
||||
|
||||
return record, nil
|
||||
}
|
||||
|
||||
// ImportFromKiroIDE imports token from Kiro IDE's token file.
|
||||
func (a *KiroAuthenticator) ImportFromKiroIDE(ctx context.Context, cfg *config.Config) (*coreauth.Auth, error) {
|
||||
tokenData, err := kiroauth.LoadKiroIDEToken()
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to load Kiro IDE token: %w", err)
|
||||
}
|
||||
|
||||
// Parse expires_at
|
||||
expiresAt, err := time.Parse(time.RFC3339, tokenData.ExpiresAt)
|
||||
if err != nil {
|
||||
expiresAt = time.Now().Add(1 * time.Hour)
|
||||
}
|
||||
|
||||
// Extract email from JWT if not already set (for imported tokens)
|
||||
if tokenData.Email == "" {
|
||||
tokenData.Email = kiroauth.ExtractEmailFromJWT(tokenData.AccessToken)
|
||||
}
|
||||
|
||||
// Extract identifier for file naming
|
||||
idPart := extractKiroIdentifier(tokenData.Email, tokenData.ProfileArn)
|
||||
// Sanitize provider to prevent path traversal (defense-in-depth)
|
||||
provider := kiroauth.SanitizeEmailForFilename(strings.ToLower(strings.TrimSpace(tokenData.Provider)))
|
||||
if provider == "" {
|
||||
provider = "imported" // Fallback for legacy tokens without provider
|
||||
}
|
||||
|
||||
now := time.Now()
|
||||
fileName := fmt.Sprintf("kiro-%s-%s.json", provider, idPart)
|
||||
|
||||
record := &coreauth.Auth{
|
||||
ID: fileName,
|
||||
Provider: "kiro",
|
||||
FileName: fileName,
|
||||
Label: fmt.Sprintf("kiro-%s", provider),
|
||||
Status: coreauth.StatusActive,
|
||||
CreatedAt: now,
|
||||
UpdatedAt: now,
|
||||
Metadata: map[string]any{
|
||||
"type": "kiro",
|
||||
"access_token": tokenData.AccessToken,
|
||||
"refresh_token": tokenData.RefreshToken,
|
||||
"profile_arn": tokenData.ProfileArn,
|
||||
"expires_at": tokenData.ExpiresAt,
|
||||
"auth_method": tokenData.AuthMethod,
|
||||
"provider": tokenData.Provider,
|
||||
"email": tokenData.Email,
|
||||
},
|
||||
Attributes: map[string]string{
|
||||
"profile_arn": tokenData.ProfileArn,
|
||||
"source": "kiro-ide-import",
|
||||
"email": tokenData.Email,
|
||||
},
|
||||
// NextRefreshAfter is aligned with RefreshLead (5min)
|
||||
NextRefreshAfter: expiresAt.Add(-5 * time.Minute),
|
||||
}
|
||||
|
||||
// Display the email if extracted
|
||||
if tokenData.Email != "" {
|
||||
fmt.Printf("\n✓ Imported Kiro token from IDE (Provider: %s, Account: %s)\n", tokenData.Provider, tokenData.Email)
|
||||
} else {
|
||||
fmt.Printf("\n✓ Imported Kiro token from IDE (Provider: %s)\n", tokenData.Provider)
|
||||
}
|
||||
|
||||
return record, nil
|
||||
}
|
||||
|
||||
// Refresh refreshes an expired Kiro token using AWS SSO OIDC.
|
||||
func (a *KiroAuthenticator) Refresh(ctx context.Context, cfg *config.Config, auth *coreauth.Auth) (*coreauth.Auth, error) {
|
||||
if auth == nil || auth.Metadata == nil {
|
||||
return nil, fmt.Errorf("invalid auth record")
|
||||
}
|
||||
|
||||
refreshToken, ok := auth.Metadata["refresh_token"].(string)
|
||||
if !ok || refreshToken == "" {
|
||||
return nil, fmt.Errorf("refresh token not found")
|
||||
}
|
||||
|
||||
clientID, _ := auth.Metadata["client_id"].(string)
|
||||
clientSecret, _ := auth.Metadata["client_secret"].(string)
|
||||
authMethod, _ := auth.Metadata["auth_method"].(string)
|
||||
startURL, _ := auth.Metadata["start_url"].(string)
|
||||
region, _ := auth.Metadata["region"].(string)
|
||||
|
||||
var tokenData *kiroauth.KiroTokenData
|
||||
var err error
|
||||
|
||||
ssoClient := kiroauth.NewSSOOIDCClient(cfg)
|
||||
|
||||
// Use SSO OIDC refresh for AWS Builder ID or IDC, otherwise use Kiro's OAuth refresh endpoint
|
||||
switch {
|
||||
case clientID != "" && clientSecret != "" && authMethod == "idc" && region != "":
|
||||
// IDC refresh with region-specific endpoint
|
||||
tokenData, err = ssoClient.RefreshTokenWithRegion(ctx, clientID, clientSecret, refreshToken, region, startURL)
|
||||
case clientID != "" && clientSecret != "" && authMethod == "builder-id":
|
||||
// Builder ID refresh with default endpoint
|
||||
tokenData, err = ssoClient.RefreshToken(ctx, clientID, clientSecret, refreshToken)
|
||||
default:
|
||||
// Fallback to Kiro's refresh endpoint (for social auth: Google/GitHub)
|
||||
oauth := kiroauth.NewKiroOAuth(cfg)
|
||||
tokenData, err = oauth.RefreshToken(ctx, refreshToken)
|
||||
}
|
||||
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("token refresh failed: %w", err)
|
||||
}
|
||||
|
||||
// Parse expires_at
|
||||
expiresAt, err := time.Parse(time.RFC3339, tokenData.ExpiresAt)
|
||||
if err != nil {
|
||||
expiresAt = time.Now().Add(1 * time.Hour)
|
||||
}
|
||||
|
||||
// Clone auth to avoid mutating the input parameter
|
||||
updated := auth.Clone()
|
||||
now := time.Now()
|
||||
updated.UpdatedAt = now
|
||||
updated.LastRefreshedAt = now
|
||||
updated.Metadata["access_token"] = tokenData.AccessToken
|
||||
updated.Metadata["refresh_token"] = tokenData.RefreshToken
|
||||
updated.Metadata["expires_at"] = tokenData.ExpiresAt
|
||||
updated.Metadata["last_refresh"] = now.Format(time.RFC3339) // For double-check optimization
|
||||
// NextRefreshAfter is aligned with RefreshLead (5min)
|
||||
updated.NextRefreshAfter = expiresAt.Add(-5 * time.Minute)
|
||||
|
||||
return updated, nil
|
||||
}
|
||||
@@ -13,6 +13,7 @@ import (
|
||||
"time"
|
||||
|
||||
"github.com/router-for-me/CLIProxyAPI/v6/internal/api"
|
||||
kiroauth "github.com/router-for-me/CLIProxyAPI/v6/internal/auth/kiro"
|
||||
"github.com/router-for-me/CLIProxyAPI/v6/internal/registry"
|
||||
"github.com/router-for-me/CLIProxyAPI/v6/internal/runtime/executor"
|
||||
_ "github.com/router-for-me/CLIProxyAPI/v6/internal/usage"
|
||||
@@ -775,7 +776,7 @@ func (s *Service) registerModelsForAuth(a *coreauth.Auth) {
|
||||
models = registry.GetGitHubCopilotModels()
|
||||
models = applyExcludedModels(models, excluded)
|
||||
case "kiro":
|
||||
models = registry.GetKiroModels()
|
||||
models = s.fetchKiroModels(a)
|
||||
models = applyExcludedModels(models, excluded)
|
||||
default:
|
||||
// Handle OpenAI-compatibility providers by name using config
|
||||
@@ -1338,3 +1339,201 @@ func applyOAuthModelAlias(cfg *config.Config, provider, authKind string, models
|
||||
}
|
||||
return out
|
||||
}
|
||||
|
||||
// fetchKiroModels attempts to dynamically fetch Kiro models from the API.
|
||||
// If dynamic fetch fails, it falls back to static registry.GetKiroModels().
|
||||
func (s *Service) fetchKiroModels(a *coreauth.Auth) []*ModelInfo {
|
||||
if a == nil {
|
||||
log.Debug("kiro: auth is nil, using static models")
|
||||
return registry.GetKiroModels()
|
||||
}
|
||||
|
||||
// Extract token data from auth attributes
|
||||
tokenData := s.extractKiroTokenData(a)
|
||||
if tokenData == nil || tokenData.AccessToken == "" {
|
||||
log.Debug("kiro: no valid token data in auth, using static models")
|
||||
return registry.GetKiroModels()
|
||||
}
|
||||
|
||||
// Create KiroAuth instance
|
||||
kAuth := kiroauth.NewKiroAuth(s.cfg)
|
||||
if kAuth == nil {
|
||||
log.Warn("kiro: failed to create KiroAuth instance, using static models")
|
||||
return registry.GetKiroModels()
|
||||
}
|
||||
|
||||
// Use timeout context for API call
|
||||
ctx, cancel := context.WithTimeout(context.Background(), 15*time.Second)
|
||||
defer cancel()
|
||||
|
||||
// Attempt to fetch dynamic models
|
||||
apiModels, err := kAuth.ListAvailableModels(ctx, tokenData)
|
||||
if err != nil {
|
||||
log.Warnf("kiro: failed to fetch dynamic models: %v, using static models", err)
|
||||
return registry.GetKiroModels()
|
||||
}
|
||||
|
||||
if len(apiModels) == 0 {
|
||||
log.Debug("kiro: API returned no models, using static models")
|
||||
return registry.GetKiroModels()
|
||||
}
|
||||
|
||||
// Convert API models to ModelInfo
|
||||
models := convertKiroAPIModels(apiModels)
|
||||
|
||||
// Generate agentic variants
|
||||
models = generateKiroAgenticVariants(models)
|
||||
|
||||
log.Infof("kiro: successfully fetched %d models from API (including agentic variants)", len(models))
|
||||
return models
|
||||
}
|
||||
|
||||
// extractKiroTokenData extracts KiroTokenData from auth attributes and metadata.
|
||||
func (s *Service) extractKiroTokenData(a *coreauth.Auth) *kiroauth.KiroTokenData {
|
||||
if a == nil || a.Attributes == nil {
|
||||
return nil
|
||||
}
|
||||
|
||||
accessToken := strings.TrimSpace(a.Attributes["access_token"])
|
||||
if accessToken == "" {
|
||||
return nil
|
||||
}
|
||||
|
||||
tokenData := &kiroauth.KiroTokenData{
|
||||
AccessToken: accessToken,
|
||||
ProfileArn: strings.TrimSpace(a.Attributes["profile_arn"]),
|
||||
}
|
||||
|
||||
// Also try to get refresh token from metadata
|
||||
if a.Metadata != nil {
|
||||
if rt, ok := a.Metadata["refresh_token"].(string); ok {
|
||||
tokenData.RefreshToken = rt
|
||||
}
|
||||
}
|
||||
|
||||
return tokenData
|
||||
}
|
||||
|
||||
// convertKiroAPIModels converts Kiro API models to ModelInfo slice.
|
||||
func convertKiroAPIModels(apiModels []*kiroauth.KiroModel) []*ModelInfo {
|
||||
if len(apiModels) == 0 {
|
||||
return nil
|
||||
}
|
||||
|
||||
now := time.Now().Unix()
|
||||
models := make([]*ModelInfo, 0, len(apiModels))
|
||||
|
||||
for _, m := range apiModels {
|
||||
if m == nil || m.ModelID == "" {
|
||||
continue
|
||||
}
|
||||
|
||||
// Create model ID with kiro- prefix
|
||||
modelID := "kiro-" + normalizeKiroModelID(m.ModelID)
|
||||
|
||||
info := &ModelInfo{
|
||||
ID: modelID,
|
||||
Object: "model",
|
||||
Created: now,
|
||||
OwnedBy: "aws",
|
||||
Type: "kiro",
|
||||
DisplayName: formatKiroDisplayName(m.ModelName, m.RateMultiplier),
|
||||
Description: m.Description,
|
||||
ContextLength: 200000,
|
||||
MaxCompletionTokens: 64000,
|
||||
Thinking: ®istry.ThinkingSupport{Min: 1024, Max: 32000, ZeroAllowed: true, DynamicAllowed: true},
|
||||
}
|
||||
|
||||
if m.MaxInputTokens > 0 {
|
||||
info.ContextLength = m.MaxInputTokens
|
||||
}
|
||||
|
||||
models = append(models, info)
|
||||
}
|
||||
|
||||
return models
|
||||
}
|
||||
|
||||
// normalizeKiroModelID normalizes a Kiro model ID by converting dots to dashes
|
||||
// and removing common prefixes.
|
||||
func normalizeKiroModelID(modelID string) string {
|
||||
// Remove common prefixes
|
||||
modelID = strings.TrimPrefix(modelID, "anthropic.")
|
||||
modelID = strings.TrimPrefix(modelID, "amazon.")
|
||||
|
||||
// Replace dots with dashes for consistency
|
||||
modelID = strings.ReplaceAll(modelID, ".", "-")
|
||||
|
||||
// Replace underscores with dashes
|
||||
modelID = strings.ReplaceAll(modelID, "_", "-")
|
||||
|
||||
return strings.ToLower(modelID)
|
||||
}
|
||||
|
||||
// formatKiroDisplayName formats the display name with rate multiplier info.
|
||||
func formatKiroDisplayName(modelName string, rateMultiplier float64) string {
|
||||
if modelName == "" {
|
||||
return ""
|
||||
}
|
||||
|
||||
displayName := "Kiro " + modelName
|
||||
if rateMultiplier > 0 && rateMultiplier != 1.0 {
|
||||
displayName += fmt.Sprintf(" (%.1fx credit)", rateMultiplier)
|
||||
}
|
||||
|
||||
return displayName
|
||||
}
|
||||
|
||||
// generateKiroAgenticVariants generates agentic variants for Kiro models.
|
||||
// Agentic variants have optimized system prompts for coding agents.
|
||||
func generateKiroAgenticVariants(models []*ModelInfo) []*ModelInfo {
|
||||
if len(models) == 0 {
|
||||
return models
|
||||
}
|
||||
|
||||
result := make([]*ModelInfo, 0, len(models)*2)
|
||||
result = append(result, models...)
|
||||
|
||||
for _, m := range models {
|
||||
if m == nil {
|
||||
continue
|
||||
}
|
||||
|
||||
// Skip if already an agentic variant
|
||||
if strings.HasSuffix(m.ID, "-agentic") {
|
||||
continue
|
||||
}
|
||||
|
||||
// Skip auto models from agentic variant generation
|
||||
if strings.Contains(m.ID, "-auto") {
|
||||
continue
|
||||
}
|
||||
|
||||
// Create agentic variant
|
||||
agentic := &ModelInfo{
|
||||
ID: m.ID + "-agentic",
|
||||
Object: m.Object,
|
||||
Created: m.Created,
|
||||
OwnedBy: m.OwnedBy,
|
||||
Type: m.Type,
|
||||
DisplayName: m.DisplayName + " (Agentic)",
|
||||
Description: m.Description + " - Optimized for coding agents (chunked writes)",
|
||||
ContextLength: m.ContextLength,
|
||||
MaxCompletionTokens: m.MaxCompletionTokens,
|
||||
}
|
||||
|
||||
// Copy thinking support if present
|
||||
if m.Thinking != nil {
|
||||
agentic.Thinking = ®istry.ThinkingSupport{
|
||||
Min: m.Thinking.Min,
|
||||
Max: m.Thinking.Max,
|
||||
ZeroAllowed: m.Thinking.ZeroAllowed,
|
||||
DynamicAllowed: m.Thinking.DynamicAllowed,
|
||||
}
|
||||
}
|
||||
|
||||
result = append(result, agentic)
|
||||
}
|
||||
|
||||
return result
|
||||
}
|
||||
|
||||
452
test_api.py
Normal file
452
test_api.py
Normal file
@@ -0,0 +1,452 @@
|
||||
#!/usr/bin/env python3
|
||||
"""
|
||||
CLIProxyAPI 全面测试脚本
|
||||
测试模型列表、流式输出、thinking模式及复杂任务
|
||||
"""
|
||||
|
||||
import requests
|
||||
import json
|
||||
import time
|
||||
import sys
|
||||
import io
|
||||
from typing import Optional, List, Dict, Any
|
||||
|
||||
# 修复 Windows 控制台编码问题
|
||||
sys.stdout = io.TextIOWrapper(sys.stdout.buffer, encoding='utf-8', errors='replace')
|
||||
sys.stderr = io.TextIOWrapper(sys.stderr.buffer, encoding='utf-8', errors='replace')
|
||||
|
||||
# 配置
|
||||
BASE_URL = "http://localhost:8317"
|
||||
API_KEY = "your-api-key-1"
|
||||
HEADERS = {
|
||||
"Authorization": f"Bearer {API_KEY}",
|
||||
"Content-Type": "application/json"
|
||||
}
|
||||
|
||||
# 复杂任务提示词 - 用于测试 thinking 模式
|
||||
COMPLEX_TASK_PROMPT = """请帮我分析以下复杂的编程问题,并给出详细的解决方案:
|
||||
|
||||
问题:设计一个高并发的分布式任务调度系统,需要满足以下要求:
|
||||
1. 支持百万级任务队列
|
||||
2. 任务可以设置优先级、延迟执行、定时执行
|
||||
3. 支持任务依赖关系(DAG调度)
|
||||
4. 失败重试机制,支持指数退避
|
||||
5. 任务结果持久化和查询
|
||||
6. 水平扩展能力
|
||||
7. 监控和告警
|
||||
|
||||
请从以下几个方面详细分析:
|
||||
1. 整体架构设计
|
||||
2. 核心数据结构
|
||||
3. 调度算法选择
|
||||
4. 容错机制设计
|
||||
5. 性能优化策略
|
||||
6. 技术选型建议
|
||||
|
||||
请逐步思考每个方面,给出你的推理过程。"""
|
||||
|
||||
# 简单测试提示词
|
||||
SIMPLE_PROMPT = "Hello! Please respond with 'OK' if you receive this message."
|
||||
|
||||
def print_separator(title: str):
|
||||
print(f"\n{'='*60}")
|
||||
print(f" {title}")
|
||||
print(f"{'='*60}\n")
|
||||
|
||||
def print_result(name: str, success: bool, detail: str = ""):
|
||||
status = "✅ PASS" if success else "❌ FAIL"
|
||||
print(f"{status} | {name}")
|
||||
if detail:
|
||||
print(f" └─ {detail[:200]}{'...' if len(detail) > 200 else ''}")
|
||||
|
||||
def get_models() -> List[str]:
|
||||
"""获取可用模型列表"""
|
||||
print_separator("获取模型列表")
|
||||
try:
|
||||
resp = requests.get(f"{BASE_URL}/v1/models", headers=HEADERS, timeout=30)
|
||||
if resp.status_code == 200:
|
||||
data = resp.json()
|
||||
models = [m.get("id", m.get("name", "unknown")) for m in data.get("data", [])]
|
||||
print(f"找到 {len(models)} 个模型:")
|
||||
for m in models:
|
||||
print(f" - {m}")
|
||||
return models
|
||||
else:
|
||||
print(f"❌ 获取模型列表失败: HTTP {resp.status_code}")
|
||||
print(f" 响应: {resp.text[:500]}")
|
||||
return []
|
||||
except Exception as e:
|
||||
print(f"❌ 获取模型列表异常: {e}")
|
||||
return []
|
||||
|
||||
def test_model_basic(model: str) -> tuple:
|
||||
"""基础可用性测试,返回 (success, error_detail)"""
|
||||
try:
|
||||
payload = {
|
||||
"model": model,
|
||||
"messages": [{"role": "user", "content": SIMPLE_PROMPT}],
|
||||
"max_tokens": 50,
|
||||
"stream": False
|
||||
}
|
||||
resp = requests.post(
|
||||
f"{BASE_URL}/v1/chat/completions",
|
||||
headers=HEADERS,
|
||||
json=payload,
|
||||
timeout=60
|
||||
)
|
||||
if resp.status_code == 200:
|
||||
data = resp.json()
|
||||
content = data.get("choices", [{}])[0].get("message", {}).get("content", "")
|
||||
return (bool(content), f"content_len={len(content)}")
|
||||
else:
|
||||
return (False, f"HTTP {resp.status_code}: {resp.text[:300]}")
|
||||
except Exception as e:
|
||||
return (False, str(e))
|
||||
|
||||
def test_streaming(model: str) -> Dict[str, Any]:
|
||||
"""测试流式输出"""
|
||||
result = {"success": False, "chunks": 0, "content": "", "error": None}
|
||||
try:
|
||||
payload = {
|
||||
"model": model,
|
||||
"messages": [{"role": "user", "content": "Count from 1 to 5, one number per line."}],
|
||||
"max_tokens": 100,
|
||||
"stream": True
|
||||
}
|
||||
resp = requests.post(
|
||||
f"{BASE_URL}/v1/chat/completions",
|
||||
headers=HEADERS,
|
||||
json=payload,
|
||||
timeout=60,
|
||||
stream=True
|
||||
)
|
||||
|
||||
if resp.status_code != 200:
|
||||
result["error"] = f"HTTP {resp.status_code}: {resp.text[:200]}"
|
||||
return result
|
||||
|
||||
content_parts = []
|
||||
for line in resp.iter_lines():
|
||||
if line:
|
||||
line_str = line.decode('utf-8')
|
||||
if line_str.startswith("data: "):
|
||||
data_str = line_str[6:]
|
||||
if data_str.strip() == "[DONE]":
|
||||
break
|
||||
try:
|
||||
data = json.loads(data_str)
|
||||
result["chunks"] += 1
|
||||
choices = data.get("choices", [])
|
||||
if choices:
|
||||
delta = choices[0].get("delta", {})
|
||||
if "content" in delta and delta["content"]:
|
||||
content_parts.append(delta["content"])
|
||||
except json.JSONDecodeError:
|
||||
pass
|
||||
except Exception as e:
|
||||
result["error"] = f"Parse error: {e}, data: {data_str[:200]}"
|
||||
|
||||
result["content"] = "".join(content_parts)
|
||||
result["success"] = result["chunks"] > 0 and len(result["content"]) > 0
|
||||
|
||||
except Exception as e:
|
||||
result["error"] = str(e)
|
||||
|
||||
return result
|
||||
|
||||
def test_thinking_mode(model: str, complex_task: bool = False) -> Dict[str, Any]:
|
||||
"""测试 thinking 模式"""
|
||||
result = {
|
||||
"success": False,
|
||||
"has_reasoning": False,
|
||||
"reasoning_content": "",
|
||||
"content": "",
|
||||
"error": None,
|
||||
"chunks": 0
|
||||
}
|
||||
|
||||
prompt = COMPLEX_TASK_PROMPT if complex_task else "What is 15 * 23? Please think step by step."
|
||||
|
||||
try:
|
||||
# 尝试不同的 thinking 模式参数格式
|
||||
payload = {
|
||||
"model": model,
|
||||
"messages": [{"role": "user", "content": prompt}],
|
||||
"max_tokens": 8000 if complex_task else 2000,
|
||||
"stream": True
|
||||
}
|
||||
|
||||
# 根据模型类型添加 thinking 参数
|
||||
if "claude" in model.lower():
|
||||
payload["thinking"] = {"type": "enabled", "budget_tokens": 5000 if complex_task else 2000}
|
||||
elif "gemini" in model.lower():
|
||||
payload["thinking"] = {"thinking_budget": 5000 if complex_task else 2000}
|
||||
elif "gpt" in model.lower() or "codex" in model.lower() or "o1" in model.lower() or "o3" in model.lower():
|
||||
payload["reasoning_effort"] = "high" if complex_task else "medium"
|
||||
else:
|
||||
# 通用格式
|
||||
payload["thinking"] = {"type": "enabled", "budget_tokens": 5000 if complex_task else 2000}
|
||||
|
||||
resp = requests.post(
|
||||
f"{BASE_URL}/v1/chat/completions",
|
||||
headers=HEADERS,
|
||||
json=payload,
|
||||
timeout=300 if complex_task else 120,
|
||||
stream=True
|
||||
)
|
||||
|
||||
if resp.status_code != 200:
|
||||
result["error"] = f"HTTP {resp.status_code}: {resp.text[:500]}"
|
||||
return result
|
||||
|
||||
content_parts = []
|
||||
reasoning_parts = []
|
||||
|
||||
for line in resp.iter_lines():
|
||||
if line:
|
||||
line_str = line.decode('utf-8')
|
||||
if line_str.startswith("data: "):
|
||||
data_str = line_str[6:]
|
||||
if data_str.strip() == "[DONE]":
|
||||
break
|
||||
try:
|
||||
data = json.loads(data_str)
|
||||
result["chunks"] += 1
|
||||
|
||||
choices = data.get("choices", [])
|
||||
if not choices:
|
||||
continue
|
||||
choice = choices[0]
|
||||
delta = choice.get("delta", {})
|
||||
|
||||
# 检查 reasoning_content (Claude/OpenAI格式)
|
||||
if "reasoning_content" in delta and delta["reasoning_content"]:
|
||||
reasoning_parts.append(delta["reasoning_content"])
|
||||
result["has_reasoning"] = True
|
||||
|
||||
# 检查 thinking (Gemini格式)
|
||||
if "thinking" in delta and delta["thinking"]:
|
||||
reasoning_parts.append(delta["thinking"])
|
||||
result["has_reasoning"] = True
|
||||
|
||||
# 常规内容
|
||||
if "content" in delta and delta["content"]:
|
||||
content_parts.append(delta["content"])
|
||||
|
||||
except json.JSONDecodeError as e:
|
||||
pass
|
||||
except Exception as e:
|
||||
result["error"] = f"Parse error: {e}"
|
||||
|
||||
result["reasoning_content"] = "".join(reasoning_parts)
|
||||
result["content"] = "".join(content_parts)
|
||||
result["success"] = result["chunks"] > 0 and (len(result["content"]) > 0 or len(result["reasoning_content"]) > 0)
|
||||
|
||||
except requests.exceptions.Timeout:
|
||||
result["error"] = "Request timeout"
|
||||
except Exception as e:
|
||||
result["error"] = str(e)
|
||||
|
||||
return result
|
||||
|
||||
def run_full_test():
|
||||
"""运行完整测试"""
|
||||
print("\n" + "="*60)
|
||||
print(" CLIProxyAPI 全面测试")
|
||||
print("="*60)
|
||||
print(f"目标地址: {BASE_URL}")
|
||||
print(f"API Key: {API_KEY[:10]}...")
|
||||
|
||||
# 1. 获取模型列表
|
||||
models = get_models()
|
||||
if not models:
|
||||
print("\n❌ 无法获取模型列表,测试终止")
|
||||
return
|
||||
|
||||
# 2. 基础可用性测试
|
||||
print_separator("基础可用性测试")
|
||||
available_models = []
|
||||
for model in models:
|
||||
success, detail = test_model_basic(model)
|
||||
print_result(f"模型: {model}", success, detail)
|
||||
if success:
|
||||
available_models.append(model)
|
||||
|
||||
print(f"\n可用模型: {len(available_models)}/{len(models)}")
|
||||
|
||||
if not available_models:
|
||||
print("\n❌ 没有可用的模型,测试终止")
|
||||
return
|
||||
|
||||
# 3. 流式输出测试
|
||||
print_separator("流式输出测试")
|
||||
streaming_results = {}
|
||||
for model in available_models:
|
||||
result = test_streaming(model)
|
||||
streaming_results[model] = result
|
||||
detail = f"chunks={result['chunks']}, content_len={len(result['content'])}"
|
||||
if result["error"]:
|
||||
detail = f"error: {result['error']}"
|
||||
print_result(f"模型: {model}", result["success"], detail)
|
||||
|
||||
# 4. Thinking 模式测试 (简单任务)
|
||||
print_separator("Thinking 模式测试 (简单任务)")
|
||||
thinking_results = {}
|
||||
for model in available_models:
|
||||
result = test_thinking_mode(model, complex_task=False)
|
||||
thinking_results[model] = result
|
||||
detail = f"reasoning={result['has_reasoning']}, chunks={result['chunks']}"
|
||||
if result["error"]:
|
||||
detail = f"error: {result['error']}"
|
||||
print_result(f"模型: {model}", result["success"], detail)
|
||||
|
||||
# 5. Thinking 模式测试 (复杂任务) - 只测试支持 thinking 的模型
|
||||
print_separator("Thinking 模式测试 (复杂任务)")
|
||||
complex_thinking_results = {}
|
||||
|
||||
# 选择前3个可用模型进行复杂任务测试
|
||||
test_models = available_models[:3]
|
||||
print(f"测试模型 (取前3个): {test_models}\n")
|
||||
|
||||
for model in test_models:
|
||||
print(f"⏳ 正在测试 {model} (复杂任务,可能需要较长时间)...")
|
||||
result = test_thinking_mode(model, complex_task=True)
|
||||
complex_thinking_results[model] = result
|
||||
|
||||
if result["success"]:
|
||||
detail = f"reasoning={result['has_reasoning']}, reasoning_len={len(result['reasoning_content'])}, content_len={len(result['content'])}"
|
||||
else:
|
||||
detail = f"error: {result['error']}" if result["error"] else "Unknown error"
|
||||
|
||||
print_result(f"模型: {model}", result["success"], detail)
|
||||
|
||||
# 如果有 reasoning 内容,打印前500字符
|
||||
if result["has_reasoning"] and result["reasoning_content"]:
|
||||
print(f"\n 📝 Reasoning 内容预览 (前500字符):")
|
||||
print(f" {result['reasoning_content'][:500]}...")
|
||||
|
||||
# 6. 总结报告
|
||||
print_separator("测试总结报告")
|
||||
|
||||
print(f"📊 模型总数: {len(models)}")
|
||||
print(f"✅ 可用模型: {len(available_models)}")
|
||||
print(f"❌ 不可用模型: {len(models) - len(available_models)}")
|
||||
|
||||
print(f"\n📊 流式输出测试:")
|
||||
streaming_pass = sum(1 for r in streaming_results.values() if r["success"])
|
||||
print(f" 通过: {streaming_pass}/{len(streaming_results)}")
|
||||
|
||||
print(f"\n📊 Thinking 模式测试 (简单):")
|
||||
thinking_pass = sum(1 for r in thinking_results.values() if r["success"])
|
||||
thinking_with_reasoning = sum(1 for r in thinking_results.values() if r["has_reasoning"])
|
||||
print(f" 通过: {thinking_pass}/{len(thinking_results)}")
|
||||
print(f" 包含推理内容: {thinking_with_reasoning}/{len(thinking_results)}")
|
||||
|
||||
print(f"\n📊 Thinking 模式测试 (复杂):")
|
||||
complex_pass = sum(1 for r in complex_thinking_results.values() if r["success"])
|
||||
complex_with_reasoning = sum(1 for r in complex_thinking_results.values() if r["has_reasoning"])
|
||||
print(f" 通过: {complex_pass}/{len(complex_thinking_results)}")
|
||||
print(f" 包含推理内容: {complex_with_reasoning}/{len(complex_thinking_results)}")
|
||||
|
||||
# 列出所有错误
|
||||
print(f"\n📋 错误详情:")
|
||||
has_errors = False
|
||||
|
||||
for model, result in streaming_results.items():
|
||||
if result["error"]:
|
||||
has_errors = True
|
||||
print(f" [流式] {model}: {result['error'][:100]}")
|
||||
|
||||
for model, result in thinking_results.items():
|
||||
if result["error"]:
|
||||
has_errors = True
|
||||
print(f" [Thinking简单] {model}: {result['error'][:100]}")
|
||||
|
||||
for model, result in complex_thinking_results.items():
|
||||
if result["error"]:
|
||||
has_errors = True
|
||||
print(f" [Thinking复杂] {model}: {result['error'][:100]}")
|
||||
|
||||
if not has_errors:
|
||||
print(" 无错误")
|
||||
|
||||
print("\n" + "="*60)
|
||||
print(" 测试完成")
|
||||
print("="*60 + "\n")
|
||||
|
||||
def test_single_model_basic(model: str):
|
||||
"""单独测试一个模型的基础功能"""
|
||||
print_separator(f"基础测试: {model}")
|
||||
success, detail = test_model_basic(model)
|
||||
print_result(f"模型: {model}", success, detail)
|
||||
return success
|
||||
|
||||
def test_single_model_streaming(model: str):
|
||||
"""单独测试一个模型的流式输出"""
|
||||
print_separator(f"流式测试: {model}")
|
||||
result = test_streaming(model)
|
||||
detail = f"chunks={result['chunks']}, content_len={len(result['content'])}"
|
||||
if result["error"]:
|
||||
detail = f"error: {result['error']}"
|
||||
print_result(f"模型: {model}", result["success"], detail)
|
||||
if result["content"]:
|
||||
print(f"\n内容: {result['content'][:300]}")
|
||||
return result
|
||||
|
||||
def test_single_model_thinking(model: str, complex_task: bool = False):
|
||||
"""单独测试一个模型的thinking模式"""
|
||||
task_type = "复杂" if complex_task else "简单"
|
||||
print_separator(f"Thinking测试({task_type}): {model}")
|
||||
result = test_thinking_mode(model, complex_task=complex_task)
|
||||
detail = f"reasoning={result['has_reasoning']}, chunks={result['chunks']}"
|
||||
if result["error"]:
|
||||
detail = f"error: {result['error']}"
|
||||
print_result(f"模型: {model}", result["success"], detail)
|
||||
if result["reasoning_content"]:
|
||||
print(f"\nReasoning预览: {result['reasoning_content'][:500]}")
|
||||
if result["content"]:
|
||||
print(f"\n内容预览: {result['content'][:500]}")
|
||||
return result
|
||||
|
||||
def print_usage():
|
||||
print("""
|
||||
用法: python test_api.py <command> [options]
|
||||
|
||||
命令:
|
||||
models - 获取模型列表
|
||||
basic <model> - 测试单个模型基础功能
|
||||
stream <model> - 测试单个模型流式输出
|
||||
thinking <model> - 测试单个模型thinking模式(简单任务)
|
||||
thinking-complex <model> - 测试单个模型thinking模式(复杂任务)
|
||||
all - 运行完整测试(原有功能)
|
||||
|
||||
示例:
|
||||
python test_api.py models
|
||||
python test_api.py basic claude-sonnet
|
||||
python test_api.py stream claude-sonnet
|
||||
python test_api.py thinking claude-sonnet
|
||||
""")
|
||||
|
||||
if __name__ == "__main__":
|
||||
import sys
|
||||
|
||||
if len(sys.argv) < 2:
|
||||
print_usage()
|
||||
sys.exit(0)
|
||||
|
||||
cmd = sys.argv[1].lower()
|
||||
|
||||
if cmd == "models":
|
||||
get_models()
|
||||
elif cmd == "basic" and len(sys.argv) >= 3:
|
||||
test_single_model_basic(sys.argv[2])
|
||||
elif cmd == "stream" and len(sys.argv) >= 3:
|
||||
test_single_model_streaming(sys.argv[2])
|
||||
elif cmd == "thinking" and len(sys.argv) >= 3:
|
||||
test_single_model_thinking(sys.argv[2], complex_task=False)
|
||||
elif cmd == "thinking-complex" and len(sys.argv) >= 3:
|
||||
test_single_model_thinking(sys.argv[2], complex_task=True)
|
||||
elif cmd == "all":
|
||||
run_full_test()
|
||||
else:
|
||||
print_usage()
|
||||
273
test_auth_diff.go
Normal file
273
test_auth_diff.go
Normal file
@@ -0,0 +1,273 @@
|
||||
// 测试脚本 3:对比 CLIProxyAPIPlus 与官方格式的差异
|
||||
// 这个脚本分析 CLIProxyAPIPlus 保存的 token 与官方格式的差异
|
||||
// 运行方式: go run test_auth_diff.go
|
||||
package main
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"io"
|
||||
"net/http"
|
||||
"os"
|
||||
"path/filepath"
|
||||
"strings"
|
||||
"time"
|
||||
)
|
||||
|
||||
func main() {
|
||||
fmt.Println("=" + strings.Repeat("=", 59))
|
||||
fmt.Println(" 测试脚本 3: Token 格式差异分析")
|
||||
fmt.Println("=" + strings.Repeat("=", 59))
|
||||
|
||||
homeDir := os.Getenv("USERPROFILE")
|
||||
|
||||
// 加载官方 IDE Token (Kiro IDE 生成)
|
||||
fmt.Println("\n[1] 官方 Kiro IDE Token 格式")
|
||||
fmt.Println("-" + strings.Repeat("-", 59))
|
||||
|
||||
ideTokenPath := filepath.Join(homeDir, ".aws", "sso", "cache", "kiro-auth-token.json")
|
||||
ideToken := loadAndAnalyze(ideTokenPath, "Kiro IDE")
|
||||
|
||||
// 加载 CLIProxyAPIPlus 保存的 Token
|
||||
fmt.Println("\n[2] CLIProxyAPIPlus 保存的 Token 格式")
|
||||
fmt.Println("-" + strings.Repeat("-", 59))
|
||||
|
||||
cliProxyDir := filepath.Join(homeDir, ".cli-proxy-api")
|
||||
files, _ := os.ReadDir(cliProxyDir)
|
||||
|
||||
var cliProxyTokens []map[string]interface{}
|
||||
for _, f := range files {
|
||||
if strings.HasPrefix(f.Name(), "kiro") && strings.HasSuffix(f.Name(), ".json") {
|
||||
p := filepath.Join(cliProxyDir, f.Name())
|
||||
token := loadAndAnalyze(p, f.Name())
|
||||
if token != nil {
|
||||
cliProxyTokens = append(cliProxyTokens, token)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// 对比分析
|
||||
fmt.Println("\n[3] 关键差异分析")
|
||||
fmt.Println("-" + strings.Repeat("-", 59))
|
||||
|
||||
if ideToken == nil {
|
||||
fmt.Println("❌ 无法加载 IDE Token,跳过对比")
|
||||
} else if len(cliProxyTokens) == 0 {
|
||||
fmt.Println("❌ 无法加载 CLIProxyAPIPlus Token,跳过对比")
|
||||
} else {
|
||||
// 对比最新的 CLIProxyAPIPlus token
|
||||
cliToken := cliProxyTokens[0]
|
||||
|
||||
fmt.Println("\n字段对比:")
|
||||
fmt.Printf("%-20s | %-15s | %-15s\n", "字段", "IDE Token", "CLIProxy Token")
|
||||
fmt.Println(strings.Repeat("-", 55))
|
||||
|
||||
fields := []string{
|
||||
"accessToken", "refreshToken", "clientId", "clientSecret",
|
||||
"authMethod", "auth_method", "provider", "region", "expiresAt", "expires_at",
|
||||
}
|
||||
|
||||
for _, field := range fields {
|
||||
ideVal := getFieldStatus(ideToken, field)
|
||||
cliVal := getFieldStatus(cliToken, field)
|
||||
|
||||
status := " "
|
||||
if ideVal != cliVal {
|
||||
if ideVal == "✅ 有" && cliVal == "❌ 无" {
|
||||
status = "⚠️"
|
||||
} else if ideVal == "❌ 无" && cliVal == "✅ 有" {
|
||||
status = "📝"
|
||||
}
|
||||
}
|
||||
|
||||
fmt.Printf("%-20s | %-15s | %-15s %s\n", field, ideVal, cliVal, status)
|
||||
}
|
||||
|
||||
// 关键问题检测
|
||||
fmt.Println("\n🔍 问题检测:")
|
||||
|
||||
// 检查 clientId/clientSecret
|
||||
if hasField(ideToken, "clientId") && !hasField(cliToken, "clientId") {
|
||||
fmt.Println(" ⚠️ 问题: CLIProxyAPIPlus 缺少 clientId 字段!")
|
||||
fmt.Println(" 原因: IdC 认证刷新 token 时需要 clientId")
|
||||
}
|
||||
|
||||
if hasField(ideToken, "clientSecret") && !hasField(cliToken, "clientSecret") {
|
||||
fmt.Println(" ⚠️ 问题: CLIProxyAPIPlus 缺少 clientSecret 字段!")
|
||||
fmt.Println(" 原因: IdC 认证刷新 token 时需要 clientSecret")
|
||||
}
|
||||
|
||||
// 检查字段名差异
|
||||
if hasField(cliToken, "auth_method") && !hasField(cliToken, "authMethod") {
|
||||
fmt.Println(" 📝 注意: CLIProxy 使用 auth_method (snake_case)")
|
||||
fmt.Println(" 而官方使用 authMethod (camelCase)")
|
||||
}
|
||||
|
||||
if hasField(cliToken, "expires_at") && !hasField(cliToken, "expiresAt") {
|
||||
fmt.Println(" 📝 注意: CLIProxy 使用 expires_at (snake_case)")
|
||||
fmt.Println(" 而官方使用 expiresAt (camelCase)")
|
||||
}
|
||||
}
|
||||
|
||||
// Step 4: 测试使用完整格式的 token
|
||||
fmt.Println("\n[4] 测试完整格式 Token (带 clientId/clientSecret)")
|
||||
fmt.Println("-" + strings.Repeat("-", 59))
|
||||
|
||||
if ideToken != nil {
|
||||
testWithFullToken(ideToken)
|
||||
}
|
||||
|
||||
fmt.Println("\n" + strings.Repeat("=", 60))
|
||||
fmt.Println(" 分析完成")
|
||||
fmt.Println(strings.Repeat("=", 60))
|
||||
|
||||
// 给出建议
|
||||
fmt.Println("\n💡 修复建议:")
|
||||
fmt.Println(" 1. CLIProxyAPIPlus 导入 token 时需要保留 clientId 和 clientSecret")
|
||||
fmt.Println(" 2. IdC 认证刷新 token 必须使用这两个字段")
|
||||
fmt.Println(" 3. 检查 CLIProxyAPIPlus 的 token 导入逻辑:")
|
||||
fmt.Println(" - internal/auth/kiro/aws.go LoadKiroIDEToken()")
|
||||
fmt.Println(" - sdk/auth/kiro.go ImportFromKiroIDE()")
|
||||
}
|
||||
|
||||
func loadAndAnalyze(path, name string) map[string]interface{} {
|
||||
data, err := os.ReadFile(path)
|
||||
if err != nil {
|
||||
fmt.Printf("❌ 无法加载 %s: %v\n", name, err)
|
||||
return nil
|
||||
}
|
||||
|
||||
var token map[string]interface{}
|
||||
if err := json.Unmarshal(data, &token); err != nil {
|
||||
fmt.Printf("❌ 无法解析 %s: %v\n", name, err)
|
||||
return nil
|
||||
}
|
||||
|
||||
fmt.Printf("📄 %s\n", path)
|
||||
fmt.Printf(" 字段数: %d\n", len(token))
|
||||
|
||||
// 列出所有字段
|
||||
fmt.Printf(" 字段列表: ")
|
||||
keys := make([]string, 0, len(token))
|
||||
for k := range token {
|
||||
keys = append(keys, k)
|
||||
}
|
||||
fmt.Printf("%v\n", keys)
|
||||
|
||||
return token
|
||||
}
|
||||
|
||||
func getFieldStatus(token map[string]interface{}, field string) string {
|
||||
if token == nil {
|
||||
return "N/A"
|
||||
}
|
||||
if v, ok := token[field]; ok && v != nil && v != "" {
|
||||
return "✅ 有"
|
||||
}
|
||||
return "❌ 无"
|
||||
}
|
||||
|
||||
func hasField(token map[string]interface{}, field string) bool {
|
||||
if token == nil {
|
||||
return false
|
||||
}
|
||||
v, ok := token[field]
|
||||
return ok && v != nil && v != ""
|
||||
}
|
||||
|
||||
func testWithFullToken(token map[string]interface{}) {
|
||||
accessToken, _ := token["accessToken"].(string)
|
||||
refreshToken, _ := token["refreshToken"].(string)
|
||||
clientId, _ := token["clientId"].(string)
|
||||
clientSecret, _ := token["clientSecret"].(string)
|
||||
region, _ := token["region"].(string)
|
||||
|
||||
if region == "" {
|
||||
region = "us-east-1"
|
||||
}
|
||||
|
||||
// 测试当前 accessToken
|
||||
fmt.Println("\n测试当前 accessToken...")
|
||||
if testAPICall(accessToken, region) {
|
||||
fmt.Println("✅ 当前 accessToken 有效")
|
||||
return
|
||||
}
|
||||
|
||||
fmt.Println("⚠️ 当前 accessToken 无效,尝试刷新...")
|
||||
|
||||
// 检查是否有完整的刷新所需字段
|
||||
if clientId == "" || clientSecret == "" {
|
||||
fmt.Println("❌ 缺少 clientId 或 clientSecret,无法刷新")
|
||||
fmt.Println(" 这就是问题所在!")
|
||||
return
|
||||
}
|
||||
|
||||
// 尝试刷新
|
||||
fmt.Println("\n使用完整字段刷新 token...")
|
||||
url := fmt.Sprintf("https://oidc.%s.amazonaws.com/token", region)
|
||||
|
||||
requestBody := map[string]interface{}{
|
||||
"refreshToken": refreshToken,
|
||||
"clientId": clientId,
|
||||
"clientSecret": clientSecret,
|
||||
"grantType": "refresh_token",
|
||||
}
|
||||
|
||||
body, _ := json.Marshal(requestBody)
|
||||
req, _ := http.NewRequest("POST", url, bytes.NewBuffer(body))
|
||||
req.Header.Set("Content-Type", "application/json")
|
||||
|
||||
client := &http.Client{Timeout: 30 * time.Second}
|
||||
resp, err := client.Do(req)
|
||||
if err != nil {
|
||||
fmt.Printf("❌ 请求失败: %v\n", err)
|
||||
return
|
||||
}
|
||||
defer resp.Body.Close()
|
||||
|
||||
respBody, _ := io.ReadAll(resp.Body)
|
||||
|
||||
if resp.StatusCode == 200 {
|
||||
var refreshResp map[string]interface{}
|
||||
json.Unmarshal(respBody, &refreshResp)
|
||||
|
||||
newAccessToken, _ := refreshResp["accessToken"].(string)
|
||||
fmt.Println("✅ Token 刷新成功!")
|
||||
|
||||
// 验证新 token
|
||||
if testAPICall(newAccessToken, region) {
|
||||
fmt.Println("✅ 新 Token 验证成功!")
|
||||
fmt.Println("\n✅ 结论: 使用完整格式 (含 clientId/clientSecret) 可以正常工作")
|
||||
}
|
||||
} else {
|
||||
fmt.Printf("❌ 刷新失败: HTTP %d\n", resp.StatusCode)
|
||||
fmt.Printf(" 响应: %s\n", string(respBody))
|
||||
}
|
||||
}
|
||||
|
||||
func testAPICall(accessToken, region string) bool {
|
||||
url := fmt.Sprintf("https://codewhisperer.%s.amazonaws.com", region)
|
||||
|
||||
payload := map[string]interface{}{
|
||||
"origin": "AI_EDITOR",
|
||||
"isEmailRequired": true,
|
||||
"resourceType": "AGENTIC_REQUEST",
|
||||
}
|
||||
body, _ := json.Marshal(payload)
|
||||
|
||||
req, _ := http.NewRequest("POST", url, bytes.NewBuffer(body))
|
||||
req.Header.Set("Content-Type", "application/x-amz-json-1.0")
|
||||
req.Header.Set("x-amz-target", "AmazonCodeWhispererService.GetUsageLimits")
|
||||
req.Header.Set("Authorization", "Bearer "+accessToken)
|
||||
req.Header.Set("Accept", "application/json")
|
||||
|
||||
client := &http.Client{Timeout: 30 * time.Second}
|
||||
resp, err := client.Do(req)
|
||||
if err != nil {
|
||||
return false
|
||||
}
|
||||
defer resp.Body.Close()
|
||||
|
||||
return resp.StatusCode == 200
|
||||
}
|
||||
323
test_auth_idc_go1.go
Normal file
323
test_auth_idc_go1.go
Normal file
@@ -0,0 +1,323 @@
|
||||
// 测试脚本 1:模拟 kiro2api_go1 的 IdC 认证方式
|
||||
// 这个脚本完整模拟 kiro-gateway/temp/kiro2api_go1 的认证逻辑
|
||||
// 运行方式: go run test_auth_idc_go1.go
|
||||
package main
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"io"
|
||||
"math/rand"
|
||||
"net/http"
|
||||
"os"
|
||||
"path/filepath"
|
||||
"strings"
|
||||
"time"
|
||||
)
|
||||
|
||||
// 配置常量 - 来自 kiro2api_go1/config/config.go
|
||||
const (
|
||||
IdcRefreshTokenURL = "https://oidc.us-east-1.amazonaws.com/token"
|
||||
CodeWhispererAPIURL = "https://codewhisperer.us-east-1.amazonaws.com"
|
||||
)
|
||||
|
||||
// AuthConfig - 来自 kiro2api_go1/auth/config.go
|
||||
type AuthConfig struct {
|
||||
AuthType string `json:"auth"`
|
||||
RefreshToken string `json:"refreshToken"`
|
||||
ClientID string `json:"clientId,omitempty"`
|
||||
ClientSecret string `json:"clientSecret,omitempty"`
|
||||
}
|
||||
|
||||
// IdcRefreshRequest - 来自 kiro2api_go1/types/token.go
|
||||
type IdcRefreshRequest struct {
|
||||
ClientId string `json:"clientId"`
|
||||
ClientSecret string `json:"clientSecret"`
|
||||
GrantType string `json:"grantType"`
|
||||
RefreshToken string `json:"refreshToken"`
|
||||
}
|
||||
|
||||
// RefreshResponse - 来自 kiro2api_go1/types/token.go
|
||||
type RefreshResponse struct {
|
||||
AccessToken string `json:"accessToken"`
|
||||
RefreshToken string `json:"refreshToken,omitempty"`
|
||||
ExpiresIn int `json:"expiresIn"`
|
||||
TokenType string `json:"tokenType,omitempty"`
|
||||
}
|
||||
|
||||
// Fingerprint - 简化的指纹结构
|
||||
type Fingerprint struct {
|
||||
OSType string
|
||||
ConnectionBehavior string
|
||||
AcceptLanguage string
|
||||
SecFetchMode string
|
||||
AcceptEncoding string
|
||||
}
|
||||
|
||||
func generateFingerprint() *Fingerprint {
|
||||
osTypes := []string{"darwin", "windows", "linux"}
|
||||
connections := []string{"keep-alive", "close"}
|
||||
languages := []string{"en-US,en;q=0.9", "zh-CN,zh;q=0.9", "en-GB,en;q=0.9"}
|
||||
fetchModes := []string{"cors", "navigate", "no-cors"}
|
||||
|
||||
return &Fingerprint{
|
||||
OSType: osTypes[rand.Intn(len(osTypes))],
|
||||
ConnectionBehavior: connections[rand.Intn(len(connections))],
|
||||
AcceptLanguage: languages[rand.Intn(len(languages))],
|
||||
SecFetchMode: fetchModes[rand.Intn(len(fetchModes))],
|
||||
AcceptEncoding: "gzip, deflate, br",
|
||||
}
|
||||
}
|
||||
|
||||
func main() {
|
||||
rand.Seed(time.Now().UnixNano())
|
||||
|
||||
fmt.Println("=" + strings.Repeat("=", 59))
|
||||
fmt.Println(" 测试脚本 1: kiro2api_go1 风格 IdC 认证")
|
||||
fmt.Println("=" + strings.Repeat("=", 59))
|
||||
|
||||
// Step 1: 加载官方格式的 token 文件
|
||||
fmt.Println("\n[Step 1] 加载官方格式 Token 文件")
|
||||
fmt.Println("-" + strings.Repeat("-", 59))
|
||||
|
||||
// 尝试从多个位置加载
|
||||
tokenPaths := []string{
|
||||
// 优先使用包含完整 clientId/clientSecret 的文件
|
||||
"E:/ai_project_2api/kiro-gateway/configs/kiro/kiro-auth-token-1768317098.json",
|
||||
filepath.Join(os.Getenv("USERPROFILE"), ".aws", "sso", "cache", "kiro-auth-token.json"),
|
||||
}
|
||||
|
||||
var tokenData map[string]interface{}
|
||||
var loadedPath string
|
||||
|
||||
for _, p := range tokenPaths {
|
||||
data, err := os.ReadFile(p)
|
||||
if err == nil {
|
||||
if err := json.Unmarshal(data, &tokenData); err == nil {
|
||||
loadedPath = p
|
||||
break
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
if tokenData == nil {
|
||||
fmt.Println("❌ 无法加载任何 token 文件")
|
||||
return
|
||||
}
|
||||
|
||||
fmt.Printf("✅ 加载文件: %s\n", loadedPath)
|
||||
|
||||
// 提取关键字段
|
||||
accessToken, _ := tokenData["accessToken"].(string)
|
||||
refreshToken, _ := tokenData["refreshToken"].(string)
|
||||
clientId, _ := tokenData["clientId"].(string)
|
||||
clientSecret, _ := tokenData["clientSecret"].(string)
|
||||
authMethod, _ := tokenData["authMethod"].(string)
|
||||
region, _ := tokenData["region"].(string)
|
||||
|
||||
if region == "" {
|
||||
region = "us-east-1"
|
||||
}
|
||||
|
||||
fmt.Printf("\n当前 Token 信息:\n")
|
||||
fmt.Printf(" AuthMethod: %s\n", authMethod)
|
||||
fmt.Printf(" Region: %s\n", region)
|
||||
fmt.Printf(" AccessToken: %s...\n", truncate(accessToken, 50))
|
||||
fmt.Printf(" RefreshToken: %s...\n", truncate(refreshToken, 50))
|
||||
fmt.Printf(" ClientID: %s\n", truncate(clientId, 30))
|
||||
fmt.Printf(" ClientSecret: %s...\n", truncate(clientSecret, 50))
|
||||
|
||||
// Step 2: 验证 IdC 认证所需字段
|
||||
fmt.Println("\n[Step 2] 验证 IdC 认证必需字段")
|
||||
fmt.Println("-" + strings.Repeat("-", 59))
|
||||
|
||||
missingFields := []string{}
|
||||
if refreshToken == "" {
|
||||
missingFields = append(missingFields, "refreshToken")
|
||||
}
|
||||
if clientId == "" {
|
||||
missingFields = append(missingFields, "clientId")
|
||||
}
|
||||
if clientSecret == "" {
|
||||
missingFields = append(missingFields, "clientSecret")
|
||||
}
|
||||
|
||||
if len(missingFields) > 0 {
|
||||
fmt.Printf("❌ 缺少必需字段: %v\n", missingFields)
|
||||
fmt.Println(" IdC 认证需要: refreshToken, clientId, clientSecret")
|
||||
return
|
||||
}
|
||||
fmt.Println("✅ 所有必需字段都存在")
|
||||
|
||||
// Step 3: 测试直接使用 accessToken 调用 API
|
||||
fmt.Println("\n[Step 3] 测试当前 AccessToken 有效性")
|
||||
fmt.Println("-" + strings.Repeat("-", 59))
|
||||
|
||||
if testAPICall(accessToken, region) {
|
||||
fmt.Println("✅ 当前 AccessToken 有效,无需刷新")
|
||||
} else {
|
||||
fmt.Println("⚠️ 当前 AccessToken 无效,需要刷新")
|
||||
|
||||
// Step 4: 使用 kiro2api_go1 风格刷新 token
|
||||
fmt.Println("\n[Step 4] 使用 kiro2api_go1 风格刷新 Token")
|
||||
fmt.Println("-" + strings.Repeat("-", 59))
|
||||
|
||||
newToken, err := refreshIdCToken(AuthConfig{
|
||||
AuthType: "IdC",
|
||||
RefreshToken: refreshToken,
|
||||
ClientID: clientId,
|
||||
ClientSecret: clientSecret,
|
||||
}, region)
|
||||
|
||||
if err != nil {
|
||||
fmt.Printf("❌ 刷新失败: %v\n", err)
|
||||
return
|
||||
}
|
||||
|
||||
fmt.Println("✅ Token 刷新成功!")
|
||||
fmt.Printf(" 新 AccessToken: %s...\n", truncate(newToken.AccessToken, 50))
|
||||
fmt.Printf(" ExpiresIn: %d 秒\n", newToken.ExpiresIn)
|
||||
|
||||
// Step 5: 验证新 token
|
||||
fmt.Println("\n[Step 5] 验证新 Token")
|
||||
fmt.Println("-" + strings.Repeat("-", 59))
|
||||
|
||||
if testAPICall(newToken.AccessToken, region) {
|
||||
fmt.Println("✅ 新 Token 验证成功!")
|
||||
|
||||
// 保存新 token
|
||||
saveNewToken(loadedPath, newToken, tokenData)
|
||||
} else {
|
||||
fmt.Println("❌ 新 Token 验证失败")
|
||||
}
|
||||
}
|
||||
|
||||
fmt.Println("\n" + strings.Repeat("=", 60))
|
||||
fmt.Println(" 测试完成")
|
||||
fmt.Println(strings.Repeat("=", 60))
|
||||
}
|
||||
|
||||
// refreshIdCToken - 完全模拟 kiro2api_go1/auth/refresh.go 的 refreshIdCToken 函数
|
||||
func refreshIdCToken(authConfig AuthConfig, region string) (*RefreshResponse, error) {
|
||||
refreshReq := IdcRefreshRequest{
|
||||
ClientId: authConfig.ClientID,
|
||||
ClientSecret: authConfig.ClientSecret,
|
||||
GrantType: "refresh_token",
|
||||
RefreshToken: authConfig.RefreshToken,
|
||||
}
|
||||
|
||||
reqBody, err := json.Marshal(refreshReq)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("序列化IdC请求失败: %v", err)
|
||||
}
|
||||
|
||||
url := fmt.Sprintf("https://oidc.%s.amazonaws.com/token", region)
|
||||
req, err := http.NewRequest("POST", url, bytes.NewBuffer(reqBody))
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("创建IdC请求失败: %v", err)
|
||||
}
|
||||
|
||||
// 设置 IdC 特殊 headers(使用指纹随机化)- 完全模拟 kiro2api_go1
|
||||
fp := generateFingerprint()
|
||||
|
||||
req.Header.Set("Content-Type", "application/json")
|
||||
req.Header.Set("Host", fmt.Sprintf("oidc.%s.amazonaws.com", region))
|
||||
req.Header.Set("Connection", fp.ConnectionBehavior)
|
||||
req.Header.Set("x-amz-user-agent", fmt.Sprintf("aws-sdk-js/3.738.0 ua/2.1 os/%s lang/js md/browser#unknown_unknown api/sso-oidc#3.738.0 m/E KiroIDE", fp.OSType))
|
||||
req.Header.Set("Accept", "*/*")
|
||||
req.Header.Set("Accept-Language", fp.AcceptLanguage)
|
||||
req.Header.Set("sec-fetch-mode", fp.SecFetchMode)
|
||||
req.Header.Set("User-Agent", "node")
|
||||
req.Header.Set("Accept-Encoding", fp.AcceptEncoding)
|
||||
|
||||
fmt.Println("发送刷新请求:")
|
||||
fmt.Printf(" URL: %s\n", url)
|
||||
fmt.Println(" Headers:")
|
||||
for k, v := range req.Header {
|
||||
if k == "Content-Type" || k == "Host" || k == "X-Amz-User-Agent" || k == "User-Agent" {
|
||||
fmt.Printf(" %s: %s\n", k, v[0])
|
||||
}
|
||||
}
|
||||
|
||||
client := &http.Client{Timeout: 30 * time.Second}
|
||||
resp, err := client.Do(req)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("IdC请求失败: %v", err)
|
||||
}
|
||||
defer resp.Body.Close()
|
||||
|
||||
body, _ := io.ReadAll(resp.Body)
|
||||
|
||||
if resp.StatusCode != http.StatusOK {
|
||||
return nil, fmt.Errorf("IdC刷新失败: 状态码 %d, 响应: %s", resp.StatusCode, string(body))
|
||||
}
|
||||
|
||||
var refreshResp RefreshResponse
|
||||
if err := json.Unmarshal(body, &refreshResp); err != nil {
|
||||
return nil, fmt.Errorf("解析IdC响应失败: %v", err)
|
||||
}
|
||||
|
||||
return &refreshResp, nil
|
||||
}
|
||||
|
||||
func testAPICall(accessToken, region string) bool {
|
||||
url := fmt.Sprintf("https://codewhisperer.%s.amazonaws.com", region)
|
||||
|
||||
payload := map[string]interface{}{
|
||||
"origin": "AI_EDITOR",
|
||||
"isEmailRequired": true,
|
||||
"resourceType": "AGENTIC_REQUEST",
|
||||
}
|
||||
body, _ := json.Marshal(payload)
|
||||
|
||||
req, _ := http.NewRequest("POST", url, bytes.NewBuffer(body))
|
||||
req.Header.Set("Content-Type", "application/x-amz-json-1.0")
|
||||
req.Header.Set("x-amz-target", "AmazonCodeWhispererService.GetUsageLimits")
|
||||
req.Header.Set("Authorization", "Bearer "+accessToken)
|
||||
req.Header.Set("Accept", "application/json")
|
||||
|
||||
client := &http.Client{Timeout: 30 * time.Second}
|
||||
resp, err := client.Do(req)
|
||||
if err != nil {
|
||||
fmt.Printf(" 请求错误: %v\n", err)
|
||||
return false
|
||||
}
|
||||
defer resp.Body.Close()
|
||||
|
||||
respBody, _ := io.ReadAll(resp.Body)
|
||||
fmt.Printf(" API 响应: HTTP %d\n", resp.StatusCode)
|
||||
|
||||
if resp.StatusCode == 200 {
|
||||
return true
|
||||
}
|
||||
|
||||
fmt.Printf(" 错误详情: %s\n", truncate(string(respBody), 200))
|
||||
return false
|
||||
}
|
||||
|
||||
func saveNewToken(originalPath string, newToken *RefreshResponse, originalData map[string]interface{}) {
|
||||
// 更新 token 数据
|
||||
originalData["accessToken"] = newToken.AccessToken
|
||||
if newToken.RefreshToken != "" {
|
||||
originalData["refreshToken"] = newToken.RefreshToken
|
||||
}
|
||||
originalData["expiresAt"] = time.Now().Add(time.Duration(newToken.ExpiresIn) * time.Second).Format(time.RFC3339)
|
||||
|
||||
data, _ := json.MarshalIndent(originalData, "", " ")
|
||||
|
||||
// 保存到新文件
|
||||
newPath := strings.TrimSuffix(originalPath, ".json") + "_refreshed.json"
|
||||
if err := os.WriteFile(newPath, data, 0644); err != nil {
|
||||
fmt.Printf("⚠️ 保存失败: %v\n", err)
|
||||
} else {
|
||||
fmt.Printf("✅ 新 Token 已保存到: %s\n", newPath)
|
||||
}
|
||||
}
|
||||
|
||||
func truncate(s string, n int) string {
|
||||
if len(s) <= n {
|
||||
return s
|
||||
}
|
||||
return s[:n]
|
||||
}
|
||||
237
test_auth_js_style.go
Normal file
237
test_auth_js_style.go
Normal file
@@ -0,0 +1,237 @@
|
||||
// 测试脚本 2:模拟 kiro2Api_js 的认证方式
|
||||
// 这个脚本完整模拟 kiro-gateway/temp/kiro2Api_js 的认证逻辑
|
||||
// 运行方式: go run test_auth_js_style.go
|
||||
package main
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"io"
|
||||
"net/http"
|
||||
"os"
|
||||
"path/filepath"
|
||||
"strings"
|
||||
"time"
|
||||
)
|
||||
|
||||
// 常量 - 来自 kiro2Api_js/src/kiro/auth.js
|
||||
const (
|
||||
REFRESH_URL_TEMPLATE = "https://prod.{{region}}.auth.desktop.kiro.dev/refreshToken"
|
||||
REFRESH_IDC_URL_TEMPLATE = "https://oidc.{{region}}.amazonaws.com/token"
|
||||
AUTH_METHOD_SOCIAL = "social"
|
||||
AUTH_METHOD_IDC = "IdC"
|
||||
)
|
||||
|
||||
func main() {
|
||||
fmt.Println("=" + strings.Repeat("=", 59))
|
||||
fmt.Println(" 测试脚本 2: kiro2Api_js 风格认证")
|
||||
fmt.Println("=" + strings.Repeat("=", 59))
|
||||
|
||||
// Step 1: 加载 token 文件
|
||||
fmt.Println("\n[Step 1] 加载 Token 文件")
|
||||
fmt.Println("-" + strings.Repeat("-", 59))
|
||||
|
||||
tokenPaths := []string{
|
||||
filepath.Join(os.Getenv("USERPROFILE"), ".aws", "sso", "cache", "kiro-auth-token.json"),
|
||||
"E:/ai_project_2api/kiro-gateway/configs/kiro/kiro-auth-token-1768317098.json",
|
||||
}
|
||||
|
||||
var tokenData map[string]interface{}
|
||||
var loadedPath string
|
||||
|
||||
for _, p := range tokenPaths {
|
||||
data, err := os.ReadFile(p)
|
||||
if err == nil {
|
||||
if err := json.Unmarshal(data, &tokenData); err == nil {
|
||||
loadedPath = p
|
||||
break
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
if tokenData == nil {
|
||||
fmt.Println("❌ 无法加载任何 token 文件")
|
||||
return
|
||||
}
|
||||
|
||||
fmt.Printf("✅ 加载文件: %s\n", loadedPath)
|
||||
|
||||
// 提取字段 - 模拟 kiro2Api_js/src/kiro/auth.js initializeAuth
|
||||
accessToken, _ := tokenData["accessToken"].(string)
|
||||
refreshToken, _ := tokenData["refreshToken"].(string)
|
||||
clientId, _ := tokenData["clientId"].(string)
|
||||
clientSecret, _ := tokenData["clientSecret"].(string)
|
||||
authMethod, _ := tokenData["authMethod"].(string)
|
||||
region, _ := tokenData["region"].(string)
|
||||
|
||||
if region == "" {
|
||||
region = "us-east-1"
|
||||
fmt.Println("⚠️ Region 未设置,使用默认值 us-east-1")
|
||||
}
|
||||
|
||||
fmt.Printf("\nToken 信息:\n")
|
||||
fmt.Printf(" AuthMethod: %s\n", authMethod)
|
||||
fmt.Printf(" Region: %s\n", region)
|
||||
fmt.Printf(" 有 ClientID: %v\n", clientId != "")
|
||||
fmt.Printf(" 有 ClientSecret: %v\n", clientSecret != "")
|
||||
|
||||
// Step 2: 测试当前 token
|
||||
fmt.Println("\n[Step 2] 测试当前 AccessToken")
|
||||
fmt.Println("-" + strings.Repeat("-", 59))
|
||||
|
||||
if testAPI(accessToken, region) {
|
||||
fmt.Println("✅ 当前 AccessToken 有效")
|
||||
return
|
||||
}
|
||||
|
||||
fmt.Println("⚠️ 当前 AccessToken 无效,开始刷新...")
|
||||
|
||||
// Step 3: 根据 authMethod 选择刷新方式 - 模拟 doRefreshToken
|
||||
fmt.Println("\n[Step 3] 刷新 Token (JS 风格)")
|
||||
fmt.Println("-" + strings.Repeat("-", 59))
|
||||
|
||||
var refreshURL string
|
||||
var requestBody map[string]interface{}
|
||||
|
||||
// 判断认证方式 - 模拟 kiro2Api_js auth.js doRefreshToken
|
||||
if authMethod == AUTH_METHOD_SOCIAL {
|
||||
// Social 认证
|
||||
refreshURL = strings.Replace(REFRESH_URL_TEMPLATE, "{{region}}", region, 1)
|
||||
requestBody = map[string]interface{}{
|
||||
"refreshToken": refreshToken,
|
||||
}
|
||||
fmt.Println("使用 Social 认证方式")
|
||||
} else {
|
||||
// IdC 认证 (默认)
|
||||
refreshURL = strings.Replace(REFRESH_IDC_URL_TEMPLATE, "{{region}}", region, 1)
|
||||
requestBody = map[string]interface{}{
|
||||
"refreshToken": refreshToken,
|
||||
"clientId": clientId,
|
||||
"clientSecret": clientSecret,
|
||||
"grantType": "refresh_token",
|
||||
}
|
||||
fmt.Println("使用 IdC 认证方式")
|
||||
}
|
||||
|
||||
fmt.Printf("刷新 URL: %s\n", refreshURL)
|
||||
fmt.Printf("请求字段: %v\n", getKeys(requestBody))
|
||||
|
||||
// 发送刷新请求
|
||||
body, _ := json.Marshal(requestBody)
|
||||
req, _ := http.NewRequest("POST", refreshURL, bytes.NewBuffer(body))
|
||||
req.Header.Set("Content-Type", "application/json")
|
||||
|
||||
client := &http.Client{Timeout: 30 * time.Second}
|
||||
resp, err := client.Do(req)
|
||||
if err != nil {
|
||||
fmt.Printf("❌ 请求失败: %v\n", err)
|
||||
return
|
||||
}
|
||||
defer resp.Body.Close()
|
||||
|
||||
respBody, _ := io.ReadAll(resp.Body)
|
||||
|
||||
fmt.Printf("\n响应状态: HTTP %d\n", resp.StatusCode)
|
||||
|
||||
if resp.StatusCode != 200 {
|
||||
fmt.Printf("❌ 刷新失败: %s\n", string(respBody))
|
||||
|
||||
// 分析错误
|
||||
var errResp map[string]interface{}
|
||||
if err := json.Unmarshal(respBody, &errResp); err == nil {
|
||||
if errType, ok := errResp["error"].(string); ok {
|
||||
fmt.Printf("错误类型: %s\n", errType)
|
||||
if errType == "invalid_grant" {
|
||||
fmt.Println("\n💡 提示: refresh_token 可能已过期,需要重新授权")
|
||||
}
|
||||
}
|
||||
if errDesc, ok := errResp["error_description"].(string); ok {
|
||||
fmt.Printf("错误描述: %s\n", errDesc)
|
||||
}
|
||||
}
|
||||
return
|
||||
}
|
||||
|
||||
// 解析响应
|
||||
var refreshResp map[string]interface{}
|
||||
json.Unmarshal(respBody, &refreshResp)
|
||||
|
||||
newAccessToken, _ := refreshResp["accessToken"].(string)
|
||||
newRefreshToken, _ := refreshResp["refreshToken"].(string)
|
||||
expiresIn, _ := refreshResp["expiresIn"].(float64)
|
||||
|
||||
fmt.Println("✅ Token 刷新成功!")
|
||||
fmt.Printf(" 新 AccessToken: %s...\n", truncate(newAccessToken, 50))
|
||||
fmt.Printf(" ExpiresIn: %.0f 秒\n", expiresIn)
|
||||
if newRefreshToken != "" {
|
||||
fmt.Printf(" 新 RefreshToken: %s...\n", truncate(newRefreshToken, 50))
|
||||
}
|
||||
|
||||
// Step 4: 验证新 token
|
||||
fmt.Println("\n[Step 4] 验证新 Token")
|
||||
fmt.Println("-" + strings.Repeat("-", 59))
|
||||
|
||||
if testAPI(newAccessToken, region) {
|
||||
fmt.Println("✅ 新 Token 验证成功!")
|
||||
|
||||
// 保存新 token - 模拟 saveCredentialsToFile
|
||||
tokenData["accessToken"] = newAccessToken
|
||||
if newRefreshToken != "" {
|
||||
tokenData["refreshToken"] = newRefreshToken
|
||||
}
|
||||
tokenData["expiresAt"] = time.Now().Add(time.Duration(expiresIn) * time.Second).Format(time.RFC3339)
|
||||
|
||||
saveData, _ := json.MarshalIndent(tokenData, "", " ")
|
||||
newPath := strings.TrimSuffix(loadedPath, ".json") + "_js_refreshed.json"
|
||||
os.WriteFile(newPath, saveData, 0644)
|
||||
fmt.Printf("✅ 已保存到: %s\n", newPath)
|
||||
} else {
|
||||
fmt.Println("❌ 新 Token 验证失败")
|
||||
}
|
||||
|
||||
fmt.Println("\n" + strings.Repeat("=", 60))
|
||||
fmt.Println(" 测试完成")
|
||||
fmt.Println(strings.Repeat("=", 60))
|
||||
}
|
||||
|
||||
func testAPI(accessToken, region string) bool {
|
||||
url := fmt.Sprintf("https://codewhisperer.%s.amazonaws.com", region)
|
||||
|
||||
payload := map[string]interface{}{
|
||||
"origin": "AI_EDITOR",
|
||||
"isEmailRequired": true,
|
||||
"resourceType": "AGENTIC_REQUEST",
|
||||
}
|
||||
body, _ := json.Marshal(payload)
|
||||
|
||||
req, _ := http.NewRequest("POST", url, bytes.NewBuffer(body))
|
||||
req.Header.Set("Content-Type", "application/x-amz-json-1.0")
|
||||
req.Header.Set("x-amz-target", "AmazonCodeWhispererService.GetUsageLimits")
|
||||
req.Header.Set("Authorization", "Bearer "+accessToken)
|
||||
req.Header.Set("Accept", "application/json")
|
||||
|
||||
client := &http.Client{Timeout: 30 * time.Second}
|
||||
resp, err := client.Do(req)
|
||||
if err != nil {
|
||||
return false
|
||||
}
|
||||
defer resp.Body.Close()
|
||||
|
||||
return resp.StatusCode == 200
|
||||
}
|
||||
|
||||
func getKeys(m map[string]interface{}) []string {
|
||||
keys := make([]string, 0, len(m))
|
||||
for k := range m {
|
||||
keys = append(keys, k)
|
||||
}
|
||||
return keys
|
||||
}
|
||||
|
||||
func truncate(s string, n int) string {
|
||||
if len(s) <= n {
|
||||
return s
|
||||
}
|
||||
return s[:n]
|
||||
}
|
||||
348
test_kiro_debug.go
Normal file
348
test_kiro_debug.go
Normal file
@@ -0,0 +1,348 @@
|
||||
// 独立测试脚本:排查 Kiro Token 403 错误
|
||||
// 运行方式: go run test_kiro_debug.go
|
||||
package main
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"encoding/base64"
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"io"
|
||||
"net/http"
|
||||
"os"
|
||||
"path/filepath"
|
||||
"strings"
|
||||
"time"
|
||||
)
|
||||
|
||||
// Token 结构 - 匹配 Kiro IDE 格式
|
||||
type KiroIDEToken struct {
|
||||
AccessToken string `json:"accessToken"`
|
||||
RefreshToken string `json:"refreshToken"`
|
||||
ExpiresAt string `json:"expiresAt"`
|
||||
ClientIDHash string `json:"clientIdHash,omitempty"`
|
||||
AuthMethod string `json:"authMethod"`
|
||||
Provider string `json:"provider"`
|
||||
Region string `json:"region,omitempty"`
|
||||
}
|
||||
|
||||
// Token 结构 - 匹配 CLIProxyAPIPlus 格式
|
||||
type CLIProxyToken struct {
|
||||
AccessToken string `json:"access_token"`
|
||||
RefreshToken string `json:"refresh_token"`
|
||||
ProfileArn string `json:"profile_arn"`
|
||||
ExpiresAt string `json:"expires_at"`
|
||||
AuthMethod string `json:"auth_method"`
|
||||
Provider string `json:"provider"`
|
||||
ClientID string `json:"client_id,omitempty"`
|
||||
ClientSecret string `json:"client_secret,omitempty"`
|
||||
Email string `json:"email,omitempty"`
|
||||
Type string `json:"type"`
|
||||
}
|
||||
|
||||
func main() {
|
||||
fmt.Println("=" + strings.Repeat("=", 59))
|
||||
fmt.Println(" Kiro Token 403 错误排查工具")
|
||||
fmt.Println("=" + strings.Repeat("=", 59))
|
||||
|
||||
homeDir, _ := os.UserHomeDir()
|
||||
|
||||
// Step 1: 检查 Kiro IDE Token 文件
|
||||
fmt.Println("\n[Step 1] 检查 Kiro IDE Token 文件")
|
||||
fmt.Println("-" + strings.Repeat("-", 59))
|
||||
|
||||
ideTokenPath := filepath.Join(homeDir, ".aws", "sso", "cache", "kiro-auth-token.json")
|
||||
ideToken, err := loadKiroIDEToken(ideTokenPath)
|
||||
if err != nil {
|
||||
fmt.Printf("❌ 无法加载 Kiro IDE Token: %v\n", err)
|
||||
return
|
||||
}
|
||||
fmt.Printf("✅ Token 文件: %s\n", ideTokenPath)
|
||||
fmt.Printf(" AuthMethod: %s\n", ideToken.AuthMethod)
|
||||
fmt.Printf(" Provider: %s\n", ideToken.Provider)
|
||||
fmt.Printf(" Region: %s\n", ideToken.Region)
|
||||
fmt.Printf(" ExpiresAt: %s\n", ideToken.ExpiresAt)
|
||||
fmt.Printf(" AccessToken (前50字符): %s...\n", truncate(ideToken.AccessToken, 50))
|
||||
|
||||
// Step 2: 检查 Token 过期状态
|
||||
fmt.Println("\n[Step 2] 检查 Token 过期状态")
|
||||
fmt.Println("-" + strings.Repeat("-", 59))
|
||||
|
||||
expiresAt, err := parseExpiresAt(ideToken.ExpiresAt)
|
||||
if err != nil {
|
||||
fmt.Printf("❌ 无法解析过期时间: %v\n", err)
|
||||
} else {
|
||||
now := time.Now()
|
||||
if now.After(expiresAt) {
|
||||
fmt.Printf("❌ Token 已过期!过期时间: %s,当前时间: %s\n", expiresAt.Format(time.RFC3339), now.Format(time.RFC3339))
|
||||
} else {
|
||||
remaining := expiresAt.Sub(now)
|
||||
fmt.Printf("✅ Token 未过期,剩余: %s\n", remaining.Round(time.Second))
|
||||
}
|
||||
}
|
||||
|
||||
// Step 3: 检查 CLIProxyAPIPlus 保存的 Token
|
||||
fmt.Println("\n[Step 3] 检查 CLIProxyAPIPlus 保存的 Token")
|
||||
fmt.Println("-" + strings.Repeat("-", 59))
|
||||
|
||||
cliProxyDir := filepath.Join(homeDir, ".cli-proxy-api")
|
||||
files, _ := os.ReadDir(cliProxyDir)
|
||||
for _, f := range files {
|
||||
if strings.HasPrefix(f.Name(), "kiro") && strings.HasSuffix(f.Name(), ".json") {
|
||||
filePath := filepath.Join(cliProxyDir, f.Name())
|
||||
cliToken, err := loadCLIProxyToken(filePath)
|
||||
if err != nil {
|
||||
fmt.Printf("❌ %s: 加载失败 - %v\n", f.Name(), err)
|
||||
continue
|
||||
}
|
||||
fmt.Printf("📄 %s:\n", f.Name())
|
||||
fmt.Printf(" AuthMethod: %s\n", cliToken.AuthMethod)
|
||||
fmt.Printf(" Provider: %s\n", cliToken.Provider)
|
||||
fmt.Printf(" ExpiresAt: %s\n", cliToken.ExpiresAt)
|
||||
fmt.Printf(" AccessToken (前50字符): %s...\n", truncate(cliToken.AccessToken, 50))
|
||||
|
||||
// 比较 Token
|
||||
if cliToken.AccessToken == ideToken.AccessToken {
|
||||
fmt.Printf(" ✅ AccessToken 与 IDE Token 一致\n")
|
||||
} else {
|
||||
fmt.Printf(" ⚠️ AccessToken 与 IDE Token 不一致!\n")
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// Step 4: 直接测试 Token 有效性 (调用 Kiro API)
|
||||
fmt.Println("\n[Step 4] 直接测试 Token 有效性")
|
||||
fmt.Println("-" + strings.Repeat("-", 59))
|
||||
|
||||
testTokenValidity(ideToken.AccessToken, ideToken.Region)
|
||||
|
||||
// Step 5: 测试不同的请求头格式
|
||||
fmt.Println("\n[Step 5] 测试不同的请求头格式")
|
||||
fmt.Println("-" + strings.Repeat("-", 59))
|
||||
|
||||
testDifferentHeaders(ideToken.AccessToken, ideToken.Region)
|
||||
|
||||
// Step 6: 解析 JWT 内容
|
||||
fmt.Println("\n[Step 6] 解析 JWT Token 内容")
|
||||
fmt.Println("-" + strings.Repeat("-", 59))
|
||||
|
||||
parseJWT(ideToken.AccessToken)
|
||||
|
||||
fmt.Println("\n" + strings.Repeat("=", 60))
|
||||
fmt.Println(" 排查完成")
|
||||
fmt.Println(strings.Repeat("=", 60))
|
||||
}
|
||||
|
||||
func loadKiroIDEToken(path string) (*KiroIDEToken, error) {
|
||||
data, err := os.ReadFile(path)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
var token KiroIDEToken
|
||||
if err := json.Unmarshal(data, &token); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return &token, nil
|
||||
}
|
||||
|
||||
func loadCLIProxyToken(path string) (*CLIProxyToken, error) {
|
||||
data, err := os.ReadFile(path)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
var token CLIProxyToken
|
||||
if err := json.Unmarshal(data, &token); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return &token, nil
|
||||
}
|
||||
|
||||
func parseExpiresAt(s string) (time.Time, error) {
|
||||
formats := []string{
|
||||
time.RFC3339,
|
||||
"2006-01-02T15:04:05.000Z",
|
||||
"2006-01-02T15:04:05Z",
|
||||
}
|
||||
for _, f := range formats {
|
||||
if t, err := time.Parse(f, s); err == nil {
|
||||
return t, nil
|
||||
}
|
||||
}
|
||||
return time.Time{}, fmt.Errorf("无法解析时间格式: %s", s)
|
||||
}
|
||||
|
||||
func truncate(s string, n int) string {
|
||||
if len(s) <= n {
|
||||
return s
|
||||
}
|
||||
return s[:n]
|
||||
}
|
||||
|
||||
func testTokenValidity(accessToken, region string) {
|
||||
if region == "" {
|
||||
region = "us-east-1"
|
||||
}
|
||||
|
||||
// 测试 GetUsageLimits API
|
||||
url := fmt.Sprintf("https://codewhisperer.%s.amazonaws.com", region)
|
||||
|
||||
payload := map[string]interface{}{
|
||||
"origin": "AI_EDITOR",
|
||||
"isEmailRequired": true,
|
||||
"resourceType": "AGENTIC_REQUEST",
|
||||
}
|
||||
body, _ := json.Marshal(payload)
|
||||
|
||||
req, _ := http.NewRequest("POST", url, bytes.NewBuffer(body))
|
||||
req.Header.Set("Content-Type", "application/x-amz-json-1.0")
|
||||
req.Header.Set("x-amz-target", "AmazonCodeWhispererService.GetUsageLimits")
|
||||
req.Header.Set("Authorization", "Bearer "+accessToken)
|
||||
req.Header.Set("Accept", "application/json")
|
||||
|
||||
fmt.Printf("请求 URL: %s\n", url)
|
||||
fmt.Printf("请求头:\n")
|
||||
for k, v := range req.Header {
|
||||
if k == "Authorization" {
|
||||
fmt.Printf(" %s: Bearer %s...\n", k, truncate(v[0][7:], 30))
|
||||
} else {
|
||||
fmt.Printf(" %s: %s\n", k, v[0])
|
||||
}
|
||||
}
|
||||
|
||||
client := &http.Client{Timeout: 30 * time.Second}
|
||||
resp, err := client.Do(req)
|
||||
if err != nil {
|
||||
fmt.Printf("❌ 请求失败: %v\n", err)
|
||||
return
|
||||
}
|
||||
defer resp.Body.Close()
|
||||
|
||||
respBody, _ := io.ReadAll(resp.Body)
|
||||
fmt.Printf("响应状态: %d\n", resp.StatusCode)
|
||||
fmt.Printf("响应内容: %s\n", string(respBody))
|
||||
|
||||
if resp.StatusCode == 200 {
|
||||
fmt.Println("✅ Token 有效!")
|
||||
} else if resp.StatusCode == 403 {
|
||||
fmt.Println("❌ Token 无效或已过期 (403)")
|
||||
}
|
||||
}
|
||||
|
||||
func testDifferentHeaders(accessToken, region string) {
|
||||
if region == "" {
|
||||
region = "us-east-1"
|
||||
}
|
||||
|
||||
tests := []struct {
|
||||
name string
|
||||
headers map[string]string
|
||||
}{
|
||||
{
|
||||
name: "最小请求头",
|
||||
headers: map[string]string{
|
||||
"Content-Type": "application/json",
|
||||
"Authorization": "Bearer " + accessToken,
|
||||
},
|
||||
},
|
||||
{
|
||||
name: "模拟 kiro2api_go1 风格",
|
||||
headers: map[string]string{
|
||||
"Content-Type": "application/json",
|
||||
"Accept": "text/event-stream",
|
||||
"Authorization": "Bearer " + accessToken,
|
||||
"x-amzn-kiro-agent-mode": "vibe",
|
||||
"x-amzn-codewhisperer-optout": "true",
|
||||
"amz-sdk-invocation-id": "test-invocation-id",
|
||||
"amz-sdk-request": "attempt=1; max=3",
|
||||
"x-amz-user-agent": "aws-sdk-js/1.0.27 KiroIDE-0.8.0-abc123",
|
||||
"User-Agent": "aws-sdk-js/1.0.27 ua/2.1 os/windows#10.0 lang/js md/nodejs#20.16.0 api/codewhispererstreaming#1.0.27 m/E KiroIDE-0.8.0-abc123",
|
||||
},
|
||||
},
|
||||
{
|
||||
name: "模拟 CLIProxyAPIPlus 风格",
|
||||
headers: map[string]string{
|
||||
"Content-Type": "application/x-amz-json-1.0",
|
||||
"x-amz-target": "AmazonCodeWhispererService.GetUsageLimits",
|
||||
"Authorization": "Bearer " + accessToken,
|
||||
"Accept": "application/json",
|
||||
"amz-sdk-invocation-id": "test-invocation-id",
|
||||
"amz-sdk-request": "attempt=1; max=1",
|
||||
"Connection": "close",
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
url := fmt.Sprintf("https://codewhisperer.%s.amazonaws.com", region)
|
||||
payload := map[string]interface{}{
|
||||
"origin": "AI_EDITOR",
|
||||
"isEmailRequired": true,
|
||||
"resourceType": "AGENTIC_REQUEST",
|
||||
}
|
||||
body, _ := json.Marshal(payload)
|
||||
|
||||
for _, test := range tests {
|
||||
fmt.Printf("\n测试: %s\n", test.name)
|
||||
|
||||
req, _ := http.NewRequest("POST", url, bytes.NewBuffer(body))
|
||||
for k, v := range test.headers {
|
||||
req.Header.Set(k, v)
|
||||
}
|
||||
|
||||
client := &http.Client{Timeout: 30 * time.Second}
|
||||
resp, err := client.Do(req)
|
||||
if err != nil {
|
||||
fmt.Printf(" ❌ 请求失败: %v\n", err)
|
||||
continue
|
||||
}
|
||||
|
||||
respBody, _ := io.ReadAll(resp.Body)
|
||||
resp.Body.Close()
|
||||
|
||||
if resp.StatusCode == 200 {
|
||||
fmt.Printf(" ✅ 成功 (HTTP %d)\n", resp.StatusCode)
|
||||
} else {
|
||||
fmt.Printf(" ❌ 失败 (HTTP %d): %s\n", resp.StatusCode, truncate(string(respBody), 100))
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func parseJWT(token string) {
|
||||
parts := strings.Split(token, ".")
|
||||
if len(parts) < 2 {
|
||||
fmt.Println("Token 不是 JWT 格式")
|
||||
return
|
||||
}
|
||||
|
||||
// 解码 header
|
||||
headerData, err := base64.RawURLEncoding.DecodeString(parts[0])
|
||||
if err != nil {
|
||||
fmt.Printf("无法解码 JWT header: %v\n", err)
|
||||
} else {
|
||||
var header map[string]interface{}
|
||||
json.Unmarshal(headerData, &header)
|
||||
fmt.Printf("JWT Header: %v\n", header)
|
||||
}
|
||||
|
||||
// 解码 payload
|
||||
payloadData, err := base64.RawURLEncoding.DecodeString(parts[1])
|
||||
if err != nil {
|
||||
fmt.Printf("无法解码 JWT payload: %v\n", err)
|
||||
} else {
|
||||
var payload map[string]interface{}
|
||||
json.Unmarshal(payloadData, &payload)
|
||||
fmt.Printf("JWT Payload:\n")
|
||||
for k, v := range payload {
|
||||
fmt.Printf(" %s: %v\n", k, v)
|
||||
}
|
||||
|
||||
// 检查过期时间
|
||||
if exp, ok := payload["exp"].(float64); ok {
|
||||
expTime := time.Unix(int64(exp), 0)
|
||||
if time.Now().After(expTime) {
|
||||
fmt.Printf(" ⚠️ JWT 已过期! exp=%s\n", expTime.Format(time.RFC3339))
|
||||
} else {
|
||||
fmt.Printf(" ✅ JWT 未过期, 剩余: %s\n", expTime.Sub(time.Now()).Round(time.Second))
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
367
test_proxy_debug.go
Normal file
367
test_proxy_debug.go
Normal file
@@ -0,0 +1,367 @@
|
||||
// 测试脚本 2:通过 CLIProxyAPIPlus 代理层排查问题
|
||||
// 运行方式: go run test_proxy_debug.go
|
||||
package main
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"io"
|
||||
"net/http"
|
||||
"os"
|
||||
"path/filepath"
|
||||
"strings"
|
||||
"time"
|
||||
)
|
||||
|
||||
const (
|
||||
ProxyURL = "http://localhost:8317"
|
||||
APIKey = "your-api-key-1"
|
||||
)
|
||||
|
||||
func main() {
|
||||
fmt.Println("=" + strings.Repeat("=", 59))
|
||||
fmt.Println(" CLIProxyAPIPlus 代理层问题排查")
|
||||
fmt.Println("=" + strings.Repeat("=", 59))
|
||||
|
||||
// Step 1: 检查代理服务状态
|
||||
fmt.Println("\n[Step 1] 检查代理服务状态")
|
||||
fmt.Println("-" + strings.Repeat("-", 59))
|
||||
|
||||
resp, err := http.Get(ProxyURL + "/health")
|
||||
if err != nil {
|
||||
fmt.Printf("❌ 代理服务不可达: %v\n", err)
|
||||
fmt.Println("请确保服务正在运行: go run ./cmd/server/main.go")
|
||||
return
|
||||
}
|
||||
resp.Body.Close()
|
||||
fmt.Printf("✅ 代理服务正常 (HTTP %d)\n", resp.StatusCode)
|
||||
|
||||
// Step 2: 获取模型列表
|
||||
fmt.Println("\n[Step 2] 获取模型列表")
|
||||
fmt.Println("-" + strings.Repeat("-", 59))
|
||||
|
||||
models := getModels()
|
||||
if len(models) == 0 {
|
||||
fmt.Println("❌ 没有可用的模型,检查凭据加载")
|
||||
checkCredentials()
|
||||
return
|
||||
}
|
||||
fmt.Printf("✅ 找到 %d 个模型:\n", len(models))
|
||||
for _, m := range models {
|
||||
fmt.Printf(" - %s\n", m)
|
||||
}
|
||||
|
||||
// Step 3: 测试模型请求 - 捕获详细错误
|
||||
fmt.Println("\n[Step 3] 测试模型请求(详细日志)")
|
||||
fmt.Println("-" + strings.Repeat("-", 59))
|
||||
|
||||
if len(models) > 0 {
|
||||
testModel := models[0]
|
||||
testModelRequest(testModel)
|
||||
}
|
||||
|
||||
// Step 4: 检查代理内部 Token 状态
|
||||
fmt.Println("\n[Step 4] 检查代理服务加载的凭据")
|
||||
fmt.Println("-" + strings.Repeat("-", 59))
|
||||
|
||||
checkProxyCredentials()
|
||||
|
||||
// Step 5: 对比直接请求和代理请求
|
||||
fmt.Println("\n[Step 5] 对比直接请求 vs 代理请求")
|
||||
fmt.Println("-" + strings.Repeat("-", 59))
|
||||
|
||||
compareDirectVsProxy()
|
||||
|
||||
fmt.Println("\n" + strings.Repeat("=", 60))
|
||||
fmt.Println(" 排查完成")
|
||||
fmt.Println(strings.Repeat("=", 60))
|
||||
}
|
||||
|
||||
func getModels() []string {
|
||||
req, _ := http.NewRequest("GET", ProxyURL+"/v1/models", nil)
|
||||
req.Header.Set("Authorization", "Bearer "+APIKey)
|
||||
|
||||
client := &http.Client{Timeout: 30 * time.Second}
|
||||
resp, err := client.Do(req)
|
||||
if err != nil {
|
||||
fmt.Printf("❌ 请求失败: %v\n", err)
|
||||
return nil
|
||||
}
|
||||
defer resp.Body.Close()
|
||||
|
||||
body, _ := io.ReadAll(resp.Body)
|
||||
|
||||
if resp.StatusCode != 200 {
|
||||
fmt.Printf("❌ HTTP %d: %s\n", resp.StatusCode, string(body))
|
||||
return nil
|
||||
}
|
||||
|
||||
var result struct {
|
||||
Data []struct {
|
||||
ID string `json:"id"`
|
||||
} `json:"data"`
|
||||
}
|
||||
json.Unmarshal(body, &result)
|
||||
|
||||
models := make([]string, len(result.Data))
|
||||
for i, m := range result.Data {
|
||||
models[i] = m.ID
|
||||
}
|
||||
return models
|
||||
}
|
||||
|
||||
func checkCredentials() {
|
||||
homeDir, _ := os.UserHomeDir()
|
||||
cliProxyDir := filepath.Join(homeDir, ".cli-proxy-api")
|
||||
|
||||
fmt.Printf("\n检查凭据目录: %s\n", cliProxyDir)
|
||||
files, err := os.ReadDir(cliProxyDir)
|
||||
if err != nil {
|
||||
fmt.Printf("❌ 无法读取目录: %v\n", err)
|
||||
return
|
||||
}
|
||||
|
||||
for _, f := range files {
|
||||
if strings.HasSuffix(f.Name(), ".json") {
|
||||
fmt.Printf(" 📄 %s\n", f.Name())
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func testModelRequest(model string) {
|
||||
fmt.Printf("测试模型: %s\n", model)
|
||||
|
||||
payload := map[string]interface{}{
|
||||
"model": model,
|
||||
"messages": []map[string]string{
|
||||
{"role": "user", "content": "Say 'OK' if you receive this."},
|
||||
},
|
||||
"max_tokens": 50,
|
||||
"stream": false,
|
||||
}
|
||||
body, _ := json.Marshal(payload)
|
||||
|
||||
req, _ := http.NewRequest("POST", ProxyURL+"/v1/chat/completions", bytes.NewBuffer(body))
|
||||
req.Header.Set("Authorization", "Bearer "+APIKey)
|
||||
req.Header.Set("Content-Type", "application/json")
|
||||
|
||||
fmt.Println("\n发送请求:")
|
||||
fmt.Printf(" URL: %s/v1/chat/completions\n", ProxyURL)
|
||||
fmt.Printf(" Model: %s\n", model)
|
||||
|
||||
client := &http.Client{Timeout: 60 * time.Second}
|
||||
resp, err := client.Do(req)
|
||||
if err != nil {
|
||||
fmt.Printf("❌ 请求失败: %v\n", err)
|
||||
return
|
||||
}
|
||||
defer resp.Body.Close()
|
||||
|
||||
respBody, _ := io.ReadAll(resp.Body)
|
||||
|
||||
fmt.Printf("\n响应:\n")
|
||||
fmt.Printf(" Status: %d\n", resp.StatusCode)
|
||||
fmt.Printf(" Headers:\n")
|
||||
for k, v := range resp.Header {
|
||||
fmt.Printf(" %s: %s\n", k, strings.Join(v, ", "))
|
||||
}
|
||||
|
||||
// 格式化 JSON 输出
|
||||
var prettyJSON bytes.Buffer
|
||||
if err := json.Indent(&prettyJSON, respBody, " ", " "); err == nil {
|
||||
fmt.Printf(" Body:\n %s\n", prettyJSON.String())
|
||||
} else {
|
||||
fmt.Printf(" Body: %s\n", string(respBody))
|
||||
}
|
||||
|
||||
if resp.StatusCode == 200 {
|
||||
fmt.Println("\n✅ 请求成功!")
|
||||
} else {
|
||||
fmt.Println("\n❌ 请求失败!分析错误原因...")
|
||||
analyzeError(respBody)
|
||||
}
|
||||
}
|
||||
|
||||
func analyzeError(body []byte) {
|
||||
var errResp struct {
|
||||
Message string `json:"message"`
|
||||
Reason string `json:"reason"`
|
||||
Error struct {
|
||||
Message string `json:"message"`
|
||||
Type string `json:"type"`
|
||||
} `json:"error"`
|
||||
}
|
||||
json.Unmarshal(body, &errResp)
|
||||
|
||||
if errResp.Message != "" {
|
||||
fmt.Printf("错误消息: %s\n", errResp.Message)
|
||||
}
|
||||
if errResp.Reason != "" {
|
||||
fmt.Printf("错误原因: %s\n", errResp.Reason)
|
||||
}
|
||||
if errResp.Error.Message != "" {
|
||||
fmt.Printf("错误详情: %s (类型: %s)\n", errResp.Error.Message, errResp.Error.Type)
|
||||
}
|
||||
|
||||
// 分析常见错误
|
||||
bodyStr := string(body)
|
||||
if strings.Contains(bodyStr, "bearer token") || strings.Contains(bodyStr, "invalid") {
|
||||
fmt.Println("\n可能的原因:")
|
||||
fmt.Println(" 1. Token 已过期 - 需要刷新")
|
||||
fmt.Println(" 2. Token 格式不正确 - 检查凭据文件")
|
||||
fmt.Println(" 3. 代理服务加载了旧的 Token")
|
||||
}
|
||||
}
|
||||
|
||||
func checkProxyCredentials() {
|
||||
// 尝试通过管理 API 获取凭据状态
|
||||
req, _ := http.NewRequest("GET", ProxyURL+"/v0/management/auth/list", nil)
|
||||
// 使用配置中的管理密钥 admin123
|
||||
req.Header.Set("Authorization", "Bearer admin123")
|
||||
|
||||
client := &http.Client{Timeout: 10 * time.Second}
|
||||
resp, err := client.Do(req)
|
||||
if err != nil {
|
||||
fmt.Printf("❌ 无法访问管理 API: %v\n", err)
|
||||
return
|
||||
}
|
||||
defer resp.Body.Close()
|
||||
|
||||
body, _ := io.ReadAll(resp.Body)
|
||||
|
||||
if resp.StatusCode == 200 {
|
||||
fmt.Println("管理 API 返回的凭据列表:")
|
||||
var prettyJSON bytes.Buffer
|
||||
if err := json.Indent(&prettyJSON, body, " ", " "); err == nil {
|
||||
fmt.Printf("%s\n", prettyJSON.String())
|
||||
} else {
|
||||
fmt.Printf("%s\n", string(body))
|
||||
}
|
||||
} else {
|
||||
fmt.Printf("管理 API 返回: HTTP %d\n", resp.StatusCode)
|
||||
fmt.Printf("响应: %s\n", truncate(string(body), 200))
|
||||
}
|
||||
}
|
||||
|
||||
func compareDirectVsProxy() {
|
||||
homeDir, _ := os.UserHomeDir()
|
||||
tokenPath := filepath.Join(homeDir, ".aws", "sso", "cache", "kiro-auth-token.json")
|
||||
|
||||
data, err := os.ReadFile(tokenPath)
|
||||
if err != nil {
|
||||
fmt.Printf("❌ 无法读取 Token 文件: %v\n", err)
|
||||
return
|
||||
}
|
||||
|
||||
var token struct {
|
||||
AccessToken string `json:"accessToken"`
|
||||
Region string `json:"region"`
|
||||
}
|
||||
json.Unmarshal(data, &token)
|
||||
|
||||
if token.Region == "" {
|
||||
token.Region = "us-east-1"
|
||||
}
|
||||
|
||||
// 直接请求
|
||||
fmt.Println("\n1. 直接请求 Kiro API:")
|
||||
directSuccess := testDirectKiroAPI(token.AccessToken, token.Region)
|
||||
|
||||
// 通过代理请求
|
||||
fmt.Println("\n2. 通过代理请求:")
|
||||
proxySuccess := testProxyAPI()
|
||||
|
||||
// 结论
|
||||
fmt.Println("\n结论:")
|
||||
if directSuccess && !proxySuccess {
|
||||
fmt.Println(" ⚠️ 直接请求成功,代理请求失败")
|
||||
fmt.Println(" 问题在于 CLIProxyAPIPlus 代理层")
|
||||
fmt.Println(" 可能原因:")
|
||||
fmt.Println(" 1. 代理服务使用了过期的 Token")
|
||||
fmt.Println(" 2. Token 刷新逻辑有问题")
|
||||
fmt.Println(" 3. 请求头构造不正确")
|
||||
} else if directSuccess && proxySuccess {
|
||||
fmt.Println(" ✅ 两者都成功")
|
||||
} else if !directSuccess && !proxySuccess {
|
||||
fmt.Println(" ❌ 两者都失败 - Token 本身可能有问题")
|
||||
}
|
||||
}
|
||||
|
||||
func testDirectKiroAPI(accessToken, region string) bool {
|
||||
url := fmt.Sprintf("https://codewhisperer.%s.amazonaws.com", region)
|
||||
|
||||
payload := map[string]interface{}{
|
||||
"origin": "AI_EDITOR",
|
||||
"isEmailRequired": true,
|
||||
"resourceType": "AGENTIC_REQUEST",
|
||||
}
|
||||
body, _ := json.Marshal(payload)
|
||||
|
||||
req, _ := http.NewRequest("POST", url, bytes.NewBuffer(body))
|
||||
req.Header.Set("Content-Type", "application/x-amz-json-1.0")
|
||||
req.Header.Set("x-amz-target", "AmazonCodeWhispererService.GetUsageLimits")
|
||||
req.Header.Set("Authorization", "Bearer "+accessToken)
|
||||
req.Header.Set("Accept", "application/json")
|
||||
|
||||
client := &http.Client{Timeout: 30 * time.Second}
|
||||
resp, err := client.Do(req)
|
||||
if err != nil {
|
||||
fmt.Printf(" ❌ 请求失败: %v\n", err)
|
||||
return false
|
||||
}
|
||||
defer resp.Body.Close()
|
||||
|
||||
if resp.StatusCode == 200 {
|
||||
fmt.Printf(" ✅ 成功 (HTTP %d)\n", resp.StatusCode)
|
||||
return true
|
||||
}
|
||||
respBody, _ := io.ReadAll(resp.Body)
|
||||
fmt.Printf(" ❌ 失败 (HTTP %d): %s\n", resp.StatusCode, truncate(string(respBody), 100))
|
||||
return false
|
||||
}
|
||||
|
||||
func testProxyAPI() bool {
|
||||
models := getModels()
|
||||
if len(models) == 0 {
|
||||
fmt.Println(" ❌ 没有可用模型")
|
||||
return false
|
||||
}
|
||||
|
||||
payload := map[string]interface{}{
|
||||
"model": models[0],
|
||||
"messages": []map[string]string{
|
||||
{"role": "user", "content": "Say OK"},
|
||||
},
|
||||
"max_tokens": 10,
|
||||
"stream": false,
|
||||
}
|
||||
body, _ := json.Marshal(payload)
|
||||
|
||||
req, _ := http.NewRequest("POST", ProxyURL+"/v1/chat/completions", bytes.NewBuffer(body))
|
||||
req.Header.Set("Authorization", "Bearer "+APIKey)
|
||||
req.Header.Set("Content-Type", "application/json")
|
||||
|
||||
client := &http.Client{Timeout: 60 * time.Second}
|
||||
resp, err := client.Do(req)
|
||||
if err != nil {
|
||||
fmt.Printf(" ❌ 请求失败: %v\n", err)
|
||||
return false
|
||||
}
|
||||
defer resp.Body.Close()
|
||||
|
||||
if resp.StatusCode == 200 {
|
||||
fmt.Printf(" ✅ 成功 (HTTP %d)\n", resp.StatusCode)
|
||||
return true
|
||||
}
|
||||
respBody, _ := io.ReadAll(resp.Body)
|
||||
fmt.Printf(" ❌ 失败 (HTTP %d): %s\n", resp.StatusCode, truncate(string(respBody), 100))
|
||||
return false
|
||||
}
|
||||
|
||||
func truncate(s string, n int) string {
|
||||
if len(s) <= n {
|
||||
return s
|
||||
}
|
||||
return s[:n] + "..."
|
||||
}
|
||||
Reference in New Issue
Block a user