mirror of
https://github.com/router-for-me/CLIProxyAPIPlus.git
synced 2026-03-21 16:40:22 +00:00
386 lines
9.6 KiB
Go
386 lines
9.6 KiB
Go
// Package kiro provides OAuth Web authentication for Kiro.
|
|
package kiro
|
|
|
|
import (
|
|
"context"
|
|
"crypto/rand"
|
|
"encoding/base64"
|
|
"fmt"
|
|
"html/template"
|
|
"net/http"
|
|
"sync"
|
|
"time"
|
|
|
|
"github.com/gin-gonic/gin"
|
|
"github.com/router-for-me/CLIProxyAPI/v6/internal/config"
|
|
log "github.com/sirupsen/logrus"
|
|
)
|
|
|
|
const (
|
|
defaultSessionExpiry = 10 * time.Minute
|
|
pollIntervalSeconds = 5
|
|
)
|
|
|
|
type authSessionStatus string
|
|
|
|
const (
|
|
statusPending authSessionStatus = "pending"
|
|
statusSuccess authSessionStatus = "success"
|
|
statusFailed authSessionStatus = "failed"
|
|
)
|
|
|
|
type webAuthSession struct {
|
|
stateID string
|
|
deviceCode string
|
|
userCode string
|
|
authURL string
|
|
verificationURI string
|
|
expiresIn int
|
|
interval int
|
|
status authSessionStatus
|
|
startedAt time.Time
|
|
completedAt time.Time
|
|
expiresAt time.Time
|
|
error string
|
|
tokenData *KiroTokenData
|
|
ssoClient *SSOOIDCClient
|
|
clientID string
|
|
clientSecret string
|
|
region string
|
|
cancelFunc context.CancelFunc
|
|
}
|
|
|
|
type OAuthWebHandler struct {
|
|
cfg *config.Config
|
|
sessions map[string]*webAuthSession
|
|
mu sync.RWMutex
|
|
onTokenObtained func(*KiroTokenData)
|
|
}
|
|
|
|
func NewOAuthWebHandler(cfg *config.Config) *OAuthWebHandler {
|
|
return &OAuthWebHandler{
|
|
cfg: cfg,
|
|
sessions: make(map[string]*webAuthSession),
|
|
}
|
|
}
|
|
|
|
func (h *OAuthWebHandler) SetTokenCallback(callback func(*KiroTokenData)) {
|
|
h.onTokenObtained = callback
|
|
}
|
|
|
|
func (h *OAuthWebHandler) RegisterRoutes(router gin.IRouter) {
|
|
oauth := router.Group("/v0/oauth/kiro")
|
|
{
|
|
oauth.GET("/start", h.handleStart)
|
|
oauth.GET("/callback", h.handleCallback)
|
|
oauth.GET("/status", h.handleStatus)
|
|
}
|
|
}
|
|
|
|
func generateStateID() (string, error) {
|
|
b := make([]byte, 16)
|
|
if _, err := rand.Read(b); err != nil {
|
|
return "", err
|
|
}
|
|
return base64.RawURLEncoding.EncodeToString(b), nil
|
|
}
|
|
|
|
func (h *OAuthWebHandler) handleStart(c *gin.Context) {
|
|
stateID, err := generateStateID()
|
|
if err != nil {
|
|
h.renderError(c, "Failed to generate state parameter")
|
|
return
|
|
}
|
|
|
|
region := defaultIDCRegion
|
|
startURL := builderIDStartURL
|
|
|
|
ssoClient := NewSSOOIDCClient(h.cfg)
|
|
|
|
regResp, err := ssoClient.RegisterClientWithRegion(c.Request.Context(), region)
|
|
if err != nil {
|
|
log.Errorf("OAuth Web: failed to register client: %v", err)
|
|
h.renderError(c, fmt.Sprintf("Failed to register client: %v", err))
|
|
return
|
|
}
|
|
|
|
authResp, err := ssoClient.StartDeviceAuthorizationWithIDC(
|
|
c.Request.Context(),
|
|
regResp.ClientID,
|
|
regResp.ClientSecret,
|
|
startURL,
|
|
region,
|
|
)
|
|
if err != nil {
|
|
log.Errorf("OAuth Web: failed to start device authorization: %v", err)
|
|
h.renderError(c, fmt.Sprintf("Failed to start device authorization: %v", err))
|
|
return
|
|
}
|
|
|
|
ctx, cancel := context.WithTimeout(context.Background(), time.Duration(authResp.ExpiresIn)*time.Second)
|
|
|
|
session := &webAuthSession{
|
|
stateID: stateID,
|
|
deviceCode: authResp.DeviceCode,
|
|
userCode: authResp.UserCode,
|
|
authURL: authResp.VerificationURIComplete,
|
|
verificationURI: authResp.VerificationURI,
|
|
expiresIn: authResp.ExpiresIn,
|
|
interval: authResp.Interval,
|
|
status: statusPending,
|
|
startedAt: time.Now(),
|
|
ssoClient: ssoClient,
|
|
clientID: regResp.ClientID,
|
|
clientSecret: regResp.ClientSecret,
|
|
region: region,
|
|
cancelFunc: cancel,
|
|
}
|
|
|
|
h.mu.Lock()
|
|
h.sessions[stateID] = session
|
|
h.mu.Unlock()
|
|
|
|
go h.pollForToken(ctx, session)
|
|
|
|
h.renderStartPage(c, session)
|
|
}
|
|
|
|
func (h *OAuthWebHandler) pollForToken(ctx context.Context, session *webAuthSession) {
|
|
defer session.cancelFunc()
|
|
|
|
interval := time.Duration(session.interval) * time.Second
|
|
if interval < time.Duration(pollIntervalSeconds)*time.Second {
|
|
interval = time.Duration(pollIntervalSeconds) * time.Second
|
|
}
|
|
|
|
ticker := time.NewTicker(interval)
|
|
defer ticker.Stop()
|
|
|
|
for {
|
|
select {
|
|
case <-ctx.Done():
|
|
h.mu.Lock()
|
|
if session.status == statusPending {
|
|
session.status = statusFailed
|
|
session.error = "Authentication timed out"
|
|
}
|
|
h.mu.Unlock()
|
|
return
|
|
case <-ticker.C:
|
|
tokenResp, err := h.ssoClient(session).CreateTokenWithRegion(
|
|
ctx,
|
|
session.clientID,
|
|
session.clientSecret,
|
|
session.deviceCode,
|
|
session.region,
|
|
)
|
|
|
|
if err != nil {
|
|
errStr := err.Error()
|
|
if errStr == ErrAuthorizationPending.Error() {
|
|
continue
|
|
}
|
|
if errStr == ErrSlowDown.Error() {
|
|
interval += 5 * time.Second
|
|
ticker.Reset(interval)
|
|
continue
|
|
}
|
|
|
|
h.mu.Lock()
|
|
session.status = statusFailed
|
|
session.error = errStr
|
|
session.completedAt = time.Now()
|
|
h.mu.Unlock()
|
|
|
|
log.Errorf("OAuth Web: token polling failed: %v", err)
|
|
return
|
|
}
|
|
|
|
expiresAt := time.Now().Add(time.Duration(tokenResp.ExpiresIn) * time.Second)
|
|
profileArn := session.ssoClient.fetchProfileArn(ctx, tokenResp.AccessToken)
|
|
email := FetchUserEmailWithFallback(ctx, h.cfg, tokenResp.AccessToken)
|
|
|
|
tokenData := &KiroTokenData{
|
|
AccessToken: tokenResp.AccessToken,
|
|
RefreshToken: tokenResp.RefreshToken,
|
|
ProfileArn: profileArn,
|
|
ExpiresAt: expiresAt.Format(time.RFC3339),
|
|
AuthMethod: "builder-id",
|
|
Provider: "AWS",
|
|
ClientID: session.clientID,
|
|
ClientSecret: session.clientSecret,
|
|
Email: email,
|
|
}
|
|
|
|
h.mu.Lock()
|
|
session.status = statusSuccess
|
|
session.completedAt = time.Now()
|
|
session.expiresAt = expiresAt
|
|
session.tokenData = tokenData
|
|
h.mu.Unlock()
|
|
|
|
if h.onTokenObtained != nil {
|
|
h.onTokenObtained(tokenData)
|
|
}
|
|
|
|
log.Infof("OAuth Web: authentication successful for %s", email)
|
|
return
|
|
}
|
|
}
|
|
}
|
|
|
|
func (h *OAuthWebHandler) ssoClient(session *webAuthSession) *SSOOIDCClient {
|
|
return session.ssoClient
|
|
}
|
|
|
|
func (h *OAuthWebHandler) handleCallback(c *gin.Context) {
|
|
stateID := c.Query("state")
|
|
errParam := c.Query("error")
|
|
|
|
if errParam != "" {
|
|
h.renderError(c, errParam)
|
|
return
|
|
}
|
|
|
|
if stateID == "" {
|
|
h.renderError(c, "Missing state parameter")
|
|
return
|
|
}
|
|
|
|
h.mu.RLock()
|
|
session, exists := h.sessions[stateID]
|
|
h.mu.RUnlock()
|
|
|
|
if !exists {
|
|
h.renderError(c, "Invalid or expired session")
|
|
return
|
|
}
|
|
|
|
if session.status == statusSuccess {
|
|
h.renderSuccess(c, session)
|
|
} else if session.status == statusFailed {
|
|
h.renderError(c, session.error)
|
|
} else {
|
|
c.Redirect(http.StatusFound, "/v0/oauth/kiro/start")
|
|
}
|
|
}
|
|
|
|
func (h *OAuthWebHandler) handleStatus(c *gin.Context) {
|
|
stateID := c.Query("state")
|
|
if stateID == "" {
|
|
c.JSON(http.StatusBadRequest, gin.H{"error": "missing state parameter"})
|
|
return
|
|
}
|
|
|
|
h.mu.RLock()
|
|
session, exists := h.sessions[stateID]
|
|
h.mu.RUnlock()
|
|
|
|
if !exists {
|
|
c.JSON(http.StatusNotFound, gin.H{"error": "session not found"})
|
|
return
|
|
}
|
|
|
|
response := gin.H{
|
|
"status": string(session.status),
|
|
}
|
|
|
|
switch session.status {
|
|
case statusPending:
|
|
elapsed := time.Since(session.startedAt).Seconds()
|
|
remaining := float64(session.expiresIn) - elapsed
|
|
if remaining < 0 {
|
|
remaining = 0
|
|
}
|
|
response["remaining_seconds"] = int(remaining)
|
|
case statusSuccess:
|
|
response["completed_at"] = session.completedAt.Format(time.RFC3339)
|
|
response["expires_at"] = session.expiresAt.Format(time.RFC3339)
|
|
case statusFailed:
|
|
response["error"] = session.error
|
|
response["failed_at"] = session.completedAt.Format(time.RFC3339)
|
|
}
|
|
|
|
c.JSON(http.StatusOK, response)
|
|
}
|
|
|
|
func (h *OAuthWebHandler) renderStartPage(c *gin.Context, session *webAuthSession) {
|
|
tmpl, err := template.New("start").Parse(oauthWebStartPageHTML)
|
|
if err != nil {
|
|
log.Errorf("OAuth Web: failed to parse template: %v", err)
|
|
c.String(http.StatusInternalServerError, "Template error")
|
|
return
|
|
}
|
|
|
|
data := map[string]interface{}{
|
|
"AuthURL": session.authURL,
|
|
"UserCode": session.userCode,
|
|
"ExpiresIn": session.expiresIn,
|
|
"StateID": session.stateID,
|
|
}
|
|
|
|
c.Header("Content-Type", "text/html; charset=utf-8")
|
|
if err := tmpl.Execute(c.Writer, data); err != nil {
|
|
log.Errorf("OAuth Web: failed to render template: %v", err)
|
|
}
|
|
}
|
|
|
|
func (h *OAuthWebHandler) renderError(c *gin.Context, errMsg string) {
|
|
tmpl, err := template.New("error").Parse(oauthWebErrorPageHTML)
|
|
if err != nil {
|
|
log.Errorf("OAuth Web: failed to parse error template: %v", err)
|
|
c.String(http.StatusInternalServerError, "Template error")
|
|
return
|
|
}
|
|
|
|
data := map[string]interface{}{
|
|
"Error": errMsg,
|
|
}
|
|
|
|
c.Header("Content-Type", "text/html; charset=utf-8")
|
|
c.Status(http.StatusBadRequest)
|
|
if err := tmpl.Execute(c.Writer, data); err != nil {
|
|
log.Errorf("OAuth Web: failed to render error template: %v", err)
|
|
}
|
|
}
|
|
|
|
func (h *OAuthWebHandler) renderSuccess(c *gin.Context, session *webAuthSession) {
|
|
tmpl, err := template.New("success").Parse(oauthWebSuccessPageHTML)
|
|
if err != nil {
|
|
log.Errorf("OAuth Web: failed to parse success template: %v", err)
|
|
c.String(http.StatusInternalServerError, "Template error")
|
|
return
|
|
}
|
|
|
|
data := map[string]interface{}{
|
|
"ExpiresAt": session.expiresAt.Format(time.RFC3339),
|
|
}
|
|
|
|
c.Header("Content-Type", "text/html; charset=utf-8")
|
|
if err := tmpl.Execute(c.Writer, data); err != nil {
|
|
log.Errorf("OAuth Web: failed to render success template: %v", err)
|
|
}
|
|
}
|
|
|
|
func (h *OAuthWebHandler) CleanupExpiredSessions() {
|
|
h.mu.Lock()
|
|
defer h.mu.Unlock()
|
|
|
|
now := time.Now()
|
|
for id, session := range h.sessions {
|
|
if session.status != statusPending && now.Sub(session.completedAt) > 30*time.Minute {
|
|
delete(h.sessions, id)
|
|
} else if session.status == statusPending && now.Sub(session.startedAt) > defaultSessionExpiry {
|
|
session.cancelFunc()
|
|
delete(h.sessions, id)
|
|
}
|
|
}
|
|
}
|
|
|
|
func (h *OAuthWebHandler) GetSession(stateID string) (*webAuthSession, bool) {
|
|
h.mu.RLock()
|
|
defer h.mu.RUnlock()
|
|
session, exists := h.sessions[stateID]
|
|
return session, exists
|
|
}
|