diff --git a/internal/auth/kiro/cognito.go b/internal/auth/kiro/cognito.go deleted file mode 100644 index 7cf32818..00000000 --- a/internal/auth/kiro/cognito.go +++ /dev/null @@ -1,408 +0,0 @@ -// Package kiro provides Cognito Identity credential exchange for IDC authentication. -// AWS Identity Center (IDC) requires SigV4 signing with Cognito-exchanged credentials -// instead of Bearer token authentication. -package kiro - -import ( - "context" - "crypto/hmac" - "crypto/sha256" - "encoding/hex" - "encoding/json" - "fmt" - "io" - "net/http" - "sort" - "strings" - "time" - - "github.com/router-for-me/CLIProxyAPI/v6/internal/config" - "github.com/router-for-me/CLIProxyAPI/v6/internal/util" - log "github.com/sirupsen/logrus" -) - -const ( - // Cognito Identity endpoints - cognitoIdentityEndpoint = "https://cognito-identity.us-east-1.amazonaws.com" - - // Identity Pool ID for Q Developer / CodeWhisperer - // This is the identity pool used by kiro-cli and Amazon Q CLI - cognitoIdentityPoolID = "us-east-1:70717e99-906f-485d-8d89-c89a0b5d49c5" - - // Cognito provider name for SSO OIDC - cognitoProviderName = "cognito-identity.amazonaws.com" -) - -// CognitoCredentials holds temporary AWS credentials from Cognito Identity. -type CognitoCredentials struct { - AccessKeyID string `json:"access_key_id"` - SecretAccessKey string `json:"secret_access_key"` - SessionToken string `json:"session_token"` - Expiration time.Time `json:"expiration"` -} - -// CognitoIdentityClient handles Cognito Identity credential exchange. -type CognitoIdentityClient struct { - httpClient *http.Client - cfg *config.Config -} - -// NewCognitoIdentityClient creates a new Cognito Identity client. -func NewCognitoIdentityClient(cfg *config.Config) *CognitoIdentityClient { - client := &http.Client{Timeout: 30 * time.Second} - if cfg != nil { - client = util.SetProxy(&cfg.SDKConfig, client) - } - return &CognitoIdentityClient{ - httpClient: client, - cfg: cfg, - } -} - -// GetIdentityID retrieves a Cognito Identity ID using the SSO access token. -func (c *CognitoIdentityClient) GetIdentityID(ctx context.Context, accessToken, region string) (string, error) { - if region == "" { - region = "us-east-1" - } - - endpoint := fmt.Sprintf("https://cognito-identity.%s.amazonaws.com", region) - - // Build the GetId request - // The SSO token is passed as a login token for the identity pool - payload := map[string]interface{}{ - "IdentityPoolId": cognitoIdentityPoolID, - "Logins": map[string]string{ - // Use the OIDC provider URL as the key - fmt.Sprintf("oidc.%s.amazonaws.com", region): accessToken, - }, - } - - body, err := json.Marshal(payload) - if err != nil { - return "", fmt.Errorf("failed to marshal GetId request: %w", err) - } - - req, err := http.NewRequestWithContext(ctx, http.MethodPost, endpoint, strings.NewReader(string(body))) - if err != nil { - return "", fmt.Errorf("failed to create GetId request: %w", err) - } - - req.Header.Set("Content-Type", "application/x-amz-json-1.1") - req.Header.Set("X-Amz-Target", "AWSCognitoIdentityService.GetId") - req.Header.Set("Accept", "application/json") - - resp, err := c.httpClient.Do(req) - if err != nil { - return "", fmt.Errorf("GetId request failed: %w", err) - } - defer resp.Body.Close() - - respBody, err := io.ReadAll(resp.Body) - if err != nil { - return "", fmt.Errorf("failed to read GetId response: %w", err) - } - - if resp.StatusCode != http.StatusOK { - log.Debugf("Cognito GetId failed (status %d): %s", resp.StatusCode, string(respBody)) - return "", fmt.Errorf("GetId failed (status %d): %s", resp.StatusCode, string(respBody)) - } - - var result struct { - IdentityID string `json:"IdentityId"` - } - if err := json.Unmarshal(respBody, &result); err != nil { - return "", fmt.Errorf("failed to parse GetId response: %w", err) - } - - if result.IdentityID == "" { - return "", fmt.Errorf("empty IdentityId in GetId response") - } - - log.Debugf("Cognito Identity ID: %s", result.IdentityID) - return result.IdentityID, nil -} - -// GetCredentialsForIdentity exchanges an identity ID and login token for temporary AWS credentials. -func (c *CognitoIdentityClient) GetCredentialsForIdentity(ctx context.Context, identityID, accessToken, region string) (*CognitoCredentials, error) { - if region == "" { - region = "us-east-1" - } - - endpoint := fmt.Sprintf("https://cognito-identity.%s.amazonaws.com", region) - - payload := map[string]interface{}{ - "IdentityId": identityID, - "Logins": map[string]string{ - fmt.Sprintf("oidc.%s.amazonaws.com", region): accessToken, - }, - } - - body, err := json.Marshal(payload) - if err != nil { - return nil, fmt.Errorf("failed to marshal GetCredentialsForIdentity request: %w", err) - } - - req, err := http.NewRequestWithContext(ctx, http.MethodPost, endpoint, strings.NewReader(string(body))) - if err != nil { - return nil, fmt.Errorf("failed to create GetCredentialsForIdentity request: %w", err) - } - - req.Header.Set("Content-Type", "application/x-amz-json-1.1") - req.Header.Set("X-Amz-Target", "AWSCognitoIdentityService.GetCredentialsForIdentity") - req.Header.Set("Accept", "application/json") - - resp, err := c.httpClient.Do(req) - if err != nil { - return nil, fmt.Errorf("GetCredentialsForIdentity request failed: %w", err) - } - defer resp.Body.Close() - - respBody, err := io.ReadAll(resp.Body) - if err != nil { - return nil, fmt.Errorf("failed to read GetCredentialsForIdentity response: %w", err) - } - - if resp.StatusCode != http.StatusOK { - log.Debugf("Cognito GetCredentialsForIdentity failed (status %d): %s", resp.StatusCode, string(respBody)) - return nil, fmt.Errorf("GetCredentialsForIdentity failed (status %d): %s", resp.StatusCode, string(respBody)) - } - - var result struct { - Credentials struct { - AccessKeyID string `json:"AccessKeyId"` - SecretKey string `json:"SecretKey"` - SessionToken string `json:"SessionToken"` - Expiration int64 `json:"Expiration"` - } `json:"Credentials"` - IdentityID string `json:"IdentityId"` - } - if err := json.Unmarshal(respBody, &result); err != nil { - return nil, fmt.Errorf("failed to parse GetCredentialsForIdentity response: %w", err) - } - - if result.Credentials.AccessKeyID == "" { - return nil, fmt.Errorf("empty AccessKeyId in GetCredentialsForIdentity response") - } - - // Expiration is in seconds since epoch - expiration := time.Unix(result.Credentials.Expiration, 0) - - log.Debugf("Cognito credentials obtained, expires: %s", expiration.Format(time.RFC3339)) - - return &CognitoCredentials{ - AccessKeyID: result.Credentials.AccessKeyID, - SecretAccessKey: result.Credentials.SecretKey, - SessionToken: result.Credentials.SessionToken, - Expiration: expiration, - }, nil -} - -// ExchangeSSOTokenForCredentials is a convenience method that performs the full -// Cognito Identity credential exchange flow: GetId -> GetCredentialsForIdentity -func (c *CognitoIdentityClient) ExchangeSSOTokenForCredentials(ctx context.Context, accessToken, region string) (*CognitoCredentials, error) { - log.Debugf("Exchanging SSO token for Cognito credentials (region: %s)", region) - - // Step 1: Get Identity ID - identityID, err := c.GetIdentityID(ctx, accessToken, region) - if err != nil { - return nil, fmt.Errorf("failed to get identity ID: %w", err) - } - - // Step 2: Get credentials for the identity - creds, err := c.GetCredentialsForIdentity(ctx, identityID, accessToken, region) - if err != nil { - return nil, fmt.Errorf("failed to get credentials for identity: %w", err) - } - - return creds, nil -} - -// SigV4Signer provides AWS Signature Version 4 signing for HTTP requests. -type SigV4Signer struct { - credentials *CognitoCredentials - region string - service string -} - -// NewSigV4Signer creates a new SigV4 signer with the given credentials. -func NewSigV4Signer(creds *CognitoCredentials, region, service string) *SigV4Signer { - return &SigV4Signer{ - credentials: creds, - region: region, - service: service, - } -} - -// SignRequest signs an HTTP request using AWS Signature Version 4. -// The request body must be provided separately since it may have been read already. -func (s *SigV4Signer) SignRequest(req *http.Request, body []byte) error { - now := time.Now().UTC() - amzDate := now.Format("20060102T150405Z") - dateStamp := now.Format("20060102") - - // Ensure required headers are set - if req.Header.Get("Host") == "" { - req.Header.Set("Host", req.URL.Host) - } - req.Header.Set("X-Amz-Date", amzDate) - if s.credentials.SessionToken != "" { - req.Header.Set("X-Amz-Security-Token", s.credentials.SessionToken) - } - - // Create canonical request - canonicalRequest, signedHeaders := s.createCanonicalRequest(req, body) - - // Create string to sign - algorithm := "AWS4-HMAC-SHA256" - credentialScope := fmt.Sprintf("%s/%s/%s/aws4_request", dateStamp, s.region, s.service) - stringToSign := fmt.Sprintf("%s\n%s\n%s\n%s", - algorithm, - amzDate, - credentialScope, - hashSHA256([]byte(canonicalRequest)), - ) - - // Calculate signature - signingKey := s.getSignatureKey(dateStamp) - signature := hex.EncodeToString(hmacSHA256(signingKey, []byte(stringToSign))) - - // Build Authorization header - authHeader := fmt.Sprintf("%s Credential=%s/%s, SignedHeaders=%s, Signature=%s", - algorithm, - s.credentials.AccessKeyID, - credentialScope, - signedHeaders, - signature, - ) - - req.Header.Set("Authorization", authHeader) - - return nil -} - -// createCanonicalRequest builds the canonical request string for SigV4. -func (s *SigV4Signer) createCanonicalRequest(req *http.Request, body []byte) (string, string) { - // HTTP method - method := req.Method - - // Canonical URI - uri := req.URL.Path - if uri == "" { - uri = "/" - } - - // Canonical query string (sorted) - queryString := s.buildCanonicalQueryString(req) - - // Canonical headers (sorted, lowercase) - canonicalHeaders, signedHeaders := s.buildCanonicalHeaders(req) - - // Hashed payload - payloadHash := hashSHA256(body) - - canonicalRequest := fmt.Sprintf("%s\n%s\n%s\n%s\n%s\n%s", - method, - uri, - queryString, - canonicalHeaders, - signedHeaders, - payloadHash, - ) - - return canonicalRequest, signedHeaders -} - -// buildCanonicalQueryString builds a sorted, URI-encoded query string. -func (s *SigV4Signer) buildCanonicalQueryString(req *http.Request) string { - if req.URL.RawQuery == "" { - return "" - } - - // Parse and sort query parameters - params := make([]string, 0) - for key, values := range req.URL.Query() { - for _, value := range values { - params = append(params, fmt.Sprintf("%s=%s", uriEncode(key), uriEncode(value))) - } - } - sort.Strings(params) - return strings.Join(params, "&") -} - -// buildCanonicalHeaders builds sorted, lowercase canonical headers. -func (s *SigV4Signer) buildCanonicalHeaders(req *http.Request) (string, string) { - // Headers to sign (must include host and x-amz-*) - headerMap := make(map[string]string) - headerMap["host"] = req.URL.Host - - for key, values := range req.Header { - lowKey := strings.ToLower(key) - // Include x-amz-* headers and content-type - if strings.HasPrefix(lowKey, "x-amz-") || lowKey == "content-type" { - headerMap[lowKey] = strings.TrimSpace(values[0]) - } - } - - // Sort header names - headerNames := make([]string, 0, len(headerMap)) - for name := range headerMap { - headerNames = append(headerNames, name) - } - sort.Strings(headerNames) - - // Build canonical headers and signed headers - var canonicalHeaders strings.Builder - for _, name := range headerNames { - canonicalHeaders.WriteString(name) - canonicalHeaders.WriteString(":") - canonicalHeaders.WriteString(headerMap[name]) - canonicalHeaders.WriteString("\n") - } - - signedHeaders := strings.Join(headerNames, ";") - - return canonicalHeaders.String(), signedHeaders -} - -// getSignatureKey derives the signing key for SigV4. -func (s *SigV4Signer) getSignatureKey(dateStamp string) []byte { - kDate := hmacSHA256([]byte("AWS4"+s.credentials.SecretAccessKey), []byte(dateStamp)) - kRegion := hmacSHA256(kDate, []byte(s.region)) - kService := hmacSHA256(kRegion, []byte(s.service)) - kSigning := hmacSHA256(kService, []byte("aws4_request")) - return kSigning -} - -// hmacSHA256 computes HMAC-SHA256. -func hmacSHA256(key, data []byte) []byte { - h := hmac.New(sha256.New, key) - h.Write(data) - return h.Sum(nil) -} - -// hashSHA256 computes SHA256 hash and returns hex string. -func hashSHA256(data []byte) string { - hash := sha256.Sum256(data) - return hex.EncodeToString(hash[:]) -} - -// uriEncode performs URI encoding for SigV4. -func uriEncode(s string) string { - var result strings.Builder - for i := 0; i < len(s); i++ { - c := s[i] - if (c >= 'A' && c <= 'Z') || (c >= 'a' && c <= 'z') || - (c >= '0' && c <= '9') || c == '-' || c == '.' || c == '_' || c == '~' { - result.WriteByte(c) - } else { - result.WriteString(fmt.Sprintf("%%%02X", c)) - } - } - return result.String() -} - -// IsExpired checks if the credentials are expired or about to expire. -func (c *CognitoCredentials) IsExpired() bool { - // Consider expired if within 5 minutes of expiration - return time.Now().Add(5 * time.Minute).After(c.Expiration) -} diff --git a/internal/auth/kiro/sso_oidc.go b/internal/auth/kiro/sso_oidc.go index 6ef2e960..292f5bcf 100644 --- a/internal/auth/kiro/sso_oidc.go +++ b/internal/auth/kiro/sso_oidc.go @@ -334,7 +334,7 @@ func (c *SSOOIDCClient) RefreshTokenWithRegion(ctx context.Context, clientID, cl } if resp.StatusCode != http.StatusOK { - log.Debugf("token refresh failed (status %d): %s", resp.StatusCode, string(respBody)) + log.Warnf("IDC token refresh failed (status %d): %s", resp.StatusCode, string(respBody)) return nil, fmt.Errorf("token refresh failed (status %d)", resp.StatusCode) } diff --git a/internal/runtime/executor/kiro_executor.go b/internal/runtime/executor/kiro_executor.go index 70f23dfb..1e882888 100644 --- a/internal/runtime/executor/kiro_executor.go +++ b/internal/runtime/executor/kiro_executor.go @@ -10,6 +10,8 @@ import ( "fmt" "io" "net/http" + "os" + "path/filepath" "strings" "sync" "time" @@ -178,64 +180,6 @@ func getKiroEndpointConfigs(auth *cliproxyauth.Auth) []kiroEndpointConfig { type KiroExecutor struct { cfg *config.Config refreshMu sync.Mutex // Serializes token refresh operations to prevent race conditions - - // cognitoCredsCache caches Cognito credentials per auth ID for IDC authentication - // Key: auth.ID, Value: *kiroauth.CognitoCredentials - cognitoCredsCache sync.Map -} - -// getCachedCognitoCredentials retrieves cached Cognito credentials if they are still valid. -func (e *KiroExecutor) getCachedCognitoCredentials(authID string) *kiroauth.CognitoCredentials { - if cached, ok := e.cognitoCredsCache.Load(authID); ok { - creds := cached.(*kiroauth.CognitoCredentials) - if !creds.IsExpired() { - return creds - } - // Credentials expired, remove from cache - e.cognitoCredsCache.Delete(authID) - } - return nil -} - -// cacheCognitoCredentials stores Cognito credentials in the cache. -func (e *KiroExecutor) cacheCognitoCredentials(authID string, creds *kiroauth.CognitoCredentials) { - e.cognitoCredsCache.Store(authID, creds) -} - -// getOrExchangeCognitoCredentials retrieves cached Cognito credentials or exchanges the SSO token for new ones. -func (e *KiroExecutor) getOrExchangeCognitoCredentials(ctx context.Context, auth *cliproxyauth.Auth, accessToken string) (*kiroauth.CognitoCredentials, error) { - if auth == nil { - return nil, fmt.Errorf("auth is nil") - } - - // Check cache first - if creds := e.getCachedCognitoCredentials(auth.ID); creds != nil { - log.Debugf("kiro: using cached Cognito credentials for auth %s (expires: %s)", auth.ID, creds.Expiration.Format(time.RFC3339)) - return creds, nil - } - - // Get region from auth metadata - region := "us-east-1" - if auth.Metadata != nil { - if r, ok := auth.Metadata["region"].(string); ok && r != "" { - region = r - } - } - - log.Infof("kiro: exchanging SSO token for Cognito credentials (region: %s)", region) - - // Exchange SSO token for Cognito credentials - cognitoClient := kiroauth.NewCognitoIdentityClient(e.cfg) - creds, err := cognitoClient.ExchangeSSOTokenForCredentials(ctx, accessToken, region) - if err != nil { - return nil, fmt.Errorf("failed to exchange SSO token for Cognito credentials: %w", err) - } - - // Cache the credentials - e.cacheCognitoCredentials(auth.ID, creds) - log.Infof("kiro: Cognito credentials obtained and cached (expires: %s)", creds.Expiration.Format(time.RFC3339)) - - return creds, nil } // isIDCAuth checks if the auth uses IDC (Identity Center) authentication method. @@ -247,12 +191,6 @@ func isIDCAuth(auth *cliproxyauth.Auth) bool { return authMethod == "idc" } -// signRequestWithSigV4 signs an HTTP request with AWS SigV4 using Cognito credentials. -func signRequestWithSigV4(req *http.Request, payload []byte, creds *kiroauth.CognitoCredentials, region, service string) error { - signer := kiroauth.NewSigV4Signer(creds, region, service) - return signer.SignRequest(req, payload) -} - // buildKiroPayloadForFormat builds the Kiro API payload based on the source format. // This is critical because OpenAI and Claude formats have different tool structures: // - OpenAI: tools[].function.name, tools[].function.description @@ -301,6 +239,10 @@ func (e *KiroExecutor) Execute(ctx context.Context, auth *cliproxyauth.Auth, req log.Warnf("kiro: pre-request token refresh failed: %v", refreshErr) } else if refreshedAuth != nil { auth = refreshedAuth + // Persist the refreshed auth to file so subsequent requests use it + if persistErr := e.persistRefreshedAuth(auth); persistErr != nil { + log.Warnf("kiro: failed to persist refreshed auth: %v", persistErr) + } accessToken, profileArn = kiroCredentials(auth) log.Infof("kiro: token refreshed successfully before request") } @@ -372,40 +314,8 @@ func (e *KiroExecutor) executeWithRetry(ctx context.Context, auth *cliproxyauth. httpReq.Header.Set("Amz-Sdk-Request", "attempt=1; max=3") httpReq.Header.Set("Amz-Sdk-Invocation-Id", uuid.New().String()) - // Choose auth method: SigV4 for IDC, Bearer token for others - // NOTE: Cognito credential exchange disabled for now - testing Bearer token first - if false && isIDCAuth(auth) { - // IDC auth requires SigV4 signing with Cognito-exchanged credentials - cognitoCreds, err := e.getOrExchangeCognitoCredentials(ctx, auth, accessToken) - if err != nil { - log.Warnf("kiro: failed to get Cognito credentials for IDC auth: %v", err) - return resp, fmt.Errorf("IDC auth requires Cognito credentials: %w", err) - } - - // Get region from auth metadata - region := "us-east-1" - if auth.Metadata != nil { - if r, ok := auth.Metadata["region"].(string); ok && r != "" { - region = r - } - } - - // Determine service from URL - service := "codewhisperer" - if strings.Contains(url, "q.us-east-1.amazonaws.com") { - service = "qdeveloper" - } - - // Sign the request with SigV4 - if err := signRequestWithSigV4(httpReq, kiroPayload, cognitoCreds, region, service); err != nil { - log.Warnf("kiro: failed to sign request with SigV4: %v", err) - return resp, fmt.Errorf("SigV4 signing failed: %w", err) - } - log.Debugf("kiro: request signed with SigV4 for IDC auth (service: %s, region: %s)", service, region) - } else { - // Standard Bearer token authentication for Builder ID, social auth, etc. - httpReq.Header.Set("Authorization", "Bearer "+accessToken) - } + // Bearer token authentication for all auth types (Builder ID, IDC, social, etc.) + httpReq.Header.Set("Authorization", "Bearer "+accessToken) var attrs map[string]string if auth != nil { @@ -494,6 +404,11 @@ func (e *KiroExecutor) executeWithRetry(ctx context.Context, auth *cliproxyauth. if refreshedAuth != nil { auth = refreshedAuth + // Persist the refreshed auth to file so subsequent requests use it + if persistErr := e.persistRefreshedAuth(auth); persistErr != nil { + log.Warnf("kiro: failed to persist refreshed auth: %v", persistErr) + // Continue anyway - the token is valid for this request + } accessToken, profileArn = kiroCredentials(auth) // Rebuild payload with new profile ARN if changed kiroPayload, _ = buildKiroPayloadForFormat(body, kiroModelID, profileArn, currentOrigin, isAgentic, isChatOnly, from, opts.Headers) @@ -552,6 +467,11 @@ func (e *KiroExecutor) executeWithRetry(ctx context.Context, auth *cliproxyauth. } if refreshedAuth != nil { auth = refreshedAuth + // Persist the refreshed auth to file so subsequent requests use it + if persistErr := e.persistRefreshedAuth(auth); persistErr != nil { + log.Warnf("kiro: failed to persist refreshed auth: %v", persistErr) + // Continue anyway - the token is valid for this request + } accessToken, profileArn = kiroCredentials(auth) kiroPayload, _ = buildKiroPayloadForFormat(body, kiroModelID, profileArn, currentOrigin, isAgentic, isChatOnly, from, opts.Headers) log.Infof("kiro: token refreshed for 403, retrying request") @@ -654,6 +574,10 @@ func (e *KiroExecutor) ExecuteStream(ctx context.Context, auth *cliproxyauth.Aut log.Warnf("kiro: pre-request token refresh failed: %v", refreshErr) } else if refreshedAuth != nil { auth = refreshedAuth + // Persist the refreshed auth to file so subsequent requests use it + if persistErr := e.persistRefreshedAuth(auth); persistErr != nil { + log.Warnf("kiro: failed to persist refreshed auth: %v", persistErr) + } accessToken, profileArn = kiroCredentials(auth) log.Infof("kiro: token refreshed successfully before stream request") } @@ -723,40 +647,8 @@ func (e *KiroExecutor) executeStreamWithRetry(ctx context.Context, auth *cliprox httpReq.Header.Set("Amz-Sdk-Request", "attempt=1; max=3") httpReq.Header.Set("Amz-Sdk-Invocation-Id", uuid.New().String()) - // Choose auth method: SigV4 for IDC, Bearer token for others - // NOTE: Cognito credential exchange disabled for now - testing Bearer token first - if false && isIDCAuth(auth) { - // IDC auth requires SigV4 signing with Cognito-exchanged credentials - cognitoCreds, err := e.getOrExchangeCognitoCredentials(ctx, auth, accessToken) - if err != nil { - log.Warnf("kiro: failed to get Cognito credentials for IDC auth: %v", err) - return nil, fmt.Errorf("IDC auth requires Cognito credentials: %w", err) - } - - // Get region from auth metadata - region := "us-east-1" - if auth.Metadata != nil { - if r, ok := auth.Metadata["region"].(string); ok && r != "" { - region = r - } - } - - // Determine service from URL - service := "codewhisperer" - if strings.Contains(url, "q.us-east-1.amazonaws.com") { - service = "qdeveloper" - } - - // Sign the request with SigV4 - if err := signRequestWithSigV4(httpReq, kiroPayload, cognitoCreds, region, service); err != nil { - log.Warnf("kiro: failed to sign request with SigV4: %v", err) - return nil, fmt.Errorf("SigV4 signing failed: %w", err) - } - log.Debugf("kiro: stream request signed with SigV4 for IDC auth (service: %s, region: %s)", service, region) - } else { - // Standard Bearer token authentication for Builder ID, social auth, etc. - httpReq.Header.Set("Authorization", "Bearer "+accessToken) - } + // Bearer token authentication for all auth types (Builder ID, IDC, social, etc.) + httpReq.Header.Set("Authorization", "Bearer "+accessToken) var attrs map[string]string if auth != nil { @@ -858,6 +750,11 @@ func (e *KiroExecutor) executeStreamWithRetry(ctx context.Context, auth *cliprox if refreshedAuth != nil { auth = refreshedAuth + // Persist the refreshed auth to file so subsequent requests use it + if persistErr := e.persistRefreshedAuth(auth); persistErr != nil { + log.Warnf("kiro: failed to persist refreshed auth: %v", persistErr) + // Continue anyway - the token is valid for this request + } accessToken, profileArn = kiroCredentials(auth) // Rebuild payload with new profile ARN if changed kiroPayload, _ = buildKiroPayloadForFormat(body, kiroModelID, profileArn, currentOrigin, isAgentic, isChatOnly, from, opts.Headers) @@ -916,6 +813,11 @@ func (e *KiroExecutor) executeStreamWithRetry(ctx context.Context, auth *cliprox } if refreshedAuth != nil { auth = refreshedAuth + // Persist the refreshed auth to file so subsequent requests use it + if persistErr := e.persistRefreshedAuth(auth); persistErr != nil { + log.Warnf("kiro: failed to persist refreshed auth: %v", persistErr) + // Continue anyway - the token is valid for this request + } accessToken, profileArn = kiroCredentials(auth) kiroPayload, _ = buildKiroPayloadForFormat(body, kiroModelID, profileArn, currentOrigin, isAgentic, isChatOnly, from, opts.Headers) log.Infof("kiro: token refreshed for 403, retrying stream request") @@ -3191,6 +3093,7 @@ func (e *KiroExecutor) Refresh(ctx context.Context, auth *cliproxyauth.Auth) (*c var refreshToken string var clientID, clientSecret string var authMethod string + var region, startURL string if auth.Metadata != nil { if rt, ok := auth.Metadata["refresh_token"].(string); ok { @@ -3205,6 +3108,12 @@ func (e *KiroExecutor) Refresh(ctx context.Context, auth *cliproxyauth.Auth) (*c if am, ok := auth.Metadata["auth_method"].(string); ok { authMethod = am } + if r, ok := auth.Metadata["region"].(string); ok { + region = r + } + if su, ok := auth.Metadata["start_url"].(string); ok { + startURL = su + } } if refreshToken == "" { @@ -3214,12 +3123,20 @@ func (e *KiroExecutor) Refresh(ctx context.Context, auth *cliproxyauth.Auth) (*c var tokenData *kiroauth.KiroTokenData var err error - // Use SSO OIDC refresh for AWS Builder ID, otherwise use Kiro's OAuth refresh endpoint - if clientID != "" && clientSecret != "" && authMethod == "builder-id" { + ssoClient := kiroauth.NewSSOOIDCClient(e.cfg) + + // Use SSO OIDC refresh for AWS Builder ID or IDC, otherwise use Kiro's OAuth refresh endpoint + switch { + case clientID != "" && clientSecret != "" && authMethod == "idc" && region != "": + // IDC refresh with region-specific endpoint + log.Debugf("kiro executor: using SSO OIDC refresh for IDC (region=%s)", region) + tokenData, err = ssoClient.RefreshTokenWithRegion(ctx, clientID, clientSecret, refreshToken, region, startURL) + case clientID != "" && clientSecret != "" && authMethod == "builder-id": + // Builder ID refresh with default endpoint log.Debugf("kiro executor: using SSO OIDC refresh for AWS Builder ID") - ssoClient := kiroauth.NewSSOOIDCClient(e.cfg) tokenData, err = ssoClient.RefreshToken(ctx, clientID, clientSecret, refreshToken) - } else { + default: + // Fallback to Kiro's OAuth refresh endpoint (for social auth: Google/GitHub) log.Debugf("kiro executor: using Kiro OAuth refresh endpoint") oauth := kiroauth.NewKiroOAuth(e.cfg) tokenData, err = oauth.RefreshToken(ctx, refreshToken) @@ -3275,6 +3192,53 @@ func (e *KiroExecutor) Refresh(ctx context.Context, auth *cliproxyauth.Auth) (*c return updated, nil } +// persistRefreshedAuth persists a refreshed auth record to disk. +// This ensures token refreshes from inline retry are saved to the auth file. +func (e *KiroExecutor) persistRefreshedAuth(auth *cliproxyauth.Auth) error { + if auth == nil || auth.Metadata == nil { + return fmt.Errorf("kiro executor: cannot persist nil auth or metadata") + } + + // Determine the file path from auth attributes or filename + var authPath string + if auth.Attributes != nil { + if p := strings.TrimSpace(auth.Attributes["path"]); p != "" { + authPath = p + } + } + if authPath == "" { + fileName := strings.TrimSpace(auth.FileName) + if fileName == "" { + return fmt.Errorf("kiro executor: auth has no file path or filename") + } + if filepath.IsAbs(fileName) { + authPath = fileName + } else if e.cfg != nil && e.cfg.AuthDir != "" { + authPath = filepath.Join(e.cfg.AuthDir, fileName) + } else { + return fmt.Errorf("kiro executor: cannot determine auth file path") + } + } + + // Marshal metadata to JSON + raw, err := json.Marshal(auth.Metadata) + if err != nil { + return fmt.Errorf("kiro executor: marshal metadata failed: %w", err) + } + + // Write to temp file first, then rename (atomic write) + tmp := authPath + ".tmp" + if err := os.WriteFile(tmp, raw, 0o600); err != nil { + return fmt.Errorf("kiro executor: write temp auth file failed: %w", err) + } + if err := os.Rename(tmp, authPath); err != nil { + return fmt.Errorf("kiro executor: rename auth file failed: %w", err) + } + + log.Debugf("kiro executor: persisted refreshed auth to %s", authPath) + return nil +} + // isTokenExpired checks if a JWT access token has expired. // Returns true if the token is expired or cannot be parsed. func (e *KiroExecutor) isTokenExpired(accessToken string) bool {