Merge branch 'main' into plus

This commit is contained in:
Luis Pater
2026-03-28 04:51:18 +08:00
committed by GitHub
145 changed files with 37863 additions and 478 deletions

View File

@@ -1,6 +1,7 @@
package management
import (
"bytes"
"context"
"encoding/json"
"fmt"
@@ -10,6 +11,7 @@ import (
"strings"
"time"
"github.com/fxamacker/cbor/v2"
"github.com/gin-gonic/gin"
"github.com/router-for-me/CLIProxyAPI/v6/internal/runtime/geminicli"
coreauth "github.com/router-for-me/CLIProxyAPI/v6/sdk/cliproxy/auth"
@@ -53,6 +55,7 @@ type apiCallResponse struct {
StatusCode int `json:"status_code"`
Header map[string][]string `json:"header"`
Body string `json:"body"`
Quota *QuotaSnapshots `json:"quota,omitempty"`
}
// APICall makes a generic HTTP request on behalf of the management API caller.
@@ -69,7 +72,7 @@ type apiCallResponse struct {
// - Authorization: Bearer <key>
// - X-Management-Key: <key>
//
// Request JSON:
// Request JSON (supports both application/json and application/cbor):
// - auth_index / authIndex / AuthIndex (optional):
// The credential "auth_index" from GET /v0/management/auth-files (or other endpoints returning it).
// If omitted or not found, credential-specific proxy/token substitution is skipped.
@@ -89,10 +92,14 @@ type apiCallResponse struct {
// 2. Global config proxy-url
// 3. Direct connect (environment proxies are not used)
//
// Response JSON (returned with HTTP 200 when the APICall itself succeeds):
// - status_code: Upstream HTTP status code.
// - header: Upstream response headers.
// - body: Upstream response body as string.
// Response (returned with HTTP 200 when the APICall itself succeeds):
//
// Format matches request Content-Type (application/json or application/cbor)
// - status_code: Upstream HTTP status code.
// - header: Upstream response headers.
// - body: Upstream response body as string.
// - quota (optional): For GitHub Copilot enterprise accounts, contains quota_snapshots
// with details for chat, completions, and premium_interactions.
//
// Example:
//
@@ -106,10 +113,28 @@ type apiCallResponse struct {
// -H "Content-Type: application/json" \
// -d '{"auth_index":"<AUTH_INDEX>","method":"POST","url":"https://api.example.com/v1/fetchAvailableModels","header":{"Authorization":"Bearer $TOKEN$","Content-Type":"application/json","User-Agent":"cliproxyapi"},"data":"{}"}'
func (h *Handler) APICall(c *gin.Context) {
// Detect content type
contentType := strings.ToLower(strings.TrimSpace(c.GetHeader("Content-Type")))
isCBOR := strings.Contains(contentType, "application/cbor")
var body apiCallRequest
if errBindJSON := c.ShouldBindJSON(&body); errBindJSON != nil {
c.JSON(http.StatusBadRequest, gin.H{"error": "invalid body"})
return
// Parse request body based on content type
if isCBOR {
rawBody, errRead := io.ReadAll(c.Request.Body)
if errRead != nil {
c.JSON(http.StatusBadRequest, gin.H{"error": "failed to read request body"})
return
}
if errUnmarshal := cbor.Unmarshal(rawBody, &body); errUnmarshal != nil {
c.JSON(http.StatusBadRequest, gin.H{"error": "invalid cbor body"})
return
}
} else {
if errBindJSON := c.ShouldBindJSON(&body); errBindJSON != nil {
c.JSON(http.StatusBadRequest, gin.H{"error": "invalid body"})
return
}
}
method := strings.ToUpper(strings.TrimSpace(body.Method))
@@ -163,9 +188,21 @@ func (h *Handler) APICall(c *gin.Context) {
reqHeaders[key] = strings.ReplaceAll(value, "$TOKEN$", token)
}
// When caller indicates CBOR in request headers, convert JSON string payload to CBOR bytes.
useCBORPayload := headerContainsValue(reqHeaders, "Content-Type", "application/cbor")
var requestBody io.Reader
if body.Data != "" {
requestBody = strings.NewReader(body.Data)
if useCBORPayload {
cborPayload, errEncode := encodeJSONStringToCBOR(body.Data)
if errEncode != nil {
c.JSON(http.StatusBadRequest, gin.H{"error": "invalid json data for cbor content-type"})
return
}
requestBody = bytes.NewReader(cborPayload)
} else {
requestBody = strings.NewReader(body.Data)
}
}
req, errNewRequest := http.NewRequestWithContext(c.Request.Context(), method, urlStr, requestBody)
@@ -208,11 +245,38 @@ func (h *Handler) APICall(c *gin.Context) {
return
}
c.JSON(http.StatusOK, apiCallResponse{
// For CBOR upstream responses, decode into plain text or JSON string before returning.
responseBodyText := string(respBody)
if headerContainsValue(reqHeaders, "Accept", "application/cbor") || strings.Contains(strings.ToLower(resp.Header.Get("Content-Type")), "application/cbor") {
if decodedBody, errDecode := decodeCBORBodyToTextOrJSON(respBody); errDecode == nil {
responseBodyText = decodedBody
}
}
response := apiCallResponse{
StatusCode: resp.StatusCode,
Header: resp.Header,
Body: string(respBody),
})
Body: responseBodyText,
}
// If this is a GitHub Copilot token endpoint response, try to enrich with quota information
if resp.StatusCode == http.StatusOK &&
strings.Contains(urlStr, "copilot_internal") &&
strings.Contains(urlStr, "/token") {
response = h.enrichCopilotTokenResponse(c.Request.Context(), response, auth, urlStr)
}
// Return response in the same format as the request
if isCBOR {
cborData, errMarshal := cbor.Marshal(response)
if errMarshal != nil {
c.JSON(http.StatusInternalServerError, gin.H{"error": "failed to encode cbor response"})
return
}
c.Data(http.StatusOK, "application/cbor", cborData)
} else {
c.JSON(http.StatusOK, response)
}
}
func firstNonEmptyString(values ...*string) string {
@@ -666,3 +730,421 @@ func buildProxyTransport(proxyStr string) *http.Transport {
}
return transport
}
// headerContainsValue checks whether a header map contains a target value (case-insensitive key and value).
func headerContainsValue(headers map[string]string, targetKey, targetValue string) bool {
if len(headers) == 0 {
return false
}
for key, value := range headers {
if !strings.EqualFold(strings.TrimSpace(key), strings.TrimSpace(targetKey)) {
continue
}
if strings.Contains(strings.ToLower(value), strings.ToLower(strings.TrimSpace(targetValue))) {
return true
}
}
return false
}
// encodeJSONStringToCBOR converts a JSON string payload into CBOR bytes.
func encodeJSONStringToCBOR(jsonString string) ([]byte, error) {
var payload any
if errUnmarshal := json.Unmarshal([]byte(jsonString), &payload); errUnmarshal != nil {
return nil, errUnmarshal
}
return cbor.Marshal(payload)
}
// decodeCBORBodyToTextOrJSON decodes CBOR bytes to plain text (for string payloads) or JSON string.
func decodeCBORBodyToTextOrJSON(raw []byte) (string, error) {
if len(raw) == 0 {
return "", nil
}
var payload any
if errUnmarshal := cbor.Unmarshal(raw, &payload); errUnmarshal != nil {
return "", errUnmarshal
}
jsonCompatible := cborValueToJSONCompatible(payload)
switch typed := jsonCompatible.(type) {
case string:
return typed, nil
case []byte:
return string(typed), nil
default:
jsonBytes, errMarshal := json.Marshal(jsonCompatible)
if errMarshal != nil {
return "", errMarshal
}
return string(jsonBytes), nil
}
}
// cborValueToJSONCompatible recursively converts CBOR-decoded values into JSON-marshalable values.
func cborValueToJSONCompatible(value any) any {
switch typed := value.(type) {
case map[any]any:
out := make(map[string]any, len(typed))
for key, item := range typed {
out[fmt.Sprint(key)] = cborValueToJSONCompatible(item)
}
return out
case map[string]any:
out := make(map[string]any, len(typed))
for key, item := range typed {
out[key] = cborValueToJSONCompatible(item)
}
return out
case []any:
out := make([]any, len(typed))
for i, item := range typed {
out[i] = cborValueToJSONCompatible(item)
}
return out
default:
return typed
}
}
// QuotaDetail represents quota information for a specific resource type
type QuotaDetail struct {
Entitlement float64 `json:"entitlement"`
OverageCount float64 `json:"overage_count"`
OveragePermitted bool `json:"overage_permitted"`
PercentRemaining float64 `json:"percent_remaining"`
QuotaID string `json:"quota_id"`
QuotaRemaining float64 `json:"quota_remaining"`
Remaining float64 `json:"remaining"`
Unlimited bool `json:"unlimited"`
}
// QuotaSnapshots contains quota details for different resource types
type QuotaSnapshots struct {
Chat QuotaDetail `json:"chat"`
Completions QuotaDetail `json:"completions"`
PremiumInteractions QuotaDetail `json:"premium_interactions"`
}
// CopilotUsageResponse represents the GitHub Copilot usage information
type CopilotUsageResponse struct {
AccessTypeSKU string `json:"access_type_sku"`
AnalyticsTrackingID string `json:"analytics_tracking_id"`
AssignedDate string `json:"assigned_date"`
CanSignupForLimited bool `json:"can_signup_for_limited"`
ChatEnabled bool `json:"chat_enabled"`
CopilotPlan string `json:"copilot_plan"`
OrganizationLoginList []interface{} `json:"organization_login_list"`
OrganizationList []interface{} `json:"organization_list"`
QuotaResetDate string `json:"quota_reset_date"`
QuotaSnapshots QuotaSnapshots `json:"quota_snapshots"`
}
type copilotQuotaRequest struct {
AuthIndexSnake *string `json:"auth_index"`
AuthIndexCamel *string `json:"authIndex"`
AuthIndexPascal *string `json:"AuthIndex"`
}
// GetCopilotQuota fetches GitHub Copilot quota information from the /copilot_internal/user endpoint.
//
// Endpoint:
//
// GET /v0/management/copilot-quota
//
// Query Parameters (optional):
// - auth_index: The credential "auth_index" from GET /v0/management/auth-files.
// If omitted, uses the first available GitHub Copilot credential.
//
// Response:
//
// Returns the CopilotUsageResponse with quota_snapshots containing detailed quota information
// for chat, completions, and premium_interactions.
//
// Example:
//
// curl -sS -X GET "http://127.0.0.1:8317/v0/management/copilot-quota?auth_index=<AUTH_INDEX>" \
// -H "Authorization: Bearer <MANAGEMENT_KEY>"
func (h *Handler) GetCopilotQuota(c *gin.Context) {
authIndex := strings.TrimSpace(c.Query("auth_index"))
if authIndex == "" {
authIndex = strings.TrimSpace(c.Query("authIndex"))
}
if authIndex == "" {
authIndex = strings.TrimSpace(c.Query("AuthIndex"))
}
auth := h.findCopilotAuth(authIndex)
if auth == nil {
c.JSON(http.StatusBadRequest, gin.H{"error": "no github copilot credential found"})
return
}
token, tokenErr := h.resolveTokenForAuth(c.Request.Context(), auth)
if tokenErr != nil {
c.JSON(http.StatusBadRequest, gin.H{"error": "failed to refresh copilot token"})
return
}
if token == "" {
c.JSON(http.StatusBadRequest, gin.H{"error": "copilot token not found"})
return
}
apiURL := "https://api.github.com/copilot_internal/user"
req, errNewRequest := http.NewRequestWithContext(c.Request.Context(), http.MethodGet, apiURL, nil)
if errNewRequest != nil {
c.JSON(http.StatusInternalServerError, gin.H{"error": "failed to build request"})
return
}
req.Header.Set("Authorization", "Bearer "+token)
req.Header.Set("User-Agent", "CLIProxyAPIPlus")
req.Header.Set("Accept", "application/json")
httpClient := &http.Client{
Timeout: defaultAPICallTimeout,
Transport: h.apiCallTransport(auth),
}
resp, errDo := httpClient.Do(req)
if errDo != nil {
log.WithError(errDo).Debug("copilot quota request failed")
c.JSON(http.StatusBadGateway, gin.H{"error": "request failed"})
return
}
defer func() {
if errClose := resp.Body.Close(); errClose != nil {
log.Errorf("response body close error: %v", errClose)
}
}()
respBody, errReadAll := io.ReadAll(resp.Body)
if errReadAll != nil {
c.JSON(http.StatusBadGateway, gin.H{"error": "failed to read response"})
return
}
if resp.StatusCode != http.StatusOK {
c.JSON(http.StatusBadGateway, gin.H{
"error": "github api request failed",
"status_code": resp.StatusCode,
"body": string(respBody),
})
return
}
var usage CopilotUsageResponse
if errUnmarshal := json.Unmarshal(respBody, &usage); errUnmarshal != nil {
c.JSON(http.StatusInternalServerError, gin.H{"error": "failed to parse response"})
return
}
c.JSON(http.StatusOK, usage)
}
// findCopilotAuth locates a GitHub Copilot credential by auth_index or returns the first available one
func (h *Handler) findCopilotAuth(authIndex string) *coreauth.Auth {
if h == nil || h.authManager == nil {
return nil
}
auths := h.authManager.List()
var firstCopilot *coreauth.Auth
for _, auth := range auths {
if auth == nil {
continue
}
provider := strings.ToLower(strings.TrimSpace(auth.Provider))
if provider != "copilot" && provider != "github" && provider != "github-copilot" {
continue
}
if firstCopilot == nil {
firstCopilot = auth
}
if authIndex != "" {
auth.EnsureIndex()
if auth.Index == authIndex {
return auth
}
}
}
return firstCopilot
}
// enrichCopilotTokenResponse fetches quota information and adds it to the Copilot token response body
func (h *Handler) enrichCopilotTokenResponse(ctx context.Context, response apiCallResponse, auth *coreauth.Auth, originalURL string) apiCallResponse {
if auth == nil || response.Body == "" {
return response
}
// Parse the token response to check if it's enterprise (null limited_user_quotas)
var tokenResp map[string]interface{}
if err := json.Unmarshal([]byte(response.Body), &tokenResp); err != nil {
log.WithError(err).Debug("enrichCopilotTokenResponse: failed to parse copilot token response")
return response
}
// Get the GitHub token to call the copilot_internal/user endpoint
token, tokenErr := h.resolveTokenForAuth(ctx, auth)
if tokenErr != nil {
log.WithError(tokenErr).Debug("enrichCopilotTokenResponse: failed to resolve token")
return response
}
if token == "" {
return response
}
// Fetch quota information from /copilot_internal/user
// Derive the base URL from the original token request to support proxies and test servers
parsedURL, errParse := url.Parse(originalURL)
if errParse != nil {
log.WithError(errParse).Debug("enrichCopilotTokenResponse: failed to parse URL")
return response
}
quotaURL := fmt.Sprintf("%s://%s/copilot_internal/user", parsedURL.Scheme, parsedURL.Host)
req, errNewRequest := http.NewRequestWithContext(ctx, http.MethodGet, quotaURL, nil)
if errNewRequest != nil {
log.WithError(errNewRequest).Debug("enrichCopilotTokenResponse: failed to build request")
return response
}
req.Header.Set("Authorization", "Bearer "+token)
req.Header.Set("User-Agent", "CLIProxyAPIPlus")
req.Header.Set("Accept", "application/json")
httpClient := &http.Client{
Timeout: defaultAPICallTimeout,
Transport: h.apiCallTransport(auth),
}
quotaResp, errDo := httpClient.Do(req)
if errDo != nil {
log.WithError(errDo).Debug("enrichCopilotTokenResponse: quota fetch HTTP request failed")
return response
}
defer func() {
if errClose := quotaResp.Body.Close(); errClose != nil {
log.Errorf("quota response body close error: %v", errClose)
}
}()
if quotaResp.StatusCode != http.StatusOK {
return response
}
quotaBody, errReadAll := io.ReadAll(quotaResp.Body)
if errReadAll != nil {
log.WithError(errReadAll).Debug("enrichCopilotTokenResponse: failed to read response")
return response
}
// Parse the quota response
var quotaData CopilotUsageResponse
if err := json.Unmarshal(quotaBody, &quotaData); err != nil {
log.WithError(err).Debug("enrichCopilotTokenResponse: failed to parse response")
return response
}
// Check if this is an enterprise account by looking for quota_snapshots in the response
// Enterprise accounts have quota_snapshots, non-enterprise have limited_user_quotas
var quotaRaw map[string]interface{}
if err := json.Unmarshal(quotaBody, &quotaRaw); err == nil {
if _, hasQuotaSnapshots := quotaRaw["quota_snapshots"]; hasQuotaSnapshots {
// Enterprise account - has quota_snapshots
tokenResp["quota_snapshots"] = quotaData.QuotaSnapshots
tokenResp["access_type_sku"] = quotaData.AccessTypeSKU
tokenResp["copilot_plan"] = quotaData.CopilotPlan
// Add quota reset date for enterprise (quota_reset_date_utc)
if quotaResetDateUTC, ok := quotaRaw["quota_reset_date_utc"]; ok {
tokenResp["quota_reset_date"] = quotaResetDateUTC
} else if quotaData.QuotaResetDate != "" {
tokenResp["quota_reset_date"] = quotaData.QuotaResetDate
}
} else {
// Non-enterprise account - build quota from limited_user_quotas and monthly_quotas
var quotaSnapshots QuotaSnapshots
// Get monthly quotas (total entitlement) and limited_user_quotas (remaining)
monthlyQuotas, hasMonthly := quotaRaw["monthly_quotas"].(map[string]interface{})
limitedQuotas, hasLimited := quotaRaw["limited_user_quotas"].(map[string]interface{})
// Process chat quota
if hasMonthly && hasLimited {
if chatTotal, ok := monthlyQuotas["chat"].(float64); ok {
chatRemaining := chatTotal // default to full if no limited quota
if chatLimited, ok := limitedQuotas["chat"].(float64); ok {
chatRemaining = chatLimited
}
percentRemaining := 0.0
if chatTotal > 0 {
percentRemaining = (chatRemaining / chatTotal) * 100.0
}
quotaSnapshots.Chat = QuotaDetail{
Entitlement: chatTotal,
Remaining: chatRemaining,
QuotaRemaining: chatRemaining,
PercentRemaining: percentRemaining,
QuotaID: "chat",
Unlimited: false,
}
}
// Process completions quota
if completionsTotal, ok := monthlyQuotas["completions"].(float64); ok {
completionsRemaining := completionsTotal // default to full if no limited quota
if completionsLimited, ok := limitedQuotas["completions"].(float64); ok {
completionsRemaining = completionsLimited
}
percentRemaining := 0.0
if completionsTotal > 0 {
percentRemaining = (completionsRemaining / completionsTotal) * 100.0
}
quotaSnapshots.Completions = QuotaDetail{
Entitlement: completionsTotal,
Remaining: completionsRemaining,
QuotaRemaining: completionsRemaining,
PercentRemaining: percentRemaining,
QuotaID: "completions",
Unlimited: false,
}
}
}
// Premium interactions don't exist for non-enterprise, leave as zero values
quotaSnapshots.PremiumInteractions = QuotaDetail{
QuotaID: "premium_interactions",
Unlimited: false,
}
// Add quota_snapshots to the token response
tokenResp["quota_snapshots"] = quotaSnapshots
tokenResp["access_type_sku"] = quotaData.AccessTypeSKU
tokenResp["copilot_plan"] = quotaData.CopilotPlan
// Add quota reset date for non-enterprise (limited_user_reset_date)
if limitedResetDate, ok := quotaRaw["limited_user_reset_date"]; ok {
tokenResp["quota_reset_date"] = limitedResetDate
}
}
}
// Re-serialize the enriched response
enrichedBody, errMarshal := json.Marshal(tokenResp)
if errMarshal != nil {
log.WithError(errMarshal).Debug("failed to marshal enriched response")
return response
}
response.Body = string(enrichedBody)
return response
}

View File

@@ -0,0 +1,149 @@
package management
import (
"bytes"
"encoding/json"
"net/http"
"net/http/httptest"
"testing"
"github.com/fxamacker/cbor/v2"
"github.com/gin-gonic/gin"
)
func TestAPICall_CBOR_Support(t *testing.T) {
gin.SetMode(gin.TestMode)
// Create a test handler
h := &Handler{}
// Create test request data
reqData := apiCallRequest{
Method: "GET",
URL: "https://httpbin.org/get",
Header: map[string]string{
"User-Agent": "test-client",
},
}
t.Run("JSON request and response", func(t *testing.T) {
// Marshal request as JSON
jsonData, err := json.Marshal(reqData)
if err != nil {
t.Fatalf("Failed to marshal JSON: %v", err)
}
// Create HTTP request
req := httptest.NewRequest(http.MethodPost, "/v0/management/api-call", bytes.NewReader(jsonData))
req.Header.Set("Content-Type", "application/json")
// Create response recorder
w := httptest.NewRecorder()
// Create Gin context
c, _ := gin.CreateTestContext(w)
c.Request = req
// Call handler
h.APICall(c)
// Verify response
if w.Code != http.StatusOK && w.Code != http.StatusBadGateway {
t.Logf("Response status: %d", w.Code)
t.Logf("Response body: %s", w.Body.String())
}
// Check content type
contentType := w.Header().Get("Content-Type")
if w.Code == http.StatusOK && !contains(contentType, "application/json") {
t.Errorf("Expected JSON response, got: %s", contentType)
}
})
t.Run("CBOR request and response", func(t *testing.T) {
// Marshal request as CBOR
cborData, err := cbor.Marshal(reqData)
if err != nil {
t.Fatalf("Failed to marshal CBOR: %v", err)
}
// Create HTTP request
req := httptest.NewRequest(http.MethodPost, "/v0/management/api-call", bytes.NewReader(cborData))
req.Header.Set("Content-Type", "application/cbor")
// Create response recorder
w := httptest.NewRecorder()
// Create Gin context
c, _ := gin.CreateTestContext(w)
c.Request = req
// Call handler
h.APICall(c)
// Verify response
if w.Code != http.StatusOK && w.Code != http.StatusBadGateway {
t.Logf("Response status: %d", w.Code)
t.Logf("Response body: %s", w.Body.String())
}
// Check content type
contentType := w.Header().Get("Content-Type")
if w.Code == http.StatusOK && !contains(contentType, "application/cbor") {
t.Errorf("Expected CBOR response, got: %s", contentType)
}
// Try to decode CBOR response
if w.Code == http.StatusOK {
var response apiCallResponse
if err := cbor.Unmarshal(w.Body.Bytes(), &response); err != nil {
t.Errorf("Failed to unmarshal CBOR response: %v", err)
} else {
t.Logf("CBOR response decoded successfully: status_code=%d", response.StatusCode)
}
}
})
t.Run("CBOR encoding and decoding consistency", func(t *testing.T) {
// Test data
testReq := apiCallRequest{
Method: "POST",
URL: "https://example.com/api",
Header: map[string]string{
"Authorization": "Bearer $TOKEN$",
"Content-Type": "application/json",
},
Data: `{"key":"value"}`,
}
// Encode to CBOR
cborData, err := cbor.Marshal(testReq)
if err != nil {
t.Fatalf("Failed to marshal to CBOR: %v", err)
}
// Decode from CBOR
var decoded apiCallRequest
if err := cbor.Unmarshal(cborData, &decoded); err != nil {
t.Fatalf("Failed to unmarshal from CBOR: %v", err)
}
// Verify fields
if decoded.Method != testReq.Method {
t.Errorf("Method mismatch: got %s, want %s", decoded.Method, testReq.Method)
}
if decoded.URL != testReq.URL {
t.Errorf("URL mismatch: got %s, want %s", decoded.URL, testReq.URL)
}
if decoded.Data != testReq.Data {
t.Errorf("Data mismatch: got %s, want %s", decoded.Data, testReq.Data)
}
if len(decoded.Header) != len(testReq.Header) {
t.Errorf("Header count mismatch: got %d, want %d", len(decoded.Header), len(testReq.Header))
}
})
}
func contains(s, substr string) bool {
return len(s) > 0 && len(substr) > 0 && (s == substr || len(s) >= len(substr) && s[:len(substr)] == substr || bytes.Contains([]byte(s), []byte(substr)))
}

View File

@@ -3,7 +3,9 @@ package management
import (
"bytes"
"context"
"crypto/rand"
"crypto/sha256"
"encoding/base64"
"encoding/hex"
"encoding/json"
"errors"
@@ -12,6 +14,7 @@ import (
"mime/multipart"
"net"
"net/http"
"net/url"
"os"
"path/filepath"
"runtime"
@@ -25,9 +28,13 @@ import (
"github.com/router-for-me/CLIProxyAPI/v6/internal/auth/antigravity"
"github.com/router-for-me/CLIProxyAPI/v6/internal/auth/claude"
"github.com/router-for-me/CLIProxyAPI/v6/internal/auth/codex"
"github.com/router-for-me/CLIProxyAPI/v6/internal/auth/copilot"
geminiAuth "github.com/router-for-me/CLIProxyAPI/v6/internal/auth/gemini"
gitlabauth "github.com/router-for-me/CLIProxyAPI/v6/internal/auth/gitlab"
iflowauth "github.com/router-for-me/CLIProxyAPI/v6/internal/auth/iflow"
"github.com/router-for-me/CLIProxyAPI/v6/internal/auth/kilo"
"github.com/router-for-me/CLIProxyAPI/v6/internal/auth/kimi"
kiroauth "github.com/router-for-me/CLIProxyAPI/v6/internal/auth/kiro"
"github.com/router-for-me/CLIProxyAPI/v6/internal/auth/qwen"
"github.com/router-for-me/CLIProxyAPI/v6/internal/interfaces"
"github.com/router-for-me/CLIProxyAPI/v6/internal/misc"
@@ -49,6 +56,8 @@ const (
codexCallbackPort = 1455
geminiCLIEndpoint = "https://cloudcode-pa.googleapis.com"
geminiCLIVersion = "v1internal"
gitLabLoginModeOAuth = "oauth"
gitLabLoginModePAT = "pat"
)
type callbackForwarder struct {
@@ -1292,6 +1301,165 @@ func (h *Handler) saveTokenRecord(ctx context.Context, record *coreauth.Auth) (s
return store.Save(ctx, record)
}
func gitLabBaseURLFromRequest(c *gin.Context) string {
if c != nil {
if raw := strings.TrimSpace(c.Query("base_url")); raw != "" {
return gitlabauth.NormalizeBaseURL(raw)
}
}
if raw := strings.TrimSpace(os.Getenv("GITLAB_BASE_URL")); raw != "" {
return gitlabauth.NormalizeBaseURL(raw)
}
return gitlabauth.DefaultBaseURL
}
func buildGitLabAuthMetadata(baseURL, mode string, tokenResp *gitlabauth.TokenResponse, direct *gitlabauth.DirectAccessResponse) map[string]any {
metadata := map[string]any{
"type": "gitlab",
"auth_method": strings.TrimSpace(mode),
"base_url": gitlabauth.NormalizeBaseURL(baseURL),
"last_refresh": time.Now().UTC().Format(time.RFC3339),
"refresh_interval_seconds": 240,
}
if tokenResp != nil {
metadata["access_token"] = strings.TrimSpace(tokenResp.AccessToken)
if refreshToken := strings.TrimSpace(tokenResp.RefreshToken); refreshToken != "" {
metadata["refresh_token"] = refreshToken
}
if tokenType := strings.TrimSpace(tokenResp.TokenType); tokenType != "" {
metadata["token_type"] = tokenType
}
if scope := strings.TrimSpace(tokenResp.Scope); scope != "" {
metadata["scope"] = scope
}
if expiry := gitlabauth.TokenExpiry(time.Now(), tokenResp); !expiry.IsZero() {
metadata["oauth_expires_at"] = expiry.Format(time.RFC3339)
}
}
mergeGitLabDirectAccessMetadata(metadata, direct)
return metadata
}
func mergeGitLabDirectAccessMetadata(metadata map[string]any, direct *gitlabauth.DirectAccessResponse) {
if metadata == nil || direct == nil {
return
}
if base := strings.TrimSpace(direct.BaseURL); base != "" {
metadata["duo_gateway_base_url"] = base
}
if token := strings.TrimSpace(direct.Token); token != "" {
metadata["duo_gateway_token"] = token
}
if direct.ExpiresAt > 0 {
expiry := time.Unix(direct.ExpiresAt, 0).UTC()
metadata["duo_gateway_expires_at"] = expiry.Format(time.RFC3339)
now := time.Now().UTC()
if ttl := expiry.Sub(now); ttl > 0 {
interval := int(ttl.Seconds()) / 2
switch {
case interval < 60:
interval = 60
case interval > 240:
interval = 240
}
metadata["refresh_interval_seconds"] = interval
}
}
if len(direct.Headers) > 0 {
headers := make(map[string]string, len(direct.Headers))
for key, value := range direct.Headers {
key = strings.TrimSpace(key)
value = strings.TrimSpace(value)
if key == "" || value == "" {
continue
}
headers[key] = value
}
if len(headers) > 0 {
metadata["duo_gateway_headers"] = headers
}
}
if direct.ModelDetails != nil {
modelDetails := map[string]any{}
if provider := strings.TrimSpace(direct.ModelDetails.ModelProvider); provider != "" {
modelDetails["model_provider"] = provider
metadata["model_provider"] = provider
}
if model := strings.TrimSpace(direct.ModelDetails.ModelName); model != "" {
modelDetails["model_name"] = model
metadata["model_name"] = model
}
if len(modelDetails) > 0 {
metadata["model_details"] = modelDetails
}
}
}
func primaryGitLabEmail(user *gitlabauth.User) string {
if user == nil {
return ""
}
if value := strings.TrimSpace(user.Email); value != "" {
return value
}
return strings.TrimSpace(user.PublicEmail)
}
func gitLabAccountIdentifier(user *gitlabauth.User) string {
if user == nil {
return "user"
}
for _, value := range []string{user.Username, primaryGitLabEmail(user), user.Name} {
if trimmed := strings.TrimSpace(value); trimmed != "" {
return trimmed
}
}
return "user"
}
func sanitizeGitLabFileName(value string) string {
value = strings.TrimSpace(strings.ToLower(value))
if value == "" {
return "user"
}
var builder strings.Builder
lastDash := false
for _, r := range value {
switch {
case r >= 'a' && r <= 'z':
builder.WriteRune(r)
lastDash = false
case r >= '0' && r <= '9':
builder.WriteRune(r)
lastDash = false
case r == '-' || r == '_' || r == '.':
builder.WriteRune(r)
lastDash = false
default:
if !lastDash {
builder.WriteRune('-')
lastDash = true
}
}
}
result := strings.Trim(builder.String(), "-")
if result == "" {
return "user"
}
return result
}
func maskGitLabToken(token string) string {
trimmed := strings.TrimSpace(token)
if trimmed == "" {
return ""
}
if len(trimmed) <= 8 {
return trimmed
}
return trimmed[:4] + "..." + trimmed[len(trimmed)-4:]
}
func (h *Handler) RequestAnthropicToken(c *gin.Context) {
ctx := context.Background()
ctx = PopulateAuthContext(ctx, c)
@@ -1842,6 +2010,263 @@ func (h *Handler) RequestCodexToken(c *gin.Context) {
c.JSON(200, gin.H{"status": "ok", "url": authURL, "state": state})
}
func (h *Handler) RequestGitLabToken(c *gin.Context) {
ctx := context.Background()
ctx = PopulateAuthContext(ctx, c)
fmt.Println("Initializing GitLab Duo authentication...")
baseURL := gitLabBaseURLFromRequest(c)
clientID := strings.TrimSpace(c.Query("client_id"))
clientSecret := strings.TrimSpace(c.Query("client_secret"))
if clientID == "" {
clientID = strings.TrimSpace(os.Getenv("GITLAB_OAUTH_CLIENT_ID"))
}
if clientSecret == "" {
clientSecret = strings.TrimSpace(os.Getenv("GITLAB_OAUTH_CLIENT_SECRET"))
}
if clientID == "" {
c.JSON(http.StatusBadRequest, gin.H{"error": "gitlab client_id is required"})
return
}
pkceCodes, err := gitlabauth.GeneratePKCECodes()
if err != nil {
log.Errorf("Failed to generate GitLab PKCE codes: %v", err)
c.JSON(http.StatusInternalServerError, gin.H{"error": "failed to generate PKCE codes"})
return
}
state, err := misc.GenerateRandomState()
if err != nil {
log.Errorf("Failed to generate GitLab state parameter: %v", err)
c.JSON(http.StatusInternalServerError, gin.H{"error": "failed to generate state parameter"})
return
}
redirectURI := gitlabauth.RedirectURL(gitlabauth.DefaultCallbackPort)
authClient := gitlabauth.NewAuthClient(h.cfg)
authURL, err := authClient.GenerateAuthURL(baseURL, clientID, redirectURI, state, pkceCodes)
if err != nil {
log.Errorf("Failed to generate GitLab authorization URL: %v", err)
c.JSON(http.StatusInternalServerError, gin.H{"error": "failed to generate authorization url"})
return
}
RegisterOAuthSession(state, "gitlab")
isWebUI := isWebUIRequest(c)
var forwarder *callbackForwarder
if isWebUI {
targetURL, errTarget := h.managementCallbackURL("/gitlab/callback")
if errTarget != nil {
log.WithError(errTarget).Error("failed to compute gitlab callback target")
c.JSON(http.StatusInternalServerError, gin.H{"error": "callback server unavailable"})
return
}
var errStart error
if forwarder, errStart = startCallbackForwarder(gitlabauth.DefaultCallbackPort, "gitlab", targetURL); errStart != nil {
log.WithError(errStart).Error("failed to start gitlab callback forwarder")
c.JSON(http.StatusInternalServerError, gin.H{"error": "failed to start callback server"})
return
}
}
go func() {
if isWebUI {
defer stopCallbackForwarderInstance(gitlabauth.DefaultCallbackPort, forwarder)
}
waitFile := filepath.Join(h.cfg.AuthDir, fmt.Sprintf(".oauth-gitlab-%s.oauth", state))
deadline := time.Now().Add(5 * time.Minute)
var code string
for {
if !IsOAuthSessionPending(state, "gitlab") {
return
}
if time.Now().After(deadline) {
log.Error("gitlab oauth flow timed out")
SetOAuthSessionError(state, "Timeout waiting for OAuth callback")
return
}
if data, errRead := os.ReadFile(waitFile); errRead == nil {
var payload map[string]string
_ = json.Unmarshal(data, &payload)
_ = os.Remove(waitFile)
if errStr := strings.TrimSpace(payload["error"]); errStr != "" {
SetOAuthSessionError(state, errStr)
return
}
if payloadState := strings.TrimSpace(payload["state"]); payloadState != state {
SetOAuthSessionError(state, "State code error")
return
}
code = strings.TrimSpace(payload["code"])
if code == "" {
SetOAuthSessionError(state, "Authorization code missing")
return
}
break
}
time.Sleep(500 * time.Millisecond)
}
tokenResp, errExchange := authClient.ExchangeCodeForTokens(ctx, baseURL, clientID, clientSecret, redirectURI, code, pkceCodes.CodeVerifier)
if errExchange != nil {
log.Errorf("Failed to exchange GitLab authorization code: %v", errExchange)
SetOAuthSessionError(state, "Failed to exchange authorization code for tokens")
return
}
user, errUser := authClient.GetCurrentUser(ctx, baseURL, tokenResp.AccessToken)
if errUser != nil {
log.Errorf("Failed to fetch GitLab user profile: %v", errUser)
SetOAuthSessionError(state, "Failed to fetch account profile")
return
}
direct, errDirect := authClient.FetchDirectAccess(ctx, baseURL, tokenResp.AccessToken)
if errDirect != nil {
log.Errorf("Failed to fetch GitLab direct access metadata: %v", errDirect)
SetOAuthSessionError(state, "Failed to fetch GitLab Duo access")
return
}
identifier := gitLabAccountIdentifier(user)
fileName := fmt.Sprintf("gitlab-%s.json", sanitizeGitLabFileName(identifier))
metadata := buildGitLabAuthMetadata(baseURL, gitLabLoginModeOAuth, tokenResp, direct)
metadata["auth_kind"] = "oauth"
metadata["oauth_client_id"] = clientID
if clientSecret != "" {
metadata["oauth_client_secret"] = clientSecret
}
metadata["username"] = strings.TrimSpace(user.Username)
if email := primaryGitLabEmail(user); email != "" {
metadata["email"] = email
}
metadata["name"] = strings.TrimSpace(user.Name)
record := &coreauth.Auth{
ID: fileName,
Provider: "gitlab",
FileName: fileName,
Label: identifier,
Metadata: metadata,
}
savedPath, errSave := h.saveTokenRecord(ctx, record)
if errSave != nil {
log.Errorf("Failed to save GitLab auth record: %v", errSave)
SetOAuthSessionError(state, "Failed to save authentication tokens")
return
}
fmt.Printf("GitLab Duo authentication successful. Token saved to %s\n", savedPath)
CompleteOAuthSession(state)
CompleteOAuthSessionsByProvider("gitlab")
}()
c.JSON(http.StatusOK, gin.H{"status": "ok", "url": authURL, "state": state})
}
func (h *Handler) RequestGitLabPATToken(c *gin.Context) {
ctx := context.Background()
ctx = PopulateAuthContext(ctx, c)
var payload struct {
BaseURL string `json:"base_url"`
PersonalAccessToken string `json:"personal_access_token"`
Token string `json:"token"`
}
if err := c.ShouldBindJSON(&payload); err != nil {
c.JSON(http.StatusBadRequest, gin.H{"status": "error", "error": "invalid body"})
return
}
baseURL := gitlabauth.NormalizeBaseURL(strings.TrimSpace(payload.BaseURL))
if baseURL == "" {
baseURL = gitLabBaseURLFromRequest(nil)
}
pat := strings.TrimSpace(payload.PersonalAccessToken)
if pat == "" {
pat = strings.TrimSpace(payload.Token)
}
if pat == "" {
c.JSON(http.StatusBadRequest, gin.H{"status": "error", "error": "personal_access_token is required"})
return
}
authClient := gitlabauth.NewAuthClient(h.cfg)
user, err := authClient.GetCurrentUser(ctx, baseURL, pat)
if err != nil {
c.JSON(http.StatusBadRequest, gin.H{"status": "error", "error": err.Error()})
return
}
patSelf, err := authClient.GetPersonalAccessTokenSelf(ctx, baseURL, pat)
if err != nil {
c.JSON(http.StatusBadRequest, gin.H{"status": "error", "error": err.Error()})
return
}
direct, err := authClient.FetchDirectAccess(ctx, baseURL, pat)
if err != nil {
c.JSON(http.StatusBadRequest, gin.H{"status": "error", "error": err.Error()})
return
}
identifier := gitLabAccountIdentifier(user)
fileName := fmt.Sprintf("gitlab-%s-pat.json", sanitizeGitLabFileName(identifier))
metadata := buildGitLabAuthMetadata(baseURL, gitLabLoginModePAT, nil, direct)
metadata["auth_kind"] = "personal_access_token"
metadata["personal_access_token"] = pat
metadata["token_preview"] = maskGitLabToken(pat)
metadata["username"] = strings.TrimSpace(user.Username)
if email := primaryGitLabEmail(user); email != "" {
metadata["email"] = email
}
metadata["name"] = strings.TrimSpace(user.Name)
if patSelf != nil {
if name := strings.TrimSpace(patSelf.Name); name != "" {
metadata["pat_name"] = name
}
if len(patSelf.Scopes) > 0 {
metadata["pat_scopes"] = append([]string(nil), patSelf.Scopes...)
}
}
record := &coreauth.Auth{
ID: fileName,
Provider: "gitlab",
FileName: fileName,
Label: identifier + " (PAT)",
Metadata: metadata,
}
savedPath, err := h.saveTokenRecord(ctx, record)
if err != nil {
c.JSON(http.StatusInternalServerError, gin.H{"status": "error", "error": "failed to save authentication tokens"})
return
}
response := gin.H{
"status": "ok",
"saved_path": savedPath,
"username": strings.TrimSpace(user.Username),
"email": primaryGitLabEmail(user),
"token_label": identifier,
}
if direct != nil && direct.ModelDetails != nil {
if provider := strings.TrimSpace(direct.ModelDetails.ModelProvider); provider != "" {
response["model_provider"] = provider
}
if model := strings.TrimSpace(direct.ModelDetails.ModelName); model != "" {
response["model_name"] = model
}
}
fmt.Printf("GitLab Duo PAT authentication successful. Token saved to %s\n", savedPath)
c.JSON(http.StatusOK, response)
}
func (h *Handler) RequestAntigravityToken(c *gin.Context) {
ctx := context.Background()
ctx = PopulateAuthContext(ctx, c)
@@ -2254,6 +2679,117 @@ func (h *Handler) RequestIFlowToken(c *gin.Context) {
c.JSON(http.StatusOK, gin.H{"status": "ok", "url": authURL, "state": state})
}
func (h *Handler) RequestGitHubToken(c *gin.Context) {
ctx := context.Background()
fmt.Println("Initializing GitHub Copilot authentication...")
state := fmt.Sprintf("gh-%d", time.Now().UnixNano())
// Initialize Copilot auth service
deviceClient := copilot.NewDeviceFlowClient(h.cfg)
// Initiate device flow
deviceCode, err := deviceClient.RequestDeviceCode(ctx)
if err != nil {
log.Errorf("Failed to initiate device flow: %v", err)
c.JSON(http.StatusInternalServerError, gin.H{"error": "failed to initiate device flow"})
return
}
authURL := deviceCode.VerificationURI
userCode := deviceCode.UserCode
RegisterOAuthSession(state, "github-copilot")
go func() {
fmt.Printf("Please visit %s and enter code: %s\n", authURL, userCode)
tokenData, errPoll := deviceClient.PollForToken(ctx, deviceCode)
if errPoll != nil {
SetOAuthSessionError(state, "Authentication failed")
fmt.Printf("Authentication failed: %v\n", errPoll)
return
}
userInfo, errUser := deviceClient.FetchUserInfo(ctx, tokenData.AccessToken)
if errUser != nil {
log.Warnf("Failed to fetch user info: %v", errUser)
}
username := userInfo.Login
if username == "" {
username = "github-user"
}
tokenStorage := &copilot.CopilotTokenStorage{
AccessToken: tokenData.AccessToken,
TokenType: tokenData.TokenType,
Scope: tokenData.Scope,
Username: username,
Email: userInfo.Email,
Name: userInfo.Name,
Type: "github-copilot",
}
fileName := fmt.Sprintf("github-copilot-%s.json", username)
label := userInfo.Email
if label == "" {
label = username
}
metadata, errMeta := copilotTokenMetadata(tokenStorage)
if errMeta != nil {
log.Errorf("Failed to build token metadata: %v", errMeta)
SetOAuthSessionError(state, "Failed to build token metadata")
return
}
record := &coreauth.Auth{
ID: fileName,
Provider: "github-copilot",
Label: label,
FileName: fileName,
Storage: tokenStorage,
Metadata: metadata,
}
savedPath, errSave := h.saveTokenRecord(ctx, record)
if errSave != nil {
log.Errorf("Failed to save authentication tokens: %v", errSave)
SetOAuthSessionError(state, "Failed to save authentication tokens")
return
}
fmt.Printf("Authentication successful! Token saved to %s\n", savedPath)
fmt.Println("You can now use GitHub Copilot services through this CLI")
CompleteOAuthSession(state)
CompleteOAuthSessionsByProvider("github-copilot")
}()
c.JSON(200, gin.H{
"status": "ok",
"url": authURL,
"state": state,
"user_code": userCode,
"verification_uri": authURL,
})
}
func copilotTokenMetadata(storage *copilot.CopilotTokenStorage) (map[string]any, error) {
if storage == nil {
return nil, fmt.Errorf("token storage is nil")
}
payload, errMarshal := json.Marshal(storage)
if errMarshal != nil {
return nil, fmt.Errorf("marshal token storage: %w", errMarshal)
}
metadata := make(map[string]any)
if errUnmarshal := json.Unmarshal(payload, &metadata); errUnmarshal != nil {
return nil, fmt.Errorf("unmarshal token storage: %w", errUnmarshal)
}
return metadata, nil
}
func (h *Handler) RequestIFlowCookieToken(c *gin.Context) {
ctx := context.Background()
@@ -2756,6 +3292,25 @@ func (h *Handler) GetAuthStatus(c *gin.Context) {
return
}
if status != "" {
if strings.HasPrefix(status, "device_code|") {
parts := strings.SplitN(status, "|", 3)
if len(parts) == 3 {
c.JSON(http.StatusOK, gin.H{
"status": "device_code",
"verification_url": parts[1],
"user_code": parts[2],
})
return
}
}
if strings.HasPrefix(status, "auth_url|") {
authURL := strings.TrimPrefix(status, "auth_url|")
c.JSON(http.StatusOK, gin.H{
"status": "auth_url",
"url": authURL,
})
return
}
c.JSON(http.StatusOK, gin.H{"status": "error", "error": status})
return
}
@@ -2770,3 +3325,385 @@ func PopulateAuthContext(ctx context.Context, c *gin.Context) context.Context {
}
return coreauth.WithRequestInfo(ctx, info)
}
const kiroCallbackPort = 9876
func (h *Handler) RequestKiroToken(c *gin.Context) {
ctx := context.Background()
// Get the login method from query parameter (default: aws for device code flow)
method := strings.ToLower(strings.TrimSpace(c.Query("method")))
if method == "" {
method = "aws"
}
fmt.Println("Initializing Kiro authentication...")
state := fmt.Sprintf("kiro-%d", time.Now().UnixNano())
switch method {
case "aws", "builder-id":
RegisterOAuthSession(state, "kiro")
// AWS Builder ID uses device code flow (no callback needed)
go func() {
ssoClient := kiroauth.NewSSOOIDCClient(h.cfg)
// Step 1: Register client
fmt.Println("Registering client...")
regResp, errRegister := ssoClient.RegisterClient(ctx)
if errRegister != nil {
log.Errorf("Failed to register client: %v", errRegister)
SetOAuthSessionError(state, "Failed to register client")
return
}
// Step 2: Start device authorization
fmt.Println("Starting device authorization...")
authResp, errAuth := ssoClient.StartDeviceAuthorization(ctx, regResp.ClientID, regResp.ClientSecret)
if errAuth != nil {
log.Errorf("Failed to start device auth: %v", errAuth)
SetOAuthSessionError(state, "Failed to start device authorization")
return
}
// Store the verification URL for the frontend to display.
// Using "|" as separator because URLs contain ":".
SetOAuthSessionError(state, "device_code|"+authResp.VerificationURIComplete+"|"+authResp.UserCode)
// Step 3: Poll for token
fmt.Println("Waiting for authorization...")
interval := 5 * time.Second
if authResp.Interval > 0 {
interval = time.Duration(authResp.Interval) * time.Second
}
deadline := time.Now().Add(time.Duration(authResp.ExpiresIn) * time.Second)
for time.Now().Before(deadline) {
select {
case <-ctx.Done():
SetOAuthSessionError(state, "Authorization cancelled")
return
case <-time.After(interval):
tokenResp, errToken := ssoClient.CreateToken(ctx, regResp.ClientID, regResp.ClientSecret, authResp.DeviceCode)
if errToken != nil {
errStr := errToken.Error()
if strings.Contains(errStr, "authorization_pending") {
continue
}
if strings.Contains(errStr, "slow_down") {
interval += 5 * time.Second
continue
}
log.Errorf("Token creation failed: %v", errToken)
SetOAuthSessionError(state, "Token creation failed")
return
}
// Success! Save the token
expiresAt := time.Now().Add(time.Duration(tokenResp.ExpiresIn) * time.Second)
email := kiroauth.ExtractEmailFromJWT(tokenResp.AccessToken)
idPart := kiroauth.SanitizeEmailForFilename(email)
if idPart == "" {
idPart = fmt.Sprintf("%d", time.Now().UnixNano()%100000)
}
now := time.Now()
fileName := fmt.Sprintf("kiro-aws-%s.json", idPart)
record := &coreauth.Auth{
ID: fileName,
Provider: "kiro",
FileName: fileName,
Metadata: map[string]any{
"type": "kiro",
"access_token": tokenResp.AccessToken,
"refresh_token": tokenResp.RefreshToken,
"expires_at": expiresAt.Format(time.RFC3339),
"auth_method": "builder-id",
"provider": "AWS",
"client_id": regResp.ClientID,
"client_secret": regResp.ClientSecret,
"email": email,
"last_refresh": now.Format(time.RFC3339),
},
}
savedPath, errSave := h.saveTokenRecord(ctx, record)
if errSave != nil {
log.Errorf("Failed to save authentication tokens: %v", errSave)
SetOAuthSessionError(state, "Failed to save authentication tokens")
return
}
fmt.Printf("Authentication successful! Token saved to %s\n", savedPath)
if email != "" {
fmt.Printf("Authenticated as: %s\n", email)
}
CompleteOAuthSession(state)
return
}
}
SetOAuthSessionError(state, "Authorization timed out")
}()
// Return immediately with the state for polling
c.JSON(http.StatusOK, gin.H{"status": "ok", "state": state, "method": "device_code"})
case "google", "github":
RegisterOAuthSession(state, "kiro")
// Social auth uses protocol handler - for WEB UI we use a callback forwarder
provider := "Google"
if method == "github" {
provider = "Github"
}
isWebUI := isWebUIRequest(c)
var forwarder *callbackForwarder
if isWebUI {
targetURL, errTarget := h.managementCallbackURL("/kiro/callback")
if errTarget != nil {
log.WithError(errTarget).Error("failed to compute kiro callback target")
c.JSON(http.StatusInternalServerError, gin.H{"error": "callback server unavailable"})
return
}
var errStart error
if forwarder, errStart = startCallbackForwarder(kiroCallbackPort, "kiro", targetURL); errStart != nil {
log.WithError(errStart).Error("failed to start kiro callback forwarder")
c.JSON(http.StatusInternalServerError, gin.H{"error": "failed to start callback server"})
return
}
}
go func() {
if isWebUI {
defer stopCallbackForwarderInstance(kiroCallbackPort, forwarder)
}
socialClient := kiroauth.NewSocialAuthClient(h.cfg)
// Generate PKCE codes
codeVerifier, codeChallenge, errPKCE := generateKiroPKCE()
if errPKCE != nil {
log.Errorf("Failed to generate PKCE: %v", errPKCE)
SetOAuthSessionError(state, "Failed to generate PKCE")
return
}
// Build login URL
authURL := fmt.Sprintf("%s/login?idp=%s&redirect_uri=%s&code_challenge=%s&code_challenge_method=S256&state=%s&prompt=select_account",
"https://prod.us-east-1.auth.desktop.kiro.dev",
provider,
url.QueryEscape(kiroauth.KiroRedirectURI),
codeChallenge,
state,
)
// Store auth URL for frontend.
// Using "|" as separator because URLs contain ":".
SetOAuthSessionError(state, "auth_url|"+authURL)
// Wait for callback file
waitFile := filepath.Join(h.cfg.AuthDir, fmt.Sprintf(".oauth-kiro-%s.oauth", state))
deadline := time.Now().Add(5 * time.Minute)
for {
if time.Now().After(deadline) {
log.Error("oauth flow timed out")
SetOAuthSessionError(state, "OAuth flow timed out")
return
}
if data, errRead := os.ReadFile(waitFile); errRead == nil {
var m map[string]string
_ = json.Unmarshal(data, &m)
_ = os.Remove(waitFile)
if errStr := m["error"]; errStr != "" {
log.Errorf("Authentication failed: %s", errStr)
SetOAuthSessionError(state, "Authentication failed")
return
}
if m["state"] != state {
log.Errorf("State mismatch")
SetOAuthSessionError(state, "State mismatch")
return
}
code := m["code"]
if code == "" {
log.Error("No authorization code received")
SetOAuthSessionError(state, "No authorization code received")
return
}
// Exchange code for tokens
tokenReq := &kiroauth.CreateTokenRequest{
Code: code,
CodeVerifier: codeVerifier,
RedirectURI: kiroauth.KiroRedirectURI,
}
tokenResp, errToken := socialClient.CreateToken(ctx, tokenReq)
if errToken != nil {
log.Errorf("Failed to exchange code for tokens: %v", errToken)
SetOAuthSessionError(state, "Failed to exchange code for tokens")
return
}
// Save the token
expiresIn := tokenResp.ExpiresIn
if expiresIn <= 0 {
expiresIn = 3600
}
expiresAt := time.Now().Add(time.Duration(expiresIn) * time.Second)
email := kiroauth.ExtractEmailFromJWT(tokenResp.AccessToken)
idPart := kiroauth.SanitizeEmailForFilename(email)
if idPart == "" {
idPart = fmt.Sprintf("%d", time.Now().UnixNano()%100000)
}
now := time.Now()
fileName := fmt.Sprintf("kiro-%s-%s.json", strings.ToLower(provider), idPart)
record := &coreauth.Auth{
ID: fileName,
Provider: "kiro",
FileName: fileName,
Metadata: map[string]any{
"type": "kiro",
"access_token": tokenResp.AccessToken,
"refresh_token": tokenResp.RefreshToken,
"profile_arn": tokenResp.ProfileArn,
"expires_at": expiresAt.Format(time.RFC3339),
"auth_method": "social",
"provider": provider,
"email": email,
"last_refresh": now.Format(time.RFC3339),
},
}
savedPath, errSave := h.saveTokenRecord(ctx, record)
if errSave != nil {
log.Errorf("Failed to save authentication tokens: %v", errSave)
SetOAuthSessionError(state, "Failed to save authentication tokens")
return
}
fmt.Printf("Authentication successful! Token saved to %s\n", savedPath)
if email != "" {
fmt.Printf("Authenticated as: %s\n", email)
}
CompleteOAuthSession(state)
return
}
time.Sleep(500 * time.Millisecond)
}
}()
c.JSON(http.StatusOK, gin.H{"status": "ok", "state": state, "method": "social"})
default:
c.JSON(http.StatusBadRequest, gin.H{"error": "invalid method, use 'aws', 'google', or 'github'"})
}
}
// generateKiroPKCE generates PKCE code verifier and challenge for Kiro OAuth.
func generateKiroPKCE() (verifier, challenge string, err error) {
b := make([]byte, 32)
if _, errRead := io.ReadFull(rand.Reader, b); errRead != nil {
return "", "", fmt.Errorf("failed to generate random bytes: %w", errRead)
}
verifier = base64.RawURLEncoding.EncodeToString(b)
h := sha256.Sum256([]byte(verifier))
challenge = base64.RawURLEncoding.EncodeToString(h[:])
return verifier, challenge, nil
}
func (h *Handler) RequestKiloToken(c *gin.Context) {
ctx := context.Background()
fmt.Println("Initializing Kilo authentication...")
state := fmt.Sprintf("kil-%d", time.Now().UnixNano())
kilocodeAuth := kilo.NewKiloAuth()
resp, err := kilocodeAuth.InitiateDeviceFlow(ctx)
if err != nil {
log.Errorf("Failed to initiate device flow: %v", err)
c.JSON(http.StatusInternalServerError, gin.H{"error": "failed to initiate device flow"})
return
}
RegisterOAuthSession(state, "kilo")
go func() {
fmt.Printf("Please visit %s and enter code: %s\n", resp.VerificationURL, resp.Code)
status, err := kilocodeAuth.PollForToken(ctx, resp.Code)
if err != nil {
SetOAuthSessionError(state, "Authentication failed")
fmt.Printf("Authentication failed: %v\n", err)
return
}
profile, err := kilocodeAuth.GetProfile(ctx, status.Token)
if err != nil {
log.Warnf("Failed to fetch profile: %v", err)
profile = &kilo.Profile{Email: status.UserEmail}
}
var orgID string
if len(profile.Orgs) > 0 {
orgID = profile.Orgs[0].ID
}
defaults, err := kilocodeAuth.GetDefaults(ctx, status.Token, orgID)
if err != nil {
defaults = &kilo.Defaults{}
}
ts := &kilo.KiloTokenStorage{
Token: status.Token,
OrganizationID: orgID,
Model: defaults.Model,
Email: status.UserEmail,
Type: "kilo",
}
fileName := kilo.CredentialFileName(status.UserEmail)
record := &coreauth.Auth{
ID: fileName,
Provider: "kilo",
FileName: fileName,
Storage: ts,
Metadata: map[string]any{
"email": status.UserEmail,
"organization_id": orgID,
"model": defaults.Model,
},
}
savedPath, errSave := h.saveTokenRecord(ctx, record)
if errSave != nil {
log.Errorf("Failed to save authentication tokens: %v", errSave)
SetOAuthSessionError(state, "Failed to save authentication tokens")
return
}
fmt.Printf("Authentication successful! Token saved to %s\n", savedPath)
CompleteOAuthSession(state)
CompleteOAuthSessionsByProvider("kilo")
}()
c.JSON(200, gin.H{
"status": "ok",
"url": resp.VerificationURL,
"state": state,
"user_code": resp.Code,
"verification_uri": resp.VerificationURL,
})
}

View File

@@ -0,0 +1,164 @@
package management
import (
"encoding/json"
"net/http"
"net/http/httptest"
"os"
"path/filepath"
"strings"
"testing"
"github.com/gin-gonic/gin"
"github.com/router-for-me/CLIProxyAPI/v6/internal/config"
coreauth "github.com/router-for-me/CLIProxyAPI/v6/sdk/cliproxy/auth"
)
func TestRequestGitLabPATToken_SavesAuthRecord(t *testing.T) {
t.Setenv("MANAGEMENT_PASSWORD", "")
gin.SetMode(gin.TestMode)
upstream := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
if got := r.Header.Get("Authorization"); got != "Bearer glpat-test-token" {
t.Fatalf("authorization header = %q, want Bearer glpat-test-token", got)
}
w.Header().Set("Content-Type", "application/json")
switch r.URL.Path {
case "/api/v4/user":
_ = json.NewEncoder(w).Encode(map[string]any{
"id": 42,
"username": "gitlab-user",
"name": "GitLab User",
"email": "gitlab@example.com",
})
case "/api/v4/personal_access_tokens/self":
_ = json.NewEncoder(w).Encode(map[string]any{
"id": 7,
"name": "management-center",
"scopes": []string{"api", "read_user"},
"user_id": 42,
})
case "/api/v4/code_suggestions/direct_access":
_ = json.NewEncoder(w).Encode(map[string]any{
"base_url": "https://cloud.gitlab.example.com",
"token": "gateway-token",
"expires_at": 1893456000,
"headers": map[string]string{
"X-Gitlab-Realm": "saas",
},
"model_details": map[string]any{
"model_provider": "anthropic",
"model_name": "claude-sonnet-4-5",
},
})
default:
http.NotFound(w, r)
}
}))
defer upstream.Close()
store := &memoryAuthStore{}
h := NewHandlerWithoutConfigFilePath(&config.Config{AuthDir: t.TempDir()}, coreauth.NewManager(nil, nil, nil))
h.tokenStore = store
rec := httptest.NewRecorder()
ctx, _ := gin.CreateTestContext(rec)
ctx.Request = httptest.NewRequest(http.MethodPost, "/v0/management/gitlab-auth-url", strings.NewReader(`{"base_url":"`+upstream.URL+`","personal_access_token":"glpat-test-token"}`))
ctx.Request.Header.Set("Content-Type", "application/json")
h.RequestGitLabPATToken(ctx)
if rec.Code != http.StatusOK {
t.Fatalf("expected status %d, got %d with body %s", http.StatusOK, rec.Code, rec.Body.String())
}
var resp map[string]any
if err := json.Unmarshal(rec.Body.Bytes(), &resp); err != nil {
t.Fatalf("decode response: %v", err)
}
if got := resp["status"]; got != "ok" {
t.Fatalf("status = %#v, want ok", got)
}
if got := resp["model_provider"]; got != "anthropic" {
t.Fatalf("model_provider = %#v, want anthropic", got)
}
if got := resp["model_name"]; got != "claude-sonnet-4-5" {
t.Fatalf("model_name = %#v, want claude-sonnet-4-5", got)
}
store.mu.Lock()
defer store.mu.Unlock()
if len(store.items) != 1 {
t.Fatalf("expected 1 saved auth record, got %d", len(store.items))
}
var saved *coreauth.Auth
for _, item := range store.items {
saved = item
}
if saved == nil {
t.Fatal("expected saved auth record")
}
if saved.Provider != "gitlab" {
t.Fatalf("provider = %q, want gitlab", saved.Provider)
}
if got := saved.Metadata["auth_kind"]; got != "personal_access_token" {
t.Fatalf("auth_kind = %#v, want personal_access_token", got)
}
if got := saved.Metadata["model_provider"]; got != "anthropic" {
t.Fatalf("saved model_provider = %#v, want anthropic", got)
}
if got := saved.Metadata["duo_gateway_token"]; got != "gateway-token" {
t.Fatalf("saved duo_gateway_token = %#v, want gateway-token", got)
}
}
func TestPostOAuthCallback_GitLabWritesPendingCallbackFile(t *testing.T) {
t.Setenv("MANAGEMENT_PASSWORD", "")
gin.SetMode(gin.TestMode)
authDir := t.TempDir()
state := "gitlab-state-123"
RegisterOAuthSession(state, "gitlab")
t.Cleanup(func() { CompleteOAuthSession(state) })
h := NewHandlerWithoutConfigFilePath(&config.Config{AuthDir: authDir}, coreauth.NewManager(nil, nil, nil))
rec := httptest.NewRecorder()
ctx, _ := gin.CreateTestContext(rec)
ctx.Request = httptest.NewRequest(http.MethodPost, "/v0/management/oauth-callback", strings.NewReader(`{"provider":"gitlab","redirect_url":"http://localhost:17171/auth/callback?code=test-code&state=`+state+`"}`))
ctx.Request.Header.Set("Content-Type", "application/json")
h.PostOAuthCallback(ctx)
if rec.Code != http.StatusOK {
t.Fatalf("expected status %d, got %d with body %s", http.StatusOK, rec.Code, rec.Body.String())
}
filePath := filepath.Join(authDir, ".oauth-gitlab-"+state+".oauth")
data, err := os.ReadFile(filePath)
if err != nil {
t.Fatalf("read callback file: %v", err)
}
var payload map[string]string
if err := json.Unmarshal(data, &payload); err != nil {
t.Fatalf("decode callback payload: %v", err)
}
if got := payload["code"]; got != "test-code" {
t.Fatalf("callback code = %q, want test-code", got)
}
if got := payload["state"]; got != state {
t.Fatalf("callback state = %q, want %q", got, state)
}
}
func TestNormalizeOAuthProvider_GitLab(t *testing.T) {
provider, err := NormalizeOAuthProvider("gitlab")
if err != nil {
t.Fatalf("NormalizeOAuthProvider returned error: %v", err)
}
if provider != "gitlab" {
t.Fatalf("provider = %q, want gitlab", provider)
}
}

View File

@@ -19,8 +19,8 @@ import (
)
const (
latestReleaseURL = "https://api.github.com/repos/router-for-me/CLIProxyAPI/releases/latest"
latestReleaseUserAgent = "CLIProxyAPI"
latestReleaseURL = "https://api.github.com/repos/router-for-me/CLIProxyAPIPlus/releases/latest"
latestReleaseUserAgent = "CLIProxyAPIPlus"
)
func (h *Handler) GetConfig(c *gin.Context) {

View File

@@ -761,18 +761,22 @@ func (h *Handler) PatchOAuthModelAlias(c *gin.Context) {
normalizedMap := sanitizedOAuthModelAlias(map[string][]config.OAuthModelAlias{channel: body.Aliases})
normalized := normalizedMap[channel]
if len(normalized) == 0 {
// Only delete if channel exists, otherwise just create empty entry
if h.cfg.OAuthModelAlias != nil {
if _, ok := h.cfg.OAuthModelAlias[channel]; ok {
delete(h.cfg.OAuthModelAlias, channel)
if len(h.cfg.OAuthModelAlias) == 0 {
h.cfg.OAuthModelAlias = nil
}
h.persist(c)
return
}
}
// Create new channel with empty aliases
if h.cfg.OAuthModelAlias == nil {
c.JSON(404, gin.H{"error": "channel not found"})
return
}
if _, ok := h.cfg.OAuthModelAlias[channel]; !ok {
c.JSON(404, gin.H{"error": "channel not found"})
return
}
delete(h.cfg.OAuthModelAlias, channel)
if len(h.cfg.OAuthModelAlias) == 0 {
h.cfg.OAuthModelAlias = nil
h.cfg.OAuthModelAlias = make(map[string][]config.OAuthModelAlias)
}
h.cfg.OAuthModelAlias[channel] = []config.OAuthModelAlias{}
h.persist(c)
return
}
@@ -800,10 +804,10 @@ func (h *Handler) DeleteOAuthModelAlias(c *gin.Context) {
c.JSON(404, gin.H{"error": "channel not found"})
return
}
delete(h.cfg.OAuthModelAlias, channel)
if len(h.cfg.OAuthModelAlias) == 0 {
h.cfg.OAuthModelAlias = nil
}
// Set to nil instead of deleting the key so that the "explicitly disabled"
// marker survives config reload and prevents SanitizeOAuthModelAlias from
// re-injecting default aliases (fixes #222).
h.cfg.OAuthModelAlias[channel] = nil
h.persist(c)
}

View File

@@ -158,7 +158,12 @@ func (s *oauthSessionStore) IsPending(state, provider string) bool {
return false
}
if session.Status != "" {
return false
if !strings.EqualFold(session.Provider, "kiro") {
return false
}
if !strings.HasPrefix(session.Status, "device_code|") && !strings.HasPrefix(session.Status, "auth_url|") {
return false
}
}
if provider == "" {
return true
@@ -223,6 +228,8 @@ func NormalizeOAuthProvider(provider string) (string, error) {
return "anthropic", nil
case "codex", "openai":
return "codex", nil
case "gitlab":
return "gitlab", nil
case "gemini", "google":
return "gemini", nil
case "iflow", "i-flow":
@@ -231,6 +238,10 @@ func NormalizeOAuthProvider(provider string) (string, error) {
return "antigravity", nil
case "qwen":
return "qwen", nil
case "kiro":
return "kiro", nil
case "github":
return "github", nil
default:
return "", errUnsupportedOAuthFlow
}