mirror of
https://github.com/router-for-me/CLIProxyAPIPlus.git
synced 2026-03-08 06:43:41 +00:00
feat: add IDC auth support with Kiro IDE headers
This commit is contained in:
@@ -40,6 +40,10 @@ type KiroTokenData struct {
|
||||
ClientSecret string `json:"clientSecret,omitempty"`
|
||||
// Email is the user's email address (used for file naming)
|
||||
Email string `json:"email,omitempty"`
|
||||
// StartURL is the IDC/Identity Center start URL (only for IDC auth method)
|
||||
StartURL string `json:"startUrl,omitempty"`
|
||||
// Region is the AWS region for IDC authentication (only for IDC auth method)
|
||||
Region string `json:"region,omitempty"`
|
||||
}
|
||||
|
||||
// KiroAuthBundle aggregates authentication data after OAuth flow completion
|
||||
|
||||
408
internal/auth/kiro/cognito.go
Normal file
408
internal/auth/kiro/cognito.go
Normal file
@@ -0,0 +1,408 @@
|
||||
// Package kiro provides Cognito Identity credential exchange for IDC authentication.
|
||||
// AWS Identity Center (IDC) requires SigV4 signing with Cognito-exchanged credentials
|
||||
// instead of Bearer token authentication.
|
||||
package kiro
|
||||
|
||||
import (
|
||||
"context"
|
||||
"crypto/hmac"
|
||||
"crypto/sha256"
|
||||
"encoding/hex"
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"io"
|
||||
"net/http"
|
||||
"sort"
|
||||
"strings"
|
||||
"time"
|
||||
|
||||
"github.com/router-for-me/CLIProxyAPI/v6/internal/config"
|
||||
"github.com/router-for-me/CLIProxyAPI/v6/internal/util"
|
||||
log "github.com/sirupsen/logrus"
|
||||
)
|
||||
|
||||
const (
|
||||
// Cognito Identity endpoints
|
||||
cognitoIdentityEndpoint = "https://cognito-identity.us-east-1.amazonaws.com"
|
||||
|
||||
// Identity Pool ID for Q Developer / CodeWhisperer
|
||||
// This is the identity pool used by kiro-cli and Amazon Q CLI
|
||||
cognitoIdentityPoolID = "us-east-1:70717e99-906f-485d-8d89-c89a0b5d49c5"
|
||||
|
||||
// Cognito provider name for SSO OIDC
|
||||
cognitoProviderName = "cognito-identity.amazonaws.com"
|
||||
)
|
||||
|
||||
// CognitoCredentials holds temporary AWS credentials from Cognito Identity.
|
||||
type CognitoCredentials struct {
|
||||
AccessKeyID string `json:"access_key_id"`
|
||||
SecretAccessKey string `json:"secret_access_key"`
|
||||
SessionToken string `json:"session_token"`
|
||||
Expiration time.Time `json:"expiration"`
|
||||
}
|
||||
|
||||
// CognitoIdentityClient handles Cognito Identity credential exchange.
|
||||
type CognitoIdentityClient struct {
|
||||
httpClient *http.Client
|
||||
cfg *config.Config
|
||||
}
|
||||
|
||||
// NewCognitoIdentityClient creates a new Cognito Identity client.
|
||||
func NewCognitoIdentityClient(cfg *config.Config) *CognitoIdentityClient {
|
||||
client := &http.Client{Timeout: 30 * time.Second}
|
||||
if cfg != nil {
|
||||
client = util.SetProxy(&cfg.SDKConfig, client)
|
||||
}
|
||||
return &CognitoIdentityClient{
|
||||
httpClient: client,
|
||||
cfg: cfg,
|
||||
}
|
||||
}
|
||||
|
||||
// GetIdentityID retrieves a Cognito Identity ID using the SSO access token.
|
||||
func (c *CognitoIdentityClient) GetIdentityID(ctx context.Context, accessToken, region string) (string, error) {
|
||||
if region == "" {
|
||||
region = "us-east-1"
|
||||
}
|
||||
|
||||
endpoint := fmt.Sprintf("https://cognito-identity.%s.amazonaws.com", region)
|
||||
|
||||
// Build the GetId request
|
||||
// The SSO token is passed as a login token for the identity pool
|
||||
payload := map[string]interface{}{
|
||||
"IdentityPoolId": cognitoIdentityPoolID,
|
||||
"Logins": map[string]string{
|
||||
// Use the OIDC provider URL as the key
|
||||
fmt.Sprintf("oidc.%s.amazonaws.com", region): accessToken,
|
||||
},
|
||||
}
|
||||
|
||||
body, err := json.Marshal(payload)
|
||||
if err != nil {
|
||||
return "", fmt.Errorf("failed to marshal GetId request: %w", err)
|
||||
}
|
||||
|
||||
req, err := http.NewRequestWithContext(ctx, http.MethodPost, endpoint, strings.NewReader(string(body)))
|
||||
if err != nil {
|
||||
return "", fmt.Errorf("failed to create GetId request: %w", err)
|
||||
}
|
||||
|
||||
req.Header.Set("Content-Type", "application/x-amz-json-1.1")
|
||||
req.Header.Set("X-Amz-Target", "AWSCognitoIdentityService.GetId")
|
||||
req.Header.Set("Accept", "application/json")
|
||||
|
||||
resp, err := c.httpClient.Do(req)
|
||||
if err != nil {
|
||||
return "", fmt.Errorf("GetId request failed: %w", err)
|
||||
}
|
||||
defer resp.Body.Close()
|
||||
|
||||
respBody, err := io.ReadAll(resp.Body)
|
||||
if err != nil {
|
||||
return "", fmt.Errorf("failed to read GetId response: %w", err)
|
||||
}
|
||||
|
||||
if resp.StatusCode != http.StatusOK {
|
||||
log.Debugf("Cognito GetId failed (status %d): %s", resp.StatusCode, string(respBody))
|
||||
return "", fmt.Errorf("GetId failed (status %d): %s", resp.StatusCode, string(respBody))
|
||||
}
|
||||
|
||||
var result struct {
|
||||
IdentityID string `json:"IdentityId"`
|
||||
}
|
||||
if err := json.Unmarshal(respBody, &result); err != nil {
|
||||
return "", fmt.Errorf("failed to parse GetId response: %w", err)
|
||||
}
|
||||
|
||||
if result.IdentityID == "" {
|
||||
return "", fmt.Errorf("empty IdentityId in GetId response")
|
||||
}
|
||||
|
||||
log.Debugf("Cognito Identity ID: %s", result.IdentityID)
|
||||
return result.IdentityID, nil
|
||||
}
|
||||
|
||||
// GetCredentialsForIdentity exchanges an identity ID and login token for temporary AWS credentials.
|
||||
func (c *CognitoIdentityClient) GetCredentialsForIdentity(ctx context.Context, identityID, accessToken, region string) (*CognitoCredentials, error) {
|
||||
if region == "" {
|
||||
region = "us-east-1"
|
||||
}
|
||||
|
||||
endpoint := fmt.Sprintf("https://cognito-identity.%s.amazonaws.com", region)
|
||||
|
||||
payload := map[string]interface{}{
|
||||
"IdentityId": identityID,
|
||||
"Logins": map[string]string{
|
||||
fmt.Sprintf("oidc.%s.amazonaws.com", region): accessToken,
|
||||
},
|
||||
}
|
||||
|
||||
body, err := json.Marshal(payload)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to marshal GetCredentialsForIdentity request: %w", err)
|
||||
}
|
||||
|
||||
req, err := http.NewRequestWithContext(ctx, http.MethodPost, endpoint, strings.NewReader(string(body)))
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to create GetCredentialsForIdentity request: %w", err)
|
||||
}
|
||||
|
||||
req.Header.Set("Content-Type", "application/x-amz-json-1.1")
|
||||
req.Header.Set("X-Amz-Target", "AWSCognitoIdentityService.GetCredentialsForIdentity")
|
||||
req.Header.Set("Accept", "application/json")
|
||||
|
||||
resp, err := c.httpClient.Do(req)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("GetCredentialsForIdentity request failed: %w", err)
|
||||
}
|
||||
defer resp.Body.Close()
|
||||
|
||||
respBody, err := io.ReadAll(resp.Body)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to read GetCredentialsForIdentity response: %w", err)
|
||||
}
|
||||
|
||||
if resp.StatusCode != http.StatusOK {
|
||||
log.Debugf("Cognito GetCredentialsForIdentity failed (status %d): %s", resp.StatusCode, string(respBody))
|
||||
return nil, fmt.Errorf("GetCredentialsForIdentity failed (status %d): %s", resp.StatusCode, string(respBody))
|
||||
}
|
||||
|
||||
var result struct {
|
||||
Credentials struct {
|
||||
AccessKeyID string `json:"AccessKeyId"`
|
||||
SecretKey string `json:"SecretKey"`
|
||||
SessionToken string `json:"SessionToken"`
|
||||
Expiration int64 `json:"Expiration"`
|
||||
} `json:"Credentials"`
|
||||
IdentityID string `json:"IdentityId"`
|
||||
}
|
||||
if err := json.Unmarshal(respBody, &result); err != nil {
|
||||
return nil, fmt.Errorf("failed to parse GetCredentialsForIdentity response: %w", err)
|
||||
}
|
||||
|
||||
if result.Credentials.AccessKeyID == "" {
|
||||
return nil, fmt.Errorf("empty AccessKeyId in GetCredentialsForIdentity response")
|
||||
}
|
||||
|
||||
// Expiration is in seconds since epoch
|
||||
expiration := time.Unix(result.Credentials.Expiration, 0)
|
||||
|
||||
log.Debugf("Cognito credentials obtained, expires: %s", expiration.Format(time.RFC3339))
|
||||
|
||||
return &CognitoCredentials{
|
||||
AccessKeyID: result.Credentials.AccessKeyID,
|
||||
SecretAccessKey: result.Credentials.SecretKey,
|
||||
SessionToken: result.Credentials.SessionToken,
|
||||
Expiration: expiration,
|
||||
}, nil
|
||||
}
|
||||
|
||||
// ExchangeSSOTokenForCredentials is a convenience method that performs the full
|
||||
// Cognito Identity credential exchange flow: GetId -> GetCredentialsForIdentity
|
||||
func (c *CognitoIdentityClient) ExchangeSSOTokenForCredentials(ctx context.Context, accessToken, region string) (*CognitoCredentials, error) {
|
||||
log.Debugf("Exchanging SSO token for Cognito credentials (region: %s)", region)
|
||||
|
||||
// Step 1: Get Identity ID
|
||||
identityID, err := c.GetIdentityID(ctx, accessToken, region)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to get identity ID: %w", err)
|
||||
}
|
||||
|
||||
// Step 2: Get credentials for the identity
|
||||
creds, err := c.GetCredentialsForIdentity(ctx, identityID, accessToken, region)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to get credentials for identity: %w", err)
|
||||
}
|
||||
|
||||
return creds, nil
|
||||
}
|
||||
|
||||
// SigV4Signer provides AWS Signature Version 4 signing for HTTP requests.
|
||||
type SigV4Signer struct {
|
||||
credentials *CognitoCredentials
|
||||
region string
|
||||
service string
|
||||
}
|
||||
|
||||
// NewSigV4Signer creates a new SigV4 signer with the given credentials.
|
||||
func NewSigV4Signer(creds *CognitoCredentials, region, service string) *SigV4Signer {
|
||||
return &SigV4Signer{
|
||||
credentials: creds,
|
||||
region: region,
|
||||
service: service,
|
||||
}
|
||||
}
|
||||
|
||||
// SignRequest signs an HTTP request using AWS Signature Version 4.
|
||||
// The request body must be provided separately since it may have been read already.
|
||||
func (s *SigV4Signer) SignRequest(req *http.Request, body []byte) error {
|
||||
now := time.Now().UTC()
|
||||
amzDate := now.Format("20060102T150405Z")
|
||||
dateStamp := now.Format("20060102")
|
||||
|
||||
// Ensure required headers are set
|
||||
if req.Header.Get("Host") == "" {
|
||||
req.Header.Set("Host", req.URL.Host)
|
||||
}
|
||||
req.Header.Set("X-Amz-Date", amzDate)
|
||||
if s.credentials.SessionToken != "" {
|
||||
req.Header.Set("X-Amz-Security-Token", s.credentials.SessionToken)
|
||||
}
|
||||
|
||||
// Create canonical request
|
||||
canonicalRequest, signedHeaders := s.createCanonicalRequest(req, body)
|
||||
|
||||
// Create string to sign
|
||||
algorithm := "AWS4-HMAC-SHA256"
|
||||
credentialScope := fmt.Sprintf("%s/%s/%s/aws4_request", dateStamp, s.region, s.service)
|
||||
stringToSign := fmt.Sprintf("%s\n%s\n%s\n%s",
|
||||
algorithm,
|
||||
amzDate,
|
||||
credentialScope,
|
||||
hashSHA256([]byte(canonicalRequest)),
|
||||
)
|
||||
|
||||
// Calculate signature
|
||||
signingKey := s.getSignatureKey(dateStamp)
|
||||
signature := hex.EncodeToString(hmacSHA256(signingKey, []byte(stringToSign)))
|
||||
|
||||
// Build Authorization header
|
||||
authHeader := fmt.Sprintf("%s Credential=%s/%s, SignedHeaders=%s, Signature=%s",
|
||||
algorithm,
|
||||
s.credentials.AccessKeyID,
|
||||
credentialScope,
|
||||
signedHeaders,
|
||||
signature,
|
||||
)
|
||||
|
||||
req.Header.Set("Authorization", authHeader)
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// createCanonicalRequest builds the canonical request string for SigV4.
|
||||
func (s *SigV4Signer) createCanonicalRequest(req *http.Request, body []byte) (string, string) {
|
||||
// HTTP method
|
||||
method := req.Method
|
||||
|
||||
// Canonical URI
|
||||
uri := req.URL.Path
|
||||
if uri == "" {
|
||||
uri = "/"
|
||||
}
|
||||
|
||||
// Canonical query string (sorted)
|
||||
queryString := s.buildCanonicalQueryString(req)
|
||||
|
||||
// Canonical headers (sorted, lowercase)
|
||||
canonicalHeaders, signedHeaders := s.buildCanonicalHeaders(req)
|
||||
|
||||
// Hashed payload
|
||||
payloadHash := hashSHA256(body)
|
||||
|
||||
canonicalRequest := fmt.Sprintf("%s\n%s\n%s\n%s\n%s\n%s",
|
||||
method,
|
||||
uri,
|
||||
queryString,
|
||||
canonicalHeaders,
|
||||
signedHeaders,
|
||||
payloadHash,
|
||||
)
|
||||
|
||||
return canonicalRequest, signedHeaders
|
||||
}
|
||||
|
||||
// buildCanonicalQueryString builds a sorted, URI-encoded query string.
|
||||
func (s *SigV4Signer) buildCanonicalQueryString(req *http.Request) string {
|
||||
if req.URL.RawQuery == "" {
|
||||
return ""
|
||||
}
|
||||
|
||||
// Parse and sort query parameters
|
||||
params := make([]string, 0)
|
||||
for key, values := range req.URL.Query() {
|
||||
for _, value := range values {
|
||||
params = append(params, fmt.Sprintf("%s=%s", uriEncode(key), uriEncode(value)))
|
||||
}
|
||||
}
|
||||
sort.Strings(params)
|
||||
return strings.Join(params, "&")
|
||||
}
|
||||
|
||||
// buildCanonicalHeaders builds sorted, lowercase canonical headers.
|
||||
func (s *SigV4Signer) buildCanonicalHeaders(req *http.Request) (string, string) {
|
||||
// Headers to sign (must include host and x-amz-*)
|
||||
headerMap := make(map[string]string)
|
||||
headerMap["host"] = req.URL.Host
|
||||
|
||||
for key, values := range req.Header {
|
||||
lowKey := strings.ToLower(key)
|
||||
// Include x-amz-* headers and content-type
|
||||
if strings.HasPrefix(lowKey, "x-amz-") || lowKey == "content-type" {
|
||||
headerMap[lowKey] = strings.TrimSpace(values[0])
|
||||
}
|
||||
}
|
||||
|
||||
// Sort header names
|
||||
headerNames := make([]string, 0, len(headerMap))
|
||||
for name := range headerMap {
|
||||
headerNames = append(headerNames, name)
|
||||
}
|
||||
sort.Strings(headerNames)
|
||||
|
||||
// Build canonical headers and signed headers
|
||||
var canonicalHeaders strings.Builder
|
||||
for _, name := range headerNames {
|
||||
canonicalHeaders.WriteString(name)
|
||||
canonicalHeaders.WriteString(":")
|
||||
canonicalHeaders.WriteString(headerMap[name])
|
||||
canonicalHeaders.WriteString("\n")
|
||||
}
|
||||
|
||||
signedHeaders := strings.Join(headerNames, ";")
|
||||
|
||||
return canonicalHeaders.String(), signedHeaders
|
||||
}
|
||||
|
||||
// getSignatureKey derives the signing key for SigV4.
|
||||
func (s *SigV4Signer) getSignatureKey(dateStamp string) []byte {
|
||||
kDate := hmacSHA256([]byte("AWS4"+s.credentials.SecretAccessKey), []byte(dateStamp))
|
||||
kRegion := hmacSHA256(kDate, []byte(s.region))
|
||||
kService := hmacSHA256(kRegion, []byte(s.service))
|
||||
kSigning := hmacSHA256(kService, []byte("aws4_request"))
|
||||
return kSigning
|
||||
}
|
||||
|
||||
// hmacSHA256 computes HMAC-SHA256.
|
||||
func hmacSHA256(key, data []byte) []byte {
|
||||
h := hmac.New(sha256.New, key)
|
||||
h.Write(data)
|
||||
return h.Sum(nil)
|
||||
}
|
||||
|
||||
// hashSHA256 computes SHA256 hash and returns hex string.
|
||||
func hashSHA256(data []byte) string {
|
||||
hash := sha256.Sum256(data)
|
||||
return hex.EncodeToString(hash[:])
|
||||
}
|
||||
|
||||
// uriEncode performs URI encoding for SigV4.
|
||||
func uriEncode(s string) string {
|
||||
var result strings.Builder
|
||||
for i := 0; i < len(s); i++ {
|
||||
c := s[i]
|
||||
if (c >= 'A' && c <= 'Z') || (c >= 'a' && c <= 'z') ||
|
||||
(c >= '0' && c <= '9') || c == '-' || c == '.' || c == '_' || c == '~' {
|
||||
result.WriteByte(c)
|
||||
} else {
|
||||
result.WriteString(fmt.Sprintf("%%%02X", c))
|
||||
}
|
||||
}
|
||||
return result.String()
|
||||
}
|
||||
|
||||
// IsExpired checks if the credentials are expired or about to expire.
|
||||
func (c *CognitoCredentials) IsExpired() bool {
|
||||
// Consider expired if within 5 minutes of expiration
|
||||
return time.Now().Add(5 * time.Minute).After(c.Expiration)
|
||||
}
|
||||
@@ -2,6 +2,7 @@
|
||||
package kiro
|
||||
|
||||
import (
|
||||
"bufio"
|
||||
"context"
|
||||
"crypto/rand"
|
||||
"crypto/sha256"
|
||||
@@ -12,6 +13,7 @@ import (
|
||||
"io"
|
||||
"net"
|
||||
"net/http"
|
||||
"os"
|
||||
"strings"
|
||||
"time"
|
||||
|
||||
@@ -24,10 +26,13 @@ import (
|
||||
const (
|
||||
// AWS SSO OIDC endpoints
|
||||
ssoOIDCEndpoint = "https://oidc.us-east-1.amazonaws.com"
|
||||
|
||||
|
||||
// Kiro's start URL for Builder ID
|
||||
builderIDStartURL = "https://view.awsapps.com/start"
|
||||
|
||||
|
||||
// Default region for IDC
|
||||
defaultIDCRegion = "us-east-1"
|
||||
|
||||
// Polling interval
|
||||
pollInterval = 5 * time.Second
|
||||
|
||||
@@ -83,6 +88,429 @@ type CreateTokenResponse struct {
|
||||
RefreshToken string `json:"refreshToken"`
|
||||
}
|
||||
|
||||
// getOIDCEndpoint returns the OIDC endpoint for the given region.
|
||||
func getOIDCEndpoint(region string) string {
|
||||
if region == "" {
|
||||
region = defaultIDCRegion
|
||||
}
|
||||
return fmt.Sprintf("https://oidc.%s.amazonaws.com", region)
|
||||
}
|
||||
|
||||
// promptInput prompts the user for input with an optional default value.
|
||||
func promptInput(prompt, defaultValue string) string {
|
||||
reader := bufio.NewReader(os.Stdin)
|
||||
if defaultValue != "" {
|
||||
fmt.Printf("%s [%s]: ", prompt, defaultValue)
|
||||
} else {
|
||||
fmt.Printf("%s: ", prompt)
|
||||
}
|
||||
input, _ := reader.ReadString('\n')
|
||||
input = strings.TrimSpace(input)
|
||||
if input == "" {
|
||||
return defaultValue
|
||||
}
|
||||
return input
|
||||
}
|
||||
|
||||
// promptSelect prompts the user to select from options using arrow keys or number input.
|
||||
func promptSelect(prompt string, options []string) int {
|
||||
fmt.Println(prompt)
|
||||
for i, opt := range options {
|
||||
fmt.Printf(" %d) %s\n", i+1, opt)
|
||||
}
|
||||
fmt.Print("Enter selection (1-", len(options), "): ")
|
||||
|
||||
reader := bufio.NewReader(os.Stdin)
|
||||
input, _ := reader.ReadString('\n')
|
||||
input = strings.TrimSpace(input)
|
||||
|
||||
// Parse the selection
|
||||
var selection int
|
||||
if _, err := fmt.Sscanf(input, "%d", &selection); err != nil || selection < 1 || selection > len(options) {
|
||||
return 0 // Default to first option
|
||||
}
|
||||
return selection - 1
|
||||
}
|
||||
|
||||
// RegisterClientWithRegion registers a new OIDC client with AWS using a specific region.
|
||||
func (c *SSOOIDCClient) RegisterClientWithRegion(ctx context.Context, region string) (*RegisterClientResponse, error) {
|
||||
endpoint := getOIDCEndpoint(region)
|
||||
|
||||
payload := map[string]interface{}{
|
||||
"clientName": "Kiro IDE",
|
||||
"clientType": "public",
|
||||
"scopes": []string{"codewhisperer:completions", "codewhisperer:analysis", "codewhisperer:conversations", "codewhisperer:transformations", "codewhisperer:taskassist"},
|
||||
"grantTypes": []string{"urn:ietf:params:oauth:grant-type:device_code", "refresh_token"},
|
||||
}
|
||||
|
||||
body, err := json.Marshal(payload)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
req, err := http.NewRequestWithContext(ctx, http.MethodPost, endpoint+"/client/register", strings.NewReader(string(body)))
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
req.Header.Set("Content-Type", "application/json")
|
||||
req.Header.Set("User-Agent", kiroUserAgent)
|
||||
|
||||
resp, err := c.httpClient.Do(req)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
defer resp.Body.Close()
|
||||
|
||||
respBody, err := io.ReadAll(resp.Body)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
if resp.StatusCode != http.StatusOK {
|
||||
log.Debugf("register client failed (status %d): %s", resp.StatusCode, string(respBody))
|
||||
return nil, fmt.Errorf("register client failed (status %d)", resp.StatusCode)
|
||||
}
|
||||
|
||||
var result RegisterClientResponse
|
||||
if err := json.Unmarshal(respBody, &result); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
return &result, nil
|
||||
}
|
||||
|
||||
// StartDeviceAuthorizationWithIDC starts the device authorization flow for IDC.
|
||||
func (c *SSOOIDCClient) StartDeviceAuthorizationWithIDC(ctx context.Context, clientID, clientSecret, startURL, region string) (*StartDeviceAuthResponse, error) {
|
||||
endpoint := getOIDCEndpoint(region)
|
||||
|
||||
payload := map[string]string{
|
||||
"clientId": clientID,
|
||||
"clientSecret": clientSecret,
|
||||
"startUrl": startURL,
|
||||
}
|
||||
|
||||
body, err := json.Marshal(payload)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
req, err := http.NewRequestWithContext(ctx, http.MethodPost, endpoint+"/device_authorization", strings.NewReader(string(body)))
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
req.Header.Set("Content-Type", "application/json")
|
||||
req.Header.Set("User-Agent", kiroUserAgent)
|
||||
|
||||
resp, err := c.httpClient.Do(req)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
defer resp.Body.Close()
|
||||
|
||||
respBody, err := io.ReadAll(resp.Body)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
if resp.StatusCode != http.StatusOK {
|
||||
log.Debugf("start device auth failed (status %d): %s", resp.StatusCode, string(respBody))
|
||||
return nil, fmt.Errorf("start device auth failed (status %d)", resp.StatusCode)
|
||||
}
|
||||
|
||||
var result StartDeviceAuthResponse
|
||||
if err := json.Unmarshal(respBody, &result); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
return &result, nil
|
||||
}
|
||||
|
||||
// CreateTokenWithRegion polls for the access token after user authorization using a specific region.
|
||||
func (c *SSOOIDCClient) CreateTokenWithRegion(ctx context.Context, clientID, clientSecret, deviceCode, region string) (*CreateTokenResponse, error) {
|
||||
endpoint := getOIDCEndpoint(region)
|
||||
|
||||
payload := map[string]string{
|
||||
"clientId": clientID,
|
||||
"clientSecret": clientSecret,
|
||||
"deviceCode": deviceCode,
|
||||
"grantType": "urn:ietf:params:oauth:grant-type:device_code",
|
||||
}
|
||||
|
||||
body, err := json.Marshal(payload)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
req, err := http.NewRequestWithContext(ctx, http.MethodPost, endpoint+"/token", strings.NewReader(string(body)))
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
req.Header.Set("Content-Type", "application/json")
|
||||
req.Header.Set("User-Agent", kiroUserAgent)
|
||||
|
||||
resp, err := c.httpClient.Do(req)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
defer resp.Body.Close()
|
||||
|
||||
respBody, err := io.ReadAll(resp.Body)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
// Check for pending authorization
|
||||
if resp.StatusCode == http.StatusBadRequest {
|
||||
var errResp struct {
|
||||
Error string `json:"error"`
|
||||
}
|
||||
if json.Unmarshal(respBody, &errResp) == nil {
|
||||
if errResp.Error == "authorization_pending" {
|
||||
return nil, fmt.Errorf("authorization_pending")
|
||||
}
|
||||
if errResp.Error == "slow_down" {
|
||||
return nil, fmt.Errorf("slow_down")
|
||||
}
|
||||
}
|
||||
log.Debugf("create token failed: %s", string(respBody))
|
||||
return nil, fmt.Errorf("create token failed")
|
||||
}
|
||||
|
||||
if resp.StatusCode != http.StatusOK {
|
||||
log.Debugf("create token failed (status %d): %s", resp.StatusCode, string(respBody))
|
||||
return nil, fmt.Errorf("create token failed (status %d)", resp.StatusCode)
|
||||
}
|
||||
|
||||
var result CreateTokenResponse
|
||||
if err := json.Unmarshal(respBody, &result); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
return &result, nil
|
||||
}
|
||||
|
||||
// RefreshTokenWithRegion refreshes an access token using the refresh token with a specific region.
|
||||
func (c *SSOOIDCClient) RefreshTokenWithRegion(ctx context.Context, clientID, clientSecret, refreshToken, region, startURL string) (*KiroTokenData, error) {
|
||||
endpoint := getOIDCEndpoint(region)
|
||||
|
||||
payload := map[string]string{
|
||||
"clientId": clientID,
|
||||
"clientSecret": clientSecret,
|
||||
"refreshToken": refreshToken,
|
||||
"grantType": "refresh_token",
|
||||
}
|
||||
|
||||
body, err := json.Marshal(payload)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
req, err := http.NewRequestWithContext(ctx, http.MethodPost, endpoint+"/token", strings.NewReader(string(body)))
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
// Set headers matching kiro2api's IDC token refresh
|
||||
// These headers are required for successful IDC token refresh
|
||||
req.Header.Set("Content-Type", "application/json")
|
||||
req.Header.Set("Host", fmt.Sprintf("oidc.%s.amazonaws.com", region))
|
||||
req.Header.Set("Connection", "keep-alive")
|
||||
req.Header.Set("x-amz-user-agent", "aws-sdk-js/3.738.0 ua/2.1 os/other lang/js md/browser#unknown_unknown api/sso-oidc#3.738.0 m/E KiroIDE")
|
||||
req.Header.Set("Accept", "*/*")
|
||||
req.Header.Set("Accept-Language", "*")
|
||||
req.Header.Set("sec-fetch-mode", "cors")
|
||||
req.Header.Set("User-Agent", "node")
|
||||
req.Header.Set("Accept-Encoding", "br, gzip, deflate")
|
||||
|
||||
resp, err := c.httpClient.Do(req)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
defer resp.Body.Close()
|
||||
|
||||
respBody, err := io.ReadAll(resp.Body)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
if resp.StatusCode != http.StatusOK {
|
||||
log.Debugf("token refresh failed (status %d): %s", resp.StatusCode, string(respBody))
|
||||
return nil, fmt.Errorf("token refresh failed (status %d)", resp.StatusCode)
|
||||
}
|
||||
|
||||
var result CreateTokenResponse
|
||||
if err := json.Unmarshal(respBody, &result); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
expiresAt := time.Now().Add(time.Duration(result.ExpiresIn) * time.Second)
|
||||
|
||||
return &KiroTokenData{
|
||||
AccessToken: result.AccessToken,
|
||||
RefreshToken: result.RefreshToken,
|
||||
ExpiresAt: expiresAt.Format(time.RFC3339),
|
||||
AuthMethod: "idc",
|
||||
Provider: "AWS",
|
||||
ClientID: clientID,
|
||||
ClientSecret: clientSecret,
|
||||
StartURL: startURL,
|
||||
Region: region,
|
||||
}, nil
|
||||
}
|
||||
|
||||
// LoginWithIDC performs the full device code flow for AWS Identity Center (IDC).
|
||||
func (c *SSOOIDCClient) LoginWithIDC(ctx context.Context, startURL, region string) (*KiroTokenData, error) {
|
||||
fmt.Println("\n╔══════════════════════════════════════════════════════════╗")
|
||||
fmt.Println("║ Kiro Authentication (AWS Identity Center) ║")
|
||||
fmt.Println("╚══════════════════════════════════════════════════════════╝")
|
||||
|
||||
// Step 1: Register client with the specified region
|
||||
fmt.Println("\nRegistering client...")
|
||||
regResp, err := c.RegisterClientWithRegion(ctx, region)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to register client: %w", err)
|
||||
}
|
||||
log.Debugf("Client registered: %s", regResp.ClientID)
|
||||
|
||||
// Step 2: Start device authorization with IDC start URL
|
||||
fmt.Println("Starting device authorization...")
|
||||
authResp, err := c.StartDeviceAuthorizationWithIDC(ctx, regResp.ClientID, regResp.ClientSecret, startURL, region)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to start device auth: %w", err)
|
||||
}
|
||||
|
||||
// Step 3: Show user the verification URL
|
||||
fmt.Printf("\n")
|
||||
fmt.Println("════════════════════════════════════════════════════════════")
|
||||
fmt.Printf(" Confirm the following code in the browser:\n")
|
||||
fmt.Printf(" Code: %s\n", authResp.UserCode)
|
||||
fmt.Println("════════════════════════════════════════════════════════════")
|
||||
fmt.Printf("\n Open this URL: %s\n\n", authResp.VerificationURIComplete)
|
||||
|
||||
// Set incognito mode based on config
|
||||
if c.cfg != nil {
|
||||
browser.SetIncognitoMode(c.cfg.IncognitoBrowser)
|
||||
if !c.cfg.IncognitoBrowser {
|
||||
log.Info("kiro: using normal browser mode (--no-incognito). Note: You may not be able to select a different account.")
|
||||
} else {
|
||||
log.Debug("kiro: using incognito mode for multi-account support")
|
||||
}
|
||||
} else {
|
||||
browser.SetIncognitoMode(true)
|
||||
log.Debug("kiro: using incognito mode for multi-account support (default)")
|
||||
}
|
||||
|
||||
// Open browser
|
||||
if err := browser.OpenURL(authResp.VerificationURIComplete); err != nil {
|
||||
log.Warnf("Could not open browser automatically: %v", err)
|
||||
fmt.Println(" Please open the URL manually in your browser.")
|
||||
} else {
|
||||
fmt.Println(" (Browser opened automatically)")
|
||||
}
|
||||
|
||||
// Step 4: Poll for token
|
||||
fmt.Println("Waiting for authorization...")
|
||||
|
||||
interval := pollInterval
|
||||
if authResp.Interval > 0 {
|
||||
interval = time.Duration(authResp.Interval) * time.Second
|
||||
}
|
||||
|
||||
deadline := time.Now().Add(time.Duration(authResp.ExpiresIn) * time.Second)
|
||||
|
||||
for time.Now().Before(deadline) {
|
||||
select {
|
||||
case <-ctx.Done():
|
||||
browser.CloseBrowser()
|
||||
return nil, ctx.Err()
|
||||
case <-time.After(interval):
|
||||
tokenResp, err := c.CreateTokenWithRegion(ctx, regResp.ClientID, regResp.ClientSecret, authResp.DeviceCode, region)
|
||||
if err != nil {
|
||||
errStr := err.Error()
|
||||
if strings.Contains(errStr, "authorization_pending") {
|
||||
fmt.Print(".")
|
||||
continue
|
||||
}
|
||||
if strings.Contains(errStr, "slow_down") {
|
||||
interval += 5 * time.Second
|
||||
continue
|
||||
}
|
||||
browser.CloseBrowser()
|
||||
return nil, fmt.Errorf("token creation failed: %w", err)
|
||||
}
|
||||
|
||||
fmt.Println("\n\n✓ Authorization successful!")
|
||||
|
||||
// Close the browser window
|
||||
if err := browser.CloseBrowser(); err != nil {
|
||||
log.Debugf("Failed to close browser: %v", err)
|
||||
}
|
||||
|
||||
// Step 5: Get profile ARN from CodeWhisperer API
|
||||
fmt.Println("Fetching profile information...")
|
||||
profileArn := c.fetchProfileArn(ctx, tokenResp.AccessToken)
|
||||
|
||||
// Fetch user email
|
||||
email := FetchUserEmailWithFallback(ctx, c.cfg, tokenResp.AccessToken)
|
||||
if email != "" {
|
||||
fmt.Printf(" Logged in as: %s\n", email)
|
||||
}
|
||||
|
||||
expiresAt := time.Now().Add(time.Duration(tokenResp.ExpiresIn) * time.Second)
|
||||
|
||||
return &KiroTokenData{
|
||||
AccessToken: tokenResp.AccessToken,
|
||||
RefreshToken: tokenResp.RefreshToken,
|
||||
ProfileArn: profileArn,
|
||||
ExpiresAt: expiresAt.Format(time.RFC3339),
|
||||
AuthMethod: "idc",
|
||||
Provider: "AWS",
|
||||
ClientID: regResp.ClientID,
|
||||
ClientSecret: regResp.ClientSecret,
|
||||
Email: email,
|
||||
StartURL: startURL,
|
||||
Region: region,
|
||||
}, nil
|
||||
}
|
||||
}
|
||||
|
||||
// Close browser on timeout
|
||||
if err := browser.CloseBrowser(); err != nil {
|
||||
log.Debugf("Failed to close browser on timeout: %v", err)
|
||||
}
|
||||
return nil, fmt.Errorf("authorization timed out")
|
||||
}
|
||||
|
||||
// LoginWithMethodSelection prompts the user to select between Builder ID and IDC, then performs the login.
|
||||
func (c *SSOOIDCClient) LoginWithMethodSelection(ctx context.Context) (*KiroTokenData, error) {
|
||||
fmt.Println("\n╔══════════════════════════════════════════════════════════╗")
|
||||
fmt.Println("║ Kiro Authentication (AWS) ║")
|
||||
fmt.Println("╚══════════════════════════════════════════════════════════╝")
|
||||
|
||||
// Prompt for login method
|
||||
options := []string{
|
||||
"Use with Builder ID (personal AWS account)",
|
||||
"Use with IDC Account (organization SSO)",
|
||||
}
|
||||
selection := promptSelect("\n? Select login method:", options)
|
||||
|
||||
if selection == 0 {
|
||||
// Builder ID flow - use existing implementation
|
||||
return c.LoginWithBuilderID(ctx)
|
||||
}
|
||||
|
||||
// IDC flow - prompt for start URL and region
|
||||
fmt.Println()
|
||||
startURL := promptInput("? Enter Start URL", "")
|
||||
if startURL == "" {
|
||||
return nil, fmt.Errorf("start URL is required for IDC login")
|
||||
}
|
||||
|
||||
region := promptInput("? Enter Region", defaultIDCRegion)
|
||||
|
||||
return c.LoginWithIDC(ctx, startURL, region)
|
||||
}
|
||||
|
||||
// RegisterClient registers a new OIDC client with AWS.
|
||||
func (c *SSOOIDCClient) RegisterClient(ctx context.Context) (*RegisterClientResponse, error) {
|
||||
payload := map[string]interface{}{
|
||||
|
||||
@@ -43,10 +43,15 @@ const (
|
||||
// Event Stream error type constants
|
||||
ErrStreamFatal = "fatal" // Connection/authentication errors, not recoverable
|
||||
ErrStreamMalformed = "malformed" // Format errors, data cannot be parsed
|
||||
// kiroUserAgent matches amq2api format for User-Agent header
|
||||
// kiroUserAgent matches amq2api format for User-Agent header (Amazon Q CLI style)
|
||||
kiroUserAgent = "aws-sdk-rust/1.3.9 os/macos lang/rust/1.87.0"
|
||||
// kiroFullUserAgent is the complete x-amz-user-agent header matching amq2api
|
||||
// kiroFullUserAgent is the complete x-amz-user-agent header matching amq2api (Amazon Q CLI style)
|
||||
kiroFullUserAgent = "aws-sdk-rust/1.3.9 ua/2.1 api/ssooidc/1.88.0 os/macos lang/rust/1.87.0 m/E app/AmazonQ-For-CLI"
|
||||
|
||||
// Kiro IDE style headers (from kiro2api - for IDC auth)
|
||||
kiroIDEUserAgent = "aws-sdk-js/1.0.18 ua/2.1 os/darwin#25.0.0 lang/js md/nodejs#20.16.0 api/codewhispererstreaming#1.0.18 m/E KiroIDE-0.2.13-66c23a8c5d15afabec89ef9954ef52a119f10d369df04d548fc6c1eac694b0d1"
|
||||
kiroIDEAmzUserAgent = "aws-sdk-js/1.0.18 KiroIDE-0.2.13-66c23a8c5d15afabec89ef9954ef52a119f10d369df04d548fc6c1eac694b0d1"
|
||||
kiroIDEAgentModeSpec = "spec"
|
||||
)
|
||||
|
||||
// Real-time usage estimation configuration
|
||||
@@ -101,11 +106,24 @@ var kiroEndpointConfigs = []kiroEndpointConfig{
|
||||
|
||||
// getKiroEndpointConfigs returns the list of Kiro API endpoint configurations to try in order.
|
||||
// Supports reordering based on "preferred_endpoint" in auth metadata/attributes.
|
||||
// For IDC auth method, automatically uses CodeWhisperer endpoint with CLI origin.
|
||||
func getKiroEndpointConfigs(auth *cliproxyauth.Auth) []kiroEndpointConfig {
|
||||
if auth == nil {
|
||||
return kiroEndpointConfigs
|
||||
}
|
||||
|
||||
// For IDC auth, use CodeWhisperer endpoint with AI_EDITOR origin (same as Social auth)
|
||||
// Based on kiro2api analysis: IDC tokens work with CodeWhisperer endpoint using Bearer auth
|
||||
// The difference is only in how tokens are refreshed (OIDC with clientId/clientSecret for IDC)
|
||||
// NOT in how API calls are made - both Social and IDC use the same endpoint/origin
|
||||
if auth.Metadata != nil {
|
||||
authMethod, _ := auth.Metadata["auth_method"].(string)
|
||||
if authMethod == "idc" {
|
||||
log.Debugf("kiro: IDC auth, using CodeWhisperer endpoint")
|
||||
return kiroEndpointConfigs
|
||||
}
|
||||
}
|
||||
|
||||
// Check for preference
|
||||
var preference string
|
||||
if auth.Metadata != nil {
|
||||
@@ -160,6 +178,79 @@ func getKiroEndpointConfigs(auth *cliproxyauth.Auth) []kiroEndpointConfig {
|
||||
type KiroExecutor struct {
|
||||
cfg *config.Config
|
||||
refreshMu sync.Mutex // Serializes token refresh operations to prevent race conditions
|
||||
|
||||
// cognitoCredsCache caches Cognito credentials per auth ID for IDC authentication
|
||||
// Key: auth.ID, Value: *kiroauth.CognitoCredentials
|
||||
cognitoCredsCache sync.Map
|
||||
}
|
||||
|
||||
// getCachedCognitoCredentials retrieves cached Cognito credentials if they are still valid.
|
||||
func (e *KiroExecutor) getCachedCognitoCredentials(authID string) *kiroauth.CognitoCredentials {
|
||||
if cached, ok := e.cognitoCredsCache.Load(authID); ok {
|
||||
creds := cached.(*kiroauth.CognitoCredentials)
|
||||
if !creds.IsExpired() {
|
||||
return creds
|
||||
}
|
||||
// Credentials expired, remove from cache
|
||||
e.cognitoCredsCache.Delete(authID)
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
// cacheCognitoCredentials stores Cognito credentials in the cache.
|
||||
func (e *KiroExecutor) cacheCognitoCredentials(authID string, creds *kiroauth.CognitoCredentials) {
|
||||
e.cognitoCredsCache.Store(authID, creds)
|
||||
}
|
||||
|
||||
// getOrExchangeCognitoCredentials retrieves cached Cognito credentials or exchanges the SSO token for new ones.
|
||||
func (e *KiroExecutor) getOrExchangeCognitoCredentials(ctx context.Context, auth *cliproxyauth.Auth, accessToken string) (*kiroauth.CognitoCredentials, error) {
|
||||
if auth == nil {
|
||||
return nil, fmt.Errorf("auth is nil")
|
||||
}
|
||||
|
||||
// Check cache first
|
||||
if creds := e.getCachedCognitoCredentials(auth.ID); creds != nil {
|
||||
log.Debugf("kiro: using cached Cognito credentials for auth %s (expires: %s)", auth.ID, creds.Expiration.Format(time.RFC3339))
|
||||
return creds, nil
|
||||
}
|
||||
|
||||
// Get region from auth metadata
|
||||
region := "us-east-1"
|
||||
if auth.Metadata != nil {
|
||||
if r, ok := auth.Metadata["region"].(string); ok && r != "" {
|
||||
region = r
|
||||
}
|
||||
}
|
||||
|
||||
log.Infof("kiro: exchanging SSO token for Cognito credentials (region: %s)", region)
|
||||
|
||||
// Exchange SSO token for Cognito credentials
|
||||
cognitoClient := kiroauth.NewCognitoIdentityClient(e.cfg)
|
||||
creds, err := cognitoClient.ExchangeSSOTokenForCredentials(ctx, accessToken, region)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to exchange SSO token for Cognito credentials: %w", err)
|
||||
}
|
||||
|
||||
// Cache the credentials
|
||||
e.cacheCognitoCredentials(auth.ID, creds)
|
||||
log.Infof("kiro: Cognito credentials obtained and cached (expires: %s)", creds.Expiration.Format(time.RFC3339))
|
||||
|
||||
return creds, nil
|
||||
}
|
||||
|
||||
// isIDCAuth checks if the auth uses IDC (Identity Center) authentication method.
|
||||
func isIDCAuth(auth *cliproxyauth.Auth) bool {
|
||||
if auth == nil || auth.Metadata == nil {
|
||||
return false
|
||||
}
|
||||
authMethod, _ := auth.Metadata["auth_method"].(string)
|
||||
return authMethod == "idc"
|
||||
}
|
||||
|
||||
// signRequestWithSigV4 signs an HTTP request with AWS SigV4 using Cognito credentials.
|
||||
func signRequestWithSigV4(req *http.Request, payload []byte, creds *kiroauth.CognitoCredentials, region, service string) error {
|
||||
signer := kiroauth.NewSigV4Signer(creds, region, service)
|
||||
return signer.SignRequest(req, payload)
|
||||
}
|
||||
|
||||
// buildKiroPayloadForFormat builds the Kiro API payload based on the source format.
|
||||
@@ -262,15 +353,60 @@ func (e *KiroExecutor) executeWithRetry(ctx context.Context, auth *cliproxyauth.
|
||||
}
|
||||
|
||||
httpReq.Header.Set("Content-Type", kiroContentType)
|
||||
httpReq.Header.Set("Authorization", "Bearer "+accessToken)
|
||||
httpReq.Header.Set("Accept", kiroAcceptStream)
|
||||
// Use endpoint-specific X-Amz-Target (critical for avoiding 403 errors)
|
||||
httpReq.Header.Set("X-Amz-Target", endpointConfig.AmzTarget)
|
||||
httpReq.Header.Set("User-Agent", kiroUserAgent)
|
||||
httpReq.Header.Set("X-Amz-User-Agent", kiroFullUserAgent)
|
||||
|
||||
// Use different headers based on auth type
|
||||
// IDC auth uses Kiro IDE style headers (from kiro2api)
|
||||
// Other auth types use Amazon Q CLI style headers
|
||||
if isIDCAuth(auth) {
|
||||
httpReq.Header.Set("User-Agent", kiroIDEUserAgent)
|
||||
httpReq.Header.Set("X-Amz-User-Agent", kiroIDEAmzUserAgent)
|
||||
httpReq.Header.Set("x-amzn-kiro-agent-mode", kiroIDEAgentModeSpec)
|
||||
log.Debugf("kiro: using Kiro IDE headers for IDC auth")
|
||||
} else {
|
||||
httpReq.Header.Set("User-Agent", kiroUserAgent)
|
||||
httpReq.Header.Set("X-Amz-User-Agent", kiroFullUserAgent)
|
||||
}
|
||||
httpReq.Header.Set("Amz-Sdk-Request", "attempt=1; max=3")
|
||||
httpReq.Header.Set("Amz-Sdk-Invocation-Id", uuid.New().String())
|
||||
|
||||
// Choose auth method: SigV4 for IDC, Bearer token for others
|
||||
// NOTE: Cognito credential exchange disabled for now - testing Bearer token first
|
||||
if false && isIDCAuth(auth) {
|
||||
// IDC auth requires SigV4 signing with Cognito-exchanged credentials
|
||||
cognitoCreds, err := e.getOrExchangeCognitoCredentials(ctx, auth, accessToken)
|
||||
if err != nil {
|
||||
log.Warnf("kiro: failed to get Cognito credentials for IDC auth: %v", err)
|
||||
return resp, fmt.Errorf("IDC auth requires Cognito credentials: %w", err)
|
||||
}
|
||||
|
||||
// Get region from auth metadata
|
||||
region := "us-east-1"
|
||||
if auth.Metadata != nil {
|
||||
if r, ok := auth.Metadata["region"].(string); ok && r != "" {
|
||||
region = r
|
||||
}
|
||||
}
|
||||
|
||||
// Determine service from URL
|
||||
service := "codewhisperer"
|
||||
if strings.Contains(url, "q.us-east-1.amazonaws.com") {
|
||||
service = "qdeveloper"
|
||||
}
|
||||
|
||||
// Sign the request with SigV4
|
||||
if err := signRequestWithSigV4(httpReq, kiroPayload, cognitoCreds, region, service); err != nil {
|
||||
log.Warnf("kiro: failed to sign request with SigV4: %v", err)
|
||||
return resp, fmt.Errorf("SigV4 signing failed: %w", err)
|
||||
}
|
||||
log.Debugf("kiro: request signed with SigV4 for IDC auth (service: %s, region: %s)", service, region)
|
||||
} else {
|
||||
// Standard Bearer token authentication for Builder ID, social auth, etc.
|
||||
httpReq.Header.Set("Authorization", "Bearer "+accessToken)
|
||||
}
|
||||
|
||||
var attrs map[string]string
|
||||
if auth != nil {
|
||||
attrs = auth.Attributes
|
||||
@@ -568,15 +704,60 @@ func (e *KiroExecutor) executeStreamWithRetry(ctx context.Context, auth *cliprox
|
||||
}
|
||||
|
||||
httpReq.Header.Set("Content-Type", kiroContentType)
|
||||
httpReq.Header.Set("Authorization", "Bearer "+accessToken)
|
||||
httpReq.Header.Set("Accept", kiroAcceptStream)
|
||||
// Use endpoint-specific X-Amz-Target (critical for avoiding 403 errors)
|
||||
httpReq.Header.Set("X-Amz-Target", endpointConfig.AmzTarget)
|
||||
httpReq.Header.Set("User-Agent", kiroUserAgent)
|
||||
httpReq.Header.Set("X-Amz-User-Agent", kiroFullUserAgent)
|
||||
|
||||
// Use different headers based on auth type
|
||||
// IDC auth uses Kiro IDE style headers (from kiro2api)
|
||||
// Other auth types use Amazon Q CLI style headers
|
||||
if isIDCAuth(auth) {
|
||||
httpReq.Header.Set("User-Agent", kiroIDEUserAgent)
|
||||
httpReq.Header.Set("X-Amz-User-Agent", kiroIDEAmzUserAgent)
|
||||
httpReq.Header.Set("x-amzn-kiro-agent-mode", kiroIDEAgentModeSpec)
|
||||
log.Debugf("kiro: using Kiro IDE headers for IDC auth")
|
||||
} else {
|
||||
httpReq.Header.Set("User-Agent", kiroUserAgent)
|
||||
httpReq.Header.Set("X-Amz-User-Agent", kiroFullUserAgent)
|
||||
}
|
||||
httpReq.Header.Set("Amz-Sdk-Request", "attempt=1; max=3")
|
||||
httpReq.Header.Set("Amz-Sdk-Invocation-Id", uuid.New().String())
|
||||
|
||||
// Choose auth method: SigV4 for IDC, Bearer token for others
|
||||
// NOTE: Cognito credential exchange disabled for now - testing Bearer token first
|
||||
if false && isIDCAuth(auth) {
|
||||
// IDC auth requires SigV4 signing with Cognito-exchanged credentials
|
||||
cognitoCreds, err := e.getOrExchangeCognitoCredentials(ctx, auth, accessToken)
|
||||
if err != nil {
|
||||
log.Warnf("kiro: failed to get Cognito credentials for IDC auth: %v", err)
|
||||
return nil, fmt.Errorf("IDC auth requires Cognito credentials: %w", err)
|
||||
}
|
||||
|
||||
// Get region from auth metadata
|
||||
region := "us-east-1"
|
||||
if auth.Metadata != nil {
|
||||
if r, ok := auth.Metadata["region"].(string); ok && r != "" {
|
||||
region = r
|
||||
}
|
||||
}
|
||||
|
||||
// Determine service from URL
|
||||
service := "codewhisperer"
|
||||
if strings.Contains(url, "q.us-east-1.amazonaws.com") {
|
||||
service = "qdeveloper"
|
||||
}
|
||||
|
||||
// Sign the request with SigV4
|
||||
if err := signRequestWithSigV4(httpReq, kiroPayload, cognitoCreds, region, service); err != nil {
|
||||
log.Warnf("kiro: failed to sign request with SigV4: %v", err)
|
||||
return nil, fmt.Errorf("SigV4 signing failed: %w", err)
|
||||
}
|
||||
log.Debugf("kiro: stream request signed with SigV4 for IDC auth (service: %s, region: %s)", service, region)
|
||||
} else {
|
||||
// Standard Bearer token authentication for Builder ID, social auth, etc.
|
||||
httpReq.Header.Set("Authorization", "Bearer "+accessToken)
|
||||
}
|
||||
|
||||
var attrs map[string]string
|
||||
if auth != nil {
|
||||
attrs = auth.Attributes
|
||||
@@ -1001,12 +1182,12 @@ func getEffectiveProfileArn(auth *cliproxyauth.Auth, profileArn string) string {
|
||||
// This consolidates the auth_method check that was previously done separately.
|
||||
func getEffectiveProfileArnWithWarning(auth *cliproxyauth.Auth, profileArn string) string {
|
||||
if auth != nil && auth.Metadata != nil {
|
||||
if authMethod, ok := auth.Metadata["auth_method"].(string); ok && authMethod == "builder-id" {
|
||||
// builder-id auth doesn't need profileArn
|
||||
if authMethod, ok := auth.Metadata["auth_method"].(string); ok && (authMethod == "builder-id" || authMethod == "idc") {
|
||||
// builder-id and idc auth don't need profileArn
|
||||
return ""
|
||||
}
|
||||
}
|
||||
// For non-builder-id auth (social auth), profileArn is required
|
||||
// For non-builder-id/idc auth (social auth), profileArn is required
|
||||
if profileArn == "" {
|
||||
log.Warnf("kiro: profile ARN not found in auth, API calls may fail")
|
||||
}
|
||||
|
||||
116
sdk/auth/kiro.go
116
sdk/auth/kiro.go
@@ -53,20 +53,8 @@ func (a *KiroAuthenticator) RefreshLead() *time.Duration {
|
||||
return &d
|
||||
}
|
||||
|
||||
// Login performs OAuth login for Kiro with AWS Builder ID.
|
||||
func (a *KiroAuthenticator) Login(ctx context.Context, cfg *config.Config, opts *LoginOptions) (*coreauth.Auth, error) {
|
||||
if cfg == nil {
|
||||
return nil, fmt.Errorf("kiro auth: configuration is required")
|
||||
}
|
||||
|
||||
oauth := kiroauth.NewKiroOAuth(cfg)
|
||||
|
||||
// Use AWS Builder ID device code flow
|
||||
tokenData, err := oauth.LoginWithBuilderID(ctx)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("login failed: %w", err)
|
||||
}
|
||||
|
||||
// createAuthRecord creates an auth record from token data.
|
||||
func (a *KiroAuthenticator) createAuthRecord(tokenData *kiroauth.KiroTokenData, source string) (*coreauth.Auth, error) {
|
||||
// Parse expires_at
|
||||
expiresAt, err := time.Parse(time.RFC3339, tokenData.ExpiresAt)
|
||||
if err != nil {
|
||||
@@ -76,34 +64,63 @@ func (a *KiroAuthenticator) Login(ctx context.Context, cfg *config.Config, opts
|
||||
// Extract identifier for file naming
|
||||
idPart := extractKiroIdentifier(tokenData.Email, tokenData.ProfileArn)
|
||||
|
||||
// Determine label based on auth method
|
||||
label := fmt.Sprintf("kiro-%s", source)
|
||||
if tokenData.AuthMethod == "idc" {
|
||||
label = "kiro-idc"
|
||||
}
|
||||
|
||||
now := time.Now()
|
||||
fileName := fmt.Sprintf("kiro-aws-%s.json", idPart)
|
||||
fileName := fmt.Sprintf("%s-%s.json", label, idPart)
|
||||
|
||||
metadata := map[string]any{
|
||||
"type": "kiro",
|
||||
"access_token": tokenData.AccessToken,
|
||||
"refresh_token": tokenData.RefreshToken,
|
||||
"profile_arn": tokenData.ProfileArn,
|
||||
"expires_at": tokenData.ExpiresAt,
|
||||
"auth_method": tokenData.AuthMethod,
|
||||
"provider": tokenData.Provider,
|
||||
"client_id": tokenData.ClientID,
|
||||
"client_secret": tokenData.ClientSecret,
|
||||
"email": tokenData.Email,
|
||||
}
|
||||
|
||||
// Add IDC-specific fields if present
|
||||
if tokenData.StartURL != "" {
|
||||
metadata["start_url"] = tokenData.StartURL
|
||||
}
|
||||
if tokenData.Region != "" {
|
||||
metadata["region"] = tokenData.Region
|
||||
}
|
||||
|
||||
attributes := map[string]string{
|
||||
"profile_arn": tokenData.ProfileArn,
|
||||
"source": source,
|
||||
"email": tokenData.Email,
|
||||
}
|
||||
|
||||
// Add IDC-specific attributes if present
|
||||
if tokenData.AuthMethod == "idc" {
|
||||
attributes["source"] = "aws-idc"
|
||||
if tokenData.StartURL != "" {
|
||||
attributes["start_url"] = tokenData.StartURL
|
||||
}
|
||||
if tokenData.Region != "" {
|
||||
attributes["region"] = tokenData.Region
|
||||
}
|
||||
}
|
||||
|
||||
record := &coreauth.Auth{
|
||||
ID: fileName,
|
||||
Provider: "kiro",
|
||||
FileName: fileName,
|
||||
Label: "kiro-aws",
|
||||
Label: label,
|
||||
Status: coreauth.StatusActive,
|
||||
CreatedAt: now,
|
||||
UpdatedAt: now,
|
||||
Metadata: map[string]any{
|
||||
"type": "kiro",
|
||||
"access_token": tokenData.AccessToken,
|
||||
"refresh_token": tokenData.RefreshToken,
|
||||
"profile_arn": tokenData.ProfileArn,
|
||||
"expires_at": tokenData.ExpiresAt,
|
||||
"auth_method": tokenData.AuthMethod,
|
||||
"provider": tokenData.Provider,
|
||||
"client_id": tokenData.ClientID,
|
||||
"client_secret": tokenData.ClientSecret,
|
||||
"email": tokenData.Email,
|
||||
},
|
||||
Attributes: map[string]string{
|
||||
"profile_arn": tokenData.ProfileArn,
|
||||
"source": "aws-builder-id",
|
||||
"email": tokenData.Email,
|
||||
},
|
||||
Metadata: metadata,
|
||||
Attributes: attributes,
|
||||
// NextRefreshAfter is aligned with RefreshLead (5min)
|
||||
NextRefreshAfter: expiresAt.Add(-5 * time.Minute),
|
||||
}
|
||||
@@ -117,6 +134,23 @@ func (a *KiroAuthenticator) Login(ctx context.Context, cfg *config.Config, opts
|
||||
return record, nil
|
||||
}
|
||||
|
||||
// Login performs OAuth login for Kiro with AWS (Builder ID or IDC).
|
||||
// This shows a method selection prompt and handles both flows.
|
||||
func (a *KiroAuthenticator) Login(ctx context.Context, cfg *config.Config, opts *LoginOptions) (*coreauth.Auth, error) {
|
||||
if cfg == nil {
|
||||
return nil, fmt.Errorf("kiro auth: configuration is required")
|
||||
}
|
||||
|
||||
// Use the unified method selection flow (Builder ID or IDC)
|
||||
ssoClient := kiroauth.NewSSOOIDCClient(cfg)
|
||||
tokenData, err := ssoClient.LoginWithMethodSelection(ctx)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("login failed: %w", err)
|
||||
}
|
||||
|
||||
return a.createAuthRecord(tokenData, "aws")
|
||||
}
|
||||
|
||||
// LoginWithAuthCode performs OAuth login for Kiro with AWS Builder ID using authorization code flow.
|
||||
// This provides a better UX than device code flow as it uses automatic browser callback.
|
||||
func (a *KiroAuthenticator) LoginWithAuthCode(ctx context.Context, cfg *config.Config, opts *LoginOptions) (*coreauth.Auth, error) {
|
||||
@@ -388,15 +422,23 @@ func (a *KiroAuthenticator) Refresh(ctx context.Context, cfg *config.Config, aut
|
||||
clientID, _ := auth.Metadata["client_id"].(string)
|
||||
clientSecret, _ := auth.Metadata["client_secret"].(string)
|
||||
authMethod, _ := auth.Metadata["auth_method"].(string)
|
||||
startURL, _ := auth.Metadata["start_url"].(string)
|
||||
region, _ := auth.Metadata["region"].(string)
|
||||
|
||||
var tokenData *kiroauth.KiroTokenData
|
||||
var err error
|
||||
|
||||
// Use SSO OIDC refresh for AWS Builder ID, otherwise use Kiro's OAuth refresh endpoint
|
||||
if clientID != "" && clientSecret != "" && authMethod == "builder-id" {
|
||||
ssoClient := kiroauth.NewSSOOIDCClient(cfg)
|
||||
ssoClient := kiroauth.NewSSOOIDCClient(cfg)
|
||||
|
||||
// Use SSO OIDC refresh for AWS Builder ID or IDC, otherwise use Kiro's OAuth refresh endpoint
|
||||
switch {
|
||||
case clientID != "" && clientSecret != "" && authMethod == "idc" && region != "":
|
||||
// IDC refresh with region-specific endpoint
|
||||
tokenData, err = ssoClient.RefreshTokenWithRegion(ctx, clientID, clientSecret, refreshToken, region, startURL)
|
||||
case clientID != "" && clientSecret != "" && authMethod == "builder-id":
|
||||
// Builder ID refresh with default endpoint
|
||||
tokenData, err = ssoClient.RefreshToken(ctx, clientID, clientSecret, refreshToken)
|
||||
} else {
|
||||
default:
|
||||
// Fallback to Kiro's refresh endpoint (for social auth: Google/GitHub)
|
||||
oauth := kiroauth.NewKiroOAuth(cfg)
|
||||
tokenData, err = oauth.RefreshToken(ctx, refreshToken)
|
||||
|
||||
Reference in New Issue
Block a user