Compare commits

..

42 Commits

Author SHA1 Message Date
Luis Pater
851712a49e Merge pull request #132 from ClubWeGo/codex/resolve-issue-#131
Resolve Issue #131
2026-01-26 23:36:16 +08:00
Luis Pater
9e34323a40 Merge branch 'router-for-me:main' into main 2026-01-26 23:35:07 +08:00
Luis Pater
70897247b2 feat(auth): add support for request_retry and disable_cooling overrides
Implement `request_retry` and `disable_cooling` metadata overrides for authentication management. Update retry and cooling logic accordingly across `Manager`, Antigravity executor, and file synthesizer. Add tests to validate new behaviors.
2026-01-26 21:59:08 +08:00
Luis Pater
9c341f5aa5 feat(auth): add skip persistence context key for file watcher events
Introduce `WithSkipPersist` to disable persistence during Manager Update/Register calls, preventing write-back loops caused by redundant file writes. Add corresponding tests and integrate with existing file store and conductor logic.
2026-01-26 18:20:19 +08:00
Darley
e3e741d0be Default Claude tool input schema 2026-01-26 09:15:38 +08:00
Darley
7c7c5fd967 Fix Kiro tool schema defaults 2026-01-26 08:27:53 +08:00
Luis Pater
fe8c7a62aa Merge branch 'router-for-me:main' into main 2026-01-26 06:23:41 +08:00
Luis Pater
2af4a8dc12 refactor(runtime): implement retry logic for Antigravity executor with improved error handling and capacity management 2026-01-26 06:22:46 +08:00
Luis Pater
0f53b952b2 Merge pull request #1225 from router-for-me/log
Add request_id to error logs and extract error messages
2026-01-25 22:08:46 +08:00
Luis Pater
7b2ae7377a chore(auth): add net/url import to auth_files.go for URL handling 2026-01-25 21:53:20 +08:00
Luis Pater
c2ab288c7d Merge pull request #130 from router-for-me/plus
v6.7.22
2026-01-25 21:51:20 +08:00
Luis Pater
dbb433fcf8 Merge branch 'main' into plus 2026-01-25 21:51:02 +08:00
Luis Pater
2abf00b5a6 Merge pull request #126 from jellyfish-p/main
feat(kiro): 添加用于令牌额度查询的api-call兼容
2026-01-25 21:49:07 +08:00
Luis Pater
275839e5c9 Merge pull request #124 from gogoing1024/main
fix(kiro): always attempt token refresh on 401 before checking retry …
2026-01-25 21:48:03 +08:00
hkfires
f30ffd5f5e feat(executor): add request_id to error logs
Extract error.message from JSON error responses when summarizing error bodies for debug logs
2026-01-25 21:31:46 +08:00
Luis Pater
bc9a24d705 docs(readme): reposition CPA-XXX Panel section for improved visibility 2026-01-25 18:58:32 +08:00
Luis Pater
2c879f13ef Merge pull request #1216 from ferretgeek/add-cpa-xxx-panel
docs: 新增 CPA-XXX 社区面板项目
2026-01-25 18:57:32 +08:00
Gemini
07b4a08979 docs: translate CPA-XXX description to English 2026-01-25 18:00:28 +08:00
jellyfish-p
497339f055 feat(kiro): 添加用于令牌额度查询的api-call兼容 2026-01-25 11:36:52 +08:00
Gemini
7f612bb069 docs: add CPA-XXX panel to community list
Co-authored-by: factory-droid[bot] <138933559+factory-droid[bot]@users.noreply.github.com>
2026-01-25 10:45:51 +08:00
hkfires
5743b78694 test(claude): update expectations for system message handling 2026-01-25 08:31:29 +08:00
Luis Pater
2e6a2b655c Merge pull request #1132 from XYenon/fix/gemini-models-displayname-override
fix(gemini): preserve displayName and description in models list
2026-01-25 03:40:04 +08:00
Luis Pater
cb47ac21bf Merge pull request #1179 from mallendeo/main
fix(claude): skip built-in tools in OAuth tool prefix
2026-01-25 03:31:58 +08:00
Luis Pater
a1394b4596 Merge pull request #1183 from Darley-Wey/fix/api-align
fix(api): enhance ClaudeModels response to align with api.anthropic.com
2026-01-25 03:30:14 +08:00
Luis Pater
9e97948f03 Merge pull request #1185 from router-for-me/auth
Refactor authentication handling for Antigravity, Claude, Codex, and Gemini
2026-01-25 03:28:53 +08:00
yuechenglong.5
8f780e7280 fix(kiro): always attempt token refresh on 401 before checking retry count
Refactor 401 error handling in both executeWithRetry and
executeStreamWithRetry to always attempt token refresh regardless of
remaining retry attempts. Previously, token refresh was only attempted
when retries remained, which could leave valid refreshed tokens unused.

Also add auth directory resolution in RefreshManager.Initialize to
properly resolve the base directory path before creating the token
repository.
2026-01-24 20:02:09 +08:00
Darley
46c6fb1e7a fix(api): enhance ClaudeModels response to align with api.anthropic.com 2026-01-24 04:41:08 +03:30
hkfires
9f9fec5d4c fix(auth): improve antigravity token exchange errors 2026-01-24 09:04:15 +08:00
hkfires
e95be10485 fix(auth): validate antigravity token userinfo email 2026-01-24 08:33:52 +08:00
hkfires
f3d58fa0ce fix(auth): correct antigravity oauth redirect and expiry 2026-01-24 08:33:52 +08:00
hkfires
8c0eaa1f71 refactor(auth): export Gemini constants and use in handler 2026-01-24 08:33:52 +08:00
hkfires
405df58f72 refactor(auth): export Codex constants and slim down handler 2026-01-24 08:33:52 +08:00
hkfires
e7f13aa008 refactor(api): slim down RequestAnthropicToken to use internal/auth 2026-01-24 08:33:51 +08:00
hkfires
7cb6a9b89a refactor(auth): export Claude OAuth constants for reuse 2026-01-24 08:33:51 +08:00
hkfires
9aa5344c29 refactor(api): slim down RequestAntigravityToken to use internal/auth 2026-01-24 08:33:51 +08:00
hkfires
8ba0ebbd2a refactor(sdk): slim down Antigravity authenticator to use internal/auth 2026-01-24 08:33:51 +08:00
hkfires
c65407ab9f refactor(auth): extract Antigravity OAuth constants to internal/auth 2026-01-24 08:33:51 +08:00
hkfires
9e59685212 refactor(auth): implement Antigravity AuthService in internal/auth 2026-01-24 08:33:51 +08:00
hkfires
4a4dfaa910 refactor(auth): replace sanitizeAntigravityFileName with antigravity.CredentialFileName 2026-01-24 08:33:51 +08:00
Luis Pater
0d6ecb0191 Fixed: #1077
refactor(translator): improve tools handling by separating functionDeclarations and googleSearch nodes
2026-01-24 05:51:11 +08:00
Mauricio Allende
f16461bfe7 fix(claude): skip built-in tools in OAuth tool prefix 2026-01-23 21:29:39 +00:00
XYenon
8c7c446f33 fix(gemini): preserve displayName and description in models list
Previously GeminiModels handler unconditionally overwrote displayName
and description with the model name, losing the original values defined
in model definitions (e.g., 'Gemini 3 Pro Preview').

Now only set these fields as fallback when they are missing or empty.
2026-01-22 15:19:27 +08:00
44 changed files with 1881 additions and 1172 deletions

2
go.mod
View File

