Merge pull request #395 from Xm798/feat/kiro

feat(kiro): add IDC auth code flow, redesign fingerprint and API protocol
This commit is contained in:
Luis Pater
2026-02-27 20:50:43 +08:00
committed by GitHub
27 changed files with 3104 additions and 750 deletions

View File

@@ -49,15 +49,8 @@ const (
ErrStreamFatal = "fatal" // Connection/authentication errors, not recoverable
ErrStreamMalformed = "malformed" // Format errors, data cannot be parsed
// kiroUserAgent matches Amazon Q CLI style for User-Agent header
kiroUserAgent = "aws-sdk-rust/1.3.9 os/macos lang/rust/1.87.0"
// kiroFullUserAgent is the complete x-amz-user-agent header (Amazon Q CLI style)
kiroFullUserAgent = "aws-sdk-rust/1.3.9 ua/2.1 api/ssooidc/1.88.0 os/macos lang/rust/1.87.0 m/E app/AmazonQ-For-CLI"
// Kiro IDE style headers for IDC auth
kiroIDEUserAgent = "aws-sdk-js/1.0.27 ua/2.1 os/win32#10.0.19044 lang/js md/nodejs#22.21.1 api/codewhispererstreaming#1.0.27 m/E"
kiroIDEAmzUserAgent = "aws-sdk-js/1.0.27"
kiroIDEAgentModeVibe = "vibe"
// kiroIDEAgentMode is the agent mode header value for Kiro IDE requests
kiroIDEAgentMode = "vibe"
// Socket retry configuration constants
// Maximum number of retry attempts for socket/network errors
@@ -87,20 +80,13 @@ var (
usageUpdateTimeInterval = 15 * time.Second // Or every 15 seconds, whichever comes first
)
// Global FingerprintManager for dynamic User-Agent generation per token
// Each token gets a unique fingerprint on first use, which is cached for subsequent requests
var (
globalFingerprintManager *kiroauth.FingerprintManager
globalFingerprintManagerOnce sync.Once
)
// getGlobalFingerprintManager returns the global FingerprintManager instance
func getGlobalFingerprintManager() *kiroauth.FingerprintManager {
globalFingerprintManagerOnce.Do(func() {
globalFingerprintManager = kiroauth.NewFingerprintManager()
log.Infof("kiro: initialized global FingerprintManager for dynamic UA generation")
})
return globalFingerprintManager
// endpointAliases maps user preference values to canonical endpoint names.
var endpointAliases = map[string]string{
"codewhisperer": "codewhisperer",
"ide": "codewhisperer",
"amazonq": "amazonq",
"q": "amazonq",
"cli": "amazonq",
}
// retryConfig holds configuration for socket retry logic.
@@ -433,87 +419,41 @@ func getKiroEndpointConfigs(auth *cliproxyauth.Auth) []kiroEndpointConfig {
return kiroEndpointConfigs
}
// Determine API region using shared resolution logic
region := resolveKiroAPIRegion(auth)
log.Debugf("kiro: using region %s", region)
// Build endpoint configs for the specified region
endpointConfigs := buildKiroEndpointConfigs(region)
// For IDC auth, use Q endpoint with AI_EDITOR origin
// IDC tokens work with Q endpoint using Bearer auth
// The difference is only in how tokens are refreshed (OIDC with clientId/clientSecret for IDC)
// NOT in how API calls are made - both Social and IDC use the same endpoint/origin
if auth.Metadata != nil {
authMethod, _ := auth.Metadata["auth_method"].(string)
if strings.ToLower(authMethod) == "idc" {
log.Debugf("kiro: IDC auth, using Q endpoint (region: %s)", region)
return endpointConfigs
}
}
// Check for preference
var preference string
if auth.Metadata != nil {
if p, ok := auth.Metadata["preferred_endpoint"].(string); ok {
preference = p
}
}
// Check attributes as fallback (e.g. from HTTP headers)
if preference == "" && auth.Attributes != nil {
preference = auth.Attributes["preferred_endpoint"]
}
configs := buildKiroEndpointConfigs(region)
preference := getAuthValue(auth, "preferred_endpoint")
if preference == "" {
return endpointConfigs
return configs
}
preference = strings.ToLower(strings.TrimSpace(preference))
targetName, ok := endpointAliases[preference]
if !ok {
return configs
}
// Create new slice to avoid modifying global state
var sorted []kiroEndpointConfig
var remaining []kiroEndpointConfig
for _, cfg := range endpointConfigs {
name := strings.ToLower(cfg.Name)
// Check for matches
// CodeWhisperer aliases: codewhisperer, ide
// AmazonQ aliases: amazonq, q, cli
isMatch := false
if (preference == "codewhisperer" || preference == "ide") && name == "codewhisperer" {
isMatch = true
} else if (preference == "amazonq" || preference == "q" || preference == "cli") && name == "amazonq" {
isMatch = true
}
if isMatch {
sorted = append(sorted, cfg)
var preferred, others []kiroEndpointConfig
for _, cfg := range configs {
if strings.ToLower(cfg.Name) == targetName {
preferred = append(preferred, cfg)
} else {
remaining = append(remaining, cfg)
others = append(others, cfg)
}
}
// If preference didn't match anything, return default
if len(sorted) == 0 {
return endpointConfigs
if len(preferred) == 0 {
return configs
}
// Combine: preferred first, then others
return append(sorted, remaining...)
return append(preferred, others...)
}
// KiroExecutor handles requests to AWS CodeWhisperer (Kiro) API.
type KiroExecutor struct {
cfg *config.Config
refreshMu sync.Mutex // Serializes token refresh operations to prevent race conditions
}
// isIDCAuth checks if the auth uses IDC (Identity Center) authentication method.
func isIDCAuth(auth *cliproxyauth.Auth) bool {
if auth == nil || auth.Metadata == nil {
return false
}
authMethod, _ := auth.Metadata["auth_method"].(string)
return strings.ToLower(authMethod) == "idc"
cfg *config.Config
refreshMu sync.Mutex // Serializes token refresh operations to prevent race conditions
profileArnMu sync.Mutex // Serializes profileArn fetches to prevent concurrent map writes
}
// buildKiroPayloadForFormat builds the Kiro API payload based on the source format.
@@ -546,27 +486,22 @@ func NewKiroExecutor(cfg *config.Config) *KiroExecutor {
// Identifier returns the unique identifier for this executor.
func (e *KiroExecutor) Identifier() string { return "kiro" }
// applyDynamicFingerprint applies token-specific fingerprint headers to the request
// For IDC auth, uses dynamic fingerprint-based User-Agent
// For other auth types, uses static Amazon Q CLI style headers
// applyDynamicFingerprint applies account-specific fingerprint headers to the request.
func applyDynamicFingerprint(req *http.Request, auth *cliproxyauth.Auth) {
if isIDCAuth(auth) {
// Get token-specific fingerprint for dynamic UA generation
tokenKey := getTokenKey(auth)
fp := getGlobalFingerprintManager().GetFingerprint(tokenKey)
accountKey := getAccountKey(auth)
fp := kiroauth.GlobalFingerprintManager().GetFingerprint(accountKey)
// Use fingerprint-generated dynamic User-Agent
req.Header.Set("User-Agent", fp.BuildUserAgent())
req.Header.Set("X-Amz-User-Agent", fp.BuildAmzUserAgent())
req.Header.Set("x-amzn-kiro-agent-mode", kiroIDEAgentModeVibe)
req.Header.Set("User-Agent", fp.BuildUserAgent())
req.Header.Set("X-Amz-User-Agent", fp.BuildAmzUserAgent())
req.Header.Set("x-amzn-kiro-agent-mode", kiroIDEAgentMode)
req.Header.Set("x-amzn-codewhisperer-optout", "true")
log.Debugf("kiro: using dynamic fingerprint for token %s (SDK:%s, OS:%s/%s, Kiro:%s)",
tokenKey[:8]+"...", fp.SDKVersion, fp.OSType, fp.OSVersion, fp.KiroVersion)
} else {
// Use static Amazon Q CLI style headers for non-IDC auth
req.Header.Set("User-Agent", kiroUserAgent)
req.Header.Set("X-Amz-User-Agent", kiroFullUserAgent)
keyPrefix := accountKey
if len(keyPrefix) > 8 {
keyPrefix = keyPrefix[:8]
}
log.Debugf("kiro: using dynamic fingerprint for account %s (SDK:%s, OS:%s/%s, Kiro:%s)",
keyPrefix+"...", fp.StreamingSDKVersion, fp.OSType, fp.OSVersion, fp.KiroVersion)
}
// PrepareRequest prepares the HTTP request before execution.
@@ -609,17 +544,51 @@ func (e *KiroExecutor) HttpRequest(ctx context.Context, auth *cliproxyauth.Auth,
return httpClient.Do(httpReq)
}
// getTokenKey returns a unique key for rate limiting based on auth credentials.
// Uses auth ID if available, otherwise falls back to a hash of the access token.
func getTokenKey(auth *cliproxyauth.Auth) string {
// getAccountKey returns a stable account key for fingerprint lookup and rate limiting.
// Fallback order:
// 1) client_id / refresh_token (best account identity)
// 2) auth.ID (stable local auth record)
// 3) profile_arn (stable AWS profile identity)
// 4) access_token (least preferred but deterministic)
// 5) fixed anonymous seed
func getAccountKey(auth *cliproxyauth.Auth) string {
var clientID, refreshToken, profileArn string
if auth != nil && auth.Metadata != nil {
clientID, _ = auth.Metadata["client_id"].(string)
refreshToken, _ = auth.Metadata["refresh_token"].(string)
profileArn, _ = auth.Metadata["profile_arn"].(string)
}
if clientID != "" || refreshToken != "" {
return kiroauth.GetAccountKey(clientID, refreshToken)
}
if auth != nil && auth.ID != "" {
return auth.ID
return kiroauth.GenerateAccountKey(auth.ID)
}
accessToken, _ := kiroCredentials(auth)
if len(accessToken) > 16 {
return accessToken[:16]
if profileArn != "" {
return kiroauth.GenerateAccountKey(profileArn)
}
return accessToken
if accessToken, _ := kiroCredentials(auth); accessToken != "" {
return kiroauth.GenerateAccountKey(accessToken)
}
return kiroauth.GenerateAccountKey("kiro-anonymous")
}
// getAuthValue looks up a value by key in auth Metadata, then Attributes.
func getAuthValue(auth *cliproxyauth.Auth, key string) string {
if auth == nil {
return ""
}
if auth.Metadata != nil {
if v, ok := auth.Metadata[key].(string); ok && v != "" {
return strings.ToLower(strings.TrimSpace(v))
}
}
if auth.Attributes != nil {
if v := auth.Attributes[key]; v != "" {
return strings.ToLower(strings.TrimSpace(v))
}
}
return ""
}
// Execute sends the request to Kiro API and returns the response.
@@ -631,7 +600,7 @@ func (e *KiroExecutor) Execute(ctx context.Context, auth *cliproxyauth.Auth, req
}
// Rate limiting: get token key for tracking
tokenKey := getTokenKey(auth)
tokenKey := getAccountKey(auth)
rateLimiter := kiroauth.GetGlobalRateLimiter()
cooldownMgr := kiroauth.GetGlobalCooldownManager()
@@ -693,6 +662,13 @@ func (e *KiroExecutor) Execute(ctx context.Context, auth *cliproxyauth.Auth, req
kiroModelID := e.mapModelToKiro(req.Model)
// Fetch profileArn if missing (for imported accounts from Kiro IDE)
if profileArn == "" {
if fetched := e.fetchAndSaveProfileArn(ctx, auth, accessToken); fetched != "" {
profileArn = fetched
}
}
// Determine agentic mode and effective profile ARN using helper functions
isAgentic, isChatOnly := determineAgenticMode(req.Model)
effectiveProfileArn := getEffectiveProfileArnWithWarning(auth, profileArn)
@@ -749,7 +725,7 @@ func (e *KiroExecutor) executeWithRetry(ctx context.Context, auth *cliproxyauth.
httpReq.Header.Set("X-Amz-Target", endpointConfig.AmzTarget)
}
// Kiro-specific headers
httpReq.Header.Set("x-amzn-kiro-agent-mode", kiroIDEAgentModeVibe)
httpReq.Header.Set("x-amzn-kiro-agent-mode", kiroIDEAgentMode)
httpReq.Header.Set("x-amzn-codewhisperer-optout", "true")
// Apply dynamic fingerprint-based headers
@@ -1060,7 +1036,7 @@ func (e *KiroExecutor) ExecuteStream(ctx context.Context, auth *cliproxyauth.Aut
}
// Rate limiting: get token key for tracking
tokenKey := getTokenKey(auth)
tokenKey := getAccountKey(auth)
rateLimiter := kiroauth.GetGlobalRateLimiter()
cooldownMgr := kiroauth.GetGlobalCooldownManager()
@@ -1126,6 +1102,13 @@ func (e *KiroExecutor) ExecuteStream(ctx context.Context, auth *cliproxyauth.Aut
kiroModelID := e.mapModelToKiro(req.Model)
// Fetch profileArn if missing (for imported accounts from Kiro IDE)
if profileArn == "" {
if fetched := e.fetchAndSaveProfileArn(ctx, auth, accessToken); fetched != "" {
profileArn = fetched
}
}
// Determine agentic mode and effective profile ARN using helper functions
isAgentic, isChatOnly := determineAgenticMode(req.Model)
effectiveProfileArn := getEffectiveProfileArnWithWarning(auth, profileArn)
@@ -1185,7 +1168,7 @@ func (e *KiroExecutor) executeStreamWithRetry(ctx context.Context, auth *cliprox
httpReq.Header.Set("X-Amz-Target", endpointConfig.AmzTarget)
}
// Kiro-specific headers
httpReq.Header.Set("x-amzn-kiro-agent-mode", kiroIDEAgentModeVibe)
httpReq.Header.Set("x-amzn-kiro-agent-mode", kiroIDEAgentMode)
httpReq.Header.Set("x-amzn-codewhisperer-optout", "true")
// Apply dynamic fingerprint-based headers
@@ -1647,62 +1630,23 @@ func determineAgenticMode(model string) (isAgentic, isChatOnly bool) {
return isAgentic, isChatOnly
}
// getEffectiveProfileArn determines if profileArn should be included based on auth method.
// profileArn is only needed for social auth (Google OAuth), not for AWS SSO OIDC (Builder ID/IDC).
//
// Detection logic (matching kiro-openai-gateway):
// 1. Check auth_method field: "builder-id" or "idc"
// 2. Check auth_type field: "aws_sso_oidc" (from kiro-cli tokens)
// 3. Check for client_id + client_secret presence (AWS SSO OIDC signature)
func getEffectiveProfileArn(auth *cliproxyauth.Auth, profileArn string) string {
if auth != nil && auth.Metadata != nil {
// Check 1: auth_method field (from CLIProxyAPI tokens)
if authMethod, ok := auth.Metadata["auth_method"].(string); ok && (authMethod == "builder-id" || authMethod == "idc") {
return "" // AWS SSO OIDC - don't include profileArn
}
// Check 2: auth_type field (from kiro-cli tokens)
if authType, ok := auth.Metadata["auth_type"].(string); ok && authType == "aws_sso_oidc" {
return "" // AWS SSO OIDC - don't include profileArn
}
// Check 3: client_id + client_secret presence (AWS SSO OIDC signature)
_, hasClientID := auth.Metadata["client_id"].(string)
_, hasClientSecret := auth.Metadata["client_secret"].(string)
if hasClientID && hasClientSecret {
return "" // AWS SSO OIDC - don't include profileArn
}
}
return profileArn
}
// getEffectiveProfileArnWithWarning determines if profileArn should be included based on auth method,
// and logs a warning if profileArn is missing for non-builder-id auth.
// This consolidates the auth_method check that was previously done separately.
//
// AWS SSO OIDC (Builder ID/IDC) users don't need profileArn - sending it causes 403 errors.
// Only Kiro Desktop (social auth like Google/GitHub) users need profileArn.
//
// Detection logic (matching kiro-openai-gateway):
// 1. Check auth_method field: "builder-id" or "idc"
// 2. Check auth_type field: "aws_sso_oidc" (from kiro-cli tokens)
// 3. Check for client_id + client_secret presence (AWS SSO OIDC signature)
// getEffectiveProfileArnWithWarning suppresses profileArn for builder-id and AWS SSO OIDC auth.
// Builder-id users (auth_method == "builder-id") and AWS SSO OIDC users (auth_type == "aws_sso_oidc")
// don't need profileArn — sending it causes 403 errors.
// For all other auth methods (e.g. social auth), profileArn is returned as-is,
// with a warning logged if it is empty.
func getEffectiveProfileArnWithWarning(auth *cliproxyauth.Auth, profileArn string) string {
if auth != nil && auth.Metadata != nil {
// Check 1: auth_method field (from CLIProxyAPI tokens)
if authMethod, ok := auth.Metadata["auth_method"].(string); ok && (authMethod == "builder-id" || authMethod == "idc") {
return "" // AWS SSO OIDC - don't include profileArn
// Check 1: auth_method field, skip for builder-id only
if authMethod, ok := auth.Metadata["auth_method"].(string); ok && authMethod == "builder-id" {
return ""
}
// Check 2: auth_type field (from kiro-cli tokens)
if authType, ok := auth.Metadata["auth_type"].(string); ok && authType == "aws_sso_oidc" {
return "" // AWS SSO OIDC - don't include profileArn
}
// Check 3: client_id + client_secret presence (AWS SSO OIDC signature, like kiro-openai-gateway)
_, hasClientID := auth.Metadata["client_id"].(string)
_, hasClientSecret := auth.Metadata["client_secret"].(string)
if hasClientID && hasClientSecret {
return "" // AWS SSO OIDC - don't include profileArn
}
}
// For social auth (Kiro Desktop), profileArn is required
// For social auth and IDC, profileArn is required
if profileArn == "" {
log.Warnf("kiro: profile ARN not found in auth, API calls may fail")
}
@@ -3999,6 +3943,51 @@ func (e *KiroExecutor) persistRefreshedAuth(auth *cliproxyauth.Auth) error {
return nil
}
// fetchAndSaveProfileArn fetches profileArn from API if missing, updates auth and persists to file.
func (e *KiroExecutor) fetchAndSaveProfileArn(ctx context.Context, auth *cliproxyauth.Auth, accessToken string) string {
if auth == nil || auth.Metadata == nil {
return ""
}
// Skip for Builder ID - they don't have profiles
if authMethod, ok := auth.Metadata["auth_method"].(string); ok && authMethod == "builder-id" {
log.Debugf("kiro executor: skipping profileArn fetch for builder-id auth")
return ""
}
e.profileArnMu.Lock()
defer e.profileArnMu.Unlock()
// Double-check: another goroutine may have already fetched and saved the profileArn
if arn, ok := auth.Metadata["profile_arn"].(string); ok && arn != "" {
return arn
}
clientID, _ := auth.Metadata["client_id"].(string)
refreshToken, _ := auth.Metadata["refresh_token"].(string)
ssoClient := kiroauth.NewSSOOIDCClient(e.cfg)
profileArn := ssoClient.FetchProfileArn(ctx, accessToken, clientID, refreshToken)
if profileArn == "" {
log.Debugf("kiro executor: FetchProfileArn returned no profiles")
return ""
}
auth.Metadata["profile_arn"] = profileArn
if auth.Attributes == nil {
auth.Attributes = make(map[string]string)
}
auth.Attributes["profile_arn"] = profileArn
if err := e.persistRefreshedAuth(auth); err != nil {
log.Warnf("kiro executor: failed to persist profileArn: %v", err)
} else {
log.Infof("kiro executor: fetched and saved profileArn: %s", profileArn)
}
return profileArn
}
// reloadAuthFromFile 从文件重新加载 auth 数据(方案 B: Fallback 机制)
// 当内存中的 token 已过期时,尝试从文件读取最新的 token
// 这解决了后台刷新器已更新文件但内存中 Auth 对象尚未同步的时间差问题
@@ -4728,7 +4717,7 @@ func (e *KiroExecutor) callKiroAndBuffer(
isAgentic, isChatOnly := determineAgenticMode(req.Model)
effectiveProfileArn := getEffectiveProfileArnWithWarning(auth, profileArn)
tokenKey := getTokenKey(auth)
tokenKey := getAccountKey(auth)
kiroStream, err := e.executeStreamWithRetry(
ctx, auth, req, opts, accessToken, effectiveProfileArn,
@@ -4770,7 +4759,7 @@ func (e *KiroExecutor) callKiroDirectStream(
isAgentic, isChatOnly := determineAgenticMode(req.Model)
effectiveProfileArn := getEffectiveProfileArnWithWarning(auth, profileArn)
tokenKey := getTokenKey(auth)
tokenKey := getAccountKey(auth)
reporter := newUsageReporter(ctx, e.Identifier(), req.Model, auth)
var streamErr error
@@ -4819,7 +4808,7 @@ func (e *KiroExecutor) executeNonStreamFallback(
kiroModelID := e.mapModelToKiro(req.Model)
isAgentic, isChatOnly := determineAgenticMode(req.Model)
effectiveProfileArn := getEffectiveProfileArnWithWarning(auth, profileArn)
tokenKey := getTokenKey(auth)
tokenKey := getAccountKey(auth)
reporter := newUsageReporter(ctx, e.Identifier(), req.Model, auth)
var err error

View File

@@ -0,0 +1,423 @@
package executor
import (
"fmt"
"testing"
kiroauth "github.com/router-for-me/CLIProxyAPI/v6/internal/auth/kiro"
cliproxyauth "github.com/router-for-me/CLIProxyAPI/v6/sdk/cliproxy/auth"
)
func TestBuildKiroEndpointConfigs(t *testing.T) {
tests := []struct {
name string
region string
expectedURL string
expectedOrigin string
expectedName string
}{
{
name: "Empty region - defaults to us-east-1",
region: "",
expectedURL: "https://q.us-east-1.amazonaws.com/generateAssistantResponse",
expectedOrigin: "AI_EDITOR",
expectedName: "AmazonQ",
},
{
name: "us-east-1",
region: "us-east-1",
expectedURL: "https://q.us-east-1.amazonaws.com/generateAssistantResponse",
expectedOrigin: "AI_EDITOR",
expectedName: "AmazonQ",
},
{
name: "ap-southeast-1",
region: "ap-southeast-1",
expectedURL: "https://q.ap-southeast-1.amazonaws.com/generateAssistantResponse",
expectedOrigin: "AI_EDITOR",
expectedName: "AmazonQ",
},
{
name: "eu-west-1",
region: "eu-west-1",
expectedURL: "https://q.eu-west-1.amazonaws.com/generateAssistantResponse",
expectedOrigin: "AI_EDITOR",
expectedName: "AmazonQ",
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
configs := buildKiroEndpointConfigs(tt.region)
if len(configs) != 2 {
t.Fatalf("expected 2 endpoint configs, got %d", len(configs))
}
// Check primary endpoint (AmazonQ)
primary := configs[0]
if primary.URL != tt.expectedURL {
t.Errorf("primary URL = %q, want %q", primary.URL, tt.expectedURL)
}
if primary.Origin != tt.expectedOrigin {
t.Errorf("primary Origin = %q, want %q", primary.Origin, tt.expectedOrigin)
}
if primary.Name != tt.expectedName {
t.Errorf("primary Name = %q, want %q", primary.Name, tt.expectedName)
}
if primary.AmzTarget != "" {
t.Errorf("primary AmzTarget should be empty, got %q", primary.AmzTarget)
}
// Check fallback endpoint (CodeWhisperer)
fallback := configs[1]
if fallback.Name != "CodeWhisperer" {
t.Errorf("fallback Name = %q, want %q", fallback.Name, "CodeWhisperer")
}
// CodeWhisperer fallback uses the same region as Q endpoint
expectedRegion := tt.region
if expectedRegion == "" {
expectedRegion = kiroDefaultRegion
}
expectedFallbackURL := fmt.Sprintf("https://codewhisperer.%s.amazonaws.com/generateAssistantResponse", expectedRegion)
if fallback.URL != expectedFallbackURL {
t.Errorf("fallback URL = %q, want %q", fallback.URL, expectedFallbackURL)
}
if fallback.AmzTarget == "" {
t.Error("fallback AmzTarget should NOT be empty")
}
})
}
}
func TestGetKiroEndpointConfigs_NilAuth(t *testing.T) {
configs := getKiroEndpointConfigs(nil)
if len(configs) != 2 {
t.Fatalf("expected 2 endpoint configs, got %d", len(configs))
}
// Should return default us-east-1 configs
if configs[0].Name != "AmazonQ" {
t.Errorf("first config Name = %q, want %q", configs[0].Name, "AmazonQ")
}
expectedURL := "https://q.us-east-1.amazonaws.com/generateAssistantResponse"
if configs[0].URL != expectedURL {
t.Errorf("first config URL = %q, want %q", configs[0].URL, expectedURL)
}
}
func TestGetKiroEndpointConfigs_WithRegionFromProfileArn(t *testing.T) {
auth := &cliproxyauth.Auth{
Metadata: map[string]any{
"profile_arn": "arn:aws:codewhisperer:ap-southeast-1:123456789012:profile/ABC",
},
}
configs := getKiroEndpointConfigs(auth)
if len(configs) != 2 {
t.Fatalf("expected 2 endpoint configs, got %d", len(configs))
}
expectedURL := "https://q.ap-southeast-1.amazonaws.com/generateAssistantResponse"
if configs[0].URL != expectedURL {
t.Errorf("primary URL = %q, want %q", configs[0].URL, expectedURL)
}
}
func TestGetKiroEndpointConfigs_WithApiRegionOverride(t *testing.T) {
auth := &cliproxyauth.Auth{
Metadata: map[string]any{
"api_region": "eu-central-1",
"profile_arn": "arn:aws:codewhisperer:us-east-1:123456789012:profile/ABC",
},
}
configs := getKiroEndpointConfigs(auth)
// api_region should take precedence over profile_arn
expectedURL := "https://q.eu-central-1.amazonaws.com/generateAssistantResponse"
if configs[0].URL != expectedURL {
t.Errorf("primary URL = %q, want %q", configs[0].URL, expectedURL)
}
}
func TestGetKiroEndpointConfigs_PreferredEndpoint(t *testing.T) {
tests := []struct {
name string
preference string
expectedFirstName string
}{
{
name: "Prefer codewhisperer",
preference: "codewhisperer",
expectedFirstName: "CodeWhisperer",
},
{
name: "Prefer ide (alias for codewhisperer)",
preference: "ide",
expectedFirstName: "CodeWhisperer",
},
{
name: "Prefer amazonq",
preference: "amazonq",
expectedFirstName: "AmazonQ",
},
{
name: "Prefer q (alias for amazonq)",
preference: "q",
expectedFirstName: "AmazonQ",
},
{
name: "Prefer cli (alias for amazonq)",
preference: "cli",
expectedFirstName: "AmazonQ",
},
{
name: "Unknown preference - no reordering",
preference: "unknown",
expectedFirstName: "AmazonQ",
},
{
name: "Empty preference - no reordering",
preference: "",
expectedFirstName: "AmazonQ",
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
auth := &cliproxyauth.Auth{
Metadata: map[string]any{
"preferred_endpoint": tt.preference,
},
}
configs := getKiroEndpointConfigs(auth)
if configs[0].Name != tt.expectedFirstName {
t.Errorf("first endpoint Name = %q, want %q", configs[0].Name, tt.expectedFirstName)
}
})
}
}
func TestGetKiroEndpointConfigs_PreferredEndpointFromAttributes(t *testing.T) {
// Test that preferred_endpoint can also come from Attributes
auth := &cliproxyauth.Auth{
Metadata: map[string]any{},
Attributes: map[string]string{"preferred_endpoint": "codewhisperer"},
}
configs := getKiroEndpointConfigs(auth)
if configs[0].Name != "CodeWhisperer" {
t.Errorf("first endpoint Name = %q, want %q", configs[0].Name, "CodeWhisperer")
}
}
func TestGetKiroEndpointConfigs_MetadataTakesPrecedenceOverAttributes(t *testing.T) {
auth := &cliproxyauth.Auth{
Metadata: map[string]any{"preferred_endpoint": "amazonq"},
Attributes: map[string]string{"preferred_endpoint": "codewhisperer"},
}
configs := getKiroEndpointConfigs(auth)
// Metadata should take precedence
if configs[0].Name != "AmazonQ" {
t.Errorf("first endpoint Name = %q, want %q", configs[0].Name, "AmazonQ")
}
}
func TestGetAuthValue(t *testing.T) {
tests := []struct {
name string
auth *cliproxyauth.Auth
key string
expected string
}{
{
name: "From metadata",
auth: &cliproxyauth.Auth{
Metadata: map[string]any{"test_key": "metadata_value"},
},
key: "test_key",
expected: "metadata_value",
},
{
name: "From attributes (fallback)",
auth: &cliproxyauth.Auth{
Attributes: map[string]string{"test_key": "attribute_value"},
},
key: "test_key",
expected: "attribute_value",
},
{
name: "Metadata takes precedence",
auth: &cliproxyauth.Auth{
Metadata: map[string]any{"test_key": "metadata_value"},
Attributes: map[string]string{"test_key": "attribute_value"},
},
key: "test_key",
expected: "metadata_value",
},
{
name: "Key not found",
auth: &cliproxyauth.Auth{
Metadata: map[string]any{"other_key": "value"},
Attributes: map[string]string{"another_key": "value"},
},
key: "test_key",
expected: "",
},
{
name: "Nil metadata",
auth: &cliproxyauth.Auth{
Attributes: map[string]string{"test_key": "attribute_value"},
},
key: "test_key",
expected: "attribute_value",
},
{
name: "Both nil",
auth: &cliproxyauth.Auth{},
key: "test_key",
expected: "",
},
{
name: "Value is trimmed and lowercased",
auth: &cliproxyauth.Auth{
Metadata: map[string]any{"test_key": " UPPER_VALUE "},
},
key: "test_key",
expected: "upper_value",
},
{
name: "Empty string value in metadata - falls back to attributes",
auth: &cliproxyauth.Auth{
Metadata: map[string]any{"test_key": ""},
Attributes: map[string]string{"test_key": "attribute_value"},
},
key: "test_key",
expected: "attribute_value",
},
{
name: "Non-string value in metadata - falls back to attributes",
auth: &cliproxyauth.Auth{
Metadata: map[string]any{"test_key": 123},
Attributes: map[string]string{"test_key": "attribute_value"},
},
key: "test_key",
expected: "attribute_value",
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
result := getAuthValue(tt.auth, tt.key)
if result != tt.expected {
t.Errorf("getAuthValue() = %q, want %q", result, tt.expected)
}
})
}
}
func TestGetAccountKey(t *testing.T) {
tests := []struct {
name string
auth *cliproxyauth.Auth
checkFn func(t *testing.T, result string)
}{
{
name: "From client_id",
auth: &cliproxyauth.Auth{
Metadata: map[string]any{
"client_id": "test-client-id-123",
"refresh_token": "test-refresh-token-456",
},
},
checkFn: func(t *testing.T, result string) {
expected := kiroauth.GetAccountKey("test-client-id-123", "test-refresh-token-456")
if result != expected {
t.Errorf("expected %s, got %s", expected, result)
}
},
},
{
name: "From refresh_token only",
auth: &cliproxyauth.Auth{
Metadata: map[string]any{
"refresh_token": "test-refresh-token-789",
},
},
checkFn: func(t *testing.T, result string) {
expected := kiroauth.GetAccountKey("", "test-refresh-token-789")
if result != expected {
t.Errorf("expected %s, got %s", expected, result)
}
},
},
{
name: "Nil auth",
auth: nil,
checkFn: func(t *testing.T, result string) {
if len(result) != 16 {
t.Errorf("expected 16 char key, got %d chars", len(result))
}
},
},
{
name: "Nil metadata",
auth: &cliproxyauth.Auth{},
checkFn: func(t *testing.T, result string) {
if len(result) != 16 {
t.Errorf("expected 16 char key, got %d chars", len(result))
}
},
},
{
name: "Empty metadata",
auth: &cliproxyauth.Auth{
Metadata: map[string]any{},
},
checkFn: func(t *testing.T, result string) {
if len(result) != 16 {
t.Errorf("expected 16 char key, got %d chars", len(result))
}
},
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
result := getAccountKey(tt.auth)
tt.checkFn(t, result)
})
}
}
func TestEndpointAliases(t *testing.T) {
// Verify all expected aliases are defined
expectedAliases := map[string]string{
"codewhisperer": "codewhisperer",
"ide": "codewhisperer",
"amazonq": "amazonq",
"q": "amazonq",
"cli": "amazonq",
}
for alias, target := range expectedAliases {
if actual, ok := endpointAliases[alias]; !ok {
t.Errorf("missing alias %q", alias)
} else if actual != target {
t.Errorf("alias %q = %q, want %q", alias, actual, target)
}
}
// Verify no unexpected aliases
if len(endpointAliases) != len(expectedAliases) {
t.Errorf("unexpected number of aliases: got %d, want %d", len(endpointAliases), len(expectedAliases))
}
}