package kiro import ( "context" "log" "sync" "time" "github.com/router-for-me/CLIProxyAPI/v6/internal/config" "golang.org/x/sync/semaphore" ) type Token struct { ID string AccessToken string RefreshToken string ExpiresAt time.Time LastVerified time.Time ClientID string ClientSecret string AuthMethod string Provider string StartURL string Region string } type TokenRepository interface { FindOldestUnverified(limit int) []*Token UpdateToken(token *Token) error } type RefresherOption func(*BackgroundRefresher) func WithInterval(interval time.Duration) RefresherOption { return func(r *BackgroundRefresher) { r.interval = interval } } func WithBatchSize(size int) RefresherOption { return func(r *BackgroundRefresher) { r.batchSize = size } } func WithConcurrency(concurrency int) RefresherOption { return func(r *BackgroundRefresher) { r.concurrency = concurrency } } type BackgroundRefresher struct { interval time.Duration batchSize int concurrency int tokenRepo TokenRepository stopCh chan struct{} wg sync.WaitGroup oauth *KiroOAuth ssoClient *SSOOIDCClient } func NewBackgroundRefresher(repo TokenRepository, opts ...RefresherOption) *BackgroundRefresher { r := &BackgroundRefresher{ interval: time.Minute, batchSize: 50, concurrency: 10, tokenRepo: repo, stopCh: make(chan struct{}), oauth: nil, // Lazy init - will be set when config available ssoClient: nil, // Lazy init - will be set when config available } for _, opt := range opts { opt(r) } return r } // WithConfig sets the configuration for OAuth and SSO clients. func WithConfig(cfg *config.Config) RefresherOption { return func(r *BackgroundRefresher) { r.oauth = NewKiroOAuth(cfg) r.ssoClient = NewSSOOIDCClient(cfg) } } func (r *BackgroundRefresher) Start(ctx context.Context) { r.wg.Add(1) go func() { defer r.wg.Done() ticker := time.NewTicker(r.interval) defer ticker.Stop() r.refreshBatch(ctx) for { select { case <-ctx.Done(): return case <-r.stopCh: return case <-ticker.C: r.refreshBatch(ctx) } } }() } func (r *BackgroundRefresher) Stop() { close(r.stopCh) r.wg.Wait() } func (r *BackgroundRefresher) refreshBatch(ctx context.Context) { tokens := r.tokenRepo.FindOldestUnverified(r.batchSize) if len(tokens) == 0 { return } sem := semaphore.NewWeighted(int64(r.concurrency)) var wg sync.WaitGroup for i, token := range tokens { if i > 0 { select { case <-ctx.Done(): return case <-r.stopCh: return case <-time.After(100 * time.Millisecond): } } if err := sem.Acquire(ctx, 1); err != nil { return } wg.Add(1) go func(t *Token) { defer wg.Done() defer sem.Release(1) r.refreshSingle(ctx, t) }(token) } wg.Wait() } func (r *BackgroundRefresher) refreshSingle(ctx context.Context, token *Token) { var newTokenData *KiroTokenData var err error switch token.AuthMethod { case "idc": newTokenData, err = r.ssoClient.RefreshTokenWithRegion( ctx, token.ClientID, token.ClientSecret, token.RefreshToken, token.Region, token.StartURL, ) case "builder-id": newTokenData, err = r.ssoClient.RefreshToken( ctx, token.ClientID, token.ClientSecret, token.RefreshToken, ) default: newTokenData, err = r.oauth.RefreshToken(ctx, token.RefreshToken) } if err != nil { log.Printf("failed to refresh token %s: %v", token.ID, err) return } token.AccessToken = newTokenData.AccessToken token.RefreshToken = newTokenData.RefreshToken token.LastVerified = time.Now() if newTokenData.ExpiresAt != "" { if expTime, parseErr := time.Parse(time.RFC3339, newTokenData.ExpiresAt); parseErr == nil { token.ExpiresAt = expTime } } if err := r.tokenRepo.UpdateToken(token); err != nil { log.Printf("failed to update token %s: %v", token.ID, err) } }