diff --git a/README.md b/README.md index 2d950a4c..54a8f46d 100644 --- a/README.md +++ b/README.md @@ -27,6 +27,51 @@ The Plus release stays in lockstep with the mainline features. ## Kiro Authentication +### CLI Login + +> **Note:** Google/GitHub login is not available for third-party applications due to AWS Cognito restrictions. + +**AWS Builder ID** (recommended): + +```bash +# Device code flow +./CLIProxyAPI --kiro-aws-login + +# Authorization code flow +./CLIProxyAPI --kiro-aws-authcode +``` + +**Import token from Kiro IDE:** + +```bash +./CLIProxyAPI --kiro-import +``` + +To get a token from Kiro IDE: + +1. Open Kiro IDE and login with Google (or GitHub) +2. Find the token file: `~/.kiro/kiro-auth-token.json` +3. Run: `./CLIProxyAPI --kiro-import` + +**AWS IAM Identity Center (IDC):** + +```bash +./CLIProxyAPI --kiro-idc-login --kiro-idc-start-url https://d-xxxxxxxxxx.awsapps.com/start + +# Specify region +./CLIProxyAPI --kiro-idc-login --kiro-idc-start-url https://d-xxxxxxxxxx.awsapps.com/start --kiro-idc-region us-west-2 +``` + +**Additional flags:** + +| Flag | Description | +|------|-------------| +| `--no-browser` | Don't open browser automatically, print URL instead | +| `--no-incognito` | Use existing browser session (Kiro defaults to incognito). Useful for corporate SSO that requires an authenticated browser session | +| `--kiro-idc-start-url` | IDC Start URL (required with `--kiro-idc-login`) | +| `--kiro-idc-region` | IDC region (default: `us-east-1`) | +| `--kiro-idc-flow` | IDC flow type: `authcode` (default) or `device` | + ### Web-based OAuth Login Access the Kiro OAuth web interface at: diff --git a/README_CN.md b/README_CN.md index 79b5203f..41ff4e50 100644 --- a/README_CN.md +++ b/README_CN.md @@ -27,6 +27,51 @@ ## Kiro 认证 +### 命令行登录 + +> **注意:** 由于 AWS Cognito 限制,Google/GitHub 登录不可用于第三方应用。 + +**AWS Builder ID**(推荐): + +```bash +# 设备码流程 +./CLIProxyAPI --kiro-aws-login + +# 授权码流程 +./CLIProxyAPI --kiro-aws-authcode +``` + +**从 Kiro IDE 导入令牌:** + +```bash +./CLIProxyAPI --kiro-import +``` + +获取令牌步骤: + +1. 打开 Kiro IDE,使用 Google(或 GitHub)登录 +2. 找到令牌文件:`~/.kiro/kiro-auth-token.json` +3. 运行:`./CLIProxyAPI --kiro-import` + +**AWS IAM Identity Center (IDC):** + +```bash +./CLIProxyAPI --kiro-idc-login --kiro-idc-start-url https://d-xxxxxxxxxx.awsapps.com/start + +# 指定区域 +./CLIProxyAPI --kiro-idc-login --kiro-idc-start-url https://d-xxxxxxxxxx.awsapps.com/start --kiro-idc-region us-west-2 +``` + +**附加参数:** + +| 参数 | 说明 | +|------|------| +| `--no-browser` | 不自动打开浏览器,打印 URL | +| `--no-incognito` | 使用已有浏览器会话(Kiro 默认使用无痕模式),适用于需要已登录浏览器会话的企业 SSO 场景 | +| `--kiro-idc-start-url` | IDC Start URL(`--kiro-idc-login` 必需) | +| `--kiro-idc-region` | IDC 区域(默认:`us-east-1`) | +| `--kiro-idc-flow` | IDC 流程类型:`authcode`(默认)或 `device` | + ### 网页端 OAuth 登录 访问 Kiro OAuth 网页认证界面: diff --git a/cmd/server/main.go b/cmd/server/main.go index db95f6b3..9a204ebb 100644 --- a/cmd/server/main.go +++ b/cmd/server/main.go @@ -87,6 +87,10 @@ func main() { var kiroAWSLogin bool var kiroAWSAuthCode bool var kiroImport bool + var kiroIDCLogin bool + var kiroIDCStartURL string + var kiroIDCRegion string + var kiroIDCFlow string var githubCopilotLogin bool var projectID string var vertexImport string @@ -117,6 +121,10 @@ func main() { flag.BoolVar(&kiroAWSLogin, "kiro-aws-login", false, "Login to Kiro using AWS Builder ID (device code flow)") flag.BoolVar(&kiroAWSAuthCode, "kiro-aws-authcode", false, "Login to Kiro using AWS Builder ID (authorization code flow, better UX)") flag.BoolVar(&kiroImport, "kiro-import", false, "Import Kiro token from Kiro IDE (~/.aws/sso/cache/kiro-auth-token.json)") + flag.BoolVar(&kiroIDCLogin, "kiro-idc-login", false, "Login to Kiro using IAM Identity Center (IDC)") + flag.StringVar(&kiroIDCStartURL, "kiro-idc-start-url", "", "IDC start URL (required with --kiro-idc-login)") + flag.StringVar(&kiroIDCRegion, "kiro-idc-region", "", "IDC region (default: us-east-1)") + flag.StringVar(&kiroIDCFlow, "kiro-idc-flow", "", "IDC flow type: authcode (default) or device") flag.BoolVar(&githubCopilotLogin, "github-copilot-login", false, "Login to GitHub Copilot using device flow") flag.StringVar(&projectID, "project_id", "", "Project ID (Gemini only, not required)") flag.StringVar(&configPath, "config", DefaultConfigPath, "Configure File Path") @@ -526,24 +534,34 @@ func main() { // Note: This config mutation is safe - auth commands exit after completion // and don't share config with StartService (which is in the else branch) setKiroIncognitoMode(cfg, useIncognito, noIncognito) + kiro.InitFingerprintConfig(cfg) cmd.DoKiroLogin(cfg, options) } else if kiroGoogleLogin { // For Kiro auth, default to incognito mode for multi-account support // Users can explicitly override with --no-incognito // Note: This config mutation is safe - auth commands exit after completion setKiroIncognitoMode(cfg, useIncognito, noIncognito) + kiro.InitFingerprintConfig(cfg) cmd.DoKiroGoogleLogin(cfg, options) } else if kiroAWSLogin { // For Kiro auth, default to incognito mode for multi-account support // Users can explicitly override with --no-incognito setKiroIncognitoMode(cfg, useIncognito, noIncognito) + kiro.InitFingerprintConfig(cfg) cmd.DoKiroAWSLogin(cfg, options) } else if kiroAWSAuthCode { // For Kiro auth with authorization code flow (better UX) setKiroIncognitoMode(cfg, useIncognito, noIncognito) + kiro.InitFingerprintConfig(cfg) cmd.DoKiroAWSAuthCodeLogin(cfg, options) } else if kiroImport { + kiro.InitFingerprintConfig(cfg) cmd.DoKiroImport(cfg, options) + } else if kiroIDCLogin { + // For Kiro IDC auth, default to incognito mode for multi-account support + setKiroIncognitoMode(cfg, useIncognito, noIncognito) + kiro.InitFingerprintConfig(cfg) + cmd.DoKiroIDCLogin(cfg, options, kiroIDCStartURL, kiroIDCRegion, kiroIDCFlow) } else { // In cloud deploy mode without config file, just wait for shutdown signals if isCloudDeploy && !configFileExists { diff --git a/config.example.yaml b/config.example.yaml index b513eb60..433efec6 100644 --- a/config.example.yaml +++ b/config.example.yaml @@ -179,6 +179,8 @@ nonstream-keepalive-interval: 0 #kiro: # - token-file: "~/.aws/sso/cache/kiro-auth-token.json" # path to Kiro token file # agent-task-type: "" # optional: "vibe" or empty (API default) +# start-url: "https://your-company.awsapps.com/start" # optional: IDC start URL (preset for login) +# region: "us-east-1" # optional: OIDC region for IDC login and token refresh # - access-token: "aoaAAAAA..." # or provide tokens directly # refresh-token: "aorAAAAA..." # profile-arn: "arn:aws:codewhisperer:us-east-1:..." diff --git a/internal/auth/kiro/aws.go b/internal/auth/kiro/aws.go index 6ec67c49..572050c5 100644 --- a/internal/auth/kiro/aws.go +++ b/internal/auth/kiro/aws.go @@ -7,10 +7,13 @@ import ( "encoding/json" "errors" "fmt" + "net/url" "os" "path/filepath" "strings" "time" + + log "github.com/sirupsen/logrus" ) // PKCECodes holds PKCE verification codes for OAuth2 PKCE flow @@ -47,7 +50,7 @@ type KiroTokenData struct { Email string `json:"email,omitempty"` // StartURL is the IDC/Identity Center start URL (only for IDC auth method) StartURL string `json:"startUrl,omitempty"` - // Region is the AWS region for IDC authentication (only for IDC auth method) + // Region is the OIDC region for IDC login and token refresh Region string `json:"region,omitempty"` } @@ -520,3 +523,159 @@ func GenerateTokenFileName(tokenData *KiroTokenData) string { // Priority 3: Fallback to authMethod only with sequence return fmt.Sprintf("kiro-%s-%05d.json", authMethod, seq) } + +// DefaultKiroRegion is the fallback region when none is specified. +const DefaultKiroRegion = "us-east-1" + +// GetCodeWhispererLegacyEndpoint returns the legacy CodeWhisperer JSON-RPC endpoint. +// This endpoint supports JSON-RPC style requests with x-amz-target headers. +// The Q endpoint (q.{region}.amazonaws.com) does NOT support JSON-RPC style. +func GetCodeWhispererLegacyEndpoint(region string) string { + if region == "" { + region = DefaultKiroRegion + } + return "https://codewhisperer." + region + ".amazonaws.com" +} + +// ProfileARN represents a parsed AWS CodeWhisperer profile ARN. +// ARN format: arn:partition:service:region:account-id:resource-type/resource-id +// Example: arn:aws:codewhisperer:us-east-1:123456789012:profile/ABCDEFGHIJKL +type ProfileARN struct { + // Raw is the original ARN string + Raw string + // Partition is the AWS partition (aws) + Partition string + // Service is the AWS service name (codewhisperer) + Service string + // Region is the AWS region (us-east-1, ap-southeast-1, etc.) + Region string + // AccountID is the AWS account ID + AccountID string + // ResourceType is the resource type (profile) + ResourceType string + // ResourceID is the resource identifier (e.g., ABCDEFGHIJKL) + ResourceID string +} + +// ParseProfileARN parses an AWS ARN string into a ProfileARN struct. +// Returns nil if the ARN is empty, invalid, or not a codewhisperer ARN. +func ParseProfileARN(arn string) *ProfileARN { + if arn == "" { + return nil + } + // ARN format: arn:partition:service:region:account-id:resource + // Minimum 6 parts separated by ":" + parts := strings.Split(arn, ":") + if len(parts) < 6 { + log.Warnf("invalid ARN format: %s", arn) + return nil + } + // Validate ARN prefix + if parts[0] != "arn" { + return nil + } + // Validate partition + partition := parts[1] + if partition == "" { + return nil + } + // Validate service is codewhisperer + service := parts[2] + if service != "codewhisperer" { + return nil + } + // Validate region format (must contain "-") + region := parts[3] + if region == "" || !strings.Contains(region, "-") { + return nil + } + // Account ID + accountID := parts[4] + + // Parse resource (format: resource-type/resource-id) + // Join remaining parts in case resource contains ":" + resource := strings.Join(parts[5:], ":") + resourceType := "" + resourceID := "" + if idx := strings.Index(resource, "/"); idx > 0 { + resourceType = resource[:idx] + resourceID = resource[idx+1:] + } else { + resourceType = resource + } + + return &ProfileARN{ + Raw: arn, + Partition: partition, + Service: service, + Region: region, + AccountID: accountID, + ResourceType: resourceType, + ResourceID: resourceID, + } +} + +// GetKiroAPIEndpoint returns the Q API endpoint for the specified region. +// If region is empty, defaults to us-east-1. +func GetKiroAPIEndpoint(region string) string { + if region == "" { + region = DefaultKiroRegion + } + return "https://q." + region + ".amazonaws.com" +} + +// GetKiroAPIEndpointFromProfileArn extracts region from profileArn and returns the endpoint. +// Returns default us-east-1 endpoint if region cannot be extracted. +func GetKiroAPIEndpointFromProfileArn(profileArn string) string { + region := ExtractRegionFromProfileArn(profileArn) + return GetKiroAPIEndpoint(region) +} + +// ExtractRegionFromProfileArn extracts the AWS region from a ProfileARN string. +// Returns empty string if ARN is invalid or region cannot be extracted. +func ExtractRegionFromProfileArn(profileArn string) string { + parsed := ParseProfileARN(profileArn) + if parsed == nil { + return "" + } + return parsed.Region +} + +// ExtractRegionFromMetadata extracts API region from auth metadata. +// Priority: api_region > profile_arn > DefaultKiroRegion +func ExtractRegionFromMetadata(metadata map[string]interface{}) string { + if metadata == nil { + return DefaultKiroRegion + } + + // Priority 1: Explicit api_region override + if r, ok := metadata["api_region"].(string); ok && r != "" { + return r + } + + // Priority 2: Extract from ProfileARN + if profileArn, ok := metadata["profile_arn"].(string); ok && profileArn != "" { + if region := ExtractRegionFromProfileArn(profileArn); region != "" { + return region + } + } + + return DefaultKiroRegion +} + +func buildURL(endpoint, path string, queryParams map[string]string) string { + fullURL := fmt.Sprintf("%s/%s", endpoint, path) + if len(queryParams) > 0 { + values := url.Values{} + for key, value := range queryParams { + if value == "" { + continue + } + values.Set(key, value) + } + if encoded := values.Encode(); encoded != "" { + fullURL = fullURL + "?" + encoded + } + } + return fullURL +} diff --git a/internal/auth/kiro/aws_auth.go b/internal/auth/kiro/aws_auth.go index 69ae2539..dcda376b 100644 --- a/internal/auth/kiro/aws_auth.go +++ b/internal/auth/kiro/aws_auth.go @@ -19,15 +19,8 @@ import ( ) const ( - // awsKiroEndpoint is used for CodeWhisperer management APIs (GetUsageLimits, ListProfiles, etc.) - // Note: This is different from the Amazon Q streaming endpoint (q.us-east-1.amazonaws.com) - // used in kiro_executor.go for GenerateAssistantResponse. Both endpoints are correct - // for their respective API operations. - awsKiroEndpoint = "https://codewhisperer.us-east-1.amazonaws.com" - defaultTokenFile = "~/.aws/sso/cache/kiro-auth-token.json" - targetGetUsage = "AmazonCodeWhispererService.GetUsageLimits" - targetListModels = "AmazonCodeWhispererService.ListAvailableModels" - targetGenerateChat = "AmazonCodeWhispererStreamingService.GenerateAssistantResponse" + pathGetUsageLimits = "getUsageLimits" + pathListAvailableModels = "ListAvailableModels" ) // KiroAuth handles AWS CodeWhisperer authentication and API communication. @@ -35,7 +28,6 @@ const ( // and communicating with the CodeWhisperer API. type KiroAuth struct { httpClient *http.Client - endpoint string } // NewKiroAuth creates a new Kiro authentication service. @@ -49,7 +41,6 @@ type KiroAuth struct { func NewKiroAuth(cfg *config.Config) *KiroAuth { return &KiroAuth{ httpClient: util.SetProxy(&cfg.SDKConfig, &http.Client{Timeout: 120 * time.Second}), - endpoint: awsKiroEndpoint, } } @@ -110,33 +101,30 @@ func (k *KiroAuth) IsTokenExpired(tokenData *KiroTokenData) bool { return time.Now().After(expiresAt) } -// makeRequest sends a request to the CodeWhisperer API. -// This is an internal method for making authenticated API calls. +// makeRequest sends a REST-style GET request to the CodeWhisperer API. // // Parameters: // - ctx: The context for the request -// - target: The API target (e.g., "AmazonCodeWhispererService.GetUsageLimits") -// - accessToken: The OAuth access token -// - payload: The request payload +// - path: The API path (e.g., "getUsageLimits") +// - tokenData: The token data containing access token, refresh token, and profile ARN +// - queryParams: Query parameters to add to the URL // // Returns: // - []byte: The response body // - error: An error if the request fails -func (k *KiroAuth) makeRequest(ctx context.Context, target string, accessToken string, payload interface{}) ([]byte, error) { - jsonBody, err := json.Marshal(payload) - if err != nil { - return nil, fmt.Errorf("failed to marshal request: %w", err) - } +func (k *KiroAuth) makeRequest(ctx context.Context, path string, tokenData *KiroTokenData, queryParams map[string]string) ([]byte, error) { + // Get endpoint from profileArn (defaults to us-east-1 if empty) + profileArn := queryParams["profileArn"] + endpoint := GetKiroAPIEndpointFromProfileArn(profileArn) + url := buildURL(endpoint, path, queryParams) - req, err := http.NewRequestWithContext(ctx, http.MethodPost, k.endpoint, strings.NewReader(string(jsonBody))) + req, err := http.NewRequestWithContext(ctx, http.MethodGet, url, nil) if err != nil { return nil, fmt.Errorf("failed to create request: %w", err) } - req.Header.Set("Content-Type", "application/x-amz-json-1.0") - req.Header.Set("x-amz-target", target) - req.Header.Set("Authorization", "Bearer "+accessToken) - req.Header.Set("Accept", "application/json") + accountKey := GetAccountKey(tokenData.ClientID, tokenData.RefreshToken) + setRuntimeHeaders(req, tokenData.AccessToken, accountKey) resp, err := k.httpClient.Do(req) if err != nil { @@ -171,13 +159,13 @@ func (k *KiroAuth) makeRequest(ctx context.Context, target string, accessToken s // - *KiroUsageInfo: The usage information // - error: An error if the request fails func (k *KiroAuth) GetUsageLimits(ctx context.Context, tokenData *KiroTokenData) (*KiroUsageInfo, error) { - payload := map[string]interface{}{ + queryParams := map[string]string{ "origin": "AI_EDITOR", "profileArn": tokenData.ProfileArn, "resourceType": "AGENTIC_REQUEST", } - body, err := k.makeRequest(ctx, targetGetUsage, tokenData.AccessToken, payload) + body, err := k.makeRequest(ctx, pathGetUsageLimits, tokenData, queryParams) if err != nil { return nil, err } @@ -221,12 +209,12 @@ func (k *KiroAuth) GetUsageLimits(ctx context.Context, tokenData *KiroTokenData) // - []*KiroModel: The list of available models // - error: An error if the request fails func (k *KiroAuth) ListAvailableModels(ctx context.Context, tokenData *KiroTokenData) ([]*KiroModel, error) { - payload := map[string]interface{}{ + queryParams := map[string]string{ "origin": "AI_EDITOR", "profileArn": tokenData.ProfileArn, } - body, err := k.makeRequest(ctx, targetListModels, tokenData.AccessToken, payload) + body, err := k.makeRequest(ctx, pathListAvailableModels, tokenData, queryParams) if err != nil { return nil, err } diff --git a/internal/auth/kiro/aws_test.go b/internal/auth/kiro/aws_test.go index 194ad59e..3033c985 100644 --- a/internal/auth/kiro/aws_test.go +++ b/internal/auth/kiro/aws_test.go @@ -3,6 +3,7 @@ package kiro import ( "encoding/base64" "encoding/json" + "strings" "testing" ) @@ -217,7 +218,8 @@ func TestGenerateTokenFileName(t *testing.T) { tests := []struct { name string tokenData *KiroTokenData - expected string + exact string // exact match (for cases with email) + prefix string // prefix match (for cases without email, where sequence is appended) }{ { name: "IDC with email", @@ -226,7 +228,7 @@ func TestGenerateTokenFileName(t *testing.T) { Email: "user@example.com", StartURL: "https://d-1234567890.awsapps.com/start", }, - expected: "kiro-idc-user-example-com.json", + exact: "kiro-idc-user-example-com.json", }, { name: "IDC without email but with startUrl", @@ -235,7 +237,7 @@ func TestGenerateTokenFileName(t *testing.T) { Email: "", StartURL: "https://d-1234567890.awsapps.com/start", }, - expected: "kiro-idc-d-1234567890.json", + prefix: "kiro-idc-d-1234567890-", }, { name: "IDC with company name in startUrl", @@ -244,7 +246,7 @@ func TestGenerateTokenFileName(t *testing.T) { Email: "", StartURL: "https://my-company.awsapps.com/start", }, - expected: "kiro-idc-my-company.json", + prefix: "kiro-idc-my-company-", }, { name: "IDC without email and without startUrl", @@ -253,7 +255,7 @@ func TestGenerateTokenFileName(t *testing.T) { Email: "", StartURL: "", }, - expected: "kiro-idc.json", + prefix: "kiro-idc-", }, { name: "Builder ID with email", @@ -262,7 +264,7 @@ func TestGenerateTokenFileName(t *testing.T) { Email: "user@gmail.com", StartURL: "https://view.awsapps.com/start", }, - expected: "kiro-builder-id-user-gmail-com.json", + exact: "kiro-builder-id-user-gmail-com.json", }, { name: "Builder ID without email", @@ -271,7 +273,7 @@ func TestGenerateTokenFileName(t *testing.T) { Email: "", StartURL: "https://view.awsapps.com/start", }, - expected: "kiro-builder-id.json", + prefix: "kiro-builder-id-", }, { name: "Social auth with email", @@ -279,7 +281,7 @@ func TestGenerateTokenFileName(t *testing.T) { AuthMethod: "google", Email: "user@gmail.com", }, - expected: "kiro-google-user-gmail-com.json", + exact: "kiro-google-user-gmail-com.json", }, { name: "Empty auth method", @@ -287,7 +289,7 @@ func TestGenerateTokenFileName(t *testing.T) { AuthMethod: "", Email: "", }, - expected: "kiro-unknown.json", + prefix: "kiro-unknown-", }, { name: "Email with special characters", @@ -296,16 +298,454 @@ func TestGenerateTokenFileName(t *testing.T) { Email: "user.name+tag@sub.example.com", StartURL: "https://d-1234567890.awsapps.com/start", }, - expected: "kiro-idc-user-name+tag-sub-example-com.json", + exact: "kiro-idc-user-name+tag-sub-example-com.json", }, } for _, tt := range tests { t.Run(tt.name, func(t *testing.T) { result := GenerateTokenFileName(tt.tokenData) - if result != tt.expected { - t.Errorf("GenerateTokenFileName() = %q, want %q", result, tt.expected) + if tt.exact != "" { + if result != tt.exact { + t.Errorf("GenerateTokenFileName() = %q, want %q", result, tt.exact) + } + } else if tt.prefix != "" { + if !strings.HasPrefix(result, tt.prefix) || !strings.HasSuffix(result, ".json") { + t.Errorf("GenerateTokenFileName() = %q, want prefix %q with .json suffix", result, tt.prefix) + } } }) } } + +func TestParseProfileARN(t *testing.T) { + tests := []struct { + name string + arn string + expected *ProfileARN + }{ + { + name: "Empty ARN", + arn: "", + expected: nil, + }, + { + name: "Invalid format - too few parts", + arn: "arn:aws:codewhisperer", + expected: nil, + }, + { + name: "Invalid prefix - not arn", + arn: "notarn:aws:codewhisperer:us-east-1:123456789012:profile/ABC", + expected: nil, + }, + { + name: "Invalid service - not codewhisperer", + arn: "arn:aws:s3:us-east-1:123456789012:bucket/mybucket", + expected: nil, + }, + { + name: "Invalid region - no hyphen", + arn: "arn:aws:codewhisperer:useast1:123456789012:profile/ABC", + expected: nil, + }, + { + name: "Empty partition", + arn: "arn::codewhisperer:us-east-1:123456789012:profile/ABC", + expected: nil, + }, + { + name: "Empty region", + arn: "arn:aws:codewhisperer::123456789012:profile/ABC", + expected: nil, + }, + { + name: "Valid ARN - us-east-1", + arn: "arn:aws:codewhisperer:us-east-1:123456789012:profile/ABCDEFGHIJKL", + expected: &ProfileARN{ + Raw: "arn:aws:codewhisperer:us-east-1:123456789012:profile/ABCDEFGHIJKL", + Partition: "aws", + Service: "codewhisperer", + Region: "us-east-1", + AccountID: "123456789012", + ResourceType: "profile", + ResourceID: "ABCDEFGHIJKL", + }, + }, + { + name: "Valid ARN - ap-southeast-1", + arn: "arn:aws:codewhisperer:ap-southeast-1:987654321098:profile/ZYXWVUTSRQ", + expected: &ProfileARN{ + Raw: "arn:aws:codewhisperer:ap-southeast-1:987654321098:profile/ZYXWVUTSRQ", + Partition: "aws", + Service: "codewhisperer", + Region: "ap-southeast-1", + AccountID: "987654321098", + ResourceType: "profile", + ResourceID: "ZYXWVUTSRQ", + }, + }, + { + name: "Valid ARN - eu-west-1", + arn: "arn:aws:codewhisperer:eu-west-1:111222333444:profile/PROFILE123", + expected: &ProfileARN{ + Raw: "arn:aws:codewhisperer:eu-west-1:111222333444:profile/PROFILE123", + Partition: "aws", + Service: "codewhisperer", + Region: "eu-west-1", + AccountID: "111222333444", + ResourceType: "profile", + ResourceID: "PROFILE123", + }, + }, + { + name: "Valid ARN - aws-cn partition", + arn: "arn:aws-cn:codewhisperer:cn-north-1:123456789012:profile/CHINAID", + expected: &ProfileARN{ + Raw: "arn:aws-cn:codewhisperer:cn-north-1:123456789012:profile/CHINAID", + Partition: "aws-cn", + Service: "codewhisperer", + Region: "cn-north-1", + AccountID: "123456789012", + ResourceType: "profile", + ResourceID: "CHINAID", + }, + }, + { + name: "Valid ARN - resource without slash", + arn: "arn:aws:codewhisperer:us-west-2:123456789012:profile", + expected: &ProfileARN{ + Raw: "arn:aws:codewhisperer:us-west-2:123456789012:profile", + Partition: "aws", + Service: "codewhisperer", + Region: "us-west-2", + AccountID: "123456789012", + ResourceType: "profile", + ResourceID: "", + }, + }, + { + name: "Valid ARN - resource with colon", + arn: "arn:aws:codewhisperer:us-east-1:123456789012:profile/ABC:extra", + expected: &ProfileARN{ + Raw: "arn:aws:codewhisperer:us-east-1:123456789012:profile/ABC:extra", + Partition: "aws", + Service: "codewhisperer", + Region: "us-east-1", + AccountID: "123456789012", + ResourceType: "profile", + ResourceID: "ABC:extra", + }, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + result := ParseProfileARN(tt.arn) + if tt.expected == nil { + if result != nil { + t.Errorf("ParseProfileARN(%q) = %+v, want nil", tt.arn, result) + } + return + } + if result == nil { + t.Errorf("ParseProfileARN(%q) = nil, want %+v", tt.arn, tt.expected) + return + } + if result.Raw != tt.expected.Raw { + t.Errorf("Raw = %q, want %q", result.Raw, tt.expected.Raw) + } + if result.Partition != tt.expected.Partition { + t.Errorf("Partition = %q, want %q", result.Partition, tt.expected.Partition) + } + if result.Service != tt.expected.Service { + t.Errorf("Service = %q, want %q", result.Service, tt.expected.Service) + } + if result.Region != tt.expected.Region { + t.Errorf("Region = %q, want %q", result.Region, tt.expected.Region) + } + if result.AccountID != tt.expected.AccountID { + t.Errorf("AccountID = %q, want %q", result.AccountID, tt.expected.AccountID) + } + if result.ResourceType != tt.expected.ResourceType { + t.Errorf("ResourceType = %q, want %q", result.ResourceType, tt.expected.ResourceType) + } + if result.ResourceID != tt.expected.ResourceID { + t.Errorf("ResourceID = %q, want %q", result.ResourceID, tt.expected.ResourceID) + } + }) + } +} + +func TestExtractRegionFromProfileArn(t *testing.T) { + tests := []struct { + name string + profileArn string + expected string + }{ + { + name: "Empty ARN", + profileArn: "", + expected: "", + }, + { + name: "Invalid ARN", + profileArn: "invalid-arn", + expected: "", + }, + { + name: "Valid ARN - us-east-1", + profileArn: "arn:aws:codewhisperer:us-east-1:123456789012:profile/ABC", + expected: "us-east-1", + }, + { + name: "Valid ARN - ap-southeast-1", + profileArn: "arn:aws:codewhisperer:ap-southeast-1:123456789012:profile/ABC", + expected: "ap-southeast-1", + }, + { + name: "Valid ARN - eu-central-1", + profileArn: "arn:aws:codewhisperer:eu-central-1:123456789012:profile/ABC", + expected: "eu-central-1", + }, + { + name: "Non-codewhisperer ARN", + profileArn: "arn:aws:s3:us-east-1:123456789012:bucket/mybucket", + expected: "", + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + result := ExtractRegionFromProfileArn(tt.profileArn) + if result != tt.expected { + t.Errorf("ExtractRegionFromProfileArn(%q) = %q, want %q", tt.profileArn, result, tt.expected) + } + }) + } +} + +func TestGetKiroAPIEndpoint(t *testing.T) { + tests := []struct { + name string + region string + expected string + }{ + { + name: "Empty region - defaults to us-east-1", + region: "", + expected: "https://q.us-east-1.amazonaws.com", + }, + { + name: "us-east-1", + region: "us-east-1", + expected: "https://q.us-east-1.amazonaws.com", + }, + { + name: "us-west-2", + region: "us-west-2", + expected: "https://q.us-west-2.amazonaws.com", + }, + { + name: "ap-southeast-1", + region: "ap-southeast-1", + expected: "https://q.ap-southeast-1.amazonaws.com", + }, + { + name: "eu-west-1", + region: "eu-west-1", + expected: "https://q.eu-west-1.amazonaws.com", + }, + { + name: "cn-north-1", + region: "cn-north-1", + expected: "https://q.cn-north-1.amazonaws.com", + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + result := GetKiroAPIEndpoint(tt.region) + if result != tt.expected { + t.Errorf("GetKiroAPIEndpoint(%q) = %q, want %q", tt.region, result, tt.expected) + } + }) + } +} + +func TestGetKiroAPIEndpointFromProfileArn(t *testing.T) { + tests := []struct { + name string + profileArn string + expected string + }{ + { + name: "Empty ARN - defaults to us-east-1", + profileArn: "", + expected: "https://q.us-east-1.amazonaws.com", + }, + { + name: "Invalid ARN - defaults to us-east-1", + profileArn: "invalid-arn", + expected: "https://q.us-east-1.amazonaws.com", + }, + { + name: "Valid ARN - us-east-1", + profileArn: "arn:aws:codewhisperer:us-east-1:123456789012:profile/ABC", + expected: "https://q.us-east-1.amazonaws.com", + }, + { + name: "Valid ARN - ap-southeast-1", + profileArn: "arn:aws:codewhisperer:ap-southeast-1:123456789012:profile/ABC", + expected: "https://q.ap-southeast-1.amazonaws.com", + }, + { + name: "Valid ARN - eu-central-1", + profileArn: "arn:aws:codewhisperer:eu-central-1:123456789012:profile/ABC", + expected: "https://q.eu-central-1.amazonaws.com", + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + result := GetKiroAPIEndpointFromProfileArn(tt.profileArn) + if result != tt.expected { + t.Errorf("GetKiroAPIEndpointFromProfileArn(%q) = %q, want %q", tt.profileArn, result, tt.expected) + } + }) + } +} + +func TestGetCodeWhispererLegacyEndpoint(t *testing.T) { + tests := []struct { + name string + region string + expected string + }{ + { + name: "Empty region - defaults to us-east-1", + region: "", + expected: "https://codewhisperer.us-east-1.amazonaws.com", + }, + { + name: "us-east-1", + region: "us-east-1", + expected: "https://codewhisperer.us-east-1.amazonaws.com", + }, + { + name: "us-west-2", + region: "us-west-2", + expected: "https://codewhisperer.us-west-2.amazonaws.com", + }, + { + name: "ap-northeast-1", + region: "ap-northeast-1", + expected: "https://codewhisperer.ap-northeast-1.amazonaws.com", + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + result := GetCodeWhispererLegacyEndpoint(tt.region) + if result != tt.expected { + t.Errorf("GetCodeWhispererLegacyEndpoint(%q) = %q, want %q", tt.region, result, tt.expected) + } + }) + } +} + +func TestExtractRegionFromMetadata(t *testing.T) { + tests := []struct { + name string + metadata map[string]interface{} + expected string + }{ + { + name: "Nil metadata - defaults to us-east-1", + metadata: nil, + expected: "us-east-1", + }, + { + name: "Empty metadata - defaults to us-east-1", + metadata: map[string]interface{}{}, + expected: "us-east-1", + }, + { + name: "Priority 1: api_region override", + metadata: map[string]interface{}{ + "api_region": "eu-west-1", + "profile_arn": "arn:aws:codewhisperer:us-east-1:123456789012:profile/ABC", + }, + expected: "eu-west-1", + }, + { + name: "Priority 2: profile_arn when api_region is empty", + metadata: map[string]interface{}{ + "api_region": "", + "profile_arn": "arn:aws:codewhisperer:ap-southeast-1:123456789012:profile/ABC", + }, + expected: "ap-southeast-1", + }, + { + name: "Priority 2: profile_arn when api_region is missing", + metadata: map[string]interface{}{ + "profile_arn": "arn:aws:codewhisperer:eu-central-1:123456789012:profile/ABC", + }, + expected: "eu-central-1", + }, + { + name: "Fallback: default when profile_arn is invalid", + metadata: map[string]interface{}{ + "profile_arn": "invalid-arn", + }, + expected: "us-east-1", + }, + { + name: "Fallback: default when profile_arn is empty", + metadata: map[string]interface{}{ + "profile_arn": "", + }, + expected: "us-east-1", + }, + { + name: "OIDC region is NOT used for API region", + metadata: map[string]interface{}{ + "region": "ap-northeast-2", // OIDC region - should be ignored + }, + expected: "us-east-1", + }, + { + name: "api_region takes precedence over OIDC region", + metadata: map[string]interface{}{ + "api_region": "us-west-2", + "region": "ap-northeast-2", // OIDC region - should be ignored + }, + expected: "us-west-2", + }, + { + name: "Non-string api_region is ignored", + metadata: map[string]interface{}{ + "api_region": 123, // wrong type + "profile_arn": "arn:aws:codewhisperer:ap-south-1:123456789012:profile/ABC", + }, + expected: "ap-south-1", + }, + { + name: "Non-string profile_arn is ignored", + metadata: map[string]interface{}{ + "profile_arn": 123, // wrong type + }, + expected: "us-east-1", + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + result := ExtractRegionFromMetadata(tt.metadata) + if result != tt.expected { + t.Errorf("ExtractRegionFromMetadata(%v) = %q, want %q", tt.metadata, result, tt.expected) + } + }) + } +} + diff --git a/internal/auth/kiro/codewhisperer_client.go b/internal/auth/kiro/codewhisperer_client.go index 0a7392e8..04d678f9 100644 --- a/internal/auth/kiro/codewhisperer_client.go +++ b/internal/auth/kiro/codewhisperer_client.go @@ -9,30 +9,23 @@ import ( "net/http" "time" - "github.com/google/uuid" "github.com/router-for-me/CLIProxyAPI/v6/internal/config" "github.com/router-for-me/CLIProxyAPI/v6/internal/util" log "github.com/sirupsen/logrus" ) -const ( - codeWhispererAPI = "https://codewhisperer.us-east-1.amazonaws.com" - kiroVersion = "0.6.18" -) - // CodeWhispererClient handles CodeWhisperer API calls. type CodeWhispererClient struct { httpClient *http.Client - machineID string } // UsageLimitsResponse represents the getUsageLimits API response. type UsageLimitsResponse struct { - DaysUntilReset *int `json:"daysUntilReset,omitempty"` - NextDateReset *float64 `json:"nextDateReset,omitempty"` - UserInfo *UserInfo `json:"userInfo,omitempty"` - SubscriptionInfo *SubscriptionInfo `json:"subscriptionInfo,omitempty"` - UsageBreakdownList []UsageBreakdown `json:"usageBreakdownList,omitempty"` + DaysUntilReset *int `json:"daysUntilReset,omitempty"` + NextDateReset *float64 `json:"nextDateReset,omitempty"` + UserInfo *UserInfo `json:"userInfo,omitempty"` + SubscriptionInfo *SubscriptionInfo `json:"subscriptionInfo,omitempty"` + UsageBreakdownList []UsageBreakdown `json:"usageBreakdownList,omitempty"` } // UserInfo contains user information from the API. @@ -49,13 +42,13 @@ type SubscriptionInfo struct { // UsageBreakdown contains usage details. type UsageBreakdown struct { - UsageLimit *int `json:"usageLimit,omitempty"` - CurrentUsage *int `json:"currentUsage,omitempty"` - UsageLimitWithPrecision *float64 `json:"usageLimitWithPrecision,omitempty"` - CurrentUsageWithPrecision *float64 `json:"currentUsageWithPrecision,omitempty"` - NextDateReset *float64 `json:"nextDateReset,omitempty"` - DisplayName string `json:"displayName,omitempty"` - ResourceType string `json:"resourceType,omitempty"` + UsageLimit *int `json:"usageLimit,omitempty"` + CurrentUsage *int `json:"currentUsage,omitempty"` + UsageLimitWithPrecision *float64 `json:"usageLimitWithPrecision,omitempty"` + CurrentUsageWithPrecision *float64 `json:"currentUsageWithPrecision,omitempty"` + NextDateReset *float64 `json:"nextDateReset,omitempty"` + DisplayName string `json:"displayName,omitempty"` + ResourceType string `json:"resourceType,omitempty"` } // NewCodeWhispererClient creates a new CodeWhisperer client. @@ -64,40 +57,34 @@ func NewCodeWhispererClient(cfg *config.Config, machineID string) *CodeWhisperer if cfg != nil { client = util.SetProxy(&cfg.SDKConfig, client) } - if machineID == "" { - machineID = uuid.New().String() - } return &CodeWhispererClient{ httpClient: client, - machineID: machineID, } } -// generateInvocationID generates a unique invocation ID. -func generateInvocationID() string { - return uuid.New().String() -} - // GetUsageLimits fetches usage limits and user info from CodeWhisperer API. // This is the recommended way to get user email after login. -func (c *CodeWhispererClient) GetUsageLimits(ctx context.Context, accessToken string) (*UsageLimitsResponse, error) { - url := fmt.Sprintf("%s/getUsageLimits?isEmailRequired=true&origin=AI_EDITOR&resourceType=AGENTIC_REQUEST", codeWhispererAPI) +func (c *CodeWhispererClient) GetUsageLimits(ctx context.Context, accessToken, clientID, refreshToken, profileArn string) (*UsageLimitsResponse, error) { + queryParams := map[string]string{ + "origin": "AI_EDITOR", + "resourceType": "AGENTIC_REQUEST", + } + // Determine endpoint based on profileArn region + endpoint := GetKiroAPIEndpointFromProfileArn(profileArn) + if profileArn != "" { + queryParams["profileArn"] = profileArn + } else { + queryParams["isEmailRequired"] = "true" + } + url := buildURL(endpoint, pathGetUsageLimits, queryParams) req, err := http.NewRequestWithContext(ctx, http.MethodGet, url, nil) if err != nil { return nil, fmt.Errorf("failed to create request: %w", err) } - // Set headers to match Kiro IDE - xAmzUserAgent := fmt.Sprintf("aws-sdk-js/1.0.0 KiroIDE-%s-%s", kiroVersion, c.machineID) - userAgent := fmt.Sprintf("aws-sdk-js/1.0.0 ua/2.1 os/windows lang/js md/nodejs#20.16.0 api/codewhispererruntime#1.0.0 m/E KiroIDE-%s-%s", kiroVersion, c.machineID) - - req.Header.Set("Authorization", "Bearer "+accessToken) - req.Header.Set("x-amz-user-agent", xAmzUserAgent) - req.Header.Set("User-Agent", userAgent) - req.Header.Set("amz-sdk-invocation-id", generateInvocationID()) - req.Header.Set("amz-sdk-request", "attempt=1; max=1") - req.Header.Set("Connection", "close") + accountKey := GetAccountKey(clientID, refreshToken) + setRuntimeHeaders(req, accessToken, accountKey) log.Debugf("codewhisperer: GET %s", url) @@ -128,8 +115,8 @@ func (c *CodeWhispererClient) GetUsageLimits(ctx context.Context, accessToken st // FetchUserEmailFromAPI fetches user email using CodeWhisperer getUsageLimits API. // This is more reliable than JWT parsing as it uses the official API. -func (c *CodeWhispererClient) FetchUserEmailFromAPI(ctx context.Context, accessToken string) string { - resp, err := c.GetUsageLimits(ctx, accessToken) +func (c *CodeWhispererClient) FetchUserEmailFromAPI(ctx context.Context, accessToken, clientID, refreshToken string) string { + resp, err := c.GetUsageLimits(ctx, accessToken, clientID, refreshToken, "") if err != nil { log.Debugf("codewhisperer: failed to get usage limits: %v", err) return "" @@ -146,10 +133,10 @@ func (c *CodeWhispererClient) FetchUserEmailFromAPI(ctx context.Context, accessT // FetchUserEmailWithFallback fetches user email with multiple fallback methods. // Priority: 1. CodeWhisperer API 2. userinfo endpoint 3. JWT parsing -func FetchUserEmailWithFallback(ctx context.Context, cfg *config.Config, accessToken string) string { +func FetchUserEmailWithFallback(ctx context.Context, cfg *config.Config, accessToken, clientID, refreshToken string) string { // Method 1: Try CodeWhisperer API (most reliable) cwClient := NewCodeWhispererClient(cfg, "") - email := cwClient.FetchUserEmailFromAPI(ctx, accessToken) + email := cwClient.FetchUserEmailFromAPI(ctx, accessToken, clientID, refreshToken) if email != "" { return email } diff --git a/internal/auth/kiro/fingerprint.go b/internal/auth/kiro/fingerprint.go index c35e62b2..97bcdb86 100644 --- a/internal/auth/kiro/fingerprint.go +++ b/internal/auth/kiro/fingerprint.go @@ -2,77 +2,105 @@ package kiro import ( "crypto/sha256" + "encoding/binary" "encoding/hex" "fmt" "math/rand" "net/http" + "runtime" + "slices" "sync" "time" + + "github.com/google/uuid" ) -// Fingerprint 多维度指纹信息 +// Fingerprint holds multi-dimensional fingerprint data for runtime request disguise. type Fingerprint struct { - SDKVersion string // 1.0.20-1.0.27 + OIDCSDKVersion string // 3.7xx (AWS SDK JS) + RuntimeSDKVersion string // 1.0.x (runtime API) + StreamingSDKVersion string // 1.0.x (streaming API) OSType string // darwin/windows/linux - OSVersion string // 10.0.22621 - NodeVersion string // 18.x/20.x/22.x - KiroVersion string // 0.3.x-0.8.x + OSVersion string + NodeVersion string + KiroVersion string KiroHash string // SHA256 - AcceptLanguage string - ScreenResolution string // 1920x1080 - ColorDepth int // 24 - HardwareConcurrency int // CPU 核心数 - TimezoneOffset int } -// FingerprintManager 指纹管理器 +// FingerprintConfig holds external fingerprint overrides. +type FingerprintConfig struct { + OIDCSDKVersion string + RuntimeSDKVersion string + StreamingSDKVersion string + OSType string + OSVersion string + NodeVersion string + KiroVersion string + KiroHash string +} + +// FingerprintManager manages per-account fingerprint generation and caching. type FingerprintManager struct { mu sync.RWMutex fingerprints map[string]*Fingerprint // tokenKey -> fingerprint rng *rand.Rand + config *FingerprintConfig // External config (Optional) } var ( - sdkVersions = []string{ - "1.0.20", "1.0.21", "1.0.22", "1.0.23", - "1.0.24", "1.0.25", "1.0.26", "1.0.27", + // SDK versions + oidcSDKVersions = []string{ + "3.980.0", "3.975.0", "3.972.0", "3.808.0", + "3.738.0", "3.737.0", "3.736.0", "3.735.0", } + // SDKVersions for getUsageLimits/ListAvailableModels/GetProfile (runtime API) + runtimeSDKVersions = []string{"1.0.0"} + // SDKVersions for generateAssistantResponse (streaming API) + streamingSDKVersions = []string{"1.0.27"} + // Valid OS types osTypes = []string{"darwin", "windows", "linux"} + // OS versions osVersions = map[string][]string{ - "darwin": {"14.0", "14.1", "14.2", "14.3", "14.4", "14.5", "15.0", "15.1"}, - "windows": {"10.0.19041", "10.0.19042", "10.0.19043", "10.0.19044", "10.0.22621", "10.0.22631"}, - "linux": {"5.15.0", "6.1.0", "6.2.0", "6.5.0", "6.6.0", "6.8.0"}, + "darwin": {"25.2.0", "25.1.0", "25.0.0", "24.5.0", "24.4.0", "24.3.0"}, + "windows": {"10.0.26200", "10.0.26100", "10.0.22631", "10.0.22621", "10.0.19045"}, + "linux": {"6.12.0", "6.11.0", "6.8.0", "6.6.0", "6.5.0", "6.1.0"}, } + // Node versions nodeVersions = []string{ - "18.17.0", "18.18.0", "18.19.0", "18.20.0", - "20.9.0", "20.10.0", "20.11.0", "20.12.0", "20.13.0", - "22.0.0", "22.1.0", "22.2.0", "22.3.0", + "22.21.1", "22.21.0", "22.20.0", "22.19.0", "22.18.0", + "20.18.0", "20.17.0", "20.16.0", } + // Kiro IDE versions kiroVersions = []string{ - "0.3.0", "0.3.1", "0.4.0", "0.4.1", "0.5.0", "0.5.1", - "0.6.0", "0.6.1", "0.7.0", "0.7.1", "0.8.0", "0.8.1", + "0.10.32", "0.10.16", "0.10.10", + "0.9.47", "0.9.40", "0.9.2", + "0.8.206", "0.8.140", "0.8.135", "0.8.86", } - acceptLanguages = []string{ - "en-US,en;q=0.9", - "en-GB,en;q=0.9", - "zh-CN,zh;q=0.9,en;q=0.8", - "zh-TW,zh;q=0.9,en;q=0.8", - "ja-JP,ja;q=0.9,en;q=0.8", - "ko-KR,ko;q=0.9,en;q=0.8", - "de-DE,de;q=0.9,en;q=0.8", - "fr-FR,fr;q=0.9,en;q=0.8", - } - screenResolutions = []string{ - "1920x1080", "2560x1440", "3840x2160", - "1366x768", "1440x900", "1680x1050", - "2560x1600", "3440x1440", - } - colorDepths = []int{24, 32} - hardwareConcurrencies = []int{4, 6, 8, 10, 12, 16, 20, 24, 32} - timezoneOffsets = []int{-480, -420, -360, -300, -240, 0, 60, 120, 480, 540} + // Global singleton + globalFingerprintManager *FingerprintManager + globalFingerprintManagerOnce sync.Once ) -// NewFingerprintManager 创建指纹管理器 +func GlobalFingerprintManager() *FingerprintManager { + globalFingerprintManagerOnce.Do(func() { + globalFingerprintManager = NewFingerprintManager() + }) + return globalFingerprintManager +} + +func SetGlobalFingerprintConfig(cfg *FingerprintConfig) { + GlobalFingerprintManager().SetConfig(cfg) +} + +// SetConfig applies the config and clears the fingerprint cache. +func (fm *FingerprintManager) SetConfig(cfg *FingerprintConfig) { + fm.mu.Lock() + defer fm.mu.Unlock() + fm.config = cfg + // Clear cached fingerprints so they regenerate with the new config + fm.fingerprints = make(map[string]*Fingerprint) +} + func NewFingerprintManager() *FingerprintManager { return &FingerprintManager{ fingerprints: make(map[string]*Fingerprint), @@ -80,7 +108,7 @@ func NewFingerprintManager() *FingerprintManager { } } -// GetFingerprint 获取或生成 Token 关联的指纹 +// GetFingerprint returns the fingerprint for tokenKey, creating one if it doesn't exist. func (fm *FingerprintManager) GetFingerprint(tokenKey string) *Fingerprint { fm.mu.RLock() if fp, exists := fm.fingerprints[tokenKey]; exists { @@ -101,97 +129,150 @@ func (fm *FingerprintManager) GetFingerprint(tokenKey string) *Fingerprint { return fp } -// generateFingerprint 生成新的指纹 func (fm *FingerprintManager) generateFingerprint(tokenKey string) *Fingerprint { - osType := fm.randomChoice(osTypes) - osVersion := fm.randomChoice(osVersions[osType]) - kiroVersion := fm.randomChoice(kiroVersions) + if fm.config != nil { + return fm.generateFromConfig(tokenKey) + } + return fm.generateRandom(tokenKey) +} - fp := &Fingerprint{ - SDKVersion: fm.randomChoice(sdkVersions), - OSType: osType, - OSVersion: osVersion, - NodeVersion: fm.randomChoice(nodeVersions), - KiroVersion: kiroVersion, - AcceptLanguage: fm.randomChoice(acceptLanguages), - ScreenResolution: fm.randomChoice(screenResolutions), - ColorDepth: fm.randomIntChoice(colorDepths), - HardwareConcurrency: fm.randomIntChoice(hardwareConcurrencies), - TimezoneOffset: fm.randomIntChoice(timezoneOffsets), +// generateFromConfig uses config values, falling back to random for empty fields. +func (fm *FingerprintManager) generateFromConfig(tokenKey string) *Fingerprint { + cfg := fm.config + + // Helper: config value or random selection + configOrRandom := func(configVal string, choices []string) string { + if configVal != "" { + return configVal + } + return choices[fm.rng.Intn(len(choices))] } - fp.KiroHash = fm.generateKiroHash(tokenKey, kiroVersion, osType) - return fp + osType := cfg.OSType + if osType == "" { + osType = runtime.GOOS + if !slices.Contains(osTypes, osType) { + osType = osTypes[fm.rng.Intn(len(osTypes))] + } + } + + osVersion := cfg.OSVersion + if osVersion == "" { + if versions, ok := osVersions[osType]; ok { + osVersion = versions[fm.rng.Intn(len(versions))] + } + } + + kiroHash := cfg.KiroHash + if kiroHash == "" { + hash := sha256.Sum256([]byte(tokenKey)) + kiroHash = hex.EncodeToString(hash[:]) + } + + return &Fingerprint{ + OIDCSDKVersion: configOrRandom(cfg.OIDCSDKVersion, oidcSDKVersions), + RuntimeSDKVersion: configOrRandom(cfg.RuntimeSDKVersion, runtimeSDKVersions), + StreamingSDKVersion: configOrRandom(cfg.StreamingSDKVersion, streamingSDKVersions), + OSType: osType, + OSVersion: osVersion, + NodeVersion: configOrRandom(cfg.NodeVersion, nodeVersions), + KiroVersion: configOrRandom(cfg.KiroVersion, kiroVersions), + KiroHash: kiroHash, + } } -// generateKiroHash 生成 Kiro Hash -func (fm *FingerprintManager) generateKiroHash(tokenKey, kiroVersion, osType string) string { - data := fmt.Sprintf("%s:%s:%s:%d", tokenKey, kiroVersion, osType, time.Now().UnixNano()) - hash := sha256.Sum256([]byte(data)) - return hex.EncodeToString(hash[:]) +// generateRandom generates a deterministic fingerprint seeded by accountKey hash. +func (fm *FingerprintManager) generateRandom(accountKey string) *Fingerprint { + // Use accountKey hash as seed for deterministic random selection + hash := sha256.Sum256([]byte(accountKey)) + seed := int64(binary.BigEndian.Uint64(hash[:8])) + rng := rand.New(rand.NewSource(seed)) + + osType := runtime.GOOS + if !slices.Contains(osTypes, osType) { + osType = osTypes[rng.Intn(len(osTypes))] + } + osVersion := osVersions[osType][rng.Intn(len(osVersions[osType]))] + + return &Fingerprint{ + OIDCSDKVersion: oidcSDKVersions[rng.Intn(len(oidcSDKVersions))], + RuntimeSDKVersion: runtimeSDKVersions[rng.Intn(len(runtimeSDKVersions))], + StreamingSDKVersion: streamingSDKVersions[rng.Intn(len(streamingSDKVersions))], + OSType: osType, + OSVersion: osVersion, + NodeVersion: nodeVersions[rng.Intn(len(nodeVersions))], + KiroVersion: kiroVersions[rng.Intn(len(kiroVersions))], + KiroHash: hex.EncodeToString(hash[:]), + } } -// randomChoice 随机选择字符串 -func (fm *FingerprintManager) randomChoice(choices []string) string { - return choices[fm.rng.Intn(len(choices))] +// GenerateAccountKey returns a 16-char hex key derived from SHA256(seed). +func GenerateAccountKey(seed string) string { + hash := sha256.Sum256([]byte(seed)) + return hex.EncodeToString(hash[:8]) } -// randomIntChoice 随机选择整数 -func (fm *FingerprintManager) randomIntChoice(choices []int) int { - return choices[fm.rng.Intn(len(choices))] +// GetAccountKey derives an account key from clientID > refreshToken > random UUID. +func GetAccountKey(clientID, refreshToken string) string { + // 1. Prefer ClientID + if clientID != "" { + return GenerateAccountKey(clientID) + } + + // 2. Fallback to RefreshToken + if refreshToken != "" { + return GenerateAccountKey(refreshToken) + } + + // 3. Random fallback + return GenerateAccountKey(uuid.New().String()) } -// ApplyToRequest 将指纹信息应用到 HTTP 请求头 -func (fp *Fingerprint) ApplyToRequest(req *http.Request) { - req.Header.Set("X-Kiro-SDK-Version", fp.SDKVersion) - req.Header.Set("X-Kiro-OS-Type", fp.OSType) - req.Header.Set("X-Kiro-OS-Version", fp.OSVersion) - req.Header.Set("X-Kiro-Node-Version", fp.NodeVersion) - req.Header.Set("X-Kiro-Version", fp.KiroVersion) - req.Header.Set("X-Kiro-Hash", fp.KiroHash) - req.Header.Set("Accept-Language", fp.AcceptLanguage) - req.Header.Set("X-Screen-Resolution", fp.ScreenResolution) - req.Header.Set("X-Color-Depth", fmt.Sprintf("%d", fp.ColorDepth)) - req.Header.Set("X-Hardware-Concurrency", fmt.Sprintf("%d", fp.HardwareConcurrency)) - req.Header.Set("X-Timezone-Offset", fmt.Sprintf("%d", fp.TimezoneOffset)) -} - -// RemoveFingerprint 移除 Token 关联的指纹 -func (fm *FingerprintManager) RemoveFingerprint(tokenKey string) { - fm.mu.Lock() - defer fm.mu.Unlock() - delete(fm.fingerprints, tokenKey) -} - -// Count 返回当前管理的指纹数量 -func (fm *FingerprintManager) Count() int { - fm.mu.RLock() - defer fm.mu.RUnlock() - return len(fm.fingerprints) -} - -// BuildUserAgent 构建 User-Agent 字符串 (Kiro IDE 风格) -// 格式: aws-sdk-js/{SDKVersion} ua/2.1 os/{OSType}#{OSVersion} lang/js md/nodejs#{NodeVersion} api/codewhispererstreaming#{SDKVersion} m/E KiroIDE-{KiroVersion}-{KiroHash} +// BuildUserAgent format: aws-sdk-js/{SDKVersion} ua/2.1 os/{OSType}#{OSVersion} lang/js md/nodejs#{NodeVersion} api/codewhispererstreaming#{SDKVersion} m/E KiroIDE-{KiroVersion}-{KiroHash} func (fp *Fingerprint) BuildUserAgent() string { return fmt.Sprintf( "aws-sdk-js/%s ua/2.1 os/%s#%s lang/js md/nodejs#%s api/codewhispererstreaming#%s m/E KiroIDE-%s-%s", - fp.SDKVersion, + fp.StreamingSDKVersion, fp.OSType, fp.OSVersion, fp.NodeVersion, - fp.SDKVersion, + fp.StreamingSDKVersion, fp.KiroVersion, fp.KiroHash, ) } -// BuildAmzUserAgent 构建 X-Amz-User-Agent 字符串 -// 格式: aws-sdk-js/{SDKVersion} KiroIDE-{KiroVersion}-{KiroHash} +// BuildAmzUserAgent format: aws-sdk-js/{SDKVersion} KiroIDE-{KiroVersion}-{KiroHash} func (fp *Fingerprint) BuildAmzUserAgent() string { return fmt.Sprintf( "aws-sdk-js/%s KiroIDE-%s-%s", - fp.SDKVersion, + fp.StreamingSDKVersion, fp.KiroVersion, fp.KiroHash, ) } + +func SetOIDCHeaders(req *http.Request) { + fp := GlobalFingerprintManager().GetFingerprint("oidc-session") + req.Header.Set("Content-Type", "application/json") + req.Header.Set("x-amz-user-agent", fmt.Sprintf("aws-sdk-js/%s KiroIDE", fp.OIDCSDKVersion)) + req.Header.Set("User-Agent", fmt.Sprintf( + "aws-sdk-js/%s ua/2.1 os/%s#%s lang/js md/nodejs#%s api/%s#%s m/E KiroIDE", + fp.OIDCSDKVersion, fp.OSType, fp.OSVersion, fp.NodeVersion, "sso-oidc", fp.OIDCSDKVersion)) + req.Header.Set("amz-sdk-invocation-id", uuid.New().String()) + req.Header.Set("amz-sdk-request", "attempt=1; max=4") +} + +func setRuntimeHeaders(req *http.Request, accessToken string, accountKey string) { + fp := GlobalFingerprintManager().GetFingerprint(accountKey) + machineID := fp.KiroHash + req.Header.Set("Authorization", "Bearer "+accessToken) + req.Header.Set("x-amz-user-agent", fmt.Sprintf("aws-sdk-js/%s KiroIDE-%s-%s", + fp.RuntimeSDKVersion, fp.KiroVersion, machineID)) + req.Header.Set("User-Agent", fmt.Sprintf( + "aws-sdk-js/%s ua/2.1 os/%s#%s lang/js md/nodejs#%s api/codewhispererruntime#%s m/N,E KiroIDE-%s-%s", + fp.RuntimeSDKVersion, fp.OSType, fp.OSVersion, fp.NodeVersion, fp.RuntimeSDKVersion, + fp.KiroVersion, machineID)) + req.Header.Set("amz-sdk-invocation-id", uuid.New().String()) + req.Header.Set("amz-sdk-request", "attempt=1; max=1") +} diff --git a/internal/auth/kiro/fingerprint_test.go b/internal/auth/kiro/fingerprint_test.go index e0ae51f2..0ac1b36e 100644 --- a/internal/auth/kiro/fingerprint_test.go +++ b/internal/auth/kiro/fingerprint_test.go @@ -2,6 +2,8 @@ package kiro import ( "net/http" + "runtime" + "strings" "sync" "testing" ) @@ -26,8 +28,14 @@ func TestGetFingerprint_NewToken(t *testing.T) { if fp == nil { t.Fatal("expected non-nil Fingerprint") } - if fp.SDKVersion == "" { - t.Error("expected non-empty SDKVersion") + if fp.OIDCSDKVersion == "" { + t.Error("expected non-empty OIDCSDKVersion") + } + if fp.RuntimeSDKVersion == "" { + t.Error("expected non-empty RuntimeSDKVersion") + } + if fp.StreamingSDKVersion == "" { + t.Error("expected non-empty StreamingSDKVersion") } if fp.OSType == "" { t.Error("expected non-empty OSType") @@ -44,18 +52,6 @@ func TestGetFingerprint_NewToken(t *testing.T) { if fp.KiroHash == "" { t.Error("expected non-empty KiroHash") } - if fp.AcceptLanguage == "" { - t.Error("expected non-empty AcceptLanguage") - } - if fp.ScreenResolution == "" { - t.Error("expected non-empty ScreenResolution") - } - if fp.ColorDepth == 0 { - t.Error("expected non-zero ColorDepth") - } - if fp.HardwareConcurrency == 0 { - t.Error("expected non-zero HardwareConcurrency") - } } func TestGetFingerprint_SameTokenReturnsSameFingerprint(t *testing.T) { @@ -78,72 +74,18 @@ func TestGetFingerprint_DifferentTokens(t *testing.T) { } } -func TestRemoveFingerprint(t *testing.T) { - fm := NewFingerprintManager() - fm.GetFingerprint("token1") - if fm.Count() != 1 { - t.Fatalf("expected count 1, got %d", fm.Count()) - } - - fm.RemoveFingerprint("token1") - if fm.Count() != 0 { - t.Errorf("expected count 0, got %d", fm.Count()) - } -} - -func TestRemoveFingerprint_NonExistent(t *testing.T) { - fm := NewFingerprintManager() - fm.RemoveFingerprint("nonexistent") - if fm.Count() != 0 { - t.Errorf("expected count 0, got %d", fm.Count()) - } -} - -func TestCount(t *testing.T) { - fm := NewFingerprintManager() - if fm.Count() != 0 { - t.Errorf("expected count 0, got %d", fm.Count()) - } - - fm.GetFingerprint("token1") - fm.GetFingerprint("token2") - fm.GetFingerprint("token3") - - if fm.Count() != 3 { - t.Errorf("expected count 3, got %d", fm.Count()) - } -} - -func TestApplyToRequest(t *testing.T) { +func TestBuildUserAgent(t *testing.T) { fm := NewFingerprintManager() fp := fm.GetFingerprint("token1") - req, _ := http.NewRequest("GET", "http://example.com", nil) - fp.ApplyToRequest(req) + ua := fp.BuildUserAgent() + if ua == "" { + t.Error("expected non-empty User-Agent") + } - if req.Header.Get("X-Kiro-SDK-Version") != fp.SDKVersion { - t.Error("X-Kiro-SDK-Version header mismatch") - } - if req.Header.Get("X-Kiro-OS-Type") != fp.OSType { - t.Error("X-Kiro-OS-Type header mismatch") - } - if req.Header.Get("X-Kiro-OS-Version") != fp.OSVersion { - t.Error("X-Kiro-OS-Version header mismatch") - } - if req.Header.Get("X-Kiro-Node-Version") != fp.NodeVersion { - t.Error("X-Kiro-Node-Version header mismatch") - } - if req.Header.Get("X-Kiro-Version") != fp.KiroVersion { - t.Error("X-Kiro-Version header mismatch") - } - if req.Header.Get("X-Kiro-Hash") != fp.KiroHash { - t.Error("X-Kiro-Hash header mismatch") - } - if req.Header.Get("Accept-Language") != fp.AcceptLanguage { - t.Error("Accept-Language header mismatch") - } - if req.Header.Get("X-Screen-Resolution") != fp.ScreenResolution { - t.Error("X-Screen-Resolution header mismatch") + amzUA := fp.BuildAmzUserAgent() + if amzUA == "" { + t.Error("expected non-empty X-Amz-User-Agent") } } @@ -166,6 +108,33 @@ func TestGetFingerprint_OSVersionMatchesOSType(t *testing.T) { } } +func TestGenerateFromConfig_OSTypeFromRuntimeGOOS(t *testing.T) { + fm := NewFingerprintManager() + + // Set config with empty OSType to trigger runtime.GOOS fallback + fm.SetConfig(&FingerprintConfig{ + OIDCSDKVersion: "3.738.0", // Set other fields to use config path + }) + + fp := fm.GetFingerprint("test-token") + + // Expected OS type based on runtime.GOOS mapping + var expectedOS string + switch runtime.GOOS { + case "darwin": + expectedOS = "darwin" + case "windows": + expectedOS = "windows" + default: + expectedOS = "linux" + } + + if fp.OSType != expectedOS { + t.Errorf("expected OSType '%s' from runtime.GOOS '%s', got '%s'", + expectedOS, runtime.GOOS, fp.OSType) + } +} + func TestFingerprintManager_ConcurrentAccess(t *testing.T) { fm := NewFingerprintManager() const numGoroutines = 100 @@ -174,22 +143,18 @@ func TestFingerprintManager_ConcurrentAccess(t *testing.T) { var wg sync.WaitGroup wg.Add(numGoroutines) - for i := 0; i < numGoroutines; i++ { + for i := range numGoroutines { go func(id int) { defer wg.Done() - for j := 0; j < numOperations; j++ { + for j := range numOperations { tokenKey := "token" + string(rune('a'+id%26)) - switch j % 4 { + switch j % 2 { case 0: fm.GetFingerprint(tokenKey) case 1: - fm.Count() - case 2: fp := fm.GetFingerprint(tokenKey) - req, _ := http.NewRequest("GET", "http://example.com", nil) - fp.ApplyToRequest(req) - case 3: - fm.RemoveFingerprint(tokenKey) + _ = fp.BuildUserAgent() + _ = fp.BuildAmzUserAgent() } } }(i) @@ -198,16 +163,20 @@ func TestFingerprintManager_ConcurrentAccess(t *testing.T) { wg.Wait() } -func TestKiroHashUniqueness(t *testing.T) { +func TestKiroHashStability(t *testing.T) { fm := NewFingerprintManager() - hashes := make(map[string]bool) - for i := 0; i < 100; i++ { - fp := fm.GetFingerprint("token" + string(rune(i))) - if hashes[fp.KiroHash] { - t.Errorf("duplicate KiroHash detected: %s", fp.KiroHash) - } - hashes[fp.KiroHash] = true + // Same token should always return same hash + fp1 := fm.GetFingerprint("token1") + fp2 := fm.GetFingerprint("token1") + if fp1.KiroHash != fp2.KiroHash { + t.Errorf("same token should have same hash: %s vs %s", fp1.KiroHash, fp2.KiroHash) + } + + // Different tokens should have different hashes + fp3 := fm.GetFingerprint("token2") + if fp1.KiroHash == fp3.KiroHash { + t.Errorf("different tokens should have different hashes") } } @@ -220,8 +189,590 @@ func TestKiroHashFormat(t *testing.T) { } for _, c := range fp.KiroHash { - if !((c >= '0' && c <= '9') || (c >= 'a' && c <= 'f')) { + if (c < '0' || c > '9') && (c < 'a' || c > 'f') { t.Errorf("invalid hex character in KiroHash: %c", c) } } } + +func TestGlobalFingerprintManager(t *testing.T) { + fm1 := GlobalFingerprintManager() + fm2 := GlobalFingerprintManager() + + if fm1 == nil { + t.Fatal("expected non-nil GlobalFingerprintManager") + } + if fm1 != fm2 { + t.Error("expected GlobalFingerprintManager to return same instance") + } +} + +func TestSetOIDCHeaders(t *testing.T) { + req, _ := http.NewRequest("GET", "http://example.com", nil) + SetOIDCHeaders(req) + + if req.Header.Get("Content-Type") != "application/json" { + t.Error("expected Content-Type header to be set") + } + + amzUA := req.Header.Get("x-amz-user-agent") + if amzUA == "" { + t.Error("expected x-amz-user-agent header to be set") + } + if !strings.Contains(amzUA, "aws-sdk-js/") { + t.Errorf("x-amz-user-agent should contain aws-sdk-js: %s", amzUA) + } + if !strings.Contains(amzUA, "KiroIDE") { + t.Errorf("x-amz-user-agent should contain KiroIDE: %s", amzUA) + } + + ua := req.Header.Get("User-Agent") + if ua == "" { + t.Error("expected User-Agent header to be set") + } + if !strings.Contains(ua, "api/sso-oidc") { + t.Errorf("User-Agent should contain api name: %s", ua) + } + + if req.Header.Get("amz-sdk-invocation-id") == "" { + t.Error("expected amz-sdk-invocation-id header to be set") + } + if req.Header.Get("amz-sdk-request") != "attempt=1; max=4" { + t.Errorf("unexpected amz-sdk-request header: %s", req.Header.Get("amz-sdk-request")) + } +} + +func TestBuildURL(t *testing.T) { + tests := []struct { + name string + endpoint string + path string + queryParams map[string]string + want string + wantContains []string + }{ + { + name: "no query params", + endpoint: "https://api.example.com", + path: "getUsageLimits", + queryParams: nil, + want: "https://api.example.com/getUsageLimits", + }, + { + name: "empty query params", + endpoint: "https://api.example.com", + path: "getUsageLimits", + queryParams: map[string]string{}, + want: "https://api.example.com/getUsageLimits", + }, + { + name: "single query param", + endpoint: "https://api.example.com", + path: "getUsageLimits", + queryParams: map[string]string{ + "origin": "AI_EDITOR", + }, + want: "https://api.example.com/getUsageLimits?origin=AI_EDITOR", + }, + { + name: "multiple query params", + endpoint: "https://api.example.com", + path: "getUsageLimits", + queryParams: map[string]string{ + "origin": "AI_EDITOR", + "resourceType": "AGENTIC_REQUEST", + "profileArn": "arn:aws:codewhisperer:us-east-1:123456789012:profile/ABCDEF", + }, + wantContains: []string{ + "https://api.example.com/getUsageLimits?", + "origin=AI_EDITOR", + "profileArn=arn%3Aaws%3Acodewhisperer%3Aus-east-1%3A123456789012%3Aprofile%2FABCDEF", + "resourceType=AGENTIC_REQUEST", + }, + }, + { + name: "omit empty params", + endpoint: "https://api.example.com", + path: "getUsageLimits", + queryParams: map[string]string{ + "origin": "AI_EDITOR", + "profileArn": "", + }, + want: "https://api.example.com/getUsageLimits?origin=AI_EDITOR", + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + got := buildURL(tt.endpoint, tt.path, tt.queryParams) + if tt.want != "" { + if got != tt.want { + t.Errorf("buildURL() = %v, want %v", got, tt.want) + } + } + if tt.wantContains != nil { + for _, substr := range tt.wantContains { + if !strings.Contains(got, substr) { + t.Errorf("buildURL() = %v, want to contain %v", got, substr) + } + } + } + }) + } +} + +func TestBuildUserAgentFormat(t *testing.T) { + fm := NewFingerprintManager() + fp := fm.GetFingerprint("token1") + + ua := fp.BuildUserAgent() + requiredParts := []string{ + "aws-sdk-js/", + "ua/2.1", + "os/", + "lang/js", + "md/nodejs#", + "api/codewhispererstreaming#", + "m/E", + "KiroIDE-", + } + for _, part := range requiredParts { + if !strings.Contains(ua, part) { + t.Errorf("User-Agent missing required part %q: %s", part, ua) + } + } +} + +func TestBuildAmzUserAgentFormat(t *testing.T) { + fm := NewFingerprintManager() + fp := fm.GetFingerprint("token1") + + amzUA := fp.BuildAmzUserAgent() + requiredParts := []string{ + "aws-sdk-js/", + "KiroIDE-", + } + for _, part := range requiredParts { + if !strings.Contains(amzUA, part) { + t.Errorf("X-Amz-User-Agent missing required part %q: %s", part, amzUA) + } + } + + // Amz-User-Agent should be shorter than User-Agent + ua := fp.BuildUserAgent() + if len(amzUA) >= len(ua) { + t.Error("X-Amz-User-Agent should be shorter than User-Agent") + } +} + +func TestSetRuntimeHeaders(t *testing.T) { + req, _ := http.NewRequest("GET", "http://example.com", nil) + accessToken := "test-access-token-1234567890" + clientID := "test-client-id-12345" + accountKey := GenerateAccountKey(clientID) + fp := GlobalFingerprintManager().GetFingerprint(accountKey) + machineID := fp.KiroHash + + setRuntimeHeaders(req, accessToken, accountKey) + + // Check Authorization header + if req.Header.Get("Authorization") != "Bearer "+accessToken { + t.Errorf("expected Authorization header 'Bearer %s', got '%s'", accessToken, req.Header.Get("Authorization")) + } + + // Check x-amz-user-agent header + amzUA := req.Header.Get("x-amz-user-agent") + if amzUA == "" { + t.Error("expected x-amz-user-agent header to be set") + } + if !strings.Contains(amzUA, "aws-sdk-js/") { + t.Errorf("x-amz-user-agent should contain aws-sdk-js: %s", amzUA) + } + if !strings.Contains(amzUA, "KiroIDE-") { + t.Errorf("x-amz-user-agent should contain KiroIDE: %s", amzUA) + } + if !strings.Contains(amzUA, machineID) { + t.Errorf("x-amz-user-agent should contain machineID: %s", amzUA) + } + + // Check User-Agent header + ua := req.Header.Get("User-Agent") + if ua == "" { + t.Error("expected User-Agent header to be set") + } + if !strings.Contains(ua, "api/codewhispererruntime#") { + t.Errorf("User-Agent should contain api/codewhispererruntime: %s", ua) + } + if !strings.Contains(ua, "m/N,E") { + t.Errorf("User-Agent should contain m/N,E: %s", ua) + } + + // Check amz-sdk-invocation-id (should be a UUID) + invocationID := req.Header.Get("amz-sdk-invocation-id") + if invocationID == "" { + t.Error("expected amz-sdk-invocation-id header to be set") + } + if len(invocationID) != 36 { + t.Errorf("expected amz-sdk-invocation-id to be UUID (36 chars), got %d", len(invocationID)) + } + + // Check amz-sdk-request + if req.Header.Get("amz-sdk-request") != "attempt=1; max=1" { + t.Errorf("unexpected amz-sdk-request header: %s", req.Header.Get("amz-sdk-request")) + } +} + +func TestSDKVersionsAreValid(t *testing.T) { + // Verify all OIDC SDK versions match expected format (3.xxx.x) + for _, v := range oidcSDKVersions { + if !strings.HasPrefix(v, "3.") { + t.Errorf("OIDC SDK version should start with 3.: %s", v) + } + parts := strings.Split(v, ".") + if len(parts) != 3 { + t.Errorf("OIDC SDK version should have 3 parts: %s", v) + } + } + + for _, v := range runtimeSDKVersions { + parts := strings.Split(v, ".") + if len(parts) != 3 { + t.Errorf("Runtime SDK version should have 3 parts: %s", v) + } + } + + for _, v := range streamingSDKVersions { + parts := strings.Split(v, ".") + if len(parts) != 3 { + t.Errorf("Streaming SDK version should have 3 parts: %s", v) + } + } +} + +func TestKiroVersionsAreValid(t *testing.T) { + // Verify all Kiro versions match expected format (0.x.xxx) + for _, v := range kiroVersions { + if !strings.HasPrefix(v, "0.") { + t.Errorf("Kiro version should start with 0.: %s", v) + } + parts := strings.Split(v, ".") + if len(parts) != 3 { + t.Errorf("Kiro version should have 3 parts: %s", v) + } + } +} + +func TestNodeVersionsAreValid(t *testing.T) { + // Verify all Node versions match expected format (xx.xx.x) + for _, v := range nodeVersions { + parts := strings.Split(v, ".") + if len(parts) != 3 { + t.Errorf("Node version should have 3 parts: %s", v) + } + // Should be Node 20.x or 22.x + if !strings.HasPrefix(v, "20.") && !strings.HasPrefix(v, "22.") { + t.Errorf("Node version should be 20.x or 22.x LTS: %s", v) + } + } +} + +func TestFingerprintManager_SetConfig(t *testing.T) { + fm := NewFingerprintManager() + + // Without config, should generate random fingerprint + fp1 := fm.GetFingerprint("token1") + if fp1 == nil { + t.Fatal("expected non-nil fingerprint") + } + + // Set config with all fields + cfg := &FingerprintConfig{ + OIDCSDKVersion: "3.999.0", + RuntimeSDKVersion: "9.9.9", + StreamingSDKVersion: "8.8.8", + OSType: "darwin", + OSVersion: "99.0.0", + NodeVersion: "99.99.99", + KiroVersion: "9.9.999", + KiroHash: "customhash123", + } + fm.SetConfig(cfg) + + // After setting config, should use config values + fp2 := fm.GetFingerprint("token2") + if fp2.OIDCSDKVersion != "3.999.0" { + t.Errorf("expected OIDCSDKVersion '3.999.0', got '%s'", fp2.OIDCSDKVersion) + } + if fp2.RuntimeSDKVersion != "9.9.9" { + t.Errorf("expected RuntimeSDKVersion '9.9.9', got '%s'", fp2.RuntimeSDKVersion) + } + if fp2.StreamingSDKVersion != "8.8.8" { + t.Errorf("expected StreamingSDKVersion '8.8.8', got '%s'", fp2.StreamingSDKVersion) + } + if fp2.OSType != "darwin" { + t.Errorf("expected OSType 'darwin', got '%s'", fp2.OSType) + } + if fp2.OSVersion != "99.0.0" { + t.Errorf("expected OSVersion '99.0.0', got '%s'", fp2.OSVersion) + } + if fp2.NodeVersion != "99.99.99" { + t.Errorf("expected NodeVersion '99.99.99', got '%s'", fp2.NodeVersion) + } + if fp2.KiroVersion != "9.9.999" { + t.Errorf("expected KiroVersion '9.9.999', got '%s'", fp2.KiroVersion) + } + if fp2.KiroHash != "customhash123" { + t.Errorf("expected KiroHash 'customhash123', got '%s'", fp2.KiroHash) + } +} + +func TestFingerprintManager_SetConfig_PartialFields(t *testing.T) { + fm := NewFingerprintManager() + + // Set config with only some fields + cfg := &FingerprintConfig{ + KiroVersion: "1.2.345", + KiroHash: "myhash", + // Other fields empty - should use random + } + fm.SetConfig(cfg) + + fp := fm.GetFingerprint("token1") + + // Configured fields should use config values + if fp.KiroVersion != "1.2.345" { + t.Errorf("expected KiroVersion '1.2.345', got '%s'", fp.KiroVersion) + } + if fp.KiroHash != "myhash" { + t.Errorf("expected KiroHash 'myhash', got '%s'", fp.KiroHash) + } + + // Empty fields should be randomly selected (non-empty) + if fp.OIDCSDKVersion == "" { + t.Error("expected non-empty OIDCSDKVersion") + } + if fp.OSType == "" { + t.Error("expected non-empty OSType") + } + if fp.NodeVersion == "" { + t.Error("expected non-empty NodeVersion") + } +} + +func TestFingerprintManager_SetConfig_ClearsCache(t *testing.T) { + fm := NewFingerprintManager() + + // Get fingerprint before config + fp1 := fm.GetFingerprint("token1") + originalHash := fp1.KiroHash + + // Set config + cfg := &FingerprintConfig{ + KiroHash: "newcustomhash", + } + fm.SetConfig(cfg) + + // Same token should now return different fingerprint (cache cleared) + fp2 := fm.GetFingerprint("token1") + if fp2.KiroHash == originalHash { + t.Error("expected cache to be cleared after SetConfig") + } + if fp2.KiroHash != "newcustomhash" { + t.Errorf("expected KiroHash 'newcustomhash', got '%s'", fp2.KiroHash) + } +} + +func TestGenerateAccountKey(t *testing.T) { + tests := []struct { + name string + seed string + check func(t *testing.T, result string) + }{ + { + name: "Empty seed", + seed: "", + check: func(t *testing.T, result string) { + if result == "" { + t.Error("expected non-empty result for empty seed") + } + if len(result) != 16 { + t.Errorf("expected 16 char hex string, got %d chars", len(result)) + } + }, + }, + { + name: "Simple seed", + seed: "test-client-id", + check: func(t *testing.T, result string) { + if len(result) != 16 { + t.Errorf("expected 16 char hex string, got %d chars", len(result)) + } + // Verify it's valid hex + for _, c := range result { + if (c < '0' || c > '9') && (c < 'a' || c > 'f') { + t.Errorf("invalid hex character: %c", c) + } + } + }, + }, + { + name: "Same seed produces same result", + seed: "deterministic-seed", + check: func(t *testing.T, result string) { + result2 := GenerateAccountKey("deterministic-seed") + if result != result2 { + t.Errorf("same seed should produce same result: %s vs %s", result, result2) + } + }, + }, + { + name: "Different seeds produce different results", + seed: "seed-one", + check: func(t *testing.T, result string) { + result2 := GenerateAccountKey("seed-two") + if result == result2 { + t.Errorf("different seeds should produce different results: %s vs %s", result, result2) + } + }, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + result := GenerateAccountKey(tt.seed) + tt.check(t, result) + }) + } +} + +func TestGetAccountKey(t *testing.T) { + tests := []struct { + name string + clientID string + refreshToken string + check func(t *testing.T, result string) + }{ + { + name: "Priority 1: clientID when both provided", + clientID: "client-id-123", + refreshToken: "refresh-token-456", + check: func(t *testing.T, result string) { + expected := GenerateAccountKey("client-id-123") + if result != expected { + t.Errorf("expected clientID-based key %s, got %s", expected, result) + } + }, + }, + { + name: "Priority 2: refreshToken when clientID is empty", + clientID: "", + refreshToken: "refresh-token-789", + check: func(t *testing.T, result string) { + expected := GenerateAccountKey("refresh-token-789") + if result != expected { + t.Errorf("expected refreshToken-based key %s, got %s", expected, result) + } + }, + }, + { + name: "Priority 3: random when both empty", + clientID: "", + refreshToken: "", + check: func(t *testing.T, result string) { + if len(result) != 16 { + t.Errorf("expected 16 char key, got %d chars", len(result)) + } + // Should be different each time (random UUID) + result2 := GetAccountKey("", "") + if result == result2 { + t.Log("warning: random keys are the same (possible but unlikely)") + } + }, + }, + { + name: "clientID only", + clientID: "solo-client-id", + refreshToken: "", + check: func(t *testing.T, result string) { + expected := GenerateAccountKey("solo-client-id") + if result != expected { + t.Errorf("expected clientID-based key %s, got %s", expected, result) + } + }, + }, + { + name: "refreshToken only", + clientID: "", + refreshToken: "solo-refresh-token", + check: func(t *testing.T, result string) { + expected := GenerateAccountKey("solo-refresh-token") + if result != expected { + t.Errorf("expected refreshToken-based key %s, got %s", expected, result) + } + }, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + result := GetAccountKey(tt.clientID, tt.refreshToken) + tt.check(t, result) + }) + } +} + +func TestGetAccountKey_Deterministic(t *testing.T) { + // Verify that GetAccountKey produces deterministic results for same inputs + clientID := "test-client-id-abc" + refreshToken := "test-refresh-token-xyz" + + // Call multiple times with same inputs + results := make([]string, 10) + for i := range 10 { + results[i] = GetAccountKey(clientID, refreshToken) + } + + // All results should be identical + for i := 1; i < 10; i++ { + if results[i] != results[0] { + t.Errorf("GetAccountKey should be deterministic: got %s and %s", results[0], results[i]) + } + } +} + +func TestFingerprintDeterministic(t *testing.T) { + // Verify that fingerprints are deterministic based on accountKey + fm := NewFingerprintManager() + + accountKey := GenerateAccountKey("test-client-id") + + // Get fingerprint multiple times + fp1 := fm.GetFingerprint(accountKey) + fp2 := fm.GetFingerprint(accountKey) + + // Should be the same pointer (cached) + if fp1 != fp2 { + t.Error("expected same fingerprint pointer for same key") + } + + // Create new manager and verify same values + fm2 := NewFingerprintManager() + fp3 := fm2.GetFingerprint(accountKey) + + // Values should be identical (deterministic generation) + if fp1.KiroHash != fp3.KiroHash { + t.Errorf("KiroHash should be deterministic: %s vs %s", fp1.KiroHash, fp3.KiroHash) + } + if fp1.OSType != fp3.OSType { + t.Errorf("OSType should be deterministic: %s vs %s", fp1.OSType, fp3.OSType) + } + if fp1.OSVersion != fp3.OSVersion { + t.Errorf("OSVersion should be deterministic: %s vs %s", fp1.OSVersion, fp3.OSVersion) + } + if fp1.KiroVersion != fp3.KiroVersion { + t.Errorf("KiroVersion should be deterministic: %s vs %s", fp1.KiroVersion, fp3.KiroVersion) + } + if fp1.NodeVersion != fp3.NodeVersion { + t.Errorf("NodeVersion should be deterministic: %s vs %s", fp1.NodeVersion, fp3.NodeVersion) + } +} diff --git a/internal/auth/kiro/oauth.go b/internal/auth/kiro/oauth.go index a286cf42..5c020de6 100644 --- a/internal/auth/kiro/oauth.go +++ b/internal/auth/kiro/oauth.go @@ -23,10 +23,10 @@ import ( const ( // Kiro auth endpoint kiroAuthEndpoint = "https://prod.us-east-1.auth.desktop.kiro.dev" - + // Default callback port defaultCallbackPort = 9876 - + // Auth timeout authTimeout = 10 * time.Minute ) @@ -41,8 +41,10 @@ type KiroTokenResponse struct { // KiroOAuth handles the OAuth flow for Kiro authentication. type KiroOAuth struct { - httpClient *http.Client - cfg *config.Config + httpClient *http.Client + cfg *config.Config + machineID string + kiroVersion string } // NewKiroOAuth creates a new Kiro OAuth handler. @@ -51,9 +53,12 @@ func NewKiroOAuth(cfg *config.Config) *KiroOAuth { if cfg != nil { client = util.SetProxy(&cfg.SDKConfig, client) } + fp := GlobalFingerprintManager().GetFingerprint("login") return &KiroOAuth{ - httpClient: client, - cfg: cfg, + httpClient: client, + cfg: cfg, + machineID: fp.KiroHash, + kiroVersion: fp.KiroVersion, } } @@ -190,7 +195,8 @@ func (o *KiroOAuth) exchangeCodeForToken(ctx context.Context, code, codeVerifier } req.Header.Set("Content-Type", "application/json") - req.Header.Set("User-Agent", "KiroIDE-0.7.45-cli-proxy-api") + req.Header.Set("User-Agent", fmt.Sprintf("KiroIDE-%s-%s", o.kiroVersion, o.machineID)) + req.Header.Set("Accept", "application/json, text/plain, */*") resp, err := o.httpClient.Do(req) if err != nil { @@ -256,11 +262,8 @@ func (o *KiroOAuth) RefreshTokenWithFingerprint(ctx context.Context, refreshToke } req.Header.Set("Content-Type", "application/json") - - // Use KiroIDE-style User-Agent to match official Kiro IDE behavior - // This helps avoid 403 errors from server-side User-Agent validation - userAgent := buildKiroUserAgent(tokenKey) - req.Header.Set("User-Agent", userAgent) + req.Header.Set("User-Agent", fmt.Sprintf("KiroIDE-%s-%s", o.kiroVersion, o.machineID)) + req.Header.Set("Accept", "application/json, text/plain, */*") resp, err := o.httpClient.Do(req) if err != nil { @@ -301,19 +304,6 @@ func (o *KiroOAuth) RefreshTokenWithFingerprint(ctx context.Context, refreshToke }, nil } -// buildKiroUserAgent builds a KiroIDE-style User-Agent string. -// If tokenKey is provided, uses fingerprint manager for consistent fingerprint. -// Otherwise generates a simple KiroIDE User-Agent. -func buildKiroUserAgent(tokenKey string) string { - if tokenKey != "" { - fm := NewFingerprintManager() - fp := fm.GetFingerprint(tokenKey) - return fmt.Sprintf("KiroIDE-%s-%s", fp.KiroVersion, fp.KiroHash[:16]) - } - // Default KiroIDE User-Agent matching kiro-openai-gateway format - return "KiroIDE-0.7.45-cli-proxy-api" -} - // LoginWithGoogle performs OAuth login with Google using Kiro's social auth. // This uses a custom protocol handler (kiro://) to receive the callback. func (o *KiroOAuth) LoginWithGoogle(ctx context.Context) (*KiroTokenData, error) { diff --git a/internal/auth/kiro/oauth_web.go b/internal/auth/kiro/oauth_web.go index 88fba672..7db8ec36 100644 --- a/internal/auth/kiro/oauth_web.go +++ b/internal/auth/kiro/oauth_web.go @@ -35,35 +35,35 @@ const ( ) type webAuthSession struct { - stateID string - deviceCode string - userCode string - authURL string - verificationURI string - expiresIn int - interval int - status authSessionStatus - startedAt time.Time - completedAt time.Time - expiresAt time.Time - error string - tokenData *KiroTokenData - ssoClient *SSOOIDCClient - clientID string - clientSecret string - region string - cancelFunc context.CancelFunc - authMethod string // "google", "github", "builder-id", "idc" - startURL string // Used for IDC - codeVerifier string // Used for social auth PKCE - codeChallenge string // Used for social auth PKCE + stateID string + deviceCode string + userCode string + authURL string + verificationURI string + expiresIn int + interval int + status authSessionStatus + startedAt time.Time + completedAt time.Time + expiresAt time.Time + error string + tokenData *KiroTokenData + ssoClient *SSOOIDCClient + clientID string + clientSecret string + region string + cancelFunc context.CancelFunc + authMethod string // "google", "github", "builder-id", "idc" + startURL string // Used for IDC + codeVerifier string // Used for social auth PKCE + codeChallenge string // Used for social auth PKCE } type OAuthWebHandler struct { - cfg *config.Config - sessions map[string]*webAuthSession - mu sync.RWMutex - onTokenObtained func(*KiroTokenData) + cfg *config.Config + sessions map[string]*webAuthSession + mu sync.RWMutex + onTokenObtained func(*KiroTokenData) } func NewOAuthWebHandler(cfg *config.Config) *OAuthWebHandler { @@ -104,7 +104,7 @@ func (h *OAuthWebHandler) handleSelect(c *gin.Context) { func (h *OAuthWebHandler) handleStart(c *gin.Context) { method := c.Query("method") - + if method == "" { c.Redirect(http.StatusFound, "/v0/oauth/kiro") return @@ -138,7 +138,7 @@ func (h *OAuthWebHandler) startSocialAuth(c *gin.Context, method string) { } socialClient := NewSocialAuthClient(h.cfg) - + var provider string if method == "google" { provider = string(ProviderGoogle) @@ -373,22 +373,28 @@ func (h *OAuthWebHandler) pollForToken(ctx context.Context, session *webAuthSess } expiresAt := time.Now().Add(time.Duration(tokenResp.ExpiresIn) * time.Second) - profileArn := session.ssoClient.fetchProfileArn(ctx, tokenResp.AccessToken) - email := FetchUserEmailWithFallback(ctx, h.cfg, tokenResp.AccessToken) + + // Fetch profileArn for IDC + var profileArn string + if session.authMethod == "idc" { + profileArn = session.ssoClient.FetchProfileArn(ctx, tokenResp.AccessToken, session.clientID, tokenResp.RefreshToken) + } + + email := FetchUserEmailWithFallback(ctx, h.cfg, tokenResp.AccessToken, session.clientID, tokenResp.RefreshToken) tokenData := &KiroTokenData{ - AccessToken: tokenResp.AccessToken, - RefreshToken: tokenResp.RefreshToken, - ProfileArn: profileArn, - ExpiresAt: expiresAt.Format(time.RFC3339), - AuthMethod: session.authMethod, - Provider: "AWS", - ClientID: session.clientID, - ClientSecret: session.clientSecret, - Email: email, - Region: session.region, - StartURL: session.startURL, - } + AccessToken: tokenResp.AccessToken, + RefreshToken: tokenResp.RefreshToken, + ProfileArn: profileArn, + ExpiresAt: expiresAt.Format(time.RFC3339), + AuthMethod: session.authMethod, + Provider: "AWS", + ClientID: session.clientID, + ClientSecret: session.clientSecret, + Email: email, + Region: session.region, + StartURL: session.startURL, + } h.mu.Lock() session.status = statusSuccess @@ -442,7 +448,7 @@ func (h *OAuthWebHandler) saveTokenToFile(tokenData *KiroTokenData) { fileName := GenerateTokenFileName(tokenData) authFilePath := filepath.Join(authDir, fileName) - + // Convert to storage format and save storage := &KiroTokenStorage{ Type: "kiro", @@ -459,12 +465,12 @@ func (h *OAuthWebHandler) saveTokenToFile(tokenData *KiroTokenData) { StartURL: tokenData.StartURL, Email: tokenData.Email, } - + if err := storage.SaveTokenToFile(authFilePath); err != nil { log.Errorf("OAuth Web: failed to save token to file: %v", err) return } - + log.Infof("OAuth Web: token saved to %s", authFilePath) } diff --git a/internal/auth/kiro/refresh_manager.go b/internal/auth/kiro/refresh_manager.go index 5330c5e1..94b21500 100644 --- a/internal/auth/kiro/refresh_manager.go +++ b/internal/auth/kiro/refresh_manager.go @@ -10,14 +10,14 @@ import ( log "github.com/sirupsen/logrus" ) -// RefreshManager 是后台刷新器的单例管理器 +// RefreshManager is a singleton manager for background token refreshing. type RefreshManager struct { mu sync.Mutex refresher *BackgroundRefresher ctx context.Context cancel context.CancelFunc started bool - onTokenRefreshed func(tokenID string, tokenData *KiroTokenData) // 刷新成功回调 + onTokenRefreshed func(tokenID string, tokenData *KiroTokenData) } var ( @@ -25,7 +25,7 @@ var ( managerOnce sync.Once ) -// GetRefreshManager 获取全局刷新管理器实例 +// GetRefreshManager returns the global RefreshManager singleton. func GetRefreshManager() *RefreshManager { managerOnce.Do(func() { globalRefreshManager = &RefreshManager{} @@ -33,9 +33,7 @@ func GetRefreshManager() *RefreshManager { return globalRefreshManager } -// Initialize 初始化后台刷新器 -// baseDir: token 文件所在的目录 -// cfg: 应用配置 +// Initialize sets up the background refresher. func (m *RefreshManager) Initialize(baseDir string, cfg *config.Config) error { m.mu.Lock() defer m.mu.Unlock() @@ -58,18 +56,16 @@ func (m *RefreshManager) Initialize(baseDir string, cfg *config.Config) error { baseDir = resolvedBaseDir } - // 创建 token 存储库 repo := NewFileTokenRepository(baseDir) - // 创建后台刷新器,配置参数 opts := []RefresherOption{ - WithInterval(time.Minute), // 每分钟检查一次 - WithBatchSize(50), // 每批最多处理 50 个 token - WithConcurrency(10), // 最多 10 个并发刷新 - WithConfig(cfg), // 设置 OAuth 和 SSO 客户端 + WithInterval(time.Minute), + WithBatchSize(50), + WithConcurrency(10), + WithConfig(cfg), } - // 如果已设置回调,传递给 BackgroundRefresher + // Pass callback to BackgroundRefresher if already set if m.onTokenRefreshed != nil { opts = append(opts, WithOnTokenRefreshed(m.onTokenRefreshed)) } @@ -80,7 +76,7 @@ func (m *RefreshManager) Initialize(baseDir string, cfg *config.Config) error { return nil } -// Start 启动后台刷新 +// Start begins background token refreshing. func (m *RefreshManager) Start() { m.mu.Lock() defer m.mu.Unlock() @@ -102,7 +98,7 @@ func (m *RefreshManager) Start() { log.Info("refresh manager: background refresh started") } -// Stop 停止后台刷新 +// Stop halts background token refreshing. func (m *RefreshManager) Stop() { m.mu.Lock() defer m.mu.Unlock() @@ -123,14 +119,14 @@ func (m *RefreshManager) Stop() { log.Info("refresh manager: background refresh stopped") } -// IsRunning 检查后台刷新是否正在运行 +// IsRunning reports whether background refreshing is active. func (m *RefreshManager) IsRunning() bool { m.mu.Lock() defer m.mu.Unlock() return m.started } -// UpdateBaseDir 更新 token 目录(用于运行时配置更改) +// UpdateBaseDir changes the token directory at runtime. func (m *RefreshManager) UpdateBaseDir(baseDir string) { m.mu.Lock() defer m.mu.Unlock() @@ -143,16 +139,15 @@ func (m *RefreshManager) UpdateBaseDir(baseDir string) { } } -// SetOnTokenRefreshed 设置 token 刷新成功后的回调函数 -// 可以在任何时候调用,支持运行时更新回调 -// callback: 回调函数,接收 tokenID(文件名)和新的 token 数据 +// SetOnTokenRefreshed registers a callback invoked after a successful token refresh. +// Can be called at any time; supports runtime callback updates. func (m *RefreshManager) SetOnTokenRefreshed(callback func(tokenID string, tokenData *KiroTokenData)) { m.mu.Lock() defer m.mu.Unlock() m.onTokenRefreshed = callback - // 如果 refresher 已经创建,使用并发安全的方式更新它的回调 + // Update the refresher's callback in a thread-safe manner if already created if m.refresher != nil { m.refresher.callbackMu.Lock() m.refresher.onTokenRefreshed = callback @@ -162,8 +157,11 @@ func (m *RefreshManager) SetOnTokenRefreshed(callback func(tokenID string, token log.Debug("refresh manager: token refresh callback registered") } -// InitializeAndStart 初始化并启动后台刷新(便捷方法) +// InitializeAndStart initializes and starts background refreshing (convenience method). func InitializeAndStart(baseDir string, cfg *config.Config) { + // Initialize global fingerprint config + initGlobalFingerprintConfig(cfg) + manager := GetRefreshManager() if err := manager.Initialize(baseDir, cfg); err != nil { log.Errorf("refresh manager: initialization failed: %v", err) @@ -172,7 +170,31 @@ func InitializeAndStart(baseDir string, cfg *config.Config) { manager.Start() } -// StopGlobalRefreshManager 停止全局刷新管理器 +// initGlobalFingerprintConfig loads fingerprint settings from application config. +func initGlobalFingerprintConfig(cfg *config.Config) { + if cfg == nil || cfg.KiroFingerprint == nil { + return + } + fpCfg := cfg.KiroFingerprint + SetGlobalFingerprintConfig(&FingerprintConfig{ + OIDCSDKVersion: fpCfg.OIDCSDKVersion, + RuntimeSDKVersion: fpCfg.RuntimeSDKVersion, + StreamingSDKVersion: fpCfg.StreamingSDKVersion, + OSType: fpCfg.OSType, + OSVersion: fpCfg.OSVersion, + NodeVersion: fpCfg.NodeVersion, + KiroVersion: fpCfg.KiroVersion, + KiroHash: fpCfg.KiroHash, + }) + log.Debug("kiro: global fingerprint config loaded") +} + +// InitFingerprintConfig initializes the global fingerprint config from application config. +func InitFingerprintConfig(cfg *config.Config) { + initGlobalFingerprintConfig(cfg) +} + +// StopGlobalRefreshManager stops the global refresh manager. func StopGlobalRefreshManager() { if globalRefreshManager != nil { globalRefreshManager.Stop() diff --git a/internal/auth/kiro/social_auth.go b/internal/auth/kiro/social_auth.go index 65f31ba4..e329df4b 100644 --- a/internal/auth/kiro/social_auth.go +++ b/internal/auth/kiro/social_auth.go @@ -84,6 +84,8 @@ type SocialAuthClient struct { httpClient *http.Client cfg *config.Config protocolHandler *ProtocolHandler + machineID string + kiroVersion string } // NewSocialAuthClient creates a new social auth client. @@ -92,10 +94,13 @@ func NewSocialAuthClient(cfg *config.Config) *SocialAuthClient { if cfg != nil { client = util.SetProxy(&cfg.SDKConfig, client) } + fp := GlobalFingerprintManager().GetFingerprint("login") return &SocialAuthClient{ httpClient: client, cfg: cfg, protocolHandler: NewProtocolHandler(), + machineID: fp.KiroHash, + kiroVersion: fp.KiroVersion, } } @@ -229,7 +234,8 @@ func (c *SocialAuthClient) CreateToken(ctx context.Context, req *CreateTokenRequ } httpReq.Header.Set("Content-Type", "application/json") - httpReq.Header.Set("User-Agent", "KiroIDE-0.7.45-cli-proxy-api") + httpReq.Header.Set("User-Agent", fmt.Sprintf("KiroIDE-%s-%s", c.kiroVersion, c.machineID)) + httpReq.Header.Set("Accept", "application/json, text/plain, */*") resp, err := c.httpClient.Do(httpReq) if err != nil { @@ -269,7 +275,8 @@ func (c *SocialAuthClient) RefreshSocialToken(ctx context.Context, refreshToken } httpReq.Header.Set("Content-Type", "application/json") - httpReq.Header.Set("User-Agent", "cli-proxy-api/1.0.0") + httpReq.Header.Set("User-Agent", fmt.Sprintf("KiroIDE-%s-%s", c.kiroVersion, c.machineID)) + httpReq.Header.Set("Accept", "application/json, text/plain, */*") resp, err := c.httpClient.Do(httpReq) if err != nil { @@ -466,7 +473,7 @@ func forceDefaultProtocolHandler() { if runtime.GOOS != "linux" { return // Non-Linux platforms use different handler mechanisms } - + // Set our handler as default using xdg-mime cmd := exec.Command("xdg-mime", "default", "kiro-oauth-handler.desktop", "x-scheme-handler/kiro") if err := cmd.Run(); err != nil { diff --git a/internal/auth/kiro/sso_oidc.go b/internal/auth/kiro/sso_oidc.go index 60fb8871..4747b24c 100644 --- a/internal/auth/kiro/sso_oidc.go +++ b/internal/auth/kiro/sso_oidc.go @@ -14,6 +14,7 @@ import ( "io" "net" "net/http" + "net/url" "os" "strings" "time" @@ -40,21 +41,13 @@ const ( // Authorization code flow callback authCodeCallbackPath = "/oauth/callback" authCodeCallbackPort = 19877 - - // User-Agent to match official Kiro IDE - kiroUserAgent = "KiroIDE" - - // IDC token refresh headers (matching Kiro IDE behavior) - idcAmzUserAgent = "aws-sdk-js/3.738.0 ua/2.1 os/other lang/js md/browser#unknown_unknown api/sso-oidc#3.738.0 m/E KiroIDE" ) -// Sentinel errors for OIDC token polling var ( ErrAuthorizationPending = errors.New("authorization_pending") ErrSlowDown = errors.New("slow_down") ) -// SSOOIDCClient handles AWS SSO OIDC authentication. type SSOOIDCClient struct { httpClient *http.Client cfg *config.Config @@ -74,10 +67,10 @@ func NewSSOOIDCClient(cfg *config.Config) *SSOOIDCClient { // RegisterClientResponse from AWS SSO OIDC. type RegisterClientResponse struct { - ClientID string `json:"clientId"` - ClientSecret string `json:"clientSecret"` - ClientIDIssuedAt int64 `json:"clientIdIssuedAt"` - ClientSecretExpiresAt int64 `json:"clientSecretExpiresAt"` + ClientID string `json:"clientId"` + ClientSecret string `json:"clientSecret"` + ClientIDIssuedAt int64 `json:"clientIdIssuedAt"` + ClientSecretExpiresAt int64 `json:"clientSecretExpiresAt"` } // StartDeviceAuthResponse from AWS SSO OIDC. @@ -174,8 +167,7 @@ func (c *SSOOIDCClient) RegisterClientWithRegion(ctx context.Context, region str if err != nil { return nil, err } - req.Header.Set("Content-Type", "application/json") - req.Header.Set("User-Agent", kiroUserAgent) + SetOIDCHeaders(req) resp, err := c.httpClient.Do(req) if err != nil { @@ -220,8 +212,7 @@ func (c *SSOOIDCClient) StartDeviceAuthorizationWithIDC(ctx context.Context, cli if err != nil { return nil, err } - req.Header.Set("Content-Type", "application/json") - req.Header.Set("User-Agent", kiroUserAgent) + SetOIDCHeaders(req) resp, err := c.httpClient.Do(req) if err != nil { @@ -267,8 +258,7 @@ func (c *SSOOIDCClient) CreateTokenWithRegion(ctx context.Context, clientID, cli if err != nil { return nil, err } - req.Header.Set("Content-Type", "application/json") - req.Header.Set("User-Agent", kiroUserAgent) + SetOIDCHeaders(req) resp, err := c.httpClient.Do(req) if err != nil { @@ -311,8 +301,11 @@ func (c *SSOOIDCClient) CreateTokenWithRegion(ctx context.Context, clientID, cli return &result, nil } -// RefreshTokenWithRegion refreshes an access token using the refresh token with a specific region. +// RefreshTokenWithRegion refreshes an access token using the refresh token with a specific OIDC region. func (c *SSOOIDCClient) RefreshTokenWithRegion(ctx context.Context, clientID, clientSecret, refreshToken, region, startURL string) (*KiroTokenData, error) { + if region == "" { + region = defaultIDCRegion + } endpoint := getOIDCEndpoint(region) payload := map[string]string{ @@ -331,18 +324,7 @@ func (c *SSOOIDCClient) RefreshTokenWithRegion(ctx context.Context, clientID, cl if err != nil { return nil, err } - - // Set headers matching kiro2api's IDC token refresh - // These headers are required for successful IDC token refresh - req.Header.Set("Content-Type", "application/json") - req.Header.Set("Host", fmt.Sprintf("oidc.%s.amazonaws.com", region)) - req.Header.Set("Connection", "keep-alive") - req.Header.Set("x-amz-user-agent", idcAmzUserAgent) - req.Header.Set("Accept", "*/*") - req.Header.Set("Accept-Language", "*") - req.Header.Set("sec-fetch-mode", "cors") - req.Header.Set("User-Agent", "node") - req.Header.Set("Accept-Encoding", "br, gzip, deflate") + SetOIDCHeaders(req) resp, err := c.httpClient.Do(req) if err != nil { @@ -469,10 +451,10 @@ func (c *SSOOIDCClient) LoginWithIDC(ctx context.Context, startURL, region strin // Step 5: Get profile ARN from CodeWhisperer API fmt.Println("Fetching profile information...") - profileArn := c.fetchProfileArn(ctx, tokenResp.AccessToken) + profileArn := c.FetchProfileArn(ctx, tokenResp.AccessToken, regResp.ClientID, tokenResp.RefreshToken) // Fetch user email - email := FetchUserEmailWithFallback(ctx, c.cfg, tokenResp.AccessToken) + email := FetchUserEmailWithFallback(ctx, c.cfg, tokenResp.AccessToken, regResp.ClientID, tokenResp.RefreshToken) if email != "" { fmt.Printf(" Logged in as: %s\n", email) } @@ -502,12 +484,36 @@ func (c *SSOOIDCClient) LoginWithIDC(ctx context.Context, startURL, region strin return nil, fmt.Errorf("authorization timed out") } +// IDCLoginOptions holds optional parameters for IDC login. +type IDCLoginOptions struct { + StartURL string // Pre-configured start URL (skips prompt if set) + Region string // OIDC region for login and token refresh (defaults to us-east-1) + UseDeviceCode bool // Use Device Code flow instead of Auth Code flow +} + // LoginWithMethodSelection prompts the user to select between Builder ID and IDC, then performs the login. -func (c *SSOOIDCClient) LoginWithMethodSelection(ctx context.Context) (*KiroTokenData, error) { +// Options can be provided to pre-configure IDC parameters (startURL, region). +// If StartURL is provided in opts, IDC flow is used directly without prompting. +func (c *SSOOIDCClient) LoginWithMethodSelection(ctx context.Context, opts *IDCLoginOptions) (*KiroTokenData, error) { fmt.Println("\n╔══════════════════════════════════════════════════════════╗") fmt.Println("║ Kiro Authentication (AWS) ║") fmt.Println("╚══════════════════════════════════════════════════════════╝") + // If IDC options with StartURL are provided, skip method selection and use IDC directly + if opts != nil && opts.StartURL != "" { + region := opts.Region + if region == "" { + region = defaultIDCRegion + } + fmt.Printf("\n Using IDC with Start URL: %s\n", opts.StartURL) + fmt.Printf(" Region: %s\n", region) + + if opts.UseDeviceCode { + return c.LoginWithIDCAndOptions(ctx, opts.StartURL, region) + } + return c.LoginWithIDCAuthCode(ctx, opts.StartURL, region) + } + // Prompt for login method options := []string{ "Use with Builder ID (personal AWS account)", @@ -520,15 +526,41 @@ func (c *SSOOIDCClient) LoginWithMethodSelection(ctx context.Context) (*KiroToke return c.LoginWithBuilderID(ctx) } - // IDC flow - prompt for start URL and region - fmt.Println() - startURL := promptInput("? Enter Start URL", "") - if startURL == "" { - return nil, fmt.Errorf("start URL is required for IDC login") + // IDC flow - use pre-configured values or prompt + var startURL, region string + + if opts != nil { + startURL = opts.StartURL + region = opts.Region } - region := promptInput("? Enter Region", defaultIDCRegion) + fmt.Println() + // Use pre-configured startURL or prompt + if startURL == "" { + startURL = promptInput("? Enter Start URL", "") + if startURL == "" { + return nil, fmt.Errorf("start URL is required for IDC login") + } + } else { + fmt.Printf(" Using pre-configured Start URL: %s\n", startURL) + } + + // Use pre-configured region or prompt + if region == "" { + region = promptInput("? Enter Region", defaultIDCRegion) + } else { + fmt.Printf(" Using pre-configured Region: %s\n", region) + } + + if opts != nil && opts.UseDeviceCode { + return c.LoginWithIDCAndOptions(ctx, startURL, region) + } + return c.LoginWithIDCAuthCode(ctx, startURL, region) +} + +// LoginWithIDCAndOptions performs IDC login with the specified region. +func (c *SSOOIDCClient) LoginWithIDCAndOptions(ctx context.Context, startURL, region string) (*KiroTokenData, error) { return c.LoginWithIDC(ctx, startURL, region) } @@ -550,8 +582,7 @@ func (c *SSOOIDCClient) RegisterClient(ctx context.Context) (*RegisterClientResp if err != nil { return nil, err } - req.Header.Set("Content-Type", "application/json") - req.Header.Set("User-Agent", kiroUserAgent) + SetOIDCHeaders(req) resp, err := c.httpClient.Do(req) if err != nil { @@ -594,8 +625,7 @@ func (c *SSOOIDCClient) StartDeviceAuthorization(ctx context.Context, clientID, if err != nil { return nil, err } - req.Header.Set("Content-Type", "application/json") - req.Header.Set("User-Agent", kiroUserAgent) + SetOIDCHeaders(req) resp, err := c.httpClient.Do(req) if err != nil { @@ -639,8 +669,7 @@ func (c *SSOOIDCClient) CreateToken(ctx context.Context, clientID, clientSecret, if err != nil { return nil, err } - req.Header.Set("Content-Type", "application/json") - req.Header.Set("User-Agent", kiroUserAgent) + SetOIDCHeaders(req) resp, err := c.httpClient.Do(req) if err != nil { @@ -702,13 +731,7 @@ func (c *SSOOIDCClient) RefreshToken(ctx context.Context, clientID, clientSecret if err != nil { return nil, err } - - // Set headers matching Kiro IDE behavior for better compatibility - req.Header.Set("Content-Type", "application/json") - req.Header.Set("Host", "oidc.us-east-1.amazonaws.com") - req.Header.Set("x-amz-user-agent", idcAmzUserAgent) - req.Header.Set("User-Agent", "node") - req.Header.Set("Accept", "*/*") + SetOIDCHeaders(req) resp, err := c.httpClient.Do(req) if err != nil { @@ -835,12 +858,8 @@ func (c *SSOOIDCClient) LoginWithBuilderID(ctx context.Context) (*KiroTokenData, log.Debugf("Failed to close browser: %v", err) } - // Step 5: Get profile ARN from CodeWhisperer API - fmt.Println("Fetching profile information...") - profileArn := c.fetchProfileArn(ctx, tokenResp.AccessToken) - // Fetch user email (tries CodeWhisperer API first, then userinfo endpoint, then JWT parsing) - email := FetchUserEmailWithFallback(ctx, c.cfg, tokenResp.AccessToken) + email := FetchUserEmailWithFallback(ctx, c.cfg, tokenResp.AccessToken, regResp.ClientID, tokenResp.RefreshToken) if email != "" { fmt.Printf(" Logged in as: %s\n", email) } @@ -850,7 +869,7 @@ func (c *SSOOIDCClient) LoginWithBuilderID(ctx context.Context) (*KiroTokenData, return &KiroTokenData{ AccessToken: tokenResp.AccessToken, RefreshToken: tokenResp.RefreshToken, - ProfileArn: profileArn, + ProfileArn: "", // Builder ID has no profile ExpiresAt: expiresAt.Format(time.RFC3339), AuthMethod: "builder-id", Provider: "AWS", @@ -859,15 +878,15 @@ func (c *SSOOIDCClient) LoginWithBuilderID(ctx context.Context) (*KiroTokenData, Email: email, Region: defaultIDCRegion, }, nil - } - } + } + } - // Close browser on timeout for better UX - if err := browser.CloseBrowser(); err != nil { - log.Debugf("Failed to close browser on timeout: %v", err) - } - return nil, fmt.Errorf("authorization timed out") - } + // Close browser on timeout for better UX + if err := browser.CloseBrowser(); err != nil { + log.Debugf("Failed to close browser on timeout: %v", err) + } + return nil, fmt.Errorf("authorization timed out") +} // FetchUserEmail retrieves the user's email from AWS SSO OIDC userinfo endpoint. // Falls back to JWT parsing if userinfo fails. @@ -931,20 +950,64 @@ func (c *SSOOIDCClient) tryUserInfoEndpoint(ctx context.Context, accessToken str return "" } -// fetchProfileArn retrieves the profile ARN from CodeWhisperer API. -// This is needed for file naming since AWS SSO OIDC doesn't return profile info. -func (c *SSOOIDCClient) fetchProfileArn(ctx context.Context, accessToken string) string { - // Try ListProfiles API first - profileArn := c.tryListProfiles(ctx, accessToken) +// FetchProfileArn fetches the profile ARN from ListAvailableProfiles API. +// This is used to get profileArn for imported accounts that may not have it. +func (c *SSOOIDCClient) FetchProfileArn(ctx context.Context, accessToken, clientID, refreshToken string) string { + profileArn := c.tryListAvailableProfiles(ctx, accessToken, clientID, refreshToken) if profileArn != "" { return profileArn } - - // Fallback: Try ListAvailableCustomizations - return c.tryListCustomizations(ctx, accessToken) + return c.tryListProfilesLegacy(ctx, accessToken) } -func (c *SSOOIDCClient) tryListProfiles(ctx context.Context, accessToken string) string { +func (c *SSOOIDCClient) tryListAvailableProfiles(ctx context.Context, accessToken, clientID, refreshToken string) string { + req, err := http.NewRequestWithContext(ctx, http.MethodPost, GetKiroAPIEndpoint("")+"/ListAvailableProfiles", strings.NewReader("{}")) + if err != nil { + return "" + } + + req.Header.Set("Content-Type", "application/json") + accountKey := GetAccountKey(clientID, refreshToken) + setRuntimeHeaders(req, accessToken, accountKey) + + resp, err := c.httpClient.Do(req) + if err != nil { + log.Debugf("ListAvailableProfiles request failed: %v", err) + return "" + } + defer resp.Body.Close() + + respBody, _ := io.ReadAll(resp.Body) + + if resp.StatusCode != http.StatusOK { + log.Debugf("ListAvailableProfiles failed (status %d): %s", resp.StatusCode, string(respBody)) + return "" + } + + log.Debugf("ListAvailableProfiles response: %s", string(respBody)) + + var result struct { + Profiles []struct { + Arn string `json:"arn"` + ProfileName string `json:"profileName"` + } `json:"profiles"` + NextToken *string `json:"nextToken"` + } + + if err := json.Unmarshal(respBody, &result); err != nil { + log.Debugf("ListAvailableProfiles parse error: %v", err) + return "" + } + + if len(result.Profiles) > 0 { + log.Debugf("Found profile: %s (%s)", result.Profiles[0].ProfileName, result.Profiles[0].Arn) + return result.Profiles[0].Arn + } + + return "" +} + +func (c *SSOOIDCClient) tryListProfilesLegacy(ctx context.Context, accessToken string) string { payload := map[string]interface{}{ "origin": "AI_EDITOR", } @@ -954,7 +1017,9 @@ func (c *SSOOIDCClient) tryListProfiles(ctx context.Context, accessToken string) return "" } - req, err := http.NewRequestWithContext(ctx, http.MethodPost, "https://codewhisperer.us-east-1.amazonaws.com", strings.NewReader(string(body))) + // Use the legacy CodeWhisperer endpoint for JSON-RPC style requests. + // The Q endpoint (q.{region}.amazonaws.com) does NOT support x-amz-target headers. + req, err := http.NewRequestWithContext(ctx, http.MethodPost, GetCodeWhispererLegacyEndpoint(""), strings.NewReader(string(body))) if err != nil { return "" } @@ -973,11 +1038,11 @@ func (c *SSOOIDCClient) tryListProfiles(ctx context.Context, accessToken string) respBody, _ := io.ReadAll(resp.Body) if resp.StatusCode != http.StatusOK { - log.Debugf("ListProfiles failed (status %d): %s", resp.StatusCode, string(respBody)) + log.Debugf("ListProfiles (legacy) failed (status %d): %s", resp.StatusCode, string(respBody)) return "" } - log.Debugf("ListProfiles response: %s", string(respBody)) + log.Debugf("ListProfiles (legacy) response: %s", string(respBody)) var result struct { Profiles []struct { @@ -1001,63 +1066,6 @@ func (c *SSOOIDCClient) tryListProfiles(ctx context.Context, accessToken string) return "" } -func (c *SSOOIDCClient) tryListCustomizations(ctx context.Context, accessToken string) string { - payload := map[string]interface{}{ - "origin": "AI_EDITOR", - } - - body, err := json.Marshal(payload) - if err != nil { - return "" - } - - req, err := http.NewRequestWithContext(ctx, http.MethodPost, "https://codewhisperer.us-east-1.amazonaws.com", strings.NewReader(string(body))) - if err != nil { - return "" - } - - req.Header.Set("Content-Type", "application/x-amz-json-1.0") - req.Header.Set("x-amz-target", "AmazonCodeWhispererService.ListAvailableCustomizations") - req.Header.Set("Authorization", "Bearer "+accessToken) - req.Header.Set("Accept", "application/json") - - resp, err := c.httpClient.Do(req) - if err != nil { - return "" - } - defer resp.Body.Close() - - respBody, _ := io.ReadAll(resp.Body) - - if resp.StatusCode != http.StatusOK { - log.Debugf("ListAvailableCustomizations failed (status %d): %s", resp.StatusCode, string(respBody)) - return "" - } - - log.Debugf("ListAvailableCustomizations response: %s", string(respBody)) - - var result struct { - Customizations []struct { - Arn string `json:"arn"` - } `json:"customizations"` - ProfileArn string `json:"profileArn"` - } - - if err := json.Unmarshal(respBody, &result); err != nil { - return "" - } - - if result.ProfileArn != "" { - return result.ProfileArn - } - - if len(result.Customizations) > 0 { - return result.Customizations[0].Arn - } - - return "" -} - // RegisterClientForAuthCode registers a new OIDC client for authorization code flow. func (c *SSOOIDCClient) RegisterClientForAuthCode(ctx context.Context, redirectURI string) (*RegisterClientResponse, error) { payload := map[string]interface{}{ @@ -1078,8 +1086,7 @@ func (c *SSOOIDCClient) RegisterClientForAuthCode(ctx context.Context, redirectU if err != nil { return nil, err } - req.Header.Set("Content-Type", "application/json") - req.Header.Set("User-Agent", kiroUserAgent) + SetOIDCHeaders(req) resp, err := c.httpClient.Do(req) if err != nil { @@ -1105,6 +1112,53 @@ func (c *SSOOIDCClient) RegisterClientForAuthCode(ctx context.Context, redirectU return &result, nil } +func (c *SSOOIDCClient) RegisterClientForAuthCodeWithIDC(ctx context.Context, redirectURI, issuerUrl, region string) (*RegisterClientResponse, error) { + endpoint := getOIDCEndpoint(region) + + payload := map[string]interface{}{ + "clientName": "Kiro IDE", + "clientType": "public", + "scopes": []string{"codewhisperer:completions", "codewhisperer:analysis", "codewhisperer:conversations", "codewhisperer:transformations", "codewhisperer:taskassist"}, + "grantTypes": []string{"authorization_code", "refresh_token"}, + "redirectUris": []string{redirectURI}, + "issuerUrl": issuerUrl, + } + + body, err := json.Marshal(payload) + if err != nil { + return nil, err + } + + req, err := http.NewRequestWithContext(ctx, http.MethodPost, endpoint+"/client/register", strings.NewReader(string(body))) + if err != nil { + return nil, err + } + SetOIDCHeaders(req) + + resp, err := c.httpClient.Do(req) + if err != nil { + return nil, err + } + defer resp.Body.Close() + + respBody, err := io.ReadAll(resp.Body) + if err != nil { + return nil, err + } + + if resp.StatusCode != http.StatusOK { + log.Debugf("register client for auth code with IDC failed (status %d): %s", resp.StatusCode, string(respBody)) + return nil, fmt.Errorf("register client failed (status %d)", resp.StatusCode) + } + + var result RegisterClientResponse + if err := json.Unmarshal(respBody, &result); err != nil { + return nil, err + } + + return &result, nil +} + // AuthCodeCallbackResult contains the result from authorization code callback. type AuthCodeCallbackResult struct { Code string @@ -1128,6 +1182,7 @@ func (c *SSOOIDCClient) startAuthCodeCallbackServer(ctx context.Context, expecte port := listener.Addr().(*net.TCPAddr).Port redirectURI := fmt.Sprintf("http://127.0.0.1:%d%s", port, authCodeCallbackPath) resultChan := make(chan AuthCodeCallbackResult, 1) + doneChan := make(chan struct{}) server := &http.Server{ ReadHeaderTimeout: 10 * time.Second, @@ -1147,6 +1202,7 @@ func (c *SSOOIDCClient) startAuthCodeCallbackServer(ctx context.Context, expecte Login Failed

