mirror of
https://github.com/router-for-me/CLIProxyAPIPlus.git
synced 2026-04-26 05:26:11 +00:00
Compare commits
42 Commits
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
851712a49e | ||
|
|
9e34323a40 | ||
|
|
70897247b2 | ||
|
|
9c341f5aa5 | ||
|
|
e3e741d0be | ||
|
|
7c7c5fd967 | ||
|
|
fe8c7a62aa | ||
|
|
2af4a8dc12 | ||
|
|
0f53b952b2 | ||
|
|
7b2ae7377a | ||
|
|
c2ab288c7d | ||
|
|
dbb433fcf8 | ||
|
|
2abf00b5a6 | ||
|
|
275839e5c9 | ||
|
|
f30ffd5f5e | ||
|
|
bc9a24d705 | ||
|
|
2c879f13ef | ||
|
|
07b4a08979 | ||
|
|
497339f055 | ||
|
|
7f612bb069 | ||
|
|
5743b78694 | ||
|
|
2e6a2b655c | ||
|
|
cb47ac21bf | ||
|
|
a1394b4596 | ||
|
|
9e97948f03 | ||
|
|
8f780e7280 | ||
|
|
46c6fb1e7a | ||
|
|
9f9fec5d4c | ||
|
|
e95be10485 | ||
|
|
f3d58fa0ce | ||
|
|
8c0eaa1f71 | ||
|
|
405df58f72 | ||
|
|
e7f13aa008 | ||
|
|
7cb6a9b89a | ||
|
|
9aa5344c29 | ||
|
|
8ba0ebbd2a | ||
|
|
c65407ab9f | ||
|
|
9e59685212 | ||
|
|
4a4dfaa910 | ||
|
|
0d6ecb0191 | ||
|
|
f16461bfe7 | ||
|
|
8c7c446f33 |
2
go.mod
2
go.mod
@@ -40,6 +40,7 @@ require (
|
|||||||
github.com/dlclark/regexp2 v1.11.5 // indirect
|
github.com/dlclark/regexp2 v1.11.5 // indirect
|
||||||
github.com/dustin/go-humanize v1.0.1 // indirect
|
github.com/dustin/go-humanize v1.0.1 // indirect
|
||||||
github.com/emirpasic/gods v1.18.1 // indirect
|
github.com/emirpasic/gods v1.18.1 // indirect
|
||||||
|
github.com/fxamacker/cbor/v2 v2.9.0 // indirect
|
||||||
github.com/gabriel-vasile/mimetype v1.4.3 // indirect
|
github.com/gabriel-vasile/mimetype v1.4.3 // indirect
|
||||||
github.com/gin-contrib/sse v0.1.0 // indirect
|
github.com/gin-contrib/sse v0.1.0 // indirect
|
||||||
github.com/go-git/gcfg/v2 v2.0.2 // indirect
|
github.com/go-git/gcfg/v2 v2.0.2 // indirect
|
||||||
@@ -69,6 +70,7 @@ require (
|
|||||||
github.com/tidwall/pretty v1.2.0 // indirect
|
github.com/tidwall/pretty v1.2.0 // indirect
|
||||||
github.com/twitchyliquid64/golang-asm v0.15.1 // indirect
|
github.com/twitchyliquid64/golang-asm v0.15.1 // indirect
|
||||||
github.com/ugorji/go/codec v1.2.12 // indirect
|
github.com/ugorji/go/codec v1.2.12 // indirect
|
||||||
|
github.com/x448/float16 v0.8.4 // indirect
|
||||||
golang.org/x/arch v0.8.0 // indirect
|
golang.org/x/arch v0.8.0 // indirect
|
||||||
golang.org/x/sys v0.38.0 // indirect
|
golang.org/x/sys v0.38.0 // indirect
|
||||||
golang.org/x/text v0.31.0 // indirect
|
golang.org/x/text v0.31.0 // indirect
|
||||||
|
|||||||
4
go.sum
4
go.sum
@@ -35,6 +35,8 @@ github.com/emirpasic/gods v1.18.1 h1:FXtiHYKDGKCW2KzwZKx0iC0PQmdlorYgdFG9jPXJ1Bc
|
|||||||
github.com/emirpasic/gods v1.18.1/go.mod h1:8tpGGwCnJ5H4r6BWwaV6OrWmMoPhUl5jm/FMNAnJvWQ=
|
github.com/emirpasic/gods v1.18.1/go.mod h1:8tpGGwCnJ5H4r6BWwaV6OrWmMoPhUl5jm/FMNAnJvWQ=
|
||||||
github.com/fsnotify/fsnotify v1.9.0 h1:2Ml+OJNzbYCTzsxtv8vKSFD9PbJjmhYF14k/jKC7S9k=
|
github.com/fsnotify/fsnotify v1.9.0 h1:2Ml+OJNzbYCTzsxtv8vKSFD9PbJjmhYF14k/jKC7S9k=
|
||||||
github.com/fsnotify/fsnotify v1.9.0/go.mod h1:8jBTzvmWwFyi3Pb8djgCCO5IBqzKJ/Jwo8TRcHyHii0=
|
github.com/fsnotify/fsnotify v1.9.0/go.mod h1:8jBTzvmWwFyi3Pb8djgCCO5IBqzKJ/Jwo8TRcHyHii0=
|
||||||
|
github.com/fxamacker/cbor/v2 v2.9.0 h1:NpKPmjDBgUfBms6tr6JZkTHtfFGcMKsw3eGcmD/sapM=
|
||||||
|
github.com/fxamacker/cbor/v2 v2.9.0/go.mod h1:vM4b+DJCtHn+zz7h3FFp/hDAI9WNWCsZj23V5ytsSxQ=
|
||||||
github.com/gabriel-vasile/mimetype v1.4.3 h1:in2uUcidCuFcDKtdcBxlR0rJ1+fsokWf+uqxgUFjbI0=
|
github.com/gabriel-vasile/mimetype v1.4.3 h1:in2uUcidCuFcDKtdcBxlR0rJ1+fsokWf+uqxgUFjbI0=
|
||||||
github.com/gabriel-vasile/mimetype v1.4.3/go.mod h1:d8uq/6HKRL6CGdk+aubisF/M5GcPfT7nKyLpA0lbSSk=
|
github.com/gabriel-vasile/mimetype v1.4.3/go.mod h1:d8uq/6HKRL6CGdk+aubisF/M5GcPfT7nKyLpA0lbSSk=
|
||||||
github.com/gin-contrib/sse v0.1.0 h1:Y/yl/+YNO8GZSjAhjMsSuLt29uWRFHdHYUb5lYOV9qE=
|
github.com/gin-contrib/sse v0.1.0 h1:Y/yl/+YNO8GZSjAhjMsSuLt29uWRFHdHYUb5lYOV9qE=
|
||||||
@@ -157,6 +159,8 @@ github.com/twitchyliquid64/golang-asm v0.15.1 h1:SU5vSMR7hnwNxj24w34ZyCi/FmDZTkS
|
|||||||
github.com/twitchyliquid64/golang-asm v0.15.1/go.mod h1:a1lVb/DtPvCB8fslRZhAngC2+aY1QWCk3Cedj/Gdt08=
|
github.com/twitchyliquid64/golang-asm v0.15.1/go.mod h1:a1lVb/DtPvCB8fslRZhAngC2+aY1QWCk3Cedj/Gdt08=
|
||||||
github.com/ugorji/go/codec v1.2.12 h1:9LC83zGrHhuUA9l16C9AHXAqEV/2wBQ4nkvumAE65EE=
|
github.com/ugorji/go/codec v1.2.12 h1:9LC83zGrHhuUA9l16C9AHXAqEV/2wBQ4nkvumAE65EE=
|
||||||
github.com/ugorji/go/codec v1.2.12/go.mod h1:UNopzCgEMSXjBc6AOMqYvWC1ktqTAfzJZUZgYf6w6lg=
|
github.com/ugorji/go/codec v1.2.12/go.mod h1:UNopzCgEMSXjBc6AOMqYvWC1ktqTAfzJZUZgYf6w6lg=
|
||||||
|
github.com/x448/float16 v0.8.4 h1:qLwI1I70+NjRFUR3zs1JPUCgaCXSh3SW62uAKT1mSBM=
|
||||||
|
github.com/x448/float16 v0.8.4/go.mod h1:14CWIYCyZA/cWjXOioeEpHeN/83MdbZDRQHoFcYsOfg=
|
||||||
golang.org/x/arch v0.0.0-20210923205945-b76863e36670/go.mod h1:5om86z9Hs0C8fWVUuoMHwpExlXzs5Tkyp9hOrfG7pp8=
|
golang.org/x/arch v0.0.0-20210923205945-b76863e36670/go.mod h1:5om86z9Hs0C8fWVUuoMHwpExlXzs5Tkyp9hOrfG7pp8=
|
||||||
golang.org/x/arch v0.8.0 h1:3wRIsP3pM4yUptoR96otTUOXI367OS0+c9eeRi9doIc=
|
golang.org/x/arch v0.8.0 h1:3wRIsP3pM4yUptoR96otTUOXI367OS0+c9eeRi9doIc=
|
||||||
golang.org/x/arch v0.8.0/go.mod h1:FEVrYAQjsQXMVJ1nsMoVVXPZg6p2JE2mx8psSWTDQys=
|
golang.org/x/arch v0.8.0/go.mod h1:FEVrYAQjsQXMVJ1nsMoVVXPZg6p2JE2mx8psSWTDQys=
|
||||||
|
|||||||
@@ -11,6 +11,7 @@ import (
|
|||||||
"strings"
|
"strings"
|
||||||
"time"
|
"time"
|
||||||
|
|
||||||
|
"github.com/fxamacker/cbor/v2"
|
||||||
"github.com/gin-gonic/gin"
|
"github.com/gin-gonic/gin"
|
||||||
"github.com/router-for-me/CLIProxyAPI/v6/internal/runtime/geminicli"
|
"github.com/router-for-me/CLIProxyAPI/v6/internal/runtime/geminicli"
|
||||||
coreauth "github.com/router-for-me/CLIProxyAPI/v6/sdk/cliproxy/auth"
|
coreauth "github.com/router-for-me/CLIProxyAPI/v6/sdk/cliproxy/auth"
|
||||||
@@ -70,7 +71,7 @@ type apiCallResponse struct {
|
|||||||
// - Authorization: Bearer <key>
|
// - Authorization: Bearer <key>
|
||||||
// - X-Management-Key: <key>
|
// - X-Management-Key: <key>
|
||||||
//
|
//
|
||||||
// Request JSON:
|
// Request JSON (supports both application/json and application/cbor):
|
||||||
// - auth_index / authIndex / AuthIndex (optional):
|
// - auth_index / authIndex / AuthIndex (optional):
|
||||||
// The credential "auth_index" from GET /v0/management/auth-files (or other endpoints returning it).
|
// The credential "auth_index" from GET /v0/management/auth-files (or other endpoints returning it).
|
||||||
// If omitted or not found, credential-specific proxy/token substitution is skipped.
|
// If omitted or not found, credential-specific proxy/token substitution is skipped.
|
||||||
@@ -90,10 +91,12 @@ type apiCallResponse struct {
|
|||||||
// 2. Global config proxy-url
|
// 2. Global config proxy-url
|
||||||
// 3. Direct connect (environment proxies are not used)
|
// 3. Direct connect (environment proxies are not used)
|
||||||
//
|
//
|
||||||
// Response JSON (returned with HTTP 200 when the APICall itself succeeds):
|
// Response (returned with HTTP 200 when the APICall itself succeeds):
|
||||||
// - status_code: Upstream HTTP status code.
|
//
|
||||||
// - header: Upstream response headers.
|
// Format matches request Content-Type (application/json or application/cbor)
|
||||||
// - body: Upstream response body as string.
|
// - status_code: Upstream HTTP status code.
|
||||||
|
// - header: Upstream response headers.
|
||||||
|
// - body: Upstream response body as string.
|
||||||
//
|
//
|
||||||
// Example:
|
// Example:
|
||||||
//
|
//
|
||||||
@@ -107,10 +110,28 @@ type apiCallResponse struct {
|
|||||||
// -H "Content-Type: application/json" \
|
// -H "Content-Type: application/json" \
|
||||||
// -d '{"auth_index":"<AUTH_INDEX>","method":"POST","url":"https://api.example.com/v1/fetchAvailableModels","header":{"Authorization":"Bearer $TOKEN$","Content-Type":"application/json","User-Agent":"cliproxyapi"},"data":"{}"}'
|
// -d '{"auth_index":"<AUTH_INDEX>","method":"POST","url":"https://api.example.com/v1/fetchAvailableModels","header":{"Authorization":"Bearer $TOKEN$","Content-Type":"application/json","User-Agent":"cliproxyapi"},"data":"{}"}'
|
||||||
func (h *Handler) APICall(c *gin.Context) {
|
func (h *Handler) APICall(c *gin.Context) {
|
||||||
|
// Detect content type
|
||||||
|
contentType := strings.ToLower(strings.TrimSpace(c.GetHeader("Content-Type")))
|
||||||
|
isCBOR := strings.Contains(contentType, "application/cbor")
|
||||||
|
|
||||||
var body apiCallRequest
|
var body apiCallRequest
|
||||||
if errBindJSON := c.ShouldBindJSON(&body); errBindJSON != nil {
|
|
||||||
c.JSON(http.StatusBadRequest, gin.H{"error": "invalid body"})
|
// Parse request body based on content type
|
||||||
return
|
if isCBOR {
|
||||||
|
rawBody, errRead := io.ReadAll(c.Request.Body)
|
||||||
|
if errRead != nil {
|
||||||
|
c.JSON(http.StatusBadRequest, gin.H{"error": "failed to read request body"})
|
||||||
|
return
|
||||||
|
}
|
||||||
|
if errUnmarshal := cbor.Unmarshal(rawBody, &body); errUnmarshal != nil {
|
||||||
|
c.JSON(http.StatusBadRequest, gin.H{"error": "invalid cbor body"})
|
||||||
|
return
|
||||||
|
}
|
||||||
|
} else {
|
||||||
|
if errBindJSON := c.ShouldBindJSON(&body); errBindJSON != nil {
|
||||||
|
c.JSON(http.StatusBadRequest, gin.H{"error": "invalid body"})
|
||||||
|
return
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
method := strings.ToUpper(strings.TrimSpace(body.Method))
|
method := strings.ToUpper(strings.TrimSpace(body.Method))
|
||||||
@@ -209,11 +230,23 @@ func (h *Handler) APICall(c *gin.Context) {
|
|||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
c.JSON(http.StatusOK, apiCallResponse{
|
response := apiCallResponse{
|
||||||
StatusCode: resp.StatusCode,
|
StatusCode: resp.StatusCode,
|
||||||
Header: resp.Header,
|
Header: resp.Header,
|
||||||
Body: string(respBody),
|
Body: string(respBody),
|
||||||
})
|
}
|
||||||
|
|
||||||
|
// Return response in the same format as the request
|
||||||
|
if isCBOR {
|
||||||
|
cborData, errMarshal := cbor.Marshal(response)
|
||||||
|
if errMarshal != nil {
|
||||||
|
c.JSON(http.StatusInternalServerError, gin.H{"error": "failed to encode cbor response"})
|
||||||
|
return
|
||||||
|
}
|
||||||
|
c.Data(http.StatusOK, "application/cbor", cborData)
|
||||||
|
} else {
|
||||||
|
c.JSON(http.StatusOK, response)
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
func firstNonEmptyString(values ...*string) string {
|
func firstNonEmptyString(values ...*string) string {
|
||||||
|
|||||||
149
internal/api/handlers/management/api_tools_cbor_test.go
Normal file
149
internal/api/handlers/management/api_tools_cbor_test.go
Normal file
@@ -0,0 +1,149 @@
|
|||||||
|
package management
|
||||||
|
|
||||||
|
import (
|
||||||
|
"bytes"
|
||||||
|
"encoding/json"
|
||||||
|
"net/http"
|
||||||
|
"net/http/httptest"
|
||||||
|
"testing"
|
||||||
|
|
||||||
|
"github.com/fxamacker/cbor/v2"
|
||||||
|
"github.com/gin-gonic/gin"
|
||||||
|
)
|
||||||
|
|
||||||
|
func TestAPICall_CBOR_Support(t *testing.T) {
|
||||||
|
gin.SetMode(gin.TestMode)
|
||||||
|
|
||||||
|
// Create a test handler
|
||||||
|
h := &Handler{}
|
||||||
|
|
||||||
|
// Create test request data
|
||||||
|
reqData := apiCallRequest{
|
||||||
|
Method: "GET",
|
||||||
|
URL: "https://httpbin.org/get",
|
||||||
|
Header: map[string]string{
|
||||||
|
"User-Agent": "test-client",
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
t.Run("JSON request and response", func(t *testing.T) {
|
||||||
|
// Marshal request as JSON
|
||||||
|
jsonData, err := json.Marshal(reqData)
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("Failed to marshal JSON: %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Create HTTP request
|
||||||
|
req := httptest.NewRequest(http.MethodPost, "/v0/management/api-call", bytes.NewReader(jsonData))
|
||||||
|
req.Header.Set("Content-Type", "application/json")
|
||||||
|
|
||||||
|
// Create response recorder
|
||||||
|
w := httptest.NewRecorder()
|
||||||
|
|
||||||
|
// Create Gin context
|
||||||
|
c, _ := gin.CreateTestContext(w)
|
||||||
|
c.Request = req
|
||||||
|
|
||||||
|
// Call handler
|
||||||
|
h.APICall(c)
|
||||||
|
|
||||||
|
// Verify response
|
||||||
|
if w.Code != http.StatusOK && w.Code != http.StatusBadGateway {
|
||||||
|
t.Logf("Response status: %d", w.Code)
|
||||||
|
t.Logf("Response body: %s", w.Body.String())
|
||||||
|
}
|
||||||
|
|
||||||
|
// Check content type
|
||||||
|
contentType := w.Header().Get("Content-Type")
|
||||||
|
if w.Code == http.StatusOK && !contains(contentType, "application/json") {
|
||||||
|
t.Errorf("Expected JSON response, got: %s", contentType)
|
||||||
|
}
|
||||||
|
})
|
||||||
|
|
||||||
|
t.Run("CBOR request and response", func(t *testing.T) {
|
||||||
|
// Marshal request as CBOR
|
||||||
|
cborData, err := cbor.Marshal(reqData)
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("Failed to marshal CBOR: %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Create HTTP request
|
||||||
|
req := httptest.NewRequest(http.MethodPost, "/v0/management/api-call", bytes.NewReader(cborData))
|
||||||
|
req.Header.Set("Content-Type", "application/cbor")
|
||||||
|
|
||||||
|
// Create response recorder
|
||||||
|
w := httptest.NewRecorder()
|
||||||
|
|
||||||
|
// Create Gin context
|
||||||
|
c, _ := gin.CreateTestContext(w)
|
||||||
|
c.Request = req
|
||||||
|
|
||||||
|
// Call handler
|
||||||
|
h.APICall(c)
|
||||||
|
|
||||||
|
// Verify response
|
||||||
|
if w.Code != http.StatusOK && w.Code != http.StatusBadGateway {
|
||||||
|
t.Logf("Response status: %d", w.Code)
|
||||||
|
t.Logf("Response body: %s", w.Body.String())
|
||||||
|
}
|
||||||
|
|
||||||
|
// Check content type
|
||||||
|
contentType := w.Header().Get("Content-Type")
|
||||||
|
if w.Code == http.StatusOK && !contains(contentType, "application/cbor") {
|
||||||
|
t.Errorf("Expected CBOR response, got: %s", contentType)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Try to decode CBOR response
|
||||||
|
if w.Code == http.StatusOK {
|
||||||
|
var response apiCallResponse
|
||||||
|
if err := cbor.Unmarshal(w.Body.Bytes(), &response); err != nil {
|
||||||
|
t.Errorf("Failed to unmarshal CBOR response: %v", err)
|
||||||
|
} else {
|
||||||
|
t.Logf("CBOR response decoded successfully: status_code=%d", response.StatusCode)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
})
|
||||||
|
|
||||||
|
t.Run("CBOR encoding and decoding consistency", func(t *testing.T) {
|
||||||
|
// Test data
|
||||||
|
testReq := apiCallRequest{
|
||||||
|
Method: "POST",
|
||||||
|
URL: "https://example.com/api",
|
||||||
|
Header: map[string]string{
|
||||||
|
"Authorization": "Bearer $TOKEN$",
|
||||||
|
"Content-Type": "application/json",
|
||||||
|
},
|
||||||
|
Data: `{"key":"value"}`,
|
||||||
|
}
|
||||||
|
|
||||||
|
// Encode to CBOR
|
||||||
|
cborData, err := cbor.Marshal(testReq)
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("Failed to marshal to CBOR: %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Decode from CBOR
|
||||||
|
var decoded apiCallRequest
|
||||||
|
if err := cbor.Unmarshal(cborData, &decoded); err != nil {
|
||||||
|
t.Fatalf("Failed to unmarshal from CBOR: %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Verify fields
|
||||||
|
if decoded.Method != testReq.Method {
|
||||||
|
t.Errorf("Method mismatch: got %s, want %s", decoded.Method, testReq.Method)
|
||||||
|
}
|
||||||
|
if decoded.URL != testReq.URL {
|
||||||
|
t.Errorf("URL mismatch: got %s, want %s", decoded.URL, testReq.URL)
|
||||||
|
}
|
||||||
|
if decoded.Data != testReq.Data {
|
||||||
|
t.Errorf("Data mismatch: got %s, want %s", decoded.Data, testReq.Data)
|
||||||
|
}
|
||||||
|
if len(decoded.Header) != len(testReq.Header) {
|
||||||
|
t.Errorf("Header count mismatch: got %d, want %d", len(decoded.Header), len(testReq.Header))
|
||||||
|
}
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
|
func contains(s, substr string) bool {
|
||||||
|
return len(s) > 0 && len(substr) > 0 && (s == substr || len(s) >= len(substr) && s[:len(substr)] == substr || bytes.Contains([]byte(s), []byte(substr)))
|
||||||
|
}
|
||||||
@@ -3,10 +3,10 @@ package management
|
|||||||
import (
|
import (
|
||||||
"bytes"
|
"bytes"
|
||||||
"context"
|
"context"
|
||||||
"encoding/hex"
|
|
||||||
"crypto/rand"
|
"crypto/rand"
|
||||||
"crypto/sha256"
|
"crypto/sha256"
|
||||||
"encoding/base64"
|
"encoding/base64"
|
||||||
|
"encoding/hex"
|
||||||
"encoding/json"
|
"encoding/json"
|
||||||
"errors"
|
"errors"
|
||||||
"fmt"
|
"fmt"
|
||||||
@@ -23,6 +23,7 @@ import (
|
|||||||
"time"
|
"time"
|
||||||
|
|
||||||
"github.com/gin-gonic/gin"
|
"github.com/gin-gonic/gin"
|
||||||
|
"github.com/router-for-me/CLIProxyAPI/v6/internal/auth/antigravity"
|
||||||
"github.com/router-for-me/CLIProxyAPI/v6/internal/auth/claude"
|
"github.com/router-for-me/CLIProxyAPI/v6/internal/auth/claude"
|
||||||
"github.com/router-for-me/CLIProxyAPI/v6/internal/auth/codex"
|
"github.com/router-for-me/CLIProxyAPI/v6/internal/auth/codex"
|
||||||
"github.com/router-for-me/CLIProxyAPI/v6/internal/auth/copilot"
|
"github.com/router-for-me/CLIProxyAPI/v6/internal/auth/copilot"
|
||||||
@@ -236,14 +237,6 @@ func stopForwarderInstance(port int, forwarder *callbackForwarder) {
|
|||||||
log.Infof("callback forwarder on port %d stopped", port)
|
log.Infof("callback forwarder on port %d stopped", port)
|
||||||
}
|
}
|
||||||
|
|
||||||
func sanitizeAntigravityFileName(email string) string {
|
|
||||||
if strings.TrimSpace(email) == "" {
|
|
||||||
return "antigravity.json"
|
|
||||||
}
|
|
||||||
replacer := strings.NewReplacer("@", "_", ".", "_")
|
|
||||||
return fmt.Sprintf("antigravity-%s.json", replacer.Replace(email))
|
|
||||||
}
|
|
||||||
|
|
||||||
func (h *Handler) managementCallbackURL(path string) (string, error) {
|
func (h *Handler) managementCallbackURL(path string) (string, error) {
|
||||||
if h == nil || h.cfg == nil || h.cfg.Port <= 0 {
|
if h == nil || h.cfg == nil || h.cfg.Port <= 0 {
|
||||||
return "", fmt.Errorf("server port is not configured")
|
return "", fmt.Errorf("server port is not configured")
|
||||||
@@ -985,67 +978,14 @@ func (h *Handler) RequestAnthropicToken(c *gin.Context) {
|
|||||||
rawCode := resultMap["code"]
|
rawCode := resultMap["code"]
|
||||||
code := strings.Split(rawCode, "#")[0]
|
code := strings.Split(rawCode, "#")[0]
|
||||||
|
|
||||||
// Exchange code for tokens (replicate logic using updated redirect_uri)
|
// Exchange code for tokens using internal auth service
|
||||||
// Extract client_id from the modified auth URL
|
bundle, errExchange := anthropicAuth.ExchangeCodeForTokens(ctx, code, state, pkceCodes)
|
||||||
clientID := ""
|
if errExchange != nil {
|
||||||
if u2, errP := url.Parse(authURL); errP == nil {
|
authErr := claude.NewAuthenticationError(claude.ErrCodeExchangeFailed, errExchange)
|
||||||
clientID = u2.Query().Get("client_id")
|
|
||||||
}
|
|
||||||
// Build request
|
|
||||||
bodyMap := map[string]any{
|
|
||||||
"code": code,
|
|
||||||
"state": state,
|
|
||||||
"grant_type": "authorization_code",
|
|
||||||
"client_id": clientID,
|
|
||||||
"redirect_uri": "http://localhost:54545/callback",
|
|
||||||
"code_verifier": pkceCodes.CodeVerifier,
|
|
||||||
}
|
|
||||||
bodyJSON, _ := json.Marshal(bodyMap)
|
|
||||||
|
|
||||||
httpClient := util.SetProxy(&h.cfg.SDKConfig, &http.Client{})
|
|
||||||
req, _ := http.NewRequestWithContext(ctx, "POST", "https://console.anthropic.com/v1/oauth/token", strings.NewReader(string(bodyJSON)))
|
|
||||||
req.Header.Set("Content-Type", "application/json")
|
|
||||||
req.Header.Set("Accept", "application/json")
|
|
||||||
resp, errDo := httpClient.Do(req)
|
|
||||||
if errDo != nil {
|
|
||||||
authErr := claude.NewAuthenticationError(claude.ErrCodeExchangeFailed, errDo)
|
|
||||||
log.Errorf("Failed to exchange authorization code for tokens: %v", authErr)
|
log.Errorf("Failed to exchange authorization code for tokens: %v", authErr)
|
||||||
SetOAuthSessionError(state, "Failed to exchange authorization code for tokens")
|
SetOAuthSessionError(state, "Failed to exchange authorization code for tokens")
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
defer func() {
|
|
||||||
if errClose := resp.Body.Close(); errClose != nil {
|
|
||||||
log.Errorf("failed to close response body: %v", errClose)
|
|
||||||
}
|
|
||||||
}()
|
|
||||||
respBody, _ := io.ReadAll(resp.Body)
|
|
||||||
if resp.StatusCode != http.StatusOK {
|
|
||||||
log.Errorf("token exchange failed with status %d: %s", resp.StatusCode, string(respBody))
|
|
||||||
SetOAuthSessionError(state, fmt.Sprintf("token exchange failed with status %d", resp.StatusCode))
|
|
||||||
return
|
|
||||||
}
|
|
||||||
var tResp struct {
|
|
||||||
AccessToken string `json:"access_token"`
|
|
||||||
RefreshToken string `json:"refresh_token"`
|
|
||||||
ExpiresIn int `json:"expires_in"`
|
|
||||||
Account struct {
|
|
||||||
EmailAddress string `json:"email_address"`
|
|
||||||
} `json:"account"`
|
|
||||||
}
|
|
||||||
if errU := json.Unmarshal(respBody, &tResp); errU != nil {
|
|
||||||
log.Errorf("failed to parse token response: %v", errU)
|
|
||||||
SetOAuthSessionError(state, "Failed to parse token response")
|
|
||||||
return
|
|
||||||
}
|
|
||||||
bundle := &claude.ClaudeAuthBundle{
|
|
||||||
TokenData: claude.ClaudeTokenData{
|
|
||||||
AccessToken: tResp.AccessToken,
|
|
||||||
RefreshToken: tResp.RefreshToken,
|
|
||||||
Email: tResp.Account.EmailAddress,
|
|
||||||
Expire: time.Now().Add(time.Duration(tResp.ExpiresIn) * time.Second).Format(time.RFC3339),
|
|
||||||
},
|
|
||||||
LastRefresh: time.Now().Format(time.RFC3339),
|
|
||||||
}
|
|
||||||
|
|
||||||
// Create token storage
|
// Create token storage
|
||||||
tokenStorage := anthropicAuth.CreateTokenStorage(bundle)
|
tokenStorage := anthropicAuth.CreateTokenStorage(bundle)
|
||||||
@@ -1085,17 +1025,13 @@ func (h *Handler) RequestGeminiCLIToken(c *gin.Context) {
|
|||||||
|
|
||||||
fmt.Println("Initializing Google authentication...")
|
fmt.Println("Initializing Google authentication...")
|
||||||
|
|
||||||
// OAuth2 configuration (mirrors internal/auth/gemini)
|
// OAuth2 configuration using exported constants from internal/auth/gemini
|
||||||
conf := &oauth2.Config{
|
conf := &oauth2.Config{
|
||||||
ClientID: "681255809395-oo8ft2oprdrnp9e3aqf6av3hmdib135j.apps.googleusercontent.com",
|
ClientID: geminiAuth.ClientID,
|
||||||
ClientSecret: "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",
|
ClientSecret: geminiAuth.ClientSecret,
|
||||||
RedirectURL: "http://localhost:8085/oauth2callback",
|
RedirectURL: fmt.Sprintf("http://localhost:%d/oauth2callback", geminiAuth.DefaultCallbackPort),
|
||||||
Scopes: []string{
|
Scopes: geminiAuth.Scopes,
|
||||||
"https://www.googleapis.com/auth/cloud-platform",
|
Endpoint: google.Endpoint,
|
||||||
"https://www.googleapis.com/auth/userinfo.email",
|
|
||||||
"https://www.googleapis.com/auth/userinfo.profile",
|
|
||||||
},
|
|
||||||
Endpoint: google.Endpoint,
|
|
||||||
}
|
}
|
||||||
|
|
||||||
// Build authorization URL and return it immediately
|
// Build authorization URL and return it immediately
|
||||||
@@ -1217,13 +1153,9 @@ func (h *Handler) RequestGeminiCLIToken(c *gin.Context) {
|
|||||||
}
|
}
|
||||||
|
|
||||||
ifToken["token_uri"] = "https://oauth2.googleapis.com/token"
|
ifToken["token_uri"] = "https://oauth2.googleapis.com/token"
|
||||||
ifToken["client_id"] = "681255809395-oo8ft2oprdrnp9e3aqf6av3hmdib135j.apps.googleusercontent.com"
|
ifToken["client_id"] = geminiAuth.ClientID
|
||||||
ifToken["client_secret"] = "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"
|
ifToken["client_secret"] = geminiAuth.ClientSecret
|
||||||
ifToken["scopes"] = []string{
|
ifToken["scopes"] = geminiAuth.Scopes
|
||||||
"https://www.googleapis.com/auth/cloud-platform",
|
|
||||||
"https://www.googleapis.com/auth/userinfo.email",
|
|
||||||
"https://www.googleapis.com/auth/userinfo.profile",
|
|
||||||
}
|
|
||||||
ifToken["universe_domain"] = "googleapis.com"
|
ifToken["universe_domain"] = "googleapis.com"
|
||||||
|
|
||||||
ts := geminiAuth.GeminiTokenStorage{
|
ts := geminiAuth.GeminiTokenStorage{
|
||||||
@@ -1410,73 +1342,25 @@ func (h *Handler) RequestCodexToken(c *gin.Context) {
|
|||||||
}
|
}
|
||||||
|
|
||||||
log.Debug("Authorization code received, exchanging for tokens...")
|
log.Debug("Authorization code received, exchanging for tokens...")
|
||||||
// Extract client_id from authURL
|
// Exchange code for tokens using internal auth service
|
||||||
clientID := ""
|
bundle, errExchange := openaiAuth.ExchangeCodeForTokens(ctx, code, pkceCodes)
|
||||||
if u2, errP := url.Parse(authURL); errP == nil {
|
if errExchange != nil {
|
||||||
clientID = u2.Query().Get("client_id")
|
authErr := codex.NewAuthenticationError(codex.ErrCodeExchangeFailed, errExchange)
|
||||||
}
|
|
||||||
// Exchange code for tokens with redirect equal to mgmtRedirect
|
|
||||||
form := url.Values{
|
|
||||||
"grant_type": {"authorization_code"},
|
|
||||||
"client_id": {clientID},
|
|
||||||
"code": {code},
|
|
||||||
"redirect_uri": {"http://localhost:1455/auth/callback"},
|
|
||||||
"code_verifier": {pkceCodes.CodeVerifier},
|
|
||||||
}
|
|
||||||
httpClient := util.SetProxy(&h.cfg.SDKConfig, &http.Client{})
|
|
||||||
req, _ := http.NewRequestWithContext(ctx, "POST", "https://auth.openai.com/oauth/token", strings.NewReader(form.Encode()))
|
|
||||||
req.Header.Set("Content-Type", "application/x-www-form-urlencoded")
|
|
||||||
req.Header.Set("Accept", "application/json")
|
|
||||||
resp, errDo := httpClient.Do(req)
|
|
||||||
if errDo != nil {
|
|
||||||
authErr := codex.NewAuthenticationError(codex.ErrCodeExchangeFailed, errDo)
|
|
||||||
SetOAuthSessionError(state, "Failed to exchange authorization code for tokens")
|
SetOAuthSessionError(state, "Failed to exchange authorization code for tokens")
|
||||||
log.Errorf("Failed to exchange authorization code for tokens: %v", authErr)
|
log.Errorf("Failed to exchange authorization code for tokens: %v", authErr)
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
defer func() { _ = resp.Body.Close() }()
|
|
||||||
respBody, _ := io.ReadAll(resp.Body)
|
// Extract additional info for filename generation
|
||||||
if resp.StatusCode != http.StatusOK {
|
claims, _ := codex.ParseJWTToken(bundle.TokenData.IDToken)
|
||||||
SetOAuthSessionError(state, fmt.Sprintf("Token exchange failed with status %d", resp.StatusCode))
|
|
||||||
log.Errorf("token exchange failed with status %d: %s", resp.StatusCode, string(respBody))
|
|
||||||
return
|
|
||||||
}
|
|
||||||
var tokenResp struct {
|
|
||||||
AccessToken string `json:"access_token"`
|
|
||||||
RefreshToken string `json:"refresh_token"`
|
|
||||||
IDToken string `json:"id_token"`
|
|
||||||
ExpiresIn int `json:"expires_in"`
|
|
||||||
}
|
|
||||||
if errU := json.Unmarshal(respBody, &tokenResp); errU != nil {
|
|
||||||
SetOAuthSessionError(state, "Failed to parse token response")
|
|
||||||
log.Errorf("failed to parse token response: %v", errU)
|
|
||||||
return
|
|
||||||
}
|
|
||||||
claims, _ := codex.ParseJWTToken(tokenResp.IDToken)
|
|
||||||
email := ""
|
|
||||||
accountID := ""
|
|
||||||
planType := ""
|
planType := ""
|
||||||
if claims != nil {
|
|
||||||
email = claims.GetUserEmail()
|
|
||||||
accountID = claims.GetAccountID()
|
|
||||||
planType = strings.TrimSpace(claims.CodexAuthInfo.ChatgptPlanType)
|
|
||||||
}
|
|
||||||
hashAccountID := ""
|
hashAccountID := ""
|
||||||
if accountID != "" {
|
if claims != nil {
|
||||||
digest := sha256.Sum256([]byte(accountID))
|
planType = strings.TrimSpace(claims.CodexAuthInfo.ChatgptPlanType)
|
||||||
hashAccountID = hex.EncodeToString(digest[:])[:8]
|
if accountID := claims.GetAccountID(); accountID != "" {
|
||||||
}
|
digest := sha256.Sum256([]byte(accountID))
|
||||||
// Build bundle compatible with existing storage
|
hashAccountID = hex.EncodeToString(digest[:])[:8]
|
||||||
bundle := &codex.CodexAuthBundle{
|
}
|
||||||
TokenData: codex.CodexTokenData{
|
|
||||||
IDToken: tokenResp.IDToken,
|
|
||||||
AccessToken: tokenResp.AccessToken,
|
|
||||||
RefreshToken: tokenResp.RefreshToken,
|
|
||||||
AccountID: accountID,
|
|
||||||
Email: email,
|
|
||||||
Expire: time.Now().Add(time.Duration(tokenResp.ExpiresIn) * time.Second).Format(time.RFC3339),
|
|
||||||
},
|
|
||||||
LastRefresh: time.Now().Format(time.RFC3339),
|
|
||||||
}
|
}
|
||||||
|
|
||||||
// Create token storage and persist
|
// Create token storage and persist
|
||||||
@@ -1511,23 +1395,12 @@ func (h *Handler) RequestCodexToken(c *gin.Context) {
|
|||||||
}
|
}
|
||||||
|
|
||||||
func (h *Handler) RequestAntigravityToken(c *gin.Context) {
|
func (h *Handler) RequestAntigravityToken(c *gin.Context) {
|
||||||
const (
|
|
||||||
antigravityCallbackPort = 51121
|
|
||||||
antigravityClientID = "1071006060591-tmhssin2h21lcre235vtolojh4g403ep.apps.googleusercontent.com"
|
|
||||||
antigravityClientSecret = "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"
|
|
||||||
)
|
|
||||||
var antigravityScopes = []string{
|
|
||||||
"https://www.googleapis.com/auth/cloud-platform",
|
|
||||||
"https://www.googleapis.com/auth/userinfo.email",
|
|
||||||
"https://www.googleapis.com/auth/userinfo.profile",
|
|
||||||
"https://www.googleapis.com/auth/cclog",
|
|
||||||
"https://www.googleapis.com/auth/experimentsandconfigs",
|
|
||||||
}
|
|
||||||
|
|
||||||
ctx := context.Background()
|
ctx := context.Background()
|
||||||
|
|
||||||
fmt.Println("Initializing Antigravity authentication...")
|
fmt.Println("Initializing Antigravity authentication...")
|
||||||
|
|
||||||
|
authSvc := antigravity.NewAntigravityAuth(h.cfg, nil)
|
||||||
|
|
||||||
state, errState := misc.GenerateRandomState()
|
state, errState := misc.GenerateRandomState()
|
||||||
if errState != nil {
|
if errState != nil {
|
||||||
log.Errorf("Failed to generate state parameter: %v", errState)
|
log.Errorf("Failed to generate state parameter: %v", errState)
|
||||||
@@ -1535,17 +1408,8 @@ func (h *Handler) RequestAntigravityToken(c *gin.Context) {
|
|||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
redirectURI := fmt.Sprintf("http://localhost:%d/oauth-callback", antigravityCallbackPort)
|
redirectURI := fmt.Sprintf("http://localhost:%d/oauth-callback", antigravity.CallbackPort)
|
||||||
|
authURL := authSvc.BuildAuthURL(state, redirectURI)
|
||||||
params := url.Values{}
|
|
||||||
params.Set("access_type", "offline")
|
|
||||||
params.Set("client_id", antigravityClientID)
|
|
||||||
params.Set("prompt", "consent")
|
|
||||||
params.Set("redirect_uri", redirectURI)
|
|
||||||
params.Set("response_type", "code")
|
|
||||||
params.Set("scope", strings.Join(antigravityScopes, " "))
|
|
||||||
params.Set("state", state)
|
|
||||||
authURL := "https://accounts.google.com/o/oauth2/v2/auth?" + params.Encode()
|
|
||||||
|
|
||||||
RegisterOAuthSession(state, "antigravity")
|
RegisterOAuthSession(state, "antigravity")
|
||||||
|
|
||||||
@@ -1559,7 +1423,7 @@ func (h *Handler) RequestAntigravityToken(c *gin.Context) {
|
|||||||
return
|
return
|
||||||
}
|
}
|
||||||
var errStart error
|
var errStart error
|
||||||
if forwarder, errStart = startCallbackForwarder(antigravityCallbackPort, "antigravity", targetURL); errStart != nil {
|
if forwarder, errStart = startCallbackForwarder(antigravity.CallbackPort, "antigravity", targetURL); errStart != nil {
|
||||||
log.WithError(errStart).Error("failed to start antigravity callback forwarder")
|
log.WithError(errStart).Error("failed to start antigravity callback forwarder")
|
||||||
c.JSON(http.StatusInternalServerError, gin.H{"error": "failed to start callback server"})
|
c.JSON(http.StatusInternalServerError, gin.H{"error": "failed to start callback server"})
|
||||||
return
|
return
|
||||||
@@ -1568,7 +1432,7 @@ func (h *Handler) RequestAntigravityToken(c *gin.Context) {
|
|||||||
|
|
||||||
go func() {
|
go func() {
|
||||||
if isWebUI {
|
if isWebUI {
|
||||||
defer stopCallbackForwarderInstance(antigravityCallbackPort, forwarder)
|
defer stopCallbackForwarderInstance(antigravity.CallbackPort, forwarder)
|
||||||
}
|
}
|
||||||
|
|
||||||
waitFile := filepath.Join(h.cfg.AuthDir, fmt.Sprintf(".oauth-antigravity-%s.oauth", state))
|
waitFile := filepath.Join(h.cfg.AuthDir, fmt.Sprintf(".oauth-antigravity-%s.oauth", state))
|
||||||
@@ -1608,93 +1472,36 @@ func (h *Handler) RequestAntigravityToken(c *gin.Context) {
|
|||||||
time.Sleep(500 * time.Millisecond)
|
time.Sleep(500 * time.Millisecond)
|
||||||
}
|
}
|
||||||
|
|
||||||
httpClient := util.SetProxy(&h.cfg.SDKConfig, &http.Client{})
|
tokenResp, errToken := authSvc.ExchangeCodeForTokens(ctx, authCode, redirectURI)
|
||||||
form := url.Values{}
|
if errToken != nil {
|
||||||
form.Set("code", authCode)
|
log.Errorf("Failed to exchange token: %v", errToken)
|
||||||
form.Set("client_id", antigravityClientID)
|
|
||||||
form.Set("client_secret", antigravityClientSecret)
|
|
||||||
form.Set("redirect_uri", redirectURI)
|
|
||||||
form.Set("grant_type", "authorization_code")
|
|
||||||
|
|
||||||
req, errNewRequest := http.NewRequestWithContext(ctx, http.MethodPost, "https://oauth2.googleapis.com/token", strings.NewReader(form.Encode()))
|
|
||||||
if errNewRequest != nil {
|
|
||||||
log.Errorf("Failed to build token request: %v", errNewRequest)
|
|
||||||
SetOAuthSessionError(state, "Failed to build token request")
|
|
||||||
return
|
|
||||||
}
|
|
||||||
req.Header.Set("Content-Type", "application/x-www-form-urlencoded")
|
|
||||||
|
|
||||||
resp, errDo := httpClient.Do(req)
|
|
||||||
if errDo != nil {
|
|
||||||
log.Errorf("Failed to execute token request: %v", errDo)
|
|
||||||
SetOAuthSessionError(state, "Failed to exchange token")
|
SetOAuthSessionError(state, "Failed to exchange token")
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
defer func() {
|
|
||||||
if errClose := resp.Body.Close(); errClose != nil {
|
|
||||||
log.Errorf("antigravity token exchange close error: %v", errClose)
|
|
||||||
}
|
|
||||||
}()
|
|
||||||
|
|
||||||
if resp.StatusCode < http.StatusOK || resp.StatusCode >= http.StatusMultipleChoices {
|
accessToken := strings.TrimSpace(tokenResp.AccessToken)
|
||||||
bodyBytes, _ := io.ReadAll(resp.Body)
|
if accessToken == "" {
|
||||||
log.Errorf("Antigravity token exchange failed with status %d: %s", resp.StatusCode, string(bodyBytes))
|
log.Error("antigravity: token exchange returned empty access token")
|
||||||
SetOAuthSessionError(state, fmt.Sprintf("Token exchange failed: %d", resp.StatusCode))
|
SetOAuthSessionError(state, "Failed to exchange token")
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
var tokenResp struct {
|
email, errInfo := authSvc.FetchUserInfo(ctx, accessToken)
|
||||||
AccessToken string `json:"access_token"`
|
if errInfo != nil {
|
||||||
RefreshToken string `json:"refresh_token"`
|
log.Errorf("Failed to fetch user info: %v", errInfo)
|
||||||
ExpiresIn int64 `json:"expires_in"`
|
SetOAuthSessionError(state, "Failed to fetch user info")
|
||||||
TokenType string `json:"token_type"`
|
|
||||||
}
|
|
||||||
if errDecode := json.NewDecoder(resp.Body).Decode(&tokenResp); errDecode != nil {
|
|
||||||
log.Errorf("Failed to parse token response: %v", errDecode)
|
|
||||||
SetOAuthSessionError(state, "Failed to parse token response")
|
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
email = strings.TrimSpace(email)
|
||||||
email := ""
|
if email == "" {
|
||||||
if strings.TrimSpace(tokenResp.AccessToken) != "" {
|
log.Error("antigravity: user info returned empty email")
|
||||||
infoReq, errInfoReq := http.NewRequestWithContext(ctx, http.MethodGet, "https://www.googleapis.com/oauth2/v1/userinfo?alt=json", nil)
|
SetOAuthSessionError(state, "Failed to fetch user info")
|
||||||
if errInfoReq != nil {
|
return
|
||||||
log.Errorf("Failed to build user info request: %v", errInfoReq)
|
|
||||||
SetOAuthSessionError(state, "Failed to build user info request")
|
|
||||||
return
|
|
||||||
}
|
|
||||||
infoReq.Header.Set("Authorization", "Bearer "+tokenResp.AccessToken)
|
|
||||||
|
|
||||||
infoResp, errInfo := httpClient.Do(infoReq)
|
|
||||||
if errInfo != nil {
|
|
||||||
log.Errorf("Failed to execute user info request: %v", errInfo)
|
|
||||||
SetOAuthSessionError(state, "Failed to execute user info request")
|
|
||||||
return
|
|
||||||
}
|
|
||||||
defer func() {
|
|
||||||
if errClose := infoResp.Body.Close(); errClose != nil {
|
|
||||||
log.Errorf("antigravity user info close error: %v", errClose)
|
|
||||||
}
|
|
||||||
}()
|
|
||||||
|
|
||||||
if infoResp.StatusCode >= http.StatusOK && infoResp.StatusCode < http.StatusMultipleChoices {
|
|
||||||
var infoPayload struct {
|
|
||||||
Email string `json:"email"`
|
|
||||||
}
|
|
||||||
if errDecodeInfo := json.NewDecoder(infoResp.Body).Decode(&infoPayload); errDecodeInfo == nil {
|
|
||||||
email = strings.TrimSpace(infoPayload.Email)
|
|
||||||
}
|
|
||||||
} else {
|
|
||||||
bodyBytes, _ := io.ReadAll(infoResp.Body)
|
|
||||||
log.Errorf("User info request failed with status %d: %s", infoResp.StatusCode, string(bodyBytes))
|
|
||||||
SetOAuthSessionError(state, fmt.Sprintf("User info request failed: %d", infoResp.StatusCode))
|
|
||||||
return
|
|
||||||
}
|
|
||||||
}
|
}
|
||||||
|
|
||||||
projectID := ""
|
projectID := ""
|
||||||
if strings.TrimSpace(tokenResp.AccessToken) != "" {
|
if accessToken != "" {
|
||||||
fetchedProjectID, errProject := sdkAuth.FetchAntigravityProjectID(ctx, tokenResp.AccessToken, httpClient)
|
fetchedProjectID, errProject := authSvc.FetchProjectID(ctx, accessToken)
|
||||||
if errProject != nil {
|
if errProject != nil {
|
||||||
log.Warnf("antigravity: failed to fetch project ID: %v", errProject)
|
log.Warnf("antigravity: failed to fetch project ID: %v", errProject)
|
||||||
} else {
|
} else {
|
||||||
@@ -1719,7 +1526,7 @@ func (h *Handler) RequestAntigravityToken(c *gin.Context) {
|
|||||||
metadata["project_id"] = projectID
|
metadata["project_id"] = projectID
|
||||||
}
|
}
|
||||||
|
|
||||||
fileName := sanitizeAntigravityFileName(email)
|
fileName := antigravity.CredentialFileName(email)
|
||||||
label := strings.TrimSpace(email)
|
label := strings.TrimSpace(email)
|
||||||
if label == "" {
|
if label == "" {
|
||||||
label = "antigravity"
|
label = "antigravity"
|
||||||
|
|||||||
344
internal/auth/antigravity/auth.go
Normal file
344
internal/auth/antigravity/auth.go
Normal file
@@ -0,0 +1,344 @@
|
|||||||
|
// Package antigravity provides OAuth2 authentication functionality for the Antigravity provider.
|
||||||
|
package antigravity
|
||||||
|
|
||||||
|
import (
|
||||||
|
"context"
|
||||||
|
"encoding/json"
|
||||||
|
"fmt"
|
||||||
|
"io"
|
||||||
|
"net/http"
|
||||||
|
"net/url"
|
||||||
|
"strings"
|
||||||
|
"time"
|
||||||
|
|
||||||
|
"github.com/router-for-me/CLIProxyAPI/v6/internal/config"
|
||||||
|
"github.com/router-for-me/CLIProxyAPI/v6/internal/util"
|
||||||
|
log "github.com/sirupsen/logrus"
|
||||||
|
)
|
||||||
|
|
||||||
|
// TokenResponse represents OAuth token response from Google
|
||||||
|
type TokenResponse struct {
|
||||||
|
AccessToken string `json:"access_token"`
|
||||||
|
RefreshToken string `json:"refresh_token"`
|
||||||
|
ExpiresIn int64 `json:"expires_in"`
|
||||||
|
TokenType string `json:"token_type"`
|
||||||
|
}
|
||||||
|
|
||||||
|
// userInfo represents Google user profile
|
||||||
|
type userInfo struct {
|
||||||
|
Email string `json:"email"`
|
||||||
|
}
|
||||||
|
|
||||||
|
// AntigravityAuth handles Antigravity OAuth authentication
|
||||||
|
type AntigravityAuth struct {
|
||||||
|
httpClient *http.Client
|
||||||
|
}
|
||||||
|
|
||||||
|
// NewAntigravityAuth creates a new Antigravity auth service.
|
||||||
|
func NewAntigravityAuth(cfg *config.Config, httpClient *http.Client) *AntigravityAuth {
|
||||||
|
if httpClient != nil {
|
||||||
|
return &AntigravityAuth{httpClient: httpClient}
|
||||||
|
}
|
||||||
|
if cfg == nil {
|
||||||
|
cfg = &config.Config{}
|
||||||
|
}
|
||||||
|
return &AntigravityAuth{
|
||||||
|
httpClient: util.SetProxy(&cfg.SDKConfig, &http.Client{}),
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// BuildAuthURL generates the OAuth authorization URL.
|
||||||
|
func (o *AntigravityAuth) BuildAuthURL(state, redirectURI string) string {
|
||||||
|
if strings.TrimSpace(redirectURI) == "" {
|
||||||
|
redirectURI = fmt.Sprintf("http://localhost:%d/oauth-callback", CallbackPort)
|
||||||
|
}
|
||||||
|
params := url.Values{}
|
||||||
|
params.Set("access_type", "offline")
|
||||||
|
params.Set("client_id", ClientID)
|
||||||
|
params.Set("prompt", "consent")
|
||||||
|
params.Set("redirect_uri", redirectURI)
|
||||||
|
params.Set("response_type", "code")
|
||||||
|
params.Set("scope", strings.Join(Scopes, " "))
|
||||||
|
params.Set("state", state)
|
||||||
|
return AuthEndpoint + "?" + params.Encode()
|
||||||
|
}
|
||||||
|
|
||||||
|
// ExchangeCodeForTokens exchanges authorization code for access and refresh tokens
|
||||||
|
func (o *AntigravityAuth) ExchangeCodeForTokens(ctx context.Context, code, redirectURI string) (*TokenResponse, error) {
|
||||||
|
data := url.Values{}
|
||||||
|
data.Set("code", code)
|
||||||
|
data.Set("client_id", ClientID)
|
||||||
|
data.Set("client_secret", ClientSecret)
|
||||||
|
data.Set("redirect_uri", redirectURI)
|
||||||
|
data.Set("grant_type", "authorization_code")
|
||||||
|
|
||||||
|
req, err := http.NewRequestWithContext(ctx, http.MethodPost, TokenEndpoint, strings.NewReader(data.Encode()))
|
||||||
|
if err != nil {
|
||||||
|
return nil, fmt.Errorf("antigravity token exchange: create request: %w", err)
|
||||||
|
}
|
||||||
|
req.Header.Set("Content-Type", "application/x-www-form-urlencoded")
|
||||||
|
|
||||||
|
resp, errDo := o.httpClient.Do(req)
|
||||||
|
if errDo != nil {
|
||||||
|
return nil, fmt.Errorf("antigravity token exchange: execute request: %w", errDo)
|
||||||
|
}
|
||||||
|
defer func() {
|
||||||
|
if errClose := resp.Body.Close(); errClose != nil {
|
||||||
|
log.Errorf("antigravity token exchange: close body error: %v", errClose)
|
||||||
|
}
|
||||||
|
}()
|
||||||
|
|
||||||
|
if resp.StatusCode < http.StatusOK || resp.StatusCode >= http.StatusMultipleChoices {
|
||||||
|
bodyBytes, errRead := io.ReadAll(io.LimitReader(resp.Body, 8<<10))
|
||||||
|
if errRead != nil {
|
||||||
|
return nil, fmt.Errorf("antigravity token exchange: read response: %w", errRead)
|
||||||
|
}
|
||||||
|
body := strings.TrimSpace(string(bodyBytes))
|
||||||
|
if body == "" {
|
||||||
|
return nil, fmt.Errorf("antigravity token exchange: request failed: status %d", resp.StatusCode)
|
||||||
|
}
|
||||||
|
return nil, fmt.Errorf("antigravity token exchange: request failed: status %d: %s", resp.StatusCode, body)
|
||||||
|
}
|
||||||
|
|
||||||
|
var token TokenResponse
|
||||||
|
if errDecode := json.NewDecoder(resp.Body).Decode(&token); errDecode != nil {
|
||||||
|
return nil, fmt.Errorf("antigravity token exchange: decode response: %w", errDecode)
|
||||||
|
}
|
||||||
|
return &token, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// FetchUserInfo retrieves user email from Google
|
||||||
|
func (o *AntigravityAuth) FetchUserInfo(ctx context.Context, accessToken string) (string, error) {
|
||||||
|
accessToken = strings.TrimSpace(accessToken)
|
||||||
|
if accessToken == "" {
|
||||||
|
return "", fmt.Errorf("antigravity userinfo: missing access token")
|
||||||
|
}
|
||||||
|
req, err := http.NewRequestWithContext(ctx, http.MethodGet, UserInfoEndpoint, nil)
|
||||||
|
if err != nil {
|
||||||
|
return "", fmt.Errorf("antigravity userinfo: create request: %w", err)
|
||||||
|
}
|
||||||
|
req.Header.Set("Authorization", "Bearer "+accessToken)
|
||||||
|
|
||||||
|
resp, errDo := o.httpClient.Do(req)
|
||||||
|
if errDo != nil {
|
||||||
|
return "", fmt.Errorf("antigravity userinfo: execute request: %w", errDo)
|
||||||
|
}
|
||||||
|
defer func() {
|
||||||
|
if errClose := resp.Body.Close(); errClose != nil {
|
||||||
|
log.Errorf("antigravity userinfo: close body error: %v", errClose)
|
||||||
|
}
|
||||||
|
}()
|
||||||
|
|
||||||
|
if resp.StatusCode < http.StatusOK || resp.StatusCode >= http.StatusMultipleChoices {
|
||||||
|
bodyBytes, errRead := io.ReadAll(io.LimitReader(resp.Body, 8<<10))
|
||||||
|
if errRead != nil {
|
||||||
|
return "", fmt.Errorf("antigravity userinfo: read response: %w", errRead)
|
||||||
|
}
|
||||||
|
body := strings.TrimSpace(string(bodyBytes))
|
||||||
|
if body == "" {
|
||||||
|
return "", fmt.Errorf("antigravity userinfo: request failed: status %d", resp.StatusCode)
|
||||||
|
}
|
||||||
|
return "", fmt.Errorf("antigravity userinfo: request failed: status %d: %s", resp.StatusCode, body)
|
||||||
|
}
|
||||||
|
var info userInfo
|
||||||
|
if errDecode := json.NewDecoder(resp.Body).Decode(&info); errDecode != nil {
|
||||||
|
return "", fmt.Errorf("antigravity userinfo: decode response: %w", errDecode)
|
||||||
|
}
|
||||||
|
email := strings.TrimSpace(info.Email)
|
||||||
|
if email == "" {
|
||||||
|
return "", fmt.Errorf("antigravity userinfo: response missing email")
|
||||||
|
}
|
||||||
|
return email, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// FetchProjectID retrieves the project ID for the authenticated user via loadCodeAssist
|
||||||
|
func (o *AntigravityAuth) FetchProjectID(ctx context.Context, accessToken string) (string, error) {
|
||||||
|
loadReqBody := map[string]any{
|
||||||
|
"metadata": map[string]string{
|
||||||
|
"ideType": "ANTIGRAVITY",
|
||||||
|
"platform": "PLATFORM_UNSPECIFIED",
|
||||||
|
"pluginType": "GEMINI",
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
rawBody, errMarshal := json.Marshal(loadReqBody)
|
||||||
|
if errMarshal != nil {
|
||||||
|
return "", fmt.Errorf("marshal request body: %w", errMarshal)
|
||||||
|
}
|
||||||
|
|
||||||
|
endpointURL := fmt.Sprintf("%s/%s:loadCodeAssist", APIEndpoint, APIVersion)
|
||||||
|
req, err := http.NewRequestWithContext(ctx, http.MethodPost, endpointURL, strings.NewReader(string(rawBody)))
|
||||||
|
if err != nil {
|
||||||
|
return "", fmt.Errorf("create request: %w", err)
|
||||||
|
}
|
||||||
|
req.Header.Set("Authorization", "Bearer "+accessToken)
|
||||||
|
req.Header.Set("Content-Type", "application/json")
|
||||||
|
req.Header.Set("User-Agent", APIUserAgent)
|
||||||
|
req.Header.Set("X-Goog-Api-Client", APIClient)
|
||||||
|
req.Header.Set("Client-Metadata", ClientMetadata)
|
||||||
|
|
||||||
|
resp, errDo := o.httpClient.Do(req)
|
||||||
|
if errDo != nil {
|
||||||
|
return "", fmt.Errorf("execute request: %w", errDo)
|
||||||
|
}
|
||||||
|
defer func() {
|
||||||
|
if errClose := resp.Body.Close(); errClose != nil {
|
||||||
|
log.Errorf("antigravity loadCodeAssist: close body error: %v", errClose)
|
||||||
|
}
|
||||||
|
}()
|
||||||
|
|
||||||
|
bodyBytes, errRead := io.ReadAll(resp.Body)
|
||||||
|
if errRead != nil {
|
||||||
|
return "", fmt.Errorf("read response: %w", errRead)
|
||||||
|
}
|
||||||
|
|
||||||
|
if resp.StatusCode < http.StatusOK || resp.StatusCode >= http.StatusMultipleChoices {
|
||||||
|
return "", fmt.Errorf("request failed with status %d: %s", resp.StatusCode, strings.TrimSpace(string(bodyBytes)))
|
||||||
|
}
|
||||||
|
|
||||||
|
var loadResp map[string]any
|
||||||
|
if errDecode := json.Unmarshal(bodyBytes, &loadResp); errDecode != nil {
|
||||||
|
return "", fmt.Errorf("decode response: %w", errDecode)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Extract projectID from response
|
||||||
|
projectID := ""
|
||||||
|
if id, ok := loadResp["cloudaicompanionProject"].(string); ok {
|
||||||
|
projectID = strings.TrimSpace(id)
|
||||||
|
}
|
||||||
|
if projectID == "" {
|
||||||
|
if projectMap, ok := loadResp["cloudaicompanionProject"].(map[string]any); ok {
|
||||||
|
if id, okID := projectMap["id"].(string); okID {
|
||||||
|
projectID = strings.TrimSpace(id)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
if projectID == "" {
|
||||||
|
tierID := "legacy-tier"
|
||||||
|
if tiers, okTiers := loadResp["allowedTiers"].([]any); okTiers {
|
||||||
|
for _, rawTier := range tiers {
|
||||||
|
tier, okTier := rawTier.(map[string]any)
|
||||||
|
if !okTier {
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
if isDefault, okDefault := tier["isDefault"].(bool); okDefault && isDefault {
|
||||||
|
if id, okID := tier["id"].(string); okID && strings.TrimSpace(id) != "" {
|
||||||
|
tierID = strings.TrimSpace(id)
|
||||||
|
break
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
projectID, err = o.OnboardUser(ctx, accessToken, tierID)
|
||||||
|
if err != nil {
|
||||||
|
return "", err
|
||||||
|
}
|
||||||
|
return projectID, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
return projectID, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// OnboardUser attempts to fetch the project ID via onboardUser by polling for completion
|
||||||
|
func (o *AntigravityAuth) OnboardUser(ctx context.Context, accessToken, tierID string) (string, error) {
|
||||||
|
log.Infof("Antigravity: onboarding user with tier: %s", tierID)
|
||||||
|
requestBody := map[string]any{
|
||||||
|
"tierId": tierID,
|
||||||
|
"metadata": map[string]string{
|
||||||
|
"ideType": "ANTIGRAVITY",
|
||||||
|
"platform": "PLATFORM_UNSPECIFIED",
|
||||||
|
"pluginType": "GEMINI",
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
rawBody, errMarshal := json.Marshal(requestBody)
|
||||||
|
if errMarshal != nil {
|
||||||
|
return "", fmt.Errorf("marshal request body: %w", errMarshal)
|
||||||
|
}
|
||||||
|
|
||||||
|
maxAttempts := 5
|
||||||
|
for attempt := 1; attempt <= maxAttempts; attempt++ {
|
||||||
|
log.Debugf("Polling attempt %d/%d", attempt, maxAttempts)
|
||||||
|
|
||||||
|
reqCtx := ctx
|
||||||
|
var cancel context.CancelFunc
|
||||||
|
if reqCtx == nil {
|
||||||
|
reqCtx = context.Background()
|
||||||
|
}
|
||||||
|
reqCtx, cancel = context.WithTimeout(reqCtx, 30*time.Second)
|
||||||
|
|
||||||
|
endpointURL := fmt.Sprintf("%s/%s:onboardUser", APIEndpoint, APIVersion)
|
||||||
|
req, errRequest := http.NewRequestWithContext(reqCtx, http.MethodPost, endpointURL, strings.NewReader(string(rawBody)))
|
||||||
|
if errRequest != nil {
|
||||||
|
cancel()
|
||||||
|
return "", fmt.Errorf("create request: %w", errRequest)
|
||||||
|
}
|
||||||
|
req.Header.Set("Authorization", "Bearer "+accessToken)
|
||||||
|
req.Header.Set("Content-Type", "application/json")
|
||||||
|
req.Header.Set("User-Agent", APIUserAgent)
|
||||||
|
req.Header.Set("X-Goog-Api-Client", APIClient)
|
||||||
|
req.Header.Set("Client-Metadata", ClientMetadata)
|
||||||
|
|
||||||
|
resp, errDo := o.httpClient.Do(req)
|
||||||
|
if errDo != nil {
|
||||||
|
cancel()
|
||||||
|
return "", fmt.Errorf("execute request: %w", errDo)
|
||||||
|
}
|
||||||
|
|
||||||
|
bodyBytes, errRead := io.ReadAll(resp.Body)
|
||||||
|
if errClose := resp.Body.Close(); errClose != nil {
|
||||||
|
log.Errorf("close body error: %v", errClose)
|
||||||
|
}
|
||||||
|
cancel()
|
||||||
|
|
||||||
|
if errRead != nil {
|
||||||
|
return "", fmt.Errorf("read response: %w", errRead)
|
||||||
|
}
|
||||||
|
|
||||||
|
if resp.StatusCode == http.StatusOK {
|
||||||
|
var data map[string]any
|
||||||
|
if errDecode := json.Unmarshal(bodyBytes, &data); errDecode != nil {
|
||||||
|
return "", fmt.Errorf("decode response: %w", errDecode)
|
||||||
|
}
|
||||||
|
|
||||||
|
if done, okDone := data["done"].(bool); okDone && done {
|
||||||
|
projectID := ""
|
||||||
|
if responseData, okResp := data["response"].(map[string]any); okResp {
|
||||||
|
switch projectValue := responseData["cloudaicompanionProject"].(type) {
|
||||||
|
case map[string]any:
|
||||||
|
if id, okID := projectValue["id"].(string); okID {
|
||||||
|
projectID = strings.TrimSpace(id)
|
||||||
|
}
|
||||||
|
case string:
|
||||||
|
projectID = strings.TrimSpace(projectValue)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
if projectID != "" {
|
||||||
|
log.Infof("Successfully fetched project_id: %s", projectID)
|
||||||
|
return projectID, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
return "", fmt.Errorf("no project_id in response")
|
||||||
|
}
|
||||||
|
|
||||||
|
time.Sleep(2 * time.Second)
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
|
||||||
|
responsePreview := strings.TrimSpace(string(bodyBytes))
|
||||||
|
if len(responsePreview) > 500 {
|
||||||
|
responsePreview = responsePreview[:500]
|
||||||
|
}
|
||||||
|
|
||||||
|
responseErr := responsePreview
|
||||||
|
if len(responseErr) > 200 {
|
||||||
|
responseErr = responseErr[:200]
|
||||||
|
}
|
||||||
|
return "", fmt.Errorf("http %d: %s", resp.StatusCode, responseErr)
|
||||||
|
}
|
||||||
|
|
||||||
|
return "", nil
|
||||||
|
}
|
||||||
34
internal/auth/antigravity/constants.go
Normal file
34
internal/auth/antigravity/constants.go
Normal file
@@ -0,0 +1,34 @@
|
|||||||
|
// Package antigravity provides OAuth2 authentication functionality for the Antigravity provider.
|
||||||
|
package antigravity
|
||||||
|
|
||||||
|
// OAuth client credentials and configuration
|
||||||
|
const (
|
||||||
|
ClientID = "1071006060591-tmhssin2h21lcre235vtolojh4g403ep.apps.googleusercontent.com"
|
||||||
|
ClientSecret = "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"
|
||||||
|
CallbackPort = 51121
|
||||||
|
)
|
||||||
|
|
||||||
|
// Scopes defines the OAuth scopes required for Antigravity authentication
|
||||||
|
var Scopes = []string{
|
||||||
|
"https://www.googleapis.com/auth/cloud-platform",
|
||||||
|
"https://www.googleapis.com/auth/userinfo.email",
|
||||||
|
"https://www.googleapis.com/auth/userinfo.profile",
|
||||||
|
"https://www.googleapis.com/auth/cclog",
|
||||||
|
"https://www.googleapis.com/auth/experimentsandconfigs",
|
||||||
|
}
|
||||||
|
|
||||||
|
// OAuth2 endpoints for Google authentication
|
||||||
|
const (
|
||||||
|
TokenEndpoint = "https://oauth2.googleapis.com/token"
|
||||||
|
AuthEndpoint = "https://accounts.google.com/o/oauth2/v2/auth"
|
||||||
|
UserInfoEndpoint = "https://www.googleapis.com/oauth2/v1/userinfo?alt=json"
|
||||||
|
)
|
||||||
|
|
||||||
|
// Antigravity API configuration
|
||||||
|
const (
|
||||||
|
APIEndpoint = "https://cloudcode-pa.googleapis.com"
|
||||||
|
APIVersion = "v1internal"
|
||||||
|
APIUserAgent = "google-api-nodejs-client/9.15.1"
|
||||||
|
APIClient = "google-cloud-sdk vscode_cloudshelleditor/0.1"
|
||||||
|
ClientMetadata = `{"ideType":"IDE_UNSPECIFIED","platform":"PLATFORM_UNSPECIFIED","pluginType":"GEMINI"}`
|
||||||
|
)
|
||||||
16
internal/auth/antigravity/filename.go
Normal file
16
internal/auth/antigravity/filename.go
Normal file
@@ -0,0 +1,16 @@
|
|||||||
|
package antigravity
|
||||||
|
|
||||||
|
import (
|
||||||
|
"fmt"
|
||||||
|
"strings"
|
||||||
|
)
|
||||||
|
|
||||||
|
// CredentialFileName returns the filename used to persist Antigravity credentials.
|
||||||
|
// It uses the email as a suffix to disambiguate accounts.
|
||||||
|
func CredentialFileName(email string) string {
|
||||||
|
email = strings.TrimSpace(email)
|
||||||
|
if email == "" {
|
||||||
|
return "antigravity.json"
|
||||||
|
}
|
||||||
|
return fmt.Sprintf("antigravity-%s.json", email)
|
||||||
|
}
|
||||||
@@ -18,11 +18,12 @@ import (
|
|||||||
log "github.com/sirupsen/logrus"
|
log "github.com/sirupsen/logrus"
|
||||||
)
|
)
|
||||||
|
|
||||||
|
// OAuth configuration constants for Claude/Anthropic
|
||||||
const (
|
const (
|
||||||
anthropicAuthURL = "https://claude.ai/oauth/authorize"
|
AuthURL = "https://claude.ai/oauth/authorize"
|
||||||
anthropicTokenURL = "https://console.anthropic.com/v1/oauth/token"
|
TokenURL = "https://console.anthropic.com/v1/oauth/token"
|
||||||
anthropicClientID = "9d1c250a-e61b-44d9-88ed-5944d1962f5e"
|
ClientID = "9d1c250a-e61b-44d9-88ed-5944d1962f5e"
|
||||||
redirectURI = "http://localhost:54545/callback"
|
RedirectURI = "http://localhost:54545/callback"
|
||||||
)
|
)
|
||||||
|
|
||||||
// tokenResponse represents the response structure from Anthropic's OAuth token endpoint.
|
// tokenResponse represents the response structure from Anthropic's OAuth token endpoint.
|
||||||
@@ -82,16 +83,16 @@ func (o *ClaudeAuth) GenerateAuthURL(state string, pkceCodes *PKCECodes) (string
|
|||||||
|
|
||||||
params := url.Values{
|
params := url.Values{
|
||||||
"code": {"true"},
|
"code": {"true"},
|
||||||
"client_id": {anthropicClientID},
|
"client_id": {ClientID},
|
||||||
"response_type": {"code"},
|
"response_type": {"code"},
|
||||||
"redirect_uri": {redirectURI},
|
"redirect_uri": {RedirectURI},
|
||||||
"scope": {"org:create_api_key user:profile user:inference"},
|
"scope": {"org:create_api_key user:profile user:inference"},
|
||||||
"code_challenge": {pkceCodes.CodeChallenge},
|
"code_challenge": {pkceCodes.CodeChallenge},
|
||||||
"code_challenge_method": {"S256"},
|
"code_challenge_method": {"S256"},
|
||||||
"state": {state},
|
"state": {state},
|
||||||
}
|
}
|
||||||
|
|
||||||
authURL := fmt.Sprintf("%s?%s", anthropicAuthURL, params.Encode())
|
authURL := fmt.Sprintf("%s?%s", AuthURL, params.Encode())
|
||||||
return authURL, state, nil
|
return authURL, state, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -137,8 +138,8 @@ func (o *ClaudeAuth) ExchangeCodeForTokens(ctx context.Context, code, state stri
|
|||||||
"code": newCode,
|
"code": newCode,
|
||||||
"state": state,
|
"state": state,
|
||||||
"grant_type": "authorization_code",
|
"grant_type": "authorization_code",
|
||||||
"client_id": anthropicClientID,
|
"client_id": ClientID,
|
||||||
"redirect_uri": redirectURI,
|
"redirect_uri": RedirectURI,
|
||||||
"code_verifier": pkceCodes.CodeVerifier,
|
"code_verifier": pkceCodes.CodeVerifier,
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -154,7 +155,7 @@ func (o *ClaudeAuth) ExchangeCodeForTokens(ctx context.Context, code, state stri
|
|||||||
|
|
||||||
// log.Debugf("Token exchange request: %s", string(jsonBody))
|
// log.Debugf("Token exchange request: %s", string(jsonBody))
|
||||||
|
|
||||||
req, err := http.NewRequestWithContext(ctx, "POST", anthropicTokenURL, strings.NewReader(string(jsonBody)))
|
req, err := http.NewRequestWithContext(ctx, "POST", TokenURL, strings.NewReader(string(jsonBody)))
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, fmt.Errorf("failed to create token request: %w", err)
|
return nil, fmt.Errorf("failed to create token request: %w", err)
|
||||||
}
|
}
|
||||||
@@ -221,7 +222,7 @@ func (o *ClaudeAuth) RefreshTokens(ctx context.Context, refreshToken string) (*C
|
|||||||
}
|
}
|
||||||
|
|
||||||
reqBody := map[string]interface{}{
|
reqBody := map[string]interface{}{
|
||||||
"client_id": anthropicClientID,
|
"client_id": ClientID,
|
||||||
"grant_type": "refresh_token",
|
"grant_type": "refresh_token",
|
||||||
"refresh_token": refreshToken,
|
"refresh_token": refreshToken,
|
||||||
}
|
}
|
||||||
@@ -231,7 +232,7 @@ func (o *ClaudeAuth) RefreshTokens(ctx context.Context, refreshToken string) (*C
|
|||||||
return nil, fmt.Errorf("failed to marshal request body: %w", err)
|
return nil, fmt.Errorf("failed to marshal request body: %w", err)
|
||||||
}
|
}
|
||||||
|
|
||||||
req, err := http.NewRequestWithContext(ctx, "POST", anthropicTokenURL, strings.NewReader(string(jsonBody)))
|
req, err := http.NewRequestWithContext(ctx, "POST", TokenURL, strings.NewReader(string(jsonBody)))
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, fmt.Errorf("failed to create refresh request: %w", err)
|
return nil, fmt.Errorf("failed to create refresh request: %w", err)
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -19,11 +19,12 @@ import (
|
|||||||
log "github.com/sirupsen/logrus"
|
log "github.com/sirupsen/logrus"
|
||||||
)
|
)
|
||||||
|
|
||||||
|
// OAuth configuration constants for OpenAI Codex
|
||||||
const (
|
const (
|
||||||
openaiAuthURL = "https://auth.openai.com/oauth/authorize"
|
AuthURL = "https://auth.openai.com/oauth/authorize"
|
||||||
openaiTokenURL = "https://auth.openai.com/oauth/token"
|
TokenURL = "https://auth.openai.com/oauth/token"
|
||||||
openaiClientID = "app_EMoamEEZ73f0CkXaXp7hrann"
|
ClientID = "app_EMoamEEZ73f0CkXaXp7hrann"
|
||||||
redirectURI = "http://localhost:1455/auth/callback"
|
RedirectURI = "http://localhost:1455/auth/callback"
|
||||||
)
|
)
|
||||||
|
|
||||||
// CodexAuth handles the OpenAI OAuth2 authentication flow.
|
// CodexAuth handles the OpenAI OAuth2 authentication flow.
|
||||||
@@ -50,9 +51,9 @@ func (o *CodexAuth) GenerateAuthURL(state string, pkceCodes *PKCECodes) (string,
|
|||||||
}
|
}
|
||||||
|
|
||||||
params := url.Values{
|
params := url.Values{
|
||||||
"client_id": {openaiClientID},
|
"client_id": {ClientID},
|
||||||
"response_type": {"code"},
|
"response_type": {"code"},
|
||||||
"redirect_uri": {redirectURI},
|
"redirect_uri": {RedirectURI},
|
||||||
"scope": {"openid email profile offline_access"},
|
"scope": {"openid email profile offline_access"},
|
||||||
"state": {state},
|
"state": {state},
|
||||||
"code_challenge": {pkceCodes.CodeChallenge},
|
"code_challenge": {pkceCodes.CodeChallenge},
|
||||||
@@ -62,7 +63,7 @@ func (o *CodexAuth) GenerateAuthURL(state string, pkceCodes *PKCECodes) (string,
|
|||||||
"codex_cli_simplified_flow": {"true"},
|
"codex_cli_simplified_flow": {"true"},
|
||||||
}
|
}
|
||||||
|
|
||||||
authURL := fmt.Sprintf("%s?%s", openaiAuthURL, params.Encode())
|
authURL := fmt.Sprintf("%s?%s", AuthURL, params.Encode())
|
||||||
return authURL, nil
|
return authURL, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -77,13 +78,13 @@ func (o *CodexAuth) ExchangeCodeForTokens(ctx context.Context, code string, pkce
|
|||||||
// Prepare token exchange request
|
// Prepare token exchange request
|
||||||
data := url.Values{
|
data := url.Values{
|
||||||
"grant_type": {"authorization_code"},
|
"grant_type": {"authorization_code"},
|
||||||
"client_id": {openaiClientID},
|
"client_id": {ClientID},
|
||||||
"code": {code},
|
"code": {code},
|
||||||
"redirect_uri": {redirectURI},
|
"redirect_uri": {RedirectURI},
|
||||||
"code_verifier": {pkceCodes.CodeVerifier},
|
"code_verifier": {pkceCodes.CodeVerifier},
|
||||||
}
|
}
|
||||||
|
|
||||||
req, err := http.NewRequestWithContext(ctx, "POST", openaiTokenURL, strings.NewReader(data.Encode()))
|
req, err := http.NewRequestWithContext(ctx, "POST", TokenURL, strings.NewReader(data.Encode()))
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, fmt.Errorf("failed to create token request: %w", err)
|
return nil, fmt.Errorf("failed to create token request: %w", err)
|
||||||
}
|
}
|
||||||
@@ -163,13 +164,13 @@ func (o *CodexAuth) RefreshTokens(ctx context.Context, refreshToken string) (*Co
|
|||||||
}
|
}
|
||||||
|
|
||||||
data := url.Values{
|
data := url.Values{
|
||||||
"client_id": {openaiClientID},
|
"client_id": {ClientID},
|
||||||
"grant_type": {"refresh_token"},
|
"grant_type": {"refresh_token"},
|
||||||
"refresh_token": {refreshToken},
|
"refresh_token": {refreshToken},
|
||||||
"scope": {"openid profile email"},
|
"scope": {"openid profile email"},
|
||||||
}
|
}
|
||||||
|
|
||||||
req, err := http.NewRequestWithContext(ctx, "POST", openaiTokenURL, strings.NewReader(data.Encode()))
|
req, err := http.NewRequestWithContext(ctx, "POST", TokenURL, strings.NewReader(data.Encode()))
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, fmt.Errorf("failed to create refresh request: %w", err)
|
return nil, fmt.Errorf("failed to create refresh request: %w", err)
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -28,19 +28,19 @@ import (
|
|||||||
"golang.org/x/oauth2/google"
|
"golang.org/x/oauth2/google"
|
||||||
)
|
)
|
||||||
|
|
||||||
|
// OAuth configuration constants for Gemini
|
||||||
const (
|
const (
|
||||||
geminiOauthClientID = "681255809395-oo8ft2oprdrnp9e3aqf6av3hmdib135j.apps.googleusercontent.com"
|
ClientID = "681255809395-oo8ft2oprdrnp9e3aqf6av3hmdib135j.apps.googleusercontent.com"
|
||||||
geminiOauthClientSecret = "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"
|
ClientSecret = "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"
|
||||||
geminiDefaultCallbackPort = 8085
|
DefaultCallbackPort = 8085
|
||||||
)
|
)
|
||||||
|
|
||||||
var (
|
// OAuth scopes for Gemini authentication
|
||||||
geminiOauthScopes = []string{
|
var Scopes = []string{
|
||||||
"https://www.googleapis.com/auth/cloud-platform",
|
"https://www.googleapis.com/auth/cloud-platform",
|
||||||
"https://www.googleapis.com/auth/userinfo.email",
|
"https://www.googleapis.com/auth/userinfo.email",
|
||||||
"https://www.googleapis.com/auth/userinfo.profile",
|
"https://www.googleapis.com/auth/userinfo.profile",
|
||||||
}
|
}
|
||||||
)
|
|
||||||
|
|
||||||
// GeminiAuth provides methods for handling the Gemini OAuth2 authentication flow.
|
// GeminiAuth provides methods for handling the Gemini OAuth2 authentication flow.
|
||||||
// It encapsulates the logic for obtaining, storing, and refreshing authentication tokens
|
// It encapsulates the logic for obtaining, storing, and refreshing authentication tokens
|
||||||
@@ -74,7 +74,7 @@ func NewGeminiAuth() *GeminiAuth {
|
|||||||
// - *http.Client: An HTTP client configured with authentication
|
// - *http.Client: An HTTP client configured with authentication
|
||||||
// - error: An error if the client configuration fails, nil otherwise
|
// - error: An error if the client configuration fails, nil otherwise
|
||||||
func (g *GeminiAuth) GetAuthenticatedClient(ctx context.Context, ts *GeminiTokenStorage, cfg *config.Config, opts *WebLoginOptions) (*http.Client, error) {
|
func (g *GeminiAuth) GetAuthenticatedClient(ctx context.Context, ts *GeminiTokenStorage, cfg *config.Config, opts *WebLoginOptions) (*http.Client, error) {
|
||||||
callbackPort := geminiDefaultCallbackPort
|
callbackPort := DefaultCallbackPort
|
||||||
if opts != nil && opts.CallbackPort > 0 {
|
if opts != nil && opts.CallbackPort > 0 {
|
||||||
callbackPort = opts.CallbackPort
|
callbackPort = opts.CallbackPort
|
||||||
}
|
}
|
||||||
@@ -112,10 +112,10 @@ func (g *GeminiAuth) GetAuthenticatedClient(ctx context.Context, ts *GeminiToken
|
|||||||
|
|
||||||
// Configure the OAuth2 client.
|
// Configure the OAuth2 client.
|
||||||
conf := &oauth2.Config{
|
conf := &oauth2.Config{
|
||||||
ClientID: geminiOauthClientID,
|
ClientID: ClientID,
|
||||||
ClientSecret: geminiOauthClientSecret,
|
ClientSecret: ClientSecret,
|
||||||
RedirectURL: callbackURL, // This will be used by the local server.
|
RedirectURL: callbackURL, // This will be used by the local server.
|
||||||
Scopes: geminiOauthScopes,
|
Scopes: Scopes,
|
||||||
Endpoint: google.Endpoint,
|
Endpoint: google.Endpoint,
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -198,9 +198,9 @@ func (g *GeminiAuth) createTokenStorage(ctx context.Context, config *oauth2.Conf
|
|||||||
}
|
}
|
||||||
|
|
||||||
ifToken["token_uri"] = "https://oauth2.googleapis.com/token"
|
ifToken["token_uri"] = "https://oauth2.googleapis.com/token"
|
||||||
ifToken["client_id"] = geminiOauthClientID
|
ifToken["client_id"] = ClientID
|
||||||
ifToken["client_secret"] = geminiOauthClientSecret
|
ifToken["client_secret"] = ClientSecret
|
||||||
ifToken["scopes"] = geminiOauthScopes
|
ifToken["scopes"] = Scopes
|
||||||
ifToken["universe_domain"] = "googleapis.com"
|
ifToken["universe_domain"] = "googleapis.com"
|
||||||
|
|
||||||
ts := GeminiTokenStorage{
|
ts := GeminiTokenStorage{
|
||||||
@@ -226,7 +226,7 @@ func (g *GeminiAuth) createTokenStorage(ctx context.Context, config *oauth2.Conf
|
|||||||
// - *oauth2.Token: The OAuth2 token obtained from the authorization flow
|
// - *oauth2.Token: The OAuth2 token obtained from the authorization flow
|
||||||
// - error: An error if the token acquisition fails, nil otherwise
|
// - error: An error if the token acquisition fails, nil otherwise
|
||||||
func (g *GeminiAuth) getTokenFromWeb(ctx context.Context, config *oauth2.Config, opts *WebLoginOptions) (*oauth2.Token, error) {
|
func (g *GeminiAuth) getTokenFromWeb(ctx context.Context, config *oauth2.Config, opts *WebLoginOptions) (*oauth2.Token, error) {
|
||||||
callbackPort := geminiDefaultCallbackPort
|
callbackPort := DefaultCallbackPort
|
||||||
if opts != nil && opts.CallbackPort > 0 {
|
if opts != nil && opts.CallbackPort > 0 {
|
||||||
callbackPort = opts.CallbackPort
|
callbackPort = opts.CallbackPort
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -6,6 +6,7 @@ import (
|
|||||||
"time"
|
"time"
|
||||||
|
|
||||||
"github.com/router-for-me/CLIProxyAPI/v6/internal/config"
|
"github.com/router-for-me/CLIProxyAPI/v6/internal/config"
|
||||||
|
"github.com/router-for-me/CLIProxyAPI/v6/internal/util"
|
||||||
log "github.com/sirupsen/logrus"
|
log "github.com/sirupsen/logrus"
|
||||||
)
|
)
|
||||||
|
|
||||||
@@ -49,6 +50,14 @@ func (m *RefreshManager) Initialize(baseDir string, cfg *config.Config) error {
|
|||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
|
resolvedBaseDir, err := util.ResolveAuthDir(baseDir)
|
||||||
|
if err != nil {
|
||||||
|
log.Warnf("refresh manager: failed to resolve auth directory %s: %v", baseDir, err)
|
||||||
|
}
|
||||||
|
if resolvedBaseDir != "" {
|
||||||
|
baseDir = resolvedBaseDir
|
||||||
|
}
|
||||||
|
|
||||||
// 创建 token 存储库
|
// 创建 token 存储库
|
||||||
repo := NewFileTokenRepository(baseDir)
|
repo := NewFileTokenRepository(baseDir)
|
||||||
|
|
||||||
|
|||||||
@@ -1042,10 +1042,10 @@ func (r *ModelRegistry) convertModelToMap(model *ModelInfo, handlerType string)
|
|||||||
"owned_by": model.OwnedBy,
|
"owned_by": model.OwnedBy,
|
||||||
}
|
}
|
||||||
if model.Created > 0 {
|
if model.Created > 0 {
|
||||||
result["created"] = model.Created
|
result["created_at"] = model.Created
|
||||||
}
|
}
|
||||||
if model.Type != "" {
|
if model.Type != "" {
|
||||||
result["type"] = model.Type
|
result["type"] = "model"
|
||||||
}
|
}
|
||||||
if model.DisplayName != "" {
|
if model.DisplayName != "" {
|
||||||
result["display_name"] = model.DisplayName
|
result["display_name"] = model.DisplayName
|
||||||
|
|||||||
@@ -148,87 +148,108 @@ func (e *AntigravityExecutor) Execute(ctx context.Context, auth *cliproxyauth.Au
|
|||||||
baseURLs := antigravityBaseURLFallbackOrder(auth)
|
baseURLs := antigravityBaseURLFallbackOrder(auth)
|
||||||
httpClient := newProxyAwareHTTPClient(ctx, e.cfg, auth, 0)
|
httpClient := newProxyAwareHTTPClient(ctx, e.cfg, auth, 0)
|
||||||
|
|
||||||
var lastStatus int
|
attempts := antigravityRetryAttempts(auth, e.cfg)
|
||||||
var lastBody []byte
|
|
||||||
var lastErr error
|
|
||||||
|
|
||||||
for idx, baseURL := range baseURLs {
|
attemptLoop:
|
||||||
httpReq, errReq := e.buildRequest(ctx, auth, token, baseModel, translated, false, opts.Alt, baseURL)
|
for attempt := 0; attempt < attempts; attempt++ {
|
||||||
if errReq != nil {
|
var lastStatus int
|
||||||
err = errReq
|
var lastBody []byte
|
||||||
return resp, err
|
var lastErr error
|
||||||
}
|
|
||||||
|
|
||||||
httpResp, errDo := httpClient.Do(httpReq)
|
for idx, baseURL := range baseURLs {
|
||||||
if errDo != nil {
|
httpReq, errReq := e.buildRequest(ctx, auth, token, baseModel, translated, false, opts.Alt, baseURL)
|
||||||
recordAPIResponseError(ctx, e.cfg, errDo)
|
if errReq != nil {
|
||||||
if errors.Is(errDo, context.Canceled) || errors.Is(errDo, context.DeadlineExceeded) {
|
err = errReq
|
||||||
return resp, errDo
|
return resp, err
|
||||||
}
|
}
|
||||||
lastStatus = 0
|
|
||||||
lastBody = nil
|
httpResp, errDo := httpClient.Do(httpReq)
|
||||||
lastErr = errDo
|
if errDo != nil {
|
||||||
if idx+1 < len(baseURLs) {
|
recordAPIResponseError(ctx, e.cfg, errDo)
|
||||||
log.Debugf("antigravity executor: request error on base url %s, retrying with fallback base url: %s", baseURL, baseURLs[idx+1])
|
if errors.Is(errDo, context.Canceled) || errors.Is(errDo, context.DeadlineExceeded) {
|
||||||
continue
|
return resp, errDo
|
||||||
|
}
|
||||||
|
lastStatus = 0
|
||||||
|
lastBody = nil
|
||||||
|
lastErr = errDo
|
||||||
|
if idx+1 < len(baseURLs) {
|
||||||
|
log.Debugf("antigravity executor: request error on base url %s, retrying with fallback base url: %s", baseURL, baseURLs[idx+1])
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
err = errDo
|
||||||
|
return resp, err
|
||||||
}
|
}
|
||||||
err = errDo
|
|
||||||
return resp, err
|
recordAPIResponseMetadata(ctx, e.cfg, httpResp.StatusCode, httpResp.Header.Clone())
|
||||||
|
bodyBytes, errRead := io.ReadAll(httpResp.Body)
|
||||||
|
if errClose := httpResp.Body.Close(); errClose != nil {
|
||||||
|
log.Errorf("antigravity executor: close response body error: %v", errClose)
|
||||||
|
}
|
||||||
|
if errRead != nil {
|
||||||
|
recordAPIResponseError(ctx, e.cfg, errRead)
|
||||||
|
err = errRead
|
||||||
|
return resp, err
|
||||||
|
}
|
||||||
|
appendAPIResponseChunk(ctx, e.cfg, bodyBytes)
|
||||||
|
|
||||||
|
if httpResp.StatusCode < http.StatusOK || httpResp.StatusCode >= http.StatusMultipleChoices {
|
||||||
|
log.Debugf("antigravity executor: upstream error status: %d, body: %s", httpResp.StatusCode, summarizeErrorBody(httpResp.Header.Get("Content-Type"), bodyBytes))
|
||||||
|
lastStatus = httpResp.StatusCode
|
||||||
|
lastBody = append([]byte(nil), bodyBytes...)
|
||||||
|
lastErr = nil
|
||||||
|
if httpResp.StatusCode == http.StatusTooManyRequests && idx+1 < len(baseURLs) {
|
||||||
|
log.Debugf("antigravity executor: rate limited on base url %s, retrying with fallback base url: %s", baseURL, baseURLs[idx+1])
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
if antigravityShouldRetryNoCapacity(httpResp.StatusCode, bodyBytes) {
|
||||||
|
if idx+1 < len(baseURLs) {
|
||||||
|
log.Debugf("antigravity executor: no capacity on base url %s, retrying with fallback base url: %s", baseURL, baseURLs[idx+1])
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
if attempt+1 < attempts {
|
||||||
|
delay := antigravityNoCapacityRetryDelay(attempt)
|
||||||
|
log.Debugf("antigravity executor: no capacity for model %s, retrying in %s (attempt %d/%d)", baseModel, delay, attempt+1, attempts)
|
||||||
|
if errWait := antigravityWait(ctx, delay); errWait != nil {
|
||||||
|
return resp, errWait
|
||||||
|
}
|
||||||
|
continue attemptLoop
|
||||||
|
}
|
||||||
|
}
|
||||||
|
sErr := statusErr{code: httpResp.StatusCode, msg: string(bodyBytes)}
|
||||||
|
if httpResp.StatusCode == http.StatusTooManyRequests {
|
||||||
|
if retryAfter, parseErr := parseRetryDelay(bodyBytes); parseErr == nil && retryAfter != nil {
|
||||||
|
sErr.retryAfter = retryAfter
|
||||||
|
}
|
||||||
|
}
|
||||||
|
err = sErr
|
||||||
|
return resp, err
|
||||||
|
}
|
||||||
|
|
||||||
|
reporter.publish(ctx, parseAntigravityUsage(bodyBytes))
|
||||||
|
var param any
|
||||||
|
converted := sdktranslator.TranslateNonStream(ctx, to, from, req.Model, bytes.Clone(opts.OriginalRequest), translated, bodyBytes, ¶m)
|
||||||
|
resp = cliproxyexecutor.Response{Payload: []byte(converted)}
|
||||||
|
reporter.ensurePublished(ctx)
|
||||||
|
return resp, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
recordAPIResponseMetadata(ctx, e.cfg, httpResp.StatusCode, httpResp.Header.Clone())
|
switch {
|
||||||
bodyBytes, errRead := io.ReadAll(httpResp.Body)
|
case lastStatus != 0:
|
||||||
if errClose := httpResp.Body.Close(); errClose != nil {
|
sErr := statusErr{code: lastStatus, msg: string(lastBody)}
|
||||||
log.Errorf("antigravity executor: close response body error: %v", errClose)
|
if lastStatus == http.StatusTooManyRequests {
|
||||||
}
|
if retryAfter, parseErr := parseRetryDelay(lastBody); parseErr == nil && retryAfter != nil {
|
||||||
if errRead != nil {
|
|
||||||
recordAPIResponseError(ctx, e.cfg, errRead)
|
|
||||||
err = errRead
|
|
||||||
return resp, err
|
|
||||||
}
|
|
||||||
appendAPIResponseChunk(ctx, e.cfg, bodyBytes)
|
|
||||||
|
|
||||||
if httpResp.StatusCode < http.StatusOK || httpResp.StatusCode >= http.StatusMultipleChoices {
|
|
||||||
log.Debugf("antigravity executor: upstream error status: %d, body: %s", httpResp.StatusCode, summarizeErrorBody(httpResp.Header.Get("Content-Type"), bodyBytes))
|
|
||||||
lastStatus = httpResp.StatusCode
|
|
||||||
lastBody = append([]byte(nil), bodyBytes...)
|
|
||||||
lastErr = nil
|
|
||||||
if httpResp.StatusCode == http.StatusTooManyRequests && idx+1 < len(baseURLs) {
|
|
||||||
log.Debugf("antigravity executor: rate limited on base url %s, retrying with fallback base url: %s", baseURL, baseURLs[idx+1])
|
|
||||||
continue
|
|
||||||
}
|
|
||||||
sErr := statusErr{code: httpResp.StatusCode, msg: string(bodyBytes)}
|
|
||||||
if httpResp.StatusCode == http.StatusTooManyRequests {
|
|
||||||
if retryAfter, parseErr := parseRetryDelay(bodyBytes); parseErr == nil && retryAfter != nil {
|
|
||||||
sErr.retryAfter = retryAfter
|
sErr.retryAfter = retryAfter
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
err = sErr
|
err = sErr
|
||||||
return resp, err
|
case lastErr != nil:
|
||||||
|
err = lastErr
|
||||||
|
default:
|
||||||
|
err = statusErr{code: http.StatusServiceUnavailable, msg: "antigravity executor: no base url available"}
|
||||||
}
|
}
|
||||||
|
return resp, err
|
||||||
reporter.publish(ctx, parseAntigravityUsage(bodyBytes))
|
|
||||||
var param any
|
|
||||||
converted := sdktranslator.TranslateNonStream(ctx, to, from, req.Model, bytes.Clone(opts.OriginalRequest), translated, bodyBytes, ¶m)
|
|
||||||
resp = cliproxyexecutor.Response{Payload: []byte(converted)}
|
|
||||||
reporter.ensurePublished(ctx)
|
|
||||||
return resp, nil
|
|
||||||
}
|
}
|
||||||
|
|
||||||
switch {
|
|
||||||
case lastStatus != 0:
|
|
||||||
sErr := statusErr{code: lastStatus, msg: string(lastBody)}
|
|
||||||
if lastStatus == http.StatusTooManyRequests {
|
|
||||||
if retryAfter, parseErr := parseRetryDelay(lastBody); parseErr == nil && retryAfter != nil {
|
|
||||||
sErr.retryAfter = retryAfter
|
|
||||||
}
|
|
||||||
}
|
|
||||||
err = sErr
|
|
||||||
case lastErr != nil:
|
|
||||||
err = lastErr
|
|
||||||
default:
|
|
||||||
err = statusErr{code: http.StatusServiceUnavailable, msg: "antigravity executor: no base url available"}
|
|
||||||
}
|
|
||||||
return resp, err
|
return resp, err
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -268,150 +289,171 @@ func (e *AntigravityExecutor) executeClaudeNonStream(ctx context.Context, auth *
|
|||||||
baseURLs := antigravityBaseURLFallbackOrder(auth)
|
baseURLs := antigravityBaseURLFallbackOrder(auth)
|
||||||
httpClient := newProxyAwareHTTPClient(ctx, e.cfg, auth, 0)
|
httpClient := newProxyAwareHTTPClient(ctx, e.cfg, auth, 0)
|
||||||
|
|
||||||
var lastStatus int
|
attempts := antigravityRetryAttempts(auth, e.cfg)
|
||||||
var lastBody []byte
|
|
||||||
var lastErr error
|
|
||||||
|
|
||||||
for idx, baseURL := range baseURLs {
|
attemptLoop:
|
||||||
httpReq, errReq := e.buildRequest(ctx, auth, token, baseModel, translated, true, opts.Alt, baseURL)
|
for attempt := 0; attempt < attempts; attempt++ {
|
||||||
if errReq != nil {
|
var lastStatus int
|
||||||
err = errReq
|
var lastBody []byte
|
||||||
return resp, err
|
var lastErr error
|
||||||
}
|
|
||||||
|
|
||||||
httpResp, errDo := httpClient.Do(httpReq)
|
for idx, baseURL := range baseURLs {
|
||||||
if errDo != nil {
|
httpReq, errReq := e.buildRequest(ctx, auth, token, baseModel, translated, true, opts.Alt, baseURL)
|
||||||
recordAPIResponseError(ctx, e.cfg, errDo)
|
if errReq != nil {
|
||||||
if errors.Is(errDo, context.Canceled) || errors.Is(errDo, context.DeadlineExceeded) {
|
err = errReq
|
||||||
return resp, errDo
|
return resp, err
|
||||||
}
|
}
|
||||||
lastStatus = 0
|
|
||||||
lastBody = nil
|
httpResp, errDo := httpClient.Do(httpReq)
|
||||||
lastErr = errDo
|
if errDo != nil {
|
||||||
if idx+1 < len(baseURLs) {
|
recordAPIResponseError(ctx, e.cfg, errDo)
|
||||||
log.Debugf("antigravity executor: request error on base url %s, retrying with fallback base url: %s", baseURL, baseURLs[idx+1])
|
if errors.Is(errDo, context.Canceled) || errors.Is(errDo, context.DeadlineExceeded) {
|
||||||
continue
|
return resp, errDo
|
||||||
}
|
|
||||||
err = errDo
|
|
||||||
return resp, err
|
|
||||||
}
|
|
||||||
recordAPIResponseMetadata(ctx, e.cfg, httpResp.StatusCode, httpResp.Header.Clone())
|
|
||||||
if httpResp.StatusCode < http.StatusOK || httpResp.StatusCode >= http.StatusMultipleChoices {
|
|
||||||
bodyBytes, errRead := io.ReadAll(httpResp.Body)
|
|
||||||
if errClose := httpResp.Body.Close(); errClose != nil {
|
|
||||||
log.Errorf("antigravity executor: close response body error: %v", errClose)
|
|
||||||
}
|
|
||||||
if errRead != nil {
|
|
||||||
recordAPIResponseError(ctx, e.cfg, errRead)
|
|
||||||
if errors.Is(errRead, context.Canceled) || errors.Is(errRead, context.DeadlineExceeded) {
|
|
||||||
err = errRead
|
|
||||||
return resp, err
|
|
||||||
}
|
|
||||||
if errCtx := ctx.Err(); errCtx != nil {
|
|
||||||
err = errCtx
|
|
||||||
return resp, err
|
|
||||||
}
|
}
|
||||||
lastStatus = 0
|
lastStatus = 0
|
||||||
lastBody = nil
|
lastBody = nil
|
||||||
lastErr = errRead
|
lastErr = errDo
|
||||||
if idx+1 < len(baseURLs) {
|
if idx+1 < len(baseURLs) {
|
||||||
log.Debugf("antigravity executor: read error on base url %s, retrying with fallback base url: %s", baseURL, baseURLs[idx+1])
|
log.Debugf("antigravity executor: request error on base url %s, retrying with fallback base url: %s", baseURL, baseURLs[idx+1])
|
||||||
continue
|
continue
|
||||||
}
|
}
|
||||||
err = errRead
|
err = errDo
|
||||||
return resp, err
|
return resp, err
|
||||||
}
|
}
|
||||||
appendAPIResponseChunk(ctx, e.cfg, bodyBytes)
|
recordAPIResponseMetadata(ctx, e.cfg, httpResp.StatusCode, httpResp.Header.Clone())
|
||||||
lastStatus = httpResp.StatusCode
|
if httpResp.StatusCode < http.StatusOK || httpResp.StatusCode >= http.StatusMultipleChoices {
|
||||||
lastBody = append([]byte(nil), bodyBytes...)
|
bodyBytes, errRead := io.ReadAll(httpResp.Body)
|
||||||
lastErr = nil
|
if errClose := httpResp.Body.Close(); errClose != nil {
|
||||||
if httpResp.StatusCode == http.StatusTooManyRequests && idx+1 < len(baseURLs) {
|
log.Errorf("antigravity executor: close response body error: %v", errClose)
|
||||||
log.Debugf("antigravity executor: rate limited on base url %s, retrying with fallback base url: %s", baseURL, baseURLs[idx+1])
|
}
|
||||||
continue
|
if errRead != nil {
|
||||||
|
recordAPIResponseError(ctx, e.cfg, errRead)
|
||||||
|
if errors.Is(errRead, context.Canceled) || errors.Is(errRead, context.DeadlineExceeded) {
|
||||||
|
err = errRead
|
||||||
|
return resp, err
|
||||||
|
}
|
||||||
|
if errCtx := ctx.Err(); errCtx != nil {
|
||||||
|
err = errCtx
|
||||||
|
return resp, err
|
||||||
|
}
|
||||||
|
lastStatus = 0
|
||||||
|
lastBody = nil
|
||||||
|
lastErr = errRead
|
||||||
|
if idx+1 < len(baseURLs) {
|
||||||
|
log.Debugf("antigravity executor: read error on base url %s, retrying with fallback base url: %s", baseURL, baseURLs[idx+1])
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
err = errRead
|
||||||
|
return resp, err
|
||||||
|
}
|
||||||
|
appendAPIResponseChunk(ctx, e.cfg, bodyBytes)
|
||||||
|
lastStatus = httpResp.StatusCode
|
||||||
|
lastBody = append([]byte(nil), bodyBytes...)
|
||||||
|
lastErr = nil
|
||||||
|
if httpResp.StatusCode == http.StatusTooManyRequests && idx+1 < len(baseURLs) {
|
||||||
|
log.Debugf("antigravity executor: rate limited on base url %s, retrying with fallback base url: %s", baseURL, baseURLs[idx+1])
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
if antigravityShouldRetryNoCapacity(httpResp.StatusCode, bodyBytes) {
|
||||||
|
if idx+1 < len(baseURLs) {
|
||||||
|
log.Debugf("antigravity executor: no capacity on base url %s, retrying with fallback base url: %s", baseURL, baseURLs[idx+1])
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
if attempt+1 < attempts {
|
||||||
|
delay := antigravityNoCapacityRetryDelay(attempt)
|
||||||
|
log.Debugf("antigravity executor: no capacity for model %s, retrying in %s (attempt %d/%d)", baseModel, delay, attempt+1, attempts)
|
||||||
|
if errWait := antigravityWait(ctx, delay); errWait != nil {
|
||||||
|
return resp, errWait
|
||||||
|
}
|
||||||
|
continue attemptLoop
|
||||||
|
}
|
||||||
|
}
|
||||||
|
sErr := statusErr{code: httpResp.StatusCode, msg: string(bodyBytes)}
|
||||||
|
if httpResp.StatusCode == http.StatusTooManyRequests {
|
||||||
|
if retryAfter, parseErr := parseRetryDelay(bodyBytes); parseErr == nil && retryAfter != nil {
|
||||||
|
sErr.retryAfter = retryAfter
|
||||||
|
}
|
||||||
|
}
|
||||||
|
err = sErr
|
||||||
|
return resp, err
|
||||||
}
|
}
|
||||||
sErr := statusErr{code: httpResp.StatusCode, msg: string(bodyBytes)}
|
|
||||||
if httpResp.StatusCode == http.StatusTooManyRequests {
|
out := make(chan cliproxyexecutor.StreamChunk)
|
||||||
if retryAfter, parseErr := parseRetryDelay(bodyBytes); parseErr == nil && retryAfter != nil {
|
go func(resp *http.Response) {
|
||||||
|
defer close(out)
|
||||||
|
defer func() {
|
||||||
|
if errClose := resp.Body.Close(); errClose != nil {
|
||||||
|
log.Errorf("antigravity executor: close response body error: %v", errClose)
|
||||||
|
}
|
||||||
|
}()
|
||||||
|
scanner := bufio.NewScanner(resp.Body)
|
||||||
|
scanner.Buffer(nil, streamScannerBuffer)
|
||||||
|
for scanner.Scan() {
|
||||||
|
line := scanner.Bytes()
|
||||||
|
appendAPIResponseChunk(ctx, e.cfg, line)
|
||||||
|
|
||||||
|
// Filter usage metadata for all models
|
||||||
|
// Only retain usage statistics in the terminal chunk
|
||||||
|
line = FilterSSEUsageMetadata(line)
|
||||||
|
|
||||||
|
payload := jsonPayload(line)
|
||||||
|
if payload == nil {
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
|
||||||
|
if detail, ok := parseAntigravityStreamUsage(payload); ok {
|
||||||
|
reporter.publish(ctx, detail)
|
||||||
|
}
|
||||||
|
|
||||||
|
out <- cliproxyexecutor.StreamChunk{Payload: payload}
|
||||||
|
}
|
||||||
|
if errScan := scanner.Err(); errScan != nil {
|
||||||
|
recordAPIResponseError(ctx, e.cfg, errScan)
|
||||||
|
reporter.publishFailure(ctx)
|
||||||
|
out <- cliproxyexecutor.StreamChunk{Err: errScan}
|
||||||
|
} else {
|
||||||
|
reporter.ensurePublished(ctx)
|
||||||
|
}
|
||||||
|
}(httpResp)
|
||||||
|
|
||||||
|
var buffer bytes.Buffer
|
||||||
|
for chunk := range out {
|
||||||
|
if chunk.Err != nil {
|
||||||
|
return resp, chunk.Err
|
||||||
|
}
|
||||||
|
if len(chunk.Payload) > 0 {
|
||||||
|
_, _ = buffer.Write(chunk.Payload)
|
||||||
|
_, _ = buffer.Write([]byte("\n"))
|
||||||
|
}
|
||||||
|
}
|
||||||
|
resp = cliproxyexecutor.Response{Payload: e.convertStreamToNonStream(buffer.Bytes())}
|
||||||
|
|
||||||
|
reporter.publish(ctx, parseAntigravityUsage(resp.Payload))
|
||||||
|
var param any
|
||||||
|
converted := sdktranslator.TranslateNonStream(ctx, to, from, req.Model, bytes.Clone(opts.OriginalRequest), translated, resp.Payload, ¶m)
|
||||||
|
resp = cliproxyexecutor.Response{Payload: []byte(converted)}
|
||||||
|
reporter.ensurePublished(ctx)
|
||||||
|
|
||||||
|
return resp, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
switch {
|
||||||
|
case lastStatus != 0:
|
||||||
|
sErr := statusErr{code: lastStatus, msg: string(lastBody)}
|
||||||
|
if lastStatus == http.StatusTooManyRequests {
|
||||||
|
if retryAfter, parseErr := parseRetryDelay(lastBody); parseErr == nil && retryAfter != nil {
|
||||||
sErr.retryAfter = retryAfter
|
sErr.retryAfter = retryAfter
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
err = sErr
|
err = sErr
|
||||||
return resp, err
|
case lastErr != nil:
|
||||||
|
err = lastErr
|
||||||
|
default:
|
||||||
|
err = statusErr{code: http.StatusServiceUnavailable, msg: "antigravity executor: no base url available"}
|
||||||
}
|
}
|
||||||
|
return resp, err
|
||||||
out := make(chan cliproxyexecutor.StreamChunk)
|
|
||||||
go func(resp *http.Response) {
|
|
||||||
defer close(out)
|
|
||||||
defer func() {
|
|
||||||
if errClose := resp.Body.Close(); errClose != nil {
|
|
||||||
log.Errorf("antigravity executor: close response body error: %v", errClose)
|
|
||||||
}
|
|
||||||
}()
|
|
||||||
scanner := bufio.NewScanner(resp.Body)
|
|
||||||
scanner.Buffer(nil, streamScannerBuffer)
|
|
||||||
for scanner.Scan() {
|
|
||||||
line := scanner.Bytes()
|
|
||||||
appendAPIResponseChunk(ctx, e.cfg, line)
|
|
||||||
|
|
||||||
// Filter usage metadata for all models
|
|
||||||
// Only retain usage statistics in the terminal chunk
|
|
||||||
line = FilterSSEUsageMetadata(line)
|
|
||||||
|
|
||||||
payload := jsonPayload(line)
|
|
||||||
if payload == nil {
|
|
||||||
continue
|
|
||||||
}
|
|
||||||
|
|
||||||
if detail, ok := parseAntigravityStreamUsage(payload); ok {
|
|
||||||
reporter.publish(ctx, detail)
|
|
||||||
}
|
|
||||||
|
|
||||||
out <- cliproxyexecutor.StreamChunk{Payload: payload}
|
|
||||||
}
|
|
||||||
if errScan := scanner.Err(); errScan != nil {
|
|
||||||
recordAPIResponseError(ctx, e.cfg, errScan)
|
|
||||||
reporter.publishFailure(ctx)
|
|
||||||
out <- cliproxyexecutor.StreamChunk{Err: errScan}
|
|
||||||
} else {
|
|
||||||
reporter.ensurePublished(ctx)
|
|
||||||
}
|
|
||||||
}(httpResp)
|
|
||||||
|
|
||||||
var buffer bytes.Buffer
|
|
||||||
for chunk := range out {
|
|
||||||
if chunk.Err != nil {
|
|
||||||
return resp, chunk.Err
|
|
||||||
}
|
|
||||||
if len(chunk.Payload) > 0 {
|
|
||||||
_, _ = buffer.Write(chunk.Payload)
|
|
||||||
_, _ = buffer.Write([]byte("\n"))
|
|
||||||
}
|
|
||||||
}
|
|
||||||
resp = cliproxyexecutor.Response{Payload: e.convertStreamToNonStream(buffer.Bytes())}
|
|
||||||
|
|
||||||
reporter.publish(ctx, parseAntigravityUsage(resp.Payload))
|
|
||||||
var param any
|
|
||||||
converted := sdktranslator.TranslateNonStream(ctx, to, from, req.Model, bytes.Clone(opts.OriginalRequest), translated, resp.Payload, ¶m)
|
|
||||||
resp = cliproxyexecutor.Response{Payload: []byte(converted)}
|
|
||||||
reporter.ensurePublished(ctx)
|
|
||||||
|
|
||||||
return resp, nil
|
|
||||||
}
|
}
|
||||||
|
|
||||||
switch {
|
|
||||||
case lastStatus != 0:
|
|
||||||
sErr := statusErr{code: lastStatus, msg: string(lastBody)}
|
|
||||||
if lastStatus == http.StatusTooManyRequests {
|
|
||||||
if retryAfter, parseErr := parseRetryDelay(lastBody); parseErr == nil && retryAfter != nil {
|
|
||||||
sErr.retryAfter = retryAfter
|
|
||||||
}
|
|
||||||
}
|
|
||||||
err = sErr
|
|
||||||
case lastErr != nil:
|
|
||||||
err = lastErr
|
|
||||||
default:
|
|
||||||
err = statusErr{code: http.StatusServiceUnavailable, msg: "antigravity executor: no base url available"}
|
|
||||||
}
|
|
||||||
return resp, err
|
return resp, err
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -635,139 +677,160 @@ func (e *AntigravityExecutor) ExecuteStream(ctx context.Context, auth *cliproxya
|
|||||||
baseURLs := antigravityBaseURLFallbackOrder(auth)
|
baseURLs := antigravityBaseURLFallbackOrder(auth)
|
||||||
httpClient := newProxyAwareHTTPClient(ctx, e.cfg, auth, 0)
|
httpClient := newProxyAwareHTTPClient(ctx, e.cfg, auth, 0)
|
||||||
|
|
||||||
var lastStatus int
|
attempts := antigravityRetryAttempts(auth, e.cfg)
|
||||||
var lastBody []byte
|
|
||||||
var lastErr error
|
|
||||||
|
|
||||||
for idx, baseURL := range baseURLs {
|
attemptLoop:
|
||||||
httpReq, errReq := e.buildRequest(ctx, auth, token, baseModel, translated, true, opts.Alt, baseURL)
|
for attempt := 0; attempt < attempts; attempt++ {
|
||||||
if errReq != nil {
|
var lastStatus int
|
||||||
err = errReq
|
var lastBody []byte
|
||||||
return nil, err
|
var lastErr error
|
||||||
}
|
|
||||||
httpResp, errDo := httpClient.Do(httpReq)
|
for idx, baseURL := range baseURLs {
|
||||||
if errDo != nil {
|
httpReq, errReq := e.buildRequest(ctx, auth, token, baseModel, translated, true, opts.Alt, baseURL)
|
||||||
recordAPIResponseError(ctx, e.cfg, errDo)
|
if errReq != nil {
|
||||||
if errors.Is(errDo, context.Canceled) || errors.Is(errDo, context.DeadlineExceeded) {
|
err = errReq
|
||||||
return nil, errDo
|
return nil, err
|
||||||
}
|
}
|
||||||
lastStatus = 0
|
httpResp, errDo := httpClient.Do(httpReq)
|
||||||
lastBody = nil
|
if errDo != nil {
|
||||||
lastErr = errDo
|
recordAPIResponseError(ctx, e.cfg, errDo)
|
||||||
if idx+1 < len(baseURLs) {
|
if errors.Is(errDo, context.Canceled) || errors.Is(errDo, context.DeadlineExceeded) {
|
||||||
log.Debugf("antigravity executor: request error on base url %s, retrying with fallback base url: %s", baseURL, baseURLs[idx+1])
|
return nil, errDo
|
||||||
continue
|
|
||||||
}
|
|
||||||
err = errDo
|
|
||||||
return nil, err
|
|
||||||
}
|
|
||||||
recordAPIResponseMetadata(ctx, e.cfg, httpResp.StatusCode, httpResp.Header.Clone())
|
|
||||||
if httpResp.StatusCode < http.StatusOK || httpResp.StatusCode >= http.StatusMultipleChoices {
|
|
||||||
bodyBytes, errRead := io.ReadAll(httpResp.Body)
|
|
||||||
if errClose := httpResp.Body.Close(); errClose != nil {
|
|
||||||
log.Errorf("antigravity executor: close response body error: %v", errClose)
|
|
||||||
}
|
|
||||||
if errRead != nil {
|
|
||||||
recordAPIResponseError(ctx, e.cfg, errRead)
|
|
||||||
if errors.Is(errRead, context.Canceled) || errors.Is(errRead, context.DeadlineExceeded) {
|
|
||||||
err = errRead
|
|
||||||
return nil, err
|
|
||||||
}
|
|
||||||
if errCtx := ctx.Err(); errCtx != nil {
|
|
||||||
err = errCtx
|
|
||||||
return nil, err
|
|
||||||
}
|
}
|
||||||
lastStatus = 0
|
lastStatus = 0
|
||||||
lastBody = nil
|
lastBody = nil
|
||||||
lastErr = errRead
|
lastErr = errDo
|
||||||
if idx+1 < len(baseURLs) {
|
if idx+1 < len(baseURLs) {
|
||||||
log.Debugf("antigravity executor: read error on base url %s, retrying with fallback base url: %s", baseURL, baseURLs[idx+1])
|
log.Debugf("antigravity executor: request error on base url %s, retrying with fallback base url: %s", baseURL, baseURLs[idx+1])
|
||||||
continue
|
continue
|
||||||
}
|
}
|
||||||
err = errRead
|
err = errDo
|
||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
appendAPIResponseChunk(ctx, e.cfg, bodyBytes)
|
recordAPIResponseMetadata(ctx, e.cfg, httpResp.StatusCode, httpResp.Header.Clone())
|
||||||
lastStatus = httpResp.StatusCode
|
if httpResp.StatusCode < http.StatusOK || httpResp.StatusCode >= http.StatusMultipleChoices {
|
||||||
lastBody = append([]byte(nil), bodyBytes...)
|
bodyBytes, errRead := io.ReadAll(httpResp.Body)
|
||||||
lastErr = nil
|
if errClose := httpResp.Body.Close(); errClose != nil {
|
||||||
if httpResp.StatusCode == http.StatusTooManyRequests && idx+1 < len(baseURLs) {
|
log.Errorf("antigravity executor: close response body error: %v", errClose)
|
||||||
log.Debugf("antigravity executor: rate limited on base url %s, retrying with fallback base url: %s", baseURL, baseURLs[idx+1])
|
}
|
||||||
continue
|
if errRead != nil {
|
||||||
|
recordAPIResponseError(ctx, e.cfg, errRead)
|
||||||
|
if errors.Is(errRead, context.Canceled) || errors.Is(errRead, context.DeadlineExceeded) {
|
||||||
|
err = errRead
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
if errCtx := ctx.Err(); errCtx != nil {
|
||||||
|
err = errCtx
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
lastStatus = 0
|
||||||
|
lastBody = nil
|
||||||
|
lastErr = errRead
|
||||||
|
if idx+1 < len(baseURLs) {
|
||||||
|
log.Debugf("antigravity executor: read error on base url %s, retrying with fallback base url: %s", baseURL, baseURLs[idx+1])
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
err = errRead
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
appendAPIResponseChunk(ctx, e.cfg, bodyBytes)
|
||||||
|
lastStatus = httpResp.StatusCode
|
||||||
|
lastBody = append([]byte(nil), bodyBytes...)
|
||||||
|
lastErr = nil
|
||||||
|
if httpResp.StatusCode == http.StatusTooManyRequests && idx+1 < len(baseURLs) {
|
||||||
|
log.Debugf("antigravity executor: rate limited on base url %s, retrying with fallback base url: %s", baseURL, baseURLs[idx+1])
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
if antigravityShouldRetryNoCapacity(httpResp.StatusCode, bodyBytes) {
|
||||||
|
if idx+1 < len(baseURLs) {
|
||||||
|
log.Debugf("antigravity executor: no capacity on base url %s, retrying with fallback base url: %s", baseURL, baseURLs[idx+1])
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
if attempt+1 < attempts {
|
||||||
|
delay := antigravityNoCapacityRetryDelay(attempt)
|
||||||
|
log.Debugf("antigravity executor: no capacity for model %s, retrying in %s (attempt %d/%d)", baseModel, delay, attempt+1, attempts)
|
||||||
|
if errWait := antigravityWait(ctx, delay); errWait != nil {
|
||||||
|
return nil, errWait
|
||||||
|
}
|
||||||
|
continue attemptLoop
|
||||||
|
}
|
||||||
|
}
|
||||||
|
sErr := statusErr{code: httpResp.StatusCode, msg: string(bodyBytes)}
|
||||||
|
if httpResp.StatusCode == http.StatusTooManyRequests {
|
||||||
|
if retryAfter, parseErr := parseRetryDelay(bodyBytes); parseErr == nil && retryAfter != nil {
|
||||||
|
sErr.retryAfter = retryAfter
|
||||||
|
}
|
||||||
|
}
|
||||||
|
err = sErr
|
||||||
|
return nil, err
|
||||||
}
|
}
|
||||||
sErr := statusErr{code: httpResp.StatusCode, msg: string(bodyBytes)}
|
|
||||||
if httpResp.StatusCode == http.StatusTooManyRequests {
|
out := make(chan cliproxyexecutor.StreamChunk)
|
||||||
if retryAfter, parseErr := parseRetryDelay(bodyBytes); parseErr == nil && retryAfter != nil {
|
stream = out
|
||||||
|
go func(resp *http.Response) {
|
||||||
|
defer close(out)
|
||||||
|
defer func() {
|
||||||
|
if errClose := resp.Body.Close(); errClose != nil {
|
||||||
|
log.Errorf("antigravity executor: close response body error: %v", errClose)
|
||||||
|
}
|
||||||
|
}()
|
||||||
|
scanner := bufio.NewScanner(resp.Body)
|
||||||
|
scanner.Buffer(nil, streamScannerBuffer)
|
||||||
|
var param any
|
||||||
|
for scanner.Scan() {
|
||||||
|
line := scanner.Bytes()
|
||||||
|
appendAPIResponseChunk(ctx, e.cfg, line)
|
||||||
|
|
||||||
|
// Filter usage metadata for all models
|
||||||
|
// Only retain usage statistics in the terminal chunk
|
||||||
|
line = FilterSSEUsageMetadata(line)
|
||||||
|
|
||||||
|
payload := jsonPayload(line)
|
||||||
|
if payload == nil {
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
|
||||||
|
if detail, ok := parseAntigravityStreamUsage(payload); ok {
|
||||||
|
reporter.publish(ctx, detail)
|
||||||
|
}
|
||||||
|
|
||||||
|
chunks := sdktranslator.TranslateStream(ctx, to, from, req.Model, bytes.Clone(opts.OriginalRequest), translated, bytes.Clone(payload), ¶m)
|
||||||
|
for i := range chunks {
|
||||||
|
out <- cliproxyexecutor.StreamChunk{Payload: []byte(chunks[i])}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
tail := sdktranslator.TranslateStream(ctx, to, from, req.Model, bytes.Clone(opts.OriginalRequest), translated, []byte("[DONE]"), ¶m)
|
||||||
|
for i := range tail {
|
||||||
|
out <- cliproxyexecutor.StreamChunk{Payload: []byte(tail[i])}
|
||||||
|
}
|
||||||
|
if errScan := scanner.Err(); errScan != nil {
|
||||||
|
recordAPIResponseError(ctx, e.cfg, errScan)
|
||||||
|
reporter.publishFailure(ctx)
|
||||||
|
out <- cliproxyexecutor.StreamChunk{Err: errScan}
|
||||||
|
} else {
|
||||||
|
reporter.ensurePublished(ctx)
|
||||||
|
}
|
||||||
|
}(httpResp)
|
||||||
|
return stream, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
switch {
|
||||||
|
case lastStatus != 0:
|
||||||
|
sErr := statusErr{code: lastStatus, msg: string(lastBody)}
|
||||||
|
if lastStatus == http.StatusTooManyRequests {
|
||||||
|
if retryAfter, parseErr := parseRetryDelay(lastBody); parseErr == nil && retryAfter != nil {
|
||||||
sErr.retryAfter = retryAfter
|
sErr.retryAfter = retryAfter
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
err = sErr
|
err = sErr
|
||||||
return nil, err
|
case lastErr != nil:
|
||||||
|
err = lastErr
|
||||||
|
default:
|
||||||
|
err = statusErr{code: http.StatusServiceUnavailable, msg: "antigravity executor: no base url available"}
|
||||||
}
|
}
|
||||||
|
return nil, err
|
||||||
out := make(chan cliproxyexecutor.StreamChunk)
|
|
||||||
stream = out
|
|
||||||
go func(resp *http.Response) {
|
|
||||||
defer close(out)
|
|
||||||
defer func() {
|
|
||||||
if errClose := resp.Body.Close(); errClose != nil {
|
|
||||||
log.Errorf("antigravity executor: close response body error: %v", errClose)
|
|
||||||
}
|
|
||||||
}()
|
|
||||||
scanner := bufio.NewScanner(resp.Body)
|
|
||||||
scanner.Buffer(nil, streamScannerBuffer)
|
|
||||||
var param any
|
|
||||||
for scanner.Scan() {
|
|
||||||
line := scanner.Bytes()
|
|
||||||
appendAPIResponseChunk(ctx, e.cfg, line)
|
|
||||||
|
|
||||||
// Filter usage metadata for all models
|
|
||||||
// Only retain usage statistics in the terminal chunk
|
|
||||||
line = FilterSSEUsageMetadata(line)
|
|
||||||
|
|
||||||
payload := jsonPayload(line)
|
|
||||||
if payload == nil {
|
|
||||||
continue
|
|
||||||
}
|
|
||||||
|
|
||||||
if detail, ok := parseAntigravityStreamUsage(payload); ok {
|
|
||||||
reporter.publish(ctx, detail)
|
|
||||||
}
|
|
||||||
|
|
||||||
chunks := sdktranslator.TranslateStream(ctx, to, from, req.Model, bytes.Clone(opts.OriginalRequest), translated, bytes.Clone(payload), ¶m)
|
|
||||||
for i := range chunks {
|
|
||||||
out <- cliproxyexecutor.StreamChunk{Payload: []byte(chunks[i])}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
tail := sdktranslator.TranslateStream(ctx, to, from, req.Model, bytes.Clone(opts.OriginalRequest), translated, []byte("[DONE]"), ¶m)
|
|
||||||
for i := range tail {
|
|
||||||
out <- cliproxyexecutor.StreamChunk{Payload: []byte(tail[i])}
|
|
||||||
}
|
|
||||||
if errScan := scanner.Err(); errScan != nil {
|
|
||||||
recordAPIResponseError(ctx, e.cfg, errScan)
|
|
||||||
reporter.publishFailure(ctx)
|
|
||||||
out <- cliproxyexecutor.StreamChunk{Err: errScan}
|
|
||||||
} else {
|
|
||||||
reporter.ensurePublished(ctx)
|
|
||||||
}
|
|
||||||
}(httpResp)
|
|
||||||
return stream, nil
|
|
||||||
}
|
}
|
||||||
|
|
||||||
switch {
|
|
||||||
case lastStatus != 0:
|
|
||||||
sErr := statusErr{code: lastStatus, msg: string(lastBody)}
|
|
||||||
if lastStatus == http.StatusTooManyRequests {
|
|
||||||
if retryAfter, parseErr := parseRetryDelay(lastBody); parseErr == nil && retryAfter != nil {
|
|
||||||
sErr.retryAfter = retryAfter
|
|
||||||
}
|
|
||||||
}
|
|
||||||
err = sErr
|
|
||||||
case lastErr != nil:
|
|
||||||
err = lastErr
|
|
||||||
default:
|
|
||||||
err = statusErr{code: http.StatusServiceUnavailable, msg: "antigravity executor: no base url available"}
|
|
||||||
}
|
|
||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -997,7 +1060,7 @@ func FetchAntigravityModels(ctx context.Context, auth *cliproxyauth.Auth, cfg *c
|
|||||||
now := time.Now().Unix()
|
now := time.Now().Unix()
|
||||||
modelConfig := registry.GetAntigravityModelConfig()
|
modelConfig := registry.GetAntigravityModelConfig()
|
||||||
models := make([]*registry.ModelInfo, 0, len(result.Map()))
|
models := make([]*registry.ModelInfo, 0, len(result.Map()))
|
||||||
for originalName := range result.Map() {
|
for originalName, modelData := range result.Map() {
|
||||||
modelID := strings.TrimSpace(originalName)
|
modelID := strings.TrimSpace(originalName)
|
||||||
if modelID == "" {
|
if modelID == "" {
|
||||||
continue
|
continue
|
||||||
@@ -1007,12 +1070,18 @@ func FetchAntigravityModels(ctx context.Context, auth *cliproxyauth.Auth, cfg *c
|
|||||||
continue
|
continue
|
||||||
}
|
}
|
||||||
modelCfg := modelConfig[modelID]
|
modelCfg := modelConfig[modelID]
|
||||||
modelName := modelID
|
|
||||||
|
// Extract displayName from upstream response, fallback to modelID
|
||||||
|
displayName := modelData.Get("displayName").String()
|
||||||
|
if displayName == "" {
|
||||||
|
displayName = modelID
|
||||||
|
}
|
||||||
|
|
||||||
modelInfo := ®istry.ModelInfo{
|
modelInfo := ®istry.ModelInfo{
|
||||||
ID: modelID,
|
ID: modelID,
|
||||||
Name: modelName,
|
Name: modelID,
|
||||||
Description: modelID,
|
Description: displayName,
|
||||||
DisplayName: modelID,
|
DisplayName: displayName,
|
||||||
Version: modelID,
|
Version: modelID,
|
||||||
Object: "model",
|
Object: "model",
|
||||||
Created: now,
|
Created: now,
|
||||||
@@ -1378,14 +1447,70 @@ func resolveUserAgent(auth *cliproxyauth.Auth) string {
|
|||||||
return defaultAntigravityAgent
|
return defaultAntigravityAgent
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func antigravityRetryAttempts(auth *cliproxyauth.Auth, cfg *config.Config) int {
|
||||||
|
retry := 0
|
||||||
|
if cfg != nil {
|
||||||
|
retry = cfg.RequestRetry
|
||||||
|
}
|
||||||
|
if auth != nil {
|
||||||
|
if override, ok := auth.RequestRetryOverride(); ok {
|
||||||
|
retry = override
|
||||||
|
}
|
||||||
|
}
|
||||||
|
if retry < 0 {
|
||||||
|
retry = 0
|
||||||
|
}
|
||||||
|
attempts := retry + 1
|
||||||
|
if attempts < 1 {
|
||||||
|
return 1
|
||||||
|
}
|
||||||
|
return attempts
|
||||||
|
}
|
||||||
|
|
||||||
|
func antigravityShouldRetryNoCapacity(statusCode int, body []byte) bool {
|
||||||
|
if statusCode != http.StatusServiceUnavailable {
|
||||||
|
return false
|
||||||
|
}
|
||||||
|
if len(body) == 0 {
|
||||||
|
return false
|
||||||
|
}
|
||||||
|
msg := strings.ToLower(string(body))
|
||||||
|
return strings.Contains(msg, "no capacity available")
|
||||||
|
}
|
||||||
|
|
||||||
|
func antigravityNoCapacityRetryDelay(attempt int) time.Duration {
|
||||||
|
if attempt < 0 {
|
||||||
|
attempt = 0
|
||||||
|
}
|
||||||
|
delay := time.Duration(attempt+1) * 250 * time.Millisecond
|
||||||
|
if delay > 2*time.Second {
|
||||||
|
delay = 2 * time.Second
|
||||||
|
}
|
||||||
|
return delay
|
||||||
|
}
|
||||||
|
|
||||||
|
func antigravityWait(ctx context.Context, wait time.Duration) error {
|
||||||
|
if wait <= 0 {
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
timer := time.NewTimer(wait)
|
||||||
|
defer timer.Stop()
|
||||||
|
select {
|
||||||
|
case <-ctx.Done():
|
||||||
|
return ctx.Err()
|
||||||
|
case <-timer.C:
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
func antigravityBaseURLFallbackOrder(auth *cliproxyauth.Auth) []string {
|
func antigravityBaseURLFallbackOrder(auth *cliproxyauth.Auth) []string {
|
||||||
if base := resolveCustomAntigravityBaseURL(auth); base != "" {
|
if base := resolveCustomAntigravityBaseURL(auth); base != "" {
|
||||||
return []string{base}
|
return []string{base}
|
||||||
}
|
}
|
||||||
return []string{
|
return []string{
|
||||||
antigravitySandboxBaseURLDaily,
|
|
||||||
antigravityBaseURLDaily,
|
antigravityBaseURLDaily,
|
||||||
antigravityBaseURLProd,
|
antigravitySandboxBaseURLDaily,
|
||||||
|
// antigravityBaseURLProd,
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|||||||
@@ -163,7 +163,7 @@ func (e *ClaudeExecutor) Execute(ctx context.Context, auth *cliproxyauth.Auth, r
|
|||||||
if httpResp.StatusCode < 200 || httpResp.StatusCode >= 300 {
|
if httpResp.StatusCode < 200 || httpResp.StatusCode >= 300 {
|
||||||
b, _ := io.ReadAll(httpResp.Body)
|
b, _ := io.ReadAll(httpResp.Body)
|
||||||
appendAPIResponseChunk(ctx, e.cfg, b)
|
appendAPIResponseChunk(ctx, e.cfg, b)
|
||||||
log.Debugf("request error, error status: %d, error body: %s", httpResp.StatusCode, summarizeErrorBody(httpResp.Header.Get("Content-Type"), b))
|
logWithRequestID(ctx).Debugf("request error, error status: %d, error message: %s", httpResp.StatusCode, summarizeErrorBody(httpResp.Header.Get("Content-Type"), b))
|
||||||
err = statusErr{code: httpResp.StatusCode, msg: string(b)}
|
err = statusErr{code: httpResp.StatusCode, msg: string(b)}
|
||||||
if errClose := httpResp.Body.Close(); errClose != nil {
|
if errClose := httpResp.Body.Close(); errClose != nil {
|
||||||
log.Errorf("response body close error: %v", errClose)
|
log.Errorf("response body close error: %v", errClose)
|
||||||
@@ -295,7 +295,7 @@ func (e *ClaudeExecutor) ExecuteStream(ctx context.Context, auth *cliproxyauth.A
|
|||||||
if httpResp.StatusCode < 200 || httpResp.StatusCode >= 300 {
|
if httpResp.StatusCode < 200 || httpResp.StatusCode >= 300 {
|
||||||
b, _ := io.ReadAll(httpResp.Body)
|
b, _ := io.ReadAll(httpResp.Body)
|
||||||
appendAPIResponseChunk(ctx, e.cfg, b)
|
appendAPIResponseChunk(ctx, e.cfg, b)
|
||||||
log.Debugf("request error, error status: %d, error body: %s", httpResp.StatusCode, summarizeErrorBody(httpResp.Header.Get("Content-Type"), b))
|
logWithRequestID(ctx).Debugf("request error, error status: %d, error message: %s", httpResp.StatusCode, summarizeErrorBody(httpResp.Header.Get("Content-Type"), b))
|
||||||
if errClose := httpResp.Body.Close(); errClose != nil {
|
if errClose := httpResp.Body.Close(); errClose != nil {
|
||||||
log.Errorf("response body close error: %v", errClose)
|
log.Errorf("response body close error: %v", errClose)
|
||||||
}
|
}
|
||||||
@@ -733,6 +733,11 @@ func applyClaudeToolPrefix(body []byte, prefix string) []byte {
|
|||||||
|
|
||||||
if tools := gjson.GetBytes(body, "tools"); tools.Exists() && tools.IsArray() {
|
if tools := gjson.GetBytes(body, "tools"); tools.Exists() && tools.IsArray() {
|
||||||
tools.ForEach(func(index, tool gjson.Result) bool {
|
tools.ForEach(func(index, tool gjson.Result) bool {
|
||||||
|
// Skip built-in tools (web_search, code_execution, etc.) which have
|
||||||
|
// a "type" field and require their name to remain unchanged.
|
||||||
|
if tool.Get("type").Exists() && tool.Get("type").String() != "" {
|
||||||
|
return true
|
||||||
|
}
|
||||||
name := tool.Get("name").String()
|
name := tool.Get("name").String()
|
||||||
if name == "" || strings.HasPrefix(name, prefix) {
|
if name == "" || strings.HasPrefix(name, prefix) {
|
||||||
return true
|
return true
|
||||||
|
|||||||
@@ -25,6 +25,18 @@ func TestApplyClaudeToolPrefix(t *testing.T) {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func TestApplyClaudeToolPrefix_SkipsBuiltinTools(t *testing.T) {
|
||||||
|
input := []byte(`{"tools":[{"type":"web_search_20250305","name":"web_search"},{"name":"my_custom_tool","input_schema":{"type":"object"}}]}`)
|
||||||
|
out := applyClaudeToolPrefix(input, "proxy_")
|
||||||
|
|
||||||
|
if got := gjson.GetBytes(out, "tools.0.name").String(); got != "web_search" {
|
||||||
|
t.Fatalf("built-in tool name should not be prefixed: tools.0.name = %q, want %q", got, "web_search")
|
||||||
|
}
|
||||||
|
if got := gjson.GetBytes(out, "tools.1.name").String(); got != "proxy_my_custom_tool" {
|
||||||
|
t.Fatalf("custom tool should be prefixed: tools.1.name = %q, want %q", got, "proxy_my_custom_tool")
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
func TestStripClaudeToolPrefixFromResponse(t *testing.T) {
|
func TestStripClaudeToolPrefixFromResponse(t *testing.T) {
|
||||||
input := []byte(`{"content":[{"type":"tool_use","name":"proxy_alpha","id":"t1","input":{}},{"type":"tool_use","name":"bravo","id":"t2","input":{}}]}`)
|
input := []byte(`{"content":[{"type":"tool_use","name":"proxy_alpha","id":"t1","input":{}},{"type":"tool_use","name":"bravo","id":"t2","input":{}}]}`)
|
||||||
out := stripClaudeToolPrefixFromResponse(input, "proxy_")
|
out := stripClaudeToolPrefixFromResponse(input, "proxy_")
|
||||||
|
|||||||
@@ -150,7 +150,7 @@ func (e *CodexExecutor) Execute(ctx context.Context, auth *cliproxyauth.Auth, re
|
|||||||
if httpResp.StatusCode < 200 || httpResp.StatusCode >= 300 {
|
if httpResp.StatusCode < 200 || httpResp.StatusCode >= 300 {
|
||||||
b, _ := io.ReadAll(httpResp.Body)
|
b, _ := io.ReadAll(httpResp.Body)
|
||||||
appendAPIResponseChunk(ctx, e.cfg, b)
|
appendAPIResponseChunk(ctx, e.cfg, b)
|
||||||
log.Debugf("request error, error status: %d, error body: %s", httpResp.StatusCode, summarizeErrorBody(httpResp.Header.Get("Content-Type"), b))
|
logWithRequestID(ctx).Debugf("request error, error status: %d, error message: %s", httpResp.StatusCode, summarizeErrorBody(httpResp.Header.Get("Content-Type"), b))
|
||||||
err = statusErr{code: httpResp.StatusCode, msg: string(b)}
|
err = statusErr{code: httpResp.StatusCode, msg: string(b)}
|
||||||
return resp, err
|
return resp, err
|
||||||
}
|
}
|
||||||
@@ -265,7 +265,7 @@ func (e *CodexExecutor) ExecuteStream(ctx context.Context, auth *cliproxyauth.Au
|
|||||||
return nil, readErr
|
return nil, readErr
|
||||||
}
|
}
|
||||||
appendAPIResponseChunk(ctx, e.cfg, data)
|
appendAPIResponseChunk(ctx, e.cfg, data)
|
||||||
log.Debugf("request error, error status: %d, error body: %s", httpResp.StatusCode, summarizeErrorBody(httpResp.Header.Get("Content-Type"), data))
|
logWithRequestID(ctx).Debugf("request error, error status: %d, error message: %s", httpResp.StatusCode, summarizeErrorBody(httpResp.Header.Get("Content-Type"), data))
|
||||||
err = statusErr{code: httpResp.StatusCode, msg: string(data)}
|
err = statusErr{code: httpResp.StatusCode, msg: string(data)}
|
||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -227,7 +227,7 @@ func (e *GeminiCLIExecutor) Execute(ctx context.Context, auth *cliproxyauth.Auth
|
|||||||
|
|
||||||
lastStatus = httpResp.StatusCode
|
lastStatus = httpResp.StatusCode
|
||||||
lastBody = append([]byte(nil), data...)
|
lastBody = append([]byte(nil), data...)
|
||||||
log.Debugf("request error, error status: %d, error body: %s", httpResp.StatusCode, summarizeErrorBody(httpResp.Header.Get("Content-Type"), data))
|
logWithRequestID(ctx).Debugf("request error, error status: %d, error message: %s", httpResp.StatusCode, summarizeErrorBody(httpResp.Header.Get("Content-Type"), data))
|
||||||
if httpResp.StatusCode == 429 {
|
if httpResp.StatusCode == 429 {
|
||||||
if idx+1 < len(models) {
|
if idx+1 < len(models) {
|
||||||
log.Debugf("gemini cli executor: rate limited, retrying with next model: %s", models[idx+1])
|
log.Debugf("gemini cli executor: rate limited, retrying with next model: %s", models[idx+1])
|
||||||
@@ -360,7 +360,7 @@ func (e *GeminiCLIExecutor) ExecuteStream(ctx context.Context, auth *cliproxyaut
|
|||||||
appendAPIResponseChunk(ctx, e.cfg, data)
|
appendAPIResponseChunk(ctx, e.cfg, data)
|
||||||
lastStatus = httpResp.StatusCode
|
lastStatus = httpResp.StatusCode
|
||||||
lastBody = append([]byte(nil), data...)
|
lastBody = append([]byte(nil), data...)
|
||||||
log.Debugf("request error, error status: %d, error body: %s", httpResp.StatusCode, summarizeErrorBody(httpResp.Header.Get("Content-Type"), data))
|
logWithRequestID(ctx).Debugf("request error, error status: %d, error message: %s", httpResp.StatusCode, summarizeErrorBody(httpResp.Header.Get("Content-Type"), data))
|
||||||
if httpResp.StatusCode == 429 {
|
if httpResp.StatusCode == 429 {
|
||||||
if idx+1 < len(models) {
|
if idx+1 < len(models) {
|
||||||
log.Debugf("gemini cli executor: rate limited, retrying with next model: %s", models[idx+1])
|
log.Debugf("gemini cli executor: rate limited, retrying with next model: %s", models[idx+1])
|
||||||
|
|||||||
@@ -188,7 +188,7 @@ func (e *GeminiExecutor) Execute(ctx context.Context, auth *cliproxyauth.Auth, r
|
|||||||
if httpResp.StatusCode < 200 || httpResp.StatusCode >= 300 {
|
if httpResp.StatusCode < 200 || httpResp.StatusCode >= 300 {
|
||||||
b, _ := io.ReadAll(httpResp.Body)
|
b, _ := io.ReadAll(httpResp.Body)
|
||||||
appendAPIResponseChunk(ctx, e.cfg, b)
|
appendAPIResponseChunk(ctx, e.cfg, b)
|
||||||
log.Debugf("request error, error status: %d, error body: %s", httpResp.StatusCode, summarizeErrorBody(httpResp.Header.Get("Content-Type"), b))
|
logWithRequestID(ctx).Debugf("request error, error status: %d, error message: %s", httpResp.StatusCode, summarizeErrorBody(httpResp.Header.Get("Content-Type"), b))
|
||||||
err = statusErr{code: httpResp.StatusCode, msg: string(b)}
|
err = statusErr{code: httpResp.StatusCode, msg: string(b)}
|
||||||
return resp, err
|
return resp, err
|
||||||
}
|
}
|
||||||
@@ -282,7 +282,7 @@ func (e *GeminiExecutor) ExecuteStream(ctx context.Context, auth *cliproxyauth.A
|
|||||||
if httpResp.StatusCode < 200 || httpResp.StatusCode >= 300 {
|
if httpResp.StatusCode < 200 || httpResp.StatusCode >= 300 {
|
||||||
b, _ := io.ReadAll(httpResp.Body)
|
b, _ := io.ReadAll(httpResp.Body)
|
||||||
appendAPIResponseChunk(ctx, e.cfg, b)
|
appendAPIResponseChunk(ctx, e.cfg, b)
|
||||||
log.Debugf("request error, error status: %d, error body: %s", httpResp.StatusCode, summarizeErrorBody(httpResp.Header.Get("Content-Type"), b))
|
logWithRequestID(ctx).Debugf("request error, error status: %d, error message: %s", httpResp.StatusCode, summarizeErrorBody(httpResp.Header.Get("Content-Type"), b))
|
||||||
if errClose := httpResp.Body.Close(); errClose != nil {
|
if errClose := httpResp.Body.Close(); errClose != nil {
|
||||||
log.Errorf("gemini executor: close response body error: %v", errClose)
|
log.Errorf("gemini executor: close response body error: %v", errClose)
|
||||||
}
|
}
|
||||||
@@ -402,7 +402,7 @@ func (e *GeminiExecutor) CountTokens(ctx context.Context, auth *cliproxyauth.Aut
|
|||||||
}
|
}
|
||||||
appendAPIResponseChunk(ctx, e.cfg, data)
|
appendAPIResponseChunk(ctx, e.cfg, data)
|
||||||
if resp.StatusCode < 200 || resp.StatusCode >= 300 {
|
if resp.StatusCode < 200 || resp.StatusCode >= 300 {
|
||||||
log.Debugf("request error, error status: %d, error body: %s", resp.StatusCode, summarizeErrorBody(resp.Header.Get("Content-Type"), data))
|
logWithRequestID(ctx).Debugf("request error, error status: %d, error message: %s", resp.StatusCode, summarizeErrorBody(resp.Header.Get("Content-Type"), data))
|
||||||
return cliproxyexecutor.Response{}, statusErr{code: resp.StatusCode, msg: string(data)}
|
return cliproxyexecutor.Response{}, statusErr{code: resp.StatusCode, msg: string(data)}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|||||||
@@ -389,7 +389,7 @@ func (e *GeminiVertexExecutor) executeWithServiceAccount(ctx context.Context, au
|
|||||||
if httpResp.StatusCode < 200 || httpResp.StatusCode >= 300 {
|
if httpResp.StatusCode < 200 || httpResp.StatusCode >= 300 {
|
||||||
b, _ := io.ReadAll(httpResp.Body)
|
b, _ := io.ReadAll(httpResp.Body)
|
||||||
appendAPIResponseChunk(ctx, e.cfg, b)
|
appendAPIResponseChunk(ctx, e.cfg, b)
|
||||||
log.Debugf("request error, error status: %d, error body: %s", httpResp.StatusCode, summarizeErrorBody(httpResp.Header.Get("Content-Type"), b))
|
logWithRequestID(ctx).Debugf("request error, error status: %d, error message: %s", httpResp.StatusCode, summarizeErrorBody(httpResp.Header.Get("Content-Type"), b))
|
||||||
err = statusErr{code: httpResp.StatusCode, msg: string(b)}
|
err = statusErr{code: httpResp.StatusCode, msg: string(b)}
|
||||||
return resp, err
|
return resp, err
|
||||||
}
|
}
|
||||||
@@ -503,7 +503,7 @@ func (e *GeminiVertexExecutor) executeWithAPIKey(ctx context.Context, auth *clip
|
|||||||
if httpResp.StatusCode < 200 || httpResp.StatusCode >= 300 {
|
if httpResp.StatusCode < 200 || httpResp.StatusCode >= 300 {
|
||||||
b, _ := io.ReadAll(httpResp.Body)
|
b, _ := io.ReadAll(httpResp.Body)
|
||||||
appendAPIResponseChunk(ctx, e.cfg, b)
|
appendAPIResponseChunk(ctx, e.cfg, b)
|
||||||
log.Debugf("request error, error status: %d, error body: %s", httpResp.StatusCode, summarizeErrorBody(httpResp.Header.Get("Content-Type"), b))
|
logWithRequestID(ctx).Debugf("request error, error status: %d, error message: %s", httpResp.StatusCode, summarizeErrorBody(httpResp.Header.Get("Content-Type"), b))
|
||||||
err = statusErr{code: httpResp.StatusCode, msg: string(b)}
|
err = statusErr{code: httpResp.StatusCode, msg: string(b)}
|
||||||
return resp, err
|
return resp, err
|
||||||
}
|
}
|
||||||
@@ -601,7 +601,7 @@ func (e *GeminiVertexExecutor) executeStreamWithServiceAccount(ctx context.Conte
|
|||||||
if httpResp.StatusCode < 200 || httpResp.StatusCode >= 300 {
|
if httpResp.StatusCode < 200 || httpResp.StatusCode >= 300 {
|
||||||
b, _ := io.ReadAll(httpResp.Body)
|
b, _ := io.ReadAll(httpResp.Body)
|
||||||
appendAPIResponseChunk(ctx, e.cfg, b)
|
appendAPIResponseChunk(ctx, e.cfg, b)
|
||||||
log.Debugf("request error, error status: %d, error body: %s", httpResp.StatusCode, summarizeErrorBody(httpResp.Header.Get("Content-Type"), b))
|
logWithRequestID(ctx).Debugf("request error, error status: %d, error message: %s", httpResp.StatusCode, summarizeErrorBody(httpResp.Header.Get("Content-Type"), b))
|
||||||
if errClose := httpResp.Body.Close(); errClose != nil {
|
if errClose := httpResp.Body.Close(); errClose != nil {
|
||||||
log.Errorf("vertex executor: close response body error: %v", errClose)
|
log.Errorf("vertex executor: close response body error: %v", errClose)
|
||||||
}
|
}
|
||||||
@@ -725,7 +725,7 @@ func (e *GeminiVertexExecutor) executeStreamWithAPIKey(ctx context.Context, auth
|
|||||||
if httpResp.StatusCode < 200 || httpResp.StatusCode >= 300 {
|
if httpResp.StatusCode < 200 || httpResp.StatusCode >= 300 {
|
||||||
b, _ := io.ReadAll(httpResp.Body)
|
b, _ := io.ReadAll(httpResp.Body)
|
||||||
appendAPIResponseChunk(ctx, e.cfg, b)
|
appendAPIResponseChunk(ctx, e.cfg, b)
|
||||||
log.Debugf("request error, error status: %d, error body: %s", httpResp.StatusCode, summarizeErrorBody(httpResp.Header.Get("Content-Type"), b))
|
logWithRequestID(ctx).Debugf("request error, error status: %d, error message: %s", httpResp.StatusCode, summarizeErrorBody(httpResp.Header.Get("Content-Type"), b))
|
||||||
if errClose := httpResp.Body.Close(); errClose != nil {
|
if errClose := httpResp.Body.Close(); errClose != nil {
|
||||||
log.Errorf("vertex executor: close response body error: %v", errClose)
|
log.Errorf("vertex executor: close response body error: %v", errClose)
|
||||||
}
|
}
|
||||||
@@ -838,7 +838,7 @@ func (e *GeminiVertexExecutor) countTokensWithServiceAccount(ctx context.Context
|
|||||||
if httpResp.StatusCode < 200 || httpResp.StatusCode >= 300 {
|
if httpResp.StatusCode < 200 || httpResp.StatusCode >= 300 {
|
||||||
b, _ := io.ReadAll(httpResp.Body)
|
b, _ := io.ReadAll(httpResp.Body)
|
||||||
appendAPIResponseChunk(ctx, e.cfg, b)
|
appendAPIResponseChunk(ctx, e.cfg, b)
|
||||||
log.Debugf("request error, error status: %d, error body: %s", httpResp.StatusCode, summarizeErrorBody(httpResp.Header.Get("Content-Type"), b))
|
logWithRequestID(ctx).Debugf("request error, error status: %d, error message: %s", httpResp.StatusCode, summarizeErrorBody(httpResp.Header.Get("Content-Type"), b))
|
||||||
return cliproxyexecutor.Response{}, statusErr{code: httpResp.StatusCode, msg: string(b)}
|
return cliproxyexecutor.Response{}, statusErr{code: httpResp.StatusCode, msg: string(b)}
|
||||||
}
|
}
|
||||||
data, errRead := io.ReadAll(httpResp.Body)
|
data, errRead := io.ReadAll(httpResp.Body)
|
||||||
@@ -922,7 +922,7 @@ func (e *GeminiVertexExecutor) countTokensWithAPIKey(ctx context.Context, auth *
|
|||||||
if httpResp.StatusCode < 200 || httpResp.StatusCode >= 300 {
|
if httpResp.StatusCode < 200 || httpResp.StatusCode >= 300 {
|
||||||
b, _ := io.ReadAll(httpResp.Body)
|
b, _ := io.ReadAll(httpResp.Body)
|
||||||
appendAPIResponseChunk(ctx, e.cfg, b)
|
appendAPIResponseChunk(ctx, e.cfg, b)
|
||||||
log.Debugf("request error, error status: %d, error body: %s", httpResp.StatusCode, summarizeErrorBody(httpResp.Header.Get("Content-Type"), b))
|
logWithRequestID(ctx).Debugf("request error, error status: %d, error message: %s", httpResp.StatusCode, summarizeErrorBody(httpResp.Header.Get("Content-Type"), b))
|
||||||
return cliproxyexecutor.Response{}, statusErr{code: httpResp.StatusCode, msg: string(b)}
|
return cliproxyexecutor.Response{}, statusErr{code: httpResp.StatusCode, msg: string(b)}
|
||||||
}
|
}
|
||||||
data, errRead := io.ReadAll(httpResp.Body)
|
data, errRead := io.ReadAll(httpResp.Body)
|
||||||
|
|||||||
@@ -142,7 +142,7 @@ func (e *IFlowExecutor) Execute(ctx context.Context, auth *cliproxyauth.Auth, re
|
|||||||
if httpResp.StatusCode < 200 || httpResp.StatusCode >= 300 {
|
if httpResp.StatusCode < 200 || httpResp.StatusCode >= 300 {
|
||||||
b, _ := io.ReadAll(httpResp.Body)
|
b, _ := io.ReadAll(httpResp.Body)
|
||||||
appendAPIResponseChunk(ctx, e.cfg, b)
|
appendAPIResponseChunk(ctx, e.cfg, b)
|
||||||
log.Debugf("iflow request error: status %d body %s", httpResp.StatusCode, summarizeErrorBody(httpResp.Header.Get("Content-Type"), b))
|
logWithRequestID(ctx).Debugf("request error, error status: %d error message: %s", httpResp.StatusCode, summarizeErrorBody(httpResp.Header.Get("Content-Type"), b))
|
||||||
err = statusErr{code: httpResp.StatusCode, msg: string(b)}
|
err = statusErr{code: httpResp.StatusCode, msg: string(b)}
|
||||||
return resp, err
|
return resp, err
|
||||||
}
|
}
|
||||||
@@ -244,7 +244,7 @@ func (e *IFlowExecutor) ExecuteStream(ctx context.Context, auth *cliproxyauth.Au
|
|||||||
log.Errorf("iflow executor: close response body error: %v", errClose)
|
log.Errorf("iflow executor: close response body error: %v", errClose)
|
||||||
}
|
}
|
||||||
appendAPIResponseChunk(ctx, e.cfg, data)
|
appendAPIResponseChunk(ctx, e.cfg, data)
|
||||||
log.Debugf("iflow streaming error: status %d body %s", httpResp.StatusCode, summarizeErrorBody(httpResp.Header.Get("Content-Type"), data))
|
logWithRequestID(ctx).Debugf("request error, error status: %d error message: %s", httpResp.StatusCode, summarizeErrorBody(httpResp.Header.Get("Content-Type"), data))
|
||||||
err = statusErr{code: httpResp.StatusCode, msg: string(data)}
|
err = statusErr{code: httpResp.StatusCode, msg: string(data)}
|
||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -791,28 +791,28 @@ func (e *KiroExecutor) executeWithRetry(ctx context.Context, auth *cliproxyauth.
|
|||||||
_ = httpResp.Body.Close()
|
_ = httpResp.Body.Close()
|
||||||
appendAPIResponseChunk(ctx, e.cfg, respBody)
|
appendAPIResponseChunk(ctx, e.cfg, respBody)
|
||||||
|
|
||||||
if attempt < maxRetries {
|
log.Warnf("kiro: received 401 error, attempting token refresh")
|
||||||
log.Warnf("kiro: received 401 error, attempting token refresh and retry (attempt %d/%d)", attempt+1, maxRetries+1)
|
refreshedAuth, refreshErr := e.Refresh(ctx, auth)
|
||||||
|
if refreshErr != nil {
|
||||||
|
log.Errorf("kiro: token refresh failed: %v", refreshErr)
|
||||||
|
return resp, statusErr{code: httpResp.StatusCode, msg: string(respBody)}
|
||||||
|
}
|
||||||
|
|
||||||
refreshedAuth, refreshErr := e.Refresh(ctx, auth)
|
if refreshedAuth != nil {
|
||||||
if refreshErr != nil {
|
auth = refreshedAuth
|
||||||
log.Errorf("kiro: token refresh failed: %v", refreshErr)
|
// Persist the refreshed auth to file so subsequent requests use it
|
||||||
return resp, statusErr{code: httpResp.StatusCode, msg: string(respBody)}
|
if persistErr := e.persistRefreshedAuth(auth); persistErr != nil {
|
||||||
|
log.Warnf("kiro: failed to persist refreshed auth: %v", persistErr)
|
||||||
|
// Continue anyway - the token is valid for this request
|
||||||
}
|
}
|
||||||
|
accessToken, profileArn = kiroCredentials(auth)
|
||||||
if refreshedAuth != nil {
|
// Rebuild payload with new profile ARN if changed
|
||||||
auth = refreshedAuth
|
kiroPayload, _ = buildKiroPayloadForFormat(body, kiroModelID, profileArn, currentOrigin, isAgentic, isChatOnly, from, opts.Headers)
|
||||||
// Persist the refreshed auth to file so subsequent requests use it
|
if attempt < maxRetries {
|
||||||
if persistErr := e.persistRefreshedAuth(auth); persistErr != nil {
|
log.Infof("kiro: token refreshed successfully, retrying request (attempt %d/%d)", attempt+1, maxRetries+1)
|
||||||
log.Warnf("kiro: failed to persist refreshed auth: %v", persistErr)
|
|
||||||
// Continue anyway - the token is valid for this request
|
|
||||||
}
|
|
||||||
accessToken, profileArn = kiroCredentials(auth)
|
|
||||||
// Rebuild payload with new profile ARN if changed
|
|
||||||
kiroPayload, _ = buildKiroPayloadForFormat(body, kiroModelID, profileArn, currentOrigin, isAgentic, isChatOnly, from, opts.Headers)
|
|
||||||
log.Infof("kiro: token refreshed successfully, retrying request")
|
|
||||||
continue
|
continue
|
||||||
}
|
}
|
||||||
|
log.Infof("kiro: token refreshed successfully, no retries remaining")
|
||||||
}
|
}
|
||||||
|
|
||||||
log.Warnf("kiro request error, status: 401, body: %s", summarizeErrorBody(httpResp.Header.Get("Content-Type"), respBody))
|
log.Warnf("kiro request error, status: 401, body: %s", summarizeErrorBody(httpResp.Header.Get("Content-Type"), respBody))
|
||||||
@@ -1199,28 +1199,28 @@ func (e *KiroExecutor) executeStreamWithRetry(ctx context.Context, auth *cliprox
|
|||||||
_ = httpResp.Body.Close()
|
_ = httpResp.Body.Close()
|
||||||
appendAPIResponseChunk(ctx, e.cfg, respBody)
|
appendAPIResponseChunk(ctx, e.cfg, respBody)
|
||||||
|
|
||||||
if attempt < maxRetries {
|
log.Warnf("kiro: stream received 401 error, attempting token refresh")
|
||||||
log.Warnf("kiro: stream received 401 error, attempting token refresh and retry (attempt %d/%d)", attempt+1, maxRetries+1)
|
refreshedAuth, refreshErr := e.Refresh(ctx, auth)
|
||||||
|
if refreshErr != nil {
|
||||||
|
log.Errorf("kiro: token refresh failed: %v", refreshErr)
|
||||||
|
return nil, statusErr{code: httpResp.StatusCode, msg: string(respBody)}
|
||||||
|
}
|
||||||
|
|
||||||
refreshedAuth, refreshErr := e.Refresh(ctx, auth)
|
if refreshedAuth != nil {
|
||||||
if refreshErr != nil {
|
auth = refreshedAuth
|
||||||
log.Errorf("kiro: token refresh failed: %v", refreshErr)
|
// Persist the refreshed auth to file so subsequent requests use it
|
||||||
return nil, statusErr{code: httpResp.StatusCode, msg: string(respBody)}
|
if persistErr := e.persistRefreshedAuth(auth); persistErr != nil {
|
||||||
|
log.Warnf("kiro: failed to persist refreshed auth: %v", persistErr)
|
||||||
|
// Continue anyway - the token is valid for this request
|
||||||
}
|
}
|
||||||
|
accessToken, profileArn = kiroCredentials(auth)
|
||||||
if refreshedAuth != nil {
|
// Rebuild payload with new profile ARN if changed
|
||||||
auth = refreshedAuth
|
kiroPayload, _ = buildKiroPayloadForFormat(body, kiroModelID, profileArn, currentOrigin, isAgentic, isChatOnly, from, opts.Headers)
|
||||||
// Persist the refreshed auth to file so subsequent requests use it
|
if attempt < maxRetries {
|
||||||
if persistErr := e.persistRefreshedAuth(auth); persistErr != nil {
|
log.Infof("kiro: token refreshed successfully, retrying stream request (attempt %d/%d)", attempt+1, maxRetries+1)
|
||||||
log.Warnf("kiro: failed to persist refreshed auth: %v", persistErr)
|
|
||||||
// Continue anyway - the token is valid for this request
|
|
||||||
}
|
|
||||||
accessToken, profileArn = kiroCredentials(auth)
|
|
||||||
// Rebuild payload with new profile ARN if changed
|
|
||||||
kiroPayload, _ = buildKiroPayloadForFormat(body, kiroModelID, profileArn, currentOrigin, isAgentic, isChatOnly, from, opts.Headers)
|
|
||||||
log.Infof("kiro: token refreshed successfully, retrying stream request")
|
|
||||||
continue
|
continue
|
||||||
}
|
}
|
||||||
|
log.Infof("kiro: token refreshed successfully, no retries remaining")
|
||||||
}
|
}
|
||||||
|
|
||||||
log.Warnf("kiro stream error, status: 401, body: %s", string(respBody))
|
log.Warnf("kiro stream error, status: 401, body: %s", string(respBody))
|
||||||
|
|||||||
@@ -12,7 +12,10 @@ import (
|
|||||||
|
|
||||||
"github.com/gin-gonic/gin"
|
"github.com/gin-gonic/gin"
|
||||||
"github.com/router-for-me/CLIProxyAPI/v6/internal/config"
|
"github.com/router-for-me/CLIProxyAPI/v6/internal/config"
|
||||||
|
"github.com/router-for-me/CLIProxyAPI/v6/internal/logging"
|
||||||
"github.com/router-for-me/CLIProxyAPI/v6/internal/util"
|
"github.com/router-for-me/CLIProxyAPI/v6/internal/util"
|
||||||
|
log "github.com/sirupsen/logrus"
|
||||||
|
"github.com/tidwall/gjson"
|
||||||
)
|
)
|
||||||
|
|
||||||
const (
|
const (
|
||||||
@@ -332,6 +335,12 @@ func summarizeErrorBody(contentType string, body []byte) string {
|
|||||||
}
|
}
|
||||||
return "[html body omitted]"
|
return "[html body omitted]"
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// Try to extract error message from JSON response
|
||||||
|
if message := extractJSONErrorMessage(body); message != "" {
|
||||||
|
return message
|
||||||
|
}
|
||||||
|
|
||||||
return string(body)
|
return string(body)
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -358,3 +367,25 @@ func extractHTMLTitle(body []byte) string {
|
|||||||
}
|
}
|
||||||
return strings.Join(strings.Fields(title), " ")
|
return strings.Join(strings.Fields(title), " ")
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// extractJSONErrorMessage attempts to extract error.message from JSON error responses
|
||||||
|
func extractJSONErrorMessage(body []byte) string {
|
||||||
|
result := gjson.GetBytes(body, "error.message")
|
||||||
|
if result.Exists() && result.String() != "" {
|
||||||
|
return result.String()
|
||||||
|
}
|
||||||
|
return ""
|
||||||
|
}
|
||||||
|
|
||||||
|
// logWithRequestID returns a logrus Entry with request_id field populated from context.
|
||||||
|
// If no request ID is found in context, it returns the standard logger.
|
||||||
|
func logWithRequestID(ctx context.Context) *log.Entry {
|
||||||
|
if ctx == nil {
|
||||||
|
return log.NewEntry(log.StandardLogger())
|
||||||
|
}
|
||||||
|
requestID := logging.GetRequestID(ctx)
|
||||||
|
if requestID == "" {
|
||||||
|
return log.NewEntry(log.StandardLogger())
|
||||||
|
}
|
||||||
|
return log.WithField("request_id", requestID)
|
||||||
|
}
|
||||||
|
|||||||
@@ -146,7 +146,7 @@ func (e *OpenAICompatExecutor) Execute(ctx context.Context, auth *cliproxyauth.A
|
|||||||
if httpResp.StatusCode < 200 || httpResp.StatusCode >= 300 {
|
if httpResp.StatusCode < 200 || httpResp.StatusCode >= 300 {
|
||||||
b, _ := io.ReadAll(httpResp.Body)
|
b, _ := io.ReadAll(httpResp.Body)
|
||||||
appendAPIResponseChunk(ctx, e.cfg, b)
|
appendAPIResponseChunk(ctx, e.cfg, b)
|
||||||
log.Debugf("request error, error status: %d, error body: %s", httpResp.StatusCode, summarizeErrorBody(httpResp.Header.Get("Content-Type"), b))
|
logWithRequestID(ctx).Debugf("request error, error status: %d, error message: %s", httpResp.StatusCode, summarizeErrorBody(httpResp.Header.Get("Content-Type"), b))
|
||||||
err = statusErr{code: httpResp.StatusCode, msg: string(b)}
|
err = statusErr{code: httpResp.StatusCode, msg: string(b)}
|
||||||
return resp, err
|
return resp, err
|
||||||
}
|
}
|
||||||
@@ -239,7 +239,7 @@ func (e *OpenAICompatExecutor) ExecuteStream(ctx context.Context, auth *cliproxy
|
|||||||
if httpResp.StatusCode < 200 || httpResp.StatusCode >= 300 {
|
if httpResp.StatusCode < 200 || httpResp.StatusCode >= 300 {
|
||||||
b, _ := io.ReadAll(httpResp.Body)
|
b, _ := io.ReadAll(httpResp.Body)
|
||||||
appendAPIResponseChunk(ctx, e.cfg, b)
|
appendAPIResponseChunk(ctx, e.cfg, b)
|
||||||
log.Debugf("request error, error status: %d, error body: %s", httpResp.StatusCode, summarizeErrorBody(httpResp.Header.Get("Content-Type"), b))
|
logWithRequestID(ctx).Debugf("request error, error status: %d, error message: %s", httpResp.StatusCode, summarizeErrorBody(httpResp.Header.Get("Content-Type"), b))
|
||||||
if errClose := httpResp.Body.Close(); errClose != nil {
|
if errClose := httpResp.Body.Close(); errClose != nil {
|
||||||
log.Errorf("openai compat executor: close response body error: %v", errClose)
|
log.Errorf("openai compat executor: close response body error: %v", errClose)
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -133,7 +133,7 @@ func (e *QwenExecutor) Execute(ctx context.Context, auth *cliproxyauth.Auth, req
|
|||||||
if httpResp.StatusCode < 200 || httpResp.StatusCode >= 300 {
|
if httpResp.StatusCode < 200 || httpResp.StatusCode >= 300 {
|
||||||
b, _ := io.ReadAll(httpResp.Body)
|
b, _ := io.ReadAll(httpResp.Body)
|
||||||
appendAPIResponseChunk(ctx, e.cfg, b)
|
appendAPIResponseChunk(ctx, e.cfg, b)
|
||||||
log.Debugf("request error, error status: %d, error body: %s", httpResp.StatusCode, summarizeErrorBody(httpResp.Header.Get("Content-Type"), b))
|
logWithRequestID(ctx).Debugf("request error, error status: %d, error message: %s", httpResp.StatusCode, summarizeErrorBody(httpResp.Header.Get("Content-Type"), b))
|
||||||
err = statusErr{code: httpResp.StatusCode, msg: string(b)}
|
err = statusErr{code: httpResp.StatusCode, msg: string(b)}
|
||||||
return resp, err
|
return resp, err
|
||||||
}
|
}
|
||||||
@@ -222,7 +222,7 @@ func (e *QwenExecutor) ExecuteStream(ctx context.Context, auth *cliproxyauth.Aut
|
|||||||
if httpResp.StatusCode < 200 || httpResp.StatusCode >= 300 {
|
if httpResp.StatusCode < 200 || httpResp.StatusCode >= 300 {
|
||||||
b, _ := io.ReadAll(httpResp.Body)
|
b, _ := io.ReadAll(httpResp.Body)
|
||||||
appendAPIResponseChunk(ctx, e.cfg, b)
|
appendAPIResponseChunk(ctx, e.cfg, b)
|
||||||
log.Debugf("request error, error status: %d, error body: %s", httpResp.StatusCode, summarizeErrorBody(httpResp.Header.Get("Content-Type"), b))
|
logWithRequestID(ctx).Debugf("request error, error status: %d, error message: %s", httpResp.StatusCode, summarizeErrorBody(httpResp.Header.Get("Content-Type"), b))
|
||||||
if errClose := httpResp.Body.Close(); errClose != nil {
|
if errClose := httpResp.Body.Close(); errClose != nil {
|
||||||
log.Errorf("qwen executor: close response body error: %v", errClose)
|
log.Errorf("qwen executor: close response body error: %v", errClose)
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -305,12 +305,12 @@ func ConvertOpenAIRequestToAntigravity(modelName string, inputRawJSON []byte, _
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
// tools -> request.tools[0].functionDeclarations + request.tools[0].googleSearch passthrough
|
// tools -> request.tools[].functionDeclarations + request.tools[].googleSearch passthrough
|
||||||
tools := gjson.GetBytes(rawJSON, "tools")
|
tools := gjson.GetBytes(rawJSON, "tools")
|
||||||
if tools.IsArray() && len(tools.Array()) > 0 {
|
if tools.IsArray() && len(tools.Array()) > 0 {
|
||||||
toolNode := []byte(`{}`)
|
functionToolNode := []byte(`{}`)
|
||||||
hasTool := false
|
|
||||||
hasFunction := false
|
hasFunction := false
|
||||||
|
googleSearchNodes := make([][]byte, 0)
|
||||||
for _, t := range tools.Array() {
|
for _, t := range tools.Array() {
|
||||||
if t.Get("type").String() == "function" {
|
if t.Get("type").String() == "function" {
|
||||||
fn := t.Get("function")
|
fn := t.Get("function")
|
||||||
@@ -349,31 +349,37 @@ func ConvertOpenAIRequestToAntigravity(modelName string, inputRawJSON []byte, _
|
|||||||
}
|
}
|
||||||
fnRaw, _ = sjson.Delete(fnRaw, "strict")
|
fnRaw, _ = sjson.Delete(fnRaw, "strict")
|
||||||
if !hasFunction {
|
if !hasFunction {
|
||||||
toolNode, _ = sjson.SetRawBytes(toolNode, "functionDeclarations", []byte("[]"))
|
functionToolNode, _ = sjson.SetRawBytes(functionToolNode, "functionDeclarations", []byte("[]"))
|
||||||
}
|
}
|
||||||
tmp, errSet := sjson.SetRawBytes(toolNode, "functionDeclarations.-1", []byte(fnRaw))
|
tmp, errSet := sjson.SetRawBytes(functionToolNode, "functionDeclarations.-1", []byte(fnRaw))
|
||||||
if errSet != nil {
|
if errSet != nil {
|
||||||
log.Warnf("Failed to append tool declaration for '%s': %v", fn.Get("name").String(), errSet)
|
log.Warnf("Failed to append tool declaration for '%s': %v", fn.Get("name").String(), errSet)
|
||||||
continue
|
continue
|
||||||
}
|
}
|
||||||
toolNode = tmp
|
functionToolNode = tmp
|
||||||
hasFunction = true
|
hasFunction = true
|
||||||
hasTool = true
|
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
if gs := t.Get("google_search"); gs.Exists() {
|
if gs := t.Get("google_search"); gs.Exists() {
|
||||||
|
googleToolNode := []byte(`{}`)
|
||||||
var errSet error
|
var errSet error
|
||||||
toolNode, errSet = sjson.SetRawBytes(toolNode, "googleSearch", []byte(gs.Raw))
|
googleToolNode, errSet = sjson.SetRawBytes(googleToolNode, "googleSearch", []byte(gs.Raw))
|
||||||
if errSet != nil {
|
if errSet != nil {
|
||||||
log.Warnf("Failed to set googleSearch tool: %v", errSet)
|
log.Warnf("Failed to set googleSearch tool: %v", errSet)
|
||||||
continue
|
continue
|
||||||
}
|
}
|
||||||
hasTool = true
|
googleSearchNodes = append(googleSearchNodes, googleToolNode)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
if hasTool {
|
if hasFunction || len(googleSearchNodes) > 0 {
|
||||||
out, _ = sjson.SetRawBytes(out, "request.tools", []byte("[]"))
|
toolsNode := []byte("[]")
|
||||||
out, _ = sjson.SetRawBytes(out, "request.tools.0", toolNode)
|
if hasFunction {
|
||||||
|
toolsNode, _ = sjson.SetRawBytes(toolsNode, "-1", functionToolNode)
|
||||||
|
}
|
||||||
|
for _, googleNode := range googleSearchNodes {
|
||||||
|
toolsNode, _ = sjson.SetRawBytes(toolsNode, "-1", googleNode)
|
||||||
|
}
|
||||||
|
out, _ = sjson.SetRawBytes(out, "request.tools", toolsNode)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|||||||
@@ -283,12 +283,12 @@ func ConvertOpenAIRequestToGeminiCLI(modelName string, inputRawJSON []byte, _ bo
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
// tools -> request.tools[0].functionDeclarations + request.tools[0].googleSearch passthrough
|
// tools -> request.tools[].functionDeclarations + request.tools[].googleSearch passthrough
|
||||||
tools := gjson.GetBytes(rawJSON, "tools")
|
tools := gjson.GetBytes(rawJSON, "tools")
|
||||||
if tools.IsArray() && len(tools.Array()) > 0 {
|
if tools.IsArray() && len(tools.Array()) > 0 {
|
||||||
toolNode := []byte(`{}`)
|
functionToolNode := []byte(`{}`)
|
||||||
hasTool := false
|
|
||||||
hasFunction := false
|
hasFunction := false
|
||||||
|
googleSearchNodes := make([][]byte, 0)
|
||||||
for _, t := range tools.Array() {
|
for _, t := range tools.Array() {
|
||||||
if t.Get("type").String() == "function" {
|
if t.Get("type").String() == "function" {
|
||||||
fn := t.Get("function")
|
fn := t.Get("function")
|
||||||
@@ -327,31 +327,37 @@ func ConvertOpenAIRequestToGeminiCLI(modelName string, inputRawJSON []byte, _ bo
|
|||||||
}
|
}
|
||||||
fnRaw, _ = sjson.Delete(fnRaw, "strict")
|
fnRaw, _ = sjson.Delete(fnRaw, "strict")
|
||||||
if !hasFunction {
|
if !hasFunction {
|
||||||
toolNode, _ = sjson.SetRawBytes(toolNode, "functionDeclarations", []byte("[]"))
|
functionToolNode, _ = sjson.SetRawBytes(functionToolNode, "functionDeclarations", []byte("[]"))
|
||||||
}
|
}
|
||||||
tmp, errSet := sjson.SetRawBytes(toolNode, "functionDeclarations.-1", []byte(fnRaw))
|
tmp, errSet := sjson.SetRawBytes(functionToolNode, "functionDeclarations.-1", []byte(fnRaw))
|
||||||
if errSet != nil {
|
if errSet != nil {
|
||||||
log.Warnf("Failed to append tool declaration for '%s': %v", fn.Get("name").String(), errSet)
|
log.Warnf("Failed to append tool declaration for '%s': %v", fn.Get("name").String(), errSet)
|
||||||
continue
|
continue
|
||||||
}
|
}
|
||||||
toolNode = tmp
|
functionToolNode = tmp
|
||||||
hasFunction = true
|
hasFunction = true
|
||||||
hasTool = true
|
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
if gs := t.Get("google_search"); gs.Exists() {
|
if gs := t.Get("google_search"); gs.Exists() {
|
||||||
|
googleToolNode := []byte(`{}`)
|
||||||
var errSet error
|
var errSet error
|
||||||
toolNode, errSet = sjson.SetRawBytes(toolNode, "googleSearch", []byte(gs.Raw))
|
googleToolNode, errSet = sjson.SetRawBytes(googleToolNode, "googleSearch", []byte(gs.Raw))
|
||||||
if errSet != nil {
|
if errSet != nil {
|
||||||
log.Warnf("Failed to set googleSearch tool: %v", errSet)
|
log.Warnf("Failed to set googleSearch tool: %v", errSet)
|
||||||
continue
|
continue
|
||||||
}
|
}
|
||||||
hasTool = true
|
googleSearchNodes = append(googleSearchNodes, googleToolNode)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
if hasTool {
|
if hasFunction || len(googleSearchNodes) > 0 {
|
||||||
out, _ = sjson.SetRawBytes(out, "request.tools", []byte("[]"))
|
toolsNode := []byte("[]")
|
||||||
out, _ = sjson.SetRawBytes(out, "request.tools.0", toolNode)
|
if hasFunction {
|
||||||
|
toolsNode, _ = sjson.SetRawBytes(toolsNode, "-1", functionToolNode)
|
||||||
|
}
|
||||||
|
for _, googleNode := range googleSearchNodes {
|
||||||
|
toolsNode, _ = sjson.SetRawBytes(toolsNode, "-1", googleNode)
|
||||||
|
}
|
||||||
|
out, _ = sjson.SetRawBytes(out, "request.tools", toolsNode)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|||||||
@@ -289,12 +289,12 @@ func ConvertOpenAIRequestToGemini(modelName string, inputRawJSON []byte, _ bool)
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
// tools -> tools[0].functionDeclarations + tools[0].googleSearch passthrough
|
// tools -> tools[].functionDeclarations + tools[].googleSearch passthrough
|
||||||
tools := gjson.GetBytes(rawJSON, "tools")
|
tools := gjson.GetBytes(rawJSON, "tools")
|
||||||
if tools.IsArray() && len(tools.Array()) > 0 {
|
if tools.IsArray() && len(tools.Array()) > 0 {
|
||||||
toolNode := []byte(`{}`)
|
functionToolNode := []byte(`{}`)
|
||||||
hasTool := false
|
|
||||||
hasFunction := false
|
hasFunction := false
|
||||||
|
googleSearchNodes := make([][]byte, 0)
|
||||||
for _, t := range tools.Array() {
|
for _, t := range tools.Array() {
|
||||||
if t.Get("type").String() == "function" {
|
if t.Get("type").String() == "function" {
|
||||||
fn := t.Get("function")
|
fn := t.Get("function")
|
||||||
@@ -333,31 +333,37 @@ func ConvertOpenAIRequestToGemini(modelName string, inputRawJSON []byte, _ bool)
|
|||||||
}
|
}
|
||||||
fnRaw, _ = sjson.Delete(fnRaw, "strict")
|
fnRaw, _ = sjson.Delete(fnRaw, "strict")
|
||||||
if !hasFunction {
|
if !hasFunction {
|
||||||
toolNode, _ = sjson.SetRawBytes(toolNode, "functionDeclarations", []byte("[]"))
|
functionToolNode, _ = sjson.SetRawBytes(functionToolNode, "functionDeclarations", []byte("[]"))
|
||||||
}
|
}
|
||||||
tmp, errSet := sjson.SetRawBytes(toolNode, "functionDeclarations.-1", []byte(fnRaw))
|
tmp, errSet := sjson.SetRawBytes(functionToolNode, "functionDeclarations.-1", []byte(fnRaw))
|
||||||
if errSet != nil {
|
if errSet != nil {
|
||||||
log.Warnf("Failed to append tool declaration for '%s': %v", fn.Get("name").String(), errSet)
|
log.Warnf("Failed to append tool declaration for '%s': %v", fn.Get("name").String(), errSet)
|
||||||
continue
|
continue
|
||||||
}
|
}
|
||||||
toolNode = tmp
|
functionToolNode = tmp
|
||||||
hasFunction = true
|
hasFunction = true
|
||||||
hasTool = true
|
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
if gs := t.Get("google_search"); gs.Exists() {
|
if gs := t.Get("google_search"); gs.Exists() {
|
||||||
|
googleToolNode := []byte(`{}`)
|
||||||
var errSet error
|
var errSet error
|
||||||
toolNode, errSet = sjson.SetRawBytes(toolNode, "googleSearch", []byte(gs.Raw))
|
googleToolNode, errSet = sjson.SetRawBytes(googleToolNode, "googleSearch", []byte(gs.Raw))
|
||||||
if errSet != nil {
|
if errSet != nil {
|
||||||
log.Warnf("Failed to set googleSearch tool: %v", errSet)
|
log.Warnf("Failed to set googleSearch tool: %v", errSet)
|
||||||
continue
|
continue
|
||||||
}
|
}
|
||||||
hasTool = true
|
googleSearchNodes = append(googleSearchNodes, googleToolNode)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
if hasTool {
|
if hasFunction || len(googleSearchNodes) > 0 {
|
||||||
out, _ = sjson.SetRawBytes(out, "tools", []byte("[]"))
|
toolsNode := []byte("[]")
|
||||||
out, _ = sjson.SetRawBytes(out, "tools.0", toolNode)
|
if hasFunction {
|
||||||
|
toolsNode, _ = sjson.SetRawBytes(toolsNode, "-1", functionToolNode)
|
||||||
|
}
|
||||||
|
for _, googleNode := range googleSearchNodes {
|
||||||
|
toolsNode, _ = sjson.SetRawBytes(toolsNode, "-1", googleNode)
|
||||||
|
}
|
||||||
|
out, _ = sjson.SetRawBytes(out, "tools", toolsNode)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|||||||
@@ -499,6 +499,16 @@ func shortenToolNameIfNeeded(name string) string {
|
|||||||
return name[:limit]
|
return name[:limit]
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func ensureKiroInputSchema(parameters interface{}) interface{} {
|
||||||
|
if parameters != nil {
|
||||||
|
return parameters
|
||||||
|
}
|
||||||
|
return map[string]interface{}{
|
||||||
|
"type": "object",
|
||||||
|
"properties": map[string]interface{}{},
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
// convertClaudeToolsToKiro converts Claude tools to Kiro format
|
// convertClaudeToolsToKiro converts Claude tools to Kiro format
|
||||||
func convertClaudeToolsToKiro(tools gjson.Result) []KiroToolWrapper {
|
func convertClaudeToolsToKiro(tools gjson.Result) []KiroToolWrapper {
|
||||||
var kiroTools []KiroToolWrapper
|
var kiroTools []KiroToolWrapper
|
||||||
@@ -509,7 +519,12 @@ func convertClaudeToolsToKiro(tools gjson.Result) []KiroToolWrapper {
|
|||||||
for _, tool := range tools.Array() {
|
for _, tool := range tools.Array() {
|
||||||
name := tool.Get("name").String()
|
name := tool.Get("name").String()
|
||||||
description := tool.Get("description").String()
|
description := tool.Get("description").String()
|
||||||
inputSchema := tool.Get("input_schema").Value()
|
inputSchemaResult := tool.Get("input_schema")
|
||||||
|
var inputSchema interface{}
|
||||||
|
if inputSchemaResult.Exists() && inputSchemaResult.Type != gjson.Null {
|
||||||
|
inputSchema = inputSchemaResult.Value()
|
||||||
|
}
|
||||||
|
inputSchema = ensureKiroInputSchema(inputSchema)
|
||||||
|
|
||||||
// Shorten tool name if it exceeds 64 characters (common with MCP tools)
|
// Shorten tool name if it exceeds 64 characters (common with MCP tools)
|
||||||
originalName := name
|
originalName := name
|
||||||
|
|||||||
@@ -314,7 +314,7 @@ func ConvertOpenAIToolsToKiroFormat(tools []map[string]interface{}) []KiroToolWr
|
|||||||
|
|
||||||
name := kirocommon.GetString(fn, "name")
|
name := kirocommon.GetString(fn, "name")
|
||||||
description := kirocommon.GetString(fn, "description")
|
description := kirocommon.GetString(fn, "description")
|
||||||
parameters := fn["parameters"]
|
parameters := ensureKiroInputSchema(fn["parameters"])
|
||||||
|
|
||||||
if name == "" {
|
if name == "" {
|
||||||
continue
|
continue
|
||||||
@@ -368,4 +368,4 @@ func ConvertClaudeToolUseToOpenAI(toolUseID, toolName string, input map[string]i
|
|||||||
// LogStreamEvent logs a streaming event for debugging
|
// LogStreamEvent logs a streaming event for debugging
|
||||||
func LogStreamEvent(eventType, data string) {
|
func LogStreamEvent(eventType, data string) {
|
||||||
log.Debugf("kiro-openai: stream event type=%s, data_len=%d", eventType, len(data))
|
log.Debugf("kiro-openai: stream event type=%s, data_len=%d", eventType, len(data))
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -381,6 +381,16 @@ func shortenToolNameIfNeeded(name string) string {
|
|||||||
return name[:limit]
|
return name[:limit]
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func ensureKiroInputSchema(parameters interface{}) interface{} {
|
||||||
|
if parameters != nil {
|
||||||
|
return parameters
|
||||||
|
}
|
||||||
|
return map[string]interface{}{
|
||||||
|
"type": "object",
|
||||||
|
"properties": map[string]interface{}{},
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
// convertOpenAIToolsToKiro converts OpenAI tools to Kiro format
|
// convertOpenAIToolsToKiro converts OpenAI tools to Kiro format
|
||||||
func convertOpenAIToolsToKiro(tools gjson.Result) []KiroToolWrapper {
|
func convertOpenAIToolsToKiro(tools gjson.Result) []KiroToolWrapper {
|
||||||
var kiroTools []KiroToolWrapper
|
var kiroTools []KiroToolWrapper
|
||||||
@@ -401,7 +411,12 @@ func convertOpenAIToolsToKiro(tools gjson.Result) []KiroToolWrapper {
|
|||||||
|
|
||||||
name := fn.Get("name").String()
|
name := fn.Get("name").String()
|
||||||
description := fn.Get("description").String()
|
description := fn.Get("description").String()
|
||||||
parameters := fn.Get("parameters").Value()
|
parametersResult := fn.Get("parameters")
|
||||||
|
var parameters interface{}
|
||||||
|
if parametersResult.Exists() && parametersResult.Type != gjson.Null {
|
||||||
|
parameters = parametersResult.Value()
|
||||||
|
}
|
||||||
|
parameters = ensureKiroInputSchema(parameters)
|
||||||
|
|
||||||
// Shorten tool name if it exceeds 64 characters (common with MCP tools)
|
// Shorten tool name if it exceeds 64 characters (common with MCP tools)
|
||||||
originalName := name
|
originalName := name
|
||||||
|
|||||||
@@ -181,11 +181,11 @@ func TestConvertClaudeRequestToOpenAI_ThinkingToReasoningContent(t *testing.T) {
|
|||||||
result := ConvertClaudeRequestToOpenAI("test-model", []byte(tt.inputJSON), false)
|
result := ConvertClaudeRequestToOpenAI("test-model", []byte(tt.inputJSON), false)
|
||||||
resultJSON := gjson.ParseBytes(result)
|
resultJSON := gjson.ParseBytes(result)
|
||||||
|
|
||||||
// Find the relevant message (skip system message at index 0)
|
// Find the relevant message
|
||||||
messages := resultJSON.Get("messages").Array()
|
messages := resultJSON.Get("messages").Array()
|
||||||
if len(messages) < 2 {
|
if len(messages) < 1 {
|
||||||
if tt.wantHasReasoningContent || tt.wantHasContent {
|
if tt.wantHasReasoningContent || tt.wantHasContent {
|
||||||
t.Fatalf("Expected at least 2 messages (system + user/assistant), got %d", len(messages))
|
t.Fatalf("Expected at least 1 message, got %d", len(messages))
|
||||||
}
|
}
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
@@ -272,15 +272,15 @@ func TestConvertClaudeRequestToOpenAI_ThinkingOnlyMessagePreserved(t *testing.T)
|
|||||||
|
|
||||||
messages := resultJSON.Get("messages").Array()
|
messages := resultJSON.Get("messages").Array()
|
||||||
|
|
||||||
// Should have: system (auto-added) + user + assistant (thinking-only) + user = 4 messages
|
// Should have: user + assistant (thinking-only) + user = 3 messages
|
||||||
if len(messages) != 4 {
|
if len(messages) != 3 {
|
||||||
t.Fatalf("Expected 4 messages, got %d. Messages: %v", len(messages), resultJSON.Get("messages").Raw)
|
t.Fatalf("Expected 3 messages, got %d. Messages: %v", len(messages), resultJSON.Get("messages").Raw)
|
||||||
}
|
}
|
||||||
|
|
||||||
// Check the assistant message (index 2) has reasoning_content
|
// Check the assistant message (index 1) has reasoning_content
|
||||||
assistantMsg := messages[2]
|
assistantMsg := messages[1]
|
||||||
if assistantMsg.Get("role").String() != "assistant" {
|
if assistantMsg.Get("role").String() != "assistant" {
|
||||||
t.Errorf("Expected message[2] to be assistant, got %s", assistantMsg.Get("role").String())
|
t.Errorf("Expected message[1] to be assistant, got %s", assistantMsg.Get("role").String())
|
||||||
}
|
}
|
||||||
|
|
||||||
if !assistantMsg.Get("reasoning_content").Exists() {
|
if !assistantMsg.Get("reasoning_content").Exists() {
|
||||||
@@ -292,6 +292,104 @@ func TestConvertClaudeRequestToOpenAI_ThinkingOnlyMessagePreserved(t *testing.T)
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func TestConvertClaudeRequestToOpenAI_SystemMessageScenarios(t *testing.T) {
|
||||||
|
tests := []struct {
|
||||||
|
name string
|
||||||
|
inputJSON string
|
||||||
|
wantHasSys bool
|
||||||
|
wantSysText string
|
||||||
|
}{
|
||||||
|
{
|
||||||
|
name: "No system field",
|
||||||
|
inputJSON: `{
|
||||||
|
"model": "claude-3-opus",
|
||||||
|
"messages": [{"role": "user", "content": "hello"}]
|
||||||
|
}`,
|
||||||
|
wantHasSys: false,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "Empty string system field",
|
||||||
|
inputJSON: `{
|
||||||
|
"model": "claude-3-opus",
|
||||||
|
"system": "",
|
||||||
|
"messages": [{"role": "user", "content": "hello"}]
|
||||||
|
}`,
|
||||||
|
wantHasSys: false,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "String system field",
|
||||||
|
inputJSON: `{
|
||||||
|
"model": "claude-3-opus",
|
||||||
|
"system": "Be helpful",
|
||||||
|
"messages": [{"role": "user", "content": "hello"}]
|
||||||
|
}`,
|
||||||
|
wantHasSys: true,
|
||||||
|
wantSysText: "Be helpful",
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "Array system field with text",
|
||||||
|
inputJSON: `{
|
||||||
|
"model": "claude-3-opus",
|
||||||
|
"system": [{"type": "text", "text": "Array system"}],
|
||||||
|
"messages": [{"role": "user", "content": "hello"}]
|
||||||
|
}`,
|
||||||
|
wantHasSys: true,
|
||||||
|
wantSysText: "Array system",
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "Array system field with multiple text blocks",
|
||||||
|
inputJSON: `{
|
||||||
|
"model": "claude-3-opus",
|
||||||
|
"system": [
|
||||||
|
{"type": "text", "text": "Block 1"},
|
||||||
|
{"type": "text", "text": "Block 2"}
|
||||||
|
],
|
||||||
|
"messages": [{"role": "user", "content": "hello"}]
|
||||||
|
}`,
|
||||||
|
wantHasSys: true,
|
||||||
|
wantSysText: "Block 2", // We will update the test logic to check all blocks or specifically the second one
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
for _, tt := range tests {
|
||||||
|
t.Run(tt.name, func(t *testing.T) {
|
||||||
|
result := ConvertClaudeRequestToOpenAI("test-model", []byte(tt.inputJSON), false)
|
||||||
|
resultJSON := gjson.ParseBytes(result)
|
||||||
|
messages := resultJSON.Get("messages").Array()
|
||||||
|
|
||||||
|
hasSys := false
|
||||||
|
var sysMsg gjson.Result
|
||||||
|
if len(messages) > 0 && messages[0].Get("role").String() == "system" {
|
||||||
|
hasSys = true
|
||||||
|
sysMsg = messages[0]
|
||||||
|
}
|
||||||
|
|
||||||
|
if hasSys != tt.wantHasSys {
|
||||||
|
t.Errorf("got hasSystem = %v, want %v", hasSys, tt.wantHasSys)
|
||||||
|
}
|
||||||
|
|
||||||
|
if tt.wantHasSys {
|
||||||
|
// Check content - it could be string or array in OpenAI
|
||||||
|
content := sysMsg.Get("content")
|
||||||
|
var gotText string
|
||||||
|
if content.IsArray() {
|
||||||
|
arr := content.Array()
|
||||||
|
if len(arr) > 0 {
|
||||||
|
// Get the last element's text for validation
|
||||||
|
gotText = arr[len(arr)-1].Get("text").String()
|
||||||
|
}
|
||||||
|
} else {
|
||||||
|
gotText = content.String()
|
||||||
|
}
|
||||||
|
|
||||||
|
if tt.wantSysText != "" && gotText != tt.wantSysText {
|
||||||
|
t.Errorf("got system text = %q, want %q", gotText, tt.wantSysText)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
})
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
func TestConvertClaudeRequestToOpenAI_ToolResultOrderAndContent(t *testing.T) {
|
func TestConvertClaudeRequestToOpenAI_ToolResultOrderAndContent(t *testing.T) {
|
||||||
inputJSON := `{
|
inputJSON := `{
|
||||||
"model": "claude-3-opus",
|
"model": "claude-3-opus",
|
||||||
@@ -318,39 +416,35 @@ func TestConvertClaudeRequestToOpenAI_ToolResultOrderAndContent(t *testing.T) {
|
|||||||
messages := resultJSON.Get("messages").Array()
|
messages := resultJSON.Get("messages").Array()
|
||||||
|
|
||||||
// OpenAI requires: tool messages MUST immediately follow assistant(tool_calls).
|
// OpenAI requires: tool messages MUST immediately follow assistant(tool_calls).
|
||||||
// Correct order: system + assistant(tool_calls) + tool(result) + user(before+after)
|
// Correct order: assistant(tool_calls) + tool(result) + user(before+after)
|
||||||
if len(messages) != 4 {
|
if len(messages) != 3 {
|
||||||
t.Fatalf("Expected 4 messages, got %d. Messages: %s", len(messages), resultJSON.Get("messages").Raw)
|
t.Fatalf("Expected 3 messages, got %d. Messages: %s", len(messages), resultJSON.Get("messages").Raw)
|
||||||
}
|
}
|
||||||
|
|
||||||
if messages[0].Get("role").String() != "system" {
|
if messages[0].Get("role").String() != "assistant" || !messages[0].Get("tool_calls").Exists() {
|
||||||
t.Fatalf("Expected messages[0] to be system, got %s", messages[0].Get("role").String())
|
t.Fatalf("Expected messages[0] to be assistant tool_calls, got %s: %s", messages[0].Get("role").String(), messages[0].Raw)
|
||||||
}
|
|
||||||
|
|
||||||
if messages[1].Get("role").String() != "assistant" || !messages[1].Get("tool_calls").Exists() {
|
|
||||||
t.Fatalf("Expected messages[1] to be assistant tool_calls, got %s: %s", messages[1].Get("role").String(), messages[1].Raw)
|
|
||||||
}
|
}
|
||||||
|
|
||||||
// tool message MUST immediately follow assistant(tool_calls) per OpenAI spec
|
// tool message MUST immediately follow assistant(tool_calls) per OpenAI spec
|
||||||
if messages[2].Get("role").String() != "tool" {
|
if messages[1].Get("role").String() != "tool" {
|
||||||
t.Fatalf("Expected messages[2] to be tool (must follow tool_calls), got %s", messages[2].Get("role").String())
|
t.Fatalf("Expected messages[1] to be tool (must follow tool_calls), got %s", messages[1].Get("role").String())
|
||||||
}
|
}
|
||||||
if got := messages[2].Get("tool_call_id").String(); got != "call_1" {
|
if got := messages[1].Get("tool_call_id").String(); got != "call_1" {
|
||||||
t.Fatalf("Expected tool_call_id %q, got %q", "call_1", got)
|
t.Fatalf("Expected tool_call_id %q, got %q", "call_1", got)
|
||||||
}
|
}
|
||||||
if got := messages[2].Get("content").String(); got != "tool ok" {
|
if got := messages[1].Get("content").String(); got != "tool ok" {
|
||||||
t.Fatalf("Expected tool content %q, got %q", "tool ok", got)
|
t.Fatalf("Expected tool content %q, got %q", "tool ok", got)
|
||||||
}
|
}
|
||||||
|
|
||||||
// User message comes after tool message
|
// User message comes after tool message
|
||||||
if messages[3].Get("role").String() != "user" {
|
if messages[2].Get("role").String() != "user" {
|
||||||
t.Fatalf("Expected messages[3] to be user, got %s", messages[3].Get("role").String())
|
t.Fatalf("Expected messages[2] to be user, got %s", messages[2].Get("role").String())
|
||||||
}
|
}
|
||||||
// User message should contain both "before" and "after" text
|
// User message should contain both "before" and "after" text
|
||||||
if got := messages[3].Get("content.0.text").String(); got != "before" {
|
if got := messages[2].Get("content.0.text").String(); got != "before" {
|
||||||
t.Fatalf("Expected user text[0] %q, got %q", "before", got)
|
t.Fatalf("Expected user text[0] %q, got %q", "before", got)
|
||||||
}
|
}
|
||||||
if got := messages[3].Get("content.1.text").String(); got != "after" {
|
if got := messages[2].Get("content.1.text").String(); got != "after" {
|
||||||
t.Fatalf("Expected user text[1] %q, got %q", "after", got)
|
t.Fatalf("Expected user text[1] %q, got %q", "after", got)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
@@ -378,16 +472,16 @@ func TestConvertClaudeRequestToOpenAI_ToolResultObjectContent(t *testing.T) {
|
|||||||
resultJSON := gjson.ParseBytes(result)
|
resultJSON := gjson.ParseBytes(result)
|
||||||
messages := resultJSON.Get("messages").Array()
|
messages := resultJSON.Get("messages").Array()
|
||||||
|
|
||||||
// system + assistant(tool_calls) + tool(result)
|
// assistant(tool_calls) + tool(result)
|
||||||
if len(messages) != 3 {
|
if len(messages) != 2 {
|
||||||
t.Fatalf("Expected 3 messages, got %d. Messages: %s", len(messages), resultJSON.Get("messages").Raw)
|
t.Fatalf("Expected 2 messages, got %d. Messages: %s", len(messages), resultJSON.Get("messages").Raw)
|
||||||
}
|
}
|
||||||
|
|
||||||
if messages[2].Get("role").String() != "tool" {
|
if messages[1].Get("role").String() != "tool" {
|
||||||
t.Fatalf("Expected messages[2] to be tool, got %s", messages[2].Get("role").String())
|
t.Fatalf("Expected messages[1] to be tool, got %s", messages[1].Get("role").String())
|
||||||
}
|
}
|
||||||
|
|
||||||
toolContent := messages[2].Get("content").String()
|
toolContent := messages[1].Get("content").String()
|
||||||
parsed := gjson.Parse(toolContent)
|
parsed := gjson.Parse(toolContent)
|
||||||
if parsed.Get("foo").String() != "bar" {
|
if parsed.Get("foo").String() != "bar" {
|
||||||
t.Fatalf("Expected tool content JSON foo=bar, got %q", toolContent)
|
t.Fatalf("Expected tool content JSON foo=bar, got %q", toolContent)
|
||||||
@@ -414,18 +508,14 @@ func TestConvertClaudeRequestToOpenAI_AssistantTextToolUseTextOrder(t *testing.T
|
|||||||
messages := resultJSON.Get("messages").Array()
|
messages := resultJSON.Get("messages").Array()
|
||||||
|
|
||||||
// New behavior: content + tool_calls unified in single assistant message
|
// New behavior: content + tool_calls unified in single assistant message
|
||||||
// Expect: system + assistant(content[pre,post] + tool_calls)
|
// Expect: assistant(content[pre,post] + tool_calls)
|
||||||
if len(messages) != 2 {
|
if len(messages) != 1 {
|
||||||
t.Fatalf("Expected 2 messages, got %d. Messages: %s", len(messages), resultJSON.Get("messages").Raw)
|
t.Fatalf("Expected 1 message, got %d. Messages: %s", len(messages), resultJSON.Get("messages").Raw)
|
||||||
}
|
}
|
||||||
|
|
||||||
if messages[0].Get("role").String() != "system" {
|
assistantMsg := messages[0]
|
||||||
t.Fatalf("Expected messages[0] to be system, got %s", messages[0].Get("role").String())
|
|
||||||
}
|
|
||||||
|
|
||||||
assistantMsg := messages[1]
|
|
||||||
if assistantMsg.Get("role").String() != "assistant" {
|
if assistantMsg.Get("role").String() != "assistant" {
|
||||||
t.Fatalf("Expected messages[1] to be assistant, got %s", assistantMsg.Get("role").String())
|
t.Fatalf("Expected messages[0] to be assistant, got %s", assistantMsg.Get("role").String())
|
||||||
}
|
}
|
||||||
|
|
||||||
// Should have both content and tool_calls in same message
|
// Should have both content and tool_calls in same message
|
||||||
@@ -470,14 +560,14 @@ func TestConvertClaudeRequestToOpenAI_AssistantThinkingToolUseThinkingSplit(t *t
|
|||||||
messages := resultJSON.Get("messages").Array()
|
messages := resultJSON.Get("messages").Array()
|
||||||
|
|
||||||
// New behavior: all content, thinking, and tool_calls unified in single assistant message
|
// New behavior: all content, thinking, and tool_calls unified in single assistant message
|
||||||
// Expect: system + assistant(content[pre,post] + tool_calls + reasoning_content[t1+t2])
|
// Expect: assistant(content[pre,post] + tool_calls + reasoning_content[t1+t2])
|
||||||
if len(messages) != 2 {
|
if len(messages) != 1 {
|
||||||
t.Fatalf("Expected 2 messages, got %d. Messages: %s", len(messages), resultJSON.Get("messages").Raw)
|
t.Fatalf("Expected 1 message, got %d. Messages: %s", len(messages), resultJSON.Get("messages").Raw)
|
||||||
}
|
}
|
||||||
|
|
||||||
assistantMsg := messages[1]
|
assistantMsg := messages[0]
|
||||||
if assistantMsg.Get("role").String() != "assistant" {
|
if assistantMsg.Get("role").String() != "assistant" {
|
||||||
t.Fatalf("Expected messages[1] to be assistant, got %s", assistantMsg.Get("role").String())
|
t.Fatalf("Expected messages[0] to be assistant, got %s", assistantMsg.Get("role").String())
|
||||||
}
|
}
|
||||||
|
|
||||||
// Should have content with both pre and post
|
// Should have content with both pre and post
|
||||||
|
|||||||
@@ -167,6 +167,16 @@ func SynthesizeGeminiVirtualAuths(primary *coreauth.Auth, metadata map[string]an
|
|||||||
"virtual_parent_id": primary.ID,
|
"virtual_parent_id": primary.ID,
|
||||||
"type": metadata["type"],
|
"type": metadata["type"],
|
||||||
}
|
}
|
||||||
|
if v, ok := metadata["disable_cooling"]; ok {
|
||||||
|
metadataCopy["disable_cooling"] = v
|
||||||
|
} else if v, ok := metadata["disable-cooling"]; ok {
|
||||||
|
metadataCopy["disable_cooling"] = v
|
||||||
|
}
|
||||||
|
if v, ok := metadata["request_retry"]; ok {
|
||||||
|
metadataCopy["request_retry"] = v
|
||||||
|
} else if v, ok := metadata["request-retry"]; ok {
|
||||||
|
metadataCopy["request_retry"] = v
|
||||||
|
}
|
||||||
proxy := strings.TrimSpace(primary.ProxyURL)
|
proxy := strings.TrimSpace(primary.ProxyURL)
|
||||||
if proxy != "" {
|
if proxy != "" {
|
||||||
metadataCopy["proxy_url"] = proxy
|
metadataCopy["proxy_url"] = proxy
|
||||||
|
|||||||
@@ -69,10 +69,12 @@ func TestFileSynthesizer_Synthesize_ValidAuthFile(t *testing.T) {
|
|||||||
|
|
||||||
// Create a valid auth file
|
// Create a valid auth file
|
||||||
authData := map[string]any{
|
authData := map[string]any{
|
||||||
"type": "claude",
|
"type": "claude",
|
||||||
"email": "test@example.com",
|
"email": "test@example.com",
|
||||||
"proxy_url": "http://proxy.local",
|
"proxy_url": "http://proxy.local",
|
||||||
"prefix": "test-prefix",
|
"prefix": "test-prefix",
|
||||||
|
"disable_cooling": true,
|
||||||
|
"request_retry": 2,
|
||||||
}
|
}
|
||||||
data, _ := json.Marshal(authData)
|
data, _ := json.Marshal(authData)
|
||||||
err := os.WriteFile(filepath.Join(tempDir, "claude-auth.json"), data, 0644)
|
err := os.WriteFile(filepath.Join(tempDir, "claude-auth.json"), data, 0644)
|
||||||
@@ -108,6 +110,12 @@ func TestFileSynthesizer_Synthesize_ValidAuthFile(t *testing.T) {
|
|||||||
if auths[0].ProxyURL != "http://proxy.local" {
|
if auths[0].ProxyURL != "http://proxy.local" {
|
||||||
t.Errorf("expected proxy_url http://proxy.local, got %s", auths[0].ProxyURL)
|
t.Errorf("expected proxy_url http://proxy.local, got %s", auths[0].ProxyURL)
|
||||||
}
|
}
|
||||||
|
if v, ok := auths[0].Metadata["disable_cooling"].(bool); !ok || !v {
|
||||||
|
t.Errorf("expected disable_cooling true, got %v", auths[0].Metadata["disable_cooling"])
|
||||||
|
}
|
||||||
|
if v, ok := auths[0].Metadata["request_retry"].(float64); !ok || int(v) != 2 {
|
||||||
|
t.Errorf("expected request_retry 2, got %v", auths[0].Metadata["request_retry"])
|
||||||
|
}
|
||||||
if auths[0].Status != coreauth.StatusActive {
|
if auths[0].Status != coreauth.StatusActive {
|
||||||
t.Errorf("expected status active, got %s", auths[0].Status)
|
t.Errorf("expected status active, got %s", auths[0].Status)
|
||||||
}
|
}
|
||||||
@@ -336,9 +344,11 @@ func TestSynthesizeGeminiVirtualAuths_MultiProject(t *testing.T) {
|
|||||||
},
|
},
|
||||||
}
|
}
|
||||||
metadata := map[string]any{
|
metadata := map[string]any{
|
||||||
"project_id": "project-a, project-b, project-c",
|
"project_id": "project-a, project-b, project-c",
|
||||||
"email": "test@example.com",
|
"email": "test@example.com",
|
||||||
"type": "gemini",
|
"type": "gemini",
|
||||||
|
"request_retry": 2,
|
||||||
|
"disable_cooling": true,
|
||||||
}
|
}
|
||||||
|
|
||||||
virtuals := SynthesizeGeminiVirtualAuths(primary, metadata, now)
|
virtuals := SynthesizeGeminiVirtualAuths(primary, metadata, now)
|
||||||
@@ -376,6 +386,12 @@ func TestSynthesizeGeminiVirtualAuths_MultiProject(t *testing.T) {
|
|||||||
if v.ProxyURL != "http://proxy.local" {
|
if v.ProxyURL != "http://proxy.local" {
|
||||||
t.Errorf("expected proxy_url http://proxy.local, got %s", v.ProxyURL)
|
t.Errorf("expected proxy_url http://proxy.local, got %s", v.ProxyURL)
|
||||||
}
|
}
|
||||||
|
if vv, ok := v.Metadata["disable_cooling"].(bool); !ok || !vv {
|
||||||
|
t.Errorf("expected disable_cooling true, got %v", v.Metadata["disable_cooling"])
|
||||||
|
}
|
||||||
|
if vv, ok := v.Metadata["request_retry"].(int); !ok || vv != 2 {
|
||||||
|
t.Errorf("expected request_retry 2, got %v", v.Metadata["request_retry"])
|
||||||
|
}
|
||||||
if v.Attributes["runtime_only"] != "true" {
|
if v.Attributes["runtime_only"] != "true" {
|
||||||
t.Error("expected runtime_only=true")
|
t.Error("expected runtime_only=true")
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -128,8 +128,23 @@ func (h *ClaudeCodeAPIHandler) ClaudeCountTokens(c *gin.Context) {
|
|||||||
// Parameters:
|
// Parameters:
|
||||||
// - c: The Gin context for the request.
|
// - c: The Gin context for the request.
|
||||||
func (h *ClaudeCodeAPIHandler) ClaudeModels(c *gin.Context) {
|
func (h *ClaudeCodeAPIHandler) ClaudeModels(c *gin.Context) {
|
||||||
|
models := h.Models()
|
||||||
|
firstID := ""
|
||||||
|
lastID := ""
|
||||||
|
if len(models) > 0 {
|
||||||
|
if id, ok := models[0]["id"].(string); ok {
|
||||||
|
firstID = id
|
||||||
|
}
|
||||||
|
if id, ok := models[len(models)-1]["id"].(string); ok {
|
||||||
|
lastID = id
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
c.JSON(http.StatusOK, gin.H{
|
c.JSON(http.StatusOK, gin.H{
|
||||||
"data": h.Models(),
|
"data": models,
|
||||||
|
"has_more": false,
|
||||||
|
"first_id": firstID,
|
||||||
|
"last_id": lastID,
|
||||||
})
|
})
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|||||||
@@ -60,8 +60,12 @@ func (h *GeminiAPIHandler) GeminiModels(c *gin.Context) {
|
|||||||
if !strings.HasPrefix(name, "models/") {
|
if !strings.HasPrefix(name, "models/") {
|
||||||
normalizedModel["name"] = "models/" + name
|
normalizedModel["name"] = "models/" + name
|
||||||
}
|
}
|
||||||
normalizedModel["displayName"] = name
|
if displayName, _ := normalizedModel["displayName"].(string); displayName == "" {
|
||||||
normalizedModel["description"] = name
|
normalizedModel["displayName"] = name
|
||||||
|
}
|
||||||
|
if description, _ := normalizedModel["description"].(string); description == "" {
|
||||||
|
normalizedModel["description"] = name
|
||||||
|
}
|
||||||
}
|
}
|
||||||
if _, ok := normalizedModel["supportedGenerationMethods"]; !ok {
|
if _, ok := normalizedModel["supportedGenerationMethods"]; !ok {
|
||||||
normalizedModel["supportedGenerationMethods"] = defaultMethods
|
normalizedModel["supportedGenerationMethods"] = defaultMethods
|
||||||
|
|||||||
@@ -2,15 +2,13 @@ package auth
|
|||||||
|
|
||||||
import (
|
import (
|
||||||
"context"
|
"context"
|
||||||
"encoding/json"
|
|
||||||
"fmt"
|
"fmt"
|
||||||
"io"
|
|
||||||
"net"
|
"net"
|
||||||
"net/http"
|
"net/http"
|
||||||
"net/url"
|
|
||||||
"strings"
|
"strings"
|
||||||
"time"
|
"time"
|
||||||
|
|
||||||
|
"github.com/router-for-me/CLIProxyAPI/v6/internal/auth/antigravity"
|
||||||
"github.com/router-for-me/CLIProxyAPI/v6/internal/browser"
|
"github.com/router-for-me/CLIProxyAPI/v6/internal/browser"
|
||||||
"github.com/router-for-me/CLIProxyAPI/v6/internal/config"
|
"github.com/router-for-me/CLIProxyAPI/v6/internal/config"
|
||||||
"github.com/router-for-me/CLIProxyAPI/v6/internal/misc"
|
"github.com/router-for-me/CLIProxyAPI/v6/internal/misc"
|
||||||
@@ -19,20 +17,6 @@ import (
|
|||||||
log "github.com/sirupsen/logrus"
|
log "github.com/sirupsen/logrus"
|
||||||
)
|
)
|
||||||
|
|
||||||
const (
|
|
||||||
antigravityClientID = "1071006060591-tmhssin2h21lcre235vtolojh4g403ep.apps.googleusercontent.com"
|
|
||||||
antigravityClientSecret = "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"
|
|
||||||
antigravityCallbackPort = 51121
|
|
||||||
)
|
|
||||||
|
|
||||||
var antigravityScopes = []string{
|
|
||||||
"https://www.googleapis.com/auth/cloud-platform",
|
|
||||||
"https://www.googleapis.com/auth/userinfo.email",
|
|
||||||
"https://www.googleapis.com/auth/userinfo.profile",
|
|
||||||
"https://www.googleapis.com/auth/cclog",
|
|
||||||
"https://www.googleapis.com/auth/experimentsandconfigs",
|
|
||||||
}
|
|
||||||
|
|
||||||
// AntigravityAuthenticator implements OAuth login for the antigravity provider.
|
// AntigravityAuthenticator implements OAuth login for the antigravity provider.
|
||||||
type AntigravityAuthenticator struct{}
|
type AntigravityAuthenticator struct{}
|
||||||
|
|
||||||
@@ -60,12 +44,12 @@ func (AntigravityAuthenticator) Login(ctx context.Context, cfg *config.Config, o
|
|||||||
opts = &LoginOptions{}
|
opts = &LoginOptions{}
|
||||||
}
|
}
|
||||||
|
|
||||||
callbackPort := antigravityCallbackPort
|
callbackPort := antigravity.CallbackPort
|
||||||
if opts.CallbackPort > 0 {
|
if opts.CallbackPort > 0 {
|
||||||
callbackPort = opts.CallbackPort
|
callbackPort = opts.CallbackPort
|
||||||
}
|
}
|
||||||
|
|
||||||
httpClient := util.SetProxy(&cfg.SDKConfig, &http.Client{})
|
authSvc := antigravity.NewAntigravityAuth(cfg, nil)
|
||||||
|
|
||||||
state, err := misc.GenerateRandomState()
|
state, err := misc.GenerateRandomState()
|
||||||
if err != nil {
|
if err != nil {
|
||||||
@@ -83,7 +67,7 @@ func (AntigravityAuthenticator) Login(ctx context.Context, cfg *config.Config, o
|
|||||||
}()
|
}()
|
||||||
|
|
||||||
redirectURI := fmt.Sprintf("http://localhost:%d/oauth-callback", port)
|
redirectURI := fmt.Sprintf("http://localhost:%d/oauth-callback", port)
|
||||||
authURL := buildAntigravityAuthURL(redirectURI, state)
|
authURL := authSvc.BuildAuthURL(state, redirectURI)
|
||||||
|
|
||||||
if !opts.NoBrowser {
|
if !opts.NoBrowser {
|
||||||
fmt.Println("Opening browser for antigravity authentication")
|
fmt.Println("Opening browser for antigravity authentication")
|
||||||
@@ -164,22 +148,29 @@ waitForCallback:
|
|||||||
return nil, fmt.Errorf("antigravity: missing authorization code")
|
return nil, fmt.Errorf("antigravity: missing authorization code")
|
||||||
}
|
}
|
||||||
|
|
||||||
tokenResp, errToken := exchangeAntigravityCode(ctx, cbRes.Code, redirectURI, httpClient)
|
tokenResp, errToken := authSvc.ExchangeCodeForTokens(ctx, cbRes.Code, redirectURI)
|
||||||
if errToken != nil {
|
if errToken != nil {
|
||||||
return nil, fmt.Errorf("antigravity: token exchange failed: %w", errToken)
|
return nil, fmt.Errorf("antigravity: token exchange failed: %w", errToken)
|
||||||
}
|
}
|
||||||
|
|
||||||
email := ""
|
accessToken := strings.TrimSpace(tokenResp.AccessToken)
|
||||||
if tokenResp.AccessToken != "" {
|
if accessToken == "" {
|
||||||
if info, errInfo := fetchAntigravityUserInfo(ctx, tokenResp.AccessToken, httpClient); errInfo == nil && strings.TrimSpace(info.Email) != "" {
|
return nil, fmt.Errorf("antigravity: token exchange returned empty access token")
|
||||||
email = strings.TrimSpace(info.Email)
|
}
|
||||||
}
|
|
||||||
|
email, errInfo := authSvc.FetchUserInfo(ctx, accessToken)
|
||||||
|
if errInfo != nil {
|
||||||
|
return nil, fmt.Errorf("antigravity: fetch user info failed: %w", errInfo)
|
||||||
|
}
|
||||||
|
email = strings.TrimSpace(email)
|
||||||
|
if email == "" {
|
||||||
|
return nil, fmt.Errorf("antigravity: empty email returned from user info")
|
||||||
}
|
}
|
||||||
|
|
||||||
// Fetch project ID via loadCodeAssist (same approach as Gemini CLI)
|
// Fetch project ID via loadCodeAssist (same approach as Gemini CLI)
|
||||||
projectID := ""
|
projectID := ""
|
||||||
if tokenResp.AccessToken != "" {
|
if accessToken != "" {
|
||||||
fetchedProjectID, errProject := fetchAntigravityProjectID(ctx, tokenResp.AccessToken, httpClient)
|
fetchedProjectID, errProject := authSvc.FetchProjectID(ctx, accessToken)
|
||||||
if errProject != nil {
|
if errProject != nil {
|
||||||
log.Warnf("antigravity: failed to fetch project ID: %v", errProject)
|
log.Warnf("antigravity: failed to fetch project ID: %v", errProject)
|
||||||
} else {
|
} else {
|
||||||
@@ -204,7 +195,7 @@ waitForCallback:
|
|||||||
metadata["project_id"] = projectID
|
metadata["project_id"] = projectID
|
||||||
}
|
}
|
||||||
|
|
||||||
fileName := sanitizeAntigravityFileName(email)
|
fileName := antigravity.CredentialFileName(email)
|
||||||
label := email
|
label := email
|
||||||
if label == "" {
|
if label == "" {
|
||||||
label = "antigravity"
|
label = "antigravity"
|
||||||
@@ -231,7 +222,7 @@ type callbackResult struct {
|
|||||||
|
|
||||||
func startAntigravityCallbackServer(port int) (*http.Server, int, <-chan callbackResult, error) {
|
func startAntigravityCallbackServer(port int) (*http.Server, int, <-chan callbackResult, error) {
|
||||||
if port <= 0 {
|
if port <= 0 {
|
||||||
port = antigravityCallbackPort
|
port = antigravity.CallbackPort
|
||||||
}
|
}
|
||||||
addr := fmt.Sprintf(":%d", port)
|
addr := fmt.Sprintf(":%d", port)
|
||||||
listener, err := net.Listen("tcp", addr)
|
listener, err := net.Listen("tcp", addr)
|
||||||
@@ -267,309 +258,9 @@ func startAntigravityCallbackServer(port int) (*http.Server, int, <-chan callbac
|
|||||||
return srv, port, resultCh, nil
|
return srv, port, resultCh, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
type antigravityTokenResponse struct {
|
|
||||||
AccessToken string `json:"access_token"`
|
|
||||||
RefreshToken string `json:"refresh_token"`
|
|
||||||
ExpiresIn int64 `json:"expires_in"`
|
|
||||||
TokenType string `json:"token_type"`
|
|
||||||
}
|
|
||||||
|
|
||||||
func exchangeAntigravityCode(ctx context.Context, code, redirectURI string, httpClient *http.Client) (*antigravityTokenResponse, error) {
|
|
||||||
data := url.Values{}
|
|
||||||
data.Set("code", code)
|
|
||||||
data.Set("client_id", antigravityClientID)
|
|
||||||
data.Set("client_secret", antigravityClientSecret)
|
|
||||||
data.Set("redirect_uri", redirectURI)
|
|
||||||
data.Set("grant_type", "authorization_code")
|
|
||||||
|
|
||||||
req, err := http.NewRequestWithContext(ctx, http.MethodPost, "https://oauth2.googleapis.com/token", strings.NewReader(data.Encode()))
|
|
||||||
if err != nil {
|
|
||||||
return nil, err
|
|
||||||
}
|
|
||||||
req.Header.Set("Content-Type", "application/x-www-form-urlencoded")
|
|
||||||
|
|
||||||
resp, errDo := httpClient.Do(req)
|
|
||||||
if errDo != nil {
|
|
||||||
return nil, errDo
|
|
||||||
}
|
|
||||||
defer func() {
|
|
||||||
if errClose := resp.Body.Close(); errClose != nil {
|
|
||||||
log.Errorf("antigravity token exchange: close body error: %v", errClose)
|
|
||||||
}
|
|
||||||
}()
|
|
||||||
|
|
||||||
var token antigravityTokenResponse
|
|
||||||
if errDecode := json.NewDecoder(resp.Body).Decode(&token); errDecode != nil {
|
|
||||||
return nil, errDecode
|
|
||||||
}
|
|
||||||
if resp.StatusCode < http.StatusOK || resp.StatusCode >= http.StatusMultipleChoices {
|
|
||||||
return nil, fmt.Errorf("oauth token exchange failed: status %d", resp.StatusCode)
|
|
||||||
}
|
|
||||||
return &token, nil
|
|
||||||
}
|
|
||||||
|
|
||||||
type antigravityUserInfo struct {
|
|
||||||
Email string `json:"email"`
|
|
||||||
}
|
|
||||||
|
|
||||||
func fetchAntigravityUserInfo(ctx context.Context, accessToken string, httpClient *http.Client) (*antigravityUserInfo, error) {
|
|
||||||
if strings.TrimSpace(accessToken) == "" {
|
|
||||||
return &antigravityUserInfo{}, nil
|
|
||||||
}
|
|
||||||
req, err := http.NewRequestWithContext(ctx, http.MethodGet, "https://www.googleapis.com/oauth2/v1/userinfo?alt=json", nil)
|
|
||||||
if err != nil {
|
|
||||||
return nil, err
|
|
||||||
}
|
|
||||||
req.Header.Set("Authorization", "Bearer "+accessToken)
|
|
||||||
|
|
||||||
resp, errDo := httpClient.Do(req)
|
|
||||||
if errDo != nil {
|
|
||||||
return nil, errDo
|
|
||||||
}
|
|
||||||
defer func() {
|
|
||||||
if errClose := resp.Body.Close(); errClose != nil {
|
|
||||||
log.Errorf("antigravity userinfo: close body error: %v", errClose)
|
|
||||||
}
|
|
||||||
}()
|
|
||||||
|
|
||||||
if resp.StatusCode < http.StatusOK || resp.StatusCode >= http.StatusMultipleChoices {
|
|
||||||
return &antigravityUserInfo{}, nil
|
|
||||||
}
|
|
||||||
var info antigravityUserInfo
|
|
||||||
if errDecode := json.NewDecoder(resp.Body).Decode(&info); errDecode != nil {
|
|
||||||
return nil, errDecode
|
|
||||||
}
|
|
||||||
return &info, nil
|
|
||||||
}
|
|
||||||
|
|
||||||
func buildAntigravityAuthURL(redirectURI, state string) string {
|
|
||||||
params := url.Values{}
|
|
||||||
params.Set("access_type", "offline")
|
|
||||||
params.Set("client_id", antigravityClientID)
|
|
||||||
params.Set("prompt", "consent")
|
|
||||||
params.Set("redirect_uri", redirectURI)
|
|
||||||
params.Set("response_type", "code")
|
|
||||||
params.Set("scope", strings.Join(antigravityScopes, " "))
|
|
||||||
params.Set("state", state)
|
|
||||||
return "https://accounts.google.com/o/oauth2/v2/auth?" + params.Encode()
|
|
||||||
}
|
|
||||||
|
|
||||||
func sanitizeAntigravityFileName(email string) string {
|
|
||||||
if strings.TrimSpace(email) == "" {
|
|
||||||
return "antigravity.json"
|
|
||||||
}
|
|
||||||
replacer := strings.NewReplacer("@", "_", ".", "_")
|
|
||||||
return fmt.Sprintf("antigravity-%s.json", replacer.Replace(email))
|
|
||||||
}
|
|
||||||
|
|
||||||
// Antigravity API constants for project discovery
|
|
||||||
const (
|
|
||||||
antigravityAPIEndpoint = "https://cloudcode-pa.googleapis.com"
|
|
||||||
antigravityAPIVersion = "v1internal"
|
|
||||||
antigravityAPIUserAgent = "google-api-nodejs-client/9.15.1"
|
|
||||||
antigravityAPIClient = "google-cloud-sdk vscode_cloudshelleditor/0.1"
|
|
||||||
antigravityClientMetadata = `{"ideType":"IDE_UNSPECIFIED","platform":"PLATFORM_UNSPECIFIED","pluginType":"GEMINI"}`
|
|
||||||
)
|
|
||||||
|
|
||||||
// FetchAntigravityProjectID exposes project discovery for external callers.
|
// FetchAntigravityProjectID exposes project discovery for external callers.
|
||||||
func FetchAntigravityProjectID(ctx context.Context, accessToken string, httpClient *http.Client) (string, error) {
|
func FetchAntigravityProjectID(ctx context.Context, accessToken string, httpClient *http.Client) (string, error) {
|
||||||
return fetchAntigravityProjectID(ctx, accessToken, httpClient)
|
cfg := &config.Config{}
|
||||||
}
|
authSvc := antigravity.NewAntigravityAuth(cfg, httpClient)
|
||||||
|
return authSvc.FetchProjectID(ctx, accessToken)
|
||||||
// fetchAntigravityProjectID retrieves the project ID for the authenticated user via loadCodeAssist.
|
|
||||||
// This uses the same approach as Gemini CLI to get the cloudaicompanionProject.
|
|
||||||
func fetchAntigravityProjectID(ctx context.Context, accessToken string, httpClient *http.Client) (string, error) {
|
|
||||||
// Call loadCodeAssist to get the project
|
|
||||||
loadReqBody := map[string]any{
|
|
||||||
"metadata": map[string]string{
|
|
||||||
"ideType": "ANTIGRAVITY",
|
|
||||||
"platform": "PLATFORM_UNSPECIFIED",
|
|
||||||
"pluginType": "GEMINI",
|
|
||||||
},
|
|
||||||
}
|
|
||||||
|
|
||||||
rawBody, errMarshal := json.Marshal(loadReqBody)
|
|
||||||
if errMarshal != nil {
|
|
||||||
return "", fmt.Errorf("marshal request body: %w", errMarshal)
|
|
||||||
}
|
|
||||||
|
|
||||||
endpointURL := fmt.Sprintf("%s/%s:loadCodeAssist", antigravityAPIEndpoint, antigravityAPIVersion)
|
|
||||||
req, err := http.NewRequestWithContext(ctx, http.MethodPost, endpointURL, strings.NewReader(string(rawBody)))
|
|
||||||
if err != nil {
|
|
||||||
return "", fmt.Errorf("create request: %w", err)
|
|
||||||
}
|
|
||||||
req.Header.Set("Authorization", "Bearer "+accessToken)
|
|
||||||
req.Header.Set("Content-Type", "application/json")
|
|
||||||
req.Header.Set("User-Agent", antigravityAPIUserAgent)
|
|
||||||
req.Header.Set("X-Goog-Api-Client", antigravityAPIClient)
|
|
||||||
req.Header.Set("Client-Metadata", antigravityClientMetadata)
|
|
||||||
|
|
||||||
resp, errDo := httpClient.Do(req)
|
|
||||||
if errDo != nil {
|
|
||||||
return "", fmt.Errorf("execute request: %w", errDo)
|
|
||||||
}
|
|
||||||
defer func() {
|
|
||||||
if errClose := resp.Body.Close(); errClose != nil {
|
|
||||||
log.Errorf("antigravity loadCodeAssist: close body error: %v", errClose)
|
|
||||||
}
|
|
||||||
}()
|
|
||||||
|
|
||||||
bodyBytes, errRead := io.ReadAll(resp.Body)
|
|
||||||
if errRead != nil {
|
|
||||||
return "", fmt.Errorf("read response: %w", errRead)
|
|
||||||
}
|
|
||||||
|
|
||||||
if resp.StatusCode < http.StatusOK || resp.StatusCode >= http.StatusMultipleChoices {
|
|
||||||
return "", fmt.Errorf("request failed with status %d: %s", resp.StatusCode, strings.TrimSpace(string(bodyBytes)))
|
|
||||||
}
|
|
||||||
|
|
||||||
var loadResp map[string]any
|
|
||||||
if errDecode := json.Unmarshal(bodyBytes, &loadResp); errDecode != nil {
|
|
||||||
return "", fmt.Errorf("decode response: %w", errDecode)
|
|
||||||
}
|
|
||||||
|
|
||||||
// Extract projectID from response
|
|
||||||
projectID := ""
|
|
||||||
if id, ok := loadResp["cloudaicompanionProject"].(string); ok {
|
|
||||||
projectID = strings.TrimSpace(id)
|
|
||||||
}
|
|
||||||
if projectID == "" {
|
|
||||||
if projectMap, ok := loadResp["cloudaicompanionProject"].(map[string]any); ok {
|
|
||||||
if id, okID := projectMap["id"].(string); okID {
|
|
||||||
projectID = strings.TrimSpace(id)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
if projectID == "" {
|
|
||||||
tierID := "legacy-tier"
|
|
||||||
if tiers, okTiers := loadResp["allowedTiers"].([]any); okTiers {
|
|
||||||
for _, rawTier := range tiers {
|
|
||||||
tier, okTier := rawTier.(map[string]any)
|
|
||||||
if !okTier {
|
|
||||||
continue
|
|
||||||
}
|
|
||||||
if isDefault, okDefault := tier["isDefault"].(bool); okDefault && isDefault {
|
|
||||||
if id, okID := tier["id"].(string); okID && strings.TrimSpace(id) != "" {
|
|
||||||
tierID = strings.TrimSpace(id)
|
|
||||||
break
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
projectID, err = antigravityOnboardUser(ctx, accessToken, tierID, httpClient)
|
|
||||||
if err != nil {
|
|
||||||
return "", err
|
|
||||||
}
|
|
||||||
return projectID, nil
|
|
||||||
}
|
|
||||||
|
|
||||||
return projectID, nil
|
|
||||||
}
|
|
||||||
|
|
||||||
// antigravityOnboardUser attempts to fetch the project ID via onboardUser by polling for completion.
|
|
||||||
// It returns an empty string when the operation times out or completes without a project ID.
|
|
||||||
func antigravityOnboardUser(ctx context.Context, accessToken, tierID string, httpClient *http.Client) (string, error) {
|
|
||||||
if httpClient == nil {
|
|
||||||
httpClient = http.DefaultClient
|
|
||||||
}
|
|
||||||
fmt.Println("Antigravity: onboarding user...", tierID)
|
|
||||||
requestBody := map[string]any{
|
|
||||||
"tierId": tierID,
|
|
||||||
"metadata": map[string]string{
|
|
||||||
"ideType": "ANTIGRAVITY",
|
|
||||||
"platform": "PLATFORM_UNSPECIFIED",
|
|
||||||
"pluginType": "GEMINI",
|
|
||||||
},
|
|
||||||
}
|
|
||||||
|
|
||||||
rawBody, errMarshal := json.Marshal(requestBody)
|
|
||||||
if errMarshal != nil {
|
|
||||||
return "", fmt.Errorf("marshal request body: %w", errMarshal)
|
|
||||||
}
|
|
||||||
|
|
||||||
maxAttempts := 5
|
|
||||||
for attempt := 1; attempt <= maxAttempts; attempt++ {
|
|
||||||
log.Debugf("Polling attempt %d/%d", attempt, maxAttempts)
|
|
||||||
|
|
||||||
reqCtx := ctx
|
|
||||||
var cancel context.CancelFunc
|
|
||||||
if reqCtx == nil {
|
|
||||||
reqCtx = context.Background()
|
|
||||||
}
|
|
||||||
reqCtx, cancel = context.WithTimeout(reqCtx, 30*time.Second)
|
|
||||||
|
|
||||||
endpointURL := fmt.Sprintf("%s/%s:onboardUser", antigravityAPIEndpoint, antigravityAPIVersion)
|
|
||||||
req, errRequest := http.NewRequestWithContext(reqCtx, http.MethodPost, endpointURL, strings.NewReader(string(rawBody)))
|
|
||||||
if errRequest != nil {
|
|
||||||
cancel()
|
|
||||||
return "", fmt.Errorf("create request: %w", errRequest)
|
|
||||||
}
|
|
||||||
req.Header.Set("Authorization", "Bearer "+accessToken)
|
|
||||||
req.Header.Set("Content-Type", "application/json")
|
|
||||||
req.Header.Set("User-Agent", antigravityAPIUserAgent)
|
|
||||||
req.Header.Set("X-Goog-Api-Client", antigravityAPIClient)
|
|
||||||
req.Header.Set("Client-Metadata", antigravityClientMetadata)
|
|
||||||
|
|
||||||
resp, errDo := httpClient.Do(req)
|
|
||||||
if errDo != nil {
|
|
||||||
cancel()
|
|
||||||
return "", fmt.Errorf("execute request: %w", errDo)
|
|
||||||
}
|
|
||||||
|
|
||||||
bodyBytes, errRead := io.ReadAll(resp.Body)
|
|
||||||
if errClose := resp.Body.Close(); errClose != nil {
|
|
||||||
log.Errorf("close body error: %v", errClose)
|
|
||||||
}
|
|
||||||
cancel()
|
|
||||||
|
|
||||||
if errRead != nil {
|
|
||||||
return "", fmt.Errorf("read response: %w", errRead)
|
|
||||||
}
|
|
||||||
|
|
||||||
if resp.StatusCode == http.StatusOK {
|
|
||||||
var data map[string]any
|
|
||||||
if errDecode := json.Unmarshal(bodyBytes, &data); errDecode != nil {
|
|
||||||
return "", fmt.Errorf("decode response: %w", errDecode)
|
|
||||||
}
|
|
||||||
|
|
||||||
if done, okDone := data["done"].(bool); okDone && done {
|
|
||||||
projectID := ""
|
|
||||||
if responseData, okResp := data["response"].(map[string]any); okResp {
|
|
||||||
switch projectValue := responseData["cloudaicompanionProject"].(type) {
|
|
||||||
case map[string]any:
|
|
||||||
if id, okID := projectValue["id"].(string); okID {
|
|
||||||
projectID = strings.TrimSpace(id)
|
|
||||||
}
|
|
||||||
case string:
|
|
||||||
projectID = strings.TrimSpace(projectValue)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
if projectID != "" {
|
|
||||||
log.Infof("Successfully fetched project_id: %s", projectID)
|
|
||||||
return projectID, nil
|
|
||||||
}
|
|
||||||
|
|
||||||
return "", fmt.Errorf("no project_id in response")
|
|
||||||
}
|
|
||||||
|
|
||||||
time.Sleep(2 * time.Second)
|
|
||||||
continue
|
|
||||||
}
|
|
||||||
|
|
||||||
responsePreview := strings.TrimSpace(string(bodyBytes))
|
|
||||||
if len(responsePreview) > 500 {
|
|
||||||
responsePreview = responsePreview[:500]
|
|
||||||
}
|
|
||||||
|
|
||||||
responseErr := responsePreview
|
|
||||||
if len(responseErr) > 200 {
|
|
||||||
responseErr = responseErr[:200]
|
|
||||||
}
|
|
||||||
return "", fmt.Errorf("http %d: %s", resp.StatusCode, responseErr)
|
|
||||||
}
|
|
||||||
|
|
||||||
return "", nil
|
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -73,9 +73,7 @@ func (s *FileTokenStore) Save(ctx context.Context, auth *cliproxyauth.Auth) (str
|
|||||||
return "", fmt.Errorf("auth filestore: marshal metadata failed: %w", errMarshal)
|
return "", fmt.Errorf("auth filestore: marshal metadata failed: %w", errMarshal)
|
||||||
}
|
}
|
||||||
if existing, errRead := os.ReadFile(path); errRead == nil {
|
if existing, errRead := os.ReadFile(path); errRead == nil {
|
||||||
// Use metadataEqualIgnoringTimestamps to skip writes when only timestamp fields change.
|
if jsonEqual(existing, raw) {
|
||||||
// This prevents the token refresh loop caused by timestamp/expired/expires_in changes.
|
|
||||||
if metadataEqualIgnoringTimestamps(existing, raw, auth.Provider) {
|
|
||||||
return path, nil
|
return path, nil
|
||||||
}
|
}
|
||||||
file, errOpen := os.OpenFile(path, os.O_WRONLY|os.O_TRUNC, 0o600)
|
file, errOpen := os.OpenFile(path, os.O_WRONLY|os.O_TRUNC, 0o600)
|
||||||
@@ -308,8 +306,7 @@ func (s *FileTokenStore) baseDirSnapshot() string {
|
|||||||
return s.baseDir
|
return s.baseDir
|
||||||
}
|
}
|
||||||
|
|
||||||
// DEPRECATED: Use metadataEqualIgnoringTimestamps for comparing auth metadata.
|
// jsonEqual compares two JSON blobs by parsing them into Go objects and deep comparing.
|
||||||
// This function is kept for backward compatibility but can cause refresh loops.
|
|
||||||
func jsonEqual(a, b []byte) bool {
|
func jsonEqual(a, b []byte) bool {
|
||||||
var objA any
|
var objA any
|
||||||
var objB any
|
var objB any
|
||||||
@@ -322,41 +319,6 @@ func jsonEqual(a, b []byte) bool {
|
|||||||
return deepEqualJSON(objA, objB)
|
return deepEqualJSON(objA, objB)
|
||||||
}
|
}
|
||||||
|
|
||||||
// metadataEqualIgnoringTimestamps compares two metadata JSON blobs,
|
|
||||||
// ignoring fields that change on every refresh but don't affect functionality.
|
|
||||||
// This prevents unnecessary file writes that would trigger watcher events and
|
|
||||||
// create refresh loops.
|
|
||||||
// The provider parameter controls whether access_token is ignored: providers like
|
|
||||||
// Google OAuth (gemini, gemini-cli) can re-fetch tokens when needed, while others
|
|
||||||
// like iFlow require the refreshed token to be persisted.
|
|
||||||
func metadataEqualIgnoringTimestamps(a, b []byte, provider string) bool {
|
|
||||||
var objA, objB map[string]any
|
|
||||||
if err := json.Unmarshal(a, &objA); err != nil {
|
|
||||||
return false
|
|
||||||
}
|
|
||||||
if err := json.Unmarshal(b, &objB); err != nil {
|
|
||||||
return false
|
|
||||||
}
|
|
||||||
|
|
||||||
// Fields to ignore: these change on every refresh but don't affect authentication logic.
|
|
||||||
// - timestamp, expired, expires_in, last_refresh: time-related fields that change on refresh
|
|
||||||
ignoredFields := []string{"timestamp", "expired", "expires_in", "last_refresh"}
|
|
||||||
|
|
||||||
// For providers that can re-fetch tokens when needed (e.g., Google OAuth),
|
|
||||||
// we ignore access_token to avoid unnecessary file writes.
|
|
||||||
switch provider {
|
|
||||||
case "gemini", "gemini-cli", "antigravity":
|
|
||||||
ignoredFields = append(ignoredFields, "access_token")
|
|
||||||
}
|
|
||||||
|
|
||||||
for _, field := range ignoredFields {
|
|
||||||
delete(objA, field)
|
|
||||||
delete(objB, field)
|
|
||||||
}
|
|
||||||
|
|
||||||
return deepEqualJSON(objA, objB)
|
|
||||||
}
|
|
||||||
|
|
||||||
func deepEqualJSON(a, b any) bool {
|
func deepEqualJSON(a, b any) bool {
|
||||||
switch valA := a.(type) {
|
switch valA := a.(type) {
|
||||||
case map[string]any:
|
case map[string]any:
|
||||||
|
|||||||
@@ -61,6 +61,15 @@ func SetQuotaCooldownDisabled(disable bool) {
|
|||||||
quotaCooldownDisabled.Store(disable)
|
quotaCooldownDisabled.Store(disable)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func quotaCooldownDisabledForAuth(auth *Auth) bool {
|
||||||
|
if auth != nil {
|
||||||
|
if override, ok := auth.DisableCoolingOverride(); ok {
|
||||||
|
return override
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return quotaCooldownDisabled.Load()
|
||||||
|
}
|
||||||
|
|
||||||
// Result captures execution outcome used to adjust auth state.
|
// Result captures execution outcome used to adjust auth state.
|
||||||
type Result struct {
|
type Result struct {
|
||||||
// AuthID references the auth that produced this result.
|
// AuthID references the auth that produced this result.
|
||||||
@@ -468,20 +477,16 @@ func (m *Manager) Execute(ctx context.Context, providers []string, req cliproxye
|
|||||||
return cliproxyexecutor.Response{}, &Error{Code: "provider_not_found", Message: "no provider supplied"}
|
return cliproxyexecutor.Response{}, &Error{Code: "provider_not_found", Message: "no provider supplied"}
|
||||||
}
|
}
|
||||||
|
|
||||||
retryTimes, maxWait := m.retrySettings()
|
_, maxWait := m.retrySettings()
|
||||||
attempts := retryTimes + 1
|
|
||||||
if attempts < 1 {
|
|
||||||
attempts = 1
|
|
||||||
}
|
|
||||||
|
|
||||||
var lastErr error
|
var lastErr error
|
||||||
for attempt := 0; attempt < attempts; attempt++ {
|
for attempt := 0; ; attempt++ {
|
||||||
resp, errExec := m.executeMixedOnce(ctx, normalized, req, opts)
|
resp, errExec := m.executeMixedOnce(ctx, normalized, req, opts)
|
||||||
if errExec == nil {
|
if errExec == nil {
|
||||||
return resp, nil
|
return resp, nil
|
||||||
}
|
}
|
||||||
lastErr = errExec
|
lastErr = errExec
|
||||||
wait, shouldRetry := m.shouldRetryAfterError(errExec, attempt, attempts, normalized, req.Model, maxWait)
|
wait, shouldRetry := m.shouldRetryAfterError(errExec, attempt, normalized, req.Model, maxWait)
|
||||||
if !shouldRetry {
|
if !shouldRetry {
|
||||||
break
|
break
|
||||||
}
|
}
|
||||||
@@ -503,20 +508,16 @@ func (m *Manager) ExecuteCount(ctx context.Context, providers []string, req clip
|
|||||||
return cliproxyexecutor.Response{}, &Error{Code: "provider_not_found", Message: "no provider supplied"}
|
return cliproxyexecutor.Response{}, &Error{Code: "provider_not_found", Message: "no provider supplied"}
|
||||||
}
|
}
|
||||||
|
|
||||||
retryTimes, maxWait := m.retrySettings()
|
_, maxWait := m.retrySettings()
|
||||||
attempts := retryTimes + 1
|
|
||||||
if attempts < 1 {
|
|
||||||
attempts = 1
|
|
||||||
}
|
|
||||||
|
|
||||||
var lastErr error
|
var lastErr error
|
||||||
for attempt := 0; attempt < attempts; attempt++ {
|
for attempt := 0; ; attempt++ {
|
||||||
resp, errExec := m.executeCountMixedOnce(ctx, normalized, req, opts)
|
resp, errExec := m.executeCountMixedOnce(ctx, normalized, req, opts)
|
||||||
if errExec == nil {
|
if errExec == nil {
|
||||||
return resp, nil
|
return resp, nil
|
||||||
}
|
}
|
||||||
lastErr = errExec
|
lastErr = errExec
|
||||||
wait, shouldRetry := m.shouldRetryAfterError(errExec, attempt, attempts, normalized, req.Model, maxWait)
|
wait, shouldRetry := m.shouldRetryAfterError(errExec, attempt, normalized, req.Model, maxWait)
|
||||||
if !shouldRetry {
|
if !shouldRetry {
|
||||||
break
|
break
|
||||||
}
|
}
|
||||||
@@ -538,20 +539,16 @@ func (m *Manager) ExecuteStream(ctx context.Context, providers []string, req cli
|
|||||||
return nil, &Error{Code: "provider_not_found", Message: "no provider supplied"}
|
return nil, &Error{Code: "provider_not_found", Message: "no provider supplied"}
|
||||||
}
|
}
|
||||||
|
|
||||||
retryTimes, maxWait := m.retrySettings()
|
_, maxWait := m.retrySettings()
|
||||||
attempts := retryTimes + 1
|
|
||||||
if attempts < 1 {
|
|
||||||
attempts = 1
|
|
||||||
}
|
|
||||||
|
|
||||||
var lastErr error
|
var lastErr error
|
||||||
for attempt := 0; attempt < attempts; attempt++ {
|
for attempt := 0; ; attempt++ {
|
||||||
chunks, errStream := m.executeStreamMixedOnce(ctx, normalized, req, opts)
|
chunks, errStream := m.executeStreamMixedOnce(ctx, normalized, req, opts)
|
||||||
if errStream == nil {
|
if errStream == nil {
|
||||||
return chunks, nil
|
return chunks, nil
|
||||||
}
|
}
|
||||||
lastErr = errStream
|
lastErr = errStream
|
||||||
wait, shouldRetry := m.shouldRetryAfterError(errStream, attempt, attempts, normalized, req.Model, maxWait)
|
wait, shouldRetry := m.shouldRetryAfterError(errStream, attempt, normalized, req.Model, maxWait)
|
||||||
if !shouldRetry {
|
if !shouldRetry {
|
||||||
break
|
break
|
||||||
}
|
}
|
||||||
@@ -1034,11 +1031,15 @@ func (m *Manager) retrySettings() (int, time.Duration) {
|
|||||||
return int(m.requestRetry.Load()), time.Duration(m.maxRetryInterval.Load())
|
return int(m.requestRetry.Load()), time.Duration(m.maxRetryInterval.Load())
|
||||||
}
|
}
|
||||||
|
|
||||||
func (m *Manager) closestCooldownWait(providers []string, model string) (time.Duration, bool) {
|
func (m *Manager) closestCooldownWait(providers []string, model string, attempt int) (time.Duration, bool) {
|
||||||
if m == nil || len(providers) == 0 {
|
if m == nil || len(providers) == 0 {
|
||||||
return 0, false
|
return 0, false
|
||||||
}
|
}
|
||||||
now := time.Now()
|
now := time.Now()
|
||||||
|
defaultRetry := int(m.requestRetry.Load())
|
||||||
|
if defaultRetry < 0 {
|
||||||
|
defaultRetry = 0
|
||||||
|
}
|
||||||
providerSet := make(map[string]struct{}, len(providers))
|
providerSet := make(map[string]struct{}, len(providers))
|
||||||
for i := range providers {
|
for i := range providers {
|
||||||
key := strings.TrimSpace(strings.ToLower(providers[i]))
|
key := strings.TrimSpace(strings.ToLower(providers[i]))
|
||||||
@@ -1061,6 +1062,16 @@ func (m *Manager) closestCooldownWait(providers []string, model string) (time.Du
|
|||||||
if _, ok := providerSet[providerKey]; !ok {
|
if _, ok := providerSet[providerKey]; !ok {
|
||||||
continue
|
continue
|
||||||
}
|
}
|
||||||
|
effectiveRetry := defaultRetry
|
||||||
|
if override, ok := auth.RequestRetryOverride(); ok {
|
||||||
|
effectiveRetry = override
|
||||||
|
}
|
||||||
|
if effectiveRetry < 0 {
|
||||||
|
effectiveRetry = 0
|
||||||
|
}
|
||||||
|
if attempt >= effectiveRetry {
|
||||||
|
continue
|
||||||
|
}
|
||||||
blocked, reason, next := isAuthBlockedForModel(auth, model, now)
|
blocked, reason, next := isAuthBlockedForModel(auth, model, now)
|
||||||
if !blocked || next.IsZero() || reason == blockReasonDisabled {
|
if !blocked || next.IsZero() || reason == blockReasonDisabled {
|
||||||
continue
|
continue
|
||||||
@@ -1077,8 +1088,8 @@ func (m *Manager) closestCooldownWait(providers []string, model string) (time.Du
|
|||||||
return minWait, found
|
return minWait, found
|
||||||
}
|
}
|
||||||
|
|
||||||
func (m *Manager) shouldRetryAfterError(err error, attempt, maxAttempts int, providers []string, model string, maxWait time.Duration) (time.Duration, bool) {
|
func (m *Manager) shouldRetryAfterError(err error, attempt int, providers []string, model string, maxWait time.Duration) (time.Duration, bool) {
|
||||||
if err == nil || attempt >= maxAttempts-1 {
|
if err == nil {
|
||||||
return 0, false
|
return 0, false
|
||||||
}
|
}
|
||||||
if maxWait <= 0 {
|
if maxWait <= 0 {
|
||||||
@@ -1087,7 +1098,7 @@ func (m *Manager) shouldRetryAfterError(err error, attempt, maxAttempts int, pro
|
|||||||
if status := statusCodeFromError(err); status == http.StatusOK {
|
if status := statusCodeFromError(err); status == http.StatusOK {
|
||||||
return 0, false
|
return 0, false
|
||||||
}
|
}
|
||||||
wait, found := m.closestCooldownWait(providers, model)
|
wait, found := m.closestCooldownWait(providers, model, attempt)
|
||||||
if !found || wait > maxWait {
|
if !found || wait > maxWait {
|
||||||
return 0, false
|
return 0, false
|
||||||
}
|
}
|
||||||
@@ -1176,7 +1187,7 @@ func (m *Manager) MarkResult(ctx context.Context, result Result) {
|
|||||||
if result.RetryAfter != nil {
|
if result.RetryAfter != nil {
|
||||||
next = now.Add(*result.RetryAfter)
|
next = now.Add(*result.RetryAfter)
|
||||||
} else {
|
} else {
|
||||||
cooldown, nextLevel := nextQuotaCooldown(backoffLevel)
|
cooldown, nextLevel := nextQuotaCooldown(backoffLevel, quotaCooldownDisabledForAuth(auth))
|
||||||
if cooldown > 0 {
|
if cooldown > 0 {
|
||||||
next = now.Add(cooldown)
|
next = now.Add(cooldown)
|
||||||
}
|
}
|
||||||
@@ -1193,7 +1204,7 @@ func (m *Manager) MarkResult(ctx context.Context, result Result) {
|
|||||||
shouldSuspendModel = true
|
shouldSuspendModel = true
|
||||||
setModelQuota = true
|
setModelQuota = true
|
||||||
case 408, 500, 502, 503, 504:
|
case 408, 500, 502, 503, 504:
|
||||||
if quotaCooldownDisabled.Load() {
|
if quotaCooldownDisabledForAuth(auth) {
|
||||||
state.NextRetryAfter = time.Time{}
|
state.NextRetryAfter = time.Time{}
|
||||||
} else {
|
} else {
|
||||||
next := now.Add(1 * time.Minute)
|
next := now.Add(1 * time.Minute)
|
||||||
@@ -1439,7 +1450,7 @@ func applyAuthFailureState(auth *Auth, resultErr *Error, retryAfter *time.Durati
|
|||||||
if retryAfter != nil {
|
if retryAfter != nil {
|
||||||
next = now.Add(*retryAfter)
|
next = now.Add(*retryAfter)
|
||||||
} else {
|
} else {
|
||||||
cooldown, nextLevel := nextQuotaCooldown(auth.Quota.BackoffLevel)
|
cooldown, nextLevel := nextQuotaCooldown(auth.Quota.BackoffLevel, quotaCooldownDisabledForAuth(auth))
|
||||||
if cooldown > 0 {
|
if cooldown > 0 {
|
||||||
next = now.Add(cooldown)
|
next = now.Add(cooldown)
|
||||||
}
|
}
|
||||||
@@ -1449,7 +1460,7 @@ func applyAuthFailureState(auth *Auth, resultErr *Error, retryAfter *time.Durati
|
|||||||
auth.NextRetryAfter = next
|
auth.NextRetryAfter = next
|
||||||
case 408, 500, 502, 503, 504:
|
case 408, 500, 502, 503, 504:
|
||||||
auth.StatusMessage = "transient upstream error"
|
auth.StatusMessage = "transient upstream error"
|
||||||
if quotaCooldownDisabled.Load() {
|
if quotaCooldownDisabledForAuth(auth) {
|
||||||
auth.NextRetryAfter = time.Time{}
|
auth.NextRetryAfter = time.Time{}
|
||||||
} else {
|
} else {
|
||||||
auth.NextRetryAfter = now.Add(1 * time.Minute)
|
auth.NextRetryAfter = now.Add(1 * time.Minute)
|
||||||
@@ -1462,11 +1473,11 @@ func applyAuthFailureState(auth *Auth, resultErr *Error, retryAfter *time.Durati
|
|||||||
}
|
}
|
||||||
|
|
||||||
// nextQuotaCooldown returns the next cooldown duration and updated backoff level for repeated quota errors.
|
// nextQuotaCooldown returns the next cooldown duration and updated backoff level for repeated quota errors.
|
||||||
func nextQuotaCooldown(prevLevel int) (time.Duration, int) {
|
func nextQuotaCooldown(prevLevel int, disableCooling bool) (time.Duration, int) {
|
||||||
if prevLevel < 0 {
|
if prevLevel < 0 {
|
||||||
prevLevel = 0
|
prevLevel = 0
|
||||||
}
|
}
|
||||||
if quotaCooldownDisabled.Load() {
|
if disableCooling {
|
||||||
return 0, prevLevel
|
return 0, prevLevel
|
||||||
}
|
}
|
||||||
cooldown := quotaBackoffBase * time.Duration(1<<prevLevel)
|
cooldown := quotaBackoffBase * time.Duration(1<<prevLevel)
|
||||||
@@ -1642,6 +1653,9 @@ func (m *Manager) persist(ctx context.Context, auth *Auth) error {
|
|||||||
if m.store == nil || auth == nil {
|
if m.store == nil || auth == nil {
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
if shouldSkipPersist(ctx) {
|
||||||
|
return nil
|
||||||
|
}
|
||||||
if auth.Attributes != nil {
|
if auth.Attributes != nil {
|
||||||
if v := strings.ToLower(strings.TrimSpace(auth.Attributes["runtime_only"])); v == "true" {
|
if v := strings.ToLower(strings.TrimSpace(auth.Attributes["runtime_only"])); v == "true" {
|
||||||
return nil
|
return nil
|
||||||
|
|||||||
97
sdk/cliproxy/auth/conductor_overrides_test.go
Normal file
97
sdk/cliproxy/auth/conductor_overrides_test.go
Normal file
@@ -0,0 +1,97 @@
|
|||||||
|
package auth
|
||||||
|
|
||||||
|
import (
|
||||||
|
"context"
|
||||||
|
"testing"
|
||||||
|
"time"
|
||||||
|
)
|
||||||
|
|
||||||
|
func TestManager_ShouldRetryAfterError_RespectsAuthRequestRetryOverride(t *testing.T) {
|
||||||
|
m := NewManager(nil, nil, nil)
|
||||||
|
m.SetRetryConfig(3, 30*time.Second)
|
||||||
|
|
||||||
|
model := "test-model"
|
||||||
|
next := time.Now().Add(5 * time.Second)
|
||||||
|
|
||||||
|
auth := &Auth{
|
||||||
|
ID: "auth-1",
|
||||||
|
Provider: "claude",
|
||||||
|
Metadata: map[string]any{
|
||||||
|
"request_retry": float64(0),
|
||||||
|
},
|
||||||
|
ModelStates: map[string]*ModelState{
|
||||||
|
model: {
|
||||||
|
Unavailable: true,
|
||||||
|
Status: StatusError,
|
||||||
|
NextRetryAfter: next,
|
||||||
|
},
|
||||||
|
},
|
||||||
|
}
|
||||||
|
if _, errRegister := m.Register(context.Background(), auth); errRegister != nil {
|
||||||
|
t.Fatalf("register auth: %v", errRegister)
|
||||||
|
}
|
||||||
|
|
||||||
|
_, maxWait := m.retrySettings()
|
||||||
|
wait, shouldRetry := m.shouldRetryAfterError(&Error{HTTPStatus: 500, Message: "boom"}, 0, []string{"claude"}, model, maxWait)
|
||||||
|
if shouldRetry {
|
||||||
|
t.Fatalf("expected shouldRetry=false for request_retry=0, got true (wait=%v)", wait)
|
||||||
|
}
|
||||||
|
|
||||||
|
auth.Metadata["request_retry"] = float64(1)
|
||||||
|
if _, errUpdate := m.Update(context.Background(), auth); errUpdate != nil {
|
||||||
|
t.Fatalf("update auth: %v", errUpdate)
|
||||||
|
}
|
||||||
|
|
||||||
|
wait, shouldRetry = m.shouldRetryAfterError(&Error{HTTPStatus: 500, Message: "boom"}, 0, []string{"claude"}, model, maxWait)
|
||||||
|
if !shouldRetry {
|
||||||
|
t.Fatalf("expected shouldRetry=true for request_retry=1, got false")
|
||||||
|
}
|
||||||
|
if wait <= 0 {
|
||||||
|
t.Fatalf("expected wait > 0, got %v", wait)
|
||||||
|
}
|
||||||
|
|
||||||
|
_, shouldRetry = m.shouldRetryAfterError(&Error{HTTPStatus: 500, Message: "boom"}, 1, []string{"claude"}, model, maxWait)
|
||||||
|
if shouldRetry {
|
||||||
|
t.Fatalf("expected shouldRetry=false on attempt=1 for request_retry=1, got true")
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestManager_MarkResult_RespectsAuthDisableCoolingOverride(t *testing.T) {
|
||||||
|
prev := quotaCooldownDisabled.Load()
|
||||||
|
quotaCooldownDisabled.Store(false)
|
||||||
|
t.Cleanup(func() { quotaCooldownDisabled.Store(prev) })
|
||||||
|
|
||||||
|
m := NewManager(nil, nil, nil)
|
||||||
|
|
||||||
|
auth := &Auth{
|
||||||
|
ID: "auth-1",
|
||||||
|
Provider: "claude",
|
||||||
|
Metadata: map[string]any{
|
||||||
|
"disable_cooling": true,
|
||||||
|
},
|
||||||
|
}
|
||||||
|
if _, errRegister := m.Register(context.Background(), auth); errRegister != nil {
|
||||||
|
t.Fatalf("register auth: %v", errRegister)
|
||||||
|
}
|
||||||
|
|
||||||
|
model := "test-model"
|
||||||
|
m.MarkResult(context.Background(), Result{
|
||||||
|
AuthID: "auth-1",
|
||||||
|
Provider: "claude",
|
||||||
|
Model: model,
|
||||||
|
Success: false,
|
||||||
|
Error: &Error{HTTPStatus: 500, Message: "boom"},
|
||||||
|
})
|
||||||
|
|
||||||
|
updated, ok := m.GetByID("auth-1")
|
||||||
|
if !ok || updated == nil {
|
||||||
|
t.Fatalf("expected auth to be present")
|
||||||
|
}
|
||||||
|
state := updated.ModelStates[model]
|
||||||
|
if state == nil {
|
||||||
|
t.Fatalf("expected model state to be present")
|
||||||
|
}
|
||||||
|
if !state.NextRetryAfter.IsZero() {
|
||||||
|
t.Fatalf("expected NextRetryAfter to be zero when disable_cooling=true, got %v", state.NextRetryAfter)
|
||||||
|
}
|
||||||
|
}
|
||||||
24
sdk/cliproxy/auth/persist_policy.go
Normal file
24
sdk/cliproxy/auth/persist_policy.go
Normal file
@@ -0,0 +1,24 @@
|
|||||||
|
package auth
|
||||||
|
|
||||||
|
import "context"
|
||||||
|
|
||||||
|
type skipPersistContextKey struct{}
|
||||||
|
|
||||||
|
// WithSkipPersist returns a derived context that disables persistence for Manager Update/Register calls.
|
||||||
|
// It is intended for code paths that are reacting to file watcher events, where the file on disk is
|
||||||
|
// already the source of truth and persisting again would create a write-back loop.
|
||||||
|
func WithSkipPersist(ctx context.Context) context.Context {
|
||||||
|
if ctx == nil {
|
||||||
|
ctx = context.Background()
|
||||||
|
}
|
||||||
|
return context.WithValue(ctx, skipPersistContextKey{}, true)
|
||||||
|
}
|
||||||
|
|
||||||
|
func shouldSkipPersist(ctx context.Context) bool {
|
||||||
|
if ctx == nil {
|
||||||
|
return false
|
||||||
|
}
|
||||||
|
v := ctx.Value(skipPersistContextKey{})
|
||||||
|
enabled, ok := v.(bool)
|
||||||
|
return ok && enabled
|
||||||
|
}
|
||||||
62
sdk/cliproxy/auth/persist_policy_test.go
Normal file
62
sdk/cliproxy/auth/persist_policy_test.go
Normal file
@@ -0,0 +1,62 @@
|
|||||||
|
package auth
|
||||||
|
|
||||||
|
import (
|
||||||
|
"context"
|
||||||
|
"sync/atomic"
|
||||||
|
"testing"
|
||||||
|
)
|
||||||
|
|
||||||
|
type countingStore struct {
|
||||||
|
saveCount atomic.Int32
|
||||||
|
}
|
||||||
|
|
||||||
|
func (s *countingStore) List(context.Context) ([]*Auth, error) { return nil, nil }
|
||||||
|
|
||||||
|
func (s *countingStore) Save(context.Context, *Auth) (string, error) {
|
||||||
|
s.saveCount.Add(1)
|
||||||
|
return "", nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (s *countingStore) Delete(context.Context, string) error { return nil }
|
||||||
|
|
||||||
|
func TestWithSkipPersist_DisablesUpdatePersistence(t *testing.T) {
|
||||||
|
store := &countingStore{}
|
||||||
|
mgr := NewManager(store, nil, nil)
|
||||||
|
auth := &Auth{
|
||||||
|
ID: "auth-1",
|
||||||
|
Provider: "antigravity",
|
||||||
|
Metadata: map[string]any{"type": "antigravity"},
|
||||||
|
}
|
||||||
|
|
||||||
|
if _, err := mgr.Update(context.Background(), auth); err != nil {
|
||||||
|
t.Fatalf("Update returned error: %v", err)
|
||||||
|
}
|
||||||
|
if got := store.saveCount.Load(); got != 1 {
|
||||||
|
t.Fatalf("expected 1 Save call, got %d", got)
|
||||||
|
}
|
||||||
|
|
||||||
|
ctxSkip := WithSkipPersist(context.Background())
|
||||||
|
if _, err := mgr.Update(ctxSkip, auth); err != nil {
|
||||||
|
t.Fatalf("Update(skipPersist) returned error: %v", err)
|
||||||
|
}
|
||||||
|
if got := store.saveCount.Load(); got != 1 {
|
||||||
|
t.Fatalf("expected Save call count to remain 1, got %d", got)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestWithSkipPersist_DisablesRegisterPersistence(t *testing.T) {
|
||||||
|
store := &countingStore{}
|
||||||
|
mgr := NewManager(store, nil, nil)
|
||||||
|
auth := &Auth{
|
||||||
|
ID: "auth-1",
|
||||||
|
Provider: "antigravity",
|
||||||
|
Metadata: map[string]any{"type": "antigravity"},
|
||||||
|
}
|
||||||
|
|
||||||
|
if _, err := mgr.Register(WithSkipPersist(context.Background()), auth); err != nil {
|
||||||
|
t.Fatalf("Register(skipPersist) returned error: %v", err)
|
||||||
|
}
|
||||||
|
if got := store.saveCount.Load(); got != 0 {
|
||||||
|
t.Fatalf("expected 0 Save calls, got %d", got)
|
||||||
|
}
|
||||||
|
}
|
||||||
@@ -194,6 +194,108 @@ func (a *Auth) ProxyInfo() string {
|
|||||||
return "via proxy"
|
return "via proxy"
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// DisableCoolingOverride returns the auth-file scoped disable_cooling override when present.
|
||||||
|
// The value is read from metadata key "disable_cooling" (or legacy "disable-cooling").
|
||||||
|
func (a *Auth) DisableCoolingOverride() (bool, bool) {
|
||||||
|
if a == nil || a.Metadata == nil {
|
||||||
|
return false, false
|
||||||
|
}
|
||||||
|
if val, ok := a.Metadata["disable_cooling"]; ok {
|
||||||
|
if parsed, okParse := parseBoolAny(val); okParse {
|
||||||
|
return parsed, true
|
||||||
|
}
|
||||||
|
}
|
||||||
|
if val, ok := a.Metadata["disable-cooling"]; ok {
|
||||||
|
if parsed, okParse := parseBoolAny(val); okParse {
|
||||||
|
return parsed, true
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return false, false
|
||||||
|
}
|
||||||
|
|
||||||
|
// RequestRetryOverride returns the auth-file scoped request_retry override when present.
|
||||||
|
// The value is read from metadata key "request_retry" (or legacy "request-retry").
|
||||||
|
func (a *Auth) RequestRetryOverride() (int, bool) {
|
||||||
|
if a == nil || a.Metadata == nil {
|
||||||
|
return 0, false
|
||||||
|
}
|
||||||
|
if val, ok := a.Metadata["request_retry"]; ok {
|
||||||
|
if parsed, okParse := parseIntAny(val); okParse {
|
||||||
|
if parsed < 0 {
|
||||||
|
parsed = 0
|
||||||
|
}
|
||||||
|
return parsed, true
|
||||||
|
}
|
||||||
|
}
|
||||||
|
if val, ok := a.Metadata["request-retry"]; ok {
|
||||||
|
if parsed, okParse := parseIntAny(val); okParse {
|
||||||
|
if parsed < 0 {
|
||||||
|
parsed = 0
|
||||||
|
}
|
||||||
|
return parsed, true
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return 0, false
|
||||||
|
}
|
||||||
|
|
||||||
|
func parseBoolAny(val any) (bool, bool) {
|
||||||
|
switch typed := val.(type) {
|
||||||
|
case bool:
|
||||||
|
return typed, true
|
||||||
|
case string:
|
||||||
|
trimmed := strings.TrimSpace(typed)
|
||||||
|
if trimmed == "" {
|
||||||
|
return false, false
|
||||||
|
}
|
||||||
|
parsed, err := strconv.ParseBool(trimmed)
|
||||||
|
if err != nil {
|
||||||
|
return false, false
|
||||||
|
}
|
||||||
|
return parsed, true
|
||||||
|
case float64:
|
||||||
|
return typed != 0, true
|
||||||
|
case json.Number:
|
||||||
|
parsed, err := typed.Int64()
|
||||||
|
if err != nil {
|
||||||
|
return false, false
|
||||||
|
}
|
||||||
|
return parsed != 0, true
|
||||||
|
default:
|
||||||
|
return false, false
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func parseIntAny(val any) (int, bool) {
|
||||||
|
switch typed := val.(type) {
|
||||||
|
case int:
|
||||||
|
return typed, true
|
||||||
|
case int32:
|
||||||
|
return int(typed), true
|
||||||
|
case int64:
|
||||||
|
return int(typed), true
|
||||||
|
case float64:
|
||||||
|
return int(typed), true
|
||||||
|
case json.Number:
|
||||||
|
parsed, err := typed.Int64()
|
||||||
|
if err != nil {
|
||||||
|
return 0, false
|
||||||
|
}
|
||||||
|
return int(parsed), true
|
||||||
|
case string:
|
||||||
|
trimmed := strings.TrimSpace(typed)
|
||||||
|
if trimmed == "" {
|
||||||
|
return 0, false
|
||||||
|
}
|
||||||
|
parsed, err := strconv.Atoi(trimmed)
|
||||||
|
if err != nil {
|
||||||
|
return 0, false
|
||||||
|
}
|
||||||
|
return parsed, true
|
||||||
|
default:
|
||||||
|
return 0, false
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
func (a *Auth) AccountInfo() (string, string) {
|
func (a *Auth) AccountInfo() (string, string) {
|
||||||
if a == nil {
|
if a == nil {
|
||||||
return "", ""
|
return "", ""
|
||||||
|
|||||||
@@ -135,6 +135,7 @@ func (s *Service) ensureAuthUpdateQueue(ctx context.Context) {
|
|||||||
}
|
}
|
||||||
|
|
||||||
func (s *Service) consumeAuthUpdates(ctx context.Context) {
|
func (s *Service) consumeAuthUpdates(ctx context.Context) {
|
||||||
|
ctx = coreauth.WithSkipPersist(ctx)
|
||||||
for {
|
for {
|
||||||
select {
|
select {
|
||||||
case <-ctx.Done():
|
case <-ctx.Done():
|
||||||
|
|||||||
Reference in New Issue
Block a user