feat: persist refreshed IDC tokens to auth file

Add persistRefreshedAuth function to write refreshed tokens back to the
auth file after inline token refresh. This prevents repeated token
refreshes on every request when the token expires.

Changes:
- Add persistRefreshedAuth() to kiro_executor.go
- Call persist after all token refresh paths (401, 403, pre-request)
- Remove unused log import from sdk/auth/kiro.go
This commit is contained in:
Joao
2025-12-22 12:23:10 +00:00
parent 7fd98f3556
commit 98db5aabd0
3 changed files with 101 additions and 545 deletions

View File

@@ -1,408 +0,0 @@
// Package kiro provides Cognito Identity credential exchange for IDC authentication.
// AWS Identity Center (IDC) requires SigV4 signing with Cognito-exchanged credentials
// instead of Bearer token authentication.
package kiro
import (
"context"
"crypto/hmac"
"crypto/sha256"
"encoding/hex"
"encoding/json"
"fmt"
"io"
"net/http"
"sort"
"strings"
"time"
"github.com/router-for-me/CLIProxyAPI/v6/internal/config"
"github.com/router-for-me/CLIProxyAPI/v6/internal/util"
log "github.com/sirupsen/logrus"
)
const (
// Cognito Identity endpoints
cognitoIdentityEndpoint = "https://cognito-identity.us-east-1.amazonaws.com"
// Identity Pool ID for Q Developer / CodeWhisperer
// This is the identity pool used by kiro-cli and Amazon Q CLI
cognitoIdentityPoolID = "us-east-1:70717e99-906f-485d-8d89-c89a0b5d49c5"
// Cognito provider name for SSO OIDC
cognitoProviderName = "cognito-identity.amazonaws.com"
)
// CognitoCredentials holds temporary AWS credentials from Cognito Identity.
type CognitoCredentials struct {
AccessKeyID string `json:"access_key_id"`
SecretAccessKey string `json:"secret_access_key"`
SessionToken string `json:"session_token"`
Expiration time.Time `json:"expiration"`
}
// CognitoIdentityClient handles Cognito Identity credential exchange.
type CognitoIdentityClient struct {
httpClient *http.Client
cfg *config.Config
}
// NewCognitoIdentityClient creates a new Cognito Identity client.
func NewCognitoIdentityClient(cfg *config.Config) *CognitoIdentityClient {
client := &http.Client{Timeout: 30 * time.Second}
if cfg != nil {
client = util.SetProxy(&cfg.SDKConfig, client)
}
return &CognitoIdentityClient{
httpClient: client,
cfg: cfg,
}
}
// GetIdentityID retrieves a Cognito Identity ID using the SSO access token.
func (c *CognitoIdentityClient) GetIdentityID(ctx context.Context, accessToken, region string) (string, error) {
if region == "" {
region = "us-east-1"
}
endpoint := fmt.Sprintf("https://cognito-identity.%s.amazonaws.com", region)
// Build the GetId request
// The SSO token is passed as a login token for the identity pool
payload := map[string]interface{}{
"IdentityPoolId": cognitoIdentityPoolID,
"Logins": map[string]string{
// Use the OIDC provider URL as the key
fmt.Sprintf("oidc.%s.amazonaws.com", region): accessToken,
},
}
body, err := json.Marshal(payload)
if err != nil {
return "", fmt.Errorf("failed to marshal GetId request: %w", err)
}
req, err := http.NewRequestWithContext(ctx, http.MethodPost, endpoint, strings.NewReader(string(body)))
if err != nil {
return "", fmt.Errorf("failed to create GetId request: %w", err)
}
req.Header.Set("Content-Type", "application/x-amz-json-1.1")
req.Header.Set("X-Amz-Target", "AWSCognitoIdentityService.GetId")
req.Header.Set("Accept", "application/json")
resp, err := c.httpClient.Do(req)
if err != nil {
return "", fmt.Errorf("GetId request failed: %w", err)
}
defer resp.Body.Close()
respBody, err := io.ReadAll(resp.Body)
if err != nil {
return "", fmt.Errorf("failed to read GetId response: %w", err)
}
if resp.StatusCode != http.StatusOK {
log.Debugf("Cognito GetId failed (status %d): %s", resp.StatusCode, string(respBody))
return "", fmt.Errorf("GetId failed (status %d): %s", resp.StatusCode, string(respBody))
}
var result struct {
IdentityID string `json:"IdentityId"`
}
if err := json.Unmarshal(respBody, &result); err != nil {
return "", fmt.Errorf("failed to parse GetId response: %w", err)
}
if result.IdentityID == "" {
return "", fmt.Errorf("empty IdentityId in GetId response")
}
log.Debugf("Cognito Identity ID: %s", result.IdentityID)
return result.IdentityID, nil
}
// GetCredentialsForIdentity exchanges an identity ID and login token for temporary AWS credentials.
func (c *CognitoIdentityClient) GetCredentialsForIdentity(ctx context.Context, identityID, accessToken, region string) (*CognitoCredentials, error) {
if region == "" {
region = "us-east-1"
}
endpoint := fmt.Sprintf("https://cognito-identity.%s.amazonaws.com", region)
payload := map[string]interface{}{
"IdentityId": identityID,
"Logins": map[string]string{
fmt.Sprintf("oidc.%s.amazonaws.com", region): accessToken,
},
}
body, err := json.Marshal(payload)
if err != nil {
return nil, fmt.Errorf("failed to marshal GetCredentialsForIdentity request: %w", err)
}
req, err := http.NewRequestWithContext(ctx, http.MethodPost, endpoint, strings.NewReader(string(body)))
if err != nil {
return nil, fmt.Errorf("failed to create GetCredentialsForIdentity request: %w", err)
}
req.Header.Set("Content-Type", "application/x-amz-json-1.1")
req.Header.Set("X-Amz-Target", "AWSCognitoIdentityService.GetCredentialsForIdentity")
req.Header.Set("Accept", "application/json")
resp, err := c.httpClient.Do(req)
if err != nil {
return nil, fmt.Errorf("GetCredentialsForIdentity request failed: %w", err)
}
defer resp.Body.Close()
respBody, err := io.ReadAll(resp.Body)
if err != nil {
return nil, fmt.Errorf("failed to read GetCredentialsForIdentity response: %w", err)
}
if resp.StatusCode != http.StatusOK {
log.Debugf("Cognito GetCredentialsForIdentity failed (status %d): %s", resp.StatusCode, string(respBody))
return nil, fmt.Errorf("GetCredentialsForIdentity failed (status %d): %s", resp.StatusCode, string(respBody))
}
var result struct {
Credentials struct {
AccessKeyID string `json:"AccessKeyId"`
SecretKey string `json:"SecretKey"`
SessionToken string `json:"SessionToken"`
Expiration int64 `json:"Expiration"`
} `json:"Credentials"`
IdentityID string `json:"IdentityId"`
}
if err := json.Unmarshal(respBody, &result); err != nil {
return nil, fmt.Errorf("failed to parse GetCredentialsForIdentity response: %w", err)
}
if result.Credentials.AccessKeyID == "" {
return nil, fmt.Errorf("empty AccessKeyId in GetCredentialsForIdentity response")
}
// Expiration is in seconds since epoch
expiration := time.Unix(result.Credentials.Expiration, 0)
log.Debugf("Cognito credentials obtained, expires: %s", expiration.Format(time.RFC3339))
return &CognitoCredentials{
AccessKeyID: result.Credentials.AccessKeyID,
SecretAccessKey: result.Credentials.SecretKey,
SessionToken: result.Credentials.SessionToken,
Expiration: expiration,
}, nil
}
// ExchangeSSOTokenForCredentials is a convenience method that performs the full
// Cognito Identity credential exchange flow: GetId -> GetCredentialsForIdentity
func (c *CognitoIdentityClient) ExchangeSSOTokenForCredentials(ctx context.Context, accessToken, region string) (*CognitoCredentials, error) {
log.Debugf("Exchanging SSO token for Cognito credentials (region: %s)", region)
// Step 1: Get Identity ID
identityID, err := c.GetIdentityID(ctx, accessToken, region)
if err != nil {
return nil, fmt.Errorf("failed to get identity ID: %w", err)
}
// Step 2: Get credentials for the identity
creds, err := c.GetCredentialsForIdentity(ctx, identityID, accessToken, region)
if err != nil {
return nil, fmt.Errorf("failed to get credentials for identity: %w", err)
}
return creds, nil
}
// SigV4Signer provides AWS Signature Version 4 signing for HTTP requests.
type SigV4Signer struct {
credentials *CognitoCredentials
region string
service string
}
// NewSigV4Signer creates a new SigV4 signer with the given credentials.
func NewSigV4Signer(creds *CognitoCredentials, region, service string) *SigV4Signer {
return &SigV4Signer{
credentials: creds,
region: region,
service: service,
}
}
// SignRequest signs an HTTP request using AWS Signature Version 4.
// The request body must be provided separately since it may have been read already.
func (s *SigV4Signer) SignRequest(req *http.Request, body []byte) error {
now := time.Now().UTC()
amzDate := now.Format("20060102T150405Z")
dateStamp := now.Format("20060102")
// Ensure required headers are set
if req.Header.Get("Host") == "" {
req.Header.Set("Host", req.URL.Host)
}
req.Header.Set("X-Amz-Date", amzDate)
if s.credentials.SessionToken != "" {
req.Header.Set("X-Amz-Security-Token", s.credentials.SessionToken)
}
// Create canonical request
canonicalRequest, signedHeaders := s.createCanonicalRequest(req, body)
// Create string to sign
algorithm := "AWS4-HMAC-SHA256"
credentialScope := fmt.Sprintf("%s/%s/%s/aws4_request", dateStamp, s.region, s.service)
stringToSign := fmt.Sprintf("%s\n%s\n%s\n%s",
algorithm,
amzDate,
credentialScope,
hashSHA256([]byte(canonicalRequest)),
)
// Calculate signature
signingKey := s.getSignatureKey(dateStamp)
signature := hex.EncodeToString(hmacSHA256(signingKey, []byte(stringToSign)))
// Build Authorization header
authHeader := fmt.Sprintf("%s Credential=%s/%s, SignedHeaders=%s, Signature=%s",
algorithm,
s.credentials.AccessKeyID,
credentialScope,
signedHeaders,
signature,
)
req.Header.Set("Authorization", authHeader)
return nil
}
// createCanonicalRequest builds the canonical request string for SigV4.
func (s *SigV4Signer) createCanonicalRequest(req *http.Request, body []byte) (string, string) {
// HTTP method
method := req.Method
// Canonical URI
uri := req.URL.Path
if uri == "" {
uri = "/"
}
// Canonical query string (sorted)
queryString := s.buildCanonicalQueryString(req)
// Canonical headers (sorted, lowercase)
canonicalHeaders, signedHeaders := s.buildCanonicalHeaders(req)
// Hashed payload
payloadHash := hashSHA256(body)
canonicalRequest := fmt.Sprintf("%s\n%s\n%s\n%s\n%s\n%s",
method,
uri,
queryString,
canonicalHeaders,
signedHeaders,
payloadHash,
)
return canonicalRequest, signedHeaders
}
// buildCanonicalQueryString builds a sorted, URI-encoded query string.
func (s *SigV4Signer) buildCanonicalQueryString(req *http.Request) string {
if req.URL.RawQuery == "" {
return ""
}
// Parse and sort query parameters
params := make([]string, 0)
for key, values := range req.URL.Query() {
for _, value := range values {
params = append(params, fmt.Sprintf("%s=%s", uriEncode(key), uriEncode(value)))
}
}
sort.Strings(params)
return strings.Join(params, "&")
}
// buildCanonicalHeaders builds sorted, lowercase canonical headers.
func (s *SigV4Signer) buildCanonicalHeaders(req *http.Request) (string, string) {
// Headers to sign (must include host and x-amz-*)
headerMap := make(map[string]string)
headerMap["host"] = req.URL.Host
for key, values := range req.Header {
lowKey := strings.ToLower(key)
// Include x-amz-* headers and content-type
if strings.HasPrefix(lowKey, "x-amz-") || lowKey == "content-type" {
headerMap[lowKey] = strings.TrimSpace(values[0])
}
}
// Sort header names
headerNames := make([]string, 0, len(headerMap))
for name := range headerMap {
headerNames = append(headerNames, name)
}
sort.Strings(headerNames)
// Build canonical headers and signed headers
var canonicalHeaders strings.Builder
for _, name := range headerNames {
canonicalHeaders.WriteString(name)
canonicalHeaders.WriteString(":")
canonicalHeaders.WriteString(headerMap[name])
canonicalHeaders.WriteString("\n")
}
signedHeaders := strings.Join(headerNames, ";")
return canonicalHeaders.String(), signedHeaders
}
// getSignatureKey derives the signing key for SigV4.
func (s *SigV4Signer) getSignatureKey(dateStamp string) []byte {
kDate := hmacSHA256([]byte("AWS4"+s.credentials.SecretAccessKey), []byte(dateStamp))
kRegion := hmacSHA256(kDate, []byte(s.region))
kService := hmacSHA256(kRegion, []byte(s.service))
kSigning := hmacSHA256(kService, []byte("aws4_request"))
return kSigning
}
// hmacSHA256 computes HMAC-SHA256.
func hmacSHA256(key, data []byte) []byte {
h := hmac.New(sha256.New, key)
h.Write(data)
return h.Sum(nil)
}
// hashSHA256 computes SHA256 hash and returns hex string.
func hashSHA256(data []byte) string {
hash := sha256.Sum256(data)
return hex.EncodeToString(hash[:])
}
// uriEncode performs URI encoding for SigV4.
func uriEncode(s string) string {
var result strings.Builder
for i := 0; i < len(s); i++ {
c := s[i]
if (c >= 'A' && c <= 'Z') || (c >= 'a' && c <= 'z') ||
(c >= '0' && c <= '9') || c == '-' || c == '.' || c == '_' || c == '~' {
result.WriteByte(c)
} else {
result.WriteString(fmt.Sprintf("%%%02X", c))
}
}
return result.String()
}
// IsExpired checks if the credentials are expired or about to expire.
func (c *CognitoCredentials) IsExpired() bool {
// Consider expired if within 5 minutes of expiration
return time.Now().Add(5 * time.Minute).After(c.Expiration)
}

