diff --git a/internal/auth/kiro/background_refresh.go b/internal/auth/kiro/background_refresh.go index 3fecc417..1203ff47 100644 --- a/internal/auth/kiro/background_refresh.go +++ b/internal/auth/kiro/background_refresh.go @@ -50,14 +50,16 @@ func WithConcurrency(concurrency int) RefresherOption { } type BackgroundRefresher struct { - interval time.Duration - batchSize int - concurrency int - tokenRepo TokenRepository - stopCh chan struct{} - wg sync.WaitGroup - oauth *KiroOAuth - ssoClient *SSOOIDCClient + interval time.Duration + batchSize int + concurrency int + tokenRepo TokenRepository + stopCh chan struct{} + wg sync.WaitGroup + oauth *KiroOAuth + ssoClient *SSOOIDCClient + callbackMu sync.RWMutex // 保护回调函数的并发访问 + onTokenRefreshed func(tokenID string, tokenData *KiroTokenData) // 刷新成功回调 } func NewBackgroundRefresher(repo TokenRepository, opts ...RefresherOption) *BackgroundRefresher { @@ -84,6 +86,17 @@ func WithConfig(cfg *config.Config) RefresherOption { } } +// WithOnTokenRefreshed sets the callback function to be called when a token is successfully refreshed. +// The callback receives the token ID (filename) and the new token data. +// This allows external components (e.g., Watcher) to be notified of token updates. +func WithOnTokenRefreshed(callback func(tokenID string, tokenData *KiroTokenData)) RefresherOption { + return func(r *BackgroundRefresher) { + r.callbackMu.Lock() + r.onTokenRefreshed = callback + r.callbackMu.Unlock() + } +} + func (r *BackgroundRefresher) Start(ctx context.Context) { r.wg.Add(1) go func() { @@ -188,5 +201,24 @@ func (r *BackgroundRefresher) refreshSingle(ctx context.Context, token *Token) { if err := r.tokenRepo.UpdateToken(token); err != nil { log.Printf("failed to update token %s: %v", token.ID, err) + return + } + + // 方案 A: 刷新成功后触发回调,通知 Watcher 更新内存中的 Auth 对象 + r.callbackMu.RLock() + callback := r.onTokenRefreshed + r.callbackMu.RUnlock() + + if callback != nil { + // 使用 defer recover 隔离回调 panic,防止崩溃整个进程 + func() { + defer func() { + if rec := recover(); rec != nil { + log.Printf("background refresh: callback panic for token %s: %v", token.ID, rec) + } + }() + log.Printf("background refresh: notifying token refresh callback for %s", token.ID) + callback(token.ID, newTokenData) + }() } } diff --git a/internal/auth/kiro/refresh_manager.go b/internal/auth/kiro/refresh_manager.go index cd27b432..05e27a54 100644 --- a/internal/auth/kiro/refresh_manager.go +++ b/internal/auth/kiro/refresh_manager.go @@ -11,11 +11,12 @@ import ( // RefreshManager 是后台刷新器的单例管理器 type RefreshManager struct { - mu sync.Mutex - refresher *BackgroundRefresher - ctx context.Context - cancel context.CancelFunc - started bool + mu sync.Mutex + refresher *BackgroundRefresher + ctx context.Context + cancel context.CancelFunc + started bool + onTokenRefreshed func(tokenID string, tokenData *KiroTokenData) // 刷新成功回调 } var ( @@ -52,13 +53,19 @@ func (m *RefreshManager) Initialize(baseDir string, cfg *config.Config) error { repo := NewFileTokenRepository(baseDir) // 创建后台刷新器,配置参数 - m.refresher = NewBackgroundRefresher( - repo, - WithInterval(time.Minute), // 每分钟检查一次 - WithBatchSize(50), // 每批最多处理 50 个 token - WithConcurrency(10), // 最多 10 个并发刷新 - WithConfig(cfg), // 设置 OAuth 和 SSO 客户端 - ) + opts := []RefresherOption{ + WithInterval(time.Minute), // 每分钟检查一次 + WithBatchSize(50), // 每批最多处理 50 个 token + WithConcurrency(10), // 最多 10 个并发刷新 + WithConfig(cfg), // 设置 OAuth 和 SSO 客户端 + } + + // 如果已设置回调,传递给 BackgroundRefresher + if m.onTokenRefreshed != nil { + opts = append(opts, WithOnTokenRefreshed(m.onTokenRefreshed)) + } + + m.refresher = NewBackgroundRefresher(repo, opts...) log.Infof("refresh manager: initialized with base directory %s", baseDir) return nil @@ -127,6 +134,25 @@ func (m *RefreshManager) UpdateBaseDir(baseDir string) { } } +// SetOnTokenRefreshed 设置 token 刷新成功后的回调函数 +// 可以在任何时候调用,支持运行时更新回调 +// callback: 回调函数,接收 tokenID(文件名)和新的 token 数据 +func (m *RefreshManager) SetOnTokenRefreshed(callback func(tokenID string, tokenData *KiroTokenData)) { + m.mu.Lock() + defer m.mu.Unlock() + + m.onTokenRefreshed = callback + + // 如果 refresher 已经创建,使用并发安全的方式更新它的回调 + if m.refresher != nil { + m.refresher.callbackMu.Lock() + m.refresher.onTokenRefreshed = callback + m.refresher.callbackMu.Unlock() + } + + log.Debug("refresh manager: token refresh callback registered") +} + // InitializeAndStart 初始化并启动后台刷新(便捷方法) func InitializeAndStart(baseDir string, cfg *config.Config) { manager := GetRefreshManager() diff --git a/internal/runtime/executor/kiro_executor.go b/internal/runtime/executor/kiro_executor.go index 4506601d..ed6014a2 100644 --- a/internal/runtime/executor/kiro_executor.go +++ b/internal/runtime/executor/kiro_executor.go @@ -581,18 +581,30 @@ func (e *KiroExecutor) Execute(ctx context.Context, auth *cliproxyauth.Auth, req // Check if token is expired before making request if e.isTokenExpired(accessToken) { - log.Infof("kiro: access token expired, attempting refresh before request") - refreshedAuth, refreshErr := e.Refresh(ctx, auth) - if refreshErr != nil { - log.Warnf("kiro: pre-request token refresh failed: %v", refreshErr) - } else if refreshedAuth != nil { - auth = refreshedAuth - // Persist the refreshed auth to file so subsequent requests use it - if persistErr := e.persistRefreshedAuth(auth); persistErr != nil { - log.Warnf("kiro: failed to persist refreshed auth: %v", persistErr) - } + log.Infof("kiro: access token expired, attempting recovery") + + // 方案 B: 先尝试从文件重新加载 token(后台刷新器可能已更新文件) + reloadedAuth, reloadErr := e.reloadAuthFromFile(auth) + if reloadErr == nil && reloadedAuth != nil { + // 文件中有更新的 token,使用它 + auth = reloadedAuth accessToken, profileArn = kiroCredentials(auth) - log.Infof("kiro: token refreshed successfully before request") + log.Infof("kiro: recovered token from file (background refresh), expires_at: %v", auth.Metadata["expires_at"]) + } else { + // 文件中的 token 也过期了,执行主动刷新 + log.Debugf("kiro: file reload failed (%v), attempting active refresh", reloadErr) + refreshedAuth, refreshErr := e.Refresh(ctx, auth) + if refreshErr != nil { + log.Warnf("kiro: pre-request token refresh failed: %v", refreshErr) + } else if refreshedAuth != nil { + auth = refreshedAuth + // Persist the refreshed auth to file so subsequent requests use it + if persistErr := e.persistRefreshedAuth(auth); persistErr != nil { + log.Warnf("kiro: failed to persist refreshed auth: %v", persistErr) + } + accessToken, profileArn = kiroCredentials(auth) + log.Infof("kiro: token refreshed successfully before request") + } } } @@ -979,18 +991,30 @@ func (e *KiroExecutor) ExecuteStream(ctx context.Context, auth *cliproxyauth.Aut // Check if token is expired before making request if e.isTokenExpired(accessToken) { - log.Infof("kiro: access token expired, attempting refresh before stream request") - refreshedAuth, refreshErr := e.Refresh(ctx, auth) - if refreshErr != nil { - log.Warnf("kiro: pre-request token refresh failed: %v", refreshErr) - } else if refreshedAuth != nil { - auth = refreshedAuth - // Persist the refreshed auth to file so subsequent requests use it - if persistErr := e.persistRefreshedAuth(auth); persistErr != nil { - log.Warnf("kiro: failed to persist refreshed auth: %v", persistErr) - } + log.Infof("kiro: access token expired, attempting recovery before stream request") + + // 方案 B: 先尝试从文件重新加载 token(后台刷新器可能已更新文件) + reloadedAuth, reloadErr := e.reloadAuthFromFile(auth) + if reloadErr == nil && reloadedAuth != nil { + // 文件中有更新的 token,使用它 + auth = reloadedAuth accessToken, profileArn = kiroCredentials(auth) - log.Infof("kiro: token refreshed successfully before stream request") + log.Infof("kiro: recovered token from file (background refresh) for stream, expires_at: %v", auth.Metadata["expires_at"]) + } else { + // 文件中的 token 也过期了,执行主动刷新 + log.Debugf("kiro: file reload failed (%v), attempting active refresh for stream", reloadErr) + refreshedAuth, refreshErr := e.Refresh(ctx, auth) + if refreshErr != nil { + log.Warnf("kiro: pre-request token refresh failed: %v", refreshErr) + } else if refreshedAuth != nil { + auth = refreshedAuth + // Persist the refreshed auth to file so subsequent requests use it + if persistErr := e.persistRefreshedAuth(auth); persistErr != nil { + log.Warnf("kiro: failed to persist refreshed auth: %v", persistErr) + } + accessToken, profileArn = kiroCredentials(auth) + log.Infof("kiro: token refreshed successfully before stream request") + } } } @@ -3689,6 +3713,121 @@ func (e *KiroExecutor) persistRefreshedAuth(auth *cliproxyauth.Auth) error { return nil } +// reloadAuthFromFile 从文件重新加载 auth 数据(方案 B: Fallback 机制) +// 当内存中的 token 已过期时,尝试从文件读取最新的 token +// 这解决了后台刷新器已更新文件但内存中 Auth 对象尚未同步的时间差问题 +func (e *KiroExecutor) reloadAuthFromFile(auth *cliproxyauth.Auth) (*cliproxyauth.Auth, error) { + if auth == nil { + return nil, fmt.Errorf("kiro executor: cannot reload nil auth") + } + + // 确定文件路径 + var authPath string + if auth.Attributes != nil { + if p := strings.TrimSpace(auth.Attributes["path"]); p != "" { + authPath = p + } + } + if authPath == "" { + fileName := strings.TrimSpace(auth.FileName) + if fileName == "" { + return nil, fmt.Errorf("kiro executor: auth has no file path or filename for reload") + } + if filepath.IsAbs(fileName) { + authPath = fileName + } else if e.cfg != nil && e.cfg.AuthDir != "" { + authPath = filepath.Join(e.cfg.AuthDir, fileName) + } else { + return nil, fmt.Errorf("kiro executor: cannot determine auth file path for reload") + } + } + + // 读取文件 + raw, err := os.ReadFile(authPath) + if err != nil { + return nil, fmt.Errorf("kiro executor: failed to read auth file %s: %w", authPath, err) + } + + // 解析 JSON + var metadata map[string]any + if err := json.Unmarshal(raw, &metadata); err != nil { + return nil, fmt.Errorf("kiro executor: failed to parse auth file %s: %w", authPath, err) + } + + // 检查文件中的 token 是否比内存中的更新 + fileExpiresAt, _ := metadata["expires_at"].(string) + fileAccessToken, _ := metadata["access_token"].(string) + memExpiresAt, _ := auth.Metadata["expires_at"].(string) + memAccessToken, _ := auth.Metadata["access_token"].(string) + + // 文件中必须有有效的 access_token + if fileAccessToken == "" { + return nil, fmt.Errorf("kiro executor: auth file has no access_token field") + } + + // 如果有 expires_at,检查是否过期 + if fileExpiresAt != "" { + fileExpTime, parseErr := time.Parse(time.RFC3339, fileExpiresAt) + if parseErr == nil { + // 如果文件中的 token 也已过期,不使用它 + if time.Now().After(fileExpTime) { + log.Debugf("kiro executor: file token also expired at %s, not using", fileExpiresAt) + return nil, fmt.Errorf("kiro executor: file token also expired") + } + } + } + + // 判断文件中的 token 是否比内存中的更新 + // 条件1: access_token 不同(说明已刷新) + // 条件2: expires_at 更新(说明已刷新) + isNewer := false + + // 优先检查 access_token 是否变化 + if fileAccessToken != memAccessToken { + isNewer = true + log.Debugf("kiro executor: file access_token differs from memory, using file token") + } + + // 如果 access_token 相同,检查 expires_at + if !isNewer && fileExpiresAt != "" && memExpiresAt != "" { + fileExpTime, fileParseErr := time.Parse(time.RFC3339, fileExpiresAt) + memExpTime, memParseErr := time.Parse(time.RFC3339, memExpiresAt) + if fileParseErr == nil && memParseErr == nil && fileExpTime.After(memExpTime) { + isNewer = true + log.Debugf("kiro executor: file expires_at (%s) is newer than memory (%s)", fileExpiresAt, memExpiresAt) + } + } + + // 如果文件中没有 expires_at 但 access_token 相同,无法判断是否更新 + if !isNewer && fileExpiresAt == "" && fileAccessToken == memAccessToken { + return nil, fmt.Errorf("kiro executor: cannot determine if file token is newer (no expires_at, same access_token)") + } + + if !isNewer { + log.Debugf("kiro executor: file token not newer than memory token") + return nil, fmt.Errorf("kiro executor: file token not newer") + } + + // 创建更新后的 auth 对象 + updated := auth.Clone() + updated.Metadata = metadata + updated.UpdatedAt = time.Now() + + // 同步更新 Attributes + if updated.Attributes == nil { + updated.Attributes = make(map[string]string) + } + if accessToken, ok := metadata["access_token"].(string); ok { + updated.Attributes["access_token"] = accessToken + } + if profileArn, ok := metadata["profile_arn"].(string); ok { + updated.Attributes["profile_arn"] = profileArn + } + + log.Infof("kiro executor: reloaded auth from file %s, new expires_at: %s", authPath, fileExpiresAt) + return updated, nil +} + // isTokenExpired checks if a JWT access token has expired. // Returns true if the token is expired or cannot be parsed. func (e *KiroExecutor) isTokenExpired(accessToken string) bool { diff --git a/internal/watcher/watcher.go b/internal/watcher/watcher.go index 77006cf8..8141ca07 100644 --- a/internal/watcher/watcher.go +++ b/internal/watcher/watcher.go @@ -145,3 +145,111 @@ func (w *Watcher) SnapshotCoreAuths() []*coreauth.Auth { w.clientsMutex.RUnlock() return snapshotCoreAuths(cfg, w.authDir) } + +// NotifyTokenRefreshed 处理后台刷新器的 token 更新通知 +// 当后台刷新器成功刷新 token 后调用此方法,更新内存中的 Auth 对象 +// tokenID: token 文件名(如 kiro-xxx.json) +// accessToken: 新的 access token +// refreshToken: 新的 refresh token +// expiresAt: 新的过期时间 +func (w *Watcher) NotifyTokenRefreshed(tokenID, accessToken, refreshToken, expiresAt string) { + if w == nil { + return + } + + w.clientsMutex.Lock() + defer w.clientsMutex.Unlock() + + // 遍历 currentAuths,找到匹配的 Auth 并更新 + updated := false + for id, auth := range w.currentAuths { + if auth == nil || auth.Metadata == nil { + continue + } + + // 检查是否是 kiro 类型的 auth + authType, _ := auth.Metadata["type"].(string) + if authType != "kiro" { + continue + } + + // 多种匹配方式,解决不同来源的 auth 对象字段差异 + matched := false + + // 1. 通过 auth.ID 匹配(ID 可能包含文件名) + if !matched && auth.ID != "" { + if auth.ID == tokenID || strings.HasSuffix(auth.ID, "/"+tokenID) || strings.HasSuffix(auth.ID, "\\"+tokenID) { + matched = true + } + // ID 可能是 "kiro-xxx" 格式(无扩展名),tokenID 是 "kiro-xxx.json" + if !matched && strings.TrimSuffix(tokenID, ".json") == auth.ID { + matched = true + } + } + + // 2. 通过 auth.Attributes["path"] 匹配 + if !matched && auth.Attributes != nil { + if authPath := auth.Attributes["path"]; authPath != "" { + // 提取文件名部分进行比较 + pathBase := authPath + if idx := strings.LastIndexAny(authPath, "/\\"); idx >= 0 { + pathBase = authPath[idx+1:] + } + if pathBase == tokenID || strings.TrimSuffix(pathBase, ".json") == strings.TrimSuffix(tokenID, ".json") { + matched = true + } + } + } + + // 3. 通过 auth.FileName 匹配(原有逻辑) + if !matched && auth.FileName != "" { + if auth.FileName == tokenID || strings.HasSuffix(auth.FileName, "/"+tokenID) || strings.HasSuffix(auth.FileName, "\\"+tokenID) { + matched = true + } + } + + if matched { + // 更新内存中的 token + auth.Metadata["access_token"] = accessToken + auth.Metadata["refresh_token"] = refreshToken + auth.Metadata["expires_at"] = expiresAt + auth.Metadata["last_refresh"] = time.Now().Format(time.RFC3339) + auth.UpdatedAt = time.Now() + auth.LastRefreshedAt = time.Now() + + log.Infof("watcher: updated in-memory auth for token %s (auth ID: %s)", tokenID, id) + updated = true + + // 同时更新 runtimeAuths 中的副本(如果存在) + if w.runtimeAuths != nil { + if runtimeAuth, ok := w.runtimeAuths[id]; ok && runtimeAuth != nil { + if runtimeAuth.Metadata == nil { + runtimeAuth.Metadata = make(map[string]any) + } + runtimeAuth.Metadata["access_token"] = accessToken + runtimeAuth.Metadata["refresh_token"] = refreshToken + runtimeAuth.Metadata["expires_at"] = expiresAt + runtimeAuth.Metadata["last_refresh"] = time.Now().Format(time.RFC3339) + runtimeAuth.UpdatedAt = time.Now() + runtimeAuth.LastRefreshedAt = time.Now() + } + } + + // 发送更新通知到 authQueue + if w.authQueue != nil { + go func(authClone *coreauth.Auth) { + update := AuthUpdate{ + Action: AuthUpdateActionModify, + ID: authClone.ID, + Auth: authClone, + } + w.dispatchAuthUpdates([]AuthUpdate{update}) + }(auth.Clone()) + } + } + } + + if !updated { + log.Debugf("watcher: no matching auth found for token %s, will be picked up on next file scan", tokenID) + } +} diff --git a/sdk/cliproxy/service.go b/sdk/cliproxy/service.go index 885304ad..750eb885 100644 --- a/sdk/cliproxy/service.go +++ b/sdk/cliproxy/service.go @@ -98,6 +98,16 @@ func (s *Service) RegisterUsagePlugin(plugin usage.Plugin) { usage.RegisterPlugin(plugin) } +// GetWatcher returns the underlying WatcherWrapper instance. +// This allows external components (e.g., RefreshManager) to interact with the watcher. +// Returns nil if the service or watcher is not initialized. +func (s *Service) GetWatcher() *WatcherWrapper { + if s == nil { + return nil + } + return s.watcher +} + // newDefaultAuthManager creates a default authentication manager with all supported providers. func newDefaultAuthManager() *sdkAuth.Manager { return sdkAuth.NewManager( @@ -575,6 +585,18 @@ func (s *Service) Run(ctx context.Context) error { } watcherWrapper.SetConfig(s.cfg) + // 方案 A: 连接 Kiro 后台刷新器回调到 Watcher + // 当后台刷新器成功刷新 token 后,立即通知 Watcher 更新内存中的 Auth 对象 + // 这解决了后台刷新与内存 Auth 对象之间的时间差问题 + kiroauth.GetRefreshManager().SetOnTokenRefreshed(func(tokenID string, tokenData *kiroauth.KiroTokenData) { + if tokenData == nil || watcherWrapper == nil { + return + } + log.Debugf("kiro refresh callback: notifying watcher for token %s", tokenID) + watcherWrapper.NotifyTokenRefreshed(tokenID, tokenData.AccessToken, tokenData.RefreshToken, tokenData.ExpiresAt) + }) + log.Debug("kiro: connected background refresh callback to watcher") + watcherCtx, watcherCancel := context.WithCancel(context.Background()) s.watcherCancel = watcherCancel if err = watcherWrapper.Start(watcherCtx); err != nil { diff --git a/sdk/cliproxy/types.go b/sdk/cliproxy/types.go index 1521dffe..ee8f761d 100644 --- a/sdk/cliproxy/types.go +++ b/sdk/cliproxy/types.go @@ -89,6 +89,7 @@ type WatcherWrapper struct { snapshotAuths func() []*coreauth.Auth setUpdateQueue func(queue chan<- watcher.AuthUpdate) dispatchRuntimeUpdate func(update watcher.AuthUpdate) bool + notifyTokenRefreshed func(tokenID, accessToken, refreshToken, expiresAt string) // 方案 A: 后台刷新通知 } // Start proxies to the underlying watcher Start implementation. @@ -146,3 +147,16 @@ func (w *WatcherWrapper) SetAuthUpdateQueue(queue chan<- watcher.AuthUpdate) { } w.setUpdateQueue(queue) } + +// NotifyTokenRefreshed 通知 Watcher 后台刷新器已更新 token +// 这是方案 A 的核心方法,用于解决后台刷新与内存 Auth 对象的时间差问题 +// tokenID: token 文件名(如 kiro-xxx.json) +// accessToken: 新的 access token +// refreshToken: 新的 refresh token +// expiresAt: 新的过期时间(RFC3339 格式) +func (w *WatcherWrapper) NotifyTokenRefreshed(tokenID, accessToken, refreshToken, expiresAt string) { + if w == nil || w.notifyTokenRefreshed == nil { + return + } + w.notifyTokenRefreshed(tokenID, accessToken, refreshToken, expiresAt) +} diff --git a/sdk/cliproxy/watcher.go b/sdk/cliproxy/watcher.go index caeadf19..e6e91bdd 100644 --- a/sdk/cliproxy/watcher.go +++ b/sdk/cliproxy/watcher.go @@ -31,5 +31,8 @@ func defaultWatcherFactory(configPath, authDir string, reload func(*config.Confi dispatchRuntimeUpdate: func(update watcher.AuthUpdate) bool { return w.DispatchRuntimeAuthUpdate(update) }, + notifyTokenRefreshed: func(tokenID, accessToken, refreshToken, expiresAt string) { + w.NotifyTokenRefreshed(tokenID, accessToken, refreshToken, expiresAt) + }, }, nil }