From 98db5aabd0591b19b444db0f009816d86c76a932 Mon Sep 17 00:00:00 2001 From: Joao Date: Mon, 22 Dec 2025 12:23:10 +0000 Subject: [PATCH] feat: persist refreshed IDC tokens to auth file Add persistRefreshedAuth function to write refreshed tokens back to the auth file after inline token refresh. This prevents repeated token refreshes on every request when the token expires. Changes: - Add persistRefreshedAuth() to kiro_executor.go - Call persist after all token refresh paths (401, 403, pre-request) - Remove unused log import from sdk/auth/kiro.go --- internal/auth/kiro/cognito.go | 408 --------------------- internal/auth/kiro/sso_oidc.go | 2 +- internal/runtime/executor/kiro_executor.go | 236 +++++------- 3 files changed, 101 insertions(+), 545 deletions(-) delete mode 100644 internal/auth/kiro/cognito.go diff --git a/internal/auth/kiro/cognito.go b/internal/auth/kiro/cognito.go deleted file mode 100644 index 7cf32818..00000000 --- a/internal/auth/kiro/cognito.go +++ /dev/null @@ -1,408 +0,0 @@ -// Package kiro provides Cognito Identity credential exchange for IDC authentication. -// AWS Identity Center (IDC) requires SigV4 signing with Cognito-exchanged credentials -// instead of Bearer token authentication. -package kiro - -import ( - "context" - "crypto/hmac" - "crypto/sha256" - "encoding/hex" - "encoding/json" - "fmt" - "io" - "net/http" - "sort" - "strings" - "time" - - "github.com/router-for-me/CLIProxyAPI/v6/internal/config" - "github.com/router-for-me/CLIProxyAPI/v6/internal/util" - log "github.com/sirupsen/logrus" -) - -const ( - // Cognito Identity endpoints - cognitoIdentityEndpoint = "https://cognito-identity.us-east-1.amazonaws.com" - - // Identity Pool ID for Q Developer / CodeWhisperer - // This is the identity pool used by kiro-cli and Amazon Q CLI - cognitoIdentityPoolID = "us-east-1:70717e99-906f-485d-8d89-c89a0b5d49c5" - - // Cognito provider name for SSO OIDC - cognitoProviderName = "cognito-identity.amazonaws.com" -) - -// CognitoCredentials holds temporary AWS credentials from Cognito Identity. -type CognitoCredentials struct { - AccessKeyID string `json:"access_key_id"` - SecretAccessKey string `json:"secret_access_key"` - SessionToken string `json:"session_token"` - Expiration time.Time `json:"expiration"` -} - -// CognitoIdentityClient handles Cognito Identity credential exchange. -type CognitoIdentityClient struct { - httpClient *http.Client - cfg *config.Config -} - -// NewCognitoIdentityClient creates a new Cognito Identity client. -func NewCognitoIdentityClient(cfg *config.Config) *CognitoIdentityClient { - client := &http.Client{Timeout: 30 * time.Second} - if cfg != nil { - client = util.SetProxy(&cfg.SDKConfig, client) - } - return &CognitoIdentityClient{ - httpClient: client, - cfg: cfg, - } -} - -// GetIdentityID retrieves a Cognito Identity ID using the SSO access token. -func (c *CognitoIdentityClient) GetIdentityID(ctx context.Context, accessToken, region string) (string, error) { - if region == "" { - region = "us-east-1" - } - - endpoint := fmt.Sprintf("https://cognito-identity.%s.amazonaws.com", region) - - // Build the GetId request - // The SSO token is passed as a login token for the identity pool - payload := map[string]interface{}{ - "IdentityPoolId": cognitoIdentityPoolID, - "Logins": map[string]string{ - // Use the OIDC provider URL as the key - fmt.Sprintf("oidc.%s.amazonaws.com", region): accessToken, - }, - } - - body, err := json.Marshal(payload) - if err != nil { - return "", fmt.Errorf("failed to marshal GetId request: %w", err) - } - - req, err := http.NewRequestWithContext(ctx, http.MethodPost, endpoint, strings.NewReader(string(body))) - if err != nil { - return "", fmt.Errorf("failed to create GetId request: %w", err) - } - - req.Header.Set("Content-Type", "application/x-amz-json-1.1") - req.Header.Set("X-Amz-Target", "AWSCognitoIdentityService.GetId") - req.Header.Set("Accept", "application/json") - - resp, err := c.httpClient.Do(req) - if err != nil { - return "", fmt.Errorf("GetId request failed: %w", err) - } - defer resp.Body.Close() - - respBody, err := io.ReadAll(resp.Body) - if err != nil { - return "", fmt.Errorf("failed to read GetId response: %w", err) - } - - if resp.StatusCode != http.StatusOK { - log.Debugf("Cognito GetId failed (status %d): %s", resp.StatusCode, string(respBody)) - return "", fmt.Errorf("GetId failed (status %d): %s", resp.StatusCode, string(respBody)) - } - - var result struct { - IdentityID string `json:"IdentityId"` - } - if err := json.Unmarshal(respBody, &result); err != nil { - return "", fmt.Errorf("failed to parse GetId response: %w", err) - } - - if result.IdentityID == "" { - return "", fmt.Errorf("empty IdentityId in GetId response") - } - - log.Debugf("Cognito Identity ID: %s", result.IdentityID) - return result.IdentityID, nil -} - -// GetCredentialsForIdentity exchanges an identity ID and login token for temporary AWS credentials. -func (c *CognitoIdentityClient) GetCredentialsForIdentity(ctx context.Context, identityID, accessToken, region string) (*CognitoCredentials, error) { - if region == "" { - region = "us-east-1" - } - - endpoint := fmt.Sprintf("https://cognito-identity.%s.amazonaws.com", region) - - payload := map[string]interface{}{ - "IdentityId": identityID, - "Logins": map[string]string{ - fmt.Sprintf("oidc.%s.amazonaws.com", region): accessToken, - }, - } - - body, err := json.Marshal(payload) - if err != nil { - return nil, fmt.Errorf("failed to marshal GetCredentialsForIdentity request: %w", err) - } - - req, err := http.NewRequestWithContext(ctx, http.MethodPost, endpoint, strings.NewReader(string(body))) - if err != nil { - return nil, fmt.Errorf("failed to create GetCredentialsForIdentity request: %w", err) - } - - req.Header.Set("Content-Type", "application/x-amz-json-1.1") - req.Header.Set("X-Amz-Target", "AWSCognitoIdentityService.GetCredentialsForIdentity") - req.Header.Set("Accept", "application/json") - - resp, err := c.httpClient.Do(req) - if err != nil { - return nil, fmt.Errorf("GetCredentialsForIdentity request failed: %w", err) - } - defer resp.Body.Close() - - respBody, err := io.ReadAll(resp.Body) - if err != nil { - return nil, fmt.Errorf("failed to read GetCredentialsForIdentity response: %w", err) - } - - if resp.StatusCode != http.StatusOK { - log.Debugf("Cognito GetCredentialsForIdentity failed (status %d): %s", resp.StatusCode, string(respBody)) - return nil, fmt.Errorf("GetCredentialsForIdentity failed (status %d): %s", resp.StatusCode, string(respBody)) - } - - var result struct { - Credentials struct { - AccessKeyID string `json:"AccessKeyId"` - SecretKey string `json:"SecretKey"` - SessionToken string `json:"SessionToken"` - Expiration int64 `json:"Expiration"` - } `json:"Credentials"` - IdentityID string `json:"IdentityId"` - } - if err := json.Unmarshal(respBody, &result); err != nil { - return nil, fmt.Errorf("failed to parse GetCredentialsForIdentity response: %w", err) - } - - if result.Credentials.AccessKeyID == "" { - return nil, fmt.Errorf("empty AccessKeyId in GetCredentialsForIdentity response") - } - - // Expiration is in seconds since epoch - expiration := time.Unix(result.Credentials.Expiration, 0) - - log.Debugf("Cognito credentials obtained, expires: %s", expiration.Format(time.RFC3339)) - - return &CognitoCredentials{ - AccessKeyID: result.Credentials.AccessKeyID, - SecretAccessKey: result.Credentials.SecretKey, - SessionToken: result.Credentials.SessionToken, - Expiration: expiration, - }, nil -} - -// ExchangeSSOTokenForCredentials is a convenience method that performs the full -// Cognito Identity credential exchange flow: GetId -> GetCredentialsForIdentity -func (c *CognitoIdentityClient) ExchangeSSOTokenForCredentials(ctx context.Context, accessToken, region string) (*CognitoCredentials, error) { - log.Debugf("Exchanging SSO token for Cognito credentials (region: %s)", region) - - // Step 1: Get Identity ID - identityID, err := c.GetIdentityID(ctx, accessToken, region) - if err != nil { - return nil, fmt.Errorf("failed to get identity ID: %w", err) - } - - // Step 2: Get credentials for the identity - creds, err := c.GetCredentialsForIdentity(ctx, identityID, accessToken, region) - if err != nil { - return nil, fmt.Errorf("failed to get credentials for identity: %w", err) - } - - return creds, nil -} - -// SigV4Signer provides AWS Signature Version 4 signing for HTTP requests. -type SigV4Signer struct { - credentials *CognitoCredentials - region string - service string -} - -// NewSigV4Signer creates a new SigV4 signer with the given credentials. -func NewSigV4Signer(creds *CognitoCredentials, region, service string) *SigV4Signer { - return &SigV4Signer{ - credentials: creds, - region: region, - service: service, - } -} - -// SignRequest signs an HTTP request using AWS Signature Version 4. -// The request body must be provided separately since it may have been read already. -func (s *SigV4Signer) SignRequest(req *http.Request, body []byte) error { - now := time.Now().UTC() - amzDate := now.Format("20060102T150405Z") - dateStamp := now.Format("20060102") - - // Ensure required headers are set - if req.Header.Get("Host") == "" { - req.Header.Set("Host", req.URL.Host) - } - req.Header.Set("X-Amz-Date", amzDate) - if s.credentials.SessionToken != "" { - req.Header.Set("X-Amz-Security-Token", s.credentials.SessionToken) - } - - // Create canonical request - canonicalRequest, signedHeaders := s.createCanonicalRequest(req, body) - - // Create string to sign - algorithm := "AWS4-HMAC-SHA256" - credentialScope := fmt.Sprintf("%s/%s/%s/aws4_request", dateStamp, s.region, s.service) - stringToSign := fmt.Sprintf("%s\n%s\n%s\n%s", - algorithm, - amzDate, - credentialScope, - hashSHA256([]byte(canonicalRequest)), - ) - - // Calculate signature - signingKey := s.getSignatureKey(dateStamp) - signature := hex.EncodeToString(hmacSHA256(signingKey, []byte(stringToSign))) - - // Build Authorization header - authHeader := fmt.Sprintf("%s Credential=%s/%s, SignedHeaders=%s, Signature=%s", - algorithm, - s.credentials.AccessKeyID, - credentialScope, - signedHeaders, - signature, - ) - - req.Header.Set("Authorization", authHeader) - - return nil -} - -// createCanonicalRequest builds the canonical request string for SigV4. -func (s *SigV4Signer) createCanonicalRequest(req *http.Request, body []byte) (string, string) { - // HTTP method - method := req.Method - - // Canonical URI - uri := req.URL.Path - if uri == "" { - uri = "/" - } - - // Canonical query string (sorted) - queryString := s.buildCanonicalQueryString(req) - - // Canonical headers (sorted, lowercase) - canonicalHeaders, signedHeaders := s.buildCanonicalHeaders(req) - - // Hashed payload - payloadHash := hashSHA256(body) - - canonicalRequest := fmt.Sprintf("%s\n%s\n%s\n%s\n%s\n%s", - method, - uri, - queryString, - canonicalHeaders, - signedHeaders, - payloadHash, - ) - - return canonicalRequest, signedHeaders -} - -// buildCanonicalQueryString builds a sorted, URI-encoded query string. -func (s *SigV4Signer) buildCanonicalQueryString(req *http.Request) string { - if req.URL.RawQuery == "" { - return "" - } - - // Parse and sort query parameters - params := make([]string, 0) - for key, values := range req.URL.Query() { - for _, value := range values { - params = append(params, fmt.Sprintf("%s=%s", uriEncode(key), uriEncode(value))) - } - } - sort.Strings(params) - return strings.Join(params, "&") -} - -// buildCanonicalHeaders builds sorted, lowercase canonical headers. -func (s *SigV4Signer) buildCanonicalHeaders(req *http.Request) (string, string) { - // Headers to sign (must include host and x-amz-*) - headerMap := make(map[string]string) - headerMap["host"] = req.URL.Host - - for key, values := range req.Header { - lowKey := strings.ToLower(key) - // Include x-amz-* headers and content-type - if strings.HasPrefix(lowKey, "x-amz-") || lowKey == "content-type" { - headerMap[lowKey] = strings.TrimSpace(values[0]) - } - } - - // Sort header names - headerNames := make([]string, 0, len(headerMap)) - for name := range headerMap { - headerNames = append(headerNames, name) - } - sort.Strings(headerNames) - - // Build canonical headers and signed headers - var canonicalHeaders strings.Builder - for _, name := range headerNames { - canonicalHeaders.WriteString(name) - canonicalHeaders.WriteString(":") - canonicalHeaders.WriteString(headerMap[name]) - canonicalHeaders.WriteString("\n") - } - - signedHeaders := strings.Join(headerNames, ";") - - return canonicalHeaders.String(), signedHeaders -} - -// getSignatureKey derives the signing key for SigV4. -func (s *SigV4Signer) getSignatureKey(dateStamp string) []byte { - kDate := hmacSHA256([]byte("AWS4"+s.credentials.SecretAccessKey), []byte(dateStamp)) - kRegion := hmacSHA256(kDate, []byte(s.region)) - kService := hmacSHA256(kRegion, []byte(s.service)) - kSigning := hmacSHA256(kService, []byte("aws4_request")) - return kSigning -} - -// hmacSHA256 computes HMAC-SHA256. -func hmacSHA256(key, data []byte) []byte { - h := hmac.New(sha256.New, key) - h.Write(data) - return h.Sum(nil) -} - -// hashSHA256 computes SHA256 hash and returns hex string. -func hashSHA256(data []byte) string { - hash := sha256.Sum256(data) - return hex.EncodeToString(hash[:]) -} - -// uriEncode performs URI encoding for SigV4. -func uriEncode(s string) string { - var result strings.Builder - for i := 0; i < len(s); i++ { - c := s[i] - if (c >= 'A' && c <= 'Z') || (c >= 'a' && c <= 'z') || - (c >= '0' && c <= '9') || c == '-' || c == '.' || c == '_' || c == '~' { - result.WriteByte(c) - } else { - result.WriteString(fmt.Sprintf("%%%02X", c)) - } - } - return result.String() -} - -// IsExpired checks if the credentials are expired or about to expire. -func (c *CognitoCredentials) IsExpired() bool { - // Consider expired if within 5 minutes of expiration - return time.Now().Add(5 * time.Minute).After(c.Expiration) -} diff --git a/internal/auth/kiro/sso_oidc.go b/internal/auth/kiro/sso_oidc.go index 6ef2e960..292f5bcf 100644 --- a/internal/auth/kiro/sso_oidc.go +++ b/internal/auth/kiro/sso_oidc.go @@ -334,7 +334,7 @@ func (c *SSOOIDCClient) RefreshTokenWithRegion(ctx context.Context, clientID, cl } if resp.StatusCode != http.StatusOK { - log.Debugf("token refresh failed (status %d): %s", resp.StatusCode, string(respBody)) + log.Warnf("IDC token refresh failed (status %d): %s", resp.StatusCode, string(respBody)) return nil, fmt.Errorf("token refresh failed (status %d)", resp.StatusCode) } diff --git a/internal/runtime/executor/kiro_executor.go b/internal/runtime/executor/kiro_executor.go index 70f23dfb..1e882888 100644 --- a/internal/runtime/executor/kiro_executor.go +++ b/internal/runtime/executor/kiro_executor.go @@ -10,6 +10,8 @@ import ( "fmt" "io" "net/http" + "os" + "path/filepath" "strings" "sync" "time" @@ -178,64 +180,6 @@ func getKiroEndpointConfigs(auth *cliproxyauth.Auth) []kiroEndpointConfig { type KiroExecutor struct { cfg *config.Config refreshMu sync.Mutex // Serializes token refresh operations to prevent race conditions - - // cognitoCredsCache caches Cognito credentials per auth ID for IDC authentication - // Key: auth.ID, Value: *kiroauth.CognitoCredentials - cognitoCredsCache sync.Map -} - -// getCachedCognitoCredentials retrieves cached Cognito credentials if they are still valid. -func (e *KiroExecutor) getCachedCognitoCredentials(authID string) *kiroauth.CognitoCredentials { - if cached, ok := e.cognitoCredsCache.Load(authID); ok { - creds := cached.(*kiroauth.CognitoCredentials) - if !creds.IsExpired() { - return creds - } - // Credentials expired, remove from cache - e.cognitoCredsCache.Delete(authID) - } - return nil -} - -// cacheCognitoCredentials stores Cognito credentials in the cache. -func (e *KiroExecutor) cacheCognitoCredentials(authID string, creds *kiroauth.CognitoCredentials) { - e.cognitoCredsCache.Store(authID, creds) -} - -// getOrExchangeCognitoCredentials retrieves cached Cognito credentials or exchanges the SSO token for new ones. -func (e *KiroExecutor) getOrExchangeCognitoCredentials(ctx context.Context, auth *cliproxyauth.Auth, accessToken string) (*kiroauth.CognitoCredentials, error) { - if auth == nil { - return nil, fmt.Errorf("auth is nil") - } - - // Check cache first - if creds := e.getCachedCognitoCredentials(auth.ID); creds != nil { - log.Debugf("kiro: using cached Cognito credentials for auth %s (expires: %s)", auth.ID, creds.Expiration.Format(time.RFC3339)) - return creds, nil - } - - // Get region from auth metadata - region := "us-east-1" - if auth.Metadata != nil { - if r, ok := auth.Metadata["region"].(string); ok && r != "" { - region = r - } - } - - log.Infof("kiro: exchanging SSO token for Cognito credentials (region: %s)", region) - - // Exchange SSO token for Cognito credentials - cognitoClient := kiroauth.NewCognitoIdentityClient(e.cfg) - creds, err := cognitoClient.ExchangeSSOTokenForCredentials(ctx, accessToken, region) - if err != nil { - return nil, fmt.Errorf("failed to exchange SSO token for Cognito credentials: %w", err) - } - - // Cache the credentials - e.cacheCognitoCredentials(auth.ID, creds) - log.Infof("kiro: Cognito credentials obtained and cached (expires: %s)", creds.Expiration.Format(time.RFC3339)) - - return creds, nil } // isIDCAuth checks if the auth uses IDC (Identity Center) authentication method. @@ -247,12 +191,6 @@ func isIDCAuth(auth *cliproxyauth.Auth) bool { return authMethod == "idc" } -// signRequestWithSigV4 signs an HTTP request with AWS SigV4 using Cognito credentials. -func signRequestWithSigV4(req *http.Request, payload []byte, creds *kiroauth.CognitoCredentials, region, service string) error { - signer := kiroauth.NewSigV4Signer(creds, region, service) - return signer.SignRequest(req, payload) -} - // buildKiroPayloadForFormat builds the Kiro API payload based on the source format. // This is critical because OpenAI and Claude formats have different tool structures: // - OpenAI: tools[].function.name, tools[].function.description @@ -301,6 +239,10 @@ func (e *KiroExecutor) Execute(ctx context.Context, auth *cliproxyauth.Auth, req log.Warnf("kiro: pre-request token refresh failed: %v", refreshErr) } else if refreshedAuth != nil { auth = refreshedAuth + // Persist the refreshed auth to file so subsequent requests use it + if persistErr := e.persistRefreshedAuth(auth); persistErr != nil { + log.Warnf("kiro: failed to persist refreshed auth: %v", persistErr) + } accessToken, profileArn = kiroCredentials(auth) log.Infof("kiro: token refreshed successfully before request") } @@ -372,40 +314,8 @@ func (e *KiroExecutor) executeWithRetry(ctx context.Context, auth *cliproxyauth. httpReq.Header.Set("Amz-Sdk-Request", "attempt=1; max=3") httpReq.Header.Set("Amz-Sdk-Invocation-Id", uuid.New().String()) - // Choose auth method: SigV4 for IDC, Bearer token for others - // NOTE: Cognito credential exchange disabled for now - testing Bearer token first - if false && isIDCAuth(auth) { - // IDC auth requires SigV4 signing with Cognito-exchanged credentials - cognitoCreds, err := e.getOrExchangeCognitoCredentials(ctx, auth, accessToken) - if err != nil { - log.Warnf("kiro: failed to get Cognito credentials for IDC auth: %v", err) - return resp, fmt.Errorf("IDC auth requires Cognito credentials: %w", err) - } - - // Get region from auth metadata - region := "us-east-1" - if auth.Metadata != nil { - if r, ok := auth.Metadata["region"].(string); ok && r != "" { - region = r - } - } - - // Determine service from URL - service := "codewhisperer" - if strings.Contains(url, "q.us-east-1.amazonaws.com") { - service = "qdeveloper" - } - - // Sign the request with SigV4 - if err := signRequestWithSigV4(httpReq, kiroPayload, cognitoCreds, region, service); err != nil { - log.Warnf("kiro: failed to sign request with SigV4: %v", err) - return resp, fmt.Errorf("SigV4 signing failed: %w", err) - } - log.Debugf("kiro: request signed with SigV4 for IDC auth (service: %s, region: %s)", service, region) - } else { - // Standard Bearer token authentication for Builder ID, social auth, etc. - httpReq.Header.Set("Authorization", "Bearer "+accessToken) - } + // Bearer token authentication for all auth types (Builder ID, IDC, social, etc.) + httpReq.Header.Set("Authorization", "Bearer "+accessToken) var attrs map[string]string if auth != nil { @@ -494,6 +404,11 @@ func (e *KiroExecutor) executeWithRetry(ctx context.Context, auth *cliproxyauth. if refreshedAuth != nil { auth = refreshedAuth + // Persist the refreshed auth to file so subsequent requests use it + if persistErr := e.persistRefreshedAuth(auth); persistErr != nil { + log.Warnf("kiro: failed to persist refreshed auth: %v", persistErr) + // Continue anyway - the token is valid for this request + } accessToken, profileArn = kiroCredentials(auth) // Rebuild payload with new profile ARN if changed kiroPayload, _ = buildKiroPayloadForFormat(body, kiroModelID, profileArn, currentOrigin, isAgentic, isChatOnly, from, opts.Headers) @@ -552,6 +467,11 @@ func (e *KiroExecutor) executeWithRetry(ctx context.Context, auth *cliproxyauth. } if refreshedAuth != nil { auth = refreshedAuth + // Persist the refreshed auth to file so subsequent requests use it + if persistErr := e.persistRefreshedAuth(auth); persistErr != nil { + log.Warnf("kiro: failed to persist refreshed auth: %v", persistErr) + // Continue anyway - the token is valid for this request + } accessToken, profileArn = kiroCredentials(auth) kiroPayload, _ = buildKiroPayloadForFormat(body, kiroModelID, profileArn, currentOrigin, isAgentic, isChatOnly, from, opts.Headers) log.Infof("kiro: token refreshed for 403, retrying request") @@ -654,6 +574,10 @@ func (e *KiroExecutor) ExecuteStream(ctx context.Context, auth *cliproxyauth.Aut log.Warnf("kiro: pre-request token refresh failed: %v", refreshErr) } else if refreshedAuth != nil { auth = refreshedAuth + // Persist the refreshed auth to file so subsequent requests use it + if persistErr := e.persistRefreshedAuth(auth); persistErr != nil { + log.Warnf("kiro: failed to persist refreshed auth: %v", persistErr) + } accessToken, profileArn = kiroCredentials(auth) log.Infof("kiro: token refreshed successfully before stream request") } @@ -723,40 +647,8 @@ func (e *KiroExecutor) executeStreamWithRetry(ctx context.Context, auth *cliprox httpReq.Header.Set("Amz-Sdk-Request", "attempt=1; max=3") httpReq.Header.Set("Amz-Sdk-Invocation-Id", uuid.New().String()) - // Choose auth method: SigV4 for IDC, Bearer token for others - // NOTE: Cognito credential exchange disabled for now - testing Bearer token first - if false && isIDCAuth(auth) { - // IDC auth requires SigV4 signing with Cognito-exchanged credentials - cognitoCreds, err := e.getOrExchangeCognitoCredentials(ctx, auth, accessToken) - if err != nil { - log.Warnf("kiro: failed to get Cognito credentials for IDC auth: %v", err) - return nil, fmt.Errorf("IDC auth requires Cognito credentials: %w", err) - } - - // Get region from auth metadata - region := "us-east-1" - if auth.Metadata != nil { - if r, ok := auth.Metadata["region"].(string); ok && r != "" { - region = r - } - } - - // Determine service from URL - service := "codewhisperer" - if strings.Contains(url, "q.us-east-1.amazonaws.com") { - service = "qdeveloper" - } - - // Sign the request with SigV4 - if err := signRequestWithSigV4(httpReq, kiroPayload, cognitoCreds, region, service); err != nil { - log.Warnf("kiro: failed to sign request with SigV4: %v", err) - return nil, fmt.Errorf("SigV4 signing failed: %w", err) - } - log.Debugf("kiro: stream request signed with SigV4 for IDC auth (service: %s, region: %s)", service, region) - } else { - // Standard Bearer token authentication for Builder ID, social auth, etc. - httpReq.Header.Set("Authorization", "Bearer "+accessToken) - } + // Bearer token authentication for all auth types (Builder ID, IDC, social, etc.) + httpReq.Header.Set("Authorization", "Bearer "+accessToken) var attrs map[string]string if auth != nil { @@ -858,6 +750,11 @@ func (e *KiroExecutor) executeStreamWithRetry(ctx context.Context, auth *cliprox if refreshedAuth != nil { auth = refreshedAuth + // Persist the refreshed auth to file so subsequent requests use it + if persistErr := e.persistRefreshedAuth(auth); persistErr != nil { + log.Warnf("kiro: failed to persist refreshed auth: %v", persistErr) + // Continue anyway - the token is valid for this request + } accessToken, profileArn = kiroCredentials(auth) // Rebuild payload with new profile ARN if changed kiroPayload, _ = buildKiroPayloadForFormat(body, kiroModelID, profileArn, currentOrigin, isAgentic, isChatOnly, from, opts.Headers) @@ -916,6 +813,11 @@ func (e *KiroExecutor) executeStreamWithRetry(ctx context.Context, auth *cliprox } if refreshedAuth != nil { auth = refreshedAuth + // Persist the refreshed auth to file so subsequent requests use it + if persistErr := e.persistRefreshedAuth(auth); persistErr != nil { + log.Warnf("kiro: failed to persist refreshed auth: %v", persistErr) + // Continue anyway - the token is valid for this request + } accessToken, profileArn = kiroCredentials(auth) kiroPayload, _ = buildKiroPayloadForFormat(body, kiroModelID, profileArn, currentOrigin, isAgentic, isChatOnly, from, opts.Headers) log.Infof("kiro: token refreshed for 403, retrying stream request") @@ -3191,6 +3093,7 @@ func (e *KiroExecutor) Refresh(ctx context.Context, auth *cliproxyauth.Auth) (*c var refreshToken string var clientID, clientSecret string var authMethod string + var region, startURL string if auth.Metadata != nil { if rt, ok := auth.Metadata["refresh_token"].(string); ok { @@ -3205,6 +3108,12 @@ func (e *KiroExecutor) Refresh(ctx context.Context, auth *cliproxyauth.Auth) (*c if am, ok := auth.Metadata["auth_method"].(string); ok { authMethod = am } + if r, ok := auth.Metadata["region"].(string); ok { + region = r + } + if su, ok := auth.Metadata["start_url"].(string); ok { + startURL = su + } } if refreshToken == "" { @@ -3214,12 +3123,20 @@ func (e *KiroExecutor) Refresh(ctx context.Context, auth *cliproxyauth.Auth) (*c var tokenData *kiroauth.KiroTokenData var err error - // Use SSO OIDC refresh for AWS Builder ID, otherwise use Kiro's OAuth refresh endpoint - if clientID != "" && clientSecret != "" && authMethod == "builder-id" { + ssoClient := kiroauth.NewSSOOIDCClient(e.cfg) + + // Use SSO OIDC refresh for AWS Builder ID or IDC, otherwise use Kiro's OAuth refresh endpoint + switch { + case clientID != "" && clientSecret != "" && authMethod == "idc" && region != "": + // IDC refresh with region-specific endpoint + log.Debugf("kiro executor: using SSO OIDC refresh for IDC (region=%s)", region) + tokenData, err = ssoClient.RefreshTokenWithRegion(ctx, clientID, clientSecret, refreshToken, region, startURL) + case clientID != "" && clientSecret != "" && authMethod == "builder-id": + // Builder ID refresh with default endpoint log.Debugf("kiro executor: using SSO OIDC refresh for AWS Builder ID") - ssoClient := kiroauth.NewSSOOIDCClient(e.cfg) tokenData, err = ssoClient.RefreshToken(ctx, clientID, clientSecret, refreshToken) - } else { + default: + // Fallback to Kiro's OAuth refresh endpoint (for social auth: Google/GitHub) log.Debugf("kiro executor: using Kiro OAuth refresh endpoint") oauth := kiroauth.NewKiroOAuth(e.cfg) tokenData, err = oauth.RefreshToken(ctx, refreshToken) @@ -3275,6 +3192,53 @@ func (e *KiroExecutor) Refresh(ctx context.Context, auth *cliproxyauth.Auth) (*c return updated, nil } +// persistRefreshedAuth persists a refreshed auth record to disk. +// This ensures token refreshes from inline retry are saved to the auth file. +func (e *KiroExecutor) persistRefreshedAuth(auth *cliproxyauth.Auth) error { + if auth == nil || auth.Metadata == nil { + return fmt.Errorf("kiro executor: cannot persist nil auth or metadata") + } + + // Determine the file path from auth attributes or filename + var authPath string + if auth.Attributes != nil { + if p := strings.TrimSpace(auth.Attributes["path"]); p != "" { + authPath = p + } + } + if authPath == "" { + fileName := strings.TrimSpace(auth.FileName) + if fileName == "" { + return fmt.Errorf("kiro executor: auth has no file path or filename") + } + if filepath.IsAbs(fileName) { + authPath = fileName + } else if e.cfg != nil && e.cfg.AuthDir != "" { + authPath = filepath.Join(e.cfg.AuthDir, fileName) + } else { + return fmt.Errorf("kiro executor: cannot determine auth file path") + } + } + + // Marshal metadata to JSON + raw, err := json.Marshal(auth.Metadata) + if err != nil { + return fmt.Errorf("kiro executor: marshal metadata failed: %w", err) + } + + // Write to temp file first, then rename (atomic write) + tmp := authPath + ".tmp" + if err := os.WriteFile(tmp, raw, 0o600); err != nil { + return fmt.Errorf("kiro executor: write temp auth file failed: %w", err) + } + if err := os.Rename(tmp, authPath); err != nil { + return fmt.Errorf("kiro executor: rename auth file failed: %w", err) + } + + log.Debugf("kiro executor: persisted refreshed auth to %s", authPath) + return nil +} + // isTokenExpired checks if a JWT access token has expired. // Returns true if the token is expired or cannot be parsed. func (e *KiroExecutor) isTokenExpired(accessToken string) bool {