@@ -40,6 +40,7 @@ require (
github.com/dlclark/regexp2 v1.11.5 // indirect github.com/dlclark/regexp2 v1.11.5 // indirect
github.com/dustin/go-humanize v1.0.1 // indirect github.com/dustin/go-humanize v1.0.1 // indirect
github.com/emirpasic/gods v1.18.1 // indirect github.com/emirpasic/gods v1.18.1 // indirect
github.com/fxamacker/cbor/v2 v2.9.0 // indirect
github.com/gabriel-vasile/mimetype v1.4.3 // indirect github.com/gabriel-vasile/mimetype v1.4.3 // indirect
github.com/gin-contrib/sse v0.1.0 // indirect github.com/gin-contrib/sse v0.1.0 // indirect
github.com/go-git/gcfg/v2 v2.0.2 // indirect github.com/go-git/gcfg/v2 v2.0.2 // indirect
@@ -69,6 +70,7 @@ require (
github.com/tidwall/pretty v1.2.0 // indirect github.com/tidwall/pretty v1.2.0 // indirect
github.com/twitchyliquid64/golang-asm v0.15.1 // indirect github.com/twitchyliquid64/golang-asm v0.15.1 // indirect
github.com/ugorji/go/codec v1.2.12 // indirect github.com/ugorji/go/codec v1.2.12 // indirect
github.com/x448/float16 v0.8.4 // indirect
golang.org/x/arch v0.8.0 // indirect golang.org/x/arch v0.8.0 // indirect
golang.org/x/sys v0.38.0 // indirect golang.org/x/sys v0.38.0 // indirect
golang.org/x/text v0.31.0 // indirect golang.org/x/text v0.31.0 // indirect

4
go.sum
View File

@@ -35,6 +35,8 @@ github.com/emirpasic/gods v1.18.1 h1:FXtiHYKDGKCW2KzwZKx0iC0PQmdlorYgdFG9jPXJ1Bc
github.com/emirpasic/gods v1.18.1/go.mod h1:8tpGGwCnJ5H4r6BWwaV6OrWmMoPhUl5jm/FMNAnJvWQ= github.com/emirpasic/gods v1.18.1/go.mod h1:8tpGGwCnJ5H4r6BWwaV6OrWmMoPhUl5jm/FMNAnJvWQ=
github.com/fsnotify/fsnotify v1.9.0 h1:2Ml+OJNzbYCTzsxtv8vKSFD9PbJjmhYF14k/jKC7S9k= github.com/fsnotify/fsnotify v1.9.0 h1:2Ml+OJNzbYCTzsxtv8vKSFD9PbJjmhYF14k/jKC7S9k=
github.com/fsnotify/fsnotify v1.9.0/go.mod h1:8jBTzvmWwFyi3Pb8djgCCO5IBqzKJ/Jwo8TRcHyHii0= github.com/fsnotify/fsnotify v1.9.0/go.mod h1:8jBTzvmWwFyi3Pb8djgCCO5IBqzKJ/Jwo8TRcHyHii0=
github.com/fxamacker/cbor/v2 v2.9.0 h1:NpKPmjDBgUfBms6tr6JZkTHtfFGcMKsw3eGcmD/sapM=
github.com/fxamacker/cbor/v2 v2.9.0/go.mod h1:vM4b+DJCtHn+zz7h3FFp/hDAI9WNWCsZj23V5ytsSxQ=
github.com/gabriel-vasile/mimetype v1.4.3 h1:in2uUcidCuFcDKtdcBxlR0rJ1+fsokWf+uqxgUFjbI0= github.com/gabriel-vasile/mimetype v1.4.3 h1:in2uUcidCuFcDKtdcBxlR0rJ1+fsokWf+uqxgUFjbI0=
github.com/gabriel-vasile/mimetype v1.4.3/go.mod h1:d8uq/6HKRL6CGdk+aubisF/M5GcPfT7nKyLpA0lbSSk= github.com/gabriel-vasile/mimetype v1.4.3/go.mod h1:d8uq/6HKRL6CGdk+aubisF/M5GcPfT7nKyLpA0lbSSk=
github.com/gin-contrib/sse v0.1.0 h1:Y/yl/+YNO8GZSjAhjMsSuLt29uWRFHdHYUb5lYOV9qE= github.com/gin-contrib/sse v0.1.0 h1:Y/yl/+YNO8GZSjAhjMsSuLt29uWRFHdHYUb5lYOV9qE=
@@ -157,6 +159,8 @@ github.com/twitchyliquid64/golang-asm v0.15.1 h1:SU5vSMR7hnwNxj24w34ZyCi/FmDZTkS
github.com/twitchyliquid64/golang-asm v0.15.1/go.mod h1:a1lVb/DtPvCB8fslRZhAngC2+aY1QWCk3Cedj/Gdt08= github.com/twitchyliquid64/golang-asm v0.15.1/go.mod h1:a1lVb/DtPvCB8fslRZhAngC2+aY1QWCk3Cedj/Gdt08=
github.com/ugorji/go/codec v1.2.12 h1:9LC83zGrHhuUA9l16C9AHXAqEV/2wBQ4nkvumAE65EE= github.com/ugorji/go/codec v1.2.12 h1:9LC83zGrHhuUA9l16C9AHXAqEV/2wBQ4nkvumAE65EE=
github.com/ugorji/go/codec v1.2.12/go.mod h1:UNopzCgEMSXjBc6AOMqYvWC1ktqTAfzJZUZgYf6w6lg= github.com/ugorji/go/codec v1.2.12/go.mod h1:UNopzCgEMSXjBc6AOMqYvWC1ktqTAfzJZUZgYf6w6lg=
github.com/x448/float16 v0.8.4 h1:qLwI1I70+NjRFUR3zs1JPUCgaCXSh3SW62uAKT1mSBM=
github.com/x448/float16 v0.8.4/go.mod h1:14CWIYCyZA/cWjXOioeEpHeN/83MdbZDRQHoFcYsOfg=
golang.org/x/arch v0.0.0-20210923205945-b76863e36670/go.mod h1:5om86z9Hs0C8fWVUuoMHwpExlXzs5Tkyp9hOrfG7pp8= golang.org/x/arch v0.0.0-20210923205945-b76863e36670/go.mod h1:5om86z9Hs0C8fWVUuoMHwpExlXzs5Tkyp9hOrfG7pp8=
golang.org/x/arch v0.8.0 h1:3wRIsP3pM4yUptoR96otTUOXI367OS0+c9eeRi9doIc= golang.org/x/arch v0.8.0 h1:3wRIsP3pM4yUptoR96otTUOXI367OS0+c9eeRi9doIc=
golang.org/x/arch v0.8.0/go.mod h1:FEVrYAQjsQXMVJ1nsMoVVXPZg6p2JE2mx8psSWTDQys= golang.org/x/arch v0.8.0/go.mod h1:FEVrYAQjsQXMVJ1nsMoVVXPZg6p2JE2mx8psSWTDQys=

View File

@@ -11,6 +11,7 @@ import (
"strings" "strings"
"time" "time"
"github.com/fxamacker/cbor/v2"
"github.com/gin-gonic/gin" "github.com/gin-gonic/gin"
"github.com/router-for-me/CLIProxyAPI/v6/internal/runtime/geminicli" "github.com/router-for-me/CLIProxyAPI/v6/internal/runtime/geminicli"
coreauth "github.com/router-for-me/CLIProxyAPI/v6/sdk/cliproxy/auth" coreauth "github.com/router-for-me/CLIProxyAPI/v6/sdk/cliproxy/auth"
@@ -70,7 +71,7 @@ type apiCallResponse struct {
// - Authorization: Bearer <key> // - Authorization: Bearer <key>
// - X-Management-Key: <key> // - X-Management-Key: <key>
// //
// Request JSON: // Request JSON (supports both application/json and application/cbor):
// - auth_index / authIndex / AuthIndex (optional): // - auth_index / authIndex / AuthIndex (optional):
// The credential "auth_index" from GET /v0/management/auth-files (or other endpoints returning it). // The credential "auth_index" from GET /v0/management/auth-files (or other endpoints returning it).
// If omitted or not found, credential-specific proxy/token substitution is skipped. // If omitted or not found, credential-specific proxy/token substitution is skipped.
@@ -90,10 +91,12 @@ type apiCallResponse struct {
// 2. Global config proxy-url // 2. Global config proxy-url
// 3. Direct connect (environment proxies are not used) // 3. Direct connect (environment proxies are not used)
// //
// Response JSON (returned with HTTP 200 when the APICall itself succeeds): // Response (returned with HTTP 200 when the APICall itself succeeds):
// - status_code: Upstream HTTP status code. //
// - header: Upstream response headers. // Format matches request Content-Type (application/json or application/cbor)
// - body: Upstream response body as string. // - status_code: Upstream HTTP status code.
// - header: Upstream response headers.
// - body: Upstream response body as string.
// //
// Example: // Example:
// //
@@ -107,10 +110,28 @@ type apiCallResponse struct {
// -H "Content-Type: application/json" \ // -H "Content-Type: application/json" \
// -d '{"auth_index":"<AUTH_INDEX>","method":"POST","url":"https://api.example.com/v1/fetchAvailableModels","header":{"Authorization":"Bearer $TOKEN$","Content-Type":"application/json","User-Agent":"cliproxyapi"},"data":"{}"}' // -d '{"auth_index":"<AUTH_INDEX>","method":"POST","url":"https://api.example.com/v1/fetchAvailableModels","header":{"Authorization":"Bearer $TOKEN$","Content-Type":"application/json","User-Agent":"cliproxyapi"},"data":"{}"}'
func (h *Handler) APICall(c *gin.Context) { func (h *Handler) APICall(c *gin.Context) {
// Detect content type
contentType := strings.ToLower(strings.TrimSpace(c.GetHeader("Content-Type")))
isCBOR := strings.Contains(contentType, "application/cbor")
var body apiCallRequest var body apiCallRequest
if errBindJSON := c.ShouldBindJSON(&body); errBindJSON != nil {
c.JSON(http.StatusBadRequest, gin.H{"error": "invalid body"}) // Parse request body based on content type
return if isCBOR {
rawBody, errRead := io.ReadAll(c.Request.Body)
if errRead != nil {
c.JSON(http.StatusBadRequest, gin.H{"error": "failed to read request body"})
return
}
if errUnmarshal := cbor.Unmarshal(rawBody, &body); errUnmarshal != nil {
c.JSON(http.StatusBadRequest, gin.H{"error": "invalid cbor body"})
return
}
} else {
if errBindJSON := c.ShouldBindJSON(&body); errBindJSON != nil {
c.JSON(http.StatusBadRequest, gin.H{"error": "invalid body"})
return
}
} }
method := strings.ToUpper(strings.TrimSpace(body.Method)) method := strings.ToUpper(strings.TrimSpace(body.Method))
@@ -209,11 +230,23 @@ func (h *Handler) APICall(c *gin.Context) {
return return
} }
c.JSON(http.StatusOK, apiCallResponse{ response := apiCallResponse{
StatusCode: resp.StatusCode, StatusCode: resp.StatusCode,
Header: resp.Header, Header: resp.Header,
Body: string(respBody), Body: string(respBody),
}) }
// Return response in the same format as the request
if isCBOR {
cborData, errMarshal := cbor.Marshal(response)
if errMarshal != nil {
c.JSON(http.StatusInternalServerError, gin.H{"error": "failed to encode cbor response"})
return
}
c.Data(http.StatusOK, "application/cbor", cborData)
} else {
c.JSON(http.StatusOK, response)
}
} }
func firstNonEmptyString(values ...*string) string { func firstNonEmptyString(values ...*string) string {

View File

@@ -0,0 +1,149 @@
package management
import (
"bytes"
"encoding/json"
"net/http"
"net/http/httptest"
"testing"
"github.com/fxamacker/cbor/v2"
"github.com/gin-gonic/gin"
)
func TestAPICall_CBOR_Support(t *testing.T) {
gin.SetMode(gin.TestMode)
// Create a test handler
h := &Handler{}
// Create test request data
reqData := apiCallRequest{
Method: "GET",
URL: "https://httpbin.org/get",
Header: map[string]string{
"User-Agent": "test-client",
},
}
t.Run("JSON request and response", func(t *testing.T) {
// Marshal request as JSON
jsonData, err := json.Marshal(reqData)
if err != nil {
t.Fatalf("Failed to marshal JSON: %v", err)
}
// Create HTTP request
req := httptest.NewRequest(http.MethodPost, "/v0/management/api-call", bytes.NewReader(jsonData))
req.Header.Set("Content-Type", "application/json")
// Create response recorder
w := httptest.NewRecorder()
// Create Gin context
c, _ := gin.CreateTestContext(w)
c.Request = req
// Call handler
h.APICall(c)
// Verify response
if w.Code != http.StatusOK && w.Code != http.StatusBadGateway {
t.Logf("Response status: %d", w.Code)
t.Logf("Response body: %s", w.Body.String())
}
// Check content type
contentType := w.Header().Get("Content-Type")
if w.Code == http.StatusOK && !contains(contentType, "application/json") {
t.Errorf("Expected JSON response, got: %s", contentType)
}
})
t.Run("CBOR request and response", func(t *testing.T) {
// Marshal request as CBOR
cborData, err := cbor.Marshal(reqData)
if err != nil {
t.Fatalf("Failed to marshal CBOR: %v", err)
}
// Create HTTP request
req := httptest.NewRequest(http.MethodPost, "/v0/management/api-call", bytes.NewReader(cborData))
req.Header.Set("Content-Type", "application/cbor")
// Create response recorder
w := httptest.NewRecorder()
// Create Gin context
c, _ := gin.CreateTestContext(w)
c.Request = req
// Call handler
h.APICall(c)
// Verify response
if w.Code != http.StatusOK && w.Code != http.StatusBadGateway {
t.Logf("Response status: %d", w.Code)
t.Logf("Response body: %s", w.Body.String())
}
// Check content type
contentType := w.Header().Get("Content-Type")
if w.Code == http.StatusOK && !contains(contentType, "application/cbor") {
t.Errorf("Expected CBOR response, got: %s", contentType)
}
// Try to decode CBOR response
if w.Code == http.StatusOK {
var response apiCallResponse
if err := cbor.Unmarshal(w.Body.Bytes(), &response); err != nil {
t.Errorf("Failed to unmarshal CBOR response: %v", err)
} else {
t.Logf("CBOR response decoded successfully: status_code=%d", response.StatusCode)
}
}
})
t.Run("CBOR encoding and decoding consistency", func(t *testing.T) {
// Test data
testReq := apiCallRequest{
Method: "POST",
URL: "https://example.com/api",
Header: map[string]string{
"Authorization": "Bearer $TOKEN$",
"Content-Type": "application/json",
},
Data: `{"key":"value"}`,
}
// Encode to CBOR
cborData, err := cbor.Marshal(testReq)
if err != nil {
t.Fatalf("Failed to marshal to CBOR: %v", err)
}
// Decode from CBOR
var decoded apiCallRequest
if err := cbor.Unmarshal(cborData, &decoded); err != nil {
t.Fatalf("Failed to unmarshal from CBOR: %v", err)
}
// Verify fields
if decoded.Method != testReq.Method {
t.Errorf("Method mismatch: got %s, want %s", decoded.Method, testReq.Method)
}
if decoded.URL != testReq.URL {
t.Errorf("URL mismatch: got %s, want %s", decoded.URL, testReq.URL)
}
if decoded.Data != testReq.Data {
t.Errorf("Data mismatch: got %s, want %s", decoded.Data, testReq.Data)
}
if len(decoded.Header) != len(testReq.Header) {
t.Errorf("Header count mismatch: got %d, want %d", len(decoded.Header), len(testReq.Header))
}
})
}
func contains(s, substr string) bool {
return len(s) > 0 && len(substr) > 0 && (s == substr || len(s) >= len(substr) && s[:len(substr)] == substr || bytes.Contains([]byte(s), []byte(substr)))
}

View File

@@ -3,10 +3,10 @@ package management
import ( import (
"bytes" "bytes"
"context" "context"
"encoding/hex"
"crypto/rand" "crypto/rand"
"crypto/sha256" "crypto/sha256"
"encoding/base64" "encoding/base64"
"encoding/hex"
"encoding/json" "encoding/json"
"errors" "errors"
"fmt" "fmt"
@@ -23,6 +23,7 @@ import (
"time" "time"
"github.com/gin-gonic/gin" "github.com/gin-gonic/gin"
"github.com/router-for-me/CLIProxyAPI/v6/internal/auth/antigravity"
"github.com/router-for-me/CLIProxyAPI/v6/internal/auth/claude" "github.com/router-for-me/CLIProxyAPI/v6/internal/auth/claude"
"github.com/router-for-me/CLIProxyAPI/v6/internal/auth/codex" "github.com/router-for-me/CLIProxyAPI/v6/internal/auth/codex"
"github.com/router-for-me/CLIProxyAPI/v6/internal/auth/copilot" "github.com/router-for-me/CLIProxyAPI/v6/internal/auth/copilot"
@@ -236,14 +237,6 @@ func stopForwarderInstance(port int, forwarder *callbackForwarder) {
log.Infof("callback forwarder on port %d stopped", port) log.Infof("callback forwarder on port %d stopped", port)
} }
func sanitizeAntigravityFileName(email string) string {
if strings.TrimSpace(email) == "" {
return "antigravity.json"
}
replacer := strings.NewReplacer("@", "_", ".", "_")
return fmt.Sprintf("antigravity-%s.json", replacer.Replace(email))
}
func (h *Handler) managementCallbackURL(path string) (string, error) { func (h *Handler) managementCallbackURL(path string) (string, error) {
if h == nil || h.cfg == nil || h.cfg.Port <= 0 { if h == nil || h.cfg == nil || h.cfg.Port <= 0 {
return "", fmt.Errorf("server port is not configured") return "", fmt.Errorf("server port is not configured")
@@ -985,67 +978,14 @@ func (h *Handler) RequestAnthropicToken(c *gin.Context) {
rawCode := resultMap["code"] rawCode := resultMap["code"]
code := strings.Split(rawCode, "#")[0] code := strings.Split(rawCode, "#")[0]
// Exchange code for tokens (replicate logic using updated redirect_uri) // Exchange code for tokens using internal auth service
// Extract client_id from the modified auth URL bundle, errExchange := anthropicAuth.ExchangeCodeForTokens(ctx, code, state, pkceCodes)
clientID := "" if errExchange != nil {
if u2, errP := url.Parse(authURL); errP == nil { authErr := claude.NewAuthenticationError(claude.ErrCodeExchangeFailed, errExchange)
clientID = u2.Query().Get("client_id")
}
// Build request
bodyMap := map[string]any{
"code": code,
"state": state,
"grant_type": "authorization_code",
"client_id": clientID,
"redirect_uri": "http://localhost:54545/callback",
"code_verifier": pkceCodes.CodeVerifier,
}
bodyJSON, _ := json.Marshal(bodyMap)
httpClient := util.SetProxy(&h.cfg.SDKConfig, &http.Client{})
req, _ := http.NewRequestWithContext(ctx, "POST", "https://console.anthropic.com/v1/oauth/token", strings.NewReader(string(bodyJSON)))
req.Header.Set("Content-Type", "application/json")
req.Header.Set("Accept", "application/json")
resp, errDo := httpClient.Do(req)
if errDo != nil {
authErr := claude.NewAuthenticationError(claude.ErrCodeExchangeFailed, errDo)
log.Errorf("Failed to exchange authorization code for tokens: %v", authErr) log.Errorf("Failed to exchange authorization code for tokens: %v", authErr)
SetOAuthSessionError(state, "Failed to exchange authorization code for tokens") SetOAuthSessionError(state, "Failed to exchange authorization code for tokens")
return return
} }
defer func() {
if errClose := resp.Body.Close(); errClose != nil {
log.Errorf("failed to close response body: %v", errClose)
}
}()
respBody, _ := io.ReadAll(resp.Body)
if resp.StatusCode != http.StatusOK {
log.Errorf("token exchange failed with status %d: %s", resp.StatusCode, string(respBody))
SetOAuthSessionError(state, fmt.Sprintf("token exchange failed with status %d", resp.StatusCode))
return
}
var tResp struct {
AccessToken string `json:"access_token"`
RefreshToken string `json:"refresh_token"`
ExpiresIn int `json:"expires_in"`
Account struct {
EmailAddress string `json:"email_address"`
} `json:"account"`
}
if errU := json.Unmarshal(respBody, &tResp); errU != nil {
log.Errorf("failed to parse token response: %v", errU)
SetOAuthSessionError(state, "Failed to parse token response")
return
}
bundle := &claude.ClaudeAuthBundle{
TokenData: claude.ClaudeTokenData{
AccessToken: tResp.AccessToken,
RefreshToken: tResp.RefreshToken,
Email: tResp.Account.EmailAddress,
Expire: time.Now().Add(time.Duration(tResp.ExpiresIn) * time.Second).Format(time.RFC3339),
},
LastRefresh: time.Now().Format(time.RFC3339),
}
// Create token storage // Create token storage
tokenStorage := anthropicAuth.CreateTokenStorage(bundle) tokenStorage := anthropicAuth.CreateTokenStorage(bundle)
@@ -1085,17 +1025,13 @@ func (h *Handler) RequestGeminiCLIToken(c *gin.Context) {
fmt.Println("Initializing Google authentication...") fmt.Println("Initializing Google authentication...")
// OAuth2 configuration (mirrors internal/auth/gemini) // OAuth2 configuration using exported constants from internal/auth/gemini
conf := &oauth2.Config{ conf := &oauth2.Config{
ClientID: "681255809395-oo8ft2oprdrnp9e3aqf6av3hmdib135j.apps.googleusercontent.com", ClientID: geminiAuth.ClientID,
ClientSecret: "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", ClientSecret: geminiAuth.ClientSecret,
RedirectURL: "http://localhost:8085/oauth2callback", RedirectURL: fmt.Sprintf("http://localhost:%d/oauth2callback", geminiAuth.DefaultCallbackPort),
Scopes: []string{ Scopes: geminiAuth.Scopes,
"https://www.googleapis.com/auth/cloud-platform", Endpoint: google.Endpoint,
"https://www.googleapis.com/auth/userinfo.email",
"https://www.googleapis.com/auth/userinfo.profile",
},
Endpoint: google.Endpoint,
} }
// Build authorization URL and return it immediately // Build authorization URL and return it immediately
@@ -1217,13 +1153,9 @@ func (h *Handler) RequestGeminiCLIToken(c *gin.Context) {
} }
ifToken["token_uri"] = "https://oauth2.googleapis.com/token" ifToken["token_uri"] = "https://oauth2.googleapis.com/token"
ifToken["client_id"] = "681255809395-oo8ft2oprdrnp9e3aqf6av3hmdib135j.apps.googleusercontent.com" ifToken["client_id"] = geminiAuth.ClientID
ifToken["client_secret"] = "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" ifToken["client_secret"] = geminiAuth.ClientSecret
ifToken["scopes"] = []string{ ifToken["scopes"] = geminiAuth.Scopes
"https://www.googleapis.com/auth/cloud-platform",
"https://www.googleapis.com/auth/userinfo.email",
"https://www.googleapis.com/auth/userinfo.profile",
}
ifToken["universe_domain"] = "googleapis.com" ifToken["universe_domain"] = "googleapis.com"
ts := geminiAuth.GeminiTokenStorage{ ts := geminiAuth.GeminiTokenStorage{
@@ -1410,73 +1342,25 @@ func (h *Handler) RequestCodexToken(c *gin.Context) {
} }
log.Debug("Authorization code received, exchanging for tokens...") log.Debug("Authorization code received, exchanging for tokens...")
// Extract client_id from authURL // Exchange code for tokens using internal auth service
clientID := "" bundle, errExchange := openaiAuth.ExchangeCodeForTokens(ctx, code, pkceCodes)
if u2, errP := url.Parse(authURL); errP == nil { if errExchange != nil {
clientID = u2.Query().Get("client_id") authErr := codex.NewAuthenticationError(codex.ErrCodeExchangeFailed, errExchange)
}
// Exchange code for tokens with redirect equal to mgmtRedirect
form := url.Values{
"grant_type": {"authorization_code"},
"client_id": {clientID},
"code": {code},
"redirect_uri": {"http://localhost:1455/auth/callback"},
"code_verifier": {pkceCodes.CodeVerifier},
}
httpClient := util.SetProxy(&h.cfg.SDKConfig, &http.Client{})
req, _ := http.NewRequestWithContext(ctx, "POST", "https://auth.openai.com/oauth/token", strings.NewReader(form.Encode()))
req.Header.Set("Content-Type", "application/x-www-form-urlencoded")
req.Header.Set("Accept", "application/json")
resp, errDo := httpClient.Do(req)
if errDo != nil {
authErr := codex.NewAuthenticationError(codex.ErrCodeExchangeFailed, errDo)
SetOAuthSessionError(state, "Failed to exchange authorization code for tokens") SetOAuthSessionError(state, "Failed to exchange authorization code for tokens")
log.Errorf("Failed to exchange authorization code for tokens: %v", authErr) log.Errorf("Failed to exchange authorization code for tokens: %v", authErr)
return return
} }
defer func() { _ = resp.Body.Close() }()
respBody, _ := io.ReadAll(resp.Body) // Extract additional info for filename generation
if resp.StatusCode != http.StatusOK { claims, _ := codex.ParseJWTToken(bundle.TokenData.IDToken)
SetOAuthSessionError(state, fmt.Sprintf("Token exchange failed with status %d", resp.StatusCode))
log.Errorf("token exchange failed with status %d: %s", resp.StatusCode, string(respBody))
return
}
var tokenResp struct {
AccessToken string `json:"access_token"`
RefreshToken string `json:"refresh_token"`
IDToken string `json:"id_token"`
ExpiresIn int `json:"expires_in"`
}
if errU := json.Unmarshal(respBody, &tokenResp); errU != nil {
SetOAuthSessionError(state, "Failed to parse token response")
log.Errorf("failed to parse token response: %v", errU)
return
}
claims, _ := codex.ParseJWTToken(tokenResp.IDToken)
email := ""
accountID := ""
planType := "" planType := ""
if claims != nil {
email = claims.GetUserEmail()
accountID = claims.GetAccountID()
planType = strings.TrimSpace(claims.CodexAuthInfo.ChatgptPlanType)
}
hashAccountID := "" hashAccountID := ""
if accountID != "" { if claims != nil {
digest := sha256.Sum256([]byte(accountID)) planType = strings.TrimSpace(claims.CodexAuthInfo.ChatgptPlanType)
hashAccountID = hex.EncodeToString(digest[:])[:8] if accountID := claims.GetAccountID(); accountID != "" {
} digest := sha256.Sum256([]byte(accountID))
// Build bundle compatible with existing storage hashAccountID = hex.EncodeToString(digest[:])[:8]
bundle := &codex.CodexAuthBundle{ }
TokenData: codex.CodexTokenData{
IDToken: tokenResp.IDToken,
AccessToken: tokenResp.AccessToken,
RefreshToken: tokenResp.RefreshToken,
AccountID: accountID,
Email: email,
Expire: time.Now().Add(time.Duration(tokenResp.ExpiresIn) * time.Second).Format(time.RFC3339),
},
LastRefresh: time.Now().Format(time.RFC3339),
} }
// Create token storage and persist // Create token storage and persist
@@ -1511,23 +1395,12 @@ func (h *Handler) RequestCodexToken(c *gin.Context) {
} }
func (h *Handler) RequestAntigravityToken(c *gin.Context) { func (h *Handler) RequestAntigravityToken(c *gin.Context) {
const (
antigravityCallbackPort = 51121
antigravityClientID = "1071006060591-tmhssin2h21lcre235vtolojh4g403ep.apps.googleusercontent.com"
antigravityClientSecret = "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"
)
var antigravityScopes = []string{
"https://www.googleapis.com/auth/cloud-platform",
"https://www.googleapis.com/auth/userinfo.email",
"https://www.googleapis.com/auth/userinfo.profile",
"https://www.googleapis.com/auth/cclog",
"https://www.googleapis.com/auth/experimentsandconfigs",
}
ctx := context.Background() ctx := context.Background()
fmt.Println("Initializing Antigravity authentication...") fmt.Println("Initializing Antigravity authentication...")
authSvc := antigravity.NewAntigravityAuth(h.cfg, nil)
state, errState := misc.GenerateRandomState() state, errState := misc.GenerateRandomState()
if errState != nil { if errState != nil {
log.Errorf("Failed to generate state parameter: %v", errState) log.Errorf("Failed to generate state parameter: %v", errState)
@@ -1535,17 +1408,8 @@ func (h *Handler) RequestAntigravityToken(c *gin.Context) {
return return
} }
redirectURI := fmt.Sprintf("http://localhost:%d/oauth-callback", antigravityCallbackPort) redirectURI := fmt.Sprintf("http://localhost:%d/oauth-callback", antigravity.CallbackPort)
authURL := authSvc.BuildAuthURL(state, redirectURI)
params := url.Values{}
params.Set("access_type", "offline")
params.Set("client_id", antigravityClientID)
params.Set("prompt", "consent")
params.Set("redirect_uri", redirectURI)
params.Set("response_type", "code")
params.Set("scope", strings.Join(antigravityScopes, " "))
params.Set("state", state)
authURL := "https://accounts.google.com/o/oauth2/v2/auth?" + params.Encode()
RegisterOAuthSession(state, "antigravity") RegisterOAuthSession(state, "antigravity")
@@ -1559,7 +1423,7 @@ func (h *Handler) RequestAntigravityToken(c *gin.Context) {
return return
} }
var errStart error var errStart error
if forwarder, errStart = startCallbackForwarder(antigravityCallbackPort, "antigravity", targetURL); errStart != nil { if forwarder, errStart = startCallbackForwarder(antigravity.CallbackPort, "antigravity", targetURL); errStart != nil {
log.WithError(errStart).Error("failed to start antigravity callback forwarder") log.WithError(errStart).Error("failed to start antigravity callback forwarder")
c.JSON(http.StatusInternalServerError, gin.H{"error": "failed to start callback server"}) c.JSON(http.StatusInternalServerError, gin.H{"error": "failed to start callback server"})
return return
@@ -1568,7 +1432,7 @@ func (h *Handler) RequestAntigravityToken(c *gin.Context) {
go func() { go func() {
if isWebUI { if isWebUI {
defer stopCallbackForwarderInstance(antigravityCallbackPort, forwarder) defer stopCallbackForwarderInstance(antigravity.CallbackPort, forwarder)
} }
waitFile := filepath.Join(h.cfg.AuthDir, fmt.Sprintf(".oauth-antigravity-%s.oauth", state)) waitFile := filepath.Join(h.cfg.AuthDir, fmt.Sprintf(".oauth-antigravity-%s.oauth", state))
@@ -1608,93 +1472,36 @@ func (h *Handler) RequestAntigravityToken(c *gin.Context) {
time.Sleep(500 * time.Millisecond) time.Sleep(500 * time.Millisecond)
} }
httpClient := util.SetProxy(&h.cfg.SDKConfig, &http.Client{}) tokenResp, errToken := authSvc.ExchangeCodeForTokens(ctx, authCode, redirectURI)
form := url.Values{} if errToken != nil {
form.Set("code", authCode) log.Errorf("Failed to exchange token: %v", errToken)
form.Set("client_id", antigravityClientID)
form.Set("client_secret", antigravityClientSecret)
form.Set("redirect_uri", redirectURI)
form.Set("grant_type", "authorization_code")
req, errNewRequest := http.NewRequestWithContext(ctx, http.MethodPost, "https://oauth2.googleapis.com/token", strings.NewReader(form.Encode()))
if errNewRequest != nil {
log.Errorf("Failed to build token request: %v", errNewRequest)
SetOAuthSessionError(state, "Failed to build token request")
return
}
req.Header.Set("Content-Type", "application/x-www-form-urlencoded")
resp, errDo := httpClient.Do(req)
if errDo != nil {
log.Errorf("Failed to execute token request: %v", errDo)
SetOAuthSessionError(state, "Failed to exchange token") SetOAuthSessionError(state, "Failed to exchange token")
return return
} }
defer func() {
if errClose := resp.Body.Close(); errClose != nil {
log.Errorf("antigravity token exchange close error: %v", errClose)
}
}()
if resp.StatusCode < http.StatusOK || resp.StatusCode >= http.StatusMultipleChoices { accessToken := strings.TrimSpace(tokenResp.AccessToken)
bodyBytes, _ := io.ReadAll(resp.Body) if accessToken == "" {
log.Errorf("Antigravity token exchange failed with status %d: %s", resp.StatusCode, string(bodyBytes)) log.Error("antigravity: token exchange returned empty access token")
SetOAuthSessionError(state, fmt.Sprintf("Token exchange failed: %d", resp.StatusCode)) SetOAuthSessionError(state, "Failed to exchange token")
return return
} }
var tokenResp struct { email, errInfo := authSvc.FetchUserInfo(ctx, accessToken)
AccessToken string `json:"access_token"` if errInfo != nil {
RefreshToken string `json:"refresh_token"` log.Errorf("Failed to fetch user info: %v", errInfo)
ExpiresIn int64 `json:"expires_in"` SetOAuthSessionError(state, "Failed to fetch user info")
TokenType string `json:"token_type"`
}
if errDecode := json.NewDecoder(resp.Body).Decode(&tokenResp); errDecode != nil {
log.Errorf("Failed to parse token response: %v", errDecode)
SetOAuthSessionError(state, "Failed to parse token response")
return return
} }
email = strings.TrimSpace(email)
email := "" if email == "" {
if strings.TrimSpace(tokenResp.AccessToken) != "" { log.Error("antigravity: user info returned empty email")
infoReq, errInfoReq := http.NewRequestWithContext(ctx, http.MethodGet, "https://www.googleapis.com/oauth2/v1/userinfo?alt=json", nil) SetOAuthSessionError(state, "Failed to fetch user info")
if errInfoReq != nil { return
log.Errorf("Failed to build user info request: %v", errInfoReq)
SetOAuthSessionError(state, "Failed to build user info request")
return
}
infoReq.Header.Set("Authorization", "Bearer "+tokenResp.AccessToken)
infoResp, errInfo := httpClient.Do(infoReq)
if errInfo != nil {
log.Errorf("Failed to execute user info request: %v", errInfo)
SetOAuthSessionError(state, "Failed to execute user info request")
return
}
defer func() {
if errClose := infoResp.Body.Close(); errClose != nil {
log.Errorf("antigravity user info close error: %v", errClose)
}
}()
if infoResp.StatusCode >= http.StatusOK && infoResp.StatusCode < http.StatusMultipleChoices {
var infoPayload struct {
Email string `json:"email"`
}
if errDecodeInfo := json.NewDecoder(infoResp.Body).Decode(&infoPayload); errDecodeInfo == nil {
email = strings.TrimSpace(infoPayload.Email)
}
} else {
bodyBytes, _ := io.ReadAll(infoResp.Body)
log.Errorf("User info request failed with status %d: %s", infoResp.StatusCode, string(bodyBytes))
SetOAuthSessionError(state, fmt.Sprintf("User info request failed: %d", infoResp.StatusCode))
return
}
} }
projectID := "" projectID := ""
if strings.TrimSpace(tokenResp.AccessToken) != "" { if accessToken != "" {
fetchedProjectID, errProject := sdkAuth.FetchAntigravityProjectID(ctx, tokenResp.AccessToken, httpClient) fetchedProjectID, errProject := authSvc.FetchProjectID(ctx, accessToken)
if errProject != nil { if errProject != nil {
log.Warnf("antigravity: failed to fetch project ID: %v", errProject) log.Warnf("antigravity: failed to fetch project ID: %v", errProject)
} else { } else {
@@ -1719,7 +1526,7 @@ func (h *Handler) RequestAntigravityToken(c *gin.Context) {
metadata["project_id"] = projectID metadata["project_id"] = projectID
} }
fileName := sanitizeAntigravityFileName(email) fileName := antigravity.CredentialFileName(email)
label := strings.TrimSpace(email) label := strings.TrimSpace(email)
if label == "" { if label == "" {
label = "antigravity" label = "antigravity"

View File

@@ -0,0 +1,344 @@
// Package antigravity provides OAuth2 authentication functionality for the Antigravity provider.
package antigravity
import (
"context"
"encoding/json"
"fmt"
"io"
"net/http"
"net/url"
"strings"
"time"
"github.com/router-for-me/CLIProxyAPI/v6/internal/config"
"github.com/router-for-me/CLIProxyAPI/v6/internal/util"
log "github.com/sirupsen/logrus"
)
// TokenResponse represents OAuth token response from Google
type TokenResponse struct {
AccessToken string `json:"access_token"`
RefreshToken string `json:"refresh_token"`
ExpiresIn int64 `json:"expires_in"`
TokenType string `json:"token_type"`
}
// userInfo represents Google user profile
type userInfo struct {
Email string `json:"email"`
}
// AntigravityAuth handles Antigravity OAuth authentication
type AntigravityAuth struct {
httpClient *http.Client
}
// NewAntigravityAuth creates a new Antigravity auth service.
func NewAntigravityAuth(cfg *config.Config, httpClient *http.Client) *AntigravityAuth {
if httpClient != nil {
return &AntigravityAuth{httpClient: httpClient}
}
if cfg == nil {
cfg = &config.Config{}
}
return &AntigravityAuth{
httpClient: util.SetProxy(&cfg.SDKConfig, &http.Client{}),
}
}
// BuildAuthURL generates the OAuth authorization URL.
func (o *AntigravityAuth) BuildAuthURL(state, redirectURI string) string {
if strings.TrimSpace(redirectURI) == "" {
redirectURI = fmt.Sprintf("http://localhost:%d/oauth-callback", CallbackPort)
}
params := url.Values{}
params.Set("access_type", "offline")
params.Set("client_id", ClientID)
params.Set("prompt", "consent")
params.Set("redirect_uri", redirectURI)
params.Set("response_type", "code")
params.Set("scope", strings.Join(Scopes, " "))
params.Set("state", state)
return AuthEndpoint + "?" + params.Encode()
}
// ExchangeCodeForTokens exchanges authorization code for access and refresh tokens
func (o *AntigravityAuth) ExchangeCodeForTokens(ctx context.Context, code, redirectURI string) (*TokenResponse, error) {
data := url.Values{}
data.Set("code", code)
data.Set("client_id", ClientID)
data.Set("client_secret", ClientSecret)
data.Set("redirect_uri", redirectURI)
data.Set("grant_type", "authorization_code")
req, err := http.NewRequestWithContext(ctx, http.MethodPost, TokenEndpoint, strings.NewReader(data.Encode()))
if err != nil {
return nil, fmt.Errorf("antigravity token exchange: create request: %w", err)
}
req.Header.Set("Content-Type", "application/x-www-form-urlencoded")
resp, errDo := o.httpClient.Do(req)
if errDo != nil {
return nil, fmt.Errorf("antigravity token exchange: execute request: %w", errDo)
}
defer func() {
if errClose := resp.Body.Close(); errClose != nil {
log.Errorf("antigravity token exchange: close body error: %v", errClose)
}
}()
if resp.StatusCode < http.StatusOK || resp.StatusCode >= http.StatusMultipleChoices {
bodyBytes, errRead := io.ReadAll(io.LimitReader(resp.Body, 8<<10))
if errRead != nil {
return nil, fmt.Errorf("antigravity token exchange: read response: %w", errRead)
}
body := strings.TrimSpace(string(bodyBytes))
if body == "" {
return nil, fmt.Errorf("antigravity token exchange: request failed: status %d", resp.StatusCode)
}
return nil, fmt.Errorf("antigravity token exchange: request failed: status %d: %s", resp.StatusCode, body)
}
var token TokenResponse
if errDecode := json.NewDecoder(resp.Body).Decode(&token); errDecode != nil {
return nil, fmt.Errorf("antigravity token exchange: decode response: %w", errDecode)
}
return &token, nil
}
// FetchUserInfo retrieves user email from Google
func (o *AntigravityAuth) FetchUserInfo(ctx context.Context, accessToken string) (string, error) {
accessToken = strings.TrimSpace(accessToken)
if accessToken == "" {
return "", fmt.Errorf("antigravity userinfo: missing access token")
}
req, err := http.NewRequestWithContext(ctx, http.MethodGet, UserInfoEndpoint, nil)
if err != nil {
return "", fmt.Errorf("antigravity userinfo: create request: %w", err)
}
req.Header.Set("Authorization", "Bearer "+accessToken)
resp, errDo := o.httpClient.Do(req)
if errDo != nil {
return "", fmt.Errorf("antigravity userinfo: execute request: %w", errDo)
}
defer func() {
if errClose := resp.Body.Close(); errClose != nil {
log.Errorf("antigravity userinfo: close body error: %v", errClose)
}
}()
if resp.StatusCode < http.StatusOK || resp.StatusCode >= http.StatusMultipleChoices {
bodyBytes, errRead := io.ReadAll(io.LimitReader(resp.Body, 8<<10))
if errRead != nil {
return "", fmt.Errorf("antigravity userinfo: read response: %w", errRead)
}
body := strings.TrimSpace(string(bodyBytes))
if body == "" {
return "", fmt.Errorf("antigravity userinfo: request failed: status %d", resp.StatusCode)
}
return "", fmt.Errorf("antigravity userinfo: request failed: status %d: %s", resp.StatusCode, body)
}
var info userInfo
if errDecode := json.NewDecoder(resp.Body).Decode(&info); errDecode != nil {
return "", fmt.Errorf("antigravity userinfo: decode response: %w", errDecode)
}
email := strings.TrimSpace(info.Email)
if email == "" {
return "", fmt.Errorf("antigravity userinfo: response missing email")
}
return email, nil
}
// FetchProjectID retrieves the project ID for the authenticated user via loadCodeAssist
func (o *AntigravityAuth) FetchProjectID(ctx context.Context, accessToken string) (string, error) {
loadReqBody := map[string]any{
"metadata": map[string]string{
"ideType": "ANTIGRAVITY",
"platform": "PLATFORM_UNSPECIFIED",
"pluginType": "GEMINI",
},
}
rawBody, errMarshal := json.Marshal(loadReqBody)
if errMarshal != nil {
return "", fmt.Errorf("marshal request body: %w", errMarshal)
}
endpointURL := fmt.Sprintf("%s/%s:loadCodeAssist", APIEndpoint, APIVersion)
req, err := http.NewRequestWithContext(ctx, http.MethodPost, endpointURL, strings.NewReader(string(rawBody)))
if err != nil {
return "", fmt.Errorf("create request: %w", err)
}
req.Header.Set("Authorization", "Bearer "+accessToken)
req.Header.Set("Content-Type", "application/json")
req.Header.Set("User-Agent", APIUserAgent)
req.Header.Set("X-Goog-Api-Client", APIClient)
req.Header.Set("Client-Metadata", ClientMetadata)
resp, errDo := o.httpClient.Do(req)
if errDo != nil {
return "", fmt.Errorf("execute request: %w", errDo)
}
defer func() {
if errClose := resp.Body.Close(); errClose != nil {
log.Errorf("antigravity loadCodeAssist: close body error: %v", errClose)
}
}()
bodyBytes, errRead := io.ReadAll(resp.Body)
if errRead != nil {
return "", fmt.Errorf("read response: %w", errRead)
}
if resp.StatusCode < http.StatusOK || resp.StatusCode >= http.StatusMultipleChoices {
return "", fmt.Errorf("request failed with status %d: %s", resp.StatusCode, strings.TrimSpace(string(bodyBytes)))
}
var loadResp map[string]any
if errDecode := json.Unmarshal(bodyBytes, &loadResp); errDecode != nil {
return "", fmt.Errorf("decode response: %w", errDecode)
}
// Extract projectID from response
projectID := ""
if id, ok := loadResp["cloudaicompanionProject"].(string); ok {
projectID = strings.TrimSpace(id)
}
if projectID == "" {
if projectMap, ok := loadResp["cloudaicompanionProject"].(map[string]any); ok {
if id, okID := projectMap["id"].(string); okID {
projectID = strings.TrimSpace(id)
}
}
}
if projectID == "" {
tierID := "legacy-tier"
if tiers, okTiers := loadResp["allowedTiers"].([]any); okTiers {
for _, rawTier := range tiers {
tier, okTier := rawTier.(map[string]any)
if !okTier {
continue
}
if isDefault, okDefault := tier["isDefault"].(bool); okDefault && isDefault {
if id, okID := tier["id"].(string); okID && strings.TrimSpace(id) != "" {
tierID = strings.TrimSpace(id)
break
}
}
}
}
projectID, err = o.OnboardUser(ctx, accessToken, tierID)
if err != nil {
return "", err
}
return projectID, nil
}
return projectID, nil
}
// OnboardUser attempts to fetch the project ID via onboardUser by polling for completion
func (o *AntigravityAuth) OnboardUser(ctx context.Context, accessToken, tierID string) (string, error) {
log.Infof("Antigravity: onboarding user with tier: %s", tierID)
requestBody := map[string]any{
"tierId": tierID,
"metadata": map[string]string{
"ideType": "ANTIGRAVITY",
"platform": "PLATFORM_UNSPECIFIED",
"pluginType": "GEMINI",
},
}
rawBody, errMarshal := json.Marshal(requestBody)
if errMarshal != nil {
return "", fmt.Errorf("marshal request body: %w", errMarshal)
}
maxAttempts := 5
for attempt := 1; attempt <= maxAttempts; attempt++ {
log.Debugf("Polling attempt %d/%d", attempt, maxAttempts)
reqCtx := ctx
var cancel context.CancelFunc
if reqCtx == nil {
reqCtx = context.Background()
}
reqCtx, cancel = context.WithTimeout(reqCtx, 30*time.Second)
endpointURL := fmt.Sprintf("%s/%s:onboardUser", APIEndpoint, APIVersion)
req, errRequest := http.NewRequestWithContext(reqCtx, http.MethodPost, endpointURL, strings.NewReader(string(rawBody)))
if errRequest != nil {
cancel()
return "", fmt.Errorf("create request: %w", errRequest)
}
req.Header.Set("Authorization", "Bearer "+accessToken)
req.Header.Set("Content-Type", "application/json")
req.Header.Set("User-Agent", APIUserAgent)
req.Header.Set("X-Goog-Api-Client", APIClient)
req.Header.Set("Client-Metadata", ClientMetadata)
resp, errDo := o.httpClient.Do(req)
if errDo != nil {
cancel()
return "", fmt.Errorf("execute request: %w", errDo)
}
bodyBytes, errRead := io.ReadAll(resp.Body)
if errClose := resp.Body.Close(); errClose != nil {
log.Errorf("close body error: %v", errClose)
}
cancel()
if errRead != nil {
return "", fmt.Errorf("read response: %w", errRead)
}
if resp.StatusCode == http.StatusOK {
var data map[string]any
if errDecode := json.Unmarshal(bodyBytes, &data); errDecode != nil {
return "", fmt.Errorf("decode response: %w", errDecode)
}
if done, okDone := data["done"].(bool); okDone && done {
projectID := ""
if responseData, okResp := data["response"].(map[string]any); okResp {
switch projectValue := responseData["cloudaicompanionProject"].(type) {
case map[string]any:
if id, okID := projectValue["id"].(string); okID {
projectID = strings.TrimSpace(id)
}
case string:
projectID = strings.TrimSpace(projectValue)
}
}
if projectID != "" {
log.Infof("Successfully fetched project_id: %s", projectID)
return projectID, nil
}
return "", fmt.Errorf("no project_id in response")
}
time.Sleep(2 * time.Second)
continue
}
responsePreview := strings.TrimSpace(string(bodyBytes))
if len(responsePreview) > 500 {
responsePreview = responsePreview[:500]
}
responseErr := responsePreview
if len(responseErr) > 200 {
responseErr = responseErr[:200]
}
return "", fmt.Errorf("http %d: %s", resp.StatusCode, responseErr)
}
return "", nil
}

View File

@@ -0,0 +1,34 @@
// Package antigravity provides OAuth2 authentication functionality for the Antigravity provider.
package antigravity
// OAuth client credentials and configuration
const (
ClientID = "1071006060591-tmhssin2h21lcre235vtolojh4g403ep.apps.googleusercontent.com"
ClientSecret = "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"
CallbackPort = 51121
)
// Scopes defines the OAuth scopes required for Antigravity authentication
var Scopes = []string{
"https://www.googleapis.com/auth/cloud-platform",
"https://www.googleapis.com/auth/userinfo.email",
"https://www.googleapis.com/auth/userinfo.profile",
"https://www.googleapis.com/auth/cclog",
"https://www.googleapis.com/auth/experimentsandconfigs",
}
// OAuth2 endpoints for Google authentication
const (
TokenEndpoint = "https://oauth2.googleapis.com/token"
AuthEndpoint = "https://accounts.google.com/o/oauth2/v2/auth"
UserInfoEndpoint = "https://www.googleapis.com/oauth2/v1/userinfo?alt=json"
)
// Antigravity API configuration
const (
APIEndpoint = "https://cloudcode-pa.googleapis.com"
APIVersion = "v1internal"
APIUserAgent = "google-api-nodejs-client/9.15.1"
APIClient = "google-cloud-sdk vscode_cloudshelleditor/0.1"
ClientMetadata = `{"ideType":"IDE_UNSPECIFIED","platform":"PLATFORM_UNSPECIFIED","pluginType":"GEMINI"}`
)

View File

@@ -0,0 +1,16 @@
package antigravity
import (
"fmt"
"strings"
)
// CredentialFileName returns the filename used to persist Antigravity credentials.
// It uses the email as a suffix to disambiguate accounts.
func CredentialFileName(email string) string {
email = strings.TrimSpace(email)
if email == "" {
return "antigravity.json"
}
return fmt.Sprintf("antigravity-%s.json", email)
}

View File

@@ -18,11 +18,12 @@ import (
log "github.com/sirupsen/logrus" log "github.com/sirupsen/logrus"
) )
// OAuth configuration constants for Claude/Anthropic
const ( const (
anthropicAuthURL = "https://claude.ai/oauth/authorize" AuthURL = "https://claude.ai/oauth/authorize"
anthropicTokenURL = "https://console.anthropic.com/v1/oauth/token" TokenURL = "https://console.anthropic.com/v1/oauth/token"
anthropicClientID = "9d1c250a-e61b-44d9-88ed-5944d1962f5e" ClientID = "9d1c250a-e61b-44d9-88ed-5944d1962f5e"
redirectURI = "http://localhost:54545/callback" RedirectURI = "http://localhost:54545/callback"
) )
// tokenResponse represents the response structure from Anthropic's OAuth token endpoint. // tokenResponse represents the response structure from Anthropic's OAuth token endpoint.
@@ -82,16 +83,16 @@ func (o *ClaudeAuth) GenerateAuthURL(state string, pkceCodes *PKCECodes) (string
params := url.Values{ params := url.Values{
"code": {"true"}, "code": {"true"},
"client_id": {anthropicClientID}, "client_id": {ClientID},
"response_type": {"code"}, "response_type": {"code"},
"redirect_uri": {redirectURI}, "redirect_uri": {RedirectURI},
"scope": {"org:create_api_key user:profile user:inference"}, "scope": {"org:create_api_key user:profile user:inference"},
"code_challenge": {pkceCodes.CodeChallenge}, "code_challenge": {pkceCodes.CodeChallenge},
"code_challenge_method": {"S256"}, "code_challenge_method": {"S256"},
"state": {state}, "state": {state},
} }
authURL := fmt.Sprintf("%s?%s", anthropicAuthURL, params.Encode()) authURL := fmt.Sprintf("%s?%s", AuthURL, params.Encode())
return authURL, state, nil return authURL, state, nil
} }
@@ -137,8 +138,8 @@ func (o *ClaudeAuth) ExchangeCodeForTokens(ctx context.Context, code, state stri
"code": newCode, "code": newCode,
"state": state, "state": state,
"grant_type": "authorization_code", "grant_type": "authorization_code",
"client_id": anthropicClientID, "client_id": ClientID,
"redirect_uri": redirectURI, "redirect_uri": RedirectURI,
"code_verifier": pkceCodes.CodeVerifier, "code_verifier": pkceCodes.CodeVerifier,
} }
@@ -154,7 +155,7 @@ func (o *ClaudeAuth) ExchangeCodeForTokens(ctx context.Context, code, state stri
// log.Debugf("Token exchange request: %s", string(jsonBody)) // log.Debugf("Token exchange request: %s", string(jsonBody))
req, err := http.NewRequestWithContext(ctx, "POST", anthropicTokenURL, strings.NewReader(string(jsonBody))) req, err := http.NewRequestWithContext(ctx, "POST", TokenURL, strings.NewReader(string(jsonBody)))
if err != nil { if err != nil {
return nil, fmt.Errorf("failed to create token request: %w", err) return nil, fmt.Errorf("failed to create token request: %w", err)
} }
@@ -221,7 +222,7 @@ func (o *ClaudeAuth) RefreshTokens(ctx context.Context, refreshToken string) (*C
} }
reqBody := map[string]interface{}{ reqBody := map[string]interface{}{
"client_id": anthropicClientID, "client_id": ClientID,
"grant_type": "refresh_token", "grant_type": "refresh_token",
"refresh_token": refreshToken, "refresh_token": refreshToken,
} }
@@ -231,7 +232,7 @@ func (o *ClaudeAuth) RefreshTokens(ctx context.Context, refreshToken string) (*C
return nil, fmt.Errorf("failed to marshal request body: %w", err) return nil, fmt.Errorf("failed to marshal request body: %w", err)
} }
req, err := http.NewRequestWithContext(ctx, "POST", anthropicTokenURL, strings.NewReader(string(jsonBody))) req, err := http.NewRequestWithContext(ctx, "POST", TokenURL, strings.NewReader(string(jsonBody)))
if err != nil { if err != nil {
return nil, fmt.Errorf("failed to create refresh request: %w", err) return nil, fmt.Errorf("failed to create refresh request: %w", err)
} }

View File

@@ -19,11 +19,12 @@ import (
log "github.com/sirupsen/logrus" log "github.com/sirupsen/logrus"
) )
// OAuth configuration constants for OpenAI Codex
const ( const (
openaiAuthURL = "https://auth.openai.com/oauth/authorize" AuthURL = "https://auth.openai.com/oauth/authorize"
openaiTokenURL = "https://auth.openai.com/oauth/token" TokenURL = "https://auth.openai.com/oauth/token"
openaiClientID = "app_EMoamEEZ73f0CkXaXp7hrann" ClientID = "app_EMoamEEZ73f0CkXaXp7hrann"
redirectURI = "http://localhost:1455/auth/callback" RedirectURI = "http://localhost:1455/auth/callback"
) )
// CodexAuth handles the OpenAI OAuth2 authentication flow. // CodexAuth handles the OpenAI OAuth2 authentication flow.
@@ -50,9 +51,9 @@ func (o *CodexAuth) GenerateAuthURL(state string, pkceCodes *PKCECodes) (string,
} }
params := url.Values{ params := url.Values{
"client_id": {openaiClientID}, "client_id": {ClientID},
"response_type": {"code"}, "response_type": {"code"},
"redirect_uri": {redirectURI}, "redirect_uri": {RedirectURI},
"scope": {"openid email profile offline_access"}, "scope": {"openid email profile offline_access"},
"state": {state}, "state": {state},
"code_challenge": {pkceCodes.CodeChallenge}, "code_challenge": {pkceCodes.CodeChallenge},
@@ -62,7 +63,7 @@ func (o *CodexAuth) GenerateAuthURL(state string, pkceCodes *PKCECodes) (string,
"codex_cli_simplified_flow": {"true"}, "codex_cli_simplified_flow": {"true"},
} }
authURL := fmt.Sprintf("%s?%s", openaiAuthURL, params.Encode()) authURL := fmt.Sprintf("%s?%s", AuthURL, params.Encode())
return authURL, nil return authURL, nil
} }
@@ -77,13 +78,13 @@ func (o *CodexAuth) ExchangeCodeForTokens(ctx context.Context, code string, pkce
// Prepare token exchange request // Prepare token exchange request
data := url.Values{ data := url.Values{
"grant_type": {"authorization_code"}, "grant_type": {"authorization_code"},
"client_id": {openaiClientID}, "client_id": {ClientID},
"code": {code}, "code": {code},
"redirect_uri": {redirectURI}, "redirect_uri": {RedirectURI},
"code_verifier": {pkceCodes.CodeVerifier}, "code_verifier": {pkceCodes.CodeVerifier},
} }
req, err := http.NewRequestWithContext(ctx, "POST", openaiTokenURL, strings.NewReader(data.Encode())) req, err := http.NewRequestWithContext(ctx, "POST", TokenURL, strings.NewReader(data.Encode()))
if err != nil { if err != nil {
return nil, fmt.Errorf("failed to create token request: %w", err) return nil, fmt.Errorf("failed to create token request: %w", err)
} }
@@ -163,13 +164,13 @@ func (o *CodexAuth) RefreshTokens(ctx context.Context, refreshToken string) (*Co
} }
data := url.Values{ data := url.Values{
"client_id": {openaiClientID}, "client_id": {ClientID},
"grant_type": {"refresh_token"}, "grant_type": {"refresh_token"},
"refresh_token": {refreshToken}, "refresh_token": {refreshToken},
"scope": {"openid profile email"}, "scope": {"openid profile email"},
} }
req, err := http.NewRequestWithContext(ctx, "POST", openaiTokenURL, strings.NewReader(data.Encode())) req, err := http.NewRequestWithContext(ctx, "POST", TokenURL, strings.NewReader(data.Encode()))
if err != nil { if err != nil {
return nil, fmt.Errorf("failed to create refresh request: %w", err) return nil, fmt.Errorf("failed to create refresh request: %w", err)
} }

View File

@@ -28,19 +28,19 @@ import (
"golang.org/x/oauth2/google" "golang.org/x/oauth2/google"
) )
// OAuth configuration constants for Gemini
const ( const (
geminiOauthClientID = "681255809395-oo8ft2oprdrnp9e3aqf6av3hmdib135j.apps.googleusercontent.com" ClientID = "681255809395-oo8ft2oprdrnp9e3aqf6av3hmdib135j.apps.googleusercontent.com"
geminiOauthClientSecret = "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" ClientSecret = "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"
geminiDefaultCallbackPort = 8085 DefaultCallbackPort = 8085
) )
var ( // OAuth scopes for Gemini authentication
geminiOauthScopes = []string{ var Scopes = []string{
"https://www.googleapis.com/auth/cloud-platform", "https://www.googleapis.com/auth/cloud-platform",
"https://www.googleapis.com/auth/userinfo.email", "https://www.googleapis.com/auth/userinfo.email",
"https://www.googleapis.com/auth/userinfo.profile", "https://www.googleapis.com/auth/userinfo.profile",
} }
)
// GeminiAuth provides methods for handling the Gemini OAuth2 authentication flow. // GeminiAuth provides methods for handling the Gemini OAuth2 authentication flow.
// It encapsulates the logic for obtaining, storing, and refreshing authentication tokens // It encapsulates the logic for obtaining, storing, and refreshing authentication tokens
@@ -74,7 +74,7 @@ func NewGeminiAuth() *GeminiAuth {
// - *http.Client: An HTTP client configured with authentication // - *http.Client: An HTTP client configured with authentication
// - error: An error if the client configuration fails, nil otherwise // - error: An error if the client configuration fails, nil otherwise
func (g *GeminiAuth) GetAuthenticatedClient(ctx context.Context, ts *GeminiTokenStorage, cfg *config.Config, opts *WebLoginOptions) (*http.Client, error) { func (g *GeminiAuth) GetAuthenticatedClient(ctx context.Context, ts *GeminiTokenStorage, cfg *config.Config, opts *WebLoginOptions) (*http.Client, error) {
callbackPort := geminiDefaultCallbackPort callbackPort := DefaultCallbackPort
if opts != nil && opts.CallbackPort > 0 { if opts != nil && opts.CallbackPort > 0 {
callbackPort = opts.CallbackPort callbackPort = opts.CallbackPort
} }
@@ -112,10 +112,10 @@ func (g *GeminiAuth) GetAuthenticatedClient(ctx context.Context, ts *GeminiToken
// Configure the OAuth2 client. // Configure the OAuth2 client.
conf := &oauth2.Config{ conf := &oauth2.Config{
ClientID: geminiOauthClientID, ClientID: ClientID,
ClientSecret: geminiOauthClientSecret, ClientSecret: ClientSecret,
RedirectURL: callbackURL, // This will be used by the local server. RedirectURL: callbackURL, // This will be used by the local server.
Scopes: geminiOauthScopes, Scopes: Scopes,
Endpoint: google.Endpoint, Endpoint: google.Endpoint,
} }
@@ -198,9 +198,9 @@ func (g *GeminiAuth) createTokenStorage(ctx context.Context, config *oauth2.Conf
} }
ifToken["token_uri"] = "https://oauth2.googleapis.com/token" ifToken["token_uri"] = "https://oauth2.googleapis.com/token"
ifToken["client_id"] = geminiOauthClientID ifToken["client_id"] = ClientID
ifToken["client_secret"] = geminiOauthClientSecret ifToken["client_secret"] = ClientSecret
ifToken["scopes"] = geminiOauthScopes ifToken["scopes"] = Scopes
ifToken["universe_domain"] = "googleapis.com" ifToken["universe_domain"] = "googleapis.com"
ts := GeminiTokenStorage{ ts := GeminiTokenStorage{
@@ -226,7 +226,7 @@ func (g *GeminiAuth) createTokenStorage(ctx context.Context, config *oauth2.Conf
// - *oauth2.Token: The OAuth2 token obtained from the authorization flow // - *oauth2.Token: The OAuth2 token obtained from the authorization flow
// - error: An error if the token acquisition fails, nil otherwise // - error: An error if the token acquisition fails, nil otherwise
func (g *GeminiAuth) getTokenFromWeb(ctx context.Context, config *oauth2.Config, opts *WebLoginOptions) (*oauth2.Token, error) { func (g *GeminiAuth) getTokenFromWeb(ctx context.Context, config *oauth2.Config, opts *WebLoginOptions) (*oauth2.Token, error) {
callbackPort := geminiDefaultCallbackPort callbackPort := DefaultCallbackPort
if opts != nil && opts.CallbackPort > 0 { if opts != nil && opts.CallbackPort > 0 {
callbackPort = opts.CallbackPort callbackPort = opts.CallbackPort
} }

View File

@@ -6,6 +6,7 @@ import (
"time" "time"
"github.com/router-for-me/CLIProxyAPI/v6/internal/config" "github.com/router-for-me/CLIProxyAPI/v6/internal/config"
"github.com/router-for-me/CLIProxyAPI/v6/internal/util"
log "github.com/sirupsen/logrus" log "github.com/sirupsen/logrus"
) )
@@ -49,6 +50,14 @@ func (m *RefreshManager) Initialize(baseDir string, cfg *config.Config) error {
return nil return nil
} }
resolvedBaseDir, err := util.ResolveAuthDir(baseDir)
if err != nil {
log.Warnf("refresh manager: failed to resolve auth directory %s: %v", baseDir, err)
}
if resolvedBaseDir != "" {
baseDir = resolvedBaseDir
}
// 创建 token 存储库 // 创建 token 存储库
repo := NewFileTokenRepository(baseDir) repo := NewFileTokenRepository(baseDir)

View File

@@ -1042,10 +1042,10 @@ func (r *ModelRegistry) convertModelToMap(model *ModelInfo, handlerType string)
"owned_by": model.OwnedBy, "owned_by": model.OwnedBy,
} }
if model.Created > 0 { if model.Created > 0 {
result["created"] = model.Created result["created_at"] = model.Created
} }
if model.Type != "" { if model.Type != "" {
result["type"] = model.Type result["type"] = "model"
} }
if model.DisplayName != "" { if model.DisplayName != "" {
result["display_name"] = model.DisplayName result["display_name"] = model.DisplayName

View File

@@ -148,87 +148,108 @@ func (e *AntigravityExecutor) Execute(ctx context.Context, auth *cliproxyauth.Au
baseURLs := antigravityBaseURLFallbackOrder(auth) baseURLs := antigravityBaseURLFallbackOrder(auth)
httpClient := newProxyAwareHTTPClient(ctx, e.cfg, auth, 0) httpClient := newProxyAwareHTTPClient(ctx, e.cfg, auth, 0)
var lastStatus int attempts := antigravityRetryAttempts(auth, e.cfg)
var lastBody []byte
var lastErr error
for idx, baseURL := range baseURLs { attemptLoop:
httpReq, errReq := e.buildRequest(ctx, auth, token, baseModel, translated, false, opts.Alt, baseURL) for attempt := 0; attempt < attempts; attempt++ {
if errReq != nil { var lastStatus int
err = errReq var lastBody []byte
return resp, err var lastErr error
}
httpResp, errDo := httpClient.Do(httpReq) for idx, baseURL := range baseURLs {
if errDo != nil { httpReq, errReq := e.buildRequest(ctx, auth, token, baseModel, translated, false, opts.Alt, baseURL)
recordAPIResponseError(ctx, e.cfg, errDo) if errReq != nil {
if errors.Is(errDo, context.Canceled) || errors.Is(errDo, context.DeadlineExceeded) { err = errReq
return resp, errDo return resp, err
} }
lastStatus = 0
lastBody = nil httpResp, errDo := httpClient.Do(httpReq)
lastErr = errDo if errDo != nil {
if idx+1 < len(baseURLs) { recordAPIResponseError(ctx, e.cfg, errDo)
log.Debugf("antigravity executor: request error on base url %s, retrying with fallback base url: %s", baseURL, baseURLs[idx+1]) if errors.Is(errDo, context.Canceled) || errors.Is(errDo, context.DeadlineExceeded) {
continue return resp, errDo
}
lastStatus = 0
lastBody = nil
lastErr = errDo
if idx+1 < len(baseURLs) {
log.Debugf("antigravity executor: request error on base url %s, retrying with fallback base url: %s", baseURL, baseURLs[idx+1])
continue
}
err = errDo
return resp, err
} }
err = errDo
return resp, err recordAPIResponseMetadata(ctx, e.cfg, httpResp.StatusCode, httpResp.Header.Clone())
bodyBytes, errRead := io.ReadAll(httpResp.Body)
if errClose := httpResp.Body.Close(); errClose != nil {
log.Errorf("antigravity executor: close response body error: %v", errClose)
}
if errRead != nil {
recordAPIResponseError(ctx, e.cfg, errRead)
err = errRead
return resp, err
}
appendAPIResponseChunk(ctx, e.cfg, bodyBytes)
if httpResp.StatusCode < http.StatusOK || httpResp.StatusCode >= http.StatusMultipleChoices {
log.Debugf("antigravity executor: upstream error status: %d, body: %s", httpResp.StatusCode, summarizeErrorBody(httpResp.Header.Get("Content-Type"), bodyBytes))
lastStatus = httpResp.StatusCode
lastBody = append([]byte(nil), bodyBytes...)
lastErr = nil
if httpResp.StatusCode == http.StatusTooManyRequests && idx+1 < len(baseURLs) {
log.Debugf("antigravity executor: rate limited on base url %s, retrying with fallback base url: %s", baseURL, baseURLs[idx+1])
continue
}
if antigravityShouldRetryNoCapacity(httpResp.StatusCode, bodyBytes) {
if idx+1 < len(baseURLs) {
log.Debugf("antigravity executor: no capacity on base url %s, retrying with fallback base url: %s", baseURL, baseURLs[idx+1])
continue
}
if attempt+1 < attempts {
delay := antigravityNoCapacityRetryDelay(attempt)
log.Debugf("antigravity executor: no capacity for model %s, retrying in %s (attempt %d/%d)", baseModel, delay, attempt+1, attempts)
if errWait := antigravityWait(ctx, delay); errWait != nil {
return resp, errWait
}
continue attemptLoop
}
}
sErr := statusErr{code: httpResp.StatusCode, msg: string(bodyBytes)}
if httpResp.StatusCode == http.StatusTooManyRequests {
if retryAfter, parseErr := parseRetryDelay(bodyBytes); parseErr == nil && retryAfter != nil {
sErr.retryAfter = retryAfter
}
}
err = sErr
return resp, err
}
reporter.publish(ctx, parseAntigravityUsage(bodyBytes))
var param any
converted := sdktranslator.TranslateNonStream(ctx, to, from, req.Model, bytes.Clone(opts.OriginalRequest), translated, bodyBytes, &param)
resp = cliproxyexecutor.Response{Payload: []byte(converted)}
reporter.ensurePublished(ctx)
return resp, nil
} }
recordAPIResponseMetadata(ctx, e.cfg, httpResp.StatusCode, httpResp.Header.Clone()) switch {
bodyBytes, errRead := io.ReadAll(httpResp.Body) case lastStatus != 0:
if errClose := httpResp.Body.Close(); errClose != nil { sErr := statusErr{code: lastStatus, msg: string(lastBody)}
log.Errorf("antigravity executor: close response body error: %v", errClose) if lastStatus == http.StatusTooManyRequests {
} if retryAfter, parseErr := parseRetryDelay(lastBody); parseErr == nil && retryAfter != nil {
if errRead != nil {
recordAPIResponseError(ctx, e.cfg, errRead)
err = errRead
return resp, err
}
appendAPIResponseChunk(ctx, e.cfg, bodyBytes)
if httpResp.StatusCode < http.StatusOK || httpResp.StatusCode >= http.StatusMultipleChoices {
log.Debugf("antigravity executor: upstream error status: %d, body: %s", httpResp.StatusCode, summarizeErrorBody(httpResp.Header.Get("Content-Type"), bodyBytes))
lastStatus = httpResp.StatusCode
lastBody = append([]byte(nil), bodyBytes...)
lastErr = nil
if httpResp.StatusCode == http.StatusTooManyRequests && idx+1 < len(baseURLs) {
log.Debugf("antigravity executor: rate limited on base url %s, retrying with fallback base url: %s", baseURL, baseURLs[idx+1])
continue
}
sErr := statusErr{code: httpResp.StatusCode, msg: string(bodyBytes)}
if httpResp.StatusCode == http.StatusTooManyRequests {
if retryAfter, parseErr := parseRetryDelay(bodyBytes); parseErr == nil && retryAfter != nil {
sErr.retryAfter = retryAfter sErr.retryAfter = retryAfter
} }
} }
err = sErr err = sErr
return resp, err case lastErr != nil:
err = lastErr
default:
err = statusErr{code: http.StatusServiceUnavailable, msg: "antigravity executor: no base url available"}
} }
return resp, err
reporter.publish(ctx, parseAntigravityUsage(bodyBytes))
var param any
converted := sdktranslator.TranslateNonStream(ctx, to, from, req.Model, bytes.Clone(opts.OriginalRequest), translated, bodyBytes, &param)
resp = cliproxyexecutor.Response{Payload: []byte(converted)}
reporter.ensurePublished(ctx)
return resp, nil
} }
switch {
case lastStatus != 0:
sErr := statusErr{code: lastStatus, msg: string(lastBody)}
if lastStatus == http.StatusTooManyRequests {
if retryAfter, parseErr := parseRetryDelay(lastBody); parseErr == nil && retryAfter != nil {
sErr.retryAfter = retryAfter
}
}
err = sErr
case lastErr != nil:
err = lastErr
default:
err = statusErr{code: http.StatusServiceUnavailable, msg: "antigravity executor: no base url available"}
}
return resp, err return resp, err
} }
@@ -268,150 +289,171 @@ func (e *AntigravityExecutor) executeClaudeNonStream(ctx context.Context, auth *
baseURLs := antigravityBaseURLFallbackOrder(auth) baseURLs := antigravityBaseURLFallbackOrder(auth)
httpClient := newProxyAwareHTTPClient(ctx, e.cfg, auth, 0) httpClient := newProxyAwareHTTPClient(ctx, e.cfg, auth, 0)
var lastStatus int attempts := antigravityRetryAttempts(auth, e.cfg)
var lastBody []byte
var lastErr error
for idx, baseURL := range baseURLs { attemptLoop:
httpReq, errReq := e.buildRequest(ctx, auth, token, baseModel, translated, true, opts.Alt, baseURL) for attempt := 0; attempt < attempts; attempt++ {
if errReq != nil { var lastStatus int
err = errReq var lastBody []byte
return resp, err var lastErr error
}
httpResp, errDo := httpClient.Do(httpReq) for idx, baseURL := range baseURLs {
if errDo != nil { httpReq, errReq := e.buildRequest(ctx, auth, token, baseModel, translated, true, opts.Alt, baseURL)
recordAPIResponseError(ctx, e.cfg, errDo) if errReq != nil {
if errors.Is(errDo, context.Canceled) || errors.Is(errDo, context.DeadlineExceeded) { err = errReq
return resp, errDo return resp, err
} }
lastStatus = 0
lastBody = nil httpResp, errDo := httpClient.Do(httpReq)
lastErr = errDo if errDo != nil {
if idx+1 < len(baseURLs) { recordAPIResponseError(ctx, e.cfg, errDo)
log.Debugf("antigravity executor: request error on base url %s, retrying with fallback base url: %s", baseURL, baseURLs[idx+1]) if errors.Is(errDo, context.Canceled) || errors.Is(errDo, context.DeadlineExceeded) {
continue return resp, errDo
}
err = errDo
return resp, err
}
recordAPIResponseMetadata(ctx, e.cfg, httpResp.StatusCode, httpResp.Header.Clone())
if httpResp.StatusCode < http.StatusOK || httpResp.StatusCode >= http.StatusMultipleChoices {
bodyBytes, errRead := io.ReadAll(httpResp.Body)
if errClose := httpResp.Body.Close(); errClose != nil {
log.Errorf("antigravity executor: close response body error: %v", errClose)
}
if errRead != nil {
recordAPIResponseError(ctx, e.cfg, errRead)
if errors.Is(errRead, context.Canceled) || errors.Is(errRead, context.DeadlineExceeded) {
err = errRead
return resp, err
}
if errCtx := ctx.Err(); errCtx != nil {
err = errCtx
return resp, err
} }
lastStatus = 0 lastStatus = 0
lastBody = nil lastBody = nil
lastErr = errRead lastErr = errDo
if idx+1 < len(baseURLs) { if idx+1 < len(baseURLs) {
log.Debugf("antigravity executor: read error on base url %s, retrying with fallback base url: %s", baseURL, baseURLs[idx+1]) log.Debugf("antigravity executor: request error on base url %s, retrying with fallback base url: %s", baseURL, baseURLs[idx+1])
continue continue
} }
err = errRead err = errDo
return resp, err return resp, err
} }
appendAPIResponseChunk(ctx, e.cfg, bodyBytes) recordAPIResponseMetadata(ctx, e.cfg, httpResp.StatusCode, httpResp.Header.Clone())
lastStatus = httpResp.StatusCode if httpResp.StatusCode < http.StatusOK || httpResp.StatusCode >= http.StatusMultipleChoices {
lastBody = append([]byte(nil), bodyBytes...) bodyBytes, errRead := io.ReadAll(httpResp.Body)
lastErr = nil if errClose := httpResp.Body.Close(); errClose != nil {
if httpResp.StatusCode == http.StatusTooManyRequests && idx+1 < len(baseURLs) { log.Errorf("antigravity executor: close response body error: %v", errClose)
log.Debugf("antigravity executor: rate limited on base url %s, retrying with fallback base url: %s", baseURL, baseURLs[idx+1]) }
continue if errRead != nil {
recordAPIResponseError(ctx, e.cfg, errRead)
if errors.Is(errRead, context.Canceled) || errors.Is(errRead, context.DeadlineExceeded) {
err = errRead
return resp, err
}
if errCtx := ctx.Err(); errCtx != nil {
err = errCtx
return resp, err
}
lastStatus = 0
lastBody = nil
lastErr = errRead
if idx+1 < len(baseURLs) {
log.Debugf("antigravity executor: read error on base url %s, retrying with fallback base url: %s", baseURL, baseURLs[idx+1])
continue
}
err = errRead
return resp, err
}
appendAPIResponseChunk(ctx, e.cfg, bodyBytes)
lastStatus = httpResp.StatusCode
lastBody = append([]byte(nil), bodyBytes...)
lastErr = nil
if httpResp.StatusCode == http.StatusTooManyRequests && idx+1 < len(baseURLs) {
log.Debugf("antigravity executor: rate limited on base url %s, retrying with fallback base url: %s", baseURL, baseURLs[idx+1])
continue
}
if antigravityShouldRetryNoCapacity(httpResp.StatusCode, bodyBytes) {
if idx+1 < len(baseURLs) {
log.Debugf("antigravity executor: no capacity on base url %s, retrying with fallback base url: %s", baseURL, baseURLs[idx+1])
continue
}
if attempt+1 < attempts {
delay := antigravityNoCapacityRetryDelay(attempt)
log.Debugf("antigravity executor: no capacity for model %s, retrying in %s (attempt %d/%d)", baseModel, delay, attempt+1, attempts)
if errWait := antigravityWait(ctx, delay); errWait != nil {
return resp, errWait
}
continue attemptLoop
}
}
sErr := statusErr{code: httpResp.StatusCode, msg: string(bodyBytes)}
if httpResp.StatusCode == http.StatusTooManyRequests {
if retryAfter, parseErr := parseRetryDelay(bodyBytes); parseErr == nil && retryAfter != nil {
sErr.retryAfter = retryAfter
}
}
err = sErr
return resp, err
} }
sErr := statusErr{code: httpResp.StatusCode, msg: string(bodyBytes)}
if httpResp.StatusCode == http.StatusTooManyRequests { out := make(chan cliproxyexecutor.StreamChunk)
if retryAfter, parseErr := parseRetryDelay(bodyBytes); parseErr == nil && retryAfter != nil { go func(resp *http.Response) {
defer close(out)
defer func() {
if errClose := resp.Body.Close(); errClose != nil {
log.Errorf("antigravity executor: close response body error: %v", errClose)
}
}()
scanner := bufio.NewScanner(resp.Body)
scanner.Buffer(nil, streamScannerBuffer)
for scanner.Scan() {
line := scanner.Bytes()
appendAPIResponseChunk(ctx, e.cfg, line)
// Filter usage metadata for all models
// Only retain usage statistics in the terminal chunk
line = FilterSSEUsageMetadata(line)
payload := jsonPayload(line)
if payload == nil {
continue
}
if detail, ok := parseAntigravityStreamUsage(payload); ok {
reporter.publish(ctx, detail)
}
out <- cliproxyexecutor.StreamChunk{Payload: payload}
}
if errScan := scanner.Err(); errScan != nil {
recordAPIResponseError(ctx, e.cfg, errScan)
reporter.publishFailure(ctx)
out <- cliproxyexecutor.StreamChunk{Err: errScan}
} else {
reporter.ensurePublished(ctx)
}
}(httpResp)
var buffer bytes.Buffer
for chunk := range out {
if chunk.Err != nil {
return resp, chunk.Err
}
if len(chunk.Payload) > 0 {
_, _ = buffer.Write(chunk.Payload)
_, _ = buffer.Write([]byte("\n"))
}
}
resp = cliproxyexecutor.Response{Payload: e.convertStreamToNonStream(buffer.Bytes())}
reporter.publish(ctx, parseAntigravityUsage(resp.Payload))
var param any
converted := sdktranslator.TranslateNonStream(ctx, to, from, req.Model, bytes.Clone(opts.OriginalRequest), translated, resp.Payload, &param)
resp = cliproxyexecutor.Response{Payload: []byte(converted)}
reporter.ensurePublished(ctx)
return resp, nil
}
switch {
case lastStatus != 0:
sErr := statusErr{code: lastStatus, msg: string(lastBody)}
if lastStatus == http.StatusTooManyRequests {
if retryAfter, parseErr := parseRetryDelay(lastBody); parseErr == nil && retryAfter != nil {
sErr.retryAfter = retryAfter sErr.retryAfter = retryAfter
} }
} }
err = sErr err = sErr
return resp, err case lastErr != nil:
err = lastErr
default:
err = statusErr{code: http.StatusServiceUnavailable, msg: "antigravity executor: no base url available"}
} }
return resp, err
out := make(chan cliproxyexecutor.StreamChunk)
go func(resp *http.Response) {
defer close(out)
defer func() {
if errClose := resp.Body.Close(); errClose != nil {
log.Errorf("antigravity executor: close response body error: %v", errClose)
}
}()
scanner := bufio.NewScanner(resp.Body)
scanner.Buffer(nil, streamScannerBuffer)
for scanner.Scan() {
line := scanner.Bytes()
appendAPIResponseChunk(ctx, e.cfg, line)
// Filter usage metadata for all models
// Only retain usage statistics in the terminal chunk
line = FilterSSEUsageMetadata(line)
payload := jsonPayload(line)
if payload == nil {
continue
}
if detail, ok := parseAntigravityStreamUsage(payload); ok {
reporter.publish(ctx, detail)
}
out <- cliproxyexecutor.StreamChunk{Payload: payload}
}
if errScan := scanner.Err(); errScan != nil {
recordAPIResponseError(ctx, e.cfg, errScan)
reporter.publishFailure(ctx)
out <- cliproxyexecutor.StreamChunk{Err: errScan}
} else {
reporter.ensurePublished(ctx)
}
}(httpResp)
var buffer bytes.Buffer
for chunk := range out {
if chunk.Err != nil {
return resp, chunk.Err
}
if len(chunk.Payload) > 0 {
_, _ = buffer.Write(chunk.Payload)
_, _ = buffer.Write([]byte("\n"))
}
}
resp = cliproxyexecutor.Response{Payload: e.convertStreamToNonStream(buffer.Bytes())}
reporter.publish(ctx, parseAntigravityUsage(resp.Payload))
var param any
converted := sdktranslator.TranslateNonStream(ctx, to, from, req.Model, bytes.Clone(opts.OriginalRequest), translated, resp.Payload, &param)
resp = cliproxyexecutor.Response{Payload: []byte(converted)}
reporter.ensurePublished(ctx)
return resp, nil
} }
switch {
case lastStatus != 0:
sErr := statusErr{code: lastStatus, msg: string(lastBody)}
if lastStatus == http.StatusTooManyRequests {
if retryAfter, parseErr := parseRetryDelay(lastBody); parseErr == nil && retryAfter != nil {
sErr.retryAfter = retryAfter
}
}
err = sErr
case lastErr != nil:
err = lastErr
default:
err = statusErr{code: http.StatusServiceUnavailable, msg: "antigravity executor: no base url available"}
}
return resp, err return resp, err
} }
@@ -635,139 +677,160 @@ func (e *AntigravityExecutor) ExecuteStream(ctx context.Context, auth *cliproxya
baseURLs := antigravityBaseURLFallbackOrder(auth) baseURLs := antigravityBaseURLFallbackOrder(auth)
httpClient := newProxyAwareHTTPClient(ctx, e.cfg, auth, 0) httpClient := newProxyAwareHTTPClient(ctx, e.cfg, auth, 0)
var lastStatus int attempts := antigravityRetryAttempts(auth, e.cfg)
var lastBody []byte
var lastErr error
for idx, baseURL := range baseURLs { attemptLoop:
httpReq, errReq := e.buildRequest(ctx, auth, token, baseModel, translated, true, opts.Alt, baseURL) for attempt := 0; attempt < attempts; attempt++ {
if errReq != nil { var lastStatus int
err = errReq var lastBody []byte
return nil, err var lastErr error
}
httpResp, errDo := httpClient.Do(httpReq) for idx, baseURL := range baseURLs {
if errDo != nil { httpReq, errReq := e.buildRequest(ctx, auth, token, baseModel, translated, true, opts.Alt, baseURL)
recordAPIResponseError(ctx, e.cfg, errDo) if errReq != nil {
if errors.Is(errDo, context.Canceled) || errors.Is(errDo, context.DeadlineExceeded) { err = errReq
return nil, errDo return nil, err
} }
lastStatus = 0 httpResp, errDo := httpClient.Do(httpReq)
lastBody = nil if errDo != nil {
lastErr = errDo recordAPIResponseError(ctx, e.cfg, errDo)
if idx+1 < len(baseURLs) { if errors.Is(errDo, context.Canceled) || errors.Is(errDo, context.DeadlineExceeded) {
log.Debugf("antigravity executor: request error on base url %s, retrying with fallback base url: %s", baseURL, baseURLs[idx+1]) return nil, errDo
continue
}
err = errDo
return nil, err
}
recordAPIResponseMetadata(ctx, e.cfg, httpResp.StatusCode, httpResp.Header.Clone())
if httpResp.StatusCode < http.StatusOK || httpResp.StatusCode >= http.StatusMultipleChoices {
bodyBytes, errRead := io.ReadAll(httpResp.Body)
if errClose := httpResp.Body.Close(); errClose != nil {
log.Errorf("antigravity executor: close response body error: %v", errClose)
}
if errRead != nil {
recordAPIResponseError(ctx, e.cfg, errRead)
if errors.Is(errRead, context.Canceled) || errors.Is(errRead, context.DeadlineExceeded) {
err = errRead
return nil, err
}
if errCtx := ctx.Err(); errCtx != nil {
err = errCtx
return nil, err
} }
lastStatus = 0 lastStatus = 0
lastBody = nil lastBody = nil
lastErr = errRead lastErr = errDo
if idx+1 < len(baseURLs) { if idx+1 < len(baseURLs) {
log.Debugf("antigravity executor: read error on base url %s, retrying with fallback base url: %s", baseURL, baseURLs[idx+1]) log.Debugf("antigravity executor: request error on base url %s, retrying with fallback base url: %s", baseURL, baseURLs[idx+1])
continue continue
} }
err = errRead err = errDo
return nil, err return nil, err
} }
appendAPIResponseChunk(ctx, e.cfg, bodyBytes) recordAPIResponseMetadata(ctx, e.cfg, httpResp.StatusCode, httpResp.Header.Clone())
lastStatus = httpResp.StatusCode if httpResp.StatusCode < http.StatusOK || httpResp.StatusCode >= http.StatusMultipleChoices {
lastBody = append([]byte(nil), bodyBytes...) bodyBytes, errRead := io.ReadAll(httpResp.Body)
lastErr = nil if errClose := httpResp.Body.Close(); errClose != nil {
if httpResp.StatusCode == http.StatusTooManyRequests && idx+1 < len(baseURLs) { log.Errorf("antigravity executor: close response body error: %v", errClose)
log.Debugf("antigravity executor: rate limited on base url %s, retrying with fallback base url: %s", baseURL, baseURLs[idx+1]) }
continue if errRead != nil {
recordAPIResponseError(ctx, e.cfg, errRead)
if errors.Is(errRead, context.Canceled) || errors.Is(errRead, context.DeadlineExceeded) {
err = errRead
return nil, err
}
if errCtx := ctx.Err(); errCtx != nil {
err = errCtx
return nil, err
}
lastStatus = 0
lastBody = nil
lastErr = errRead
if idx+1 < len(baseURLs) {
log.Debugf("antigravity executor: read error on base url %s, retrying with fallback base url: %s", baseURL, baseURLs[idx+1])
continue
}
err = errRead
return nil, err
}
appendAPIResponseChunk(ctx, e.cfg, bodyBytes)
lastStatus = httpResp.StatusCode
lastBody = append([]byte(nil), bodyBytes...)
lastErr = nil
if httpResp.StatusCode == http.StatusTooManyRequests && idx+1 < len(baseURLs) {
log.Debugf("antigravity executor: rate limited on base url %s, retrying with fallback base url: %s", baseURL, baseURLs[idx+1])
continue
}
if antigravityShouldRetryNoCapacity(httpResp.StatusCode, bodyBytes) {
if idx+1 < len(baseURLs) {
log.Debugf("antigravity executor: no capacity on base url %s, retrying with fallback base url: %s", baseURL, baseURLs[idx+1])
continue
}
if attempt+1 < attempts {
delay := antigravityNoCapacityRetryDelay(attempt)
log.Debugf("antigravity executor: no capacity for model %s, retrying in %s (attempt %d/%d)", baseModel, delay, attempt+1, attempts)
if errWait := antigravityWait(ctx, delay); errWait != nil {
return nil, errWait
}
continue attemptLoop
}
}
sErr := statusErr{code: httpResp.StatusCode, msg: string(bodyBytes)}
if httpResp.StatusCode == http.StatusTooManyRequests {
if retryAfter, parseErr := parseRetryDelay(bodyBytes); parseErr == nil && retryAfter != nil {
sErr.retryAfter = retryAfter
}
}
err = sErr
return nil, err
} }
sErr := statusErr{code: httpResp.StatusCode, msg: string(bodyBytes)}
if httpResp.StatusCode == http.StatusTooManyRequests { out := make(chan cliproxyexecutor.StreamChunk)
if retryAfter, parseErr := parseRetryDelay(bodyBytes); parseErr == nil && retryAfter != nil { stream = out
go func(resp *http.Response) {
defer close(out)
defer func() {
if errClose := resp.Body.Close(); errClose != nil {
log.Errorf("antigravity executor: close response body error: %v", errClose)
}
}()
scanner := bufio.NewScanner(resp.Body)
scanner.Buffer(nil, streamScannerBuffer)
var param any
for scanner.Scan() {
line := scanner.Bytes()
appendAPIResponseChunk(ctx, e.cfg, line)
// Filter usage metadata for all models
// Only retain usage statistics in the terminal chunk
line = FilterSSEUsageMetadata(line)
payload := jsonPayload(line)
if payload == nil {
continue
}
if detail, ok := parseAntigravityStreamUsage(payload); ok {
reporter.publish(ctx, detail)
}
chunks := sdktranslator.TranslateStream(ctx, to, from, req.Model, bytes.Clone(opts.OriginalRequest), translated, bytes.Clone(payload), &param)
for i := range chunks {
out <- cliproxyexecutor.StreamChunk{Payload: []byte(chunks[i])}
}
}
tail := sdktranslator.TranslateStream(ctx, to, from, req.Model, bytes.Clone(opts.OriginalRequest), translated, []byte("[DONE]"), &param)
for i := range tail {
out <- cliproxyexecutor.StreamChunk{Payload: []byte(tail[i])}
}
if errScan := scanner.Err(); errScan != nil {
recordAPIResponseError(ctx, e.cfg, errScan)
reporter.publishFailure(ctx)
out <- cliproxyexecutor.StreamChunk{Err: errScan}
} else {
reporter.ensurePublished(ctx)
}
}(httpResp)
return stream, nil
}
switch {
case lastStatus != 0:
sErr := statusErr{code: lastStatus, msg: string(lastBody)}
if lastStatus == http.StatusTooManyRequests {
if retryAfter, parseErr := parseRetryDelay(lastBody); parseErr == nil && retryAfter != nil {
sErr.retryAfter = retryAfter sErr.retryAfter = retryAfter
} }
} }
err = sErr err = sErr
return nil, err case lastErr != nil:
err = lastErr
default:
err = statusErr{code: http.StatusServiceUnavailable, msg: "antigravity executor: no base url available"}
} }
return nil, err
out := make(chan cliproxyexecutor.StreamChunk)
stream = out
go func(resp *http.Response) {
defer close(out)
defer func() {
if errClose := resp.Body.Close(); errClose != nil {
log.Errorf("antigravity executor: close response body error: %v", errClose)
}
}()
scanner := bufio.NewScanner(resp.Body)
scanner.Buffer(nil, streamScannerBuffer)
var param any
for scanner.Scan() {
line := scanner.Bytes()
appendAPIResponseChunk(ctx, e.cfg, line)
// Filter usage metadata for all models
// Only retain usage statistics in the terminal chunk
line = FilterSSEUsageMetadata(line)
payload := jsonPayload(line)
if payload == nil {
continue
}
if detail, ok := parseAntigravityStreamUsage(payload); ok {
reporter.publish(ctx, detail)
}
chunks := sdktranslator.TranslateStream(ctx, to, from, req.Model, bytes.Clone(opts.OriginalRequest), translated, bytes.Clone(payload), &param)
for i := range chunks {
out <- cliproxyexecutor.StreamChunk{Payload: []byte(chunks[i])}
}
}
tail := sdktranslator.TranslateStream(ctx, to, from, req.Model, bytes.Clone(opts.OriginalRequest), translated, []byte("[DONE]"), &param)
for i := range tail {
out <- cliproxyexecutor.StreamChunk{Payload: []byte(tail[i])}
}
if errScan := scanner.Err(); errScan != nil {
recordAPIResponseError(ctx, e.cfg, errScan)
reporter.publishFailure(ctx)
out <- cliproxyexecutor.StreamChunk{Err: errScan}
} else {
reporter.ensurePublished(ctx)
}
}(httpResp)
return stream, nil
} }
switch {
case lastStatus != 0:
sErr := statusErr{code: lastStatus, msg: string(lastBody)}
if lastStatus == http.StatusTooManyRequests {
if retryAfter, parseErr := parseRetryDelay(lastBody); parseErr == nil && retryAfter != nil {
sErr.retryAfter = retryAfter
}
}
err = sErr
case lastErr != nil:
err = lastErr
default:
err = statusErr{code: http.StatusServiceUnavailable, msg: "antigravity executor: no base url available"}
}
return nil, err return nil, err
} }
@@ -997,7 +1060,7 @@ func FetchAntigravityModels(ctx context.Context, auth *cliproxyauth.Auth, cfg *c
now := time.Now().Unix() now := time.Now().Unix()
modelConfig := registry.GetAntigravityModelConfig() modelConfig := registry.GetAntigravityModelConfig()
models := make([]*registry.ModelInfo, 0, len(result.Map())) models := make([]*registry.ModelInfo, 0, len(result.Map()))
for originalName := range result.Map() { for originalName, modelData := range result.Map() {
modelID := strings.TrimSpace(originalName) modelID := strings.TrimSpace(originalName)
if modelID == "" { if modelID == "" {
continue continue
@@ -1007,12 +1070,18 @@ func FetchAntigravityModels(ctx context.Context, auth *cliproxyauth.Auth, cfg *c
continue continue
} }
modelCfg := modelConfig[modelID] modelCfg := modelConfig[modelID]
modelName := modelID
// Extract displayName from upstream response, fallback to modelID
displayName := modelData.Get("displayName").String()
if displayName == "" {
displayName = modelID
}
modelInfo := &registry.ModelInfo{ modelInfo := &registry.ModelInfo{
ID: modelID, ID: modelID,
Name: modelName, Name: modelID,
Description: modelID, Description: displayName,
DisplayName: modelID, DisplayName: displayName,
Version: modelID, Version: modelID,
Object: "model", Object: "model",
Created: now, Created: now,
@@ -1378,14 +1447,70 @@ func resolveUserAgent(auth *cliproxyauth.Auth) string {
return defaultAntigravityAgent return defaultAntigravityAgent
} }
func antigravityRetryAttempts(auth *cliproxyauth.Auth, cfg *config.Config) int {
retry := 0
if cfg != nil {
retry = cfg.RequestRetry
}
if auth != nil {
if override, ok := auth.RequestRetryOverride(); ok {
retry = override
}
}
if retry < 0 {
retry = 0
}
attempts := retry + 1
if attempts < 1 {
return 1
}
return attempts
}
func antigravityShouldRetryNoCapacity(statusCode int, body []byte) bool {
if statusCode != http.StatusServiceUnavailable {
return false
}
if len(body) == 0 {
return false
}
msg := strings.ToLower(string(body))
return strings.Contains(msg, "no capacity available")
}
func antigravityNoCapacityRetryDelay(attempt int) time.Duration {
if attempt < 0 {
attempt = 0
}
delay := time.Duration(attempt+1) * 250 * time.Millisecond
if delay > 2*time.Second {
delay = 2 * time.Second
}
return delay
}
func antigravityWait(ctx context.Context, wait time.Duration) error {
if wait <= 0 {
return nil
}
timer := time.NewTimer(wait)
defer timer.Stop()
select {
case <-ctx.Done():
return ctx.Err()
case <-timer.C:
return nil
}
}
func antigravityBaseURLFallbackOrder(auth *cliproxyauth.Auth) []string { func antigravityBaseURLFallbackOrder(auth *cliproxyauth.Auth) []string {
if base := resolveCustomAntigravityBaseURL(auth); base != "" { if base := resolveCustomAntigravityBaseURL(auth); base != "" {
return []string{base} return []string{base}
} }
return []string{ return []string{
antigravitySandboxBaseURLDaily,
antigravityBaseURLDaily, antigravityBaseURLDaily,
antigravityBaseURLProd, antigravitySandboxBaseURLDaily,
// antigravityBaseURLProd,
} }
} }

View File

@@ -163,7 +163,7 @@ func (e *ClaudeExecutor) Execute(ctx context.Context, auth *cliproxyauth.Auth, r
if httpResp.StatusCode < 200 || httpResp.StatusCode >= 300 { if httpResp.StatusCode < 200 || httpResp.StatusCode >= 300 {
b, _ := io.ReadAll(httpResp.Body) b, _ := io.ReadAll(httpResp.Body)
appendAPIResponseChunk(ctx, e.cfg, b) appendAPIResponseChunk(ctx, e.cfg, b)
log.Debugf("request error, error status: %d, error body: %s", httpResp.StatusCode, summarizeErrorBody(httpResp.Header.Get("Content-Type"), b)) logWithRequestID(ctx).Debugf("request error, error status: %d, error message: %s", httpResp.StatusCode, summarizeErrorBody(httpResp.Header.Get("Content-Type"), b))
err = statusErr{code: httpResp.StatusCode, msg: string(b)} err = statusErr{code: httpResp.StatusCode, msg: string(b)}
if errClose := httpResp.Body.Close(); errClose != nil { if errClose := httpResp.Body.Close(); errClose != nil {
log.Errorf("response body close error: %v", errClose) log.Errorf("response body close error: %v", errClose)
@@ -295,7 +295,7 @@ func (e *ClaudeExecutor) ExecuteStream(ctx context.Context, auth *cliproxyauth.A
if httpResp.StatusCode < 200 || httpResp.StatusCode >= 300 { if httpResp.StatusCode < 200 || httpResp.StatusCode >= 300 {
b, _ := io.ReadAll(httpResp.Body) b, _ := io.ReadAll(httpResp.Body)
appendAPIResponseChunk(ctx, e.cfg, b) appendAPIResponseChunk(ctx, e.cfg, b)
log.Debugf("request error, error status: %d, error body: %s", httpResp.StatusCode, summarizeErrorBody(httpResp.Header.Get("Content-Type"), b)) logWithRequestID(ctx).Debugf("request error, error status: %d, error message: %s", httpResp.StatusCode, summarizeErrorBody(httpResp.Header.Get("Content-Type"), b))
if errClose := httpResp.Body.Close(); errClose != nil { if errClose := httpResp.Body.Close(); errClose != nil {
log.Errorf("response body close error: %v", errClose) log.Errorf("response body close error: %v", errClose)
} }
@@ -733,6 +733,11 @@ func applyClaudeToolPrefix(body []byte, prefix string) []byte {
if tools := gjson.GetBytes(body, "tools"); tools.Exists() && tools.IsArray() { if tools := gjson.GetBytes(body, "tools"); tools.Exists() && tools.IsArray() {
tools.ForEach(func(index, tool gjson.Result) bool { tools.ForEach(func(index, tool gjson.Result) bool {
// Skip built-in tools (web_search, code_execution, etc.) which have
// a "type" field and require their name to remain unchanged.
if tool.Get("type").Exists() && tool.Get("type").String() != "" {
return true
}
name := tool.Get("name").String() name := tool.Get("name").String()
if name == "" || strings.HasPrefix(name, prefix) { if name == "" || strings.HasPrefix(name, prefix) {
return true return true

View File

@@ -25,6 +25,18 @@ func TestApplyClaudeToolPrefix(t *testing.T) {
} }
} }
func TestApplyClaudeToolPrefix_SkipsBuiltinTools(t *testing.T) {
input := []byte(`{"tools":[{"type":"web_search_20250305","name":"web_search"},{"name":"my_custom_tool","input_schema":{"type":"object"}}]}`)
out := applyClaudeToolPrefix(input, "proxy_")
if got := gjson.GetBytes(out, "tools.0.name").String(); got != "web_search" {
t.Fatalf("built-in tool name should not be prefixed: tools.0.name = %q, want %q", got, "web_search")
}
if got := gjson.GetBytes(out, "tools.1.name").String(); got != "proxy_my_custom_tool" {
t.Fatalf("custom tool should be prefixed: tools.1.name = %q, want %q", got, "proxy_my_custom_tool")
}
}
func TestStripClaudeToolPrefixFromResponse(t *testing.T) { func TestStripClaudeToolPrefixFromResponse(t *testing.T) {
input := []byte(`{"content":[{"type":"tool_use","name":"proxy_alpha","id":"t1","input":{}},{"type":"tool_use","name":"bravo","id":"t2","input":{}}]}`) input := []byte(`{"content":[{"type":"tool_use","name":"proxy_alpha","id":"t1","input":{}},{"type":"tool_use","name":"bravo","id":"t2","input":{}}]}`)
out := stripClaudeToolPrefixFromResponse(input, "proxy_") out := stripClaudeToolPrefixFromResponse(input, "proxy_")

View File

@@ -150,7 +150,7 @@ func (e *CodexExecutor) Execute(ctx context.Context, auth *cliproxyauth.Auth, re
if httpResp.StatusCode < 200 || httpResp.StatusCode >= 300 { if httpResp.StatusCode < 200 || httpResp.StatusCode >= 300 {
b, _ := io.ReadAll(httpResp.Body) b, _ := io.ReadAll(httpResp.Body)
appendAPIResponseChunk(ctx, e.cfg, b) appendAPIResponseChunk(ctx, e.cfg, b)
log.Debugf("request error, error status: %d, error body: %s", httpResp.StatusCode, summarizeErrorBody(httpResp.Header.Get("Content-Type"), b)) logWithRequestID(ctx).Debugf("request error, error status: %d, error message: %s", httpResp.StatusCode, summarizeErrorBody(httpResp.Header.Get("Content-Type"), b))
err = statusErr{code: httpResp.StatusCode, msg: string(b)} err = statusErr{code: httpResp.StatusCode, msg: string(b)}
return resp, err return resp, err
} }
@@ -265,7 +265,7 @@ func (e *CodexExecutor) ExecuteStream(ctx context.Context, auth *cliproxyauth.Au
return nil, readErr return nil, readErr
} }
appendAPIResponseChunk(ctx, e.cfg, data) appendAPIResponseChunk(ctx, e.cfg, data)
log.Debugf("request error, error status: %d, error body: %s", httpResp.StatusCode, summarizeErrorBody(httpResp.Header.Get("Content-Type"), data)) logWithRequestID(ctx).Debugf("request error, error status: %d, error message: %s", httpResp.StatusCode, summarizeErrorBody(httpResp.Header.Get("Content-Type"), data))
err = statusErr{code: httpResp.StatusCode, msg: string(data)} err = statusErr{code: httpResp.StatusCode, msg: string(data)}
return nil, err return nil, err
} }

View File

@@ -227,7 +227,7 @@ func (e *GeminiCLIExecutor) Execute(ctx context.Context, auth *cliproxyauth.Auth
lastStatus = httpResp.StatusCode lastStatus = httpResp.StatusCode
lastBody = append([]byte(nil), data...) lastBody = append([]byte(nil), data...)
log.Debugf("request error, error status: %d, error body: %s", httpResp.StatusCode, summarizeErrorBody(httpResp.Header.Get("Content-Type"), data)) logWithRequestID(ctx).Debugf("request error, error status: %d, error message: %s", httpResp.StatusCode, summarizeErrorBody(httpResp.Header.Get("Content-Type"), data))
if httpResp.StatusCode == 429 { if httpResp.StatusCode == 429 {
if idx+1 < len(models) { if idx+1 < len(models) {
log.Debugf("gemini cli executor: rate limited, retrying with next model: %s", models[idx+1]) log.Debugf("gemini cli executor: rate limited, retrying with next model: %s", models[idx+1])
@@ -360,7 +360,7 @@ func (e *GeminiCLIExecutor) ExecuteStream(ctx context.Context, auth *cliproxyaut
appendAPIResponseChunk(ctx, e.cfg, data) appendAPIResponseChunk(ctx, e.cfg, data)
lastStatus = httpResp.StatusCode lastStatus = httpResp.StatusCode
lastBody = append([]byte(nil), data...) lastBody = append([]byte(nil), data...)
log.Debugf("request error, error status: %d, error body: %s", httpResp.StatusCode, summarizeErrorBody(httpResp.Header.Get("Content-Type"), data)) logWithRequestID(ctx).Debugf("request error, error status: %d, error message: %s", httpResp.StatusCode, summarizeErrorBody(httpResp.Header.Get("Content-Type"), data))
if httpResp.StatusCode == 429 { if httpResp.StatusCode == 429 {
if idx+1 < len(models) { if idx+1 < len(models) {
log.Debugf("gemini cli executor: rate limited, retrying with next model: %s", models[idx+1]) log.Debugf("gemini cli executor: rate limited, retrying with next model: %s", models[idx+1])

View File

@@ -188,7 +188,7 @@ func (e *GeminiExecutor) Execute(ctx context.Context, auth *cliproxyauth.Auth, r
if httpResp.StatusCode < 200 || httpResp.StatusCode >= 300 { if httpResp.StatusCode < 200 || httpResp.StatusCode >= 300 {
b, _ := io.ReadAll(httpResp.Body) b, _ := io.ReadAll(httpResp.Body)
appendAPIResponseChunk(ctx, e.cfg, b) appendAPIResponseChunk(ctx, e.cfg, b)
log.Debugf("request error, error status: %d, error body: %s", httpResp.StatusCode, summarizeErrorBody(httpResp.Header.Get("Content-Type"), b)) logWithRequestID(ctx).Debugf("request error, error status: %d, error message: %s", httpResp.StatusCode, summarizeErrorBody(httpResp.Header.Get("Content-Type"), b))
err = statusErr{code: httpResp.StatusCode, msg: string(b)} err = statusErr{code: httpResp.StatusCode, msg: string(b)}
return resp, err return resp, err
} }
@@ -282,7 +282,7 @@ func (e *GeminiExecutor) ExecuteStream(ctx context.Context, auth *cliproxyauth.A
if httpResp.StatusCode < 200 || httpResp.StatusCode >= 300 { if httpResp.StatusCode < 200 || httpResp.StatusCode >= 300 {
b, _ := io.ReadAll(httpResp.Body) b, _ := io.ReadAll(httpResp.Body)
appendAPIResponseChunk(ctx, e.cfg, b) appendAPIResponseChunk(ctx, e.cfg, b)
log.Debugf("request error, error status: %d, error body: %s", httpResp.StatusCode, summarizeErrorBody(httpResp.Header.Get("Content-Type"), b)) logWithRequestID(ctx).Debugf("request error, error status: %d, error message: %s", httpResp.StatusCode, summarizeErrorBody(httpResp.Header.Get("Content-Type"), b))
if errClose := httpResp.Body.Close(); errClose != nil { if errClose := httpResp.Body.Close(); errClose != nil {
log.Errorf("gemini executor: close response body error: %v", errClose) log.Errorf("gemini executor: close response body error: %v", errClose)
} }
@@ -402,7 +402,7 @@ func (e *GeminiExecutor) CountTokens(ctx context.Context, auth *cliproxyauth.Aut
} }
appendAPIResponseChunk(ctx, e.cfg, data) appendAPIResponseChunk(ctx, e.cfg, data)
if resp.StatusCode < 200 || resp.StatusCode >= 300 { if resp.StatusCode < 200 || resp.StatusCode >= 300 {
log.Debugf("request error, error status: %d, error body: %s", resp.StatusCode, summarizeErrorBody(resp.Header.Get("Content-Type"), data)) logWithRequestID(ctx).Debugf("request error, error status: %d, error message: %s", resp.StatusCode, summarizeErrorBody(resp.Header.Get("Content-Type"), data))
return cliproxyexecutor.Response{}, statusErr{code: resp.StatusCode, msg: string(data)} return cliproxyexecutor.Response{}, statusErr{code: resp.StatusCode, msg: string(data)}
} }

View File

@@ -389,7 +389,7 @@ func (e *GeminiVertexExecutor) executeWithServiceAccount(ctx context.Context, au
if httpResp.StatusCode < 200 || httpResp.StatusCode >= 300 { if httpResp.StatusCode < 200 || httpResp.StatusCode >= 300 {
b, _ := io.ReadAll(httpResp.Body) b, _ := io.ReadAll(httpResp.Body)
appendAPIResponseChunk(ctx, e.cfg, b) appendAPIResponseChunk(ctx, e.cfg, b)
log.Debugf("request error, error status: %d, error body: %s", httpResp.StatusCode, summarizeErrorBody(httpResp.Header.Get("Content-Type"), b)) logWithRequestID(ctx).Debugf("request error, error status: %d, error message: %s", httpResp.StatusCode, summarizeErrorBody(httpResp.Header.Get("Content-Type"), b))
err = statusErr{code: httpResp.StatusCode, msg: string(b)} err = statusErr{code: httpResp.StatusCode, msg: string(b)}
return resp, err return resp, err
} }
@@ -503,7 +503,7 @@ func (e *GeminiVertexExecutor) executeWithAPIKey(ctx context.Context, auth *clip
if httpResp.StatusCode < 200 || httpResp.StatusCode >= 300 { if httpResp.StatusCode < 200 || httpResp.StatusCode >= 300 {
b, _ := io.ReadAll(httpResp.Body) b, _ := io.ReadAll(httpResp.Body)
appendAPIResponseChunk(ctx, e.cfg, b) appendAPIResponseChunk(ctx, e.cfg, b)
log.Debugf("request error, error status: %d, error body: %s", httpResp.StatusCode, summarizeErrorBody(httpResp.Header.Get("Content-Type"), b)) logWithRequestID(ctx).Debugf("request error, error status: %d, error message: %s", httpResp.StatusCode, summarizeErrorBody(httpResp.Header.Get("Content-Type"), b))
err = statusErr{code: httpResp.StatusCode, msg: string(b)} err = statusErr{code: httpResp.StatusCode, msg: string(b)}
return resp, err return resp, err
} }
@@ -601,7 +601,7 @@ func (e *GeminiVertexExecutor) executeStreamWithServiceAccount(ctx context.Conte
if httpResp.StatusCode < 200 || httpResp.StatusCode >= 300 { if httpResp.StatusCode < 200 || httpResp.StatusCode >= 300 {
b, _ := io.ReadAll(httpResp.Body) b, _ := io.ReadAll(httpResp.Body)
appendAPIResponseChunk(ctx, e.cfg, b) appendAPIResponseChunk(ctx, e.cfg, b)
log.Debugf("request error, error status: %d, error body: %s", httpResp.StatusCode, summarizeErrorBody(httpResp.Header.Get("Content-Type"), b)) logWithRequestID(ctx).Debugf("request error, error status: %d, error message: %s", httpResp.StatusCode, summarizeErrorBody(httpResp.Header.Get("Content-Type"), b))
if errClose := httpResp.Body.Close(); errClose != nil { if errClose := httpResp.Body.Close(); errClose != nil {
log.Errorf("vertex executor: close response body error: %v", errClose) log.Errorf("vertex executor: close response body error: %v", errClose)
} }
@@ -725,7 +725,7 @@ func (e *GeminiVertexExecutor) executeStreamWithAPIKey(ctx context.Context, auth
if httpResp.StatusCode < 200 || httpResp.StatusCode >= 300 { if httpResp.StatusCode < 200 || httpResp.StatusCode >= 300 {
b, _ := io.ReadAll(httpResp.Body) b, _ := io.ReadAll(httpResp.Body)
appendAPIResponseChunk(ctx, e.cfg, b) appendAPIResponseChunk(ctx, e.cfg, b)
log.Debugf("request error, error status: %d, error body: %s", httpResp.StatusCode, summarizeErrorBody(httpResp.Header.Get("Content-Type"), b)) logWithRequestID(ctx).Debugf("request error, error status: %d, error message: %s", httpResp.StatusCode, summarizeErrorBody(httpResp.Header.Get("Content-Type"), b))
if errClose := httpResp.Body.Close(); errClose != nil { if errClose := httpResp.Body.Close(); errClose != nil {
log.Errorf("vertex executor: close response body error: %v", errClose) log.Errorf("vertex executor: close response body error: %v", errClose)
} }
@@ -838,7 +838,7 @@ func (e *GeminiVertexExecutor) countTokensWithServiceAccount(ctx context.Context
if httpResp.StatusCode < 200 || httpResp.StatusCode >= 300 { if httpResp.StatusCode < 200 || httpResp.StatusCode >= 300 {
b, _ := io.ReadAll(httpResp.Body) b, _ := io.ReadAll(httpResp.Body)
appendAPIResponseChunk(ctx, e.cfg, b) appendAPIResponseChunk(ctx, e.cfg, b)
log.Debugf("request error, error status: %d, error body: %s", httpResp.StatusCode, summarizeErrorBody(httpResp.Header.Get("Content-Type"), b)) logWithRequestID(ctx).Debugf("request error, error status: %d, error message: %s", httpResp.StatusCode, summarizeErrorBody(httpResp.Header.Get("Content-Type"), b))
return cliproxyexecutor.Response{}, statusErr{code: httpResp.StatusCode, msg: string(b)} return cliproxyexecutor.Response{}, statusErr{code: httpResp.StatusCode, msg: string(b)}
} }
data, errRead := io.ReadAll(httpResp.Body) data, errRead := io.ReadAll(httpResp.Body)
@@ -922,7 +922,7 @@ func (e *GeminiVertexExecutor) countTokensWithAPIKey(ctx context.Context, auth *
if httpResp.StatusCode < 200 || httpResp.StatusCode >= 300 { if httpResp.StatusCode < 200 || httpResp.StatusCode >= 300 {
b, _ := io.ReadAll(httpResp.Body) b, _ := io.ReadAll(httpResp.Body)
appendAPIResponseChunk(ctx, e.cfg, b) appendAPIResponseChunk(ctx, e.cfg, b)
log.Debugf("request error, error status: %d, error body: %s", httpResp.StatusCode, summarizeErrorBody(httpResp.Header.Get("Content-Type"), b)) logWithRequestID(ctx).Debugf("request error, error status: %d, error message: %s", httpResp.StatusCode, summarizeErrorBody(httpResp.Header.Get("Content-Type"), b))
return cliproxyexecutor.Response{}, statusErr{code: httpResp.StatusCode, msg: string(b)} return cliproxyexecutor.Response{}, statusErr{code: httpResp.StatusCode, msg: string(b)}
} }
data, errRead := io.ReadAll(httpResp.Body) data, errRead := io.ReadAll(httpResp.Body)

View File

@@ -142,7 +142,7 @@ func (e *IFlowExecutor) Execute(ctx context.Context, auth *cliproxyauth.Auth, re
if httpResp.StatusCode < 200 || httpResp.StatusCode >= 300 { if httpResp.StatusCode < 200 || httpResp.StatusCode >= 300 {
b, _ := io.ReadAll(httpResp.Body) b, _ := io.ReadAll(httpResp.Body)
appendAPIResponseChunk(ctx, e.cfg, b) appendAPIResponseChunk(ctx, e.cfg, b)
log.Debugf("iflow request error: status %d body %s", httpResp.StatusCode, summarizeErrorBody(httpResp.Header.Get("Content-Type"), b)) logWithRequestID(ctx).Debugf("request error, error status: %d error message: %s", httpResp.StatusCode, summarizeErrorBody(httpResp.Header.Get("Content-Type"), b))
err = statusErr{code: httpResp.StatusCode, msg: string(b)} err = statusErr{code: httpResp.StatusCode, msg: string(b)}
return resp, err return resp, err
} }
@@ -244,7 +244,7 @@ func (e *IFlowExecutor) ExecuteStream(ctx context.Context, auth *cliproxyauth.Au
log.Errorf("iflow executor: close response body error: %v", errClose) log.Errorf("iflow executor: close response body error: %v", errClose)
} }
appendAPIResponseChunk(ctx, e.cfg, data) appendAPIResponseChunk(ctx, e.cfg, data)
log.Debugf("iflow streaming error: status %d body %s", httpResp.StatusCode, summarizeErrorBody(httpResp.Header.Get("Content-Type"), data)) logWithRequestID(ctx).Debugf("request error, error status: %d error message: %s", httpResp.StatusCode, summarizeErrorBody(httpResp.Header.Get("Content-Type"), data))
err = statusErr{code: httpResp.StatusCode, msg: string(data)} err = statusErr{code: httpResp.StatusCode, msg: string(data)}
return nil, err return nil, err
} }

View File

@@ -791,28 +791,28 @@ func (e *KiroExecutor) executeWithRetry(ctx context.Context, auth *cliproxyauth.
_ = httpResp.Body.Close() _ = httpResp.Body.Close()
appendAPIResponseChunk(ctx, e.cfg, respBody) appendAPIResponseChunk(ctx, e.cfg, respBody)
if attempt < maxRetries { log.Warnf("kiro: received 401 error, attempting token refresh")
log.Warnf("kiro: received 401 error, attempting token refresh and retry (attempt %d/%d)", attempt+1, maxRetries+1) refreshedAuth, refreshErr := e.Refresh(ctx, auth)
if refreshErr != nil {
log.Errorf("kiro: token refresh failed: %v", refreshErr)
return resp, statusErr{code: httpResp.StatusCode, msg: string(respBody)}
}
refreshedAuth, refreshErr := e.Refresh(ctx, auth) if refreshedAuth != nil {
if refreshErr != nil { auth = refreshedAuth
log.Errorf("kiro: token refresh failed: %v", refreshErr) // Persist the refreshed auth to file so subsequent requests use it
return resp, statusErr{code: httpResp.StatusCode, msg: string(respBody)} if persistErr := e.persistRefreshedAuth(auth); persistErr != nil {
log.Warnf("kiro: failed to persist refreshed auth: %v", persistErr)
// Continue anyway - the token is valid for this request
} }
accessToken, profileArn = kiroCredentials(auth)
if refreshedAuth != nil { // Rebuild payload with new profile ARN if changed
auth = refreshedAuth kiroPayload, _ = buildKiroPayloadForFormat(body, kiroModelID, profileArn, currentOrigin, isAgentic, isChatOnly, from, opts.Headers)
// Persist the refreshed auth to file so subsequent requests use it if attempt < maxRetries {
if persistErr := e.persistRefreshedAuth(auth); persistErr != nil { log.Infof("kiro: token refreshed successfully, retrying request (attempt %d/%d)", attempt+1, maxRetries+1)
log.Warnf("kiro: failed to persist refreshed auth: %v", persistErr)
// Continue anyway - the token is valid for this request
}
accessToken, profileArn = kiroCredentials(auth)
// Rebuild payload with new profile ARN if changed
kiroPayload, _ = buildKiroPayloadForFormat(body, kiroModelID, profileArn, currentOrigin, isAgentic, isChatOnly, from, opts.Headers)
log.Infof("kiro: token refreshed successfully, retrying request")
continue continue
} }
log.Infof("kiro: token refreshed successfully, no retries remaining")
} }
log.Warnf("kiro request error, status: 401, body: %s", summarizeErrorBody(httpResp.Header.Get("Content-Type"), respBody)) log.Warnf("kiro request error, status: 401, body: %s", summarizeErrorBody(httpResp.Header.Get("Content-Type"), respBody))
@@ -1199,28 +1199,28 @@ func (e *KiroExecutor) executeStreamWithRetry(ctx context.Context, auth *cliprox
_ = httpResp.Body.Close() _ = httpResp.Body.Close()
appendAPIResponseChunk(ctx, e.cfg, respBody) appendAPIResponseChunk(ctx, e.cfg, respBody)
if attempt < maxRetries { log.Warnf("kiro: stream received 401 error, attempting token refresh")
log.Warnf("kiro: stream received 401 error, attempting token refresh and retry (attempt %d/%d)", attempt+1, maxRetries+1) refreshedAuth, refreshErr := e.Refresh(ctx, auth)
if refreshErr != nil {
log.Errorf("kiro: token refresh failed: %v", refreshErr)
return nil, statusErr{code: httpResp.StatusCode, msg: string(respBody)}
}
refreshedAuth, refreshErr := e.Refresh(ctx, auth) if refreshedAuth != nil {
if refreshErr != nil { auth = refreshedAuth
log.Errorf("kiro: token refresh failed: %v", refreshErr) // Persist the refreshed auth to file so subsequent requests use it
return nil, statusErr{code: httpResp.StatusCode, msg: string(respBody)} if persistErr := e.persistRefreshedAuth(auth); persistErr != nil {
log.Warnf("kiro: failed to persist refreshed auth: %v", persistErr)
// Continue anyway - the token is valid for this request
} }
accessToken, profileArn = kiroCredentials(auth)
if refreshedAuth != nil { // Rebuild payload with new profile ARN if changed
auth = refreshedAuth kiroPayload, _ = buildKiroPayloadForFormat(body, kiroModelID, profileArn, currentOrigin, isAgentic, isChatOnly, from, opts.Headers)
// Persist the refreshed auth to file so subsequent requests use it if attempt < maxRetries {
if persistErr := e.persistRefreshedAuth(auth); persistErr != nil { log.Infof("kiro: token refreshed successfully, retrying stream request (attempt %d/%d)", attempt+1, maxRetries+1)
log.Warnf("kiro: failed to persist refreshed auth: %v", persistErr)
// Continue anyway - the token is valid for this request
}
accessToken, profileArn = kiroCredentials(auth)
// Rebuild payload with new profile ARN if changed
kiroPayload, _ = buildKiroPayloadForFormat(body, kiroModelID, profileArn, currentOrigin, isAgentic, isChatOnly, from, opts.Headers)
log.Infof("kiro: token refreshed successfully, retrying stream request")
continue continue
} }
log.Infof("kiro: token refreshed successfully, no retries remaining")
} }
log.Warnf("kiro stream error, status: 401, body: %s", string(respBody)) log.Warnf("kiro stream error, status: 401, body: %s", string(respBody))

View File

@@ -12,7 +12,10 @@ import (
"github.com/gin-gonic/gin" "github.com/gin-gonic/gin"
"github.com/router-for-me/CLIProxyAPI/v6/internal/config" "github.com/router-for-me/CLIProxyAPI/v6/internal/config"
"github.com/router-for-me/CLIProxyAPI/v6/internal/logging"
"github.com/router-for-me/CLIProxyAPI/v6/internal/util" "github.com/router-for-me/CLIProxyAPI/v6/internal/util"
log "github.com/sirupsen/logrus"
"github.com/tidwall/gjson"
) )
const ( const (
@@ -332,6 +335,12 @@ func summarizeErrorBody(contentType string, body []byte) string {
} }
return "[html body omitted]" return "[html body omitted]"
} }
// Try to extract error message from JSON response
if message := extractJSONErrorMessage(body); message != "" {
return message
}
return string(body) return string(body)
} }
@@ -358,3 +367,25 @@ func extractHTMLTitle(body []byte) string {
} }
return strings.Join(strings.Fields(title), " ") return strings.Join(strings.Fields(title), " ")
} }
// extractJSONErrorMessage attempts to extract error.message from JSON error responses
func extractJSONErrorMessage(body []byte) string {
result := gjson.GetBytes(body, "error.message")
if result.Exists() && result.String() != "" {
return result.String()
}
return ""
}
// logWithRequestID returns a logrus Entry with request_id field populated from context.
// If no request ID is found in context, it returns the standard logger.
func logWithRequestID(ctx context.Context) *log.Entry {
if ctx == nil {
return log.NewEntry(log.StandardLogger())
}
requestID := logging.GetRequestID(ctx)
if requestID == "" {
return log.NewEntry(log.StandardLogger())
}
return log.WithField("request_id", requestID)
}

View File

@@ -146,7 +146,7 @@ func (e *OpenAICompatExecutor) Execute(ctx context.Context, auth *cliproxyauth.A
if httpResp.StatusCode < 200 || httpResp.StatusCode >= 300 { if httpResp.StatusCode < 200 || httpResp.StatusCode >= 300 {
b, _ := io.ReadAll(httpResp.Body) b, _ := io.ReadAll(httpResp.Body)
appendAPIResponseChunk(ctx, e.cfg, b) appendAPIResponseChunk(ctx, e.cfg, b)
log.Debugf("request error, error status: %d, error body: %s", httpResp.StatusCode, summarizeErrorBody(httpResp.Header.Get("Content-Type"), b)) logWithRequestID(ctx).Debugf("request error, error status: %d, error message: %s", httpResp.StatusCode, summarizeErrorBody(httpResp.Header.Get("Content-Type"), b))
err = statusErr{code: httpResp.StatusCode, msg: string(b)} err = statusErr{code: httpResp.StatusCode, msg: string(b)}
return resp, err return resp, err
} }
@@ -239,7 +239,7 @@ func (e *OpenAICompatExecutor) ExecuteStream(ctx context.Context, auth *cliproxy
if httpResp.StatusCode < 200 || httpResp.StatusCode >= 300 { if httpResp.StatusCode < 200 || httpResp.StatusCode >= 300 {
b, _ := io.ReadAll(httpResp.Body) b, _ := io.ReadAll(httpResp.Body)
appendAPIResponseChunk(ctx, e.cfg, b) appendAPIResponseChunk(ctx, e.cfg, b)
log.Debugf("request error, error status: %d, error body: %s", httpResp.StatusCode, summarizeErrorBody(httpResp.Header.Get("Content-Type"), b)) logWithRequestID(ctx).Debugf("request error, error status: %d, error message: %s", httpResp.StatusCode, summarizeErrorBody(httpResp.Header.Get("Content-Type"), b))
if errClose := httpResp.Body.Close(); errClose != nil { if errClose := httpResp.Body.Close(); errClose != nil {
log.Errorf("openai compat executor: close response body error: %v", errClose) log.Errorf("openai compat executor: close response body error: %v", errClose)
} }

View File

@@ -133,7 +133,7 @@ func (e *QwenExecutor) Execute(ctx context.Context, auth *cliproxyauth.Auth, req
if httpResp.StatusCode < 200 || httpResp.StatusCode >= 300 { if httpResp.StatusCode < 200 || httpResp.StatusCode >= 300 {
b, _ := io.ReadAll(httpResp.Body) b, _ := io.ReadAll(httpResp.Body)
appendAPIResponseChunk(ctx, e.cfg, b) appendAPIResponseChunk(ctx, e.cfg, b)
log.Debugf("request error, error status: %d, error body: %s", httpResp.StatusCode, summarizeErrorBody(httpResp.Header.Get("Content-Type"), b)) logWithRequestID(ctx).Debugf("request error, error status: %d, error message: %s", httpResp.StatusCode, summarizeErrorBody(httpResp.Header.Get("Content-Type"), b))
err = statusErr{code: httpResp.StatusCode, msg: string(b)} err = statusErr{code: httpResp.StatusCode, msg: string(b)}
return resp, err return resp, err
} }
@@ -222,7 +222,7 @@ func (e *QwenExecutor) ExecuteStream(ctx context.Context, auth *cliproxyauth.Aut
if httpResp.StatusCode < 200 || httpResp.StatusCode >= 300 { if httpResp.StatusCode < 200 || httpResp.StatusCode >= 300 {
b, _ := io.ReadAll(httpResp.Body) b, _ := io.ReadAll(httpResp.Body)
appendAPIResponseChunk(ctx, e.cfg, b) appendAPIResponseChunk(ctx, e.cfg, b)
log.Debugf("request error, error status: %d, error body: %s", httpResp.StatusCode, summarizeErrorBody(httpResp.Header.Get("Content-Type"), b)) logWithRequestID(ctx).Debugf("request error, error status: %d, error message: %s", httpResp.StatusCode, summarizeErrorBody(httpResp.Header.Get("Content-Type"), b))
if errClose := httpResp.Body.Close(); errClose != nil { if errClose := httpResp.Body.Close(); errClose != nil {
log.Errorf("qwen executor: close response body error: %v", errClose) log.Errorf("qwen executor: close response body error: %v", errClose)
} }

View File

@@ -305,12 +305,12 @@ func ConvertOpenAIRequestToAntigravity(modelName string, inputRawJSON []byte, _
} }
} }
// tools -> request.tools[0].functionDeclarations + request.tools[0].googleSearch passthrough // tools -> request.tools[].functionDeclarations + request.tools[].googleSearch passthrough
tools := gjson.GetBytes(rawJSON, "tools") tools := gjson.GetBytes(rawJSON, "tools")
if tools.IsArray() && len(tools.Array()) > 0 { if tools.IsArray() && len(tools.Array()) > 0 {
toolNode := []byte(`{}`) functionToolNode := []byte(`{}`)
hasTool := false
hasFunction := false hasFunction := false
googleSearchNodes := make([][]byte, 0)
for _, t := range tools.Array() { for _, t := range tools.Array() {
if t.Get("type").String() == "function" { if t.Get("type").String() == "function" {
fn := t.Get("function") fn := t.Get("function")
@@ -349,31 +349,37 @@ func ConvertOpenAIRequestToAntigravity(modelName string, inputRawJSON []byte, _
} }
fnRaw, _ = sjson.Delete(fnRaw, "strict") fnRaw, _ = sjson.Delete(fnRaw, "strict")
if !hasFunction { if !hasFunction {
toolNode, _ = sjson.SetRawBytes(toolNode, "functionDeclarations", []byte("[]")) functionToolNode, _ = sjson.SetRawBytes(functionToolNode, "functionDeclarations", []byte("[]"))
} }
tmp, errSet := sjson.SetRawBytes(toolNode, "functionDeclarations.-1", []byte(fnRaw)) tmp, errSet := sjson.SetRawBytes(functionToolNode, "functionDeclarations.-1", []byte(fnRaw))
if errSet != nil { if errSet != nil {
log.Warnf("Failed to append tool declaration for '%s': %v", fn.Get("name").String(), errSet) log.Warnf("Failed to append tool declaration for '%s': %v", fn.Get("name").String(), errSet)
continue continue
} }
toolNode = tmp functionToolNode = tmp
hasFunction = true hasFunction = true
hasTool = true
} }
} }
if gs := t.Get("google_search"); gs.Exists() { if gs := t.Get("google_search"); gs.Exists() {
googleToolNode := []byte(`{}`)
var errSet error var errSet error
toolNode, errSet = sjson.SetRawBytes(toolNode, "googleSearch", []byte(gs.Raw)) googleToolNode, errSet = sjson.SetRawBytes(googleToolNode, "googleSearch", []byte(gs.Raw))
if errSet != nil { if errSet != nil {
log.Warnf("Failed to set googleSearch tool: %v", errSet) log.Warnf("Failed to set googleSearch tool: %v", errSet)
continue continue
} }
hasTool = true googleSearchNodes = append(googleSearchNodes, googleToolNode)
} }
} }
if hasTool { if hasFunction || len(googleSearchNodes) > 0 {
out, _ = sjson.SetRawBytes(out, "request.tools", []byte("[]")) toolsNode := []byte("[]")
out, _ = sjson.SetRawBytes(out, "request.tools.0", toolNode) if hasFunction {
toolsNode, _ = sjson.SetRawBytes(toolsNode, "-1", functionToolNode)
}
for _, googleNode := range googleSearchNodes {
toolsNode, _ = sjson.SetRawBytes(toolsNode, "-1", googleNode)
}
out, _ = sjson.SetRawBytes(out, "request.tools", toolsNode)
} }
} }

View File

@@ -283,12 +283,12 @@ func ConvertOpenAIRequestToGeminiCLI(modelName string, inputRawJSON []byte, _ bo
} }
} }
// tools -> request.tools[0].functionDeclarations + request.tools[0].googleSearch passthrough // tools -> request.tools[].functionDeclarations + request.tools[].googleSearch passthrough
tools := gjson.GetBytes(rawJSON, "tools") tools := gjson.GetBytes(rawJSON, "tools")
if tools.IsArray() && len(tools.Array()) > 0 { if tools.IsArray() && len(tools.Array()) > 0 {
toolNode := []byte(`{}`) functionToolNode := []byte(`{}`)
hasTool := false
hasFunction := false hasFunction := false
googleSearchNodes := make([][]byte, 0)
for _, t := range tools.Array() { for _, t := range tools.Array() {
if t.Get("type").String() == "function" { if t.Get("type").String() == "function" {
fn := t.Get("function") fn := t.Get("function")
@@ -327,31 +327,37 @@ func ConvertOpenAIRequestToGeminiCLI(modelName string, inputRawJSON []byte, _ bo
} }
fnRaw, _ = sjson.Delete(fnRaw, "strict") fnRaw, _ = sjson.Delete(fnRaw, "strict")
if !hasFunction { if !hasFunction {
toolNode, _ = sjson.SetRawBytes(toolNode, "functionDeclarations", []byte("[]")) functionToolNode, _ = sjson.SetRawBytes(functionToolNode, "functionDeclarations", []byte("[]"))
} }
tmp, errSet := sjson.SetRawBytes(toolNode, "functionDeclarations.-1", []byte(fnRaw)) tmp, errSet := sjson.SetRawBytes(functionToolNode, "functionDeclarations.-1", []byte(fnRaw))
if errSet != nil { if errSet != nil {
log.Warnf("Failed to append tool declaration for '%s': %v", fn.Get("name").String(), errSet) log.Warnf("Failed to append tool declaration for '%s': %v", fn.Get("name").String(), errSet)
continue continue
} }
toolNode = tmp functionToolNode = tmp
hasFunction = true hasFunction = true
hasTool = true
} }
} }
if gs := t.Get("google_search"); gs.Exists() { if gs := t.Get("google_search"); gs.Exists() {
googleToolNode := []byte(`{}`)
var errSet error var errSet error
toolNode, errSet = sjson.SetRawBytes(toolNode, "googleSearch", []byte(gs.Raw)) googleToolNode, errSet = sjson.SetRawBytes(googleToolNode, "googleSearch", []byte(gs.Raw))
if errSet != nil { if errSet != nil {
log.Warnf("Failed to set googleSearch tool: %v", errSet) log.Warnf("Failed to set googleSearch tool: %v", errSet)
continue continue
} }
hasTool = true googleSearchNodes = append(googleSearchNodes, googleToolNode)
} }
} }
if hasTool { if hasFunction || len(googleSearchNodes) > 0 {
out, _ = sjson.SetRawBytes(out, "request.tools", []byte("[]")) toolsNode := []byte("[]")
out, _ = sjson.SetRawBytes(out, "request.tools.0", toolNode) if hasFunction {
toolsNode, _ = sjson.SetRawBytes(toolsNode, "-1", functionToolNode)
}
for _, googleNode := range googleSearchNodes {
toolsNode, _ = sjson.SetRawBytes(toolsNode, "-1", googleNode)
}
out, _ = sjson.SetRawBytes(out, "request.tools", toolsNode)
} }
} }

View File

@@ -289,12 +289,12 @@ func ConvertOpenAIRequestToGemini(modelName string, inputRawJSON []byte, _ bool)
} }
} }
// tools -> tools[0].functionDeclarations + tools[0].googleSearch passthrough // tools -> tools[].functionDeclarations + tools[].googleSearch passthrough
tools := gjson.GetBytes(rawJSON, "tools") tools := gjson.GetBytes(rawJSON, "tools")
if tools.IsArray() && len(tools.Array()) > 0 { if tools.IsArray() && len(tools.Array()) > 0 {
toolNode := []byte(`{}`) functionToolNode := []byte(`{}`)
hasTool := false
hasFunction := false hasFunction := false
googleSearchNodes := make([][]byte, 0)
for _, t := range tools.Array() { for _, t := range tools.Array() {
if t.Get("type").String() == "function" { if t.Get("type").String() == "function" {
fn := t.Get("function") fn := t.Get("function")
@@ -333,31 +333,37 @@ func ConvertOpenAIRequestToGemini(modelName string, inputRawJSON []byte, _ bool)
} }
fnRaw, _ = sjson.Delete(fnRaw, "strict") fnRaw, _ = sjson.Delete(fnRaw, "strict")
if !hasFunction { if !hasFunction {
toolNode, _ = sjson.SetRawBytes(toolNode, "functionDeclarations", []byte("[]")) functionToolNode, _ = sjson.SetRawBytes(functionToolNode, "functionDeclarations", []byte("[]"))
} }
tmp, errSet := sjson.SetRawBytes(toolNode, "functionDeclarations.-1", []byte(fnRaw)) tmp, errSet := sjson.SetRawBytes(functionToolNode, "functionDeclarations.-1", []byte(fnRaw))
if errSet != nil { if errSet != nil {
log.Warnf("Failed to append tool declaration for '%s': %v", fn.Get("name").String(), errSet) log.Warnf("Failed to append tool declaration for '%s': %v", fn.Get("name").String(), errSet)
continue continue
} }
toolNode = tmp functionToolNode = tmp
hasFunction = true hasFunction = true
hasTool = true
} }
} }
if gs := t.Get("google_search"); gs.Exists() { if gs := t.Get("google_search"); gs.Exists() {
googleToolNode := []byte(`{}`)
var errSet error var errSet error
toolNode, errSet = sjson.SetRawBytes(toolNode, "googleSearch", []byte(gs.Raw)) googleToolNode, errSet = sjson.SetRawBytes(googleToolNode, "googleSearch", []byte(gs.Raw))
if errSet != nil { if errSet != nil {
log.Warnf("Failed to set googleSearch tool: %v", errSet) log.Warnf("Failed to set googleSearch tool: %v", errSet)
continue continue
} }
hasTool = true googleSearchNodes = append(googleSearchNodes, googleToolNode)
} }
} }
if hasTool { if hasFunction || len(googleSearchNodes) > 0 {
out, _ = sjson.SetRawBytes(out, "tools", []byte("[]")) toolsNode := []byte("[]")
out, _ = sjson.SetRawBytes(out, "tools.0", toolNode) if hasFunction {
toolsNode, _ = sjson.SetRawBytes(toolsNode, "-1", functionToolNode)
}
for _, googleNode := range googleSearchNodes {
toolsNode, _ = sjson.SetRawBytes(toolsNode, "-1", googleNode)
}
out, _ = sjson.SetRawBytes(out, "tools", toolsNode)
} }
} }

View File

@@ -499,6 +499,16 @@ func shortenToolNameIfNeeded(name string) string {
return name[:limit] return name[:limit]
} }
func ensureKiroInputSchema(parameters interface{}) interface{} {
if parameters != nil {
return parameters
}
return map[string]interface{}{
"type": "object",
"properties": map[string]interface{}{},
}
}
// convertClaudeToolsToKiro converts Claude tools to Kiro format // convertClaudeToolsToKiro converts Claude tools to Kiro format
func convertClaudeToolsToKiro(tools gjson.Result) []KiroToolWrapper { func convertClaudeToolsToKiro(tools gjson.Result) []KiroToolWrapper {
var kiroTools []KiroToolWrapper var kiroTools []KiroToolWrapper
@@ -509,7 +519,12 @@ func convertClaudeToolsToKiro(tools gjson.Result) []KiroToolWrapper {
for _, tool := range tools.Array() { for _, tool := range tools.Array() {
name := tool.Get("name").String() name := tool.Get("name").String()
description := tool.Get("description").String() description := tool.Get("description").String()
inputSchema := tool.Get("input_schema").Value() inputSchemaResult := tool.Get("input_schema")
var inputSchema interface{}
if inputSchemaResult.Exists() && inputSchemaResult.Type != gjson.Null {
inputSchema = inputSchemaResult.Value()
}
inputSchema = ensureKiroInputSchema(inputSchema)
// Shorten tool name if it exceeds 64 characters (common with MCP tools) // Shorten tool name if it exceeds 64 characters (common with MCP tools)
originalName := name originalName := name

View File

@@ -314,7 +314,7 @@ func ConvertOpenAIToolsToKiroFormat(tools []map[string]interface{}) []KiroToolWr
name := kirocommon.GetString(fn, "name") name := kirocommon.GetString(fn, "name")
description := kirocommon.GetString(fn, "description") description := kirocommon.GetString(fn, "description")
parameters := fn["parameters"] parameters := ensureKiroInputSchema(fn["parameters"])
if name == "" { if name == "" {
continue continue
@@ -368,4 +368,4 @@ func ConvertClaudeToolUseToOpenAI(toolUseID, toolName string, input map[string]i
// LogStreamEvent logs a streaming event for debugging // LogStreamEvent logs a streaming event for debugging
func LogStreamEvent(eventType, data string) { func LogStreamEvent(eventType, data string) {
log.Debugf("kiro-openai: stream event type=%s, data_len=%d", eventType, len(data)) log.Debugf("kiro-openai: stream event type=%s, data_len=%d", eventType, len(data))
} }

View File

@@ -381,6 +381,16 @@ func shortenToolNameIfNeeded(name string) string {
return name[:limit] return name[:limit]
} }
func ensureKiroInputSchema(parameters interface{}) interface{} {
if parameters != nil {
return parameters
}
return map[string]interface{}{
"type": "object",
"properties": map[string]interface{}{},
}
}
// convertOpenAIToolsToKiro converts OpenAI tools to Kiro format // convertOpenAIToolsToKiro converts OpenAI tools to Kiro format
func convertOpenAIToolsToKiro(tools gjson.Result) []KiroToolWrapper { func convertOpenAIToolsToKiro(tools gjson.Result) []KiroToolWrapper {
var kiroTools []KiroToolWrapper var kiroTools []KiroToolWrapper
@@ -401,7 +411,12 @@ func convertOpenAIToolsToKiro(tools gjson.Result) []KiroToolWrapper {
name := fn.Get("name").String() name := fn.Get("name").String()
description := fn.Get("description").String() description := fn.Get("description").String()
parameters := fn.Get("parameters").Value() parametersResult := fn.Get("parameters")
var parameters interface{}
if parametersResult.Exists() && parametersResult.Type != gjson.Null {
parameters = parametersResult.Value()
}
parameters = ensureKiroInputSchema(parameters)
// Shorten tool name if it exceeds 64 characters (common with MCP tools) // Shorten tool name if it exceeds 64 characters (common with MCP tools)
originalName := name originalName := name

View File

@@ -181,11 +181,11 @@ func TestConvertClaudeRequestToOpenAI_ThinkingToReasoningContent(t *testing.T) {
result := ConvertClaudeRequestToOpenAI("test-model", []byte(tt.inputJSON), false) result := ConvertClaudeRequestToOpenAI("test-model", []byte(tt.inputJSON), false)
resultJSON := gjson.ParseBytes(result) resultJSON := gjson.ParseBytes(result)
// Find the relevant message (skip system message at index 0) // Find the relevant message
messages := resultJSON.Get("messages").Array() messages := resultJSON.Get("messages").Array()
if len(messages) < 2 { if len(messages) < 1 {
if tt.wantHasReasoningContent || tt.wantHasContent { if tt.wantHasReasoningContent || tt.wantHasContent {
t.Fatalf("Expected at least 2 messages (system + user/assistant), got %d", len(messages)) t.Fatalf("Expected at least 1 message, got %d", len(messages))
} }
return return
} }
@@ -272,15 +272,15 @@ func TestConvertClaudeRequestToOpenAI_ThinkingOnlyMessagePreserved(t *testing.T)
messages := resultJSON.Get("messages").Array() messages := resultJSON.Get("messages").Array()
// Should have: system (auto-added) + user + assistant (thinking-only) + user = 4 messages // Should have: user + assistant (thinking-only) + user = 3 messages
if len(messages) != 4 { if len(messages) != 3 {
t.Fatalf("Expected 4 messages, got %d. Messages: %v", len(messages), resultJSON.Get("messages").Raw) t.Fatalf("Expected 3 messages, got %d. Messages: %v", len(messages), resultJSON.Get("messages").Raw)
} }
// Check the assistant message (index 2) has reasoning_content // Check the assistant message (index 1) has reasoning_content
assistantMsg := messages[2] assistantMsg := messages[1]
if assistantMsg.Get("role").String() != "assistant" { if assistantMsg.Get("role").String() != "assistant" {
t.Errorf("Expected message[2] to be assistant, got %s", assistantMsg.Get("role").String()) t.Errorf("Expected message[1] to be assistant, got %s", assistantMsg.Get("role").String())
} }
if !assistantMsg.Get("reasoning_content").Exists() { if !assistantMsg.Get("reasoning_content").Exists() {
@@ -292,6 +292,104 @@ func TestConvertClaudeRequestToOpenAI_ThinkingOnlyMessagePreserved(t *testing.T)
} }
} }
func TestConvertClaudeRequestToOpenAI_SystemMessageScenarios(t *testing.T) {
tests := []struct {
name string
inputJSON string
wantHasSys bool
wantSysText string
}{
{
name: "No system field",
inputJSON: `{
"model": "claude-3-opus",
"messages": [{"role": "user", "content": "hello"}]
}`,
wantHasSys: false,
},
{
name: "Empty string system field",
inputJSON: `{
"model": "claude-3-opus",
"system": "",
"messages": [{"role": "user", "content": "hello"}]
}`,
wantHasSys: false,
},
{
name: "String system field",
inputJSON: `{
"model": "claude-3-opus",
"system": "Be helpful",
"messages": [{"role": "user", "content": "hello"}]
}`,
wantHasSys: true,
wantSysText: "Be helpful",
},
{
name: "Array system field with text",
inputJSON: `{
"model": "claude-3-opus",
"system": [{"type": "text", "text": "Array system"}],
"messages": [{"role": "user", "content": "hello"}]
}`,
wantHasSys: true,
wantSysText: "Array system",
},
{
name: "Array system field with multiple text blocks",
inputJSON: `{
"model": "claude-3-opus",
"system": [
{"type": "text", "text": "Block 1"},
{"type": "text", "text": "Block 2"}
],
"messages": [{"role": "user", "content": "hello"}]
}`,
wantHasSys: true,
wantSysText: "Block 2", // We will update the test logic to check all blocks or specifically the second one
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
result := ConvertClaudeRequestToOpenAI("test-model", []byte(tt.inputJSON), false)
resultJSON := gjson.ParseBytes(result)
messages := resultJSON.Get("messages").Array()
hasSys := false
var sysMsg gjson.Result
if len(messages) > 0 && messages[0].Get("role").String() == "system" {
hasSys = true
sysMsg = messages[0]
}
if hasSys != tt.wantHasSys {
t.Errorf("got hasSystem = %v, want %v", hasSys, tt.wantHasSys)
}
if tt.wantHasSys {
// Check content - it could be string or array in OpenAI
content := sysMsg.Get("content")
var gotText string
if content.IsArray() {
arr := content.Array()
if len(arr) > 0 {
// Get the last element's text for validation
gotText = arr[len(arr)-1].Get("text").String()
}
} else {
gotText = content.String()
}
if tt.wantSysText != "" && gotText != tt.wantSysText {
t.Errorf("got system text = %q, want %q", gotText, tt.wantSysText)
}
}
})
}
}
func TestConvertClaudeRequestToOpenAI_ToolResultOrderAndContent(t *testing.T) { func TestConvertClaudeRequestToOpenAI_ToolResultOrderAndContent(t *testing.T) {
inputJSON := `{ inputJSON := `{
"model": "claude-3-opus", "model": "claude-3-opus",
@@ -318,39 +416,35 @@ func TestConvertClaudeRequestToOpenAI_ToolResultOrderAndContent(t *testing.T) {
messages := resultJSON.Get("messages").Array() messages := resultJSON.Get("messages").Array()
// OpenAI requires: tool messages MUST immediately follow assistant(tool_calls). // OpenAI requires: tool messages MUST immediately follow assistant(tool_calls).
// Correct order: system + assistant(tool_calls) + tool(result) + user(before+after) // Correct order: assistant(tool_calls) + tool(result) + user(before+after)
if len(messages) != 4 { if len(messages) != 3 {
t.Fatalf("Expected 4 messages, got %d. Messages: %s", len(messages), resultJSON.Get("messages").Raw) t.Fatalf("Expected 3 messages, got %d. Messages: %s", len(messages), resultJSON.Get("messages").Raw)
} }
if messages[0].Get("role").String() != "system" { if messages[0].Get("role").String() != "assistant" || !messages[0].Get("tool_calls").Exists() {
t.Fatalf("Expected messages[0] to be system, got %s", messages[0].Get("role").String()) t.Fatalf("Expected messages[0] to be assistant tool_calls, got %s: %s", messages[0].Get("role").String(), messages[0].Raw)
}
if messages[1].Get("role").String() != "assistant" || !messages[1].Get("tool_calls").Exists() {
t.Fatalf("Expected messages[1] to be assistant tool_calls, got %s: %s", messages[1].Get("role").String(), messages[1].Raw)
} }
// tool message MUST immediately follow assistant(tool_calls) per OpenAI spec // tool message MUST immediately follow assistant(tool_calls) per OpenAI spec
if messages[2].Get("role").String() != "tool" { if messages[1].Get("role").String() != "tool" {
t.Fatalf("Expected messages[2] to be tool (must follow tool_calls), got %s", messages[2].Get("role").String()) t.Fatalf("Expected messages[1] to be tool (must follow tool_calls), got %s", messages[1].Get("role").String())
} }
if got := messages[2].Get("tool_call_id").String(); got != "call_1" { if got := messages[1].Get("tool_call_id").String(); got != "call_1" {
t.Fatalf("Expected tool_call_id %q, got %q", "call_1", got) t.Fatalf("Expected tool_call_id %q, got %q", "call_1", got)
} }
if got := messages[2].Get("content").String(); got != "tool ok" { if got := messages[1].Get("content").String(); got != "tool ok" {
t.Fatalf("Expected tool content %q, got %q", "tool ok", got) t.Fatalf("Expected tool content %q, got %q", "tool ok", got)
} }
// User message comes after tool message // User message comes after tool message
if messages[3].Get("role").String() != "user" { if messages[2].Get("role").String() != "user" {
t.Fatalf("Expected messages[3] to be user, got %s", messages[3].Get("role").String()) t.Fatalf("Expected messages[2] to be user, got %s", messages[2].Get("role").String())
} }
// User message should contain both "before" and "after" text // User message should contain both "before" and "after" text
if got := messages[3].Get("content.0.text").String(); got != "before" { if got := messages[2].Get("content.0.text").String(); got != "before" {
t.Fatalf("Expected user text[0] %q, got %q", "before", got) t.Fatalf("Expected user text[0] %q, got %q", "before", got)
} }
if got := messages[3].Get("content.1.text").String(); got != "after" { if got := messages[2].Get("content.1.text").String(); got != "after" {
t.Fatalf("Expected user text[1] %q, got %q", "after", got) t.Fatalf("Expected user text[1] %q, got %q", "after", got)
} }
} }
@@ -378,16 +472,16 @@ func TestConvertClaudeRequestToOpenAI_ToolResultObjectContent(t *testing.T) {
resultJSON := gjson.ParseBytes(result) resultJSON := gjson.ParseBytes(result)
messages := resultJSON.Get("messages").Array() messages := resultJSON.Get("messages").Array()
// system + assistant(tool_calls) + tool(result) // assistant(tool_calls) + tool(result)
if len(messages) != 3 { if len(messages) != 2 {
t.Fatalf("Expected 3 messages, got %d. Messages: %s", len(messages), resultJSON.Get("messages").Raw) t.Fatalf("Expected 2 messages, got %d. Messages: %s", len(messages), resultJSON.Get("messages").Raw)
} }
if messages[2].Get("role").String() != "tool" { if messages[1].Get("role").String() != "tool" {
t.Fatalf("Expected messages[2] to be tool, got %s", messages[2].Get("role").String()) t.Fatalf("Expected messages[1] to be tool, got %s", messages[1].Get("role").String())
} }
toolContent := messages[2].Get("content").String() toolContent := messages[1].Get("content").String()
parsed := gjson.Parse(toolContent) parsed := gjson.Parse(toolContent)
if parsed.Get("foo").String() != "bar" { if parsed.Get("foo").String() != "bar" {
t.Fatalf("Expected tool content JSON foo=bar, got %q", toolContent) t.Fatalf("Expected tool content JSON foo=bar, got %q", toolContent)
@@ -414,18 +508,14 @@ func TestConvertClaudeRequestToOpenAI_AssistantTextToolUseTextOrder(t *testing.T
messages := resultJSON.Get("messages").Array() messages := resultJSON.Get("messages").Array()
// New behavior: content + tool_calls unified in single assistant message // New behavior: content + tool_calls unified in single assistant message
// Expect: system + assistant(content[pre,post] + tool_calls) // Expect: assistant(content[pre,post] + tool_calls)
if len(messages) != 2 { if len(messages) != 1 {
t.Fatalf("Expected 2 messages, got %d. Messages: %s", len(messages), resultJSON.Get("messages").Raw) t.Fatalf("Expected 1 message, got %d. Messages: %s", len(messages), resultJSON.Get("messages").Raw)
} }
if messages[0].Get("role").String() != "system" { assistantMsg := messages[0]
t.Fatalf("Expected messages[0] to be system, got %s", messages[0].Get("role").String())
}
assistantMsg := messages[1]
if assistantMsg.Get("role").String() != "assistant" { if assistantMsg.Get("role").String() != "assistant" {
t.Fatalf("Expected messages[1] to be assistant, got %s", assistantMsg.Get("role").String()) t.Fatalf("Expected messages[0] to be assistant, got %s", assistantMsg.Get("role").String())
} }
// Should have both content and tool_calls in same message // Should have both content and tool_calls in same message
@@ -470,14 +560,14 @@ func TestConvertClaudeRequestToOpenAI_AssistantThinkingToolUseThinkingSplit(t *t
messages := resultJSON.Get("messages").Array() messages := resultJSON.Get("messages").Array()
// New behavior: all content, thinking, and tool_calls unified in single assistant message // New behavior: all content, thinking, and tool_calls unified in single assistant message
// Expect: system + assistant(content[pre,post] + tool_calls + reasoning_content[t1+t2]) // Expect: assistant(content[pre,post] + tool_calls + reasoning_content[t1+t2])
if len(messages) != 2 { if len(messages) != 1 {
t.Fatalf("Expected 2 messages, got %d. Messages: %s", len(messages), resultJSON.Get("messages").Raw) t.Fatalf("Expected 1 message, got %d. Messages: %s", len(messages), resultJSON.Get("messages").Raw)
} }
assistantMsg := messages[1] assistantMsg := messages[0]
if assistantMsg.Get("role").String() != "assistant" { if assistantMsg.Get("role").String() != "assistant" {
t.Fatalf("Expected messages[1] to be assistant, got %s", assistantMsg.Get("role").String()) t.Fatalf("Expected messages[0] to be assistant, got %s", assistantMsg.Get("role").String())
} }
// Should have content with both pre and post // Should have content with both pre and post

View File

@@ -167,6 +167,16 @@ func SynthesizeGeminiVirtualAuths(primary *coreauth.Auth, metadata map[string]an
"virtual_parent_id": primary.ID, "virtual_parent_id": primary.ID,
"type": metadata["type"], "type": metadata["type"],
} }
if v, ok := metadata["disable_cooling"]; ok {
metadataCopy["disable_cooling"] = v
} else if v, ok := metadata["disable-cooling"]; ok {
metadataCopy["disable_cooling"] = v
}
if v, ok := metadata["request_retry"]; ok {
metadataCopy["request_retry"] = v
} else if v, ok := metadata["request-retry"]; ok {
metadataCopy["request_retry"] = v
}
proxy := strings.TrimSpace(primary.ProxyURL) proxy := strings.TrimSpace(primary.ProxyURL)
if proxy != "" { if proxy != "" {
metadataCopy["proxy_url"] = proxy metadataCopy["proxy_url"] = proxy

View File

@@ -69,10 +69,12 @@ func TestFileSynthesizer_Synthesize_ValidAuthFile(t *testing.T) {
// Create a valid auth file // Create a valid auth file
authData := map[string]any{ authData := map[string]any{
"type": "claude", "type": "claude",
"email": "test@example.com", "email": "test@example.com",
"proxy_url": "http://proxy.local", "proxy_url": "http://proxy.local",
"prefix": "test-prefix", "prefix": "test-prefix",
"disable_cooling": true,
"request_retry": 2,
} }
data, _ := json.Marshal(authData) data, _ := json.Marshal(authData)
err := os.WriteFile(filepath.Join(tempDir, "claude-auth.json"), data, 0644) err := os.WriteFile(filepath.Join(tempDir, "claude-auth.json"), data, 0644)
@@ -108,6 +110,12 @@ func TestFileSynthesizer_Synthesize_ValidAuthFile(t *testing.T) {
if auths[0].ProxyURL != "http://proxy.local" { if auths[0].ProxyURL != "http://proxy.local" {
t.Errorf("expected proxy_url http://proxy.local, got %s", auths[0].ProxyURL) t.Errorf("expected proxy_url http://proxy.local, got %s", auths[0].ProxyURL)
} }
if v, ok := auths[0].Metadata["disable_cooling"].(bool); !ok || !v {
t.Errorf("expected disable_cooling true, got %v", auths[0].Metadata["disable_cooling"])
}
if v, ok := auths[0].Metadata["request_retry"].(float64); !ok || int(v) != 2 {
t.Errorf("expected request_retry 2, got %v", auths[0].Metadata["request_retry"])
}
if auths[0].Status != coreauth.StatusActive { if auths[0].Status != coreauth.StatusActive {
t.Errorf("expected status active, got %s", auths[0].Status) t.Errorf("expected status active, got %s", auths[0].Status)
} }
@@ -336,9 +344,11 @@ func TestSynthesizeGeminiVirtualAuths_MultiProject(t *testing.T) {
}, },
} }
metadata := map[string]any{ metadata := map[string]any{
"project_id": "project-a, project-b, project-c", "project_id": "project-a, project-b, project-c",
"email": "test@example.com", "email": "test@example.com",
"type": "gemini", "type": "gemini",
"request_retry": 2,
"disable_cooling": true,
} }
virtuals := SynthesizeGeminiVirtualAuths(primary, metadata, now) virtuals := SynthesizeGeminiVirtualAuths(primary, metadata, now)
@@ -376,6 +386,12 @@ func TestSynthesizeGeminiVirtualAuths_MultiProject(t *testing.T) {
if v.ProxyURL != "http://proxy.local" { if v.ProxyURL != "http://proxy.local" {
t.Errorf("expected proxy_url http://proxy.local, got %s", v.ProxyURL) t.Errorf("expected proxy_url http://proxy.local, got %s", v.ProxyURL)
} }
if vv, ok := v.Metadata["disable_cooling"].(bool); !ok || !vv {
t.Errorf("expected disable_cooling true, got %v", v.Metadata["disable_cooling"])
}
if vv, ok := v.Metadata["request_retry"].(int); !ok || vv != 2 {
t.Errorf("expected request_retry 2, got %v", v.Metadata["request_retry"])
}
if v.Attributes["runtime_only"] != "true" { if v.Attributes["runtime_only"] != "true" {
t.Error("expected runtime_only=true") t.Error("expected runtime_only=true")
} }

View File

@@ -128,8 +128,23 @@ func (h *ClaudeCodeAPIHandler) ClaudeCountTokens(c *gin.Context) {
// Parameters: // Parameters:
// - c: The Gin context for the request. // - c: The Gin context for the request.
func (h *ClaudeCodeAPIHandler) ClaudeModels(c *gin.Context) { func (h *ClaudeCodeAPIHandler) ClaudeModels(c *gin.Context) {
models := h.Models()
firstID := ""
lastID := ""
if len(models) > 0 {
if id, ok := models[0]["id"].(string); ok {
firstID = id
}
if id, ok := models[len(models)-1]["id"].(string); ok {
lastID = id
}
}
c.JSON(http.StatusOK, gin.H{ c.JSON(http.StatusOK, gin.H{
"data": h.Models(), "data": models,
"has_more": false,
"first_id": firstID,
"last_id": lastID,
}) })
} }

View File

@@ -60,8 +60,12 @@ func (h *GeminiAPIHandler) GeminiModels(c *gin.Context) {
if !strings.HasPrefix(name, "models/") { if !strings.HasPrefix(name, "models/") {
normalizedModel["name"] = "models/" + name normalizedModel["name"] = "models/" + name
} }
normalizedModel["displayName"] = name if displayName, _ := normalizedModel["displayName"].(string); displayName == "" {
normalizedModel["description"] = name normalizedModel["displayName"] = name
}
if description, _ := normalizedModel["description"].(string); description == "" {
normalizedModel["description"] = name
}
} }
if _, ok := normalizedModel["supportedGenerationMethods"]; !ok { if _, ok := normalizedModel["supportedGenerationMethods"]; !ok {
normalizedModel["supportedGenerationMethods"] = defaultMethods normalizedModel["supportedGenerationMethods"] = defaultMethods

View File

@@ -2,15 +2,13 @@ package auth
import ( import (
"context" "context"
"encoding/json"
"fmt" "fmt"
"io"
"net" "net"
"net/http" "net/http"
"net/url"
"strings" "strings"
"time" "time"
"github.com/router-for-me/CLIProxyAPI/v6/internal/auth/antigravity"
"github.com/router-for-me/CLIProxyAPI/v6/internal/browser" "github.com/router-for-me/CLIProxyAPI/v6/internal/browser"
"github.com/router-for-me/CLIProxyAPI/v6/internal/config" "github.com/router-for-me/CLIProxyAPI/v6/internal/config"
"github.com/router-for-me/CLIProxyAPI/v6/internal/misc" "github.com/router-for-me/CLIProxyAPI/v6/internal/misc"
@@ -19,20 +17,6 @@ import (
log "github.com/sirupsen/logrus" log "github.com/sirupsen/logrus"
) )
const (
antigravityClientID = "1071006060591-tmhssin2h21lcre235vtolojh4g403ep.apps.googleusercontent.com"
antigravityClientSecret = "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"
antigravityCallbackPort = 51121
)
var antigravityScopes = []string{
"https://www.googleapis.com/auth/cloud-platform",
"https://www.googleapis.com/auth/userinfo.email",
"https://www.googleapis.com/auth/userinfo.profile",
"https://www.googleapis.com/auth/cclog",
"https://www.googleapis.com/auth/experimentsandconfigs",
}
// AntigravityAuthenticator implements OAuth login for the antigravity provider. // AntigravityAuthenticator implements OAuth login for the antigravity provider.
type AntigravityAuthenticator struct{} type AntigravityAuthenticator struct{}
@@ -60,12 +44,12 @@ func (AntigravityAuthenticator) Login(ctx context.Context, cfg *config.Config, o
opts = &LoginOptions{} opts = &LoginOptions{}
} }
callbackPort := antigravityCallbackPort callbackPort := antigravity.CallbackPort
if opts.CallbackPort > 0 { if opts.CallbackPort > 0 {
callbackPort = opts.CallbackPort callbackPort = opts.CallbackPort
} }
httpClient := util.SetProxy(&cfg.SDKConfig, &http.Client{}) authSvc := antigravity.NewAntigravityAuth(cfg, nil)
state, err := misc.GenerateRandomState() state, err := misc.GenerateRandomState()
if err != nil { if err != nil {
@@ -83,7 +67,7 @@ func (AntigravityAuthenticator) Login(ctx context.Context, cfg *config.Config, o
}() }()
redirectURI := fmt.Sprintf("http://localhost:%d/oauth-callback", port) redirectURI := fmt.Sprintf("http://localhost:%d/oauth-callback", port)
authURL := buildAntigravityAuthURL(redirectURI, state) authURL := authSvc.BuildAuthURL(state, redirectURI)
if !opts.NoBrowser { if !opts.NoBrowser {
fmt.Println("Opening browser for antigravity authentication") fmt.Println("Opening browser for antigravity authentication")
@@ -164,22 +148,29 @@ waitForCallback:
return nil, fmt.Errorf("antigravity: missing authorization code") return nil, fmt.Errorf("antigravity: missing authorization code")
} }
tokenResp, errToken := exchangeAntigravityCode(ctx, cbRes.Code, redirectURI, httpClient) tokenResp, errToken := authSvc.ExchangeCodeForTokens(ctx, cbRes.Code, redirectURI)
if errToken != nil { if errToken != nil {
return nil, fmt.Errorf("antigravity: token exchange failed: %w", errToken) return nil, fmt.Errorf("antigravity: token exchange failed: %w", errToken)
} }
email := "" accessToken := strings.TrimSpace(tokenResp.AccessToken)
if tokenResp.AccessToken != "" { if accessToken == "" {
if info, errInfo := fetchAntigravityUserInfo(ctx, tokenResp.AccessToken, httpClient); errInfo == nil && strings.TrimSpace(info.Email) != "" { return nil, fmt.Errorf("antigravity: token exchange returned empty access token")
email = strings.TrimSpace(info.Email) }
}
email, errInfo := authSvc.FetchUserInfo(ctx, accessToken)
if errInfo != nil {
return nil, fmt.Errorf("antigravity: fetch user info failed: %w", errInfo)
}
email = strings.TrimSpace(email)
if email == "" {
return nil, fmt.Errorf("antigravity: empty email returned from user info")
} }
// Fetch project ID via loadCodeAssist (same approach as Gemini CLI) // Fetch project ID via loadCodeAssist (same approach as Gemini CLI)
projectID := "" projectID := ""
if tokenResp.AccessToken != "" { if accessToken != "" {
fetchedProjectID, errProject := fetchAntigravityProjectID(ctx, tokenResp.AccessToken, httpClient) fetchedProjectID, errProject := authSvc.FetchProjectID(ctx, accessToken)
if errProject != nil { if errProject != nil {
log.Warnf("antigravity: failed to fetch project ID: %v", errProject) log.Warnf("antigravity: failed to fetch project ID: %v", errProject)
} else { } else {
@@ -204,7 +195,7 @@ waitForCallback:
metadata["project_id"] = projectID metadata["project_id"] = projectID
} }
fileName := sanitizeAntigravityFileName(email) fileName := antigravity.CredentialFileName(email)
label := email label := email
if label == "" { if label == "" {
label = "antigravity" label = "antigravity"
@@ -231,7 +222,7 @@ type callbackResult struct {
func startAntigravityCallbackServer(port int) (*http.Server, int, <-chan callbackResult, error) { func startAntigravityCallbackServer(port int) (*http.Server, int, <-chan callbackResult, error) {
if port <= 0 { if port <= 0 {
port = antigravityCallbackPort port = antigravity.CallbackPort
} }
addr := fmt.Sprintf(":%d", port) addr := fmt.Sprintf(":%d", port)
listener, err := net.Listen("tcp", addr) listener, err := net.Listen("tcp", addr)
@@ -267,309 +258,9 @@ func startAntigravityCallbackServer(port int) (*http.Server, int, <-chan callbac
return srv, port, resultCh, nil return srv, port, resultCh, nil
} }
type antigravityTokenResponse struct {
AccessToken string `json:"access_token"`
RefreshToken string `json:"refresh_token"`
ExpiresIn int64 `json:"expires_in"`
TokenType string `json:"token_type"`
}
func exchangeAntigravityCode(ctx context.Context, code, redirectURI string, httpClient *http.Client) (*antigravityTokenResponse, error) {
data := url.Values{}
data.Set("code", code)
data.Set("client_id", antigravityClientID)
data.Set("client_secret", antigravityClientSecret)
data.Set("redirect_uri", redirectURI)
data.Set("grant_type", "authorization_code")
req, err := http.NewRequestWithContext(ctx, http.MethodPost, "https://oauth2.googleapis.com/token", strings.NewReader(data.Encode()))
if err != nil {
return nil, err
}
req.Header.Set("Content-Type", "application/x-www-form-urlencoded")
resp, errDo := httpClient.Do(req)
if errDo != nil {
return nil, errDo
}
defer func() {
if errClose := resp.Body.Close(); errClose != nil {
log.Errorf("antigravity token exchange: close body error: %v", errClose)
}
}()
var token antigravityTokenResponse
if errDecode := json.NewDecoder(resp.Body).Decode(&token); errDecode != nil {
return nil, errDecode
}
if resp.StatusCode < http.StatusOK || resp.StatusCode >= http.StatusMultipleChoices {
return nil, fmt.Errorf("oauth token exchange failed: status %d", resp.StatusCode)
}
return &token, nil
}
type antigravityUserInfo struct {
Email string `json:"email"`
}
func fetchAntigravityUserInfo(ctx context.Context, accessToken string, httpClient *http.Client) (*antigravityUserInfo, error) {
if strings.TrimSpace(accessToken) == "" {
return &antigravityUserInfo{}, nil
}
req, err := http.NewRequestWithContext(ctx, http.MethodGet, "https://www.googleapis.com/oauth2/v1/userinfo?alt=json", nil)
if err != nil {
return nil, err
}
req.Header.Set("Authorization", "Bearer "+accessToken)
resp, errDo := httpClient.Do(req)
if errDo != nil {
return nil, errDo
}
defer func() {
if errClose := resp.Body.Close(); errClose != nil {
log.Errorf("antigravity userinfo: close body error: %v", errClose)
}
}()
if resp.StatusCode < http.StatusOK || resp.StatusCode >= http.StatusMultipleChoices {
return &antigravityUserInfo{}, nil
}
var info antigravityUserInfo
if errDecode := json.NewDecoder(resp.Body).Decode(&info); errDecode != nil {
return nil, errDecode
}
return &info, nil
}
func buildAntigravityAuthURL(redirectURI, state string) string {
params := url.Values{}
params.Set("access_type", "offline")
params.Set("client_id", antigravityClientID)
params.Set("prompt", "consent")
params.Set("redirect_uri", redirectURI)
params.Set("response_type", "code")
params.Set("scope", strings.Join(antigravityScopes, " "))
params.Set("state", state)
return "https://accounts.google.com/o/oauth2/v2/auth?" + params.Encode()
}
func sanitizeAntigravityFileName(email string) string {
if strings.TrimSpace(email) == "" {
return "antigravity.json"
}
replacer := strings.NewReplacer("@", "_", ".", "_")
return fmt.Sprintf("antigravity-%s.json", replacer.Replace(email))
}
// Antigravity API constants for project discovery
const (
antigravityAPIEndpoint = "https://cloudcode-pa.googleapis.com"
antigravityAPIVersion = "v1internal"
antigravityAPIUserAgent = "google-api-nodejs-client/9.15.1"
antigravityAPIClient = "google-cloud-sdk vscode_cloudshelleditor/0.1"
antigravityClientMetadata = `{"ideType":"IDE_UNSPECIFIED","platform":"PLATFORM_UNSPECIFIED","pluginType":"GEMINI"}`
)
// FetchAntigravityProjectID exposes project discovery for external callers. // FetchAntigravityProjectID exposes project discovery for external callers.
func FetchAntigravityProjectID(ctx context.Context, accessToken string, httpClient *http.Client) (string, error) { func FetchAntigravityProjectID(ctx context.Context, accessToken string, httpClient *http.Client) (string, error) {
return fetchAntigravityProjectID(ctx, accessToken, httpClient) cfg := &config.Config{}
} authSvc := antigravity.NewAntigravityAuth(cfg, httpClient)
return authSvc.FetchProjectID(ctx, accessToken)
// fetchAntigravityProjectID retrieves the project ID for the authenticated user via loadCodeAssist.
// This uses the same approach as Gemini CLI to get the cloudaicompanionProject.
func fetchAntigravityProjectID(ctx context.Context, accessToken string, httpClient *http.Client) (string, error) {
// Call loadCodeAssist to get the project
loadReqBody := map[string]any{
"metadata": map[string]string{
"ideType": "ANTIGRAVITY",
"platform": "PLATFORM_UNSPECIFIED",
"pluginType": "GEMINI",
},
}
rawBody, errMarshal := json.Marshal(loadReqBody)
if errMarshal != nil {
return "", fmt.Errorf("marshal request body: %w", errMarshal)
}
endpointURL := fmt.Sprintf("%s/%s:loadCodeAssist", antigravityAPIEndpoint, antigravityAPIVersion)
req, err := http.NewRequestWithContext(ctx, http.MethodPost, endpointURL, strings.NewReader(string(rawBody)))
if err != nil {
return "", fmt.Errorf("create request: %w", err)
}
req.Header.Set("Authorization", "Bearer "+accessToken)
req.Header.Set("Content-Type", "application/json")
req.Header.Set("User-Agent", antigravityAPIUserAgent)
req.Header.Set("X-Goog-Api-Client", antigravityAPIClient)
req.Header.Set("Client-Metadata", antigravityClientMetadata)
resp, errDo := httpClient.Do(req)
if errDo != nil {
return "", fmt.Errorf("execute request: %w", errDo)
}
defer func() {
if errClose := resp.Body.Close(); errClose != nil {
log.Errorf("antigravity loadCodeAssist: close body error: %v", errClose)
}
}()
bodyBytes, errRead := io.ReadAll(resp.Body)
if errRead != nil {
return "", fmt.Errorf("read response: %w", errRead)
}
if resp.StatusCode < http.StatusOK || resp.StatusCode >= http.StatusMultipleChoices {
return "", fmt.Errorf("request failed with status %d: %s", resp.StatusCode, strings.TrimSpace(string(bodyBytes)))
}
var loadResp map[string]any
if errDecode := json.Unmarshal(bodyBytes, &loadResp); errDecode != nil {
return "", fmt.Errorf("decode response: %w", errDecode)
}
// Extract projectID from response
projectID := ""
if id, ok := loadResp["cloudaicompanionProject"].(string); ok {
projectID = strings.TrimSpace(id)
}
if projectID == "" {
if projectMap, ok := loadResp["cloudaicompanionProject"].(map[string]any); ok {
if id, okID := projectMap["id"].(string); okID {
projectID = strings.TrimSpace(id)
}
}
}
if projectID == "" {
tierID := "legacy-tier"
if tiers, okTiers := loadResp["allowedTiers"].([]any); okTiers {
for _, rawTier := range tiers {
tier, okTier := rawTier.(map[string]any)
if !okTier {
continue
}
if isDefault, okDefault := tier["isDefault"].(bool); okDefault && isDefault {
if id, okID := tier["id"].(string); okID && strings.TrimSpace(id) != "" {
tierID = strings.TrimSpace(id)
break
}
}
}
}
projectID, err = antigravityOnboardUser(ctx, accessToken, tierID, httpClient)
if err != nil {
return "", err
}
return projectID, nil
}
return projectID, nil
}
// antigravityOnboardUser attempts to fetch the project ID via onboardUser by polling for completion.
// It returns an empty string when the operation times out or completes without a project ID.
func antigravityOnboardUser(ctx context.Context, accessToken, tierID string, httpClient *http.Client) (string, error) {
if httpClient == nil {
httpClient = http.DefaultClient
}
fmt.Println("Antigravity: onboarding user...", tierID)
requestBody := map[string]any{
"tierId": tierID,
"metadata": map[string]string{
"ideType": "ANTIGRAVITY",
"platform": "PLATFORM_UNSPECIFIED",
"pluginType": "GEMINI",
},
}
rawBody, errMarshal := json.Marshal(requestBody)
if errMarshal != nil {
return "", fmt.Errorf("marshal request body: %w", errMarshal)
}
maxAttempts := 5
for attempt := 1; attempt <= maxAttempts; attempt++ {
log.Debugf("Polling attempt %d/%d", attempt, maxAttempts)
reqCtx := ctx
var cancel context.CancelFunc
if reqCtx == nil {
reqCtx = context.Background()
}
reqCtx, cancel = context.WithTimeout(reqCtx, 30*time.Second)
endpointURL := fmt.Sprintf("%s/%s:onboardUser", antigravityAPIEndpoint, antigravityAPIVersion)
req, errRequest := http.NewRequestWithContext(reqCtx, http.MethodPost, endpointURL, strings.NewReader(string(rawBody)))
if errRequest != nil {
cancel()
return "", fmt.Errorf("create request: %w", errRequest)
}
req.Header.Set("Authorization", "Bearer "+accessToken)
req.Header.Set("Content-Type", "application/json")
req.Header.Set("User-Agent", antigravityAPIUserAgent)
req.Header.Set("X-Goog-Api-Client", antigravityAPIClient)
req.Header.Set("Client-Metadata", antigravityClientMetadata)
resp, errDo := httpClient.Do(req)
if errDo != nil {
cancel()
return "", fmt.Errorf("execute request: %w", errDo)
}
bodyBytes, errRead := io.ReadAll(resp.Body)
if errClose := resp.Body.Close(); errClose != nil {
log.Errorf("close body error: %v", errClose)
}
cancel()
if errRead != nil {
return "", fmt.Errorf("read response: %w", errRead)
}
if resp.StatusCode == http.StatusOK {
var data map[string]any
if errDecode := json.Unmarshal(bodyBytes, &data); errDecode != nil {
return "", fmt.Errorf("decode response: %w", errDecode)
}
if done, okDone := data["done"].(bool); okDone && done {
projectID := ""
if responseData, okResp := data["response"].(map[string]any); okResp {
switch projectValue := responseData["cloudaicompanionProject"].(type) {
case map[string]any:
if id, okID := projectValue["id"].(string); okID {
projectID = strings.TrimSpace(id)
}
case string:
projectID = strings.TrimSpace(projectValue)
}
}
if projectID != "" {
log.Infof("Successfully fetched project_id: %s", projectID)
return projectID, nil
}
return "", fmt.Errorf("no project_id in response")
}
time.Sleep(2 * time.Second)
continue
}
responsePreview := strings.TrimSpace(string(bodyBytes))
if len(responsePreview) > 500 {
responsePreview = responsePreview[:500]
}
responseErr := responsePreview
if len(responseErr) > 200 {
responseErr = responseErr[:200]
}
return "", fmt.Errorf("http %d: %s", resp.StatusCode, responseErr)
}
return "", nil
} }

View File

@@ -73,9 +73,7 @@ func (s *FileTokenStore) Save(ctx context.Context, auth *cliproxyauth.Auth) (str
return "", fmt.Errorf("auth filestore: marshal metadata failed: %w", errMarshal) return "", fmt.Errorf("auth filestore: marshal metadata failed: %w", errMarshal)
} }
if existing, errRead := os.ReadFile(path); errRead == nil { if existing, errRead := os.ReadFile(path); errRead == nil {
// Use metadataEqualIgnoringTimestamps to skip writes when only timestamp fields change. if jsonEqual(existing, raw) {
// This prevents the token refresh loop caused by timestamp/expired/expires_in changes.
if metadataEqualIgnoringTimestamps(existing, raw, auth.Provider) {
return path, nil return path, nil
} }
file, errOpen := os.OpenFile(path, os.O_WRONLY|os.O_TRUNC, 0o600) file, errOpen := os.OpenFile(path, os.O_WRONLY|os.O_TRUNC, 0o600)
@@ -308,8 +306,7 @@ func (s *FileTokenStore) baseDirSnapshot() string {
return s.baseDir return s.baseDir
} }
// DEPRECATED: Use metadataEqualIgnoringTimestamps for comparing auth metadata. // jsonEqual compares two JSON blobs by parsing them into Go objects and deep comparing.
// This function is kept for backward compatibility but can cause refresh loops.
func jsonEqual(a, b []byte) bool { func jsonEqual(a, b []byte) bool {
var objA any var objA any
var objB any var objB any
@@ -322,41 +319,6 @@ func jsonEqual(a, b []byte) bool {
return deepEqualJSON(objA, objB) return deepEqualJSON(objA, objB)
} }
// metadataEqualIgnoringTimestamps compares two metadata JSON blobs,
// ignoring fields that change on every refresh but don't affect functionality.
// This prevents unnecessary file writes that would trigger watcher events and
// create refresh loops.
// The provider parameter controls whether access_token is ignored: providers like
// Google OAuth (gemini, gemini-cli) can re-fetch tokens when needed, while others
// like iFlow require the refreshed token to be persisted.
func metadataEqualIgnoringTimestamps(a, b []byte, provider string) bool {
var objA, objB map[string]any
if err := json.Unmarshal(a, &objA); err != nil {
return false
}
if err := json.Unmarshal(b, &objB); err != nil {
return false
}
// Fields to ignore: these change on every refresh but don't affect authentication logic.
// - timestamp, expired, expires_in, last_refresh: time-related fields that change on refresh
ignoredFields := []string{"timestamp", "expired", "expires_in", "last_refresh"}
// For providers that can re-fetch tokens when needed (e.g., Google OAuth),
// we ignore access_token to avoid unnecessary file writes.
switch provider {
case "gemini", "gemini-cli", "antigravity":
ignoredFields = append(ignoredFields, "access_token")
}
for _, field := range ignoredFields {
delete(objA, field)
delete(objB, field)
}
return deepEqualJSON(objA, objB)
}
func deepEqualJSON(a, b any) bool { func deepEqualJSON(a, b any) bool {
switch valA := a.(type) { switch valA := a.(type) {
case map[string]any: case map[string]any:

View File

@@ -61,6 +61,15 @@ func SetQuotaCooldownDisabled(disable bool) {
quotaCooldownDisabled.Store(disable) quotaCooldownDisabled.Store(disable)
} }
func quotaCooldownDisabledForAuth(auth *Auth) bool {
if auth != nil {
if override, ok := auth.DisableCoolingOverride(); ok {
return override
}
}
return quotaCooldownDisabled.Load()
}
// Result captures execution outcome used to adjust auth state. // Result captures execution outcome used to adjust auth state.
type Result struct { type Result struct {
// AuthID references the auth that produced this result. // AuthID references the auth that produced this result.
@@ -468,20 +477,16 @@ func (m *Manager) Execute(ctx context.Context, providers []string, req cliproxye
return cliproxyexecutor.Response{}, &Error{Code: "provider_not_found", Message: "no provider supplied"} return cliproxyexecutor.Response{}, &Error{Code: "provider_not_found", Message: "no provider supplied"}
} }
retryTimes, maxWait := m.retrySettings() _, maxWait := m.retrySettings()
attempts := retryTimes + 1
if attempts < 1 {
attempts = 1
}
var lastErr error var lastErr error
for attempt := 0; attempt < attempts; attempt++ { for attempt := 0; ; attempt++ {
resp, errExec := m.executeMixedOnce(ctx, normalized, req, opts) resp, errExec := m.executeMixedOnce(ctx, normalized, req, opts)
if errExec == nil { if errExec == nil {
return resp, nil return resp, nil
} }
lastErr = errExec lastErr = errExec
wait, shouldRetry := m.shouldRetryAfterError(errExec, attempt, attempts, normalized, req.Model, maxWait) wait, shouldRetry := m.shouldRetryAfterError(errExec, attempt, normalized, req.Model, maxWait)
if !shouldRetry { if !shouldRetry {
break break
} }
@@ -503,20 +508,16 @@ func (m *Manager) ExecuteCount(ctx context.Context, providers []string, req clip
return cliproxyexecutor.Response{}, &Error{Code: "provider_not_found", Message: "no provider supplied"} return cliproxyexecutor.Response{}, &Error{Code: "provider_not_found", Message: "no provider supplied"}
} }
retryTimes, maxWait := m.retrySettings() _, maxWait := m.retrySettings()
attempts := retryTimes + 1
if attempts < 1 {
attempts = 1
}
var lastErr error var lastErr error
for attempt := 0; attempt < attempts; attempt++ { for attempt := 0; ; attempt++ {
resp, errExec := m.executeCountMixedOnce(ctx, normalized, req, opts) resp, errExec := m.executeCountMixedOnce(ctx, normalized, req, opts)
if errExec == nil { if errExec == nil {
return resp, nil return resp, nil
} }
lastErr = errExec lastErr = errExec
wait, shouldRetry := m.shouldRetryAfterError(errExec, attempt, attempts, normalized, req.Model, maxWait) wait, shouldRetry := m.shouldRetryAfterError(errExec, attempt, normalized, req.Model, maxWait)
if !shouldRetry { if !shouldRetry {
break break
} }
@@ -538,20 +539,16 @@ func (m *Manager) ExecuteStream(ctx context.Context, providers []string, req cli
return nil, &Error{Code: "provider_not_found", Message: "no provider supplied"} return nil, &Error{Code: "provider_not_found", Message: "no provider supplied"}
} }
retryTimes, maxWait := m.retrySettings() _, maxWait := m.retrySettings()
attempts := retryTimes + 1
if attempts < 1 {
attempts = 1
}
var lastErr error var lastErr error
for attempt := 0; attempt < attempts; attempt++ { for attempt := 0; ; attempt++ {
chunks, errStream := m.executeStreamMixedOnce(ctx, normalized, req, opts) chunks, errStream := m.executeStreamMixedOnce(ctx, normalized, req, opts)
if errStream == nil { if errStream == nil {
return chunks, nil return chunks, nil
} }
lastErr = errStream lastErr = errStream
wait, shouldRetry := m.shouldRetryAfterError(errStream, attempt, attempts, normalized, req.Model, maxWait) wait, shouldRetry := m.shouldRetryAfterError(errStream, attempt, normalized, req.Model, maxWait)
if !shouldRetry { if !shouldRetry {
break break
} }
@@ -1034,11 +1031,15 @@ func (m *Manager) retrySettings() (int, time.Duration) {
return int(m.requestRetry.Load()), time.Duration(m.maxRetryInterval.Load()) return int(m.requestRetry.Load()), time.Duration(m.maxRetryInterval.Load())
} }
func (m *Manager) closestCooldownWait(providers []string, model string) (time.Duration, bool) { func (m *Manager) closestCooldownWait(providers []string, model string, attempt int) (time.Duration, bool) {
if m == nil || len(providers) == 0 { if m == nil || len(providers) == 0 {
return 0, false return 0, false
} }
now := time.Now() now := time.Now()
defaultRetry := int(m.requestRetry.Load())
if defaultRetry < 0 {
defaultRetry = 0
}
providerSet := make(map[string]struct{}, len(providers)) providerSet := make(map[string]struct{}, len(providers))
for i := range providers { for i := range providers {
key := strings.TrimSpace(strings.ToLower(providers[i])) key := strings.TrimSpace(strings.ToLower(providers[i]))
@@ -1061,6 +1062,16 @@ func (m *Manager) closestCooldownWait(providers []string, model string) (time.Du
if _, ok := providerSet[providerKey]; !ok { if _, ok := providerSet[providerKey]; !ok {
continue continue
} }
effectiveRetry := defaultRetry
if override, ok := auth.RequestRetryOverride(); ok {
effectiveRetry = override
}
if effectiveRetry < 0 {
effectiveRetry = 0
}
if attempt >= effectiveRetry {
continue
}
blocked, reason, next := isAuthBlockedForModel(auth, model, now) blocked, reason, next := isAuthBlockedForModel(auth, model, now)
if !blocked || next.IsZero() || reason == blockReasonDisabled { if !blocked || next.IsZero() || reason == blockReasonDisabled {
continue continue
@@ -1077,8 +1088,8 @@ func (m *Manager) closestCooldownWait(providers []string, model string) (time.Du
return minWait, found return minWait, found
} }
func (m *Manager) shouldRetryAfterError(err error, attempt, maxAttempts int, providers []string, model string, maxWait time.Duration) (time.Duration, bool) { func (m *Manager) shouldRetryAfterError(err error, attempt int, providers []string, model string, maxWait time.Duration) (time.Duration, bool) {
if err == nil || attempt >= maxAttempts-1 { if err == nil {
return 0, false return 0, false
} }
if maxWait <= 0 { if maxWait <= 0 {
@@ -1087,7 +1098,7 @@ func (m *Manager) shouldRetryAfterError(err error, attempt, maxAttempts int, pro
if status := statusCodeFromError(err); status == http.StatusOK { if status := statusCodeFromError(err); status == http.StatusOK {
return 0, false return 0, false
} }
wait, found := m.closestCooldownWait(providers, model) wait, found := m.closestCooldownWait(providers, model, attempt)
if !found || wait > maxWait { if !found || wait > maxWait {
return 0, false return 0, false
} }
@@ -1176,7 +1187,7 @@ func (m *Manager) MarkResult(ctx context.Context, result Result) {
if result.RetryAfter != nil { if result.RetryAfter != nil {
next = now.Add(*result.RetryAfter) next = now.Add(*result.RetryAfter)
} else { } else {
cooldown, nextLevel := nextQuotaCooldown(backoffLevel) cooldown, nextLevel := nextQuotaCooldown(backoffLevel, quotaCooldownDisabledForAuth(auth))
if cooldown > 0 { if cooldown > 0 {
next = now.Add(cooldown) next = now.Add(cooldown)
} }
@@ -1193,7 +1204,7 @@ func (m *Manager) MarkResult(ctx context.Context, result Result) {
shouldSuspendModel = true shouldSuspendModel = true
setModelQuota = true setModelQuota = true
case 408, 500, 502, 503, 504: case 408, 500, 502, 503, 504:
if quotaCooldownDisabled.Load() { if quotaCooldownDisabledForAuth(auth) {
state.NextRetryAfter = time.Time{} state.NextRetryAfter = time.Time{}
} else { } else {
next := now.Add(1 * time.Minute) next := now.Add(1 * time.Minute)
@@ -1439,7 +1450,7 @@ func applyAuthFailureState(auth *Auth, resultErr *Error, retryAfter *time.Durati
if retryAfter != nil { if retryAfter != nil {
next = now.Add(*retryAfter) next = now.Add(*retryAfter)
} else { } else {
cooldown, nextLevel := nextQuotaCooldown(auth.Quota.BackoffLevel) cooldown, nextLevel := nextQuotaCooldown(auth.Quota.BackoffLevel, quotaCooldownDisabledForAuth(auth))
if cooldown > 0 { if cooldown > 0 {
next = now.Add(cooldown) next = now.Add(cooldown)
} }
@@ -1449,7 +1460,7 @@ func applyAuthFailureState(auth *Auth, resultErr *Error, retryAfter *time.Durati
auth.NextRetryAfter = next auth.NextRetryAfter = next
case 408, 500, 502, 503, 504: case 408, 500, 502, 503, 504:
auth.StatusMessage = "transient upstream error" auth.StatusMessage = "transient upstream error"
if quotaCooldownDisabled.Load() { if quotaCooldownDisabledForAuth(auth) {
auth.NextRetryAfter = time.Time{} auth.NextRetryAfter = time.Time{}
} else { } else {
auth.NextRetryAfter = now.Add(1 * time.Minute) auth.NextRetryAfter = now.Add(1 * time.Minute)
@@ -1462,11 +1473,11 @@ func applyAuthFailureState(auth *Auth, resultErr *Error, retryAfter *time.Durati
} }
// nextQuotaCooldown returns the next cooldown duration and updated backoff level for repeated quota errors. // nextQuotaCooldown returns the next cooldown duration and updated backoff level for repeated quota errors.
func nextQuotaCooldown(prevLevel int) (time.Duration, int) { func nextQuotaCooldown(prevLevel int, disableCooling bool) (time.Duration, int) {
if prevLevel < 0 { if prevLevel < 0 {
prevLevel = 0 prevLevel = 0
} }
if quotaCooldownDisabled.Load() { if disableCooling {
return 0, prevLevel return 0, prevLevel
} }
cooldown := quotaBackoffBase * time.Duration(1<<prevLevel) cooldown := quotaBackoffBase * time.Duration(1<<prevLevel)
@@ -1642,6 +1653,9 @@ func (m *Manager) persist(ctx context.Context, auth *Auth) error {
if m.store == nil || auth == nil { if m.store == nil || auth == nil {
return nil return nil
} }
if shouldSkipPersist(ctx) {
return nil
}
if auth.Attributes != nil { if auth.Attributes != nil {
if v := strings.ToLower(strings.TrimSpace(auth.Attributes["runtime_only"])); v == "true" { if v := strings.ToLower(strings.TrimSpace(auth.Attributes["runtime_only"])); v == "true" {
return nil return nil

View File

@@ -0,0 +1,97 @@
package auth
import (
"context"
"testing"
"time"
)
func TestManager_ShouldRetryAfterError_RespectsAuthRequestRetryOverride(t *testing.T) {
m := NewManager(nil, nil, nil)
m.SetRetryConfig(3, 30*time.Second)
model := "test-model"
next := time.Now().Add(5 * time.Second)
auth := &Auth{
ID: "auth-1",
Provider: "claude",
Metadata: map[string]any{
"request_retry": float64(0),
},
ModelStates: map[string]*ModelState{
model: {
Unavailable: true,
Status: StatusError,
NextRetryAfter: next,
},
},
}
if _, errRegister := m.Register(context.Background(), auth); errRegister != nil {
t.Fatalf("register auth: %v", errRegister)
}
_, maxWait := m.retrySettings()
wait, shouldRetry := m.shouldRetryAfterError(&Error{HTTPStatus: 500, Message: "boom"}, 0, []string{"claude"}, model, maxWait)
if shouldRetry {
t.Fatalf("expected shouldRetry=false for request_retry=0, got true (wait=%v)", wait)
}
auth.Metadata["request_retry"] = float64(1)
if _, errUpdate := m.Update(context.Background(), auth); errUpdate != nil {
t.Fatalf("update auth: %v", errUpdate)
}
wait, shouldRetry = m.shouldRetryAfterError(&Error{HTTPStatus: 500, Message: "boom"}, 0, []string{"claude"}, model, maxWait)
if !shouldRetry {
t.Fatalf("expected shouldRetry=true for request_retry=1, got false")
}
if wait <= 0 {
t.Fatalf("expected wait > 0, got %v", wait)
}
_, shouldRetry = m.shouldRetryAfterError(&Error{HTTPStatus: 500, Message: "boom"}, 1, []string{"claude"}, model, maxWait)
if shouldRetry {
t.Fatalf("expected shouldRetry=false on attempt=1 for request_retry=1, got true")
}
}
func TestManager_MarkResult_RespectsAuthDisableCoolingOverride(t *testing.T) {
prev := quotaCooldownDisabled.Load()
quotaCooldownDisabled.Store(false)
t.Cleanup(func() { quotaCooldownDisabled.Store(prev) })
m := NewManager(nil, nil, nil)
auth := &Auth{
ID: "auth-1",
Provider: "claude",
Metadata: map[string]any{
"disable_cooling": true,
},
}
if _, errRegister := m.Register(context.Background(), auth); errRegister != nil {
t.Fatalf("register auth: %v", errRegister)
}
model := "test-model"
m.MarkResult(context.Background(), Result{
AuthID: "auth-1",
Provider: "claude",
Model: model,
Success: false,
Error: &Error{HTTPStatus: 500, Message: "boom"},
})
updated, ok := m.GetByID("auth-1")
if !ok || updated == nil {
t.Fatalf("expected auth to be present")
}
state := updated.ModelStates[model]
if state == nil {
t.Fatalf("expected model state to be present")
}
if !state.NextRetryAfter.IsZero() {
t.Fatalf("expected NextRetryAfter to be zero when disable_cooling=true, got %v", state.NextRetryAfter)
}
}

View File

@@ -0,0 +1,24 @@
package auth
import "context"
type skipPersistContextKey struct{}
// WithSkipPersist returns a derived context that disables persistence for Manager Update/Register calls.
// It is intended for code paths that are reacting to file watcher events, where the file on disk is
// already the source of truth and persisting again would create a write-back loop.
func WithSkipPersist(ctx context.Context) context.Context {
if ctx == nil {
ctx = context.Background()
}
return context.WithValue(ctx, skipPersistContextKey{}, true)
}
func shouldSkipPersist(ctx context.Context) bool {
if ctx == nil {
return false
}
v := ctx.Value(skipPersistContextKey{})
enabled, ok := v.(bool)
return ok && enabled
}

View File

@@ -0,0 +1,62 @@
package auth
import (
"context"
"sync/atomic"
"testing"
)
type countingStore struct {
saveCount atomic.Int32
}
func (s *countingStore) List(context.Context) ([]*Auth, error) { return nil, nil }
func (s *countingStore) Save(context.Context, *Auth) (string, error) {
s.saveCount.Add(1)
return "", nil
}
func (s *countingStore) Delete(context.Context, string) error { return nil }
func TestWithSkipPersist_DisablesUpdatePersistence(t *testing.T) {
store := &countingStore{}
mgr := NewManager(store, nil, nil)
auth := &Auth{
ID: "auth-1",
Provider: "antigravity",
Metadata: map[string]any{"type": "antigravity"},
}
if _, err := mgr.Update(context.Background(), auth); err != nil {
t.Fatalf("Update returned error: %v", err)
}
if got := store.saveCount.Load(); got != 1 {
t.Fatalf("expected 1 Save call, got %d", got)
}
ctxSkip := WithSkipPersist(context.Background())
if _, err := mgr.Update(ctxSkip, auth); err != nil {
t.Fatalf("Update(skipPersist) returned error: %v", err)
}
if got := store.saveCount.Load(); got != 1 {
t.Fatalf("expected Save call count to remain 1, got %d", got)
}
}
func TestWithSkipPersist_DisablesRegisterPersistence(t *testing.T) {
store := &countingStore{}
mgr := NewManager(store, nil, nil)
auth := &Auth{
ID: "auth-1",
Provider: "antigravity",
Metadata: map[string]any{"type": "antigravity"},
}
if _, err := mgr.Register(WithSkipPersist(context.Background()), auth); err != nil {
t.Fatalf("Register(skipPersist) returned error: %v", err)
}
if got := store.saveCount.Load(); got != 0 {
t.Fatalf("expected 0 Save calls, got %d", got)
}
}

View File

@@ -194,6 +194,108 @@ func (a *Auth) ProxyInfo() string {
return "via proxy" return "via proxy"
} }
// DisableCoolingOverride returns the auth-file scoped disable_cooling override when present.
// The value is read from metadata key "disable_cooling" (or legacy "disable-cooling").
func (a *Auth) DisableCoolingOverride() (bool, bool) {
if a == nil || a.Metadata == nil {
return false, false
}
if val, ok := a.Metadata["disable_cooling"]; ok {
if parsed, okParse := parseBoolAny(val); okParse {
return parsed, true
}
}
if val, ok := a.Metadata["disable-cooling"]; ok {
if parsed, okParse := parseBoolAny(val); okParse {
return parsed, true
}
}
return false, false
}
// RequestRetryOverride returns the auth-file scoped request_retry override when present.
// The value is read from metadata key "request_retry" (or legacy "request-retry").
func (a *Auth) RequestRetryOverride() (int, bool) {
if a == nil || a.Metadata == nil {
return 0, false
}
if val, ok := a.Metadata["request_retry"]; ok {
if parsed, okParse := parseIntAny(val); okParse {
if parsed < 0 {
parsed = 0
}
return parsed, true
}
}
if val, ok := a.Metadata["request-retry"]; ok {
if parsed, okParse := parseIntAny(val); okParse {
if parsed < 0 {
parsed = 0
}
return parsed, true
}
}
return 0, false
}
func parseBoolAny(val any) (bool, bool) {
switch typed := val.(type) {
case bool:
return typed, true
case string:
trimmed := strings.TrimSpace(typed)
if trimmed == "" {
return false, false
}
parsed, err := strconv.ParseBool(trimmed)
if err != nil {
return false, false
}
return parsed, true
case float64:
return typed != 0, true
case json.Number:
parsed, err := typed.Int64()
if err != nil {
return false, false
}
return parsed != 0, true
default:
return false, false
}
}
func parseIntAny(val any) (int, bool) {
switch typed := val.(type) {
case int:
return typed, true
case int32:
return int(typed), true
case int64:
return int(typed), true
case float64:
return int(typed), true
case json.Number:
parsed, err := typed.Int64()
if err != nil {
return 0, false
}
return int(parsed), true
case string:
trimmed := strings.TrimSpace(typed)
if trimmed == "" {
return 0, false
}
parsed, err := strconv.Atoi(trimmed)
if err != nil {
return 0, false
}
return parsed, true
default:
return 0, false
}
}
func (a *Auth) AccountInfo() (string, string) { func (a *Auth) AccountInfo() (string, string) {
if a == nil { if a == nil {
return "", "" return "", ""

View File

@@ -135,6 +135,7 @@ func (s *Service) ensureAuthUpdateQueue(ctx context.Context) {
} }
func (s *Service) consumeAuthUpdates(ctx context.Context) { func (s *Service) consumeAuthUpdates(ctx context.Context) {
ctx = coreauth.WithSkipPersist(ctx)
for { for {
select { select {
case <-ctx.Done(): case <-ctx.Done():