mirror of
https://github.com/router-for-me/CLIProxyAPIPlus.git
synced 2026-03-30 01:06:39 +00:00
Merge branch 'main' into plus
This commit is contained in:
@@ -242,6 +242,11 @@ func (s *OAuthServer) handleSuccess(w http.ResponseWriter, r *http.Request) {
|
||||
platformURL = "https://console.anthropic.com/"
|
||||
}
|
||||
|
||||
// Validate platformURL to prevent XSS - only allow http/https URLs
|
||||
if !isValidURL(platformURL) {
|
||||
platformURL = "https://console.anthropic.com/"
|
||||
}
|
||||
|
||||
// Generate success page HTML with dynamic content
|
||||
successHTML := s.generateSuccessHTML(setupRequired, platformURL)
|
||||
|
||||
@@ -251,6 +256,12 @@ func (s *OAuthServer) handleSuccess(w http.ResponseWriter, r *http.Request) {
|
||||
}
|
||||
}
|
||||
|
||||
// isValidURL checks if the URL is a valid http/https URL to prevent XSS
|
||||
func isValidURL(urlStr string) bool {
|
||||
urlStr = strings.TrimSpace(urlStr)
|
||||
return strings.HasPrefix(urlStr, "https://") || strings.HasPrefix(urlStr, "http://")
|
||||
}
|
||||
|
||||
// generateSuccessHTML creates the HTML content for the success page.
|
||||
// It customizes the page based on whether additional setup is required
|
||||
// and includes a link to the platform.
|
||||
|
||||
@@ -239,6 +239,11 @@ func (s *OAuthServer) handleSuccess(w http.ResponseWriter, r *http.Request) {
|
||||
platformURL = "https://platform.openai.com"
|
||||
}
|
||||
|
||||
// Validate platformURL to prevent XSS - only allow http/https URLs
|
||||
if !isValidURL(platformURL) {
|
||||
platformURL = "https://platform.openai.com"
|
||||
}
|
||||
|
||||
// Generate success page HTML with dynamic content
|
||||
successHTML := s.generateSuccessHTML(setupRequired, platformURL)
|
||||
|
||||
@@ -248,6 +253,12 @@ func (s *OAuthServer) handleSuccess(w http.ResponseWriter, r *http.Request) {
|
||||
}
|
||||
}
|
||||
|
||||
// isValidURL checks if the URL is a valid http/https URL to prevent XSS
|
||||
func isValidURL(urlStr string) bool {
|
||||
urlStr = strings.TrimSpace(urlStr)
|
||||
return strings.HasPrefix(urlStr, "https://") || strings.HasPrefix(urlStr, "http://")
|
||||
}
|
||||
|
||||
// generateSuccessHTML creates the HTML content for the success page.
|
||||
// It customizes the page based on whether additional setup is required
|
||||
// and includes a link to the platform.
|
||||
|
||||
326
internal/auth/copilot/copilot_auth.go
Normal file
326
internal/auth/copilot/copilot_auth.go
Normal file
@@ -0,0 +1,326 @@
|
||||
// Package copilot provides authentication and token management for GitHub Copilot API.
|
||||
// It handles the OAuth2 device flow for secure authentication with the Copilot API.
|
||||
package copilot
|
||||
|
||||
import (
|
||||
"context"
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"io"
|
||||
"net/http"
|
||||
"net/url"
|
||||
"strings"
|
||||
"time"
|
||||
|
||||
"github.com/router-for-me/CLIProxyAPI/v6/internal/config"
|
||||
"github.com/router-for-me/CLIProxyAPI/v6/internal/util"
|
||||
log "github.com/sirupsen/logrus"
|
||||
)
|
||||
|
||||
const (
|
||||
// copilotAPITokenURL is the endpoint for getting Copilot API tokens from GitHub token.
|
||||
copilotAPITokenURL = "https://api.github.com/copilot_internal/v2/token"
|
||||
// copilotAPIEndpoint is the base URL for making API requests.
|
||||
copilotAPIEndpoint = "https://api.githubcopilot.com"
|
||||
|
||||
// Common HTTP header values for Copilot API requests.
|
||||
copilotUserAgent = "GithubCopilot/1.0"
|
||||
copilotEditorVersion = "vscode/1.100.0"
|
||||
copilotPluginVersion = "copilot/1.300.0"
|
||||
copilotIntegrationID = "vscode-chat"
|
||||
copilotOpenAIIntent = "conversation-panel"
|
||||
)
|
||||
|
||||
// CopilotAPIToken represents the Copilot API token response.
|
||||
type CopilotAPIToken struct {
|
||||
// Token is the JWT token for authenticating with the Copilot API.
|
||||
Token string `json:"token"`
|
||||
// ExpiresAt is the Unix timestamp when the token expires.
|
||||
ExpiresAt int64 `json:"expires_at"`
|
||||
// Endpoints contains the available API endpoints.
|
||||
Endpoints struct {
|
||||
API string `json:"api"`
|
||||
Proxy string `json:"proxy"`
|
||||
OriginTracker string `json:"origin-tracker"`
|
||||
Telemetry string `json:"telemetry"`
|
||||
} `json:"endpoints,omitempty"`
|
||||
// ErrorDetails contains error information if the request failed.
|
||||
ErrorDetails *struct {
|
||||
URL string `json:"url"`
|
||||
Message string `json:"message"`
|
||||
DocumentationURL string `json:"documentation_url"`
|
||||
} `json:"error_details,omitempty"`
|
||||
}
|
||||
|
||||
// CopilotAuth handles GitHub Copilot authentication flow.
|
||||
// It provides methods for device flow authentication and token management.
|
||||
type CopilotAuth struct {
|
||||
httpClient *http.Client
|
||||
deviceClient *DeviceFlowClient
|
||||
cfg *config.Config
|
||||
}
|
||||
|
||||
// NewCopilotAuth creates a new CopilotAuth service instance.
|
||||
// It initializes an HTTP client with proxy settings from the provided configuration.
|
||||
func NewCopilotAuth(cfg *config.Config) *CopilotAuth {
|
||||
return &CopilotAuth{
|
||||
httpClient: util.SetProxy(&cfg.SDKConfig, &http.Client{Timeout: 30 * time.Second}),
|
||||
deviceClient: NewDeviceFlowClient(cfg),
|
||||
cfg: cfg,
|
||||
}
|
||||
}
|
||||
|
||||
// StartDeviceFlow initiates the device flow authentication.
|
||||
// Returns the device code response containing the user code and verification URI.
|
||||
func (c *CopilotAuth) StartDeviceFlow(ctx context.Context) (*DeviceCodeResponse, error) {
|
||||
return c.deviceClient.RequestDeviceCode(ctx)
|
||||
}
|
||||
|
||||
// WaitForAuthorization polls for user authorization and returns the auth bundle.
|
||||
func (c *CopilotAuth) WaitForAuthorization(ctx context.Context, deviceCode *DeviceCodeResponse) (*CopilotAuthBundle, error) {
|
||||
tokenData, err := c.deviceClient.PollForToken(ctx, deviceCode)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
// Fetch the GitHub username
|
||||
userInfo, err := c.deviceClient.FetchUserInfo(ctx, tokenData.AccessToken)
|
||||
if err != nil {
|
||||
log.Warnf("copilot: failed to fetch user info: %v", err)
|
||||
}
|
||||
|
||||
username := userInfo.Login
|
||||
if username == "" {
|
||||
username = "github-user"
|
||||
}
|
||||
|
||||
return &CopilotAuthBundle{
|
||||
TokenData: tokenData,
|
||||
Username: username,
|
||||
Email: userInfo.Email,
|
||||
Name: userInfo.Name,
|
||||
}, nil
|
||||
}
|
||||
|
||||
// GetCopilotAPIToken exchanges a GitHub access token for a Copilot API token.
|
||||
// This token is used to make authenticated requests to the Copilot API.
|
||||
func (c *CopilotAuth) GetCopilotAPIToken(ctx context.Context, githubAccessToken string) (*CopilotAPIToken, error) {
|
||||
if githubAccessToken == "" {
|
||||
return nil, NewAuthenticationError(ErrTokenExchangeFailed, fmt.Errorf("github access token is empty"))
|
||||
}
|
||||
|
||||
req, err := http.NewRequestWithContext(ctx, http.MethodGet, copilotAPITokenURL, nil)
|
||||
if err != nil {
|
||||
return nil, NewAuthenticationError(ErrTokenExchangeFailed, err)
|
||||
}
|
||||
|
||||
req.Header.Set("Authorization", "token "+githubAccessToken)
|
||||
req.Header.Set("Accept", "application/json")
|
||||
req.Header.Set("User-Agent", copilotUserAgent)
|
||||
req.Header.Set("Editor-Version", copilotEditorVersion)
|
||||
req.Header.Set("Editor-Plugin-Version", copilotPluginVersion)
|
||||
|
||||
resp, err := c.httpClient.Do(req)
|
||||
if err != nil {
|
||||
return nil, NewAuthenticationError(ErrTokenExchangeFailed, err)
|
||||
}
|
||||
defer func() {
|
||||
if errClose := resp.Body.Close(); errClose != nil {
|
||||
log.Errorf("copilot api token: close body error: %v", errClose)
|
||||
}
|
||||
}()
|
||||
|
||||
bodyBytes, err := io.ReadAll(resp.Body)
|
||||
if err != nil {
|
||||
return nil, NewAuthenticationError(ErrTokenExchangeFailed, err)
|
||||
}
|
||||
|
||||
if !isHTTPSuccess(resp.StatusCode) {
|
||||
return nil, NewAuthenticationError(ErrTokenExchangeFailed,
|
||||
fmt.Errorf("status %d: %s", resp.StatusCode, string(bodyBytes)))
|
||||
}
|
||||
|
||||
var apiToken CopilotAPIToken
|
||||
if err = json.Unmarshal(bodyBytes, &apiToken); err != nil {
|
||||
return nil, NewAuthenticationError(ErrTokenExchangeFailed, err)
|
||||
}
|
||||
|
||||
if apiToken.Token == "" {
|
||||
return nil, NewAuthenticationError(ErrTokenExchangeFailed, fmt.Errorf("empty copilot api token"))
|
||||
}
|
||||
|
||||
return &apiToken, nil
|
||||
}
|
||||
|
||||
// ValidateToken checks if a GitHub access token is valid by attempting to fetch user info.
|
||||
func (c *CopilotAuth) ValidateToken(ctx context.Context, accessToken string) (bool, string, error) {
|
||||
if accessToken == "" {
|
||||
return false, "", nil
|
||||
}
|
||||
|
||||
userInfo, err := c.deviceClient.FetchUserInfo(ctx, accessToken)
|
||||
if err != nil {
|
||||
return false, "", err
|
||||
}
|
||||
|
||||
return true, userInfo.Login, nil
|
||||
}
|
||||
|
||||
// CreateTokenStorage creates a new CopilotTokenStorage from auth bundle.
|
||||
func (c *CopilotAuth) CreateTokenStorage(bundle *CopilotAuthBundle) *CopilotTokenStorage {
|
||||
return &CopilotTokenStorage{
|
||||
AccessToken: bundle.TokenData.AccessToken,
|
||||
TokenType: bundle.TokenData.TokenType,
|
||||
Scope: bundle.TokenData.Scope,
|
||||
Username: bundle.Username,
|
||||
Email: bundle.Email,
|
||||
Name: bundle.Name,
|
||||
Type: "github-copilot",
|
||||
}
|
||||
}
|
||||
|
||||
// LoadAndValidateToken loads a token from storage and validates it.
|
||||
// Returns the storage if valid, or an error if the token is invalid or expired.
|
||||
func (c *CopilotAuth) LoadAndValidateToken(ctx context.Context, storage *CopilotTokenStorage) (bool, error) {
|
||||
if storage == nil || storage.AccessToken == "" {
|
||||
return false, fmt.Errorf("no token available")
|
||||
}
|
||||
|
||||
// Check if we can still use the GitHub token to get a Copilot API token
|
||||
apiToken, err := c.GetCopilotAPIToken(ctx, storage.AccessToken)
|
||||
if err != nil {
|
||||
return false, err
|
||||
}
|
||||
|
||||
// Check if the API token is expired
|
||||
if apiToken.ExpiresAt > 0 && time.Now().Unix() >= apiToken.ExpiresAt {
|
||||
return false, fmt.Errorf("copilot api token expired")
|
||||
}
|
||||
|
||||
return true, nil
|
||||
}
|
||||
|
||||
// GetAPIEndpoint returns the Copilot API endpoint URL.
|
||||
func (c *CopilotAuth) GetAPIEndpoint() string {
|
||||
return copilotAPIEndpoint
|
||||
}
|
||||
|
||||
// MakeAuthenticatedRequest creates an authenticated HTTP request to the Copilot API.
|
||||
func (c *CopilotAuth) MakeAuthenticatedRequest(ctx context.Context, method, url string, body io.Reader, apiToken *CopilotAPIToken) (*http.Request, error) {
|
||||
req, err := http.NewRequestWithContext(ctx, method, url, body)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to create request: %w", err)
|
||||
}
|
||||
|
||||
req.Header.Set("Authorization", "Bearer "+apiToken.Token)
|
||||
req.Header.Set("Content-Type", "application/json")
|
||||
req.Header.Set("Accept", "application/json")
|
||||
req.Header.Set("User-Agent", copilotUserAgent)
|
||||
req.Header.Set("Editor-Version", copilotEditorVersion)
|
||||
req.Header.Set("Editor-Plugin-Version", copilotPluginVersion)
|
||||
req.Header.Set("Openai-Intent", copilotOpenAIIntent)
|
||||
req.Header.Set("Copilot-Integration-Id", copilotIntegrationID)
|
||||
|
||||
return req, nil
|
||||
}
|
||||
|
||||
// CopilotModelEntry represents a single model entry returned by the Copilot /models API.
|
||||
type CopilotModelEntry struct {
|
||||
ID string `json:"id"`
|
||||
Object string `json:"object"`
|
||||
Created int64 `json:"created"`
|
||||
OwnedBy string `json:"owned_by"`
|
||||
Name string `json:"name,omitempty"`
|
||||
Version string `json:"version,omitempty"`
|
||||
Capabilities map[string]any `json:"capabilities,omitempty"`
|
||||
}
|
||||
|
||||
// CopilotModelsResponse represents the response from the Copilot /models endpoint.
|
||||
type CopilotModelsResponse struct {
|
||||
Data []CopilotModelEntry `json:"data"`
|
||||
Object string `json:"object"`
|
||||
}
|
||||
|
||||
// maxModelsResponseSize is the maximum allowed response size from the /models endpoint (2 MB).
|
||||
const maxModelsResponseSize = 2 * 1024 * 1024
|
||||
|
||||
// allowedCopilotAPIHosts is the set of hosts that are considered safe for Copilot API requests.
|
||||
var allowedCopilotAPIHosts = map[string]bool{
|
||||
"api.githubcopilot.com": true,
|
||||
"api.individual.githubcopilot.com": true,
|
||||
"api.business.githubcopilot.com": true,
|
||||
"copilot-proxy.githubusercontent.com": true,
|
||||
}
|
||||
|
||||
// ListModels fetches the list of available models from the Copilot API.
|
||||
// It requires a valid Copilot API token (not the GitHub access token).
|
||||
func (c *CopilotAuth) ListModels(ctx context.Context, apiToken *CopilotAPIToken) ([]CopilotModelEntry, error) {
|
||||
if apiToken == nil || apiToken.Token == "" {
|
||||
return nil, fmt.Errorf("copilot: api token is required for listing models")
|
||||
}
|
||||
|
||||
// Build models URL, validating the endpoint host to prevent SSRF.
|
||||
modelsURL := copilotAPIEndpoint + "/models"
|
||||
if ep := strings.TrimRight(apiToken.Endpoints.API, "/"); ep != "" {
|
||||
parsed, err := url.Parse(ep)
|
||||
if err == nil && parsed.Scheme == "https" && allowedCopilotAPIHosts[parsed.Host] {
|
||||
modelsURL = ep + "/models"
|
||||
} else {
|
||||
log.Warnf("copilot: ignoring untrusted API endpoint %q, using default", ep)
|
||||
}
|
||||
}
|
||||
|
||||
req, err := c.MakeAuthenticatedRequest(ctx, http.MethodGet, modelsURL, nil, apiToken)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("copilot: failed to create models request: %w", err)
|
||||
}
|
||||
|
||||
resp, err := c.httpClient.Do(req)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("copilot: models request failed: %w", err)
|
||||
}
|
||||
defer func() {
|
||||
if errClose := resp.Body.Close(); errClose != nil {
|
||||
log.Errorf("copilot list models: close body error: %v", errClose)
|
||||
}
|
||||
}()
|
||||
|
||||
// Limit response body to prevent memory exhaustion.
|
||||
limitedReader := io.LimitReader(resp.Body, maxModelsResponseSize)
|
||||
bodyBytes, err := io.ReadAll(limitedReader)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("copilot: failed to read models response: %w", err)
|
||||
}
|
||||
|
||||
if !isHTTPSuccess(resp.StatusCode) {
|
||||
return nil, fmt.Errorf("copilot: list models failed with status %d: %s", resp.StatusCode, string(bodyBytes))
|
||||
}
|
||||
|
||||
var modelsResp CopilotModelsResponse
|
||||
if err = json.Unmarshal(bodyBytes, &modelsResp); err != nil {
|
||||
return nil, fmt.Errorf("copilot: failed to parse models response: %w", err)
|
||||
}
|
||||
|
||||
return modelsResp.Data, nil
|
||||
}
|
||||
|
||||
// ListModelsWithGitHubToken is a convenience method that exchanges a GitHub access token
|
||||
// for a Copilot API token and then fetches the available models.
|
||||
func (c *CopilotAuth) ListModelsWithGitHubToken(ctx context.Context, githubAccessToken string) ([]CopilotModelEntry, error) {
|
||||
apiToken, err := c.GetCopilotAPIToken(ctx, githubAccessToken)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("copilot: failed to get API token for model listing: %w", err)
|
||||
}
|
||||
|
||||
return c.ListModels(ctx, apiToken)
|
||||
}
|
||||
|
||||
// buildChatCompletionURL builds the URL for chat completions API.
|
||||
func buildChatCompletionURL() string {
|
||||
return copilotAPIEndpoint + "/chat/completions"
|
||||
}
|
||||
|
||||
// isHTTPSuccess checks if the status code indicates success (2xx).
|
||||
func isHTTPSuccess(statusCode int) bool {
|
||||
return statusCode >= 200 && statusCode < 300
|
||||
}
|
||||
187
internal/auth/copilot/errors.go
Normal file
187
internal/auth/copilot/errors.go
Normal file
@@ -0,0 +1,187 @@
|
||||
package copilot
|
||||
|
||||
import (
|
||||
"errors"
|
||||
"fmt"
|
||||
"net/http"
|
||||
)
|
||||
|
||||
// OAuthError represents an OAuth-specific error.
|
||||
type OAuthError struct {
|
||||
// Code is the OAuth error code.
|
||||
Code string `json:"error"`
|
||||
// Description is a human-readable description of the error.
|
||||
Description string `json:"error_description,omitempty"`
|
||||
// URI is a URI identifying a human-readable web page with information about the error.
|
||||
URI string `json:"error_uri,omitempty"`
|
||||
// StatusCode is the HTTP status code associated with the error.
|
||||
StatusCode int `json:"-"`
|
||||
}
|
||||
|
||||
// Error returns a string representation of the OAuth error.
|
||||
func (e *OAuthError) Error() string {
|
||||
if e.Description != "" {
|
||||
return fmt.Sprintf("OAuth error %s: %s", e.Code, e.Description)
|
||||
}
|
||||
return fmt.Sprintf("OAuth error: %s", e.Code)
|
||||
}
|
||||
|
||||
// NewOAuthError creates a new OAuth error with the specified code, description, and status code.
|
||||
func NewOAuthError(code, description string, statusCode int) *OAuthError {
|
||||
return &OAuthError{
|
||||
Code: code,
|
||||
Description: description,
|
||||
StatusCode: statusCode,
|
||||
}
|
||||
}
|
||||
|
||||
// AuthenticationError represents authentication-related errors.
|
||||
type AuthenticationError struct {
|
||||
// Type is the type of authentication error.
|
||||
Type string `json:"type"`
|
||||
// Message is a human-readable message describing the error.
|
||||
Message string `json:"message"`
|
||||
// Code is the HTTP status code associated with the error.
|
||||
Code int `json:"code"`
|
||||
// Cause is the underlying error that caused this authentication error.
|
||||
Cause error `json:"-"`
|
||||
}
|
||||
|
||||
// Error returns a string representation of the authentication error.
|
||||
func (e *AuthenticationError) Error() string {
|
||||
if e.Cause != nil {
|
||||
return fmt.Sprintf("%s: %s (caused by: %v)", e.Type, e.Message, e.Cause)
|
||||
}
|
||||
return fmt.Sprintf("%s: %s", e.Type, e.Message)
|
||||
}
|
||||
|
||||
// Unwrap returns the underlying cause of the error.
|
||||
func (e *AuthenticationError) Unwrap() error {
|
||||
return e.Cause
|
||||
}
|
||||
|
||||
// Common authentication error types for GitHub Copilot device flow.
|
||||
var (
|
||||
// ErrDeviceCodeFailed represents an error when requesting the device code fails.
|
||||
ErrDeviceCodeFailed = &AuthenticationError{
|
||||
Type: "device_code_failed",
|
||||
Message: "Failed to request device code from GitHub",
|
||||
Code: http.StatusBadRequest,
|
||||
}
|
||||
|
||||
// ErrDeviceCodeExpired represents an error when the device code has expired.
|
||||
ErrDeviceCodeExpired = &AuthenticationError{
|
||||
Type: "device_code_expired",
|
||||
Message: "Device code has expired. Please try again.",
|
||||
Code: http.StatusGone,
|
||||
}
|
||||
|
||||
// ErrAuthorizationPending represents a pending authorization state (not an error, used for polling).
|
||||
ErrAuthorizationPending = &AuthenticationError{
|
||||
Type: "authorization_pending",
|
||||
Message: "Authorization is pending. Waiting for user to authorize.",
|
||||
Code: http.StatusAccepted,
|
||||
}
|
||||
|
||||
// ErrSlowDown represents a request to slow down polling.
|
||||
ErrSlowDown = &AuthenticationError{
|
||||
Type: "slow_down",
|
||||
Message: "Polling too frequently. Slowing down.",
|
||||
Code: http.StatusTooManyRequests,
|
||||
}
|
||||
|
||||
// ErrAccessDenied represents an error when the user denies authorization.
|
||||
ErrAccessDenied = &AuthenticationError{
|
||||
Type: "access_denied",
|
||||
Message: "User denied authorization",
|
||||
Code: http.StatusForbidden,
|
||||
}
|
||||
|
||||
// ErrTokenExchangeFailed represents an error when token exchange fails.
|
||||
ErrTokenExchangeFailed = &AuthenticationError{
|
||||
Type: "token_exchange_failed",
|
||||
Message: "Failed to exchange device code for access token",
|
||||
Code: http.StatusBadRequest,
|
||||
}
|
||||
|
||||
// ErrPollingTimeout represents an error when polling times out.
|
||||
ErrPollingTimeout = &AuthenticationError{
|
||||
Type: "polling_timeout",
|
||||
Message: "Timeout waiting for user authorization",
|
||||
Code: http.StatusRequestTimeout,
|
||||
}
|
||||
|
||||
// ErrUserInfoFailed represents an error when fetching user info fails.
|
||||
ErrUserInfoFailed = &AuthenticationError{
|
||||
Type: "user_info_failed",
|
||||
Message: "Failed to fetch GitHub user information",
|
||||
Code: http.StatusBadRequest,
|
||||
}
|
||||
)
|
||||
|
||||
// NewAuthenticationError creates a new authentication error with a cause based on a base error.
|
||||
func NewAuthenticationError(baseErr *AuthenticationError, cause error) *AuthenticationError {
|
||||
return &AuthenticationError{
|
||||
Type: baseErr.Type,
|
||||
Message: baseErr.Message,
|
||||
Code: baseErr.Code,
|
||||
Cause: cause,
|
||||
}
|
||||
}
|
||||
|
||||
// IsAuthenticationError checks if an error is an authentication error.
|
||||
func IsAuthenticationError(err error) bool {
|
||||
var authenticationError *AuthenticationError
|
||||
ok := errors.As(err, &authenticationError)
|
||||
return ok
|
||||
}
|
||||
|
||||
// IsOAuthError checks if an error is an OAuth error.
|
||||
func IsOAuthError(err error) bool {
|
||||
var oAuthError *OAuthError
|
||||
ok := errors.As(err, &oAuthError)
|
||||
return ok
|
||||
}
|
||||
|
||||
// GetUserFriendlyMessage returns a user-friendly error message based on the error type.
|
||||
func GetUserFriendlyMessage(err error) string {
|
||||
var authErr *AuthenticationError
|
||||
if errors.As(err, &authErr) {
|
||||
switch authErr.Type {
|
||||
case "device_code_failed":
|
||||
return "Failed to start GitHub authentication. Please check your network connection and try again."
|
||||
case "device_code_expired":
|
||||
return "The authentication code has expired. Please try again."
|
||||
case "authorization_pending":
|
||||
return "Waiting for you to authorize the application on GitHub."
|
||||
case "slow_down":
|
||||
return "Please wait a moment before trying again."
|
||||
case "access_denied":
|
||||
return "Authentication was cancelled or denied."
|
||||
case "token_exchange_failed":
|
||||
return "Failed to complete authentication. Please try again."
|
||||
case "polling_timeout":
|
||||
return "Authentication timed out. Please try again."
|
||||
case "user_info_failed":
|
||||
return "Failed to get your GitHub account information. Please try again."
|
||||
default:
|
||||
return "Authentication failed. Please try again."
|
||||
}
|
||||
}
|
||||
|
||||
var oauthErr *OAuthError
|
||||
if errors.As(err, &oauthErr) {
|
||||
switch oauthErr.Code {
|
||||
case "access_denied":
|
||||
return "Authentication was cancelled or denied."
|
||||
case "invalid_request":
|
||||
return "Invalid authentication request. Please try again."
|
||||
case "server_error":
|
||||
return "GitHub server error. Please try again later."
|
||||
default:
|
||||
return fmt.Sprintf("Authentication failed: %s", oauthErr.Description)
|
||||
}
|
||||
}
|
||||
|
||||
return "An unexpected error occurred. Please try again."
|
||||
}
|
||||
271
internal/auth/copilot/oauth.go
Normal file
271
internal/auth/copilot/oauth.go
Normal file
@@ -0,0 +1,271 @@
|
||||
package copilot
|
||||
|
||||
import (
|
||||
"context"
|
||||
"encoding/json"
|
||||
"errors"
|
||||
"fmt"
|
||||
"io"
|
||||
"net/http"
|
||||
"net/url"
|
||||
"strings"
|
||||
"time"
|
||||
|
||||
"github.com/router-for-me/CLIProxyAPI/v6/internal/config"
|
||||
"github.com/router-for-me/CLIProxyAPI/v6/internal/util"
|
||||
log "github.com/sirupsen/logrus"
|
||||
)
|
||||
|
||||
const (
|
||||
// copilotClientID is GitHub's Copilot CLI OAuth client ID.
|
||||
copilotClientID = "Iv1.b507a08c87ecfe98"
|
||||
// copilotDeviceCodeURL is the endpoint for requesting device codes.
|
||||
copilotDeviceCodeURL = "https://github.com/login/device/code"
|
||||
// copilotTokenURL is the endpoint for exchanging device codes for tokens.
|
||||
copilotTokenURL = "https://github.com/login/oauth/access_token"
|
||||
// copilotUserInfoURL is the endpoint for fetching GitHub user information.
|
||||
copilotUserInfoURL = "https://api.github.com/user"
|
||||
// defaultPollInterval is the default interval for polling token endpoint.
|
||||
defaultPollInterval = 5 * time.Second
|
||||
// maxPollDuration is the maximum time to wait for user authorization.
|
||||
maxPollDuration = 15 * time.Minute
|
||||
)
|
||||
|
||||
// DeviceFlowClient handles the OAuth2 device flow for GitHub Copilot.
|
||||
type DeviceFlowClient struct {
|
||||
httpClient *http.Client
|
||||
cfg *config.Config
|
||||
}
|
||||
|
||||
// NewDeviceFlowClient creates a new device flow client.
|
||||
func NewDeviceFlowClient(cfg *config.Config) *DeviceFlowClient {
|
||||
client := &http.Client{Timeout: 30 * time.Second}
|
||||
if cfg != nil {
|
||||
client = util.SetProxy(&cfg.SDKConfig, client)
|
||||
}
|
||||
return &DeviceFlowClient{
|
||||
httpClient: client,
|
||||
cfg: cfg,
|
||||
}
|
||||
}
|
||||
|
||||
// RequestDeviceCode initiates the device flow by requesting a device code from GitHub.
|
||||
func (c *DeviceFlowClient) RequestDeviceCode(ctx context.Context) (*DeviceCodeResponse, error) {
|
||||
data := url.Values{}
|
||||
data.Set("client_id", copilotClientID)
|
||||
data.Set("scope", "read:user user:email")
|
||||
|
||||
req, err := http.NewRequestWithContext(ctx, http.MethodPost, copilotDeviceCodeURL, strings.NewReader(data.Encode()))
|
||||
if err != nil {
|
||||
return nil, NewAuthenticationError(ErrDeviceCodeFailed, err)
|
||||
}
|
||||
req.Header.Set("Content-Type", "application/x-www-form-urlencoded")
|
||||
req.Header.Set("Accept", "application/json")
|
||||
|
||||
resp, err := c.httpClient.Do(req)
|
||||
if err != nil {
|
||||
return nil, NewAuthenticationError(ErrDeviceCodeFailed, err)
|
||||
}
|
||||
defer func() {
|
||||
if errClose := resp.Body.Close(); errClose != nil {
|
||||
log.Errorf("copilot device code: close body error: %v", errClose)
|
||||
}
|
||||
}()
|
||||
|
||||
if !isHTTPSuccess(resp.StatusCode) {
|
||||
bodyBytes, _ := io.ReadAll(resp.Body)
|
||||
return nil, NewAuthenticationError(ErrDeviceCodeFailed, fmt.Errorf("status %d: %s", resp.StatusCode, string(bodyBytes)))
|
||||
}
|
||||
|
||||
var deviceCode DeviceCodeResponse
|
||||
if err = json.NewDecoder(resp.Body).Decode(&deviceCode); err != nil {
|
||||
return nil, NewAuthenticationError(ErrDeviceCodeFailed, err)
|
||||
}
|
||||
|
||||
return &deviceCode, nil
|
||||
}
|
||||
|
||||
// PollForToken polls the token endpoint until the user authorizes or the device code expires.
|
||||
func (c *DeviceFlowClient) PollForToken(ctx context.Context, deviceCode *DeviceCodeResponse) (*CopilotTokenData, error) {
|
||||
if deviceCode == nil {
|
||||
return nil, NewAuthenticationError(ErrTokenExchangeFailed, fmt.Errorf("device code is nil"))
|
||||
}
|
||||
|
||||
interval := time.Duration(deviceCode.Interval) * time.Second
|
||||
if interval < defaultPollInterval {
|
||||
interval = defaultPollInterval
|
||||
}
|
||||
|
||||
deadline := time.Now().Add(maxPollDuration)
|
||||
if deviceCode.ExpiresIn > 0 {
|
||||
codeDeadline := time.Now().Add(time.Duration(deviceCode.ExpiresIn) * time.Second)
|
||||
if codeDeadline.Before(deadline) {
|
||||
deadline = codeDeadline
|
||||
}
|
||||
}
|
||||
|
||||
ticker := time.NewTicker(interval)
|
||||
defer ticker.Stop()
|
||||
|
||||
for {
|
||||
select {
|
||||
case <-ctx.Done():
|
||||
return nil, NewAuthenticationError(ErrPollingTimeout, ctx.Err())
|
||||
case <-ticker.C:
|
||||
if time.Now().After(deadline) {
|
||||
return nil, ErrPollingTimeout
|
||||
}
|
||||
|
||||
token, err := c.exchangeDeviceCode(ctx, deviceCode.DeviceCode)
|
||||
if err != nil {
|
||||
var authErr *AuthenticationError
|
||||
if errors.As(err, &authErr) {
|
||||
switch authErr.Type {
|
||||
case ErrAuthorizationPending.Type:
|
||||
// Continue polling
|
||||
continue
|
||||
case ErrSlowDown.Type:
|
||||
// Increase interval and continue
|
||||
interval += 5 * time.Second
|
||||
ticker.Reset(interval)
|
||||
continue
|
||||
case ErrDeviceCodeExpired.Type:
|
||||
return nil, err
|
||||
case ErrAccessDenied.Type:
|
||||
return nil, err
|
||||
}
|
||||
}
|
||||
return nil, err
|
||||
}
|
||||
return token, nil
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// exchangeDeviceCode attempts to exchange the device code for an access token.
|
||||
func (c *DeviceFlowClient) exchangeDeviceCode(ctx context.Context, deviceCode string) (*CopilotTokenData, error) {
|
||||
data := url.Values{}
|
||||
data.Set("client_id", copilotClientID)
|
||||
data.Set("device_code", deviceCode)
|
||||
data.Set("grant_type", "urn:ietf:params:oauth:grant-type:device_code")
|
||||
|
||||
req, err := http.NewRequestWithContext(ctx, http.MethodPost, copilotTokenURL, strings.NewReader(data.Encode()))
|
||||
if err != nil {
|
||||
return nil, NewAuthenticationError(ErrTokenExchangeFailed, err)
|
||||
}
|
||||
req.Header.Set("Content-Type", "application/x-www-form-urlencoded")
|
||||
req.Header.Set("Accept", "application/json")
|
||||
|
||||
resp, err := c.httpClient.Do(req)
|
||||
if err != nil {
|
||||
return nil, NewAuthenticationError(ErrTokenExchangeFailed, err)
|
||||
}
|
||||
defer func() {
|
||||
if errClose := resp.Body.Close(); errClose != nil {
|
||||
log.Errorf("copilot token exchange: close body error: %v", errClose)
|
||||
}
|
||||
}()
|
||||
|
||||
bodyBytes, err := io.ReadAll(resp.Body)
|
||||
if err != nil {
|
||||
return nil, NewAuthenticationError(ErrTokenExchangeFailed, err)
|
||||
}
|
||||
|
||||
// GitHub returns 200 for both success and error cases in device flow
|
||||
// Check for OAuth error response first
|
||||
var oauthResp struct {
|
||||
Error string `json:"error"`
|
||||
ErrorDescription string `json:"error_description"`
|
||||
AccessToken string `json:"access_token"`
|
||||
TokenType string `json:"token_type"`
|
||||
Scope string `json:"scope"`
|
||||
}
|
||||
|
||||
if err = json.Unmarshal(bodyBytes, &oauthResp); err != nil {
|
||||
return nil, NewAuthenticationError(ErrTokenExchangeFailed, err)
|
||||
}
|
||||
|
||||
if oauthResp.Error != "" {
|
||||
switch oauthResp.Error {
|
||||
case "authorization_pending":
|
||||
return nil, ErrAuthorizationPending
|
||||
case "slow_down":
|
||||
return nil, ErrSlowDown
|
||||
case "expired_token":
|
||||
return nil, ErrDeviceCodeExpired
|
||||
case "access_denied":
|
||||
return nil, ErrAccessDenied
|
||||
default:
|
||||
return nil, NewOAuthError(oauthResp.Error, oauthResp.ErrorDescription, resp.StatusCode)
|
||||
}
|
||||
}
|
||||
|
||||
if oauthResp.AccessToken == "" {
|
||||
return nil, NewAuthenticationError(ErrTokenExchangeFailed, fmt.Errorf("empty access token"))
|
||||
}
|
||||
|
||||
return &CopilotTokenData{
|
||||
AccessToken: oauthResp.AccessToken,
|
||||
TokenType: oauthResp.TokenType,
|
||||
Scope: oauthResp.Scope,
|
||||
}, nil
|
||||
}
|
||||
|
||||
// GitHubUserInfo holds GitHub user profile information.
|
||||
type GitHubUserInfo struct {
|
||||
// Login is the GitHub username.
|
||||
Login string
|
||||
// Email is the primary email address (may be empty if not public).
|
||||
Email string
|
||||
// Name is the display name.
|
||||
Name string
|
||||
}
|
||||
|
||||
// FetchUserInfo retrieves the GitHub user profile for the authenticated user.
|
||||
func (c *DeviceFlowClient) FetchUserInfo(ctx context.Context, accessToken string) (GitHubUserInfo, error) {
|
||||
if accessToken == "" {
|
||||
return GitHubUserInfo{}, NewAuthenticationError(ErrUserInfoFailed, fmt.Errorf("access token is empty"))
|
||||
}
|
||||
|
||||
req, err := http.NewRequestWithContext(ctx, http.MethodGet, copilotUserInfoURL, nil)
|
||||
if err != nil {
|
||||
return GitHubUserInfo{}, NewAuthenticationError(ErrUserInfoFailed, err)
|
||||
}
|
||||
req.Header.Set("Authorization", "Bearer "+accessToken)
|
||||
req.Header.Set("Accept", "application/json")
|
||||
req.Header.Set("User-Agent", "CLIProxyAPI")
|
||||
|
||||
resp, err := c.httpClient.Do(req)
|
||||
if err != nil {
|
||||
return GitHubUserInfo{}, NewAuthenticationError(ErrUserInfoFailed, err)
|
||||
}
|
||||
defer func() {
|
||||
if errClose := resp.Body.Close(); errClose != nil {
|
||||
log.Errorf("copilot user info: close body error: %v", errClose)
|
||||
}
|
||||
}()
|
||||
|
||||
if !isHTTPSuccess(resp.StatusCode) {
|
||||
bodyBytes, _ := io.ReadAll(resp.Body)
|
||||
return GitHubUserInfo{}, NewAuthenticationError(ErrUserInfoFailed, fmt.Errorf("status %d: %s", resp.StatusCode, string(bodyBytes)))
|
||||
}
|
||||
|
||||
var raw struct {
|
||||
Login string `json:"login"`
|
||||
Email string `json:"email"`
|
||||
Name string `json:"name"`
|
||||
}
|
||||
if err = json.NewDecoder(resp.Body).Decode(&raw); err != nil {
|
||||
return GitHubUserInfo{}, NewAuthenticationError(ErrUserInfoFailed, err)
|
||||
}
|
||||
|
||||
if raw.Login == "" {
|
||||
return GitHubUserInfo{}, NewAuthenticationError(ErrUserInfoFailed, fmt.Errorf("empty username"))
|
||||
}
|
||||
|
||||
return GitHubUserInfo{
|
||||
Login: raw.Login,
|
||||
Email: raw.Email,
|
||||
Name: raw.Name,
|
||||
}, nil
|
||||
}
|
||||
213
internal/auth/copilot/oauth_test.go
Normal file
213
internal/auth/copilot/oauth_test.go
Normal file
@@ -0,0 +1,213 @@
|
||||
package copilot
|
||||
|
||||
import (
|
||||
"context"
|
||||
"encoding/json"
|
||||
"net/http"
|
||||
"net/http/httptest"
|
||||
"strings"
|
||||
"testing"
|
||||
)
|
||||
|
||||
// roundTripFunc lets us inject a custom transport for testing.
|
||||
type roundTripFunc func(*http.Request) (*http.Response, error)
|
||||
|
||||
func (f roundTripFunc) RoundTrip(r *http.Request) (*http.Response, error) { return f(r) }
|
||||
|
||||
// newTestClient returns an *http.Client whose requests are redirected to the given test server,
|
||||
// regardless of the original URL host.
|
||||
func newTestClient(srv *httptest.Server) *http.Client {
|
||||
return &http.Client{
|
||||
Transport: roundTripFunc(func(req *http.Request) (*http.Response, error) {
|
||||
req2 := req.Clone(req.Context())
|
||||
req2.URL.Scheme = "http"
|
||||
req2.URL.Host = strings.TrimPrefix(srv.URL, "http://")
|
||||
return srv.Client().Transport.RoundTrip(req2)
|
||||
}),
|
||||
}
|
||||
}
|
||||
|
||||
// TestFetchUserInfo_FullProfile verifies that FetchUserInfo returns login, email, and name.
|
||||
func TestFetchUserInfo_FullProfile(t *testing.T) {
|
||||
srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
if !strings.HasPrefix(r.Header.Get("Authorization"), "Bearer ") {
|
||||
w.WriteHeader(http.StatusUnauthorized)
|
||||
return
|
||||
}
|
||||
w.Header().Set("Content-Type", "application/json")
|
||||
_ = json.NewEncoder(w).Encode(map[string]string{
|
||||
"login": "octocat",
|
||||
"email": "octocat@github.com",
|
||||
"name": "The Octocat",
|
||||
})
|
||||
}))
|
||||
defer srv.Close()
|
||||
|
||||
client := &DeviceFlowClient{httpClient: newTestClient(srv)}
|
||||
info, err := client.FetchUserInfo(context.Background(), "test-token")
|
||||
if err != nil {
|
||||
t.Fatalf("unexpected error: %v", err)
|
||||
}
|
||||
if info.Login != "octocat" {
|
||||
t.Errorf("Login: got %q, want %q", info.Login, "octocat")
|
||||
}
|
||||
if info.Email != "octocat@github.com" {
|
||||
t.Errorf("Email: got %q, want %q", info.Email, "octocat@github.com")
|
||||
}
|
||||
if info.Name != "The Octocat" {
|
||||
t.Errorf("Name: got %q, want %q", info.Name, "The Octocat")
|
||||
}
|
||||
}
|
||||
|
||||
// TestFetchUserInfo_EmptyEmail verifies graceful handling when email is absent (private account).
|
||||
func TestFetchUserInfo_EmptyEmail(t *testing.T) {
|
||||
srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
w.Header().Set("Content-Type", "application/json")
|
||||
// GitHub returns null for private emails.
|
||||
_, _ = w.Write([]byte(`{"login":"privateuser","email":null,"name":"Private User"}`))
|
||||
}))
|
||||
defer srv.Close()
|
||||
|
||||
client := &DeviceFlowClient{httpClient: newTestClient(srv)}
|
||||
info, err := client.FetchUserInfo(context.Background(), "test-token")
|
||||
if err != nil {
|
||||
t.Fatalf("unexpected error: %v", err)
|
||||
}
|
||||
if info.Login != "privateuser" {
|
||||
t.Errorf("Login: got %q, want %q", info.Login, "privateuser")
|
||||
}
|
||||
if info.Email != "" {
|
||||
t.Errorf("Email: got %q, want empty string", info.Email)
|
||||
}
|
||||
if info.Name != "Private User" {
|
||||
t.Errorf("Name: got %q, want %q", info.Name, "Private User")
|
||||
}
|
||||
}
|
||||
|
||||
// TestFetchUserInfo_EmptyToken verifies error is returned for empty access token.
|
||||
func TestFetchUserInfo_EmptyToken(t *testing.T) {
|
||||
client := &DeviceFlowClient{httpClient: http.DefaultClient}
|
||||
_, err := client.FetchUserInfo(context.Background(), "")
|
||||
if err == nil {
|
||||
t.Fatal("expected error for empty token, got nil")
|
||||
}
|
||||
}
|
||||
|
||||
// TestFetchUserInfo_EmptyLogin verifies error is returned when API returns no login.
|
||||
func TestFetchUserInfo_EmptyLogin(t *testing.T) {
|
||||
srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
w.Header().Set("Content-Type", "application/json")
|
||||
_, _ = w.Write([]byte(`{"email":"someone@example.com","name":"No Login"}`))
|
||||
}))
|
||||
defer srv.Close()
|
||||
|
||||
client := &DeviceFlowClient{httpClient: newTestClient(srv)}
|
||||
_, err := client.FetchUserInfo(context.Background(), "test-token")
|
||||
if err == nil {
|
||||
t.Fatal("expected error for empty login, got nil")
|
||||
}
|
||||
}
|
||||
|
||||
// TestFetchUserInfo_HTTPError verifies error is returned on non-2xx response.
|
||||
func TestFetchUserInfo_HTTPError(t *testing.T) {
|
||||
srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
w.WriteHeader(http.StatusUnauthorized)
|
||||
_, _ = w.Write([]byte(`{"message":"Bad credentials"}`))
|
||||
}))
|
||||
defer srv.Close()
|
||||
|
||||
client := &DeviceFlowClient{httpClient: newTestClient(srv)}
|
||||
_, err := client.FetchUserInfo(context.Background(), "bad-token")
|
||||
if err == nil {
|
||||
t.Fatal("expected error for 401 response, got nil")
|
||||
}
|
||||
}
|
||||
|
||||
// TestCopilotTokenStorage_EmailNameFields verifies Email and Name serialise correctly.
|
||||
func TestCopilotTokenStorage_EmailNameFields(t *testing.T) {
|
||||
ts := &CopilotTokenStorage{
|
||||
AccessToken: "ghu_abc",
|
||||
TokenType: "bearer",
|
||||
Scope: "read:user user:email",
|
||||
Username: "octocat",
|
||||
Email: "octocat@github.com",
|
||||
Name: "The Octocat",
|
||||
Type: "github-copilot",
|
||||
}
|
||||
|
||||
data, err := json.Marshal(ts)
|
||||
if err != nil {
|
||||
t.Fatalf("marshal error: %v", err)
|
||||
}
|
||||
|
||||
var out map[string]any
|
||||
if err = json.Unmarshal(data, &out); err != nil {
|
||||
t.Fatalf("unmarshal error: %v", err)
|
||||
}
|
||||
|
||||
for _, key := range []string{"access_token", "username", "email", "name", "type"} {
|
||||
if _, ok := out[key]; !ok {
|
||||
t.Errorf("expected key %q in JSON output, not found", key)
|
||||
}
|
||||
}
|
||||
if out["email"] != "octocat@github.com" {
|
||||
t.Errorf("email: got %v, want %q", out["email"], "octocat@github.com")
|
||||
}
|
||||
if out["name"] != "The Octocat" {
|
||||
t.Errorf("name: got %v, want %q", out["name"], "The Octocat")
|
||||
}
|
||||
}
|
||||
|
||||
// TestCopilotTokenStorage_OmitEmptyEmailName verifies email/name are omitted when empty (omitempty).
|
||||
func TestCopilotTokenStorage_OmitEmptyEmailName(t *testing.T) {
|
||||
ts := &CopilotTokenStorage{
|
||||
AccessToken: "ghu_abc",
|
||||
Username: "octocat",
|
||||
Type: "github-copilot",
|
||||
}
|
||||
|
||||
data, err := json.Marshal(ts)
|
||||
if err != nil {
|
||||
t.Fatalf("marshal error: %v", err)
|
||||
}
|
||||
|
||||
var out map[string]any
|
||||
if err = json.Unmarshal(data, &out); err != nil {
|
||||
t.Fatalf("unmarshal error: %v", err)
|
||||
}
|
||||
|
||||
if _, ok := out["email"]; ok {
|
||||
t.Error("email key should be omitted when empty (omitempty), but was present")
|
||||
}
|
||||
if _, ok := out["name"]; ok {
|
||||
t.Error("name key should be omitted when empty (omitempty), but was present")
|
||||
}
|
||||
}
|
||||
|
||||
// TestCopilotAuthBundle_EmailNameFields verifies bundle carries email and name through the pipeline.
|
||||
func TestCopilotAuthBundle_EmailNameFields(t *testing.T) {
|
||||
bundle := &CopilotAuthBundle{
|
||||
TokenData: &CopilotTokenData{AccessToken: "ghu_abc"},
|
||||
Username: "octocat",
|
||||
Email: "octocat@github.com",
|
||||
Name: "The Octocat",
|
||||
}
|
||||
if bundle.Email != "octocat@github.com" {
|
||||
t.Errorf("bundle.Email: got %q, want %q", bundle.Email, "octocat@github.com")
|
||||
}
|
||||
if bundle.Name != "The Octocat" {
|
||||
t.Errorf("bundle.Name: got %q, want %q", bundle.Name, "The Octocat")
|
||||
}
|
||||
}
|
||||
|
||||
// TestGitHubUserInfo_Struct verifies the exported GitHubUserInfo struct fields are accessible.
|
||||
func TestGitHubUserInfo_Struct(t *testing.T) {
|
||||
info := GitHubUserInfo{
|
||||
Login: "octocat",
|
||||
Email: "octocat@github.com",
|
||||
Name: "The Octocat",
|
||||
}
|
||||
if info.Login == "" || info.Email == "" || info.Name == "" {
|
||||
t.Error("GitHubUserInfo fields should not be empty")
|
||||
}
|
||||
}
|
||||
101
internal/auth/copilot/token.go
Normal file
101
internal/auth/copilot/token.go
Normal file
@@ -0,0 +1,101 @@
|
||||
// Package copilot provides authentication and token management functionality
|
||||
// for GitHub Copilot AI services. It handles OAuth2 device flow token storage,
|
||||
// serialization, and retrieval for maintaining authenticated sessions with the Copilot API.
|
||||
package copilot
|
||||
|
||||
import (
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"os"
|
||||
"path/filepath"
|
||||
|
||||
"github.com/router-for-me/CLIProxyAPI/v6/internal/misc"
|
||||
)
|
||||
|
||||
// CopilotTokenStorage stores OAuth2 token information for GitHub Copilot API authentication.
|
||||
// It maintains compatibility with the existing auth system while adding Copilot-specific fields
|
||||
// for managing access tokens and user account information.
|
||||
type CopilotTokenStorage struct {
|
||||
// AccessToken is the OAuth2 access token used for authenticating API requests.
|
||||
AccessToken string `json:"access_token"`
|
||||
// TokenType is the type of token, typically "bearer".
|
||||
TokenType string `json:"token_type"`
|
||||
// Scope is the OAuth2 scope granted to the token.
|
||||
Scope string `json:"scope"`
|
||||
// ExpiresAt is the timestamp when the access token expires (if provided).
|
||||
ExpiresAt string `json:"expires_at,omitempty"`
|
||||
// Username is the GitHub username associated with this token.
|
||||
Username string `json:"username"`
|
||||
// Email is the GitHub email address associated with this token.
|
||||
Email string `json:"email,omitempty"`
|
||||
// Name is the GitHub display name associated with this token.
|
||||
Name string `json:"name,omitempty"`
|
||||
// Type indicates the authentication provider type, always "github-copilot" for this storage.
|
||||
Type string `json:"type"`
|
||||
}
|
||||
|
||||
// CopilotTokenData holds the raw OAuth token response from GitHub.
|
||||
type CopilotTokenData struct {
|
||||
// AccessToken is the OAuth2 access token.
|
||||
AccessToken string `json:"access_token"`
|
||||
// TokenType is the type of token, typically "bearer".
|
||||
TokenType string `json:"token_type"`
|
||||
// Scope is the OAuth2 scope granted to the token.
|
||||
Scope string `json:"scope"`
|
||||
}
|
||||
|
||||
// CopilotAuthBundle bundles authentication data for storage.
|
||||
type CopilotAuthBundle struct {
|
||||
// TokenData contains the OAuth token information.
|
||||
TokenData *CopilotTokenData
|
||||
// Username is the GitHub username.
|
||||
Username string
|
||||
// Email is the GitHub email address.
|
||||
Email string
|
||||
// Name is the GitHub display name.
|
||||
Name string
|
||||
}
|
||||
|
||||
// DeviceCodeResponse represents GitHub's device code response.
|
||||
type DeviceCodeResponse struct {
|
||||
// DeviceCode is the device verification code.
|
||||
DeviceCode string `json:"device_code"`
|
||||
// UserCode is the code the user must enter at the verification URI.
|
||||
UserCode string `json:"user_code"`
|
||||
// VerificationURI is the URL where the user should enter the code.
|
||||
VerificationURI string `json:"verification_uri"`
|
||||
// ExpiresIn is the number of seconds until the device code expires.
|
||||
ExpiresIn int `json:"expires_in"`
|
||||
// Interval is the minimum number of seconds to wait between polling requests.
|
||||
Interval int `json:"interval"`
|
||||
}
|
||||
|
||||
// SaveTokenToFile serializes the Copilot token storage to a JSON file.
|
||||
// This method creates the necessary directory structure and writes the token
|
||||
// data in JSON format to the specified file path for persistent storage.
|
||||
//
|
||||
// Parameters:
|
||||
// - authFilePath: The full path where the token file should be saved
|
||||
//
|
||||
// Returns:
|
||||
// - error: An error if the operation fails, nil otherwise
|
||||
func (ts *CopilotTokenStorage) SaveTokenToFile(authFilePath string) error {
|
||||
misc.LogSavingCredentials(authFilePath)
|
||||
ts.Type = "github-copilot"
|
||||
if err := os.MkdirAll(filepath.Dir(authFilePath), 0700); err != nil {
|
||||
return fmt.Errorf("failed to create directory: %v", err)
|
||||
}
|
||||
|
||||
f, err := os.Create(authFilePath)
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to create token file: %w", err)
|
||||
}
|
||||
defer func() {
|
||||
_ = f.Close()
|
||||
}()
|
||||
|
||||
if err = json.NewEncoder(f).Encode(ts); err != nil {
|
||||
return fmt.Errorf("failed to write token to file: %w", err)
|
||||
}
|
||||
return nil
|
||||
}
|
||||
492
internal/auth/gitlab/gitlab.go
Normal file
492
internal/auth/gitlab/gitlab.go
Normal file
@@ -0,0 +1,492 @@
|
||||
package gitlab
|
||||
|
||||
import (
|
||||
"context"
|
||||
"crypto/rand"
|
||||
"crypto/sha256"
|
||||
"encoding/base64"
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"io"
|
||||
"net"
|
||||
"net/http"
|
||||
"net/url"
|
||||
"strconv"
|
||||
"strings"
|
||||
"sync"
|
||||
"time"
|
||||
|
||||
"github.com/router-for-me/CLIProxyAPI/v6/internal/config"
|
||||
"github.com/router-for-me/CLIProxyAPI/v6/internal/util"
|
||||
log "github.com/sirupsen/logrus"
|
||||
)
|
||||
|
||||
const (
|
||||
DefaultBaseURL = "https://gitlab.com"
|
||||
DefaultCallbackPort = 17171
|
||||
defaultOAuthScope = "api read_user"
|
||||
)
|
||||
|
||||
type PKCECodes struct {
|
||||
CodeVerifier string
|
||||
CodeChallenge string
|
||||
}
|
||||
|
||||
type OAuthResult struct {
|
||||
Code string
|
||||
State string
|
||||
Error string
|
||||
}
|
||||
|
||||
type OAuthServer struct {
|
||||
server *http.Server
|
||||
port int
|
||||
resultChan chan *OAuthResult
|
||||
errorChan chan error
|
||||
mu sync.Mutex
|
||||
running bool
|
||||
}
|
||||
|
||||
type TokenResponse struct {
|
||||
AccessToken string `json:"access_token"`
|
||||
TokenType string `json:"token_type"`
|
||||
RefreshToken string `json:"refresh_token"`
|
||||
Scope string `json:"scope"`
|
||||
CreatedAt int64 `json:"created_at"`
|
||||
ExpiresIn int `json:"expires_in"`
|
||||
}
|
||||
|
||||
type User struct {
|
||||
ID int64 `json:"id"`
|
||||
Username string `json:"username"`
|
||||
Name string `json:"name"`
|
||||
Email string `json:"email"`
|
||||
PublicEmail string `json:"public_email"`
|
||||
}
|
||||
|
||||
type PersonalAccessTokenSelf struct {
|
||||
ID int64 `json:"id"`
|
||||
Name string `json:"name"`
|
||||
Scopes []string `json:"scopes"`
|
||||
UserID int64 `json:"user_id"`
|
||||
}
|
||||
|
||||
type ModelDetails struct {
|
||||
ModelProvider string `json:"model_provider"`
|
||||
ModelName string `json:"model_name"`
|
||||
}
|
||||
|
||||
type DirectAccessResponse struct {
|
||||
BaseURL string `json:"base_url"`
|
||||
Token string `json:"token"`
|
||||
ExpiresAt int64 `json:"expires_at"`
|
||||
Headers map[string]string `json:"headers"`
|
||||
ModelDetails *ModelDetails `json:"model_details,omitempty"`
|
||||
}
|
||||
|
||||
type DiscoveredModel struct {
|
||||
ModelProvider string
|
||||
ModelName string
|
||||
}
|
||||
|
||||
type AuthClient struct {
|
||||
httpClient *http.Client
|
||||
}
|
||||
|
||||
func NewAuthClient(cfg *config.Config) *AuthClient {
|
||||
client := &http.Client{}
|
||||
if cfg != nil {
|
||||
client = util.SetProxy(&cfg.SDKConfig, client)
|
||||
}
|
||||
return &AuthClient{httpClient: client}
|
||||
}
|
||||
|
||||
func NormalizeBaseURL(raw string) string {
|
||||
value := strings.TrimSpace(raw)
|
||||
if value == "" {
|
||||
return DefaultBaseURL
|
||||
}
|
||||
if !strings.Contains(value, "://") {
|
||||
value = "https://" + value
|
||||
}
|
||||
value = strings.TrimRight(value, "/")
|
||||
return value
|
||||
}
|
||||
|
||||
func TokenExpiry(now time.Time, token *TokenResponse) time.Time {
|
||||
if token == nil {
|
||||
return time.Time{}
|
||||
}
|
||||
if token.CreatedAt > 0 && token.ExpiresIn > 0 {
|
||||
return time.Unix(token.CreatedAt+int64(token.ExpiresIn), 0).UTC()
|
||||
}
|
||||
if token.ExpiresIn > 0 {
|
||||
return now.UTC().Add(time.Duration(token.ExpiresIn) * time.Second)
|
||||
}
|
||||
return time.Time{}
|
||||
}
|
||||
|
||||
func GeneratePKCECodes() (*PKCECodes, error) {
|
||||
verifierBytes := make([]byte, 32)
|
||||
if _, err := rand.Read(verifierBytes); err != nil {
|
||||
return nil, fmt.Errorf("gitlab pkce generation failed: %w", err)
|
||||
}
|
||||
verifier := base64.RawURLEncoding.EncodeToString(verifierBytes)
|
||||
sum := sha256.Sum256([]byte(verifier))
|
||||
challenge := base64.RawURLEncoding.EncodeToString(sum[:])
|
||||
return &PKCECodes{
|
||||
CodeVerifier: verifier,
|
||||
CodeChallenge: challenge,
|
||||
}, nil
|
||||
}
|
||||
|
||||
func NewOAuthServer(port int) *OAuthServer {
|
||||
return &OAuthServer{
|
||||
port: port,
|
||||
resultChan: make(chan *OAuthResult, 1),
|
||||
errorChan: make(chan error, 1),
|
||||
}
|
||||
}
|
||||
|
||||
func (s *OAuthServer) Start() error {
|
||||
s.mu.Lock()
|
||||
defer s.mu.Unlock()
|
||||
|
||||
if s.running {
|
||||
return fmt.Errorf("gitlab oauth server already running")
|
||||
}
|
||||
if !s.isPortAvailable() {
|
||||
return fmt.Errorf("port %d is already in use", s.port)
|
||||
}
|
||||
|
||||
mux := http.NewServeMux()
|
||||
mux.HandleFunc("/auth/callback", s.handleCallback)
|
||||
|
||||
s.server = &http.Server{
|
||||
Addr: fmt.Sprintf(":%d", s.port),
|
||||
Handler: mux,
|
||||
ReadTimeout: 10 * time.Second,
|
||||
WriteTimeout: 10 * time.Second,
|
||||
}
|
||||
s.running = true
|
||||
|
||||
go func() {
|
||||
if err := s.server.ListenAndServe(); err != nil && err != http.ErrServerClosed {
|
||||
s.errorChan <- err
|
||||
}
|
||||
}()
|
||||
|
||||
time.Sleep(100 * time.Millisecond)
|
||||
return nil
|
||||
}
|
||||
|
||||
func (s *OAuthServer) Stop(ctx context.Context) error {
|
||||
s.mu.Lock()
|
||||
defer s.mu.Unlock()
|
||||
if !s.running || s.server == nil {
|
||||
return nil
|
||||
}
|
||||
defer func() {
|
||||
s.running = false
|
||||
s.server = nil
|
||||
}()
|
||||
return s.server.Shutdown(ctx)
|
||||
}
|
||||
|
||||
func (s *OAuthServer) WaitForCallback(timeout time.Duration) (*OAuthResult, error) {
|
||||
select {
|
||||
case result := <-s.resultChan:
|
||||
return result, nil
|
||||
case err := <-s.errorChan:
|
||||
return nil, err
|
||||
case <-time.After(timeout):
|
||||
return nil, fmt.Errorf("timeout waiting for OAuth callback")
|
||||
}
|
||||
}
|
||||
|
||||
func (s *OAuthServer) handleCallback(w http.ResponseWriter, r *http.Request) {
|
||||
if r.Method != http.MethodGet {
|
||||
http.Error(w, "method not allowed", http.StatusMethodNotAllowed)
|
||||
return
|
||||
}
|
||||
query := r.URL.Query()
|
||||
if errParam := strings.TrimSpace(query.Get("error")); errParam != "" {
|
||||
s.sendResult(&OAuthResult{Error: errParam})
|
||||
http.Error(w, errParam, http.StatusBadRequest)
|
||||
return
|
||||
}
|
||||
code := strings.TrimSpace(query.Get("code"))
|
||||
state := strings.TrimSpace(query.Get("state"))
|
||||
if code == "" || state == "" {
|
||||
s.sendResult(&OAuthResult{Error: "missing_code_or_state"})
|
||||
http.Error(w, "missing code or state", http.StatusBadRequest)
|
||||
return
|
||||
}
|
||||
s.sendResult(&OAuthResult{Code: code, State: state})
|
||||
_, _ = w.Write([]byte("GitLab authentication received. You can close this tab."))
|
||||
}
|
||||
|
||||
func (s *OAuthServer) sendResult(result *OAuthResult) {
|
||||
select {
|
||||
case s.resultChan <- result:
|
||||
default:
|
||||
log.Debug("gitlab oauth result channel full, dropping callback result")
|
||||
}
|
||||
}
|
||||
|
||||
func (s *OAuthServer) isPortAvailable() bool {
|
||||
listener, err := net.Listen("tcp", fmt.Sprintf(":%d", s.port))
|
||||
if err != nil {
|
||||
return false
|
||||
}
|
||||
_ = listener.Close()
|
||||
return true
|
||||
}
|
||||
|
||||
func RedirectURL(port int) string {
|
||||
return fmt.Sprintf("http://localhost:%d/auth/callback", port)
|
||||
}
|
||||
|
||||
func (c *AuthClient) GenerateAuthURL(baseURL, clientID, redirectURI, state string, pkce *PKCECodes) (string, error) {
|
||||
if pkce == nil {
|
||||
return "", fmt.Errorf("gitlab auth URL generation failed: PKCE codes are required")
|
||||
}
|
||||
if strings.TrimSpace(clientID) == "" {
|
||||
return "", fmt.Errorf("gitlab auth URL generation failed: client ID is required")
|
||||
}
|
||||
baseURL = NormalizeBaseURL(baseURL)
|
||||
params := url.Values{
|
||||
"client_id": {strings.TrimSpace(clientID)},
|
||||
"response_type": {"code"},
|
||||
"redirect_uri": {strings.TrimSpace(redirectURI)},
|
||||
"scope": {defaultOAuthScope},
|
||||
"state": {strings.TrimSpace(state)},
|
||||
"code_challenge": {pkce.CodeChallenge},
|
||||
"code_challenge_method": {"S256"},
|
||||
}
|
||||
return fmt.Sprintf("%s/oauth/authorize?%s", baseURL, params.Encode()), nil
|
||||
}
|
||||
|
||||
func (c *AuthClient) ExchangeCodeForTokens(ctx context.Context, baseURL, clientID, clientSecret, redirectURI, code, codeVerifier string) (*TokenResponse, error) {
|
||||
form := url.Values{
|
||||
"grant_type": {"authorization_code"},
|
||||
"client_id": {strings.TrimSpace(clientID)},
|
||||
"code": {strings.TrimSpace(code)},
|
||||
"redirect_uri": {strings.TrimSpace(redirectURI)},
|
||||
"code_verifier": {strings.TrimSpace(codeVerifier)},
|
||||
}
|
||||
if secret := strings.TrimSpace(clientSecret); secret != "" {
|
||||
form.Set("client_secret", secret)
|
||||
}
|
||||
return c.postToken(ctx, NormalizeBaseURL(baseURL)+"/oauth/token", form)
|
||||
}
|
||||
|
||||
func (c *AuthClient) RefreshTokens(ctx context.Context, baseURL, clientID, clientSecret, refreshToken string) (*TokenResponse, error) {
|
||||
form := url.Values{
|
||||
"grant_type": {"refresh_token"},
|
||||
"refresh_token": {strings.TrimSpace(refreshToken)},
|
||||
}
|
||||
if clientID = strings.TrimSpace(clientID); clientID != "" {
|
||||
form.Set("client_id", clientID)
|
||||
}
|
||||
if secret := strings.TrimSpace(clientSecret); secret != "" {
|
||||
form.Set("client_secret", secret)
|
||||
}
|
||||
return c.postToken(ctx, NormalizeBaseURL(baseURL)+"/oauth/token", form)
|
||||
}
|
||||
|
||||
func (c *AuthClient) postToken(ctx context.Context, tokenURL string, form url.Values) (*TokenResponse, error) {
|
||||
req, err := http.NewRequestWithContext(ctx, http.MethodPost, tokenURL, strings.NewReader(form.Encode()))
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("gitlab token request failed: %w", err)
|
||||
}
|
||||
req.Header.Set("Content-Type", "application/x-www-form-urlencoded")
|
||||
req.Header.Set("Accept", "application/json")
|
||||
|
||||
resp, err := c.httpClient.Do(req)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("gitlab token request failed: %w", err)
|
||||
}
|
||||
defer func() { _ = resp.Body.Close() }()
|
||||
|
||||
body, err := io.ReadAll(resp.Body)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("gitlab token response read failed: %w", err)
|
||||
}
|
||||
if resp.StatusCode < 200 || resp.StatusCode >= 300 {
|
||||
return nil, fmt.Errorf("gitlab token request failed with status %d: %s", resp.StatusCode, strings.TrimSpace(string(body)))
|
||||
}
|
||||
var token TokenResponse
|
||||
if err := json.Unmarshal(body, &token); err != nil {
|
||||
return nil, fmt.Errorf("gitlab token response decode failed: %w", err)
|
||||
}
|
||||
return &token, nil
|
||||
}
|
||||
|
||||
func (c *AuthClient) GetCurrentUser(ctx context.Context, baseURL, token string) (*User, error) {
|
||||
req, err := http.NewRequestWithContext(ctx, http.MethodGet, NormalizeBaseURL(baseURL)+"/api/v4/user", nil)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("gitlab user request failed: %w", err)
|
||||
}
|
||||
req.Header.Set("Authorization", "Bearer "+strings.TrimSpace(token))
|
||||
req.Header.Set("Accept", "application/json")
|
||||
|
||||
resp, err := c.httpClient.Do(req)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("gitlab user request failed: %w", err)
|
||||
}
|
||||
defer func() { _ = resp.Body.Close() }()
|
||||
|
||||
body, err := io.ReadAll(resp.Body)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("gitlab user response read failed: %w", err)
|
||||
}
|
||||
if resp.StatusCode < 200 || resp.StatusCode >= 300 {
|
||||
return nil, fmt.Errorf("gitlab user request failed with status %d: %s", resp.StatusCode, strings.TrimSpace(string(body)))
|
||||
}
|
||||
|
||||
var user User
|
||||
if err := json.Unmarshal(body, &user); err != nil {
|
||||
return nil, fmt.Errorf("gitlab user response decode failed: %w", err)
|
||||
}
|
||||
return &user, nil
|
||||
}
|
||||
|
||||
func (c *AuthClient) GetPersonalAccessTokenSelf(ctx context.Context, baseURL, token string) (*PersonalAccessTokenSelf, error) {
|
||||
req, err := http.NewRequestWithContext(ctx, http.MethodGet, NormalizeBaseURL(baseURL)+"/api/v4/personal_access_tokens/self", nil)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("gitlab PAT self request failed: %w", err)
|
||||
}
|
||||
req.Header.Set("Authorization", "Bearer "+strings.TrimSpace(token))
|
||||
req.Header.Set("Accept", "application/json")
|
||||
|
||||
resp, err := c.httpClient.Do(req)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("gitlab PAT self request failed: %w", err)
|
||||
}
|
||||
defer func() { _ = resp.Body.Close() }()
|
||||
|
||||
body, err := io.ReadAll(resp.Body)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("gitlab PAT self response read failed: %w", err)
|
||||
}
|
||||
if resp.StatusCode < 200 || resp.StatusCode >= 300 {
|
||||
return nil, fmt.Errorf("gitlab PAT self request failed with status %d: %s", resp.StatusCode, strings.TrimSpace(string(body)))
|
||||
}
|
||||
|
||||
var pat PersonalAccessTokenSelf
|
||||
if err := json.Unmarshal(body, &pat); err != nil {
|
||||
return nil, fmt.Errorf("gitlab PAT self response decode failed: %w", err)
|
||||
}
|
||||
return &pat, nil
|
||||
}
|
||||
|
||||
func (c *AuthClient) FetchDirectAccess(ctx context.Context, baseURL, token string) (*DirectAccessResponse, error) {
|
||||
req, err := http.NewRequestWithContext(ctx, http.MethodPost, NormalizeBaseURL(baseURL)+"/api/v4/code_suggestions/direct_access", nil)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("gitlab direct access request failed: %w", err)
|
||||
}
|
||||
req.Header.Set("Authorization", "Bearer "+strings.TrimSpace(token))
|
||||
req.Header.Set("Accept", "application/json")
|
||||
|
||||
resp, err := c.httpClient.Do(req)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("gitlab direct access request failed: %w", err)
|
||||
}
|
||||
defer func() { _ = resp.Body.Close() }()
|
||||
|
||||
body, err := io.ReadAll(resp.Body)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("gitlab direct access response read failed: %w", err)
|
||||
}
|
||||
if resp.StatusCode < 200 || resp.StatusCode >= 300 {
|
||||
return nil, fmt.Errorf("gitlab direct access request failed with status %d: %s", resp.StatusCode, strings.TrimSpace(string(body)))
|
||||
}
|
||||
|
||||
var direct DirectAccessResponse
|
||||
if err := json.Unmarshal(body, &direct); err != nil {
|
||||
return nil, fmt.Errorf("gitlab direct access response decode failed: %w", err)
|
||||
}
|
||||
if direct.Headers == nil {
|
||||
direct.Headers = make(map[string]string)
|
||||
}
|
||||
return &direct, nil
|
||||
}
|
||||
|
||||
func ExtractDiscoveredModels(metadata map[string]any) []DiscoveredModel {
|
||||
if len(metadata) == 0 {
|
||||
return nil
|
||||
}
|
||||
|
||||
models := make([]DiscoveredModel, 0, 4)
|
||||
seen := make(map[string]struct{})
|
||||
appendModel := func(provider, name string) {
|
||||
provider = strings.TrimSpace(provider)
|
||||
name = strings.TrimSpace(name)
|
||||
if name == "" {
|
||||
return
|
||||
}
|
||||
key := strings.ToLower(name)
|
||||
if _, ok := seen[key]; ok {
|
||||
return
|
||||
}
|
||||
seen[key] = struct{}{}
|
||||
models = append(models, DiscoveredModel{
|
||||
ModelProvider: provider,
|
||||
ModelName: name,
|
||||
})
|
||||
}
|
||||
|
||||
if raw, ok := metadata["model_details"]; ok {
|
||||
appendDiscoveredModels(raw, appendModel)
|
||||
}
|
||||
appendModel(stringValue(metadata["model_provider"]), stringValue(metadata["model_name"]))
|
||||
|
||||
for _, key := range []string{"models", "supported_models", "discovered_models"} {
|
||||
if raw, ok := metadata[key]; ok {
|
||||
appendDiscoveredModels(raw, appendModel)
|
||||
}
|
||||
}
|
||||
|
||||
return models
|
||||
}
|
||||
|
||||
func appendDiscoveredModels(raw any, appendModel func(provider, name string)) {
|
||||
switch typed := raw.(type) {
|
||||
case map[string]any:
|
||||
appendModel(stringValue(typed["model_provider"]), stringValue(typed["model_name"]))
|
||||
appendModel(stringValue(typed["provider"]), stringValue(typed["name"]))
|
||||
if nested, ok := typed["models"]; ok {
|
||||
appendDiscoveredModels(nested, appendModel)
|
||||
}
|
||||
case []any:
|
||||
for _, item := range typed {
|
||||
appendDiscoveredModels(item, appendModel)
|
||||
}
|
||||
case []string:
|
||||
for _, item := range typed {
|
||||
appendModel("", item)
|
||||
}
|
||||
case string:
|
||||
appendModel("", typed)
|
||||
}
|
||||
}
|
||||
|
||||
func stringValue(raw any) string {
|
||||
switch typed := raw.(type) {
|
||||
case string:
|
||||
return strings.TrimSpace(typed)
|
||||
case fmt.Stringer:
|
||||
return strings.TrimSpace(typed.String())
|
||||
case json.Number:
|
||||
return typed.String()
|
||||
case int:
|
||||
return strconv.Itoa(typed)
|
||||
case int64:
|
||||
return strconv.FormatInt(typed, 10)
|
||||
case float64:
|
||||
return strconv.FormatInt(int64(typed), 10)
|
||||
default:
|
||||
return ""
|
||||
}
|
||||
}
|
||||
138
internal/auth/gitlab/gitlab_test.go
Normal file
138
internal/auth/gitlab/gitlab_test.go
Normal file
@@ -0,0 +1,138 @@
|
||||
package gitlab
|
||||
|
||||
import (
|
||||
"context"
|
||||
"encoding/json"
|
||||
"net/http"
|
||||
"net/http/httptest"
|
||||
"net/url"
|
||||
"strings"
|
||||
"testing"
|
||||
)
|
||||
|
||||
func TestAuthClientGenerateAuthURLIncludesPKCE(t *testing.T) {
|
||||
client := NewAuthClient(nil)
|
||||
pkce, err := GeneratePKCECodes()
|
||||
if err != nil {
|
||||
t.Fatalf("GeneratePKCECodes() error = %v", err)
|
||||
}
|
||||
|
||||
rawURL, err := client.GenerateAuthURL("https://gitlab.example.com", "client-id", RedirectURL(17171), "state-123", pkce)
|
||||
if err != nil {
|
||||
t.Fatalf("GenerateAuthURL() error = %v", err)
|
||||
}
|
||||
|
||||
parsed, err := url.Parse(rawURL)
|
||||
if err != nil {
|
||||
t.Fatalf("Parse(authURL) error = %v", err)
|
||||
}
|
||||
if got := parsed.Path; got != "/oauth/authorize" {
|
||||
t.Fatalf("expected /oauth/authorize path, got %q", got)
|
||||
}
|
||||
query := parsed.Query()
|
||||
if got := query.Get("client_id"); got != "client-id" {
|
||||
t.Fatalf("expected client_id, got %q", got)
|
||||
}
|
||||
if got := query.Get("scope"); got != defaultOAuthScope {
|
||||
t.Fatalf("expected scope %q, got %q", defaultOAuthScope, got)
|
||||
}
|
||||
if got := query.Get("code_challenge_method"); got != "S256" {
|
||||
t.Fatalf("expected PKCE method S256, got %q", got)
|
||||
}
|
||||
if got := query.Get("code_challenge"); got == "" {
|
||||
t.Fatal("expected non-empty code_challenge")
|
||||
}
|
||||
}
|
||||
|
||||
func TestAuthClientExchangeCodeForTokens(t *testing.T) {
|
||||
srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
if r.URL.Path != "/oauth/token" {
|
||||
t.Fatalf("unexpected path %q", r.URL.Path)
|
||||
}
|
||||
if err := r.ParseForm(); err != nil {
|
||||
t.Fatalf("ParseForm() error = %v", err)
|
||||
}
|
||||
if got := r.Form.Get("grant_type"); got != "authorization_code" {
|
||||
t.Fatalf("expected authorization_code grant, got %q", got)
|
||||
}
|
||||
if got := r.Form.Get("code_verifier"); got != "verifier-123" {
|
||||
t.Fatalf("expected code_verifier, got %q", got)
|
||||
}
|
||||
_ = json.NewEncoder(w).Encode(map[string]any{
|
||||
"access_token": "oauth-access",
|
||||
"refresh_token": "oauth-refresh",
|
||||
"token_type": "Bearer",
|
||||
"scope": "api read_user",
|
||||
"created_at": 1710000000,
|
||||
"expires_in": 3600,
|
||||
})
|
||||
}))
|
||||
defer srv.Close()
|
||||
|
||||
client := NewAuthClient(nil)
|
||||
token, err := client.ExchangeCodeForTokens(context.Background(), srv.URL, "client-id", "client-secret", RedirectURL(17171), "auth-code", "verifier-123")
|
||||
if err != nil {
|
||||
t.Fatalf("ExchangeCodeForTokens() error = %v", err)
|
||||
}
|
||||
if token.AccessToken != "oauth-access" {
|
||||
t.Fatalf("expected access token, got %q", token.AccessToken)
|
||||
}
|
||||
if token.RefreshToken != "oauth-refresh" {
|
||||
t.Fatalf("expected refresh token, got %q", token.RefreshToken)
|
||||
}
|
||||
}
|
||||
|
||||
func TestExtractDiscoveredModels(t *testing.T) {
|
||||
models := ExtractDiscoveredModels(map[string]any{
|
||||
"model_details": map[string]any{
|
||||
"model_provider": "anthropic",
|
||||
"model_name": "claude-sonnet-4-5",
|
||||
},
|
||||
"supported_models": []any{
|
||||
map[string]any{"model_provider": "openai", "model_name": "gpt-4.1"},
|
||||
"claude-sonnet-4-5",
|
||||
},
|
||||
})
|
||||
if len(models) != 2 {
|
||||
t.Fatalf("expected 2 unique models, got %d", len(models))
|
||||
}
|
||||
if models[0].ModelName != "claude-sonnet-4-5" {
|
||||
t.Fatalf("unexpected first model %q", models[0].ModelName)
|
||||
}
|
||||
if models[1].ModelName != "gpt-4.1" {
|
||||
t.Fatalf("unexpected second model %q", models[1].ModelName)
|
||||
}
|
||||
}
|
||||
|
||||
func TestFetchDirectAccessDecodesModelDetails(t *testing.T) {
|
||||
srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
if r.URL.Path != "/api/v4/code_suggestions/direct_access" {
|
||||
t.Fatalf("unexpected path %q", r.URL.Path)
|
||||
}
|
||||
if got := r.Header.Get("Authorization"); !strings.Contains(got, "token-123") {
|
||||
t.Fatalf("expected bearer token, got %q", got)
|
||||
}
|
||||
_ = json.NewEncoder(w).Encode(map[string]any{
|
||||
"base_url": "https://cloud.gitlab.example.com",
|
||||
"token": "gateway-token",
|
||||
"expires_at": 1710003600,
|
||||
"headers": map[string]string{
|
||||
"X-Gitlab-Realm": "saas",
|
||||
},
|
||||
"model_details": map[string]any{
|
||||
"model_provider": "anthropic",
|
||||
"model_name": "claude-sonnet-4-5",
|
||||
},
|
||||
})
|
||||
}))
|
||||
defer srv.Close()
|
||||
|
||||
client := NewAuthClient(nil)
|
||||
direct, err := client.FetchDirectAccess(context.Background(), srv.URL, "token-123")
|
||||
if err != nil {
|
||||
t.Fatalf("FetchDirectAccess() error = %v", err)
|
||||
}
|
||||
if direct.ModelDetails == nil || direct.ModelDetails.ModelName != "claude-sonnet-4-5" {
|
||||
t.Fatalf("expected model details, got %+v", direct.ModelDetails)
|
||||
}
|
||||
}
|
||||
@@ -9,6 +9,7 @@ import (
|
||||
"io"
|
||||
"net/http"
|
||||
"net/url"
|
||||
"os"
|
||||
"strings"
|
||||
"time"
|
||||
|
||||
@@ -28,10 +29,21 @@ const (
|
||||
iFlowAPIKeyEndpoint = "https://platform.iflow.cn/api/openapi/apikey"
|
||||
|
||||
// Client credentials provided by iFlow for the Code Assist integration.
|
||||
iFlowOAuthClientID = "10009311001"
|
||||
iFlowOAuthClientSecret = "4Z3YjXycVsQvyGF1etiNlIBB4RsqSDtW"
|
||||
iFlowOAuthClientID = "10009311001"
|
||||
// Default client secret (can be overridden via IFLOW_CLIENT_SECRET env var)
|
||||
defaultIFlowClientSecret = "4Z3YjXycVsQvyGF1etiNlIBB4RsqSDtW"
|
||||
)
|
||||
|
||||
// getIFlowClientSecret returns the iFlow OAuth client secret.
|
||||
// It first checks the IFLOW_CLIENT_SECRET environment variable,
|
||||
// falling back to the default value if not set.
|
||||
func getIFlowClientSecret() string {
|
||||
if secret := os.Getenv("IFLOW_CLIENT_SECRET"); secret != "" {
|
||||
return secret
|
||||
}
|
||||
return defaultIFlowClientSecret
|
||||
}
|
||||
|
||||
// DefaultAPIBaseURL is the canonical chat completions endpoint.
|
||||
const DefaultAPIBaseURL = "https://apis.iflow.cn/v1"
|
||||
|
||||
@@ -72,7 +84,7 @@ func (ia *IFlowAuth) ExchangeCodeForTokens(ctx context.Context, code, redirectUR
|
||||
form.Set("code", code)
|
||||
form.Set("redirect_uri", redirectURI)
|
||||
form.Set("client_id", iFlowOAuthClientID)
|
||||
form.Set("client_secret", iFlowOAuthClientSecret)
|
||||
form.Set("client_secret", getIFlowClientSecret())
|
||||
|
||||
req, err := ia.newTokenRequest(ctx, form)
|
||||
if err != nil {
|
||||
@@ -88,7 +100,7 @@ func (ia *IFlowAuth) RefreshTokens(ctx context.Context, refreshToken string) (*I
|
||||
form.Set("grant_type", "refresh_token")
|
||||
form.Set("refresh_token", refreshToken)
|
||||
form.Set("client_id", iFlowOAuthClientID)
|
||||
form.Set("client_secret", iFlowOAuthClientSecret)
|
||||
form.Set("client_secret", getIFlowClientSecret())
|
||||
|
||||
req, err := ia.newTokenRequest(ctx, form)
|
||||
if err != nil {
|
||||
@@ -104,7 +116,7 @@ func (ia *IFlowAuth) newTokenRequest(ctx context.Context, form url.Values) (*htt
|
||||
return nil, fmt.Errorf("iflow token: create request failed: %w", err)
|
||||
}
|
||||
|
||||
basic := base64.StdEncoding.EncodeToString([]byte(iFlowOAuthClientID + ":" + iFlowOAuthClientSecret))
|
||||
basic := base64.StdEncoding.EncodeToString([]byte(iFlowOAuthClientID + ":" + getIFlowClientSecret()))
|
||||
req.Header.Set("Content-Type", "application/x-www-form-urlencoded")
|
||||
req.Header.Set("Accept", "application/json")
|
||||
req.Header.Set("Authorization", "Basic "+basic)
|
||||
|
||||
168
internal/auth/kilo/kilo_auth.go
Normal file
168
internal/auth/kilo/kilo_auth.go
Normal file
@@ -0,0 +1,168 @@
|
||||
// Package kilo provides authentication and token management functionality
|
||||
// for Kilo AI services.
|
||||
package kilo
|
||||
|
||||
import (
|
||||
"context"
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"net/http"
|
||||
"time"
|
||||
)
|
||||
|
||||
const (
|
||||
// BaseURL is the base URL for the Kilo AI API.
|
||||
BaseURL = "https://api.kilo.ai/api"
|
||||
)
|
||||
|
||||
// DeviceAuthResponse represents the response from initiating device flow.
|
||||
type DeviceAuthResponse struct {
|
||||
Code string `json:"code"`
|
||||
VerificationURL string `json:"verificationUrl"`
|
||||
ExpiresIn int `json:"expiresIn"`
|
||||
}
|
||||
|
||||
// DeviceStatusResponse represents the response when polling for device flow status.
|
||||
type DeviceStatusResponse struct {
|
||||
Status string `json:"status"`
|
||||
Token string `json:"token"`
|
||||
UserEmail string `json:"userEmail"`
|
||||
}
|
||||
|
||||
// Profile represents the user profile from Kilo AI.
|
||||
type Profile struct {
|
||||
Email string `json:"email"`
|
||||
Orgs []Organization `json:"organizations"`
|
||||
}
|
||||
|
||||
// Organization represents a Kilo AI organization.
|
||||
type Organization struct {
|
||||
ID string `json:"id"`
|
||||
Name string `json:"name"`
|
||||
}
|
||||
|
||||
// Defaults represents default settings for an organization or user.
|
||||
type Defaults struct {
|
||||
Model string `json:"model"`
|
||||
}
|
||||
|
||||
// KiloAuth provides methods for handling the Kilo AI authentication flow.
|
||||
type KiloAuth struct {
|
||||
client *http.Client
|
||||
}
|
||||
|
||||
// NewKiloAuth creates a new instance of KiloAuth.
|
||||
func NewKiloAuth() *KiloAuth {
|
||||
return &KiloAuth{
|
||||
client: &http.Client{Timeout: 30 * time.Second},
|
||||
}
|
||||
}
|
||||
|
||||
// InitiateDeviceFlow starts the device authentication flow.
|
||||
func (k *KiloAuth) InitiateDeviceFlow(ctx context.Context) (*DeviceAuthResponse, error) {
|
||||
resp, err := k.client.Post(BaseURL+"/device-auth/codes", "application/json", nil)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
defer resp.Body.Close()
|
||||
|
||||
if resp.StatusCode != http.StatusCreated && resp.StatusCode != http.StatusOK {
|
||||
return nil, fmt.Errorf("failed to initiate device flow: status %d", resp.StatusCode)
|
||||
}
|
||||
|
||||
var data DeviceAuthResponse
|
||||
if err := json.NewDecoder(resp.Body).Decode(&data); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return &data, nil
|
||||
}
|
||||
|
||||
// PollForToken polls for the device flow completion.
|
||||
func (k *KiloAuth) PollForToken(ctx context.Context, code string) (*DeviceStatusResponse, error) {
|
||||
ticker := time.NewTicker(5 * time.Second)
|
||||
defer ticker.Stop()
|
||||
|
||||
for {
|
||||
select {
|
||||
case <-ctx.Done():
|
||||
return nil, ctx.Err()
|
||||
case <-ticker.C:
|
||||
resp, err := k.client.Get(BaseURL + "/device-auth/codes/" + code)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
defer resp.Body.Close()
|
||||
|
||||
var data DeviceStatusResponse
|
||||
if err := json.NewDecoder(resp.Body).Decode(&data); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
switch data.Status {
|
||||
case "approved":
|
||||
return &data, nil
|
||||
case "denied", "expired":
|
||||
return nil, fmt.Errorf("device flow %s", data.Status)
|
||||
case "pending":
|
||||
continue
|
||||
default:
|
||||
return nil, fmt.Errorf("unknown status: %s", data.Status)
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// GetProfile fetches the user's profile.
|
||||
func (k *KiloAuth) GetProfile(ctx context.Context, token string) (*Profile, error) {
|
||||
req, err := http.NewRequestWithContext(ctx, "GET", BaseURL+"/profile", nil)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to create get profile request: %w", err)
|
||||
}
|
||||
req.Header.Set("Authorization", "Bearer "+token)
|
||||
|
||||
resp, err := k.client.Do(req)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
defer resp.Body.Close()
|
||||
|
||||
if resp.StatusCode != http.StatusOK {
|
||||
return nil, fmt.Errorf("failed to get profile: status %d", resp.StatusCode)
|
||||
}
|
||||
|
||||
var profile Profile
|
||||
if err := json.NewDecoder(resp.Body).Decode(&profile); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return &profile, nil
|
||||
}
|
||||
|
||||
// GetDefaults fetches default settings for an organization.
|
||||
func (k *KiloAuth) GetDefaults(ctx context.Context, token, orgID string) (*Defaults, error) {
|
||||
url := BaseURL + "/defaults"
|
||||
if orgID != "" {
|
||||
url = BaseURL + "/organizations/" + orgID + "/defaults"
|
||||
}
|
||||
|
||||
req, err := http.NewRequestWithContext(ctx, "GET", url, nil)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to create get defaults request: %w", err)
|
||||
}
|
||||
req.Header.Set("Authorization", "Bearer "+token)
|
||||
|
||||
resp, err := k.client.Do(req)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
defer resp.Body.Close()
|
||||
|
||||
if resp.StatusCode != http.StatusOK {
|
||||
return nil, fmt.Errorf("failed to get defaults: status %d", resp.StatusCode)
|
||||
}
|
||||
|
||||
var defaults Defaults
|
||||
if err := json.NewDecoder(resp.Body).Decode(&defaults); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return &defaults, nil
|
||||
}
|
||||
60
internal/auth/kilo/kilo_token.go
Normal file
60
internal/auth/kilo/kilo_token.go
Normal file
@@ -0,0 +1,60 @@
|
||||
// Package kilo provides authentication and token management functionality
|
||||
// for Kilo AI services.
|
||||
package kilo
|
||||
|
||||
import (
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"os"
|
||||
"path/filepath"
|
||||
|
||||
"github.com/router-for-me/CLIProxyAPI/v6/internal/misc"
|
||||
log "github.com/sirupsen/logrus"
|
||||
)
|
||||
|
||||
// KiloTokenStorage stores token information for Kilo AI authentication.
|
||||
type KiloTokenStorage struct {
|
||||
// Token is the Kilo access token.
|
||||
Token string `json:"kilocodeToken"`
|
||||
|
||||
// OrganizationID is the Kilo organization ID.
|
||||
OrganizationID string `json:"kilocodeOrganizationId"`
|
||||
|
||||
// Model is the default model to use.
|
||||
Model string `json:"kilocodeModel"`
|
||||
|
||||
// Email is the email address of the authenticated user.
|
||||
Email string `json:"email"`
|
||||
|
||||
// Type indicates the authentication provider type, always "kilo" for this storage.
|
||||
Type string `json:"type"`
|
||||
}
|
||||
|
||||
// SaveTokenToFile serializes the Kilo token storage to a JSON file.
|
||||
func (ts *KiloTokenStorage) SaveTokenToFile(authFilePath string) error {
|
||||
misc.LogSavingCredentials(authFilePath)
|
||||
ts.Type = "kilo"
|
||||
if err := os.MkdirAll(filepath.Dir(authFilePath), 0700); err != nil {
|
||||
return fmt.Errorf("failed to create directory: %v", err)
|
||||
}
|
||||
|
||||
f, err := os.Create(authFilePath)
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to create token file: %w", err)
|
||||
}
|
||||
defer func() {
|
||||
if errClose := f.Close(); errClose != nil {
|
||||
log.Errorf("failed to close file: %v", errClose)
|
||||
}
|
||||
}()
|
||||
|
||||
if err = json.NewEncoder(f).Encode(ts); err != nil {
|
||||
return fmt.Errorf("failed to write token to file: %w", err)
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
// CredentialFileName returns the filename used to persist Kilo credentials.
|
||||
func CredentialFileName(email string) string {
|
||||
return fmt.Sprintf("kilo-%s.json", email)
|
||||
}
|
||||
681
internal/auth/kiro/aws.go
Normal file
681
internal/auth/kiro/aws.go
Normal file
@@ -0,0 +1,681 @@
|
||||
// Package kiro provides authentication functionality for AWS CodeWhisperer (Kiro) API.
|
||||
// It includes interfaces and implementations for token storage and authentication methods.
|
||||
package kiro
|
||||
|
||||
import (
|
||||
"encoding/base64"
|
||||
"encoding/json"
|
||||
"errors"
|
||||
"fmt"
|
||||
"net/url"
|
||||
"os"
|
||||
"path/filepath"
|
||||
"strings"
|
||||
"time"
|
||||
|
||||
log "github.com/sirupsen/logrus"
|
||||
)
|
||||
|
||||
// PKCECodes holds PKCE verification codes for OAuth2 PKCE flow
|
||||
type PKCECodes struct {
|
||||
// CodeVerifier is the cryptographically random string used to correlate
|
||||
// the authorization request to the token request
|
||||
CodeVerifier string `json:"code_verifier"`
|
||||
// CodeChallenge is the SHA256 hash of the code verifier, base64url-encoded
|
||||
CodeChallenge string `json:"code_challenge"`
|
||||
}
|
||||
|
||||
// KiroTokenData holds OAuth token information from AWS CodeWhisperer (Kiro)
|
||||
type KiroTokenData struct {
|
||||
// AccessToken is the OAuth2 access token for API access
|
||||
AccessToken string `json:"accessToken"`
|
||||
// RefreshToken is used to obtain new access tokens
|
||||
RefreshToken string `json:"refreshToken"`
|
||||
// ProfileArn is the AWS CodeWhisperer profile ARN
|
||||
ProfileArn string `json:"profileArn"`
|
||||
// ExpiresAt is the timestamp when the token expires
|
||||
ExpiresAt string `json:"expiresAt"`
|
||||
// AuthMethod indicates the authentication method used (e.g., "builder-id", "social", "idc")
|
||||
AuthMethod string `json:"authMethod"`
|
||||
// Provider indicates the OAuth provider (e.g., "AWS", "Google", "Enterprise")
|
||||
Provider string `json:"provider"`
|
||||
// ClientID is the OIDC client ID (needed for token refresh)
|
||||
ClientID string `json:"clientId,omitempty"`
|
||||
// ClientSecret is the OIDC client secret (needed for token refresh)
|
||||
ClientSecret string `json:"clientSecret,omitempty"`
|
||||
// ClientIDHash is the hash of client ID used to locate device registration file
|
||||
// (Enterprise Kiro IDE stores clientId/clientSecret in ~/.aws/sso/cache/{clientIdHash}.json)
|
||||
ClientIDHash string `json:"clientIdHash,omitempty"`
|
||||
// Email is the user's email address (used for file naming)
|
||||
Email string `json:"email,omitempty"`
|
||||
// StartURL is the IDC/Identity Center start URL (only for IDC auth method)
|
||||
StartURL string `json:"startUrl,omitempty"`
|
||||
// Region is the OIDC region for IDC login and token refresh
|
||||
Region string `json:"region,omitempty"`
|
||||
}
|
||||
|
||||
// KiroAuthBundle aggregates authentication data after OAuth flow completion
|
||||
type KiroAuthBundle struct {
|
||||
// TokenData contains the OAuth tokens from the authentication flow
|
||||
TokenData KiroTokenData `json:"token_data"`
|
||||
// LastRefresh is the timestamp of the last token refresh
|
||||
LastRefresh string `json:"last_refresh"`
|
||||
}
|
||||
|
||||
// KiroUsageInfo represents usage information from CodeWhisperer API
|
||||
type KiroUsageInfo struct {
|
||||
// SubscriptionTitle is the subscription plan name (e.g., "KIRO FREE")
|
||||
SubscriptionTitle string `json:"subscription_title"`
|
||||
// CurrentUsage is the current credit usage
|
||||
CurrentUsage float64 `json:"current_usage"`
|
||||
// UsageLimit is the maximum credit limit
|
||||
UsageLimit float64 `json:"usage_limit"`
|
||||
// NextReset is the timestamp of the next usage reset
|
||||
NextReset string `json:"next_reset"`
|
||||
}
|
||||
|
||||
// KiroModel represents a model available through the CodeWhisperer API
|
||||
type KiroModel struct {
|
||||
// ModelID is the unique identifier for the model
|
||||
ModelID string `json:"modelId"`
|
||||
// ModelName is the human-readable name
|
||||
ModelName string `json:"modelName"`
|
||||
// Description is the model description
|
||||
Description string `json:"description"`
|
||||
// RateMultiplier is the credit multiplier for this model
|
||||
RateMultiplier float64 `json:"rateMultiplier"`
|
||||
// RateUnit is the unit for rate calculation (e.g., "credit")
|
||||
RateUnit string `json:"rateUnit"`
|
||||
// MaxInputTokens is the maximum input token limit
|
||||
MaxInputTokens int `json:"maxInputTokens,omitempty"`
|
||||
}
|
||||
|
||||
// KiroIDETokenFile is the default path to Kiro IDE's token file
|
||||
const KiroIDETokenFile = ".aws/sso/cache/kiro-auth-token.json"
|
||||
|
||||
// Default retry configuration for file reading
|
||||
const (
|
||||
defaultTokenReadMaxAttempts = 10 // Maximum retry attempts
|
||||
defaultTokenReadBaseDelay = 50 * time.Millisecond // Base delay between retries
|
||||
)
|
||||
|
||||
// isTransientFileError checks if the error is a transient file access error
|
||||
// that may be resolved by retrying (e.g., file locked by another process on Windows).
|
||||
func isTransientFileError(err error) bool {
|
||||
if err == nil {
|
||||
return false
|
||||
}
|
||||
|
||||
// Check for OS-level file access errors (Windows sharing violation, etc.)
|
||||
var pathErr *os.PathError
|
||||
if errors.As(err, &pathErr) {
|
||||
// Windows sharing violation (ERROR_SHARING_VIOLATION = 32)
|
||||
// Windows lock violation (ERROR_LOCK_VIOLATION = 33)
|
||||
errStr := pathErr.Err.Error()
|
||||
if strings.Contains(errStr, "being used by another process") ||
|
||||
strings.Contains(errStr, "sharing violation") ||
|
||||
strings.Contains(errStr, "lock violation") {
|
||||
return true
|
||||
}
|
||||
}
|
||||
|
||||
// Check error message for common transient patterns
|
||||
errMsg := strings.ToLower(err.Error())
|
||||
transientPatterns := []string{
|
||||
"being used by another process",
|
||||
"sharing violation",
|
||||
"lock violation",
|
||||
"access is denied",
|
||||
"unexpected end of json",
|
||||
"unexpected eof",
|
||||
}
|
||||
for _, pattern := range transientPatterns {
|
||||
if strings.Contains(errMsg, pattern) {
|
||||
return true
|
||||
}
|
||||
}
|
||||
|
||||
return false
|
||||
}
|
||||
|
||||
// LoadKiroIDETokenWithRetry loads token data from Kiro IDE's token file with retry logic.
|
||||
// This handles transient file access errors (e.g., file locked by Kiro IDE during write).
|
||||
// maxAttempts: maximum number of retry attempts (default 10 if <= 0)
|
||||
// baseDelay: base delay between retries with exponential backoff (default 50ms if <= 0)
|
||||
func LoadKiroIDETokenWithRetry(maxAttempts int, baseDelay time.Duration) (*KiroTokenData, error) {
|
||||
if maxAttempts <= 0 {
|
||||
maxAttempts = defaultTokenReadMaxAttempts
|
||||
}
|
||||
if baseDelay <= 0 {
|
||||
baseDelay = defaultTokenReadBaseDelay
|
||||
}
|
||||
|
||||
var lastErr error
|
||||
for attempt := 0; attempt < maxAttempts; attempt++ {
|
||||
token, err := LoadKiroIDEToken()
|
||||
if err == nil {
|
||||
return token, nil
|
||||
}
|
||||
lastErr = err
|
||||
|
||||
// Only retry for transient errors
|
||||
if !isTransientFileError(err) {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
// Exponential backoff: delay * 2^attempt, capped at 500ms
|
||||
delay := baseDelay * time.Duration(1<<uint(attempt))
|
||||
if delay > 500*time.Millisecond {
|
||||
delay = 500 * time.Millisecond
|
||||
}
|
||||
time.Sleep(delay)
|
||||
}
|
||||
|
||||
return nil, fmt.Errorf("failed to read token file after %d attempts: %w", maxAttempts, lastErr)
|
||||
}
|
||||
|
||||
// LoadKiroIDEToken loads token data from Kiro IDE's token file.
|
||||
// For Enterprise Kiro IDE (IDC auth), it also loads clientId and clientSecret
|
||||
// from the device registration file referenced by clientIdHash.
|
||||
func LoadKiroIDEToken() (*KiroTokenData, error) {
|
||||
homeDir, err := os.UserHomeDir()
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to get home directory: %w", err)
|
||||
}
|
||||
|
||||
tokenPath := filepath.Join(homeDir, KiroIDETokenFile)
|
||||
data, err := os.ReadFile(tokenPath)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to read Kiro IDE token file (%s): %w", tokenPath, err)
|
||||
}
|
||||
|
||||
var token KiroTokenData
|
||||
if err := json.Unmarshal(data, &token); err != nil {
|
||||
return nil, fmt.Errorf("failed to parse Kiro IDE token: %w", err)
|
||||
}
|
||||
|
||||
if token.AccessToken == "" {
|
||||
return nil, fmt.Errorf("access token is empty in Kiro IDE token file")
|
||||
}
|
||||
|
||||
// Normalize AuthMethod to lowercase (Kiro IDE uses "IdC" but we expect "idc")
|
||||
token.AuthMethod = strings.ToLower(token.AuthMethod)
|
||||
|
||||
// For Enterprise Kiro IDE (IDC auth), load clientId and clientSecret from device registration
|
||||
// The device registration file is located at ~/.aws/sso/cache/{clientIdHash}.json
|
||||
if token.ClientIDHash != "" && token.ClientID == "" {
|
||||
if err := loadDeviceRegistration(homeDir, token.ClientIDHash, &token); err != nil {
|
||||
// Log warning but don't fail - token might still work for some operations
|
||||
fmt.Printf("warning: failed to load device registration for clientIdHash %s: %v\n", token.ClientIDHash, err)
|
||||
}
|
||||
}
|
||||
|
||||
return &token, nil
|
||||
}
|
||||
|
||||
// loadDeviceRegistration loads clientId and clientSecret from the device registration file.
|
||||
// Enterprise Kiro IDE stores these in ~/.aws/sso/cache/{clientIdHash}.json
|
||||
func loadDeviceRegistration(homeDir, clientIDHash string, token *KiroTokenData) error {
|
||||
if clientIDHash == "" {
|
||||
return fmt.Errorf("clientIdHash is empty")
|
||||
}
|
||||
|
||||
// Sanitize clientIdHash to prevent path traversal
|
||||
if strings.Contains(clientIDHash, "/") || strings.Contains(clientIDHash, "\\") || strings.Contains(clientIDHash, "..") {
|
||||
return fmt.Errorf("invalid clientIdHash: contains path separator")
|
||||
}
|
||||
|
||||
deviceRegPath := filepath.Join(homeDir, ".aws", "sso", "cache", clientIDHash+".json")
|
||||
data, err := os.ReadFile(deviceRegPath)
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to read device registration file (%s): %w", deviceRegPath, err)
|
||||
}
|
||||
|
||||
// Device registration file structure
|
||||
var deviceReg struct {
|
||||
ClientID string `json:"clientId"`
|
||||
ClientSecret string `json:"clientSecret"`
|
||||
ExpiresAt string `json:"expiresAt"`
|
||||
}
|
||||
|
||||
if err := json.Unmarshal(data, &deviceReg); err != nil {
|
||||
return fmt.Errorf("failed to parse device registration: %w", err)
|
||||
}
|
||||
|
||||
if deviceReg.ClientID == "" || deviceReg.ClientSecret == "" {
|
||||
return fmt.Errorf("device registration missing clientId or clientSecret")
|
||||
}
|
||||
|
||||
token.ClientID = deviceReg.ClientID
|
||||
token.ClientSecret = deviceReg.ClientSecret
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// LoadKiroTokenFromPath loads token data from a custom path.
|
||||
// This supports multiple accounts by allowing different token files.
|
||||
// For Enterprise Kiro IDE (IDC auth), it also loads clientId and clientSecret
|
||||
// from the device registration file referenced by clientIdHash.
|
||||
func LoadKiroTokenFromPath(tokenPath string) (*KiroTokenData, error) {
|
||||
homeDir, err := os.UserHomeDir()
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to get home directory: %w", err)
|
||||
}
|
||||
|
||||
// Expand ~ to home directory
|
||||
if len(tokenPath) > 0 && tokenPath[0] == '~' {
|
||||
tokenPath = filepath.Join(homeDir, tokenPath[1:])
|
||||
}
|
||||
|
||||
data, err := os.ReadFile(tokenPath)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to read token file (%s): %w", tokenPath, err)
|
||||
}
|
||||
|
||||
var token KiroTokenData
|
||||
if err := json.Unmarshal(data, &token); err != nil {
|
||||
return nil, fmt.Errorf("failed to parse token file: %w", err)
|
||||
}
|
||||
|
||||
if token.AccessToken == "" {
|
||||
return nil, fmt.Errorf("access token is empty in token file")
|
||||
}
|
||||
|
||||
// Normalize AuthMethod to lowercase (Kiro IDE uses "IdC" but we expect "idc")
|
||||
token.AuthMethod = strings.ToLower(token.AuthMethod)
|
||||
|
||||
// For Enterprise Kiro IDE (IDC auth), load clientId and clientSecret from device registration
|
||||
if token.ClientIDHash != "" && token.ClientID == "" {
|
||||
if err := loadDeviceRegistration(homeDir, token.ClientIDHash, &token); err != nil {
|
||||
// Log warning but don't fail - token might still work for some operations
|
||||
fmt.Printf("warning: failed to load device registration for clientIdHash %s: %v\n", token.ClientIDHash, err)
|
||||
}
|
||||
}
|
||||
|
||||
return &token, nil
|
||||
}
|
||||
|
||||
// ListKiroTokenFiles lists all Kiro token files in the cache directory.
|
||||
// This supports multiple accounts by finding all token files.
|
||||
func ListKiroTokenFiles() ([]string, error) {
|
||||
homeDir, err := os.UserHomeDir()
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to get home directory: %w", err)
|
||||
}
|
||||
|
||||
cacheDir := filepath.Join(homeDir, ".aws", "sso", "cache")
|
||||
|
||||
// Check if directory exists
|
||||
if _, err := os.Stat(cacheDir); os.IsNotExist(err) {
|
||||
return nil, nil // No token files
|
||||
}
|
||||
|
||||
entries, err := os.ReadDir(cacheDir)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to read cache directory: %w", err)
|
||||
}
|
||||
|
||||
var tokenFiles []string
|
||||
for _, entry := range entries {
|
||||
if entry.IsDir() {
|
||||
continue
|
||||
}
|
||||
name := entry.Name()
|
||||
// Look for kiro token files only (avoid matching unrelated AWS SSO cache files)
|
||||
if strings.HasSuffix(name, ".json") && strings.HasPrefix(name, "kiro") {
|
||||
tokenFiles = append(tokenFiles, filepath.Join(cacheDir, name))
|
||||
}
|
||||
}
|
||||
|
||||
return tokenFiles, nil
|
||||
}
|
||||
|
||||
// LoadAllKiroTokens loads all Kiro tokens from the cache directory.
|
||||
// This supports multiple accounts.
|
||||
func LoadAllKiroTokens() ([]*KiroTokenData, error) {
|
||||
files, err := ListKiroTokenFiles()
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
var tokens []*KiroTokenData
|
||||
for _, file := range files {
|
||||
token, err := LoadKiroTokenFromPath(file)
|
||||
if err != nil {
|
||||
// Skip invalid token files
|
||||
continue
|
||||
}
|
||||
tokens = append(tokens, token)
|
||||
}
|
||||
|
||||
return tokens, nil
|
||||
}
|
||||
|
||||
// JWTClaims represents the claims we care about from a JWT token.
|
||||
// JWT tokens from Kiro/AWS contain user information in the payload.
|
||||
type JWTClaims struct {
|
||||
Email string `json:"email,omitempty"`
|
||||
Sub string `json:"sub,omitempty"`
|
||||
PreferredUser string `json:"preferred_username,omitempty"`
|
||||
Name string `json:"name,omitempty"`
|
||||
Iss string `json:"iss,omitempty"`
|
||||
}
|
||||
|
||||
// ExtractEmailFromJWT extracts the user's email from a JWT access token.
|
||||
// JWT tokens typically have format: header.payload.signature
|
||||
// The payload is base64url-encoded JSON containing user claims.
|
||||
func ExtractEmailFromJWT(accessToken string) string {
|
||||
if accessToken == "" {
|
||||
return ""
|
||||
}
|
||||
|
||||
// JWT format: header.payload.signature
|
||||
parts := strings.Split(accessToken, ".")
|
||||
if len(parts) != 3 {
|
||||
return ""
|
||||
}
|
||||
|
||||
// Decode the payload (second part)
|
||||
payload := parts[1]
|
||||
|
||||
// Add padding if needed (base64url requires padding)
|
||||
switch len(payload) % 4 {
|
||||
case 2:
|
||||
payload += "=="
|
||||
case 3:
|
||||
payload += "="
|
||||
}
|
||||
|
||||
decoded, err := base64.URLEncoding.DecodeString(payload)
|
||||
if err != nil {
|
||||
// Try RawURLEncoding (no padding)
|
||||
decoded, err = base64.RawURLEncoding.DecodeString(parts[1])
|
||||
if err != nil {
|
||||
return ""
|
||||
}
|
||||
}
|
||||
|
||||
var claims JWTClaims
|
||||
if err := json.Unmarshal(decoded, &claims); err != nil {
|
||||
return ""
|
||||
}
|
||||
|
||||
// Return email if available
|
||||
if claims.Email != "" {
|
||||
return claims.Email
|
||||
}
|
||||
|
||||
// Fallback to preferred_username (some providers use this)
|
||||
if claims.PreferredUser != "" && strings.Contains(claims.PreferredUser, "@") {
|
||||
return claims.PreferredUser
|
||||
}
|
||||
|
||||
// Fallback to sub if it looks like an email
|
||||
if claims.Sub != "" && strings.Contains(claims.Sub, "@") {
|
||||
return claims.Sub
|
||||
}
|
||||
|
||||
return ""
|
||||
}
|
||||
|
||||
// SanitizeEmailForFilename sanitizes an email address for use in a filename.
|
||||
// Replaces special characters with underscores and prevents path traversal attacks.
|
||||
// Also handles URL-encoded characters to prevent encoded path traversal attempts.
|
||||
func SanitizeEmailForFilename(email string) string {
|
||||
if email == "" {
|
||||
return ""
|
||||
}
|
||||
|
||||
result := email
|
||||
|
||||
// First, handle URL-encoded path traversal attempts (%2F, %2E, %5C, etc.)
|
||||
// This prevents encoded characters from bypassing the sanitization.
|
||||
// Note: We replace % last to catch any remaining encodings including double-encoding (%252F)
|
||||
result = strings.ReplaceAll(result, "%2F", "_") // /
|
||||
result = strings.ReplaceAll(result, "%2f", "_")
|
||||
result = strings.ReplaceAll(result, "%5C", "_") // \
|
||||
result = strings.ReplaceAll(result, "%5c", "_")
|
||||
result = strings.ReplaceAll(result, "%2E", "_") // .
|
||||
result = strings.ReplaceAll(result, "%2e", "_")
|
||||
result = strings.ReplaceAll(result, "%00", "_") // null byte
|
||||
result = strings.ReplaceAll(result, "%", "_") // Catch remaining % to prevent double-encoding attacks
|
||||
|
||||
// Replace characters that are problematic in filenames
|
||||
// Keep @ and . in middle but replace other special characters
|
||||
for _, char := range []string{"/", "\\", ":", "*", "?", "\"", "<", ">", "|", " ", "\x00"} {
|
||||
result = strings.ReplaceAll(result, char, "_")
|
||||
}
|
||||
|
||||
// Prevent path traversal: replace leading dots in each path component
|
||||
// This handles cases like "../../../etc/passwd" → "_.._.._.._etc_passwd"
|
||||
parts := strings.Split(result, "_")
|
||||
for i, part := range parts {
|
||||
for strings.HasPrefix(part, ".") {
|
||||
part = "_" + part[1:]
|
||||
}
|
||||
parts[i] = part
|
||||
}
|
||||
result = strings.Join(parts, "_")
|
||||
|
||||
return result
|
||||
}
|
||||
|
||||
// ExtractIDCIdentifier extracts a unique identifier from IDC startUrl.
|
||||
// Examples:
|
||||
// - "https://d-1234567890.awsapps.com/start" -> "d-1234567890"
|
||||
// - "https://my-company.awsapps.com/start" -> "my-company"
|
||||
// - "https://acme-corp.awsapps.com/start" -> "acme-corp"
|
||||
func ExtractIDCIdentifier(startURL string) string {
|
||||
if startURL == "" {
|
||||
return ""
|
||||
}
|
||||
|
||||
// Remove protocol prefix
|
||||
url := strings.TrimPrefix(startURL, "https://")
|
||||
url = strings.TrimPrefix(url, "http://")
|
||||
|
||||
// Extract subdomain (first part before the first dot)
|
||||
// Format: {identifier}.awsapps.com/start
|
||||
parts := strings.Split(url, ".")
|
||||
if len(parts) > 0 && parts[0] != "" {
|
||||
identifier := parts[0]
|
||||
// Sanitize for filename safety
|
||||
identifier = strings.ReplaceAll(identifier, "/", "_")
|
||||
identifier = strings.ReplaceAll(identifier, "\\", "_")
|
||||
identifier = strings.ReplaceAll(identifier, ":", "_")
|
||||
return identifier
|
||||
}
|
||||
|
||||
return ""
|
||||
}
|
||||
|
||||
// GenerateTokenFileName generates a unique filename for token storage.
|
||||
// Priority: email > startUrl identifier (for IDC) > authMethod only
|
||||
// Email is unique, so no sequence suffix needed. Sequence is only added
|
||||
// when email is unavailable to prevent filename collisions.
|
||||
// Format: kiro-{authMethod}-{identifier}[-{seq}].json
|
||||
func GenerateTokenFileName(tokenData *KiroTokenData) string {
|
||||
authMethod := tokenData.AuthMethod
|
||||
if authMethod == "" {
|
||||
authMethod = "unknown"
|
||||
}
|
||||
|
||||
// Priority 1: Use email if available (no sequence needed, email is unique)
|
||||
if tokenData.Email != "" {
|
||||
// Sanitize email for filename (replace @ and . with -)
|
||||
sanitizedEmail := tokenData.Email
|
||||
sanitizedEmail = strings.ReplaceAll(sanitizedEmail, "@", "-")
|
||||
sanitizedEmail = strings.ReplaceAll(sanitizedEmail, ".", "-")
|
||||
return fmt.Sprintf("kiro-%s-%s.json", authMethod, sanitizedEmail)
|
||||
}
|
||||
|
||||
// Generate sequence only when email is unavailable
|
||||
seq := time.Now().UnixNano() % 100000
|
||||
|
||||
// Priority 2: For IDC, use startUrl identifier with sequence
|
||||
if authMethod == "idc" && tokenData.StartURL != "" {
|
||||
identifier := ExtractIDCIdentifier(tokenData.StartURL)
|
||||
if identifier != "" {
|
||||
return fmt.Sprintf("kiro-%s-%s-%05d.json", authMethod, identifier, seq)
|
||||
}
|
||||
}
|
||||
|
||||
// Priority 3: Fallback to authMethod only with sequence
|
||||
return fmt.Sprintf("kiro-%s-%05d.json", authMethod, seq)
|
||||
}
|
||||
|
||||
// DefaultKiroRegion is the fallback region when none is specified.
|
||||
const DefaultKiroRegion = "us-east-1"
|
||||
|
||||
// GetCodeWhispererLegacyEndpoint returns the legacy CodeWhisperer JSON-RPC endpoint.
|
||||
// This endpoint supports JSON-RPC style requests with x-amz-target headers.
|
||||
// The Q endpoint (q.{region}.amazonaws.com) does NOT support JSON-RPC style.
|
||||
func GetCodeWhispererLegacyEndpoint(region string) string {
|
||||
if region == "" {
|
||||
region = DefaultKiroRegion
|
||||
}
|
||||
return "https://codewhisperer." + region + ".amazonaws.com"
|
||||
}
|
||||
|
||||
// ProfileARN represents a parsed AWS CodeWhisperer profile ARN.
|
||||
// ARN format: arn:partition:service:region:account-id:resource-type/resource-id
|
||||
// Example: arn:aws:codewhisperer:us-east-1:123456789012:profile/ABCDEFGHIJKL
|
||||
type ProfileARN struct {
|
||||
// Raw is the original ARN string
|
||||
Raw string
|
||||
// Partition is the AWS partition (aws)
|
||||
Partition string
|
||||
// Service is the AWS service name (codewhisperer)
|
||||
Service string
|
||||
// Region is the AWS region (us-east-1, ap-southeast-1, etc.)
|
||||
Region string
|
||||
// AccountID is the AWS account ID
|
||||
AccountID string
|
||||
// ResourceType is the resource type (profile)
|
||||
ResourceType string
|
||||
// ResourceID is the resource identifier (e.g., ABCDEFGHIJKL)
|
||||
ResourceID string
|
||||
}
|
||||
|
||||
// ParseProfileARN parses an AWS ARN string into a ProfileARN struct.
|
||||
// Returns nil if the ARN is empty, invalid, or not a codewhisperer ARN.
|
||||
func ParseProfileARN(arn string) *ProfileARN {
|
||||
if arn == "" {
|
||||
return nil
|
||||
}
|
||||
// ARN format: arn:partition:service:region:account-id:resource
|
||||
// Minimum 6 parts separated by ":"
|
||||
parts := strings.Split(arn, ":")
|
||||
if len(parts) < 6 {
|
||||
log.Warnf("invalid ARN format: %s", arn)
|
||||
return nil
|
||||
}
|
||||
// Validate ARN prefix
|
||||
if parts[0] != "arn" {
|
||||
return nil
|
||||
}
|
||||
// Validate partition
|
||||
partition := parts[1]
|
||||
if partition == "" {
|
||||
return nil
|
||||
}
|
||||
// Validate service is codewhisperer
|
||||
service := parts[2]
|
||||
if service != "codewhisperer" {
|
||||
return nil
|
||||
}
|
||||
// Validate region format (must contain "-")
|
||||
region := parts[3]
|
||||
if region == "" || !strings.Contains(region, "-") {
|
||||
return nil
|
||||
}
|
||||
// Account ID
|
||||
accountID := parts[4]
|
||||
|
||||
// Parse resource (format: resource-type/resource-id)
|
||||
// Join remaining parts in case resource contains ":"
|
||||
resource := strings.Join(parts[5:], ":")
|
||||
resourceType := ""
|
||||
resourceID := ""
|
||||
if idx := strings.Index(resource, "/"); idx > 0 {
|
||||
resourceType = resource[:idx]
|
||||
resourceID = resource[idx+1:]
|
||||
} else {
|
||||
resourceType = resource
|
||||
}
|
||||
|
||||
return &ProfileARN{
|
||||
Raw: arn,
|
||||
Partition: partition,
|
||||
Service: service,
|
||||
Region: region,
|
||||
AccountID: accountID,
|
||||
ResourceType: resourceType,
|
||||
ResourceID: resourceID,
|
||||
}
|
||||
}
|
||||
|
||||
// GetKiroAPIEndpoint returns the Q API endpoint for the specified region.
|
||||
// If region is empty, defaults to us-east-1.
|
||||
func GetKiroAPIEndpoint(region string) string {
|
||||
if region == "" {
|
||||
region = DefaultKiroRegion
|
||||
}
|
||||
return "https://q." + region + ".amazonaws.com"
|
||||
}
|
||||
|
||||
// GetKiroAPIEndpointFromProfileArn extracts region from profileArn and returns the endpoint.
|
||||
// Returns default us-east-1 endpoint if region cannot be extracted.
|
||||
func GetKiroAPIEndpointFromProfileArn(profileArn string) string {
|
||||
region := ExtractRegionFromProfileArn(profileArn)
|
||||
return GetKiroAPIEndpoint(region)
|
||||
}
|
||||
|
||||
// ExtractRegionFromProfileArn extracts the AWS region from a ProfileARN string.
|
||||
// Returns empty string if ARN is invalid or region cannot be extracted.
|
||||
func ExtractRegionFromProfileArn(profileArn string) string {
|
||||
parsed := ParseProfileARN(profileArn)
|
||||
if parsed == nil {
|
||||
return ""
|
||||
}
|
||||
return parsed.Region
|
||||
}
|
||||
|
||||
// ExtractRegionFromMetadata extracts API region from auth metadata.
|
||||
// Priority: api_region > profile_arn > DefaultKiroRegion
|
||||
func ExtractRegionFromMetadata(metadata map[string]interface{}) string {
|
||||
if metadata == nil {
|
||||
return DefaultKiroRegion
|
||||
}
|
||||
|
||||
// Priority 1: Explicit api_region override
|
||||
if r, ok := metadata["api_region"].(string); ok && r != "" {
|
||||
return r
|
||||
}
|
||||
|
||||
// Priority 2: Extract from ProfileARN
|
||||
if profileArn, ok := metadata["profile_arn"].(string); ok && profileArn != "" {
|
||||
if region := ExtractRegionFromProfileArn(profileArn); region != "" {
|
||||
return region
|
||||
}
|
||||
}
|
||||
|
||||
return DefaultKiroRegion
|
||||
}
|
||||
|
||||
func buildURL(endpoint, path string, queryParams map[string]string) string {
|
||||
fullURL := fmt.Sprintf("%s/%s", endpoint, path)
|
||||
if len(queryParams) > 0 {
|
||||
values := url.Values{}
|
||||
for key, value := range queryParams {
|
||||
if value == "" {
|
||||
continue
|
||||
}
|
||||
values.Set(key, value)
|
||||
}
|
||||
if encoded := values.Encode(); encoded != "" {
|
||||
fullURL = fullURL + "?" + encoded
|
||||
}
|
||||
}
|
||||
return fullURL
|
||||
}
|
||||
326
internal/auth/kiro/aws_auth.go
Normal file
326
internal/auth/kiro/aws_auth.go
Normal file
@@ -0,0 +1,326 @@
|
||||
// Package kiro provides OAuth2 authentication functionality for AWS CodeWhisperer (Kiro) API.
|
||||
// This package implements token loading, refresh, and API communication with CodeWhisperer.
|
||||
package kiro
|
||||
|
||||
import (
|
||||
"context"
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"io"
|
||||
"net/http"
|
||||
"os"
|
||||
"path/filepath"
|
||||
"strings"
|
||||
"time"
|
||||
|
||||
"github.com/router-for-me/CLIProxyAPI/v6/internal/config"
|
||||
"github.com/router-for-me/CLIProxyAPI/v6/internal/util"
|
||||
log "github.com/sirupsen/logrus"
|
||||
)
|
||||
|
||||
const (
|
||||
pathGetUsageLimits = "getUsageLimits"
|
||||
pathListAvailableModels = "ListAvailableModels"
|
||||
)
|
||||
|
||||
// KiroAuth handles AWS CodeWhisperer authentication and API communication.
|
||||
// It provides methods for loading tokens, refreshing expired tokens,
|
||||
// and communicating with the CodeWhisperer API.
|
||||
type KiroAuth struct {
|
||||
httpClient *http.Client
|
||||
}
|
||||
|
||||
// NewKiroAuth creates a new Kiro authentication service.
|
||||
// It initializes the HTTP client with proxy settings from the configuration.
|
||||
//
|
||||
// Parameters:
|
||||
// - cfg: The application configuration containing proxy settings
|
||||
//
|
||||
// Returns:
|
||||
// - *KiroAuth: A new Kiro authentication service instance
|
||||
func NewKiroAuth(cfg *config.Config) *KiroAuth {
|
||||
return &KiroAuth{
|
||||
httpClient: util.SetProxy(&cfg.SDKConfig, &http.Client{Timeout: 120 * time.Second}),
|
||||
}
|
||||
}
|
||||
|
||||
// LoadTokenFromFile loads token data from a file path.
|
||||
// This method reads and parses the token file, expanding ~ to the home directory.
|
||||
//
|
||||
// Parameters:
|
||||
// - tokenFile: Path to the token file (supports ~ expansion)
|
||||
//
|
||||
// Returns:
|
||||
// - *KiroTokenData: The parsed token data
|
||||
// - error: An error if file reading or parsing fails
|
||||
func (k *KiroAuth) LoadTokenFromFile(tokenFile string) (*KiroTokenData, error) {
|
||||
// Expand ~ to home directory
|
||||
if strings.HasPrefix(tokenFile, "~") {
|
||||
home, err := os.UserHomeDir()
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to get home directory: %w", err)
|
||||
}
|
||||
tokenFile = filepath.Join(home, tokenFile[1:])
|
||||
}
|
||||
|
||||
data, err := os.ReadFile(tokenFile)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to read token file: %w", err)
|
||||
}
|
||||
|
||||
var tokenData KiroTokenData
|
||||
if err := json.Unmarshal(data, &tokenData); err != nil {
|
||||
return nil, fmt.Errorf("failed to parse token file: %w", err)
|
||||
}
|
||||
|
||||
return &tokenData, nil
|
||||
}
|
||||
|
||||
// IsTokenExpired checks if the token has expired.
|
||||
// This method parses the expiration timestamp and compares it with the current time.
|
||||
//
|
||||
// Parameters:
|
||||
// - tokenData: The token data to check
|
||||
//
|
||||
// Returns:
|
||||
// - bool: True if the token has expired, false otherwise
|
||||
func (k *KiroAuth) IsTokenExpired(tokenData *KiroTokenData) bool {
|
||||
if tokenData.ExpiresAt == "" {
|
||||
return true
|
||||
}
|
||||
|
||||
expiresAt, err := time.Parse(time.RFC3339, tokenData.ExpiresAt)
|
||||
if err != nil {
|
||||
// Try alternate format
|
||||
expiresAt, err = time.Parse("2006-01-02T15:04:05.000Z", tokenData.ExpiresAt)
|
||||
if err != nil {
|
||||
return true
|
||||
}
|
||||
}
|
||||
|
||||
return time.Now().After(expiresAt)
|
||||
}
|
||||
|
||||
// makeRequest sends a REST-style GET request to the CodeWhisperer API.
|
||||
//
|
||||
// Parameters:
|
||||
// - ctx: The context for the request
|
||||
// - path: The API path (e.g., "getUsageLimits")
|
||||
// - tokenData: The token data containing access token, refresh token, and profile ARN
|
||||
// - queryParams: Query parameters to add to the URL
|
||||
//
|
||||
// Returns:
|
||||
// - []byte: The response body
|
||||
// - error: An error if the request fails
|
||||
func (k *KiroAuth) makeRequest(ctx context.Context, path string, tokenData *KiroTokenData, queryParams map[string]string) ([]byte, error) {
|
||||
// Get endpoint from profileArn (defaults to us-east-1 if empty)
|
||||
profileArn := queryParams["profileArn"]
|
||||
endpoint := GetKiroAPIEndpointFromProfileArn(profileArn)
|
||||
url := buildURL(endpoint, path, queryParams)
|
||||
|
||||
req, err := http.NewRequestWithContext(ctx, http.MethodGet, url, nil)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to create request: %w", err)
|
||||
}
|
||||
|
||||
accountKey := GetAccountKey(tokenData.ClientID, tokenData.RefreshToken)
|
||||
setRuntimeHeaders(req, tokenData.AccessToken, accountKey)
|
||||
|
||||
resp, err := k.httpClient.Do(req)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("request failed: %w", err)
|
||||
}
|
||||
defer func() {
|
||||
if errClose := resp.Body.Close(); errClose != nil {
|
||||
log.Errorf("failed to close response body: %v", errClose)
|
||||
}
|
||||
}()
|
||||
|
||||
body, err := io.ReadAll(resp.Body)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to read response: %w", err)
|
||||
}
|
||||
|
||||
if resp.StatusCode != http.StatusOK {
|
||||
return nil, fmt.Errorf("API error (status %d): %s", resp.StatusCode, string(body))
|
||||
}
|
||||
|
||||
return body, nil
|
||||
}
|
||||
|
||||
// GetUsageLimits retrieves usage information from the CodeWhisperer API.
|
||||
// This method fetches the current usage statistics and subscription information.
|
||||
//
|
||||
// Parameters:
|
||||
// - ctx: The context for the request
|
||||
// - tokenData: The token data containing access token and profile ARN
|
||||
//
|
||||
// Returns:
|
||||
// - *KiroUsageInfo: The usage information
|
||||
// - error: An error if the request fails
|
||||
func (k *KiroAuth) GetUsageLimits(ctx context.Context, tokenData *KiroTokenData) (*KiroUsageInfo, error) {
|
||||
queryParams := map[string]string{
|
||||
"origin": "AI_EDITOR",
|
||||
"profileArn": tokenData.ProfileArn,
|
||||
"resourceType": "AGENTIC_REQUEST",
|
||||
}
|
||||
|
||||
body, err := k.makeRequest(ctx, pathGetUsageLimits, tokenData, queryParams)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
var result struct {
|
||||
SubscriptionInfo struct {
|
||||
SubscriptionTitle string `json:"subscriptionTitle"`
|
||||
} `json:"subscriptionInfo"`
|
||||
UsageBreakdownList []struct {
|
||||
CurrentUsageWithPrecision float64 `json:"currentUsageWithPrecision"`
|
||||
UsageLimitWithPrecision float64 `json:"usageLimitWithPrecision"`
|
||||
} `json:"usageBreakdownList"`
|
||||
NextDateReset float64 `json:"nextDateReset"`
|
||||
}
|
||||
|
||||
if err := json.Unmarshal(body, &result); err != nil {
|
||||
return nil, fmt.Errorf("failed to parse usage response: %w", err)
|
||||
}
|
||||
|
||||
usage := &KiroUsageInfo{
|
||||
SubscriptionTitle: result.SubscriptionInfo.SubscriptionTitle,
|
||||
NextReset: fmt.Sprintf("%v", result.NextDateReset),
|
||||
}
|
||||
|
||||
if len(result.UsageBreakdownList) > 0 {
|
||||
usage.CurrentUsage = result.UsageBreakdownList[0].CurrentUsageWithPrecision
|
||||
usage.UsageLimit = result.UsageBreakdownList[0].UsageLimitWithPrecision
|
||||
}
|
||||
|
||||
return usage, nil
|
||||
}
|
||||
|
||||
// ListAvailableModels retrieves available models from the CodeWhisperer API.
|
||||
// This method fetches the list of AI models available for the authenticated user.
|
||||
//
|
||||
// Parameters:
|
||||
// - ctx: The context for the request
|
||||
// - tokenData: The token data containing access token and profile ARN
|
||||
//
|
||||
// Returns:
|
||||
// - []*KiroModel: The list of available models
|
||||
// - error: An error if the request fails
|
||||
func (k *KiroAuth) ListAvailableModels(ctx context.Context, tokenData *KiroTokenData) ([]*KiroModel, error) {
|
||||
queryParams := map[string]string{
|
||||
"origin": "AI_EDITOR",
|
||||
"profileArn": tokenData.ProfileArn,
|
||||
}
|
||||
|
||||
body, err := k.makeRequest(ctx, pathListAvailableModels, tokenData, queryParams)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
var result struct {
|
||||
Models []struct {
|
||||
ModelID string `json:"modelId"`
|
||||
ModelName string `json:"modelName"`
|
||||
Description string `json:"description"`
|
||||
RateMultiplier float64 `json:"rateMultiplier"`
|
||||
RateUnit string `json:"rateUnit"`
|
||||
TokenLimits *struct {
|
||||
MaxInputTokens int `json:"maxInputTokens"`
|
||||
} `json:"tokenLimits"`
|
||||
} `json:"models"`
|
||||
}
|
||||
|
||||
if err := json.Unmarshal(body, &result); err != nil {
|
||||
return nil, fmt.Errorf("failed to parse models response: %w", err)
|
||||
}
|
||||
|
||||
models := make([]*KiroModel, 0, len(result.Models))
|
||||
for _, m := range result.Models {
|
||||
maxInputTokens := 0
|
||||
if m.TokenLimits != nil {
|
||||
maxInputTokens = m.TokenLimits.MaxInputTokens
|
||||
}
|
||||
models = append(models, &KiroModel{
|
||||
ModelID: m.ModelID,
|
||||
ModelName: m.ModelName,
|
||||
Description: m.Description,
|
||||
RateMultiplier: m.RateMultiplier,
|
||||
RateUnit: m.RateUnit,
|
||||
MaxInputTokens: maxInputTokens,
|
||||
})
|
||||
}
|
||||
|
||||
return models, nil
|
||||
}
|
||||
|
||||
// CreateTokenStorage creates a new KiroTokenStorage from token data.
|
||||
// This method converts the token data into a storage structure suitable for persistence.
|
||||
//
|
||||
// Parameters:
|
||||
// - tokenData: The token data to convert
|
||||
//
|
||||
// Returns:
|
||||
// - *KiroTokenStorage: A new token storage instance
|
||||
func (k *KiroAuth) CreateTokenStorage(tokenData *KiroTokenData) *KiroTokenStorage {
|
||||
return &KiroTokenStorage{
|
||||
AccessToken: tokenData.AccessToken,
|
||||
RefreshToken: tokenData.RefreshToken,
|
||||
ProfileArn: tokenData.ProfileArn,
|
||||
ExpiresAt: tokenData.ExpiresAt,
|
||||
AuthMethod: tokenData.AuthMethod,
|
||||
Provider: tokenData.Provider,
|
||||
LastRefresh: time.Now().Format(time.RFC3339),
|
||||
ClientID: tokenData.ClientID,
|
||||
ClientSecret: tokenData.ClientSecret,
|
||||
Region: tokenData.Region,
|
||||
StartURL: tokenData.StartURL,
|
||||
Email: tokenData.Email,
|
||||
}
|
||||
}
|
||||
|
||||
// ValidateToken checks if the token is valid by making a test API call.
|
||||
// This method verifies the token by attempting to fetch usage limits.
|
||||
//
|
||||
// Parameters:
|
||||
// - ctx: The context for the request
|
||||
// - tokenData: The token data to validate
|
||||
//
|
||||
// Returns:
|
||||
// - error: An error if the token is invalid
|
||||
func (k *KiroAuth) ValidateToken(ctx context.Context, tokenData *KiroTokenData) error {
|
||||
_, err := k.GetUsageLimits(ctx, tokenData)
|
||||
return err
|
||||
}
|
||||
|
||||
// UpdateTokenStorage updates an existing token storage with new token data.
|
||||
// This method refreshes the token storage with newly obtained access and refresh tokens.
|
||||
//
|
||||
// Parameters:
|
||||
// - storage: The existing token storage to update
|
||||
// - tokenData: The new token data to apply
|
||||
func (k *KiroAuth) UpdateTokenStorage(storage *KiroTokenStorage, tokenData *KiroTokenData) {
|
||||
storage.AccessToken = tokenData.AccessToken
|
||||
storage.RefreshToken = tokenData.RefreshToken
|
||||
storage.ProfileArn = tokenData.ProfileArn
|
||||
storage.ExpiresAt = tokenData.ExpiresAt
|
||||
storage.AuthMethod = tokenData.AuthMethod
|
||||
storage.Provider = tokenData.Provider
|
||||
storage.LastRefresh = time.Now().Format(time.RFC3339)
|
||||
if tokenData.ClientID != "" {
|
||||
storage.ClientID = tokenData.ClientID
|
||||
}
|
||||
if tokenData.ClientSecret != "" {
|
||||
storage.ClientSecret = tokenData.ClientSecret
|
||||
}
|
||||
if tokenData.Region != "" {
|
||||
storage.Region = tokenData.Region
|
||||
}
|
||||
if tokenData.StartURL != "" {
|
||||
storage.StartURL = tokenData.StartURL
|
||||
}
|
||||
if tokenData.Email != "" {
|
||||
storage.Email = tokenData.Email
|
||||
}
|
||||
}
|
||||
751
internal/auth/kiro/aws_test.go
Normal file
751
internal/auth/kiro/aws_test.go
Normal file
@@ -0,0 +1,751 @@
|
||||
package kiro
|
||||
|
||||
import (
|
||||
"encoding/base64"
|
||||
"encoding/json"
|
||||
"strings"
|
||||
"testing"
|
||||
)
|
||||
|
||||
func TestExtractEmailFromJWT(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
token string
|
||||
expected string
|
||||
}{
|
||||
{
|
||||
name: "Empty token",
|
||||
token: "",
|
||||
expected: "",
|
||||
},
|
||||
{
|
||||
name: "Invalid token format",
|
||||
token: "not.a.valid.jwt",
|
||||
expected: "",
|
||||
},
|
||||
{
|
||||
name: "Invalid token - not base64",
|
||||
token: "xxx.yyy.zzz",
|
||||
expected: "",
|
||||
},
|
||||
{
|
||||
name: "Valid JWT with email",
|
||||
token: createTestJWT(map[string]any{"email": "test@example.com", "sub": "user123"}),
|
||||
expected: "test@example.com",
|
||||
},
|
||||
{
|
||||
name: "JWT without email but with preferred_username",
|
||||
token: createTestJWT(map[string]any{"preferred_username": "user@domain.com", "sub": "user123"}),
|
||||
expected: "user@domain.com",
|
||||
},
|
||||
{
|
||||
name: "JWT with email-like sub",
|
||||
token: createTestJWT(map[string]any{"sub": "another@test.com"}),
|
||||
expected: "another@test.com",
|
||||
},
|
||||
{
|
||||
name: "JWT without any email fields",
|
||||
token: createTestJWT(map[string]any{"sub": "user123", "name": "Test User"}),
|
||||
expected: "",
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
result := ExtractEmailFromJWT(tt.token)
|
||||
if result != tt.expected {
|
||||
t.Errorf("ExtractEmailFromJWT() = %q, want %q", result, tt.expected)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestSanitizeEmailForFilename(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
email string
|
||||
expected string
|
||||
}{
|
||||
{
|
||||
name: "Empty email",
|
||||
email: "",
|
||||
expected: "",
|
||||
},
|
||||
{
|
||||
name: "Simple email",
|
||||
email: "user@example.com",
|
||||
expected: "user@example.com",
|
||||
},
|
||||
{
|
||||
name: "Email with space",
|
||||
email: "user name@example.com",
|
||||
expected: "user_name@example.com",
|
||||
},
|
||||
{
|
||||
name: "Email with special chars",
|
||||
email: "user:name@example.com",
|
||||
expected: "user_name@example.com",
|
||||
},
|
||||
{
|
||||
name: "Email with multiple special chars",
|
||||
email: "user/name:test@example.com",
|
||||
expected: "user_name_test@example.com",
|
||||
},
|
||||
{
|
||||
name: "Path traversal attempt",
|
||||
email: "../../../etc/passwd",
|
||||
expected: "_.__.__._etc_passwd",
|
||||
},
|
||||
{
|
||||
name: "Path traversal with backslash",
|
||||
email: `..\..\..\..\windows\system32`,
|
||||
expected: "_.__.__.__._windows_system32",
|
||||
},
|
||||
{
|
||||
name: "Null byte injection attempt",
|
||||
email: "user\x00@evil.com",
|
||||
expected: "user_@evil.com",
|
||||
},
|
||||
// URL-encoded path traversal tests
|
||||
{
|
||||
name: "URL-encoded slash",
|
||||
email: "user%2Fpath@example.com",
|
||||
expected: "user_path@example.com",
|
||||
},
|
||||
{
|
||||
name: "URL-encoded backslash",
|
||||
email: "user%5Cpath@example.com",
|
||||
expected: "user_path@example.com",
|
||||
},
|
||||
{
|
||||
name: "URL-encoded dot",
|
||||
email: "%2E%2E%2Fetc%2Fpasswd",
|
||||
expected: "___etc_passwd",
|
||||
},
|
||||
{
|
||||
name: "URL-encoded null",
|
||||
email: "user%00@evil.com",
|
||||
expected: "user_@evil.com",
|
||||
},
|
||||
{
|
||||
name: "Double URL-encoding attack",
|
||||
email: "%252F%252E%252E",
|
||||
expected: "_252F_252E_252E", // % replaced with _, remaining chars preserved (safe)
|
||||
},
|
||||
{
|
||||
name: "Mixed case URL-encoding",
|
||||
email: "%2f%2F%5c%5C",
|
||||
expected: "____",
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
result := SanitizeEmailForFilename(tt.email)
|
||||
if result != tt.expected {
|
||||
t.Errorf("SanitizeEmailForFilename() = %q, want %q", result, tt.expected)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
// createTestJWT creates a test JWT token with the given claims
|
||||
func createTestJWT(claims map[string]any) string {
|
||||
header := base64.RawURLEncoding.EncodeToString([]byte(`{"alg":"RS256","typ":"JWT"}`))
|
||||
|
||||
payloadBytes, _ := json.Marshal(claims)
|
||||
payload := base64.RawURLEncoding.EncodeToString(payloadBytes)
|
||||
|
||||
signature := base64.RawURLEncoding.EncodeToString([]byte("fake-signature"))
|
||||
|
||||
return header + "." + payload + "." + signature
|
||||
}
|
||||
|
||||
func TestExtractIDCIdentifier(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
startURL string
|
||||
expected string
|
||||
}{
|
||||
{
|
||||
name: "Empty URL",
|
||||
startURL: "",
|
||||
expected: "",
|
||||
},
|
||||
{
|
||||
name: "Standard IDC URL with d- prefix",
|
||||
startURL: "https://d-1234567890.awsapps.com/start",
|
||||
expected: "d-1234567890",
|
||||
},
|
||||
{
|
||||
name: "IDC URL with company name",
|
||||
startURL: "https://my-company.awsapps.com/start",
|
||||
expected: "my-company",
|
||||
},
|
||||
{
|
||||
name: "IDC URL with simple name",
|
||||
startURL: "https://acme-corp.awsapps.com/start",
|
||||
expected: "acme-corp",
|
||||
},
|
||||
{
|
||||
name: "IDC URL without https",
|
||||
startURL: "http://d-9876543210.awsapps.com/start",
|
||||
expected: "d-9876543210",
|
||||
},
|
||||
{
|
||||
name: "IDC URL with subdomain only",
|
||||
startURL: "https://test.awsapps.com/start",
|
||||
expected: "test",
|
||||
},
|
||||
{
|
||||
name: "Builder ID URL",
|
||||
startURL: "https://view.awsapps.com/start",
|
||||
expected: "view",
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
result := ExtractIDCIdentifier(tt.startURL)
|
||||
if result != tt.expected {
|
||||
t.Errorf("ExtractIDCIdentifier() = %q, want %q", result, tt.expected)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestGenerateTokenFileName(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
tokenData *KiroTokenData
|
||||
exact string // exact match (for cases with email)
|
||||
prefix string // prefix match (for cases without email, where sequence is appended)
|
||||
}{
|
||||
{
|
||||
name: "IDC with email",
|
||||
tokenData: &KiroTokenData{
|
||||
AuthMethod: "idc",
|
||||
Email: "user@example.com",
|
||||
StartURL: "https://d-1234567890.awsapps.com/start",
|
||||
},
|
||||
exact: "kiro-idc-user-example-com.json",
|
||||
},
|
||||
{
|
||||
name: "IDC without email but with startUrl",
|
||||
tokenData: &KiroTokenData{
|
||||
AuthMethod: "idc",
|
||||
Email: "",
|
||||
StartURL: "https://d-1234567890.awsapps.com/start",
|
||||
},
|
||||
prefix: "kiro-idc-d-1234567890-",
|
||||
},
|
||||
{
|
||||
name: "IDC with company name in startUrl",
|
||||
tokenData: &KiroTokenData{
|
||||
AuthMethod: "idc",
|
||||
Email: "",
|
||||
StartURL: "https://my-company.awsapps.com/start",
|
||||
},
|
||||
prefix: "kiro-idc-my-company-",
|
||||
},
|
||||
{
|
||||
name: "IDC without email and without startUrl",
|
||||
tokenData: &KiroTokenData{
|
||||
AuthMethod: "idc",
|
||||
Email: "",
|
||||
StartURL: "",
|
||||
},
|
||||
prefix: "kiro-idc-",
|
||||
},
|
||||
{
|
||||
name: "Builder ID with email",
|
||||
tokenData: &KiroTokenData{
|
||||
AuthMethod: "builder-id",
|
||||
Email: "user@gmail.com",
|
||||
StartURL: "https://view.awsapps.com/start",
|
||||
},
|
||||
exact: "kiro-builder-id-user-gmail-com.json",
|
||||
},
|
||||
{
|
||||
name: "Builder ID without email",
|
||||
tokenData: &KiroTokenData{
|
||||
AuthMethod: "builder-id",
|
||||
Email: "",
|
||||
StartURL: "https://view.awsapps.com/start",
|
||||
},
|
||||
prefix: "kiro-builder-id-",
|
||||
},
|
||||
{
|
||||
name: "Social auth with email",
|
||||
tokenData: &KiroTokenData{
|
||||
AuthMethod: "google",
|
||||
Email: "user@gmail.com",
|
||||
},
|
||||
exact: "kiro-google-user-gmail-com.json",
|
||||
},
|
||||
{
|
||||
name: "Empty auth method",
|
||||
tokenData: &KiroTokenData{
|
||||
AuthMethod: "",
|
||||
Email: "",
|
||||
},
|
||||
prefix: "kiro-unknown-",
|
||||
},
|
||||
{
|
||||
name: "Email with special characters",
|
||||
tokenData: &KiroTokenData{
|
||||
AuthMethod: "idc",
|
||||
Email: "user.name+tag@sub.example.com",
|
||||
StartURL: "https://d-1234567890.awsapps.com/start",
|
||||
},
|
||||
exact: "kiro-idc-user-name+tag-sub-example-com.json",
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
result := GenerateTokenFileName(tt.tokenData)
|
||||
if tt.exact != "" {
|
||||
if result != tt.exact {
|
||||
t.Errorf("GenerateTokenFileName() = %q, want %q", result, tt.exact)
|
||||
}
|
||||
} else if tt.prefix != "" {
|
||||
if !strings.HasPrefix(result, tt.prefix) || !strings.HasSuffix(result, ".json") {
|
||||
t.Errorf("GenerateTokenFileName() = %q, want prefix %q with .json suffix", result, tt.prefix)
|
||||
}
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestParseProfileARN(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
arn string
|
||||
expected *ProfileARN
|
||||
}{
|
||||
{
|
||||
name: "Empty ARN",
|
||||
arn: "",
|
||||
expected: nil,
|
||||
},
|
||||
{
|
||||
name: "Invalid format - too few parts",
|
||||
arn: "arn:aws:codewhisperer",
|
||||
expected: nil,
|
||||
},
|
||||
{
|
||||
name: "Invalid prefix - not arn",
|
||||
arn: "notarn:aws:codewhisperer:us-east-1:123456789012:profile/ABC",
|
||||
expected: nil,
|
||||
},
|
||||
{
|
||||
name: "Invalid service - not codewhisperer",
|
||||
arn: "arn:aws:s3:us-east-1:123456789012:bucket/mybucket",
|
||||
expected: nil,
|
||||
},
|
||||
{
|
||||
name: "Invalid region - no hyphen",
|
||||
arn: "arn:aws:codewhisperer:useast1:123456789012:profile/ABC",
|
||||
expected: nil,
|
||||
},
|
||||
{
|
||||
name: "Empty partition",
|
||||
arn: "arn::codewhisperer:us-east-1:123456789012:profile/ABC",
|
||||
expected: nil,
|
||||
},
|
||||
{
|
||||
name: "Empty region",
|
||||
arn: "arn:aws:codewhisperer::123456789012:profile/ABC",
|
||||
expected: nil,
|
||||
},
|
||||
{
|
||||
name: "Valid ARN - us-east-1",
|
||||
arn: "arn:aws:codewhisperer:us-east-1:123456789012:profile/ABCDEFGHIJKL",
|
||||
expected: &ProfileARN{
|
||||
Raw: "arn:aws:codewhisperer:us-east-1:123456789012:profile/ABCDEFGHIJKL",
|
||||
Partition: "aws",
|
||||
Service: "codewhisperer",
|
||||
Region: "us-east-1",
|
||||
AccountID: "123456789012",
|
||||
ResourceType: "profile",
|
||||
ResourceID: "ABCDEFGHIJKL",
|
||||
},
|
||||
},
|
||||
{
|
||||
name: "Valid ARN - ap-southeast-1",
|
||||
arn: "arn:aws:codewhisperer:ap-southeast-1:987654321098:profile/ZYXWVUTSRQ",
|
||||
expected: &ProfileARN{
|
||||
Raw: "arn:aws:codewhisperer:ap-southeast-1:987654321098:profile/ZYXWVUTSRQ",
|
||||
Partition: "aws",
|
||||
Service: "codewhisperer",
|
||||
Region: "ap-southeast-1",
|
||||
AccountID: "987654321098",
|
||||
ResourceType: "profile",
|
||||
ResourceID: "ZYXWVUTSRQ",
|
||||
},
|
||||
},
|
||||
{
|
||||
name: "Valid ARN - eu-west-1",
|
||||
arn: "arn:aws:codewhisperer:eu-west-1:111222333444:profile/PROFILE123",
|
||||
expected: &ProfileARN{
|
||||
Raw: "arn:aws:codewhisperer:eu-west-1:111222333444:profile/PROFILE123",
|
||||
Partition: "aws",
|
||||
Service: "codewhisperer",
|
||||
Region: "eu-west-1",
|
||||
AccountID: "111222333444",
|
||||
ResourceType: "profile",
|
||||
ResourceID: "PROFILE123",
|
||||
},
|
||||
},
|
||||
{
|
||||
name: "Valid ARN - aws-cn partition",
|
||||
arn: "arn:aws-cn:codewhisperer:cn-north-1:123456789012:profile/CHINAID",
|
||||
expected: &ProfileARN{
|
||||
Raw: "arn:aws-cn:codewhisperer:cn-north-1:123456789012:profile/CHINAID",
|
||||
Partition: "aws-cn",
|
||||
Service: "codewhisperer",
|
||||
Region: "cn-north-1",
|
||||
AccountID: "123456789012",
|
||||
ResourceType: "profile",
|
||||
ResourceID: "CHINAID",
|
||||
},
|
||||
},
|
||||
{
|
||||
name: "Valid ARN - resource without slash",
|
||||
arn: "arn:aws:codewhisperer:us-west-2:123456789012:profile",
|
||||
expected: &ProfileARN{
|
||||
Raw: "arn:aws:codewhisperer:us-west-2:123456789012:profile",
|
||||
Partition: "aws",
|
||||
Service: "codewhisperer",
|
||||
Region: "us-west-2",
|
||||
AccountID: "123456789012",
|
||||
ResourceType: "profile",
|
||||
ResourceID: "",
|
||||
},
|
||||
},
|
||||
{
|
||||
name: "Valid ARN - resource with colon",
|
||||
arn: "arn:aws:codewhisperer:us-east-1:123456789012:profile/ABC:extra",
|
||||
expected: &ProfileARN{
|
||||
Raw: "arn:aws:codewhisperer:us-east-1:123456789012:profile/ABC:extra",
|
||||
Partition: "aws",
|
||||
Service: "codewhisperer",
|
||||
Region: "us-east-1",
|
||||
AccountID: "123456789012",
|
||||
ResourceType: "profile",
|
||||
ResourceID: "ABC:extra",
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
result := ParseProfileARN(tt.arn)
|
||||
if tt.expected == nil {
|
||||
if result != nil {
|
||||
t.Errorf("ParseProfileARN(%q) = %+v, want nil", tt.arn, result)
|
||||
}
|
||||
return
|
||||
}
|
||||
if result == nil {
|
||||
t.Errorf("ParseProfileARN(%q) = nil, want %+v", tt.arn, tt.expected)
|
||||
return
|
||||
}
|
||||
if result.Raw != tt.expected.Raw {
|
||||
t.Errorf("Raw = %q, want %q", result.Raw, tt.expected.Raw)
|
||||
}
|
||||
if result.Partition != tt.expected.Partition {
|
||||
t.Errorf("Partition = %q, want %q", result.Partition, tt.expected.Partition)
|
||||
}
|
||||
if result.Service != tt.expected.Service {
|
||||
t.Errorf("Service = %q, want %q", result.Service, tt.expected.Service)
|
||||
}
|
||||
if result.Region != tt.expected.Region {
|
||||
t.Errorf("Region = %q, want %q", result.Region, tt.expected.Region)
|
||||
}
|
||||
if result.AccountID != tt.expected.AccountID {
|
||||
t.Errorf("AccountID = %q, want %q", result.AccountID, tt.expected.AccountID)
|
||||
}
|
||||
if result.ResourceType != tt.expected.ResourceType {
|
||||
t.Errorf("ResourceType = %q, want %q", result.ResourceType, tt.expected.ResourceType)
|
||||
}
|
||||
if result.ResourceID != tt.expected.ResourceID {
|
||||
t.Errorf("ResourceID = %q, want %q", result.ResourceID, tt.expected.ResourceID)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestExtractRegionFromProfileArn(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
profileArn string
|
||||
expected string
|
||||
}{
|
||||
{
|
||||
name: "Empty ARN",
|
||||
profileArn: "",
|
||||
expected: "",
|
||||
},
|
||||
{
|
||||
name: "Invalid ARN",
|
||||
profileArn: "invalid-arn",
|
||||
expected: "",
|
||||
},
|
||||
{
|
||||
name: "Valid ARN - us-east-1",
|
||||
profileArn: "arn:aws:codewhisperer:us-east-1:123456789012:profile/ABC",
|
||||
expected: "us-east-1",
|
||||
},
|
||||
{
|
||||
name: "Valid ARN - ap-southeast-1",
|
||||
profileArn: "arn:aws:codewhisperer:ap-southeast-1:123456789012:profile/ABC",
|
||||
expected: "ap-southeast-1",
|
||||
},
|
||||
{
|
||||
name: "Valid ARN - eu-central-1",
|
||||
profileArn: "arn:aws:codewhisperer:eu-central-1:123456789012:profile/ABC",
|
||||
expected: "eu-central-1",
|
||||
},
|
||||
{
|
||||
name: "Non-codewhisperer ARN",
|
||||
profileArn: "arn:aws:s3:us-east-1:123456789012:bucket/mybucket",
|
||||
expected: "",
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
result := ExtractRegionFromProfileArn(tt.profileArn)
|
||||
if result != tt.expected {
|
||||
t.Errorf("ExtractRegionFromProfileArn(%q) = %q, want %q", tt.profileArn, result, tt.expected)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestGetKiroAPIEndpoint(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
region string
|
||||
expected string
|
||||
}{
|
||||
{
|
||||
name: "Empty region - defaults to us-east-1",
|
||||
region: "",
|
||||
expected: "https://q.us-east-1.amazonaws.com",
|
||||
},
|
||||
{
|
||||
name: "us-east-1",
|
||||
region: "us-east-1",
|
||||
expected: "https://q.us-east-1.amazonaws.com",
|
||||
},
|
||||
{
|
||||
name: "us-west-2",
|
||||
region: "us-west-2",
|
||||
expected: "https://q.us-west-2.amazonaws.com",
|
||||
},
|
||||
{
|
||||
name: "ap-southeast-1",
|
||||
region: "ap-southeast-1",
|
||||
expected: "https://q.ap-southeast-1.amazonaws.com",
|
||||
},
|
||||
{
|
||||
name: "eu-west-1",
|
||||
region: "eu-west-1",
|
||||
expected: "https://q.eu-west-1.amazonaws.com",
|
||||
},
|
||||
{
|
||||
name: "cn-north-1",
|
||||
region: "cn-north-1",
|
||||
expected: "https://q.cn-north-1.amazonaws.com",
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
result := GetKiroAPIEndpoint(tt.region)
|
||||
if result != tt.expected {
|
||||
t.Errorf("GetKiroAPIEndpoint(%q) = %q, want %q", tt.region, result, tt.expected)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestGetKiroAPIEndpointFromProfileArn(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
profileArn string
|
||||
expected string
|
||||
}{
|
||||
{
|
||||
name: "Empty ARN - defaults to us-east-1",
|
||||
profileArn: "",
|
||||
expected: "https://q.us-east-1.amazonaws.com",
|
||||
},
|
||||
{
|
||||
name: "Invalid ARN - defaults to us-east-1",
|
||||
profileArn: "invalid-arn",
|
||||
expected: "https://q.us-east-1.amazonaws.com",
|
||||
},
|
||||
{
|
||||
name: "Valid ARN - us-east-1",
|
||||
profileArn: "arn:aws:codewhisperer:us-east-1:123456789012:profile/ABC",
|
||||
expected: "https://q.us-east-1.amazonaws.com",
|
||||
},
|
||||
{
|
||||
name: "Valid ARN - ap-southeast-1",
|
||||
profileArn: "arn:aws:codewhisperer:ap-southeast-1:123456789012:profile/ABC",
|
||||
expected: "https://q.ap-southeast-1.amazonaws.com",
|
||||
},
|
||||
{
|
||||
name: "Valid ARN - eu-central-1",
|
||||
profileArn: "arn:aws:codewhisperer:eu-central-1:123456789012:profile/ABC",
|
||||
expected: "https://q.eu-central-1.amazonaws.com",
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
result := GetKiroAPIEndpointFromProfileArn(tt.profileArn)
|
||||
if result != tt.expected {
|
||||
t.Errorf("GetKiroAPIEndpointFromProfileArn(%q) = %q, want %q", tt.profileArn, result, tt.expected)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestGetCodeWhispererLegacyEndpoint(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
region string
|
||||
expected string
|
||||
}{
|
||||
{
|
||||
name: "Empty region - defaults to us-east-1",
|
||||
region: "",
|
||||
expected: "https://codewhisperer.us-east-1.amazonaws.com",
|
||||
},
|
||||
{
|
||||
name: "us-east-1",
|
||||
region: "us-east-1",
|
||||
expected: "https://codewhisperer.us-east-1.amazonaws.com",
|
||||
},
|
||||
{
|
||||
name: "us-west-2",
|
||||
region: "us-west-2",
|
||||
expected: "https://codewhisperer.us-west-2.amazonaws.com",
|
||||
},
|
||||
{
|
||||
name: "ap-northeast-1",
|
||||
region: "ap-northeast-1",
|
||||
expected: "https://codewhisperer.ap-northeast-1.amazonaws.com",
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
result := GetCodeWhispererLegacyEndpoint(tt.region)
|
||||
if result != tt.expected {
|
||||
t.Errorf("GetCodeWhispererLegacyEndpoint(%q) = %q, want %q", tt.region, result, tt.expected)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestExtractRegionFromMetadata(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
metadata map[string]interface{}
|
||||
expected string
|
||||
}{
|
||||
{
|
||||
name: "Nil metadata - defaults to us-east-1",
|
||||
metadata: nil,
|
||||
expected: "us-east-1",
|
||||
},
|
||||
{
|
||||
name: "Empty metadata - defaults to us-east-1",
|
||||
metadata: map[string]interface{}{},
|
||||
expected: "us-east-1",
|
||||
},
|
||||
{
|
||||
name: "Priority 1: api_region override",
|
||||
metadata: map[string]interface{}{
|
||||
"api_region": "eu-west-1",
|
||||
"profile_arn": "arn:aws:codewhisperer:us-east-1:123456789012:profile/ABC",
|
||||
},
|
||||
expected: "eu-west-1",
|
||||
},
|
||||
{
|
||||
name: "Priority 2: profile_arn when api_region is empty",
|
||||
metadata: map[string]interface{}{
|
||||
"api_region": "",
|
||||
"profile_arn": "arn:aws:codewhisperer:ap-southeast-1:123456789012:profile/ABC",
|
||||
},
|
||||
expected: "ap-southeast-1",
|
||||
},
|
||||
{
|
||||
name: "Priority 2: profile_arn when api_region is missing",
|
||||
metadata: map[string]interface{}{
|
||||
"profile_arn": "arn:aws:codewhisperer:eu-central-1:123456789012:profile/ABC",
|
||||
},
|
||||
expected: "eu-central-1",
|
||||
},
|
||||
{
|
||||
name: "Fallback: default when profile_arn is invalid",
|
||||
metadata: map[string]interface{}{
|
||||
"profile_arn": "invalid-arn",
|
||||
},
|
||||
expected: "us-east-1",
|
||||
},
|
||||
{
|
||||
name: "Fallback: default when profile_arn is empty",
|
||||
metadata: map[string]interface{}{
|
||||
"profile_arn": "",
|
||||
},
|
||||
expected: "us-east-1",
|
||||
},
|
||||
{
|
||||
name: "OIDC region is NOT used for API region",
|
||||
metadata: map[string]interface{}{
|
||||
"region": "ap-northeast-2", // OIDC region - should be ignored
|
||||
},
|
||||
expected: "us-east-1",
|
||||
},
|
||||
{
|
||||
name: "api_region takes precedence over OIDC region",
|
||||
metadata: map[string]interface{}{
|
||||
"api_region": "us-west-2",
|
||||
"region": "ap-northeast-2", // OIDC region - should be ignored
|
||||
},
|
||||
expected: "us-west-2",
|
||||
},
|
||||
{
|
||||
name: "Non-string api_region is ignored",
|
||||
metadata: map[string]interface{}{
|
||||
"api_region": 123, // wrong type
|
||||
"profile_arn": "arn:aws:codewhisperer:ap-south-1:123456789012:profile/ABC",
|
||||
},
|
||||
expected: "ap-south-1",
|
||||
},
|
||||
{
|
||||
name: "Non-string profile_arn is ignored",
|
||||
metadata: map[string]interface{}{
|
||||
"profile_arn": 123, // wrong type
|
||||
},
|
||||
expected: "us-east-1",
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
result := ExtractRegionFromMetadata(tt.metadata)
|
||||
if result != tt.expected {
|
||||
t.Errorf("ExtractRegionFromMetadata(%v) = %q, want %q", tt.metadata, result, tt.expected)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
247
internal/auth/kiro/background_refresh.go
Normal file
247
internal/auth/kiro/background_refresh.go
Normal file
@@ -0,0 +1,247 @@
|
||||
package kiro
|
||||
|
||||
import (
|
||||
"context"
|
||||
"log"
|
||||
"strings"
|
||||
"sync"
|
||||
"time"
|
||||
|
||||
"github.com/router-for-me/CLIProxyAPI/v6/internal/config"
|
||||
"golang.org/x/sync/semaphore"
|
||||
)
|
||||
|
||||
type Token struct {
|
||||
ID string
|
||||
AccessToken string
|
||||
RefreshToken string
|
||||
ExpiresAt time.Time
|
||||
LastVerified time.Time
|
||||
ClientID string
|
||||
ClientSecret string
|
||||
AuthMethod string
|
||||
Provider string
|
||||
StartURL string
|
||||
Region string
|
||||
}
|
||||
|
||||
type TokenRepository interface {
|
||||
FindOldestUnverified(limit int) []*Token
|
||||
UpdateToken(token *Token) error
|
||||
}
|
||||
|
||||
type RefresherOption func(*BackgroundRefresher)
|
||||
|
||||
func WithInterval(interval time.Duration) RefresherOption {
|
||||
return func(r *BackgroundRefresher) {
|
||||
r.interval = interval
|
||||
}
|
||||
}
|
||||
|
||||
func WithBatchSize(size int) RefresherOption {
|
||||
return func(r *BackgroundRefresher) {
|
||||
r.batchSize = size
|
||||
}
|
||||
}
|
||||
|
||||
func WithConcurrency(concurrency int) RefresherOption {
|
||||
return func(r *BackgroundRefresher) {
|
||||
r.concurrency = concurrency
|
||||
}
|
||||
}
|
||||
|
||||
type BackgroundRefresher struct {
|
||||
interval time.Duration
|
||||
batchSize int
|
||||
concurrency int
|
||||
tokenRepo TokenRepository
|
||||
stopCh chan struct{}
|
||||
wg sync.WaitGroup
|
||||
oauth *KiroOAuth
|
||||
ssoClient *SSOOIDCClient
|
||||
callbackMu sync.RWMutex // 保护回调函数的并发访问
|
||||
onTokenRefreshed func(tokenID string, tokenData *KiroTokenData) // 刷新成功回调
|
||||
}
|
||||
|
||||
func NewBackgroundRefresher(repo TokenRepository, opts ...RefresherOption) *BackgroundRefresher {
|
||||
r := &BackgroundRefresher{
|
||||
interval: time.Minute,
|
||||
batchSize: 50,
|
||||
concurrency: 10,
|
||||
tokenRepo: repo,
|
||||
stopCh: make(chan struct{}),
|
||||
oauth: nil, // Lazy init - will be set when config available
|
||||
ssoClient: nil, // Lazy init - will be set when config available
|
||||
}
|
||||
for _, opt := range opts {
|
||||
opt(r)
|
||||
}
|
||||
return r
|
||||
}
|
||||
|
||||
// WithConfig sets the configuration for OAuth and SSO clients.
|
||||
func WithConfig(cfg *config.Config) RefresherOption {
|
||||
return func(r *BackgroundRefresher) {
|
||||
r.oauth = NewKiroOAuth(cfg)
|
||||
r.ssoClient = NewSSOOIDCClient(cfg)
|
||||
}
|
||||
}
|
||||
|
||||
// WithOnTokenRefreshed sets the callback function to be called when a token is successfully refreshed.
|
||||
// The callback receives the token ID (filename) and the new token data.
|
||||
// This allows external components (e.g., Watcher) to be notified of token updates.
|
||||
func WithOnTokenRefreshed(callback func(tokenID string, tokenData *KiroTokenData)) RefresherOption {
|
||||
return func(r *BackgroundRefresher) {
|
||||
r.callbackMu.Lock()
|
||||
r.onTokenRefreshed = callback
|
||||
r.callbackMu.Unlock()
|
||||
}
|
||||
}
|
||||
|
||||
func (r *BackgroundRefresher) Start(ctx context.Context) {
|
||||
r.wg.Add(1)
|
||||
go func() {
|
||||
defer r.wg.Done()
|
||||
ticker := time.NewTicker(r.interval)
|
||||
defer ticker.Stop()
|
||||
|
||||
r.refreshBatch(ctx)
|
||||
|
||||
for {
|
||||
select {
|
||||
case <-ctx.Done():
|
||||
return
|
||||
case <-r.stopCh:
|
||||
return
|
||||
case <-ticker.C:
|
||||
r.refreshBatch(ctx)
|
||||
}
|
||||
}
|
||||
}()
|
||||
}
|
||||
|
||||
func (r *BackgroundRefresher) Stop() {
|
||||
close(r.stopCh)
|
||||
r.wg.Wait()
|
||||
}
|
||||
|
||||
func (r *BackgroundRefresher) refreshBatch(ctx context.Context) {
|
||||
tokens := r.tokenRepo.FindOldestUnverified(r.batchSize)
|
||||
if len(tokens) == 0 {
|
||||
return
|
||||
}
|
||||
|
||||
sem := semaphore.NewWeighted(int64(r.concurrency))
|
||||
var wg sync.WaitGroup
|
||||
|
||||
for i, token := range tokens {
|
||||
if i > 0 {
|
||||
select {
|
||||
case <-ctx.Done():
|
||||
return
|
||||
case <-r.stopCh:
|
||||
return
|
||||
case <-time.After(100 * time.Millisecond):
|
||||
}
|
||||
}
|
||||
|
||||
if err := sem.Acquire(ctx, 1); err != nil {
|
||||
return
|
||||
}
|
||||
|
||||
wg.Add(1)
|
||||
go func(t *Token) {
|
||||
defer wg.Done()
|
||||
defer sem.Release(1)
|
||||
r.refreshSingle(ctx, t)
|
||||
}(token)
|
||||
}
|
||||
|
||||
wg.Wait()
|
||||
}
|
||||
|
||||
func (r *BackgroundRefresher) refreshSingle(ctx context.Context, token *Token) {
|
||||
// Normalize auth method to lowercase for case-insensitive matching
|
||||
authMethod := strings.ToLower(token.AuthMethod)
|
||||
|
||||
// Create refresh function based on auth method
|
||||
refreshFunc := func(ctx context.Context) (*KiroTokenData, error) {
|
||||
switch authMethod {
|
||||
case "idc":
|
||||
return r.ssoClient.RefreshTokenWithRegion(
|
||||
ctx,
|
||||
token.ClientID,
|
||||
token.ClientSecret,
|
||||
token.RefreshToken,
|
||||
token.Region,
|
||||
token.StartURL,
|
||||
)
|
||||
case "builder-id":
|
||||
return r.ssoClient.RefreshToken(
|
||||
ctx,
|
||||
token.ClientID,
|
||||
token.ClientSecret,
|
||||
token.RefreshToken,
|
||||
)
|
||||
default:
|
||||
return r.oauth.RefreshTokenWithFingerprint(ctx, token.RefreshToken, token.ID)
|
||||
}
|
||||
}
|
||||
|
||||
// Use graceful degradation for better reliability
|
||||
result := RefreshWithGracefulDegradation(
|
||||
ctx,
|
||||
refreshFunc,
|
||||
token.AccessToken,
|
||||
token.ExpiresAt,
|
||||
)
|
||||
|
||||
if result.Error != nil {
|
||||
log.Printf("failed to refresh token %s: %v", token.ID, result.Error)
|
||||
return
|
||||
}
|
||||
|
||||
newTokenData := result.TokenData
|
||||
if result.UsedFallback {
|
||||
log.Printf("token %s: using existing token as fallback (refresh failed but token still valid)", token.ID)
|
||||
// Don't update the token file if we're using fallback
|
||||
// Just update LastVerified to prevent immediate re-check
|
||||
token.LastVerified = time.Now()
|
||||
return
|
||||
}
|
||||
|
||||
token.AccessToken = newTokenData.AccessToken
|
||||
if newTokenData.RefreshToken != "" {
|
||||
token.RefreshToken = newTokenData.RefreshToken
|
||||
}
|
||||
token.LastVerified = time.Now()
|
||||
|
||||
if newTokenData.ExpiresAt != "" {
|
||||
if expTime, parseErr := time.Parse(time.RFC3339, newTokenData.ExpiresAt); parseErr == nil {
|
||||
token.ExpiresAt = expTime
|
||||
}
|
||||
}
|
||||
|
||||
if err := r.tokenRepo.UpdateToken(token); err != nil {
|
||||
log.Printf("failed to update token %s: %v", token.ID, err)
|
||||
return
|
||||
}
|
||||
|
||||
// 方案 A: 刷新成功后触发回调,通知 Watcher 更新内存中的 Auth 对象
|
||||
r.callbackMu.RLock()
|
||||
callback := r.onTokenRefreshed
|
||||
r.callbackMu.RUnlock()
|
||||
|
||||
if callback != nil {
|
||||
// 使用 defer recover 隔离回调 panic,防止崩溃整个进程
|
||||
func() {
|
||||
defer func() {
|
||||
if rec := recover(); rec != nil {
|
||||
log.Printf("background refresh: callback panic for token %s: %v", token.ID, rec)
|
||||
}
|
||||
}()
|
||||
log.Printf("background refresh: notifying token refresh callback for %s", token.ID)
|
||||
callback(token.ID, newTokenData)
|
||||
}()
|
||||
}
|
||||
}
|
||||
153
internal/auth/kiro/codewhisperer_client.go
Normal file
153
internal/auth/kiro/codewhisperer_client.go
Normal file
@@ -0,0 +1,153 @@
|
||||
// Package kiro provides CodeWhisperer API client for fetching user info.
|
||||
package kiro
|
||||
|
||||
import (
|
||||
"context"
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"io"
|
||||
"net/http"
|
||||
"time"
|
||||
|
||||
"github.com/router-for-me/CLIProxyAPI/v6/internal/config"
|
||||
"github.com/router-for-me/CLIProxyAPI/v6/internal/util"
|
||||
log "github.com/sirupsen/logrus"
|
||||
)
|
||||
|
||||
// CodeWhispererClient handles CodeWhisperer API calls.
|
||||
type CodeWhispererClient struct {
|
||||
httpClient *http.Client
|
||||
}
|
||||
|
||||
// UsageLimitsResponse represents the getUsageLimits API response.
|
||||
type UsageLimitsResponse struct {
|
||||
DaysUntilReset *int `json:"daysUntilReset,omitempty"`
|
||||
NextDateReset *float64 `json:"nextDateReset,omitempty"`
|
||||
UserInfo *UserInfo `json:"userInfo,omitempty"`
|
||||
SubscriptionInfo *SubscriptionInfo `json:"subscriptionInfo,omitempty"`
|
||||
UsageBreakdownList []UsageBreakdown `json:"usageBreakdownList,omitempty"`
|
||||
}
|
||||
|
||||
// UserInfo contains user information from the API.
|
||||
type UserInfo struct {
|
||||
Email string `json:"email,omitempty"`
|
||||
UserID string `json:"userId,omitempty"`
|
||||
}
|
||||
|
||||
// SubscriptionInfo contains subscription details.
|
||||
type SubscriptionInfo struct {
|
||||
SubscriptionTitle string `json:"subscriptionTitle,omitempty"`
|
||||
Type string `json:"type,omitempty"`
|
||||
}
|
||||
|
||||
// UsageBreakdown contains usage details.
|
||||
type UsageBreakdown struct {
|
||||
UsageLimit *int `json:"usageLimit,omitempty"`
|
||||
CurrentUsage *int `json:"currentUsage,omitempty"`
|
||||
UsageLimitWithPrecision *float64 `json:"usageLimitWithPrecision,omitempty"`
|
||||
CurrentUsageWithPrecision *float64 `json:"currentUsageWithPrecision,omitempty"`
|
||||
NextDateReset *float64 `json:"nextDateReset,omitempty"`
|
||||
DisplayName string `json:"displayName,omitempty"`
|
||||
ResourceType string `json:"resourceType,omitempty"`
|
||||
}
|
||||
|
||||
// NewCodeWhispererClient creates a new CodeWhisperer client.
|
||||
func NewCodeWhispererClient(cfg *config.Config, machineID string) *CodeWhispererClient {
|
||||
client := &http.Client{Timeout: 30 * time.Second}
|
||||
if cfg != nil {
|
||||
client = util.SetProxy(&cfg.SDKConfig, client)
|
||||
}
|
||||
return &CodeWhispererClient{
|
||||
httpClient: client,
|
||||
}
|
||||
}
|
||||
|
||||
// GetUsageLimits fetches usage limits and user info from CodeWhisperer API.
|
||||
// This is the recommended way to get user email after login.
|
||||
func (c *CodeWhispererClient) GetUsageLimits(ctx context.Context, accessToken, clientID, refreshToken, profileArn string) (*UsageLimitsResponse, error) {
|
||||
queryParams := map[string]string{
|
||||
"origin": "AI_EDITOR",
|
||||
"resourceType": "AGENTIC_REQUEST",
|
||||
}
|
||||
// Determine endpoint based on profileArn region
|
||||
endpoint := GetKiroAPIEndpointFromProfileArn(profileArn)
|
||||
if profileArn != "" {
|
||||
queryParams["profileArn"] = profileArn
|
||||
} else {
|
||||
queryParams["isEmailRequired"] = "true"
|
||||
}
|
||||
url := buildURL(endpoint, pathGetUsageLimits, queryParams)
|
||||
|
||||
req, err := http.NewRequestWithContext(ctx, http.MethodGet, url, nil)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to create request: %w", err)
|
||||
}
|
||||
|
||||
accountKey := GetAccountKey(clientID, refreshToken)
|
||||
setRuntimeHeaders(req, accessToken, accountKey)
|
||||
|
||||
log.Debugf("codewhisperer: GET %s", url)
|
||||
|
||||
resp, err := c.httpClient.Do(req)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("request failed: %w", err)
|
||||
}
|
||||
defer resp.Body.Close()
|
||||
|
||||
body, err := io.ReadAll(resp.Body)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to read response: %w", err)
|
||||
}
|
||||
|
||||
log.Debugf("codewhisperer: status=%d, body=%s", resp.StatusCode, string(body))
|
||||
|
||||
if resp.StatusCode != http.StatusOK {
|
||||
return nil, fmt.Errorf("API returned status %d: %s", resp.StatusCode, string(body))
|
||||
}
|
||||
|
||||
var result UsageLimitsResponse
|
||||
if err := json.Unmarshal(body, &result); err != nil {
|
||||
return nil, fmt.Errorf("failed to parse response: %w", err)
|
||||
}
|
||||
|
||||
return &result, nil
|
||||
}
|
||||
|
||||
// FetchUserEmailFromAPI fetches user email using CodeWhisperer getUsageLimits API.
|
||||
// This is more reliable than JWT parsing as it uses the official API.
|
||||
func (c *CodeWhispererClient) FetchUserEmailFromAPI(ctx context.Context, accessToken, clientID, refreshToken string) string {
|
||||
resp, err := c.GetUsageLimits(ctx, accessToken, clientID, refreshToken, "")
|
||||
if err != nil {
|
||||
log.Debugf("codewhisperer: failed to get usage limits: %v", err)
|
||||
return ""
|
||||
}
|
||||
|
||||
if resp.UserInfo != nil && resp.UserInfo.Email != "" {
|
||||
log.Debugf("codewhisperer: got email from API: %s", resp.UserInfo.Email)
|
||||
return resp.UserInfo.Email
|
||||
}
|
||||
|
||||
log.Debugf("codewhisperer: no email in response")
|
||||
return ""
|
||||
}
|
||||
|
||||
// FetchUserEmailWithFallback fetches user email with multiple fallback methods.
|
||||
// Priority: 1. CodeWhisperer API 2. userinfo endpoint 3. JWT parsing
|
||||
func FetchUserEmailWithFallback(ctx context.Context, cfg *config.Config, accessToken, clientID, refreshToken string) string {
|
||||
// Method 1: Try CodeWhisperer API (most reliable)
|
||||
cwClient := NewCodeWhispererClient(cfg, "")
|
||||
email := cwClient.FetchUserEmailFromAPI(ctx, accessToken, clientID, refreshToken)
|
||||
if email != "" {
|
||||
return email
|
||||
}
|
||||
|
||||
// Method 2: Try SSO OIDC userinfo endpoint
|
||||
ssoClient := NewSSOOIDCClient(cfg)
|
||||
email = ssoClient.FetchUserEmail(ctx, accessToken)
|
||||
if email != "" {
|
||||
return email
|
||||
}
|
||||
|
||||
// Method 3: Fallback to JWT parsing
|
||||
return ExtractEmailFromJWT(accessToken)
|
||||
}
|
||||
112
internal/auth/kiro/cooldown.go
Normal file
112
internal/auth/kiro/cooldown.go
Normal file
@@ -0,0 +1,112 @@
|
||||
package kiro
|
||||
|
||||
import (
|
||||
"sync"
|
||||
"time"
|
||||
)
|
||||
|
||||
const (
|
||||
CooldownReason429 = "rate_limit_exceeded"
|
||||
CooldownReasonSuspended = "account_suspended"
|
||||
CooldownReasonQuotaExhausted = "quota_exhausted"
|
||||
|
||||
DefaultShortCooldown = 1 * time.Minute
|
||||
MaxShortCooldown = 5 * time.Minute
|
||||
LongCooldown = 24 * time.Hour
|
||||
)
|
||||
|
||||
type CooldownManager struct {
|
||||
mu sync.RWMutex
|
||||
cooldowns map[string]time.Time
|
||||
reasons map[string]string
|
||||
}
|
||||
|
||||
func NewCooldownManager() *CooldownManager {
|
||||
return &CooldownManager{
|
||||
cooldowns: make(map[string]time.Time),
|
||||
reasons: make(map[string]string),
|
||||
}
|
||||
}
|
||||
|
||||
func (cm *CooldownManager) SetCooldown(tokenKey string, duration time.Duration, reason string) {
|
||||
cm.mu.Lock()
|
||||
defer cm.mu.Unlock()
|
||||
cm.cooldowns[tokenKey] = time.Now().Add(duration)
|
||||
cm.reasons[tokenKey] = reason
|
||||
}
|
||||
|
||||
func (cm *CooldownManager) IsInCooldown(tokenKey string) bool {
|
||||
cm.mu.RLock()
|
||||
defer cm.mu.RUnlock()
|
||||
endTime, exists := cm.cooldowns[tokenKey]
|
||||
if !exists {
|
||||
return false
|
||||
}
|
||||
return time.Now().Before(endTime)
|
||||
}
|
||||
|
||||
func (cm *CooldownManager) GetRemainingCooldown(tokenKey string) time.Duration {
|
||||
cm.mu.RLock()
|
||||
defer cm.mu.RUnlock()
|
||||
endTime, exists := cm.cooldowns[tokenKey]
|
||||
if !exists {
|
||||
return 0
|
||||
}
|
||||
remaining := time.Until(endTime)
|
||||
if remaining < 0 {
|
||||
return 0
|
||||
}
|
||||
return remaining
|
||||
}
|
||||
|
||||
func (cm *CooldownManager) GetCooldownReason(tokenKey string) string {
|
||||
cm.mu.RLock()
|
||||
defer cm.mu.RUnlock()
|
||||
return cm.reasons[tokenKey]
|
||||
}
|
||||
|
||||
func (cm *CooldownManager) ClearCooldown(tokenKey string) {
|
||||
cm.mu.Lock()
|
||||
defer cm.mu.Unlock()
|
||||
delete(cm.cooldowns, tokenKey)
|
||||
delete(cm.reasons, tokenKey)
|
||||
}
|
||||
|
||||
func (cm *CooldownManager) CleanupExpired() {
|
||||
cm.mu.Lock()
|
||||
defer cm.mu.Unlock()
|
||||
now := time.Now()
|
||||
for tokenKey, endTime := range cm.cooldowns {
|
||||
if now.After(endTime) {
|
||||
delete(cm.cooldowns, tokenKey)
|
||||
delete(cm.reasons, tokenKey)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func (cm *CooldownManager) StartCleanupRoutine(interval time.Duration, stopCh <-chan struct{}) {
|
||||
ticker := time.NewTicker(interval)
|
||||
defer ticker.Stop()
|
||||
for {
|
||||
select {
|
||||
case <-ticker.C:
|
||||
cm.CleanupExpired()
|
||||
case <-stopCh:
|
||||
return
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func CalculateCooldownFor429(retryCount int) time.Duration {
|
||||
duration := DefaultShortCooldown * time.Duration(1<<retryCount)
|
||||
if duration > MaxShortCooldown {
|
||||
return MaxShortCooldown
|
||||
}
|
||||
return duration
|
||||
}
|
||||
|
||||
func CalculateCooldownUntilNextDay() time.Duration {
|
||||
now := time.Now()
|
||||
nextDay := time.Date(now.Year(), now.Month(), now.Day()+1, 0, 0, 0, 0, now.Location())
|
||||
return time.Until(nextDay)
|
||||
}
|
||||
240
internal/auth/kiro/cooldown_test.go
Normal file
240
internal/auth/kiro/cooldown_test.go
Normal file
@@ -0,0 +1,240 @@
|
||||
package kiro
|
||||
|
||||
import (
|
||||
"sync"
|
||||
"testing"
|
||||
"time"
|
||||
)
|
||||
|
||||
func TestNewCooldownManager(t *testing.T) {
|
||||
cm := NewCooldownManager()
|
||||
if cm == nil {
|
||||
t.Fatal("expected non-nil CooldownManager")
|
||||
}
|
||||
if cm.cooldowns == nil {
|
||||
t.Error("expected non-nil cooldowns map")
|
||||
}
|
||||
if cm.reasons == nil {
|
||||
t.Error("expected non-nil reasons map")
|
||||
}
|
||||
}
|
||||
|
||||
func TestSetCooldown(t *testing.T) {
|
||||
cm := NewCooldownManager()
|
||||
cm.SetCooldown("token1", 1*time.Minute, CooldownReason429)
|
||||
|
||||
if !cm.IsInCooldown("token1") {
|
||||
t.Error("expected token to be in cooldown")
|
||||
}
|
||||
if cm.GetCooldownReason("token1") != CooldownReason429 {
|
||||
t.Errorf("expected reason %s, got %s", CooldownReason429, cm.GetCooldownReason("token1"))
|
||||
}
|
||||
}
|
||||
|
||||
func TestIsInCooldown_NotSet(t *testing.T) {
|
||||
cm := NewCooldownManager()
|
||||
if cm.IsInCooldown("nonexistent") {
|
||||
t.Error("expected non-existent token to not be in cooldown")
|
||||
}
|
||||
}
|
||||
|
||||
func TestIsInCooldown_Expired(t *testing.T) {
|
||||
cm := NewCooldownManager()
|
||||
cm.SetCooldown("token1", 1*time.Millisecond, CooldownReason429)
|
||||
|
||||
time.Sleep(10 * time.Millisecond)
|
||||
|
||||
if cm.IsInCooldown("token1") {
|
||||
t.Error("expected expired cooldown to return false")
|
||||
}
|
||||
}
|
||||
|
||||
func TestGetRemainingCooldown(t *testing.T) {
|
||||
cm := NewCooldownManager()
|
||||
cm.SetCooldown("token1", 1*time.Second, CooldownReason429)
|
||||
|
||||
remaining := cm.GetRemainingCooldown("token1")
|
||||
if remaining <= 0 || remaining > 1*time.Second {
|
||||
t.Errorf("expected remaining cooldown between 0 and 1s, got %v", remaining)
|
||||
}
|
||||
}
|
||||
|
||||
func TestGetRemainingCooldown_NotSet(t *testing.T) {
|
||||
cm := NewCooldownManager()
|
||||
remaining := cm.GetRemainingCooldown("nonexistent")
|
||||
if remaining != 0 {
|
||||
t.Errorf("expected 0 remaining for non-existent, got %v", remaining)
|
||||
}
|
||||
}
|
||||
|
||||
func TestGetRemainingCooldown_Expired(t *testing.T) {
|
||||
cm := NewCooldownManager()
|
||||
cm.SetCooldown("token1", 1*time.Millisecond, CooldownReason429)
|
||||
|
||||
time.Sleep(10 * time.Millisecond)
|
||||
|
||||
remaining := cm.GetRemainingCooldown("token1")
|
||||
if remaining != 0 {
|
||||
t.Errorf("expected 0 remaining for expired, got %v", remaining)
|
||||
}
|
||||
}
|
||||
|
||||
func TestGetCooldownReason(t *testing.T) {
|
||||
cm := NewCooldownManager()
|
||||
cm.SetCooldown("token1", 1*time.Minute, CooldownReasonSuspended)
|
||||
|
||||
reason := cm.GetCooldownReason("token1")
|
||||
if reason != CooldownReasonSuspended {
|
||||
t.Errorf("expected reason %s, got %s", CooldownReasonSuspended, reason)
|
||||
}
|
||||
}
|
||||
|
||||
func TestGetCooldownReason_NotSet(t *testing.T) {
|
||||
cm := NewCooldownManager()
|
||||
reason := cm.GetCooldownReason("nonexistent")
|
||||
if reason != "" {
|
||||
t.Errorf("expected empty reason for non-existent, got %s", reason)
|
||||
}
|
||||
}
|
||||
|
||||
func TestClearCooldown(t *testing.T) {
|
||||
cm := NewCooldownManager()
|
||||
cm.SetCooldown("token1", 1*time.Minute, CooldownReason429)
|
||||
cm.ClearCooldown("token1")
|
||||
|
||||
if cm.IsInCooldown("token1") {
|
||||
t.Error("expected cooldown to be cleared")
|
||||
}
|
||||
if cm.GetCooldownReason("token1") != "" {
|
||||
t.Error("expected reason to be cleared")
|
||||
}
|
||||
}
|
||||
|
||||
func TestClearCooldown_NonExistent(t *testing.T) {
|
||||
cm := NewCooldownManager()
|
||||
cm.ClearCooldown("nonexistent")
|
||||
}
|
||||
|
||||
func TestCleanupExpired(t *testing.T) {
|
||||
cm := NewCooldownManager()
|
||||
cm.SetCooldown("expired1", 1*time.Millisecond, CooldownReason429)
|
||||
cm.SetCooldown("expired2", 1*time.Millisecond, CooldownReason429)
|
||||
cm.SetCooldown("active", 1*time.Hour, CooldownReason429)
|
||||
|
||||
time.Sleep(10 * time.Millisecond)
|
||||
cm.CleanupExpired()
|
||||
|
||||
if cm.GetCooldownReason("expired1") != "" {
|
||||
t.Error("expected expired1 to be cleaned up")
|
||||
}
|
||||
if cm.GetCooldownReason("expired2") != "" {
|
||||
t.Error("expected expired2 to be cleaned up")
|
||||
}
|
||||
if cm.GetCooldownReason("active") != CooldownReason429 {
|
||||
t.Error("expected active to remain")
|
||||
}
|
||||
}
|
||||
|
||||
func TestCalculateCooldownFor429_FirstRetry(t *testing.T) {
|
||||
duration := CalculateCooldownFor429(0)
|
||||
if duration != DefaultShortCooldown {
|
||||
t.Errorf("expected %v for retry 0, got %v", DefaultShortCooldown, duration)
|
||||
}
|
||||
}
|
||||
|
||||
func TestCalculateCooldownFor429_Exponential(t *testing.T) {
|
||||
d1 := CalculateCooldownFor429(1)
|
||||
d2 := CalculateCooldownFor429(2)
|
||||
|
||||
if d2 <= d1 {
|
||||
t.Errorf("expected d2 > d1, got d1=%v, d2=%v", d1, d2)
|
||||
}
|
||||
}
|
||||
|
||||
func TestCalculateCooldownFor429_MaxCap(t *testing.T) {
|
||||
duration := CalculateCooldownFor429(10)
|
||||
if duration > MaxShortCooldown {
|
||||
t.Errorf("expected max %v, got %v", MaxShortCooldown, duration)
|
||||
}
|
||||
}
|
||||
|
||||
func TestCalculateCooldownUntilNextDay(t *testing.T) {
|
||||
duration := CalculateCooldownUntilNextDay()
|
||||
if duration <= 0 || duration > 24*time.Hour {
|
||||
t.Errorf("expected duration between 0 and 24h, got %v", duration)
|
||||
}
|
||||
}
|
||||
|
||||
func TestCooldownManager_ConcurrentAccess(t *testing.T) {
|
||||
cm := NewCooldownManager()
|
||||
const numGoroutines = 50
|
||||
const numOperations = 100
|
||||
|
||||
var wg sync.WaitGroup
|
||||
wg.Add(numGoroutines)
|
||||
|
||||
for i := 0; i < numGoroutines; i++ {
|
||||
go func(id int) {
|
||||
defer wg.Done()
|
||||
tokenKey := "token" + string(rune('a'+id%10))
|
||||
for j := 0; j < numOperations; j++ {
|
||||
switch j % 6 {
|
||||
case 0:
|
||||
cm.SetCooldown(tokenKey, time.Duration(j)*time.Millisecond, CooldownReason429)
|
||||
case 1:
|
||||
cm.IsInCooldown(tokenKey)
|
||||
case 2:
|
||||
cm.GetRemainingCooldown(tokenKey)
|
||||
case 3:
|
||||
cm.GetCooldownReason(tokenKey)
|
||||
case 4:
|
||||
cm.ClearCooldown(tokenKey)
|
||||
case 5:
|
||||
cm.CleanupExpired()
|
||||
}
|
||||
}
|
||||
}(i)
|
||||
}
|
||||
|
||||
wg.Wait()
|
||||
}
|
||||
|
||||
func TestCooldownReasonConstants(t *testing.T) {
|
||||
if CooldownReason429 != "rate_limit_exceeded" {
|
||||
t.Errorf("unexpected CooldownReason429: %s", CooldownReason429)
|
||||
}
|
||||
if CooldownReasonSuspended != "account_suspended" {
|
||||
t.Errorf("unexpected CooldownReasonSuspended: %s", CooldownReasonSuspended)
|
||||
}
|
||||
if CooldownReasonQuotaExhausted != "quota_exhausted" {
|
||||
t.Errorf("unexpected CooldownReasonQuotaExhausted: %s", CooldownReasonQuotaExhausted)
|
||||
}
|
||||
}
|
||||
|
||||
func TestDefaultConstants(t *testing.T) {
|
||||
if DefaultShortCooldown != 1*time.Minute {
|
||||
t.Errorf("unexpected DefaultShortCooldown: %v", DefaultShortCooldown)
|
||||
}
|
||||
if MaxShortCooldown != 5*time.Minute {
|
||||
t.Errorf("unexpected MaxShortCooldown: %v", MaxShortCooldown)
|
||||
}
|
||||
if LongCooldown != 24*time.Hour {
|
||||
t.Errorf("unexpected LongCooldown: %v", LongCooldown)
|
||||
}
|
||||
}
|
||||
|
||||
func TestSetCooldown_OverwritesPrevious(t *testing.T) {
|
||||
cm := NewCooldownManager()
|
||||
cm.SetCooldown("token1", 1*time.Hour, CooldownReason429)
|
||||
cm.SetCooldown("token1", 1*time.Minute, CooldownReasonSuspended)
|
||||
|
||||
reason := cm.GetCooldownReason("token1")
|
||||
if reason != CooldownReasonSuspended {
|
||||
t.Errorf("expected reason to be overwritten to %s, got %s", CooldownReasonSuspended, reason)
|
||||
}
|
||||
|
||||
remaining := cm.GetRemainingCooldown("token1")
|
||||
if remaining > 1*time.Minute {
|
||||
t.Errorf("expected remaining <= 1 minute, got %v", remaining)
|
||||
}
|
||||
}
|
||||
278
internal/auth/kiro/fingerprint.go
Normal file
278
internal/auth/kiro/fingerprint.go
Normal file
@@ -0,0 +1,278 @@
|
||||
package kiro
|
||||
|
||||
import (
|
||||
"crypto/sha256"
|
||||
"encoding/binary"
|
||||
"encoding/hex"
|
||||
"fmt"
|
||||
"math/rand"
|
||||
"net/http"
|
||||
"runtime"
|
||||
"slices"
|
||||
"sync"
|
||||
"time"
|
||||
|
||||
"github.com/google/uuid"
|
||||
)
|
||||
|
||||
// Fingerprint holds multi-dimensional fingerprint data for runtime request disguise.
|
||||
type Fingerprint struct {
|
||||
OIDCSDKVersion string // 3.7xx (AWS SDK JS)
|
||||
RuntimeSDKVersion string // 1.0.x (runtime API)
|
||||
StreamingSDKVersion string // 1.0.x (streaming API)
|
||||
OSType string // darwin/windows/linux
|
||||
OSVersion string
|
||||
NodeVersion string
|
||||
KiroVersion string
|
||||
KiroHash string // SHA256
|
||||
}
|
||||
|
||||
// FingerprintConfig holds external fingerprint overrides.
|
||||
type FingerprintConfig struct {
|
||||
OIDCSDKVersion string
|
||||
RuntimeSDKVersion string
|
||||
StreamingSDKVersion string
|
||||
OSType string
|
||||
OSVersion string
|
||||
NodeVersion string
|
||||
KiroVersion string
|
||||
KiroHash string
|
||||
}
|
||||
|
||||
// FingerprintManager manages per-account fingerprint generation and caching.
|
||||
type FingerprintManager struct {
|
||||
mu sync.RWMutex
|
||||
fingerprints map[string]*Fingerprint // tokenKey -> fingerprint
|
||||
rng *rand.Rand
|
||||
config *FingerprintConfig // External config (Optional)
|
||||
}
|
||||
|
||||
var (
|
||||
// SDK versions
|
||||
oidcSDKVersions = []string{
|
||||
"3.980.0", "3.975.0", "3.972.0", "3.808.0",
|
||||
"3.738.0", "3.737.0", "3.736.0", "3.735.0",
|
||||
}
|
||||
// SDKVersions for getUsageLimits/ListAvailableModels/GetProfile (runtime API)
|
||||
runtimeSDKVersions = []string{"1.0.0"}
|
||||
// SDKVersions for generateAssistantResponse (streaming API)
|
||||
streamingSDKVersions = []string{"1.0.27"}
|
||||
// Valid OS types
|
||||
osTypes = []string{"darwin", "windows", "linux"}
|
||||
// OS versions
|
||||
osVersions = map[string][]string{
|
||||
"darwin": {"25.2.0", "25.1.0", "25.0.0", "24.5.0", "24.4.0", "24.3.0"},
|
||||
"windows": {"10.0.26200", "10.0.26100", "10.0.22631", "10.0.22621", "10.0.19045"},
|
||||
"linux": {"6.12.0", "6.11.0", "6.8.0", "6.6.0", "6.5.0", "6.1.0"},
|
||||
}
|
||||
// Node versions
|
||||
nodeVersions = []string{
|
||||
"22.21.1", "22.21.0", "22.20.0", "22.19.0", "22.18.0",
|
||||
"20.18.0", "20.17.0", "20.16.0",
|
||||
}
|
||||
// Kiro IDE versions
|
||||
kiroVersions = []string{
|
||||
"0.10.32", "0.10.16", "0.10.10",
|
||||
"0.9.47", "0.9.40", "0.9.2",
|
||||
"0.8.206", "0.8.140", "0.8.135", "0.8.86",
|
||||
}
|
||||
// Global singleton
|
||||
globalFingerprintManager *FingerprintManager
|
||||
globalFingerprintManagerOnce sync.Once
|
||||
)
|
||||
|
||||
func GlobalFingerprintManager() *FingerprintManager {
|
||||
globalFingerprintManagerOnce.Do(func() {
|
||||
globalFingerprintManager = NewFingerprintManager()
|
||||
})
|
||||
return globalFingerprintManager
|
||||
}
|
||||
|
||||
func SetGlobalFingerprintConfig(cfg *FingerprintConfig) {
|
||||
GlobalFingerprintManager().SetConfig(cfg)
|
||||
}
|
||||
|
||||
// SetConfig applies the config and clears the fingerprint cache.
|
||||
func (fm *FingerprintManager) SetConfig(cfg *FingerprintConfig) {
|
||||
fm.mu.Lock()
|
||||
defer fm.mu.Unlock()
|
||||
fm.config = cfg
|
||||
// Clear cached fingerprints so they regenerate with the new config
|
||||
fm.fingerprints = make(map[string]*Fingerprint)
|
||||
}
|
||||
|
||||
func NewFingerprintManager() *FingerprintManager {
|
||||
return &FingerprintManager{
|
||||
fingerprints: make(map[string]*Fingerprint),
|
||||
rng: rand.New(rand.NewSource(time.Now().UnixNano())),
|
||||
}
|
||||
}
|
||||
|
||||
// GetFingerprint returns the fingerprint for tokenKey, creating one if it doesn't exist.
|
||||
func (fm *FingerprintManager) GetFingerprint(tokenKey string) *Fingerprint {
|
||||
fm.mu.RLock()
|
||||
if fp, exists := fm.fingerprints[tokenKey]; exists {
|
||||
fm.mu.RUnlock()
|
||||
return fp
|
||||
}
|
||||
fm.mu.RUnlock()
|
||||
|
||||
fm.mu.Lock()
|
||||
defer fm.mu.Unlock()
|
||||
|
||||
if fp, exists := fm.fingerprints[tokenKey]; exists {
|
||||
return fp
|
||||
}
|
||||
|
||||
fp := fm.generateFingerprint(tokenKey)
|
||||
fm.fingerprints[tokenKey] = fp
|
||||
return fp
|
||||
}
|
||||
|
||||
func (fm *FingerprintManager) generateFingerprint(tokenKey string) *Fingerprint {
|
||||
if fm.config != nil {
|
||||
return fm.generateFromConfig(tokenKey)
|
||||
}
|
||||
return fm.generateRandom(tokenKey)
|
||||
}
|
||||
|
||||
// generateFromConfig uses config values, falling back to random for empty fields.
|
||||
func (fm *FingerprintManager) generateFromConfig(tokenKey string) *Fingerprint {
|
||||
cfg := fm.config
|
||||
|
||||
// Helper: config value or random selection
|
||||
configOrRandom := func(configVal string, choices []string) string {
|
||||
if configVal != "" {
|
||||
return configVal
|
||||
}
|
||||
return choices[fm.rng.Intn(len(choices))]
|
||||
}
|
||||
|
||||
osType := cfg.OSType
|
||||
if osType == "" {
|
||||
osType = runtime.GOOS
|
||||
if !slices.Contains(osTypes, osType) {
|
||||
osType = osTypes[fm.rng.Intn(len(osTypes))]
|
||||
}
|
||||
}
|
||||
|
||||
osVersion := cfg.OSVersion
|
||||
if osVersion == "" {
|
||||
if versions, ok := osVersions[osType]; ok {
|
||||
osVersion = versions[fm.rng.Intn(len(versions))]
|
||||
}
|
||||
}
|
||||
|
||||
kiroHash := cfg.KiroHash
|
||||
if kiroHash == "" {
|
||||
hash := sha256.Sum256([]byte(tokenKey))
|
||||
kiroHash = hex.EncodeToString(hash[:])
|
||||
}
|
||||
|
||||
return &Fingerprint{
|
||||
OIDCSDKVersion: configOrRandom(cfg.OIDCSDKVersion, oidcSDKVersions),
|
||||
RuntimeSDKVersion: configOrRandom(cfg.RuntimeSDKVersion, runtimeSDKVersions),
|
||||
StreamingSDKVersion: configOrRandom(cfg.StreamingSDKVersion, streamingSDKVersions),
|
||||
OSType: osType,
|
||||
OSVersion: osVersion,
|
||||
NodeVersion: configOrRandom(cfg.NodeVersion, nodeVersions),
|
||||
KiroVersion: configOrRandom(cfg.KiroVersion, kiroVersions),
|
||||
KiroHash: kiroHash,
|
||||
}
|
||||
}
|
||||
|
||||
// generateRandom generates a deterministic fingerprint seeded by accountKey hash.
|
||||
func (fm *FingerprintManager) generateRandom(accountKey string) *Fingerprint {
|
||||
// Use accountKey hash as seed for deterministic random selection
|
||||
hash := sha256.Sum256([]byte(accountKey))
|
||||
seed := int64(binary.BigEndian.Uint64(hash[:8]))
|
||||
rng := rand.New(rand.NewSource(seed))
|
||||
|
||||
osType := runtime.GOOS
|
||||
if !slices.Contains(osTypes, osType) {
|
||||
osType = osTypes[rng.Intn(len(osTypes))]
|
||||
}
|
||||
osVersion := osVersions[osType][rng.Intn(len(osVersions[osType]))]
|
||||
|
||||
return &Fingerprint{
|
||||
OIDCSDKVersion: oidcSDKVersions[rng.Intn(len(oidcSDKVersions))],
|
||||
RuntimeSDKVersion: runtimeSDKVersions[rng.Intn(len(runtimeSDKVersions))],
|
||||
StreamingSDKVersion: streamingSDKVersions[rng.Intn(len(streamingSDKVersions))],
|
||||
OSType: osType,
|
||||
OSVersion: osVersion,
|
||||
NodeVersion: nodeVersions[rng.Intn(len(nodeVersions))],
|
||||
KiroVersion: kiroVersions[rng.Intn(len(kiroVersions))],
|
||||
KiroHash: hex.EncodeToString(hash[:]),
|
||||
}
|
||||
}
|
||||
|
||||
// GenerateAccountKey returns a 16-char hex key derived from SHA256(seed).
|
||||
func GenerateAccountKey(seed string) string {
|
||||
hash := sha256.Sum256([]byte(seed))
|
||||
return hex.EncodeToString(hash[:8])
|
||||
}
|
||||
|
||||
// GetAccountKey derives an account key from clientID > refreshToken > random UUID.
|
||||
func GetAccountKey(clientID, refreshToken string) string {
|
||||
// 1. Prefer ClientID
|
||||
if clientID != "" {
|
||||
return GenerateAccountKey(clientID)
|
||||
}
|
||||
|
||||
// 2. Fallback to RefreshToken
|
||||
if refreshToken != "" {
|
||||
return GenerateAccountKey(refreshToken)
|
||||
}
|
||||
|
||||
// 3. Random fallback
|
||||
return GenerateAccountKey(uuid.New().String())
|
||||
}
|
||||
|
||||
// BuildUserAgent format: aws-sdk-js/{SDKVersion} ua/2.1 os/{OSType}#{OSVersion} lang/js md/nodejs#{NodeVersion} api/codewhispererstreaming#{SDKVersion} m/E KiroIDE-{KiroVersion}-{KiroHash}
|
||||
func (fp *Fingerprint) BuildUserAgent() string {
|
||||
return fmt.Sprintf(
|
||||
"aws-sdk-js/%s ua/2.1 os/%s#%s lang/js md/nodejs#%s api/codewhispererstreaming#%s m/E KiroIDE-%s-%s",
|
||||
fp.StreamingSDKVersion,
|
||||
fp.OSType,
|
||||
fp.OSVersion,
|
||||
fp.NodeVersion,
|
||||
fp.StreamingSDKVersion,
|
||||
fp.KiroVersion,
|
||||
fp.KiroHash,
|
||||
)
|
||||
}
|
||||
|
||||
// BuildAmzUserAgent format: aws-sdk-js/{SDKVersion} KiroIDE-{KiroVersion}-{KiroHash}
|
||||
func (fp *Fingerprint) BuildAmzUserAgent() string {
|
||||
return fmt.Sprintf(
|
||||
"aws-sdk-js/%s KiroIDE-%s-%s",
|
||||
fp.StreamingSDKVersion,
|
||||
fp.KiroVersion,
|
||||
fp.KiroHash,
|
||||
)
|
||||
}
|
||||
|
||||
func SetOIDCHeaders(req *http.Request) {
|
||||
fp := GlobalFingerprintManager().GetFingerprint("oidc-session")
|
||||
req.Header.Set("Content-Type", "application/json")
|
||||
req.Header.Set("x-amz-user-agent", fmt.Sprintf("aws-sdk-js/%s KiroIDE", fp.OIDCSDKVersion))
|
||||
req.Header.Set("User-Agent", fmt.Sprintf(
|
||||
"aws-sdk-js/%s ua/2.1 os/%s#%s lang/js md/nodejs#%s api/%s#%s m/E KiroIDE",
|
||||
fp.OIDCSDKVersion, fp.OSType, fp.OSVersion, fp.NodeVersion, "sso-oidc", fp.OIDCSDKVersion))
|
||||
req.Header.Set("amz-sdk-invocation-id", uuid.New().String())
|
||||
req.Header.Set("amz-sdk-request", "attempt=1; max=4")
|
||||
}
|
||||
|
||||
func setRuntimeHeaders(req *http.Request, accessToken string, accountKey string) {
|
||||
fp := GlobalFingerprintManager().GetFingerprint(accountKey)
|
||||
machineID := fp.KiroHash
|
||||
req.Header.Set("Authorization", "Bearer "+accessToken)
|
||||
req.Header.Set("x-amz-user-agent", fmt.Sprintf("aws-sdk-js/%s KiroIDE-%s-%s",
|
||||
fp.RuntimeSDKVersion, fp.KiroVersion, machineID))
|
||||
req.Header.Set("User-Agent", fmt.Sprintf(
|
||||
"aws-sdk-js/%s ua/2.1 os/%s#%s lang/js md/nodejs#%s api/codewhispererruntime#%s m/N,E KiroIDE-%s-%s",
|
||||
fp.RuntimeSDKVersion, fp.OSType, fp.OSVersion, fp.NodeVersion, fp.RuntimeSDKVersion,
|
||||
fp.KiroVersion, machineID))
|
||||
req.Header.Set("amz-sdk-invocation-id", uuid.New().String())
|
||||
req.Header.Set("amz-sdk-request", "attempt=1; max=1")
|
||||
}
|
||||
778
internal/auth/kiro/fingerprint_test.go
Normal file
778
internal/auth/kiro/fingerprint_test.go
Normal file
@@ -0,0 +1,778 @@
|
||||
package kiro
|
||||
|
||||
import (
|
||||
"net/http"
|
||||
"runtime"
|
||||
"strings"
|
||||
"sync"
|
||||
"testing"
|
||||
)
|
||||
|
||||
func TestNewFingerprintManager(t *testing.T) {
|
||||
fm := NewFingerprintManager()
|
||||
if fm == nil {
|
||||
t.Fatal("expected non-nil FingerprintManager")
|
||||
}
|
||||
if fm.fingerprints == nil {
|
||||
t.Error("expected non-nil fingerprints map")
|
||||
}
|
||||
if fm.rng == nil {
|
||||
t.Error("expected non-nil rng")
|
||||
}
|
||||
}
|
||||
|
||||
func TestGetFingerprint_NewToken(t *testing.T) {
|
||||
fm := NewFingerprintManager()
|
||||
fp := fm.GetFingerprint("token1")
|
||||
|
||||
if fp == nil {
|
||||
t.Fatal("expected non-nil Fingerprint")
|
||||
}
|
||||
if fp.OIDCSDKVersion == "" {
|
||||
t.Error("expected non-empty OIDCSDKVersion")
|
||||
}
|
||||
if fp.RuntimeSDKVersion == "" {
|
||||
t.Error("expected non-empty RuntimeSDKVersion")
|
||||
}
|
||||
if fp.StreamingSDKVersion == "" {
|
||||
t.Error("expected non-empty StreamingSDKVersion")
|
||||
}
|
||||
if fp.OSType == "" {
|
||||
t.Error("expected non-empty OSType")
|
||||
}
|
||||
if fp.OSVersion == "" {
|
||||
t.Error("expected non-empty OSVersion")
|
||||
}
|
||||
if fp.NodeVersion == "" {
|
||||
t.Error("expected non-empty NodeVersion")
|
||||
}
|
||||
if fp.KiroVersion == "" {
|
||||
t.Error("expected non-empty KiroVersion")
|
||||
}
|
||||
if fp.KiroHash == "" {
|
||||
t.Error("expected non-empty KiroHash")
|
||||
}
|
||||
}
|
||||
|
||||
func TestGetFingerprint_SameTokenReturnsSameFingerprint(t *testing.T) {
|
||||
fm := NewFingerprintManager()
|
||||
fp1 := fm.GetFingerprint("token1")
|
||||
fp2 := fm.GetFingerprint("token1")
|
||||
|
||||
if fp1 != fp2 {
|
||||
t.Error("expected same fingerprint for same token")
|
||||
}
|
||||
}
|
||||
|
||||
func TestGetFingerprint_DifferentTokens(t *testing.T) {
|
||||
fm := NewFingerprintManager()
|
||||
fp1 := fm.GetFingerprint("token1")
|
||||
fp2 := fm.GetFingerprint("token2")
|
||||
|
||||
if fp1 == fp2 {
|
||||
t.Error("expected different fingerprints for different tokens")
|
||||
}
|
||||
}
|
||||
|
||||
func TestBuildUserAgent(t *testing.T) {
|
||||
fm := NewFingerprintManager()
|
||||
fp := fm.GetFingerprint("token1")
|
||||
|
||||
ua := fp.BuildUserAgent()
|
||||
if ua == "" {
|
||||
t.Error("expected non-empty User-Agent")
|
||||
}
|
||||
|
||||
amzUA := fp.BuildAmzUserAgent()
|
||||
if amzUA == "" {
|
||||
t.Error("expected non-empty X-Amz-User-Agent")
|
||||
}
|
||||
}
|
||||
|
||||
func TestGetFingerprint_OSVersionMatchesOSType(t *testing.T) {
|
||||
fm := NewFingerprintManager()
|
||||
|
||||
for i := 0; i < 20; i++ {
|
||||
fp := fm.GetFingerprint("token" + string(rune('a'+i)))
|
||||
validVersions := osVersions[fp.OSType]
|
||||
found := false
|
||||
for _, v := range validVersions {
|
||||
if v == fp.OSVersion {
|
||||
found = true
|
||||
break
|
||||
}
|
||||
}
|
||||
if !found {
|
||||
t.Errorf("OS version %s not valid for OS type %s", fp.OSVersion, fp.OSType)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func TestGenerateFromConfig_OSTypeFromRuntimeGOOS(t *testing.T) {
|
||||
fm := NewFingerprintManager()
|
||||
|
||||
// Set config with empty OSType to trigger runtime.GOOS fallback
|
||||
fm.SetConfig(&FingerprintConfig{
|
||||
OIDCSDKVersion: "3.738.0", // Set other fields to use config path
|
||||
})
|
||||
|
||||
fp := fm.GetFingerprint("test-token")
|
||||
|
||||
// Expected OS type based on runtime.GOOS mapping
|
||||
var expectedOS string
|
||||
switch runtime.GOOS {
|
||||
case "darwin":
|
||||
expectedOS = "darwin"
|
||||
case "windows":
|
||||
expectedOS = "windows"
|
||||
default:
|
||||
expectedOS = "linux"
|
||||
}
|
||||
|
||||
if fp.OSType != expectedOS {
|
||||
t.Errorf("expected OSType '%s' from runtime.GOOS '%s', got '%s'",
|
||||
expectedOS, runtime.GOOS, fp.OSType)
|
||||
}
|
||||
}
|
||||
|
||||
func TestFingerprintManager_ConcurrentAccess(t *testing.T) {
|
||||
fm := NewFingerprintManager()
|
||||
const numGoroutines = 100
|
||||
const numOperations = 100
|
||||
|
||||
var wg sync.WaitGroup
|
||||
wg.Add(numGoroutines)
|
||||
|
||||
for i := range numGoroutines {
|
||||
go func(id int) {
|
||||
defer wg.Done()
|
||||
for j := range numOperations {
|
||||
tokenKey := "token" + string(rune('a'+id%26))
|
||||
switch j % 2 {
|
||||
case 0:
|
||||
fm.GetFingerprint(tokenKey)
|
||||
case 1:
|
||||
fp := fm.GetFingerprint(tokenKey)
|
||||
_ = fp.BuildUserAgent()
|
||||
_ = fp.BuildAmzUserAgent()
|
||||
}
|
||||
}
|
||||
}(i)
|
||||
}
|
||||
|
||||
wg.Wait()
|
||||
}
|
||||
|
||||
func TestKiroHashStability(t *testing.T) {
|
||||
fm := NewFingerprintManager()
|
||||
|
||||
// Same token should always return same hash
|
||||
fp1 := fm.GetFingerprint("token1")
|
||||
fp2 := fm.GetFingerprint("token1")
|
||||
if fp1.KiroHash != fp2.KiroHash {
|
||||
t.Errorf("same token should have same hash: %s vs %s", fp1.KiroHash, fp2.KiroHash)
|
||||
}
|
||||
|
||||
// Different tokens should have different hashes
|
||||
fp3 := fm.GetFingerprint("token2")
|
||||
if fp1.KiroHash == fp3.KiroHash {
|
||||
t.Errorf("different tokens should have different hashes")
|
||||
}
|
||||
}
|
||||
|
||||
func TestKiroHashFormat(t *testing.T) {
|
||||
fm := NewFingerprintManager()
|
||||
fp := fm.GetFingerprint("token1")
|
||||
|
||||
if len(fp.KiroHash) != 64 {
|
||||
t.Errorf("expected KiroHash length 64 (SHA256 hex), got %d", len(fp.KiroHash))
|
||||
}
|
||||
|
||||
for _, c := range fp.KiroHash {
|
||||
if (c < '0' || c > '9') && (c < 'a' || c > 'f') {
|
||||
t.Errorf("invalid hex character in KiroHash: %c", c)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func TestGlobalFingerprintManager(t *testing.T) {
|
||||
fm1 := GlobalFingerprintManager()
|
||||
fm2 := GlobalFingerprintManager()
|
||||
|
||||
if fm1 == nil {
|
||||
t.Fatal("expected non-nil GlobalFingerprintManager")
|
||||
}
|
||||
if fm1 != fm2 {
|
||||
t.Error("expected GlobalFingerprintManager to return same instance")
|
||||
}
|
||||
}
|
||||
|
||||
func TestSetOIDCHeaders(t *testing.T) {
|
||||
req, _ := http.NewRequest("GET", "http://example.com", nil)
|
||||
SetOIDCHeaders(req)
|
||||
|
||||
if req.Header.Get("Content-Type") != "application/json" {
|
||||
t.Error("expected Content-Type header to be set")
|
||||
}
|
||||
|
||||
amzUA := req.Header.Get("x-amz-user-agent")
|
||||
if amzUA == "" {
|
||||
t.Error("expected x-amz-user-agent header to be set")
|
||||
}
|
||||
if !strings.Contains(amzUA, "aws-sdk-js/") {
|
||||
t.Errorf("x-amz-user-agent should contain aws-sdk-js: %s", amzUA)
|
||||
}
|
||||
if !strings.Contains(amzUA, "KiroIDE") {
|
||||
t.Errorf("x-amz-user-agent should contain KiroIDE: %s", amzUA)
|
||||
}
|
||||
|
||||
ua := req.Header.Get("User-Agent")
|
||||
if ua == "" {
|
||||
t.Error("expected User-Agent header to be set")
|
||||
}
|
||||
if !strings.Contains(ua, "api/sso-oidc") {
|
||||
t.Errorf("User-Agent should contain api name: %s", ua)
|
||||
}
|
||||
|
||||
if req.Header.Get("amz-sdk-invocation-id") == "" {
|
||||
t.Error("expected amz-sdk-invocation-id header to be set")
|
||||
}
|
||||
if req.Header.Get("amz-sdk-request") != "attempt=1; max=4" {
|
||||
t.Errorf("unexpected amz-sdk-request header: %s", req.Header.Get("amz-sdk-request"))
|
||||
}
|
||||
}
|
||||
|
||||
func TestBuildURL(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
endpoint string
|
||||
path string
|
||||
queryParams map[string]string
|
||||
want string
|
||||
wantContains []string
|
||||
}{
|
||||
{
|
||||
name: "no query params",
|
||||
endpoint: "https://api.example.com",
|
||||
path: "getUsageLimits",
|
||||
queryParams: nil,
|
||||
want: "https://api.example.com/getUsageLimits",
|
||||
},
|
||||
{
|
||||
name: "empty query params",
|
||||
endpoint: "https://api.example.com",
|
||||
path: "getUsageLimits",
|
||||
queryParams: map[string]string{},
|
||||
want: "https://api.example.com/getUsageLimits",
|
||||
},
|
||||
{
|
||||
name: "single query param",
|
||||
endpoint: "https://api.example.com",
|
||||
path: "getUsageLimits",
|
||||
queryParams: map[string]string{
|
||||
"origin": "AI_EDITOR",
|
||||
},
|
||||
want: "https://api.example.com/getUsageLimits?origin=AI_EDITOR",
|
||||
},
|
||||
{
|
||||
name: "multiple query params",
|
||||
endpoint: "https://api.example.com",
|
||||
path: "getUsageLimits",
|
||||
queryParams: map[string]string{
|
||||
"origin": "AI_EDITOR",
|
||||
"resourceType": "AGENTIC_REQUEST",
|
||||
"profileArn": "arn:aws:codewhisperer:us-east-1:123456789012:profile/ABCDEF",
|
||||
},
|
||||
wantContains: []string{
|
||||
"https://api.example.com/getUsageLimits?",
|
||||
"origin=AI_EDITOR",
|
||||
"profileArn=arn%3Aaws%3Acodewhisperer%3Aus-east-1%3A123456789012%3Aprofile%2FABCDEF",
|
||||
"resourceType=AGENTIC_REQUEST",
|
||||
},
|
||||
},
|
||||
{
|
||||
name: "omit empty params",
|
||||
endpoint: "https://api.example.com",
|
||||
path: "getUsageLimits",
|
||||
queryParams: map[string]string{
|
||||
"origin": "AI_EDITOR",
|
||||
"profileArn": "",
|
||||
},
|
||||
want: "https://api.example.com/getUsageLimits?origin=AI_EDITOR",
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
got := buildURL(tt.endpoint, tt.path, tt.queryParams)
|
||||
if tt.want != "" {
|
||||
if got != tt.want {
|
||||
t.Errorf("buildURL() = %v, want %v", got, tt.want)
|
||||
}
|
||||
}
|
||||
if tt.wantContains != nil {
|
||||
for _, substr := range tt.wantContains {
|
||||
if !strings.Contains(got, substr) {
|
||||
t.Errorf("buildURL() = %v, want to contain %v", got, substr)
|
||||
}
|
||||
}
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestBuildUserAgentFormat(t *testing.T) {
|
||||
fm := NewFingerprintManager()
|
||||
fp := fm.GetFingerprint("token1")
|
||||
|
||||
ua := fp.BuildUserAgent()
|
||||
requiredParts := []string{
|
||||
"aws-sdk-js/",
|
||||
"ua/2.1",
|
||||
"os/",
|
||||
"lang/js",
|
||||
"md/nodejs#",
|
||||
"api/codewhispererstreaming#",
|
||||
"m/E",
|
||||
"KiroIDE-",
|
||||
}
|
||||
for _, part := range requiredParts {
|
||||
if !strings.Contains(ua, part) {
|
||||
t.Errorf("User-Agent missing required part %q: %s", part, ua)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func TestBuildAmzUserAgentFormat(t *testing.T) {
|
||||
fm := NewFingerprintManager()
|
||||
fp := fm.GetFingerprint("token1")
|
||||
|
||||
amzUA := fp.BuildAmzUserAgent()
|
||||
requiredParts := []string{
|
||||
"aws-sdk-js/",
|
||||
"KiroIDE-",
|
||||
}
|
||||
for _, part := range requiredParts {
|
||||
if !strings.Contains(amzUA, part) {
|
||||
t.Errorf("X-Amz-User-Agent missing required part %q: %s", part, amzUA)
|
||||
}
|
||||
}
|
||||
|
||||
// Amz-User-Agent should be shorter than User-Agent
|
||||
ua := fp.BuildUserAgent()
|
||||
if len(amzUA) >= len(ua) {
|
||||
t.Error("X-Amz-User-Agent should be shorter than User-Agent")
|
||||
}
|
||||
}
|
||||
|
||||
func TestSetRuntimeHeaders(t *testing.T) {
|
||||
req, _ := http.NewRequest("GET", "http://example.com", nil)
|
||||
accessToken := "test-access-token-1234567890"
|
||||
clientID := "test-client-id-12345"
|
||||
accountKey := GenerateAccountKey(clientID)
|
||||
fp := GlobalFingerprintManager().GetFingerprint(accountKey)
|
||||
machineID := fp.KiroHash
|
||||
|
||||
setRuntimeHeaders(req, accessToken, accountKey)
|
||||
|
||||
// Check Authorization header
|
||||
if req.Header.Get("Authorization") != "Bearer "+accessToken {
|
||||
t.Errorf("expected Authorization header 'Bearer %s', got '%s'", accessToken, req.Header.Get("Authorization"))
|
||||
}
|
||||
|
||||
// Check x-amz-user-agent header
|
||||
amzUA := req.Header.Get("x-amz-user-agent")
|
||||
if amzUA == "" {
|
||||
t.Error("expected x-amz-user-agent header to be set")
|
||||
}
|
||||
if !strings.Contains(amzUA, "aws-sdk-js/") {
|
||||
t.Errorf("x-amz-user-agent should contain aws-sdk-js: %s", amzUA)
|
||||
}
|
||||
if !strings.Contains(amzUA, "KiroIDE-") {
|
||||
t.Errorf("x-amz-user-agent should contain KiroIDE: %s", amzUA)
|
||||
}
|
||||
if !strings.Contains(amzUA, machineID) {
|
||||
t.Errorf("x-amz-user-agent should contain machineID: %s", amzUA)
|
||||
}
|
||||
|
||||
// Check User-Agent header
|
||||
ua := req.Header.Get("User-Agent")
|
||||
if ua == "" {
|
||||
t.Error("expected User-Agent header to be set")
|
||||
}
|
||||
if !strings.Contains(ua, "api/codewhispererruntime#") {
|
||||
t.Errorf("User-Agent should contain api/codewhispererruntime: %s", ua)
|
||||
}
|
||||
if !strings.Contains(ua, "m/N,E") {
|
||||
t.Errorf("User-Agent should contain m/N,E: %s", ua)
|
||||
}
|
||||
|
||||
// Check amz-sdk-invocation-id (should be a UUID)
|
||||
invocationID := req.Header.Get("amz-sdk-invocation-id")
|
||||
if invocationID == "" {
|
||||
t.Error("expected amz-sdk-invocation-id header to be set")
|
||||
}
|
||||
if len(invocationID) != 36 {
|
||||
t.Errorf("expected amz-sdk-invocation-id to be UUID (36 chars), got %d", len(invocationID))
|
||||
}
|
||||
|
||||
// Check amz-sdk-request
|
||||
if req.Header.Get("amz-sdk-request") != "attempt=1; max=1" {
|
||||
t.Errorf("unexpected amz-sdk-request header: %s", req.Header.Get("amz-sdk-request"))
|
||||
}
|
||||
}
|
||||
|
||||
func TestSDKVersionsAreValid(t *testing.T) {
|
||||
// Verify all OIDC SDK versions match expected format (3.xxx.x)
|
||||
for _, v := range oidcSDKVersions {
|
||||
if !strings.HasPrefix(v, "3.") {
|
||||
t.Errorf("OIDC SDK version should start with 3.: %s", v)
|
||||
}
|
||||
parts := strings.Split(v, ".")
|
||||
if len(parts) != 3 {
|
||||
t.Errorf("OIDC SDK version should have 3 parts: %s", v)
|
||||
}
|
||||
}
|
||||
|
||||
for _, v := range runtimeSDKVersions {
|
||||
parts := strings.Split(v, ".")
|
||||
if len(parts) != 3 {
|
||||
t.Errorf("Runtime SDK version should have 3 parts: %s", v)
|
||||
}
|
||||
}
|
||||
|
||||
for _, v := range streamingSDKVersions {
|
||||
parts := strings.Split(v, ".")
|
||||
if len(parts) != 3 {
|
||||
t.Errorf("Streaming SDK version should have 3 parts: %s", v)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func TestKiroVersionsAreValid(t *testing.T) {
|
||||
// Verify all Kiro versions match expected format (0.x.xxx)
|
||||
for _, v := range kiroVersions {
|
||||
if !strings.HasPrefix(v, "0.") {
|
||||
t.Errorf("Kiro version should start with 0.: %s", v)
|
||||
}
|
||||
parts := strings.Split(v, ".")
|
||||
if len(parts) != 3 {
|
||||
t.Errorf("Kiro version should have 3 parts: %s", v)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func TestNodeVersionsAreValid(t *testing.T) {
|
||||
// Verify all Node versions match expected format (xx.xx.x)
|
||||
for _, v := range nodeVersions {
|
||||
parts := strings.Split(v, ".")
|
||||
if len(parts) != 3 {
|
||||
t.Errorf("Node version should have 3 parts: %s", v)
|
||||
}
|
||||
// Should be Node 20.x or 22.x
|
||||
if !strings.HasPrefix(v, "20.") && !strings.HasPrefix(v, "22.") {
|
||||
t.Errorf("Node version should be 20.x or 22.x LTS: %s", v)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func TestFingerprintManager_SetConfig(t *testing.T) {
|
||||
fm := NewFingerprintManager()
|
||||
|
||||
// Without config, should generate random fingerprint
|
||||
fp1 := fm.GetFingerprint("token1")
|
||||
if fp1 == nil {
|
||||
t.Fatal("expected non-nil fingerprint")
|
||||
}
|
||||
|
||||
// Set config with all fields
|
||||
cfg := &FingerprintConfig{
|
||||
OIDCSDKVersion: "3.999.0",
|
||||
RuntimeSDKVersion: "9.9.9",
|
||||
StreamingSDKVersion: "8.8.8",
|
||||
OSType: "darwin",
|
||||
OSVersion: "99.0.0",
|
||||
NodeVersion: "99.99.99",
|
||||
KiroVersion: "9.9.999",
|
||||
KiroHash: "customhash123",
|
||||
}
|
||||
fm.SetConfig(cfg)
|
||||
|
||||
// After setting config, should use config values
|
||||
fp2 := fm.GetFingerprint("token2")
|
||||
if fp2.OIDCSDKVersion != "3.999.0" {
|
||||
t.Errorf("expected OIDCSDKVersion '3.999.0', got '%s'", fp2.OIDCSDKVersion)
|
||||
}
|
||||
if fp2.RuntimeSDKVersion != "9.9.9" {
|
||||
t.Errorf("expected RuntimeSDKVersion '9.9.9', got '%s'", fp2.RuntimeSDKVersion)
|
||||
}
|
||||
if fp2.StreamingSDKVersion != "8.8.8" {
|
||||
t.Errorf("expected StreamingSDKVersion '8.8.8', got '%s'", fp2.StreamingSDKVersion)
|
||||
}
|
||||
if fp2.OSType != "darwin" {
|
||||
t.Errorf("expected OSType 'darwin', got '%s'", fp2.OSType)
|
||||
}
|
||||
if fp2.OSVersion != "99.0.0" {
|
||||
t.Errorf("expected OSVersion '99.0.0', got '%s'", fp2.OSVersion)
|
||||
}
|
||||
if fp2.NodeVersion != "99.99.99" {
|
||||
t.Errorf("expected NodeVersion '99.99.99', got '%s'", fp2.NodeVersion)
|
||||
}
|
||||
if fp2.KiroVersion != "9.9.999" {
|
||||
t.Errorf("expected KiroVersion '9.9.999', got '%s'", fp2.KiroVersion)
|
||||
}
|
||||
if fp2.KiroHash != "customhash123" {
|
||||
t.Errorf("expected KiroHash 'customhash123', got '%s'", fp2.KiroHash)
|
||||
}
|
||||
}
|
||||
|
||||
func TestFingerprintManager_SetConfig_PartialFields(t *testing.T) {
|
||||
fm := NewFingerprintManager()
|
||||
|
||||
// Set config with only some fields
|
||||
cfg := &FingerprintConfig{
|
||||
KiroVersion: "1.2.345",
|
||||
KiroHash: "myhash",
|
||||
// Other fields empty - should use random
|
||||
}
|
||||
fm.SetConfig(cfg)
|
||||
|
||||
fp := fm.GetFingerprint("token1")
|
||||
|
||||
// Configured fields should use config values
|
||||
if fp.KiroVersion != "1.2.345" {
|
||||
t.Errorf("expected KiroVersion '1.2.345', got '%s'", fp.KiroVersion)
|
||||
}
|
||||
if fp.KiroHash != "myhash" {
|
||||
t.Errorf("expected KiroHash 'myhash', got '%s'", fp.KiroHash)
|
||||
}
|
||||
|
||||
// Empty fields should be randomly selected (non-empty)
|
||||
if fp.OIDCSDKVersion == "" {
|
||||
t.Error("expected non-empty OIDCSDKVersion")
|
||||
}
|
||||
if fp.OSType == "" {
|
||||
t.Error("expected non-empty OSType")
|
||||
}
|
||||
if fp.NodeVersion == "" {
|
||||
t.Error("expected non-empty NodeVersion")
|
||||
}
|
||||
}
|
||||
|
||||
func TestFingerprintManager_SetConfig_ClearsCache(t *testing.T) {
|
||||
fm := NewFingerprintManager()
|
||||
|
||||
// Get fingerprint before config
|
||||
fp1 := fm.GetFingerprint("token1")
|
||||
originalHash := fp1.KiroHash
|
||||
|
||||
// Set config
|
||||
cfg := &FingerprintConfig{
|
||||
KiroHash: "newcustomhash",
|
||||
}
|
||||
fm.SetConfig(cfg)
|
||||
|
||||
// Same token should now return different fingerprint (cache cleared)
|
||||
fp2 := fm.GetFingerprint("token1")
|
||||
if fp2.KiroHash == originalHash {
|
||||
t.Error("expected cache to be cleared after SetConfig")
|
||||
}
|
||||
if fp2.KiroHash != "newcustomhash" {
|
||||
t.Errorf("expected KiroHash 'newcustomhash', got '%s'", fp2.KiroHash)
|
||||
}
|
||||
}
|
||||
|
||||
func TestGenerateAccountKey(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
seed string
|
||||
check func(t *testing.T, result string)
|
||||
}{
|
||||
{
|
||||
name: "Empty seed",
|
||||
seed: "",
|
||||
check: func(t *testing.T, result string) {
|
||||
if result == "" {
|
||||
t.Error("expected non-empty result for empty seed")
|
||||
}
|
||||
if len(result) != 16 {
|
||||
t.Errorf("expected 16 char hex string, got %d chars", len(result))
|
||||
}
|
||||
},
|
||||
},
|
||||
{
|
||||
name: "Simple seed",
|
||||
seed: "test-client-id",
|
||||
check: func(t *testing.T, result string) {
|
||||
if len(result) != 16 {
|
||||
t.Errorf("expected 16 char hex string, got %d chars", len(result))
|
||||
}
|
||||
// Verify it's valid hex
|
||||
for _, c := range result {
|
||||
if (c < '0' || c > '9') && (c < 'a' || c > 'f') {
|
||||
t.Errorf("invalid hex character: %c", c)
|
||||
}
|
||||
}
|
||||
},
|
||||
},
|
||||
{
|
||||
name: "Same seed produces same result",
|
||||
seed: "deterministic-seed",
|
||||
check: func(t *testing.T, result string) {
|
||||
result2 := GenerateAccountKey("deterministic-seed")
|
||||
if result != result2 {
|
||||
t.Errorf("same seed should produce same result: %s vs %s", result, result2)
|
||||
}
|
||||
},
|
||||
},
|
||||
{
|
||||
name: "Different seeds produce different results",
|
||||
seed: "seed-one",
|
||||
check: func(t *testing.T, result string) {
|
||||
result2 := GenerateAccountKey("seed-two")
|
||||
if result == result2 {
|
||||
t.Errorf("different seeds should produce different results: %s vs %s", result, result2)
|
||||
}
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
result := GenerateAccountKey(tt.seed)
|
||||
tt.check(t, result)
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestGetAccountKey(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
clientID string
|
||||
refreshToken string
|
||||
check func(t *testing.T, result string)
|
||||
}{
|
||||
{
|
||||
name: "Priority 1: clientID when both provided",
|
||||
clientID: "client-id-123",
|
||||
refreshToken: "refresh-token-456",
|
||||
check: func(t *testing.T, result string) {
|
||||
expected := GenerateAccountKey("client-id-123")
|
||||
if result != expected {
|
||||
t.Errorf("expected clientID-based key %s, got %s", expected, result)
|
||||
}
|
||||
},
|
||||
},
|
||||
{
|
||||
name: "Priority 2: refreshToken when clientID is empty",
|
||||
clientID: "",
|
||||
refreshToken: "refresh-token-789",
|
||||
check: func(t *testing.T, result string) {
|
||||
expected := GenerateAccountKey("refresh-token-789")
|
||||
if result != expected {
|
||||
t.Errorf("expected refreshToken-based key %s, got %s", expected, result)
|
||||
}
|
||||
},
|
||||
},
|
||||
{
|
||||
name: "Priority 3: random when both empty",
|
||||
clientID: "",
|
||||
refreshToken: "",
|
||||
check: func(t *testing.T, result string) {
|
||||
if len(result) != 16 {
|
||||
t.Errorf("expected 16 char key, got %d chars", len(result))
|
||||
}
|
||||
// Should be different each time (random UUID)
|
||||
result2 := GetAccountKey("", "")
|
||||
if result == result2 {
|
||||
t.Log("warning: random keys are the same (possible but unlikely)")
|
||||
}
|
||||
},
|
||||
},
|
||||
{
|
||||
name: "clientID only",
|
||||
clientID: "solo-client-id",
|
||||
refreshToken: "",
|
||||
check: func(t *testing.T, result string) {
|
||||
expected := GenerateAccountKey("solo-client-id")
|
||||
if result != expected {
|
||||
t.Errorf("expected clientID-based key %s, got %s", expected, result)
|
||||
}
|
||||
},
|
||||
},
|
||||
{
|
||||
name: "refreshToken only",
|
||||
clientID: "",
|
||||
refreshToken: "solo-refresh-token",
|
||||
check: func(t *testing.T, result string) {
|
||||
expected := GenerateAccountKey("solo-refresh-token")
|
||||
if result != expected {
|
||||
t.Errorf("expected refreshToken-based key %s, got %s", expected, result)
|
||||
}
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
result := GetAccountKey(tt.clientID, tt.refreshToken)
|
||||
tt.check(t, result)
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestGetAccountKey_Deterministic(t *testing.T) {
|
||||
// Verify that GetAccountKey produces deterministic results for same inputs
|
||||
clientID := "test-client-id-abc"
|
||||
refreshToken := "test-refresh-token-xyz"
|
||||
|
||||
// Call multiple times with same inputs
|
||||
results := make([]string, 10)
|
||||
for i := range 10 {
|
||||
results[i] = GetAccountKey(clientID, refreshToken)
|
||||
}
|
||||
|
||||
// All results should be identical
|
||||
for i := 1; i < 10; i++ {
|
||||
if results[i] != results[0] {
|
||||
t.Errorf("GetAccountKey should be deterministic: got %s and %s", results[0], results[i])
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func TestFingerprintDeterministic(t *testing.T) {
|
||||
// Verify that fingerprints are deterministic based on accountKey
|
||||
fm := NewFingerprintManager()
|
||||
|
||||
accountKey := GenerateAccountKey("test-client-id")
|
||||
|
||||
// Get fingerprint multiple times
|
||||
fp1 := fm.GetFingerprint(accountKey)
|
||||
fp2 := fm.GetFingerprint(accountKey)
|
||||
|
||||
// Should be the same pointer (cached)
|
||||
if fp1 != fp2 {
|
||||
t.Error("expected same fingerprint pointer for same key")
|
||||
}
|
||||
|
||||
// Create new manager and verify same values
|
||||
fm2 := NewFingerprintManager()
|
||||
fp3 := fm2.GetFingerprint(accountKey)
|
||||
|
||||
// Values should be identical (deterministic generation)
|
||||
if fp1.KiroHash != fp3.KiroHash {
|
||||
t.Errorf("KiroHash should be deterministic: %s vs %s", fp1.KiroHash, fp3.KiroHash)
|
||||
}
|
||||
if fp1.OSType != fp3.OSType {
|
||||
t.Errorf("OSType should be deterministic: %s vs %s", fp1.OSType, fp3.OSType)
|
||||
}
|
||||
if fp1.OSVersion != fp3.OSVersion {
|
||||
t.Errorf("OSVersion should be deterministic: %s vs %s", fp1.OSVersion, fp3.OSVersion)
|
||||
}
|
||||
if fp1.KiroVersion != fp3.KiroVersion {
|
||||
t.Errorf("KiroVersion should be deterministic: %s vs %s", fp1.KiroVersion, fp3.KiroVersion)
|
||||
}
|
||||
if fp1.NodeVersion != fp3.NodeVersion {
|
||||
t.Errorf("NodeVersion should be deterministic: %s vs %s", fp1.NodeVersion, fp3.NodeVersion)
|
||||
}
|
||||
}
|
||||
174
internal/auth/kiro/jitter.go
Normal file
174
internal/auth/kiro/jitter.go
Normal file
@@ -0,0 +1,174 @@
|
||||
package kiro
|
||||
|
||||
import (
|
||||
"math/rand"
|
||||
"sync"
|
||||
"time"
|
||||
)
|
||||
|
||||
// Jitter configuration constants
|
||||
const (
|
||||
// JitterPercent is the default percentage of jitter to apply (±30%)
|
||||
JitterPercent = 0.30
|
||||
|
||||
// Human-like delay ranges
|
||||
ShortDelayMin = 50 * time.Millisecond // Minimum for rapid consecutive operations
|
||||
ShortDelayMax = 200 * time.Millisecond // Maximum for rapid consecutive operations
|
||||
NormalDelayMin = 1 * time.Second // Minimum for normal thinking time
|
||||
NormalDelayMax = 3 * time.Second // Maximum for normal thinking time
|
||||
LongDelayMin = 5 * time.Second // Minimum for reading/resting
|
||||
LongDelayMax = 10 * time.Second // Maximum for reading/resting
|
||||
|
||||
// Probability thresholds for human-like behavior
|
||||
ShortDelayProbability = 0.20 // 20% chance of short delay (consecutive ops)
|
||||
LongDelayProbability = 0.05 // 5% chance of long delay (reading/resting)
|
||||
NormalDelayProbability = 0.75 // 75% chance of normal delay (thinking)
|
||||
)
|
||||
|
||||
var (
|
||||
jitterRand *rand.Rand
|
||||
jitterRandOnce sync.Once
|
||||
jitterMu sync.Mutex
|
||||
lastRequestTime time.Time
|
||||
)
|
||||
|
||||
// initJitterRand initializes the random number generator for jitter calculations.
|
||||
// Uses a time-based seed for unpredictable but reproducible randomness.
|
||||
func initJitterRand() {
|
||||
jitterRandOnce.Do(func() {
|
||||
jitterRand = rand.New(rand.NewSource(time.Now().UnixNano()))
|
||||
})
|
||||
}
|
||||
|
||||
// RandomDelay generates a random delay between min and max duration.
|
||||
// Thread-safe implementation using mutex protection.
|
||||
func RandomDelay(min, max time.Duration) time.Duration {
|
||||
initJitterRand()
|
||||
jitterMu.Lock()
|
||||
defer jitterMu.Unlock()
|
||||
|
||||
if min >= max {
|
||||
return min
|
||||
}
|
||||
|
||||
rangeMs := max.Milliseconds() - min.Milliseconds()
|
||||
randomMs := jitterRand.Int63n(rangeMs)
|
||||
return min + time.Duration(randomMs)*time.Millisecond
|
||||
}
|
||||
|
||||
// JitterDelay adds jitter to a base delay.
|
||||
// Applies ±jitterPercent variation to the base delay.
|
||||
// For example, JitterDelay(1*time.Second, 0.30) returns a value between 700ms and 1300ms.
|
||||
func JitterDelay(baseDelay time.Duration, jitterPercent float64) time.Duration {
|
||||
initJitterRand()
|
||||
jitterMu.Lock()
|
||||
defer jitterMu.Unlock()
|
||||
|
||||
if jitterPercent <= 0 || jitterPercent > 1 {
|
||||
jitterPercent = JitterPercent
|
||||
}
|
||||
|
||||
// Calculate jitter range: base * jitterPercent
|
||||
jitterRange := float64(baseDelay) * jitterPercent
|
||||
|
||||
// Generate random value in range [-jitterRange, +jitterRange]
|
||||
jitter := (jitterRand.Float64()*2 - 1) * jitterRange
|
||||
|
||||
result := time.Duration(float64(baseDelay) + jitter)
|
||||
if result < 0 {
|
||||
return 0
|
||||
}
|
||||
return result
|
||||
}
|
||||
|
||||
// JitterDelayDefault applies the default ±30% jitter to a base delay.
|
||||
func JitterDelayDefault(baseDelay time.Duration) time.Duration {
|
||||
return JitterDelay(baseDelay, JitterPercent)
|
||||
}
|
||||
|
||||
// HumanLikeDelay generates a delay that mimics human behavior patterns.
|
||||
// The delay is selected based on probability distribution:
|
||||
// - 20% chance: Short delay (50-200ms) - simulates consecutive rapid operations
|
||||
// - 75% chance: Normal delay (1-3s) - simulates thinking/reading time
|
||||
// - 5% chance: Long delay (5-10s) - simulates breaks/reading longer content
|
||||
//
|
||||
// Returns the delay duration (caller should call time.Sleep with this value).
|
||||
func HumanLikeDelay() time.Duration {
|
||||
initJitterRand()
|
||||
jitterMu.Lock()
|
||||
defer jitterMu.Unlock()
|
||||
|
||||
// Track time since last request for adaptive behavior
|
||||
now := time.Now()
|
||||
timeSinceLastRequest := now.Sub(lastRequestTime)
|
||||
lastRequestTime = now
|
||||
|
||||
// If requests are very close together, use short delay
|
||||
if timeSinceLastRequest < 500*time.Millisecond && timeSinceLastRequest > 0 {
|
||||
rangeMs := ShortDelayMax.Milliseconds() - ShortDelayMin.Milliseconds()
|
||||
randomMs := jitterRand.Int63n(rangeMs)
|
||||
return ShortDelayMin + time.Duration(randomMs)*time.Millisecond
|
||||
}
|
||||
|
||||
// Otherwise, use probability-based selection
|
||||
roll := jitterRand.Float64()
|
||||
|
||||
var min, max time.Duration
|
||||
switch {
|
||||
case roll < ShortDelayProbability:
|
||||
// Short delay - consecutive operations
|
||||
min, max = ShortDelayMin, ShortDelayMax
|
||||
case roll < ShortDelayProbability+LongDelayProbability:
|
||||
// Long delay - reading/resting
|
||||
min, max = LongDelayMin, LongDelayMax
|
||||
default:
|
||||
// Normal delay - thinking time
|
||||
min, max = NormalDelayMin, NormalDelayMax
|
||||
}
|
||||
|
||||
rangeMs := max.Milliseconds() - min.Milliseconds()
|
||||
randomMs := jitterRand.Int63n(rangeMs)
|
||||
return min + time.Duration(randomMs)*time.Millisecond
|
||||
}
|
||||
|
||||
// ApplyHumanLikeDelay applies human-like delay by sleeping.
|
||||
// This is a convenience function that combines HumanLikeDelay with time.Sleep.
|
||||
func ApplyHumanLikeDelay() {
|
||||
delay := HumanLikeDelay()
|
||||
if delay > 0 {
|
||||
time.Sleep(delay)
|
||||
}
|
||||
}
|
||||
|
||||
// ExponentialBackoffWithJitter calculates retry delay using exponential backoff with jitter.
|
||||
// Formula: min(baseDelay * 2^attempt + jitter, maxDelay)
|
||||
// This helps prevent thundering herd problem when multiple clients retry simultaneously.
|
||||
func ExponentialBackoffWithJitter(attempt int, baseDelay, maxDelay time.Duration) time.Duration {
|
||||
if attempt < 0 {
|
||||
attempt = 0
|
||||
}
|
||||
|
||||
// Calculate exponential backoff: baseDelay * 2^attempt
|
||||
backoff := baseDelay * time.Duration(1<<uint(attempt))
|
||||
if backoff > maxDelay {
|
||||
backoff = maxDelay
|
||||
}
|
||||
|
||||
// Add ±30% jitter
|
||||
return JitterDelay(backoff, JitterPercent)
|
||||
}
|
||||
|
||||
// ShouldSkipDelay determines if delay should be skipped based on context.
|
||||
// Returns true for streaming responses, WebSocket connections, etc.
|
||||
// This function can be extended to check additional skip conditions.
|
||||
func ShouldSkipDelay(isStreaming bool) bool {
|
||||
return isStreaming
|
||||
}
|
||||
|
||||
// ResetLastRequestTime resets the last request time tracker.
|
||||
// Useful for testing or when starting a new session.
|
||||
func ResetLastRequestTime() {
|
||||
jitterMu.Lock()
|
||||
defer jitterMu.Unlock()
|
||||
lastRequestTime = time.Time{}
|
||||
}
|
||||
187
internal/auth/kiro/metrics.go
Normal file
187
internal/auth/kiro/metrics.go
Normal file
@@ -0,0 +1,187 @@
|
||||
package kiro
|
||||
|
||||
import (
|
||||
"math"
|
||||
"sync"
|
||||
"time"
|
||||
)
|
||||
|
||||
// TokenMetrics holds performance metrics for a single token.
|
||||
type TokenMetrics struct {
|
||||
SuccessRate float64 // Success rate (0.0 - 1.0)
|
||||
AvgLatency float64 // Average latency in milliseconds
|
||||
QuotaRemaining float64 // Remaining quota (0.0 - 1.0)
|
||||
LastUsed time.Time // Last usage timestamp
|
||||
FailCount int // Consecutive failure count
|
||||
TotalRequests int // Total request count
|
||||
successCount int // Internal: successful request count
|
||||
totalLatency float64 // Internal: cumulative latency
|
||||
}
|
||||
|
||||
// TokenScorer manages token metrics and scoring.
|
||||
type TokenScorer struct {
|
||||
mu sync.RWMutex
|
||||
metrics map[string]*TokenMetrics
|
||||
|
||||
// Scoring weights
|
||||
successRateWeight float64
|
||||
quotaWeight float64
|
||||
latencyWeight float64
|
||||
lastUsedWeight float64
|
||||
failPenaltyMultiplier float64
|
||||
}
|
||||
|
||||
// NewTokenScorer creates a new TokenScorer with default weights.
|
||||
func NewTokenScorer() *TokenScorer {
|
||||
return &TokenScorer{
|
||||
metrics: make(map[string]*TokenMetrics),
|
||||
successRateWeight: 0.4,
|
||||
quotaWeight: 0.25,
|
||||
latencyWeight: 0.2,
|
||||
lastUsedWeight: 0.15,
|
||||
failPenaltyMultiplier: 0.1,
|
||||
}
|
||||
}
|
||||
|
||||
// getOrCreateMetrics returns existing metrics or creates new ones.
|
||||
func (s *TokenScorer) getOrCreateMetrics(tokenKey string) *TokenMetrics {
|
||||
if m, ok := s.metrics[tokenKey]; ok {
|
||||
return m
|
||||
}
|
||||
m := &TokenMetrics{
|
||||
SuccessRate: 1.0,
|
||||
QuotaRemaining: 1.0,
|
||||
}
|
||||
s.metrics[tokenKey] = m
|
||||
return m
|
||||
}
|
||||
|
||||
// RecordRequest records the result of a request for a token.
|
||||
func (s *TokenScorer) RecordRequest(tokenKey string, success bool, latency time.Duration) {
|
||||
s.mu.Lock()
|
||||
defer s.mu.Unlock()
|
||||
|
||||
m := s.getOrCreateMetrics(tokenKey)
|
||||
m.TotalRequests++
|
||||
m.LastUsed = time.Now()
|
||||
m.totalLatency += float64(latency.Milliseconds())
|
||||
|
||||
if success {
|
||||
m.successCount++
|
||||
m.FailCount = 0
|
||||
} else {
|
||||
m.FailCount++
|
||||
}
|
||||
|
||||
// Update derived metrics
|
||||
if m.TotalRequests > 0 {
|
||||
m.SuccessRate = float64(m.successCount) / float64(m.TotalRequests)
|
||||
m.AvgLatency = m.totalLatency / float64(m.TotalRequests)
|
||||
}
|
||||
}
|
||||
|
||||
// SetQuotaRemaining updates the remaining quota for a token.
|
||||
func (s *TokenScorer) SetQuotaRemaining(tokenKey string, quota float64) {
|
||||
s.mu.Lock()
|
||||
defer s.mu.Unlock()
|
||||
|
||||
m := s.getOrCreateMetrics(tokenKey)
|
||||
m.QuotaRemaining = quota
|
||||
}
|
||||
|
||||
// GetMetrics returns a copy of the metrics for a token.
|
||||
func (s *TokenScorer) GetMetrics(tokenKey string) *TokenMetrics {
|
||||
s.mu.RLock()
|
||||
defer s.mu.RUnlock()
|
||||
|
||||
if m, ok := s.metrics[tokenKey]; ok {
|
||||
copy := *m
|
||||
return ©
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
// CalculateScore computes the score for a token (higher is better).
|
||||
func (s *TokenScorer) CalculateScore(tokenKey string) float64 {
|
||||
s.mu.RLock()
|
||||
defer s.mu.RUnlock()
|
||||
|
||||
m, ok := s.metrics[tokenKey]
|
||||
if !ok {
|
||||
return 1.0 // New tokens get a high initial score
|
||||
}
|
||||
|
||||
// Success rate component (0-1)
|
||||
successScore := m.SuccessRate
|
||||
|
||||
// Quota component (0-1)
|
||||
quotaScore := m.QuotaRemaining
|
||||
|
||||
// Latency component (normalized, lower is better)
|
||||
// Using exponential decay: score = e^(-latency/1000)
|
||||
// 1000ms latency -> ~0.37 score, 100ms -> ~0.90 score
|
||||
latencyScore := math.Exp(-m.AvgLatency / 1000.0)
|
||||
if m.TotalRequests == 0 {
|
||||
latencyScore = 1.0
|
||||
}
|
||||
|
||||
// Last used component (prefer tokens not recently used)
|
||||
// Score increases as time since last use increases
|
||||
timeSinceUse := time.Since(m.LastUsed).Seconds()
|
||||
// Normalize: 60 seconds -> ~0.63 score, 0 seconds -> 0 score
|
||||
lastUsedScore := 1.0 - math.Exp(-timeSinceUse/60.0)
|
||||
if m.LastUsed.IsZero() {
|
||||
lastUsedScore = 1.0
|
||||
}
|
||||
|
||||
// Calculate weighted score
|
||||
score := s.successRateWeight*successScore +
|
||||
s.quotaWeight*quotaScore +
|
||||
s.latencyWeight*latencyScore +
|
||||
s.lastUsedWeight*lastUsedScore
|
||||
|
||||
// Apply consecutive failure penalty
|
||||
if m.FailCount > 0 {
|
||||
penalty := s.failPenaltyMultiplier * float64(m.FailCount)
|
||||
score = score * math.Max(0, 1.0-penalty)
|
||||
}
|
||||
|
||||
return score
|
||||
}
|
||||
|
||||
// SelectBestToken selects the token with the highest score.
|
||||
func (s *TokenScorer) SelectBestToken(tokens []string) string {
|
||||
if len(tokens) == 0 {
|
||||
return ""
|
||||
}
|
||||
if len(tokens) == 1 {
|
||||
return tokens[0]
|
||||
}
|
||||
|
||||
bestToken := tokens[0]
|
||||
bestScore := s.CalculateScore(tokens[0])
|
||||
|
||||
for _, token := range tokens[1:] {
|
||||
score := s.CalculateScore(token)
|
||||
if score > bestScore {
|
||||
bestScore = score
|
||||
bestToken = token
|
||||
}
|
||||
}
|
||||
|
||||
return bestToken
|
||||
}
|
||||
|
||||
// ResetMetrics clears all metrics for a token.
|
||||
func (s *TokenScorer) ResetMetrics(tokenKey string) {
|
||||
s.mu.Lock()
|
||||
defer s.mu.Unlock()
|
||||
delete(s.metrics, tokenKey)
|
||||
}
|
||||
|
||||
// ResetAllMetrics clears all stored metrics.
|
||||
func (s *TokenScorer) ResetAllMetrics() {
|
||||
s.mu.Lock()
|
||||
defer s.mu.Unlock()
|
||||
s.metrics = make(map[string]*TokenMetrics)
|
||||
}
|
||||
301
internal/auth/kiro/metrics_test.go
Normal file
301
internal/auth/kiro/metrics_test.go
Normal file
@@ -0,0 +1,301 @@
|
||||
package kiro
|
||||
|
||||
import (
|
||||
"sync"
|
||||
"testing"
|
||||
"time"
|
||||
)
|
||||
|
||||
func TestNewTokenScorer(t *testing.T) {
|
||||
s := NewTokenScorer()
|
||||
if s == nil {
|
||||
t.Fatal("expected non-nil TokenScorer")
|
||||
}
|
||||
if s.metrics == nil {
|
||||
t.Error("expected non-nil metrics map")
|
||||
}
|
||||
if s.successRateWeight != 0.4 {
|
||||
t.Errorf("expected successRateWeight 0.4, got %f", s.successRateWeight)
|
||||
}
|
||||
if s.quotaWeight != 0.25 {
|
||||
t.Errorf("expected quotaWeight 0.25, got %f", s.quotaWeight)
|
||||
}
|
||||
}
|
||||
|
||||
func TestRecordRequest_Success(t *testing.T) {
|
||||
s := NewTokenScorer()
|
||||
s.RecordRequest("token1", true, 100*time.Millisecond)
|
||||
|
||||
m := s.GetMetrics("token1")
|
||||
if m == nil {
|
||||
t.Fatal("expected non-nil metrics")
|
||||
}
|
||||
if m.TotalRequests != 1 {
|
||||
t.Errorf("expected TotalRequests 1, got %d", m.TotalRequests)
|
||||
}
|
||||
if m.SuccessRate != 1.0 {
|
||||
t.Errorf("expected SuccessRate 1.0, got %f", m.SuccessRate)
|
||||
}
|
||||
if m.FailCount != 0 {
|
||||
t.Errorf("expected FailCount 0, got %d", m.FailCount)
|
||||
}
|
||||
if m.AvgLatency != 100 {
|
||||
t.Errorf("expected AvgLatency 100, got %f", m.AvgLatency)
|
||||
}
|
||||
}
|
||||
|
||||
func TestRecordRequest_Failure(t *testing.T) {
|
||||
s := NewTokenScorer()
|
||||
s.RecordRequest("token1", false, 200*time.Millisecond)
|
||||
|
||||
m := s.GetMetrics("token1")
|
||||
if m.SuccessRate != 0.0 {
|
||||
t.Errorf("expected SuccessRate 0.0, got %f", m.SuccessRate)
|
||||
}
|
||||
if m.FailCount != 1 {
|
||||
t.Errorf("expected FailCount 1, got %d", m.FailCount)
|
||||
}
|
||||
}
|
||||
|
||||
func TestRecordRequest_MixedResults(t *testing.T) {
|
||||
s := NewTokenScorer()
|
||||
s.RecordRequest("token1", true, 100*time.Millisecond)
|
||||
s.RecordRequest("token1", true, 100*time.Millisecond)
|
||||
s.RecordRequest("token1", false, 100*time.Millisecond)
|
||||
s.RecordRequest("token1", true, 100*time.Millisecond)
|
||||
|
||||
m := s.GetMetrics("token1")
|
||||
if m.TotalRequests != 4 {
|
||||
t.Errorf("expected TotalRequests 4, got %d", m.TotalRequests)
|
||||
}
|
||||
if m.SuccessRate != 0.75 {
|
||||
t.Errorf("expected SuccessRate 0.75, got %f", m.SuccessRate)
|
||||
}
|
||||
if m.FailCount != 0 {
|
||||
t.Errorf("expected FailCount 0 (reset on success), got %d", m.FailCount)
|
||||
}
|
||||
}
|
||||
|
||||
func TestRecordRequest_ConsecutiveFailures(t *testing.T) {
|
||||
s := NewTokenScorer()
|
||||
s.RecordRequest("token1", true, 100*time.Millisecond)
|
||||
s.RecordRequest("token1", false, 100*time.Millisecond)
|
||||
s.RecordRequest("token1", false, 100*time.Millisecond)
|
||||
s.RecordRequest("token1", false, 100*time.Millisecond)
|
||||
|
||||
m := s.GetMetrics("token1")
|
||||
if m.FailCount != 3 {
|
||||
t.Errorf("expected FailCount 3, got %d", m.FailCount)
|
||||
}
|
||||
}
|
||||
|
||||
func TestSetQuotaRemaining(t *testing.T) {
|
||||
s := NewTokenScorer()
|
||||
s.SetQuotaRemaining("token1", 0.5)
|
||||
|
||||
m := s.GetMetrics("token1")
|
||||
if m.QuotaRemaining != 0.5 {
|
||||
t.Errorf("expected QuotaRemaining 0.5, got %f", m.QuotaRemaining)
|
||||
}
|
||||
}
|
||||
|
||||
func TestGetMetrics_NonExistent(t *testing.T) {
|
||||
s := NewTokenScorer()
|
||||
m := s.GetMetrics("nonexistent")
|
||||
if m != nil {
|
||||
t.Error("expected nil metrics for non-existent token")
|
||||
}
|
||||
}
|
||||
|
||||
func TestGetMetrics_ReturnsCopy(t *testing.T) {
|
||||
s := NewTokenScorer()
|
||||
s.RecordRequest("token1", true, 100*time.Millisecond)
|
||||
|
||||
m1 := s.GetMetrics("token1")
|
||||
m1.TotalRequests = 999
|
||||
|
||||
m2 := s.GetMetrics("token1")
|
||||
if m2.TotalRequests == 999 {
|
||||
t.Error("GetMetrics should return a copy")
|
||||
}
|
||||
}
|
||||
|
||||
func TestCalculateScore_NewToken(t *testing.T) {
|
||||
s := NewTokenScorer()
|
||||
score := s.CalculateScore("newtoken")
|
||||
if score != 1.0 {
|
||||
t.Errorf("expected score 1.0 for new token, got %f", score)
|
||||
}
|
||||
}
|
||||
|
||||
func TestCalculateScore_PerfectToken(t *testing.T) {
|
||||
s := NewTokenScorer()
|
||||
s.RecordRequest("token1", true, 50*time.Millisecond)
|
||||
s.SetQuotaRemaining("token1", 1.0)
|
||||
|
||||
time.Sleep(100 * time.Millisecond)
|
||||
score := s.CalculateScore("token1")
|
||||
if score < 0.5 || score > 1.0 {
|
||||
t.Errorf("expected high score for perfect token, got %f", score)
|
||||
}
|
||||
}
|
||||
|
||||
func TestCalculateScore_FailedToken(t *testing.T) {
|
||||
s := NewTokenScorer()
|
||||
for i := 0; i < 5; i++ {
|
||||
s.RecordRequest("token1", false, 1000*time.Millisecond)
|
||||
}
|
||||
s.SetQuotaRemaining("token1", 0.1)
|
||||
|
||||
score := s.CalculateScore("token1")
|
||||
if score > 0.5 {
|
||||
t.Errorf("expected low score for failed token, got %f", score)
|
||||
}
|
||||
}
|
||||
|
||||
func TestCalculateScore_FailPenalty(t *testing.T) {
|
||||
s := NewTokenScorer()
|
||||
s.RecordRequest("token1", true, 100*time.Millisecond)
|
||||
scoreNoFail := s.CalculateScore("token1")
|
||||
|
||||
s.RecordRequest("token1", false, 100*time.Millisecond)
|
||||
s.RecordRequest("token1", false, 100*time.Millisecond)
|
||||
scoreWithFail := s.CalculateScore("token1")
|
||||
|
||||
if scoreWithFail >= scoreNoFail {
|
||||
t.Errorf("expected lower score with consecutive failures: noFail=%f, withFail=%f", scoreNoFail, scoreWithFail)
|
||||
}
|
||||
}
|
||||
|
||||
func TestSelectBestToken_Empty(t *testing.T) {
|
||||
s := NewTokenScorer()
|
||||
best := s.SelectBestToken([]string{})
|
||||
if best != "" {
|
||||
t.Errorf("expected empty string for empty tokens, got %s", best)
|
||||
}
|
||||
}
|
||||
|
||||
func TestSelectBestToken_SingleToken(t *testing.T) {
|
||||
s := NewTokenScorer()
|
||||
best := s.SelectBestToken([]string{"token1"})
|
||||
if best != "token1" {
|
||||
t.Errorf("expected token1, got %s", best)
|
||||
}
|
||||
}
|
||||
|
||||
func TestSelectBestToken_MultipleTokens(t *testing.T) {
|
||||
s := NewTokenScorer()
|
||||
|
||||
s.RecordRequest("bad", false, 1000*time.Millisecond)
|
||||
s.RecordRequest("bad", false, 1000*time.Millisecond)
|
||||
s.SetQuotaRemaining("bad", 0.1)
|
||||
|
||||
s.RecordRequest("good", true, 50*time.Millisecond)
|
||||
s.SetQuotaRemaining("good", 0.9)
|
||||
|
||||
time.Sleep(50 * time.Millisecond)
|
||||
|
||||
best := s.SelectBestToken([]string{"bad", "good"})
|
||||
if best != "good" {
|
||||
t.Errorf("expected good token to be selected, got %s", best)
|
||||
}
|
||||
}
|
||||
|
||||
func TestResetMetrics(t *testing.T) {
|
||||
s := NewTokenScorer()
|
||||
s.RecordRequest("token1", true, 100*time.Millisecond)
|
||||
s.ResetMetrics("token1")
|
||||
|
||||
m := s.GetMetrics("token1")
|
||||
if m != nil {
|
||||
t.Error("expected nil metrics after reset")
|
||||
}
|
||||
}
|
||||
|
||||
func TestResetAllMetrics(t *testing.T) {
|
||||
s := NewTokenScorer()
|
||||
s.RecordRequest("token1", true, 100*time.Millisecond)
|
||||
s.RecordRequest("token2", true, 100*time.Millisecond)
|
||||
s.RecordRequest("token3", true, 100*time.Millisecond)
|
||||
|
||||
s.ResetAllMetrics()
|
||||
|
||||
if s.GetMetrics("token1") != nil {
|
||||
t.Error("expected nil metrics for token1 after reset all")
|
||||
}
|
||||
if s.GetMetrics("token2") != nil {
|
||||
t.Error("expected nil metrics for token2 after reset all")
|
||||
}
|
||||
}
|
||||
|
||||
func TestTokenScorer_ConcurrentAccess(t *testing.T) {
|
||||
s := NewTokenScorer()
|
||||
const numGoroutines = 50
|
||||
const numOperations = 100
|
||||
|
||||
var wg sync.WaitGroup
|
||||
wg.Add(numGoroutines)
|
||||
|
||||
for i := 0; i < numGoroutines; i++ {
|
||||
go func(id int) {
|
||||
defer wg.Done()
|
||||
tokenKey := "token" + string(rune('a'+id%10))
|
||||
for j := 0; j < numOperations; j++ {
|
||||
switch j % 6 {
|
||||
case 0:
|
||||
s.RecordRequest(tokenKey, j%2 == 0, time.Duration(j)*time.Millisecond)
|
||||
case 1:
|
||||
s.SetQuotaRemaining(tokenKey, float64(j%100)/100)
|
||||
case 2:
|
||||
s.GetMetrics(tokenKey)
|
||||
case 3:
|
||||
s.CalculateScore(tokenKey)
|
||||
case 4:
|
||||
s.SelectBestToken([]string{tokenKey, "token_x", "token_y"})
|
||||
case 5:
|
||||
if j%20 == 0 {
|
||||
s.ResetMetrics(tokenKey)
|
||||
}
|
||||
}
|
||||
}
|
||||
}(i)
|
||||
}
|
||||
|
||||
wg.Wait()
|
||||
}
|
||||
|
||||
func TestAvgLatencyCalculation(t *testing.T) {
|
||||
s := NewTokenScorer()
|
||||
s.RecordRequest("token1", true, 100*time.Millisecond)
|
||||
s.RecordRequest("token1", true, 200*time.Millisecond)
|
||||
s.RecordRequest("token1", true, 300*time.Millisecond)
|
||||
|
||||
m := s.GetMetrics("token1")
|
||||
if m.AvgLatency != 200 {
|
||||
t.Errorf("expected AvgLatency 200, got %f", m.AvgLatency)
|
||||
}
|
||||
}
|
||||
|
||||
func TestLastUsedUpdated(t *testing.T) {
|
||||
s := NewTokenScorer()
|
||||
before := time.Now()
|
||||
s.RecordRequest("token1", true, 100*time.Millisecond)
|
||||
|
||||
m := s.GetMetrics("token1")
|
||||
if m.LastUsed.Before(before) {
|
||||
t.Error("expected LastUsed to be after test start time")
|
||||
}
|
||||
if m.LastUsed.After(time.Now()) {
|
||||
t.Error("expected LastUsed to be before or equal to now")
|
||||
}
|
||||
}
|
||||
|
||||
func TestDefaultQuotaForNewToken(t *testing.T) {
|
||||
s := NewTokenScorer()
|
||||
s.RecordRequest("token1", true, 100*time.Millisecond)
|
||||
|
||||
m := s.GetMetrics("token1")
|
||||
if m.QuotaRemaining != 1.0 {
|
||||
t.Errorf("expected default QuotaRemaining 1.0, got %f", m.QuotaRemaining)
|
||||
}
|
||||
}
|
||||
319
internal/auth/kiro/oauth.go
Normal file
319
internal/auth/kiro/oauth.go
Normal file
@@ -0,0 +1,319 @@
|
||||
// Package kiro provides OAuth2 authentication for Kiro using native Google login.
|
||||
package kiro
|
||||
|
||||
import (
|
||||
"context"
|
||||
"crypto/rand"
|
||||
"crypto/sha256"
|
||||
"encoding/base64"
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"html"
|
||||
"io"
|
||||
"net"
|
||||
"net/http"
|
||||
"strings"
|
||||
"time"
|
||||
|
||||
"github.com/router-for-me/CLIProxyAPI/v6/internal/config"
|
||||
"github.com/router-for-me/CLIProxyAPI/v6/internal/util"
|
||||
log "github.com/sirupsen/logrus"
|
||||
)
|
||||
|
||||
const (
|
||||
// Kiro auth endpoint
|
||||
kiroAuthEndpoint = "https://prod.us-east-1.auth.desktop.kiro.dev"
|
||||
|
||||
// Default callback port
|
||||
defaultCallbackPort = 9876
|
||||
|
||||
// Auth timeout
|
||||
authTimeout = 10 * time.Minute
|
||||
)
|
||||
|
||||
// KiroTokenResponse represents the response from Kiro token endpoint.
|
||||
type KiroTokenResponse struct {
|
||||
AccessToken string `json:"accessToken"`
|
||||
RefreshToken string `json:"refreshToken"`
|
||||
ProfileArn string `json:"profileArn"`
|
||||
ExpiresIn int `json:"expiresIn"`
|
||||
}
|
||||
|
||||
// KiroOAuth handles the OAuth flow for Kiro authentication.
|
||||
type KiroOAuth struct {
|
||||
httpClient *http.Client
|
||||
cfg *config.Config
|
||||
machineID string
|
||||
kiroVersion string
|
||||
}
|
||||
|
||||
// NewKiroOAuth creates a new Kiro OAuth handler.
|
||||
func NewKiroOAuth(cfg *config.Config) *KiroOAuth {
|
||||
client := &http.Client{Timeout: 30 * time.Second}
|
||||
if cfg != nil {
|
||||
client = util.SetProxy(&cfg.SDKConfig, client)
|
||||
}
|
||||
fp := GlobalFingerprintManager().GetFingerprint("login")
|
||||
return &KiroOAuth{
|
||||
httpClient: client,
|
||||
cfg: cfg,
|
||||
machineID: fp.KiroHash,
|
||||
kiroVersion: fp.KiroVersion,
|
||||
}
|
||||
}
|
||||
|
||||
// generateCodeVerifier generates a random code verifier for PKCE.
|
||||
func generateCodeVerifier() (string, error) {
|
||||
b := make([]byte, 32)
|
||||
if _, err := rand.Read(b); err != nil {
|
||||
return "", err
|
||||
}
|
||||
return base64.RawURLEncoding.EncodeToString(b), nil
|
||||
}
|
||||
|
||||
// generateCodeChallenge generates the code challenge from verifier.
|
||||
func generateCodeChallenge(verifier string) string {
|
||||
h := sha256.Sum256([]byte(verifier))
|
||||
return base64.RawURLEncoding.EncodeToString(h[:])
|
||||
}
|
||||
|
||||
// generateState generates a random state parameter.
|
||||
func generateState() (string, error) {
|
||||
b := make([]byte, 16)
|
||||
if _, err := rand.Read(b); err != nil {
|
||||
return "", err
|
||||
}
|
||||
return base64.RawURLEncoding.EncodeToString(b), nil
|
||||
}
|
||||
|
||||
// AuthResult contains the authorization code and state from callback.
|
||||
type AuthResult struct {
|
||||
Code string
|
||||
State string
|
||||
Error string
|
||||
}
|
||||
|
||||
// startCallbackServer starts a local HTTP server to receive the OAuth callback.
|
||||
func (o *KiroOAuth) startCallbackServer(ctx context.Context, expectedState string) (string, <-chan AuthResult, error) {
|
||||
// Try to find an available port - use localhost like Kiro does
|
||||
listener, err := net.Listen("tcp", fmt.Sprintf("localhost:%d", defaultCallbackPort))
|
||||
if err != nil {
|
||||
// Try with dynamic port (RFC 8252 allows dynamic ports for native apps)
|
||||
log.Warnf("kiro oauth: default port %d is busy, falling back to dynamic port", defaultCallbackPort)
|
||||
listener, err = net.Listen("tcp", "localhost:0")
|
||||
if err != nil {
|
||||
return "", nil, fmt.Errorf("failed to start callback server: %w", err)
|
||||
}
|
||||
}
|
||||
|
||||
port := listener.Addr().(*net.TCPAddr).Port
|
||||
// Use http scheme for local callback server
|
||||
redirectURI := fmt.Sprintf("http://localhost:%d/oauth/callback", port)
|
||||
resultChan := make(chan AuthResult, 1)
|
||||
|
||||
server := &http.Server{
|
||||
ReadHeaderTimeout: 10 * time.Second,
|
||||
}
|
||||
|
||||
mux := http.NewServeMux()
|
||||
mux.HandleFunc("/oauth/callback", func(w http.ResponseWriter, r *http.Request) {
|
||||
code := r.URL.Query().Get("code")
|
||||
state := r.URL.Query().Get("state")
|
||||
errParam := r.URL.Query().Get("error")
|
||||
|
||||
if errParam != "" {
|
||||
w.Header().Set("Content-Type", "text/html")
|
||||
w.WriteHeader(http.StatusBadRequest)
|
||||
fmt.Fprintf(w, `<html><body><h1>Login Failed</h1><p>%s</p><p>You can close this window.</p></body></html>`, html.EscapeString(errParam))
|
||||
resultChan <- AuthResult{Error: errParam}
|
||||
return
|
||||
}
|
||||
|
||||
if state != expectedState {
|
||||
w.Header().Set("Content-Type", "text/html")
|
||||
w.WriteHeader(http.StatusBadRequest)
|
||||
fmt.Fprint(w, `<html><body><h1>Login Failed</h1><p>Invalid state parameter</p><p>You can close this window.</p></body></html>`)
|
||||
resultChan <- AuthResult{Error: "state mismatch"}
|
||||
return
|
||||
}
|
||||
|
||||
w.Header().Set("Content-Type", "text/html")
|
||||
fmt.Fprint(w, `<html><body><h1>Login Successful!</h1><p>You can close this window and return to the terminal.</p></body></html>`)
|
||||
resultChan <- AuthResult{Code: code, State: state}
|
||||
})
|
||||
|
||||
server.Handler = mux
|
||||
|
||||
go func() {
|
||||
if err := server.Serve(listener); err != nil && err != http.ErrServerClosed {
|
||||
log.Debugf("callback server error: %v", err)
|
||||
}
|
||||
}()
|
||||
|
||||
go func() {
|
||||
select {
|
||||
case <-ctx.Done():
|
||||
case <-time.After(authTimeout):
|
||||
case <-resultChan:
|
||||
}
|
||||
_ = server.Shutdown(context.Background())
|
||||
}()
|
||||
|
||||
return redirectURI, resultChan, nil
|
||||
}
|
||||
|
||||
// LoginWithBuilderID performs OAuth login with AWS Builder ID using device code flow.
|
||||
func (o *KiroOAuth) LoginWithBuilderID(ctx context.Context) (*KiroTokenData, error) {
|
||||
ssoClient := NewSSOOIDCClient(o.cfg)
|
||||
return ssoClient.LoginWithBuilderID(ctx)
|
||||
}
|
||||
|
||||
// LoginWithBuilderIDAuthCode performs OAuth login with AWS Builder ID using authorization code flow.
|
||||
// This provides a better UX than device code flow as it uses automatic browser callback.
|
||||
func (o *KiroOAuth) LoginWithBuilderIDAuthCode(ctx context.Context) (*KiroTokenData, error) {
|
||||
ssoClient := NewSSOOIDCClient(o.cfg)
|
||||
return ssoClient.LoginWithBuilderIDAuthCode(ctx)
|
||||
}
|
||||
|
||||
// exchangeCodeForToken exchanges the authorization code for tokens.
|
||||
func (o *KiroOAuth) exchangeCodeForToken(ctx context.Context, code, codeVerifier, redirectURI string) (*KiroTokenData, error) {
|
||||
payload := map[string]string{
|
||||
"code": code,
|
||||
"code_verifier": codeVerifier,
|
||||
"redirect_uri": redirectURI,
|
||||
}
|
||||
|
||||
body, err := json.Marshal(payload)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to marshal request: %w", err)
|
||||
}
|
||||
|
||||
tokenURL := kiroAuthEndpoint + "/oauth/token"
|
||||
req, err := http.NewRequestWithContext(ctx, http.MethodPost, tokenURL, strings.NewReader(string(body)))
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to create request: %w", err)
|
||||
}
|
||||
|
||||
req.Header.Set("Content-Type", "application/json")
|
||||
req.Header.Set("User-Agent", fmt.Sprintf("KiroIDE-%s-%s", o.kiroVersion, o.machineID))
|
||||
req.Header.Set("Accept", "application/json, text/plain, */*")
|
||||
|
||||
resp, err := o.httpClient.Do(req)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("token request failed: %w", err)
|
||||
}
|
||||
defer resp.Body.Close()
|
||||
|
||||
respBody, err := io.ReadAll(resp.Body)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to read response: %w", err)
|
||||
}
|
||||
|
||||
if resp.StatusCode != http.StatusOK {
|
||||
log.Debugf("token exchange failed (status %d): %s", resp.StatusCode, string(respBody))
|
||||
return nil, fmt.Errorf("token exchange failed (status %d)", resp.StatusCode)
|
||||
}
|
||||
|
||||
var tokenResp KiroTokenResponse
|
||||
if err := json.Unmarshal(respBody, &tokenResp); err != nil {
|
||||
return nil, fmt.Errorf("failed to parse token response: %w", err)
|
||||
}
|
||||
|
||||
// Validate ExpiresIn - use default 1 hour if invalid
|
||||
expiresIn := tokenResp.ExpiresIn
|
||||
if expiresIn <= 0 {
|
||||
expiresIn = 3600
|
||||
}
|
||||
expiresAt := time.Now().Add(time.Duration(expiresIn) * time.Second)
|
||||
|
||||
return &KiroTokenData{
|
||||
AccessToken: tokenResp.AccessToken,
|
||||
RefreshToken: tokenResp.RefreshToken,
|
||||
ProfileArn: tokenResp.ProfileArn,
|
||||
ExpiresAt: expiresAt.Format(time.RFC3339),
|
||||
AuthMethod: "social",
|
||||
Provider: "", // Caller should preserve original provider
|
||||
Region: "us-east-1",
|
||||
}, nil
|
||||
}
|
||||
|
||||
// RefreshToken refreshes an expired access token.
|
||||
// Uses KiroIDE-style User-Agent to match official Kiro IDE behavior.
|
||||
func (o *KiroOAuth) RefreshToken(ctx context.Context, refreshToken string) (*KiroTokenData, error) {
|
||||
return o.RefreshTokenWithFingerprint(ctx, refreshToken, "")
|
||||
}
|
||||
|
||||
// RefreshTokenWithFingerprint refreshes an expired access token with a specific fingerprint.
|
||||
// tokenKey is used to generate a consistent fingerprint for the token.
|
||||
func (o *KiroOAuth) RefreshTokenWithFingerprint(ctx context.Context, refreshToken, tokenKey string) (*KiroTokenData, error) {
|
||||
payload := map[string]string{
|
||||
"refreshToken": refreshToken,
|
||||
}
|
||||
|
||||
body, err := json.Marshal(payload)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to marshal request: %w", err)
|
||||
}
|
||||
|
||||
refreshURL := kiroAuthEndpoint + "/refreshToken"
|
||||
req, err := http.NewRequestWithContext(ctx, http.MethodPost, refreshURL, strings.NewReader(string(body)))
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to create request: %w", err)
|
||||
}
|
||||
|
||||
req.Header.Set("Content-Type", "application/json")
|
||||
req.Header.Set("User-Agent", fmt.Sprintf("KiroIDE-%s-%s", o.kiroVersion, o.machineID))
|
||||
req.Header.Set("Accept", "application/json, text/plain, */*")
|
||||
|
||||
resp, err := o.httpClient.Do(req)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("refresh request failed: %w", err)
|
||||
}
|
||||
defer resp.Body.Close()
|
||||
|
||||
respBody, err := io.ReadAll(resp.Body)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to read response: %w", err)
|
||||
}
|
||||
|
||||
if resp.StatusCode != http.StatusOK {
|
||||
log.Debugf("token refresh failed (status %d): %s", resp.StatusCode, string(respBody))
|
||||
return nil, fmt.Errorf("token refresh failed (status %d): %s", resp.StatusCode, string(respBody))
|
||||
}
|
||||
|
||||
var tokenResp KiroTokenResponse
|
||||
if err := json.Unmarshal(respBody, &tokenResp); err != nil {
|
||||
return nil, fmt.Errorf("failed to parse token response: %w", err)
|
||||
}
|
||||
|
||||
// Validate ExpiresIn - use default 1 hour if invalid
|
||||
expiresIn := tokenResp.ExpiresIn
|
||||
if expiresIn <= 0 {
|
||||
expiresIn = 3600
|
||||
}
|
||||
expiresAt := time.Now().Add(time.Duration(expiresIn) * time.Second)
|
||||
|
||||
return &KiroTokenData{
|
||||
AccessToken: tokenResp.AccessToken,
|
||||
RefreshToken: tokenResp.RefreshToken,
|
||||
ProfileArn: tokenResp.ProfileArn,
|
||||
ExpiresAt: expiresAt.Format(time.RFC3339),
|
||||
AuthMethod: "social",
|
||||
Provider: "", // Caller should preserve original provider
|
||||
Region: "us-east-1",
|
||||
}, nil
|
||||
}
|
||||
|
||||
// LoginWithGoogle performs OAuth login with Google using Kiro's social auth.
|
||||
// This uses a custom protocol handler (kiro://) to receive the callback.
|
||||
func (o *KiroOAuth) LoginWithGoogle(ctx context.Context) (*KiroTokenData, error) {
|
||||
socialClient := NewSocialAuthClient(o.cfg)
|
||||
return socialClient.LoginWithGoogle(ctx)
|
||||
}
|
||||
|
||||
// LoginWithGitHub performs OAuth login with GitHub using Kiro's social auth.
|
||||
// This uses a custom protocol handler (kiro://) to receive the callback.
|
||||
func (o *KiroOAuth) LoginWithGitHub(ctx context.Context) (*KiroTokenData, error) {
|
||||
socialClient := NewSocialAuthClient(o.cfg)
|
||||
return socialClient.LoginWithGitHub(ctx)
|
||||
}
|
||||
975
internal/auth/kiro/oauth_web.go
Normal file
975
internal/auth/kiro/oauth_web.go
Normal file
@@ -0,0 +1,975 @@
|
||||
// Package kiro provides OAuth Web authentication for Kiro.
|
||||
package kiro
|
||||
|
||||
import (
|
||||
"context"
|
||||
"crypto/rand"
|
||||
"encoding/base64"
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"html/template"
|
||||
"net/http"
|
||||
"os"
|
||||
"path/filepath"
|
||||
"strings"
|
||||
"sync"
|
||||
"time"
|
||||
|
||||
"github.com/gin-gonic/gin"
|
||||
"github.com/router-for-me/CLIProxyAPI/v6/internal/config"
|
||||
"github.com/router-for-me/CLIProxyAPI/v6/internal/util"
|
||||
log "github.com/sirupsen/logrus"
|
||||
)
|
||||
|
||||
const (
|
||||
defaultSessionExpiry = 10 * time.Minute
|
||||
pollIntervalSeconds = 5
|
||||
)
|
||||
|
||||
type authSessionStatus string
|
||||
|
||||
const (
|
||||
statusPending authSessionStatus = "pending"
|
||||
statusSuccess authSessionStatus = "success"
|
||||
statusFailed authSessionStatus = "failed"
|
||||
)
|
||||
|
||||
type webAuthSession struct {
|
||||
stateID string
|
||||
deviceCode string
|
||||
userCode string
|
||||
authURL string
|
||||
verificationURI string
|
||||
expiresIn int
|
||||
interval int
|
||||
status authSessionStatus
|
||||
startedAt time.Time
|
||||
completedAt time.Time
|
||||
expiresAt time.Time
|
||||
error string
|
||||
tokenData *KiroTokenData
|
||||
ssoClient *SSOOIDCClient
|
||||
clientID string
|
||||
clientSecret string
|
||||
region string
|
||||
cancelFunc context.CancelFunc
|
||||
authMethod string // "google", "github", "builder-id", "idc"
|
||||
startURL string // Used for IDC
|
||||
codeVerifier string // Used for social auth PKCE
|
||||
codeChallenge string // Used for social auth PKCE
|
||||
}
|
||||
|
||||
type OAuthWebHandler struct {
|
||||
cfg *config.Config
|
||||
sessions map[string]*webAuthSession
|
||||
mu sync.RWMutex
|
||||
onTokenObtained func(*KiroTokenData)
|
||||
}
|
||||
|
||||
func NewOAuthWebHandler(cfg *config.Config) *OAuthWebHandler {
|
||||
return &OAuthWebHandler{
|
||||
cfg: cfg,
|
||||
sessions: make(map[string]*webAuthSession),
|
||||
}
|
||||
}
|
||||
|
||||
func (h *OAuthWebHandler) SetTokenCallback(callback func(*KiroTokenData)) {
|
||||
h.onTokenObtained = callback
|
||||
}
|
||||
|
||||
func (h *OAuthWebHandler) RegisterRoutes(router gin.IRouter) {
|
||||
oauth := router.Group("/v0/oauth/kiro")
|
||||
{
|
||||
oauth.GET("", h.handleSelect)
|
||||
oauth.GET("/start", h.handleStart)
|
||||
oauth.GET("/callback", h.handleCallback)
|
||||
oauth.GET("/social/callback", h.handleSocialCallback)
|
||||
oauth.GET("/status", h.handleStatus)
|
||||
oauth.POST("/import", h.handleImportToken)
|
||||
oauth.POST("/refresh", h.handleManualRefresh)
|
||||
}
|
||||
}
|
||||
|
||||
func generateStateID() (string, error) {
|
||||
b := make([]byte, 16)
|
||||
if _, err := rand.Read(b); err != nil {
|
||||
return "", err
|
||||
}
|
||||
return base64.RawURLEncoding.EncodeToString(b), nil
|
||||
}
|
||||
|
||||
func (h *OAuthWebHandler) handleSelect(c *gin.Context) {
|
||||
h.renderSelectPage(c)
|
||||
}
|
||||
|
||||
func (h *OAuthWebHandler) handleStart(c *gin.Context) {
|
||||
method := c.Query("method")
|
||||
|
||||
if method == "" {
|
||||
c.Redirect(http.StatusFound, "/v0/oauth/kiro")
|
||||
return
|
||||
}
|
||||
|
||||
switch method {
|
||||
case "google", "github":
|
||||
// Google/GitHub social login is not supported for third-party apps
|
||||
// due to AWS Cognito redirect_uri restrictions
|
||||
h.renderError(c, "Google/GitHub login is not available for third-party applications. Please use AWS Builder ID or import your token from Kiro IDE.")
|
||||
case "builder-id":
|
||||
h.startBuilderIDAuth(c)
|
||||
case "idc":
|
||||
h.startIDCAuth(c)
|
||||
default:
|
||||
h.renderError(c, fmt.Sprintf("Unknown authentication method: %s", method))
|
||||
}
|
||||
}
|
||||
|
||||
func (h *OAuthWebHandler) startSocialAuth(c *gin.Context, method string) {
|
||||
stateID, err := generateStateID()
|
||||
if err != nil {
|
||||
h.renderError(c, "Failed to generate state parameter")
|
||||
return
|
||||
}
|
||||
|
||||
codeVerifier, codeChallenge, err := generatePKCE()
|
||||
if err != nil {
|
||||
h.renderError(c, "Failed to generate PKCE parameters")
|
||||
return
|
||||
}
|
||||
|
||||
socialClient := NewSocialAuthClient(h.cfg)
|
||||
|
||||
var provider string
|
||||
if method == "google" {
|
||||
provider = string(ProviderGoogle)
|
||||
} else {
|
||||
provider = string(ProviderGitHub)
|
||||
}
|
||||
|
||||
redirectURI := h.getSocialCallbackURL(c)
|
||||
authURL := socialClient.buildLoginURL(provider, redirectURI, codeChallenge, stateID)
|
||||
|
||||
ctx, cancel := context.WithTimeout(context.Background(), 10*time.Minute)
|
||||
|
||||
session := &webAuthSession{
|
||||
stateID: stateID,
|
||||
authMethod: method,
|
||||
authURL: authURL,
|
||||
status: statusPending,
|
||||
startedAt: time.Now(),
|
||||
expiresIn: 600,
|
||||
codeVerifier: codeVerifier,
|
||||
codeChallenge: codeChallenge,
|
||||
region: "us-east-1",
|
||||
cancelFunc: cancel,
|
||||
}
|
||||
|
||||
h.mu.Lock()
|
||||
h.sessions[stateID] = session
|
||||
h.mu.Unlock()
|
||||
|
||||
go func() {
|
||||
<-ctx.Done()
|
||||
h.mu.Lock()
|
||||
if session.status == statusPending {
|
||||
session.status = statusFailed
|
||||
session.error = "Authentication timed out"
|
||||
}
|
||||
h.mu.Unlock()
|
||||
}()
|
||||
|
||||
c.Redirect(http.StatusFound, authURL)
|
||||
}
|
||||
|
||||
func (h *OAuthWebHandler) getSocialCallbackURL(c *gin.Context) string {
|
||||
scheme := "http"
|
||||
if c.Request.TLS != nil || c.GetHeader("X-Forwarded-Proto") == "https" {
|
||||
scheme = "https"
|
||||
}
|
||||
return fmt.Sprintf("%s://%s/v0/oauth/kiro/social/callback", scheme, c.Request.Host)
|
||||
}
|
||||
|
||||
func (h *OAuthWebHandler) startBuilderIDAuth(c *gin.Context) {
|
||||
stateID, err := generateStateID()
|
||||
if err != nil {
|
||||
h.renderError(c, "Failed to generate state parameter")
|
||||
return
|
||||
}
|
||||
|
||||
region := defaultIDCRegion
|
||||
startURL := builderIDStartURL
|
||||
|
||||
ssoClient := NewSSOOIDCClient(h.cfg)
|
||||
|
||||
regResp, err := ssoClient.RegisterClientWithRegion(c.Request.Context(), region)
|
||||
if err != nil {
|
||||
log.Errorf("OAuth Web: failed to register client: %v", err)
|
||||
h.renderError(c, fmt.Sprintf("Failed to register client: %v", err))
|
||||
return
|
||||
}
|
||||
|
||||
authResp, err := ssoClient.StartDeviceAuthorizationWithIDC(
|
||||
c.Request.Context(),
|
||||
regResp.ClientID,
|
||||
regResp.ClientSecret,
|
||||
startURL,
|
||||
region,
|
||||
)
|
||||
if err != nil {
|
||||
log.Errorf("OAuth Web: failed to start device authorization: %v", err)
|
||||
h.renderError(c, fmt.Sprintf("Failed to start device authorization: %v", err))
|
||||
return
|
||||
}
|
||||
|
||||
ctx, cancel := context.WithTimeout(context.Background(), time.Duration(authResp.ExpiresIn)*time.Second)
|
||||
|
||||
session := &webAuthSession{
|
||||
stateID: stateID,
|
||||
deviceCode: authResp.DeviceCode,
|
||||
userCode: authResp.UserCode,
|
||||
authURL: authResp.VerificationURIComplete,
|
||||
verificationURI: authResp.VerificationURI,
|
||||
expiresIn: authResp.ExpiresIn,
|
||||
interval: authResp.Interval,
|
||||
status: statusPending,
|
||||
startedAt: time.Now(),
|
||||
ssoClient: ssoClient,
|
||||
clientID: regResp.ClientID,
|
||||
clientSecret: regResp.ClientSecret,
|
||||
region: region,
|
||||
authMethod: "builder-id",
|
||||
startURL: startURL,
|
||||
cancelFunc: cancel,
|
||||
}
|
||||
|
||||
h.mu.Lock()
|
||||
h.sessions[stateID] = session
|
||||
h.mu.Unlock()
|
||||
|
||||
go h.pollForToken(ctx, session)
|
||||
|
||||
h.renderStartPage(c, session)
|
||||
}
|
||||
|
||||
func (h *OAuthWebHandler) startIDCAuth(c *gin.Context) {
|
||||
startURL := c.Query("startUrl")
|
||||
region := c.Query("region")
|
||||
|
||||
if startURL == "" {
|
||||
h.renderError(c, "Missing startUrl parameter for IDC authentication")
|
||||
return
|
||||
}
|
||||
if region == "" {
|
||||
region = defaultIDCRegion
|
||||
}
|
||||
|
||||
stateID, err := generateStateID()
|
||||
if err != nil {
|
||||
h.renderError(c, "Failed to generate state parameter")
|
||||
return
|
||||
}
|
||||
|
||||
ssoClient := NewSSOOIDCClient(h.cfg)
|
||||
|
||||
regResp, err := ssoClient.RegisterClientWithRegion(c.Request.Context(), region)
|
||||
if err != nil {
|
||||
log.Errorf("OAuth Web: failed to register client: %v", err)
|
||||
h.renderError(c, fmt.Sprintf("Failed to register client: %v", err))
|
||||
return
|
||||
}
|
||||
|
||||
authResp, err := ssoClient.StartDeviceAuthorizationWithIDC(
|
||||
c.Request.Context(),
|
||||
regResp.ClientID,
|
||||
regResp.ClientSecret,
|
||||
startURL,
|
||||
region,
|
||||
)
|
||||
if err != nil {
|
||||
log.Errorf("OAuth Web: failed to start device authorization: %v", err)
|
||||
h.renderError(c, fmt.Sprintf("Failed to start device authorization: %v", err))
|
||||
return
|
||||
}
|
||||
|
||||
ctx, cancel := context.WithTimeout(context.Background(), time.Duration(authResp.ExpiresIn)*time.Second)
|
||||
|
||||
session := &webAuthSession{
|
||||
stateID: stateID,
|
||||
deviceCode: authResp.DeviceCode,
|
||||
userCode: authResp.UserCode,
|
||||
authURL: authResp.VerificationURIComplete,
|
||||
verificationURI: authResp.VerificationURI,
|
||||
expiresIn: authResp.ExpiresIn,
|
||||
interval: authResp.Interval,
|
||||
status: statusPending,
|
||||
startedAt: time.Now(),
|
||||
ssoClient: ssoClient,
|
||||
clientID: regResp.ClientID,
|
||||
clientSecret: regResp.ClientSecret,
|
||||
region: region,
|
||||
authMethod: "idc",
|
||||
startURL: startURL,
|
||||
cancelFunc: cancel,
|
||||
}
|
||||
|
||||
h.mu.Lock()
|
||||
h.sessions[stateID] = session
|
||||
h.mu.Unlock()
|
||||
|
||||
go h.pollForToken(ctx, session)
|
||||
|
||||
h.renderStartPage(c, session)
|
||||
}
|
||||
|
||||
func (h *OAuthWebHandler) pollForToken(ctx context.Context, session *webAuthSession) {
|
||||
defer session.cancelFunc()
|
||||
|
||||
interval := time.Duration(session.interval) * time.Second
|
||||
if interval < time.Duration(pollIntervalSeconds)*time.Second {
|
||||
interval = time.Duration(pollIntervalSeconds) * time.Second
|
||||
}
|
||||
|
||||
ticker := time.NewTicker(interval)
|
||||
defer ticker.Stop()
|
||||
|
||||
for {
|
||||
select {
|
||||
case <-ctx.Done():
|
||||
h.mu.Lock()
|
||||
if session.status == statusPending {
|
||||
session.status = statusFailed
|
||||
session.error = "Authentication timed out"
|
||||
}
|
||||
h.mu.Unlock()
|
||||
return
|
||||
case <-ticker.C:
|
||||
tokenResp, err := h.ssoClient(session).CreateTokenWithRegion(
|
||||
ctx,
|
||||
session.clientID,
|
||||
session.clientSecret,
|
||||
session.deviceCode,
|
||||
session.region,
|
||||
)
|
||||
|
||||
if err != nil {
|
||||
errStr := err.Error()
|
||||
if errStr == ErrAuthorizationPending.Error() {
|
||||
continue
|
||||
}
|
||||
if errStr == ErrSlowDown.Error() {
|
||||
interval += 5 * time.Second
|
||||
ticker.Reset(interval)
|
||||
continue
|
||||
}
|
||||
|
||||
h.mu.Lock()
|
||||
session.status = statusFailed
|
||||
session.error = errStr
|
||||
session.completedAt = time.Now()
|
||||
h.mu.Unlock()
|
||||
|
||||
log.Errorf("OAuth Web: token polling failed: %v", err)
|
||||
return
|
||||
}
|
||||
|
||||
expiresAt := time.Now().Add(time.Duration(tokenResp.ExpiresIn) * time.Second)
|
||||
|
||||
// Fetch profileArn for IDC
|
||||
var profileArn string
|
||||
if session.authMethod == "idc" {
|
||||
profileArn = session.ssoClient.FetchProfileArn(ctx, tokenResp.AccessToken, session.clientID, tokenResp.RefreshToken)
|
||||
}
|
||||
|
||||
email := FetchUserEmailWithFallback(ctx, h.cfg, tokenResp.AccessToken, session.clientID, tokenResp.RefreshToken)
|
||||
|
||||
tokenData := &KiroTokenData{
|
||||
AccessToken: tokenResp.AccessToken,
|
||||
RefreshToken: tokenResp.RefreshToken,
|
||||
ProfileArn: profileArn,
|
||||
ExpiresAt: expiresAt.Format(time.RFC3339),
|
||||
AuthMethod: session.authMethod,
|
||||
Provider: "AWS",
|
||||
ClientID: session.clientID,
|
||||
ClientSecret: session.clientSecret,
|
||||
Email: email,
|
||||
Region: session.region,
|
||||
StartURL: session.startURL,
|
||||
}
|
||||
|
||||
h.mu.Lock()
|
||||
session.status = statusSuccess
|
||||
session.completedAt = time.Now()
|
||||
session.expiresAt = expiresAt
|
||||
session.tokenData = tokenData
|
||||
h.mu.Unlock()
|
||||
|
||||
if h.onTokenObtained != nil {
|
||||
h.onTokenObtained(tokenData)
|
||||
}
|
||||
|
||||
// Save token to file
|
||||
h.saveTokenToFile(tokenData)
|
||||
|
||||
log.Infof("OAuth Web: authentication successful for %s", email)
|
||||
return
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// saveTokenToFile saves the token data to the auth directory
|
||||
func (h *OAuthWebHandler) saveTokenToFile(tokenData *KiroTokenData) {
|
||||
// Get auth directory from config or use default
|
||||
authDir := ""
|
||||
if h.cfg != nil && h.cfg.AuthDir != "" {
|
||||
var err error
|
||||
authDir, err = util.ResolveAuthDir(h.cfg.AuthDir)
|
||||
if err != nil {
|
||||
log.Errorf("OAuth Web: failed to resolve auth directory: %v", err)
|
||||
}
|
||||
}
|
||||
|
||||
// Fall back to default location
|
||||
if authDir == "" {
|
||||
home, err := os.UserHomeDir()
|
||||
if err != nil {
|
||||
log.Errorf("OAuth Web: failed to get home directory: %v", err)
|
||||
return
|
||||
}
|
||||
authDir = filepath.Join(home, ".cli-proxy-api")
|
||||
}
|
||||
|
||||
// Create directory if not exists
|
||||
if err := os.MkdirAll(authDir, 0700); err != nil {
|
||||
log.Errorf("OAuth Web: failed to create auth directory: %v", err)
|
||||
return
|
||||
}
|
||||
|
||||
// Generate filename using the unified function
|
||||
fileName := GenerateTokenFileName(tokenData)
|
||||
|
||||
authFilePath := filepath.Join(authDir, fileName)
|
||||
|
||||
// Convert to storage format and save
|
||||
storage := &KiroTokenStorage{
|
||||
Type: "kiro",
|
||||
AccessToken: tokenData.AccessToken,
|
||||
RefreshToken: tokenData.RefreshToken,
|
||||
ProfileArn: tokenData.ProfileArn,
|
||||
ExpiresAt: tokenData.ExpiresAt,
|
||||
AuthMethod: tokenData.AuthMethod,
|
||||
Provider: tokenData.Provider,
|
||||
LastRefresh: time.Now().Format(time.RFC3339),
|
||||
ClientID: tokenData.ClientID,
|
||||
ClientSecret: tokenData.ClientSecret,
|
||||
Region: tokenData.Region,
|
||||
StartURL: tokenData.StartURL,
|
||||
Email: tokenData.Email,
|
||||
}
|
||||
|
||||
if err := storage.SaveTokenToFile(authFilePath); err != nil {
|
||||
log.Errorf("OAuth Web: failed to save token to file: %v", err)
|
||||
return
|
||||
}
|
||||
|
||||
log.Infof("OAuth Web: token saved to %s", authFilePath)
|
||||
}
|
||||
|
||||
func (h *OAuthWebHandler) ssoClient(session *webAuthSession) *SSOOIDCClient {
|
||||
return session.ssoClient
|
||||
}
|
||||
|
||||
func (h *OAuthWebHandler) handleCallback(c *gin.Context) {
|
||||
stateID := c.Query("state")
|
||||
errParam := c.Query("error")
|
||||
|
||||
if errParam != "" {
|
||||
h.renderError(c, errParam)
|
||||
return
|
||||
}
|
||||
|
||||
if stateID == "" {
|
||||
h.renderError(c, "Missing state parameter")
|
||||
return
|
||||
}
|
||||
|
||||
h.mu.RLock()
|
||||
session, exists := h.sessions[stateID]
|
||||
h.mu.RUnlock()
|
||||
|
||||
if !exists {
|
||||
h.renderError(c, "Invalid or expired session")
|
||||
return
|
||||
}
|
||||
|
||||
if session.status == statusSuccess {
|
||||
h.renderSuccess(c, session)
|
||||
} else if session.status == statusFailed {
|
||||
h.renderError(c, session.error)
|
||||
} else {
|
||||
c.Redirect(http.StatusFound, "/v0/oauth/kiro/start")
|
||||
}
|
||||
}
|
||||
|
||||
func (h *OAuthWebHandler) handleSocialCallback(c *gin.Context) {
|
||||
stateID := c.Query("state")
|
||||
code := c.Query("code")
|
||||
errParam := c.Query("error")
|
||||
|
||||
if errParam != "" {
|
||||
h.renderError(c, errParam)
|
||||
return
|
||||
}
|
||||
|
||||
if stateID == "" {
|
||||
h.renderError(c, "Missing state parameter")
|
||||
return
|
||||
}
|
||||
|
||||
if code == "" {
|
||||
h.renderError(c, "Missing authorization code")
|
||||
return
|
||||
}
|
||||
|
||||
h.mu.RLock()
|
||||
session, exists := h.sessions[stateID]
|
||||
h.mu.RUnlock()
|
||||
|
||||
if !exists {
|
||||
h.renderError(c, "Invalid or expired session")
|
||||
return
|
||||
}
|
||||
|
||||
if session.authMethod != "google" && session.authMethod != "github" {
|
||||
h.renderError(c, "Invalid session type for social callback")
|
||||
return
|
||||
}
|
||||
|
||||
socialClient := NewSocialAuthClient(h.cfg)
|
||||
redirectURI := h.getSocialCallbackURL(c)
|
||||
|
||||
tokenReq := &CreateTokenRequest{
|
||||
Code: code,
|
||||
CodeVerifier: session.codeVerifier,
|
||||
RedirectURI: redirectURI,
|
||||
}
|
||||
|
||||
tokenResp, err := socialClient.CreateToken(c.Request.Context(), tokenReq)
|
||||
if err != nil {
|
||||
log.Errorf("OAuth Web: social token exchange failed: %v", err)
|
||||
h.mu.Lock()
|
||||
session.status = statusFailed
|
||||
session.error = fmt.Sprintf("Token exchange failed: %v", err)
|
||||
session.completedAt = time.Now()
|
||||
h.mu.Unlock()
|
||||
h.renderError(c, session.error)
|
||||
return
|
||||
}
|
||||
|
||||
expiresIn := tokenResp.ExpiresIn
|
||||
if expiresIn <= 0 {
|
||||
expiresIn = 3600
|
||||
}
|
||||
expiresAt := time.Now().Add(time.Duration(expiresIn) * time.Second)
|
||||
|
||||
email := ExtractEmailFromJWT(tokenResp.AccessToken)
|
||||
|
||||
var provider string
|
||||
if session.authMethod == "google" {
|
||||
provider = string(ProviderGoogle)
|
||||
} else {
|
||||
provider = string(ProviderGitHub)
|
||||
}
|
||||
|
||||
tokenData := &KiroTokenData{
|
||||
AccessToken: tokenResp.AccessToken,
|
||||
RefreshToken: tokenResp.RefreshToken,
|
||||
ProfileArn: tokenResp.ProfileArn,
|
||||
ExpiresAt: expiresAt.Format(time.RFC3339),
|
||||
AuthMethod: session.authMethod,
|
||||
Provider: provider,
|
||||
Email: email,
|
||||
Region: "us-east-1",
|
||||
}
|
||||
|
||||
h.mu.Lock()
|
||||
session.status = statusSuccess
|
||||
session.completedAt = time.Now()
|
||||
session.expiresAt = expiresAt
|
||||
session.tokenData = tokenData
|
||||
h.mu.Unlock()
|
||||
|
||||
if session.cancelFunc != nil {
|
||||
session.cancelFunc()
|
||||
}
|
||||
|
||||
if h.onTokenObtained != nil {
|
||||
h.onTokenObtained(tokenData)
|
||||
}
|
||||
|
||||
// Save token to file
|
||||
h.saveTokenToFile(tokenData)
|
||||
|
||||
log.Infof("OAuth Web: social authentication successful for %s via %s", email, provider)
|
||||
h.renderSuccess(c, session)
|
||||
}
|
||||
|
||||
func (h *OAuthWebHandler) handleStatus(c *gin.Context) {
|
||||
stateID := c.Query("state")
|
||||
if stateID == "" {
|
||||
c.JSON(http.StatusBadRequest, gin.H{"error": "missing state parameter"})
|
||||
return
|
||||
}
|
||||
|
||||
h.mu.RLock()
|
||||
session, exists := h.sessions[stateID]
|
||||
h.mu.RUnlock()
|
||||
|
||||
if !exists {
|
||||
c.JSON(http.StatusNotFound, gin.H{"error": "session not found"})
|
||||
return
|
||||
}
|
||||
|
||||
response := gin.H{
|
||||
"status": string(session.status),
|
||||
}
|
||||
|
||||
switch session.status {
|
||||
case statusPending:
|
||||
elapsed := time.Since(session.startedAt).Seconds()
|
||||
remaining := float64(session.expiresIn) - elapsed
|
||||
if remaining < 0 {
|
||||
remaining = 0
|
||||
}
|
||||
response["remaining_seconds"] = int(remaining)
|
||||
case statusSuccess:
|
||||
response["completed_at"] = session.completedAt.Format(time.RFC3339)
|
||||
response["expires_at"] = session.expiresAt.Format(time.RFC3339)
|
||||
case statusFailed:
|
||||
response["error"] = session.error
|
||||
response["failed_at"] = session.completedAt.Format(time.RFC3339)
|
||||
}
|
||||
|
||||
c.JSON(http.StatusOK, response)
|
||||
}
|
||||
|
||||
func (h *OAuthWebHandler) renderStartPage(c *gin.Context, session *webAuthSession) {
|
||||
tmpl, err := template.New("start").Parse(oauthWebStartPageHTML)
|
||||
if err != nil {
|
||||
log.Errorf("OAuth Web: failed to parse template: %v", err)
|
||||
c.String(http.StatusInternalServerError, "Template error")
|
||||
return
|
||||
}
|
||||
|
||||
data := map[string]interface{}{
|
||||
"AuthURL": session.authURL,
|
||||
"UserCode": session.userCode,
|
||||
"ExpiresIn": session.expiresIn,
|
||||
"StateID": session.stateID,
|
||||
}
|
||||
|
||||
c.Header("Content-Type", "text/html; charset=utf-8")
|
||||
if err := tmpl.Execute(c.Writer, data); err != nil {
|
||||
log.Errorf("OAuth Web: failed to render template: %v", err)
|
||||
}
|
||||
}
|
||||
|
||||
func (h *OAuthWebHandler) renderSelectPage(c *gin.Context) {
|
||||
tmpl, err := template.New("select").Parse(oauthWebSelectPageHTML)
|
||||
if err != nil {
|
||||
log.Errorf("OAuth Web: failed to parse select template: %v", err)
|
||||
c.String(http.StatusInternalServerError, "Template error")
|
||||
return
|
||||
}
|
||||
|
||||
c.Header("Content-Type", "text/html; charset=utf-8")
|
||||
if err := tmpl.Execute(c.Writer, nil); err != nil {
|
||||
log.Errorf("OAuth Web: failed to render select template: %v", err)
|
||||
}
|
||||
}
|
||||
|
||||
func (h *OAuthWebHandler) renderError(c *gin.Context, errMsg string) {
|
||||
tmpl, err := template.New("error").Parse(oauthWebErrorPageHTML)
|
||||
if err != nil {
|
||||
log.Errorf("OAuth Web: failed to parse error template: %v", err)
|
||||
c.String(http.StatusInternalServerError, "Template error")
|
||||
return
|
||||
}
|
||||
|
||||
data := map[string]interface{}{
|
||||
"Error": errMsg,
|
||||
}
|
||||
|
||||
c.Header("Content-Type", "text/html; charset=utf-8")
|
||||
c.Status(http.StatusBadRequest)
|
||||
if err := tmpl.Execute(c.Writer, data); err != nil {
|
||||
log.Errorf("OAuth Web: failed to render error template: %v", err)
|
||||
}
|
||||
}
|
||||
|
||||
func (h *OAuthWebHandler) renderSuccess(c *gin.Context, session *webAuthSession) {
|
||||
tmpl, err := template.New("success").Parse(oauthWebSuccessPageHTML)
|
||||
if err != nil {
|
||||
log.Errorf("OAuth Web: failed to parse success template: %v", err)
|
||||
c.String(http.StatusInternalServerError, "Template error")
|
||||
return
|
||||
}
|
||||
|
||||
data := map[string]interface{}{
|
||||
"ExpiresAt": session.expiresAt.Format(time.RFC3339),
|
||||
}
|
||||
|
||||
c.Header("Content-Type", "text/html; charset=utf-8")
|
||||
if err := tmpl.Execute(c.Writer, data); err != nil {
|
||||
log.Errorf("OAuth Web: failed to render success template: %v", err)
|
||||
}
|
||||
}
|
||||
|
||||
func (h *OAuthWebHandler) CleanupExpiredSessions() {
|
||||
h.mu.Lock()
|
||||
defer h.mu.Unlock()
|
||||
|
||||
now := time.Now()
|
||||
for id, session := range h.sessions {
|
||||
if session.status != statusPending && now.Sub(session.completedAt) > 30*time.Minute {
|
||||
delete(h.sessions, id)
|
||||
} else if session.status == statusPending && now.Sub(session.startedAt) > defaultSessionExpiry {
|
||||
session.cancelFunc()
|
||||
delete(h.sessions, id)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func (h *OAuthWebHandler) GetSession(stateID string) (*webAuthSession, bool) {
|
||||
h.mu.RLock()
|
||||
defer h.mu.RUnlock()
|
||||
session, exists := h.sessions[stateID]
|
||||
return session, exists
|
||||
}
|
||||
|
||||
// ImportTokenRequest represents the request body for token import
|
||||
type ImportTokenRequest struct {
|
||||
RefreshToken string `json:"refreshToken"`
|
||||
}
|
||||
|
||||
// handleImportToken handles manual refresh token import from Kiro IDE
|
||||
func (h *OAuthWebHandler) handleImportToken(c *gin.Context) {
|
||||
var req ImportTokenRequest
|
||||
if err := c.ShouldBindJSON(&req); err != nil {
|
||||
c.JSON(http.StatusBadRequest, gin.H{
|
||||
"success": false,
|
||||
"error": "Invalid request body",
|
||||
})
|
||||
return
|
||||
}
|
||||
|
||||
refreshToken := strings.TrimSpace(req.RefreshToken)
|
||||
if refreshToken == "" {
|
||||
c.JSON(http.StatusBadRequest, gin.H{
|
||||
"success": false,
|
||||
"error": "Refresh token is required",
|
||||
})
|
||||
return
|
||||
}
|
||||
|
||||
// Validate token format
|
||||
if !strings.HasPrefix(refreshToken, "aorAAAAAG") {
|
||||
c.JSON(http.StatusBadRequest, gin.H{
|
||||
"success": false,
|
||||
"error": "Invalid token format. Token should start with aorAAAAAG...",
|
||||
})
|
||||
return
|
||||
}
|
||||
|
||||
// Create social auth client to refresh and validate the token
|
||||
socialClient := NewSocialAuthClient(h.cfg)
|
||||
|
||||
// Refresh the token to validate it and get access token
|
||||
tokenData, err := socialClient.RefreshSocialToken(c.Request.Context(), refreshToken)
|
||||
if err != nil {
|
||||
log.Errorf("OAuth Web: token refresh failed during import: %v", err)
|
||||
c.JSON(http.StatusBadRequest, gin.H{
|
||||
"success": false,
|
||||
"error": fmt.Sprintf("Token validation failed: %v", err),
|
||||
})
|
||||
return
|
||||
}
|
||||
|
||||
// Set the original refresh token (the refreshed one might be empty)
|
||||
if tokenData.RefreshToken == "" {
|
||||
tokenData.RefreshToken = refreshToken
|
||||
}
|
||||
tokenData.AuthMethod = "social"
|
||||
tokenData.Provider = "imported"
|
||||
|
||||
// Notify callback if set
|
||||
if h.onTokenObtained != nil {
|
||||
h.onTokenObtained(tokenData)
|
||||
}
|
||||
|
||||
// Save token to file
|
||||
h.saveTokenToFile(tokenData)
|
||||
|
||||
// Generate filename for response using the unified function
|
||||
fileName := GenerateTokenFileName(tokenData)
|
||||
|
||||
log.Infof("OAuth Web: token imported successfully")
|
||||
c.JSON(http.StatusOK, gin.H{
|
||||
"success": true,
|
||||
"message": "Token imported successfully",
|
||||
"fileName": fileName,
|
||||
})
|
||||
}
|
||||
|
||||
// handleManualRefresh handles manual token refresh requests from the web UI.
|
||||
// This allows users to trigger a token refresh when needed, without waiting
|
||||
// for the automatic 30-second check and 20-minute-before-expiry refresh cycle.
|
||||
// Uses the same refresh logic as kiro_executor.Refresh for consistency.
|
||||
func (h *OAuthWebHandler) handleManualRefresh(c *gin.Context) {
|
||||
authDir := ""
|
||||
if h.cfg != nil && h.cfg.AuthDir != "" {
|
||||
var err error
|
||||
authDir, err = util.ResolveAuthDir(h.cfg.AuthDir)
|
||||
if err != nil {
|
||||
log.Errorf("OAuth Web: failed to resolve auth directory: %v", err)
|
||||
}
|
||||
}
|
||||
|
||||
if authDir == "" {
|
||||
home, err := os.UserHomeDir()
|
||||
if err != nil {
|
||||
c.JSON(http.StatusInternalServerError, gin.H{
|
||||
"success": false,
|
||||
"error": "Failed to get home directory",
|
||||
})
|
||||
return
|
||||
}
|
||||
authDir = filepath.Join(home, ".cli-proxy-api")
|
||||
}
|
||||
|
||||
// Find all kiro token files in the auth directory
|
||||
files, err := os.ReadDir(authDir)
|
||||
if err != nil {
|
||||
c.JSON(http.StatusInternalServerError, gin.H{
|
||||
"success": false,
|
||||
"error": fmt.Sprintf("Failed to read auth directory: %v", err),
|
||||
})
|
||||
return
|
||||
}
|
||||
|
||||
var refreshedCount int
|
||||
var errors []string
|
||||
|
||||
for _, file := range files {
|
||||
if file.IsDir() {
|
||||
continue
|
||||
}
|
||||
name := file.Name()
|
||||
if !strings.HasPrefix(name, "kiro-") || !strings.HasSuffix(name, ".json") {
|
||||
continue
|
||||
}
|
||||
|
||||
filePath := filepath.Join(authDir, name)
|
||||
data, err := os.ReadFile(filePath)
|
||||
if err != nil {
|
||||
errors = append(errors, fmt.Sprintf("%s: read error - %v", name, err))
|
||||
continue
|
||||
}
|
||||
|
||||
var storage KiroTokenStorage
|
||||
if err := json.Unmarshal(data, &storage); err != nil {
|
||||
errors = append(errors, fmt.Sprintf("%s: parse error - %v", name, err))
|
||||
continue
|
||||
}
|
||||
|
||||
if storage.RefreshToken == "" {
|
||||
errors = append(errors, fmt.Sprintf("%s: no refresh token", name))
|
||||
continue
|
||||
}
|
||||
|
||||
// Refresh token using the same logic as kiro_executor.Refresh
|
||||
tokenData, err := h.refreshTokenData(c.Request.Context(), &storage)
|
||||
if err != nil {
|
||||
errors = append(errors, fmt.Sprintf("%s: refresh failed - %v", name, err))
|
||||
continue
|
||||
}
|
||||
|
||||
// Update storage with new token data
|
||||
storage.AccessToken = tokenData.AccessToken
|
||||
if tokenData.RefreshToken != "" {
|
||||
storage.RefreshToken = tokenData.RefreshToken
|
||||
}
|
||||
storage.ExpiresAt = tokenData.ExpiresAt
|
||||
storage.LastRefresh = time.Now().Format(time.RFC3339)
|
||||
if tokenData.ProfileArn != "" {
|
||||
storage.ProfileArn = tokenData.ProfileArn
|
||||
}
|
||||
|
||||
// Write updated token back to file
|
||||
updatedData, err := json.MarshalIndent(storage, "", " ")
|
||||
if err != nil {
|
||||
errors = append(errors, fmt.Sprintf("%s: marshal error - %v", name, err))
|
||||
continue
|
||||
}
|
||||
|
||||
tmpFile := filePath + ".tmp"
|
||||
if err := os.WriteFile(tmpFile, updatedData, 0600); err != nil {
|
||||
errors = append(errors, fmt.Sprintf("%s: write error - %v", name, err))
|
||||
continue
|
||||
}
|
||||
if err := os.Rename(tmpFile, filePath); err != nil {
|
||||
errors = append(errors, fmt.Sprintf("%s: rename error - %v", name, err))
|
||||
continue
|
||||
}
|
||||
|
||||
log.Infof("OAuth Web: manually refreshed token in %s, expires at %s", name, tokenData.ExpiresAt)
|
||||
refreshedCount++
|
||||
|
||||
// Notify callback if set
|
||||
if h.onTokenObtained != nil {
|
||||
h.onTokenObtained(tokenData)
|
||||
}
|
||||
}
|
||||
|
||||
if refreshedCount == 0 && len(errors) > 0 {
|
||||
c.JSON(http.StatusBadRequest, gin.H{
|
||||
"success": false,
|
||||
"error": fmt.Sprintf("All refresh attempts failed: %v", errors),
|
||||
})
|
||||
return
|
||||
}
|
||||
|
||||
response := gin.H{
|
||||
"success": true,
|
||||
"message": fmt.Sprintf("Refreshed %d token(s)", refreshedCount),
|
||||
"refreshedCount": refreshedCount,
|
||||
}
|
||||
if len(errors) > 0 {
|
||||
response["warnings"] = errors
|
||||
}
|
||||
|
||||
c.JSON(http.StatusOK, response)
|
||||
}
|
||||
|
||||
// refreshTokenData refreshes a token using the appropriate method based on auth type.
|
||||
// This mirrors the logic in kiro_executor.Refresh for consistency.
|
||||
func (h *OAuthWebHandler) refreshTokenData(ctx context.Context, storage *KiroTokenStorage) (*KiroTokenData, error) {
|
||||
ssoClient := NewSSOOIDCClient(h.cfg)
|
||||
|
||||
switch {
|
||||
case storage.ClientID != "" && storage.ClientSecret != "" && storage.AuthMethod == "idc" && storage.Region != "":
|
||||
// IDC refresh with region-specific endpoint
|
||||
log.Debugf("OAuth Web: using SSO OIDC refresh for IDC (region=%s)", storage.Region)
|
||||
return ssoClient.RefreshTokenWithRegion(ctx, storage.ClientID, storage.ClientSecret, storage.RefreshToken, storage.Region, storage.StartURL)
|
||||
|
||||
case storage.ClientID != "" && storage.ClientSecret != "" && storage.AuthMethod == "builder-id":
|
||||
// Builder ID refresh with default endpoint
|
||||
log.Debugf("OAuth Web: using SSO OIDC refresh for AWS Builder ID")
|
||||
return ssoClient.RefreshToken(ctx, storage.ClientID, storage.ClientSecret, storage.RefreshToken)
|
||||
|
||||
default:
|
||||
// Fallback to Kiro's OAuth refresh endpoint (for social auth: Google/GitHub)
|
||||
log.Debugf("OAuth Web: using Kiro OAuth refresh endpoint")
|
||||
oauth := NewKiroOAuth(h.cfg)
|
||||
return oauth.RefreshToken(ctx, storage.RefreshToken)
|
||||
}
|
||||
}
|
||||
779
internal/auth/kiro/oauth_web_templates.go
Normal file
779
internal/auth/kiro/oauth_web_templates.go
Normal file
@@ -0,0 +1,779 @@
|
||||
// Package kiro provides OAuth Web authentication templates.
|
||||
package kiro
|
||||
|
||||
const (
|
||||
oauthWebStartPageHTML = `<!DOCTYPE html>
|
||||
<html>
|
||||
<head>
|
||||
<meta charset="UTF-8">
|
||||
<meta name="viewport" content="width=device-width, initial-scale=1.0">
|
||||
<title>AWS SSO Authentication</title>
|
||||
<style>
|
||||
* { box-sizing: border-box; }
|
||||
body {
|
||||
font-family: -apple-system, BlinkMacSystemFont, "Segoe UI", Roboto, sans-serif;
|
||||
margin: 0;
|
||||
padding: 20px;
|
||||
background: linear-gradient(135deg, #667eea 0%, #764ba2 100%);
|
||||
min-height: 100vh;
|
||||
display: flex;
|
||||
justify-content: center;
|
||||
align-items: center;
|
||||
}
|
||||
.container {
|
||||
max-width: 500px;
|
||||
width: 100%;
|
||||
background: #fff;
|
||||
padding: 40px;
|
||||
border-radius: 12px;
|
||||
box-shadow: 0 10px 40px rgba(0,0,0,0.2);
|
||||
}
|
||||
h1 {
|
||||
margin: 0 0 10px;
|
||||
color: #333;
|
||||
font-size: 24px;
|
||||
text-align: center;
|
||||
}
|
||||
.subtitle {
|
||||
text-align: center;
|
||||
color: #666;
|
||||
margin-bottom: 30px;
|
||||
}
|
||||
.step {
|
||||
background: #f8f9fa;
|
||||
padding: 20px;
|
||||
border-radius: 8px;
|
||||
margin-bottom: 15px;
|
||||
}
|
||||
.step-title {
|
||||
display: flex;
|
||||
align-items: center;
|
||||
font-weight: 600;
|
||||
color: #333;
|
||||
margin-bottom: 10px;
|
||||
}
|
||||
.step-number {
|
||||
width: 28px;
|
||||
height: 28px;
|
||||
background: #667eea;
|
||||
color: white;
|
||||
border-radius: 50%;
|
||||
display: flex;
|
||||
align-items: center;
|
||||
justify-content: center;
|
||||
font-size: 14px;
|
||||
margin-right: 12px;
|
||||
}
|
||||
.user-code {
|
||||
background: #e7f3ff;
|
||||
border: 2px dashed #2196F3;
|
||||
border-radius: 8px;
|
||||
padding: 20px;
|
||||
text-align: center;
|
||||
margin-top: 10px;
|
||||
}
|
||||
.user-code-label {
|
||||
font-size: 12px;
|
||||
color: #666;
|
||||
text-transform: uppercase;
|
||||
letter-spacing: 1px;
|
||||
margin-bottom: 8px;
|
||||
}
|
||||
.user-code-value {
|
||||
font-size: 32px;
|
||||
font-weight: bold;
|
||||
font-family: monospace;
|
||||
color: #2196F3;
|
||||
letter-spacing: 4px;
|
||||
}
|
||||
.auth-btn {
|
||||
display: block;
|
||||
width: 100%;
|
||||
padding: 15px;
|
||||
background: #667eea;
|
||||
color: white;
|
||||
text-align: center;
|
||||
text-decoration: none;
|
||||
border-radius: 8px;
|
||||
font-weight: 600;
|
||||
font-size: 16px;
|
||||
transition: all 0.3s;
|
||||
border: none;
|
||||
cursor: pointer;
|
||||
margin-top: 20px;
|
||||
}
|
||||
.auth-btn:hover {
|
||||
background: #5568d3;
|
||||
transform: translateY(-2px);
|
||||
box-shadow: 0 4px 12px rgba(102, 126, 234, 0.4);
|
||||
}
|
||||
.status {
|
||||
margin-top: 30px;
|
||||
padding: 20px;
|
||||
background: #f8f9fa;
|
||||
border-radius: 8px;
|
||||
text-align: center;
|
||||
}
|
||||
.status-pending { border-left: 4px solid #ffc107; }
|
||||
.status-success { border-left: 4px solid #28a745; }
|
||||
.status-failed { border-left: 4px solid #dc3545; }
|
||||
.spinner {
|
||||
border: 3px solid #f3f3f3;
|
||||
border-top: 3px solid #667eea;
|
||||
border-radius: 50%;
|
||||
width: 40px;
|
||||
height: 40px;
|
||||
animation: spin 1s linear infinite;
|
||||
margin: 0 auto 15px;
|
||||
}
|
||||
@keyframes spin {
|
||||
0% { transform: rotate(0deg); }
|
||||
100% { transform: rotate(360deg); }
|
||||
}
|
||||
.timer {
|
||||
font-size: 24px;
|
||||
font-weight: bold;
|
||||
color: #667eea;
|
||||
margin: 10px 0;
|
||||
}
|
||||
.timer.warning { color: #ffc107; }
|
||||
.timer.danger { color: #dc3545; }
|
||||
.status-message { color: #666; line-height: 1.6; }
|
||||
.success-icon, .error-icon { font-size: 48px; margin-bottom: 15px; }
|
||||
.info-box {
|
||||
background: #e7f3ff;
|
||||
border-left: 4px solid #2196F3;
|
||||
padding: 15px;
|
||||
margin-top: 20px;
|
||||
border-radius: 4px;
|
||||
font-size: 14px;
|
||||
color: #666;
|
||||
}
|
||||
</style>
|
||||
</head>
|
||||
<body>
|
||||
<div class="container">
|
||||
<h1>🔐 AWS SSO Authentication</h1>
|
||||
<p class="subtitle">Follow the steps below to complete authentication</p>
|
||||
|
||||
<div class="step">
|
||||
<div class="step-title">
|
||||
<span class="step-number">1</span>
|
||||
Click the button below to open the authorization page
|
||||
</div>
|
||||
<a href="{{.AuthURL}}" target="_blank" class="auth-btn" id="authBtn">
|
||||
🚀 Open Authorization Page
|
||||
</a>
|
||||
</div>
|
||||
|
||||
<div class="step">
|
||||
<div class="step-title">
|
||||
<span class="step-number">2</span>
|
||||
Enter the verification code below
|
||||
</div>
|
||||
<div class="user-code">
|
||||
<div class="user-code-label">Verification Code</div>
|
||||
<div class="user-code-value">{{.UserCode}}</div>
|
||||
</div>
|
||||
</div>
|
||||
|
||||
<div class="step">
|
||||
<div class="step-title">
|
||||
<span class="step-number">3</span>
|
||||
Complete AWS SSO login
|
||||
</div>
|
||||
<p style="color: #666; font-size: 14px; margin-top: 10px;">
|
||||
Use your AWS SSO account to login and authorize
|
||||
</p>
|
||||
</div>
|
||||
|
||||
<div class="status status-pending" id="statusBox">
|
||||
<div class="spinner" id="spinner"></div>
|
||||
<div class="timer" id="timer">{{.ExpiresIn}}s</div>
|
||||
<div class="status-message" id="statusMessage">
|
||||
Waiting for authorization...
|
||||
</div>
|
||||
</div>
|
||||
|
||||
<div class="info-box">
|
||||
💡 <strong>Tip:</strong> The authorization page will open in a new tab. This page will automatically update once authorization is complete.
|
||||
</div>
|
||||
</div>
|
||||
|
||||
<script>
|
||||
let pollInterval;
|
||||
let timerInterval;
|
||||
let remainingSeconds = {{.ExpiresIn}};
|
||||
const stateID = "{{.StateID}}";
|
||||
|
||||
setTimeout(() => {
|
||||
document.getElementById('authBtn').click();
|
||||
}, 500);
|
||||
|
||||
function pollStatus() {
|
||||
fetch('/v0/oauth/kiro/status?state=' + stateID)
|
||||
.then(response => response.json())
|
||||
.then(data => {
|
||||
console.log('Status:', data);
|
||||
if (data.status === 'success') {
|
||||
clearInterval(pollInterval);
|
||||
clearInterval(timerInterval);
|
||||
showSuccess(data);
|
||||
} else if (data.status === 'failed') {
|
||||
clearInterval(pollInterval);
|
||||
clearInterval(timerInterval);
|
||||
showError(data);
|
||||
} else {
|
||||
remainingSeconds = data.remaining_seconds || 0;
|
||||
}
|
||||
})
|
||||
.catch(error => {
|
||||
console.error('Poll error:', error);
|
||||
});
|
||||
}
|
||||
|
||||
function updateTimer() {
|
||||
const timerEl = document.getElementById('timer');
|
||||
const minutes = Math.floor(remainingSeconds / 60);
|
||||
const seconds = remainingSeconds % 60;
|
||||
timerEl.textContent = minutes + ':' + seconds.toString().padStart(2, '0');
|
||||
|
||||
if (remainingSeconds < 60) {
|
||||
timerEl.className = 'timer danger';
|
||||
} else if (remainingSeconds < 180) {
|
||||
timerEl.className = 'timer warning';
|
||||
} else {
|
||||
timerEl.className = 'timer';
|
||||
}
|
||||
|
||||
remainingSeconds--;
|
||||
|
||||
if (remainingSeconds < 0) {
|
||||
clearInterval(timerInterval);
|
||||
clearInterval(pollInterval);
|
||||
showError({ error: 'Authentication timed out. Please refresh and try again.' });
|
||||
}
|
||||
}
|
||||
|
||||
function showSuccess(data) {
|
||||
const statusBox = document.getElementById('statusBox');
|
||||
statusBox.className = 'status status-success';
|
||||
statusBox.innerHTML = '<div class="success-icon">✅</div>' +
|
||||
'<div class="status-message">' +
|
||||
'<strong>Authentication Successful!</strong><br>' +
|
||||
'Token expires: ' + new Date(data.expires_at).toLocaleString() +
|
||||
'</div>';
|
||||
}
|
||||
|
||||
function showError(data) {
|
||||
const statusBox = document.getElementById('statusBox');
|
||||
statusBox.className = 'status status-failed';
|
||||
statusBox.innerHTML = '<div class="error-icon">❌</div>' +
|
||||
'<div class="status-message">' +
|
||||
'<strong>Authentication Failed</strong><br>' +
|
||||
(data.error || 'Unknown error') +
|
||||
'</div>' +
|
||||
'<button class="auth-btn" onclick="location.reload()" style="margin-top: 15px;">' +
|
||||
'🔄 Retry' +
|
||||
'</button>';
|
||||
}
|
||||
|
||||
pollInterval = setInterval(pollStatus, 3000);
|
||||
timerInterval = setInterval(updateTimer, 1000);
|
||||
pollStatus();
|
||||
</script>
|
||||
</body>
|
||||
</html>`
|
||||
|
||||
oauthWebErrorPageHTML = `<!DOCTYPE html>
|
||||
<html>
|
||||
<head>
|
||||
<meta charset="UTF-8">
|
||||
<meta name="viewport" content="width=device-width, initial-scale=1.0">
|
||||
<title>Authentication Failed</title>
|
||||
<style>
|
||||
body {
|
||||
font-family: -apple-system, BlinkMacSystemFont, "Segoe UI", Roboto, sans-serif;
|
||||
max-width: 600px;
|
||||
margin: 50px auto;
|
||||
padding: 20px;
|
||||
background: #f5f5f5;
|
||||
}
|
||||
.error {
|
||||
background: #fff;
|
||||
padding: 30px;
|
||||
border-radius: 8px;
|
||||
box-shadow: 0 2px 4px rgba(0,0,0,0.1);
|
||||
border-left: 4px solid #dc3545;
|
||||
}
|
||||
h1 { color: #dc3545; margin-top: 0; }
|
||||
.error-message { color: #666; line-height: 1.6; }
|
||||
.retry-btn {
|
||||
display: inline-block;
|
||||
margin-top: 20px;
|
||||
padding: 10px 20px;
|
||||
background: #007bff;
|
||||
color: white;
|
||||
text-decoration: none;
|
||||
border-radius: 4px;
|
||||
}
|
||||
.retry-btn:hover { background: #0056b3; }
|
||||
</style>
|
||||
</head>
|
||||
<body>
|
||||
<div class="error">
|
||||
<h1>❌ Authentication Failed</h1>
|
||||
<div class="error-message">
|
||||
<p><strong>Error:</strong></p>
|
||||
<p>{{.Error}}</p>
|
||||
</div>
|
||||
<a href="/v0/oauth/kiro/start" class="retry-btn">🔄 Retry</a>
|
||||
</div>
|
||||
</body>
|
||||
</html>`
|
||||
|
||||
oauthWebSuccessPageHTML = `<!DOCTYPE html>
|
||||
<html>
|
||||
<head>
|
||||
<meta charset="UTF-8">
|
||||
<meta name="viewport" content="width=device-width, initial-scale=1.0">
|
||||
<title>Authentication Successful</title>
|
||||
<style>
|
||||
body {
|
||||
font-family: -apple-system, BlinkMacSystemFont, "Segoe UI", Roboto, sans-serif;
|
||||
max-width: 600px;
|
||||
margin: 50px auto;
|
||||
padding: 20px;
|
||||
background: #f5f5f5;
|
||||
}
|
||||
.success {
|
||||
background: #fff;
|
||||
padding: 30px;
|
||||
border-radius: 8px;
|
||||
box-shadow: 0 2px 4px rgba(0,0,0,0.1);
|
||||
border-left: 4px solid #28a745;
|
||||
text-align: center;
|
||||
}
|
||||
h1 { color: #28a745; margin-top: 0; }
|
||||
.success-message { color: #666; line-height: 1.6; }
|
||||
.icon { font-size: 48px; margin-bottom: 15px; }
|
||||
.expires { font-size: 14px; color: #999; margin-top: 15px; }
|
||||
</style>
|
||||
</head>
|
||||
<body>
|
||||
<div class="success">
|
||||
<div class="icon">✅</div>
|
||||
<h1>Authentication Successful!</h1>
|
||||
<div class="success-message">
|
||||
<p>You can close this window.</p>
|
||||
</div>
|
||||
<div class="expires">Token expires: {{.ExpiresAt}}</div>
|
||||
</div>
|
||||
</body>
|
||||
</html>`
|
||||
|
||||
oauthWebSelectPageHTML = `<!DOCTYPE html>
|
||||
<html>
|
||||
<head>
|
||||
<meta charset="UTF-8">
|
||||
<meta name="viewport" content="width=device-width, initial-scale=1.0">
|
||||
<title>Select Authentication Method</title>
|
||||
<style>
|
||||
* { box-sizing: border-box; }
|
||||
body {
|
||||
font-family: -apple-system, BlinkMacSystemFont, "Segoe UI", Roboto, sans-serif;
|
||||
margin: 0;
|
||||
padding: 20px;
|
||||
background: linear-gradient(135deg, #667eea 0%, #764ba2 100%);
|
||||
min-height: 100vh;
|
||||
display: flex;
|
||||
justify-content: center;
|
||||
align-items: center;
|
||||
}
|
||||
.container {
|
||||
max-width: 500px;
|
||||
width: 100%;
|
||||
background: #fff;
|
||||
padding: 40px;
|
||||
border-radius: 12px;
|
||||
box-shadow: 0 10px 40px rgba(0,0,0,0.2);
|
||||
}
|
||||
h1 {
|
||||
margin: 0 0 10px;
|
||||
color: #333;
|
||||
font-size: 24px;
|
||||
text-align: center;
|
||||
}
|
||||
.subtitle {
|
||||
text-align: center;
|
||||
color: #666;
|
||||
margin-bottom: 30px;
|
||||
}
|
||||
.auth-methods {
|
||||
display: flex;
|
||||
flex-direction: column;
|
||||
gap: 15px;
|
||||
}
|
||||
.auth-btn {
|
||||
display: flex;
|
||||
align-items: center;
|
||||
width: 100%;
|
||||
padding: 15px 20px;
|
||||
background: #667eea;
|
||||
color: white;
|
||||
text-decoration: none;
|
||||
border-radius: 8px;
|
||||
font-weight: 600;
|
||||
font-size: 16px;
|
||||
transition: all 0.3s;
|
||||
border: none;
|
||||
cursor: pointer;
|
||||
}
|
||||
.auth-btn:hover {
|
||||
background: #5568d3;
|
||||
transform: translateY(-2px);
|
||||
box-shadow: 0 4px 12px rgba(102, 126, 234, 0.4);
|
||||
}
|
||||
.auth-btn .icon {
|
||||
font-size: 24px;
|
||||
margin-right: 15px;
|
||||
width: 32px;
|
||||
text-align: center;
|
||||
}
|
||||
.auth-btn.google { background: #4285F4; }
|
||||
.auth-btn.google:hover { background: #3367D6; }
|
||||
.auth-btn.github { background: #24292e; }
|
||||
.auth-btn.github:hover { background: #1a1e22; }
|
||||
.auth-btn.aws { background: #FF9900; }
|
||||
.auth-btn.aws:hover { background: #E68A00; }
|
||||
.auth-btn.idc { background: #232F3E; }
|
||||
.auth-btn.idc:hover { background: #1a242f; }
|
||||
.idc-form {
|
||||
background: #f8f9fa;
|
||||
padding: 20px;
|
||||
border-radius: 8px;
|
||||
margin-top: 15px;
|
||||
display: none;
|
||||
}
|
||||
.idc-form.show {
|
||||
display: block;
|
||||
}
|
||||
.form-group {
|
||||
margin-bottom: 15px;
|
||||
}
|
||||
.form-group label {
|
||||
display: block;
|
||||
font-weight: 600;
|
||||
color: #333;
|
||||
margin-bottom: 8px;
|
||||
font-size: 14px;
|
||||
}
|
||||
.form-group input {
|
||||
width: 100%;
|
||||
padding: 12px;
|
||||
border: 2px solid #e0e0e0;
|
||||
border-radius: 6px;
|
||||
font-size: 14px;
|
||||
transition: border-color 0.3s;
|
||||
}
|
||||
.form-group input:focus {
|
||||
outline: none;
|
||||
border-color: #667eea;
|
||||
}
|
||||
.form-group .hint {
|
||||
font-size: 12px;
|
||||
color: #999;
|
||||
margin-top: 5px;
|
||||
}
|
||||
.submit-btn {
|
||||
display: block;
|
||||
width: 100%;
|
||||
padding: 15px;
|
||||
background: #232F3E;
|
||||
color: white;
|
||||
text-align: center;
|
||||
text-decoration: none;
|
||||
border-radius: 8px;
|
||||
font-weight: 600;
|
||||
font-size: 16px;
|
||||
transition: all 0.3s;
|
||||
border: none;
|
||||
cursor: pointer;
|
||||
}
|
||||
.submit-btn:hover {
|
||||
background: #1a242f;
|
||||
transform: translateY(-2px);
|
||||
box-shadow: 0 4px 12px rgba(35, 47, 62, 0.4);
|
||||
}
|
||||
.divider {
|
||||
display: flex;
|
||||
align-items: center;
|
||||
margin: 20px 0;
|
||||
}
|
||||
.divider::before,
|
||||
.divider::after {
|
||||
content: "";
|
||||
flex: 1;
|
||||
border-bottom: 1px solid #e0e0e0;
|
||||
}
|
||||
.divider span {
|
||||
padding: 0 15px;
|
||||
color: #999;
|
||||
font-size: 14px;
|
||||
}
|
||||
.info-box {
|
||||
background: #e7f3ff;
|
||||
border-left: 4px solid #2196F3;
|
||||
padding: 15px;
|
||||
margin-top: 20px;
|
||||
border-radius: 4px;
|
||||
font-size: 14px;
|
||||
color: #666;
|
||||
}
|
||||
.warning-box {
|
||||
background: #fff3cd;
|
||||
border-left: 4px solid #ffc107;
|
||||
padding: 15px;
|
||||
margin-top: 20px;
|
||||
border-radius: 4px;
|
||||
font-size: 14px;
|
||||
color: #856404;
|
||||
}
|
||||
.auth-btn.manual { background: #6c757d; }
|
||||
.auth-btn.manual:hover { background: #5a6268; }
|
||||
.auth-btn.refresh { background: #17a2b8; }
|
||||
.auth-btn.refresh:hover { background: #138496; }
|
||||
.auth-btn.refresh:disabled { background: #7fb3bd; cursor: not-allowed; }
|
||||
.manual-form {
|
||||
background: #f8f9fa;
|
||||
padding: 20px;
|
||||
border-radius: 8px;
|
||||
margin-top: 15px;
|
||||
display: none;
|
||||
}
|
||||
.manual-form.show {
|
||||
display: block;
|
||||
}
|
||||
.form-group textarea {
|
||||
width: 100%;
|
||||
padding: 12px;
|
||||
border: 2px solid #e0e0e0;
|
||||
border-radius: 6px;
|
||||
font-size: 14px;
|
||||
font-family: monospace;
|
||||
transition: border-color 0.3s;
|
||||
resize: vertical;
|
||||
min-height: 80px;
|
||||
}
|
||||
.form-group textarea:focus {
|
||||
outline: none;
|
||||
border-color: #667eea;
|
||||
}
|
||||
.status-message {
|
||||
padding: 15px;
|
||||
border-radius: 6px;
|
||||
margin-top: 15px;
|
||||
display: none;
|
||||
}
|
||||
.status-message.success {
|
||||
background: #d4edda;
|
||||
color: #155724;
|
||||
display: block;
|
||||
}
|
||||
.status-message.error {
|
||||
background: #f8d7da;
|
||||
color: #721c24;
|
||||
display: block;
|
||||
}
|
||||
</style>
|
||||
</head>
|
||||
<body>
|
||||
<div class="container">
|
||||
<h1>🔐 Select Authentication Method</h1>
|
||||
<p class="subtitle">Choose how you want to authenticate with Kiro</p>
|
||||
|
||||
<div class="auth-methods">
|
||||
<a href="/v0/oauth/kiro/start?method=builder-id" class="auth-btn aws">
|
||||
<span class="icon">🔶</span>
|
||||
AWS Builder ID (Recommended)
|
||||
</a>
|
||||
|
||||
<button type="button" class="auth-btn idc" onclick="toggleIdcForm()">
|
||||
<span class="icon">🏢</span>
|
||||
AWS Identity Center (IDC)
|
||||
</button>
|
||||
|
||||
<div class="divider"><span>or</span></div>
|
||||
|
||||
<button type="button" class="auth-btn manual" onclick="toggleManualForm()">
|
||||
<span class="icon">📋</span>
|
||||
Import RefreshToken from Kiro IDE
|
||||
</button>
|
||||
|
||||
<button type="button" class="auth-btn refresh" onclick="manualRefresh()" id="refreshBtn">
|
||||
<span class="icon">🔄</span>
|
||||
Manual Refresh All Tokens
|
||||
</button>
|
||||
|
||||
<div class="status-message" id="refreshStatus"></div>
|
||||
</div>
|
||||
|
||||
<div class="idc-form" id="idcForm">
|
||||
<form action="/v0/oauth/kiro/start" method="get">
|
||||
<input type="hidden" name="method" value="idc">
|
||||
|
||||
<div class="form-group">
|
||||
<label for="startUrl">Start URL</label>
|
||||
<input type="url" id="startUrl" name="startUrl" placeholder="https://your-org.awsapps.com/start" required>
|
||||
<div class="hint">Your AWS Identity Center Start URL</div>
|
||||
</div>
|
||||
|
||||
<div class="form-group">
|
||||
<label for="region">Region</label>
|
||||
<input type="text" id="region" name="region" value="us-east-1" placeholder="us-east-1">
|
||||
<div class="hint">AWS Region for your Identity Center</div>
|
||||
</div>
|
||||
|
||||
<button type="submit" class="submit-btn">
|
||||
🚀 Continue with IDC
|
||||
</button>
|
||||
</form>
|
||||
</div>
|
||||
|
||||
<div class="manual-form" id="manualForm">
|
||||
<form id="importForm" onsubmit="submitImport(event)">
|
||||
<div class="form-group">
|
||||
<label for="refreshToken">Refresh Token</label>
|
||||
<textarea id="refreshToken" name="refreshToken" placeholder="Paste your refreshToken here (starts with aorAAAAAG...)" required></textarea>
|
||||
<div class="hint">Copy from Kiro IDE: ~/.kiro/kiro-auth-token.json → refreshToken field</div>
|
||||
</div>
|
||||
|
||||
<button type="submit" class="submit-btn" id="importBtn">
|
||||
📥 Import Token
|
||||
</button>
|
||||
|
||||
<div class="status-message" id="importStatus"></div>
|
||||
</form>
|
||||
</div>
|
||||
|
||||
<div class="warning-box">
|
||||
⚠️ <strong>Note:</strong> Google and GitHub login are not available for third-party applications due to AWS Cognito restrictions. Please use AWS Builder ID or import your token from Kiro IDE.
|
||||
</div>
|
||||
|
||||
<div class="info-box">
|
||||
💡 <strong>How to get RefreshToken:</strong><br>
|
||||
1. Open Kiro IDE and login with Google/GitHub<br>
|
||||
2. Find the token file: <code>~/.kiro/kiro-auth-token.json</code><br>
|
||||
3. Copy the <code>refreshToken</code> value and paste it above
|
||||
</div>
|
||||
</div>
|
||||
|
||||
<script>
|
||||
function toggleIdcForm() {
|
||||
const idcForm = document.getElementById('idcForm');
|
||||
const manualForm = document.getElementById('manualForm');
|
||||
manualForm.classList.remove('show');
|
||||
idcForm.classList.toggle('show');
|
||||
if (idcForm.classList.contains('show')) {
|
||||
document.getElementById('startUrl').focus();
|
||||
}
|
||||
}
|
||||
|
||||
function toggleManualForm() {
|
||||
const idcForm = document.getElementById('idcForm');
|
||||
const manualForm = document.getElementById('manualForm');
|
||||
idcForm.classList.remove('show');
|
||||
manualForm.classList.toggle('show');
|
||||
if (manualForm.classList.contains('show')) {
|
||||
document.getElementById('refreshToken').focus();
|
||||
}
|
||||
}
|
||||
|
||||
async function submitImport(event) {
|
||||
event.preventDefault();
|
||||
const refreshToken = document.getElementById('refreshToken').value.trim();
|
||||
const statusEl = document.getElementById('importStatus');
|
||||
const btn = document.getElementById('importBtn');
|
||||
|
||||
if (!refreshToken) {
|
||||
statusEl.className = 'status-message error';
|
||||
statusEl.textContent = 'Please enter a refresh token';
|
||||
return;
|
||||
}
|
||||
|
||||
if (!refreshToken.startsWith('aorAAAAAG')) {
|
||||
statusEl.className = 'status-message error';
|
||||
statusEl.textContent = 'Invalid token format. Token should start with aorAAAAAG...';
|
||||
return;
|
||||
}
|
||||
|
||||
btn.disabled = true;
|
||||
btn.textContent = '⏳ Importing...';
|
||||
statusEl.className = 'status-message';
|
||||
statusEl.style.display = 'none';
|
||||
|
||||
try {
|
||||
const response = await fetch('/v0/oauth/kiro/import', {
|
||||
method: 'POST',
|
||||
headers: { 'Content-Type': 'application/json' },
|
||||
body: JSON.stringify({ refreshToken: refreshToken })
|
||||
});
|
||||
|
||||
const data = await response.json();
|
||||
|
||||
if (response.ok && data.success) {
|
||||
statusEl.className = 'status-message success';
|
||||
statusEl.textContent = '✅ Token imported successfully! File: ' + (data.fileName || 'kiro-token.json');
|
||||
} else {
|
||||
statusEl.className = 'status-message error';
|
||||
statusEl.textContent = '❌ ' + (data.error || data.message || 'Import failed');
|
||||
}
|
||||
} catch (error) {
|
||||
statusEl.className = 'status-message error';
|
||||
statusEl.textContent = '❌ Network error: ' + error.message;
|
||||
} finally {
|
||||
btn.disabled = false;
|
||||
btn.textContent = '📥 Import Token';
|
||||
}
|
||||
}
|
||||
|
||||
async function manualRefresh() {
|
||||
const btn = document.getElementById('refreshBtn');
|
||||
const statusEl = document.getElementById('refreshStatus');
|
||||
|
||||
btn.disabled = true;
|
||||
btn.innerHTML = '<span class="icon">⏳</span> Refreshing...';
|
||||
statusEl.className = 'status-message';
|
||||
statusEl.style.display = 'none';
|
||||
|
||||
try {
|
||||
const response = await fetch('/v0/oauth/kiro/refresh', {
|
||||
method: 'POST',
|
||||
headers: { 'Content-Type': 'application/json' }
|
||||
});
|
||||
|
||||
const data = await response.json();
|
||||
|
||||
if (response.ok && data.success) {
|
||||
statusEl.className = 'status-message success';
|
||||
let msg = '✅ ' + data.message;
|
||||
if (data.warnings && data.warnings.length > 0) {
|
||||
msg += ' (Warnings: ' + data.warnings.join('; ') + ')';
|
||||
}
|
||||
statusEl.textContent = msg;
|
||||
} else {
|
||||
statusEl.className = 'status-message error';
|
||||
statusEl.textContent = '❌ ' + (data.error || data.message || 'Refresh failed');
|
||||
}
|
||||
} catch (error) {
|
||||
statusEl.className = 'status-message error';
|
||||
statusEl.textContent = '❌ Network error: ' + error.message;
|
||||
} finally {
|
||||
btn.disabled = false;
|
||||
btn.innerHTML = '<span class="icon">🔄</span> Manual Refresh All Tokens';
|
||||
}
|
||||
}
|
||||
</script>
|
||||
</body>
|
||||
</html>`
|
||||
)
|
||||
725
internal/auth/kiro/protocol_handler.go
Normal file
725
internal/auth/kiro/protocol_handler.go
Normal file
@@ -0,0 +1,725 @@
|
||||
// Package kiro provides custom protocol handler registration for Kiro OAuth.
|
||||
// This enables the CLI to intercept kiro:// URIs for social authentication (Google/GitHub).
|
||||
package kiro
|
||||
|
||||
import (
|
||||
"context"
|
||||
"fmt"
|
||||
"html"
|
||||
"net"
|
||||
"net/http"
|
||||
"net/url"
|
||||
"os"
|
||||
"os/exec"
|
||||
"path/filepath"
|
||||
"runtime"
|
||||
"strings"
|
||||
"sync"
|
||||
"time"
|
||||
|
||||
log "github.com/sirupsen/logrus"
|
||||
)
|
||||
|
||||
const (
|
||||
// KiroProtocol is the custom URI scheme used by Kiro
|
||||
KiroProtocol = "kiro"
|
||||
|
||||
// KiroAuthority is the URI authority for authentication callbacks
|
||||
KiroAuthority = "kiro.kiroAgent"
|
||||
|
||||
// KiroAuthPath is the path for successful authentication
|
||||
KiroAuthPath = "/authenticate-success"
|
||||
|
||||
// KiroRedirectURI is the full redirect URI for social auth
|
||||
KiroRedirectURI = "kiro://kiro.kiroAgent/authenticate-success"
|
||||
|
||||
// DefaultHandlerPort is the default port for the local callback server
|
||||
DefaultHandlerPort = 19876
|
||||
|
||||
// HandlerTimeout is how long to wait for the OAuth callback
|
||||
HandlerTimeout = 10 * time.Minute
|
||||
)
|
||||
|
||||
// ProtocolHandler manages the custom kiro:// protocol handler for OAuth callbacks.
|
||||
type ProtocolHandler struct {
|
||||
port int
|
||||
server *http.Server
|
||||
listener net.Listener
|
||||
resultChan chan *AuthCallback
|
||||
stopChan chan struct{}
|
||||
mu sync.Mutex
|
||||
running bool
|
||||
}
|
||||
|
||||
// AuthCallback contains the OAuth callback parameters.
|
||||
type AuthCallback struct {
|
||||
Code string
|
||||
State string
|
||||
Error string
|
||||
}
|
||||
|
||||
// NewProtocolHandler creates a new protocol handler.
|
||||
func NewProtocolHandler() *ProtocolHandler {
|
||||
return &ProtocolHandler{
|
||||
port: DefaultHandlerPort,
|
||||
resultChan: make(chan *AuthCallback, 1),
|
||||
stopChan: make(chan struct{}),
|
||||
}
|
||||
}
|
||||
|
||||
// Start starts the local callback server that receives redirects from the protocol handler.
|
||||
func (h *ProtocolHandler) Start(ctx context.Context) (int, error) {
|
||||
h.mu.Lock()
|
||||
defer h.mu.Unlock()
|
||||
|
||||
if h.running {
|
||||
return h.port, nil
|
||||
}
|
||||
|
||||
// Drain any stale results from previous runs
|
||||
select {
|
||||
case <-h.resultChan:
|
||||
default:
|
||||
}
|
||||
|
||||
// Reset stopChan for reuse - close old channel first to unblock any waiting goroutines
|
||||
if h.stopChan != nil {
|
||||
select {
|
||||
case <-h.stopChan:
|
||||
// Already closed
|
||||
default:
|
||||
close(h.stopChan)
|
||||
}
|
||||
}
|
||||
h.stopChan = make(chan struct{})
|
||||
|
||||
// Try ports in known range (must match handler script port range)
|
||||
var listener net.Listener
|
||||
var err error
|
||||
portRange := []int{DefaultHandlerPort, DefaultHandlerPort + 1, DefaultHandlerPort + 2, DefaultHandlerPort + 3, DefaultHandlerPort + 4}
|
||||
|
||||
for _, port := range portRange {
|
||||
listener, err = net.Listen("tcp", fmt.Sprintf("127.0.0.1:%d", port))
|
||||
if err == nil {
|
||||
break
|
||||
}
|
||||
log.Debugf("kiro protocol handler: port %d busy, trying next", port)
|
||||
}
|
||||
|
||||
if listener == nil {
|
||||
return 0, fmt.Errorf("failed to start callback server: all ports %d-%d are busy", DefaultHandlerPort, DefaultHandlerPort+4)
|
||||
}
|
||||
|
||||
h.listener = listener
|
||||
h.port = listener.Addr().(*net.TCPAddr).Port
|
||||
|
||||
mux := http.NewServeMux()
|
||||
mux.HandleFunc("/oauth/callback", h.handleCallback)
|
||||
|
||||
h.server = &http.Server{
|
||||
Handler: mux,
|
||||
ReadHeaderTimeout: 10 * time.Second,
|
||||
}
|
||||
|
||||
go func() {
|
||||
if err := h.server.Serve(listener); err != nil && err != http.ErrServerClosed {
|
||||
log.Debugf("kiro protocol handler server error: %v", err)
|
||||
}
|
||||
}()
|
||||
|
||||
h.running = true
|
||||
log.Debugf("kiro protocol handler started on port %d", h.port)
|
||||
|
||||
// Auto-shutdown after context done, timeout, or explicit stop
|
||||
// Capture references to prevent race with new Start() calls
|
||||
currentStopChan := h.stopChan
|
||||
currentServer := h.server
|
||||
currentListener := h.listener
|
||||
go func() {
|
||||
select {
|
||||
case <-ctx.Done():
|
||||
case <-time.After(HandlerTimeout):
|
||||
case <-currentStopChan:
|
||||
return // Already stopped, exit goroutine
|
||||
}
|
||||
// Only stop if this is still the current server/listener instance
|
||||
h.mu.Lock()
|
||||
if h.server == currentServer && h.listener == currentListener {
|
||||
h.mu.Unlock()
|
||||
h.Stop()
|
||||
} else {
|
||||
h.mu.Unlock()
|
||||
}
|
||||
}()
|
||||
|
||||
return h.port, nil
|
||||
}
|
||||
|
||||
// Stop stops the callback server.
|
||||
func (h *ProtocolHandler) Stop() {
|
||||
h.mu.Lock()
|
||||
defer h.mu.Unlock()
|
||||
|
||||
if !h.running {
|
||||
return
|
||||
}
|
||||
|
||||
// Signal the auto-shutdown goroutine to exit.
|
||||
// This select pattern is safe because stopChan is only modified while holding h.mu,
|
||||
// and we hold the lock here. The select prevents panic from double-close.
|
||||
select {
|
||||
case <-h.stopChan:
|
||||
// Already closed
|
||||
default:
|
||||
close(h.stopChan)
|
||||
}
|
||||
|
||||
if h.server != nil {
|
||||
ctx, cancel := context.WithTimeout(context.Background(), 5*time.Second)
|
||||
defer cancel()
|
||||
_ = h.server.Shutdown(ctx)
|
||||
}
|
||||
|
||||
h.running = false
|
||||
log.Debug("kiro protocol handler stopped")
|
||||
}
|
||||
|
||||
// WaitForCallback waits for the OAuth callback and returns the result.
|
||||
func (h *ProtocolHandler) WaitForCallback(ctx context.Context) (*AuthCallback, error) {
|
||||
select {
|
||||
case <-ctx.Done():
|
||||
return nil, ctx.Err()
|
||||
case <-time.After(HandlerTimeout):
|
||||
return nil, fmt.Errorf("timeout waiting for OAuth callback")
|
||||
case result := <-h.resultChan:
|
||||
return result, nil
|
||||
}
|
||||
}
|
||||
|
||||
// GetPort returns the port the handler is listening on.
|
||||
func (h *ProtocolHandler) GetPort() int {
|
||||
return h.port
|
||||
}
|
||||
|
||||
// handleCallback processes the OAuth callback from the protocol handler script.
|
||||
func (h *ProtocolHandler) handleCallback(w http.ResponseWriter, r *http.Request) {
|
||||
code := r.URL.Query().Get("code")
|
||||
state := r.URL.Query().Get("state")
|
||||
errParam := r.URL.Query().Get("error")
|
||||
|
||||
result := &AuthCallback{
|
||||
Code: code,
|
||||
State: state,
|
||||
Error: errParam,
|
||||
}
|
||||
|
||||
// Send result
|
||||
select {
|
||||
case h.resultChan <- result:
|
||||
default:
|
||||
// Channel full, ignore duplicate callbacks
|
||||
}
|
||||
|
||||
// Send success response
|
||||
w.Header().Set("Content-Type", "text/html; charset=utf-8")
|
||||
if errParam != "" {
|
||||
w.WriteHeader(http.StatusBadRequest)
|
||||
fmt.Fprintf(w, `<!DOCTYPE html>
|
||||
<html>
|
||||
<head><title>Login Failed</title></head>
|
||||
<body>
|
||||
<h1>Login Failed</h1>
|
||||
<p>Error: %s</p>
|
||||
<p>You can close this window.</p>
|
||||
</body>
|
||||
</html>`, html.EscapeString(errParam))
|
||||
} else {
|
||||
fmt.Fprint(w, `<!DOCTYPE html>
|
||||
<html>
|
||||
<head><title>Login Successful</title></head>
|
||||
<body>
|
||||
<h1>Login Successful!</h1>
|
||||
<p>You can close this window and return to the terminal.</p>
|
||||
<script>window.close();</script>
|
||||
</body>
|
||||
</html>`)
|
||||
}
|
||||
}
|
||||
|
||||
// IsProtocolHandlerInstalled checks if the kiro:// protocol handler is installed.
|
||||
func IsProtocolHandlerInstalled() bool {
|
||||
switch runtime.GOOS {
|
||||
case "linux":
|
||||
return isLinuxHandlerInstalled()
|
||||
case "windows":
|
||||
return isWindowsHandlerInstalled()
|
||||
case "darwin":
|
||||
return isDarwinHandlerInstalled()
|
||||
default:
|
||||
return false
|
||||
}
|
||||
}
|
||||
|
||||
// InstallProtocolHandler installs the kiro:// protocol handler for the current platform.
|
||||
func InstallProtocolHandler(handlerPort int) error {
|
||||
switch runtime.GOOS {
|
||||
case "linux":
|
||||
return installLinuxHandler(handlerPort)
|
||||
case "windows":
|
||||
return installWindowsHandler(handlerPort)
|
||||
case "darwin":
|
||||
return installDarwinHandler(handlerPort)
|
||||
default:
|
||||
return fmt.Errorf("unsupported platform: %s", runtime.GOOS)
|
||||
}
|
||||
}
|
||||
|
||||
// UninstallProtocolHandler removes the kiro:// protocol handler.
|
||||
func UninstallProtocolHandler() error {
|
||||
switch runtime.GOOS {
|
||||
case "linux":
|
||||
return uninstallLinuxHandler()
|
||||
case "windows":
|
||||
return uninstallWindowsHandler()
|
||||
case "darwin":
|
||||
return uninstallDarwinHandler()
|
||||
default:
|
||||
return fmt.Errorf("unsupported platform: %s", runtime.GOOS)
|
||||
}
|
||||
}
|
||||
|
||||
// --- Linux Implementation ---
|
||||
|
||||
func getLinuxDesktopPath() string {
|
||||
homeDir, _ := os.UserHomeDir()
|
||||
return filepath.Join(homeDir, ".local", "share", "applications", "kiro-oauth-handler.desktop")
|
||||
}
|
||||
|
||||
func getLinuxHandlerScriptPath() string {
|
||||
homeDir, _ := os.UserHomeDir()
|
||||
return filepath.Join(homeDir, ".local", "bin", "kiro-oauth-handler")
|
||||
}
|
||||
|
||||
func isLinuxHandlerInstalled() bool {
|
||||
desktopPath := getLinuxDesktopPath()
|
||||
_, err := os.Stat(desktopPath)
|
||||
return err == nil
|
||||
}
|
||||
|
||||
func installLinuxHandler(handlerPort int) error {
|
||||
// Create directories
|
||||
homeDir, err := os.UserHomeDir()
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
binDir := filepath.Join(homeDir, ".local", "bin")
|
||||
appDir := filepath.Join(homeDir, ".local", "share", "applications")
|
||||
|
||||
if err := os.MkdirAll(binDir, 0755); err != nil {
|
||||
return fmt.Errorf("failed to create bin directory: %w", err)
|
||||
}
|
||||
if err := os.MkdirAll(appDir, 0755); err != nil {
|
||||
return fmt.Errorf("failed to create applications directory: %w", err)
|
||||
}
|
||||
|
||||
// Create handler script - tries multiple ports to handle dynamic port allocation
|
||||
scriptPath := getLinuxHandlerScriptPath()
|
||||
scriptContent := fmt.Sprintf(`#!/bin/bash
|
||||
# Kiro OAuth Protocol Handler
|
||||
# Handles kiro:// URIs - tries CLI first, then forwards to Kiro IDE
|
||||
|
||||
URL="$1"
|
||||
|
||||
# Check curl availability
|
||||
if ! command -v curl &> /dev/null; then
|
||||
echo "Error: curl is required for Kiro OAuth handler" >&2
|
||||
exit 1
|
||||
fi
|
||||
|
||||
# Extract code and state from URL
|
||||
[[ "$URL" =~ code=([^&]+) ]] && CODE="${BASH_REMATCH[1]}"
|
||||
[[ "$URL" =~ state=([^&]+) ]] && STATE="${BASH_REMATCH[1]}"
|
||||
[[ "$URL" =~ error=([^&]+) ]] && ERROR="${BASH_REMATCH[1]}"
|
||||
|
||||
# Try CLI proxy on multiple possible ports (default + dynamic range)
|
||||
CLI_OK=0
|
||||
for PORT in %d %d %d %d %d; do
|
||||
if [ -n "$ERROR" ]; then
|
||||
curl -sf --connect-timeout 1 "http://127.0.0.1:$PORT/oauth/callback?error=$ERROR" && CLI_OK=1 && break
|
||||
elif [ -n "$CODE" ] && [ -n "$STATE" ]; then
|
||||
curl -sf --connect-timeout 1 "http://127.0.0.1:$PORT/oauth/callback?code=$CODE&state=$STATE" && CLI_OK=1 && break
|
||||
fi
|
||||
done
|
||||
|
||||
# If CLI not available, forward to Kiro IDE
|
||||
if [ $CLI_OK -eq 0 ] && [ -x "/usr/share/kiro/kiro" ]; then
|
||||
/usr/share/kiro/kiro --open-url "$URL" &
|
||||
fi
|
||||
`, handlerPort, handlerPort+1, handlerPort+2, handlerPort+3, handlerPort+4)
|
||||
|
||||
if err := os.WriteFile(scriptPath, []byte(scriptContent), 0755); err != nil {
|
||||
return fmt.Errorf("failed to write handler script: %w", err)
|
||||
}
|
||||
|
||||
// Create .desktop file
|
||||
desktopPath := getLinuxDesktopPath()
|
||||
desktopContent := fmt.Sprintf(`[Desktop Entry]
|
||||
Name=Kiro OAuth Handler
|
||||
Comment=Handle kiro:// protocol for CLI Proxy API authentication
|
||||
Exec=%s %%u
|
||||
Type=Application
|
||||
Terminal=false
|
||||
NoDisplay=true
|
||||
MimeType=x-scheme-handler/kiro;
|
||||
Categories=Utility;
|
||||
`, scriptPath)
|
||||
|
||||
if err := os.WriteFile(desktopPath, []byte(desktopContent), 0644); err != nil {
|
||||
return fmt.Errorf("failed to write desktop file: %w", err)
|
||||
}
|
||||
|
||||
// Register handler with xdg-mime
|
||||
cmd := exec.Command("xdg-mime", "default", "kiro-oauth-handler.desktop", "x-scheme-handler/kiro")
|
||||
if err := cmd.Run(); err != nil {
|
||||
log.Warnf("xdg-mime registration failed (may need manual setup): %v", err)
|
||||
}
|
||||
|
||||
// Update desktop database
|
||||
cmd = exec.Command("update-desktop-database", appDir)
|
||||
_ = cmd.Run() // Ignore errors, not critical
|
||||
|
||||
log.Info("Kiro protocol handler installed for Linux")
|
||||
return nil
|
||||
}
|
||||
|
||||
func uninstallLinuxHandler() error {
|
||||
desktopPath := getLinuxDesktopPath()
|
||||
scriptPath := getLinuxHandlerScriptPath()
|
||||
|
||||
if err := os.Remove(desktopPath); err != nil && !os.IsNotExist(err) {
|
||||
return fmt.Errorf("failed to remove desktop file: %w", err)
|
||||
}
|
||||
if err := os.Remove(scriptPath); err != nil && !os.IsNotExist(err) {
|
||||
return fmt.Errorf("failed to remove handler script: %w", err)
|
||||
}
|
||||
|
||||
log.Info("Kiro protocol handler uninstalled")
|
||||
return nil
|
||||
}
|
||||
|
||||
// --- Windows Implementation ---
|
||||
|
||||
func isWindowsHandlerInstalled() bool {
|
||||
// Check registry key existence
|
||||
cmd := exec.Command("reg", "query", `HKCU\Software\Classes\kiro`, "/ve")
|
||||
return cmd.Run() == nil
|
||||
}
|
||||
|
||||
func installWindowsHandler(handlerPort int) error {
|
||||
homeDir, err := os.UserHomeDir()
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
// Create handler script (PowerShell)
|
||||
scriptDir := filepath.Join(homeDir, ".cliproxyapi")
|
||||
if err := os.MkdirAll(scriptDir, 0755); err != nil {
|
||||
return fmt.Errorf("failed to create script directory: %w", err)
|
||||
}
|
||||
|
||||
scriptPath := filepath.Join(scriptDir, "kiro-oauth-handler.ps1")
|
||||
scriptContent := fmt.Sprintf(`# Kiro OAuth Protocol Handler for Windows
|
||||
param([string]$url)
|
||||
|
||||
# Load required assembly for HttpUtility
|
||||
Add-Type -AssemblyName System.Web
|
||||
|
||||
# Parse URL parameters
|
||||
$uri = [System.Uri]$url
|
||||
$query = [System.Web.HttpUtility]::ParseQueryString($uri.Query)
|
||||
$code = $query["code"]
|
||||
$state = $query["state"]
|
||||
$errorParam = $query["error"]
|
||||
|
||||
# Try multiple ports (default + dynamic range)
|
||||
$ports = @(%d, %d, %d, %d, %d)
|
||||
$success = $false
|
||||
|
||||
foreach ($port in $ports) {
|
||||
if ($success) { break }
|
||||
$callbackUrl = "http://127.0.0.1:$port/oauth/callback"
|
||||
try {
|
||||
if ($errorParam) {
|
||||
$fullUrl = $callbackUrl + "?error=" + $errorParam
|
||||
Invoke-WebRequest -Uri $fullUrl -UseBasicParsing -TimeoutSec 1 -ErrorAction Stop | Out-Null
|
||||
$success = $true
|
||||
} elseif ($code -and $state) {
|
||||
$fullUrl = $callbackUrl + "?code=" + $code + "&state=" + $state
|
||||
Invoke-WebRequest -Uri $fullUrl -UseBasicParsing -TimeoutSec 1 -ErrorAction Stop | Out-Null
|
||||
$success = $true
|
||||
}
|
||||
} catch {
|
||||
# Try next port
|
||||
}
|
||||
}
|
||||
`, handlerPort, handlerPort+1, handlerPort+2, handlerPort+3, handlerPort+4)
|
||||
|
||||
if err := os.WriteFile(scriptPath, []byte(scriptContent), 0644); err != nil {
|
||||
return fmt.Errorf("failed to write handler script: %w", err)
|
||||
}
|
||||
|
||||
// Create batch wrapper
|
||||
batchPath := filepath.Join(scriptDir, "kiro-oauth-handler.bat")
|
||||
batchContent := fmt.Sprintf("@echo off\npowershell -ExecutionPolicy Bypass -File \"%s\" %%1\n", scriptPath)
|
||||
|
||||
if err := os.WriteFile(batchPath, []byte(batchContent), 0644); err != nil {
|
||||
return fmt.Errorf("failed to write batch wrapper: %w", err)
|
||||
}
|
||||
|
||||
// Register in Windows registry
|
||||
commands := [][]string{
|
||||
{"reg", "add", `HKCU\Software\Classes\kiro`, "/ve", "/d", "URL:Kiro Protocol", "/f"},
|
||||
{"reg", "add", `HKCU\Software\Classes\kiro`, "/v", "URL Protocol", "/d", "", "/f"},
|
||||
{"reg", "add", `HKCU\Software\Classes\kiro\shell`, "/f"},
|
||||
{"reg", "add", `HKCU\Software\Classes\kiro\shell\open`, "/f"},
|
||||
{"reg", "add", `HKCU\Software\Classes\kiro\shell\open\command`, "/ve", "/d", fmt.Sprintf("\"%s\" \"%%1\"", batchPath), "/f"},
|
||||
}
|
||||
|
||||
for _, args := range commands {
|
||||
cmd := exec.Command(args[0], args[1:]...)
|
||||
if err := cmd.Run(); err != nil {
|
||||
return fmt.Errorf("failed to run registry command: %w", err)
|
||||
}
|
||||
}
|
||||
|
||||
log.Info("Kiro protocol handler installed for Windows")
|
||||
return nil
|
||||
}
|
||||
|
||||
func uninstallWindowsHandler() error {
|
||||
// Remove registry keys
|
||||
cmd := exec.Command("reg", "delete", `HKCU\Software\Classes\kiro`, "/f")
|
||||
if err := cmd.Run(); err != nil {
|
||||
log.Warnf("failed to remove registry key: %v", err)
|
||||
}
|
||||
|
||||
// Remove scripts
|
||||
homeDir, _ := os.UserHomeDir()
|
||||
scriptDir := filepath.Join(homeDir, ".cliproxyapi")
|
||||
_ = os.Remove(filepath.Join(scriptDir, "kiro-oauth-handler.ps1"))
|
||||
_ = os.Remove(filepath.Join(scriptDir, "kiro-oauth-handler.bat"))
|
||||
|
||||
log.Info("Kiro protocol handler uninstalled")
|
||||
return nil
|
||||
}
|
||||
|
||||
// --- macOS Implementation ---
|
||||
|
||||
func getDarwinAppPath() string {
|
||||
homeDir, _ := os.UserHomeDir()
|
||||
return filepath.Join(homeDir, "Applications", "KiroOAuthHandler.app")
|
||||
}
|
||||
|
||||
func isDarwinHandlerInstalled() bool {
|
||||
appPath := getDarwinAppPath()
|
||||
_, err := os.Stat(appPath)
|
||||
return err == nil
|
||||
}
|
||||
|
||||
func installDarwinHandler(handlerPort int) error {
|
||||
// Create app bundle structure
|
||||
appPath := getDarwinAppPath()
|
||||
contentsPath := filepath.Join(appPath, "Contents")
|
||||
macOSPath := filepath.Join(contentsPath, "MacOS")
|
||||
|
||||
if err := os.MkdirAll(macOSPath, 0755); err != nil {
|
||||
return fmt.Errorf("failed to create app bundle: %w", err)
|
||||
}
|
||||
|
||||
// Create Info.plist
|
||||
plistPath := filepath.Join(contentsPath, "Info.plist")
|
||||
plistContent := `<?xml version="1.0" encoding="UTF-8"?>
|
||||
<!DOCTYPE plist PUBLIC "-//Apple//DTD PLIST 1.0//EN" "http://www.apple.com/DTDs/PropertyList-1.0.dtd">
|
||||
<plist version="1.0">
|
||||
<dict>
|
||||
<key>CFBundleIdentifier</key>
|
||||
<string>com.cliproxyapi.kiro-oauth-handler</string>
|
||||
<key>CFBundleName</key>
|
||||
<string>KiroOAuthHandler</string>
|
||||
<key>CFBundleExecutable</key>
|
||||
<string>kiro-oauth-handler</string>
|
||||
<key>CFBundleVersion</key>
|
||||
<string>1.0</string>
|
||||
<key>CFBundleURLTypes</key>
|
||||
<array>
|
||||
<dict>
|
||||
<key>CFBundleURLName</key>
|
||||
<string>Kiro Protocol</string>
|
||||
<key>CFBundleURLSchemes</key>
|
||||
<array>
|
||||
<string>kiro</string>
|
||||
</array>
|
||||
</dict>
|
||||
</array>
|
||||
<key>LSBackgroundOnly</key>
|
||||
<true/>
|
||||
</dict>
|
||||
</plist>`
|
||||
|
||||
if err := os.WriteFile(plistPath, []byte(plistContent), 0644); err != nil {
|
||||
return fmt.Errorf("failed to write Info.plist: %w", err)
|
||||
}
|
||||
|
||||
// Create executable script - tries multiple ports to handle dynamic port allocation
|
||||
execPath := filepath.Join(macOSPath, "kiro-oauth-handler")
|
||||
execContent := fmt.Sprintf(`#!/bin/bash
|
||||
# Kiro OAuth Protocol Handler for macOS
|
||||
|
||||
URL="$1"
|
||||
|
||||
# Check curl availability (should always exist on macOS)
|
||||
if [ ! -x /usr/bin/curl ]; then
|
||||
echo "Error: curl is required for Kiro OAuth handler" >&2
|
||||
exit 1
|
||||
fi
|
||||
|
||||
# Extract code and state from URL
|
||||
[[ "$URL" =~ code=([^&]+) ]] && CODE="${BASH_REMATCH[1]}"
|
||||
[[ "$URL" =~ state=([^&]+) ]] && STATE="${BASH_REMATCH[1]}"
|
||||
[[ "$URL" =~ error=([^&]+) ]] && ERROR="${BASH_REMATCH[1]}"
|
||||
|
||||
# Try multiple ports (default + dynamic range)
|
||||
for PORT in %d %d %d %d %d; do
|
||||
if [ -n "$ERROR" ]; then
|
||||
/usr/bin/curl -sf --connect-timeout 1 "http://127.0.0.1:$PORT/oauth/callback?error=$ERROR" && exit 0
|
||||
elif [ -n "$CODE" ] && [ -n "$STATE" ]; then
|
||||
/usr/bin/curl -sf --connect-timeout 1 "http://127.0.0.1:$PORT/oauth/callback?code=$CODE&state=$STATE" && exit 0
|
||||
fi
|
||||
done
|
||||
`, handlerPort, handlerPort+1, handlerPort+2, handlerPort+3, handlerPort+4)
|
||||
|
||||
if err := os.WriteFile(execPath, []byte(execContent), 0755); err != nil {
|
||||
return fmt.Errorf("failed to write executable: %w", err)
|
||||
}
|
||||
|
||||
// Register the app with Launch Services
|
||||
cmd := exec.Command("/System/Library/Frameworks/CoreServices.framework/Frameworks/LaunchServices.framework/Support/lsregister",
|
||||
"-f", appPath)
|
||||
if err := cmd.Run(); err != nil {
|
||||
log.Warnf("lsregister failed (handler may still work): %v", err)
|
||||
}
|
||||
|
||||
log.Info("Kiro protocol handler installed for macOS")
|
||||
return nil
|
||||
}
|
||||
|
||||
func uninstallDarwinHandler() error {
|
||||
appPath := getDarwinAppPath()
|
||||
|
||||
// Unregister from Launch Services
|
||||
cmd := exec.Command("/System/Library/Frameworks/CoreServices.framework/Frameworks/LaunchServices.framework/Support/lsregister",
|
||||
"-u", appPath)
|
||||
_ = cmd.Run()
|
||||
|
||||
// Remove app bundle
|
||||
if err := os.RemoveAll(appPath); err != nil && !os.IsNotExist(err) {
|
||||
return fmt.Errorf("failed to remove app bundle: %w", err)
|
||||
}
|
||||
|
||||
log.Info("Kiro protocol handler uninstalled")
|
||||
return nil
|
||||
}
|
||||
|
||||
// ParseKiroURI parses a kiro:// URI and extracts the callback parameters.
|
||||
func ParseKiroURI(rawURI string) (*AuthCallback, error) {
|
||||
u, err := url.Parse(rawURI)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("invalid URI: %w", err)
|
||||
}
|
||||
|
||||
if u.Scheme != KiroProtocol {
|
||||
return nil, fmt.Errorf("invalid scheme: expected %s, got %s", KiroProtocol, u.Scheme)
|
||||
}
|
||||
|
||||
if u.Host != KiroAuthority {
|
||||
return nil, fmt.Errorf("invalid authority: expected %s, got %s", KiroAuthority, u.Host)
|
||||
}
|
||||
|
||||
query := u.Query()
|
||||
return &AuthCallback{
|
||||
Code: query.Get("code"),
|
||||
State: query.Get("state"),
|
||||
Error: query.Get("error"),
|
||||
}, nil
|
||||
}
|
||||
|
||||
// GetHandlerInstructions returns platform-specific instructions for manual handler setup.
|
||||
func GetHandlerInstructions() string {
|
||||
switch runtime.GOOS {
|
||||
case "linux":
|
||||
return `To manually set up the Kiro protocol handler on Linux:
|
||||
|
||||
1. Create ~/.local/share/applications/kiro-oauth-handler.desktop:
|
||||
[Desktop Entry]
|
||||
Name=Kiro OAuth Handler
|
||||
Exec=~/.local/bin/kiro-oauth-handler %u
|
||||
Type=Application
|
||||
Terminal=false
|
||||
MimeType=x-scheme-handler/kiro;
|
||||
|
||||
2. Create ~/.local/bin/kiro-oauth-handler (make it executable):
|
||||
#!/bin/bash
|
||||
URL="$1"
|
||||
# ... (see generated script for full content)
|
||||
|
||||
3. Run: xdg-mime default kiro-oauth-handler.desktop x-scheme-handler/kiro`
|
||||
|
||||
case "windows":
|
||||
return `To manually set up the Kiro protocol handler on Windows:
|
||||
|
||||
1. Open Registry Editor (regedit.exe)
|
||||
2. Create key: HKEY_CURRENT_USER\Software\Classes\kiro
|
||||
3. Set default value to: URL:Kiro Protocol
|
||||
4. Create string value "URL Protocol" with empty data
|
||||
5. Create subkey: shell\open\command
|
||||
6. Set default value to: "C:\path\to\handler.bat" "%1"`
|
||||
|
||||
case "darwin":
|
||||
return `To manually set up the Kiro protocol handler on macOS:
|
||||
|
||||
1. Create ~/Applications/KiroOAuthHandler.app bundle
|
||||
2. Add Info.plist with CFBundleURLTypes containing "kiro" scheme
|
||||
3. Create executable in Contents/MacOS/
|
||||
4. Run: /System/Library/.../lsregister -f ~/Applications/KiroOAuthHandler.app`
|
||||
|
||||
default:
|
||||
return "Protocol handler setup is not supported on this platform."
|
||||
}
|
||||
}
|
||||
|
||||
// SetupProtocolHandlerIfNeeded checks and installs the protocol handler if needed.
|
||||
func SetupProtocolHandlerIfNeeded(handlerPort int) error {
|
||||
if IsProtocolHandlerInstalled() {
|
||||
log.Debug("Kiro protocol handler already installed")
|
||||
return nil
|
||||
}
|
||||
|
||||
fmt.Println("\n╔══════════════════════════════════════════════════════════╗")
|
||||
fmt.Println("║ Kiro Protocol Handler Setup Required ║")
|
||||
fmt.Println("╚══════════════════════════════════════════════════════════╝")
|
||||
fmt.Println("\nTo enable Google/GitHub login, we need to install a protocol handler.")
|
||||
fmt.Println("This allows your browser to redirect back to the CLI after authentication.")
|
||||
fmt.Println("\nInstalling protocol handler...")
|
||||
|
||||
if err := InstallProtocolHandler(handlerPort); err != nil {
|
||||
fmt.Printf("\n⚠ Automatic installation failed: %v\n", err)
|
||||
fmt.Println("\nManual setup instructions:")
|
||||
fmt.Println(strings.Repeat("-", 60))
|
||||
fmt.Println(GetHandlerInstructions())
|
||||
return err
|
||||
}
|
||||
|
||||
fmt.Println("\n✓ Protocol handler installed successfully!")
|
||||
return nil
|
||||
}
|
||||
316
internal/auth/kiro/rate_limiter.go
Normal file
316
internal/auth/kiro/rate_limiter.go
Normal file
@@ -0,0 +1,316 @@
|
||||
package kiro
|
||||
|
||||
import (
|
||||
"math"
|
||||
"math/rand"
|
||||
"strings"
|
||||
"sync"
|
||||
"time"
|
||||
)
|
||||
|
||||
const (
|
||||
DefaultMinTokenInterval = 1 * time.Second
|
||||
DefaultMaxTokenInterval = 2 * time.Second
|
||||
DefaultDailyMaxRequests = 500
|
||||
DefaultJitterPercent = 0.3
|
||||
DefaultBackoffBase = 30 * time.Second
|
||||
DefaultBackoffMax = 5 * time.Minute
|
||||
DefaultBackoffMultiplier = 1.5
|
||||
DefaultSuspendCooldown = 1 * time.Hour
|
||||
)
|
||||
|
||||
// TokenState Token 状态
|
||||
type TokenState struct {
|
||||
LastRequest time.Time
|
||||
RequestCount int
|
||||
CooldownEnd time.Time
|
||||
FailCount int
|
||||
DailyRequests int
|
||||
DailyResetTime time.Time
|
||||
IsSuspended bool
|
||||
SuspendedAt time.Time
|
||||
SuspendReason string
|
||||
}
|
||||
|
||||
// RateLimiter 频率限制器
|
||||
type RateLimiter struct {
|
||||
mu sync.RWMutex
|
||||
states map[string]*TokenState
|
||||
minTokenInterval time.Duration
|
||||
maxTokenInterval time.Duration
|
||||
dailyMaxRequests int
|
||||
jitterPercent float64
|
||||
backoffBase time.Duration
|
||||
backoffMax time.Duration
|
||||
backoffMultiplier float64
|
||||
suspendCooldown time.Duration
|
||||
rng *rand.Rand
|
||||
}
|
||||
|
||||
// NewRateLimiter 创建默认配置的频率限制器
|
||||
func NewRateLimiter() *RateLimiter {
|
||||
return &RateLimiter{
|
||||
states: make(map[string]*TokenState),
|
||||
minTokenInterval: DefaultMinTokenInterval,
|
||||
maxTokenInterval: DefaultMaxTokenInterval,
|
||||
dailyMaxRequests: DefaultDailyMaxRequests,
|
||||
jitterPercent: DefaultJitterPercent,
|
||||
backoffBase: DefaultBackoffBase,
|
||||
backoffMax: DefaultBackoffMax,
|
||||
backoffMultiplier: DefaultBackoffMultiplier,
|
||||
suspendCooldown: DefaultSuspendCooldown,
|
||||
rng: rand.New(rand.NewSource(time.Now().UnixNano())),
|
||||
}
|
||||
}
|
||||
|
||||
// RateLimiterConfig 频率限制器配置
|
||||
type RateLimiterConfig struct {
|
||||
MinTokenInterval time.Duration
|
||||
MaxTokenInterval time.Duration
|
||||
DailyMaxRequests int
|
||||
JitterPercent float64
|
||||
BackoffBase time.Duration
|
||||
BackoffMax time.Duration
|
||||
BackoffMultiplier float64
|
||||
SuspendCooldown time.Duration
|
||||
}
|
||||
|
||||
// NewRateLimiterWithConfig 使用自定义配置创建频率限制器
|
||||
func NewRateLimiterWithConfig(cfg RateLimiterConfig) *RateLimiter {
|
||||
rl := NewRateLimiter()
|
||||
if cfg.MinTokenInterval > 0 {
|
||||
rl.minTokenInterval = cfg.MinTokenInterval
|
||||
}
|
||||
if cfg.MaxTokenInterval > 0 {
|
||||
rl.maxTokenInterval = cfg.MaxTokenInterval
|
||||
}
|
||||
if cfg.DailyMaxRequests > 0 {
|
||||
rl.dailyMaxRequests = cfg.DailyMaxRequests
|
||||
}
|
||||
if cfg.JitterPercent > 0 {
|
||||
rl.jitterPercent = cfg.JitterPercent
|
||||
}
|
||||
if cfg.BackoffBase > 0 {
|
||||
rl.backoffBase = cfg.BackoffBase
|
||||
}
|
||||
if cfg.BackoffMax > 0 {
|
||||
rl.backoffMax = cfg.BackoffMax
|
||||
}
|
||||
if cfg.BackoffMultiplier > 0 {
|
||||
rl.backoffMultiplier = cfg.BackoffMultiplier
|
||||
}
|
||||
if cfg.SuspendCooldown > 0 {
|
||||
rl.suspendCooldown = cfg.SuspendCooldown
|
||||
}
|
||||
return rl
|
||||
}
|
||||
|
||||
// getOrCreateState 获取或创建 Token 状态
|
||||
func (rl *RateLimiter) getOrCreateState(tokenKey string) *TokenState {
|
||||
state, exists := rl.states[tokenKey]
|
||||
if !exists {
|
||||
state = &TokenState{
|
||||
DailyResetTime: time.Now().Truncate(24 * time.Hour).Add(24 * time.Hour),
|
||||
}
|
||||
rl.states[tokenKey] = state
|
||||
}
|
||||
return state
|
||||
}
|
||||
|
||||
// resetDailyIfNeeded 如果需要则重置每日计数
|
||||
func (rl *RateLimiter) resetDailyIfNeeded(state *TokenState) {
|
||||
now := time.Now()
|
||||
if now.After(state.DailyResetTime) {
|
||||
state.DailyRequests = 0
|
||||
state.DailyResetTime = now.Truncate(24 * time.Hour).Add(24 * time.Hour)
|
||||
}
|
||||
}
|
||||
|
||||
// calculateInterval 计算带抖动的随机间隔
|
||||
func (rl *RateLimiter) calculateInterval() time.Duration {
|
||||
baseInterval := rl.minTokenInterval + time.Duration(rl.rng.Int63n(int64(rl.maxTokenInterval-rl.minTokenInterval)))
|
||||
jitter := time.Duration(float64(baseInterval) * rl.jitterPercent * (rl.rng.Float64()*2 - 1))
|
||||
return baseInterval + jitter
|
||||
}
|
||||
|
||||
// WaitForToken 等待 Token 可用(带抖动的随机间隔)
|
||||
func (rl *RateLimiter) WaitForToken(tokenKey string) {
|
||||
rl.mu.Lock()
|
||||
state := rl.getOrCreateState(tokenKey)
|
||||
rl.resetDailyIfNeeded(state)
|
||||
|
||||
now := time.Now()
|
||||
|
||||
// 检查是否在冷却期
|
||||
if now.Before(state.CooldownEnd) {
|
||||
waitTime := state.CooldownEnd.Sub(now)
|
||||
rl.mu.Unlock()
|
||||
time.Sleep(waitTime)
|
||||
rl.mu.Lock()
|
||||
state = rl.getOrCreateState(tokenKey)
|
||||
now = time.Now()
|
||||
}
|
||||
|
||||
// 计算距离上次请求的间隔
|
||||
interval := rl.calculateInterval()
|
||||
nextAllowedTime := state.LastRequest.Add(interval)
|
||||
|
||||
if now.Before(nextAllowedTime) {
|
||||
waitTime := nextAllowedTime.Sub(now)
|
||||
rl.mu.Unlock()
|
||||
time.Sleep(waitTime)
|
||||
rl.mu.Lock()
|
||||
state = rl.getOrCreateState(tokenKey)
|
||||
}
|
||||
|
||||
state.LastRequest = time.Now()
|
||||
state.RequestCount++
|
||||
state.DailyRequests++
|
||||
rl.mu.Unlock()
|
||||
}
|
||||
|
||||
// MarkTokenFailed 标记 Token 失败
|
||||
func (rl *RateLimiter) MarkTokenFailed(tokenKey string) {
|
||||
rl.mu.Lock()
|
||||
defer rl.mu.Unlock()
|
||||
|
||||
state := rl.getOrCreateState(tokenKey)
|
||||
state.FailCount++
|
||||
state.CooldownEnd = time.Now().Add(rl.calculateBackoff(state.FailCount))
|
||||
}
|
||||
|
||||
// MarkTokenSuccess 标记 Token 成功
|
||||
func (rl *RateLimiter) MarkTokenSuccess(tokenKey string) {
|
||||
rl.mu.Lock()
|
||||
defer rl.mu.Unlock()
|
||||
|
||||
state := rl.getOrCreateState(tokenKey)
|
||||
state.FailCount = 0
|
||||
state.CooldownEnd = time.Time{}
|
||||
}
|
||||
|
||||
// CheckAndMarkSuspended 检测暂停错误并标记
|
||||
func (rl *RateLimiter) CheckAndMarkSuspended(tokenKey string, errorMsg string) bool {
|
||||
suspendKeywords := []string{
|
||||
"suspended",
|
||||
"banned",
|
||||
"disabled",
|
||||
"account has been",
|
||||
"access denied",
|
||||
"rate limit exceeded",
|
||||
"too many requests",
|
||||
"quota exceeded",
|
||||
}
|
||||
|
||||
lowerMsg := strings.ToLower(errorMsg)
|
||||
for _, keyword := range suspendKeywords {
|
||||
if strings.Contains(lowerMsg, keyword) {
|
||||
rl.mu.Lock()
|
||||
defer rl.mu.Unlock()
|
||||
|
||||
state := rl.getOrCreateState(tokenKey)
|
||||
state.IsSuspended = true
|
||||
state.SuspendedAt = time.Now()
|
||||
state.SuspendReason = errorMsg
|
||||
state.CooldownEnd = time.Now().Add(rl.suspendCooldown)
|
||||
return true
|
||||
}
|
||||
}
|
||||
return false
|
||||
}
|
||||
|
||||
// IsTokenAvailable 检查 Token 是否可用
|
||||
func (rl *RateLimiter) IsTokenAvailable(tokenKey string) bool {
|
||||
rl.mu.RLock()
|
||||
defer rl.mu.RUnlock()
|
||||
|
||||
state, exists := rl.states[tokenKey]
|
||||
if !exists {
|
||||
return true
|
||||
}
|
||||
|
||||
now := time.Now()
|
||||
|
||||
// 检查是否被暂停
|
||||
if state.IsSuspended {
|
||||
if now.After(state.SuspendedAt.Add(rl.suspendCooldown)) {
|
||||
return true
|
||||
}
|
||||
return false
|
||||
}
|
||||
|
||||
// 检查是否在冷却期
|
||||
if now.Before(state.CooldownEnd) {
|
||||
return false
|
||||
}
|
||||
|
||||
// 检查每日请求限制
|
||||
rl.mu.RUnlock()
|
||||
rl.mu.Lock()
|
||||
rl.resetDailyIfNeeded(state)
|
||||
dailyRequests := state.DailyRequests
|
||||
dailyMax := rl.dailyMaxRequests
|
||||
rl.mu.Unlock()
|
||||
rl.mu.RLock()
|
||||
|
||||
if dailyRequests >= dailyMax {
|
||||
return false
|
||||
}
|
||||
|
||||
return true
|
||||
}
|
||||
|
||||
// calculateBackoff 计算指数退避时间
|
||||
func (rl *RateLimiter) calculateBackoff(failCount int) time.Duration {
|
||||
if failCount <= 0 {
|
||||
return 0
|
||||
}
|
||||
|
||||
backoff := float64(rl.backoffBase) * math.Pow(rl.backoffMultiplier, float64(failCount-1))
|
||||
|
||||
// 添加抖动
|
||||
jitter := backoff * rl.jitterPercent * (rl.rng.Float64()*2 - 1)
|
||||
backoff += jitter
|
||||
|
||||
if time.Duration(backoff) > rl.backoffMax {
|
||||
return rl.backoffMax
|
||||
}
|
||||
return time.Duration(backoff)
|
||||
}
|
||||
|
||||
// GetTokenState 获取 Token 状态(只读)
|
||||
func (rl *RateLimiter) GetTokenState(tokenKey string) *TokenState {
|
||||
rl.mu.RLock()
|
||||
defer rl.mu.RUnlock()
|
||||
|
||||
state, exists := rl.states[tokenKey]
|
||||
if !exists {
|
||||
return nil
|
||||
}
|
||||
|
||||
// 返回副本以防止外部修改
|
||||
stateCopy := *state
|
||||
return &stateCopy
|
||||
}
|
||||
|
||||
// ClearTokenState 清除 Token 状态
|
||||
func (rl *RateLimiter) ClearTokenState(tokenKey string) {
|
||||
rl.mu.Lock()
|
||||
defer rl.mu.Unlock()
|
||||
delete(rl.states, tokenKey)
|
||||
}
|
||||
|
||||
// ResetSuspension 重置暂停状态
|
||||
func (rl *RateLimiter) ResetSuspension(tokenKey string) {
|
||||
rl.mu.Lock()
|
||||
defer rl.mu.Unlock()
|
||||
|
||||
state, exists := rl.states[tokenKey]
|
||||
if exists {
|
||||
state.IsSuspended = false
|
||||
state.SuspendedAt = time.Time{}
|
||||
state.SuspendReason = ""
|
||||
state.CooldownEnd = time.Time{}
|
||||
state.FailCount = 0
|
||||
}
|
||||
}
|
||||
46
internal/auth/kiro/rate_limiter_singleton.go
Normal file
46
internal/auth/kiro/rate_limiter_singleton.go
Normal file
@@ -0,0 +1,46 @@
|
||||
package kiro
|
||||
|
||||
import (
|
||||
"sync"
|
||||
"time"
|
||||
|
||||
log "github.com/sirupsen/logrus"
|
||||
)
|
||||
|
||||
var (
|
||||
globalRateLimiter *RateLimiter
|
||||
globalRateLimiterOnce sync.Once
|
||||
|
||||
globalCooldownManager *CooldownManager
|
||||
globalCooldownManagerOnce sync.Once
|
||||
cooldownStopCh chan struct{}
|
||||
)
|
||||
|
||||
// GetGlobalRateLimiter returns the singleton RateLimiter instance.
|
||||
func GetGlobalRateLimiter() *RateLimiter {
|
||||
globalRateLimiterOnce.Do(func() {
|
||||
globalRateLimiter = NewRateLimiter()
|
||||
log.Info("kiro: global RateLimiter initialized")
|
||||
})
|
||||
return globalRateLimiter
|
||||
}
|
||||
|
||||
// GetGlobalCooldownManager returns the singleton CooldownManager instance.
|
||||
func GetGlobalCooldownManager() *CooldownManager {
|
||||
globalCooldownManagerOnce.Do(func() {
|
||||
globalCooldownManager = NewCooldownManager()
|
||||
cooldownStopCh = make(chan struct{})
|
||||
go globalCooldownManager.StartCleanupRoutine(5*time.Minute, cooldownStopCh)
|
||||
log.Info("kiro: global CooldownManager initialized with cleanup routine")
|
||||
})
|
||||
return globalCooldownManager
|
||||
}
|
||||
|
||||
// ShutdownRateLimiters stops the cooldown cleanup routine.
|
||||
// Should be called during application shutdown.
|
||||
func ShutdownRateLimiters() {
|
||||
if cooldownStopCh != nil {
|
||||
close(cooldownStopCh)
|
||||
log.Info("kiro: rate limiter cleanup routine stopped")
|
||||
}
|
||||
}
|
||||
304
internal/auth/kiro/rate_limiter_test.go
Normal file
304
internal/auth/kiro/rate_limiter_test.go
Normal file
@@ -0,0 +1,304 @@
|
||||
package kiro
|
||||
|
||||
import (
|
||||
"sync"
|
||||
"testing"
|
||||
"time"
|
||||
)
|
||||
|
||||
func TestNewRateLimiter(t *testing.T) {
|
||||
rl := NewRateLimiter()
|
||||
if rl == nil {
|
||||
t.Fatal("expected non-nil RateLimiter")
|
||||
}
|
||||
if rl.states == nil {
|
||||
t.Error("expected non-nil states map")
|
||||
}
|
||||
if rl.minTokenInterval != DefaultMinTokenInterval {
|
||||
t.Errorf("expected minTokenInterval %v, got %v", DefaultMinTokenInterval, rl.minTokenInterval)
|
||||
}
|
||||
if rl.maxTokenInterval != DefaultMaxTokenInterval {
|
||||
t.Errorf("expected maxTokenInterval %v, got %v", DefaultMaxTokenInterval, rl.maxTokenInterval)
|
||||
}
|
||||
if rl.dailyMaxRequests != DefaultDailyMaxRequests {
|
||||
t.Errorf("expected dailyMaxRequests %d, got %d", DefaultDailyMaxRequests, rl.dailyMaxRequests)
|
||||
}
|
||||
}
|
||||
|
||||
func TestNewRateLimiterWithConfig(t *testing.T) {
|
||||
cfg := RateLimiterConfig{
|
||||
MinTokenInterval: 5 * time.Second,
|
||||
MaxTokenInterval: 15 * time.Second,
|
||||
DailyMaxRequests: 100,
|
||||
JitterPercent: 0.2,
|
||||
BackoffBase: 1 * time.Minute,
|
||||
BackoffMax: 30 * time.Minute,
|
||||
BackoffMultiplier: 1.5,
|
||||
SuspendCooldown: 12 * time.Hour,
|
||||
}
|
||||
|
||||
rl := NewRateLimiterWithConfig(cfg)
|
||||
if rl.minTokenInterval != 5*time.Second {
|
||||
t.Errorf("expected minTokenInterval 5s, got %v", rl.minTokenInterval)
|
||||
}
|
||||
if rl.maxTokenInterval != 15*time.Second {
|
||||
t.Errorf("expected maxTokenInterval 15s, got %v", rl.maxTokenInterval)
|
||||
}
|
||||
if rl.dailyMaxRequests != 100 {
|
||||
t.Errorf("expected dailyMaxRequests 100, got %d", rl.dailyMaxRequests)
|
||||
}
|
||||
}
|
||||
|
||||
func TestNewRateLimiterWithConfig_PartialConfig(t *testing.T) {
|
||||
cfg := RateLimiterConfig{
|
||||
MinTokenInterval: 5 * time.Second,
|
||||
}
|
||||
|
||||
rl := NewRateLimiterWithConfig(cfg)
|
||||
if rl.minTokenInterval != 5*time.Second {
|
||||
t.Errorf("expected minTokenInterval 5s, got %v", rl.minTokenInterval)
|
||||
}
|
||||
if rl.maxTokenInterval != DefaultMaxTokenInterval {
|
||||
t.Errorf("expected default maxTokenInterval, got %v", rl.maxTokenInterval)
|
||||
}
|
||||
}
|
||||
|
||||
func TestGetTokenState_NonExistent(t *testing.T) {
|
||||
rl := NewRateLimiter()
|
||||
state := rl.GetTokenState("nonexistent")
|
||||
if state != nil {
|
||||
t.Error("expected nil state for non-existent token")
|
||||
}
|
||||
}
|
||||
|
||||
func TestIsTokenAvailable_NewToken(t *testing.T) {
|
||||
rl := NewRateLimiter()
|
||||
if !rl.IsTokenAvailable("newtoken") {
|
||||
t.Error("expected new token to be available")
|
||||
}
|
||||
}
|
||||
|
||||
func TestMarkTokenFailed(t *testing.T) {
|
||||
rl := NewRateLimiter()
|
||||
rl.MarkTokenFailed("token1")
|
||||
|
||||
state := rl.GetTokenState("token1")
|
||||
if state == nil {
|
||||
t.Fatal("expected non-nil state")
|
||||
}
|
||||
if state.FailCount != 1 {
|
||||
t.Errorf("expected FailCount 1, got %d", state.FailCount)
|
||||
}
|
||||
if state.CooldownEnd.IsZero() {
|
||||
t.Error("expected non-zero CooldownEnd")
|
||||
}
|
||||
}
|
||||
|
||||
func TestMarkTokenSuccess(t *testing.T) {
|
||||
rl := NewRateLimiter()
|
||||
rl.MarkTokenFailed("token1")
|
||||
rl.MarkTokenFailed("token1")
|
||||
rl.MarkTokenSuccess("token1")
|
||||
|
||||
state := rl.GetTokenState("token1")
|
||||
if state == nil {
|
||||
t.Fatal("expected non-nil state")
|
||||
}
|
||||
if state.FailCount != 0 {
|
||||
t.Errorf("expected FailCount 0, got %d", state.FailCount)
|
||||
}
|
||||
if !state.CooldownEnd.IsZero() {
|
||||
t.Error("expected zero CooldownEnd after success")
|
||||
}
|
||||
}
|
||||
|
||||
func TestCheckAndMarkSuspended_Suspended(t *testing.T) {
|
||||
rl := NewRateLimiter()
|
||||
|
||||
testCases := []string{
|
||||
"Account has been suspended",
|
||||
"You are banned from this service",
|
||||
"Account disabled",
|
||||
"Access denied permanently",
|
||||
"Rate limit exceeded",
|
||||
"Too many requests",
|
||||
"Quota exceeded for today",
|
||||
}
|
||||
|
||||
for i, msg := range testCases {
|
||||
tokenKey := "token" + string(rune('a'+i))
|
||||
if !rl.CheckAndMarkSuspended(tokenKey, msg) {
|
||||
t.Errorf("expected suspension detected for: %s", msg)
|
||||
}
|
||||
state := rl.GetTokenState(tokenKey)
|
||||
if !state.IsSuspended {
|
||||
t.Errorf("expected IsSuspended true for: %s", msg)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func TestCheckAndMarkSuspended_NotSuspended(t *testing.T) {
|
||||
rl := NewRateLimiter()
|
||||
|
||||
normalErrors := []string{
|
||||
"connection timeout",
|
||||
"internal server error",
|
||||
"bad request",
|
||||
"invalid token format",
|
||||
}
|
||||
|
||||
for i, msg := range normalErrors {
|
||||
tokenKey := "token" + string(rune('a'+i))
|
||||
if rl.CheckAndMarkSuspended(tokenKey, msg) {
|
||||
t.Errorf("unexpected suspension for: %s", msg)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func TestIsTokenAvailable_Suspended(t *testing.T) {
|
||||
rl := NewRateLimiter()
|
||||
rl.CheckAndMarkSuspended("token1", "Account suspended")
|
||||
|
||||
if rl.IsTokenAvailable("token1") {
|
||||
t.Error("expected suspended token to be unavailable")
|
||||
}
|
||||
}
|
||||
|
||||
func TestClearTokenState(t *testing.T) {
|
||||
rl := NewRateLimiter()
|
||||
rl.MarkTokenFailed("token1")
|
||||
rl.ClearTokenState("token1")
|
||||
|
||||
state := rl.GetTokenState("token1")
|
||||
if state != nil {
|
||||
t.Error("expected nil state after clear")
|
||||
}
|
||||
}
|
||||
|
||||
func TestResetSuspension(t *testing.T) {
|
||||
rl := NewRateLimiter()
|
||||
rl.CheckAndMarkSuspended("token1", "Account suspended")
|
||||
rl.ResetSuspension("token1")
|
||||
|
||||
state := rl.GetTokenState("token1")
|
||||
if state.IsSuspended {
|
||||
t.Error("expected IsSuspended false after reset")
|
||||
}
|
||||
if state.FailCount != 0 {
|
||||
t.Errorf("expected FailCount 0, got %d", state.FailCount)
|
||||
}
|
||||
}
|
||||
|
||||
func TestResetSuspension_NonExistent(t *testing.T) {
|
||||
rl := NewRateLimiter()
|
||||
rl.ResetSuspension("nonexistent")
|
||||
}
|
||||
|
||||
func TestCalculateBackoff_ZeroFailCount(t *testing.T) {
|
||||
rl := NewRateLimiter()
|
||||
backoff := rl.calculateBackoff(0)
|
||||
if backoff != 0 {
|
||||
t.Errorf("expected 0 backoff for 0 fails, got %v", backoff)
|
||||
}
|
||||
}
|
||||
|
||||
func TestCalculateBackoff_Exponential(t *testing.T) {
|
||||
cfg := RateLimiterConfig{
|
||||
BackoffBase: 1 * time.Minute,
|
||||
BackoffMax: 60 * time.Minute,
|
||||
BackoffMultiplier: 2.0,
|
||||
JitterPercent: 0.3,
|
||||
}
|
||||
rl := NewRateLimiterWithConfig(cfg)
|
||||
|
||||
backoff1 := rl.calculateBackoff(1)
|
||||
if backoff1 < 40*time.Second || backoff1 > 80*time.Second {
|
||||
t.Errorf("expected ~1min (with jitter) for fail 1, got %v", backoff1)
|
||||
}
|
||||
|
||||
backoff2 := rl.calculateBackoff(2)
|
||||
if backoff2 < 80*time.Second || backoff2 > 160*time.Second {
|
||||
t.Errorf("expected ~2min (with jitter) for fail 2, got %v", backoff2)
|
||||
}
|
||||
}
|
||||
|
||||
func TestCalculateBackoff_MaxCap(t *testing.T) {
|
||||
cfg := RateLimiterConfig{
|
||||
BackoffBase: 1 * time.Minute,
|
||||
BackoffMax: 10 * time.Minute,
|
||||
BackoffMultiplier: 2.0,
|
||||
JitterPercent: 0,
|
||||
}
|
||||
rl := NewRateLimiterWithConfig(cfg)
|
||||
|
||||
backoff := rl.calculateBackoff(10)
|
||||
if backoff > 10*time.Minute {
|
||||
t.Errorf("expected backoff capped at 10min, got %v", backoff)
|
||||
}
|
||||
}
|
||||
|
||||
func TestGetTokenState_ReturnsCopy(t *testing.T) {
|
||||
rl := NewRateLimiter()
|
||||
rl.MarkTokenFailed("token1")
|
||||
|
||||
state1 := rl.GetTokenState("token1")
|
||||
state1.FailCount = 999
|
||||
|
||||
state2 := rl.GetTokenState("token1")
|
||||
if state2.FailCount == 999 {
|
||||
t.Error("GetTokenState should return a copy")
|
||||
}
|
||||
}
|
||||
|
||||
func TestRateLimiter_ConcurrentAccess(t *testing.T) {
|
||||
rl := NewRateLimiter()
|
||||
const numGoroutines = 50
|
||||
const numOperations = 50
|
||||
|
||||
var wg sync.WaitGroup
|
||||
wg.Add(numGoroutines)
|
||||
|
||||
for i := 0; i < numGoroutines; i++ {
|
||||
go func(id int) {
|
||||
defer wg.Done()
|
||||
tokenKey := "token" + string(rune('a'+id%10))
|
||||
for j := 0; j < numOperations; j++ {
|
||||
switch j % 6 {
|
||||
case 0:
|
||||
rl.IsTokenAvailable(tokenKey)
|
||||
case 1:
|
||||
rl.MarkTokenFailed(tokenKey)
|
||||
case 2:
|
||||
rl.MarkTokenSuccess(tokenKey)
|
||||
case 3:
|
||||
rl.GetTokenState(tokenKey)
|
||||
case 4:
|
||||
rl.CheckAndMarkSuspended(tokenKey, "test error")
|
||||
case 5:
|
||||
rl.ResetSuspension(tokenKey)
|
||||
}
|
||||
}
|
||||
}(i)
|
||||
}
|
||||
|
||||
wg.Wait()
|
||||
}
|
||||
|
||||
func TestCalculateInterval_WithinRange(t *testing.T) {
|
||||
cfg := RateLimiterConfig{
|
||||
MinTokenInterval: 10 * time.Second,
|
||||
MaxTokenInterval: 30 * time.Second,
|
||||
JitterPercent: 0.3,
|
||||
}
|
||||
rl := NewRateLimiterWithConfig(cfg)
|
||||
|
||||
minAllowed := 7 * time.Second
|
||||
maxAllowed := 40 * time.Second
|
||||
|
||||
for i := 0; i < 100; i++ {
|
||||
interval := rl.calculateInterval()
|
||||
if interval < minAllowed || interval > maxAllowed {
|
||||
t.Errorf("interval %v outside expected range [%v, %v]", interval, minAllowed, maxAllowed)
|
||||
}
|
||||
}
|
||||
}
|
||||
202
internal/auth/kiro/refresh_manager.go
Normal file
202
internal/auth/kiro/refresh_manager.go
Normal file
@@ -0,0 +1,202 @@
|
||||
package kiro
|
||||
|
||||
import (
|
||||
"context"
|
||||
"sync"
|
||||
"time"
|
||||
|
||||
"github.com/router-for-me/CLIProxyAPI/v6/internal/config"
|
||||
"github.com/router-for-me/CLIProxyAPI/v6/internal/util"
|
||||
log "github.com/sirupsen/logrus"
|
||||
)
|
||||
|
||||
// RefreshManager is a singleton manager for background token refreshing.
|
||||
type RefreshManager struct {
|
||||
mu sync.Mutex
|
||||
refresher *BackgroundRefresher
|
||||
ctx context.Context
|
||||
cancel context.CancelFunc
|
||||
started bool
|
||||
onTokenRefreshed func(tokenID string, tokenData *KiroTokenData)
|
||||
}
|
||||
|
||||
var (
|
||||
globalRefreshManager *RefreshManager
|
||||
managerOnce sync.Once
|
||||
)
|
||||
|
||||
// GetRefreshManager returns the global RefreshManager singleton.
|
||||
func GetRefreshManager() *RefreshManager {
|
||||
managerOnce.Do(func() {
|
||||
globalRefreshManager = &RefreshManager{}
|
||||
})
|
||||
return globalRefreshManager
|
||||
}
|
||||
|
||||
// Initialize sets up the background refresher.
|
||||
func (m *RefreshManager) Initialize(baseDir string, cfg *config.Config) error {
|
||||
m.mu.Lock()
|
||||
defer m.mu.Unlock()
|
||||
|
||||
if m.started {
|
||||
log.Debug("refresh manager: already initialized")
|
||||
return nil
|
||||
}
|
||||
|
||||
if baseDir == "" {
|
||||
log.Warn("refresh manager: base directory not provided, skipping initialization")
|
||||
return nil
|
||||
}
|
||||
|
||||
resolvedBaseDir, err := util.ResolveAuthDir(baseDir)
|
||||
if err != nil {
|
||||
log.Warnf("refresh manager: failed to resolve auth directory %s: %v", baseDir, err)
|
||||
}
|
||||
if resolvedBaseDir != "" {
|
||||
baseDir = resolvedBaseDir
|
||||
}
|
||||
|
||||
repo := NewFileTokenRepository(baseDir)
|
||||
|
||||
opts := []RefresherOption{
|
||||
WithInterval(time.Minute),
|
||||
WithBatchSize(50),
|
||||
WithConcurrency(10),
|
||||
WithConfig(cfg),
|
||||
}
|
||||
|
||||
// Pass callback to BackgroundRefresher if already set
|
||||
if m.onTokenRefreshed != nil {
|
||||
opts = append(opts, WithOnTokenRefreshed(m.onTokenRefreshed))
|
||||
}
|
||||
|
||||
m.refresher = NewBackgroundRefresher(repo, opts...)
|
||||
|
||||
log.Infof("refresh manager: initialized with base directory %s", baseDir)
|
||||
return nil
|
||||
}
|
||||
|
||||
// Start begins background token refreshing.
|
||||
func (m *RefreshManager) Start() {
|
||||
m.mu.Lock()
|
||||
defer m.mu.Unlock()
|
||||
|
||||
if m.started {
|
||||
log.Debug("refresh manager: already started")
|
||||
return
|
||||
}
|
||||
|
||||
if m.refresher == nil {
|
||||
log.Warn("refresh manager: not initialized, cannot start")
|
||||
return
|
||||
}
|
||||
|
||||
m.ctx, m.cancel = context.WithCancel(context.Background())
|
||||
m.refresher.Start(m.ctx)
|
||||
m.started = true
|
||||
|
||||
log.Info("refresh manager: background refresh started")
|
||||
}
|
||||
|
||||
// Stop halts background token refreshing.
|
||||
func (m *RefreshManager) Stop() {
|
||||
m.mu.Lock()
|
||||
defer m.mu.Unlock()
|
||||
|
||||
if !m.started {
|
||||
return
|
||||
}
|
||||
|
||||
if m.cancel != nil {
|
||||
m.cancel()
|
||||
}
|
||||
|
||||
if m.refresher != nil {
|
||||
m.refresher.Stop()
|
||||
}
|
||||
|
||||
m.started = false
|
||||
log.Info("refresh manager: background refresh stopped")
|
||||
}
|
||||
|
||||
// IsRunning reports whether background refreshing is active.
|
||||
func (m *RefreshManager) IsRunning() bool {
|
||||
m.mu.Lock()
|
||||
defer m.mu.Unlock()
|
||||
return m.started
|
||||
}
|
||||
|
||||
// UpdateBaseDir changes the token directory at runtime.
|
||||
func (m *RefreshManager) UpdateBaseDir(baseDir string) {
|
||||
m.mu.Lock()
|
||||
defer m.mu.Unlock()
|
||||
|
||||
if m.refresher != nil && m.refresher.tokenRepo != nil {
|
||||
if repo, ok := m.refresher.tokenRepo.(*FileTokenRepository); ok {
|
||||
repo.SetBaseDir(baseDir)
|
||||
log.Infof("refresh manager: updated base directory to %s", baseDir)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// SetOnTokenRefreshed registers a callback invoked after a successful token refresh.
|
||||
// Can be called at any time; supports runtime callback updates.
|
||||
func (m *RefreshManager) SetOnTokenRefreshed(callback func(tokenID string, tokenData *KiroTokenData)) {
|
||||
m.mu.Lock()
|
||||
defer m.mu.Unlock()
|
||||
|
||||
m.onTokenRefreshed = callback
|
||||
|
||||
// Update the refresher's callback in a thread-safe manner if already created
|
||||
if m.refresher != nil {
|
||||
m.refresher.callbackMu.Lock()
|
||||
m.refresher.onTokenRefreshed = callback
|
||||
m.refresher.callbackMu.Unlock()
|
||||
}
|
||||
|
||||
log.Debug("refresh manager: token refresh callback registered")
|
||||
}
|
||||
|
||||
// InitializeAndStart initializes and starts background refreshing (convenience method).
|
||||
func InitializeAndStart(baseDir string, cfg *config.Config) {
|
||||
// Initialize global fingerprint config
|
||||
initGlobalFingerprintConfig(cfg)
|
||||
|
||||
manager := GetRefreshManager()
|
||||
if err := manager.Initialize(baseDir, cfg); err != nil {
|
||||
log.Errorf("refresh manager: initialization failed: %v", err)
|
||||
return
|
||||
}
|
||||
manager.Start()
|
||||
}
|
||||
|
||||
// initGlobalFingerprintConfig loads fingerprint settings from application config.
|
||||
func initGlobalFingerprintConfig(cfg *config.Config) {
|
||||
if cfg == nil || cfg.KiroFingerprint == nil {
|
||||
return
|
||||
}
|
||||
fpCfg := cfg.KiroFingerprint
|
||||
SetGlobalFingerprintConfig(&FingerprintConfig{
|
||||
OIDCSDKVersion: fpCfg.OIDCSDKVersion,
|
||||
RuntimeSDKVersion: fpCfg.RuntimeSDKVersion,
|
||||
StreamingSDKVersion: fpCfg.StreamingSDKVersion,
|
||||
OSType: fpCfg.OSType,
|
||||
OSVersion: fpCfg.OSVersion,
|
||||
NodeVersion: fpCfg.NodeVersion,
|
||||
KiroVersion: fpCfg.KiroVersion,
|
||||
KiroHash: fpCfg.KiroHash,
|
||||
})
|
||||
log.Debug("kiro: global fingerprint config loaded")
|
||||
}
|
||||
|
||||
// InitFingerprintConfig initializes the global fingerprint config from application config.
|
||||
func InitFingerprintConfig(cfg *config.Config) {
|
||||
initGlobalFingerprintConfig(cfg)
|
||||
}
|
||||
|
||||
// StopGlobalRefreshManager stops the global refresh manager.
|
||||
func StopGlobalRefreshManager() {
|
||||
if globalRefreshManager != nil {
|
||||
globalRefreshManager.Stop()
|
||||
}
|
||||
}
|
||||
159
internal/auth/kiro/refresh_utils.go
Normal file
159
internal/auth/kiro/refresh_utils.go
Normal file
@@ -0,0 +1,159 @@
|
||||
// Package kiro provides refresh utilities for Kiro token management.
|
||||
package kiro
|
||||
|
||||
import (
|
||||
"context"
|
||||
"fmt"
|
||||
"time"
|
||||
|
||||
log "github.com/sirupsen/logrus"
|
||||
)
|
||||
|
||||
// RefreshResult contains the result of a token refresh attempt.
|
||||
type RefreshResult struct {
|
||||
TokenData *KiroTokenData
|
||||
Error error
|
||||
UsedFallback bool // True if we used the existing token as fallback
|
||||
}
|
||||
|
||||
// RefreshWithGracefulDegradation attempts to refresh a token with graceful degradation.
|
||||
// If refresh fails but the existing access token is still valid, it returns the existing token.
|
||||
// This matches kiro-openai-gateway's behavior for better reliability.
|
||||
//
|
||||
// Parameters:
|
||||
// - ctx: Context for the request
|
||||
// - refreshFunc: Function to perform the actual refresh
|
||||
// - existingAccessToken: Current access token (for fallback)
|
||||
// - expiresAt: Expiration time of the existing token
|
||||
//
|
||||
// Returns:
|
||||
// - RefreshResult containing the new or existing token data
|
||||
func RefreshWithGracefulDegradation(
|
||||
ctx context.Context,
|
||||
refreshFunc func(ctx context.Context) (*KiroTokenData, error),
|
||||
existingAccessToken string,
|
||||
expiresAt time.Time,
|
||||
) RefreshResult {
|
||||
// Try to refresh the token
|
||||
newTokenData, err := refreshFunc(ctx)
|
||||
if err == nil {
|
||||
return RefreshResult{
|
||||
TokenData: newTokenData,
|
||||
Error: nil,
|
||||
UsedFallback: false,
|
||||
}
|
||||
}
|
||||
|
||||
// Refresh failed - check if we can use the existing token
|
||||
log.Warnf("kiro: token refresh failed: %v", err)
|
||||
|
||||
// Check if existing token is still valid (not expired)
|
||||
if existingAccessToken != "" && time.Now().Before(expiresAt) {
|
||||
remainingTime := time.Until(expiresAt)
|
||||
log.Warnf("kiro: using existing access token (expires in %v). Will retry refresh later.", remainingTime.Round(time.Second))
|
||||
|
||||
return RefreshResult{
|
||||
TokenData: &KiroTokenData{
|
||||
AccessToken: existingAccessToken,
|
||||
ExpiresAt: expiresAt.Format(time.RFC3339),
|
||||
},
|
||||
Error: nil,
|
||||
UsedFallback: true,
|
||||
}
|
||||
}
|
||||
|
||||
// Token is expired and refresh failed - return the error
|
||||
return RefreshResult{
|
||||
TokenData: nil,
|
||||
Error: fmt.Errorf("token refresh failed and existing token is expired: %w", err),
|
||||
UsedFallback: false,
|
||||
}
|
||||
}
|
||||
|
||||
// IsTokenExpiringSoon checks if a token is expiring within the given threshold.
|
||||
// Default threshold is 5 minutes if not specified.
|
||||
func IsTokenExpiringSoon(expiresAt time.Time, threshold time.Duration) bool {
|
||||
if threshold == 0 {
|
||||
threshold = 5 * time.Minute
|
||||
}
|
||||
return time.Now().Add(threshold).After(expiresAt)
|
||||
}
|
||||
|
||||
// IsTokenExpired checks if a token has already expired.
|
||||
func IsTokenExpired(expiresAt time.Time) bool {
|
||||
return time.Now().After(expiresAt)
|
||||
}
|
||||
|
||||
// ParseExpiresAt parses an expiration time string in RFC3339 format.
|
||||
// Returns zero time if parsing fails.
|
||||
func ParseExpiresAt(expiresAtStr string) time.Time {
|
||||
if expiresAtStr == "" {
|
||||
return time.Time{}
|
||||
}
|
||||
t, err := time.Parse(time.RFC3339, expiresAtStr)
|
||||
if err != nil {
|
||||
log.Debugf("kiro: failed to parse expiresAt '%s': %v", expiresAtStr, err)
|
||||
return time.Time{}
|
||||
}
|
||||
return t
|
||||
}
|
||||
|
||||
// RefreshConfig contains configuration for token refresh behavior.
|
||||
type RefreshConfig struct {
|
||||
// MaxRetries is the maximum number of refresh attempts (default: 1)
|
||||
MaxRetries int
|
||||
// RetryDelay is the delay between retry attempts (default: 1 second)
|
||||
RetryDelay time.Duration
|
||||
// RefreshThreshold is how early to refresh before expiration (default: 5 minutes)
|
||||
RefreshThreshold time.Duration
|
||||
// EnableGracefulDegradation allows using existing token if refresh fails (default: true)
|
||||
EnableGracefulDegradation bool
|
||||
}
|
||||
|
||||
// DefaultRefreshConfig returns the default refresh configuration.
|
||||
func DefaultRefreshConfig() RefreshConfig {
|
||||
return RefreshConfig{
|
||||
MaxRetries: 1,
|
||||
RetryDelay: time.Second,
|
||||
RefreshThreshold: 5 * time.Minute,
|
||||
EnableGracefulDegradation: true,
|
||||
}
|
||||
}
|
||||
|
||||
// RefreshWithRetry attempts to refresh a token with retry logic.
|
||||
func RefreshWithRetry(
|
||||
ctx context.Context,
|
||||
refreshFunc func(ctx context.Context) (*KiroTokenData, error),
|
||||
config RefreshConfig,
|
||||
) (*KiroTokenData, error) {
|
||||
var lastErr error
|
||||
|
||||
maxAttempts := config.MaxRetries + 1
|
||||
if maxAttempts < 1 {
|
||||
maxAttempts = 1
|
||||
}
|
||||
|
||||
for attempt := 1; attempt <= maxAttempts; attempt++ {
|
||||
tokenData, err := refreshFunc(ctx)
|
||||
if err == nil {
|
||||
if attempt > 1 {
|
||||
log.Infof("kiro: token refresh succeeded on attempt %d", attempt)
|
||||
}
|
||||
return tokenData, nil
|
||||
}
|
||||
|
||||
lastErr = err
|
||||
log.Warnf("kiro: token refresh attempt %d/%d failed: %v", attempt, maxAttempts, err)
|
||||
|
||||
// Don't sleep after the last attempt
|
||||
if attempt < maxAttempts {
|
||||
select {
|
||||
case <-ctx.Done():
|
||||
return nil, ctx.Err()
|
||||
case <-time.After(config.RetryDelay):
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
return nil, fmt.Errorf("token refresh failed after %d attempts: %w", maxAttempts, lastErr)
|
||||
}
|
||||
488
internal/auth/kiro/social_auth.go
Normal file
488
internal/auth/kiro/social_auth.go
Normal file
@@ -0,0 +1,488 @@
|
||||
// Package kiro provides social authentication (Google/GitHub) for Kiro via AuthServiceClient.
|
||||
package kiro
|
||||
|
||||
import (
|
||||
"bufio"
|
||||
"context"
|
||||
"crypto/rand"
|
||||
"crypto/sha256"
|
||||
"encoding/base64"
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"html"
|
||||
"io"
|
||||
"net"
|
||||
"net/http"
|
||||
"net/url"
|
||||
"os"
|
||||
"os/exec"
|
||||
"runtime"
|
||||
"strings"
|
||||
"time"
|
||||
|
||||
"github.com/router-for-me/CLIProxyAPI/v6/internal/browser"
|
||||
"github.com/router-for-me/CLIProxyAPI/v6/internal/config"
|
||||
"github.com/router-for-me/CLIProxyAPI/v6/internal/util"
|
||||
log "github.com/sirupsen/logrus"
|
||||
"golang.org/x/term"
|
||||
)
|
||||
|
||||
const (
|
||||
// Kiro AuthService endpoint
|
||||
kiroAuthServiceEndpoint = "https://prod.us-east-1.auth.desktop.kiro.dev"
|
||||
|
||||
// OAuth timeout
|
||||
socialAuthTimeout = 10 * time.Minute
|
||||
|
||||
// Default callback port for social auth HTTP server
|
||||
socialAuthCallbackPort = 9876
|
||||
)
|
||||
|
||||
// SocialProvider represents the social login provider.
|
||||
type SocialProvider string
|
||||
|
||||
const (
|
||||
// ProviderGoogle is Google OAuth provider
|
||||
ProviderGoogle SocialProvider = "Google"
|
||||
// ProviderGitHub is GitHub OAuth provider
|
||||
ProviderGitHub SocialProvider = "Github"
|
||||
// Note: AWS Builder ID is NOT supported by Kiro's auth service.
|
||||
// It only supports: Google, Github, Cognito
|
||||
// AWS Builder ID must use device code flow via SSO OIDC.
|
||||
)
|
||||
|
||||
// CreateTokenRequest is sent to Kiro's /oauth/token endpoint.
|
||||
type CreateTokenRequest struct {
|
||||
Code string `json:"code"`
|
||||
CodeVerifier string `json:"code_verifier"`
|
||||
RedirectURI string `json:"redirect_uri"`
|
||||
InvitationCode string `json:"invitation_code,omitempty"`
|
||||
}
|
||||
|
||||
// SocialTokenResponse from Kiro's /oauth/token endpoint for social auth.
|
||||
type SocialTokenResponse struct {
|
||||
AccessToken string `json:"accessToken"`
|
||||
RefreshToken string `json:"refreshToken"`
|
||||
ProfileArn string `json:"profileArn"`
|
||||
ExpiresIn int `json:"expiresIn"`
|
||||
}
|
||||
|
||||
// RefreshTokenRequest is sent to Kiro's /refreshToken endpoint.
|
||||
type RefreshTokenRequest struct {
|
||||
RefreshToken string `json:"refreshToken"`
|
||||
}
|
||||
|
||||
// WebCallbackResult contains the OAuth callback result from HTTP server.
|
||||
type WebCallbackResult struct {
|
||||
Code string
|
||||
State string
|
||||
Error string
|
||||
}
|
||||
|
||||
// SocialAuthClient handles social authentication with Kiro.
|
||||
type SocialAuthClient struct {
|
||||
httpClient *http.Client
|
||||
cfg *config.Config
|
||||
protocolHandler *ProtocolHandler
|
||||
machineID string
|
||||
kiroVersion string
|
||||
}
|
||||
|
||||
// NewSocialAuthClient creates a new social auth client.
|
||||
func NewSocialAuthClient(cfg *config.Config) *SocialAuthClient {
|
||||
client := &http.Client{Timeout: 30 * time.Second}
|
||||
if cfg != nil {
|
||||
client = util.SetProxy(&cfg.SDKConfig, client)
|
||||
}
|
||||
fp := GlobalFingerprintManager().GetFingerprint("login")
|
||||
return &SocialAuthClient{
|
||||
httpClient: client,
|
||||
cfg: cfg,
|
||||
protocolHandler: NewProtocolHandler(),
|
||||
machineID: fp.KiroHash,
|
||||
kiroVersion: fp.KiroVersion,
|
||||
}
|
||||
}
|
||||
|
||||
// startWebCallbackServer starts a local HTTP server to receive the OAuth callback.
|
||||
// This is used instead of the kiro:// protocol handler to avoid redirect_mismatch errors.
|
||||
func (c *SocialAuthClient) startWebCallbackServer(ctx context.Context, expectedState string) (string, <-chan WebCallbackResult, error) {
|
||||
// Try to find an available port - use localhost like Kiro does
|
||||
listener, err := net.Listen("tcp", fmt.Sprintf("localhost:%d", socialAuthCallbackPort))
|
||||
if err != nil {
|
||||
// Try with dynamic port (RFC 8252 allows dynamic ports for native apps)
|
||||
log.Warnf("kiro social auth: default port %d is busy, falling back to dynamic port", socialAuthCallbackPort)
|
||||
listener, err = net.Listen("tcp", "localhost:0")
|
||||
if err != nil {
|
||||
return "", nil, fmt.Errorf("failed to start callback server: %w", err)
|
||||
}
|
||||
}
|
||||
|
||||
port := listener.Addr().(*net.TCPAddr).Port
|
||||
// Use http scheme for local callback server
|
||||
redirectURI := fmt.Sprintf("http://localhost:%d/oauth/callback", port)
|
||||
resultChan := make(chan WebCallbackResult, 1)
|
||||
|
||||
server := &http.Server{
|
||||
ReadHeaderTimeout: 10 * time.Second,
|
||||
}
|
||||
|
||||
mux := http.NewServeMux()
|
||||
mux.HandleFunc("/oauth/callback", func(w http.ResponseWriter, r *http.Request) {
|
||||
code := r.URL.Query().Get("code")
|
||||
state := r.URL.Query().Get("state")
|
||||
errParam := r.URL.Query().Get("error")
|
||||
|
||||
if errParam != "" {
|
||||
w.Header().Set("Content-Type", "text/html; charset=utf-8")
|
||||
w.WriteHeader(http.StatusBadRequest)
|
||||
fmt.Fprintf(w, `<!DOCTYPE html>
|
||||
<html><head><title>Login Failed</title></head>
|
||||
<body><h1>Login Failed</h1><p>%s</p><p>You can close this window.</p></body></html>`, html.EscapeString(errParam))
|
||||
resultChan <- WebCallbackResult{Error: errParam}
|
||||
return
|
||||
}
|
||||
|
||||
if state != expectedState {
|
||||
w.Header().Set("Content-Type", "text/html; charset=utf-8")
|
||||
w.WriteHeader(http.StatusBadRequest)
|
||||
fmt.Fprint(w, `<!DOCTYPE html>
|
||||
<html><head><title>Login Failed</title></head>
|
||||
<body><h1>Login Failed</h1><p>Invalid state parameter</p><p>You can close this window.</p></body></html>`)
|
||||
resultChan <- WebCallbackResult{Error: "state mismatch"}
|
||||
return
|
||||
}
|
||||
|
||||
w.Header().Set("Content-Type", "text/html; charset=utf-8")
|
||||
fmt.Fprint(w, `<!DOCTYPE html>
|
||||
<html><head><title>Login Successful</title></head>
|
||||
<body><h1>Login Successful!</h1><p>You can close this window and return to the terminal.</p>
|
||||
<script>window.close();</script></body></html>`)
|
||||
resultChan <- WebCallbackResult{Code: code, State: state}
|
||||
})
|
||||
|
||||
server.Handler = mux
|
||||
|
||||
go func() {
|
||||
if err := server.Serve(listener); err != nil && err != http.ErrServerClosed {
|
||||
log.Debugf("kiro social auth callback server error: %v", err)
|
||||
}
|
||||
}()
|
||||
|
||||
go func() {
|
||||
select {
|
||||
case <-ctx.Done():
|
||||
case <-time.After(socialAuthTimeout):
|
||||
case <-resultChan:
|
||||
}
|
||||
_ = server.Shutdown(context.Background())
|
||||
}()
|
||||
|
||||
return redirectURI, resultChan, nil
|
||||
}
|
||||
|
||||
// generatePKCE generates PKCE code verifier and challenge.
|
||||
func generatePKCE() (verifier, challenge string, err error) {
|
||||
// Generate 32 bytes of random data for verifier
|
||||
b := make([]byte, 32)
|
||||
if _, err := rand.Read(b); err != nil {
|
||||
return "", "", fmt.Errorf("failed to generate random bytes: %w", err)
|
||||
}
|
||||
verifier = base64.RawURLEncoding.EncodeToString(b)
|
||||
|
||||
// Generate SHA256 hash of verifier for challenge
|
||||
h := sha256.Sum256([]byte(verifier))
|
||||
challenge = base64.RawURLEncoding.EncodeToString(h[:])
|
||||
|
||||
return verifier, challenge, nil
|
||||
}
|
||||
|
||||
// generateState generates a random state parameter.
|
||||
func generateStateParam() (string, error) {
|
||||
b := make([]byte, 16)
|
||||
if _, err := rand.Read(b); err != nil {
|
||||
return "", err
|
||||
}
|
||||
return base64.RawURLEncoding.EncodeToString(b), nil
|
||||
}
|
||||
|
||||
// buildLoginURL constructs the Kiro OAuth login URL.
|
||||
// The login endpoint expects a GET request with query parameters.
|
||||
// Format: /login?idp=Google&redirect_uri=...&code_challenge=...&code_challenge_method=S256&state=...&prompt=select_account
|
||||
// The prompt=select_account parameter forces the account selection screen even if already logged in.
|
||||
func (c *SocialAuthClient) buildLoginURL(provider, redirectURI, codeChallenge, state string) string {
|
||||
return fmt.Sprintf("%s/login?idp=%s&redirect_uri=%s&code_challenge=%s&code_challenge_method=S256&state=%s&prompt=select_account",
|
||||
kiroAuthServiceEndpoint,
|
||||
provider,
|
||||
url.QueryEscape(redirectURI),
|
||||
codeChallenge,
|
||||
state,
|
||||
)
|
||||
}
|
||||
|
||||
// CreateToken exchanges the authorization code for tokens.
|
||||
func (c *SocialAuthClient) CreateToken(ctx context.Context, req *CreateTokenRequest) (*SocialTokenResponse, error) {
|
||||
body, err := json.Marshal(req)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to marshal token request: %w", err)
|
||||
}
|
||||
|
||||
tokenURL := kiroAuthServiceEndpoint + "/oauth/token"
|
||||
httpReq, err := http.NewRequestWithContext(ctx, http.MethodPost, tokenURL, strings.NewReader(string(body)))
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to create token request: %w", err)
|
||||
}
|
||||
|
||||
httpReq.Header.Set("Content-Type", "application/json")
|
||||
httpReq.Header.Set("User-Agent", fmt.Sprintf("KiroIDE-%s-%s", c.kiroVersion, c.machineID))
|
||||
httpReq.Header.Set("Accept", "application/json, text/plain, */*")
|
||||
|
||||
resp, err := c.httpClient.Do(httpReq)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("token request failed: %w", err)
|
||||
}
|
||||
defer resp.Body.Close()
|
||||
|
||||
respBody, err := io.ReadAll(resp.Body)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to read token response: %w", err)
|
||||
}
|
||||
|
||||
if resp.StatusCode != http.StatusOK {
|
||||
log.Debugf("token exchange failed (status %d): %s", resp.StatusCode, string(respBody))
|
||||
return nil, fmt.Errorf("token exchange failed (status %d)", resp.StatusCode)
|
||||
}
|
||||
|
||||
var tokenResp SocialTokenResponse
|
||||
if err := json.Unmarshal(respBody, &tokenResp); err != nil {
|
||||
return nil, fmt.Errorf("failed to parse token response: %w", err)
|
||||
}
|
||||
|
||||
return &tokenResp, nil
|
||||
}
|
||||
|
||||
// RefreshSocialToken refreshes an expired social auth token.
|
||||
func (c *SocialAuthClient) RefreshSocialToken(ctx context.Context, refreshToken string) (*KiroTokenData, error) {
|
||||
body, err := json.Marshal(&RefreshTokenRequest{RefreshToken: refreshToken})
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to marshal refresh request: %w", err)
|
||||
}
|
||||
|
||||
refreshURL := kiroAuthServiceEndpoint + "/refreshToken"
|
||||
httpReq, err := http.NewRequestWithContext(ctx, http.MethodPost, refreshURL, strings.NewReader(string(body)))
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to create refresh request: %w", err)
|
||||
}
|
||||
|
||||
httpReq.Header.Set("Content-Type", "application/json")
|
||||
httpReq.Header.Set("User-Agent", fmt.Sprintf("KiroIDE-%s-%s", c.kiroVersion, c.machineID))
|
||||
httpReq.Header.Set("Accept", "application/json, text/plain, */*")
|
||||
|
||||
resp, err := c.httpClient.Do(httpReq)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("refresh request failed: %w", err)
|
||||
}
|
||||
defer resp.Body.Close()
|
||||
|
||||
respBody, err := io.ReadAll(resp.Body)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to read refresh response: %w", err)
|
||||
}
|
||||
|
||||
if resp.StatusCode != http.StatusOK {
|
||||
log.Debugf("token refresh failed (status %d): %s", resp.StatusCode, string(respBody))
|
||||
return nil, fmt.Errorf("token refresh failed (status %d)", resp.StatusCode)
|
||||
}
|
||||
|
||||
var tokenResp SocialTokenResponse
|
||||
if err := json.Unmarshal(respBody, &tokenResp); err != nil {
|
||||
return nil, fmt.Errorf("failed to parse refresh response: %w", err)
|
||||
}
|
||||
|
||||
// Validate ExpiresIn - use default 1 hour if invalid
|
||||
expiresIn := tokenResp.ExpiresIn
|
||||
if expiresIn <= 0 {
|
||||
expiresIn = 3600 // Default 1 hour
|
||||
}
|
||||
expiresAt := time.Now().Add(time.Duration(expiresIn) * time.Second)
|
||||
|
||||
return &KiroTokenData{
|
||||
AccessToken: tokenResp.AccessToken,
|
||||
RefreshToken: tokenResp.RefreshToken,
|
||||
ProfileArn: tokenResp.ProfileArn,
|
||||
ExpiresAt: expiresAt.Format(time.RFC3339),
|
||||
AuthMethod: "social",
|
||||
Provider: "", // Caller should preserve original provider
|
||||
Region: "us-east-1",
|
||||
}, nil
|
||||
}
|
||||
|
||||
// LoginWithSocial performs OAuth login with Google or GitHub.
|
||||
// Uses local HTTP callback server instead of custom protocol handler to avoid redirect_mismatch errors.
|
||||
func (c *SocialAuthClient) LoginWithSocial(ctx context.Context, provider SocialProvider) (*KiroTokenData, error) {
|
||||
providerName := string(provider)
|
||||
|
||||
fmt.Println("\n╔══════════════════════════════════════════════════════════╗")
|
||||
fmt.Printf("║ Kiro Authentication (%s) ║\n", providerName)
|
||||
fmt.Println("╚══════════════════════════════════════════════════════════╝")
|
||||
|
||||
// Step 1: Start local HTTP callback server (instead of kiro:// protocol handler)
|
||||
// This avoids redirect_mismatch errors with AWS Cognito
|
||||
fmt.Println("\nSetting up authentication...")
|
||||
|
||||
// Step 2: Generate PKCE codes
|
||||
codeVerifier, codeChallenge, err := generatePKCE()
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to generate PKCE: %w", err)
|
||||
}
|
||||
|
||||
// Step 3: Generate state
|
||||
state, err := generateStateParam()
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to generate state: %w", err)
|
||||
}
|
||||
|
||||
// Step 4: Start local HTTP callback server
|
||||
redirectURI, resultChan, err := c.startWebCallbackServer(ctx, state)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to start callback server: %w", err)
|
||||
}
|
||||
log.Debugf("kiro social auth: callback server started at %s", redirectURI)
|
||||
|
||||
// Step 5: Build the login URL using HTTP redirect URI
|
||||
authURL := c.buildLoginURL(providerName, redirectURI, codeChallenge, state)
|
||||
|
||||
// Set incognito mode based on config (defaults to true for Kiro, can be overridden with --no-incognito)
|
||||
// Incognito mode enables multi-account support by bypassing cached sessions
|
||||
if c.cfg != nil {
|
||||
browser.SetIncognitoMode(c.cfg.IncognitoBrowser)
|
||||
if !c.cfg.IncognitoBrowser {
|
||||
log.Info("kiro: using normal browser mode (--no-incognito). Note: You may not be able to select a different account.")
|
||||
} else {
|
||||
log.Debug("kiro: using incognito mode for multi-account support")
|
||||
}
|
||||
} else {
|
||||
browser.SetIncognitoMode(true) // Default to incognito if no config
|
||||
log.Debug("kiro: using incognito mode for multi-account support (default)")
|
||||
}
|
||||
|
||||
// Step 6: Open browser for user authentication
|
||||
fmt.Println("\n════════════════════════════════════════════════════════════")
|
||||
fmt.Printf(" Opening browser for %s authentication...\n", providerName)
|
||||
fmt.Println("════════════════════════════════════════════════════════════")
|
||||
fmt.Printf("\n URL: %s\n\n", authURL)
|
||||
|
||||
if err := browser.OpenURL(authURL); err != nil {
|
||||
log.Warnf("Could not open browser automatically: %v", err)
|
||||
fmt.Println(" ⚠ Could not open browser automatically.")
|
||||
fmt.Println(" Please open the URL above in your browser manually.")
|
||||
} else {
|
||||
fmt.Println(" (Browser opened automatically)")
|
||||
}
|
||||
|
||||
fmt.Println("\n Waiting for authentication callback...")
|
||||
|
||||
// Step 7: Wait for callback from HTTP server
|
||||
select {
|
||||
case <-ctx.Done():
|
||||
return nil, ctx.Err()
|
||||
case <-time.After(socialAuthTimeout):
|
||||
return nil, fmt.Errorf("authentication timed out")
|
||||
case callback := <-resultChan:
|
||||
if callback.Error != "" {
|
||||
return nil, fmt.Errorf("authentication error: %s", callback.Error)
|
||||
}
|
||||
|
||||
// State is already validated by the callback server
|
||||
if callback.Code == "" {
|
||||
return nil, fmt.Errorf("no authorization code received")
|
||||
}
|
||||
|
||||
fmt.Println("\n✓ Authorization received!")
|
||||
|
||||
// Step 8: Exchange code for tokens
|
||||
fmt.Println("Exchanging code for tokens...")
|
||||
|
||||
tokenReq := &CreateTokenRequest{
|
||||
Code: callback.Code,
|
||||
CodeVerifier: codeVerifier,
|
||||
RedirectURI: redirectURI, // Use HTTP redirect URI, not kiro:// protocol
|
||||
}
|
||||
|
||||
tokenResp, err := c.CreateToken(ctx, tokenReq)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to exchange code for tokens: %w", err)
|
||||
}
|
||||
|
||||
fmt.Println("\n✓ Authentication successful!")
|
||||
|
||||
// Close the browser window
|
||||
if err := browser.CloseBrowser(); err != nil {
|
||||
log.Debugf("Failed to close browser: %v", err)
|
||||
}
|
||||
|
||||
// Validate ExpiresIn - use default 1 hour if invalid
|
||||
expiresIn := tokenResp.ExpiresIn
|
||||
if expiresIn <= 0 {
|
||||
expiresIn = 3600
|
||||
}
|
||||
expiresAt := time.Now().Add(time.Duration(expiresIn) * time.Second)
|
||||
|
||||
// Try to extract email from JWT access token first
|
||||
email := ExtractEmailFromJWT(tokenResp.AccessToken)
|
||||
|
||||
// If no email in JWT, ask user for account label (only in interactive mode)
|
||||
if email == "" && isInteractiveTerminal() {
|
||||
fmt.Print("\n Enter account label for file naming (optional, press Enter to skip): ")
|
||||
reader := bufio.NewReader(os.Stdin)
|
||||
var err error
|
||||
email, err = reader.ReadString('\n')
|
||||
if err != nil {
|
||||
log.Debugf("Failed to read account label: %v", err)
|
||||
}
|
||||
email = strings.TrimSpace(email)
|
||||
}
|
||||
|
||||
return &KiroTokenData{
|
||||
AccessToken: tokenResp.AccessToken,
|
||||
RefreshToken: tokenResp.RefreshToken,
|
||||
ProfileArn: tokenResp.ProfileArn,
|
||||
ExpiresAt: expiresAt.Format(time.RFC3339),
|
||||
AuthMethod: "social",
|
||||
Provider: providerName,
|
||||
Email: email, // JWT email or user-provided label
|
||||
Region: "us-east-1",
|
||||
}, nil
|
||||
}
|
||||
}
|
||||
|
||||
// LoginWithGoogle performs OAuth login with Google.
|
||||
func (c *SocialAuthClient) LoginWithGoogle(ctx context.Context) (*KiroTokenData, error) {
|
||||
return c.LoginWithSocial(ctx, ProviderGoogle)
|
||||
}
|
||||
|
||||
// LoginWithGitHub performs OAuth login with GitHub.
|
||||
func (c *SocialAuthClient) LoginWithGitHub(ctx context.Context) (*KiroTokenData, error) {
|
||||
return c.LoginWithSocial(ctx, ProviderGitHub)
|
||||
}
|
||||
|
||||
// forceDefaultProtocolHandler sets our protocol handler as the default for kiro:// URLs.
|
||||
// This prevents the "Open with" dialog from appearing on Linux.
|
||||
// On non-Linux platforms, this is a no-op as they use different mechanisms.
|
||||
func forceDefaultProtocolHandler() {
|
||||
if runtime.GOOS != "linux" {
|
||||
return // Non-Linux platforms use different handler mechanisms
|
||||
}
|
||||
|
||||
// Set our handler as default using xdg-mime
|
||||
cmd := exec.Command("xdg-mime", "default", "kiro-oauth-handler.desktop", "x-scheme-handler/kiro")
|
||||
if err := cmd.Run(); err != nil {
|
||||
log.Warnf("Failed to set default protocol handler: %v. You may see a handler selection dialog.", err)
|
||||
}
|
||||
}
|
||||
|
||||
// isInteractiveTerminal checks if stdin is connected to an interactive terminal.
|
||||
// Returns false in CI/automated environments or when stdin is piped.
|
||||
func isInteractiveTerminal() bool {
|
||||
return term.IsTerminal(int(os.Stdin.Fd()))
|
||||
}
|
||||
1603
internal/auth/kiro/sso_oidc.go
Normal file
1603
internal/auth/kiro/sso_oidc.go
Normal file
File diff suppressed because it is too large
Load Diff
261
internal/auth/kiro/sso_oidc_test.go
Normal file
261
internal/auth/kiro/sso_oidc_test.go
Normal file
@@ -0,0 +1,261 @@
|
||||
package kiro
|
||||
|
||||
import (
|
||||
"context"
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"io"
|
||||
"net/http"
|
||||
"net/http/httptest"
|
||||
"net/url"
|
||||
"strings"
|
||||
"testing"
|
||||
)
|
||||
|
||||
type recordingRoundTripper struct {
|
||||
lastReq *http.Request
|
||||
}
|
||||
|
||||
func (rt *recordingRoundTripper) RoundTrip(req *http.Request) (*http.Response, error) {
|
||||
rt.lastReq = req
|
||||
body := `{"nextToken":null,"profiles":[{"arn":"arn:aws:codewhisperer:us-east-1:123456789012:profile/ABC","profileName":"test"}]}`
|
||||
return &http.Response{
|
||||
StatusCode: http.StatusOK,
|
||||
Body: io.NopCloser(strings.NewReader(body)),
|
||||
Header: make(http.Header),
|
||||
}, nil
|
||||
}
|
||||
|
||||
func TestTryListAvailableProfiles_UsesClientIDForAccountKey(t *testing.T) {
|
||||
rt := &recordingRoundTripper{}
|
||||
client := &SSOOIDCClient{
|
||||
httpClient: &http.Client{Transport: rt},
|
||||
}
|
||||
|
||||
profileArn := client.tryListAvailableProfiles(context.Background(), "access-token", "client-id-123", "refresh-token-456")
|
||||
if profileArn == "" {
|
||||
t.Fatal("expected profileArn, got empty result")
|
||||
}
|
||||
|
||||
accountKey := GetAccountKey("client-id-123", "refresh-token-456")
|
||||
fp := GlobalFingerprintManager().GetFingerprint(accountKey)
|
||||
expected := fmt.Sprintf("aws-sdk-js/%s KiroIDE-%s-%s", fp.RuntimeSDKVersion, fp.KiroVersion, fp.KiroHash)
|
||||
got := rt.lastReq.Header.Get("X-Amz-User-Agent")
|
||||
if got != expected {
|
||||
t.Errorf("X-Amz-User-Agent = %q, want %q", got, expected)
|
||||
}
|
||||
}
|
||||
|
||||
func TestTryListAvailableProfiles_UsesRefreshTokenWhenClientIDMissing(t *testing.T) {
|
||||
rt := &recordingRoundTripper{}
|
||||
client := &SSOOIDCClient{
|
||||
httpClient: &http.Client{Transport: rt},
|
||||
}
|
||||
|
||||
profileArn := client.tryListAvailableProfiles(context.Background(), "access-token", "", "refresh-token-789")
|
||||
if profileArn == "" {
|
||||
t.Fatal("expected profileArn, got empty result")
|
||||
}
|
||||
|
||||
accountKey := GetAccountKey("", "refresh-token-789")
|
||||
fp := GlobalFingerprintManager().GetFingerprint(accountKey)
|
||||
expected := fmt.Sprintf("aws-sdk-js/%s KiroIDE-%s-%s", fp.RuntimeSDKVersion, fp.KiroVersion, fp.KiroHash)
|
||||
got := rt.lastReq.Header.Get("X-Amz-User-Agent")
|
||||
if got != expected {
|
||||
t.Errorf("X-Amz-User-Agent = %q, want %q", got, expected)
|
||||
}
|
||||
}
|
||||
|
||||
func TestRegisterClientForAuthCodeWithIDC(t *testing.T) {
|
||||
var capturedReq struct {
|
||||
Method string
|
||||
Path string
|
||||
Headers http.Header
|
||||
Body map[string]interface{}
|
||||
}
|
||||
|
||||
mockResp := RegisterClientResponse{
|
||||
ClientID: "test-client-id",
|
||||
ClientSecret: "test-client-secret",
|
||||
ClientIDIssuedAt: 1700000000,
|
||||
ClientSecretExpiresAt: 1700086400,
|
||||
}
|
||||
|
||||
ts := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
capturedReq.Method = r.Method
|
||||
capturedReq.Path = r.URL.Path
|
||||
capturedReq.Headers = r.Header.Clone()
|
||||
|
||||
bodyBytes, _ := io.ReadAll(r.Body)
|
||||
json.Unmarshal(bodyBytes, &capturedReq.Body)
|
||||
|
||||
w.Header().Set("Content-Type", "application/json")
|
||||
json.NewEncoder(w).Encode(mockResp)
|
||||
}))
|
||||
defer ts.Close()
|
||||
|
||||
// Extract host to build a region that resolves to our test server.
|
||||
// Override getOIDCEndpoint by passing region="" and patching the endpoint.
|
||||
// Since getOIDCEndpoint builds "https://oidc.{region}.amazonaws.com", we
|
||||
// instead inject the test server URL directly via a custom HTTP client transport.
|
||||
client := &SSOOIDCClient{
|
||||
httpClient: ts.Client(),
|
||||
}
|
||||
|
||||
// We need to route the request to our test server. Use a transport that rewrites the URL.
|
||||
client.httpClient.Transport = &rewriteTransport{
|
||||
base: ts.Client().Transport,
|
||||
targetURL: ts.URL,
|
||||
}
|
||||
|
||||
resp, err := client.RegisterClientForAuthCodeWithIDC(
|
||||
context.Background(),
|
||||
"http://127.0.0.1:19877/oauth/callback",
|
||||
"https://my-idc-instance.awsapps.com/start",
|
||||
"us-east-1",
|
||||
)
|
||||
if err != nil {
|
||||
t.Fatalf("unexpected error: %v", err)
|
||||
}
|
||||
|
||||
// Verify request method and path
|
||||
if capturedReq.Method != http.MethodPost {
|
||||
t.Errorf("method = %q, want POST", capturedReq.Method)
|
||||
}
|
||||
if capturedReq.Path != "/client/register" {
|
||||
t.Errorf("path = %q, want /client/register", capturedReq.Path)
|
||||
}
|
||||
|
||||
// Verify headers
|
||||
if ct := capturedReq.Headers.Get("Content-Type"); ct != "application/json" {
|
||||
t.Errorf("Content-Type = %q, want application/json", ct)
|
||||
}
|
||||
ua := capturedReq.Headers.Get("User-Agent")
|
||||
if !strings.Contains(ua, "KiroIDE") {
|
||||
t.Errorf("User-Agent %q does not contain KiroIDE", ua)
|
||||
}
|
||||
if !strings.Contains(ua, "sso-oidc") {
|
||||
t.Errorf("User-Agent %q does not contain sso-oidc", ua)
|
||||
}
|
||||
xua := capturedReq.Headers.Get("X-Amz-User-Agent")
|
||||
if !strings.Contains(xua, "KiroIDE") {
|
||||
t.Errorf("x-amz-user-agent %q does not contain KiroIDE", xua)
|
||||
}
|
||||
|
||||
// Verify body fields
|
||||
if v, _ := capturedReq.Body["clientName"].(string); v != "Kiro IDE" {
|
||||
t.Errorf("clientName = %q, want %q", v, "Kiro IDE")
|
||||
}
|
||||
if v, _ := capturedReq.Body["clientType"].(string); v != "public" {
|
||||
t.Errorf("clientType = %q, want %q", v, "public")
|
||||
}
|
||||
if v, _ := capturedReq.Body["issuerUrl"].(string); v != "https://my-idc-instance.awsapps.com/start" {
|
||||
t.Errorf("issuerUrl = %q, want %q", v, "https://my-idc-instance.awsapps.com/start")
|
||||
}
|
||||
|
||||
// Verify scopes array
|
||||
scopesRaw, ok := capturedReq.Body["scopes"].([]interface{})
|
||||
if !ok || len(scopesRaw) != 5 {
|
||||
t.Fatalf("scopes: got %v, want 5-element array", capturedReq.Body["scopes"])
|
||||
}
|
||||
expectedScopes := []string{
|
||||
"codewhisperer:completions", "codewhisperer:analysis",
|
||||
"codewhisperer:conversations", "codewhisperer:transformations",
|
||||
"codewhisperer:taskassist",
|
||||
}
|
||||
for i, s := range expectedScopes {
|
||||
if scopesRaw[i].(string) != s {
|
||||
t.Errorf("scopes[%d] = %q, want %q", i, scopesRaw[i], s)
|
||||
}
|
||||
}
|
||||
|
||||
// Verify grantTypes
|
||||
grantTypesRaw, ok := capturedReq.Body["grantTypes"].([]interface{})
|
||||
if !ok || len(grantTypesRaw) != 2 {
|
||||
t.Fatalf("grantTypes: got %v, want 2-element array", capturedReq.Body["grantTypes"])
|
||||
}
|
||||
if grantTypesRaw[0].(string) != "authorization_code" || grantTypesRaw[1].(string) != "refresh_token" {
|
||||
t.Errorf("grantTypes = %v, want [authorization_code, refresh_token]", grantTypesRaw)
|
||||
}
|
||||
|
||||
// Verify redirectUris
|
||||
redirectRaw, ok := capturedReq.Body["redirectUris"].([]interface{})
|
||||
if !ok || len(redirectRaw) != 1 {
|
||||
t.Fatalf("redirectUris: got %v, want 1-element array", capturedReq.Body["redirectUris"])
|
||||
}
|
||||
if redirectRaw[0].(string) != "http://127.0.0.1:19877/oauth/callback" {
|
||||
t.Errorf("redirectUris[0] = %q, want %q", redirectRaw[0], "http://127.0.0.1:19877/oauth/callback")
|
||||
}
|
||||
|
||||
// Verify response parsing
|
||||
if resp.ClientID != "test-client-id" {
|
||||
t.Errorf("ClientID = %q, want %q", resp.ClientID, "test-client-id")
|
||||
}
|
||||
if resp.ClientSecret != "test-client-secret" {
|
||||
t.Errorf("ClientSecret = %q, want %q", resp.ClientSecret, "test-client-secret")
|
||||
}
|
||||
if resp.ClientIDIssuedAt != 1700000000 {
|
||||
t.Errorf("ClientIDIssuedAt = %d, want %d", resp.ClientIDIssuedAt, 1700000000)
|
||||
}
|
||||
if resp.ClientSecretExpiresAt != 1700086400 {
|
||||
t.Errorf("ClientSecretExpiresAt = %d, want %d", resp.ClientSecretExpiresAt, 1700086400)
|
||||
}
|
||||
}
|
||||
|
||||
// rewriteTransport redirects all requests to the test server URL.
|
||||
type rewriteTransport struct {
|
||||
base http.RoundTripper
|
||||
targetURL string
|
||||
}
|
||||
|
||||
func (t *rewriteTransport) RoundTrip(req *http.Request) (*http.Response, error) {
|
||||
target, _ := url.Parse(t.targetURL)
|
||||
req.URL.Scheme = target.Scheme
|
||||
req.URL.Host = target.Host
|
||||
if t.base != nil {
|
||||
return t.base.RoundTrip(req)
|
||||
}
|
||||
return http.DefaultTransport.RoundTrip(req)
|
||||
}
|
||||
|
||||
func TestBuildAuthorizationURL(t *testing.T) {
|
||||
scopes := "codewhisperer:completions,codewhisperer:analysis,codewhisperer:conversations,codewhisperer:transformations,codewhisperer:taskassist"
|
||||
endpoint := "https://oidc.us-east-1.amazonaws.com"
|
||||
redirectURI := "http://127.0.0.1:19877/oauth/callback"
|
||||
|
||||
authURL := buildAuthorizationURL(endpoint, "test-client-id", redirectURI, scopes, "random-state", "test-challenge")
|
||||
|
||||
// Verify colons and commas in scopes are percent-encoded
|
||||
if !strings.Contains(authURL, "codewhisperer%3Acompletions") {
|
||||
t.Errorf("expected colons in scopes to be percent-encoded, got: %s", authURL)
|
||||
}
|
||||
if !strings.Contains(authURL, "completions%2Ccodewhisperer") {
|
||||
t.Errorf("expected commas in scopes to be percent-encoded, got: %s", authURL)
|
||||
}
|
||||
|
||||
// Parse back and verify all parameters round-trip correctly
|
||||
parsed, err := url.Parse(authURL)
|
||||
if err != nil {
|
||||
t.Fatalf("failed to parse auth URL: %v", err)
|
||||
}
|
||||
|
||||
if !strings.HasPrefix(authURL, endpoint+"/authorize?") {
|
||||
t.Errorf("expected URL to start with %s/authorize?, got: %s", endpoint, authURL)
|
||||
}
|
||||
|
||||
q := parsed.Query()
|
||||
checks := map[string]string{
|
||||
"response_type": "code",
|
||||
"client_id": "test-client-id",
|
||||
"redirect_uri": redirectURI,
|
||||
"scopes": scopes,
|
||||
"state": "random-state",
|
||||
"code_challenge": "test-challenge",
|
||||
"code_challenge_method": "S256",
|
||||
}
|
||||
for key, want := range checks {
|
||||
if got := q.Get(key); got != want {
|
||||
t.Errorf("%s = %q, want %q", key, got, want)
|
||||
}
|
||||
}
|
||||
}
|
||||
89
internal/auth/kiro/token.go
Normal file
89
internal/auth/kiro/token.go
Normal file
@@ -0,0 +1,89 @@
|
||||
package kiro
|
||||
|
||||
import (
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"os"
|
||||
"path/filepath"
|
||||
)
|
||||
|
||||
// KiroTokenStorage holds the persistent token data for Kiro authentication.
|
||||
type KiroTokenStorage struct {
|
||||
// Type is the provider type for management UI recognition (must be "kiro")
|
||||
Type string `json:"type"`
|
||||
// AccessToken is the OAuth2 access token for API access
|
||||
AccessToken string `json:"access_token"`
|
||||
// RefreshToken is used to obtain new access tokens
|
||||
RefreshToken string `json:"refresh_token"`
|
||||
// ProfileArn is the AWS CodeWhisperer profile ARN
|
||||
ProfileArn string `json:"profile_arn"`
|
||||
// ExpiresAt is the timestamp when the token expires
|
||||
ExpiresAt string `json:"expires_at"`
|
||||
// AuthMethod indicates the authentication method used
|
||||
AuthMethod string `json:"auth_method"`
|
||||
// Provider indicates the OAuth provider
|
||||
Provider string `json:"provider"`
|
||||
// LastRefresh is the timestamp of the last token refresh
|
||||
LastRefresh string `json:"last_refresh"`
|
||||
// ClientID is the OAuth client ID (required for token refresh)
|
||||
ClientID string `json:"client_id,omitempty"`
|
||||
// ClientSecret is the OAuth client secret (required for token refresh)
|
||||
ClientSecret string `json:"client_secret,omitempty"`
|
||||
// Region is the OIDC region for IDC login and token refresh
|
||||
Region string `json:"region,omitempty"`
|
||||
// StartURL is the AWS Identity Center start URL (for IDC auth)
|
||||
StartURL string `json:"start_url,omitempty"`
|
||||
// Email is the user's email address
|
||||
Email string `json:"email,omitempty"`
|
||||
}
|
||||
|
||||
// SaveTokenToFile persists the token storage to the specified file path.
|
||||
func (s *KiroTokenStorage) SaveTokenToFile(authFilePath string) error {
|
||||
dir := filepath.Dir(authFilePath)
|
||||
if err := os.MkdirAll(dir, 0700); err != nil {
|
||||
return fmt.Errorf("failed to create directory: %w", err)
|
||||
}
|
||||
|
||||
data, err := json.MarshalIndent(s, "", " ")
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to marshal token storage: %w", err)
|
||||
}
|
||||
|
||||
if err := os.WriteFile(authFilePath, data, 0600); err != nil {
|
||||
return fmt.Errorf("failed to write token file: %w", err)
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// LoadFromFile loads token storage from the specified file path.
|
||||
func LoadFromFile(authFilePath string) (*KiroTokenStorage, error) {
|
||||
data, err := os.ReadFile(authFilePath)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to read token file: %w", err)
|
||||
}
|
||||
|
||||
var storage KiroTokenStorage
|
||||
if err := json.Unmarshal(data, &storage); err != nil {
|
||||
return nil, fmt.Errorf("failed to parse token file: %w", err)
|
||||
}
|
||||
|
||||
return &storage, nil
|
||||
}
|
||||
|
||||
// ToTokenData converts storage to KiroTokenData for API use.
|
||||
func (s *KiroTokenStorage) ToTokenData() *KiroTokenData {
|
||||
return &KiroTokenData{
|
||||
AccessToken: s.AccessToken,
|
||||
RefreshToken: s.RefreshToken,
|
||||
ProfileArn: s.ProfileArn,
|
||||
ExpiresAt: s.ExpiresAt,
|
||||
AuthMethod: s.AuthMethod,
|
||||
Provider: s.Provider,
|
||||
ClientID: s.ClientID,
|
||||
ClientSecret: s.ClientSecret,
|
||||
Region: s.Region,
|
||||
StartURL: s.StartURL,
|
||||
Email: s.Email,
|
||||
}
|
||||
}
|
||||
260
internal/auth/kiro/token_repository.go
Normal file
260
internal/auth/kiro/token_repository.go
Normal file
@@ -0,0 +1,260 @@
|
||||
package kiro
|
||||
|
||||
import (
|
||||
"context"
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"io/fs"
|
||||
"os"
|
||||
"path/filepath"
|
||||
"sort"
|
||||
"strings"
|
||||
"sync"
|
||||
"time"
|
||||
|
||||
log "github.com/sirupsen/logrus"
|
||||
)
|
||||
|
||||
// FileTokenRepository 实现 TokenRepository 接口,基于文件系统存储
|
||||
type FileTokenRepository struct {
|
||||
mu sync.RWMutex
|
||||
baseDir string
|
||||
}
|
||||
|
||||
// NewFileTokenRepository 创建一个新的文件 token 存储库
|
||||
func NewFileTokenRepository(baseDir string) *FileTokenRepository {
|
||||
return &FileTokenRepository{
|
||||
baseDir: baseDir,
|
||||
}
|
||||
}
|
||||
|
||||
// SetBaseDir 设置基础目录
|
||||
func (r *FileTokenRepository) SetBaseDir(dir string) {
|
||||
r.mu.Lock()
|
||||
r.baseDir = strings.TrimSpace(dir)
|
||||
r.mu.Unlock()
|
||||
}
|
||||
|
||||
// FindOldestUnverified 查找需要刷新的 token(按最后验证时间排序)
|
||||
func (r *FileTokenRepository) FindOldestUnverified(limit int) []*Token {
|
||||
r.mu.RLock()
|
||||
baseDir := r.baseDir
|
||||
r.mu.RUnlock()
|
||||
|
||||
if baseDir == "" {
|
||||
log.Debug("token repository: base directory not configured")
|
||||
return nil
|
||||
}
|
||||
|
||||
var tokens []*Token
|
||||
|
||||
err := filepath.WalkDir(baseDir, func(path string, d fs.DirEntry, walkErr error) error {
|
||||
if walkErr != nil {
|
||||
return nil // 忽略错误,继续遍历
|
||||
}
|
||||
if d.IsDir() {
|
||||
return nil
|
||||
}
|
||||
if !strings.HasSuffix(strings.ToLower(d.Name()), ".json") {
|
||||
return nil
|
||||
}
|
||||
|
||||
// 只处理 kiro 相关的 token 文件
|
||||
if !strings.HasPrefix(d.Name(), "kiro-") {
|
||||
return nil
|
||||
}
|
||||
|
||||
token, err := r.readTokenFile(path)
|
||||
if err != nil {
|
||||
log.Debugf("token repository: failed to read token file %s: %v", path, err)
|
||||
return nil
|
||||
}
|
||||
|
||||
if token != nil && token.RefreshToken != "" {
|
||||
// 检查 token 是否需要刷新(过期前 5 分钟)
|
||||
if token.ExpiresAt.IsZero() || time.Until(token.ExpiresAt) < 5*time.Minute {
|
||||
tokens = append(tokens, token)
|
||||
}
|
||||
}
|
||||
|
||||
return nil
|
||||
})
|
||||
|
||||
if err != nil {
|
||||
log.Warnf("token repository: error walking directory: %v", err)
|
||||
}
|
||||
|
||||
// 按最后验证时间排序(最旧的优先)
|
||||
sort.Slice(tokens, func(i, j int) bool {
|
||||
return tokens[i].LastVerified.Before(tokens[j].LastVerified)
|
||||
})
|
||||
|
||||
// 限制返回数量
|
||||
if limit > 0 && len(tokens) > limit {
|
||||
tokens = tokens[:limit]
|
||||
}
|
||||
|
||||
return tokens
|
||||
}
|
||||
|
||||
// UpdateToken 更新 token 并持久化到文件
|
||||
func (r *FileTokenRepository) UpdateToken(token *Token) error {
|
||||
if token == nil {
|
||||
return fmt.Errorf("token repository: token is nil")
|
||||
}
|
||||
|
||||
r.mu.RLock()
|
||||
baseDir := r.baseDir
|
||||
r.mu.RUnlock()
|
||||
|
||||
if baseDir == "" {
|
||||
return fmt.Errorf("token repository: base directory not configured")
|
||||
}
|
||||
|
||||
// 构建文件路径
|
||||
filePath := filepath.Join(baseDir, token.ID)
|
||||
if !strings.HasSuffix(filePath, ".json") {
|
||||
filePath += ".json"
|
||||
}
|
||||
|
||||
// 读取现有文件内容
|
||||
existingData := make(map[string]any)
|
||||
if data, err := os.ReadFile(filePath); err == nil {
|
||||
_ = json.Unmarshal(data, &existingData)
|
||||
}
|
||||
|
||||
// 更新字段
|
||||
existingData["access_token"] = token.AccessToken
|
||||
existingData["refresh_token"] = token.RefreshToken
|
||||
existingData["last_refresh"] = time.Now().Format(time.RFC3339)
|
||||
|
||||
if !token.ExpiresAt.IsZero() {
|
||||
existingData["expires_at"] = token.ExpiresAt.Format(time.RFC3339)
|
||||
}
|
||||
|
||||
// 保持原有的关键字段
|
||||
if token.ClientID != "" {
|
||||
existingData["client_id"] = token.ClientID
|
||||
}
|
||||
if token.ClientSecret != "" {
|
||||
existingData["client_secret"] = token.ClientSecret
|
||||
}
|
||||
if token.AuthMethod != "" {
|
||||
existingData["auth_method"] = token.AuthMethod
|
||||
}
|
||||
if token.Region != "" {
|
||||
existingData["region"] = token.Region
|
||||
}
|
||||
if token.StartURL != "" {
|
||||
existingData["start_url"] = token.StartURL
|
||||
}
|
||||
|
||||
// 序列化并写入文件
|
||||
raw, err := json.MarshalIndent(existingData, "", " ")
|
||||
if err != nil {
|
||||
return fmt.Errorf("token repository: marshal failed: %w", err)
|
||||
}
|
||||
|
||||
// 原子写入:先写入临时文件,再重命名
|
||||
tmpPath := filePath + ".tmp"
|
||||
if err := os.WriteFile(tmpPath, raw, 0o600); err != nil {
|
||||
return fmt.Errorf("token repository: write temp file failed: %w", err)
|
||||
}
|
||||
if err := os.Rename(tmpPath, filePath); err != nil {
|
||||
_ = os.Remove(tmpPath)
|
||||
return fmt.Errorf("token repository: rename failed: %w", err)
|
||||
}
|
||||
|
||||
log.Debugf("token repository: updated token %s", token.ID)
|
||||
return nil
|
||||
}
|
||||
|
||||
// readTokenFile 从文件读取 token
|
||||
func (r *FileTokenRepository) readTokenFile(path string) (*Token, error) {
|
||||
data, err := os.ReadFile(path)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
var metadata map[string]any
|
||||
if err := json.Unmarshal(data, &metadata); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
// 检查是否是 kiro token
|
||||
tokenType, _ := metadata["type"].(string)
|
||||
if tokenType != "kiro" {
|
||||
return nil, nil
|
||||
}
|
||||
|
||||
// 检查 auth_method (case-insensitive comparison to handle "IdC", "IDC", "idc", etc.)
|
||||
authMethod, _ := metadata["auth_method"].(string)
|
||||
authMethod = strings.ToLower(authMethod)
|
||||
if authMethod != "idc" && authMethod != "builder-id" {
|
||||
return nil, nil // 只处理 IDC 和 Builder ID token
|
||||
}
|
||||
|
||||
token := &Token{
|
||||
ID: filepath.Base(path),
|
||||
AuthMethod: authMethod,
|
||||
}
|
||||
|
||||
// 解析各字段
|
||||
token.AccessToken, _ = metadata["access_token"].(string)
|
||||
token.RefreshToken, _ = metadata["refresh_token"].(string)
|
||||
token.ClientID, _ = metadata["client_id"].(string)
|
||||
token.ClientSecret, _ = metadata["client_secret"].(string)
|
||||
token.Region, _ = metadata["region"].(string)
|
||||
token.StartURL, _ = metadata["start_url"].(string)
|
||||
token.Provider, _ = metadata["provider"].(string)
|
||||
|
||||
// 解析时间字段
|
||||
if expiresAtStr, ok := metadata["expires_at"].(string); ok && expiresAtStr != "" {
|
||||
if t, err := time.Parse(time.RFC3339, expiresAtStr); err == nil {
|
||||
token.ExpiresAt = t
|
||||
}
|
||||
}
|
||||
if lastRefreshStr, ok := metadata["last_refresh"].(string); ok && lastRefreshStr != "" {
|
||||
if t, err := time.Parse(time.RFC3339, lastRefreshStr); err == nil {
|
||||
token.LastVerified = t
|
||||
}
|
||||
}
|
||||
|
||||
return token, nil
|
||||
}
|
||||
|
||||
// ListKiroTokens 列出所有 Kiro token(用于调试)
|
||||
func (r *FileTokenRepository) ListKiroTokens(ctx context.Context) ([]*Token, error) {
|
||||
r.mu.RLock()
|
||||
baseDir := r.baseDir
|
||||
r.mu.RUnlock()
|
||||
|
||||
if baseDir == "" {
|
||||
return nil, fmt.Errorf("token repository: base directory not configured")
|
||||
}
|
||||
|
||||
var tokens []*Token
|
||||
|
||||
err := filepath.WalkDir(baseDir, func(path string, d fs.DirEntry, walkErr error) error {
|
||||
if walkErr != nil {
|
||||
return nil
|
||||
}
|
||||
if d.IsDir() {
|
||||
return nil
|
||||
}
|
||||
if !strings.HasPrefix(d.Name(), "kiro-") || !strings.HasSuffix(d.Name(), ".json") {
|
||||
return nil
|
||||
}
|
||||
|
||||
token, err := r.readTokenFile(path)
|
||||
if err != nil {
|
||||
return nil
|
||||
}
|
||||
if token != nil {
|
||||
tokens = append(tokens, token)
|
||||
}
|
||||
return nil
|
||||
})
|
||||
|
||||
return tokens, err
|
||||
}
|
||||
236
internal/auth/kiro/usage_checker.go
Normal file
236
internal/auth/kiro/usage_checker.go
Normal file
@@ -0,0 +1,236 @@
|
||||
// Package kiro provides authentication functionality for AWS CodeWhisperer (Kiro) API.
|
||||
// This file implements usage quota checking and monitoring.
|
||||
package kiro
|
||||
|
||||
import (
|
||||
"context"
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"io"
|
||||
"net/http"
|
||||
"time"
|
||||
|
||||
"github.com/router-for-me/CLIProxyAPI/v6/internal/config"
|
||||
"github.com/router-for-me/CLIProxyAPI/v6/internal/util"
|
||||
)
|
||||
|
||||
// UsageQuotaResponse represents the API response structure for usage quota checking.
|
||||
type UsageQuotaResponse struct {
|
||||
UsageBreakdownList []UsageBreakdownExtended `json:"usageBreakdownList"`
|
||||
SubscriptionInfo *SubscriptionInfo `json:"subscriptionInfo,omitempty"`
|
||||
NextDateReset float64 `json:"nextDateReset,omitempty"`
|
||||
}
|
||||
|
||||
// UsageBreakdownExtended represents detailed usage information for quota checking.
|
||||
// Note: UsageBreakdown is already defined in codewhisperer_client.go
|
||||
type UsageBreakdownExtended struct {
|
||||
ResourceType string `json:"resourceType"`
|
||||
UsageLimitWithPrecision float64 `json:"usageLimitWithPrecision"`
|
||||
CurrentUsageWithPrecision float64 `json:"currentUsageWithPrecision"`
|
||||
FreeTrialInfo *FreeTrialInfoExtended `json:"freeTrialInfo,omitempty"`
|
||||
}
|
||||
|
||||
// FreeTrialInfoExtended represents free trial usage information.
|
||||
type FreeTrialInfoExtended struct {
|
||||
FreeTrialStatus string `json:"freeTrialStatus"`
|
||||
UsageLimitWithPrecision float64 `json:"usageLimitWithPrecision"`
|
||||
CurrentUsageWithPrecision float64 `json:"currentUsageWithPrecision"`
|
||||
}
|
||||
|
||||
// QuotaStatus represents the quota status for a token.
|
||||
type QuotaStatus struct {
|
||||
TotalLimit float64
|
||||
CurrentUsage float64
|
||||
RemainingQuota float64
|
||||
IsExhausted bool
|
||||
ResourceType string
|
||||
NextReset time.Time
|
||||
}
|
||||
|
||||
// UsageChecker provides methods for checking token quota usage.
|
||||
type UsageChecker struct {
|
||||
httpClient *http.Client
|
||||
}
|
||||
|
||||
// NewUsageChecker creates a new UsageChecker instance.
|
||||
func NewUsageChecker(cfg *config.Config) *UsageChecker {
|
||||
return &UsageChecker{
|
||||
httpClient: util.SetProxy(&cfg.SDKConfig, &http.Client{Timeout: 30 * time.Second}),
|
||||
}
|
||||
}
|
||||
|
||||
// NewUsageCheckerWithClient creates a UsageChecker with a custom HTTP client.
|
||||
func NewUsageCheckerWithClient(client *http.Client) *UsageChecker {
|
||||
return &UsageChecker{
|
||||
httpClient: client,
|
||||
}
|
||||
}
|
||||
|
||||
// CheckUsage retrieves usage limits for the given token.
|
||||
func (c *UsageChecker) CheckUsage(ctx context.Context, tokenData *KiroTokenData) (*UsageQuotaResponse, error) {
|
||||
if tokenData == nil {
|
||||
return nil, fmt.Errorf("token data is nil")
|
||||
}
|
||||
|
||||
if tokenData.AccessToken == "" {
|
||||
return nil, fmt.Errorf("access token is empty")
|
||||
}
|
||||
|
||||
queryParams := map[string]string{
|
||||
"origin": "AI_EDITOR",
|
||||
"profileArn": tokenData.ProfileArn,
|
||||
"resourceType": "AGENTIC_REQUEST",
|
||||
}
|
||||
|
||||
// Use endpoint from profileArn if available
|
||||
endpoint := GetKiroAPIEndpointFromProfileArn(tokenData.ProfileArn)
|
||||
url := buildURL(endpoint, pathGetUsageLimits, queryParams)
|
||||
|
||||
req, err := http.NewRequestWithContext(ctx, http.MethodGet, url, nil)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to create request: %w", err)
|
||||
}
|
||||
|
||||
accountKey := GetAccountKey(tokenData.ClientID, tokenData.RefreshToken)
|
||||
setRuntimeHeaders(req, tokenData.AccessToken, accountKey)
|
||||
|
||||
resp, err := c.httpClient.Do(req)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("request failed: %w", err)
|
||||
}
|
||||
defer resp.Body.Close()
|
||||
|
||||
body, err := io.ReadAll(resp.Body)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to read response: %w", err)
|
||||
}
|
||||
|
||||
if resp.StatusCode != http.StatusOK {
|
||||
return nil, fmt.Errorf("API error (status %d): %s", resp.StatusCode, string(body))
|
||||
}
|
||||
|
||||
var result UsageQuotaResponse
|
||||
if err := json.Unmarshal(body, &result); err != nil {
|
||||
return nil, fmt.Errorf("failed to parse usage response: %w", err)
|
||||
}
|
||||
|
||||
return &result, nil
|
||||
}
|
||||
|
||||
// CheckUsageByAccessToken retrieves usage limits using an access token and profile ARN directly.
|
||||
func (c *UsageChecker) CheckUsageByAccessToken(ctx context.Context, accessToken, profileArn string) (*UsageQuotaResponse, error) {
|
||||
tokenData := &KiroTokenData{
|
||||
AccessToken: accessToken,
|
||||
ProfileArn: profileArn,
|
||||
}
|
||||
return c.CheckUsage(ctx, tokenData)
|
||||
}
|
||||
|
||||
// GetRemainingQuota calculates the remaining quota from usage limits.
|
||||
func GetRemainingQuota(usage *UsageQuotaResponse) float64 {
|
||||
if usage == nil || len(usage.UsageBreakdownList) == 0 {
|
||||
return 0
|
||||
}
|
||||
|
||||
var totalRemaining float64
|
||||
for _, breakdown := range usage.UsageBreakdownList {
|
||||
remaining := breakdown.UsageLimitWithPrecision - breakdown.CurrentUsageWithPrecision
|
||||
if remaining > 0 {
|
||||
totalRemaining += remaining
|
||||
}
|
||||
|
||||
if breakdown.FreeTrialInfo != nil {
|
||||
freeRemaining := breakdown.FreeTrialInfo.UsageLimitWithPrecision - breakdown.FreeTrialInfo.CurrentUsageWithPrecision
|
||||
if freeRemaining > 0 {
|
||||
totalRemaining += freeRemaining
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
return totalRemaining
|
||||
}
|
||||
|
||||
// IsQuotaExhausted checks if the quota is exhausted based on usage limits.
|
||||
func IsQuotaExhausted(usage *UsageQuotaResponse) bool {
|
||||
if usage == nil || len(usage.UsageBreakdownList) == 0 {
|
||||
return true
|
||||
}
|
||||
|
||||
for _, breakdown := range usage.UsageBreakdownList {
|
||||
if breakdown.CurrentUsageWithPrecision < breakdown.UsageLimitWithPrecision {
|
||||
return false
|
||||
}
|
||||
|
||||
if breakdown.FreeTrialInfo != nil {
|
||||
if breakdown.FreeTrialInfo.CurrentUsageWithPrecision < breakdown.FreeTrialInfo.UsageLimitWithPrecision {
|
||||
return false
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
return true
|
||||
}
|
||||
|
||||
// GetQuotaStatus retrieves a comprehensive quota status for a token.
|
||||
func (c *UsageChecker) GetQuotaStatus(ctx context.Context, tokenData *KiroTokenData) (*QuotaStatus, error) {
|
||||
usage, err := c.CheckUsage(ctx, tokenData)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
status := &QuotaStatus{
|
||||
IsExhausted: IsQuotaExhausted(usage),
|
||||
}
|
||||
|
||||
if len(usage.UsageBreakdownList) > 0 {
|
||||
breakdown := usage.UsageBreakdownList[0]
|
||||
status.TotalLimit = breakdown.UsageLimitWithPrecision
|
||||
status.CurrentUsage = breakdown.CurrentUsageWithPrecision
|
||||
status.RemainingQuota = breakdown.UsageLimitWithPrecision - breakdown.CurrentUsageWithPrecision
|
||||
status.ResourceType = breakdown.ResourceType
|
||||
|
||||
if breakdown.FreeTrialInfo != nil {
|
||||
status.TotalLimit += breakdown.FreeTrialInfo.UsageLimitWithPrecision
|
||||
status.CurrentUsage += breakdown.FreeTrialInfo.CurrentUsageWithPrecision
|
||||
freeRemaining := breakdown.FreeTrialInfo.UsageLimitWithPrecision - breakdown.FreeTrialInfo.CurrentUsageWithPrecision
|
||||
if freeRemaining > 0 {
|
||||
status.RemainingQuota += freeRemaining
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
if usage.NextDateReset > 0 {
|
||||
status.NextReset = time.Unix(int64(usage.NextDateReset/1000), 0)
|
||||
}
|
||||
|
||||
return status, nil
|
||||
}
|
||||
|
||||
// CalculateAvailableCount calculates the available request count based on usage limits.
|
||||
func CalculateAvailableCount(usage *UsageQuotaResponse) float64 {
|
||||
return GetRemainingQuota(usage)
|
||||
}
|
||||
|
||||
// GetUsagePercentage calculates the usage percentage.
|
||||
func GetUsagePercentage(usage *UsageQuotaResponse) float64 {
|
||||
if usage == nil || len(usage.UsageBreakdownList) == 0 {
|
||||
return 100.0
|
||||
}
|
||||
|
||||
var totalLimit, totalUsage float64
|
||||
for _, breakdown := range usage.UsageBreakdownList {
|
||||
totalLimit += breakdown.UsageLimitWithPrecision
|
||||
totalUsage += breakdown.CurrentUsageWithPrecision
|
||||
|
||||
if breakdown.FreeTrialInfo != nil {
|
||||
totalLimit += breakdown.FreeTrialInfo.UsageLimitWithPrecision
|
||||
totalUsage += breakdown.FreeTrialInfo.CurrentUsageWithPrecision
|
||||
}
|
||||
}
|
||||
|
||||
if totalLimit == 0 {
|
||||
return 100.0
|
||||
}
|
||||
|
||||
return (totalUsage / totalLimit) * 100
|
||||
}
|
||||
Reference in New Issue
Block a user