View File

@@ -334,7 +334,7 @@ func (c *SSOOIDCClient) RefreshTokenWithRegion(ctx context.Context, clientID, cl
}
if resp.StatusCode != http.StatusOK {
log.Debugf("token refresh failed (status %d): %s", resp.StatusCode, string(respBody))
log.Warnf("IDC token refresh failed (status %d): %s", resp.StatusCode, string(respBody))
return nil, fmt.Errorf("token refresh failed (status %d)", resp.StatusCode)
}

View File

@@ -10,6 +10,8 @@ import (
"fmt"
"io"
"net/http"
"os"
"path/filepath"
"strings"
"sync"
"time"
@@ -178,64 +180,6 @@ func getKiroEndpointConfigs(auth *cliproxyauth.Auth) []kiroEndpointConfig {
type KiroExecutor struct {
cfg *config.Config
refreshMu sync.Mutex // Serializes token refresh operations to prevent race conditions
// cognitoCredsCache caches Cognito credentials per auth ID for IDC authentication
// Key: auth.ID, Value: *kiroauth.CognitoCredentials
cognitoCredsCache sync.Map
}
// getCachedCognitoCredentials retrieves cached Cognito credentials if they are still valid.
func (e *KiroExecutor) getCachedCognitoCredentials(authID string) *kiroauth.CognitoCredentials {
if cached, ok := e.cognitoCredsCache.Load(authID); ok {
creds := cached.(*kiroauth.CognitoCredentials)
if !creds.IsExpired() {
return creds
}
// Credentials expired, remove from cache
e.cognitoCredsCache.Delete(authID)
}
return nil
}
// cacheCognitoCredentials stores Cognito credentials in the cache.
func (e *KiroExecutor) cacheCognitoCredentials(authID string, creds *kiroauth.CognitoCredentials) {
e.cognitoCredsCache.Store(authID, creds)
}
// getOrExchangeCognitoCredentials retrieves cached Cognito credentials or exchanges the SSO token for new ones.
func (e *KiroExecutor) getOrExchangeCognitoCredentials(ctx context.Context, auth *cliproxyauth.Auth, accessToken string) (*kiroauth.CognitoCredentials, error) {
if auth == nil {
return nil, fmt.Errorf("auth is nil")
}
// Check cache first
if creds := e.getCachedCognitoCredentials(auth.ID); creds != nil {
log.Debugf("kiro: using cached Cognito credentials for auth %s (expires: %s)", auth.ID, creds.Expiration.Format(time.RFC3339))
return creds, nil
}
// Get region from auth metadata
region := "us-east-1"
if auth.Metadata != nil {
if r, ok := auth.Metadata["region"].(string); ok && r != "" {
region = r
}
}
log.Infof("kiro: exchanging SSO token for Cognito credentials (region: %s)", region)
// Exchange SSO token for Cognito credentials
cognitoClient := kiroauth.NewCognitoIdentityClient(e.cfg)
creds, err := cognitoClient.ExchangeSSOTokenForCredentials(ctx, accessToken, region)
if err != nil {
return nil, fmt.Errorf("failed to exchange SSO token for Cognito credentials: %w", err)
}
// Cache the credentials
e.cacheCognitoCredentials(auth.ID, creds)
log.Infof("kiro: Cognito credentials obtained and cached (expires: %s)", creds.Expiration.Format(time.RFC3339))
return creds, nil
}
// isIDCAuth checks if the auth uses IDC (Identity Center) authentication method.
@@ -247,12 +191,6 @@ func isIDCAuth(auth *cliproxyauth.Auth) bool {
return authMethod == "idc"
}
// signRequestWithSigV4 signs an HTTP request with AWS SigV4 using Cognito credentials.
func signRequestWithSigV4(req *http.Request, payload []byte, creds *kiroauth.CognitoCredentials, region, service string) error {
signer := kiroauth.NewSigV4Signer(creds, region, service)
return signer.SignRequest(req, payload)
}
// buildKiroPayloadForFormat builds the Kiro API payload based on the source format.
// This is critical because OpenAI and Claude formats have different tool structures:
// - OpenAI: tools[].function.name, tools[].function.description
@@ -301,6 +239,10 @@ func (e *KiroExecutor) Execute(ctx context.Context, auth *cliproxyauth.Auth, req
log.Warnf("kiro: pre-request token refresh failed: %v", refreshErr)
} else if refreshedAuth != nil {
auth = refreshedAuth
// Persist the refreshed auth to file so subsequent requests use it
if persistErr := e.persistRefreshedAuth(auth); persistErr != nil {
log.Warnf("kiro: failed to persist refreshed auth: %v", persistErr)
}
accessToken, profileArn = kiroCredentials(auth)
log.Infof("kiro: token refreshed successfully before request")
}
@@ -372,40 +314,8 @@ func (e *KiroExecutor) executeWithRetry(ctx context.Context, auth *cliproxyauth.
httpReq.Header.Set("Amz-Sdk-Request", "attempt=1; max=3")
httpReq.Header.Set("Amz-Sdk-Invocation-Id", uuid.New().String())
// Choose auth method: SigV4 for IDC, Bearer token for others
// NOTE: Cognito credential exchange disabled for now - testing Bearer token first
if false && isIDCAuth(auth) {
// IDC auth requires SigV4 signing with Cognito-exchanged credentials
cognitoCreds, err := e.getOrExchangeCognitoCredentials(ctx, auth, accessToken)
if err != nil {
log.Warnf("kiro: failed to get Cognito credentials for IDC auth: %v", err)
return resp, fmt.Errorf("IDC auth requires Cognito credentials: %w", err)
}
// Get region from auth metadata
region := "us-east-1"
if auth.Metadata != nil {
if r, ok := auth.Metadata["region"].(string); ok && r != "" {
region = r
}
}
// Determine service from URL
service := "codewhisperer"
if strings.Contains(url, "q.us-east-1.amazonaws.com") {
service = "qdeveloper"
}
// Sign the request with SigV4
if err := signRequestWithSigV4(httpReq, kiroPayload, cognitoCreds, region, service); err != nil {
log.Warnf("kiro: failed to sign request with SigV4: %v", err)
return resp, fmt.Errorf("SigV4 signing failed: %w", err)
}
log.Debugf("kiro: request signed with SigV4 for IDC auth (service: %s, region: %s)", service, region)
} else {
// Standard Bearer token authentication for Builder ID, social auth, etc.
httpReq.Header.Set("Authorization", "Bearer "+accessToken)
}
// Bearer token authentication for all auth types (Builder ID, IDC, social, etc.)
httpReq.Header.Set("Authorization", "Bearer "+accessToken)
var attrs map[string]string
if auth != nil {
@@ -494,6 +404,11 @@ func (e *KiroExecutor) executeWithRetry(ctx context.Context, auth *cliproxyauth.
if refreshedAuth != nil {
auth = refreshedAuth
// Persist the refreshed auth to file so subsequent requests use it
if persistErr := e.persistRefreshedAuth(auth); persistErr != nil {
log.Warnf("kiro: failed to persist refreshed auth: %v", persistErr)
// Continue anyway - the token is valid for this request
}
accessToken, profileArn = kiroCredentials(auth)
// Rebuild payload with new profile ARN if changed
kiroPayload, _ = buildKiroPayloadForFormat(body, kiroModelID, profileArn, currentOrigin, isAgentic, isChatOnly, from, opts.Headers)
@@ -552,6 +467,11 @@ func (e *KiroExecutor) executeWithRetry(ctx context.Context, auth *cliproxyauth.
}
if refreshedAuth != nil {
auth = refreshedAuth
// Persist the refreshed auth to file so subsequent requests use it
if persistErr := e.persistRefreshedAuth(auth); persistErr != nil {
log.Warnf("kiro: failed to persist refreshed auth: %v", persistErr)
// Continue anyway - the token is valid for this request
}
accessToken, profileArn = kiroCredentials(auth)
kiroPayload, _ = buildKiroPayloadForFormat(body, kiroModelID, profileArn, currentOrigin, isAgentic, isChatOnly, from, opts.Headers)
log.Infof("kiro: token refreshed for 403, retrying request")
@@ -654,6 +574,10 @@ func (e *KiroExecutor) ExecuteStream(ctx context.Context, auth *cliproxyauth.Aut
log.Warnf("kiro: pre-request token refresh failed: %v", refreshErr)
} else if refreshedAuth != nil {
auth = refreshedAuth
// Persist the refreshed auth to file so subsequent requests use it
if persistErr := e.persistRefreshedAuth(auth); persistErr != nil {
log.Warnf("kiro: failed to persist refreshed auth: %v", persistErr)
}
accessToken, profileArn = kiroCredentials(auth)
log.Infof("kiro: token refreshed successfully before stream request")
}
@@ -723,40 +647,8 @@ func (e *KiroExecutor) executeStreamWithRetry(ctx context.Context, auth *cliprox
httpReq.Header.Set("Amz-Sdk-Request", "attempt=1; max=3")
httpReq.Header.Set("Amz-Sdk-Invocation-Id", uuid.New().String())
// Choose auth method: SigV4 for IDC, Bearer token for others
// NOTE: Cognito credential exchange disabled for now - testing Bearer token first
if false && isIDCAuth(auth) {
// IDC auth requires SigV4 signing with Cognito-exchanged credentials
cognitoCreds, err := e.getOrExchangeCognitoCredentials(ctx, auth, accessToken)
if err != nil {
log.Warnf("kiro: failed to get Cognito credentials for IDC auth: %v", err)
return nil, fmt.Errorf("IDC auth requires Cognito credentials: %w", err)
}
// Get region from auth metadata
region := "us-east-1"
if auth.Metadata != nil {
if r, ok := auth.Metadata["region"].(string); ok && r != "" {
region = r
}
}
// Determine service from URL
service := "codewhisperer"
if strings.Contains(url, "q.us-east-1.amazonaws.com") {
service = "qdeveloper"
}
// Sign the request with SigV4
if err := signRequestWithSigV4(httpReq, kiroPayload, cognitoCreds, region, service); err != nil {
log.Warnf("kiro: failed to sign request with SigV4: %v", err)
return nil, fmt.Errorf("SigV4 signing failed: %w", err)
}
log.Debugf("kiro: stream request signed with SigV4 for IDC auth (service: %s, region: %s)", service, region)
} else {
// Standard Bearer token authentication for Builder ID, social auth, etc.
httpReq.Header.Set("Authorization", "Bearer "+accessToken)
}
// Bearer token authentication for all auth types (Builder ID, IDC, social, etc.)
httpReq.Header.Set("Authorization", "Bearer "+accessToken)
var attrs map[string]string
if auth != nil {
@@ -858,6 +750,11 @@ func (e *KiroExecutor) executeStreamWithRetry(ctx context.Context, auth *cliprox
if refreshedAuth != nil {
auth = refreshedAuth
// Persist the refreshed auth to file so subsequent requests use it
if persistErr := e.persistRefreshedAuth(auth); persistErr != nil {
log.Warnf("kiro: failed to persist refreshed auth: %v", persistErr)
// Continue anyway - the token is valid for this request
}
accessToken, profileArn = kiroCredentials(auth)
// Rebuild payload with new profile ARN if changed
kiroPayload, _ = buildKiroPayloadForFormat(body, kiroModelID, profileArn, currentOrigin, isAgentic, isChatOnly, from, opts.Headers)
@@ -916,6 +813,11 @@ func (e *KiroExecutor) executeStreamWithRetry(ctx context.Context, auth *cliprox
}
if refreshedAuth != nil {
auth = refreshedAuth
// Persist the refreshed auth to file so subsequent requests use it
if persistErr := e.persistRefreshedAuth(auth); persistErr != nil {
log.Warnf("kiro: failed to persist refreshed auth: %v", persistErr)
// Continue anyway - the token is valid for this request
}
accessToken, profileArn = kiroCredentials(auth)
kiroPayload, _ = buildKiroPayloadForFormat(body, kiroModelID, profileArn, currentOrigin, isAgentic, isChatOnly, from, opts.Headers)
log.Infof("kiro: token refreshed for 403, retrying stream request")
@@ -3191,6 +3093,7 @@ func (e *KiroExecutor) Refresh(ctx context.Context, auth *cliproxyauth.Auth) (*c
var refreshToken string
var clientID, clientSecret string
var authMethod string
var region, startURL string
if auth.Metadata != nil {
if rt, ok := auth.Metadata["refresh_token"].(string); ok {
@@ -3205,6 +3108,12 @@ func (e *KiroExecutor) Refresh(ctx context.Context, auth *cliproxyauth.Auth) (*c
if am, ok := auth.Metadata["auth_method"].(string); ok {
authMethod = am
}
if r, ok := auth.Metadata["region"].(string); ok {
region = r
}
if su, ok := auth.Metadata["start_url"].(string); ok {
startURL = su
}
}
if refreshToken == "" {
@@ -3214,12 +3123,20 @@ func (e *KiroExecutor) Refresh(ctx context.Context, auth *cliproxyauth.Auth) (*c
var tokenData *kiroauth.KiroTokenData
var err error
// Use SSO OIDC refresh for AWS Builder ID, otherwise use Kiro's OAuth refresh endpoint
if clientID != "" && clientSecret != "" && authMethod == "builder-id" {
ssoClient := kiroauth.NewSSOOIDCClient(e.cfg)
// Use SSO OIDC refresh for AWS Builder ID or IDC, otherwise use Kiro's OAuth refresh endpoint
switch {
case clientID != "" && clientSecret != "" && authMethod == "idc" && region != "":
// IDC refresh with region-specific endpoint
log.Debugf("kiro executor: using SSO OIDC refresh for IDC (region=%s)", region)
tokenData, err = ssoClient.RefreshTokenWithRegion(ctx, clientID, clientSecret, refreshToken, region, startURL)
case clientID != "" && clientSecret != "" && authMethod == "builder-id":
// Builder ID refresh with default endpoint
log.Debugf("kiro executor: using SSO OIDC refresh for AWS Builder ID")
ssoClient := kiroauth.NewSSOOIDCClient(e.cfg)
tokenData, err = ssoClient.RefreshToken(ctx, clientID, clientSecret, refreshToken)
} else {
default:
// Fallback to Kiro's OAuth refresh endpoint (for social auth: Google/GitHub)
log.Debugf("kiro executor: using Kiro OAuth refresh endpoint")
oauth := kiroauth.NewKiroOAuth(e.cfg)
tokenData, err = oauth.RefreshToken(ctx, refreshToken)
@@ -3275,6 +3192,53 @@ func (e *KiroExecutor) Refresh(ctx context.Context, auth *cliproxyauth.Auth) (*c
return updated, nil
}
// persistRefreshedAuth persists a refreshed auth record to disk.
// This ensures token refreshes from inline retry are saved to the auth file.
func (e *KiroExecutor) persistRefreshedAuth(auth *cliproxyauth.Auth) error {
if auth == nil || auth.Metadata == nil {
return fmt.Errorf("kiro executor: cannot persist nil auth or metadata")
}
// Determine the file path from auth attributes or filename
var authPath string
if auth.Attributes != nil {
if p := strings.TrimSpace(auth.Attributes["path"]); p != "" {
authPath = p
}
}
if authPath == "" {
fileName := strings.TrimSpace(auth.FileName)
if fileName == "" {
return fmt.Errorf("kiro executor: auth has no file path or filename")
}
if filepath.IsAbs(fileName) {
authPath = fileName
} else if e.cfg != nil && e.cfg.AuthDir != "" {
authPath = filepath.Join(e.cfg.AuthDir, fileName)
} else {
return fmt.Errorf("kiro executor: cannot determine auth file path")
}
}
// Marshal metadata to JSON
raw, err := json.Marshal(auth.Metadata)
if err != nil {
return fmt.Errorf("kiro executor: marshal metadata failed: %w", err)
}
// Write to temp file first, then rename (atomic write)
tmp := authPath + ".tmp"
if err := os.WriteFile(tmp, raw, 0o600); err != nil {
return fmt.Errorf("kiro executor: write temp auth file failed: %w", err)
}
if err := os.Rename(tmp, authPath); err != nil {
return fmt.Errorf("kiro executor: rename auth file failed: %w", err)
}
log.Debugf("kiro executor: persisted refreshed auth to %s", authPath)
return nil
}
// isTokenExpired checks if a JWT access token has expired.
// Returns true if the token is expired or cannot be parsed.
func (e *KiroExecutor) isTokenExpired(accessToken string) bool {