Login Failed

Error: %s

You can close this window.

`, html.EscapeString(errParam)) resultChan <- AuthCodeCallbackResult{Error: errParam} + close(doneChan) return } @@ -1156,6 +1212,7 @@ func (c *SSOOIDCClient) startAuthCodeCallbackServer(ctx context.Context, expecte Login Failed

Login Failed

Invalid state parameter

You can close this window.

`) resultChan <- AuthCodeCallbackResult{Error: "state mismatch"} + close(doneChan) return } @@ -1164,6 +1221,7 @@ func (c *SSOOIDCClient) startAuthCodeCallbackServer(ctx context.Context, expecte

Login Successful!

You can close this window and return to the terminal.

`) resultChan <- AuthCodeCallbackResult{Code: code, State: state} + close(doneChan) }) server.Handler = mux @@ -1178,7 +1236,7 @@ func (c *SSOOIDCClient) startAuthCodeCallbackServer(ctx context.Context, expecte select { case <-ctx.Done(): case <-time.After(10 * time.Minute): - case <-resultChan: + case <-doneChan: } _ = server.Shutdown(context.Background()) }() @@ -1227,8 +1285,54 @@ func (c *SSOOIDCClient) CreateTokenWithAuthCode(ctx context.Context, clientID, c if err != nil { return nil, err } - req.Header.Set("Content-Type", "application/json") - req.Header.Set("User-Agent", kiroUserAgent) + SetOIDCHeaders(req) + + resp, err := c.httpClient.Do(req) + if err != nil { + return nil, err + } + defer resp.Body.Close() + + respBody, err := io.ReadAll(resp.Body) + if err != nil { + return nil, err + } + + if resp.StatusCode != http.StatusOK { + log.Debugf("create token with auth code failed (status %d): %s", resp.StatusCode, string(respBody)) + return nil, fmt.Errorf("create token failed (status %d)", resp.StatusCode) + } + + var result CreateTokenResponse + if err := json.Unmarshal(respBody, &result); err != nil { + return nil, err + } + + return &result, nil +} + +func (c *SSOOIDCClient) CreateTokenWithAuthCodeAndRegion(ctx context.Context, clientID, clientSecret, code, codeVerifier, redirectURI, region string) (*CreateTokenResponse, error) { + endpoint := getOIDCEndpoint(region) + + payload := map[string]string{ + "clientId": clientID, + "clientSecret": clientSecret, + "code": code, + "codeVerifier": codeVerifier, + "redirectUri": redirectURI, + "grantType": "authorization_code", + } + + body, err := json.Marshal(payload) + if err != nil { + return nil, err + } + + req, err := http.NewRequestWithContext(ctx, http.MethodPost, endpoint+"/token", strings.NewReader(string(body))) + if err != nil { + return nil, err + } + SetOIDCHeaders(req) resp, err := c.httpClient.Do(req) if err != nil { @@ -1352,12 +1456,118 @@ func (c *SSOOIDCClient) LoginWithBuilderIDAuthCode(ctx context.Context) (*KiroTo fmt.Println("\n✓ Authentication successful!") - // Step 8: Get profile ARN - fmt.Println("Fetching profile information...") - profileArn := c.fetchProfileArn(ctx, tokenResp.AccessToken) - // Fetch user email (tries CodeWhisperer API first, then userinfo endpoint, then JWT parsing) - email := FetchUserEmailWithFallback(ctx, c.cfg, tokenResp.AccessToken) + email := FetchUserEmailWithFallback(ctx, c.cfg, tokenResp.AccessToken, regResp.ClientID, tokenResp.RefreshToken) + if email != "" { + fmt.Printf(" Logged in as: %s\n", email) + } + + expiresAt := time.Now().Add(time.Duration(tokenResp.ExpiresIn) * time.Second) + + return &KiroTokenData{ + AccessToken: tokenResp.AccessToken, + RefreshToken: tokenResp.RefreshToken, + ProfileArn: "", // Builder ID has no profile + ExpiresAt: expiresAt.Format(time.RFC3339), + AuthMethod: "builder-id", + Provider: "AWS", + ClientID: regResp.ClientID, + ClientSecret: regResp.ClientSecret, + Email: email, + Region: defaultIDCRegion, + }, nil + } +} + +func (c *SSOOIDCClient) LoginWithIDCAuthCode(ctx context.Context, startURL, region string) (*KiroTokenData, error) { + fmt.Println("\n╔══════════════════════════════════════════════════════════╗") + fmt.Println("║ Kiro Authentication (AWS IDC - Auth Code) ║") + fmt.Println("╚══════════════════════════════════════════════════════════╝") + + if region == "" { + region = defaultIDCRegion + } + + codeVerifier, codeChallenge, err := generatePKCEForAuthCode() + if err != nil { + return nil, fmt.Errorf("failed to generate PKCE: %w", err) + } + + state, err := generateStateForAuthCode() + if err != nil { + return nil, fmt.Errorf("failed to generate state: %w", err) + } + + fmt.Println("\nStarting callback server...") + redirectURI, resultChan, err := c.startAuthCodeCallbackServer(ctx, state) + if err != nil { + return nil, fmt.Errorf("failed to start callback server: %w", err) + } + log.Debugf("Callback server started, redirect URI: %s", redirectURI) + + fmt.Println("Registering client...") + regResp, err := c.RegisterClientForAuthCodeWithIDC(ctx, redirectURI, startURL, region) + if err != nil { + return nil, fmt.Errorf("failed to register client: %w", err) + } + log.Debugf("Client registered: %s", regResp.ClientID) + + endpoint := getOIDCEndpoint(region) + scopes := "codewhisperer:completions,codewhisperer:analysis,codewhisperer:conversations,codewhisperer:transformations,codewhisperer:taskassist" + authURL := buildAuthorizationURL(endpoint, regResp.ClientID, redirectURI, scopes, state, codeChallenge) + + fmt.Println("\n════════════════════════════════════════════════════════════") + fmt.Println(" Opening browser for authentication...") + fmt.Println("════════════════════════════════════════════════════════════") + fmt.Printf("\n URL: %s\n\n", authURL) + + if c.cfg != nil { + browser.SetIncognitoMode(c.cfg.IncognitoBrowser) + } else { + browser.SetIncognitoMode(true) + } + + if err := browser.OpenURL(authURL); err != nil { + log.Warnf("Could not open browser automatically: %v", err) + fmt.Println(" ⚠ Could not open browser automatically.") + fmt.Println(" Please open the URL above in your browser manually.") + } else { + fmt.Println(" (Browser opened automatically)") + } + + fmt.Println("\n Waiting for authorization callback...") + + select { + case <-ctx.Done(): + browser.CloseBrowser() + return nil, ctx.Err() + case <-time.After(10 * time.Minute): + browser.CloseBrowser() + return nil, fmt.Errorf("authorization timed out") + case result := <-resultChan: + if result.Error != "" { + browser.CloseBrowser() + return nil, fmt.Errorf("authorization failed: %s", result.Error) + } + + fmt.Println("\n✓ Authorization received!") + + if err := browser.CloseBrowser(); err != nil { + log.Debugf("Failed to close browser: %v", err) + } + + fmt.Println("Exchanging code for tokens...") + tokenResp, err := c.CreateTokenWithAuthCodeAndRegion(ctx, regResp.ClientID, regResp.ClientSecret, result.Code, codeVerifier, redirectURI, region) + if err != nil { + return nil, fmt.Errorf("failed to exchange code for tokens: %w", err) + } + + fmt.Println("\n✓ Authentication successful!") + + fmt.Println("Fetching profile information...") + profileArn := c.FetchProfileArn(ctx, tokenResp.AccessToken, regResp.ClientID, tokenResp.RefreshToken) + + email := FetchUserEmailWithFallback(ctx, c.cfg, tokenResp.AccessToken, regResp.ClientID, tokenResp.RefreshToken) if email != "" { fmt.Printf(" Logged in as: %s\n", email) } @@ -1369,12 +1579,25 @@ func (c *SSOOIDCClient) LoginWithBuilderIDAuthCode(ctx context.Context) (*KiroTo RefreshToken: tokenResp.RefreshToken, ProfileArn: profileArn, ExpiresAt: expiresAt.Format(time.RFC3339), - AuthMethod: "builder-id", + AuthMethod: "idc", Provider: "AWS", ClientID: regResp.ClientID, ClientSecret: regResp.ClientSecret, Email: email, - Region: defaultIDCRegion, + StartURL: startURL, + Region: region, }, nil } } + +func buildAuthorizationURL(endpoint, clientID, redirectURI, scopes, state, codeChallenge string) string { + params := url.Values{} + params.Set("response_type", "code") + params.Set("client_id", clientID) + params.Set("redirect_uri", redirectURI) + params.Set("scopes", scopes) + params.Set("state", state) + params.Set("code_challenge", codeChallenge) + params.Set("code_challenge_method", "S256") + return fmt.Sprintf("%s/authorize?%s", endpoint, params.Encode()) +} diff --git a/internal/auth/kiro/sso_oidc_test.go b/internal/auth/kiro/sso_oidc_test.go new file mode 100644 index 00000000..760a6033 --- /dev/null +++ b/internal/auth/kiro/sso_oidc_test.go @@ -0,0 +1,261 @@ +package kiro + +import ( + "context" + "encoding/json" + "fmt" + "io" + "net/http" + "net/http/httptest" + "net/url" + "strings" + "testing" +) + +type recordingRoundTripper struct { + lastReq *http.Request +} + +func (rt *recordingRoundTripper) RoundTrip(req *http.Request) (*http.Response, error) { + rt.lastReq = req + body := `{"nextToken":null,"profiles":[{"arn":"arn:aws:codewhisperer:us-east-1:123456789012:profile/ABC","profileName":"test"}]}` + return &http.Response{ + StatusCode: http.StatusOK, + Body: io.NopCloser(strings.NewReader(body)), + Header: make(http.Header), + }, nil +} + +func TestTryListAvailableProfiles_UsesClientIDForAccountKey(t *testing.T) { + rt := &recordingRoundTripper{} + client := &SSOOIDCClient{ + httpClient: &http.Client{Transport: rt}, + } + + profileArn := client.tryListAvailableProfiles(context.Background(), "access-token", "client-id-123", "refresh-token-456") + if profileArn == "" { + t.Fatal("expected profileArn, got empty result") + } + + accountKey := GetAccountKey("client-id-123", "refresh-token-456") + fp := GlobalFingerprintManager().GetFingerprint(accountKey) + expected := fmt.Sprintf("aws-sdk-js/%s KiroIDE-%s-%s", fp.RuntimeSDKVersion, fp.KiroVersion, fp.KiroHash) + got := rt.lastReq.Header.Get("X-Amz-User-Agent") + if got != expected { + t.Errorf("X-Amz-User-Agent = %q, want %q", got, expected) + } +} + +func TestTryListAvailableProfiles_UsesRefreshTokenWhenClientIDMissing(t *testing.T) { + rt := &recordingRoundTripper{} + client := &SSOOIDCClient{ + httpClient: &http.Client{Transport: rt}, + } + + profileArn := client.tryListAvailableProfiles(context.Background(), "access-token", "", "refresh-token-789") + if profileArn == "" { + t.Fatal("expected profileArn, got empty result") + } + + accountKey := GetAccountKey("", "refresh-token-789") + fp := GlobalFingerprintManager().GetFingerprint(accountKey) + expected := fmt.Sprintf("aws-sdk-js/%s KiroIDE-%s-%s", fp.RuntimeSDKVersion, fp.KiroVersion, fp.KiroHash) + got := rt.lastReq.Header.Get("X-Amz-User-Agent") + if got != expected { + t.Errorf("X-Amz-User-Agent = %q, want %q", got, expected) + } +} + +func TestRegisterClientForAuthCodeWithIDC(t *testing.T) { + var capturedReq struct { + Method string + Path string + Headers http.Header + Body map[string]interface{} + } + + mockResp := RegisterClientResponse{ + ClientID: "test-client-id", + ClientSecret: "test-client-secret", + ClientIDIssuedAt: 1700000000, + ClientSecretExpiresAt: 1700086400, + } + + ts := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + capturedReq.Method = r.Method + capturedReq.Path = r.URL.Path + capturedReq.Headers = r.Header.Clone() + + bodyBytes, _ := io.ReadAll(r.Body) + json.Unmarshal(bodyBytes, &capturedReq.Body) + + w.Header().Set("Content-Type", "application/json") + json.NewEncoder(w).Encode(mockResp) + })) + defer ts.Close() + + // Extract host to build a region that resolves to our test server. + // Override getOIDCEndpoint by passing region="" and patching the endpoint. + // Since getOIDCEndpoint builds "https://oidc.{region}.amazonaws.com", we + // instead inject the test server URL directly via a custom HTTP client transport. + client := &SSOOIDCClient{ + httpClient: ts.Client(), + } + + // We need to route the request to our test server. Use a transport that rewrites the URL. + client.httpClient.Transport = &rewriteTransport{ + base: ts.Client().Transport, + targetURL: ts.URL, + } + + resp, err := client.RegisterClientForAuthCodeWithIDC( + context.Background(), + "http://127.0.0.1:19877/oauth/callback", + "https://my-idc-instance.awsapps.com/start", + "us-east-1", + ) + if err != nil { + t.Fatalf("unexpected error: %v", err) + } + + // Verify request method and path + if capturedReq.Method != http.MethodPost { + t.Errorf("method = %q, want POST", capturedReq.Method) + } + if capturedReq.Path != "/client/register" { + t.Errorf("path = %q, want /client/register", capturedReq.Path) + } + + // Verify headers + if ct := capturedReq.Headers.Get("Content-Type"); ct != "application/json" { + t.Errorf("Content-Type = %q, want application/json", ct) + } + ua := capturedReq.Headers.Get("User-Agent") + if !strings.Contains(ua, "KiroIDE") { + t.Errorf("User-Agent %q does not contain KiroIDE", ua) + } + if !strings.Contains(ua, "sso-oidc") { + t.Errorf("User-Agent %q does not contain sso-oidc", ua) + } + xua := capturedReq.Headers.Get("X-Amz-User-Agent") + if !strings.Contains(xua, "KiroIDE") { + t.Errorf("x-amz-user-agent %q does not contain KiroIDE", xua) + } + + // Verify body fields + if v, _ := capturedReq.Body["clientName"].(string); v != "Kiro IDE" { + t.Errorf("clientName = %q, want %q", v, "Kiro IDE") + } + if v, _ := capturedReq.Body["clientType"].(string); v != "public" { + t.Errorf("clientType = %q, want %q", v, "public") + } + if v, _ := capturedReq.Body["issuerUrl"].(string); v != "https://my-idc-instance.awsapps.com/start" { + t.Errorf("issuerUrl = %q, want %q", v, "https://my-idc-instance.awsapps.com/start") + } + + // Verify scopes array + scopesRaw, ok := capturedReq.Body["scopes"].([]interface{}) + if !ok || len(scopesRaw) != 5 { + t.Fatalf("scopes: got %v, want 5-element array", capturedReq.Body["scopes"]) + } + expectedScopes := []string{ + "codewhisperer:completions", "codewhisperer:analysis", + "codewhisperer:conversations", "codewhisperer:transformations", + "codewhisperer:taskassist", + } + for i, s := range expectedScopes { + if scopesRaw[i].(string) != s { + t.Errorf("scopes[%d] = %q, want %q", i, scopesRaw[i], s) + } + } + + // Verify grantTypes + grantTypesRaw, ok := capturedReq.Body["grantTypes"].([]interface{}) + if !ok || len(grantTypesRaw) != 2 { + t.Fatalf("grantTypes: got %v, want 2-element array", capturedReq.Body["grantTypes"]) + } + if grantTypesRaw[0].(string) != "authorization_code" || grantTypesRaw[1].(string) != "refresh_token" { + t.Errorf("grantTypes = %v, want [authorization_code, refresh_token]", grantTypesRaw) + } + + // Verify redirectUris + redirectRaw, ok := capturedReq.Body["redirectUris"].([]interface{}) + if !ok || len(redirectRaw) != 1 { + t.Fatalf("redirectUris: got %v, want 1-element array", capturedReq.Body["redirectUris"]) + } + if redirectRaw[0].(string) != "http://127.0.0.1:19877/oauth/callback" { + t.Errorf("redirectUris[0] = %q, want %q", redirectRaw[0], "http://127.0.0.1:19877/oauth/callback") + } + + // Verify response parsing + if resp.ClientID != "test-client-id" { + t.Errorf("ClientID = %q, want %q", resp.ClientID, "test-client-id") + } + if resp.ClientSecret != "test-client-secret" { + t.Errorf("ClientSecret = %q, want %q", resp.ClientSecret, "test-client-secret") + } + if resp.ClientIDIssuedAt != 1700000000 { + t.Errorf("ClientIDIssuedAt = %d, want %d", resp.ClientIDIssuedAt, 1700000000) + } + if resp.ClientSecretExpiresAt != 1700086400 { + t.Errorf("ClientSecretExpiresAt = %d, want %d", resp.ClientSecretExpiresAt, 1700086400) + } +} + +// rewriteTransport redirects all requests to the test server URL. +type rewriteTransport struct { + base http.RoundTripper + targetURL string +} + +func (t *rewriteTransport) RoundTrip(req *http.Request) (*http.Response, error) { + target, _ := url.Parse(t.targetURL) + req.URL.Scheme = target.Scheme + req.URL.Host = target.Host + if t.base != nil { + return t.base.RoundTrip(req) + } + return http.DefaultTransport.RoundTrip(req) +} + +func TestBuildAuthorizationURL(t *testing.T) { + scopes := "codewhisperer:completions,codewhisperer:analysis,codewhisperer:conversations,codewhisperer:transformations,codewhisperer:taskassist" + endpoint := "https://oidc.us-east-1.amazonaws.com" + redirectURI := "http://127.0.0.1:19877/oauth/callback" + + authURL := buildAuthorizationURL(endpoint, "test-client-id", redirectURI, scopes, "random-state", "test-challenge") + + // Verify colons and commas in scopes are percent-encoded + if !strings.Contains(authURL, "codewhisperer%3Acompletions") { + t.Errorf("expected colons in scopes to be percent-encoded, got: %s", authURL) + } + if !strings.Contains(authURL, "completions%2Ccodewhisperer") { + t.Errorf("expected commas in scopes to be percent-encoded, got: %s", authURL) + } + + // Parse back and verify all parameters round-trip correctly + parsed, err := url.Parse(authURL) + if err != nil { + t.Fatalf("failed to parse auth URL: %v", err) + } + + if !strings.HasPrefix(authURL, endpoint+"/authorize?") { + t.Errorf("expected URL to start with %s/authorize?, got: %s", endpoint, authURL) + } + + q := parsed.Query() + checks := map[string]string{ + "response_type": "code", + "client_id": "test-client-id", + "redirect_uri": redirectURI, + "scopes": scopes, + "state": "random-state", + "code_challenge": "test-challenge", + "code_challenge_method": "S256", + } + for key, want := range checks { + if got := q.Get(key); got != want { + t.Errorf("%s = %q, want %q", key, got, want) + } + } +} diff --git a/internal/auth/kiro/token.go b/internal/auth/kiro/token.go index 0484a2dc..91a4995b 100644 --- a/internal/auth/kiro/token.go +++ b/internal/auth/kiro/token.go @@ -29,7 +29,7 @@ type KiroTokenStorage struct { ClientID string `json:"client_id,omitempty"` // ClientSecret is the OAuth client secret (required for token refresh) ClientSecret string `json:"client_secret,omitempty"` - // Region is the AWS region + // Region is the OIDC region for IDC login and token refresh Region string `json:"region,omitempty"` // StartURL is the AWS Identity Center start URL (for IDC auth) StartURL string `json:"start_url,omitempty"` diff --git a/internal/auth/kiro/token_repository.go b/internal/auth/kiro/token_repository.go index 815f1827..3ddf620e 100644 --- a/internal/auth/kiro/token_repository.go +++ b/internal/auth/kiro/token_repository.go @@ -200,36 +200,22 @@ func (r *FileTokenRepository) readTokenFile(path string) (*Token, error) { } // 解析各字段 - if v, ok := metadata["access_token"].(string); ok { - token.AccessToken = v - } - if v, ok := metadata["refresh_token"].(string); ok { - token.RefreshToken = v - } - if v, ok := metadata["client_id"].(string); ok { - token.ClientID = v - } - if v, ok := metadata["client_secret"].(string); ok { - token.ClientSecret = v - } - if v, ok := metadata["region"].(string); ok { - token.Region = v - } - if v, ok := metadata["start_url"].(string); ok { - token.StartURL = v - } - if v, ok := metadata["provider"].(string); ok { - token.Provider = v - } + token.AccessToken, _ = metadata["access_token"].(string) + token.RefreshToken, _ = metadata["refresh_token"].(string) + token.ClientID, _ = metadata["client_id"].(string) + token.ClientSecret, _ = metadata["client_secret"].(string) + token.Region, _ = metadata["region"].(string) + token.StartURL, _ = metadata["start_url"].(string) + token.Provider, _ = metadata["provider"].(string) // 解析时间字段 - if v, ok := metadata["expires_at"].(string); ok { - if t, err := time.Parse(time.RFC3339, v); err == nil { + if expiresAtStr, ok := metadata["expires_at"].(string); ok && expiresAtStr != "" { + if t, err := time.Parse(time.RFC3339, expiresAtStr); err == nil { token.ExpiresAt = t } } - if v, ok := metadata["last_refresh"].(string); ok { - if t, err := time.Parse(time.RFC3339, v); err == nil { + if lastRefreshStr, ok := metadata["last_refresh"].(string); ok && lastRefreshStr != "" { + if t, err := time.Parse(time.RFC3339, lastRefreshStr); err == nil { token.LastVerified = t } } diff --git a/internal/auth/kiro/usage_checker.go b/internal/auth/kiro/usage_checker.go index 94870214..363f6384 100644 --- a/internal/auth/kiro/usage_checker.go +++ b/internal/auth/kiro/usage_checker.go @@ -8,7 +8,6 @@ import ( "fmt" "io" "net/http" - "strings" "time" "github.com/router-for-me/CLIProxyAPI/v6/internal/config" @@ -51,14 +50,12 @@ type QuotaStatus struct { // UsageChecker provides methods for checking token quota usage. type UsageChecker struct { httpClient *http.Client - endpoint string } // NewUsageChecker creates a new UsageChecker instance. func NewUsageChecker(cfg *config.Config) *UsageChecker { return &UsageChecker{ httpClient: util.SetProxy(&cfg.SDKConfig, &http.Client{Timeout: 30 * time.Second}), - endpoint: awsKiroEndpoint, } } @@ -66,7 +63,6 @@ func NewUsageChecker(cfg *config.Config) *UsageChecker { func NewUsageCheckerWithClient(client *http.Client) *UsageChecker { return &UsageChecker{ httpClient: client, - endpoint: awsKiroEndpoint, } } @@ -80,26 +76,23 @@ func (c *UsageChecker) CheckUsage(ctx context.Context, tokenData *KiroTokenData) return nil, fmt.Errorf("access token is empty") } - payload := map[string]interface{}{ + queryParams := map[string]string{ "origin": "AI_EDITOR", "profileArn": tokenData.ProfileArn, "resourceType": "AGENTIC_REQUEST", } - jsonBody, err := json.Marshal(payload) - if err != nil { - return nil, fmt.Errorf("failed to marshal request: %w", err) - } + // Use endpoint from profileArn if available + endpoint := GetKiroAPIEndpointFromProfileArn(tokenData.ProfileArn) + url := buildURL(endpoint, pathGetUsageLimits, queryParams) - req, err := http.NewRequestWithContext(ctx, http.MethodPost, c.endpoint, strings.NewReader(string(jsonBody))) + req, err := http.NewRequestWithContext(ctx, http.MethodGet, url, nil) if err != nil { return nil, fmt.Errorf("failed to create request: %w", err) } - req.Header.Set("Content-Type", "application/x-amz-json-1.0") - req.Header.Set("x-amz-target", targetGetUsage) - req.Header.Set("Authorization", "Bearer "+tokenData.AccessToken) - req.Header.Set("Accept", "application/json") + accountKey := GetAccountKey(tokenData.ClientID, tokenData.RefreshToken) + setRuntimeHeaders(req, tokenData.AccessToken, accountKey) resp, err := c.httpClient.Do(req) if err != nil { diff --git a/internal/cmd/kiro_login.go b/internal/cmd/kiro_login.go index 74d09686..075e48e6 100644 --- a/internal/cmd/kiro_login.go +++ b/internal/cmd/kiro_login.go @@ -206,3 +206,52 @@ func DoKiroImport(cfg *config.Config, options *LoginOptions) { } fmt.Println("Kiro token import successful!") } + +func DoKiroIDCLogin(cfg *config.Config, options *LoginOptions, startURL, region, flow string) { + if options == nil { + options = &LoginOptions{} + } + + if startURL == "" { + log.Errorf("Kiro IDC login requires --kiro-idc-start-url") + fmt.Println("\nUsage: --kiro-idc-login --kiro-idc-start-url https://d-xxx.awsapps.com/start") + return + } + + manager := newAuthManager() + + authenticator := sdkAuth.NewKiroAuthenticator() + metadata := map[string]string{ + "start-url": startURL, + "region": region, + "flow": flow, + } + + record, err := authenticator.Login(context.Background(), cfg, &sdkAuth.LoginOptions{ + NoBrowser: options.NoBrowser, + Metadata: metadata, + Prompt: options.Prompt, + }) + if err != nil { + log.Errorf("Kiro IDC authentication failed: %v", err) + fmt.Println("\nTroubleshooting:") + fmt.Println("1. Make sure your IDC Start URL is correct") + fmt.Println("2. Complete the authorization in the browser") + fmt.Println("3. If auth code flow fails, try: --kiro-idc-flow device") + return + } + + savedPath, err := manager.SaveAuth(record, cfg) + if err != nil { + log.Errorf("Failed to save auth: %v", err) + return + } + + if savedPath != "" { + fmt.Printf("Authentication saved to %s\n", savedPath) + } + if record != nil && record.Label != "" { + fmt.Printf("Authenticated as %s\n", record.Label) + } + fmt.Println("Kiro IDC authentication successful!") +} diff --git a/internal/config/config.go b/internal/config/config.go index eb887384..87735aa2 100644 --- a/internal/config/config.go +++ b/internal/config/config.go @@ -87,6 +87,10 @@ type Config struct { // KiroKey defines a list of Kiro (AWS CodeWhisperer) configurations. KiroKey []KiroKey `yaml:"kiro" json:"kiro"` + // KiroFingerprint defines a global fingerprint configuration for all Kiro requests. + // When set, all Kiro requests will use this fixed fingerprint instead of random generation. + KiroFingerprint *KiroFingerprintConfig `yaml:"kiro-fingerprint,omitempty" json:"kiro-fingerprint,omitempty"` + // KiroPreferredEndpoint sets the global default preferred endpoint for all Kiro providers. // Values: "ide" (default, CodeWhisperer) or "cli" (Amazon Q). KiroPreferredEndpoint string `yaml:"kiro-preferred-endpoint" json:"kiro-preferred-endpoint"` @@ -477,6 +481,9 @@ type KiroKey struct { // Region is the AWS region (default: us-east-1). Region string `yaml:"region,omitempty" json:"region,omitempty"` + // StartURL is the IAM Identity Center (IDC) start URL for SSO login. + StartURL string `yaml:"start-url,omitempty" json:"start-url,omitempty"` + // ProxyURL optionally overrides the global proxy for this configuration. ProxyURL string `yaml:"proxy-url,omitempty" json:"proxy-url,omitempty"` @@ -489,6 +496,20 @@ type KiroKey struct { PreferredEndpoint string `yaml:"preferred-endpoint,omitempty" json:"preferred-endpoint,omitempty"` } +// KiroFingerprintConfig defines a global fingerprint configuration for Kiro requests. +// When configured, all Kiro requests will use this fixed fingerprint instead of random generation. +// Empty fields will fall back to random selection from built-in pools. +type KiroFingerprintConfig struct { + OIDCSDKVersion string `yaml:"oidc-sdk-version,omitempty" json:"oidc-sdk-version,omitempty"` + RuntimeSDKVersion string `yaml:"runtime-sdk-version,omitempty" json:"runtime-sdk-version,omitempty"` + StreamingSDKVersion string `yaml:"streaming-sdk-version,omitempty" json:"streaming-sdk-version,omitempty"` + OSType string `yaml:"os-type,omitempty" json:"os-type,omitempty"` + OSVersion string `yaml:"os-version,omitempty" json:"os-version,omitempty"` + NodeVersion string `yaml:"node-version,omitempty" json:"node-version,omitempty"` + KiroVersion string `yaml:"kiro-version,omitempty" json:"kiro-version,omitempty"` + KiroHash string `yaml:"kiro-hash,omitempty" json:"kiro-hash,omitempty"` +} + // OpenAICompatibility represents the configuration for OpenAI API compatibility // with external providers, allowing model aliases to be routed through OpenAI API format. type OpenAICompatibility struct { diff --git a/internal/runtime/executor/kiro_executor.go b/internal/runtime/executor/kiro_executor.go index 3d1e2d51..f5e5d9ae 100644 --- a/internal/runtime/executor/kiro_executor.go +++ b/internal/runtime/executor/kiro_executor.go @@ -49,15 +49,8 @@ const ( ErrStreamFatal = "fatal" // Connection/authentication errors, not recoverable ErrStreamMalformed = "malformed" // Format errors, data cannot be parsed - // kiroUserAgent matches Amazon Q CLI style for User-Agent header - kiroUserAgent = "aws-sdk-rust/1.3.9 os/macos lang/rust/1.87.0" - // kiroFullUserAgent is the complete x-amz-user-agent header (Amazon Q CLI style) - kiroFullUserAgent = "aws-sdk-rust/1.3.9 ua/2.1 api/ssooidc/1.88.0 os/macos lang/rust/1.87.0 m/E app/AmazonQ-For-CLI" - - // Kiro IDE style headers for IDC auth - kiroIDEUserAgent = "aws-sdk-js/1.0.27 ua/2.1 os/win32#10.0.19044 lang/js md/nodejs#22.21.1 api/codewhispererstreaming#1.0.27 m/E" - kiroIDEAmzUserAgent = "aws-sdk-js/1.0.27" - kiroIDEAgentModeVibe = "vibe" + // kiroIDEAgentMode is the agent mode header value for Kiro IDE requests + kiroIDEAgentMode = "vibe" // Socket retry configuration constants // Maximum number of retry attempts for socket/network errors @@ -87,20 +80,13 @@ var ( usageUpdateTimeInterval = 15 * time.Second // Or every 15 seconds, whichever comes first ) -// Global FingerprintManager for dynamic User-Agent generation per token -// Each token gets a unique fingerprint on first use, which is cached for subsequent requests -var ( - globalFingerprintManager *kiroauth.FingerprintManager - globalFingerprintManagerOnce sync.Once -) - -// getGlobalFingerprintManager returns the global FingerprintManager instance -func getGlobalFingerprintManager() *kiroauth.FingerprintManager { - globalFingerprintManagerOnce.Do(func() { - globalFingerprintManager = kiroauth.NewFingerprintManager() - log.Infof("kiro: initialized global FingerprintManager for dynamic UA generation") - }) - return globalFingerprintManager +// endpointAliases maps user preference values to canonical endpoint names. +var endpointAliases = map[string]string{ + "codewhisperer": "codewhisperer", + "ide": "codewhisperer", + "amazonq": "amazonq", + "q": "amazonq", + "cli": "amazonq", } // retryConfig holds configuration for socket retry logic. @@ -433,87 +419,41 @@ func getKiroEndpointConfigs(auth *cliproxyauth.Auth) []kiroEndpointConfig { return kiroEndpointConfigs } - // Determine API region using shared resolution logic region := resolveKiroAPIRegion(auth) + log.Debugf("kiro: using region %s", region) - // Build endpoint configs for the specified region - endpointConfigs := buildKiroEndpointConfigs(region) - - // For IDC auth, use Q endpoint with AI_EDITOR origin - // IDC tokens work with Q endpoint using Bearer auth - // The difference is only in how tokens are refreshed (OIDC with clientId/clientSecret for IDC) - // NOT in how API calls are made - both Social and IDC use the same endpoint/origin - if auth.Metadata != nil { - authMethod, _ := auth.Metadata["auth_method"].(string) - if strings.ToLower(authMethod) == "idc" { - log.Debugf("kiro: IDC auth, using Q endpoint (region: %s)", region) - return endpointConfigs - } - } - - // Check for preference - var preference string - if auth.Metadata != nil { - if p, ok := auth.Metadata["preferred_endpoint"].(string); ok { - preference = p - } - } - // Check attributes as fallback (e.g. from HTTP headers) - if preference == "" && auth.Attributes != nil { - preference = auth.Attributes["preferred_endpoint"] - } + configs := buildKiroEndpointConfigs(region) + preference := getAuthValue(auth, "preferred_endpoint") if preference == "" { - return endpointConfigs + return configs } - preference = strings.ToLower(strings.TrimSpace(preference)) + targetName, ok := endpointAliases[preference] + if !ok { + return configs + } - // Create new slice to avoid modifying global state - var sorted []kiroEndpointConfig - var remaining []kiroEndpointConfig - - for _, cfg := range endpointConfigs { - name := strings.ToLower(cfg.Name) - // Check for matches - // CodeWhisperer aliases: codewhisperer, ide - // AmazonQ aliases: amazonq, q, cli - isMatch := false - if (preference == "codewhisperer" || preference == "ide") && name == "codewhisperer" { - isMatch = true - } else if (preference == "amazonq" || preference == "q" || preference == "cli") && name == "amazonq" { - isMatch = true - } - - if isMatch { - sorted = append(sorted, cfg) + var preferred, others []kiroEndpointConfig + for _, cfg := range configs { + if strings.ToLower(cfg.Name) == targetName { + preferred = append(preferred, cfg) } else { - remaining = append(remaining, cfg) + others = append(others, cfg) } } - // If preference didn't match anything, return default - if len(sorted) == 0 { - return endpointConfigs + if len(preferred) == 0 { + return configs } - - // Combine: preferred first, then others - return append(sorted, remaining...) + return append(preferred, others...) } // KiroExecutor handles requests to AWS CodeWhisperer (Kiro) API. type KiroExecutor struct { - cfg *config.Config - refreshMu sync.Mutex // Serializes token refresh operations to prevent race conditions -} - -// isIDCAuth checks if the auth uses IDC (Identity Center) authentication method. -func isIDCAuth(auth *cliproxyauth.Auth) bool { - if auth == nil || auth.Metadata == nil { - return false - } - authMethod, _ := auth.Metadata["auth_method"].(string) - return strings.ToLower(authMethod) == "idc" + cfg *config.Config + refreshMu sync.Mutex // Serializes token refresh operations to prevent race conditions + profileArnMu sync.Mutex // Serializes profileArn fetches to prevent concurrent map writes } // buildKiroPayloadForFormat builds the Kiro API payload based on the source format. @@ -546,27 +486,22 @@ func NewKiroExecutor(cfg *config.Config) *KiroExecutor { // Identifier returns the unique identifier for this executor. func (e *KiroExecutor) Identifier() string { return "kiro" } -// applyDynamicFingerprint applies token-specific fingerprint headers to the request -// For IDC auth, uses dynamic fingerprint-based User-Agent -// For other auth types, uses static Amazon Q CLI style headers +// applyDynamicFingerprint applies account-specific fingerprint headers to the request. func applyDynamicFingerprint(req *http.Request, auth *cliproxyauth.Auth) { - if isIDCAuth(auth) { - // Get token-specific fingerprint for dynamic UA generation - tokenKey := getTokenKey(auth) - fp := getGlobalFingerprintManager().GetFingerprint(tokenKey) + accountKey := getAccountKey(auth) + fp := kiroauth.GlobalFingerprintManager().GetFingerprint(accountKey) - // Use fingerprint-generated dynamic User-Agent - req.Header.Set("User-Agent", fp.BuildUserAgent()) - req.Header.Set("X-Amz-User-Agent", fp.BuildAmzUserAgent()) - req.Header.Set("x-amzn-kiro-agent-mode", kiroIDEAgentModeVibe) + req.Header.Set("User-Agent", fp.BuildUserAgent()) + req.Header.Set("X-Amz-User-Agent", fp.BuildAmzUserAgent()) + req.Header.Set("x-amzn-kiro-agent-mode", kiroIDEAgentMode) + req.Header.Set("x-amzn-codewhisperer-optout", "true") - log.Debugf("kiro: using dynamic fingerprint for token %s (SDK:%s, OS:%s/%s, Kiro:%s)", - tokenKey[:8]+"...", fp.SDKVersion, fp.OSType, fp.OSVersion, fp.KiroVersion) - } else { - // Use static Amazon Q CLI style headers for non-IDC auth - req.Header.Set("User-Agent", kiroUserAgent) - req.Header.Set("X-Amz-User-Agent", kiroFullUserAgent) + keyPrefix := accountKey + if len(keyPrefix) > 8 { + keyPrefix = keyPrefix[:8] } + log.Debugf("kiro: using dynamic fingerprint for account %s (SDK:%s, OS:%s/%s, Kiro:%s)", + keyPrefix+"...", fp.StreamingSDKVersion, fp.OSType, fp.OSVersion, fp.KiroVersion) } // PrepareRequest prepares the HTTP request before execution. @@ -609,17 +544,51 @@ func (e *KiroExecutor) HttpRequest(ctx context.Context, auth *cliproxyauth.Auth, return httpClient.Do(httpReq) } -// getTokenKey returns a unique key for rate limiting based on auth credentials. -// Uses auth ID if available, otherwise falls back to a hash of the access token. -func getTokenKey(auth *cliproxyauth.Auth) string { +// getAccountKey returns a stable account key for fingerprint lookup and rate limiting. +// Fallback order: +// 1) client_id / refresh_token (best account identity) +// 2) auth.ID (stable local auth record) +// 3) profile_arn (stable AWS profile identity) +// 4) access_token (least preferred but deterministic) +// 5) fixed anonymous seed +func getAccountKey(auth *cliproxyauth.Auth) string { + var clientID, refreshToken, profileArn string + if auth != nil && auth.Metadata != nil { + clientID, _ = auth.Metadata["client_id"].(string) + refreshToken, _ = auth.Metadata["refresh_token"].(string) + profileArn, _ = auth.Metadata["profile_arn"].(string) + } + if clientID != "" || refreshToken != "" { + return kiroauth.GetAccountKey(clientID, refreshToken) + } if auth != nil && auth.ID != "" { - return auth.ID + return kiroauth.GenerateAccountKey(auth.ID) } - accessToken, _ := kiroCredentials(auth) - if len(accessToken) > 16 { - return accessToken[:16] + if profileArn != "" { + return kiroauth.GenerateAccountKey(profileArn) } - return accessToken + if accessToken, _ := kiroCredentials(auth); accessToken != "" { + return kiroauth.GenerateAccountKey(accessToken) + } + return kiroauth.GenerateAccountKey("kiro-anonymous") +} + +// getAuthValue looks up a value by key in auth Metadata, then Attributes. +func getAuthValue(auth *cliproxyauth.Auth, key string) string { + if auth == nil { + return "" + } + if auth.Metadata != nil { + if v, ok := auth.Metadata[key].(string); ok && v != "" { + return strings.ToLower(strings.TrimSpace(v)) + } + } + if auth.Attributes != nil { + if v := auth.Attributes[key]; v != "" { + return strings.ToLower(strings.TrimSpace(v)) + } + } + return "" } // Execute sends the request to Kiro API and returns the response. @@ -631,7 +600,7 @@ func (e *KiroExecutor) Execute(ctx context.Context, auth *cliproxyauth.Auth, req } // Rate limiting: get token key for tracking - tokenKey := getTokenKey(auth) + tokenKey := getAccountKey(auth) rateLimiter := kiroauth.GetGlobalRateLimiter() cooldownMgr := kiroauth.GetGlobalCooldownManager() @@ -693,6 +662,13 @@ func (e *KiroExecutor) Execute(ctx context.Context, auth *cliproxyauth.Auth, req kiroModelID := e.mapModelToKiro(req.Model) + // Fetch profileArn if missing (for imported accounts from Kiro IDE) + if profileArn == "" { + if fetched := e.fetchAndSaveProfileArn(ctx, auth, accessToken); fetched != "" { + profileArn = fetched + } + } + // Determine agentic mode and effective profile ARN using helper functions isAgentic, isChatOnly := determineAgenticMode(req.Model) effectiveProfileArn := getEffectiveProfileArnWithWarning(auth, profileArn) @@ -749,7 +725,7 @@ func (e *KiroExecutor) executeWithRetry(ctx context.Context, auth *cliproxyauth. httpReq.Header.Set("X-Amz-Target", endpointConfig.AmzTarget) } // Kiro-specific headers - httpReq.Header.Set("x-amzn-kiro-agent-mode", kiroIDEAgentModeVibe) + httpReq.Header.Set("x-amzn-kiro-agent-mode", kiroIDEAgentMode) httpReq.Header.Set("x-amzn-codewhisperer-optout", "true") // Apply dynamic fingerprint-based headers @@ -1060,7 +1036,7 @@ func (e *KiroExecutor) ExecuteStream(ctx context.Context, auth *cliproxyauth.Aut } // Rate limiting: get token key for tracking - tokenKey := getTokenKey(auth) + tokenKey := getAccountKey(auth) rateLimiter := kiroauth.GetGlobalRateLimiter() cooldownMgr := kiroauth.GetGlobalCooldownManager() @@ -1126,6 +1102,13 @@ func (e *KiroExecutor) ExecuteStream(ctx context.Context, auth *cliproxyauth.Aut kiroModelID := e.mapModelToKiro(req.Model) + // Fetch profileArn if missing (for imported accounts from Kiro IDE) + if profileArn == "" { + if fetched := e.fetchAndSaveProfileArn(ctx, auth, accessToken); fetched != "" { + profileArn = fetched + } + } + // Determine agentic mode and effective profile ARN using helper functions isAgentic, isChatOnly := determineAgenticMode(req.Model) effectiveProfileArn := getEffectiveProfileArnWithWarning(auth, profileArn) @@ -1185,7 +1168,7 @@ func (e *KiroExecutor) executeStreamWithRetry(ctx context.Context, auth *cliprox httpReq.Header.Set("X-Amz-Target", endpointConfig.AmzTarget) } // Kiro-specific headers - httpReq.Header.Set("x-amzn-kiro-agent-mode", kiroIDEAgentModeVibe) + httpReq.Header.Set("x-amzn-kiro-agent-mode", kiroIDEAgentMode) httpReq.Header.Set("x-amzn-codewhisperer-optout", "true") // Apply dynamic fingerprint-based headers @@ -1647,62 +1630,23 @@ func determineAgenticMode(model string) (isAgentic, isChatOnly bool) { return isAgentic, isChatOnly } -// getEffectiveProfileArn determines if profileArn should be included based on auth method. -// profileArn is only needed for social auth (Google OAuth), not for AWS SSO OIDC (Builder ID/IDC). -// -// Detection logic (matching kiro-openai-gateway): -// 1. Check auth_method field: "builder-id" or "idc" -// 2. Check auth_type field: "aws_sso_oidc" (from kiro-cli tokens) -// 3. Check for client_id + client_secret presence (AWS SSO OIDC signature) -func getEffectiveProfileArn(auth *cliproxyauth.Auth, profileArn string) string { - if auth != nil && auth.Metadata != nil { - // Check 1: auth_method field (from CLIProxyAPI tokens) - if authMethod, ok := auth.Metadata["auth_method"].(string); ok && (authMethod == "builder-id" || authMethod == "idc") { - return "" // AWS SSO OIDC - don't include profileArn - } - // Check 2: auth_type field (from kiro-cli tokens) - if authType, ok := auth.Metadata["auth_type"].(string); ok && authType == "aws_sso_oidc" { - return "" // AWS SSO OIDC - don't include profileArn - } - // Check 3: client_id + client_secret presence (AWS SSO OIDC signature) - _, hasClientID := auth.Metadata["client_id"].(string) - _, hasClientSecret := auth.Metadata["client_secret"].(string) - if hasClientID && hasClientSecret { - return "" // AWS SSO OIDC - don't include profileArn - } - } - return profileArn -} - -// getEffectiveProfileArnWithWarning determines if profileArn should be included based on auth method, -// and logs a warning if profileArn is missing for non-builder-id auth. -// This consolidates the auth_method check that was previously done separately. -// -// AWS SSO OIDC (Builder ID/IDC) users don't need profileArn - sending it causes 403 errors. -// Only Kiro Desktop (social auth like Google/GitHub) users need profileArn. -// -// Detection logic (matching kiro-openai-gateway): -// 1. Check auth_method field: "builder-id" or "idc" -// 2. Check auth_type field: "aws_sso_oidc" (from kiro-cli tokens) -// 3. Check for client_id + client_secret presence (AWS SSO OIDC signature) +// getEffectiveProfileArnWithWarning suppresses profileArn for builder-id and AWS SSO OIDC auth. +// Builder-id users (auth_method == "builder-id") and AWS SSO OIDC users (auth_type == "aws_sso_oidc") +// don't need profileArn — sending it causes 403 errors. +// For all other auth methods (e.g. social auth), profileArn is returned as-is, +// with a warning logged if it is empty. func getEffectiveProfileArnWithWarning(auth *cliproxyauth.Auth, profileArn string) string { if auth != nil && auth.Metadata != nil { - // Check 1: auth_method field (from CLIProxyAPI tokens) - if authMethod, ok := auth.Metadata["auth_method"].(string); ok && (authMethod == "builder-id" || authMethod == "idc") { - return "" // AWS SSO OIDC - don't include profileArn + // Check 1: auth_method field, skip for builder-id only + if authMethod, ok := auth.Metadata["auth_method"].(string); ok && authMethod == "builder-id" { + return "" } // Check 2: auth_type field (from kiro-cli tokens) if authType, ok := auth.Metadata["auth_type"].(string); ok && authType == "aws_sso_oidc" { return "" // AWS SSO OIDC - don't include profileArn } - // Check 3: client_id + client_secret presence (AWS SSO OIDC signature, like kiro-openai-gateway) - _, hasClientID := auth.Metadata["client_id"].(string) - _, hasClientSecret := auth.Metadata["client_secret"].(string) - if hasClientID && hasClientSecret { - return "" // AWS SSO OIDC - don't include profileArn - } } - // For social auth (Kiro Desktop), profileArn is required + // For social auth and IDC, profileArn is required if profileArn == "" { log.Warnf("kiro: profile ARN not found in auth, API calls may fail") } @@ -3999,6 +3943,51 @@ func (e *KiroExecutor) persistRefreshedAuth(auth *cliproxyauth.Auth) error { return nil } +// fetchAndSaveProfileArn fetches profileArn from API if missing, updates auth and persists to file. +func (e *KiroExecutor) fetchAndSaveProfileArn(ctx context.Context, auth *cliproxyauth.Auth, accessToken string) string { + if auth == nil || auth.Metadata == nil { + return "" + } + + // Skip for Builder ID - they don't have profiles + if authMethod, ok := auth.Metadata["auth_method"].(string); ok && authMethod == "builder-id" { + log.Debugf("kiro executor: skipping profileArn fetch for builder-id auth") + return "" + } + + e.profileArnMu.Lock() + defer e.profileArnMu.Unlock() + + // Double-check: another goroutine may have already fetched and saved the profileArn + if arn, ok := auth.Metadata["profile_arn"].(string); ok && arn != "" { + return arn + } + + clientID, _ := auth.Metadata["client_id"].(string) + refreshToken, _ := auth.Metadata["refresh_token"].(string) + + ssoClient := kiroauth.NewSSOOIDCClient(e.cfg) + profileArn := ssoClient.FetchProfileArn(ctx, accessToken, clientID, refreshToken) + if profileArn == "" { + log.Debugf("kiro executor: FetchProfileArn returned no profiles") + return "" + } + + auth.Metadata["profile_arn"] = profileArn + if auth.Attributes == nil { + auth.Attributes = make(map[string]string) + } + auth.Attributes["profile_arn"] = profileArn + + if err := e.persistRefreshedAuth(auth); err != nil { + log.Warnf("kiro executor: failed to persist profileArn: %v", err) + } else { + log.Infof("kiro executor: fetched and saved profileArn: %s", profileArn) + } + + return profileArn +} + // reloadAuthFromFile 从文件重新加载 auth 数据(方案 B: Fallback 机制) // 当内存中的 token 已过期时,尝试从文件读取最新的 token // 这解决了后台刷新器已更新文件但内存中 Auth 对象尚未同步的时间差问题 @@ -4728,7 +4717,7 @@ func (e *KiroExecutor) callKiroAndBuffer( isAgentic, isChatOnly := determineAgenticMode(req.Model) effectiveProfileArn := getEffectiveProfileArnWithWarning(auth, profileArn) - tokenKey := getTokenKey(auth) + tokenKey := getAccountKey(auth) kiroStream, err := e.executeStreamWithRetry( ctx, auth, req, opts, accessToken, effectiveProfileArn, @@ -4770,7 +4759,7 @@ func (e *KiroExecutor) callKiroDirectStream( isAgentic, isChatOnly := determineAgenticMode(req.Model) effectiveProfileArn := getEffectiveProfileArnWithWarning(auth, profileArn) - tokenKey := getTokenKey(auth) + tokenKey := getAccountKey(auth) reporter := newUsageReporter(ctx, e.Identifier(), req.Model, auth) var streamErr error @@ -4819,7 +4808,7 @@ func (e *KiroExecutor) executeNonStreamFallback( kiroModelID := e.mapModelToKiro(req.Model) isAgentic, isChatOnly := determineAgenticMode(req.Model) effectiveProfileArn := getEffectiveProfileArnWithWarning(auth, profileArn) - tokenKey := getTokenKey(auth) + tokenKey := getAccountKey(auth) reporter := newUsageReporter(ctx, e.Identifier(), req.Model, auth) var err error diff --git a/internal/runtime/executor/kiro_executor_test.go b/internal/runtime/executor/kiro_executor_test.go new file mode 100644 index 00000000..7a2819fd --- /dev/null +++ b/internal/runtime/executor/kiro_executor_test.go @@ -0,0 +1,423 @@ +package executor + +import ( + "fmt" + "testing" + + kiroauth "github.com/router-for-me/CLIProxyAPI/v6/internal/auth/kiro" + cliproxyauth "github.com/router-for-me/CLIProxyAPI/v6/sdk/cliproxy/auth" +) + +func TestBuildKiroEndpointConfigs(t *testing.T) { + tests := []struct { + name string + region string + expectedURL string + expectedOrigin string + expectedName string + }{ + { + name: "Empty region - defaults to us-east-1", + region: "", + expectedURL: "https://q.us-east-1.amazonaws.com/generateAssistantResponse", + expectedOrigin: "AI_EDITOR", + expectedName: "AmazonQ", + }, + { + name: "us-east-1", + region: "us-east-1", + expectedURL: "https://q.us-east-1.amazonaws.com/generateAssistantResponse", + expectedOrigin: "AI_EDITOR", + expectedName: "AmazonQ", + }, + { + name: "ap-southeast-1", + region: "ap-southeast-1", + expectedURL: "https://q.ap-southeast-1.amazonaws.com/generateAssistantResponse", + expectedOrigin: "AI_EDITOR", + expectedName: "AmazonQ", + }, + { + name: "eu-west-1", + region: "eu-west-1", + expectedURL: "https://q.eu-west-1.amazonaws.com/generateAssistantResponse", + expectedOrigin: "AI_EDITOR", + expectedName: "AmazonQ", + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + configs := buildKiroEndpointConfigs(tt.region) + + if len(configs) != 2 { + t.Fatalf("expected 2 endpoint configs, got %d", len(configs)) + } + + // Check primary endpoint (AmazonQ) + primary := configs[0] + if primary.URL != tt.expectedURL { + t.Errorf("primary URL = %q, want %q", primary.URL, tt.expectedURL) + } + if primary.Origin != tt.expectedOrigin { + t.Errorf("primary Origin = %q, want %q", primary.Origin, tt.expectedOrigin) + } + if primary.Name != tt.expectedName { + t.Errorf("primary Name = %q, want %q", primary.Name, tt.expectedName) + } + if primary.AmzTarget != "" { + t.Errorf("primary AmzTarget should be empty, got %q", primary.AmzTarget) + } + + // Check fallback endpoint (CodeWhisperer) + fallback := configs[1] + if fallback.Name != "CodeWhisperer" { + t.Errorf("fallback Name = %q, want %q", fallback.Name, "CodeWhisperer") + } + // CodeWhisperer fallback uses the same region as Q endpoint + expectedRegion := tt.region + if expectedRegion == "" { + expectedRegion = kiroDefaultRegion + } + expectedFallbackURL := fmt.Sprintf("https://codewhisperer.%s.amazonaws.com/generateAssistantResponse", expectedRegion) + if fallback.URL != expectedFallbackURL { + t.Errorf("fallback URL = %q, want %q", fallback.URL, expectedFallbackURL) + } + if fallback.AmzTarget == "" { + t.Error("fallback AmzTarget should NOT be empty") + } + }) + } +} + +func TestGetKiroEndpointConfigs_NilAuth(t *testing.T) { + configs := getKiroEndpointConfigs(nil) + + if len(configs) != 2 { + t.Fatalf("expected 2 endpoint configs, got %d", len(configs)) + } + + // Should return default us-east-1 configs + if configs[0].Name != "AmazonQ" { + t.Errorf("first config Name = %q, want %q", configs[0].Name, "AmazonQ") + } + expectedURL := "https://q.us-east-1.amazonaws.com/generateAssistantResponse" + if configs[0].URL != expectedURL { + t.Errorf("first config URL = %q, want %q", configs[0].URL, expectedURL) + } +} + +func TestGetKiroEndpointConfigs_WithRegionFromProfileArn(t *testing.T) { + auth := &cliproxyauth.Auth{ + Metadata: map[string]any{ + "profile_arn": "arn:aws:codewhisperer:ap-southeast-1:123456789012:profile/ABC", + }, + } + + configs := getKiroEndpointConfigs(auth) + + if len(configs) != 2 { + t.Fatalf("expected 2 endpoint configs, got %d", len(configs)) + } + + expectedURL := "https://q.ap-southeast-1.amazonaws.com/generateAssistantResponse" + if configs[0].URL != expectedURL { + t.Errorf("primary URL = %q, want %q", configs[0].URL, expectedURL) + } +} + +func TestGetKiroEndpointConfigs_WithApiRegionOverride(t *testing.T) { + auth := &cliproxyauth.Auth{ + Metadata: map[string]any{ + "api_region": "eu-central-1", + "profile_arn": "arn:aws:codewhisperer:us-east-1:123456789012:profile/ABC", + }, + } + + configs := getKiroEndpointConfigs(auth) + + // api_region should take precedence over profile_arn + expectedURL := "https://q.eu-central-1.amazonaws.com/generateAssistantResponse" + if configs[0].URL != expectedURL { + t.Errorf("primary URL = %q, want %q", configs[0].URL, expectedURL) + } +} + +func TestGetKiroEndpointConfigs_PreferredEndpoint(t *testing.T) { + tests := []struct { + name string + preference string + expectedFirstName string + }{ + { + name: "Prefer codewhisperer", + preference: "codewhisperer", + expectedFirstName: "CodeWhisperer", + }, + { + name: "Prefer ide (alias for codewhisperer)", + preference: "ide", + expectedFirstName: "CodeWhisperer", + }, + { + name: "Prefer amazonq", + preference: "amazonq", + expectedFirstName: "AmazonQ", + }, + { + name: "Prefer q (alias for amazonq)", + preference: "q", + expectedFirstName: "AmazonQ", + }, + { + name: "Prefer cli (alias for amazonq)", + preference: "cli", + expectedFirstName: "AmazonQ", + }, + { + name: "Unknown preference - no reordering", + preference: "unknown", + expectedFirstName: "AmazonQ", + }, + { + name: "Empty preference - no reordering", + preference: "", + expectedFirstName: "AmazonQ", + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + auth := &cliproxyauth.Auth{ + Metadata: map[string]any{ + "preferred_endpoint": tt.preference, + }, + } + + configs := getKiroEndpointConfigs(auth) + + if configs[0].Name != tt.expectedFirstName { + t.Errorf("first endpoint Name = %q, want %q", configs[0].Name, tt.expectedFirstName) + } + }) + } +} + +func TestGetKiroEndpointConfigs_PreferredEndpointFromAttributes(t *testing.T) { + // Test that preferred_endpoint can also come from Attributes + auth := &cliproxyauth.Auth{ + Metadata: map[string]any{}, + Attributes: map[string]string{"preferred_endpoint": "codewhisperer"}, + } + + configs := getKiroEndpointConfigs(auth) + + if configs[0].Name != "CodeWhisperer" { + t.Errorf("first endpoint Name = %q, want %q", configs[0].Name, "CodeWhisperer") + } +} + +func TestGetKiroEndpointConfigs_MetadataTakesPrecedenceOverAttributes(t *testing.T) { + auth := &cliproxyauth.Auth{ + Metadata: map[string]any{"preferred_endpoint": "amazonq"}, + Attributes: map[string]string{"preferred_endpoint": "codewhisperer"}, + } + + configs := getKiroEndpointConfigs(auth) + + // Metadata should take precedence + if configs[0].Name != "AmazonQ" { + t.Errorf("first endpoint Name = %q, want %q", configs[0].Name, "AmazonQ") + } +} + +func TestGetAuthValue(t *testing.T) { + tests := []struct { + name string + auth *cliproxyauth.Auth + key string + expected string + }{ + { + name: "From metadata", + auth: &cliproxyauth.Auth{ + Metadata: map[string]any{"test_key": "metadata_value"}, + }, + key: "test_key", + expected: "metadata_value", + }, + { + name: "From attributes (fallback)", + auth: &cliproxyauth.Auth{ + Attributes: map[string]string{"test_key": "attribute_value"}, + }, + key: "test_key", + expected: "attribute_value", + }, + { + name: "Metadata takes precedence", + auth: &cliproxyauth.Auth{ + Metadata: map[string]any{"test_key": "metadata_value"}, + Attributes: map[string]string{"test_key": "attribute_value"}, + }, + key: "test_key", + expected: "metadata_value", + }, + { + name: "Key not found", + auth: &cliproxyauth.Auth{ + Metadata: map[string]any{"other_key": "value"}, + Attributes: map[string]string{"another_key": "value"}, + }, + key: "test_key", + expected: "", + }, + { + name: "Nil metadata", + auth: &cliproxyauth.Auth{ + Attributes: map[string]string{"test_key": "attribute_value"}, + }, + key: "test_key", + expected: "attribute_value", + }, + { + name: "Both nil", + auth: &cliproxyauth.Auth{}, + key: "test_key", + expected: "", + }, + { + name: "Value is trimmed and lowercased", + auth: &cliproxyauth.Auth{ + Metadata: map[string]any{"test_key": " UPPER_VALUE "}, + }, + key: "test_key", + expected: "upper_value", + }, + { + name: "Empty string value in metadata - falls back to attributes", + auth: &cliproxyauth.Auth{ + Metadata: map[string]any{"test_key": ""}, + Attributes: map[string]string{"test_key": "attribute_value"}, + }, + key: "test_key", + expected: "attribute_value", + }, + { + name: "Non-string value in metadata - falls back to attributes", + auth: &cliproxyauth.Auth{ + Metadata: map[string]any{"test_key": 123}, + Attributes: map[string]string{"test_key": "attribute_value"}, + }, + key: "test_key", + expected: "attribute_value", + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + result := getAuthValue(tt.auth, tt.key) + if result != tt.expected { + t.Errorf("getAuthValue() = %q, want %q", result, tt.expected) + } + }) + } +} + +func TestGetAccountKey(t *testing.T) { + tests := []struct { + name string + auth *cliproxyauth.Auth + checkFn func(t *testing.T, result string) + }{ + { + name: "From client_id", + auth: &cliproxyauth.Auth{ + Metadata: map[string]any{ + "client_id": "test-client-id-123", + "refresh_token": "test-refresh-token-456", + }, + }, + checkFn: func(t *testing.T, result string) { + expected := kiroauth.GetAccountKey("test-client-id-123", "test-refresh-token-456") + if result != expected { + t.Errorf("expected %s, got %s", expected, result) + } + }, + }, + { + name: "From refresh_token only", + auth: &cliproxyauth.Auth{ + Metadata: map[string]any{ + "refresh_token": "test-refresh-token-789", + }, + }, + checkFn: func(t *testing.T, result string) { + expected := kiroauth.GetAccountKey("", "test-refresh-token-789") + if result != expected { + t.Errorf("expected %s, got %s", expected, result) + } + }, + }, + { + name: "Nil auth", + auth: nil, + checkFn: func(t *testing.T, result string) { + if len(result) != 16 { + t.Errorf("expected 16 char key, got %d chars", len(result)) + } + }, + }, + { + name: "Nil metadata", + auth: &cliproxyauth.Auth{}, + checkFn: func(t *testing.T, result string) { + if len(result) != 16 { + t.Errorf("expected 16 char key, got %d chars", len(result)) + } + }, + }, + { + name: "Empty metadata", + auth: &cliproxyauth.Auth{ + Metadata: map[string]any{}, + }, + checkFn: func(t *testing.T, result string) { + if len(result) != 16 { + t.Errorf("expected 16 char key, got %d chars", len(result)) + } + }, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + result := getAccountKey(tt.auth) + tt.checkFn(t, result) + }) + } +} + +func TestEndpointAliases(t *testing.T) { + // Verify all expected aliases are defined + expectedAliases := map[string]string{ + "codewhisperer": "codewhisperer", + "ide": "codewhisperer", + "amazonq": "amazonq", + "q": "amazonq", + "cli": "amazonq", + } + + for alias, target := range expectedAliases { + if actual, ok := endpointAliases[alias]; !ok { + t.Errorf("missing alias %q", alias) + } else if actual != target { + t.Errorf("alias %q = %q, want %q", alias, actual, target) + } + } + + // Verify no unexpected aliases + if len(endpointAliases) != len(expectedAliases) { + t.Errorf("unexpected number of aliases: got %d, want %d", len(endpointAliases), len(expectedAliases)) + } +} diff --git a/internal/translator/kiro/claude/kiro_claude_request.go b/internal/translator/kiro/claude/kiro_claude_request.go index 0ad090ae..067a2710 100644 --- a/internal/translator/kiro/claude/kiro_claude_request.go +++ b/internal/translator/kiro/claude/kiro_claude_request.go @@ -38,10 +38,12 @@ type KiroInferenceConfig struct { // KiroConversationState holds the conversation context type KiroConversationState struct { - ChatTriggerType string `json:"chatTriggerType"` // Required: "MANUAL" - must be first field - ConversationID string `json:"conversationId"` - CurrentMessage KiroCurrentMessage `json:"currentMessage"` - History []KiroHistoryMessage `json:"history,omitempty"` + AgentContinuationID string `json:"agentContinuationId,omitempty"` + AgentTaskType string `json:"agentTaskType,omitempty"` + ChatTriggerType string `json:"chatTriggerType"` // Required: "MANUAL" + ConversationID string `json:"conversationId"` + CurrentMessage KiroCurrentMessage `json:"currentMessage"` + History []KiroHistoryMessage `json:"history,omitempty"` } // KiroCurrentMessage wraps the current user message @@ -293,10 +295,18 @@ func BuildKiroPayload(claudeBody []byte, modelID, profileArn, origin string, isA } } + // Session IDs: extract from messages[].additional_kwargs (LangChain format) or random + conversationID := extractMetadataFromMessages(messages, "conversationId") + continuationID := extractMetadataFromMessages(messages, "continuationId") + if conversationID == "" { + conversationID = uuid.New().String() + } + payload := KiroPayload{ ConversationState: KiroConversationState{ + AgentTaskType: "vibe", ChatTriggerType: "MANUAL", - ConversationID: uuid.New().String(), + ConversationID: conversationID, CurrentMessage: currentMessage, History: history, }, @@ -304,6 +314,11 @@ func BuildKiroPayload(claudeBody []byte, modelID, profileArn, origin string, isA InferenceConfig: inferenceConfig, } + // Only set AgentContinuationID if client provided + if continuationID != "" { + payload.ConversationState.AgentContinuationID = continuationID + } + result, err := json.Marshal(payload) if err != nil { log.Debugf("kiro: failed to marshal payload: %v", err) @@ -329,6 +344,18 @@ func normalizeOrigin(origin string) string { } } +// extractMetadataFromMessages extracts metadata from messages[].additional_kwargs (LangChain format). +// Searches from the last message backwards, returns empty string if not found. +func extractMetadataFromMessages(messages gjson.Result, key string) string { + arr := messages.Array() + for i := len(arr) - 1; i >= 0; i-- { + if val := arr[i].Get("additional_kwargs." + key); val.Exists() && val.String() != "" { + return val.String() + } + } + return "" +} + // extractSystemPrompt extracts system prompt from Claude request func extractSystemPrompt(claudeBody []byte) string { systemField := gjson.GetBytes(claudeBody, "system") diff --git a/internal/translator/kiro/openai/kiro_openai_request.go b/internal/translator/kiro/openai/kiro_openai_request.go index 79411c42..ee313abb 100644 --- a/internal/translator/kiro/openai/kiro_openai_request.go +++ b/internal/translator/kiro/openai/kiro_openai_request.go @@ -36,10 +36,12 @@ type KiroInferenceConfig struct { // KiroConversationState holds the conversation context type KiroConversationState struct { - ChatTriggerType string `json:"chatTriggerType"` // Required: "MANUAL" - ConversationID string `json:"conversationId"` - CurrentMessage KiroCurrentMessage `json:"currentMessage"` - History []KiroHistoryMessage `json:"history,omitempty"` + AgentContinuationID string `json:"agentContinuationId,omitempty"` + AgentTaskType string `json:"agentTaskType,omitempty"` + ChatTriggerType string `json:"chatTriggerType"` // Required: "MANUAL" + ConversationID string `json:"conversationId"` + CurrentMessage KiroCurrentMessage `json:"currentMessage"` + History []KiroHistoryMessage `json:"history,omitempty"` } // KiroCurrentMessage wraps the current user message @@ -297,10 +299,18 @@ func BuildKiroPayloadFromOpenAI(openaiBody []byte, modelID, profileArn, origin s } } + // Session IDs: extract from messages[].additional_kwargs (LangChain format) or random + conversationID := extractMetadataFromMessages(messages, "conversationId") + continuationID := extractMetadataFromMessages(messages, "continuationId") + if conversationID == "" { + conversationID = uuid.New().String() + } + payload := KiroPayload{ ConversationState: KiroConversationState{ + AgentTaskType: "vibe", ChatTriggerType: "MANUAL", - ConversationID: uuid.New().String(), + ConversationID: conversationID, CurrentMessage: currentMessage, History: history, }, @@ -308,6 +318,11 @@ func BuildKiroPayloadFromOpenAI(openaiBody []byte, modelID, profileArn, origin s InferenceConfig: inferenceConfig, } + // Only set AgentContinuationID if client provided + if continuationID != "" { + payload.ConversationState.AgentContinuationID = continuationID + } + result, err := json.Marshal(payload) if err != nil { log.Debugf("kiro-openai: failed to marshal payload: %v", err) @@ -333,6 +348,18 @@ func normalizeOrigin(origin string) string { } } +// extractMetadataFromMessages extracts metadata from messages[].additional_kwargs (LangChain format). +// Searches from the last message backwards, returns empty string if not found. +func extractMetadataFromMessages(messages gjson.Result, key string) string { + arr := messages.Array() + for i := len(arr) - 1; i >= 0; i-- { + if val := arr[i].Get("additional_kwargs." + key); val.Exists() && val.String() != "" { + return val.String() + } + } + return "" +} + // extractSystemPromptFromOpenAI extracts system prompt from OpenAI messages func extractSystemPromptFromOpenAI(messages gjson.Result) string { if !messages.IsArray() { diff --git a/sdk/auth/kiro.go b/sdk/auth/kiro.go index ad165b75..76092108 100644 --- a/sdk/auth/kiro.go +++ b/sdk/auth/kiro.go @@ -166,9 +166,21 @@ func (a *KiroAuthenticator) Login(ctx context.Context, cfg *config.Config, opts return nil, fmt.Errorf("kiro auth: configuration is required") } + // Extract IDC options from metadata if present + var idcOpts *kiroauth.IDCLoginOptions + if opts != nil && opts.Metadata != nil { + if startURL := opts.Metadata["start-url"]; startURL != "" { + idcOpts = &kiroauth.IDCLoginOptions{ + StartURL: startURL, + Region: opts.Metadata["region"], + UseDeviceCode: opts.Metadata["flow"] == "device", + } + } + } + // Use the unified method selection flow (Builder ID or IDC) ssoClient := kiroauth.NewSSOOIDCClient(cfg) - tokenData, err := ssoClient.LoginWithMethodSelection(ctx) + tokenData, err := ssoClient.LoginWithMethodSelection(ctx, idcOpts) if err != nil { return nil, fmt.Errorf("login failed: %w", err) }