mirror of
https://github.com/router-for-me/CLIProxyAPIPlus.git
synced 2026-03-21 16:40:22 +00:00
193 lines
3.8 KiB
Go
193 lines
3.8 KiB
Go
package kiro
|
|
|
|
import (
|
|
"context"
|
|
"log"
|
|
"sync"
|
|
"time"
|
|
|
|
"github.com/router-for-me/CLIProxyAPI/v6/internal/config"
|
|
"golang.org/x/sync/semaphore"
|
|
)
|
|
|
|
type Token struct {
|
|
ID string
|
|
AccessToken string
|
|
RefreshToken string
|
|
ExpiresAt time.Time
|
|
LastVerified time.Time
|
|
ClientID string
|
|
ClientSecret string
|
|
AuthMethod string
|
|
Provider string
|
|
StartURL string
|
|
Region string
|
|
}
|
|
|
|
type TokenRepository interface {
|
|
FindOldestUnverified(limit int) []*Token
|
|
UpdateToken(token *Token) error
|
|
}
|
|
|
|
type RefresherOption func(*BackgroundRefresher)
|
|
|
|
func WithInterval(interval time.Duration) RefresherOption {
|
|
return func(r *BackgroundRefresher) {
|
|
r.interval = interval
|
|
}
|
|
}
|
|
|
|
func WithBatchSize(size int) RefresherOption {
|
|
return func(r *BackgroundRefresher) {
|
|
r.batchSize = size
|
|
}
|
|
}
|
|
|
|
func WithConcurrency(concurrency int) RefresherOption {
|
|
return func(r *BackgroundRefresher) {
|
|
r.concurrency = concurrency
|
|
}
|
|
}
|
|
|
|
type BackgroundRefresher struct {
|
|
interval time.Duration
|
|
batchSize int
|
|
concurrency int
|
|
tokenRepo TokenRepository
|
|
stopCh chan struct{}
|
|
wg sync.WaitGroup
|
|
oauth *KiroOAuth
|
|
ssoClient *SSOOIDCClient
|
|
}
|
|
|
|
func NewBackgroundRefresher(repo TokenRepository, opts ...RefresherOption) *BackgroundRefresher {
|
|
r := &BackgroundRefresher{
|
|
interval: time.Minute,
|
|
batchSize: 50,
|
|
concurrency: 10,
|
|
tokenRepo: repo,
|
|
stopCh: make(chan struct{}),
|
|
oauth: nil, // Lazy init - will be set when config available
|
|
ssoClient: nil, // Lazy init - will be set when config available
|
|
}
|
|
for _, opt := range opts {
|
|
opt(r)
|
|
}
|
|
return r
|
|
}
|
|
|
|
// WithConfig sets the configuration for OAuth and SSO clients.
|
|
func WithConfig(cfg *config.Config) RefresherOption {
|
|
return func(r *BackgroundRefresher) {
|
|
r.oauth = NewKiroOAuth(cfg)
|
|
r.ssoClient = NewSSOOIDCClient(cfg)
|
|
}
|
|
}
|
|
|
|
func (r *BackgroundRefresher) Start(ctx context.Context) {
|
|
r.wg.Add(1)
|
|
go func() {
|
|
defer r.wg.Done()
|
|
ticker := time.NewTicker(r.interval)
|
|
defer ticker.Stop()
|
|
|
|
r.refreshBatch(ctx)
|
|
|
|
for {
|
|
select {
|
|
case <-ctx.Done():
|
|
return
|
|
case <-r.stopCh:
|
|
return
|
|
case <-ticker.C:
|
|
r.refreshBatch(ctx)
|
|
}
|
|
}
|
|
}()
|
|
}
|
|
|
|
func (r *BackgroundRefresher) Stop() {
|
|
close(r.stopCh)
|
|
r.wg.Wait()
|
|
}
|
|
|
|
func (r *BackgroundRefresher) refreshBatch(ctx context.Context) {
|
|
tokens := r.tokenRepo.FindOldestUnverified(r.batchSize)
|
|
if len(tokens) == 0 {
|
|
return
|
|
}
|
|
|
|
sem := semaphore.NewWeighted(int64(r.concurrency))
|
|
var wg sync.WaitGroup
|
|
|
|
for i, token := range tokens {
|
|
if i > 0 {
|
|
select {
|
|
case <-ctx.Done():
|
|
return
|
|
case <-r.stopCh:
|
|
return
|
|
case <-time.After(100 * time.Millisecond):
|
|
}
|
|
}
|
|
|
|
if err := sem.Acquire(ctx, 1); err != nil {
|
|
return
|
|
}
|
|
|
|
wg.Add(1)
|
|
go func(t *Token) {
|
|
defer wg.Done()
|
|
defer sem.Release(1)
|
|
r.refreshSingle(ctx, t)
|
|
}(token)
|
|
}
|
|
|
|
wg.Wait()
|
|
}
|
|
|
|
func (r *BackgroundRefresher) refreshSingle(ctx context.Context, token *Token) {
|
|
var newTokenData *KiroTokenData
|
|
var err error
|
|
|
|
switch token.AuthMethod {
|
|
case "idc":
|
|
newTokenData, err = r.ssoClient.RefreshTokenWithRegion(
|
|
ctx,
|
|
token.ClientID,
|
|
token.ClientSecret,
|
|
token.RefreshToken,
|
|
token.Region,
|
|
token.StartURL,
|
|
)
|
|
case "builder-id":
|
|
newTokenData, err = r.ssoClient.RefreshToken(
|
|
ctx,
|
|
token.ClientID,
|
|
token.ClientSecret,
|
|
token.RefreshToken,
|
|
)
|
|
default:
|
|
newTokenData, err = r.oauth.RefreshToken(ctx, token.RefreshToken)
|
|
}
|
|
|
|
if err != nil {
|
|
log.Printf("failed to refresh token %s: %v", token.ID, err)
|
|
return
|
|
}
|
|
|
|
token.AccessToken = newTokenData.AccessToken
|
|
token.RefreshToken = newTokenData.RefreshToken
|
|
token.LastVerified = time.Now()
|
|
|
|
if newTokenData.ExpiresAt != "" {
|
|
if expTime, parseErr := time.Parse(time.RFC3339, newTokenData.ExpiresAt); parseErr == nil {
|
|
token.ExpiresAt = expTime
|
|
}
|
|
}
|
|
|
|
if err := r.tokenRepo.UpdateToken(token); err != nil {
|
|
log.Printf("failed to update token %s: %v", token.ID, err)
|
|
}
|
|
}
|