|
|
|
|
@@ -3,6 +3,9 @@ package management
|
|
|
|
|
import (
|
|
|
|
|
"bytes"
|
|
|
|
|
"context"
|
|
|
|
|
"crypto/rand"
|
|
|
|
|
"crypto/sha256"
|
|
|
|
|
"encoding/base64"
|
|
|
|
|
"encoding/json"
|
|
|
|
|
"errors"
|
|
|
|
|
"fmt"
|
|
|
|
|
@@ -23,6 +26,7 @@ import (
|
|
|
|
|
"github.com/router-for-me/CLIProxyAPI/v6/internal/auth/codex"
|
|
|
|
|
geminiAuth "github.com/router-for-me/CLIProxyAPI/v6/internal/auth/gemini"
|
|
|
|
|
iflowauth "github.com/router-for-me/CLIProxyAPI/v6/internal/auth/iflow"
|
|
|
|
|
kiroauth "github.com/router-for-me/CLIProxyAPI/v6/internal/auth/kiro"
|
|
|
|
|
"github.com/router-for-me/CLIProxyAPI/v6/internal/auth/qwen"
|
|
|
|
|
"github.com/router-for-me/CLIProxyAPI/v6/internal/interfaces"
|
|
|
|
|
"github.com/router-for-me/CLIProxyAPI/v6/internal/misc"
|
|
|
|
|
@@ -37,9 +41,32 @@ import (
|
|
|
|
|
)
|
|
|
|
|
|
|
|
|
|
var (
|
|
|
|
|
oauthStatus = make(map[string]string)
|
|
|
|
|
oauthStatus = make(map[string]string)
|
|
|
|
|
oauthStatusMutex sync.RWMutex
|
|
|
|
|
)
|
|
|
|
|
|
|
|
|
|
// getOAuthStatus safely retrieves an OAuth status
|
|
|
|
|
func getOAuthStatus(key string) (string, bool) {
|
|
|
|
|
oauthStatusMutex.RLock()
|
|
|
|
|
defer oauthStatusMutex.RUnlock()
|
|
|
|
|
status, ok := oauthStatus[key]
|
|
|
|
|
return status, ok
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
// setOAuthStatus safely sets an OAuth status
|
|
|
|
|
func setOAuthStatus(key string, status string) {
|
|
|
|
|
oauthStatusMutex.Lock()
|
|
|
|
|
defer oauthStatusMutex.Unlock()
|
|
|
|
|
oauthStatus[key] = status
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
// deleteOAuthStatus safely deletes an OAuth status
|
|
|
|
|
func deleteOAuthStatus(key string) {
|
|
|
|
|
oauthStatusMutex.Lock()
|
|
|
|
|
defer oauthStatusMutex.Unlock()
|
|
|
|
|
delete(oauthStatus, key)
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
var lastRefreshKeys = []string{"last_refresh", "lastRefresh", "last_refreshed_at", "lastRefreshedAt"}
|
|
|
|
|
|
|
|
|
|
const (
|
|
|
|
|
@@ -812,7 +839,7 @@ func (h *Handler) RequestAnthropicToken(c *gin.Context) {
|
|
|
|
|
deadline := time.Now().Add(timeout)
|
|
|
|
|
for {
|
|
|
|
|
if time.Now().After(deadline) {
|
|
|
|
|
oauthStatus[state] = "Timeout waiting for OAuth callback"
|
|
|
|
|
setOAuthStatus(state, "Timeout waiting for OAuth callback")
|
|
|
|
|
return nil, fmt.Errorf("timeout waiting for OAuth callback")
|
|
|
|
|
}
|
|
|
|
|
data, errRead := os.ReadFile(path)
|
|
|
|
|
@@ -837,13 +864,13 @@ func (h *Handler) RequestAnthropicToken(c *gin.Context) {
|
|
|
|
|
if errStr := resultMap["error"]; errStr != "" {
|
|
|
|
|
oauthErr := claude.NewOAuthError(errStr, "", http.StatusBadRequest)
|
|
|
|
|
log.Error(claude.GetUserFriendlyMessage(oauthErr))
|
|
|
|
|
oauthStatus[state] = "Bad request"
|
|
|
|
|
setOAuthStatus(state, "Bad request")
|
|
|
|
|
return
|
|
|
|
|
}
|
|
|
|
|
if resultMap["state"] != state {
|
|
|
|
|
authErr := claude.NewAuthenticationError(claude.ErrInvalidState, fmt.Errorf("expected %s, got %s", state, resultMap["state"]))
|
|
|
|
|
log.Error(claude.GetUserFriendlyMessage(authErr))
|
|
|
|
|
oauthStatus[state] = "State code error"
|
|
|
|
|
setOAuthStatus(state, "State code error")
|
|
|
|
|
return
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
@@ -876,7 +903,7 @@ func (h *Handler) RequestAnthropicToken(c *gin.Context) {
|
|
|
|
|
if errDo != nil {
|
|
|
|
|
authErr := claude.NewAuthenticationError(claude.ErrCodeExchangeFailed, errDo)
|
|
|
|
|
log.Errorf("Failed to exchange authorization code for tokens: %v", authErr)
|
|
|
|
|
oauthStatus[state] = "Failed to exchange authorization code for tokens"
|
|
|
|
|
setOAuthStatus(state, "Failed to exchange authorization code for tokens")
|
|
|
|
|
return
|
|
|
|
|
}
|
|
|
|
|
defer func() {
|
|
|
|
|
@@ -887,7 +914,7 @@ func (h *Handler) RequestAnthropicToken(c *gin.Context) {
|
|
|
|
|
respBody, _ := io.ReadAll(resp.Body)
|
|
|
|
|
if resp.StatusCode != http.StatusOK {
|
|
|
|
|
log.Errorf("token exchange failed with status %d: %s", resp.StatusCode, string(respBody))
|
|
|
|
|
oauthStatus[state] = fmt.Sprintf("token exchange failed with status %d", resp.StatusCode)
|
|
|
|
|
setOAuthStatus(state, fmt.Sprintf("token exchange failed with status %d", resp.StatusCode))
|
|
|
|
|
return
|
|
|
|
|
}
|
|
|
|
|
var tResp struct {
|
|
|
|
|
@@ -900,7 +927,7 @@ func (h *Handler) RequestAnthropicToken(c *gin.Context) {
|
|
|
|
|
}
|
|
|
|
|
if errU := json.Unmarshal(respBody, &tResp); errU != nil {
|
|
|
|
|
log.Errorf("failed to parse token response: %v", errU)
|
|
|
|
|
oauthStatus[state] = "Failed to parse token response"
|
|
|
|
|
setOAuthStatus(state, "Failed to parse token response")
|
|
|
|
|
return
|
|
|
|
|
}
|
|
|
|
|
bundle := &claude.ClaudeAuthBundle{
|
|
|
|
|
@@ -925,7 +952,7 @@ func (h *Handler) RequestAnthropicToken(c *gin.Context) {
|
|
|
|
|
savedPath, errSave := h.saveTokenRecord(ctx, record)
|
|
|
|
|
if errSave != nil {
|
|
|
|
|
log.Errorf("Failed to save authentication tokens: %v", errSave)
|
|
|
|
|
oauthStatus[state] = "Failed to save authentication tokens"
|
|
|
|
|
setOAuthStatus(state, "Failed to save authentication tokens")
|
|
|
|
|
return
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
@@ -934,10 +961,10 @@ func (h *Handler) RequestAnthropicToken(c *gin.Context) {
|
|
|
|
|
fmt.Println("API key obtained and saved")
|
|
|
|
|
}
|
|
|
|
|
fmt.Println("You can now use Claude services through this CLI")
|
|
|
|
|
delete(oauthStatus, state)
|
|
|
|
|
deleteOAuthStatus(state)
|
|
|
|
|
}()
|
|
|
|
|
|
|
|
|
|
oauthStatus[state] = ""
|
|
|
|
|
setOAuthStatus(state, "")
|
|
|
|
|
c.JSON(200, gin.H{"status": "ok", "url": authURL, "state": state})
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
@@ -996,7 +1023,7 @@ func (h *Handler) RequestGeminiCLIToken(c *gin.Context) {
|
|
|
|
|
for {
|
|
|
|
|
if time.Now().After(deadline) {
|
|
|
|
|
log.Error("oauth flow timed out")
|
|
|
|
|
oauthStatus[state] = "OAuth flow timed out"
|
|
|
|
|
setOAuthStatus(state, "OAuth flow timed out")
|
|
|
|
|
return
|
|
|
|
|
}
|
|
|
|
|
if data, errR := os.ReadFile(waitFile); errR == nil {
|
|
|
|
|
@@ -1005,13 +1032,13 @@ func (h *Handler) RequestGeminiCLIToken(c *gin.Context) {
|
|
|
|
|
_ = os.Remove(waitFile)
|
|
|
|
|
if errStr := m["error"]; errStr != "" {
|
|
|
|
|
log.Errorf("Authentication failed: %s", errStr)
|
|
|
|
|
oauthStatus[state] = "Authentication failed"
|
|
|
|
|
setOAuthStatus(state, "Authentication failed")
|
|
|
|
|
return
|
|
|
|
|
}
|
|
|
|
|
authCode = m["code"]
|
|
|
|
|
if authCode == "" {
|
|
|
|
|
log.Errorf("Authentication failed: code not found")
|
|
|
|
|
oauthStatus[state] = "Authentication failed: code not found"
|
|
|
|
|
setOAuthStatus(state, "Authentication failed: code not found")
|
|
|
|
|
return
|
|
|
|
|
}
|
|
|
|
|
break
|
|
|
|
|
@@ -1023,7 +1050,7 @@ func (h *Handler) RequestGeminiCLIToken(c *gin.Context) {
|
|
|
|
|
token, err := conf.Exchange(ctx, authCode)
|
|
|
|
|
if err != nil {
|
|
|
|
|
log.Errorf("Failed to exchange token: %v", err)
|
|
|
|
|
oauthStatus[state] = "Failed to exchange token"
|
|
|
|
|
setOAuthStatus(state, "Failed to exchange token")
|
|
|
|
|
return
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
@@ -1034,7 +1061,7 @@ func (h *Handler) RequestGeminiCLIToken(c *gin.Context) {
|
|
|
|
|
req, errNewRequest := http.NewRequestWithContext(ctx, "GET", "https://www.googleapis.com/oauth2/v1/userinfo?alt=json", nil)
|
|
|
|
|
if errNewRequest != nil {
|
|
|
|
|
log.Errorf("Could not get user info: %v", errNewRequest)
|
|
|
|
|
oauthStatus[state] = "Could not get user info"
|
|
|
|
|
setOAuthStatus(state, "Could not get user info")
|
|
|
|
|
return
|
|
|
|
|
}
|
|
|
|
|
req.Header.Set("Content-Type", "application/json")
|
|
|
|
|
@@ -1043,7 +1070,7 @@ func (h *Handler) RequestGeminiCLIToken(c *gin.Context) {
|
|
|
|
|
resp, errDo := authHTTPClient.Do(req)
|
|
|
|
|
if errDo != nil {
|
|
|
|
|
log.Errorf("Failed to execute request: %v", errDo)
|
|
|
|
|
oauthStatus[state] = "Failed to execute request"
|
|
|
|
|
setOAuthStatus(state, "Failed to execute request")
|
|
|
|
|
return
|
|
|
|
|
}
|
|
|
|
|
defer func() {
|
|
|
|
|
@@ -1055,7 +1082,7 @@ func (h *Handler) RequestGeminiCLIToken(c *gin.Context) {
|
|
|
|
|
bodyBytes, _ := io.ReadAll(resp.Body)
|
|
|
|
|
if resp.StatusCode < 200 || resp.StatusCode >= 300 {
|
|
|
|
|
log.Errorf("Get user info request failed with status %d: %s", resp.StatusCode, string(bodyBytes))
|
|
|
|
|
oauthStatus[state] = fmt.Sprintf("Get user info request failed with status %d", resp.StatusCode)
|
|
|
|
|
setOAuthStatus(state, fmt.Sprintf("Get user info request failed with status %d", resp.StatusCode))
|
|
|
|
|
return
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
@@ -1064,7 +1091,7 @@ func (h *Handler) RequestGeminiCLIToken(c *gin.Context) {
|
|
|
|
|
fmt.Printf("Authenticated user email: %s\n", email)
|
|
|
|
|
} else {
|
|
|
|
|
fmt.Println("Failed to get user email from token")
|
|
|
|
|
oauthStatus[state] = "Failed to get user email from token"
|
|
|
|
|
setOAuthStatus(state, "Failed to get user email from token")
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
// Marshal/unmarshal oauth2.Token to generic map and enrich fields
|
|
|
|
|
@@ -1072,7 +1099,7 @@ func (h *Handler) RequestGeminiCLIToken(c *gin.Context) {
|
|
|
|
|
jsonData, _ := json.Marshal(token)
|
|
|
|
|
if errUnmarshal := json.Unmarshal(jsonData, &ifToken); errUnmarshal != nil {
|
|
|
|
|
log.Errorf("Failed to unmarshal token: %v", errUnmarshal)
|
|
|
|
|
oauthStatus[state] = "Failed to unmarshal token"
|
|
|
|
|
setOAuthStatus(state, "Failed to unmarshal token")
|
|
|
|
|
return
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
@@ -1098,7 +1125,7 @@ func (h *Handler) RequestGeminiCLIToken(c *gin.Context) {
|
|
|
|
|
gemClient, errGetClient := gemAuth.GetAuthenticatedClient(ctx, &ts, h.cfg, true)
|
|
|
|
|
if errGetClient != nil {
|
|
|
|
|
log.Errorf("failed to get authenticated client: %v", errGetClient)
|
|
|
|
|
oauthStatus[state] = "Failed to get authenticated client"
|
|
|
|
|
setOAuthStatus(state, "Failed to get authenticated client")
|
|
|
|
|
return
|
|
|
|
|
}
|
|
|
|
|
fmt.Println("Authentication successful.")
|
|
|
|
|
@@ -1108,12 +1135,12 @@ func (h *Handler) RequestGeminiCLIToken(c *gin.Context) {
|
|
|
|
|
projects, errAll := onboardAllGeminiProjects(ctx, gemClient, &ts)
|
|
|
|
|
if errAll != nil {
|
|
|
|
|
log.Errorf("Failed to complete Gemini CLI onboarding: %v", errAll)
|
|
|
|
|
oauthStatus[state] = "Failed to complete Gemini CLI onboarding"
|
|
|
|
|
setOAuthStatus(state, "Failed to complete Gemini CLI onboarding")
|
|
|
|
|
return
|
|
|
|
|
}
|
|
|
|
|
if errVerify := ensureGeminiProjectsEnabled(ctx, gemClient, projects); errVerify != nil {
|
|
|
|
|
log.Errorf("Failed to verify Cloud AI API status: %v", errVerify)
|
|
|
|
|
oauthStatus[state] = "Failed to verify Cloud AI API status"
|
|
|
|
|
setOAuthStatus(state, "Failed to verify Cloud AI API status")
|
|
|
|
|
return
|
|
|
|
|
}
|
|
|
|
|
ts.ProjectID = strings.Join(projects, ",")
|
|
|
|
|
@@ -1121,26 +1148,26 @@ func (h *Handler) RequestGeminiCLIToken(c *gin.Context) {
|
|
|
|
|
} else {
|
|
|
|
|
if errEnsure := ensureGeminiProjectAndOnboard(ctx, gemClient, &ts, requestedProjectID); errEnsure != nil {
|
|
|
|
|
log.Errorf("Failed to complete Gemini CLI onboarding: %v", errEnsure)
|
|
|
|
|
oauthStatus[state] = "Failed to complete Gemini CLI onboarding"
|
|
|
|
|
setOAuthStatus(state, "Failed to complete Gemini CLI onboarding")
|
|
|
|
|
return
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
if strings.TrimSpace(ts.ProjectID) == "" {
|
|
|
|
|
log.Error("Onboarding did not return a project ID")
|
|
|
|
|
oauthStatus[state] = "Failed to resolve project ID"
|
|
|
|
|
setOAuthStatus(state, "Failed to resolve project ID")
|
|
|
|
|
return
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
isChecked, errCheck := checkCloudAPIIsEnabled(ctx, gemClient, ts.ProjectID)
|
|
|
|
|
if errCheck != nil {
|
|
|
|
|
log.Errorf("Failed to verify Cloud AI API status: %v", errCheck)
|
|
|
|
|
oauthStatus[state] = "Failed to verify Cloud AI API status"
|
|
|
|
|
setOAuthStatus(state, "Failed to verify Cloud AI API status")
|
|
|
|
|
return
|
|
|
|
|
}
|
|
|
|
|
ts.Checked = isChecked
|
|
|
|
|
if !isChecked {
|
|
|
|
|
log.Error("Cloud AI API is not enabled for the selected project")
|
|
|
|
|
oauthStatus[state] = "Cloud AI API not enabled"
|
|
|
|
|
setOAuthStatus(state, "Cloud AI API not enabled")
|
|
|
|
|
return
|
|
|
|
|
}
|
|
|
|
|
}
|
|
|
|
|
@@ -1163,15 +1190,15 @@ func (h *Handler) RequestGeminiCLIToken(c *gin.Context) {
|
|
|
|
|
savedPath, errSave := h.saveTokenRecord(ctx, record)
|
|
|
|
|
if errSave != nil {
|
|
|
|
|
log.Errorf("Failed to save token to file: %v", errSave)
|
|
|
|
|
oauthStatus[state] = "Failed to save token to file"
|
|
|
|
|
setOAuthStatus(state, "Failed to save token to file")
|
|
|
|
|
return
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
delete(oauthStatus, state)
|
|
|
|
|
deleteOAuthStatus(state)
|
|
|
|
|
fmt.Printf("You can now use Gemini CLI services through this CLI; token saved to %s\n", savedPath)
|
|
|
|
|
}()
|
|
|
|
|
|
|
|
|
|
oauthStatus[state] = ""
|
|
|
|
|
setOAuthStatus(state, "")
|
|
|
|
|
c.JSON(200, gin.H{"status": "ok", "url": authURL, "state": state})
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
@@ -1235,7 +1262,7 @@ func (h *Handler) RequestCodexToken(c *gin.Context) {
|
|
|
|
|
if time.Now().After(deadline) {
|
|
|
|
|
authErr := codex.NewAuthenticationError(codex.ErrCallbackTimeout, fmt.Errorf("timeout waiting for OAuth callback"))
|
|
|
|
|
log.Error(codex.GetUserFriendlyMessage(authErr))
|
|
|
|
|
oauthStatus[state] = "Timeout waiting for OAuth callback"
|
|
|
|
|
setOAuthStatus(state, "Timeout waiting for OAuth callback")
|
|
|
|
|
return
|
|
|
|
|
}
|
|
|
|
|
if data, errR := os.ReadFile(waitFile); errR == nil {
|
|
|
|
|
@@ -1245,12 +1272,12 @@ func (h *Handler) RequestCodexToken(c *gin.Context) {
|
|
|
|
|
if errStr := m["error"]; errStr != "" {
|
|
|
|
|
oauthErr := codex.NewOAuthError(errStr, "", http.StatusBadRequest)
|
|
|
|
|
log.Error(codex.GetUserFriendlyMessage(oauthErr))
|
|
|
|
|
oauthStatus[state] = "Bad Request"
|
|
|
|
|
setOAuthStatus(state, "Bad Request")
|
|
|
|
|
return
|
|
|
|
|
}
|
|
|
|
|
if m["state"] != state {
|
|
|
|
|
authErr := codex.NewAuthenticationError(codex.ErrInvalidState, fmt.Errorf("expected %s, got %s", state, m["state"]))
|
|
|
|
|
oauthStatus[state] = "State code error"
|
|
|
|
|
setOAuthStatus(state, "State code error")
|
|
|
|
|
log.Error(codex.GetUserFriendlyMessage(authErr))
|
|
|
|
|
return
|
|
|
|
|
}
|
|
|
|
|
@@ -1281,14 +1308,14 @@ func (h *Handler) RequestCodexToken(c *gin.Context) {
|
|
|
|
|
resp, errDo := httpClient.Do(req)
|
|
|
|
|
if errDo != nil {
|
|
|
|
|
authErr := codex.NewAuthenticationError(codex.ErrCodeExchangeFailed, errDo)
|
|
|
|
|
oauthStatus[state] = "Failed to exchange authorization code for tokens"
|
|
|
|
|
setOAuthStatus(state, "Failed to exchange authorization code for tokens")
|
|
|
|
|
log.Errorf("Failed to exchange authorization code for tokens: %v", authErr)
|
|
|
|
|
return
|
|
|
|
|
}
|
|
|
|
|
defer func() { _ = resp.Body.Close() }()
|
|
|
|
|
respBody, _ := io.ReadAll(resp.Body)
|
|
|
|
|
if resp.StatusCode != http.StatusOK {
|
|
|
|
|
oauthStatus[state] = fmt.Sprintf("Token exchange failed with status %d", resp.StatusCode)
|
|
|
|
|
setOAuthStatus(state, fmt.Sprintf("Token exchange failed with status %d", resp.StatusCode))
|
|
|
|
|
log.Errorf("token exchange failed with status %d: %s", resp.StatusCode, string(respBody))
|
|
|
|
|
return
|
|
|
|
|
}
|
|
|
|
|
@@ -1299,7 +1326,7 @@ func (h *Handler) RequestCodexToken(c *gin.Context) {
|
|
|
|
|
ExpiresIn int `json:"expires_in"`
|
|
|
|
|
}
|
|
|
|
|
if errU := json.Unmarshal(respBody, &tokenResp); errU != nil {
|
|
|
|
|
oauthStatus[state] = "Failed to parse token response"
|
|
|
|
|
setOAuthStatus(state, "Failed to parse token response")
|
|
|
|
|
log.Errorf("failed to parse token response: %v", errU)
|
|
|
|
|
return
|
|
|
|
|
}
|
|
|
|
|
@@ -1337,8 +1364,8 @@ func (h *Handler) RequestCodexToken(c *gin.Context) {
|
|
|
|
|
}
|
|
|
|
|
savedPath, errSave := h.saveTokenRecord(ctx, record)
|
|
|
|
|
if errSave != nil {
|
|
|
|
|
oauthStatus[state] = "Failed to save authentication tokens"
|
|
|
|
|
log.Errorf("Failed to save authentication tokens: %v", errSave)
|
|
|
|
|
setOAuthStatus(state, "Failed to save authentication tokens")
|
|
|
|
|
return
|
|
|
|
|
}
|
|
|
|
|
fmt.Printf("Authentication successful! Token saved to %s\n", savedPath)
|
|
|
|
|
@@ -1346,10 +1373,10 @@ func (h *Handler) RequestCodexToken(c *gin.Context) {
|
|
|
|
|
fmt.Println("API key obtained and saved")
|
|
|
|
|
}
|
|
|
|
|
fmt.Println("You can now use Codex services through this CLI")
|
|
|
|
|
delete(oauthStatus, state)
|
|
|
|
|
deleteOAuthStatus(state)
|
|
|
|
|
}()
|
|
|
|
|
|
|
|
|
|
oauthStatus[state] = ""
|
|
|
|
|
setOAuthStatus(state, "")
|
|
|
|
|
c.JSON(200, gin.H{"status": "ok", "url": authURL, "state": state})
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
@@ -1416,7 +1443,7 @@ func (h *Handler) RequestAntigravityToken(c *gin.Context) {
|
|
|
|
|
for {
|
|
|
|
|
if time.Now().After(deadline) {
|
|
|
|
|
log.Error("oauth flow timed out")
|
|
|
|
|
oauthStatus[state] = "OAuth flow timed out"
|
|
|
|
|
setOAuthStatus(state, "OAuth flow timed out")
|
|
|
|
|
return
|
|
|
|
|
}
|
|
|
|
|
if data, errReadFile := os.ReadFile(waitFile); errReadFile == nil {
|
|
|
|
|
@@ -1425,18 +1452,18 @@ func (h *Handler) RequestAntigravityToken(c *gin.Context) {
|
|
|
|
|
_ = os.Remove(waitFile)
|
|
|
|
|
if errStr := strings.TrimSpace(payload["error"]); errStr != "" {
|
|
|
|
|
log.Errorf("Authentication failed: %s", errStr)
|
|
|
|
|
oauthStatus[state] = "Authentication failed"
|
|
|
|
|
setOAuthStatus(state, "Authentication failed")
|
|
|
|
|
return
|
|
|
|
|
}
|
|
|
|
|
if payloadState := strings.TrimSpace(payload["state"]); payloadState != "" && payloadState != state {
|
|
|
|
|
log.Errorf("Authentication failed: state mismatch")
|
|
|
|
|
oauthStatus[state] = "Authentication failed: state mismatch"
|
|
|
|
|
setOAuthStatus(state, "Authentication failed: state mismatch")
|
|
|
|
|
return
|
|
|
|
|
}
|
|
|
|
|
authCode = strings.TrimSpace(payload["code"])
|
|
|
|
|
if authCode == "" {
|
|
|
|
|
log.Error("Authentication failed: code not found")
|
|
|
|
|
oauthStatus[state] = "Authentication failed: code not found"
|
|
|
|
|
setOAuthStatus(state, "Authentication failed: code not found")
|
|
|
|
|
return
|
|
|
|
|
}
|
|
|
|
|
break
|
|
|
|
|
@@ -1455,7 +1482,7 @@ func (h *Handler) RequestAntigravityToken(c *gin.Context) {
|
|
|
|
|
req, errNewRequest := http.NewRequestWithContext(ctx, http.MethodPost, "https://oauth2.googleapis.com/token", strings.NewReader(form.Encode()))
|
|
|
|
|
if errNewRequest != nil {
|
|
|
|
|
log.Errorf("Failed to build token request: %v", errNewRequest)
|
|
|
|
|
oauthStatus[state] = "Failed to build token request"
|
|
|
|
|
setOAuthStatus(state, "Failed to build token request")
|
|
|
|
|
return
|
|
|
|
|
}
|
|
|
|
|
req.Header.Set("Content-Type", "application/x-www-form-urlencoded")
|
|
|
|
|
@@ -1463,7 +1490,7 @@ func (h *Handler) RequestAntigravityToken(c *gin.Context) {
|
|
|
|
|
resp, errDo := httpClient.Do(req)
|
|
|
|
|
if errDo != nil {
|
|
|
|
|
log.Errorf("Failed to execute token request: %v", errDo)
|
|
|
|
|
oauthStatus[state] = "Failed to exchange token"
|
|
|
|
|
setOAuthStatus(state, "Failed to exchange token")
|
|
|
|
|
return
|
|
|
|
|
}
|
|
|
|
|
defer func() {
|
|
|
|
|
@@ -1475,7 +1502,7 @@ func (h *Handler) RequestAntigravityToken(c *gin.Context) {
|
|
|
|
|
if resp.StatusCode < http.StatusOK || resp.StatusCode >= http.StatusMultipleChoices {
|
|
|
|
|
bodyBytes, _ := io.ReadAll(resp.Body)
|
|
|
|
|
log.Errorf("Antigravity token exchange failed with status %d: %s", resp.StatusCode, string(bodyBytes))
|
|
|
|
|
oauthStatus[state] = fmt.Sprintf("Token exchange failed: %d", resp.StatusCode)
|
|
|
|
|
setOAuthStatus(state, fmt.Sprintf("Token exchange failed: %d", resp.StatusCode))
|
|
|
|
|
return
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
@@ -1487,7 +1514,7 @@ func (h *Handler) RequestAntigravityToken(c *gin.Context) {
|
|
|
|
|
}
|
|
|
|
|
if errDecode := json.NewDecoder(resp.Body).Decode(&tokenResp); errDecode != nil {
|
|
|
|
|
log.Errorf("Failed to parse token response: %v", errDecode)
|
|
|
|
|
oauthStatus[state] = "Failed to parse token response"
|
|
|
|
|
setOAuthStatus(state, "Failed to parse token response")
|
|
|
|
|
return
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
@@ -1496,7 +1523,7 @@ func (h *Handler) RequestAntigravityToken(c *gin.Context) {
|
|
|
|
|
infoReq, errInfoReq := http.NewRequestWithContext(ctx, http.MethodGet, "https://www.googleapis.com/oauth2/v1/userinfo?alt=json", nil)
|
|
|
|
|
if errInfoReq != nil {
|
|
|
|
|
log.Errorf("Failed to build user info request: %v", errInfoReq)
|
|
|
|
|
oauthStatus[state] = "Failed to build user info request"
|
|
|
|
|
setOAuthStatus(state, "Failed to build user info request")
|
|
|
|
|
return
|
|
|
|
|
}
|
|
|
|
|
infoReq.Header.Set("Authorization", "Bearer "+tokenResp.AccessToken)
|
|
|
|
|
@@ -1504,7 +1531,7 @@ func (h *Handler) RequestAntigravityToken(c *gin.Context) {
|
|
|
|
|
infoResp, errInfo := httpClient.Do(infoReq)
|
|
|
|
|
if errInfo != nil {
|
|
|
|
|
log.Errorf("Failed to execute user info request: %v", errInfo)
|
|
|
|
|
oauthStatus[state] = "Failed to execute user info request"
|
|
|
|
|
setOAuthStatus(state, "Failed to execute user info request")
|
|
|
|
|
return
|
|
|
|
|
}
|
|
|
|
|
defer func() {
|
|
|
|
|
@@ -1523,7 +1550,7 @@ func (h *Handler) RequestAntigravityToken(c *gin.Context) {
|
|
|
|
|
} else {
|
|
|
|
|
bodyBytes, _ := io.ReadAll(infoResp.Body)
|
|
|
|
|
log.Errorf("User info request failed with status %d: %s", infoResp.StatusCode, string(bodyBytes))
|
|
|
|
|
oauthStatus[state] = fmt.Sprintf("User info request failed: %d", infoResp.StatusCode)
|
|
|
|
|
setOAuthStatus(state, fmt.Sprintf("User info request failed: %d", infoResp.StatusCode))
|
|
|
|
|
return
|
|
|
|
|
}
|
|
|
|
|
}
|
|
|
|
|
@@ -1571,11 +1598,11 @@ func (h *Handler) RequestAntigravityToken(c *gin.Context) {
|
|
|
|
|
savedPath, errSave := h.saveTokenRecord(ctx, record)
|
|
|
|
|
if errSave != nil {
|
|
|
|
|
log.Errorf("Failed to save token to file: %v", errSave)
|
|
|
|
|
oauthStatus[state] = "Failed to save token to file"
|
|
|
|
|
setOAuthStatus(state, "Failed to save token to file")
|
|
|
|
|
return
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
delete(oauthStatus, state)
|
|
|
|
|
deleteOAuthStatus(state)
|
|
|
|
|
fmt.Printf("Authentication successful! Token saved to %s\n", savedPath)
|
|
|
|
|
if projectID != "" {
|
|
|
|
|
fmt.Printf("Using GCP project: %s\n", projectID)
|
|
|
|
|
@@ -1583,7 +1610,7 @@ func (h *Handler) RequestAntigravityToken(c *gin.Context) {
|
|
|
|
|
fmt.Println("You can now use Antigravity services through this CLI")
|
|
|
|
|
}()
|
|
|
|
|
|
|
|
|
|
oauthStatus[state] = ""
|
|
|
|
|
setOAuthStatus(state, "")
|
|
|
|
|
c.JSON(200, gin.H{"status": "ok", "url": authURL, "state": state})
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
@@ -1609,7 +1636,7 @@ func (h *Handler) RequestQwenToken(c *gin.Context) {
|
|
|
|
|
fmt.Println("Waiting for authentication...")
|
|
|
|
|
tokenData, errPollForToken := qwenAuth.PollForToken(deviceFlow.DeviceCode, deviceFlow.CodeVerifier)
|
|
|
|
|
if errPollForToken != nil {
|
|
|
|
|
oauthStatus[state] = "Authentication failed"
|
|
|
|
|
setOAuthStatus(state, "Authentication failed")
|
|
|
|
|
fmt.Printf("Authentication failed: %v\n", errPollForToken)
|
|
|
|
|
return
|
|
|
|
|
}
|
|
|
|
|
@@ -1628,16 +1655,16 @@ func (h *Handler) RequestQwenToken(c *gin.Context) {
|
|
|
|
|
savedPath, errSave := h.saveTokenRecord(ctx, record)
|
|
|
|
|
if errSave != nil {
|
|
|
|
|
log.Errorf("Failed to save authentication tokens: %v", errSave)
|
|
|
|
|
oauthStatus[state] = "Failed to save authentication tokens"
|
|
|
|
|
setOAuthStatus(state, "Failed to save authentication tokens")
|
|
|
|
|
return
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
fmt.Printf("Authentication successful! Token saved to %s\n", savedPath)
|
|
|
|
|
fmt.Println("You can now use Qwen services through this CLI")
|
|
|
|
|
delete(oauthStatus, state)
|
|
|
|
|
deleteOAuthStatus(state)
|
|
|
|
|
}()
|
|
|
|
|
|
|
|
|
|
oauthStatus[state] = ""
|
|
|
|
|
setOAuthStatus(state, "")
|
|
|
|
|
c.JSON(200, gin.H{"status": "ok", "url": authURL, "state": state})
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
@@ -1676,7 +1703,7 @@ func (h *Handler) RequestIFlowToken(c *gin.Context) {
|
|
|
|
|
var resultMap map[string]string
|
|
|
|
|
for {
|
|
|
|
|
if time.Now().After(deadline) {
|
|
|
|
|
oauthStatus[state] = "Authentication failed"
|
|
|
|
|
setOAuthStatus(state, "Authentication failed")
|
|
|
|
|
fmt.Println("Authentication failed: timeout waiting for callback")
|
|
|
|
|
return
|
|
|
|
|
}
|
|
|
|
|
@@ -1689,26 +1716,26 @@ func (h *Handler) RequestIFlowToken(c *gin.Context) {
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
if errStr := strings.TrimSpace(resultMap["error"]); errStr != "" {
|
|
|
|
|
oauthStatus[state] = "Authentication failed"
|
|
|
|
|
setOAuthStatus(state, "Authentication failed")
|
|
|
|
|
fmt.Printf("Authentication failed: %s\n", errStr)
|
|
|
|
|
return
|
|
|
|
|
}
|
|
|
|
|
if resultState := strings.TrimSpace(resultMap["state"]); resultState != state {
|
|
|
|
|
oauthStatus[state] = "Authentication failed"
|
|
|
|
|
setOAuthStatus(state, "Authentication failed")
|
|
|
|
|
fmt.Println("Authentication failed: state mismatch")
|
|
|
|
|
return
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
code := strings.TrimSpace(resultMap["code"])
|
|
|
|
|
if code == "" {
|
|
|
|
|
oauthStatus[state] = "Authentication failed"
|
|
|
|
|
setOAuthStatus(state, "Authentication failed")
|
|
|
|
|
fmt.Println("Authentication failed: code missing")
|
|
|
|
|
return
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
tokenData, errExchange := authSvc.ExchangeCodeForTokens(ctx, code, redirectURI)
|
|
|
|
|
if errExchange != nil {
|
|
|
|
|
oauthStatus[state] = "Authentication failed"
|
|
|
|
|
setOAuthStatus(state, "Authentication failed")
|
|
|
|
|
fmt.Printf("Authentication failed: %v\n", errExchange)
|
|
|
|
|
return
|
|
|
|
|
}
|
|
|
|
|
@@ -1730,8 +1757,8 @@ func (h *Handler) RequestIFlowToken(c *gin.Context) {
|
|
|
|
|
|
|
|
|
|
savedPath, errSave := h.saveTokenRecord(ctx, record)
|
|
|
|
|
if errSave != nil {
|
|
|
|
|
oauthStatus[state] = "Failed to save authentication tokens"
|
|
|
|
|
log.Errorf("Failed to save authentication tokens: %v", errSave)
|
|
|
|
|
setOAuthStatus(state, "Failed to save authentication tokens")
|
|
|
|
|
return
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
@@ -1740,10 +1767,10 @@ func (h *Handler) RequestIFlowToken(c *gin.Context) {
|
|
|
|
|
fmt.Println("API key obtained and saved")
|
|
|
|
|
}
|
|
|
|
|
fmt.Println("You can now use iFlow services through this CLI")
|
|
|
|
|
delete(oauthStatus, state)
|
|
|
|
|
deleteOAuthStatus(state)
|
|
|
|
|
}()
|
|
|
|
|
|
|
|
|
|
oauthStatus[state] = ""
|
|
|
|
|
setOAuthStatus(state, "")
|
|
|
|
|
c.JSON(http.StatusOK, gin.H{"status": "ok", "url": authURL, "state": state})
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
@@ -2180,9 +2207,35 @@ func checkCloudAPIIsEnabled(ctx context.Context, httpClient *http.Client, projec
|
|
|
|
|
|
|
|
|
|
func (h *Handler) GetAuthStatus(c *gin.Context) {
|
|
|
|
|
state := c.Query("state")
|
|
|
|
|
if err, ok := oauthStatus[state]; ok {
|
|
|
|
|
if err != "" {
|
|
|
|
|
c.JSON(200, gin.H{"status": "error", "error": err})
|
|
|
|
|
if statusValue, ok := getOAuthStatus(state); ok {
|
|
|
|
|
if statusValue != "" {
|
|
|
|
|
// Check for device_code prefix (Kiro AWS Builder ID flow)
|
|
|
|
|
// Format: "device_code|verification_url|user_code"
|
|
|
|
|
// Using "|" as separator because URLs contain ":"
|
|
|
|
|
if strings.HasPrefix(statusValue, "device_code|") {
|
|
|
|
|
parts := strings.SplitN(statusValue, "|", 3)
|
|
|
|
|
if len(parts) == 3 {
|
|
|
|
|
c.JSON(200, gin.H{
|
|
|
|
|
"status": "device_code",
|
|
|
|
|
"verification_url": parts[1],
|
|
|
|
|
"user_code": parts[2],
|
|
|
|
|
})
|
|
|
|
|
return
|
|
|
|
|
}
|
|
|
|
|
}
|
|
|
|
|
// Check for auth_url prefix (Kiro social auth flow)
|
|
|
|
|
// Format: "auth_url|url"
|
|
|
|
|
// Using "|" as separator because URLs contain ":"
|
|
|
|
|
if strings.HasPrefix(statusValue, "auth_url|") {
|
|
|
|
|
authURL := strings.TrimPrefix(statusValue, "auth_url|")
|
|
|
|
|
c.JSON(200, gin.H{
|
|
|
|
|
"status": "auth_url",
|
|
|
|
|
"url": authURL,
|
|
|
|
|
})
|
|
|
|
|
return
|
|
|
|
|
}
|
|
|
|
|
// Otherwise treat as error
|
|
|
|
|
c.JSON(200, gin.H{"status": "error", "error": statusValue})
|
|
|
|
|
} else {
|
|
|
|
|
c.JSON(200, gin.H{"status": "wait"})
|
|
|
|
|
return
|
|
|
|
|
@@ -2190,5 +2243,297 @@ func (h *Handler) GetAuthStatus(c *gin.Context) {
|
|
|
|
|
} else {
|
|
|
|
|
c.JSON(200, gin.H{"status": "ok"})
|
|
|
|
|
}
|
|
|
|
|
delete(oauthStatus, state)
|
|
|
|
|
deleteOAuthStatus(state)
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
const kiroCallbackPort = 9876
|
|
|
|
|
|
|
|
|
|
func (h *Handler) RequestKiroToken(c *gin.Context) {
|
|
|
|
|
ctx := context.Background()
|
|
|
|
|
|
|
|
|
|
// Get the login method from query parameter (default: aws for device code flow)
|
|
|
|
|
method := strings.ToLower(strings.TrimSpace(c.Query("method")))
|
|
|
|
|
if method == "" {
|
|
|
|
|
method = "aws"
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
fmt.Println("Initializing Kiro authentication...")
|
|
|
|
|
|
|
|
|
|
state := fmt.Sprintf("kiro-%d", time.Now().UnixNano())
|
|
|
|
|
|
|
|
|
|
switch method {
|
|
|
|
|
case "aws", "builder-id":
|
|
|
|
|
// AWS Builder ID uses device code flow (no callback needed)
|
|
|
|
|
go func() {
|
|
|
|
|
ssoClient := kiroauth.NewSSOOIDCClient(h.cfg)
|
|
|
|
|
|
|
|
|
|
// Step 1: Register client
|
|
|
|
|
fmt.Println("Registering client...")
|
|
|
|
|
regResp, err := ssoClient.RegisterClient(ctx)
|
|
|
|
|
if err != nil {
|
|
|
|
|
log.Errorf("Failed to register client: %v", err)
|
|
|
|
|
setOAuthStatus(state, "Failed to register client")
|
|
|
|
|
return
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
// Step 2: Start device authorization
|
|
|
|
|
fmt.Println("Starting device authorization...")
|
|
|
|
|
authResp, err := ssoClient.StartDeviceAuthorization(ctx, regResp.ClientID, regResp.ClientSecret)
|
|
|
|
|
if err != nil {
|
|
|
|
|
log.Errorf("Failed to start device auth: %v", err)
|
|
|
|
|
setOAuthStatus(state, "Failed to start device authorization")
|
|
|
|
|
return
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
// Store the verification URL for the frontend to display
|
|
|
|
|
// Using "|" as separator because URLs contain ":"
|
|
|
|
|
setOAuthStatus(state, "device_code|"+authResp.VerificationURIComplete+"|"+authResp.UserCode)
|
|
|
|
|
|
|
|
|
|
// Step 3: Poll for token
|
|
|
|
|
fmt.Println("Waiting for authorization...")
|
|
|
|
|
interval := 5 * time.Second
|
|
|
|
|
if authResp.Interval > 0 {
|
|
|
|
|
interval = time.Duration(authResp.Interval) * time.Second
|
|
|
|
|
}
|
|
|
|
|
deadline := time.Now().Add(time.Duration(authResp.ExpiresIn) * time.Second)
|
|
|
|
|
|
|
|
|
|
for time.Now().Before(deadline) {
|
|
|
|
|
select {
|
|
|
|
|
case <-ctx.Done():
|
|
|
|
|
setOAuthStatus(state, "Authorization cancelled")
|
|
|
|
|
return
|
|
|
|
|
case <-time.After(interval):
|
|
|
|
|
tokenResp, err := ssoClient.CreateToken(ctx, regResp.ClientID, regResp.ClientSecret, authResp.DeviceCode)
|
|
|
|
|
if err != nil {
|
|
|
|
|
errStr := err.Error()
|
|
|
|
|
if strings.Contains(errStr, "authorization_pending") {
|
|
|
|
|
continue
|
|
|
|
|
}
|
|
|
|
|
if strings.Contains(errStr, "slow_down") {
|
|
|
|
|
interval += 5 * time.Second
|
|
|
|
|
continue
|
|
|
|
|
}
|
|
|
|
|
log.Errorf("Token creation failed: %v", err)
|
|
|
|
|
setOAuthStatus(state, "Token creation failed")
|
|
|
|
|
return
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
// Success! Save the token
|
|
|
|
|
expiresAt := time.Now().Add(time.Duration(tokenResp.ExpiresIn) * time.Second)
|
|
|
|
|
email := kiroauth.ExtractEmailFromJWT(tokenResp.AccessToken)
|
|
|
|
|
|
|
|
|
|
idPart := kiroauth.SanitizeEmailForFilename(email)
|
|
|
|
|
if idPart == "" {
|
|
|
|
|
idPart = fmt.Sprintf("%d", time.Now().UnixNano()%100000)
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
now := time.Now()
|
|
|
|
|
fileName := fmt.Sprintf("kiro-aws-%s.json", idPart)
|
|
|
|
|
|
|
|
|
|
record := &coreauth.Auth{
|
|
|
|
|
ID: fileName,
|
|
|
|
|
Provider: "kiro",
|
|
|
|
|
FileName: fileName,
|
|
|
|
|
Metadata: map[string]any{
|
|
|
|
|
"type": "kiro",
|
|
|
|
|
"access_token": tokenResp.AccessToken,
|
|
|
|
|
"refresh_token": tokenResp.RefreshToken,
|
|
|
|
|
"expires_at": expiresAt.Format(time.RFC3339),
|
|
|
|
|
"auth_method": "builder-id",
|
|
|
|
|
"provider": "AWS",
|
|
|
|
|
"client_id": regResp.ClientID,
|
|
|
|
|
"client_secret": regResp.ClientSecret,
|
|
|
|
|
"email": email,
|
|
|
|
|
"last_refresh": now.Format(time.RFC3339),
|
|
|
|
|
},
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
savedPath, errSave := h.saveTokenRecord(ctx, record)
|
|
|
|
|
if errSave != nil {
|
|
|
|
|
log.Errorf("Failed to save authentication tokens: %v", errSave)
|
|
|
|
|
setOAuthStatus(state, "Failed to save authentication tokens")
|
|
|
|
|
return
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
fmt.Printf("Authentication successful! Token saved to %s\n", savedPath)
|
|
|
|
|
if email != "" {
|
|
|
|
|
fmt.Printf("Authenticated as: %s\n", email)
|
|
|
|
|
}
|
|
|
|
|
deleteOAuthStatus(state)
|
|
|
|
|
return
|
|
|
|
|
}
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
setOAuthStatus(state, "Authorization timed out")
|
|
|
|
|
}()
|
|
|
|
|
|
|
|
|
|
// Return immediately with the state for polling
|
|
|
|
|
c.JSON(200, gin.H{"status": "ok", "state": state, "method": "device_code"})
|
|
|
|
|
|
|
|
|
|
case "google", "github":
|
|
|
|
|
// Social auth uses protocol handler - for WEB UI we use a callback forwarder
|
|
|
|
|
provider := "Google"
|
|
|
|
|
if method == "github" {
|
|
|
|
|
provider = "Github"
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
isWebUI := isWebUIRequest(c)
|
|
|
|
|
if isWebUI {
|
|
|
|
|
targetURL, errTarget := h.managementCallbackURL("/kiro/callback")
|
|
|
|
|
if errTarget != nil {
|
|
|
|
|
log.WithError(errTarget).Error("failed to compute kiro callback target")
|
|
|
|
|
c.JSON(http.StatusInternalServerError, gin.H{"error": "callback server unavailable"})
|
|
|
|
|
return
|
|
|
|
|
}
|
|
|
|
|
if _, errStart := startCallbackForwarder(kiroCallbackPort, "kiro", targetURL); errStart != nil {
|
|
|
|
|
log.WithError(errStart).Error("failed to start kiro callback forwarder")
|
|
|
|
|
c.JSON(http.StatusInternalServerError, gin.H{"error": "failed to start callback server"})
|
|
|
|
|
return
|
|
|
|
|
}
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
go func() {
|
|
|
|
|
if isWebUI {
|
|
|
|
|
defer stopCallbackForwarder(kiroCallbackPort)
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
socialClient := kiroauth.NewSocialAuthClient(h.cfg)
|
|
|
|
|
|
|
|
|
|
// Generate PKCE codes
|
|
|
|
|
codeVerifier, codeChallenge, err := generateKiroPKCE()
|
|
|
|
|
if err != nil {
|
|
|
|
|
log.Errorf("Failed to generate PKCE: %v", err)
|
|
|
|
|
setOAuthStatus(state, "Failed to generate PKCE")
|
|
|
|
|
return
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
// Build login URL
|
|
|
|
|
authURL := fmt.Sprintf("%s/login?idp=%s&redirect_uri=%s&code_challenge=%s&code_challenge_method=S256&state=%s&prompt=select_account",
|
|
|
|
|
"https://prod.us-east-1.auth.desktop.kiro.dev",
|
|
|
|
|
provider,
|
|
|
|
|
url.QueryEscape(kiroauth.KiroRedirectURI),
|
|
|
|
|
codeChallenge,
|
|
|
|
|
state,
|
|
|
|
|
)
|
|
|
|
|
|
|
|
|
|
// Store auth URL for frontend
|
|
|
|
|
// Using "|" as separator because URLs contain ":"
|
|
|
|
|
setOAuthStatus(state, "auth_url|"+authURL)
|
|
|
|
|
|
|
|
|
|
// Wait for callback file
|
|
|
|
|
waitFile := filepath.Join(h.cfg.AuthDir, fmt.Sprintf(".oauth-kiro-%s.oauth", state))
|
|
|
|
|
deadline := time.Now().Add(5 * time.Minute)
|
|
|
|
|
|
|
|
|
|
for {
|
|
|
|
|
if time.Now().After(deadline) {
|
|
|
|
|
log.Error("oauth flow timed out")
|
|
|
|
|
setOAuthStatus(state, "OAuth flow timed out")
|
|
|
|
|
return
|
|
|
|
|
}
|
|
|
|
|
if data, errR := os.ReadFile(waitFile); errR == nil {
|
|
|
|
|
var m map[string]string
|
|
|
|
|
_ = json.Unmarshal(data, &m)
|
|
|
|
|
_ = os.Remove(waitFile)
|
|
|
|
|
if errStr := m["error"]; errStr != "" {
|
|
|
|
|
log.Errorf("Authentication failed: %s", errStr)
|
|
|
|
|
setOAuthStatus(state, "Authentication failed")
|
|
|
|
|
return
|
|
|
|
|
}
|
|
|
|
|
if m["state"] != state {
|
|
|
|
|
log.Errorf("State mismatch")
|
|
|
|
|
setOAuthStatus(state, "State mismatch")
|
|
|
|
|
return
|
|
|
|
|
}
|
|
|
|
|
code := m["code"]
|
|
|
|
|
if code == "" {
|
|
|
|
|
log.Error("No authorization code received")
|
|
|
|
|
setOAuthStatus(state, "No authorization code received")
|
|
|
|
|
return
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
// Exchange code for tokens
|
|
|
|
|
tokenReq := &kiroauth.CreateTokenRequest{
|
|
|
|
|
Code: code,
|
|
|
|
|
CodeVerifier: codeVerifier,
|
|
|
|
|
RedirectURI: kiroauth.KiroRedirectURI,
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
tokenResp, errToken := socialClient.CreateToken(ctx, tokenReq)
|
|
|
|
|
if errToken != nil {
|
|
|
|
|
log.Errorf("Failed to exchange code for tokens: %v", errToken)
|
|
|
|
|
setOAuthStatus(state, "Failed to exchange code for tokens")
|
|
|
|
|
return
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
// Save the token
|
|
|
|
|
expiresIn := tokenResp.ExpiresIn
|
|
|
|
|
if expiresIn <= 0 {
|
|
|
|
|
expiresIn = 3600
|
|
|
|
|
}
|
|
|
|
|
expiresAt := time.Now().Add(time.Duration(expiresIn) * time.Second)
|
|
|
|
|
email := kiroauth.ExtractEmailFromJWT(tokenResp.AccessToken)
|
|
|
|
|
|
|
|
|
|
idPart := kiroauth.SanitizeEmailForFilename(email)
|
|
|
|
|
if idPart == "" {
|
|
|
|
|
idPart = fmt.Sprintf("%d", time.Now().UnixNano()%100000)
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
now := time.Now()
|
|
|
|
|
fileName := fmt.Sprintf("kiro-%s-%s.json", strings.ToLower(provider), idPart)
|
|
|
|
|
|
|
|
|
|
record := &coreauth.Auth{
|
|
|
|
|
ID: fileName,
|
|
|
|
|
Provider: "kiro",
|
|
|
|
|
FileName: fileName,
|
|
|
|
|
Metadata: map[string]any{
|
|
|
|
|
"type": "kiro",
|
|
|
|
|
"access_token": tokenResp.AccessToken,
|
|
|
|
|
"refresh_token": tokenResp.RefreshToken,
|
|
|
|
|
"profile_arn": tokenResp.ProfileArn,
|
|
|
|
|
"expires_at": expiresAt.Format(time.RFC3339),
|
|
|
|
|
"auth_method": "social",
|
|
|
|
|
"provider": provider,
|
|
|
|
|
"email": email,
|
|
|
|
|
"last_refresh": now.Format(time.RFC3339),
|
|
|
|
|
},
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
savedPath, errSave := h.saveTokenRecord(ctx, record)
|
|
|
|
|
if errSave != nil {
|
|
|
|
|
log.Errorf("Failed to save authentication tokens: %v", errSave)
|
|
|
|
|
setOAuthStatus(state, "Failed to save authentication tokens")
|
|
|
|
|
return
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
fmt.Printf("Authentication successful! Token saved to %s\n", savedPath)
|
|
|
|
|
if email != "" {
|
|
|
|
|
fmt.Printf("Authenticated as: %s\n", email)
|
|
|
|
|
}
|
|
|
|
|
deleteOAuthStatus(state)
|
|
|
|
|
return
|
|
|
|
|
}
|
|
|
|
|
time.Sleep(500 * time.Millisecond)
|
|
|
|
|
}
|
|
|
|
|
}()
|
|
|
|
|
|
|
|
|
|
setOAuthStatus(state, "")
|
|
|
|
|
c.JSON(200, gin.H{"status": "ok", "state": state, "method": "social"})
|
|
|
|
|
|
|
|
|
|
default:
|
|
|
|
|
c.JSON(http.StatusBadRequest, gin.H{"error": "invalid method, use 'aws', 'google', or 'github'"})
|
|
|
|
|
}
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
// generateKiroPKCE generates PKCE code verifier and challenge for Kiro OAuth.
|
|
|
|
|
func generateKiroPKCE() (verifier, challenge string, err error) {
|
|
|
|
|
b := make([]byte, 32)
|
|
|
|
|
if _, err := io.ReadFull(rand.Reader, b); err != nil {
|
|
|
|
|
return "", "", fmt.Errorf("failed to generate random bytes: %w", err)
|
|
|
|
|
}
|
|
|
|
|
verifier = base64.RawURLEncoding.EncodeToString(b)
|
|
|
|
|
|
|
|
|
|
h := sha256.Sum256([]byte(verifier))
|
|
|
|
|
challenge = base64.RawURLEncoding.EncodeToString(h[:])
|
|
|
|
|
|
|
|
|
|
return verifier, challenge, nil
|
|
|
|
|
}
|
|
|
|
|
|