mirror of
https://github.com/router-for-me/CLIProxyAPIPlus.git
synced 2026-03-21 16:40:22 +00:00
Compare commits
93 Commits
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
c1bb77c7c9 | ||
|
|
6bcac3a55a | ||
|
|
fc346f4537 | ||
|
|
43e531a3b6 | ||
|
|
d24ea4ce2a | ||
|
|
2c30c981ae | ||
|
|
aa1da8a858 | ||
|
|
f1e9a787d7 | ||
|
|
4eeec297de | ||
|
|
77cc4ce3a0 | ||
|
|
37dfea1d3f | ||
|
|
e6626c672a | ||
|
|
c66cb0afd2 | ||
|
|
fb48eee973 | ||
|
|
bb44e5ec44 | ||
|
|
c785c1a3ca | ||
|
|
0659ffab75 | ||
|
|
7cb398d167 | ||
|
|
c3e12c5e58 | ||
|
|
1825fc7503 | ||
|
|
48732ba05e | ||
|
|
acf483c9e6 | ||
|
|
3b3e0d1141 | ||
|
|
7acd428507 | ||
|
|
450d1227bd | ||
|
|
492b9c46f0 | ||
|
|
6e634fe3f9 | ||
|
|
eb7571936c | ||
|
|
5382764d8a | ||
|
|
49c8ec69d0 | ||
|
|
3b421c8181 | ||
|
|
21d2329947 | ||
|
|
0993413bab | ||
|
|
713388dd7b | ||
|
|
e6c7af0fa9 | ||
|
|
837aa6e3aa | ||
|
|
d210be06c2 | ||
|
|
afc8a0f9be | ||
|
|
af8e9ef458 | ||
|
|
cec6f993ad | ||
|
|
950de29f48 | ||
|
|
d6ec33e8e1 | ||
|
|
081cfe806e | ||
|
|
c1c62a6c04 | ||
|
|
d693d7993b | ||
|
|
5936f9895c | ||
|
|
2fdf5d2793 | ||
|
|
b3da00d2ed | ||
|
|
740277a9f2 | ||
|
|
f91807b6b9 | ||
|
|
57d18bb226 | ||
|
|
10b9c6cb8a | ||
|
|
b24786f8a7 | ||
|
|
7b0eb41ebc | ||
|
|
70949929db | ||
|
|
7c9c89dace | ||
|
|
ef5901c81b | ||
|
|
d4829c82f7 | ||
|
|
a5f4166a9b | ||
|
|
0cbfe7f457 | ||
|
|
f2b1ec4f9e | ||
|
|
1cc21cc45b | ||
|
|
07cf616e2b | ||
|
|
2b8c466e88 | ||
|
|
ca2174ea48 | ||
|
|
c09fb2a79d | ||
|
|
4445a165e9 | ||
|
|
e92e2af71a | ||
|
|
a6bdd9a652 | ||
|
|
349a6349b3 | ||
|
|
00822770ec | ||
|
|
1a0ceda0fc | ||
|
|
b9ae4ab803 | ||
|
|
72add453d2 | ||
|
|
2789396435 | ||
|
|
61da7bd981 | ||
|
|
1f8f198c45 | ||
|
|
a45c6defa7 | ||
|
|
40bee3e8d9 | ||
|
|
93147dddeb | ||
|
|
c0f9b15a58 | ||
|
|
6f2fbdcbae | ||
|
|
65debb874f | ||
|
|
3caadac003 | ||
|
|
6a9e3a6b84 | ||
|
|
269972440a | ||
|
|
cce13e6ad2 | ||
|
|
8a565dcad8 | ||
|
|
d536110404 | ||
|
|
48e957ddff | ||
|
|
94563d622c | ||
|
|
ce0c6aa82b | ||
|
|
3c85d2a4d7 |
@@ -72,6 +72,7 @@ func main() {
|
|||||||
// Command-line flags to control the application's behavior.
|
// Command-line flags to control the application's behavior.
|
||||||
var login bool
|
var login bool
|
||||||
var codexLogin bool
|
var codexLogin bool
|
||||||
|
var codexDeviceLogin bool
|
||||||
var claudeLogin bool
|
var claudeLogin bool
|
||||||
var qwenLogin bool
|
var qwenLogin bool
|
||||||
var kiloLogin bool
|
var kiloLogin bool
|
||||||
@@ -99,6 +100,7 @@ func main() {
|
|||||||
// Define command-line flags for different operation modes.
|
// Define command-line flags for different operation modes.
|
||||||
flag.BoolVar(&login, "login", false, "Login Google Account")
|
flag.BoolVar(&login, "login", false, "Login Google Account")
|
||||||
flag.BoolVar(&codexLogin, "codex-login", false, "Login to Codex using OAuth")
|
flag.BoolVar(&codexLogin, "codex-login", false, "Login to Codex using OAuth")
|
||||||
|
flag.BoolVar(&codexDeviceLogin, "codex-device-login", false, "Login to Codex using device code flow")
|
||||||
flag.BoolVar(&claudeLogin, "claude-login", false, "Login to Claude using OAuth")
|
flag.BoolVar(&claudeLogin, "claude-login", false, "Login to Claude using OAuth")
|
||||||
flag.BoolVar(&qwenLogin, "qwen-login", false, "Login to Qwen using OAuth")
|
flag.BoolVar(&qwenLogin, "qwen-login", false, "Login to Qwen using OAuth")
|
||||||
flag.BoolVar(&kiloLogin, "kilo-login", false, "Login to Kilo AI using device flow")
|
flag.BoolVar(&kiloLogin, "kilo-login", false, "Login to Kilo AI using device flow")
|
||||||
@@ -502,6 +504,9 @@ func main() {
|
|||||||
} else if codexLogin {
|
} else if codexLogin {
|
||||||
// Handle Codex login
|
// Handle Codex login
|
||||||
cmd.DoCodexLogin(cfg, options)
|
cmd.DoCodexLogin(cfg, options)
|
||||||
|
} else if codexDeviceLogin {
|
||||||
|
// Handle Codex device-code login
|
||||||
|
cmd.DoCodexDeviceLogin(cfg, options)
|
||||||
} else if claudeLogin {
|
} else if claudeLogin {
|
||||||
// Handle Claude login
|
// Handle Claude login
|
||||||
cmd.DoClaudeLogin(cfg, options)
|
cmd.DoClaudeLogin(cfg, options)
|
||||||
|
|||||||
@@ -73,6 +73,10 @@ proxy-url: ''
|
|||||||
# When true, unprefixed model requests only use credentials without a prefix (except when prefix == model name).
|
# When true, unprefixed model requests only use credentials without a prefix (except when prefix == model name).
|
||||||
force-model-prefix: false
|
force-model-prefix: false
|
||||||
|
|
||||||
|
# When true, forward filtered upstream response headers to downstream clients.
|
||||||
|
# Default is false (disabled).
|
||||||
|
passthrough-headers: false
|
||||||
|
|
||||||
# Number of times to retry a request. Retries will occur if the HTTP response code is 403, 408, 500, 502, 503, or 504.
|
# Number of times to retry a request. Retries will occur if the HTTP response code is 403, 408, 500, 502, 503, or 504.
|
||||||
request-retry: 3
|
request-retry: 3
|
||||||
|
|
||||||
@@ -160,6 +164,7 @@ nonstream-keepalive-interval: 0
|
|||||||
# sensitive-words: # optional: words to obfuscate with zero-width characters
|
# sensitive-words: # optional: words to obfuscate with zero-width characters
|
||||||
# - "API"
|
# - "API"
|
||||||
# - "proxy"
|
# - "proxy"
|
||||||
|
# cache-user-id: true # optional: default is false; set true to reuse cached user_id per API key instead of generating a random one each request
|
||||||
|
|
||||||
# Default headers for Claude API requests. Update when Claude Code releases new versions.
|
# Default headers for Claude API requests. Update when Claude Code releases new versions.
|
||||||
# These are used as fallbacks when the client does not send its own headers.
|
# These are used as fallbacks when the client does not send its own headers.
|
||||||
|
|||||||
@@ -159,13 +159,13 @@ func (MyExecutor) CountTokens(context.Context, *coreauth.Auth, clipexec.Request,
|
|||||||
return clipexec.Response{}, errors.New("count tokens not implemented")
|
return clipexec.Response{}, errors.New("count tokens not implemented")
|
||||||
}
|
}
|
||||||
|
|
||||||
func (MyExecutor) ExecuteStream(ctx context.Context, a *coreauth.Auth, req clipexec.Request, opts clipexec.Options) (<-chan clipexec.StreamChunk, error) {
|
func (MyExecutor) ExecuteStream(ctx context.Context, a *coreauth.Auth, req clipexec.Request, opts clipexec.Options) (*clipexec.StreamResult, error) {
|
||||||
ch := make(chan clipexec.StreamChunk, 1)
|
ch := make(chan clipexec.StreamChunk, 1)
|
||||||
go func() {
|
go func() {
|
||||||
defer close(ch)
|
defer close(ch)
|
||||||
ch <- clipexec.StreamChunk{Payload: []byte("data: {\"ok\":true}\n\n")}
|
ch <- clipexec.StreamChunk{Payload: []byte("data: {\"ok\":true}\n\n")}
|
||||||
}()
|
}()
|
||||||
return ch, nil
|
return &clipexec.StreamResult{Chunks: ch}, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func (MyExecutor) Refresh(ctx context.Context, a *coreauth.Auth) (*coreauth.Auth, error) {
|
func (MyExecutor) Refresh(ctx context.Context, a *coreauth.Auth) (*coreauth.Auth, error) {
|
||||||
|
|||||||
@@ -58,7 +58,7 @@ func (EchoExecutor) Execute(context.Context, *coreauth.Auth, clipexec.Request, c
|
|||||||
return clipexec.Response{}, errors.New("echo executor: Execute not implemented")
|
return clipexec.Response{}, errors.New("echo executor: Execute not implemented")
|
||||||
}
|
}
|
||||||
|
|
||||||
func (EchoExecutor) ExecuteStream(context.Context, *coreauth.Auth, clipexec.Request, clipexec.Options) (<-chan clipexec.StreamChunk, error) {
|
func (EchoExecutor) ExecuteStream(context.Context, *coreauth.Auth, clipexec.Request, clipexec.Options) (*clipexec.StreamResult, error) {
|
||||||
return nil, errors.New("echo executor: ExecuteStream not implemented")
|
return nil, errors.New("echo executor: ExecuteStream not implemented")
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|||||||
@@ -951,11 +951,17 @@ func (h *Handler) saveTokenRecord(ctx context.Context, record *coreauth.Auth) (s
|
|||||||
if store == nil {
|
if store == nil {
|
||||||
return "", fmt.Errorf("token store unavailable")
|
return "", fmt.Errorf("token store unavailable")
|
||||||
}
|
}
|
||||||
|
if h.postAuthHook != nil {
|
||||||
|
if err := h.postAuthHook(ctx, record); err != nil {
|
||||||
|
return "", fmt.Errorf("post-auth hook failed: %w", err)
|
||||||
|
}
|
||||||
|
}
|
||||||
return store.Save(ctx, record)
|
return store.Save(ctx, record)
|
||||||
}
|
}
|
||||||
|
|
||||||
func (h *Handler) RequestAnthropicToken(c *gin.Context) {
|
func (h *Handler) RequestAnthropicToken(c *gin.Context) {
|
||||||
ctx := context.Background()
|
ctx := context.Background()
|
||||||
|
ctx = PopulateAuthContext(ctx, c)
|
||||||
|
|
||||||
fmt.Println("Initializing Claude authentication...")
|
fmt.Println("Initializing Claude authentication...")
|
||||||
|
|
||||||
@@ -1100,6 +1106,7 @@ func (h *Handler) RequestAnthropicToken(c *gin.Context) {
|
|||||||
|
|
||||||
func (h *Handler) RequestGeminiCLIToken(c *gin.Context) {
|
func (h *Handler) RequestGeminiCLIToken(c *gin.Context) {
|
||||||
ctx := context.Background()
|
ctx := context.Background()
|
||||||
|
ctx = PopulateAuthContext(ctx, c)
|
||||||
proxyHTTPClient := util.SetProxy(&h.cfg.SDKConfig, &http.Client{})
|
proxyHTTPClient := util.SetProxy(&h.cfg.SDKConfig, &http.Client{})
|
||||||
ctx = context.WithValue(ctx, oauth2.HTTPClient, proxyHTTPClient)
|
ctx = context.WithValue(ctx, oauth2.HTTPClient, proxyHTTPClient)
|
||||||
|
|
||||||
@@ -1358,6 +1365,7 @@ func (h *Handler) RequestGeminiCLIToken(c *gin.Context) {
|
|||||||
|
|
||||||
func (h *Handler) RequestCodexToken(c *gin.Context) {
|
func (h *Handler) RequestCodexToken(c *gin.Context) {
|
||||||
ctx := context.Background()
|
ctx := context.Background()
|
||||||
|
ctx = PopulateAuthContext(ctx, c)
|
||||||
|
|
||||||
fmt.Println("Initializing Codex authentication...")
|
fmt.Println("Initializing Codex authentication...")
|
||||||
|
|
||||||
@@ -1503,6 +1511,7 @@ func (h *Handler) RequestCodexToken(c *gin.Context) {
|
|||||||
|
|
||||||
func (h *Handler) RequestAntigravityToken(c *gin.Context) {
|
func (h *Handler) RequestAntigravityToken(c *gin.Context) {
|
||||||
ctx := context.Background()
|
ctx := context.Background()
|
||||||
|
ctx = PopulateAuthContext(ctx, c)
|
||||||
|
|
||||||
fmt.Println("Initializing Antigravity authentication...")
|
fmt.Println("Initializing Antigravity authentication...")
|
||||||
|
|
||||||
@@ -1667,6 +1676,7 @@ func (h *Handler) RequestAntigravityToken(c *gin.Context) {
|
|||||||
|
|
||||||
func (h *Handler) RequestQwenToken(c *gin.Context) {
|
func (h *Handler) RequestQwenToken(c *gin.Context) {
|
||||||
ctx := context.Background()
|
ctx := context.Background()
|
||||||
|
ctx = PopulateAuthContext(ctx, c)
|
||||||
|
|
||||||
fmt.Println("Initializing Qwen authentication...")
|
fmt.Println("Initializing Qwen authentication...")
|
||||||
|
|
||||||
@@ -1722,6 +1732,7 @@ func (h *Handler) RequestQwenToken(c *gin.Context) {
|
|||||||
|
|
||||||
func (h *Handler) RequestKimiToken(c *gin.Context) {
|
func (h *Handler) RequestKimiToken(c *gin.Context) {
|
||||||
ctx := context.Background()
|
ctx := context.Background()
|
||||||
|
ctx = PopulateAuthContext(ctx, c)
|
||||||
|
|
||||||
fmt.Println("Initializing Kimi authentication...")
|
fmt.Println("Initializing Kimi authentication...")
|
||||||
|
|
||||||
@@ -1798,6 +1809,7 @@ func (h *Handler) RequestKimiToken(c *gin.Context) {
|
|||||||
|
|
||||||
func (h *Handler) RequestIFlowToken(c *gin.Context) {
|
func (h *Handler) RequestIFlowToken(c *gin.Context) {
|
||||||
ctx := context.Background()
|
ctx := context.Background()
|
||||||
|
ctx = PopulateAuthContext(ctx, c)
|
||||||
|
|
||||||
fmt.Println("Initializing iFlow authentication...")
|
fmt.Println("Initializing iFlow authentication...")
|
||||||
|
|
||||||
@@ -1917,8 +1929,6 @@ func (h *Handler) RequestGitHubToken(c *gin.Context) {
|
|||||||
state := fmt.Sprintf("gh-%d", time.Now().UnixNano())
|
state := fmt.Sprintf("gh-%d", time.Now().UnixNano())
|
||||||
|
|
||||||
// Initialize Copilot auth service
|
// Initialize Copilot auth service
|
||||||
// We need to import "github.com/router-for-me/CLIProxyAPI/v6/internal/auth/copilot" first if not present
|
|
||||||
// Assuming copilot package is imported as "copilot"
|
|
||||||
deviceClient := copilot.NewDeviceFlowClient(h.cfg)
|
deviceClient := copilot.NewDeviceFlowClient(h.cfg)
|
||||||
|
|
||||||
// Initiate device flow
|
// Initiate device flow
|
||||||
@@ -1932,7 +1942,7 @@ func (h *Handler) RequestGitHubToken(c *gin.Context) {
|
|||||||
authURL := deviceCode.VerificationURI
|
authURL := deviceCode.VerificationURI
|
||||||
userCode := deviceCode.UserCode
|
userCode := deviceCode.UserCode
|
||||||
|
|
||||||
RegisterOAuthSession(state, "github")
|
RegisterOAuthSession(state, "github-copilot")
|
||||||
|
|
||||||
go func() {
|
go func() {
|
||||||
fmt.Printf("Please visit %s and enter code: %s\n", authURL, userCode)
|
fmt.Printf("Please visit %s and enter code: %s\n", authURL, userCode)
|
||||||
@@ -1944,9 +1954,13 @@ func (h *Handler) RequestGitHubToken(c *gin.Context) {
|
|||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
username, errUser := deviceClient.FetchUserInfo(ctx, tokenData.AccessToken)
|
userInfo, errUser := deviceClient.FetchUserInfo(ctx, tokenData.AccessToken)
|
||||||
if errUser != nil {
|
if errUser != nil {
|
||||||
log.Warnf("Failed to fetch user info: %v", errUser)
|
log.Warnf("Failed to fetch user info: %v", errUser)
|
||||||
|
}
|
||||||
|
|
||||||
|
username := userInfo.Login
|
||||||
|
if username == "" {
|
||||||
username = "github-user"
|
username = "github-user"
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -1955,18 +1969,26 @@ func (h *Handler) RequestGitHubToken(c *gin.Context) {
|
|||||||
TokenType: tokenData.TokenType,
|
TokenType: tokenData.TokenType,
|
||||||
Scope: tokenData.Scope,
|
Scope: tokenData.Scope,
|
||||||
Username: username,
|
Username: username,
|
||||||
|
Email: userInfo.Email,
|
||||||
|
Name: userInfo.Name,
|
||||||
Type: "github-copilot",
|
Type: "github-copilot",
|
||||||
}
|
}
|
||||||
|
|
||||||
fileName := fmt.Sprintf("github-%s.json", username)
|
fileName := fmt.Sprintf("github-copilot-%s.json", username)
|
||||||
|
label := userInfo.Email
|
||||||
|
if label == "" {
|
||||||
|
label = username
|
||||||
|
}
|
||||||
record := &coreauth.Auth{
|
record := &coreauth.Auth{
|
||||||
ID: fileName,
|
ID: fileName,
|
||||||
Provider: "github",
|
Provider: "github-copilot",
|
||||||
|
Label: label,
|
||||||
FileName: fileName,
|
FileName: fileName,
|
||||||
Storage: tokenStorage,
|
Storage: tokenStorage,
|
||||||
Metadata: map[string]any{
|
Metadata: map[string]any{
|
||||||
"email": username,
|
"email": userInfo.Email,
|
||||||
"username": username,
|
"username": username,
|
||||||
|
"name": userInfo.Name,
|
||||||
},
|
},
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -1980,7 +2002,7 @@ func (h *Handler) RequestGitHubToken(c *gin.Context) {
|
|||||||
fmt.Printf("Authentication successful! Token saved to %s\n", savedPath)
|
fmt.Printf("Authentication successful! Token saved to %s\n", savedPath)
|
||||||
fmt.Println("You can now use GitHub Copilot services through this CLI")
|
fmt.Println("You can now use GitHub Copilot services through this CLI")
|
||||||
CompleteOAuthSession(state)
|
CompleteOAuthSession(state)
|
||||||
CompleteOAuthSessionsByProvider("github")
|
CompleteOAuthSessionsByProvider("github-copilot")
|
||||||
}()
|
}()
|
||||||
|
|
||||||
c.JSON(200, gin.H{
|
c.JSON(200, gin.H{
|
||||||
@@ -2521,6 +2543,14 @@ func (h *Handler) GetAuthStatus(c *gin.Context) {
|
|||||||
c.JSON(http.StatusOK, gin.H{"status": "wait"})
|
c.JSON(http.StatusOK, gin.H{"status": "wait"})
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// PopulateAuthContext extracts request info and adds it to the context
|
||||||
|
func PopulateAuthContext(ctx context.Context, c *gin.Context) context.Context {
|
||||||
|
info := &coreauth.RequestInfo{
|
||||||
|
Query: c.Request.URL.Query(),
|
||||||
|
Headers: c.Request.Header,
|
||||||
|
}
|
||||||
|
return coreauth.WithRequestInfo(ctx, info)
|
||||||
|
}
|
||||||
const kiroCallbackPort = 9876
|
const kiroCallbackPort = 9876
|
||||||
|
|
||||||
func (h *Handler) RequestKiroToken(c *gin.Context) {
|
func (h *Handler) RequestKiroToken(c *gin.Context) {
|
||||||
|
|||||||
@@ -47,6 +47,7 @@ type Handler struct {
|
|||||||
allowRemoteOverride bool
|
allowRemoteOverride bool
|
||||||
envSecret string
|
envSecret string
|
||||||
logDir string
|
logDir string
|
||||||
|
postAuthHook coreauth.PostAuthHook
|
||||||
}
|
}
|
||||||
|
|
||||||
// NewHandler creates a new management handler instance.
|
// NewHandler creates a new management handler instance.
|
||||||
@@ -128,6 +129,11 @@ func (h *Handler) SetLogDirectory(dir string) {
|
|||||||
h.logDir = dir
|
h.logDir = dir
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// SetPostAuthHook registers a hook to be called after auth record creation but before persistence.
|
||||||
|
func (h *Handler) SetPostAuthHook(hook coreauth.PostAuthHook) {
|
||||||
|
h.postAuthHook = hook
|
||||||
|
}
|
||||||
|
|
||||||
// Middleware enforces access control for management endpoints.
|
// Middleware enforces access control for management endpoints.
|
||||||
// All requests (local and remote) require a valid management key.
|
// All requests (local and remote) require a valid management key.
|
||||||
// Additionally, remote access requires allow-remote-management=true.
|
// Additionally, remote access requires allow-remote-management=true.
|
||||||
|
|||||||
@@ -215,7 +215,7 @@ func createReverseProxy(upstreamURL string, secretSource SecretSource) (*httputi
|
|||||||
|
|
||||||
// Don't log as error for context canceled - it's usually client closing connection
|
// Don't log as error for context canceled - it's usually client closing connection
|
||||||
if errors.Is(err, context.Canceled) {
|
if errors.Is(err, context.Canceled) {
|
||||||
log.Debugf("amp upstream proxy [%s]: client canceled request for %s %s", errType, req.Method, req.URL.Path)
|
return
|
||||||
} else {
|
} else {
|
||||||
log.Errorf("amp upstream proxy error [%s] for %s %s: %v", errType, req.Method, req.URL.Path, err)
|
log.Errorf("amp upstream proxy error [%s] for %s %s: %v", errType, req.Method, req.URL.Path, err)
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -493,6 +493,30 @@ func TestReverseProxy_ErrorHandler(t *testing.T) {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func TestReverseProxy_ErrorHandler_ContextCanceled(t *testing.T) {
|
||||||
|
// Test that context.Canceled errors return 499 without generic error response
|
||||||
|
proxy, err := createReverseProxy("http://example.com", NewStaticSecretSource(""))
|
||||||
|
if err != nil {
|
||||||
|
t.Fatal(err)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Create a canceled context to trigger the cancellation path
|
||||||
|
ctx, cancel := context.WithCancel(context.Background())
|
||||||
|
cancel() // Cancel immediately
|
||||||
|
|
||||||
|
req := httptest.NewRequest(http.MethodGet, "/test", nil).WithContext(ctx)
|
||||||
|
rr := httptest.NewRecorder()
|
||||||
|
|
||||||
|
// Directly invoke the ErrorHandler with context.Canceled
|
||||||
|
proxy.ErrorHandler(rr, req, context.Canceled)
|
||||||
|
|
||||||
|
// Body should be empty for canceled requests (no JSON error response)
|
||||||
|
body := rr.Body.Bytes()
|
||||||
|
if len(body) > 0 {
|
||||||
|
t.Fatalf("expected empty body for canceled context, got: %s", body)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
func TestReverseProxy_FullRoundTrip_Gzip(t *testing.T) {
|
func TestReverseProxy_FullRoundTrip_Gzip(t *testing.T) {
|
||||||
// Upstream returns gzipped JSON without Content-Encoding header
|
// Upstream returns gzipped JSON without Content-Encoding header
|
||||||
upstream := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
upstream := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||||
|
|||||||
@@ -52,6 +52,7 @@ type serverOptionConfig struct {
|
|||||||
keepAliveEnabled bool
|
keepAliveEnabled bool
|
||||||
keepAliveTimeout time.Duration
|
keepAliveTimeout time.Duration
|
||||||
keepAliveOnTimeout func()
|
keepAliveOnTimeout func()
|
||||||
|
postAuthHook auth.PostAuthHook
|
||||||
}
|
}
|
||||||
|
|
||||||
// ServerOption customises HTTP server construction.
|
// ServerOption customises HTTP server construction.
|
||||||
@@ -112,6 +113,13 @@ func WithRequestLoggerFactory(factory func(*config.Config, string) logging.Reque
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// WithPostAuthHook registers a hook to be called after auth record creation.
|
||||||
|
func WithPostAuthHook(hook auth.PostAuthHook) ServerOption {
|
||||||
|
return func(cfg *serverOptionConfig) {
|
||||||
|
cfg.postAuthHook = hook
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
// Server represents the main API server.
|
// Server represents the main API server.
|
||||||
// It encapsulates the Gin engine, HTTP server, handlers, and configuration.
|
// It encapsulates the Gin engine, HTTP server, handlers, and configuration.
|
||||||
type Server struct {
|
type Server struct {
|
||||||
@@ -263,6 +271,9 @@ func NewServer(cfg *config.Config, authManager *auth.Manager, accessManager *sdk
|
|||||||
}
|
}
|
||||||
logDir := logging.ResolveLogDirectory(cfg)
|
logDir := logging.ResolveLogDirectory(cfg)
|
||||||
s.mgmt.SetLogDirectory(logDir)
|
s.mgmt.SetLogDirectory(logDir)
|
||||||
|
if optionState.postAuthHook != nil {
|
||||||
|
s.mgmt.SetPostAuthHook(optionState.postAuthHook)
|
||||||
|
}
|
||||||
s.localPassword = optionState.localPassword
|
s.localPassword = optionState.localPassword
|
||||||
|
|
||||||
// Setup routes
|
// Setup routes
|
||||||
|
|||||||
@@ -20,7 +20,7 @@ import (
|
|||||||
// OAuth configuration constants for Claude/Anthropic
|
// OAuth configuration constants for Claude/Anthropic
|
||||||
const (
|
const (
|
||||||
AuthURL = "https://claude.ai/oauth/authorize"
|
AuthURL = "https://claude.ai/oauth/authorize"
|
||||||
TokenURL = "https://console.anthropic.com/v1/oauth/token"
|
TokenURL = "https://api.anthropic.com/v1/oauth/token"
|
||||||
ClientID = "9d1c250a-e61b-44d9-88ed-5944d1962f5e"
|
ClientID = "9d1c250a-e61b-44d9-88ed-5944d1962f5e"
|
||||||
RedirectURI = "http://localhost:54545/callback"
|
RedirectURI = "http://localhost:54545/callback"
|
||||||
)
|
)
|
||||||
|
|||||||
@@ -36,11 +36,21 @@ type ClaudeTokenStorage struct {
|
|||||||
|
|
||||||
// Expire is the timestamp when the current access token expires.
|
// Expire is the timestamp when the current access token expires.
|
||||||
Expire string `json:"expired"`
|
Expire string `json:"expired"`
|
||||||
|
|
||||||
|
// Metadata holds arbitrary key-value pairs injected via hooks.
|
||||||
|
// It is not exported to JSON directly to allow flattening during serialization.
|
||||||
|
Metadata map[string]any `json:"-"`
|
||||||
|
}
|
||||||
|
|
||||||
|
// SetMetadata allows external callers to inject metadata into the storage before saving.
|
||||||
|
func (ts *ClaudeTokenStorage) SetMetadata(meta map[string]any) {
|
||||||
|
ts.Metadata = meta
|
||||||
}
|
}
|
||||||
|
|
||||||
// SaveTokenToFile serializes the Claude token storage to a JSON file.
|
// SaveTokenToFile serializes the Claude token storage to a JSON file.
|
||||||
// This method creates the necessary directory structure and writes the token
|
// This method creates the necessary directory structure and writes the token
|
||||||
// data in JSON format to the specified file path for persistent storage.
|
// data in JSON format to the specified file path for persistent storage.
|
||||||
|
// It merges any injected metadata into the top-level JSON object.
|
||||||
//
|
//
|
||||||
// Parameters:
|
// Parameters:
|
||||||
// - authFilePath: The full path where the token file should be saved
|
// - authFilePath: The full path where the token file should be saved
|
||||||
@@ -65,8 +75,14 @@ func (ts *ClaudeTokenStorage) SaveTokenToFile(authFilePath string) error {
|
|||||||
_ = f.Close()
|
_ = f.Close()
|
||||||
}()
|
}()
|
||||||
|
|
||||||
|
// Merge metadata using helper
|
||||||
|
data, errMerge := misc.MergeMetadata(ts, ts.Metadata)
|
||||||
|
if errMerge != nil {
|
||||||
|
return fmt.Errorf("failed to merge metadata: %w", errMerge)
|
||||||
|
}
|
||||||
|
|
||||||
// Encode and write the token data as JSON
|
// Encode and write the token data as JSON
|
||||||
if err = json.NewEncoder(f).Encode(ts); err != nil {
|
if err = json.NewEncoder(f).Encode(data); err != nil {
|
||||||
return fmt.Errorf("failed to write token to file: %w", err)
|
return fmt.Errorf("failed to write token to file: %w", err)
|
||||||
}
|
}
|
||||||
return nil
|
return nil
|
||||||
|
|||||||
@@ -71,16 +71,26 @@ func (o *CodexAuth) GenerateAuthURL(state string, pkceCodes *PKCECodes) (string,
|
|||||||
// It performs an HTTP POST request to the OpenAI token endpoint with the provided
|
// It performs an HTTP POST request to the OpenAI token endpoint with the provided
|
||||||
// authorization code and PKCE verifier.
|
// authorization code and PKCE verifier.
|
||||||
func (o *CodexAuth) ExchangeCodeForTokens(ctx context.Context, code string, pkceCodes *PKCECodes) (*CodexAuthBundle, error) {
|
func (o *CodexAuth) ExchangeCodeForTokens(ctx context.Context, code string, pkceCodes *PKCECodes) (*CodexAuthBundle, error) {
|
||||||
|
return o.ExchangeCodeForTokensWithRedirect(ctx, code, RedirectURI, pkceCodes)
|
||||||
|
}
|
||||||
|
|
||||||
|
// ExchangeCodeForTokensWithRedirect exchanges an authorization code for tokens using
|
||||||
|
// a caller-provided redirect URI. This supports alternate auth flows such as device
|
||||||
|
// login while preserving the existing token parsing and storage behavior.
|
||||||
|
func (o *CodexAuth) ExchangeCodeForTokensWithRedirect(ctx context.Context, code, redirectURI string, pkceCodes *PKCECodes) (*CodexAuthBundle, error) {
|
||||||
if pkceCodes == nil {
|
if pkceCodes == nil {
|
||||||
return nil, fmt.Errorf("PKCE codes are required for token exchange")
|
return nil, fmt.Errorf("PKCE codes are required for token exchange")
|
||||||
}
|
}
|
||||||
|
if strings.TrimSpace(redirectURI) == "" {
|
||||||
|
return nil, fmt.Errorf("redirect URI is required for token exchange")
|
||||||
|
}
|
||||||
|
|
||||||
// Prepare token exchange request
|
// Prepare token exchange request
|
||||||
data := url.Values{
|
data := url.Values{
|
||||||
"grant_type": {"authorization_code"},
|
"grant_type": {"authorization_code"},
|
||||||
"client_id": {ClientID},
|
"client_id": {ClientID},
|
||||||
"code": {code},
|
"code": {code},
|
||||||
"redirect_uri": {RedirectURI},
|
"redirect_uri": {strings.TrimSpace(redirectURI)},
|
||||||
"code_verifier": {pkceCodes.CodeVerifier},
|
"code_verifier": {pkceCodes.CodeVerifier},
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -266,6 +276,10 @@ func (o *CodexAuth) RefreshTokensWithRetry(ctx context.Context, refreshToken str
|
|||||||
if err == nil {
|
if err == nil {
|
||||||
return tokenData, nil
|
return tokenData, nil
|
||||||
}
|
}
|
||||||
|
if isNonRetryableRefreshErr(err) {
|
||||||
|
log.Warnf("Token refresh attempt %d failed with non-retryable error: %v", attempt+1, err)
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
|
||||||
lastErr = err
|
lastErr = err
|
||||||
log.Warnf("Token refresh attempt %d failed: %v", attempt+1, err)
|
log.Warnf("Token refresh attempt %d failed: %v", attempt+1, err)
|
||||||
@@ -274,6 +288,14 @@ func (o *CodexAuth) RefreshTokensWithRetry(ctx context.Context, refreshToken str
|
|||||||
return nil, fmt.Errorf("token refresh failed after %d attempts: %w", maxRetries, lastErr)
|
return nil, fmt.Errorf("token refresh failed after %d attempts: %w", maxRetries, lastErr)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func isNonRetryableRefreshErr(err error) bool {
|
||||||
|
if err == nil {
|
||||||
|
return false
|
||||||
|
}
|
||||||
|
raw := strings.ToLower(err.Error())
|
||||||
|
return strings.Contains(raw, "refresh_token_reused")
|
||||||
|
}
|
||||||
|
|
||||||
// UpdateTokenStorage updates an existing CodexTokenStorage with new token data.
|
// UpdateTokenStorage updates an existing CodexTokenStorage with new token data.
|
||||||
// This is typically called after a successful token refresh to persist the new credentials.
|
// This is typically called after a successful token refresh to persist the new credentials.
|
||||||
func (o *CodexAuth) UpdateTokenStorage(storage *CodexTokenStorage, tokenData *CodexTokenData) {
|
func (o *CodexAuth) UpdateTokenStorage(storage *CodexTokenStorage, tokenData *CodexTokenData) {
|
||||||
|
|||||||
44
internal/auth/codex/openai_auth_test.go
Normal file
44
internal/auth/codex/openai_auth_test.go
Normal file
@@ -0,0 +1,44 @@
|
|||||||
|
package codex
|
||||||
|
|
||||||
|
import (
|
||||||
|
"context"
|
||||||
|
"io"
|
||||||
|
"net/http"
|
||||||
|
"strings"
|
||||||
|
"sync/atomic"
|
||||||
|
"testing"
|
||||||
|
)
|
||||||
|
|
||||||
|
type roundTripFunc func(*http.Request) (*http.Response, error)
|
||||||
|
|
||||||
|
func (f roundTripFunc) RoundTrip(req *http.Request) (*http.Response, error) {
|
||||||
|
return f(req)
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestRefreshTokensWithRetry_NonRetryableOnlyAttemptsOnce(t *testing.T) {
|
||||||
|
var calls int32
|
||||||
|
auth := &CodexAuth{
|
||||||
|
httpClient: &http.Client{
|
||||||
|
Transport: roundTripFunc(func(req *http.Request) (*http.Response, error) {
|
||||||
|
atomic.AddInt32(&calls, 1)
|
||||||
|
return &http.Response{
|
||||||
|
StatusCode: http.StatusBadRequest,
|
||||||
|
Body: io.NopCloser(strings.NewReader(`{"error":"invalid_grant","code":"refresh_token_reused"}`)),
|
||||||
|
Header: make(http.Header),
|
||||||
|
Request: req,
|
||||||
|
}, nil
|
||||||
|
}),
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
_, err := auth.RefreshTokensWithRetry(context.Background(), "dummy_refresh_token", 3)
|
||||||
|
if err == nil {
|
||||||
|
t.Fatalf("expected error for non-retryable refresh failure")
|
||||||
|
}
|
||||||
|
if !strings.Contains(strings.ToLower(err.Error()), "refresh_token_reused") {
|
||||||
|
t.Fatalf("expected refresh_token_reused in error, got: %v", err)
|
||||||
|
}
|
||||||
|
if got := atomic.LoadInt32(&calls); got != 1 {
|
||||||
|
t.Fatalf("expected 1 refresh attempt, got %d", got)
|
||||||
|
}
|
||||||
|
}
|
||||||
@@ -32,11 +32,21 @@ type CodexTokenStorage struct {
|
|||||||
Type string `json:"type"`
|
Type string `json:"type"`
|
||||||
// Expire is the timestamp when the current access token expires.
|
// Expire is the timestamp when the current access token expires.
|
||||||
Expire string `json:"expired"`
|
Expire string `json:"expired"`
|
||||||
|
|
||||||
|
// Metadata holds arbitrary key-value pairs injected via hooks.
|
||||||
|
// It is not exported to JSON directly to allow flattening during serialization.
|
||||||
|
Metadata map[string]any `json:"-"`
|
||||||
|
}
|
||||||
|
|
||||||
|
// SetMetadata allows external callers to inject metadata into the storage before saving.
|
||||||
|
func (ts *CodexTokenStorage) SetMetadata(meta map[string]any) {
|
||||||
|
ts.Metadata = meta
|
||||||
}
|
}
|
||||||
|
|
||||||
// SaveTokenToFile serializes the Codex token storage to a JSON file.
|
// SaveTokenToFile serializes the Codex token storage to a JSON file.
|
||||||
// This method creates the necessary directory structure and writes the token
|
// This method creates the necessary directory structure and writes the token
|
||||||
// data in JSON format to the specified file path for persistent storage.
|
// data in JSON format to the specified file path for persistent storage.
|
||||||
|
// It merges any injected metadata into the top-level JSON object.
|
||||||
//
|
//
|
||||||
// Parameters:
|
// Parameters:
|
||||||
// - authFilePath: The full path where the token file should be saved
|
// - authFilePath: The full path where the token file should be saved
|
||||||
@@ -58,7 +68,13 @@ func (ts *CodexTokenStorage) SaveTokenToFile(authFilePath string) error {
|
|||||||
_ = f.Close()
|
_ = f.Close()
|
||||||
}()
|
}()
|
||||||
|
|
||||||
if err = json.NewEncoder(f).Encode(ts); err != nil {
|
// Merge metadata using helper
|
||||||
|
data, errMerge := misc.MergeMetadata(ts, ts.Metadata)
|
||||||
|
if errMerge != nil {
|
||||||
|
return fmt.Errorf("failed to merge metadata: %w", errMerge)
|
||||||
|
}
|
||||||
|
|
||||||
|
if err = json.NewEncoder(f).Encode(data); err != nil {
|
||||||
return fmt.Errorf("failed to write token to file: %w", err)
|
return fmt.Errorf("failed to write token to file: %w", err)
|
||||||
}
|
}
|
||||||
return nil
|
return nil
|
||||||
|
|||||||
@@ -82,15 +82,21 @@ func (c *CopilotAuth) WaitForAuthorization(ctx context.Context, deviceCode *Devi
|
|||||||
}
|
}
|
||||||
|
|
||||||
// Fetch the GitHub username
|
// Fetch the GitHub username
|
||||||
username, err := c.deviceClient.FetchUserInfo(ctx, tokenData.AccessToken)
|
userInfo, err := c.deviceClient.FetchUserInfo(ctx, tokenData.AccessToken)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
log.Warnf("copilot: failed to fetch user info: %v", err)
|
log.Warnf("copilot: failed to fetch user info: %v", err)
|
||||||
username = "unknown"
|
}
|
||||||
|
|
||||||
|
username := userInfo.Login
|
||||||
|
if username == "" {
|
||||||
|
username = "github-user"
|
||||||
}
|
}
|
||||||
|
|
||||||
return &CopilotAuthBundle{
|
return &CopilotAuthBundle{
|
||||||
TokenData: tokenData,
|
TokenData: tokenData,
|
||||||
Username: username,
|
Username: username,
|
||||||
|
Email: userInfo.Email,
|
||||||
|
Name: userInfo.Name,
|
||||||
}, nil
|
}, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -150,12 +156,12 @@ func (c *CopilotAuth) ValidateToken(ctx context.Context, accessToken string) (bo
|
|||||||
return false, "", nil
|
return false, "", nil
|
||||||
}
|
}
|
||||||
|
|
||||||
username, err := c.deviceClient.FetchUserInfo(ctx, accessToken)
|
userInfo, err := c.deviceClient.FetchUserInfo(ctx, accessToken)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return false, "", err
|
return false, "", err
|
||||||
}
|
}
|
||||||
|
|
||||||
return true, username, nil
|
return true, userInfo.Login, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
// CreateTokenStorage creates a new CopilotTokenStorage from auth bundle.
|
// CreateTokenStorage creates a new CopilotTokenStorage from auth bundle.
|
||||||
@@ -165,6 +171,8 @@ func (c *CopilotAuth) CreateTokenStorage(bundle *CopilotAuthBundle) *CopilotToke
|
|||||||
TokenType: bundle.TokenData.TokenType,
|
TokenType: bundle.TokenData.TokenType,
|
||||||
Scope: bundle.TokenData.Scope,
|
Scope: bundle.TokenData.Scope,
|
||||||
Username: bundle.Username,
|
Username: bundle.Username,
|
||||||
|
Email: bundle.Email,
|
||||||
|
Name: bundle.Name,
|
||||||
Type: "github-copilot",
|
Type: "github-copilot",
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -53,7 +53,7 @@ func NewDeviceFlowClient(cfg *config.Config) *DeviceFlowClient {
|
|||||||
func (c *DeviceFlowClient) RequestDeviceCode(ctx context.Context) (*DeviceCodeResponse, error) {
|
func (c *DeviceFlowClient) RequestDeviceCode(ctx context.Context) (*DeviceCodeResponse, error) {
|
||||||
data := url.Values{}
|
data := url.Values{}
|
||||||
data.Set("client_id", copilotClientID)
|
data.Set("client_id", copilotClientID)
|
||||||
data.Set("scope", "user:email")
|
data.Set("scope", "read:user user:email")
|
||||||
|
|
||||||
req, err := http.NewRequestWithContext(ctx, http.MethodPost, copilotDeviceCodeURL, strings.NewReader(data.Encode()))
|
req, err := http.NewRequestWithContext(ctx, http.MethodPost, copilotDeviceCodeURL, strings.NewReader(data.Encode()))
|
||||||
if err != nil {
|
if err != nil {
|
||||||
@@ -211,15 +211,25 @@ func (c *DeviceFlowClient) exchangeDeviceCode(ctx context.Context, deviceCode st
|
|||||||
}, nil
|
}, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
// FetchUserInfo retrieves the GitHub username for the authenticated user.
|
// GitHubUserInfo holds GitHub user profile information.
|
||||||
func (c *DeviceFlowClient) FetchUserInfo(ctx context.Context, accessToken string) (string, error) {
|
type GitHubUserInfo struct {
|
||||||
|
// Login is the GitHub username.
|
||||||
|
Login string
|
||||||
|
// Email is the primary email address (may be empty if not public).
|
||||||
|
Email string
|
||||||
|
// Name is the display name.
|
||||||
|
Name string
|
||||||
|
}
|
||||||
|
|
||||||
|
// FetchUserInfo retrieves the GitHub user profile for the authenticated user.
|
||||||
|
func (c *DeviceFlowClient) FetchUserInfo(ctx context.Context, accessToken string) (GitHubUserInfo, error) {
|
||||||
if accessToken == "" {
|
if accessToken == "" {
|
||||||
return "", NewAuthenticationError(ErrUserInfoFailed, fmt.Errorf("access token is empty"))
|
return GitHubUserInfo{}, NewAuthenticationError(ErrUserInfoFailed, fmt.Errorf("access token is empty"))
|
||||||
}
|
}
|
||||||
|
|
||||||
req, err := http.NewRequestWithContext(ctx, http.MethodGet, copilotUserInfoURL, nil)
|
req, err := http.NewRequestWithContext(ctx, http.MethodGet, copilotUserInfoURL, nil)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return "", NewAuthenticationError(ErrUserInfoFailed, err)
|
return GitHubUserInfo{}, NewAuthenticationError(ErrUserInfoFailed, err)
|
||||||
}
|
}
|
||||||
req.Header.Set("Authorization", "Bearer "+accessToken)
|
req.Header.Set("Authorization", "Bearer "+accessToken)
|
||||||
req.Header.Set("Accept", "application/json")
|
req.Header.Set("Accept", "application/json")
|
||||||
@@ -227,7 +237,7 @@ func (c *DeviceFlowClient) FetchUserInfo(ctx context.Context, accessToken string
|
|||||||
|
|
||||||
resp, err := c.httpClient.Do(req)
|
resp, err := c.httpClient.Do(req)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return "", NewAuthenticationError(ErrUserInfoFailed, err)
|
return GitHubUserInfo{}, NewAuthenticationError(ErrUserInfoFailed, err)
|
||||||
}
|
}
|
||||||
defer func() {
|
defer func() {
|
||||||
if errClose := resp.Body.Close(); errClose != nil {
|
if errClose := resp.Body.Close(); errClose != nil {
|
||||||
@@ -237,19 +247,25 @@ func (c *DeviceFlowClient) FetchUserInfo(ctx context.Context, accessToken string
|
|||||||
|
|
||||||
if !isHTTPSuccess(resp.StatusCode) {
|
if !isHTTPSuccess(resp.StatusCode) {
|
||||||
bodyBytes, _ := io.ReadAll(resp.Body)
|
bodyBytes, _ := io.ReadAll(resp.Body)
|
||||||
return "", NewAuthenticationError(ErrUserInfoFailed, fmt.Errorf("status %d: %s", resp.StatusCode, string(bodyBytes)))
|
return GitHubUserInfo{}, NewAuthenticationError(ErrUserInfoFailed, fmt.Errorf("status %d: %s", resp.StatusCode, string(bodyBytes)))
|
||||||
}
|
}
|
||||||
|
|
||||||
var userInfo struct {
|
var raw struct {
|
||||||
Login string `json:"login"`
|
Login string `json:"login"`
|
||||||
|
Email string `json:"email"`
|
||||||
|
Name string `json:"name"`
|
||||||
}
|
}
|
||||||
if err = json.NewDecoder(resp.Body).Decode(&userInfo); err != nil {
|
if err = json.NewDecoder(resp.Body).Decode(&raw); err != nil {
|
||||||
return "", NewAuthenticationError(ErrUserInfoFailed, err)
|
return GitHubUserInfo{}, NewAuthenticationError(ErrUserInfoFailed, err)
|
||||||
}
|
}
|
||||||
|
|
||||||
if userInfo.Login == "" {
|
if raw.Login == "" {
|
||||||
return "", NewAuthenticationError(ErrUserInfoFailed, fmt.Errorf("empty username"))
|
return GitHubUserInfo{}, NewAuthenticationError(ErrUserInfoFailed, fmt.Errorf("empty username"))
|
||||||
}
|
}
|
||||||
|
|
||||||
return userInfo.Login, nil
|
return GitHubUserInfo{
|
||||||
|
Login: raw.Login,
|
||||||
|
Email: raw.Email,
|
||||||
|
Name: raw.Name,
|
||||||
|
}, nil
|
||||||
}
|
}
|
||||||
|
|||||||
213
internal/auth/copilot/oauth_test.go
Normal file
213
internal/auth/copilot/oauth_test.go
Normal file
@@ -0,0 +1,213 @@
|
|||||||
|
package copilot
|
||||||
|
|
||||||
|
import (
|
||||||
|
"context"
|
||||||
|
"encoding/json"
|
||||||
|
"net/http"
|
||||||
|
"net/http/httptest"
|
||||||
|
"strings"
|
||||||
|
"testing"
|
||||||
|
)
|
||||||
|
|
||||||
|
// roundTripFunc lets us inject a custom transport for testing.
|
||||||
|
type roundTripFunc func(*http.Request) (*http.Response, error)
|
||||||
|
|
||||||
|
func (f roundTripFunc) RoundTrip(r *http.Request) (*http.Response, error) { return f(r) }
|
||||||
|
|
||||||
|
// newTestClient returns an *http.Client whose requests are redirected to the given test server,
|
||||||
|
// regardless of the original URL host.
|
||||||
|
func newTestClient(srv *httptest.Server) *http.Client {
|
||||||
|
return &http.Client{
|
||||||
|
Transport: roundTripFunc(func(req *http.Request) (*http.Response, error) {
|
||||||
|
req2 := req.Clone(req.Context())
|
||||||
|
req2.URL.Scheme = "http"
|
||||||
|
req2.URL.Host = strings.TrimPrefix(srv.URL, "http://")
|
||||||
|
return srv.Client().Transport.RoundTrip(req2)
|
||||||
|
}),
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// TestFetchUserInfo_FullProfile verifies that FetchUserInfo returns login, email, and name.
|
||||||
|
func TestFetchUserInfo_FullProfile(t *testing.T) {
|
||||||
|
srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||||
|
if !strings.HasPrefix(r.Header.Get("Authorization"), "Bearer ") {
|
||||||
|
w.WriteHeader(http.StatusUnauthorized)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
w.Header().Set("Content-Type", "application/json")
|
||||||
|
_ = json.NewEncoder(w).Encode(map[string]string{
|
||||||
|
"login": "octocat",
|
||||||
|
"email": "octocat@github.com",
|
||||||
|
"name": "The Octocat",
|
||||||
|
})
|
||||||
|
}))
|
||||||
|
defer srv.Close()
|
||||||
|
|
||||||
|
client := &DeviceFlowClient{httpClient: newTestClient(srv)}
|
||||||
|
info, err := client.FetchUserInfo(context.Background(), "test-token")
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("unexpected error: %v", err)
|
||||||
|
}
|
||||||
|
if info.Login != "octocat" {
|
||||||
|
t.Errorf("Login: got %q, want %q", info.Login, "octocat")
|
||||||
|
}
|
||||||
|
if info.Email != "octocat@github.com" {
|
||||||
|
t.Errorf("Email: got %q, want %q", info.Email, "octocat@github.com")
|
||||||
|
}
|
||||||
|
if info.Name != "The Octocat" {
|
||||||
|
t.Errorf("Name: got %q, want %q", info.Name, "The Octocat")
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// TestFetchUserInfo_EmptyEmail verifies graceful handling when email is absent (private account).
|
||||||
|
func TestFetchUserInfo_EmptyEmail(t *testing.T) {
|
||||||
|
srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||||
|
w.Header().Set("Content-Type", "application/json")
|
||||||
|
// GitHub returns null for private emails.
|
||||||
|
_, _ = w.Write([]byte(`{"login":"privateuser","email":null,"name":"Private User"}`))
|
||||||
|
}))
|
||||||
|
defer srv.Close()
|
||||||
|
|
||||||
|
client := &DeviceFlowClient{httpClient: newTestClient(srv)}
|
||||||
|
info, err := client.FetchUserInfo(context.Background(), "test-token")
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("unexpected error: %v", err)
|
||||||
|
}
|
||||||
|
if info.Login != "privateuser" {
|
||||||
|
t.Errorf("Login: got %q, want %q", info.Login, "privateuser")
|
||||||
|
}
|
||||||
|
if info.Email != "" {
|
||||||
|
t.Errorf("Email: got %q, want empty string", info.Email)
|
||||||
|
}
|
||||||
|
if info.Name != "Private User" {
|
||||||
|
t.Errorf("Name: got %q, want %q", info.Name, "Private User")
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// TestFetchUserInfo_EmptyToken verifies error is returned for empty access token.
|
||||||
|
func TestFetchUserInfo_EmptyToken(t *testing.T) {
|
||||||
|
client := &DeviceFlowClient{httpClient: http.DefaultClient}
|
||||||
|
_, err := client.FetchUserInfo(context.Background(), "")
|
||||||
|
if err == nil {
|
||||||
|
t.Fatal("expected error for empty token, got nil")
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// TestFetchUserInfo_EmptyLogin verifies error is returned when API returns no login.
|
||||||
|
func TestFetchUserInfo_EmptyLogin(t *testing.T) {
|
||||||
|
srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||||
|
w.Header().Set("Content-Type", "application/json")
|
||||||
|
_, _ = w.Write([]byte(`{"email":"someone@example.com","name":"No Login"}`))
|
||||||
|
}))
|
||||||
|
defer srv.Close()
|
||||||
|
|
||||||
|
client := &DeviceFlowClient{httpClient: newTestClient(srv)}
|
||||||
|
_, err := client.FetchUserInfo(context.Background(), "test-token")
|
||||||
|
if err == nil {
|
||||||
|
t.Fatal("expected error for empty login, got nil")
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// TestFetchUserInfo_HTTPError verifies error is returned on non-2xx response.
|
||||||
|
func TestFetchUserInfo_HTTPError(t *testing.T) {
|
||||||
|
srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||||
|
w.WriteHeader(http.StatusUnauthorized)
|
||||||
|
_, _ = w.Write([]byte(`{"message":"Bad credentials"}`))
|
||||||
|
}))
|
||||||
|
defer srv.Close()
|
||||||
|
|
||||||
|
client := &DeviceFlowClient{httpClient: newTestClient(srv)}
|
||||||
|
_, err := client.FetchUserInfo(context.Background(), "bad-token")
|
||||||
|
if err == nil {
|
||||||
|
t.Fatal("expected error for 401 response, got nil")
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// TestCopilotTokenStorage_EmailNameFields verifies Email and Name serialise correctly.
|
||||||
|
func TestCopilotTokenStorage_EmailNameFields(t *testing.T) {
|
||||||
|
ts := &CopilotTokenStorage{
|
||||||
|
AccessToken: "ghu_abc",
|
||||||
|
TokenType: "bearer",
|
||||||
|
Scope: "read:user user:email",
|
||||||
|
Username: "octocat",
|
||||||
|
Email: "octocat@github.com",
|
||||||
|
Name: "The Octocat",
|
||||||
|
Type: "github-copilot",
|
||||||
|
}
|
||||||
|
|
||||||
|
data, err := json.Marshal(ts)
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("marshal error: %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
var out map[string]any
|
||||||
|
if err = json.Unmarshal(data, &out); err != nil {
|
||||||
|
t.Fatalf("unmarshal error: %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
for _, key := range []string{"access_token", "username", "email", "name", "type"} {
|
||||||
|
if _, ok := out[key]; !ok {
|
||||||
|
t.Errorf("expected key %q in JSON output, not found", key)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
if out["email"] != "octocat@github.com" {
|
||||||
|
t.Errorf("email: got %v, want %q", out["email"], "octocat@github.com")
|
||||||
|
}
|
||||||
|
if out["name"] != "The Octocat" {
|
||||||
|
t.Errorf("name: got %v, want %q", out["name"], "The Octocat")
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// TestCopilotTokenStorage_OmitEmptyEmailName verifies email/name are omitted when empty (omitempty).
|
||||||
|
func TestCopilotTokenStorage_OmitEmptyEmailName(t *testing.T) {
|
||||||
|
ts := &CopilotTokenStorage{
|
||||||
|
AccessToken: "ghu_abc",
|
||||||
|
Username: "octocat",
|
||||||
|
Type: "github-copilot",
|
||||||
|
}
|
||||||
|
|
||||||
|
data, err := json.Marshal(ts)
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("marshal error: %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
var out map[string]any
|
||||||
|
if err = json.Unmarshal(data, &out); err != nil {
|
||||||
|
t.Fatalf("unmarshal error: %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
if _, ok := out["email"]; ok {
|
||||||
|
t.Error("email key should be omitted when empty (omitempty), but was present")
|
||||||
|
}
|
||||||
|
if _, ok := out["name"]; ok {
|
||||||
|
t.Error("name key should be omitted when empty (omitempty), but was present")
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// TestCopilotAuthBundle_EmailNameFields verifies bundle carries email and name through the pipeline.
|
||||||
|
func TestCopilotAuthBundle_EmailNameFields(t *testing.T) {
|
||||||
|
bundle := &CopilotAuthBundle{
|
||||||
|
TokenData: &CopilotTokenData{AccessToken: "ghu_abc"},
|
||||||
|
Username: "octocat",
|
||||||
|
Email: "octocat@github.com",
|
||||||
|
Name: "The Octocat",
|
||||||
|
}
|
||||||
|
if bundle.Email != "octocat@github.com" {
|
||||||
|
t.Errorf("bundle.Email: got %q, want %q", bundle.Email, "octocat@github.com")
|
||||||
|
}
|
||||||
|
if bundle.Name != "The Octocat" {
|
||||||
|
t.Errorf("bundle.Name: got %q, want %q", bundle.Name, "The Octocat")
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// TestGitHubUserInfo_Struct verifies the exported GitHubUserInfo struct fields are accessible.
|
||||||
|
func TestGitHubUserInfo_Struct(t *testing.T) {
|
||||||
|
info := GitHubUserInfo{
|
||||||
|
Login: "octocat",
|
||||||
|
Email: "octocat@github.com",
|
||||||
|
Name: "The Octocat",
|
||||||
|
}
|
||||||
|
if info.Login == "" || info.Email == "" || info.Name == "" {
|
||||||
|
t.Error("GitHubUserInfo fields should not be empty")
|
||||||
|
}
|
||||||
|
}
|
||||||
@@ -26,6 +26,10 @@ type CopilotTokenStorage struct {
|
|||||||
ExpiresAt string `json:"expires_at,omitempty"`
|
ExpiresAt string `json:"expires_at,omitempty"`
|
||||||
// Username is the GitHub username associated with this token.
|
// Username is the GitHub username associated with this token.
|
||||||
Username string `json:"username"`
|
Username string `json:"username"`
|
||||||
|
// Email is the GitHub email address associated with this token.
|
||||||
|
Email string `json:"email,omitempty"`
|
||||||
|
// Name is the GitHub display name associated with this token.
|
||||||
|
Name string `json:"name,omitempty"`
|
||||||
// Type indicates the authentication provider type, always "github-copilot" for this storage.
|
// Type indicates the authentication provider type, always "github-copilot" for this storage.
|
||||||
Type string `json:"type"`
|
Type string `json:"type"`
|
||||||
}
|
}
|
||||||
@@ -46,6 +50,10 @@ type CopilotAuthBundle struct {
|
|||||||
TokenData *CopilotTokenData
|
TokenData *CopilotTokenData
|
||||||
// Username is the GitHub username.
|
// Username is the GitHub username.
|
||||||
Username string
|
Username string
|
||||||
|
// Email is the GitHub email address.
|
||||||
|
Email string
|
||||||
|
// Name is the GitHub display name.
|
||||||
|
Name string
|
||||||
}
|
}
|
||||||
|
|
||||||
// DeviceCodeResponse represents GitHub's device code response.
|
// DeviceCodeResponse represents GitHub's device code response.
|
||||||
|
|||||||
@@ -35,11 +35,21 @@ type GeminiTokenStorage struct {
|
|||||||
|
|
||||||
// Type indicates the authentication provider type, always "gemini" for this storage.
|
// Type indicates the authentication provider type, always "gemini" for this storage.
|
||||||
Type string `json:"type"`
|
Type string `json:"type"`
|
||||||
|
|
||||||
|
// Metadata holds arbitrary key-value pairs injected via hooks.
|
||||||
|
// It is not exported to JSON directly to allow flattening during serialization.
|
||||||
|
Metadata map[string]any `json:"-"`
|
||||||
|
}
|
||||||
|
|
||||||
|
// SetMetadata allows external callers to inject metadata into the storage before saving.
|
||||||
|
func (ts *GeminiTokenStorage) SetMetadata(meta map[string]any) {
|
||||||
|
ts.Metadata = meta
|
||||||
}
|
}
|
||||||
|
|
||||||
// SaveTokenToFile serializes the Gemini token storage to a JSON file.
|
// SaveTokenToFile serializes the Gemini token storage to a JSON file.
|
||||||
// This method creates the necessary directory structure and writes the token
|
// This method creates the necessary directory structure and writes the token
|
||||||
// data in JSON format to the specified file path for persistent storage.
|
// data in JSON format to the specified file path for persistent storage.
|
||||||
|
// It merges any injected metadata into the top-level JSON object.
|
||||||
//
|
//
|
||||||
// Parameters:
|
// Parameters:
|
||||||
// - authFilePath: The full path where the token file should be saved
|
// - authFilePath: The full path where the token file should be saved
|
||||||
@@ -49,6 +59,11 @@ type GeminiTokenStorage struct {
|
|||||||
func (ts *GeminiTokenStorage) SaveTokenToFile(authFilePath string) error {
|
func (ts *GeminiTokenStorage) SaveTokenToFile(authFilePath string) error {
|
||||||
misc.LogSavingCredentials(authFilePath)
|
misc.LogSavingCredentials(authFilePath)
|
||||||
ts.Type = "gemini"
|
ts.Type = "gemini"
|
||||||
|
// Merge metadata using helper
|
||||||
|
data, errMerge := misc.MergeMetadata(ts, ts.Metadata)
|
||||||
|
if errMerge != nil {
|
||||||
|
return fmt.Errorf("failed to merge metadata: %w", errMerge)
|
||||||
|
}
|
||||||
if err := os.MkdirAll(filepath.Dir(authFilePath), 0700); err != nil {
|
if err := os.MkdirAll(filepath.Dir(authFilePath), 0700); err != nil {
|
||||||
return fmt.Errorf("failed to create directory: %v", err)
|
return fmt.Errorf("failed to create directory: %v", err)
|
||||||
}
|
}
|
||||||
@@ -63,7 +78,9 @@ func (ts *GeminiTokenStorage) SaveTokenToFile(authFilePath string) error {
|
|||||||
}
|
}
|
||||||
}()
|
}()
|
||||||
|
|
||||||
if err = json.NewEncoder(f).Encode(ts); err != nil {
|
enc := json.NewEncoder(f)
|
||||||
|
enc.SetIndent("", " ")
|
||||||
|
if err := enc.Encode(data); err != nil {
|
||||||
return fmt.Errorf("failed to write token to file: %w", err)
|
return fmt.Errorf("failed to write token to file: %w", err)
|
||||||
}
|
}
|
||||||
return nil
|
return nil
|
||||||
|
|||||||
@@ -21,6 +21,15 @@ type IFlowTokenStorage struct {
|
|||||||
Scope string `json:"scope"`
|
Scope string `json:"scope"`
|
||||||
Cookie string `json:"cookie"`
|
Cookie string `json:"cookie"`
|
||||||
Type string `json:"type"`
|
Type string `json:"type"`
|
||||||
|
|
||||||
|
// Metadata holds arbitrary key-value pairs injected via hooks.
|
||||||
|
// It is not exported to JSON directly to allow flattening during serialization.
|
||||||
|
Metadata map[string]any `json:"-"`
|
||||||
|
}
|
||||||
|
|
||||||
|
// SetMetadata allows external callers to inject metadata into the storage before saving.
|
||||||
|
func (ts *IFlowTokenStorage) SetMetadata(meta map[string]any) {
|
||||||
|
ts.Metadata = meta
|
||||||
}
|
}
|
||||||
|
|
||||||
// SaveTokenToFile serialises the token storage to disk.
|
// SaveTokenToFile serialises the token storage to disk.
|
||||||
@@ -37,7 +46,13 @@ func (ts *IFlowTokenStorage) SaveTokenToFile(authFilePath string) error {
|
|||||||
}
|
}
|
||||||
defer func() { _ = f.Close() }()
|
defer func() { _ = f.Close() }()
|
||||||
|
|
||||||
if err = json.NewEncoder(f).Encode(ts); err != nil {
|
// Merge metadata using helper
|
||||||
|
data, errMerge := misc.MergeMetadata(ts, ts.Metadata)
|
||||||
|
if errMerge != nil {
|
||||||
|
return fmt.Errorf("failed to merge metadata: %w", errMerge)
|
||||||
|
}
|
||||||
|
|
||||||
|
if err = json.NewEncoder(f).Encode(data); err != nil {
|
||||||
return fmt.Errorf("iflow token: encode token failed: %w", err)
|
return fmt.Errorf("iflow token: encode token failed: %w", err)
|
||||||
}
|
}
|
||||||
return nil
|
return nil
|
||||||
|
|||||||
@@ -29,6 +29,15 @@ type KimiTokenStorage struct {
|
|||||||
Expired string `json:"expired,omitempty"`
|
Expired string `json:"expired,omitempty"`
|
||||||
// Type indicates the authentication provider type, always "kimi" for this storage.
|
// Type indicates the authentication provider type, always "kimi" for this storage.
|
||||||
Type string `json:"type"`
|
Type string `json:"type"`
|
||||||
|
|
||||||
|
// Metadata holds arbitrary key-value pairs injected via hooks.
|
||||||
|
// It is not exported to JSON directly to allow flattening during serialization.
|
||||||
|
Metadata map[string]any `json:"-"`
|
||||||
|
}
|
||||||
|
|
||||||
|
// SetMetadata allows external callers to inject metadata into the storage before saving.
|
||||||
|
func (ts *KimiTokenStorage) SetMetadata(meta map[string]any) {
|
||||||
|
ts.Metadata = meta
|
||||||
}
|
}
|
||||||
|
|
||||||
// KimiTokenData holds the raw OAuth token response from Kimi.
|
// KimiTokenData holds the raw OAuth token response from Kimi.
|
||||||
@@ -86,9 +95,15 @@ func (ts *KimiTokenStorage) SaveTokenToFile(authFilePath string) error {
|
|||||||
_ = f.Close()
|
_ = f.Close()
|
||||||
}()
|
}()
|
||||||
|
|
||||||
|
// Merge metadata using helper
|
||||||
|
data, errMerge := misc.MergeMetadata(ts, ts.Metadata)
|
||||||
|
if errMerge != nil {
|
||||||
|
return fmt.Errorf("failed to merge metadata: %w", errMerge)
|
||||||
|
}
|
||||||
|
|
||||||
encoder := json.NewEncoder(f)
|
encoder := json.NewEncoder(f)
|
||||||
encoder.SetIndent("", " ")
|
encoder.SetIndent("", " ")
|
||||||
if err = encoder.Encode(ts); err != nil {
|
if err = encoder.Encode(data); err != nil {
|
||||||
return fmt.Errorf("failed to write token to file: %w", err)
|
return fmt.Errorf("failed to write token to file: %w", err)
|
||||||
}
|
}
|
||||||
return nil
|
return nil
|
||||||
|
|||||||
@@ -30,11 +30,21 @@ type QwenTokenStorage struct {
|
|||||||
Type string `json:"type"`
|
Type string `json:"type"`
|
||||||
// Expire is the timestamp when the current access token expires.
|
// Expire is the timestamp when the current access token expires.
|
||||||
Expire string `json:"expired"`
|
Expire string `json:"expired"`
|
||||||
|
|
||||||
|
// Metadata holds arbitrary key-value pairs injected via hooks.
|
||||||
|
// It is not exported to JSON directly to allow flattening during serialization.
|
||||||
|
Metadata map[string]any `json:"-"`
|
||||||
|
}
|
||||||
|
|
||||||
|
// SetMetadata allows external callers to inject metadata into the storage before saving.
|
||||||
|
func (ts *QwenTokenStorage) SetMetadata(meta map[string]any) {
|
||||||
|
ts.Metadata = meta
|
||||||
}
|
}
|
||||||
|
|
||||||
// SaveTokenToFile serializes the Qwen token storage to a JSON file.
|
// SaveTokenToFile serializes the Qwen token storage to a JSON file.
|
||||||
// This method creates the necessary directory structure and writes the token
|
// This method creates the necessary directory structure and writes the token
|
||||||
// data in JSON format to the specified file path for persistent storage.
|
// data in JSON format to the specified file path for persistent storage.
|
||||||
|
// It merges any injected metadata into the top-level JSON object.
|
||||||
//
|
//
|
||||||
// Parameters:
|
// Parameters:
|
||||||
// - authFilePath: The full path where the token file should be saved
|
// - authFilePath: The full path where the token file should be saved
|
||||||
@@ -56,7 +66,13 @@ func (ts *QwenTokenStorage) SaveTokenToFile(authFilePath string) error {
|
|||||||
_ = f.Close()
|
_ = f.Close()
|
||||||
}()
|
}()
|
||||||
|
|
||||||
if err = json.NewEncoder(f).Encode(ts); err != nil {
|
// Merge metadata using helper
|
||||||
|
data, errMerge := misc.MergeMetadata(ts, ts.Metadata)
|
||||||
|
if errMerge != nil {
|
||||||
|
return fmt.Errorf("failed to merge metadata: %w", errMerge)
|
||||||
|
}
|
||||||
|
|
||||||
|
if err = json.NewEncoder(f).Encode(data); err != nil {
|
||||||
return fmt.Errorf("failed to write token to file: %w", err)
|
return fmt.Errorf("failed to write token to file: %w", err)
|
||||||
}
|
}
|
||||||
return nil
|
return nil
|
||||||
|
|||||||
60
internal/cmd/openai_device_login.go
Normal file
60
internal/cmd/openai_device_login.go
Normal file
@@ -0,0 +1,60 @@
|
|||||||
|
package cmd
|
||||||
|
|
||||||
|
import (
|
||||||
|
"context"
|
||||||
|
"errors"
|
||||||
|
"fmt"
|
||||||
|
"os"
|
||||||
|
|
||||||
|
"github.com/router-for-me/CLIProxyAPI/v6/internal/auth/codex"
|
||||||
|
"github.com/router-for-me/CLIProxyAPI/v6/internal/config"
|
||||||
|
sdkAuth "github.com/router-for-me/CLIProxyAPI/v6/sdk/auth"
|
||||||
|
log "github.com/sirupsen/logrus"
|
||||||
|
)
|
||||||
|
|
||||||
|
const (
|
||||||
|
codexLoginModeMetadataKey = "codex_login_mode"
|
||||||
|
codexLoginModeDevice = "device"
|
||||||
|
)
|
||||||
|
|
||||||
|
// DoCodexDeviceLogin triggers the Codex device-code flow while keeping the
|
||||||
|
// existing codex-login OAuth callback flow intact.
|
||||||
|
func DoCodexDeviceLogin(cfg *config.Config, options *LoginOptions) {
|
||||||
|
if options == nil {
|
||||||
|
options = &LoginOptions{}
|
||||||
|
}
|
||||||
|
|
||||||
|
promptFn := options.Prompt
|
||||||
|
if promptFn == nil {
|
||||||
|
promptFn = defaultProjectPrompt()
|
||||||
|
}
|
||||||
|
|
||||||
|
manager := newAuthManager()
|
||||||
|
|
||||||
|
authOpts := &sdkAuth.LoginOptions{
|
||||||
|
NoBrowser: options.NoBrowser,
|
||||||
|
CallbackPort: options.CallbackPort,
|
||||||
|
Metadata: map[string]string{
|
||||||
|
codexLoginModeMetadataKey: codexLoginModeDevice,
|
||||||
|
},
|
||||||
|
Prompt: promptFn,
|
||||||
|
}
|
||||||
|
|
||||||
|
_, savedPath, err := manager.Login(context.Background(), "codex", cfg, authOpts)
|
||||||
|
if err != nil {
|
||||||
|
if authErr, ok := errors.AsType[*codex.AuthenticationError](err); ok {
|
||||||
|
log.Error(codex.GetUserFriendlyMessage(authErr))
|
||||||
|
if authErr.Type == codex.ErrPortInUse.Type {
|
||||||
|
os.Exit(codex.ErrPortInUse.Code)
|
||||||
|
}
|
||||||
|
return
|
||||||
|
}
|
||||||
|
fmt.Printf("Codex device authentication failed: %v\n", err)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
if savedPath != "" {
|
||||||
|
fmt.Printf("Authentication saved to %s\n", savedPath)
|
||||||
|
}
|
||||||
|
fmt.Println("Codex device authentication successful!")
|
||||||
|
}
|
||||||
@@ -314,6 +314,10 @@ type CloakConfig struct {
|
|||||||
// SensitiveWords is a list of words to obfuscate with zero-width characters.
|
// SensitiveWords is a list of words to obfuscate with zero-width characters.
|
||||||
// This can help bypass certain content filters.
|
// This can help bypass certain content filters.
|
||||||
SensitiveWords []string `yaml:"sensitive-words,omitempty" json:"sensitive-words,omitempty"`
|
SensitiveWords []string `yaml:"sensitive-words,omitempty" json:"sensitive-words,omitempty"`
|
||||||
|
|
||||||
|
// CacheUserID controls whether Claude user_id values are cached per API key.
|
||||||
|
// When false, a fresh random user_id is generated for every request.
|
||||||
|
CacheUserID *bool `yaml:"cache-user-id,omitempty" json:"cache-user-id,omitempty"`
|
||||||
}
|
}
|
||||||
|
|
||||||
// ClaudeKey represents the configuration for a Claude API key,
|
// ClaudeKey represents the configuration for a Claude API key,
|
||||||
@@ -759,22 +763,24 @@ func (cfg *Config) SanitizeOAuthModelAlias() {
|
|||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
// Inject default Kiro aliases if no user-configured kiro aliases exist
|
// Inject channel defaults when the channel is absent in user config.
|
||||||
|
// Presence is checked case-insensitively and includes explicit nil/empty markers.
|
||||||
if cfg.OAuthModelAlias == nil {
|
if cfg.OAuthModelAlias == nil {
|
||||||
cfg.OAuthModelAlias = make(map[string][]OAuthModelAlias)
|
cfg.OAuthModelAlias = make(map[string][]OAuthModelAlias)
|
||||||
}
|
}
|
||||||
if _, hasKiro := cfg.OAuthModelAlias["kiro"]; !hasKiro {
|
hasChannel := func(channel string) bool {
|
||||||
// Check case-insensitive too
|
|
||||||
found := false
|
|
||||||
for k := range cfg.OAuthModelAlias {
|
for k := range cfg.OAuthModelAlias {
|
||||||
if strings.EqualFold(strings.TrimSpace(k), "kiro") {
|
if strings.EqualFold(strings.TrimSpace(k), channel) {
|
||||||
found = true
|
return true
|
||||||
break
|
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
if !found {
|
return false
|
||||||
cfg.OAuthModelAlias["kiro"] = defaultKiroAliases()
|
}
|
||||||
}
|
if !hasChannel("kiro") {
|
||||||
|
cfg.OAuthModelAlias["kiro"] = defaultKiroAliases()
|
||||||
|
}
|
||||||
|
if !hasChannel("github-copilot") {
|
||||||
|
cfg.OAuthModelAlias["github-copilot"] = defaultGitHubCopilotAliases()
|
||||||
}
|
}
|
||||||
|
|
||||||
if len(cfg.OAuthModelAlias) == 0 {
|
if len(cfg.OAuthModelAlias) == 0 {
|
||||||
|
|||||||
@@ -42,6 +42,21 @@ func defaultKiroAliases() []OAuthModelAlias {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// defaultGitHubCopilotAliases returns default oauth-model-alias entries that
|
||||||
|
// expose Claude hyphen-style IDs for GitHub Copilot Claude models.
|
||||||
|
// This keeps compatibility with clients (e.g. Claude Code) that use
|
||||||
|
// Anthropic-style model IDs like "claude-opus-4-6".
|
||||||
|
func defaultGitHubCopilotAliases() []OAuthModelAlias {
|
||||||
|
return []OAuthModelAlias{
|
||||||
|
{Name: "claude-haiku-4.5", Alias: "claude-haiku-4-5", Fork: true},
|
||||||
|
{Name: "claude-opus-4.1", Alias: "claude-opus-4-1", Fork: true},
|
||||||
|
{Name: "claude-opus-4.5", Alias: "claude-opus-4-5", Fork: true},
|
||||||
|
{Name: "claude-opus-4.6", Alias: "claude-opus-4-6", Fork: true},
|
||||||
|
{Name: "claude-sonnet-4.5", Alias: "claude-sonnet-4-5", Fork: true},
|
||||||
|
{Name: "claude-sonnet-4.6", Alias: "claude-sonnet-4-6", Fork: true},
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
// defaultAntigravityAliases returns the default oauth-model-alias configuration
|
// defaultAntigravityAliases returns the default oauth-model-alias configuration
|
||||||
// for the antigravity channel when neither field exists.
|
// for the antigravity channel when neither field exists.
|
||||||
func defaultAntigravityAliases() []OAuthModelAlias {
|
func defaultAntigravityAliases() []OAuthModelAlias {
|
||||||
|
|||||||
@@ -107,6 +107,44 @@ func TestSanitizeOAuthModelAlias_InjectsDefaultKiroAliases(t *testing.T) {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func TestSanitizeOAuthModelAlias_InjectsDefaultGitHubCopilotAliases(t *testing.T) {
|
||||||
|
cfg := &Config{
|
||||||
|
OAuthModelAlias: map[string][]OAuthModelAlias{
|
||||||
|
"codex": {
|
||||||
|
{Name: "gpt-5", Alias: "g5"},
|
||||||
|
},
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
cfg.SanitizeOAuthModelAlias()
|
||||||
|
|
||||||
|
copilotAliases := cfg.OAuthModelAlias["github-copilot"]
|
||||||
|
if len(copilotAliases) == 0 {
|
||||||
|
t.Fatal("expected default github-copilot aliases to be injected")
|
||||||
|
}
|
||||||
|
|
||||||
|
aliasSet := make(map[string]bool, len(copilotAliases))
|
||||||
|
for _, a := range copilotAliases {
|
||||||
|
aliasSet[a.Alias] = true
|
||||||
|
if !a.Fork {
|
||||||
|
t.Fatalf("expected all default github-copilot aliases to have fork=true, got fork=false for %q", a.Alias)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
expectedAliases := []string{
|
||||||
|
"claude-haiku-4-5",
|
||||||
|
"claude-opus-4-1",
|
||||||
|
"claude-opus-4-5",
|
||||||
|
"claude-opus-4-6",
|
||||||
|
"claude-sonnet-4-5",
|
||||||
|
"claude-sonnet-4-6",
|
||||||
|
}
|
||||||
|
for _, expected := range expectedAliases {
|
||||||
|
if !aliasSet[expected] {
|
||||||
|
t.Fatalf("expected default github-copilot alias %q to be present", expected)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
func TestSanitizeOAuthModelAlias_DoesNotOverrideUserKiroAliases(t *testing.T) {
|
func TestSanitizeOAuthModelAlias_DoesNotOverrideUserKiroAliases(t *testing.T) {
|
||||||
// When user has configured kiro aliases, defaults should NOT be injected
|
// When user has configured kiro aliases, defaults should NOT be injected
|
||||||
cfg := &Config{
|
cfg := &Config{
|
||||||
@@ -128,6 +166,26 @@ func TestSanitizeOAuthModelAlias_DoesNotOverrideUserKiroAliases(t *testing.T) {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func TestSanitizeOAuthModelAlias_DoesNotOverrideUserGitHubCopilotAliases(t *testing.T) {
|
||||||
|
cfg := &Config{
|
||||||
|
OAuthModelAlias: map[string][]OAuthModelAlias{
|
||||||
|
"github-copilot": {
|
||||||
|
{Name: "claude-opus-4.6", Alias: "my-opus", Fork: true},
|
||||||
|
},
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
cfg.SanitizeOAuthModelAlias()
|
||||||
|
|
||||||
|
copilotAliases := cfg.OAuthModelAlias["github-copilot"]
|
||||||
|
if len(copilotAliases) != 1 {
|
||||||
|
t.Fatalf("expected 1 user-configured github-copilot alias, got %d", len(copilotAliases))
|
||||||
|
}
|
||||||
|
if copilotAliases[0].Alias != "my-opus" {
|
||||||
|
t.Fatalf("expected user alias to be preserved, got %q", copilotAliases[0].Alias)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
func TestSanitizeOAuthModelAlias_DoesNotReinjectAfterExplicitDeletion(t *testing.T) {
|
func TestSanitizeOAuthModelAlias_DoesNotReinjectAfterExplicitDeletion(t *testing.T) {
|
||||||
// When user explicitly deletes kiro aliases (key exists with nil value),
|
// When user explicitly deletes kiro aliases (key exists with nil value),
|
||||||
// defaults should NOT be re-injected on subsequent sanitize calls (#222).
|
// defaults should NOT be re-injected on subsequent sanitize calls (#222).
|
||||||
@@ -154,6 +212,24 @@ func TestSanitizeOAuthModelAlias_DoesNotReinjectAfterExplicitDeletion(t *testing
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func TestSanitizeOAuthModelAlias_GitHubCopilotDoesNotReinjectAfterExplicitDeletion(t *testing.T) {
|
||||||
|
cfg := &Config{
|
||||||
|
OAuthModelAlias: map[string][]OAuthModelAlias{
|
||||||
|
"github-copilot": nil, // explicitly deleted
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
cfg.SanitizeOAuthModelAlias()
|
||||||
|
|
||||||
|
copilotAliases := cfg.OAuthModelAlias["github-copilot"]
|
||||||
|
if len(copilotAliases) != 0 {
|
||||||
|
t.Fatalf("expected github-copilot aliases to remain empty after explicit deletion, got %d aliases", len(copilotAliases))
|
||||||
|
}
|
||||||
|
if _, exists := cfg.OAuthModelAlias["github-copilot"]; !exists {
|
||||||
|
t.Fatal("expected github-copilot key to be preserved as nil marker after sanitization")
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
func TestSanitizeOAuthModelAlias_DoesNotReinjectAfterExplicitDeletionEmpty(t *testing.T) {
|
func TestSanitizeOAuthModelAlias_DoesNotReinjectAfterExplicitDeletionEmpty(t *testing.T) {
|
||||||
// Same as above but with empty slice instead of nil (PUT with empty body).
|
// Same as above but with empty slice instead of nil (PUT with empty body).
|
||||||
cfg := &Config{
|
cfg := &Config{
|
||||||
|
|||||||
@@ -20,6 +20,10 @@ type SDKConfig struct {
|
|||||||
// APIKeys is a list of keys for authenticating clients to this proxy server.
|
// APIKeys is a list of keys for authenticating clients to this proxy server.
|
||||||
APIKeys []string `yaml:"api-keys" json:"api-keys"`
|
APIKeys []string `yaml:"api-keys" json:"api-keys"`
|
||||||
|
|
||||||
|
// PassthroughHeaders controls whether upstream response headers are forwarded to downstream clients.
|
||||||
|
// Default is false (disabled).
|
||||||
|
PassthroughHeaders bool `yaml:"passthrough-headers" json:"passthrough-headers"`
|
||||||
|
|
||||||
// Streaming configures server-side streaming behavior (keep-alives and safe bootstrap retries).
|
// Streaming configures server-side streaming behavior (keep-alives and safe bootstrap retries).
|
||||||
Streaming StreamingConfig `yaml:"streaming" json:"streaming"`
|
Streaming StreamingConfig `yaml:"streaming" json:"streaming"`
|
||||||
|
|
||||||
|
|||||||
@@ -1,6 +1,7 @@
|
|||||||
package misc
|
package misc
|
||||||
|
|
||||||
import (
|
import (
|
||||||
|
"encoding/json"
|
||||||
"fmt"
|
"fmt"
|
||||||
"path/filepath"
|
"path/filepath"
|
||||||
"strings"
|
"strings"
|
||||||
@@ -24,3 +25,37 @@ func LogSavingCredentials(path string) {
|
|||||||
func LogCredentialSeparator() {
|
func LogCredentialSeparator() {
|
||||||
log.Debug(credentialSeparator)
|
log.Debug(credentialSeparator)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// MergeMetadata serializes the source struct into a map and merges the provided metadata into it.
|
||||||
|
func MergeMetadata(source any, metadata map[string]any) (map[string]any, error) {
|
||||||
|
var data map[string]any
|
||||||
|
|
||||||
|
// Fast path: if source is already a map, just copy it to avoid mutation of original
|
||||||
|
if srcMap, ok := source.(map[string]any); ok {
|
||||||
|
data = make(map[string]any, len(srcMap)+len(metadata))
|
||||||
|
for k, v := range srcMap {
|
||||||
|
data[k] = v
|
||||||
|
}
|
||||||
|
} else {
|
||||||
|
// Slow path: marshal to JSON and back to map to respect JSON tags
|
||||||
|
temp, err := json.Marshal(source)
|
||||||
|
if err != nil {
|
||||||
|
return nil, fmt.Errorf("failed to marshal source: %w", err)
|
||||||
|
}
|
||||||
|
if err := json.Unmarshal(temp, &data); err != nil {
|
||||||
|
return nil, fmt.Errorf("failed to unmarshal to map: %w", err)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// Merge extra metadata
|
||||||
|
if metadata != nil {
|
||||||
|
if data == nil {
|
||||||
|
data = make(map[string]any)
|
||||||
|
}
|
||||||
|
for k, v := range metadata {
|
||||||
|
data[k] = v
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
return data, nil
|
||||||
|
}
|
||||||
|
|||||||
@@ -129,7 +129,19 @@ func LookupStaticModelInfo(modelID string) *ModelInfo {
|
|||||||
// These models are available through the GitHub Copilot API at api.githubcopilot.com.
|
// These models are available through the GitHub Copilot API at api.githubcopilot.com.
|
||||||
func GetGitHubCopilotModels() []*ModelInfo {
|
func GetGitHubCopilotModels() []*ModelInfo {
|
||||||
now := int64(1732752000) // 2024-11-27
|
now := int64(1732752000) // 2024-11-27
|
||||||
return []*ModelInfo{
|
gpt4oEntries := []struct {
|
||||||
|
ID string
|
||||||
|
DisplayName string
|
||||||
|
Description string
|
||||||
|
}{
|
||||||
|
{ID: "gpt-4o-2024-11-20", DisplayName: "GPT-4o (2024-11-20)", Description: "OpenAI GPT-4o 2024-11-20 via GitHub Copilot"},
|
||||||
|
{ID: "gpt-4o-2024-08-06", DisplayName: "GPT-4o (2024-08-06)", Description: "OpenAI GPT-4o 2024-08-06 via GitHub Copilot"},
|
||||||
|
{ID: "gpt-4o-2024-05-13", DisplayName: "GPT-4o (2024-05-13)", Description: "OpenAI GPT-4o 2024-05-13 via GitHub Copilot"},
|
||||||
|
{ID: "gpt-4o", DisplayName: "GPT-4o", Description: "OpenAI GPT-4o via GitHub Copilot"},
|
||||||
|
{ID: "gpt-4-o-preview", DisplayName: "GPT-4-o Preview", Description: "OpenAI GPT-4-o Preview via GitHub Copilot"},
|
||||||
|
}
|
||||||
|
|
||||||
|
models := []*ModelInfo{
|
||||||
{
|
{
|
||||||
ID: "gpt-4.1",
|
ID: "gpt-4.1",
|
||||||
Object: "model",
|
Object: "model",
|
||||||
@@ -141,6 +153,23 @@ func GetGitHubCopilotModels() []*ModelInfo {
|
|||||||
ContextLength: 128000,
|
ContextLength: 128000,
|
||||||
MaxCompletionTokens: 16384,
|
MaxCompletionTokens: 16384,
|
||||||
},
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
for _, entry := range gpt4oEntries {
|
||||||
|
models = append(models, &ModelInfo{
|
||||||
|
ID: entry.ID,
|
||||||
|
Object: "model",
|
||||||
|
Created: now,
|
||||||
|
OwnedBy: "github-copilot",
|
||||||
|
Type: "github-copilot",
|
||||||
|
DisplayName: entry.DisplayName,
|
||||||
|
Description: entry.Description,
|
||||||
|
ContextLength: 128000,
|
||||||
|
MaxCompletionTokens: 16384,
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
|
return append(models, []*ModelInfo{
|
||||||
{
|
{
|
||||||
ID: "gpt-5",
|
ID: "gpt-5",
|
||||||
Object: "model",
|
Object: "model",
|
||||||
@@ -377,6 +406,17 @@ func GetGitHubCopilotModels() []*ModelInfo {
|
|||||||
ContextLength: 1048576,
|
ContextLength: 1048576,
|
||||||
MaxCompletionTokens: 65536,
|
MaxCompletionTokens: 65536,
|
||||||
},
|
},
|
||||||
|
{
|
||||||
|
ID: "gemini-3.1-pro-preview",
|
||||||
|
Object: "model",
|
||||||
|
Created: now,
|
||||||
|
OwnedBy: "github-copilot",
|
||||||
|
Type: "github-copilot",
|
||||||
|
DisplayName: "Gemini 3.1 Pro (Preview)",
|
||||||
|
Description: "Google Gemini 3.1 Pro Preview via GitHub Copilot",
|
||||||
|
ContextLength: 1048576,
|
||||||
|
MaxCompletionTokens: 65536,
|
||||||
|
},
|
||||||
{
|
{
|
||||||
ID: "gemini-3-flash-preview",
|
ID: "gemini-3-flash-preview",
|
||||||
Object: "model",
|
Object: "model",
|
||||||
@@ -411,7 +451,7 @@ func GetGitHubCopilotModels() []*ModelInfo {
|
|||||||
MaxCompletionTokens: 16384,
|
MaxCompletionTokens: 16384,
|
||||||
SupportedEndpoints: []string{"/chat/completions", "/responses"},
|
SupportedEndpoints: []string{"/chat/completions", "/responses"},
|
||||||
},
|
},
|
||||||
}
|
}...)
|
||||||
}
|
}
|
||||||
|
|
||||||
// GetKiroModels returns the Kiro (AWS CodeWhisperer) model definitions
|
// GetKiroModels returns the Kiro (AWS CodeWhisperer) model definitions
|
||||||
|
|||||||
@@ -196,6 +196,21 @@ func GetGeminiModels() []*ModelInfo {
|
|||||||
SupportedGenerationMethods: []string{"generateContent", "countTokens", "createCachedContent", "batchGenerateContent"},
|
SupportedGenerationMethods: []string{"generateContent", "countTokens", "createCachedContent", "batchGenerateContent"},
|
||||||
Thinking: &ThinkingSupport{Min: 128, Max: 32768, ZeroAllowed: false, DynamicAllowed: true, Levels: []string{"low", "high"}},
|
Thinking: &ThinkingSupport{Min: 128, Max: 32768, ZeroAllowed: false, DynamicAllowed: true, Levels: []string{"low", "high"}},
|
||||||
},
|
},
|
||||||
|
{
|
||||||
|
ID: "gemini-3.1-pro-preview",
|
||||||
|
Object: "model",
|
||||||
|
Created: 1771459200,
|
||||||
|
OwnedBy: "google",
|
||||||
|
Type: "gemini",
|
||||||
|
Name: "models/gemini-3.1-pro-preview",
|
||||||
|
Version: "3.1",
|
||||||
|
DisplayName: "Gemini 3.1 Pro Preview",
|
||||||
|
Description: "Gemini 3.1 Pro Preview",
|
||||||
|
InputTokenLimit: 1048576,
|
||||||
|
OutputTokenLimit: 65536,
|
||||||
|
SupportedGenerationMethods: []string{"generateContent", "countTokens", "createCachedContent", "batchGenerateContent"},
|
||||||
|
Thinking: &ThinkingSupport{Min: 128, Max: 32768, ZeroAllowed: false, DynamicAllowed: true, Levels: []string{"low", "high"}},
|
||||||
|
},
|
||||||
{
|
{
|
||||||
ID: "gemini-3-flash-preview",
|
ID: "gemini-3-flash-preview",
|
||||||
Object: "model",
|
Object: "model",
|
||||||
@@ -306,6 +321,21 @@ func GetGeminiVertexModels() []*ModelInfo {
|
|||||||
SupportedGenerationMethods: []string{"generateContent", "countTokens", "createCachedContent", "batchGenerateContent"},
|
SupportedGenerationMethods: []string{"generateContent", "countTokens", "createCachedContent", "batchGenerateContent"},
|
||||||
Thinking: &ThinkingSupport{Min: 128, Max: 32768, ZeroAllowed: false, DynamicAllowed: true, Levels: []string{"minimal", "low", "medium", "high"}},
|
Thinking: &ThinkingSupport{Min: 128, Max: 32768, ZeroAllowed: false, DynamicAllowed: true, Levels: []string{"minimal", "low", "medium", "high"}},
|
||||||
},
|
},
|
||||||
|
{
|
||||||
|
ID: "gemini-3.1-pro-preview",
|
||||||
|
Object: "model",
|
||||||
|
Created: 1771459200,
|
||||||
|
OwnedBy: "google",
|
||||||
|
Type: "gemini",
|
||||||
|
Name: "models/gemini-3.1-pro-preview",
|
||||||
|
Version: "3.1",
|
||||||
|
DisplayName: "Gemini 3.1 Pro Preview",
|
||||||
|
Description: "Gemini 3.1 Pro Preview",
|
||||||
|
InputTokenLimit: 1048576,
|
||||||
|
OutputTokenLimit: 65536,
|
||||||
|
SupportedGenerationMethods: []string{"generateContent", "countTokens", "createCachedContent", "batchGenerateContent"},
|
||||||
|
Thinking: &ThinkingSupport{Min: 128, Max: 32768, ZeroAllowed: false, DynamicAllowed: true, Levels: []string{"low", "high"}},
|
||||||
|
},
|
||||||
{
|
{
|
||||||
ID: "gemini-3-pro-image-preview",
|
ID: "gemini-3-pro-image-preview",
|
||||||
Object: "model",
|
Object: "model",
|
||||||
@@ -448,6 +478,21 @@ func GetGeminiCLIModels() []*ModelInfo {
|
|||||||
SupportedGenerationMethods: []string{"generateContent", "countTokens", "createCachedContent", "batchGenerateContent"},
|
SupportedGenerationMethods: []string{"generateContent", "countTokens", "createCachedContent", "batchGenerateContent"},
|
||||||
Thinking: &ThinkingSupport{Min: 128, Max: 32768, ZeroAllowed: false, DynamicAllowed: true, Levels: []string{"low", "high"}},
|
Thinking: &ThinkingSupport{Min: 128, Max: 32768, ZeroAllowed: false, DynamicAllowed: true, Levels: []string{"low", "high"}},
|
||||||
},
|
},
|
||||||
|
{
|
||||||
|
ID: "gemini-3.1-pro-preview",
|
||||||
|
Object: "model",
|
||||||
|
Created: 1771459200,
|
||||||
|
OwnedBy: "google",
|
||||||
|
Type: "gemini",
|
||||||
|
Name: "models/gemini-3.1-pro-preview",
|
||||||
|
Version: "3.1",
|
||||||
|
DisplayName: "Gemini 3.1 Pro Preview",
|
||||||
|
Description: "Gemini 3.1 Pro Preview",
|
||||||
|
InputTokenLimit: 1048576,
|
||||||
|
OutputTokenLimit: 65536,
|
||||||
|
SupportedGenerationMethods: []string{"generateContent", "countTokens", "createCachedContent", "batchGenerateContent"},
|
||||||
|
Thinking: &ThinkingSupport{Min: 128, Max: 32768, ZeroAllowed: false, DynamicAllowed: true, Levels: []string{"low", "high"}},
|
||||||
|
},
|
||||||
{
|
{
|
||||||
ID: "gemini-3-flash-preview",
|
ID: "gemini-3-flash-preview",
|
||||||
Object: "model",
|
Object: "model",
|
||||||
@@ -529,6 +574,21 @@ func GetAIStudioModels() []*ModelInfo {
|
|||||||
SupportedGenerationMethods: []string{"generateContent", "countTokens", "createCachedContent", "batchGenerateContent"},
|
SupportedGenerationMethods: []string{"generateContent", "countTokens", "createCachedContent", "batchGenerateContent"},
|
||||||
Thinking: &ThinkingSupport{Min: 128, Max: 32768, ZeroAllowed: false, DynamicAllowed: true},
|
Thinking: &ThinkingSupport{Min: 128, Max: 32768, ZeroAllowed: false, DynamicAllowed: true},
|
||||||
},
|
},
|
||||||
|
{
|
||||||
|
ID: "gemini-3.1-pro-preview",
|
||||||
|
Object: "model",
|
||||||
|
Created: 1771459200,
|
||||||
|
OwnedBy: "google",
|
||||||
|
Type: "gemini",
|
||||||
|
Name: "models/gemini-3.1-pro-preview",
|
||||||
|
Version: "3.1",
|
||||||
|
DisplayName: "Gemini 3.1 Pro Preview",
|
||||||
|
Description: "Gemini 3.1 Pro Preview",
|
||||||
|
InputTokenLimit: 1048576,
|
||||||
|
OutputTokenLimit: 65536,
|
||||||
|
SupportedGenerationMethods: []string{"generateContent", "countTokens", "createCachedContent", "batchGenerateContent"},
|
||||||
|
Thinking: &ThinkingSupport{Min: 128, Max: 32768, ZeroAllowed: false, DynamicAllowed: true},
|
||||||
|
},
|
||||||
{
|
{
|
||||||
ID: "gemini-3-flash-preview",
|
ID: "gemini-3-flash-preview",
|
||||||
Object: "model",
|
Object: "model",
|
||||||
@@ -915,6 +975,7 @@ func GetAntigravityModelConfig() map[string]*AntigravityModelConfig {
|
|||||||
"gemini-2.5-flash-lite": {Thinking: &ThinkingSupport{Min: 0, Max: 24576, ZeroAllowed: true, DynamicAllowed: true}},
|
"gemini-2.5-flash-lite": {Thinking: &ThinkingSupport{Min: 0, Max: 24576, ZeroAllowed: true, DynamicAllowed: true}},
|
||||||
"gemini-3-pro-high": {Thinking: &ThinkingSupport{Min: 128, Max: 32768, ZeroAllowed: false, DynamicAllowed: true, Levels: []string{"low", "high"}}},
|
"gemini-3-pro-high": {Thinking: &ThinkingSupport{Min: 128, Max: 32768, ZeroAllowed: false, DynamicAllowed: true, Levels: []string{"low", "high"}}},
|
||||||
"gemini-3-pro-image": {Thinking: &ThinkingSupport{Min: 128, Max: 32768, ZeroAllowed: false, DynamicAllowed: true, Levels: []string{"low", "high"}}},
|
"gemini-3-pro-image": {Thinking: &ThinkingSupport{Min: 128, Max: 32768, ZeroAllowed: false, DynamicAllowed: true, Levels: []string{"low", "high"}}},
|
||||||
|
"gemini-3.1-pro-high": {Thinking: &ThinkingSupport{Min: 128, Max: 32768, ZeroAllowed: false, DynamicAllowed: true, Levels: []string{"low", "high"}}},
|
||||||
"gemini-3-flash": {Thinking: &ThinkingSupport{Min: 128, Max: 32768, ZeroAllowed: false, DynamicAllowed: true, Levels: []string{"minimal", "low", "medium", "high"}}},
|
"gemini-3-flash": {Thinking: &ThinkingSupport{Min: 128, Max: 32768, ZeroAllowed: false, DynamicAllowed: true, Levels: []string{"minimal", "low", "medium", "high"}}},
|
||||||
"claude-opus-4-5-thinking": {Thinking: &ThinkingSupport{Min: 1024, Max: 128000, ZeroAllowed: true, DynamicAllowed: true}, MaxCompletionTokens: 64000},
|
"claude-opus-4-5-thinking": {Thinking: &ThinkingSupport{Min: 1024, Max: 128000, ZeroAllowed: true, DynamicAllowed: true}, MaxCompletionTokens: 64000},
|
||||||
"claude-opus-4-6-thinking": {Thinking: &ThinkingSupport{Min: 1024, Max: 128000, ZeroAllowed: true, DynamicAllowed: true}, MaxCompletionTokens: 64000},
|
"claude-opus-4-6-thinking": {Thinking: &ThinkingSupport{Min: 1024, Max: 128000, ZeroAllowed: true, DynamicAllowed: true}, MaxCompletionTokens: 64000},
|
||||||
|
|||||||
@@ -164,12 +164,12 @@ func (e *AIStudioExecutor) Execute(ctx context.Context, auth *cliproxyauth.Auth,
|
|||||||
reporter.publish(ctx, parseGeminiUsage(wsResp.Body))
|
reporter.publish(ctx, parseGeminiUsage(wsResp.Body))
|
||||||
var param any
|
var param any
|
||||||
out := sdktranslator.TranslateNonStream(ctx, body.toFormat, opts.SourceFormat, req.Model, opts.OriginalRequest, translatedReq, wsResp.Body, ¶m)
|
out := sdktranslator.TranslateNonStream(ctx, body.toFormat, opts.SourceFormat, req.Model, opts.OriginalRequest, translatedReq, wsResp.Body, ¶m)
|
||||||
resp = cliproxyexecutor.Response{Payload: ensureColonSpacedJSON([]byte(out))}
|
resp = cliproxyexecutor.Response{Payload: ensureColonSpacedJSON([]byte(out)), Headers: wsResp.Headers.Clone()}
|
||||||
return resp, nil
|
return resp, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
// ExecuteStream performs a streaming request to the AI Studio API.
|
// ExecuteStream performs a streaming request to the AI Studio API.
|
||||||
func (e *AIStudioExecutor) ExecuteStream(ctx context.Context, auth *cliproxyauth.Auth, req cliproxyexecutor.Request, opts cliproxyexecutor.Options) (stream <-chan cliproxyexecutor.StreamChunk, err error) {
|
func (e *AIStudioExecutor) ExecuteStream(ctx context.Context, auth *cliproxyauth.Auth, req cliproxyexecutor.Request, opts cliproxyexecutor.Options) (_ *cliproxyexecutor.StreamResult, err error) {
|
||||||
if opts.Alt == "responses/compact" {
|
if opts.Alt == "responses/compact" {
|
||||||
return nil, statusErr{code: http.StatusNotImplemented, msg: "/responses/compact not supported"}
|
return nil, statusErr{code: http.StatusNotImplemented, msg: "/responses/compact not supported"}
|
||||||
}
|
}
|
||||||
@@ -254,7 +254,6 @@ func (e *AIStudioExecutor) ExecuteStream(ctx context.Context, auth *cliproxyauth
|
|||||||
return nil, statusErr{code: firstEvent.Status, msg: body.String()}
|
return nil, statusErr{code: firstEvent.Status, msg: body.String()}
|
||||||
}
|
}
|
||||||
out := make(chan cliproxyexecutor.StreamChunk)
|
out := make(chan cliproxyexecutor.StreamChunk)
|
||||||
stream = out
|
|
||||||
go func(first wsrelay.StreamEvent) {
|
go func(first wsrelay.StreamEvent) {
|
||||||
defer close(out)
|
defer close(out)
|
||||||
var param any
|
var param any
|
||||||
@@ -318,7 +317,7 @@ func (e *AIStudioExecutor) ExecuteStream(ctx context.Context, auth *cliproxyauth
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
}(firstEvent)
|
}(firstEvent)
|
||||||
return stream, nil
|
return &cliproxyexecutor.StreamResult{Headers: firstEvent.Headers.Clone(), Chunks: out}, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
// CountTokens counts tokens for the given request using the AI Studio API.
|
// CountTokens counts tokens for the given request using the AI Studio API.
|
||||||
|
|||||||
@@ -232,7 +232,7 @@ attemptLoop:
|
|||||||
reporter.publish(ctx, parseAntigravityUsage(bodyBytes))
|
reporter.publish(ctx, parseAntigravityUsage(bodyBytes))
|
||||||
var param any
|
var param any
|
||||||
converted := sdktranslator.TranslateNonStream(ctx, to, from, req.Model, opts.OriginalRequest, translated, bodyBytes, ¶m)
|
converted := sdktranslator.TranslateNonStream(ctx, to, from, req.Model, opts.OriginalRequest, translated, bodyBytes, ¶m)
|
||||||
resp = cliproxyexecutor.Response{Payload: []byte(converted)}
|
resp = cliproxyexecutor.Response{Payload: []byte(converted), Headers: httpResp.Header.Clone()}
|
||||||
reporter.ensurePublished(ctx)
|
reporter.ensurePublished(ctx)
|
||||||
return resp, nil
|
return resp, nil
|
||||||
}
|
}
|
||||||
@@ -436,7 +436,7 @@ attemptLoop:
|
|||||||
reporter.publish(ctx, parseAntigravityUsage(resp.Payload))
|
reporter.publish(ctx, parseAntigravityUsage(resp.Payload))
|
||||||
var param any
|
var param any
|
||||||
converted := sdktranslator.TranslateNonStream(ctx, to, from, req.Model, opts.OriginalRequest, translated, resp.Payload, ¶m)
|
converted := sdktranslator.TranslateNonStream(ctx, to, from, req.Model, opts.OriginalRequest, translated, resp.Payload, ¶m)
|
||||||
resp = cliproxyexecutor.Response{Payload: []byte(converted)}
|
resp = cliproxyexecutor.Response{Payload: []byte(converted), Headers: httpResp.Header.Clone()}
|
||||||
reporter.ensurePublished(ctx)
|
reporter.ensurePublished(ctx)
|
||||||
|
|
||||||
return resp, nil
|
return resp, nil
|
||||||
@@ -645,7 +645,7 @@ func (e *AntigravityExecutor) convertStreamToNonStream(stream []byte) []byte {
|
|||||||
}
|
}
|
||||||
|
|
||||||
// ExecuteStream performs a streaming request to the Antigravity API.
|
// ExecuteStream performs a streaming request to the Antigravity API.
|
||||||
func (e *AntigravityExecutor) ExecuteStream(ctx context.Context, auth *cliproxyauth.Auth, req cliproxyexecutor.Request, opts cliproxyexecutor.Options) (stream <-chan cliproxyexecutor.StreamChunk, err error) {
|
func (e *AntigravityExecutor) ExecuteStream(ctx context.Context, auth *cliproxyauth.Auth, req cliproxyexecutor.Request, opts cliproxyexecutor.Options) (_ *cliproxyexecutor.StreamResult, err error) {
|
||||||
if opts.Alt == "responses/compact" {
|
if opts.Alt == "responses/compact" {
|
||||||
return nil, statusErr{code: http.StatusNotImplemented, msg: "/responses/compact not supported"}
|
return nil, statusErr{code: http.StatusNotImplemented, msg: "/responses/compact not supported"}
|
||||||
}
|
}
|
||||||
@@ -775,7 +775,6 @@ attemptLoop:
|
|||||||
}
|
}
|
||||||
|
|
||||||
out := make(chan cliproxyexecutor.StreamChunk)
|
out := make(chan cliproxyexecutor.StreamChunk)
|
||||||
stream = out
|
|
||||||
go func(resp *http.Response) {
|
go func(resp *http.Response) {
|
||||||
defer close(out)
|
defer close(out)
|
||||||
defer func() {
|
defer func() {
|
||||||
@@ -820,7 +819,7 @@ attemptLoop:
|
|||||||
reporter.ensurePublished(ctx)
|
reporter.ensurePublished(ctx)
|
||||||
}
|
}
|
||||||
}(httpResp)
|
}(httpResp)
|
||||||
return stream, nil
|
return &cliproxyexecutor.StreamResult{Headers: httpResp.Header.Clone(), Chunks: out}, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
switch {
|
switch {
|
||||||
@@ -968,7 +967,7 @@ func (e *AntigravityExecutor) CountTokens(ctx context.Context, auth *cliproxyaut
|
|||||||
if httpResp.StatusCode >= http.StatusOK && httpResp.StatusCode < http.StatusMultipleChoices {
|
if httpResp.StatusCode >= http.StatusOK && httpResp.StatusCode < http.StatusMultipleChoices {
|
||||||
count := gjson.GetBytes(bodyBytes, "totalTokens").Int()
|
count := gjson.GetBytes(bodyBytes, "totalTokens").Int()
|
||||||
translated := sdktranslator.TranslateTokenCount(respCtx, to, from, count, bodyBytes)
|
translated := sdktranslator.TranslateTokenCount(respCtx, to, from, count, bodyBytes)
|
||||||
return cliproxyexecutor.Response{Payload: []byte(translated)}, nil
|
return cliproxyexecutor.Response{Payload: []byte(translated), Headers: httpResp.Header.Clone()}, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
lastStatus = httpResp.StatusCode
|
lastStatus = httpResp.StatusCode
|
||||||
|
|||||||
@@ -117,7 +117,7 @@ func (e *ClaudeExecutor) Execute(ctx context.Context, auth *cliproxyauth.Auth, r
|
|||||||
|
|
||||||
// Apply cloaking (system prompt injection, fake user ID, sensitive word obfuscation)
|
// Apply cloaking (system prompt injection, fake user ID, sensitive word obfuscation)
|
||||||
// based on client type and configuration.
|
// based on client type and configuration.
|
||||||
body = applyCloaking(ctx, e.cfg, auth, body, baseModel)
|
body = applyCloaking(ctx, e.cfg, auth, body, baseModel, apiKey)
|
||||||
|
|
||||||
requestedModel := payloadRequestedModel(opts, req.Model)
|
requestedModel := payloadRequestedModel(opts, req.Model)
|
||||||
body = applyPayloadConfigWithRoot(e.cfg, baseModel, to.String(), "", body, originalTranslated, requestedModel)
|
body = applyPayloadConfigWithRoot(e.cfg, baseModel, to.String(), "", body, originalTranslated, requestedModel)
|
||||||
@@ -223,11 +223,11 @@ func (e *ClaudeExecutor) Execute(ctx context.Context, auth *cliproxyauth.Auth, r
|
|||||||
data,
|
data,
|
||||||
¶m,
|
¶m,
|
||||||
)
|
)
|
||||||
resp = cliproxyexecutor.Response{Payload: []byte(out)}
|
resp = cliproxyexecutor.Response{Payload: []byte(out), Headers: httpResp.Header.Clone()}
|
||||||
return resp, nil
|
return resp, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func (e *ClaudeExecutor) ExecuteStream(ctx context.Context, auth *cliproxyauth.Auth, req cliproxyexecutor.Request, opts cliproxyexecutor.Options) (stream <-chan cliproxyexecutor.StreamChunk, err error) {
|
func (e *ClaudeExecutor) ExecuteStream(ctx context.Context, auth *cliproxyauth.Auth, req cliproxyexecutor.Request, opts cliproxyexecutor.Options) (_ *cliproxyexecutor.StreamResult, err error) {
|
||||||
if opts.Alt == "responses/compact" {
|
if opts.Alt == "responses/compact" {
|
||||||
return nil, statusErr{code: http.StatusNotImplemented, msg: "/responses/compact not supported"}
|
return nil, statusErr{code: http.StatusNotImplemented, msg: "/responses/compact not supported"}
|
||||||
}
|
}
|
||||||
@@ -258,7 +258,7 @@ func (e *ClaudeExecutor) ExecuteStream(ctx context.Context, auth *cliproxyauth.A
|
|||||||
|
|
||||||
// Apply cloaking (system prompt injection, fake user ID, sensitive word obfuscation)
|
// Apply cloaking (system prompt injection, fake user ID, sensitive word obfuscation)
|
||||||
// based on client type and configuration.
|
// based on client type and configuration.
|
||||||
body = applyCloaking(ctx, e.cfg, auth, body, baseModel)
|
body = applyCloaking(ctx, e.cfg, auth, body, baseModel, apiKey)
|
||||||
|
|
||||||
requestedModel := payloadRequestedModel(opts, req.Model)
|
requestedModel := payloadRequestedModel(opts, req.Model)
|
||||||
body = applyPayloadConfigWithRoot(e.cfg, baseModel, to.String(), "", body, originalTranslated, requestedModel)
|
body = applyPayloadConfigWithRoot(e.cfg, baseModel, to.String(), "", body, originalTranslated, requestedModel)
|
||||||
@@ -330,7 +330,6 @@ func (e *ClaudeExecutor) ExecuteStream(ctx context.Context, auth *cliproxyauth.A
|
|||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
out := make(chan cliproxyexecutor.StreamChunk)
|
out := make(chan cliproxyexecutor.StreamChunk)
|
||||||
stream = out
|
|
||||||
go func() {
|
go func() {
|
||||||
defer close(out)
|
defer close(out)
|
||||||
defer func() {
|
defer func() {
|
||||||
@@ -399,7 +398,7 @@ func (e *ClaudeExecutor) ExecuteStream(ctx context.Context, auth *cliproxyauth.A
|
|||||||
out <- cliproxyexecutor.StreamChunk{Err: errScan}
|
out <- cliproxyexecutor.StreamChunk{Err: errScan}
|
||||||
}
|
}
|
||||||
}()
|
}()
|
||||||
return stream, nil
|
return &cliproxyexecutor.StreamResult{Headers: httpResp.Header.Clone(), Chunks: out}, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func (e *ClaudeExecutor) CountTokens(ctx context.Context, auth *cliproxyauth.Auth, req cliproxyexecutor.Request, opts cliproxyexecutor.Options) (cliproxyexecutor.Response, error) {
|
func (e *ClaudeExecutor) CountTokens(ctx context.Context, auth *cliproxyauth.Auth, req cliproxyexecutor.Request, opts cliproxyexecutor.Options) (cliproxyexecutor.Response, error) {
|
||||||
@@ -488,7 +487,7 @@ func (e *ClaudeExecutor) CountTokens(ctx context.Context, auth *cliproxyauth.Aut
|
|||||||
appendAPIResponseChunk(ctx, e.cfg, data)
|
appendAPIResponseChunk(ctx, e.cfg, data)
|
||||||
count := gjson.GetBytes(data, "input_tokens").Int()
|
count := gjson.GetBytes(data, "input_tokens").Int()
|
||||||
out := sdktranslator.TranslateTokenCount(ctx, to, from, count, data)
|
out := sdktranslator.TranslateTokenCount(ctx, to, from, count, data)
|
||||||
return cliproxyexecutor.Response{Payload: []byte(out)}, nil
|
return cliproxyexecutor.Response{Payload: []byte(out), Headers: resp.Header.Clone()}, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func (e *ClaudeExecutor) Refresh(ctx context.Context, auth *cliproxyauth.Auth) (*cliproxyauth.Auth, error) {
|
func (e *ClaudeExecutor) Refresh(ctx context.Context, auth *cliproxyauth.Auth) (*cliproxyauth.Auth, error) {
|
||||||
@@ -983,10 +982,10 @@ func getClientUserAgent(ctx context.Context) string {
|
|||||||
}
|
}
|
||||||
|
|
||||||
// getCloakConfigFromAuth extracts cloak configuration from auth attributes.
|
// getCloakConfigFromAuth extracts cloak configuration from auth attributes.
|
||||||
// Returns (cloakMode, strictMode, sensitiveWords).
|
// Returns (cloakMode, strictMode, sensitiveWords, cacheUserID).
|
||||||
func getCloakConfigFromAuth(auth *cliproxyauth.Auth) (string, bool, []string) {
|
func getCloakConfigFromAuth(auth *cliproxyauth.Auth) (string, bool, []string, bool) {
|
||||||
if auth == nil || auth.Attributes == nil {
|
if auth == nil || auth.Attributes == nil {
|
||||||
return "auto", false, nil
|
return "auto", false, nil, false
|
||||||
}
|
}
|
||||||
|
|
||||||
cloakMode := auth.Attributes["cloak_mode"]
|
cloakMode := auth.Attributes["cloak_mode"]
|
||||||
@@ -1004,7 +1003,9 @@ func getCloakConfigFromAuth(auth *cliproxyauth.Auth) (string, bool, []string) {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
return cloakMode, strictMode, sensitiveWords
|
cacheUserID := strings.EqualFold(strings.TrimSpace(auth.Attributes["cloak_cache_user_id"]), "true")
|
||||||
|
|
||||||
|
return cloakMode, strictMode, sensitiveWords, cacheUserID
|
||||||
}
|
}
|
||||||
|
|
||||||
// resolveClaudeKeyCloakConfig finds the matching ClaudeKey config and returns its CloakConfig.
|
// resolveClaudeKeyCloakConfig finds the matching ClaudeKey config and returns its CloakConfig.
|
||||||
@@ -1037,16 +1038,24 @@ func resolveClaudeKeyCloakConfig(cfg *config.Config, auth *cliproxyauth.Auth) *c
|
|||||||
}
|
}
|
||||||
|
|
||||||
// injectFakeUserID generates and injects a fake user ID into the request metadata.
|
// injectFakeUserID generates and injects a fake user ID into the request metadata.
|
||||||
func injectFakeUserID(payload []byte) []byte {
|
// When useCache is false, a new user ID is generated for every call.
|
||||||
|
func injectFakeUserID(payload []byte, apiKey string, useCache bool) []byte {
|
||||||
|
generateID := func() string {
|
||||||
|
if useCache {
|
||||||
|
return cachedUserID(apiKey)
|
||||||
|
}
|
||||||
|
return generateFakeUserID()
|
||||||
|
}
|
||||||
|
|
||||||
metadata := gjson.GetBytes(payload, "metadata")
|
metadata := gjson.GetBytes(payload, "metadata")
|
||||||
if !metadata.Exists() {
|
if !metadata.Exists() {
|
||||||
payload, _ = sjson.SetBytes(payload, "metadata.user_id", generateFakeUserID())
|
payload, _ = sjson.SetBytes(payload, "metadata.user_id", generateID())
|
||||||
return payload
|
return payload
|
||||||
}
|
}
|
||||||
|
|
||||||
existingUserID := gjson.GetBytes(payload, "metadata.user_id").String()
|
existingUserID := gjson.GetBytes(payload, "metadata.user_id").String()
|
||||||
if existingUserID == "" || !isValidUserID(existingUserID) {
|
if existingUserID == "" || !isValidUserID(existingUserID) {
|
||||||
payload, _ = sjson.SetBytes(payload, "metadata.user_id", generateFakeUserID())
|
payload, _ = sjson.SetBytes(payload, "metadata.user_id", generateID())
|
||||||
}
|
}
|
||||||
return payload
|
return payload
|
||||||
}
|
}
|
||||||
@@ -1083,7 +1092,7 @@ func checkSystemInstructionsWithMode(payload []byte, strictMode bool) []byte {
|
|||||||
|
|
||||||
// applyCloaking applies cloaking transformations to the payload based on config and client.
|
// applyCloaking applies cloaking transformations to the payload based on config and client.
|
||||||
// Cloaking includes: system prompt injection, fake user ID, and sensitive word obfuscation.
|
// Cloaking includes: system prompt injection, fake user ID, and sensitive word obfuscation.
|
||||||
func applyCloaking(ctx context.Context, cfg *config.Config, auth *cliproxyauth.Auth, payload []byte, model string) []byte {
|
func applyCloaking(ctx context.Context, cfg *config.Config, auth *cliproxyauth.Auth, payload []byte, model string, apiKey string) []byte {
|
||||||
clientUserAgent := getClientUserAgent(ctx)
|
clientUserAgent := getClientUserAgent(ctx)
|
||||||
|
|
||||||
// Get cloak config from ClaudeKey configuration
|
// Get cloak config from ClaudeKey configuration
|
||||||
@@ -1093,16 +1102,20 @@ func applyCloaking(ctx context.Context, cfg *config.Config, auth *cliproxyauth.A
|
|||||||
var cloakMode string
|
var cloakMode string
|
||||||
var strictMode bool
|
var strictMode bool
|
||||||
var sensitiveWords []string
|
var sensitiveWords []string
|
||||||
|
var cacheUserID bool
|
||||||
|
|
||||||
if cloakCfg != nil {
|
if cloakCfg != nil {
|
||||||
cloakMode = cloakCfg.Mode
|
cloakMode = cloakCfg.Mode
|
||||||
strictMode = cloakCfg.StrictMode
|
strictMode = cloakCfg.StrictMode
|
||||||
sensitiveWords = cloakCfg.SensitiveWords
|
sensitiveWords = cloakCfg.SensitiveWords
|
||||||
|
if cloakCfg.CacheUserID != nil {
|
||||||
|
cacheUserID = *cloakCfg.CacheUserID
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
// Fallback to auth attributes if no config found
|
// Fallback to auth attributes if no config found
|
||||||
if cloakMode == "" {
|
if cloakMode == "" {
|
||||||
attrMode, attrStrict, attrWords := getCloakConfigFromAuth(auth)
|
attrMode, attrStrict, attrWords, attrCache := getCloakConfigFromAuth(auth)
|
||||||
cloakMode = attrMode
|
cloakMode = attrMode
|
||||||
if !strictMode {
|
if !strictMode {
|
||||||
strictMode = attrStrict
|
strictMode = attrStrict
|
||||||
@@ -1110,6 +1123,12 @@ func applyCloaking(ctx context.Context, cfg *config.Config, auth *cliproxyauth.A
|
|||||||
if len(sensitiveWords) == 0 {
|
if len(sensitiveWords) == 0 {
|
||||||
sensitiveWords = attrWords
|
sensitiveWords = attrWords
|
||||||
}
|
}
|
||||||
|
if cloakCfg == nil || cloakCfg.CacheUserID == nil {
|
||||||
|
cacheUserID = attrCache
|
||||||
|
}
|
||||||
|
} else if cloakCfg == nil || cloakCfg.CacheUserID == nil {
|
||||||
|
_, _, _, attrCache := getCloakConfigFromAuth(auth)
|
||||||
|
cacheUserID = attrCache
|
||||||
}
|
}
|
||||||
|
|
||||||
// Determine if cloaking should be applied
|
// Determine if cloaking should be applied
|
||||||
@@ -1123,7 +1142,7 @@ func applyCloaking(ctx context.Context, cfg *config.Config, auth *cliproxyauth.A
|
|||||||
}
|
}
|
||||||
|
|
||||||
// Inject fake user ID
|
// Inject fake user ID
|
||||||
payload = injectFakeUserID(payload)
|
payload = injectFakeUserID(payload, apiKey, cacheUserID)
|
||||||
|
|
||||||
// Apply sensitive word obfuscation
|
// Apply sensitive word obfuscation
|
||||||
if len(sensitiveWords) > 0 {
|
if len(sensitiveWords) > 0 {
|
||||||
|
|||||||
@@ -2,9 +2,18 @@ package executor
|
|||||||
|
|
||||||
import (
|
import (
|
||||||
"bytes"
|
"bytes"
|
||||||
|
"context"
|
||||||
|
"io"
|
||||||
|
"net/http"
|
||||||
|
"net/http/httptest"
|
||||||
"testing"
|
"testing"
|
||||||
|
|
||||||
|
"github.com/router-for-me/CLIProxyAPI/v6/internal/config"
|
||||||
|
cliproxyauth "github.com/router-for-me/CLIProxyAPI/v6/sdk/cliproxy/auth"
|
||||||
|
cliproxyexecutor "github.com/router-for-me/CLIProxyAPI/v6/sdk/cliproxy/executor"
|
||||||
|
sdktranslator "github.com/router-for-me/CLIProxyAPI/v6/sdk/translator"
|
||||||
"github.com/tidwall/gjson"
|
"github.com/tidwall/gjson"
|
||||||
|
"github.com/tidwall/sjson"
|
||||||
)
|
)
|
||||||
|
|
||||||
func TestApplyClaudeToolPrefix(t *testing.T) {
|
func TestApplyClaudeToolPrefix(t *testing.T) {
|
||||||
@@ -199,6 +208,119 @@ func TestApplyClaudeToolPrefix_NestedToolReference(t *testing.T) {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func TestClaudeExecutor_ReusesUserIDAcrossModelsWhenCacheEnabled(t *testing.T) {
|
||||||
|
resetUserIDCache()
|
||||||
|
|
||||||
|
var userIDs []string
|
||||||
|
var requestModels []string
|
||||||
|
server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||||
|
body, _ := io.ReadAll(r.Body)
|
||||||
|
userID := gjson.GetBytes(body, "metadata.user_id").String()
|
||||||
|
model := gjson.GetBytes(body, "model").String()
|
||||||
|
userIDs = append(userIDs, userID)
|
||||||
|
requestModels = append(requestModels, model)
|
||||||
|
t.Logf("HTTP Server received request: model=%s, user_id=%s, url=%s", model, userID, r.URL.String())
|
||||||
|
w.Header().Set("Content-Type", "application/json")
|
||||||
|
_, _ = w.Write([]byte(`{"id":"msg_1","type":"message","model":"claude-3-5-sonnet","role":"assistant","content":[{"type":"text","text":"ok"}],"usage":{"input_tokens":1,"output_tokens":1}}`))
|
||||||
|
}))
|
||||||
|
defer server.Close()
|
||||||
|
|
||||||
|
t.Logf("End-to-end test: Fake HTTP server started at %s", server.URL)
|
||||||
|
|
||||||
|
cacheEnabled := true
|
||||||
|
executor := NewClaudeExecutor(&config.Config{
|
||||||
|
ClaudeKey: []config.ClaudeKey{
|
||||||
|
{
|
||||||
|
APIKey: "key-123",
|
||||||
|
BaseURL: server.URL,
|
||||||
|
Cloak: &config.CloakConfig{
|
||||||
|
CacheUserID: &cacheEnabled,
|
||||||
|
},
|
||||||
|
},
|
||||||
|
},
|
||||||
|
})
|
||||||
|
auth := &cliproxyauth.Auth{Attributes: map[string]string{
|
||||||
|
"api_key": "key-123",
|
||||||
|
"base_url": server.URL,
|
||||||
|
}}
|
||||||
|
|
||||||
|
payload := []byte(`{"messages":[{"role":"user","content":[{"type":"text","text":"hi"}]}]}`)
|
||||||
|
models := []string{"claude-3-5-sonnet", "claude-3-5-haiku"}
|
||||||
|
for _, model := range models {
|
||||||
|
t.Logf("Sending request for model: %s", model)
|
||||||
|
modelPayload, _ := sjson.SetBytes(payload, "model", model)
|
||||||
|
if _, err := executor.Execute(context.Background(), auth, cliproxyexecutor.Request{
|
||||||
|
Model: model,
|
||||||
|
Payload: modelPayload,
|
||||||
|
}, cliproxyexecutor.Options{
|
||||||
|
SourceFormat: sdktranslator.FromString("claude"),
|
||||||
|
}); err != nil {
|
||||||
|
t.Fatalf("Execute(%s) error: %v", model, err)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
if len(userIDs) != 2 {
|
||||||
|
t.Fatalf("expected 2 requests, got %d", len(userIDs))
|
||||||
|
}
|
||||||
|
if userIDs[0] == "" || userIDs[1] == "" {
|
||||||
|
t.Fatal("expected user_id to be populated")
|
||||||
|
}
|
||||||
|
t.Logf("user_id[0] (model=%s): %s", requestModels[0], userIDs[0])
|
||||||
|
t.Logf("user_id[1] (model=%s): %s", requestModels[1], userIDs[1])
|
||||||
|
if userIDs[0] != userIDs[1] {
|
||||||
|
t.Fatalf("expected user_id to be reused across models, got %q and %q", userIDs[0], userIDs[1])
|
||||||
|
}
|
||||||
|
if !isValidUserID(userIDs[0]) {
|
||||||
|
t.Fatalf("user_id %q is not valid", userIDs[0])
|
||||||
|
}
|
||||||
|
t.Logf("✓ End-to-end test passed: Same user_id (%s) was used for both models", userIDs[0])
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestClaudeExecutor_GeneratesNewUserIDByDefault(t *testing.T) {
|
||||||
|
resetUserIDCache()
|
||||||
|
|
||||||
|
var userIDs []string
|
||||||
|
server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||||
|
body, _ := io.ReadAll(r.Body)
|
||||||
|
userIDs = append(userIDs, gjson.GetBytes(body, "metadata.user_id").String())
|
||||||
|
w.Header().Set("Content-Type", "application/json")
|
||||||
|
_, _ = w.Write([]byte(`{"id":"msg_1","type":"message","model":"claude-3-5-sonnet","role":"assistant","content":[{"type":"text","text":"ok"}],"usage":{"input_tokens":1,"output_tokens":1}}`))
|
||||||
|
}))
|
||||||
|
defer server.Close()
|
||||||
|
|
||||||
|
executor := NewClaudeExecutor(&config.Config{})
|
||||||
|
auth := &cliproxyauth.Auth{Attributes: map[string]string{
|
||||||
|
"api_key": "key-123",
|
||||||
|
"base_url": server.URL,
|
||||||
|
}}
|
||||||
|
|
||||||
|
payload := []byte(`{"messages":[{"role":"user","content":[{"type":"text","text":"hi"}]}]}`)
|
||||||
|
|
||||||
|
for i := 0; i < 2; i++ {
|
||||||
|
if _, err := executor.Execute(context.Background(), auth, cliproxyexecutor.Request{
|
||||||
|
Model: "claude-3-5-sonnet",
|
||||||
|
Payload: payload,
|
||||||
|
}, cliproxyexecutor.Options{
|
||||||
|
SourceFormat: sdktranslator.FromString("claude"),
|
||||||
|
}); err != nil {
|
||||||
|
t.Fatalf("Execute call %d error: %v", i, err)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
if len(userIDs) != 2 {
|
||||||
|
t.Fatalf("expected 2 requests, got %d", len(userIDs))
|
||||||
|
}
|
||||||
|
if userIDs[0] == "" || userIDs[1] == "" {
|
||||||
|
t.Fatal("expected user_id to be populated")
|
||||||
|
}
|
||||||
|
if userIDs[0] == userIDs[1] {
|
||||||
|
t.Fatalf("expected user_id to change when caching is not enabled, got identical values %q", userIDs[0])
|
||||||
|
}
|
||||||
|
if !isValidUserID(userIDs[0]) || !isValidUserID(userIDs[1]) {
|
||||||
|
t.Fatalf("user_ids should be valid, got %q and %q", userIDs[0], userIDs[1])
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
func TestStripClaudeToolPrefixFromResponse_NestedToolReference(t *testing.T) {
|
func TestStripClaudeToolPrefixFromResponse_NestedToolReference(t *testing.T) {
|
||||||
input := []byte(`{"content":[{"type":"tool_result","tool_use_id":"toolu_123","content":[{"type":"tool_reference","tool_name":"proxy_mcp__nia__manage_resource"}]}]}`)
|
input := []byte(`{"content":[{"type":"tool_result","tool_use_id":"toolu_123","content":[{"type":"tool_reference","tool_name":"proxy_mcp__nia__manage_resource"}]}]}`)
|
||||||
out := stripClaudeToolPrefixFromResponse(input, "proxy_")
|
out := stripClaudeToolPrefixFromResponse(input, "proxy_")
|
||||||
|
|||||||
@@ -183,7 +183,7 @@ func (e *CodexExecutor) Execute(ctx context.Context, auth *cliproxyauth.Auth, re
|
|||||||
|
|
||||||
var param any
|
var param any
|
||||||
out := sdktranslator.TranslateNonStream(ctx, to, from, req.Model, originalPayload, body, line, ¶m)
|
out := sdktranslator.TranslateNonStream(ctx, to, from, req.Model, originalPayload, body, line, ¶m)
|
||||||
resp = cliproxyexecutor.Response{Payload: []byte(out)}
|
resp = cliproxyexecutor.Response{Payload: []byte(out), Headers: httpResp.Header.Clone()}
|
||||||
return resp, nil
|
return resp, nil
|
||||||
}
|
}
|
||||||
err = statusErr{code: 408, msg: "stream error: stream disconnected before completion: stream closed before response.completed"}
|
err = statusErr{code: 408, msg: "stream error: stream disconnected before completion: stream closed before response.completed"}
|
||||||
@@ -273,11 +273,11 @@ func (e *CodexExecutor) executeCompact(ctx context.Context, auth *cliproxyauth.A
|
|||||||
reporter.ensurePublished(ctx)
|
reporter.ensurePublished(ctx)
|
||||||
var param any
|
var param any
|
||||||
out := sdktranslator.TranslateNonStream(ctx, to, from, req.Model, originalPayload, body, data, ¶m)
|
out := sdktranslator.TranslateNonStream(ctx, to, from, req.Model, originalPayload, body, data, ¶m)
|
||||||
resp = cliproxyexecutor.Response{Payload: []byte(out)}
|
resp = cliproxyexecutor.Response{Payload: []byte(out), Headers: httpResp.Header.Clone()}
|
||||||
return resp, nil
|
return resp, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func (e *CodexExecutor) ExecuteStream(ctx context.Context, auth *cliproxyauth.Auth, req cliproxyexecutor.Request, opts cliproxyexecutor.Options) (stream <-chan cliproxyexecutor.StreamChunk, err error) {
|
func (e *CodexExecutor) ExecuteStream(ctx context.Context, auth *cliproxyauth.Auth, req cliproxyexecutor.Request, opts cliproxyexecutor.Options) (_ *cliproxyexecutor.StreamResult, err error) {
|
||||||
if opts.Alt == "responses/compact" {
|
if opts.Alt == "responses/compact" {
|
||||||
return nil, statusErr{code: http.StatusBadRequest, msg: "streaming not supported for /responses/compact"}
|
return nil, statusErr{code: http.StatusBadRequest, msg: "streaming not supported for /responses/compact"}
|
||||||
}
|
}
|
||||||
@@ -362,7 +362,6 @@ func (e *CodexExecutor) ExecuteStream(ctx context.Context, auth *cliproxyauth.Au
|
|||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
out := make(chan cliproxyexecutor.StreamChunk)
|
out := make(chan cliproxyexecutor.StreamChunk)
|
||||||
stream = out
|
|
||||||
go func() {
|
go func() {
|
||||||
defer close(out)
|
defer close(out)
|
||||||
defer func() {
|
defer func() {
|
||||||
@@ -397,7 +396,7 @@ func (e *CodexExecutor) ExecuteStream(ctx context.Context, auth *cliproxyauth.Au
|
|||||||
out <- cliproxyexecutor.StreamChunk{Err: errScan}
|
out <- cliproxyexecutor.StreamChunk{Err: errScan}
|
||||||
}
|
}
|
||||||
}()
|
}()
|
||||||
return stream, nil
|
return &cliproxyexecutor.StreamResult{Headers: httpResp.Header.Clone(), Chunks: out}, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func (e *CodexExecutor) CountTokens(ctx context.Context, auth *cliproxyauth.Auth, req cliproxyexecutor.Request, opts cliproxyexecutor.Options) (cliproxyexecutor.Response, error) {
|
func (e *CodexExecutor) CountTokens(ctx context.Context, auth *cliproxyauth.Auth, req cliproxyexecutor.Request, opts cliproxyexecutor.Options) (cliproxyexecutor.Response, error) {
|
||||||
|
|||||||
@@ -363,7 +363,7 @@ func (e *CodexWebsocketsExecutor) Execute(ctx context.Context, auth *cliproxyaut
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
func (e *CodexWebsocketsExecutor) ExecuteStream(ctx context.Context, auth *cliproxyauth.Auth, req cliproxyexecutor.Request, opts cliproxyexecutor.Options) (stream <-chan cliproxyexecutor.StreamChunk, err error) {
|
func (e *CodexWebsocketsExecutor) ExecuteStream(ctx context.Context, auth *cliproxyauth.Auth, req cliproxyexecutor.Request, opts cliproxyexecutor.Options) (_ *cliproxyexecutor.StreamResult, err error) {
|
||||||
log.Debugf("Executing Codex Websockets stream request with auth ID: %s, model: %s", auth.ID, req.Model)
|
log.Debugf("Executing Codex Websockets stream request with auth ID: %s, model: %s", auth.ID, req.Model)
|
||||||
if ctx == nil {
|
if ctx == nil {
|
||||||
ctx = context.Background()
|
ctx = context.Background()
|
||||||
@@ -436,7 +436,9 @@ func (e *CodexWebsocketsExecutor) ExecuteStream(ctx context.Context, auth *clipr
|
|||||||
})
|
})
|
||||||
|
|
||||||
conn, respHS, errDial := e.ensureUpstreamConn(ctx, auth, sess, authID, wsURL, wsHeaders)
|
conn, respHS, errDial := e.ensureUpstreamConn(ctx, auth, sess, authID, wsURL, wsHeaders)
|
||||||
|
var upstreamHeaders http.Header
|
||||||
if respHS != nil {
|
if respHS != nil {
|
||||||
|
upstreamHeaders = respHS.Header.Clone()
|
||||||
recordAPIResponseMetadata(ctx, e.cfg, respHS.StatusCode, respHS.Header.Clone())
|
recordAPIResponseMetadata(ctx, e.cfg, respHS.StatusCode, respHS.Header.Clone())
|
||||||
}
|
}
|
||||||
if errDial != nil {
|
if errDial != nil {
|
||||||
@@ -516,7 +518,6 @@ func (e *CodexWebsocketsExecutor) ExecuteStream(ctx context.Context, auth *clipr
|
|||||||
markCodexWebsocketCreateSent(sess, conn, wsReqBody)
|
markCodexWebsocketCreateSent(sess, conn, wsReqBody)
|
||||||
|
|
||||||
out := make(chan cliproxyexecutor.StreamChunk)
|
out := make(chan cliproxyexecutor.StreamChunk)
|
||||||
stream = out
|
|
||||||
go func() {
|
go func() {
|
||||||
terminateReason := "completed"
|
terminateReason := "completed"
|
||||||
var terminateErr error
|
var terminateErr error
|
||||||
@@ -627,7 +628,7 @@ func (e *CodexWebsocketsExecutor) ExecuteStream(ctx context.Context, auth *clipr
|
|||||||
}
|
}
|
||||||
}()
|
}()
|
||||||
|
|
||||||
return stream, nil
|
return &cliproxyexecutor.StreamResult{Headers: upstreamHeaders, Chunks: out}, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func (e *CodexWebsocketsExecutor) dialCodexWebsocket(ctx context.Context, auth *cliproxyauth.Auth, wsURL string, headers http.Header) (*websocket.Conn, *http.Response, error) {
|
func (e *CodexWebsocketsExecutor) dialCodexWebsocket(ctx context.Context, auth *cliproxyauth.Auth, wsURL string, headers http.Header) (*websocket.Conn, *http.Response, error) {
|
||||||
@@ -1343,7 +1344,7 @@ func (e *CodexAutoExecutor) Execute(ctx context.Context, auth *cliproxyauth.Auth
|
|||||||
return e.httpExec.Execute(ctx, auth, req, opts)
|
return e.httpExec.Execute(ctx, auth, req, opts)
|
||||||
}
|
}
|
||||||
|
|
||||||
func (e *CodexAutoExecutor) ExecuteStream(ctx context.Context, auth *cliproxyauth.Auth, req cliproxyexecutor.Request, opts cliproxyexecutor.Options) (<-chan cliproxyexecutor.StreamChunk, error) {
|
func (e *CodexAutoExecutor) ExecuteStream(ctx context.Context, auth *cliproxyauth.Auth, req cliproxyexecutor.Request, opts cliproxyexecutor.Options) (*cliproxyexecutor.StreamResult, error) {
|
||||||
if e == nil || e.httpExec == nil || e.wsExec == nil {
|
if e == nil || e.httpExec == nil || e.wsExec == nil {
|
||||||
return nil, fmt.Errorf("codex auto executor: executor is nil")
|
return nil, fmt.Errorf("codex auto executor: executor is nil")
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -225,7 +225,7 @@ func (e *GeminiCLIExecutor) Execute(ctx context.Context, auth *cliproxyauth.Auth
|
|||||||
reporter.publish(ctx, parseGeminiCLIUsage(data))
|
reporter.publish(ctx, parseGeminiCLIUsage(data))
|
||||||
var param any
|
var param any
|
||||||
out := sdktranslator.TranslateNonStream(respCtx, to, from, attemptModel, opts.OriginalRequest, payload, data, ¶m)
|
out := sdktranslator.TranslateNonStream(respCtx, to, from, attemptModel, opts.OriginalRequest, payload, data, ¶m)
|
||||||
resp = cliproxyexecutor.Response{Payload: []byte(out)}
|
resp = cliproxyexecutor.Response{Payload: []byte(out), Headers: httpResp.Header.Clone()}
|
||||||
return resp, nil
|
return resp, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -256,7 +256,7 @@ func (e *GeminiCLIExecutor) Execute(ctx context.Context, auth *cliproxyauth.Auth
|
|||||||
}
|
}
|
||||||
|
|
||||||
// ExecuteStream performs a streaming request to the Gemini CLI API.
|
// ExecuteStream performs a streaming request to the Gemini CLI API.
|
||||||
func (e *GeminiCLIExecutor) ExecuteStream(ctx context.Context, auth *cliproxyauth.Auth, req cliproxyexecutor.Request, opts cliproxyexecutor.Options) (stream <-chan cliproxyexecutor.StreamChunk, err error) {
|
func (e *GeminiCLIExecutor) ExecuteStream(ctx context.Context, auth *cliproxyauth.Auth, req cliproxyexecutor.Request, opts cliproxyexecutor.Options) (_ *cliproxyexecutor.StreamResult, err error) {
|
||||||
if opts.Alt == "responses/compact" {
|
if opts.Alt == "responses/compact" {
|
||||||
return nil, statusErr{code: http.StatusNotImplemented, msg: "/responses/compact not supported"}
|
return nil, statusErr{code: http.StatusNotImplemented, msg: "/responses/compact not supported"}
|
||||||
}
|
}
|
||||||
@@ -382,7 +382,6 @@ func (e *GeminiCLIExecutor) ExecuteStream(ctx context.Context, auth *cliproxyaut
|
|||||||
}
|
}
|
||||||
|
|
||||||
out := make(chan cliproxyexecutor.StreamChunk)
|
out := make(chan cliproxyexecutor.StreamChunk)
|
||||||
stream = out
|
|
||||||
go func(resp *http.Response, reqBody []byte, attemptModel string) {
|
go func(resp *http.Response, reqBody []byte, attemptModel string) {
|
||||||
defer close(out)
|
defer close(out)
|
||||||
defer func() {
|
defer func() {
|
||||||
@@ -441,7 +440,7 @@ func (e *GeminiCLIExecutor) ExecuteStream(ctx context.Context, auth *cliproxyaut
|
|||||||
}
|
}
|
||||||
}(httpResp, append([]byte(nil), payload...), attemptModel)
|
}(httpResp, append([]byte(nil), payload...), attemptModel)
|
||||||
|
|
||||||
return stream, nil
|
return &cliproxyexecutor.StreamResult{Headers: httpResp.Header.Clone(), Chunks: out}, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
if len(lastBody) > 0 {
|
if len(lastBody) > 0 {
|
||||||
@@ -546,7 +545,7 @@ func (e *GeminiCLIExecutor) CountTokens(ctx context.Context, auth *cliproxyauth.
|
|||||||
if resp.StatusCode >= 200 && resp.StatusCode < 300 {
|
if resp.StatusCode >= 200 && resp.StatusCode < 300 {
|
||||||
count := gjson.GetBytes(data, "totalTokens").Int()
|
count := gjson.GetBytes(data, "totalTokens").Int()
|
||||||
translated := sdktranslator.TranslateTokenCount(respCtx, to, from, count, data)
|
translated := sdktranslator.TranslateTokenCount(respCtx, to, from, count, data)
|
||||||
return cliproxyexecutor.Response{Payload: []byte(translated)}, nil
|
return cliproxyexecutor.Response{Payload: []byte(translated), Headers: resp.Header.Clone()}, nil
|
||||||
}
|
}
|
||||||
lastStatus = resp.StatusCode
|
lastStatus = resp.StatusCode
|
||||||
lastBody = append([]byte(nil), data...)
|
lastBody = append([]byte(nil), data...)
|
||||||
|
|||||||
@@ -205,12 +205,12 @@ func (e *GeminiExecutor) Execute(ctx context.Context, auth *cliproxyauth.Auth, r
|
|||||||
reporter.publish(ctx, parseGeminiUsage(data))
|
reporter.publish(ctx, parseGeminiUsage(data))
|
||||||
var param any
|
var param any
|
||||||
out := sdktranslator.TranslateNonStream(ctx, to, from, req.Model, opts.OriginalRequest, body, data, ¶m)
|
out := sdktranslator.TranslateNonStream(ctx, to, from, req.Model, opts.OriginalRequest, body, data, ¶m)
|
||||||
resp = cliproxyexecutor.Response{Payload: []byte(out)}
|
resp = cliproxyexecutor.Response{Payload: []byte(out), Headers: httpResp.Header.Clone()}
|
||||||
return resp, nil
|
return resp, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
// ExecuteStream performs a streaming request to the Gemini API.
|
// ExecuteStream performs a streaming request to the Gemini API.
|
||||||
func (e *GeminiExecutor) ExecuteStream(ctx context.Context, auth *cliproxyauth.Auth, req cliproxyexecutor.Request, opts cliproxyexecutor.Options) (stream <-chan cliproxyexecutor.StreamChunk, err error) {
|
func (e *GeminiExecutor) ExecuteStream(ctx context.Context, auth *cliproxyauth.Auth, req cliproxyexecutor.Request, opts cliproxyexecutor.Options) (_ *cliproxyexecutor.StreamResult, err error) {
|
||||||
if opts.Alt == "responses/compact" {
|
if opts.Alt == "responses/compact" {
|
||||||
return nil, statusErr{code: http.StatusNotImplemented, msg: "/responses/compact not supported"}
|
return nil, statusErr{code: http.StatusNotImplemented, msg: "/responses/compact not supported"}
|
||||||
}
|
}
|
||||||
@@ -298,7 +298,6 @@ func (e *GeminiExecutor) ExecuteStream(ctx context.Context, auth *cliproxyauth.A
|
|||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
out := make(chan cliproxyexecutor.StreamChunk)
|
out := make(chan cliproxyexecutor.StreamChunk)
|
||||||
stream = out
|
|
||||||
go func() {
|
go func() {
|
||||||
defer close(out)
|
defer close(out)
|
||||||
defer func() {
|
defer func() {
|
||||||
@@ -335,7 +334,7 @@ func (e *GeminiExecutor) ExecuteStream(ctx context.Context, auth *cliproxyauth.A
|
|||||||
out <- cliproxyexecutor.StreamChunk{Err: errScan}
|
out <- cliproxyexecutor.StreamChunk{Err: errScan}
|
||||||
}
|
}
|
||||||
}()
|
}()
|
||||||
return stream, nil
|
return &cliproxyexecutor.StreamResult{Headers: httpResp.Header.Clone(), Chunks: out}, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
// CountTokens counts tokens for the given request using the Gemini API.
|
// CountTokens counts tokens for the given request using the Gemini API.
|
||||||
@@ -416,7 +415,7 @@ func (e *GeminiExecutor) CountTokens(ctx context.Context, auth *cliproxyauth.Aut
|
|||||||
|
|
||||||
count := gjson.GetBytes(data, "totalTokens").Int()
|
count := gjson.GetBytes(data, "totalTokens").Int()
|
||||||
translated := sdktranslator.TranslateTokenCount(respCtx, to, from, count, data)
|
translated := sdktranslator.TranslateTokenCount(respCtx, to, from, count, data)
|
||||||
return cliproxyexecutor.Response{Payload: []byte(translated)}, nil
|
return cliproxyexecutor.Response{Payload: []byte(translated), Headers: resp.Header.Clone()}, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
// Refresh refreshes the authentication credentials (no-op for Gemini API key).
|
// Refresh refreshes the authentication credentials (no-op for Gemini API key).
|
||||||
|
|||||||
@@ -253,7 +253,7 @@ func (e *GeminiVertexExecutor) Execute(ctx context.Context, auth *cliproxyauth.A
|
|||||||
}
|
}
|
||||||
|
|
||||||
// ExecuteStream performs a streaming request to the Vertex AI API.
|
// ExecuteStream performs a streaming request to the Vertex AI API.
|
||||||
func (e *GeminiVertexExecutor) ExecuteStream(ctx context.Context, auth *cliproxyauth.Auth, req cliproxyexecutor.Request, opts cliproxyexecutor.Options) (stream <-chan cliproxyexecutor.StreamChunk, err error) {
|
func (e *GeminiVertexExecutor) ExecuteStream(ctx context.Context, auth *cliproxyauth.Auth, req cliproxyexecutor.Request, opts cliproxyexecutor.Options) (*cliproxyexecutor.StreamResult, error) {
|
||||||
if opts.Alt == "responses/compact" {
|
if opts.Alt == "responses/compact" {
|
||||||
return nil, statusErr{code: http.StatusNotImplemented, msg: "/responses/compact not supported"}
|
return nil, statusErr{code: http.StatusNotImplemented, msg: "/responses/compact not supported"}
|
||||||
}
|
}
|
||||||
@@ -419,7 +419,7 @@ func (e *GeminiVertexExecutor) executeWithServiceAccount(ctx context.Context, au
|
|||||||
to := sdktranslator.FromString("gemini")
|
to := sdktranslator.FromString("gemini")
|
||||||
var param any
|
var param any
|
||||||
out := sdktranslator.TranslateNonStream(ctx, to, from, req.Model, opts.OriginalRequest, body, data, ¶m)
|
out := sdktranslator.TranslateNonStream(ctx, to, from, req.Model, opts.OriginalRequest, body, data, ¶m)
|
||||||
resp = cliproxyexecutor.Response{Payload: []byte(out)}
|
resp = cliproxyexecutor.Response{Payload: []byte(out), Headers: httpResp.Header.Clone()}
|
||||||
return resp, nil
|
return resp, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -524,12 +524,12 @@ func (e *GeminiVertexExecutor) executeWithAPIKey(ctx context.Context, auth *clip
|
|||||||
reporter.publish(ctx, parseGeminiUsage(data))
|
reporter.publish(ctx, parseGeminiUsage(data))
|
||||||
var param any
|
var param any
|
||||||
out := sdktranslator.TranslateNonStream(ctx, to, from, req.Model, opts.OriginalRequest, body, data, ¶m)
|
out := sdktranslator.TranslateNonStream(ctx, to, from, req.Model, opts.OriginalRequest, body, data, ¶m)
|
||||||
resp = cliproxyexecutor.Response{Payload: []byte(out)}
|
resp = cliproxyexecutor.Response{Payload: []byte(out), Headers: httpResp.Header.Clone()}
|
||||||
return resp, nil
|
return resp, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
// executeStreamWithServiceAccount handles streaming authentication using service account credentials.
|
// executeStreamWithServiceAccount handles streaming authentication using service account credentials.
|
||||||
func (e *GeminiVertexExecutor) executeStreamWithServiceAccount(ctx context.Context, auth *cliproxyauth.Auth, req cliproxyexecutor.Request, opts cliproxyexecutor.Options, projectID, location string, saJSON []byte) (stream <-chan cliproxyexecutor.StreamChunk, err error) {
|
func (e *GeminiVertexExecutor) executeStreamWithServiceAccount(ctx context.Context, auth *cliproxyauth.Auth, req cliproxyexecutor.Request, opts cliproxyexecutor.Options, projectID, location string, saJSON []byte) (_ *cliproxyexecutor.StreamResult, err error) {
|
||||||
baseModel := thinking.ParseSuffix(req.Model).ModelName
|
baseModel := thinking.ParseSuffix(req.Model).ModelName
|
||||||
|
|
||||||
reporter := newUsageReporter(ctx, e.Identifier(), baseModel, auth)
|
reporter := newUsageReporter(ctx, e.Identifier(), baseModel, auth)
|
||||||
@@ -618,7 +618,6 @@ func (e *GeminiVertexExecutor) executeStreamWithServiceAccount(ctx context.Conte
|
|||||||
}
|
}
|
||||||
|
|
||||||
out := make(chan cliproxyexecutor.StreamChunk)
|
out := make(chan cliproxyexecutor.StreamChunk)
|
||||||
stream = out
|
|
||||||
go func() {
|
go func() {
|
||||||
defer close(out)
|
defer close(out)
|
||||||
defer func() {
|
defer func() {
|
||||||
@@ -650,11 +649,11 @@ func (e *GeminiVertexExecutor) executeStreamWithServiceAccount(ctx context.Conte
|
|||||||
out <- cliproxyexecutor.StreamChunk{Err: errScan}
|
out <- cliproxyexecutor.StreamChunk{Err: errScan}
|
||||||
}
|
}
|
||||||
}()
|
}()
|
||||||
return stream, nil
|
return &cliproxyexecutor.StreamResult{Headers: httpResp.Header.Clone(), Chunks: out}, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
// executeStreamWithAPIKey handles streaming authentication using API key credentials.
|
// executeStreamWithAPIKey handles streaming authentication using API key credentials.
|
||||||
func (e *GeminiVertexExecutor) executeStreamWithAPIKey(ctx context.Context, auth *cliproxyauth.Auth, req cliproxyexecutor.Request, opts cliproxyexecutor.Options, apiKey, baseURL string) (stream <-chan cliproxyexecutor.StreamChunk, err error) {
|
func (e *GeminiVertexExecutor) executeStreamWithAPIKey(ctx context.Context, auth *cliproxyauth.Auth, req cliproxyexecutor.Request, opts cliproxyexecutor.Options, apiKey, baseURL string) (_ *cliproxyexecutor.StreamResult, err error) {
|
||||||
baseModel := thinking.ParseSuffix(req.Model).ModelName
|
baseModel := thinking.ParseSuffix(req.Model).ModelName
|
||||||
|
|
||||||
reporter := newUsageReporter(ctx, e.Identifier(), baseModel, auth)
|
reporter := newUsageReporter(ctx, e.Identifier(), baseModel, auth)
|
||||||
@@ -743,7 +742,6 @@ func (e *GeminiVertexExecutor) executeStreamWithAPIKey(ctx context.Context, auth
|
|||||||
}
|
}
|
||||||
|
|
||||||
out := make(chan cliproxyexecutor.StreamChunk)
|
out := make(chan cliproxyexecutor.StreamChunk)
|
||||||
stream = out
|
|
||||||
go func() {
|
go func() {
|
||||||
defer close(out)
|
defer close(out)
|
||||||
defer func() {
|
defer func() {
|
||||||
@@ -775,7 +773,7 @@ func (e *GeminiVertexExecutor) executeStreamWithAPIKey(ctx context.Context, auth
|
|||||||
out <- cliproxyexecutor.StreamChunk{Err: errScan}
|
out <- cliproxyexecutor.StreamChunk{Err: errScan}
|
||||||
}
|
}
|
||||||
}()
|
}()
|
||||||
return stream, nil
|
return &cliproxyexecutor.StreamResult{Headers: httpResp.Header.Clone(), Chunks: out}, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
// countTokensWithServiceAccount counts tokens using service account credentials.
|
// countTokensWithServiceAccount counts tokens using service account credentials.
|
||||||
@@ -859,7 +857,7 @@ func (e *GeminiVertexExecutor) countTokensWithServiceAccount(ctx context.Context
|
|||||||
appendAPIResponseChunk(ctx, e.cfg, data)
|
appendAPIResponseChunk(ctx, e.cfg, data)
|
||||||
count := gjson.GetBytes(data, "totalTokens").Int()
|
count := gjson.GetBytes(data, "totalTokens").Int()
|
||||||
out := sdktranslator.TranslateTokenCount(ctx, to, from, count, data)
|
out := sdktranslator.TranslateTokenCount(ctx, to, from, count, data)
|
||||||
return cliproxyexecutor.Response{Payload: []byte(out)}, nil
|
return cliproxyexecutor.Response{Payload: []byte(out), Headers: httpResp.Header.Clone()}, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
// countTokensWithAPIKey handles token counting using API key credentials.
|
// countTokensWithAPIKey handles token counting using API key credentials.
|
||||||
@@ -943,7 +941,7 @@ func (e *GeminiVertexExecutor) countTokensWithAPIKey(ctx context.Context, auth *
|
|||||||
appendAPIResponseChunk(ctx, e.cfg, data)
|
appendAPIResponseChunk(ctx, e.cfg, data)
|
||||||
count := gjson.GetBytes(data, "totalTokens").Int()
|
count := gjson.GetBytes(data, "totalTokens").Int()
|
||||||
out := sdktranslator.TranslateTokenCount(ctx, to, from, count, data)
|
out := sdktranslator.TranslateTokenCount(ctx, to, from, count, data)
|
||||||
return cliproxyexecutor.Response{Payload: []byte(out)}, nil
|
return cliproxyexecutor.Response{Payload: []byte(out), Headers: httpResp.Header.Clone()}, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
// vertexCreds extracts project, location and raw service account JSON from auth metadata.
|
// vertexCreds extracts project, location and raw service account JSON from auth metadata.
|
||||||
|
|||||||
@@ -35,12 +35,12 @@ const (
|
|||||||
maxScannerBufferSize = 20_971_520
|
maxScannerBufferSize = 20_971_520
|
||||||
|
|
||||||
// Copilot API header values.
|
// Copilot API header values.
|
||||||
copilotUserAgent = "GitHubCopilotChat/0.35.0"
|
copilotUserAgent = "GitHubCopilotChat/0.35.0"
|
||||||
copilotEditorVersion = "vscode/1.107.0"
|
copilotEditorVersion = "vscode/1.107.0"
|
||||||
copilotPluginVersion = "copilot-chat/0.35.0"
|
copilotPluginVersion = "copilot-chat/0.35.0"
|
||||||
copilotIntegrationID = "vscode-chat"
|
copilotIntegrationID = "vscode-chat"
|
||||||
copilotOpenAIIntent = "conversation-panel"
|
copilotOpenAIIntent = "conversation-panel"
|
||||||
copilotGitHubAPIVer = "2025-04-01"
|
copilotGitHubAPIVer = "2025-04-01"
|
||||||
)
|
)
|
||||||
|
|
||||||
// GitHubCopilotExecutor handles requests to the GitHub Copilot API.
|
// GitHubCopilotExecutor handles requests to the GitHub Copilot API.
|
||||||
@@ -232,7 +232,7 @@ func (e *GitHubCopilotExecutor) Execute(ctx context.Context, auth *cliproxyauth.
|
|||||||
}
|
}
|
||||||
|
|
||||||
// ExecuteStream handles streaming requests to GitHub Copilot.
|
// ExecuteStream handles streaming requests to GitHub Copilot.
|
||||||
func (e *GitHubCopilotExecutor) ExecuteStream(ctx context.Context, auth *cliproxyauth.Auth, req cliproxyexecutor.Request, opts cliproxyexecutor.Options) (stream <-chan cliproxyexecutor.StreamChunk, err error) {
|
func (e *GitHubCopilotExecutor) ExecuteStream(ctx context.Context, auth *cliproxyauth.Auth, req cliproxyexecutor.Request, opts cliproxyexecutor.Options) (_ *cliproxyexecutor.StreamResult, err error) {
|
||||||
apiToken, baseURL, errToken := e.ensureAPIToken(ctx, auth)
|
apiToken, baseURL, errToken := e.ensureAPIToken(ctx, auth)
|
||||||
if errToken != nil {
|
if errToken != nil {
|
||||||
return nil, errToken
|
return nil, errToken
|
||||||
@@ -341,7 +341,6 @@ func (e *GitHubCopilotExecutor) ExecuteStream(ctx context.Context, auth *cliprox
|
|||||||
}
|
}
|
||||||
|
|
||||||
out := make(chan cliproxyexecutor.StreamChunk)
|
out := make(chan cliproxyexecutor.StreamChunk)
|
||||||
stream = out
|
|
||||||
|
|
||||||
go func() {
|
go func() {
|
||||||
defer close(out)
|
defer close(out)
|
||||||
@@ -394,7 +393,10 @@ func (e *GitHubCopilotExecutor) ExecuteStream(ctx context.Context, auth *cliprox
|
|||||||
}
|
}
|
||||||
}()
|
}()
|
||||||
|
|
||||||
return stream, nil
|
return &cliproxyexecutor.StreamResult{
|
||||||
|
Headers: httpResp.Header.Clone(),
|
||||||
|
Chunks: out,
|
||||||
|
}, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
// CountTokens is not supported for GitHub Copilot.
|
// CountTokens is not supported for GitHub Copilot.
|
||||||
|
|||||||
@@ -169,12 +169,12 @@ func (e *IFlowExecutor) Execute(ctx context.Context, auth *cliproxyauth.Auth, re
|
|||||||
// Note: TranslateNonStream uses req.Model (original with suffix) to preserve
|
// Note: TranslateNonStream uses req.Model (original with suffix) to preserve
|
||||||
// the original model name in the response for client compatibility.
|
// the original model name in the response for client compatibility.
|
||||||
out := sdktranslator.TranslateNonStream(ctx, to, from, req.Model, opts.OriginalRequest, body, data, ¶m)
|
out := sdktranslator.TranslateNonStream(ctx, to, from, req.Model, opts.OriginalRequest, body, data, ¶m)
|
||||||
resp = cliproxyexecutor.Response{Payload: []byte(out)}
|
resp = cliproxyexecutor.Response{Payload: []byte(out), Headers: httpResp.Header.Clone()}
|
||||||
return resp, nil
|
return resp, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
// ExecuteStream performs a streaming chat completion request.
|
// ExecuteStream performs a streaming chat completion request.
|
||||||
func (e *IFlowExecutor) ExecuteStream(ctx context.Context, auth *cliproxyauth.Auth, req cliproxyexecutor.Request, opts cliproxyexecutor.Options) (stream <-chan cliproxyexecutor.StreamChunk, err error) {
|
func (e *IFlowExecutor) ExecuteStream(ctx context.Context, auth *cliproxyauth.Auth, req cliproxyexecutor.Request, opts cliproxyexecutor.Options) (_ *cliproxyexecutor.StreamResult, err error) {
|
||||||
if opts.Alt == "responses/compact" {
|
if opts.Alt == "responses/compact" {
|
||||||
return nil, statusErr{code: http.StatusNotImplemented, msg: "/responses/compact not supported"}
|
return nil, statusErr{code: http.StatusNotImplemented, msg: "/responses/compact not supported"}
|
||||||
}
|
}
|
||||||
@@ -262,7 +262,6 @@ func (e *IFlowExecutor) ExecuteStream(ctx context.Context, auth *cliproxyauth.Au
|
|||||||
}
|
}
|
||||||
|
|
||||||
out := make(chan cliproxyexecutor.StreamChunk)
|
out := make(chan cliproxyexecutor.StreamChunk)
|
||||||
stream = out
|
|
||||||
go func() {
|
go func() {
|
||||||
defer close(out)
|
defer close(out)
|
||||||
defer func() {
|
defer func() {
|
||||||
@@ -294,7 +293,7 @@ func (e *IFlowExecutor) ExecuteStream(ctx context.Context, auth *cliproxyauth.Au
|
|||||||
reporter.ensurePublished(ctx)
|
reporter.ensurePublished(ctx)
|
||||||
}()
|
}()
|
||||||
|
|
||||||
return stream, nil
|
return &cliproxyexecutor.StreamResult{Headers: httpResp.Header.Clone(), Chunks: out}, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func (e *IFlowExecutor) CountTokens(ctx context.Context, auth *cliproxyauth.Auth, req cliproxyexecutor.Request, opts cliproxyexecutor.Options) (cliproxyexecutor.Response, error) {
|
func (e *IFlowExecutor) CountTokens(ctx context.Context, auth *cliproxyauth.Auth, req cliproxyexecutor.Request, opts cliproxyexecutor.Options) (cliproxyexecutor.Response, error) {
|
||||||
|
|||||||
@@ -168,7 +168,7 @@ func (e *KiloExecutor) Execute(ctx context.Context, auth *cliproxyauth.Auth, req
|
|||||||
}
|
}
|
||||||
|
|
||||||
// ExecuteStream performs a streaming request.
|
// ExecuteStream performs a streaming request.
|
||||||
func (e *KiloExecutor) ExecuteStream(ctx context.Context, auth *cliproxyauth.Auth, req cliproxyexecutor.Request, opts cliproxyexecutor.Options) (stream <-chan cliproxyexecutor.StreamChunk, err error) {
|
func (e *KiloExecutor) ExecuteStream(ctx context.Context, auth *cliproxyauth.Auth, req cliproxyexecutor.Request, opts cliproxyexecutor.Options) (_ *cliproxyexecutor.StreamResult, err error) {
|
||||||
baseModel := thinking.ParseSuffix(req.Model).ModelName
|
baseModel := thinking.ParseSuffix(req.Model).ModelName
|
||||||
|
|
||||||
reporter := newUsageReporter(ctx, e.Identifier(), baseModel, auth)
|
reporter := newUsageReporter(ctx, e.Identifier(), baseModel, auth)
|
||||||
@@ -253,7 +253,6 @@ func (e *KiloExecutor) ExecuteStream(ctx context.Context, auth *cliproxyauth.Aut
|
|||||||
}
|
}
|
||||||
|
|
||||||
out := make(chan cliproxyexecutor.StreamChunk)
|
out := make(chan cliproxyexecutor.StreamChunk)
|
||||||
stream = out
|
|
||||||
go func() {
|
go func() {
|
||||||
defer close(out)
|
defer close(out)
|
||||||
defer httpResp.Body.Close()
|
defer httpResp.Body.Close()
|
||||||
@@ -286,7 +285,10 @@ func (e *KiloExecutor) ExecuteStream(ctx context.Context, auth *cliproxyauth.Aut
|
|||||||
reporter.ensurePublished(ctx)
|
reporter.ensurePublished(ctx)
|
||||||
}()
|
}()
|
||||||
|
|
||||||
return stream, nil
|
return &cliproxyexecutor.StreamResult{
|
||||||
|
Headers: httpResp.Header.Clone(),
|
||||||
|
Chunks: out,
|
||||||
|
}, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
// Refresh validates the Kilo token.
|
// Refresh validates the Kilo token.
|
||||||
@@ -456,4 +458,3 @@ func FetchKiloModels(ctx context.Context, auth *cliproxyauth.Auth, cfg *config.C
|
|||||||
|
|
||||||
return allModels
|
return allModels
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|||||||
@@ -161,12 +161,12 @@ func (e *KimiExecutor) Execute(ctx context.Context, auth *cliproxyauth.Auth, req
|
|||||||
// Note: TranslateNonStream uses req.Model (original with suffix) to preserve
|
// Note: TranslateNonStream uses req.Model (original with suffix) to preserve
|
||||||
// the original model name in the response for client compatibility.
|
// the original model name in the response for client compatibility.
|
||||||
out := sdktranslator.TranslateNonStream(ctx, to, from, req.Model, opts.OriginalRequest, body, data, ¶m)
|
out := sdktranslator.TranslateNonStream(ctx, to, from, req.Model, opts.OriginalRequest, body, data, ¶m)
|
||||||
resp = cliproxyexecutor.Response{Payload: []byte(out)}
|
resp = cliproxyexecutor.Response{Payload: []byte(out), Headers: httpResp.Header.Clone()}
|
||||||
return resp, nil
|
return resp, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
// ExecuteStream performs a streaming chat completion request to Kimi.
|
// ExecuteStream performs a streaming chat completion request to Kimi.
|
||||||
func (e *KimiExecutor) ExecuteStream(ctx context.Context, auth *cliproxyauth.Auth, req cliproxyexecutor.Request, opts cliproxyexecutor.Options) (stream <-chan cliproxyexecutor.StreamChunk, err error) {
|
func (e *KimiExecutor) ExecuteStream(ctx context.Context, auth *cliproxyauth.Auth, req cliproxyexecutor.Request, opts cliproxyexecutor.Options) (_ *cliproxyexecutor.StreamResult, err error) {
|
||||||
from := opts.SourceFormat
|
from := opts.SourceFormat
|
||||||
if from.String() == "claude" {
|
if from.String() == "claude" {
|
||||||
auth.Attributes["base_url"] = kimiauth.KimiAPIBaseURL
|
auth.Attributes["base_url"] = kimiauth.KimiAPIBaseURL
|
||||||
@@ -253,7 +253,6 @@ func (e *KimiExecutor) ExecuteStream(ctx context.Context, auth *cliproxyauth.Aut
|
|||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
out := make(chan cliproxyexecutor.StreamChunk)
|
out := make(chan cliproxyexecutor.StreamChunk)
|
||||||
stream = out
|
|
||||||
go func() {
|
go func() {
|
||||||
defer close(out)
|
defer close(out)
|
||||||
defer func() {
|
defer func() {
|
||||||
@@ -285,7 +284,7 @@ func (e *KimiExecutor) ExecuteStream(ctx context.Context, auth *cliproxyauth.Aut
|
|||||||
out <- cliproxyexecutor.StreamChunk{Err: errScan}
|
out <- cliproxyexecutor.StreamChunk{Err: errScan}
|
||||||
}
|
}
|
||||||
}()
|
}()
|
||||||
return stream, nil
|
return &cliproxyexecutor.StreamResult{Headers: httpResp.Header.Clone(), Chunks: out}, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
// CountTokens estimates token count for Kimi requests.
|
// CountTokens estimates token count for Kimi requests.
|
||||||
|
|||||||
@@ -1053,7 +1053,7 @@ func (e *KiroExecutor) executeWithRetry(ctx context.Context, auth *cliproxyauth.
|
|||||||
|
|
||||||
// ExecuteStream handles streaming requests to Kiro API.
|
// ExecuteStream handles streaming requests to Kiro API.
|
||||||
// Supports automatic token refresh on 401/403 errors and quota fallback on 429.
|
// Supports automatic token refresh on 401/403 errors and quota fallback on 429.
|
||||||
func (e *KiroExecutor) ExecuteStream(ctx context.Context, auth *cliproxyauth.Auth, req cliproxyexecutor.Request, opts cliproxyexecutor.Options) (stream <-chan cliproxyexecutor.StreamChunk, err error) {
|
func (e *KiroExecutor) ExecuteStream(ctx context.Context, auth *cliproxyauth.Auth, req cliproxyexecutor.Request, opts cliproxyexecutor.Options) (_ *cliproxyexecutor.StreamResult, err error) {
|
||||||
accessToken, profileArn := kiroCredentials(auth)
|
accessToken, profileArn := kiroCredentials(auth)
|
||||||
if accessToken == "" {
|
if accessToken == "" {
|
||||||
return nil, fmt.Errorf("kiro: access token not found in auth")
|
return nil, fmt.Errorf("kiro: access token not found in auth")
|
||||||
@@ -1110,7 +1110,11 @@ func (e *KiroExecutor) ExecuteStream(ctx context.Context, auth *cliproxyauth.Aut
|
|||||||
// Route to MCP endpoint instead of normal Kiro API
|
// Route to MCP endpoint instead of normal Kiro API
|
||||||
if kiroclaude.HasWebSearchTool(req.Payload) {
|
if kiroclaude.HasWebSearchTool(req.Payload) {
|
||||||
log.Infof("kiro: detected pure web_search request, routing to MCP endpoint")
|
log.Infof("kiro: detected pure web_search request, routing to MCP endpoint")
|
||||||
return e.handleWebSearchStream(ctx, auth, req, opts, accessToken, profileArn)
|
streamWebSearch, errWebSearch := e.handleWebSearchStream(ctx, auth, req, opts, accessToken, profileArn)
|
||||||
|
if errWebSearch != nil {
|
||||||
|
return nil, errWebSearch
|
||||||
|
}
|
||||||
|
return &cliproxyexecutor.StreamResult{Chunks: streamWebSearch}, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
reporter := newUsageReporter(ctx, e.Identifier(), req.Model, auth)
|
reporter := newUsageReporter(ctx, e.Identifier(), req.Model, auth)
|
||||||
@@ -1128,7 +1132,11 @@ func (e *KiroExecutor) ExecuteStream(ctx context.Context, auth *cliproxyauth.Aut
|
|||||||
|
|
||||||
// Execute stream with retry on 401/403 and 429 (quota exhausted)
|
// Execute stream with retry on 401/403 and 429 (quota exhausted)
|
||||||
// Note: currentOrigin and kiroPayload are built inside executeStreamWithRetry for each endpoint
|
// Note: currentOrigin and kiroPayload are built inside executeStreamWithRetry for each endpoint
|
||||||
return e.executeStreamWithRetry(ctx, auth, req, opts, accessToken, effectiveProfileArn, nil, body, from, reporter, "", kiroModelID, isAgentic, isChatOnly, tokenKey)
|
streamKiro, errStreamKiro := e.executeStreamWithRetry(ctx, auth, req, opts, accessToken, effectiveProfileArn, nil, body, from, reporter, "", kiroModelID, isAgentic, isChatOnly, tokenKey)
|
||||||
|
if errStreamKiro != nil {
|
||||||
|
return nil, errStreamKiro
|
||||||
|
}
|
||||||
|
return &cliproxyexecutor.StreamResult{Chunks: streamKiro}, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
// executeStreamWithRetry performs the streaming HTTP request with automatic retry on auth errors.
|
// executeStreamWithRetry performs the streaming HTTP request with automatic retry on auth errors.
|
||||||
@@ -2543,6 +2551,7 @@ func (e *KiroExecutor) streamToChannel(ctx context.Context, body io.Reader, out
|
|||||||
isThinkingBlockOpen := false // Track if thinking content block SSE event is open
|
isThinkingBlockOpen := false // Track if thinking content block SSE event is open
|
||||||
thinkingBlockIndex := -1 // Index of the thinking content block
|
thinkingBlockIndex := -1 // Index of the thinking content block
|
||||||
var accumulatedThinkingContent strings.Builder // Accumulate thinking content for token counting
|
var accumulatedThinkingContent strings.Builder // Accumulate thinking content for token counting
|
||||||
|
hasOfficialReasoningEvent := false // Disable tag parsing after official reasoning events appear
|
||||||
|
|
||||||
// Buffer for handling partial tag matches at chunk boundaries
|
// Buffer for handling partial tag matches at chunk boundaries
|
||||||
var pendingContent strings.Builder // Buffer content that might be part of a tag
|
var pendingContent strings.Builder // Buffer content that might be part of a tag
|
||||||
@@ -2978,6 +2987,31 @@ func (e *KiroExecutor) streamToChannel(ctx context.Context, body io.Reader, out
|
|||||||
lastUsageUpdateTime = time.Now()
|
lastUsageUpdateTime = time.Now()
|
||||||
}
|
}
|
||||||
|
|
||||||
|
if hasOfficialReasoningEvent {
|
||||||
|
processText := strings.TrimSpace(strings.ReplaceAll(strings.ReplaceAll(contentDelta, kirocommon.ThinkingStartTag, ""), kirocommon.ThinkingEndTag, ""))
|
||||||
|
if processText != "" {
|
||||||
|
if !isTextBlockOpen {
|
||||||
|
contentBlockIndex++
|
||||||
|
isTextBlockOpen = true
|
||||||
|
blockStart := kiroclaude.BuildClaudeContentBlockStartEvent(contentBlockIndex, "text", "", "")
|
||||||
|
sseData := sdktranslator.TranslateStream(ctx, sdktranslator.FromString("kiro"), targetFormat, model, originalReq, claudeBody, blockStart, &translatorParam)
|
||||||
|
for _, chunk := range sseData {
|
||||||
|
if chunk != "" {
|
||||||
|
out <- cliproxyexecutor.StreamChunk{Payload: []byte(chunk + "\n\n")}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
claudeEvent := kiroclaude.BuildClaudeStreamEvent(processText, contentBlockIndex)
|
||||||
|
sseData := sdktranslator.TranslateStream(ctx, sdktranslator.FromString("kiro"), targetFormat, model, originalReq, claudeBody, claudeEvent, &translatorParam)
|
||||||
|
for _, chunk := range sseData {
|
||||||
|
if chunk != "" {
|
||||||
|
out <- cliproxyexecutor.StreamChunk{Payload: []byte(chunk + "\n\n")}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
|
||||||
// TAG-BASED THINKING PARSING: Parse <thinking> tags from content
|
// TAG-BASED THINKING PARSING: Parse <thinking> tags from content
|
||||||
// Combine pending content with new content for processing
|
// Combine pending content with new content for processing
|
||||||
pendingContent.WriteString(contentDelta)
|
pendingContent.WriteString(contentDelta)
|
||||||
@@ -3256,6 +3290,7 @@ func (e *KiroExecutor) streamToChannel(ctx context.Context, body io.Reader, out
|
|||||||
}
|
}
|
||||||
|
|
||||||
if thinkingText != "" {
|
if thinkingText != "" {
|
||||||
|
hasOfficialReasoningEvent = true
|
||||||
// Close text block if open before starting thinking block
|
// Close text block if open before starting thinking block
|
||||||
if isTextBlockOpen && contentBlockIndex >= 0 {
|
if isTextBlockOpen && contentBlockIndex >= 0 {
|
||||||
blockStop := kiroclaude.BuildClaudeContentBlockStopEvent(contentBlockIndex)
|
blockStop := kiroclaude.BuildClaudeContentBlockStopEvent(contentBlockIndex)
|
||||||
|
|||||||
@@ -172,11 +172,11 @@ func (e *OpenAICompatExecutor) Execute(ctx context.Context, auth *cliproxyauth.A
|
|||||||
// Translate response back to source format when needed
|
// Translate response back to source format when needed
|
||||||
var param any
|
var param any
|
||||||
out := sdktranslator.TranslateNonStream(ctx, to, from, req.Model, opts.OriginalRequest, translated, body, ¶m)
|
out := sdktranslator.TranslateNonStream(ctx, to, from, req.Model, opts.OriginalRequest, translated, body, ¶m)
|
||||||
resp = cliproxyexecutor.Response{Payload: []byte(out)}
|
resp = cliproxyexecutor.Response{Payload: []byte(out), Headers: httpResp.Header.Clone()}
|
||||||
return resp, nil
|
return resp, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func (e *OpenAICompatExecutor) ExecuteStream(ctx context.Context, auth *cliproxyauth.Auth, req cliproxyexecutor.Request, opts cliproxyexecutor.Options) (stream <-chan cliproxyexecutor.StreamChunk, err error) {
|
func (e *OpenAICompatExecutor) ExecuteStream(ctx context.Context, auth *cliproxyauth.Auth, req cliproxyexecutor.Request, opts cliproxyexecutor.Options) (_ *cliproxyexecutor.StreamResult, err error) {
|
||||||
baseModel := thinking.ParseSuffix(req.Model).ModelName
|
baseModel := thinking.ParseSuffix(req.Model).ModelName
|
||||||
|
|
||||||
reporter := newUsageReporter(ctx, e.Identifier(), baseModel, auth)
|
reporter := newUsageReporter(ctx, e.Identifier(), baseModel, auth)
|
||||||
@@ -258,7 +258,6 @@ func (e *OpenAICompatExecutor) ExecuteStream(ctx context.Context, auth *cliproxy
|
|||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
out := make(chan cliproxyexecutor.StreamChunk)
|
out := make(chan cliproxyexecutor.StreamChunk)
|
||||||
stream = out
|
|
||||||
go func() {
|
go func() {
|
||||||
defer close(out)
|
defer close(out)
|
||||||
defer func() {
|
defer func() {
|
||||||
@@ -298,7 +297,7 @@ func (e *OpenAICompatExecutor) ExecuteStream(ctx context.Context, auth *cliproxy
|
|||||||
// Ensure we record the request if no usage chunk was ever seen
|
// Ensure we record the request if no usage chunk was ever seen
|
||||||
reporter.ensurePublished(ctx)
|
reporter.ensurePublished(ctx)
|
||||||
}()
|
}()
|
||||||
return stream, nil
|
return &cliproxyexecutor.StreamResult{Headers: httpResp.Header.Clone(), Chunks: out}, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func (e *OpenAICompatExecutor) CountTokens(ctx context.Context, auth *cliproxyauth.Auth, req cliproxyexecutor.Request, opts cliproxyexecutor.Options) (cliproxyexecutor.Response, error) {
|
func (e *OpenAICompatExecutor) CountTokens(ctx context.Context, auth *cliproxyauth.Auth, req cliproxyexecutor.Request, opts cliproxyexecutor.Options) (cliproxyexecutor.Response, error) {
|
||||||
|
|||||||
@@ -8,6 +8,7 @@ import (
|
|||||||
"io"
|
"io"
|
||||||
"net/http"
|
"net/http"
|
||||||
"strings"
|
"strings"
|
||||||
|
"sync"
|
||||||
"time"
|
"time"
|
||||||
|
|
||||||
qwenauth "github.com/router-for-me/CLIProxyAPI/v6/internal/auth/qwen"
|
qwenauth "github.com/router-for-me/CLIProxyAPI/v6/internal/auth/qwen"
|
||||||
@@ -22,9 +23,151 @@ import (
|
|||||||
)
|
)
|
||||||
|
|
||||||
const (
|
const (
|
||||||
qwenUserAgent = "QwenCode/0.10.3 (darwin; arm64)"
|
qwenUserAgent = "QwenCode/0.10.3 (darwin; arm64)"
|
||||||
|
qwenRateLimitPerMin = 60 // 60 requests per minute per credential
|
||||||
|
qwenRateLimitWindow = time.Minute // sliding window duration
|
||||||
)
|
)
|
||||||
|
|
||||||
|
// qwenBeijingLoc caches the Beijing timezone to avoid repeated LoadLocation syscalls.
|
||||||
|
var qwenBeijingLoc = func() *time.Location {
|
||||||
|
loc, err := time.LoadLocation("Asia/Shanghai")
|
||||||
|
if err != nil || loc == nil {
|
||||||
|
log.Warnf("qwen: failed to load Asia/Shanghai timezone: %v, using fixed UTC+8", err)
|
||||||
|
return time.FixedZone("CST", 8*3600)
|
||||||
|
}
|
||||||
|
return loc
|
||||||
|
}()
|
||||||
|
|
||||||
|
// qwenQuotaCodes is a package-level set of error codes that indicate quota exhaustion.
|
||||||
|
var qwenQuotaCodes = map[string]struct{}{
|
||||||
|
"insufficient_quota": {},
|
||||||
|
"quota_exceeded": {},
|
||||||
|
}
|
||||||
|
|
||||||
|
// qwenRateLimiter tracks request timestamps per credential for rate limiting.
|
||||||
|
// Qwen has a limit of 60 requests per minute per account.
|
||||||
|
var qwenRateLimiter = struct {
|
||||||
|
sync.Mutex
|
||||||
|
requests map[string][]time.Time // authID -> request timestamps
|
||||||
|
}{
|
||||||
|
requests: make(map[string][]time.Time),
|
||||||
|
}
|
||||||
|
|
||||||
|
// redactAuthID returns a redacted version of the auth ID for safe logging.
|
||||||
|
// Keeps a small prefix/suffix to allow correlation across events.
|
||||||
|
func redactAuthID(id string) string {
|
||||||
|
if id == "" {
|
||||||
|
return ""
|
||||||
|
}
|
||||||
|
if len(id) <= 8 {
|
||||||
|
return id
|
||||||
|
}
|
||||||
|
return id[:4] + "..." + id[len(id)-4:]
|
||||||
|
}
|
||||||
|
|
||||||
|
// checkQwenRateLimit checks if the credential has exceeded the rate limit.
|
||||||
|
// Returns nil if allowed, or a statusErr with retryAfter if rate limited.
|
||||||
|
func checkQwenRateLimit(authID string) error {
|
||||||
|
if authID == "" {
|
||||||
|
// Empty authID should not bypass rate limiting in production
|
||||||
|
// Use debug level to avoid log spam for certain auth flows
|
||||||
|
log.Debug("qwen rate limit check: empty authID, skipping rate limit")
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
now := time.Now()
|
||||||
|
windowStart := now.Add(-qwenRateLimitWindow)
|
||||||
|
|
||||||
|
qwenRateLimiter.Lock()
|
||||||
|
defer qwenRateLimiter.Unlock()
|
||||||
|
|
||||||
|
// Get and filter timestamps within the window
|
||||||
|
timestamps := qwenRateLimiter.requests[authID]
|
||||||
|
var validTimestamps []time.Time
|
||||||
|
for _, ts := range timestamps {
|
||||||
|
if ts.After(windowStart) {
|
||||||
|
validTimestamps = append(validTimestamps, ts)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// Always prune expired entries to prevent memory leak
|
||||||
|
// Delete empty entries, otherwise update with pruned slice
|
||||||
|
if len(validTimestamps) == 0 {
|
||||||
|
delete(qwenRateLimiter.requests, authID)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Check if rate limit exceeded
|
||||||
|
if len(validTimestamps) >= qwenRateLimitPerMin {
|
||||||
|
// Calculate when the oldest request will expire
|
||||||
|
oldestInWindow := validTimestamps[0]
|
||||||
|
retryAfter := oldestInWindow.Add(qwenRateLimitWindow).Sub(now)
|
||||||
|
if retryAfter < time.Second {
|
||||||
|
retryAfter = time.Second
|
||||||
|
}
|
||||||
|
retryAfterSec := int(retryAfter.Seconds())
|
||||||
|
return statusErr{
|
||||||
|
code: http.StatusTooManyRequests,
|
||||||
|
msg: fmt.Sprintf(`{"error":{"code":"rate_limit_exceeded","message":"Qwen rate limit: %d requests/minute exceeded, retry after %ds","type":"rate_limit_exceeded"}}`, qwenRateLimitPerMin, retryAfterSec),
|
||||||
|
retryAfter: &retryAfter,
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// Record this request and update the map with pruned timestamps
|
||||||
|
validTimestamps = append(validTimestamps, now)
|
||||||
|
qwenRateLimiter.requests[authID] = validTimestamps
|
||||||
|
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// isQwenQuotaError checks if the error response indicates a quota exceeded error.
|
||||||
|
// Qwen returns HTTP 403 with error.code="insufficient_quota" when daily quota is exhausted.
|
||||||
|
func isQwenQuotaError(body []byte) bool {
|
||||||
|
code := strings.ToLower(gjson.GetBytes(body, "error.code").String())
|
||||||
|
errType := strings.ToLower(gjson.GetBytes(body, "error.type").String())
|
||||||
|
|
||||||
|
// Primary check: exact match on error.code or error.type (most reliable)
|
||||||
|
if _, ok := qwenQuotaCodes[code]; ok {
|
||||||
|
return true
|
||||||
|
}
|
||||||
|
if _, ok := qwenQuotaCodes[errType]; ok {
|
||||||
|
return true
|
||||||
|
}
|
||||||
|
|
||||||
|
// Fallback: check message only if code/type don't match (less reliable)
|
||||||
|
msg := strings.ToLower(gjson.GetBytes(body, "error.message").String())
|
||||||
|
if strings.Contains(msg, "insufficient_quota") || strings.Contains(msg, "quota exceeded") ||
|
||||||
|
strings.Contains(msg, "free allocated quota exceeded") {
|
||||||
|
return true
|
||||||
|
}
|
||||||
|
|
||||||
|
return false
|
||||||
|
}
|
||||||
|
|
||||||
|
// wrapQwenError wraps an HTTP error response, detecting quota errors and mapping them to 429.
|
||||||
|
// Returns the appropriate status code and retryAfter duration for statusErr.
|
||||||
|
// Only checks for quota errors when httpCode is 403 or 429 to avoid false positives.
|
||||||
|
func wrapQwenError(ctx context.Context, httpCode int, body []byte) (errCode int, retryAfter *time.Duration) {
|
||||||
|
errCode = httpCode
|
||||||
|
// Only check quota errors for expected status codes to avoid false positives
|
||||||
|
// Qwen returns 403 for quota errors, 429 for rate limits
|
||||||
|
if (httpCode == http.StatusForbidden || httpCode == http.StatusTooManyRequests) && isQwenQuotaError(body) {
|
||||||
|
errCode = http.StatusTooManyRequests // Map to 429 to trigger quota logic
|
||||||
|
cooldown := timeUntilNextDay()
|
||||||
|
retryAfter = &cooldown
|
||||||
|
logWithRequestID(ctx).Warnf("qwen quota exceeded (http %d -> %d), cooling down until tomorrow (%v)", httpCode, errCode, cooldown)
|
||||||
|
}
|
||||||
|
return errCode, retryAfter
|
||||||
|
}
|
||||||
|
|
||||||
|
// timeUntilNextDay returns duration until midnight Beijing time (UTC+8).
|
||||||
|
// Qwen's daily quota resets at 00:00 Beijing time.
|
||||||
|
func timeUntilNextDay() time.Duration {
|
||||||
|
now := time.Now()
|
||||||
|
nowLocal := now.In(qwenBeijingLoc)
|
||||||
|
tomorrow := time.Date(nowLocal.Year(), nowLocal.Month(), nowLocal.Day()+1, 0, 0, 0, 0, qwenBeijingLoc)
|
||||||
|
return tomorrow.Sub(now)
|
||||||
|
}
|
||||||
|
|
||||||
// QwenExecutor is a stateless executor for Qwen Code using OpenAI-compatible chat completions.
|
// QwenExecutor is a stateless executor for Qwen Code using OpenAI-compatible chat completions.
|
||||||
// If access token is unavailable, it falls back to legacy via ClientAdapter.
|
// If access token is unavailable, it falls back to legacy via ClientAdapter.
|
||||||
type QwenExecutor struct {
|
type QwenExecutor struct {
|
||||||
@@ -67,6 +210,17 @@ func (e *QwenExecutor) Execute(ctx context.Context, auth *cliproxyauth.Auth, req
|
|||||||
if opts.Alt == "responses/compact" {
|
if opts.Alt == "responses/compact" {
|
||||||
return resp, statusErr{code: http.StatusNotImplemented, msg: "/responses/compact not supported"}
|
return resp, statusErr{code: http.StatusNotImplemented, msg: "/responses/compact not supported"}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// Check rate limit before proceeding
|
||||||
|
var authID string
|
||||||
|
if auth != nil {
|
||||||
|
authID = auth.ID
|
||||||
|
}
|
||||||
|
if err := checkQwenRateLimit(authID); err != nil {
|
||||||
|
logWithRequestID(ctx).Warnf("qwen rate limit exceeded for credential %s", redactAuthID(authID))
|
||||||
|
return resp, err
|
||||||
|
}
|
||||||
|
|
||||||
baseModel := thinking.ParseSuffix(req.Model).ModelName
|
baseModel := thinking.ParseSuffix(req.Model).ModelName
|
||||||
|
|
||||||
token, baseURL := qwenCreds(auth)
|
token, baseURL := qwenCreds(auth)
|
||||||
@@ -102,9 +256,8 @@ func (e *QwenExecutor) Execute(ctx context.Context, auth *cliproxyauth.Auth, req
|
|||||||
return resp, err
|
return resp, err
|
||||||
}
|
}
|
||||||
applyQwenHeaders(httpReq, token, false)
|
applyQwenHeaders(httpReq, token, false)
|
||||||
var authID, authLabel, authType, authValue string
|
var authLabel, authType, authValue string
|
||||||
if auth != nil {
|
if auth != nil {
|
||||||
authID = auth.ID
|
|
||||||
authLabel = auth.Label
|
authLabel = auth.Label
|
||||||
authType, authValue = auth.AccountInfo()
|
authType, authValue = auth.AccountInfo()
|
||||||
}
|
}
|
||||||
@@ -135,8 +288,10 @@ func (e *QwenExecutor) Execute(ctx context.Context, auth *cliproxyauth.Auth, req
|
|||||||
if httpResp.StatusCode < 200 || httpResp.StatusCode >= 300 {
|
if httpResp.StatusCode < 200 || httpResp.StatusCode >= 300 {
|
||||||
b, _ := io.ReadAll(httpResp.Body)
|
b, _ := io.ReadAll(httpResp.Body)
|
||||||
appendAPIResponseChunk(ctx, e.cfg, b)
|
appendAPIResponseChunk(ctx, e.cfg, b)
|
||||||
logWithRequestID(ctx).Debugf("request error, error status: %d, error message: %s", httpResp.StatusCode, summarizeErrorBody(httpResp.Header.Get("Content-Type"), b))
|
|
||||||
err = statusErr{code: httpResp.StatusCode, msg: string(b)}
|
errCode, retryAfter := wrapQwenError(ctx, httpResp.StatusCode, b)
|
||||||
|
logWithRequestID(ctx).Debugf("request error, error status: %d (mapped: %d), error message: %s", httpResp.StatusCode, errCode, summarizeErrorBody(httpResp.Header.Get("Content-Type"), b))
|
||||||
|
err = statusErr{code: errCode, msg: string(b), retryAfter: retryAfter}
|
||||||
return resp, err
|
return resp, err
|
||||||
}
|
}
|
||||||
data, err := io.ReadAll(httpResp.Body)
|
data, err := io.ReadAll(httpResp.Body)
|
||||||
@@ -150,14 +305,25 @@ func (e *QwenExecutor) Execute(ctx context.Context, auth *cliproxyauth.Auth, req
|
|||||||
// Note: TranslateNonStream uses req.Model (original with suffix) to preserve
|
// Note: TranslateNonStream uses req.Model (original with suffix) to preserve
|
||||||
// the original model name in the response for client compatibility.
|
// the original model name in the response for client compatibility.
|
||||||
out := sdktranslator.TranslateNonStream(ctx, to, from, req.Model, opts.OriginalRequest, body, data, ¶m)
|
out := sdktranslator.TranslateNonStream(ctx, to, from, req.Model, opts.OriginalRequest, body, data, ¶m)
|
||||||
resp = cliproxyexecutor.Response{Payload: []byte(out)}
|
resp = cliproxyexecutor.Response{Payload: []byte(out), Headers: httpResp.Header.Clone()}
|
||||||
return resp, nil
|
return resp, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func (e *QwenExecutor) ExecuteStream(ctx context.Context, auth *cliproxyauth.Auth, req cliproxyexecutor.Request, opts cliproxyexecutor.Options) (stream <-chan cliproxyexecutor.StreamChunk, err error) {
|
func (e *QwenExecutor) ExecuteStream(ctx context.Context, auth *cliproxyauth.Auth, req cliproxyexecutor.Request, opts cliproxyexecutor.Options) (_ *cliproxyexecutor.StreamResult, err error) {
|
||||||
if opts.Alt == "responses/compact" {
|
if opts.Alt == "responses/compact" {
|
||||||
return nil, statusErr{code: http.StatusNotImplemented, msg: "/responses/compact not supported"}
|
return nil, statusErr{code: http.StatusNotImplemented, msg: "/responses/compact not supported"}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// Check rate limit before proceeding
|
||||||
|
var authID string
|
||||||
|
if auth != nil {
|
||||||
|
authID = auth.ID
|
||||||
|
}
|
||||||
|
if err := checkQwenRateLimit(authID); err != nil {
|
||||||
|
logWithRequestID(ctx).Warnf("qwen rate limit exceeded for credential %s", redactAuthID(authID))
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
|
||||||
baseModel := thinking.ParseSuffix(req.Model).ModelName
|
baseModel := thinking.ParseSuffix(req.Model).ModelName
|
||||||
|
|
||||||
token, baseURL := qwenCreds(auth)
|
token, baseURL := qwenCreds(auth)
|
||||||
@@ -200,9 +366,8 @@ func (e *QwenExecutor) ExecuteStream(ctx context.Context, auth *cliproxyauth.Aut
|
|||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
applyQwenHeaders(httpReq, token, true)
|
applyQwenHeaders(httpReq, token, true)
|
||||||
var authID, authLabel, authType, authValue string
|
var authLabel, authType, authValue string
|
||||||
if auth != nil {
|
if auth != nil {
|
||||||
authID = auth.ID
|
|
||||||
authLabel = auth.Label
|
authLabel = auth.Label
|
||||||
authType, authValue = auth.AccountInfo()
|
authType, authValue = auth.AccountInfo()
|
||||||
}
|
}
|
||||||
@@ -228,15 +393,16 @@ func (e *QwenExecutor) ExecuteStream(ctx context.Context, auth *cliproxyauth.Aut
|
|||||||
if httpResp.StatusCode < 200 || httpResp.StatusCode >= 300 {
|
if httpResp.StatusCode < 200 || httpResp.StatusCode >= 300 {
|
||||||
b, _ := io.ReadAll(httpResp.Body)
|
b, _ := io.ReadAll(httpResp.Body)
|
||||||
appendAPIResponseChunk(ctx, e.cfg, b)
|
appendAPIResponseChunk(ctx, e.cfg, b)
|
||||||
logWithRequestID(ctx).Debugf("request error, error status: %d, error message: %s", httpResp.StatusCode, summarizeErrorBody(httpResp.Header.Get("Content-Type"), b))
|
|
||||||
|
errCode, retryAfter := wrapQwenError(ctx, httpResp.StatusCode, b)
|
||||||
|
logWithRequestID(ctx).Debugf("request error, error status: %d (mapped: %d), error message: %s", httpResp.StatusCode, errCode, summarizeErrorBody(httpResp.Header.Get("Content-Type"), b))
|
||||||
if errClose := httpResp.Body.Close(); errClose != nil {
|
if errClose := httpResp.Body.Close(); errClose != nil {
|
||||||
log.Errorf("qwen executor: close response body error: %v", errClose)
|
log.Errorf("qwen executor: close response body error: %v", errClose)
|
||||||
}
|
}
|
||||||
err = statusErr{code: httpResp.StatusCode, msg: string(b)}
|
err = statusErr{code: errCode, msg: string(b), retryAfter: retryAfter}
|
||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
out := make(chan cliproxyexecutor.StreamChunk)
|
out := make(chan cliproxyexecutor.StreamChunk)
|
||||||
stream = out
|
|
||||||
go func() {
|
go func() {
|
||||||
defer close(out)
|
defer close(out)
|
||||||
defer func() {
|
defer func() {
|
||||||
@@ -268,7 +434,7 @@ func (e *QwenExecutor) ExecuteStream(ctx context.Context, auth *cliproxyauth.Aut
|
|||||||
out <- cliproxyexecutor.StreamChunk{Err: errScan}
|
out <- cliproxyexecutor.StreamChunk{Err: errScan}
|
||||||
}
|
}
|
||||||
}()
|
}()
|
||||||
return stream, nil
|
return &cliproxyexecutor.StreamResult{Headers: httpResp.Header.Clone(), Chunks: out}, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func (e *QwenExecutor) CountTokens(ctx context.Context, auth *cliproxyauth.Auth, req cliproxyexecutor.Request, opts cliproxyexecutor.Options) (cliproxyexecutor.Response, error) {
|
func (e *QwenExecutor) CountTokens(ctx context.Context, auth *cliproxyauth.Auth, req cliproxyexecutor.Request, opts cliproxyexecutor.Options) (cliproxyexecutor.Response, error) {
|
||||||
|
|||||||
89
internal/runtime/executor/user_id_cache.go
Normal file
89
internal/runtime/executor/user_id_cache.go
Normal file
@@ -0,0 +1,89 @@
|
|||||||
|
package executor
|
||||||
|
|
||||||
|
import (
|
||||||
|
"crypto/sha256"
|
||||||
|
"encoding/hex"
|
||||||
|
"sync"
|
||||||
|
"time"
|
||||||
|
)
|
||||||
|
|
||||||
|
type userIDCacheEntry struct {
|
||||||
|
value string
|
||||||
|
expire time.Time
|
||||||
|
}
|
||||||
|
|
||||||
|
var (
|
||||||
|
userIDCache = make(map[string]userIDCacheEntry)
|
||||||
|
userIDCacheMu sync.RWMutex
|
||||||
|
userIDCacheCleanupOnce sync.Once
|
||||||
|
)
|
||||||
|
|
||||||
|
const (
|
||||||
|
userIDTTL = time.Hour
|
||||||
|
userIDCacheCleanupPeriod = 15 * time.Minute
|
||||||
|
)
|
||||||
|
|
||||||
|
func startUserIDCacheCleanup() {
|
||||||
|
go func() {
|
||||||
|
ticker := time.NewTicker(userIDCacheCleanupPeriod)
|
||||||
|
defer ticker.Stop()
|
||||||
|
for range ticker.C {
|
||||||
|
purgeExpiredUserIDs()
|
||||||
|
}
|
||||||
|
}()
|
||||||
|
}
|
||||||
|
|
||||||
|
func purgeExpiredUserIDs() {
|
||||||
|
now := time.Now()
|
||||||
|
userIDCacheMu.Lock()
|
||||||
|
for key, entry := range userIDCache {
|
||||||
|
if !entry.expire.After(now) {
|
||||||
|
delete(userIDCache, key)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
userIDCacheMu.Unlock()
|
||||||
|
}
|
||||||
|
|
||||||
|
func userIDCacheKey(apiKey string) string {
|
||||||
|
sum := sha256.Sum256([]byte(apiKey))
|
||||||
|
return hex.EncodeToString(sum[:])
|
||||||
|
}
|
||||||
|
|
||||||
|
func cachedUserID(apiKey string) string {
|
||||||
|
if apiKey == "" {
|
||||||
|
return generateFakeUserID()
|
||||||
|
}
|
||||||
|
|
||||||
|
userIDCacheCleanupOnce.Do(startUserIDCacheCleanup)
|
||||||
|
|
||||||
|
key := userIDCacheKey(apiKey)
|
||||||
|
now := time.Now()
|
||||||
|
|
||||||
|
userIDCacheMu.RLock()
|
||||||
|
entry, ok := userIDCache[key]
|
||||||
|
valid := ok && entry.value != "" && entry.expire.After(now) && isValidUserID(entry.value)
|
||||||
|
userIDCacheMu.RUnlock()
|
||||||
|
if valid {
|
||||||
|
userIDCacheMu.Lock()
|
||||||
|
entry = userIDCache[key]
|
||||||
|
if entry.value != "" && entry.expire.After(now) && isValidUserID(entry.value) {
|
||||||
|
entry.expire = now.Add(userIDTTL)
|
||||||
|
userIDCache[key] = entry
|
||||||
|
userIDCacheMu.Unlock()
|
||||||
|
return entry.value
|
||||||
|
}
|
||||||
|
userIDCacheMu.Unlock()
|
||||||
|
}
|
||||||
|
|
||||||
|
newID := generateFakeUserID()
|
||||||
|
|
||||||
|
userIDCacheMu.Lock()
|
||||||
|
entry, ok = userIDCache[key]
|
||||||
|
if !ok || entry.value == "" || !entry.expire.After(now) || !isValidUserID(entry.value) {
|
||||||
|
entry.value = newID
|
||||||
|
}
|
||||||
|
entry.expire = now.Add(userIDTTL)
|
||||||
|
userIDCache[key] = entry
|
||||||
|
userIDCacheMu.Unlock()
|
||||||
|
return entry.value
|
||||||
|
}
|
||||||
86
internal/runtime/executor/user_id_cache_test.go
Normal file
86
internal/runtime/executor/user_id_cache_test.go
Normal file
@@ -0,0 +1,86 @@
|
|||||||
|
package executor
|
||||||
|
|
||||||
|
import (
|
||||||
|
"testing"
|
||||||
|
"time"
|
||||||
|
)
|
||||||
|
|
||||||
|
func resetUserIDCache() {
|
||||||
|
userIDCacheMu.Lock()
|
||||||
|
userIDCache = make(map[string]userIDCacheEntry)
|
||||||
|
userIDCacheMu.Unlock()
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestCachedUserID_ReusesWithinTTL(t *testing.T) {
|
||||||
|
resetUserIDCache()
|
||||||
|
|
||||||
|
first := cachedUserID("api-key-1")
|
||||||
|
second := cachedUserID("api-key-1")
|
||||||
|
|
||||||
|
if first == "" {
|
||||||
|
t.Fatal("expected generated user_id to be non-empty")
|
||||||
|
}
|
||||||
|
if first != second {
|
||||||
|
t.Fatalf("expected cached user_id to be reused, got %q and %q", first, second)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestCachedUserID_ExpiresAfterTTL(t *testing.T) {
|
||||||
|
resetUserIDCache()
|
||||||
|
|
||||||
|
expiredID := cachedUserID("api-key-expired")
|
||||||
|
cacheKey := userIDCacheKey("api-key-expired")
|
||||||
|
userIDCacheMu.Lock()
|
||||||
|
userIDCache[cacheKey] = userIDCacheEntry{
|
||||||
|
value: expiredID,
|
||||||
|
expire: time.Now().Add(-time.Minute),
|
||||||
|
}
|
||||||
|
userIDCacheMu.Unlock()
|
||||||
|
|
||||||
|
newID := cachedUserID("api-key-expired")
|
||||||
|
if newID == expiredID {
|
||||||
|
t.Fatalf("expected expired user_id to be replaced, got %q", newID)
|
||||||
|
}
|
||||||
|
if newID == "" {
|
||||||
|
t.Fatal("expected regenerated user_id to be non-empty")
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestCachedUserID_IsScopedByAPIKey(t *testing.T) {
|
||||||
|
resetUserIDCache()
|
||||||
|
|
||||||
|
first := cachedUserID("api-key-1")
|
||||||
|
second := cachedUserID("api-key-2")
|
||||||
|
|
||||||
|
if first == second {
|
||||||
|
t.Fatalf("expected different API keys to have different user_ids, got %q", first)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestCachedUserID_RenewsTTLOnHit(t *testing.T) {
|
||||||
|
resetUserIDCache()
|
||||||
|
|
||||||
|
key := "api-key-renew"
|
||||||
|
id := cachedUserID(key)
|
||||||
|
cacheKey := userIDCacheKey(key)
|
||||||
|
|
||||||
|
soon := time.Now()
|
||||||
|
userIDCacheMu.Lock()
|
||||||
|
userIDCache[cacheKey] = userIDCacheEntry{
|
||||||
|
value: id,
|
||||||
|
expire: soon.Add(2 * time.Second),
|
||||||
|
}
|
||||||
|
userIDCacheMu.Unlock()
|
||||||
|
|
||||||
|
if refreshed := cachedUserID(key); refreshed != id {
|
||||||
|
t.Fatalf("expected cached user_id to be reused before expiry, got %q", refreshed)
|
||||||
|
}
|
||||||
|
|
||||||
|
userIDCacheMu.RLock()
|
||||||
|
entry := userIDCache[cacheKey]
|
||||||
|
userIDCacheMu.RUnlock()
|
||||||
|
|
||||||
|
if entry.expire.Sub(soon) < 30*time.Minute {
|
||||||
|
t.Fatalf("expected TTL to renew, got %v remaining", entry.expire.Sub(soon))
|
||||||
|
}
|
||||||
|
}
|
||||||
@@ -10,53 +10,10 @@ import (
|
|||||||
|
|
||||||
"github.com/router-for-me/CLIProxyAPI/v6/internal/registry"
|
"github.com/router-for-me/CLIProxyAPI/v6/internal/registry"
|
||||||
"github.com/router-for-me/CLIProxyAPI/v6/internal/thinking"
|
"github.com/router-for-me/CLIProxyAPI/v6/internal/thinking"
|
||||||
log "github.com/sirupsen/logrus"
|
|
||||||
"github.com/tidwall/gjson"
|
"github.com/tidwall/gjson"
|
||||||
"github.com/tidwall/sjson"
|
"github.com/tidwall/sjson"
|
||||||
)
|
)
|
||||||
|
|
||||||
// validReasoningEffortLevels contains the standard values accepted by the
|
|
||||||
// OpenAI reasoning_effort field. Provider-specific extensions (xhigh, minimal,
|
|
||||||
// auto) are NOT in this set and must be clamped before use.
|
|
||||||
var validReasoningEffortLevels = map[string]struct{}{
|
|
||||||
"none": {},
|
|
||||||
"low": {},
|
|
||||||
"medium": {},
|
|
||||||
"high": {},
|
|
||||||
}
|
|
||||||
|
|
||||||
// clampReasoningEffort maps any thinking level string to a value that is safe
|
|
||||||
// to send as OpenAI reasoning_effort. Non-standard CPA-internal values are
|
|
||||||
// mapped to the nearest standard equivalent.
|
|
||||||
//
|
|
||||||
// Mapping rules:
|
|
||||||
// - none / low / medium / high → returned as-is (already valid)
|
|
||||||
// - xhigh → "high" (nearest lower standard level)
|
|
||||||
// - minimal → "low" (nearest higher standard level)
|
|
||||||
// - auto → "medium" (reasonable default)
|
|
||||||
// - anything else → "medium" (safe default)
|
|
||||||
func clampReasoningEffort(level string) string {
|
|
||||||
if _, ok := validReasoningEffortLevels[level]; ok {
|
|
||||||
return level
|
|
||||||
}
|
|
||||||
var clamped string
|
|
||||||
switch level {
|
|
||||||
case string(thinking.LevelXHigh):
|
|
||||||
clamped = string(thinking.LevelHigh)
|
|
||||||
case string(thinking.LevelMinimal):
|
|
||||||
clamped = string(thinking.LevelLow)
|
|
||||||
case string(thinking.LevelAuto):
|
|
||||||
clamped = string(thinking.LevelMedium)
|
|
||||||
default:
|
|
||||||
clamped = string(thinking.LevelMedium)
|
|
||||||
}
|
|
||||||
log.WithFields(log.Fields{
|
|
||||||
"original": level,
|
|
||||||
"clamped": clamped,
|
|
||||||
}).Debug("openai: reasoning_effort clamped to nearest valid standard value")
|
|
||||||
return clamped
|
|
||||||
}
|
|
||||||
|
|
||||||
// Applier implements thinking.ProviderApplier for OpenAI models.
|
// Applier implements thinking.ProviderApplier for OpenAI models.
|
||||||
//
|
//
|
||||||
// OpenAI-specific behavior:
|
// OpenAI-specific behavior:
|
||||||
@@ -101,7 +58,7 @@ func (a *Applier) Apply(body []byte, config thinking.ThinkingConfig, modelInfo *
|
|||||||
}
|
}
|
||||||
|
|
||||||
if config.Mode == thinking.ModeLevel {
|
if config.Mode == thinking.ModeLevel {
|
||||||
result, _ := sjson.SetBytes(body, "reasoning_effort", clampReasoningEffort(string(config.Level)))
|
result, _ := sjson.SetBytes(body, "reasoning_effort", string(config.Level))
|
||||||
return result, nil
|
return result, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -122,7 +79,7 @@ func (a *Applier) Apply(body []byte, config thinking.ThinkingConfig, modelInfo *
|
|||||||
return body, nil
|
return body, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
result, _ := sjson.SetBytes(body, "reasoning_effort", clampReasoningEffort(effort))
|
result, _ := sjson.SetBytes(body, "reasoning_effort", effort)
|
||||||
return result, nil
|
return result, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -157,7 +114,7 @@ func applyCompatibleOpenAI(body []byte, config thinking.ThinkingConfig) ([]byte,
|
|||||||
return body, nil
|
return body, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
result, _ := sjson.SetBytes(body, "reasoning_effort", clampReasoningEffort(effort))
|
result, _ := sjson.SetBytes(body, "reasoning_effort", effort)
|
||||||
return result, nil
|
return result, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|||||||
@@ -231,8 +231,12 @@ func ConvertClaudeRequestToAntigravity(modelName string, inputRawJSON []byte, _
|
|||||||
|
|
||||||
} else if functionResponseResult.IsObject() {
|
} else if functionResponseResult.IsObject() {
|
||||||
functionResponseJSON, _ = sjson.SetRaw(functionResponseJSON, "response.result", functionResponseResult.Raw)
|
functionResponseJSON, _ = sjson.SetRaw(functionResponseJSON, "response.result", functionResponseResult.Raw)
|
||||||
} else {
|
} else if functionResponseResult.Raw != "" {
|
||||||
functionResponseJSON, _ = sjson.SetRaw(functionResponseJSON, "response.result", functionResponseResult.Raw)
|
functionResponseJSON, _ = sjson.SetRaw(functionResponseJSON, "response.result", functionResponseResult.Raw)
|
||||||
|
} else {
|
||||||
|
// Content field is missing entirely — .Raw is empty which
|
||||||
|
// causes sjson.SetRaw to produce invalid JSON (e.g. "result":}).
|
||||||
|
functionResponseJSON, _ = sjson.Set(functionResponseJSON, "response.result", "")
|
||||||
}
|
}
|
||||||
|
|
||||||
partJSON := `{}`
|
partJSON := `{}`
|
||||||
|
|||||||
@@ -661,6 +661,85 @@ func TestConvertClaudeRequestToAntigravity_ThinkingOnly_NoHint(t *testing.T) {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func TestConvertClaudeRequestToAntigravity_ToolResultNoContent(t *testing.T) {
|
||||||
|
// Bug repro: tool_result with no content field produces invalid JSON
|
||||||
|
inputJSON := []byte(`{
|
||||||
|
"model": "claude-opus-4-6-thinking",
|
||||||
|
"messages": [
|
||||||
|
{
|
||||||
|
"role": "assistant",
|
||||||
|
"content": [
|
||||||
|
{
|
||||||
|
"type": "tool_use",
|
||||||
|
"id": "MyTool-123-456",
|
||||||
|
"name": "MyTool",
|
||||||
|
"input": {"key": "value"}
|
||||||
|
}
|
||||||
|
]
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"role": "user",
|
||||||
|
"content": [
|
||||||
|
{
|
||||||
|
"type": "tool_result",
|
||||||
|
"tool_use_id": "MyTool-123-456"
|
||||||
|
}
|
||||||
|
]
|
||||||
|
}
|
||||||
|
]
|
||||||
|
}`)
|
||||||
|
|
||||||
|
output := ConvertClaudeRequestToAntigravity("claude-opus-4-6-thinking", inputJSON, true)
|
||||||
|
outputStr := string(output)
|
||||||
|
|
||||||
|
if !gjson.Valid(outputStr) {
|
||||||
|
t.Errorf("Result is not valid JSON:\n%s", outputStr)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Verify the functionResponse has a valid result value
|
||||||
|
fr := gjson.Get(outputStr, "request.contents.1.parts.0.functionResponse.response.result")
|
||||||
|
if !fr.Exists() {
|
||||||
|
t.Error("functionResponse.response.result should exist")
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestConvertClaudeRequestToAntigravity_ToolResultNullContent(t *testing.T) {
|
||||||
|
// Bug repro: tool_result with null content produces invalid JSON
|
||||||
|
inputJSON := []byte(`{
|
||||||
|
"model": "claude-opus-4-6-thinking",
|
||||||
|
"messages": [
|
||||||
|
{
|
||||||
|
"role": "assistant",
|
||||||
|
"content": [
|
||||||
|
{
|
||||||
|
"type": "tool_use",
|
||||||
|
"id": "MyTool-123-456",
|
||||||
|
"name": "MyTool",
|
||||||
|
"input": {"key": "value"}
|
||||||
|
}
|
||||||
|
]
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"role": "user",
|
||||||
|
"content": [
|
||||||
|
{
|
||||||
|
"type": "tool_result",
|
||||||
|
"tool_use_id": "MyTool-123-456",
|
||||||
|
"content": null
|
||||||
|
}
|
||||||
|
]
|
||||||
|
}
|
||||||
|
]
|
||||||
|
}`)
|
||||||
|
|
||||||
|
output := ConvertClaudeRequestToAntigravity("claude-opus-4-6-thinking", inputJSON, true)
|
||||||
|
outputStr := string(output)
|
||||||
|
|
||||||
|
if !gjson.Valid(outputStr) {
|
||||||
|
t.Errorf("Result is not valid JSON:\n%s", outputStr)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
func TestConvertClaudeRequestToAntigravity_ToolAndThinking_NoExistingSystem(t *testing.T) {
|
func TestConvertClaudeRequestToAntigravity_ToolAndThinking_NoExistingSystem(t *testing.T) {
|
||||||
// When tools + thinking but no system instruction, should create one with hint
|
// When tools + thinking but no system instruction, should create one with hint
|
||||||
inputJSON := []byte(`{
|
inputJSON := []byte(`{
|
||||||
|
|||||||
@@ -95,9 +95,9 @@ func ConvertAntigravityResponseToOpenAI(_ context.Context, _ string, originalReq
|
|||||||
if totalTokenCountResult := usageResult.Get("totalTokenCount"); totalTokenCountResult.Exists() {
|
if totalTokenCountResult := usageResult.Get("totalTokenCount"); totalTokenCountResult.Exists() {
|
||||||
template, _ = sjson.Set(template, "usage.total_tokens", totalTokenCountResult.Int())
|
template, _ = sjson.Set(template, "usage.total_tokens", totalTokenCountResult.Int())
|
||||||
}
|
}
|
||||||
promptTokenCount := usageResult.Get("promptTokenCount").Int() - cachedTokenCount
|
promptTokenCount := usageResult.Get("promptTokenCount").Int()
|
||||||
thoughtsTokenCount := usageResult.Get("thoughtsTokenCount").Int()
|
thoughtsTokenCount := usageResult.Get("thoughtsTokenCount").Int()
|
||||||
template, _ = sjson.Set(template, "usage.prompt_tokens", promptTokenCount+thoughtsTokenCount)
|
template, _ = sjson.Set(template, "usage.prompt_tokens", promptTokenCount)
|
||||||
if thoughtsTokenCount > 0 {
|
if thoughtsTokenCount > 0 {
|
||||||
template, _ = sjson.Set(template, "usage.completion_tokens_details.reasoning_tokens", thoughtsTokenCount)
|
template, _ = sjson.Set(template, "usage.completion_tokens_details.reasoning_tokens", thoughtsTokenCount)
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -199,6 +199,21 @@ func ConvertOpenAIRequestToClaude(modelName string, inputRawJSON []byte, stream
|
|||||||
msg, _ = sjson.SetRaw(msg, "content.-1", imagePart)
|
msg, _ = sjson.SetRaw(msg, "content.-1", imagePart)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
case "file":
|
||||||
|
fileData := part.Get("file.file_data").String()
|
||||||
|
if strings.HasPrefix(fileData, "data:") {
|
||||||
|
semicolonIdx := strings.Index(fileData, ";")
|
||||||
|
commaIdx := strings.Index(fileData, ",")
|
||||||
|
if semicolonIdx != -1 && commaIdx != -1 && commaIdx > semicolonIdx {
|
||||||
|
mediaType := strings.TrimPrefix(fileData[:semicolonIdx], "data:")
|
||||||
|
data := fileData[commaIdx+1:]
|
||||||
|
docPart := `{"type":"document","source":{"type":"base64","media_type":"","data":""}}`
|
||||||
|
docPart, _ = sjson.Set(docPart, "source.media_type", mediaType)
|
||||||
|
docPart, _ = sjson.Set(docPart, "source.data", data)
|
||||||
|
msg, _ = sjson.SetRaw(msg, "content.-1", docPart)
|
||||||
|
}
|
||||||
|
}
|
||||||
}
|
}
|
||||||
return true
|
return true
|
||||||
})
|
})
|
||||||
|
|||||||
@@ -155,6 +155,7 @@ func ConvertOpenAIResponsesRequestToClaude(modelName string, inputRawJSON []byte
|
|||||||
var textAggregate strings.Builder
|
var textAggregate strings.Builder
|
||||||
var partsJSON []string
|
var partsJSON []string
|
||||||
hasImage := false
|
hasImage := false
|
||||||
|
hasFile := false
|
||||||
if parts := item.Get("content"); parts.Exists() && parts.IsArray() {
|
if parts := item.Get("content"); parts.Exists() && parts.IsArray() {
|
||||||
parts.ForEach(func(_, part gjson.Result) bool {
|
parts.ForEach(func(_, part gjson.Result) bool {
|
||||||
ptype := part.Get("type").String()
|
ptype := part.Get("type").String()
|
||||||
@@ -207,6 +208,30 @@ func ConvertOpenAIResponsesRequestToClaude(modelName string, inputRawJSON []byte
|
|||||||
hasImage = true
|
hasImage = true
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
case "input_file":
|
||||||
|
fileData := part.Get("file_data").String()
|
||||||
|
if fileData != "" {
|
||||||
|
mediaType := "application/octet-stream"
|
||||||
|
data := fileData
|
||||||
|
if strings.HasPrefix(fileData, "data:") {
|
||||||
|
trimmed := strings.TrimPrefix(fileData, "data:")
|
||||||
|
mediaAndData := strings.SplitN(trimmed, ";base64,", 2)
|
||||||
|
if len(mediaAndData) == 2 {
|
||||||
|
if mediaAndData[0] != "" {
|
||||||
|
mediaType = mediaAndData[0]
|
||||||
|
}
|
||||||
|
data = mediaAndData[1]
|
||||||
|
}
|
||||||
|
}
|
||||||
|
contentPart := `{"type":"document","source":{"type":"base64","media_type":"","data":""}}`
|
||||||
|
contentPart, _ = sjson.Set(contentPart, "source.media_type", mediaType)
|
||||||
|
contentPart, _ = sjson.Set(contentPart, "source.data", data)
|
||||||
|
partsJSON = append(partsJSON, contentPart)
|
||||||
|
if role == "" {
|
||||||
|
role = "user"
|
||||||
|
}
|
||||||
|
hasFile = true
|
||||||
|
}
|
||||||
}
|
}
|
||||||
return true
|
return true
|
||||||
})
|
})
|
||||||
@@ -228,7 +253,7 @@ func ConvertOpenAIResponsesRequestToClaude(modelName string, inputRawJSON []byte
|
|||||||
if len(partsJSON) > 0 {
|
if len(partsJSON) > 0 {
|
||||||
msg := `{"role":"","content":[]}`
|
msg := `{"role":"","content":[]}`
|
||||||
msg, _ = sjson.Set(msg, "role", role)
|
msg, _ = sjson.Set(msg, "role", role)
|
||||||
if len(partsJSON) == 1 && !hasImage {
|
if len(partsJSON) == 1 && !hasImage && !hasFile {
|
||||||
// Preserve legacy behavior for single text content
|
// Preserve legacy behavior for single text content
|
||||||
msg, _ = sjson.Delete(msg, "content")
|
msg, _ = sjson.Delete(msg, "content")
|
||||||
textPart := gjson.Parse(partsJSON[0])
|
textPart := gjson.Parse(partsJSON[0])
|
||||||
|
|||||||
@@ -22,8 +22,9 @@ var (
|
|||||||
|
|
||||||
// ConvertCodexResponseToClaudeParams holds parameters for response conversion.
|
// ConvertCodexResponseToClaudeParams holds parameters for response conversion.
|
||||||
type ConvertCodexResponseToClaudeParams struct {
|
type ConvertCodexResponseToClaudeParams struct {
|
||||||
HasToolCall bool
|
HasToolCall bool
|
||||||
BlockIndex int
|
BlockIndex int
|
||||||
|
HasReceivedArgumentsDelta bool
|
||||||
}
|
}
|
||||||
|
|
||||||
// ConvertCodexResponseToClaude performs sophisticated streaming response format conversion.
|
// ConvertCodexResponseToClaude performs sophisticated streaming response format conversion.
|
||||||
@@ -137,6 +138,7 @@ func ConvertCodexResponseToClaude(_ context.Context, _ string, originalRequestRa
|
|||||||
itemType := itemResult.Get("type").String()
|
itemType := itemResult.Get("type").String()
|
||||||
if itemType == "function_call" {
|
if itemType == "function_call" {
|
||||||
(*param).(*ConvertCodexResponseToClaudeParams).HasToolCall = true
|
(*param).(*ConvertCodexResponseToClaudeParams).HasToolCall = true
|
||||||
|
(*param).(*ConvertCodexResponseToClaudeParams).HasReceivedArgumentsDelta = false
|
||||||
template = `{"type":"content_block_start","index":0,"content_block":{"type":"tool_use","id":"","name":"","input":{}}}`
|
template = `{"type":"content_block_start","index":0,"content_block":{"type":"tool_use","id":"","name":"","input":{}}}`
|
||||||
template, _ = sjson.Set(template, "index", (*param).(*ConvertCodexResponseToClaudeParams).BlockIndex)
|
template, _ = sjson.Set(template, "index", (*param).(*ConvertCodexResponseToClaudeParams).BlockIndex)
|
||||||
template, _ = sjson.Set(template, "content_block.id", itemResult.Get("call_id").String())
|
template, _ = sjson.Set(template, "content_block.id", itemResult.Get("call_id").String())
|
||||||
@@ -171,12 +173,29 @@ func ConvertCodexResponseToClaude(_ context.Context, _ string, originalRequestRa
|
|||||||
output += fmt.Sprintf("data: %s\n\n", template)
|
output += fmt.Sprintf("data: %s\n\n", template)
|
||||||
}
|
}
|
||||||
} else if typeStr == "response.function_call_arguments.delta" {
|
} else if typeStr == "response.function_call_arguments.delta" {
|
||||||
|
(*param).(*ConvertCodexResponseToClaudeParams).HasReceivedArgumentsDelta = true
|
||||||
template = `{"type":"content_block_delta","index":0,"delta":{"type":"input_json_delta","partial_json":""}}`
|
template = `{"type":"content_block_delta","index":0,"delta":{"type":"input_json_delta","partial_json":""}}`
|
||||||
template, _ = sjson.Set(template, "index", (*param).(*ConvertCodexResponseToClaudeParams).BlockIndex)
|
template, _ = sjson.Set(template, "index", (*param).(*ConvertCodexResponseToClaudeParams).BlockIndex)
|
||||||
template, _ = sjson.Set(template, "delta.partial_json", rootResult.Get("delta").String())
|
template, _ = sjson.Set(template, "delta.partial_json", rootResult.Get("delta").String())
|
||||||
|
|
||||||
output += "event: content_block_delta\n"
|
output += "event: content_block_delta\n"
|
||||||
output += fmt.Sprintf("data: %s\n\n", template)
|
output += fmt.Sprintf("data: %s\n\n", template)
|
||||||
|
} else if typeStr == "response.function_call_arguments.done" {
|
||||||
|
// Some models (e.g. gpt-5.3-codex-spark) send function call arguments
|
||||||
|
// in a single "done" event without preceding "delta" events.
|
||||||
|
// Emit the full arguments as a single input_json_delta so the
|
||||||
|
// downstream Claude client receives the complete tool input.
|
||||||
|
// When delta events were already received, skip to avoid duplicating arguments.
|
||||||
|
if !(*param).(*ConvertCodexResponseToClaudeParams).HasReceivedArgumentsDelta {
|
||||||
|
if args := rootResult.Get("arguments").String(); args != "" {
|
||||||
|
template = `{"type":"content_block_delta","index":0,"delta":{"type":"input_json_delta","partial_json":""}}`
|
||||||
|
template, _ = sjson.Set(template, "index", (*param).(*ConvertCodexResponseToClaudeParams).BlockIndex)
|
||||||
|
template, _ = sjson.Set(template, "delta.partial_json", args)
|
||||||
|
|
||||||
|
output += "event: content_block_delta\n"
|
||||||
|
output += fmt.Sprintf("data: %s\n\n", template)
|
||||||
|
}
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
return []string{output}
|
return []string{output}
|
||||||
|
|||||||
@@ -180,7 +180,19 @@ func ConvertOpenAIRequestToCodex(modelName string, inputRawJSON []byte, stream b
|
|||||||
msg, _ = sjson.SetRaw(msg, "content.-1", part)
|
msg, _ = sjson.SetRaw(msg, "content.-1", part)
|
||||||
}
|
}
|
||||||
case "file":
|
case "file":
|
||||||
// Files are not specified in examples; skip for now
|
if role == "user" {
|
||||||
|
fileData := it.Get("file.file_data").String()
|
||||||
|
filename := it.Get("file.filename").String()
|
||||||
|
if fileData != "" {
|
||||||
|
part := `{}`
|
||||||
|
part, _ = sjson.Set(part, "type", "input_file")
|
||||||
|
part, _ = sjson.Set(part, "file_data", fileData)
|
||||||
|
if filename != "" {
|
||||||
|
part, _ = sjson.Set(part, "filename", filename)
|
||||||
|
}
|
||||||
|
msg, _ = sjson.SetRaw(msg, "content.-1", part)
|
||||||
|
}
|
||||||
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -26,6 +26,8 @@ func ConvertOpenAIResponsesRequestToCodex(modelName string, inputRawJSON []byte,
|
|||||||
rawJSON, _ = sjson.DeleteBytes(rawJSON, "temperature")
|
rawJSON, _ = sjson.DeleteBytes(rawJSON, "temperature")
|
||||||
rawJSON, _ = sjson.DeleteBytes(rawJSON, "top_p")
|
rawJSON, _ = sjson.DeleteBytes(rawJSON, "top_p")
|
||||||
rawJSON, _ = sjson.DeleteBytes(rawJSON, "service_tier")
|
rawJSON, _ = sjson.DeleteBytes(rawJSON, "service_tier")
|
||||||
|
rawJSON, _ = sjson.DeleteBytes(rawJSON, "truncation")
|
||||||
|
rawJSON = applyResponsesCompactionCompatibility(rawJSON)
|
||||||
|
|
||||||
// Delete the user field as it is not supported by the Codex upstream.
|
// Delete the user field as it is not supported by the Codex upstream.
|
||||||
rawJSON, _ = sjson.DeleteBytes(rawJSON, "user")
|
rawJSON, _ = sjson.DeleteBytes(rawJSON, "user")
|
||||||
@@ -36,6 +38,23 @@ func ConvertOpenAIResponsesRequestToCodex(modelName string, inputRawJSON []byte,
|
|||||||
return rawJSON
|
return rawJSON
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// applyResponsesCompactionCompatibility handles OpenAI Responses context_management.compaction
|
||||||
|
// for Codex upstream compatibility.
|
||||||
|
//
|
||||||
|
// Codex /responses currently rejects context_management with:
|
||||||
|
// {"detail":"Unsupported parameter: context_management"}.
|
||||||
|
//
|
||||||
|
// Compatibility strategy:
|
||||||
|
// 1) Remove context_management before forwarding to Codex upstream.
|
||||||
|
func applyResponsesCompactionCompatibility(rawJSON []byte) []byte {
|
||||||
|
if !gjson.GetBytes(rawJSON, "context_management").Exists() {
|
||||||
|
return rawJSON
|
||||||
|
}
|
||||||
|
|
||||||
|
rawJSON, _ = sjson.DeleteBytes(rawJSON, "context_management")
|
||||||
|
return rawJSON
|
||||||
|
}
|
||||||
|
|
||||||
// convertSystemRoleToDeveloper traverses the input array and converts any message items
|
// convertSystemRoleToDeveloper traverses the input array and converts any message items
|
||||||
// with role "system" to role "developer". This is necessary because Codex API does not
|
// with role "system" to role "developer". This is necessary because Codex API does not
|
||||||
// accept "system" role in the input array.
|
// accept "system" role in the input array.
|
||||||
|
|||||||
@@ -280,3 +280,41 @@ func TestUserFieldDeletion(t *testing.T) {
|
|||||||
t.Errorf("user field should be deleted, but it was found with value: %s", userField.Raw)
|
t.Errorf("user field should be deleted, but it was found with value: %s", userField.Raw)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func TestContextManagementCompactionCompatibility(t *testing.T) {
|
||||||
|
inputJSON := []byte(`{
|
||||||
|
"model": "gpt-5.2",
|
||||||
|
"context_management": [
|
||||||
|
{
|
||||||
|
"type": "compaction",
|
||||||
|
"compact_threshold": 12000
|
||||||
|
}
|
||||||
|
],
|
||||||
|
"input": [{"role":"user","content":"hello"}]
|
||||||
|
}`)
|
||||||
|
|
||||||
|
output := ConvertOpenAIResponsesRequestToCodex("gpt-5.2", inputJSON, false)
|
||||||
|
outputStr := string(output)
|
||||||
|
|
||||||
|
if gjson.Get(outputStr, "context_management").Exists() {
|
||||||
|
t.Fatalf("context_management should be removed for Codex compatibility")
|
||||||
|
}
|
||||||
|
if gjson.Get(outputStr, "truncation").Exists() {
|
||||||
|
t.Fatalf("truncation should be removed for Codex compatibility")
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestTruncationRemovedForCodexCompatibility(t *testing.T) {
|
||||||
|
inputJSON := []byte(`{
|
||||||
|
"model": "gpt-5.2",
|
||||||
|
"truncation": "disabled",
|
||||||
|
"input": [{"role":"user","content":"hello"}]
|
||||||
|
}`)
|
||||||
|
|
||||||
|
output := ConvertOpenAIResponsesRequestToCodex("gpt-5.2", inputJSON, false)
|
||||||
|
outputStr := string(output)
|
||||||
|
|
||||||
|
if gjson.Get(outputStr, "truncation").Exists() {
|
||||||
|
t.Fatalf("truncation should be removed for Codex compatibility")
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|||||||
@@ -100,7 +100,7 @@ func ConvertCliResponseToOpenAI(_ context.Context, _ string, originalRequestRawJ
|
|||||||
}
|
}
|
||||||
promptTokenCount := usageResult.Get("promptTokenCount").Int()
|
promptTokenCount := usageResult.Get("promptTokenCount").Int()
|
||||||
thoughtsTokenCount := usageResult.Get("thoughtsTokenCount").Int()
|
thoughtsTokenCount := usageResult.Get("thoughtsTokenCount").Int()
|
||||||
template, _ = sjson.Set(template, "usage.prompt_tokens", promptTokenCount+thoughtsTokenCount)
|
template, _ = sjson.Set(template, "usage.prompt_tokens", promptTokenCount)
|
||||||
if thoughtsTokenCount > 0 {
|
if thoughtsTokenCount > 0 {
|
||||||
template, _ = sjson.Set(template, "usage.completion_tokens_details.reasoning_tokens", thoughtsTokenCount)
|
template, _ = sjson.Set(template, "usage.completion_tokens_details.reasoning_tokens", thoughtsTokenCount)
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -100,9 +100,9 @@ func ConvertGeminiResponseToOpenAI(_ context.Context, _ string, originalRequestR
|
|||||||
if totalTokenCountResult := usageResult.Get("totalTokenCount"); totalTokenCountResult.Exists() {
|
if totalTokenCountResult := usageResult.Get("totalTokenCount"); totalTokenCountResult.Exists() {
|
||||||
baseTemplate, _ = sjson.Set(baseTemplate, "usage.total_tokens", totalTokenCountResult.Int())
|
baseTemplate, _ = sjson.Set(baseTemplate, "usage.total_tokens", totalTokenCountResult.Int())
|
||||||
}
|
}
|
||||||
promptTokenCount := usageResult.Get("promptTokenCount").Int() - cachedTokenCount
|
promptTokenCount := usageResult.Get("promptTokenCount").Int()
|
||||||
thoughtsTokenCount := usageResult.Get("thoughtsTokenCount").Int()
|
thoughtsTokenCount := usageResult.Get("thoughtsTokenCount").Int()
|
||||||
baseTemplate, _ = sjson.Set(baseTemplate, "usage.prompt_tokens", promptTokenCount+thoughtsTokenCount)
|
baseTemplate, _ = sjson.Set(baseTemplate, "usage.prompt_tokens", promptTokenCount)
|
||||||
if thoughtsTokenCount > 0 {
|
if thoughtsTokenCount > 0 {
|
||||||
baseTemplate, _ = sjson.Set(baseTemplate, "usage.completion_tokens_details.reasoning_tokens", thoughtsTokenCount)
|
baseTemplate, _ = sjson.Set(baseTemplate, "usage.completion_tokens_details.reasoning_tokens", thoughtsTokenCount)
|
||||||
}
|
}
|
||||||
@@ -297,7 +297,7 @@ func ConvertGeminiResponseToOpenAINonStream(_ context.Context, _ string, origina
|
|||||||
promptTokenCount := usageResult.Get("promptTokenCount").Int()
|
promptTokenCount := usageResult.Get("promptTokenCount").Int()
|
||||||
thoughtsTokenCount := usageResult.Get("thoughtsTokenCount").Int()
|
thoughtsTokenCount := usageResult.Get("thoughtsTokenCount").Int()
|
||||||
cachedTokenCount := usageResult.Get("cachedContentTokenCount").Int()
|
cachedTokenCount := usageResult.Get("cachedContentTokenCount").Int()
|
||||||
template, _ = sjson.Set(template, "usage.prompt_tokens", promptTokenCount+thoughtsTokenCount)
|
template, _ = sjson.Set(template, "usage.prompt_tokens", promptTokenCount)
|
||||||
if thoughtsTokenCount > 0 {
|
if thoughtsTokenCount > 0 {
|
||||||
template, _ = sjson.Set(template, "usage.completion_tokens_details.reasoning_tokens", thoughtsTokenCount)
|
template, _ = sjson.Set(template, "usage.completion_tokens_details.reasoning_tokens", thoughtsTokenCount)
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -531,8 +531,8 @@ func ConvertGeminiResponseToOpenAIResponses(_ context.Context, modelName string,
|
|||||||
|
|
||||||
// usage mapping
|
// usage mapping
|
||||||
if um := root.Get("usageMetadata"); um.Exists() {
|
if um := root.Get("usageMetadata"); um.Exists() {
|
||||||
// input tokens = prompt + thoughts
|
// input tokens = prompt only (thoughts go to output)
|
||||||
input := um.Get("promptTokenCount").Int() + um.Get("thoughtsTokenCount").Int()
|
input := um.Get("promptTokenCount").Int()
|
||||||
completed, _ = sjson.Set(completed, "response.usage.input_tokens", input)
|
completed, _ = sjson.Set(completed, "response.usage.input_tokens", input)
|
||||||
// cached token details: align with OpenAI "cached_tokens" semantics.
|
// cached token details: align with OpenAI "cached_tokens" semantics.
|
||||||
completed, _ = sjson.Set(completed, "response.usage.input_tokens_details.cached_tokens", um.Get("cachedContentTokenCount").Int())
|
completed, _ = sjson.Set(completed, "response.usage.input_tokens_details.cached_tokens", um.Get("cachedContentTokenCount").Int())
|
||||||
@@ -737,8 +737,8 @@ func ConvertGeminiResponseToOpenAIResponsesNonStream(_ context.Context, _ string
|
|||||||
|
|
||||||
// usage mapping
|
// usage mapping
|
||||||
if um := root.Get("usageMetadata"); um.Exists() {
|
if um := root.Get("usageMetadata"); um.Exists() {
|
||||||
// input tokens = prompt + thoughts
|
// input tokens = prompt only (thoughts go to output)
|
||||||
input := um.Get("promptTokenCount").Int() + um.Get("thoughtsTokenCount").Int()
|
input := um.Get("promptTokenCount").Int()
|
||||||
resp, _ = sjson.Set(resp, "usage.input_tokens", input)
|
resp, _ = sjson.Set(resp, "usage.input_tokens", input)
|
||||||
// cached token details: align with OpenAI "cached_tokens" semantics.
|
// cached token details: align with OpenAI "cached_tokens" semantics.
|
||||||
resp, _ = sjson.Set(resp, "usage.input_tokens_details.cached_tokens", um.Get("cachedContentTokenCount").Int())
|
resp, _ = sjson.Set(resp, "usage.input_tokens_details.cached_tokens", um.Get("cachedContentTokenCount").Int())
|
||||||
|
|||||||
@@ -243,13 +243,11 @@ func BuildKiroPayload(claudeBody []byte, modelID, profileArn, origin string, isA
|
|||||||
// Process messages and build history
|
// Process messages and build history
|
||||||
history, currentUserMsg, currentToolResults := processMessages(messages, modelID, origin)
|
history, currentUserMsg, currentToolResults := processMessages(messages, modelID, origin)
|
||||||
|
|
||||||
// Build content with system prompt (only on first turn to avoid re-injection)
|
// Build content with system prompt.
|
||||||
|
// Keep thinking tags on subsequent turns so multi-turn Claude sessions
|
||||||
|
// continue to emit reasoning events.
|
||||||
if currentUserMsg != nil {
|
if currentUserMsg != nil {
|
||||||
effectiveSystemPrompt := systemPrompt
|
currentUserMsg.Content = buildFinalContent(currentUserMsg.Content, systemPrompt, currentToolResults)
|
||||||
if len(history) > 0 {
|
|
||||||
effectiveSystemPrompt = "" // Don't re-inject on subsequent turns
|
|
||||||
}
|
|
||||||
currentUserMsg.Content = buildFinalContent(currentUserMsg.Content, effectiveSystemPrompt, currentToolResults)
|
|
||||||
|
|
||||||
// Deduplicate currentToolResults
|
// Deduplicate currentToolResults
|
||||||
currentToolResults = deduplicateToolResults(currentToolResults)
|
currentToolResults = deduplicateToolResults(currentToolResults)
|
||||||
@@ -475,6 +473,15 @@ func IsThinkingEnabledWithHeaders(body []byte, headers http.Header) bool {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// Check model name directly for thinking hints.
|
||||||
|
// This enables thinking variants even when clients don't send explicit thinking fields.
|
||||||
|
model := strings.TrimSpace(gjson.GetBytes(body, "model").String())
|
||||||
|
modelLower := strings.ToLower(model)
|
||||||
|
if strings.Contains(modelLower, "thinking") || strings.Contains(modelLower, "-reason") {
|
||||||
|
log.Debugf("kiro: thinking mode enabled via model name hint: %s", model)
|
||||||
|
return true
|
||||||
|
}
|
||||||
|
|
||||||
log.Debugf("kiro: IsThinkingEnabled returning false (no thinking mode detected)")
|
log.Debugf("kiro: IsThinkingEnabled returning false (no thinking mode detected)")
|
||||||
return false
|
return false
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -53,19 +53,25 @@ var KnownCommandTools = map[string]bool{
|
|||||||
"execute_python": true,
|
"execute_python": true,
|
||||||
}
|
}
|
||||||
|
|
||||||
// RequiredFieldsByTool maps tool names to their required fields.
|
// RequiredFieldsByTool maps tool names to their required field groups.
|
||||||
// If any of these fields are missing, the tool input is considered truncated.
|
// Each outer element is a required group; each inner slice lists alternative field names (OR logic).
|
||||||
var RequiredFieldsByTool = map[string][]string{
|
// A group is satisfied when ANY one of its alternatives exists in the parsed input.
|
||||||
"Write": {"file_path", "content"},
|
// All groups must be satisfied for the tool input to be considered valid.
|
||||||
"write_to_file": {"path", "content"},
|
//
|
||||||
"fsWrite": {"path", "content"},
|
// Example:
|
||||||
"create_file": {"path", "content"},
|
// {{"cmd", "command"}} means the tool needs EITHER "cmd" OR "command".
|
||||||
"edit_file": {"path"},
|
// {{"file_path"}, {"content"}} means the tool needs BOTH "file_path" AND "content".
|
||||||
"apply_diff": {"path", "diff"},
|
var RequiredFieldsByTool = map[string][][]string{
|
||||||
"str_replace_editor": {"path", "old_str", "new_str"},
|
"Write": {{"file_path"}, {"content"}},
|
||||||
"Bash": {"command"},
|
"write_to_file": {{"path"}, {"content"}},
|
||||||
"execute": {"command"},
|
"fsWrite": {{"path"}, {"content"}},
|
||||||
"run_command": {"command"},
|
"create_file": {{"path"}, {"content"}},
|
||||||
|
"edit_file": {{"path"}},
|
||||||
|
"apply_diff": {{"path"}, {"diff"}},
|
||||||
|
"str_replace_editor": {{"path"}, {"old_str"}, {"new_str"}},
|
||||||
|
"Bash": {{"cmd", "command"}},
|
||||||
|
"execute": {{"command"}},
|
||||||
|
"run_command": {{"command"}},
|
||||||
}
|
}
|
||||||
|
|
||||||
// DetectTruncation checks if the tool use input appears to be truncated.
|
// DetectTruncation checks if the tool use input appears to be truncated.
|
||||||
@@ -104,9 +110,9 @@ func DetectTruncation(toolName, toolUseID, rawInput string, parsedInput map[stri
|
|||||||
|
|
||||||
// Scenario 3: JSON parsed but critical fields are missing
|
// Scenario 3: JSON parsed but critical fields are missing
|
||||||
if parsedInput != nil {
|
if parsedInput != nil {
|
||||||
requiredFields, hasRequirements := RequiredFieldsByTool[toolName]
|
requiredGroups, hasRequirements := RequiredFieldsByTool[toolName]
|
||||||
if hasRequirements {
|
if hasRequirements {
|
||||||
missingFields := findMissingRequiredFields(parsedInput, requiredFields)
|
missingFields := findMissingRequiredFields(parsedInput, requiredGroups)
|
||||||
if len(missingFields) > 0 {
|
if len(missingFields) > 0 {
|
||||||
info.IsTruncated = true
|
info.IsTruncated = true
|
||||||
info.TruncationType = TruncationTypeMissingFields
|
info.TruncationType = TruncationTypeMissingFields
|
||||||
@@ -253,12 +259,21 @@ func extractParsedFieldNames(parsed map[string]interface{}) map[string]string {
|
|||||||
return fields
|
return fields
|
||||||
}
|
}
|
||||||
|
|
||||||
// findMissingRequiredFields checks which required fields are missing from the parsed input.
|
// findMissingRequiredFields checks which required field groups are unsatisfied.
|
||||||
func findMissingRequiredFields(parsed map[string]interface{}, required []string) []string {
|
// Each group is a slice of alternative field names; the group is satisfied when ANY alternative exists.
|
||||||
|
// Returns the list of unsatisfied groups (represented by their alternatives joined with "/").
|
||||||
|
func findMissingRequiredFields(parsed map[string]interface{}, requiredGroups [][]string) []string {
|
||||||
var missing []string
|
var missing []string
|
||||||
for _, field := range required {
|
for _, group := range requiredGroups {
|
||||||
if _, exists := parsed[field]; !exists {
|
satisfied := false
|
||||||
missing = append(missing, field)
|
for _, field := range group {
|
||||||
|
if _, exists := parsed[field]; exists {
|
||||||
|
satisfied = true
|
||||||
|
break
|
||||||
|
}
|
||||||
|
}
|
||||||
|
if !satisfied {
|
||||||
|
missing = append(missing, strings.Join(group, "/"))
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
return missing
|
return missing
|
||||||
|
|||||||
@@ -234,16 +234,16 @@ func BuildKiroPayloadFromOpenAI(openaiBody []byte, modelID, profileArn, origin s
|
|||||||
// Kiro API supports official thinking/reasoning mode via <thinking_mode> tag.
|
// Kiro API supports official thinking/reasoning mode via <thinking_mode> tag.
|
||||||
// When set to "enabled", Kiro returns reasoning content as official reasoningContentEvent
|
// When set to "enabled", Kiro returns reasoning content as official reasoningContentEvent
|
||||||
// rather than inline <thinking> tags in assistantResponseEvent.
|
// rather than inline <thinking> tags in assistantResponseEvent.
|
||||||
// We use a high max_thinking_length to allow extensive reasoning.
|
// Use a conservative thinking budget to reduce latency/cost spikes in long sessions.
|
||||||
if thinkingEnabled {
|
if thinkingEnabled {
|
||||||
thinkingHint := `<thinking_mode>enabled</thinking_mode>
|
thinkingHint := `<thinking_mode>enabled</thinking_mode>
|
||||||
<max_thinking_length>200000</max_thinking_length>`
|
<max_thinking_length>16000</max_thinking_length>`
|
||||||
if systemPrompt != "" {
|
if systemPrompt != "" {
|
||||||
systemPrompt = thinkingHint + "\n\n" + systemPrompt
|
systemPrompt = thinkingHint + "\n\n" + systemPrompt
|
||||||
} else {
|
} else {
|
||||||
systemPrompt = thinkingHint
|
systemPrompt = thinkingHint
|
||||||
}
|
}
|
||||||
log.Debugf("kiro-openai: injected thinking prompt (official mode)")
|
log.Infof("kiro-openai: injected thinking prompt (official mode), has_tools: %v", len(kiroTools) > 0)
|
||||||
}
|
}
|
||||||
|
|
||||||
// Process messages and build history
|
// Process messages and build history
|
||||||
@@ -578,6 +578,7 @@ func processOpenAIMessages(messages gjson.Result, modelID, origin string) ([]Kir
|
|||||||
|
|
||||||
// Truncate history if too long to prevent Kiro API errors
|
// Truncate history if too long to prevent Kiro API errors
|
||||||
history = truncateHistoryIfNeeded(history)
|
history = truncateHistoryIfNeeded(history)
|
||||||
|
history, currentToolResults = filterOrphanedToolResults(history, currentToolResults)
|
||||||
|
|
||||||
return history, currentUserMsg, currentToolResults
|
return history, currentUserMsg, currentToolResults
|
||||||
}
|
}
|
||||||
@@ -593,6 +594,61 @@ func truncateHistoryIfNeeded(history []KiroHistoryMessage) []KiroHistoryMessage
|
|||||||
return history[len(history)-kiroMaxHistoryMessages:]
|
return history[len(history)-kiroMaxHistoryMessages:]
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func filterOrphanedToolResults(history []KiroHistoryMessage, currentToolResults []KiroToolResult) ([]KiroHistoryMessage, []KiroToolResult) {
|
||||||
|
// Remove tool results with no matching tool_use in retained history.
|
||||||
|
// This happens after truncation when the assistant turn that produced tool_use
|
||||||
|
// is dropped but a later user/tool_result survives.
|
||||||
|
validToolUseIDs := make(map[string]bool)
|
||||||
|
for _, h := range history {
|
||||||
|
if h.AssistantResponseMessage == nil {
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
for _, tu := range h.AssistantResponseMessage.ToolUses {
|
||||||
|
validToolUseIDs[tu.ToolUseID] = true
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
for i, h := range history {
|
||||||
|
if h.UserInputMessage == nil || h.UserInputMessage.UserInputMessageContext == nil {
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
ctx := h.UserInputMessage.UserInputMessageContext
|
||||||
|
if len(ctx.ToolResults) == 0 {
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
|
||||||
|
filtered := make([]KiroToolResult, 0, len(ctx.ToolResults))
|
||||||
|
for _, tr := range ctx.ToolResults {
|
||||||
|
if validToolUseIDs[tr.ToolUseID] {
|
||||||
|
filtered = append(filtered, tr)
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
log.Debugf("kiro-openai: dropping orphaned tool_result in history[%d]: toolUseId=%s (no matching tool_use)", i, tr.ToolUseID)
|
||||||
|
}
|
||||||
|
ctx.ToolResults = filtered
|
||||||
|
if len(ctx.ToolResults) == 0 && len(ctx.Tools) == 0 {
|
||||||
|
h.UserInputMessage.UserInputMessageContext = nil
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
if len(currentToolResults) > 0 {
|
||||||
|
filtered := make([]KiroToolResult, 0, len(currentToolResults))
|
||||||
|
for _, tr := range currentToolResults {
|
||||||
|
if validToolUseIDs[tr.ToolUseID] {
|
||||||
|
filtered = append(filtered, tr)
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
log.Debugf("kiro-openai: dropping orphaned tool_result in currentMessage: toolUseId=%s (no matching tool_use)", tr.ToolUseID)
|
||||||
|
}
|
||||||
|
if len(filtered) != len(currentToolResults) {
|
||||||
|
log.Infof("kiro-openai: dropped %d orphaned tool_result(s) from currentMessage", len(currentToolResults)-len(filtered))
|
||||||
|
}
|
||||||
|
currentToolResults = filtered
|
||||||
|
}
|
||||||
|
|
||||||
|
return history, currentToolResults
|
||||||
|
}
|
||||||
|
|
||||||
// buildUserMessageFromOpenAI builds a user message from OpenAI format and extracts tool results
|
// buildUserMessageFromOpenAI builds a user message from OpenAI format and extracts tool results
|
||||||
func buildUserMessageFromOpenAI(msg gjson.Result, modelID, origin string) (KiroUserInputMessage, []KiroToolResult) {
|
func buildUserMessageFromOpenAI(msg gjson.Result, modelID, origin string) (KiroUserInputMessage, []KiroToolResult) {
|
||||||
content := msg.Get("content")
|
content := msg.Get("content")
|
||||||
@@ -831,7 +887,6 @@ func hasThinkingTagInBody(body []byte) bool {
|
|||||||
return strings.Contains(bodyStr, "<thinking_mode>") || strings.Contains(bodyStr, "<max_thinking_length>")
|
return strings.Contains(bodyStr, "<thinking_mode>") || strings.Contains(bodyStr, "<max_thinking_length>")
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
// extractToolChoiceHint extracts tool_choice from OpenAI request and returns a system prompt hint.
|
// extractToolChoiceHint extracts tool_choice from OpenAI request and returns a system prompt hint.
|
||||||
// OpenAI tool_choice values:
|
// OpenAI tool_choice values:
|
||||||
// - "none": Don't use any tools
|
// - "none": Don't use any tools
|
||||||
|
|||||||
@@ -384,3 +384,57 @@ func TestAssistantEndsConversation(t *testing.T) {
|
|||||||
t.Error("Expected a 'Continue' message to be created when assistant is last")
|
t.Error("Expected a 'Continue' message to be created when assistant is last")
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func TestFilterOrphanedToolResults_RemovesHistoryAndCurrentOrphans(t *testing.T) {
|
||||||
|
history := []KiroHistoryMessage{
|
||||||
|
{
|
||||||
|
AssistantResponseMessage: &KiroAssistantResponseMessage{
|
||||||
|
Content: "assistant",
|
||||||
|
ToolUses: []KiroToolUse{
|
||||||
|
{ToolUseID: "keep-1", Name: "Read", Input: map[string]interface{}{}},
|
||||||
|
},
|
||||||
|
},
|
||||||
|
},
|
||||||
|
{
|
||||||
|
UserInputMessage: &KiroUserInputMessage{
|
||||||
|
Content: "user-with-mixed-results",
|
||||||
|
UserInputMessageContext: &KiroUserInputMessageContext{
|
||||||
|
ToolResults: []KiroToolResult{
|
||||||
|
{ToolUseID: "keep-1", Status: "success", Content: []KiroTextContent{{Text: "ok"}}},
|
||||||
|
{ToolUseID: "orphan-1", Status: "success", Content: []KiroTextContent{{Text: "bad"}}},
|
||||||
|
},
|
||||||
|
},
|
||||||
|
},
|
||||||
|
},
|
||||||
|
{
|
||||||
|
UserInputMessage: &KiroUserInputMessage{
|
||||||
|
Content: "user-only-orphans",
|
||||||
|
UserInputMessageContext: &KiroUserInputMessageContext{
|
||||||
|
ToolResults: []KiroToolResult{
|
||||||
|
{ToolUseID: "orphan-2", Status: "success", Content: []KiroTextContent{{Text: "bad"}}},
|
||||||
|
},
|
||||||
|
},
|
||||||
|
},
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
currentToolResults := []KiroToolResult{
|
||||||
|
{ToolUseID: "keep-1", Status: "success", Content: []KiroTextContent{{Text: "ok"}}},
|
||||||
|
{ToolUseID: "orphan-3", Status: "success", Content: []KiroTextContent{{Text: "bad"}}},
|
||||||
|
}
|
||||||
|
|
||||||
|
filteredHistory, filteredCurrent := filterOrphanedToolResults(history, currentToolResults)
|
||||||
|
|
||||||
|
ctx1 := filteredHistory[1].UserInputMessage.UserInputMessageContext
|
||||||
|
if ctx1 == nil || len(ctx1.ToolResults) != 1 || ctx1.ToolResults[0].ToolUseID != "keep-1" {
|
||||||
|
t.Fatalf("expected mixed history message to keep only keep-1, got: %+v", ctx1)
|
||||||
|
}
|
||||||
|
|
||||||
|
if filteredHistory[2].UserInputMessage.UserInputMessageContext != nil {
|
||||||
|
t.Fatalf("expected orphan-only history context to be removed")
|
||||||
|
}
|
||||||
|
|
||||||
|
if len(filteredCurrent) != 1 || filteredCurrent[0].ToolUseID != "keep-1" {
|
||||||
|
t.Fatalf("expected current tool results to keep only keep-1, got: %+v", filteredCurrent)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|||||||
@@ -112,12 +112,13 @@ func (h *ClaudeCodeAPIHandler) ClaudeCountTokens(c *gin.Context) {
|
|||||||
|
|
||||||
modelName := gjson.GetBytes(rawJSON, "model").String()
|
modelName := gjson.GetBytes(rawJSON, "model").String()
|
||||||
|
|
||||||
resp, errMsg := h.ExecuteCountWithAuthManager(cliCtx, h.HandlerType(), modelName, rawJSON, alt)
|
resp, upstreamHeaders, errMsg := h.ExecuteCountWithAuthManager(cliCtx, h.HandlerType(), modelName, rawJSON, alt)
|
||||||
if errMsg != nil {
|
if errMsg != nil {
|
||||||
h.WriteErrorResponse(c, errMsg)
|
h.WriteErrorResponse(c, errMsg)
|
||||||
cliCancel(errMsg.Error)
|
cliCancel(errMsg.Error)
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
handlers.WriteUpstreamHeaders(c.Writer.Header(), upstreamHeaders)
|
||||||
_, _ = c.Writer.Write(resp)
|
_, _ = c.Writer.Write(resp)
|
||||||
cliCancel()
|
cliCancel()
|
||||||
}
|
}
|
||||||
@@ -165,7 +166,7 @@ func (h *ClaudeCodeAPIHandler) handleNonStreamingResponse(c *gin.Context, rawJSO
|
|||||||
|
|
||||||
modelName := gjson.GetBytes(rawJSON, "model").String()
|
modelName := gjson.GetBytes(rawJSON, "model").String()
|
||||||
|
|
||||||
resp, errMsg := h.ExecuteWithAuthManager(cliCtx, h.HandlerType(), modelName, rawJSON, alt)
|
resp, upstreamHeaders, errMsg := h.ExecuteWithAuthManager(cliCtx, h.HandlerType(), modelName, rawJSON, alt)
|
||||||
stopKeepAlive()
|
stopKeepAlive()
|
||||||
if errMsg != nil {
|
if errMsg != nil {
|
||||||
h.WriteErrorResponse(c, errMsg)
|
h.WriteErrorResponse(c, errMsg)
|
||||||
@@ -194,6 +195,7 @@ func (h *ClaudeCodeAPIHandler) handleNonStreamingResponse(c *gin.Context, rawJSO
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
handlers.WriteUpstreamHeaders(c.Writer.Header(), upstreamHeaders)
|
||||||
_, _ = c.Writer.Write(resp)
|
_, _ = c.Writer.Write(resp)
|
||||||
cliCancel()
|
cliCancel()
|
||||||
}
|
}
|
||||||
@@ -225,7 +227,7 @@ func (h *ClaudeCodeAPIHandler) handleStreamingResponse(c *gin.Context, rawJSON [
|
|||||||
// This allows proper cleanup and cancellation of ongoing requests
|
// This allows proper cleanup and cancellation of ongoing requests
|
||||||
cliCtx, cliCancel := h.GetContextWithCancel(h, c, context.Background())
|
cliCtx, cliCancel := h.GetContextWithCancel(h, c, context.Background())
|
||||||
|
|
||||||
dataChan, errChan := h.ExecuteStreamWithAuthManager(cliCtx, h.HandlerType(), modelName, rawJSON, "")
|
dataChan, upstreamHeaders, errChan := h.ExecuteStreamWithAuthManager(cliCtx, h.HandlerType(), modelName, rawJSON, "")
|
||||||
setSSEHeaders := func() {
|
setSSEHeaders := func() {
|
||||||
c.Header("Content-Type", "text/event-stream")
|
c.Header("Content-Type", "text/event-stream")
|
||||||
c.Header("Cache-Control", "no-cache")
|
c.Header("Cache-Control", "no-cache")
|
||||||
@@ -257,6 +259,7 @@ func (h *ClaudeCodeAPIHandler) handleStreamingResponse(c *gin.Context, rawJSON [
|
|||||||
if !ok {
|
if !ok {
|
||||||
// Stream closed without data? Send DONE or just headers.
|
// Stream closed without data? Send DONE or just headers.
|
||||||
setSSEHeaders()
|
setSSEHeaders()
|
||||||
|
handlers.WriteUpstreamHeaders(c.Writer.Header(), upstreamHeaders)
|
||||||
flusher.Flush()
|
flusher.Flush()
|
||||||
cliCancel(nil)
|
cliCancel(nil)
|
||||||
return
|
return
|
||||||
@@ -264,6 +267,7 @@ func (h *ClaudeCodeAPIHandler) handleStreamingResponse(c *gin.Context, rawJSON [
|
|||||||
|
|
||||||
// Success! Set headers now.
|
// Success! Set headers now.
|
||||||
setSSEHeaders()
|
setSSEHeaders()
|
||||||
|
handlers.WriteUpstreamHeaders(c.Writer.Header(), upstreamHeaders)
|
||||||
|
|
||||||
// Write the first chunk
|
// Write the first chunk
|
||||||
if len(chunk) > 0 {
|
if len(chunk) > 0 {
|
||||||
|
|||||||
@@ -159,7 +159,8 @@ func (h *GeminiCLIAPIHandler) handleInternalStreamGenerateContent(c *gin.Context
|
|||||||
modelName := modelResult.String()
|
modelName := modelResult.String()
|
||||||
|
|
||||||
cliCtx, cliCancel := h.GetContextWithCancel(h, c, context.Background())
|
cliCtx, cliCancel := h.GetContextWithCancel(h, c, context.Background())
|
||||||
dataChan, errChan := h.ExecuteStreamWithAuthManager(cliCtx, h.HandlerType(), modelName, rawJSON, "")
|
dataChan, upstreamHeaders, errChan := h.ExecuteStreamWithAuthManager(cliCtx, h.HandlerType(), modelName, rawJSON, "")
|
||||||
|
handlers.WriteUpstreamHeaders(c.Writer.Header(), upstreamHeaders)
|
||||||
h.forwardCLIStream(c, flusher, "", func(err error) { cliCancel(err) }, dataChan, errChan)
|
h.forwardCLIStream(c, flusher, "", func(err error) { cliCancel(err) }, dataChan, errChan)
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
@@ -172,12 +173,13 @@ func (h *GeminiCLIAPIHandler) handleInternalGenerateContent(c *gin.Context, rawJ
|
|||||||
modelName := modelResult.String()
|
modelName := modelResult.String()
|
||||||
|
|
||||||
cliCtx, cliCancel := h.GetContextWithCancel(h, c, context.Background())
|
cliCtx, cliCancel := h.GetContextWithCancel(h, c, context.Background())
|
||||||
resp, errMsg := h.ExecuteWithAuthManager(cliCtx, h.HandlerType(), modelName, rawJSON, "")
|
resp, upstreamHeaders, errMsg := h.ExecuteWithAuthManager(cliCtx, h.HandlerType(), modelName, rawJSON, "")
|
||||||
if errMsg != nil {
|
if errMsg != nil {
|
||||||
h.WriteErrorResponse(c, errMsg)
|
h.WriteErrorResponse(c, errMsg)
|
||||||
cliCancel(errMsg.Error)
|
cliCancel(errMsg.Error)
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
handlers.WriteUpstreamHeaders(c.Writer.Header(), upstreamHeaders)
|
||||||
_, _ = c.Writer.Write(resp)
|
_, _ = c.Writer.Write(resp)
|
||||||
cliCancel()
|
cliCancel()
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -188,7 +188,7 @@ func (h *GeminiAPIHandler) handleStreamGenerateContent(c *gin.Context, modelName
|
|||||||
}
|
}
|
||||||
|
|
||||||
cliCtx, cliCancel := h.GetContextWithCancel(h, c, context.Background())
|
cliCtx, cliCancel := h.GetContextWithCancel(h, c, context.Background())
|
||||||
dataChan, errChan := h.ExecuteStreamWithAuthManager(cliCtx, h.HandlerType(), modelName, rawJSON, alt)
|
dataChan, upstreamHeaders, errChan := h.ExecuteStreamWithAuthManager(cliCtx, h.HandlerType(), modelName, rawJSON, alt)
|
||||||
|
|
||||||
setSSEHeaders := func() {
|
setSSEHeaders := func() {
|
||||||
c.Header("Content-Type", "text/event-stream")
|
c.Header("Content-Type", "text/event-stream")
|
||||||
@@ -223,6 +223,7 @@ func (h *GeminiAPIHandler) handleStreamGenerateContent(c *gin.Context, modelName
|
|||||||
if alt == "" {
|
if alt == "" {
|
||||||
setSSEHeaders()
|
setSSEHeaders()
|
||||||
}
|
}
|
||||||
|
handlers.WriteUpstreamHeaders(c.Writer.Header(), upstreamHeaders)
|
||||||
flusher.Flush()
|
flusher.Flush()
|
||||||
cliCancel(nil)
|
cliCancel(nil)
|
||||||
return
|
return
|
||||||
@@ -232,6 +233,7 @@ func (h *GeminiAPIHandler) handleStreamGenerateContent(c *gin.Context, modelName
|
|||||||
if alt == "" {
|
if alt == "" {
|
||||||
setSSEHeaders()
|
setSSEHeaders()
|
||||||
}
|
}
|
||||||
|
handlers.WriteUpstreamHeaders(c.Writer.Header(), upstreamHeaders)
|
||||||
|
|
||||||
// Write first chunk
|
// Write first chunk
|
||||||
if alt == "" {
|
if alt == "" {
|
||||||
@@ -262,12 +264,13 @@ func (h *GeminiAPIHandler) handleCountTokens(c *gin.Context, modelName string, r
|
|||||||
c.Header("Content-Type", "application/json")
|
c.Header("Content-Type", "application/json")
|
||||||
alt := h.GetAlt(c)
|
alt := h.GetAlt(c)
|
||||||
cliCtx, cliCancel := h.GetContextWithCancel(h, c, context.Background())
|
cliCtx, cliCancel := h.GetContextWithCancel(h, c, context.Background())
|
||||||
resp, errMsg := h.ExecuteCountWithAuthManager(cliCtx, h.HandlerType(), modelName, rawJSON, alt)
|
resp, upstreamHeaders, errMsg := h.ExecuteCountWithAuthManager(cliCtx, h.HandlerType(), modelName, rawJSON, alt)
|
||||||
if errMsg != nil {
|
if errMsg != nil {
|
||||||
h.WriteErrorResponse(c, errMsg)
|
h.WriteErrorResponse(c, errMsg)
|
||||||
cliCancel(errMsg.Error)
|
cliCancel(errMsg.Error)
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
handlers.WriteUpstreamHeaders(c.Writer.Header(), upstreamHeaders)
|
||||||
_, _ = c.Writer.Write(resp)
|
_, _ = c.Writer.Write(resp)
|
||||||
cliCancel()
|
cliCancel()
|
||||||
}
|
}
|
||||||
@@ -286,13 +289,14 @@ func (h *GeminiAPIHandler) handleGenerateContent(c *gin.Context, modelName strin
|
|||||||
alt := h.GetAlt(c)
|
alt := h.GetAlt(c)
|
||||||
cliCtx, cliCancel := h.GetContextWithCancel(h, c, context.Background())
|
cliCtx, cliCancel := h.GetContextWithCancel(h, c, context.Background())
|
||||||
stopKeepAlive := h.StartNonStreamingKeepAlive(c, cliCtx)
|
stopKeepAlive := h.StartNonStreamingKeepAlive(c, cliCtx)
|
||||||
resp, errMsg := h.ExecuteWithAuthManager(cliCtx, h.HandlerType(), modelName, rawJSON, alt)
|
resp, upstreamHeaders, errMsg := h.ExecuteWithAuthManager(cliCtx, h.HandlerType(), modelName, rawJSON, alt)
|
||||||
stopKeepAlive()
|
stopKeepAlive()
|
||||||
if errMsg != nil {
|
if errMsg != nil {
|
||||||
h.WriteErrorResponse(c, errMsg)
|
h.WriteErrorResponse(c, errMsg)
|
||||||
cliCancel(errMsg.Error)
|
cliCancel(errMsg.Error)
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
handlers.WriteUpstreamHeaders(c.Writer.Header(), upstreamHeaders)
|
||||||
_, _ = c.Writer.Write(resp)
|
_, _ = c.Writer.Write(resp)
|
||||||
cliCancel()
|
cliCancel()
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -179,6 +179,12 @@ func StreamingBootstrapRetries(cfg *config.SDKConfig) int {
|
|||||||
return retries
|
return retries
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// PassthroughHeadersEnabled returns whether upstream response headers should be forwarded to clients.
|
||||||
|
// Default is false.
|
||||||
|
func PassthroughHeadersEnabled(cfg *config.SDKConfig) bool {
|
||||||
|
return cfg != nil && cfg.PassthroughHeaders
|
||||||
|
}
|
||||||
|
|
||||||
func requestExecutionMetadata(ctx context.Context) map[string]any {
|
func requestExecutionMetadata(ctx context.Context) map[string]any {
|
||||||
// Idempotency-Key is an optional client-supplied header used to correlate retries.
|
// Idempotency-Key is an optional client-supplied header used to correlate retries.
|
||||||
// It is forwarded as execution metadata; when absent we generate a UUID.
|
// It is forwarded as execution metadata; when absent we generate a UUID.
|
||||||
@@ -462,10 +468,10 @@ func appendAPIResponse(c *gin.Context, data []byte) {
|
|||||||
|
|
||||||
// ExecuteWithAuthManager executes a non-streaming request via the core auth manager.
|
// ExecuteWithAuthManager executes a non-streaming request via the core auth manager.
|
||||||
// This path is the only supported execution route.
|
// This path is the only supported execution route.
|
||||||
func (h *BaseAPIHandler) ExecuteWithAuthManager(ctx context.Context, handlerType, modelName string, rawJSON []byte, alt string) ([]byte, *interfaces.ErrorMessage) {
|
func (h *BaseAPIHandler) ExecuteWithAuthManager(ctx context.Context, handlerType, modelName string, rawJSON []byte, alt string) ([]byte, http.Header, *interfaces.ErrorMessage) {
|
||||||
providers, normalizedModel, errMsg := h.getRequestDetails(modelName)
|
providers, normalizedModel, errMsg := h.getRequestDetails(modelName)
|
||||||
if errMsg != nil {
|
if errMsg != nil {
|
||||||
return nil, errMsg
|
return nil, nil, errMsg
|
||||||
}
|
}
|
||||||
reqMeta := requestExecutionMetadata(ctx)
|
reqMeta := requestExecutionMetadata(ctx)
|
||||||
reqMeta[coreexecutor.RequestedModelMetadataKey] = normalizedModel
|
reqMeta[coreexecutor.RequestedModelMetadataKey] = normalizedModel
|
||||||
@@ -498,17 +504,20 @@ func (h *BaseAPIHandler) ExecuteWithAuthManager(ctx context.Context, handlerType
|
|||||||
addon = hdr.Clone()
|
addon = hdr.Clone()
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
return nil, &interfaces.ErrorMessage{StatusCode: status, Error: err, Addon: addon}
|
return nil, nil, &interfaces.ErrorMessage{StatusCode: status, Error: err, Addon: addon}
|
||||||
}
|
}
|
||||||
return resp.Payload, nil
|
if !PassthroughHeadersEnabled(h.Cfg) {
|
||||||
|
return resp.Payload, nil, nil
|
||||||
|
}
|
||||||
|
return resp.Payload, FilterUpstreamHeaders(resp.Headers), nil
|
||||||
}
|
}
|
||||||
|
|
||||||
// ExecuteCountWithAuthManager executes a non-streaming request via the core auth manager.
|
// ExecuteCountWithAuthManager executes a non-streaming request via the core auth manager.
|
||||||
// This path is the only supported execution route.
|
// This path is the only supported execution route.
|
||||||
func (h *BaseAPIHandler) ExecuteCountWithAuthManager(ctx context.Context, handlerType, modelName string, rawJSON []byte, alt string) ([]byte, *interfaces.ErrorMessage) {
|
func (h *BaseAPIHandler) ExecuteCountWithAuthManager(ctx context.Context, handlerType, modelName string, rawJSON []byte, alt string) ([]byte, http.Header, *interfaces.ErrorMessage) {
|
||||||
providers, normalizedModel, errMsg := h.getRequestDetails(modelName)
|
providers, normalizedModel, errMsg := h.getRequestDetails(modelName)
|
||||||
if errMsg != nil {
|
if errMsg != nil {
|
||||||
return nil, errMsg
|
return nil, nil, errMsg
|
||||||
}
|
}
|
||||||
reqMeta := requestExecutionMetadata(ctx)
|
reqMeta := requestExecutionMetadata(ctx)
|
||||||
reqMeta[coreexecutor.RequestedModelMetadataKey] = normalizedModel
|
reqMeta[coreexecutor.RequestedModelMetadataKey] = normalizedModel
|
||||||
@@ -541,20 +550,24 @@ func (h *BaseAPIHandler) ExecuteCountWithAuthManager(ctx context.Context, handle
|
|||||||
addon = hdr.Clone()
|
addon = hdr.Clone()
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
return nil, &interfaces.ErrorMessage{StatusCode: status, Error: err, Addon: addon}
|
return nil, nil, &interfaces.ErrorMessage{StatusCode: status, Error: err, Addon: addon}
|
||||||
}
|
}
|
||||||
return resp.Payload, nil
|
if !PassthroughHeadersEnabled(h.Cfg) {
|
||||||
|
return resp.Payload, nil, nil
|
||||||
|
}
|
||||||
|
return resp.Payload, FilterUpstreamHeaders(resp.Headers), nil
|
||||||
}
|
}
|
||||||
|
|
||||||
// ExecuteStreamWithAuthManager executes a streaming request via the core auth manager.
|
// ExecuteStreamWithAuthManager executes a streaming request via the core auth manager.
|
||||||
// This path is the only supported execution route.
|
// This path is the only supported execution route.
|
||||||
func (h *BaseAPIHandler) ExecuteStreamWithAuthManager(ctx context.Context, handlerType, modelName string, rawJSON []byte, alt string) (<-chan []byte, <-chan *interfaces.ErrorMessage) {
|
// The returned http.Header carries upstream response headers captured before streaming begins.
|
||||||
|
func (h *BaseAPIHandler) ExecuteStreamWithAuthManager(ctx context.Context, handlerType, modelName string, rawJSON []byte, alt string) (<-chan []byte, http.Header, <-chan *interfaces.ErrorMessage) {
|
||||||
providers, normalizedModel, errMsg := h.getRequestDetails(modelName)
|
providers, normalizedModel, errMsg := h.getRequestDetails(modelName)
|
||||||
if errMsg != nil {
|
if errMsg != nil {
|
||||||
errChan := make(chan *interfaces.ErrorMessage, 1)
|
errChan := make(chan *interfaces.ErrorMessage, 1)
|
||||||
errChan <- errMsg
|
errChan <- errMsg
|
||||||
close(errChan)
|
close(errChan)
|
||||||
return nil, errChan
|
return nil, nil, errChan
|
||||||
}
|
}
|
||||||
reqMeta := requestExecutionMetadata(ctx)
|
reqMeta := requestExecutionMetadata(ctx)
|
||||||
reqMeta[coreexecutor.RequestedModelMetadataKey] = normalizedModel
|
reqMeta[coreexecutor.RequestedModelMetadataKey] = normalizedModel
|
||||||
@@ -573,7 +586,7 @@ func (h *BaseAPIHandler) ExecuteStreamWithAuthManager(ctx context.Context, handl
|
|||||||
SourceFormat: sdktranslator.FromString(handlerType),
|
SourceFormat: sdktranslator.FromString(handlerType),
|
||||||
}
|
}
|
||||||
opts.Metadata = reqMeta
|
opts.Metadata = reqMeta
|
||||||
chunks, err := h.AuthManager.ExecuteStream(ctx, providers, req, opts)
|
streamResult, err := h.AuthManager.ExecuteStream(ctx, providers, req, opts)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
errChan := make(chan *interfaces.ErrorMessage, 1)
|
errChan := make(chan *interfaces.ErrorMessage, 1)
|
||||||
status := http.StatusInternalServerError
|
status := http.StatusInternalServerError
|
||||||
@@ -590,8 +603,19 @@ func (h *BaseAPIHandler) ExecuteStreamWithAuthManager(ctx context.Context, handl
|
|||||||
}
|
}
|
||||||
errChan <- &interfaces.ErrorMessage{StatusCode: status, Error: err, Addon: addon}
|
errChan <- &interfaces.ErrorMessage{StatusCode: status, Error: err, Addon: addon}
|
||||||
close(errChan)
|
close(errChan)
|
||||||
return nil, errChan
|
return nil, nil, errChan
|
||||||
}
|
}
|
||||||
|
passthroughHeadersEnabled := PassthroughHeadersEnabled(h.Cfg)
|
||||||
|
// Capture upstream headers from the initial connection synchronously before the goroutine starts.
|
||||||
|
// Keep a mutable map so bootstrap retries can replace it before first payload is sent.
|
||||||
|
var upstreamHeaders http.Header
|
||||||
|
if passthroughHeadersEnabled {
|
||||||
|
upstreamHeaders = cloneHeader(FilterUpstreamHeaders(streamResult.Headers))
|
||||||
|
if upstreamHeaders == nil {
|
||||||
|
upstreamHeaders = make(http.Header)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
chunks := streamResult.Chunks
|
||||||
dataChan := make(chan []byte)
|
dataChan := make(chan []byte)
|
||||||
errChan := make(chan *interfaces.ErrorMessage, 1)
|
errChan := make(chan *interfaces.ErrorMessage, 1)
|
||||||
go func() {
|
go func() {
|
||||||
@@ -665,9 +689,12 @@ func (h *BaseAPIHandler) ExecuteStreamWithAuthManager(ctx context.Context, handl
|
|||||||
if !sentPayload {
|
if !sentPayload {
|
||||||
if bootstrapRetries < maxBootstrapRetries && bootstrapEligible(streamErr) {
|
if bootstrapRetries < maxBootstrapRetries && bootstrapEligible(streamErr) {
|
||||||
bootstrapRetries++
|
bootstrapRetries++
|
||||||
retryChunks, retryErr := h.AuthManager.ExecuteStream(ctx, providers, req, opts)
|
retryResult, retryErr := h.AuthManager.ExecuteStream(ctx, providers, req, opts)
|
||||||
if retryErr == nil {
|
if retryErr == nil {
|
||||||
chunks = retryChunks
|
if passthroughHeadersEnabled {
|
||||||
|
replaceHeader(upstreamHeaders, FilterUpstreamHeaders(retryResult.Headers))
|
||||||
|
}
|
||||||
|
chunks = retryResult.Chunks
|
||||||
continue outer
|
continue outer
|
||||||
}
|
}
|
||||||
streamErr = retryErr
|
streamErr = retryErr
|
||||||
@@ -690,6 +717,12 @@ func (h *BaseAPIHandler) ExecuteStreamWithAuthManager(ctx context.Context, handl
|
|||||||
return
|
return
|
||||||
}
|
}
|
||||||
if len(chunk.Payload) > 0 {
|
if len(chunk.Payload) > 0 {
|
||||||
|
if handlerType == "openai-response" {
|
||||||
|
if err := validateSSEDataJSON(chunk.Payload); err != nil {
|
||||||
|
_ = sendErr(&interfaces.ErrorMessage{StatusCode: http.StatusBadGateway, Error: err})
|
||||||
|
return
|
||||||
|
}
|
||||||
|
}
|
||||||
sentPayload = true
|
sentPayload = true
|
||||||
if okSendData := sendData(cloneBytes(chunk.Payload)); !okSendData {
|
if okSendData := sendData(cloneBytes(chunk.Payload)); !okSendData {
|
||||||
return
|
return
|
||||||
@@ -698,7 +731,36 @@ func (h *BaseAPIHandler) ExecuteStreamWithAuthManager(ctx context.Context, handl
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
}()
|
}()
|
||||||
return dataChan, errChan
|
return dataChan, upstreamHeaders, errChan
|
||||||
|
}
|
||||||
|
|
||||||
|
func validateSSEDataJSON(chunk []byte) error {
|
||||||
|
for _, line := range bytes.Split(chunk, []byte("\n")) {
|
||||||
|
line = bytes.TrimSpace(line)
|
||||||
|
if len(line) == 0 {
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
if !bytes.HasPrefix(line, []byte("data:")) {
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
data := bytes.TrimSpace(line[5:])
|
||||||
|
if len(data) == 0 {
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
if bytes.Equal(data, []byte("[DONE]")) {
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
if json.Valid(data) {
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
const max = 512
|
||||||
|
preview := data
|
||||||
|
if len(preview) > max {
|
||||||
|
preview = preview[:max]
|
||||||
|
}
|
||||||
|
return fmt.Errorf("invalid SSE data JSON (len=%d): %q", len(data), preview)
|
||||||
|
}
|
||||||
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func statusFromError(err error) int {
|
func statusFromError(err error) int {
|
||||||
@@ -758,13 +820,33 @@ func cloneBytes(src []byte) []byte {
|
|||||||
return dst
|
return dst
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func cloneHeader(src http.Header) http.Header {
|
||||||
|
if src == nil {
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
dst := make(http.Header, len(src))
|
||||||
|
for key, values := range src {
|
||||||
|
dst[key] = append([]string(nil), values...)
|
||||||
|
}
|
||||||
|
return dst
|
||||||
|
}
|
||||||
|
|
||||||
|
func replaceHeader(dst http.Header, src http.Header) {
|
||||||
|
for key := range dst {
|
||||||
|
delete(dst, key)
|
||||||
|
}
|
||||||
|
for key, values := range src {
|
||||||
|
dst[key] = append([]string(nil), values...)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
// WriteErrorResponse writes an error message to the response writer using the HTTP status embedded in the message.
|
// WriteErrorResponse writes an error message to the response writer using the HTTP status embedded in the message.
|
||||||
func (h *BaseAPIHandler) WriteErrorResponse(c *gin.Context, msg *interfaces.ErrorMessage) {
|
func (h *BaseAPIHandler) WriteErrorResponse(c *gin.Context, msg *interfaces.ErrorMessage) {
|
||||||
status := http.StatusInternalServerError
|
status := http.StatusInternalServerError
|
||||||
if msg != nil && msg.StatusCode > 0 {
|
if msg != nil && msg.StatusCode > 0 {
|
||||||
status = msg.StatusCode
|
status = msg.StatusCode
|
||||||
}
|
}
|
||||||
if msg != nil && msg.Addon != nil {
|
if msg != nil && msg.Addon != nil && PassthroughHeadersEnabled(h.Cfg) {
|
||||||
for key, values := range msg.Addon {
|
for key, values := range msg.Addon {
|
||||||
if len(values) == 0 {
|
if len(values) == 0 {
|
||||||
continue
|
continue
|
||||||
|
|||||||
68
sdk/api/handlers/handlers_error_response_test.go
Normal file
68
sdk/api/handlers/handlers_error_response_test.go
Normal file
@@ -0,0 +1,68 @@
|
|||||||
|
package handlers
|
||||||
|
|
||||||
|
import (
|
||||||
|
"errors"
|
||||||
|
"net/http"
|
||||||
|
"net/http/httptest"
|
||||||
|
"reflect"
|
||||||
|
"testing"
|
||||||
|
|
||||||
|
"github.com/gin-gonic/gin"
|
||||||
|
"github.com/router-for-me/CLIProxyAPI/v6/internal/interfaces"
|
||||||
|
sdkconfig "github.com/router-for-me/CLIProxyAPI/v6/sdk/config"
|
||||||
|
)
|
||||||
|
|
||||||
|
func TestWriteErrorResponse_AddonHeadersDisabledByDefault(t *testing.T) {
|
||||||
|
gin.SetMode(gin.TestMode)
|
||||||
|
recorder := httptest.NewRecorder()
|
||||||
|
c, _ := gin.CreateTestContext(recorder)
|
||||||
|
c.Request = httptest.NewRequest(http.MethodGet, "/", nil)
|
||||||
|
|
||||||
|
handler := NewBaseAPIHandlers(nil, nil)
|
||||||
|
handler.WriteErrorResponse(c, &interfaces.ErrorMessage{
|
||||||
|
StatusCode: http.StatusTooManyRequests,
|
||||||
|
Error: errors.New("rate limit"),
|
||||||
|
Addon: http.Header{
|
||||||
|
"Retry-After": {"30"},
|
||||||
|
"X-Request-Id": {"req-1"},
|
||||||
|
},
|
||||||
|
})
|
||||||
|
|
||||||
|
if recorder.Code != http.StatusTooManyRequests {
|
||||||
|
t.Fatalf("status = %d, want %d", recorder.Code, http.StatusTooManyRequests)
|
||||||
|
}
|
||||||
|
if got := recorder.Header().Get("Retry-After"); got != "" {
|
||||||
|
t.Fatalf("Retry-After should be empty when passthrough is disabled, got %q", got)
|
||||||
|
}
|
||||||
|
if got := recorder.Header().Get("X-Request-Id"); got != "" {
|
||||||
|
t.Fatalf("X-Request-Id should be empty when passthrough is disabled, got %q", got)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestWriteErrorResponse_AddonHeadersEnabled(t *testing.T) {
|
||||||
|
gin.SetMode(gin.TestMode)
|
||||||
|
recorder := httptest.NewRecorder()
|
||||||
|
c, _ := gin.CreateTestContext(recorder)
|
||||||
|
c.Request = httptest.NewRequest(http.MethodGet, "/", nil)
|
||||||
|
c.Writer.Header().Set("X-Request-Id", "old-value")
|
||||||
|
|
||||||
|
handler := NewBaseAPIHandlers(&sdkconfig.SDKConfig{PassthroughHeaders: true}, nil)
|
||||||
|
handler.WriteErrorResponse(c, &interfaces.ErrorMessage{
|
||||||
|
StatusCode: http.StatusTooManyRequests,
|
||||||
|
Error: errors.New("rate limit"),
|
||||||
|
Addon: http.Header{
|
||||||
|
"Retry-After": {"30"},
|
||||||
|
"X-Request-Id": {"new-1", "new-2"},
|
||||||
|
},
|
||||||
|
})
|
||||||
|
|
||||||
|
if recorder.Code != http.StatusTooManyRequests {
|
||||||
|
t.Fatalf("status = %d, want %d", recorder.Code, http.StatusTooManyRequests)
|
||||||
|
}
|
||||||
|
if got := recorder.Header().Get("Retry-After"); got != "30" {
|
||||||
|
t.Fatalf("Retry-After = %q, want %q", got, "30")
|
||||||
|
}
|
||||||
|
if got := recorder.Header().Values("X-Request-Id"); !reflect.DeepEqual(got, []string{"new-1", "new-2"}) {
|
||||||
|
t.Fatalf("X-Request-Id = %#v, want %#v", got, []string{"new-1", "new-2"})
|
||||||
|
}
|
||||||
|
}
|
||||||
@@ -23,7 +23,7 @@ func (e *failOnceStreamExecutor) Execute(context.Context, *coreauth.Auth, coreex
|
|||||||
return coreexecutor.Response{}, &coreauth.Error{Code: "not_implemented", Message: "Execute not implemented"}
|
return coreexecutor.Response{}, &coreauth.Error{Code: "not_implemented", Message: "Execute not implemented"}
|
||||||
}
|
}
|
||||||
|
|
||||||
func (e *failOnceStreamExecutor) ExecuteStream(context.Context, *coreauth.Auth, coreexecutor.Request, coreexecutor.Options) (<-chan coreexecutor.StreamChunk, error) {
|
func (e *failOnceStreamExecutor) ExecuteStream(context.Context, *coreauth.Auth, coreexecutor.Request, coreexecutor.Options) (*coreexecutor.StreamResult, error) {
|
||||||
e.mu.Lock()
|
e.mu.Lock()
|
||||||
e.calls++
|
e.calls++
|
||||||
call := e.calls
|
call := e.calls
|
||||||
@@ -40,12 +40,18 @@ func (e *failOnceStreamExecutor) ExecuteStream(context.Context, *coreauth.Auth,
|
|||||||
},
|
},
|
||||||
}
|
}
|
||||||
close(ch)
|
close(ch)
|
||||||
return ch, nil
|
return &coreexecutor.StreamResult{
|
||||||
|
Headers: http.Header{"X-Upstream-Attempt": {"1"}},
|
||||||
|
Chunks: ch,
|
||||||
|
}, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
ch <- coreexecutor.StreamChunk{Payload: []byte("ok")}
|
ch <- coreexecutor.StreamChunk{Payload: []byte("ok")}
|
||||||
close(ch)
|
close(ch)
|
||||||
return ch, nil
|
return &coreexecutor.StreamResult{
|
||||||
|
Headers: http.Header{"X-Upstream-Attempt": {"2"}},
|
||||||
|
Chunks: ch,
|
||||||
|
}, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func (e *failOnceStreamExecutor) Refresh(ctx context.Context, auth *coreauth.Auth) (*coreauth.Auth, error) {
|
func (e *failOnceStreamExecutor) Refresh(ctx context.Context, auth *coreauth.Auth) (*coreauth.Auth, error) {
|
||||||
@@ -81,7 +87,7 @@ func (e *payloadThenErrorStreamExecutor) Execute(context.Context, *coreauth.Auth
|
|||||||
return coreexecutor.Response{}, &coreauth.Error{Code: "not_implemented", Message: "Execute not implemented"}
|
return coreexecutor.Response{}, &coreauth.Error{Code: "not_implemented", Message: "Execute not implemented"}
|
||||||
}
|
}
|
||||||
|
|
||||||
func (e *payloadThenErrorStreamExecutor) ExecuteStream(context.Context, *coreauth.Auth, coreexecutor.Request, coreexecutor.Options) (<-chan coreexecutor.StreamChunk, error) {
|
func (e *payloadThenErrorStreamExecutor) ExecuteStream(context.Context, *coreauth.Auth, coreexecutor.Request, coreexecutor.Options) (*coreexecutor.StreamResult, error) {
|
||||||
e.mu.Lock()
|
e.mu.Lock()
|
||||||
e.calls++
|
e.calls++
|
||||||
e.mu.Unlock()
|
e.mu.Unlock()
|
||||||
@@ -97,7 +103,7 @@ func (e *payloadThenErrorStreamExecutor) ExecuteStream(context.Context, *coreaut
|
|||||||
},
|
},
|
||||||
}
|
}
|
||||||
close(ch)
|
close(ch)
|
||||||
return ch, nil
|
return &coreexecutor.StreamResult{Chunks: ch}, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func (e *payloadThenErrorStreamExecutor) Refresh(ctx context.Context, auth *coreauth.Auth) (*coreauth.Auth, error) {
|
func (e *payloadThenErrorStreamExecutor) Refresh(ctx context.Context, auth *coreauth.Auth) (*coreauth.Auth, error) {
|
||||||
@@ -128,13 +134,44 @@ type authAwareStreamExecutor struct {
|
|||||||
authIDs []string
|
authIDs []string
|
||||||
}
|
}
|
||||||
|
|
||||||
|
type invalidJSONStreamExecutor struct{}
|
||||||
|
|
||||||
|
func (e *invalidJSONStreamExecutor) Identifier() string { return "codex" }
|
||||||
|
|
||||||
|
func (e *invalidJSONStreamExecutor) Execute(context.Context, *coreauth.Auth, coreexecutor.Request, coreexecutor.Options) (coreexecutor.Response, error) {
|
||||||
|
return coreexecutor.Response{}, &coreauth.Error{Code: "not_implemented", Message: "Execute not implemented"}
|
||||||
|
}
|
||||||
|
|
||||||
|
func (e *invalidJSONStreamExecutor) ExecuteStream(context.Context, *coreauth.Auth, coreexecutor.Request, coreexecutor.Options) (*coreexecutor.StreamResult, error) {
|
||||||
|
ch := make(chan coreexecutor.StreamChunk, 1)
|
||||||
|
ch <- coreexecutor.StreamChunk{Payload: []byte("event: response.completed\ndata: {\"type\"")}
|
||||||
|
close(ch)
|
||||||
|
return &coreexecutor.StreamResult{Chunks: ch}, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (e *invalidJSONStreamExecutor) Refresh(ctx context.Context, auth *coreauth.Auth) (*coreauth.Auth, error) {
|
||||||
|
return auth, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (e *invalidJSONStreamExecutor) CountTokens(context.Context, *coreauth.Auth, coreexecutor.Request, coreexecutor.Options) (coreexecutor.Response, error) {
|
||||||
|
return coreexecutor.Response{}, &coreauth.Error{Code: "not_implemented", Message: "CountTokens not implemented"}
|
||||||
|
}
|
||||||
|
|
||||||
|
func (e *invalidJSONStreamExecutor) HttpRequest(ctx context.Context, auth *coreauth.Auth, req *http.Request) (*http.Response, error) {
|
||||||
|
return nil, &coreauth.Error{
|
||||||
|
Code: "not_implemented",
|
||||||
|
Message: "HttpRequest not implemented",
|
||||||
|
HTTPStatus: http.StatusNotImplemented,
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
func (e *authAwareStreamExecutor) Identifier() string { return "codex" }
|
func (e *authAwareStreamExecutor) Identifier() string { return "codex" }
|
||||||
|
|
||||||
func (e *authAwareStreamExecutor) Execute(context.Context, *coreauth.Auth, coreexecutor.Request, coreexecutor.Options) (coreexecutor.Response, error) {
|
func (e *authAwareStreamExecutor) Execute(context.Context, *coreauth.Auth, coreexecutor.Request, coreexecutor.Options) (coreexecutor.Response, error) {
|
||||||
return coreexecutor.Response{}, &coreauth.Error{Code: "not_implemented", Message: "Execute not implemented"}
|
return coreexecutor.Response{}, &coreauth.Error{Code: "not_implemented", Message: "Execute not implemented"}
|
||||||
}
|
}
|
||||||
|
|
||||||
func (e *authAwareStreamExecutor) ExecuteStream(ctx context.Context, auth *coreauth.Auth, req coreexecutor.Request, opts coreexecutor.Options) (<-chan coreexecutor.StreamChunk, error) {
|
func (e *authAwareStreamExecutor) ExecuteStream(ctx context.Context, auth *coreauth.Auth, req coreexecutor.Request, opts coreexecutor.Options) (*coreexecutor.StreamResult, error) {
|
||||||
_ = ctx
|
_ = ctx
|
||||||
_ = req
|
_ = req
|
||||||
_ = opts
|
_ = opts
|
||||||
@@ -160,12 +197,12 @@ func (e *authAwareStreamExecutor) ExecuteStream(ctx context.Context, auth *corea
|
|||||||
},
|
},
|
||||||
}
|
}
|
||||||
close(ch)
|
close(ch)
|
||||||
return ch, nil
|
return &coreexecutor.StreamResult{Chunks: ch}, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
ch <- coreexecutor.StreamChunk{Payload: []byte("ok")}
|
ch <- coreexecutor.StreamChunk{Payload: []byte("ok")}
|
||||||
close(ch)
|
close(ch)
|
||||||
return ch, nil
|
return &coreexecutor.StreamResult{Chunks: ch}, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func (e *authAwareStreamExecutor) Refresh(ctx context.Context, auth *coreauth.Auth) (*coreauth.Auth, error) {
|
func (e *authAwareStreamExecutor) Refresh(ctx context.Context, auth *coreauth.Auth) (*coreauth.Auth, error) {
|
||||||
@@ -231,11 +268,12 @@ func TestExecuteStreamWithAuthManager_RetriesBeforeFirstByte(t *testing.T) {
|
|||||||
})
|
})
|
||||||
|
|
||||||
handler := NewBaseAPIHandlers(&sdkconfig.SDKConfig{
|
handler := NewBaseAPIHandlers(&sdkconfig.SDKConfig{
|
||||||
|
PassthroughHeaders: true,
|
||||||
Streaming: sdkconfig.StreamingConfig{
|
Streaming: sdkconfig.StreamingConfig{
|
||||||
BootstrapRetries: 1,
|
BootstrapRetries: 1,
|
||||||
},
|
},
|
||||||
}, manager)
|
}, manager)
|
||||||
dataChan, errChan := handler.ExecuteStreamWithAuthManager(context.Background(), "openai", "test-model", []byte(`{"model":"test-model"}`), "")
|
dataChan, upstreamHeaders, errChan := handler.ExecuteStreamWithAuthManager(context.Background(), "openai", "test-model", []byte(`{"model":"test-model"}`), "")
|
||||||
if dataChan == nil || errChan == nil {
|
if dataChan == nil || errChan == nil {
|
||||||
t.Fatalf("expected non-nil channels")
|
t.Fatalf("expected non-nil channels")
|
||||||
}
|
}
|
||||||
@@ -257,6 +295,70 @@ func TestExecuteStreamWithAuthManager_RetriesBeforeFirstByte(t *testing.T) {
|
|||||||
if executor.Calls() != 2 {
|
if executor.Calls() != 2 {
|
||||||
t.Fatalf("expected 2 stream attempts, got %d", executor.Calls())
|
t.Fatalf("expected 2 stream attempts, got %d", executor.Calls())
|
||||||
}
|
}
|
||||||
|
upstreamAttemptHeader := upstreamHeaders.Get("X-Upstream-Attempt")
|
||||||
|
if upstreamAttemptHeader != "2" {
|
||||||
|
t.Fatalf("expected upstream header from retry attempt, got %q", upstreamAttemptHeader)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestExecuteStreamWithAuthManager_HeaderPassthroughDisabledByDefault(t *testing.T) {
|
||||||
|
executor := &failOnceStreamExecutor{}
|
||||||
|
manager := coreauth.NewManager(nil, nil, nil)
|
||||||
|
manager.RegisterExecutor(executor)
|
||||||
|
|
||||||
|
auth1 := &coreauth.Auth{
|
||||||
|
ID: "auth1",
|
||||||
|
Provider: "codex",
|
||||||
|
Status: coreauth.StatusActive,
|
||||||
|
Metadata: map[string]any{"email": "test1@example.com"},
|
||||||
|
}
|
||||||
|
if _, err := manager.Register(context.Background(), auth1); err != nil {
|
||||||
|
t.Fatalf("manager.Register(auth1): %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
auth2 := &coreauth.Auth{
|
||||||
|
ID: "auth2",
|
||||||
|
Provider: "codex",
|
||||||
|
Status: coreauth.StatusActive,
|
||||||
|
Metadata: map[string]any{"email": "test2@example.com"},
|
||||||
|
}
|
||||||
|
if _, err := manager.Register(context.Background(), auth2); err != nil {
|
||||||
|
t.Fatalf("manager.Register(auth2): %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
registry.GetGlobalRegistry().RegisterClient(auth1.ID, auth1.Provider, []*registry.ModelInfo{{ID: "test-model"}})
|
||||||
|
registry.GetGlobalRegistry().RegisterClient(auth2.ID, auth2.Provider, []*registry.ModelInfo{{ID: "test-model"}})
|
||||||
|
t.Cleanup(func() {
|
||||||
|
registry.GetGlobalRegistry().UnregisterClient(auth1.ID)
|
||||||
|
registry.GetGlobalRegistry().UnregisterClient(auth2.ID)
|
||||||
|
})
|
||||||
|
|
||||||
|
handler := NewBaseAPIHandlers(&sdkconfig.SDKConfig{
|
||||||
|
Streaming: sdkconfig.StreamingConfig{
|
||||||
|
BootstrapRetries: 1,
|
||||||
|
},
|
||||||
|
}, manager)
|
||||||
|
dataChan, upstreamHeaders, errChan := handler.ExecuteStreamWithAuthManager(context.Background(), "openai", "test-model", []byte(`{"model":"test-model"}`), "")
|
||||||
|
if dataChan == nil || errChan == nil {
|
||||||
|
t.Fatalf("expected non-nil channels")
|
||||||
|
}
|
||||||
|
|
||||||
|
var got []byte
|
||||||
|
for chunk := range dataChan {
|
||||||
|
got = append(got, chunk...)
|
||||||
|
}
|
||||||
|
for msg := range errChan {
|
||||||
|
if msg != nil {
|
||||||
|
t.Fatalf("unexpected error: %+v", msg)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
if string(got) != "ok" {
|
||||||
|
t.Fatalf("expected payload ok, got %q", string(got))
|
||||||
|
}
|
||||||
|
if upstreamHeaders != nil {
|
||||||
|
t.Fatalf("expected nil upstream headers when passthrough is disabled, got %#v", upstreamHeaders)
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
func TestExecuteStreamWithAuthManager_DoesNotRetryAfterFirstByte(t *testing.T) {
|
func TestExecuteStreamWithAuthManager_DoesNotRetryAfterFirstByte(t *testing.T) {
|
||||||
@@ -296,7 +398,7 @@ func TestExecuteStreamWithAuthManager_DoesNotRetryAfterFirstByte(t *testing.T) {
|
|||||||
BootstrapRetries: 1,
|
BootstrapRetries: 1,
|
||||||
},
|
},
|
||||||
}, manager)
|
}, manager)
|
||||||
dataChan, errChan := handler.ExecuteStreamWithAuthManager(context.Background(), "openai", "test-model", []byte(`{"model":"test-model"}`), "")
|
dataChan, _, errChan := handler.ExecuteStreamWithAuthManager(context.Background(), "openai", "test-model", []byte(`{"model":"test-model"}`), "")
|
||||||
if dataChan == nil || errChan == nil {
|
if dataChan == nil || errChan == nil {
|
||||||
t.Fatalf("expected non-nil channels")
|
t.Fatalf("expected non-nil channels")
|
||||||
}
|
}
|
||||||
@@ -367,7 +469,7 @@ func TestExecuteStreamWithAuthManager_PinnedAuthKeepsSameUpstream(t *testing.T)
|
|||||||
},
|
},
|
||||||
}, manager)
|
}, manager)
|
||||||
ctx := WithPinnedAuthID(context.Background(), "auth1")
|
ctx := WithPinnedAuthID(context.Background(), "auth1")
|
||||||
dataChan, errChan := handler.ExecuteStreamWithAuthManager(ctx, "openai", "test-model", []byte(`{"model":"test-model"}`), "")
|
dataChan, _, errChan := handler.ExecuteStreamWithAuthManager(ctx, "openai", "test-model", []byte(`{"model":"test-model"}`), "")
|
||||||
if dataChan == nil || errChan == nil {
|
if dataChan == nil || errChan == nil {
|
||||||
t.Fatalf("expected non-nil channels")
|
t.Fatalf("expected non-nil channels")
|
||||||
}
|
}
|
||||||
@@ -431,7 +533,7 @@ func TestExecuteStreamWithAuthManager_SelectedAuthCallbackReceivesAuthID(t *test
|
|||||||
ctx := WithSelectedAuthIDCallback(context.Background(), func(authID string) {
|
ctx := WithSelectedAuthIDCallback(context.Background(), func(authID string) {
|
||||||
selectedAuthID = authID
|
selectedAuthID = authID
|
||||||
})
|
})
|
||||||
dataChan, errChan := handler.ExecuteStreamWithAuthManager(ctx, "openai", "test-model", []byte(`{"model":"test-model"}`), "")
|
dataChan, _, errChan := handler.ExecuteStreamWithAuthManager(ctx, "openai", "test-model", []byte(`{"model":"test-model"}`), "")
|
||||||
if dataChan == nil || errChan == nil {
|
if dataChan == nil || errChan == nil {
|
||||||
t.Fatalf("expected non-nil channels")
|
t.Fatalf("expected non-nil channels")
|
||||||
}
|
}
|
||||||
@@ -453,3 +555,55 @@ func TestExecuteStreamWithAuthManager_SelectedAuthCallbackReceivesAuthID(t *test
|
|||||||
t.Fatalf("selectedAuthID = %q, want %q", selectedAuthID, "auth2")
|
t.Fatalf("selectedAuthID = %q, want %q", selectedAuthID, "auth2")
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func TestExecuteStreamWithAuthManager_ValidatesOpenAIResponsesStreamDataJSON(t *testing.T) {
|
||||||
|
executor := &invalidJSONStreamExecutor{}
|
||||||
|
manager := coreauth.NewManager(nil, nil, nil)
|
||||||
|
manager.RegisterExecutor(executor)
|
||||||
|
|
||||||
|
auth1 := &coreauth.Auth{
|
||||||
|
ID: "auth1",
|
||||||
|
Provider: "codex",
|
||||||
|
Status: coreauth.StatusActive,
|
||||||
|
Metadata: map[string]any{"email": "test1@example.com"},
|
||||||
|
}
|
||||||
|
if _, err := manager.Register(context.Background(), auth1); err != nil {
|
||||||
|
t.Fatalf("manager.Register(auth1): %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
registry.GetGlobalRegistry().RegisterClient(auth1.ID, auth1.Provider, []*registry.ModelInfo{{ID: "test-model"}})
|
||||||
|
t.Cleanup(func() {
|
||||||
|
registry.GetGlobalRegistry().UnregisterClient(auth1.ID)
|
||||||
|
})
|
||||||
|
|
||||||
|
handler := NewBaseAPIHandlers(&sdkconfig.SDKConfig{}, manager)
|
||||||
|
dataChan, _, errChan := handler.ExecuteStreamWithAuthManager(context.Background(), "openai-response", "test-model", []byte(`{"model":"test-model"}`), "")
|
||||||
|
if dataChan == nil || errChan == nil {
|
||||||
|
t.Fatalf("expected non-nil channels")
|
||||||
|
}
|
||||||
|
|
||||||
|
var got []byte
|
||||||
|
for chunk := range dataChan {
|
||||||
|
got = append(got, chunk...)
|
||||||
|
}
|
||||||
|
if len(got) != 0 {
|
||||||
|
t.Fatalf("expected empty payload, got %q", string(got))
|
||||||
|
}
|
||||||
|
|
||||||
|
gotErr := false
|
||||||
|
for msg := range errChan {
|
||||||
|
if msg == nil {
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
if msg.StatusCode != http.StatusBadGateway {
|
||||||
|
t.Fatalf("expected status %d, got %d", http.StatusBadGateway, msg.StatusCode)
|
||||||
|
}
|
||||||
|
if msg.Error == nil {
|
||||||
|
t.Fatalf("expected error")
|
||||||
|
}
|
||||||
|
gotErr = true
|
||||||
|
}
|
||||||
|
if !gotErr {
|
||||||
|
t.Fatalf("expected terminal error")
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|||||||
80
sdk/api/handlers/header_filter.go
Normal file
80
sdk/api/handlers/header_filter.go
Normal file
@@ -0,0 +1,80 @@
|
|||||||
|
package handlers
|
||||||
|
|
||||||
|
import (
|
||||||
|
"net/http"
|
||||||
|
"strings"
|
||||||
|
)
|
||||||
|
|
||||||
|
// hopByHopHeaders lists RFC 7230 Section 6.1 hop-by-hop headers that MUST NOT
|
||||||
|
// be forwarded by proxies, plus security-sensitive headers that should not leak.
|
||||||
|
var hopByHopHeaders = map[string]struct{}{
|
||||||
|
// RFC 7230 hop-by-hop
|
||||||
|
"Connection": {},
|
||||||
|
"Keep-Alive": {},
|
||||||
|
"Proxy-Authenticate": {},
|
||||||
|
"Proxy-Authorization": {},
|
||||||
|
"Te": {},
|
||||||
|
"Trailer": {},
|
||||||
|
"Transfer-Encoding": {},
|
||||||
|
"Upgrade": {},
|
||||||
|
// Security-sensitive
|
||||||
|
"Set-Cookie": {},
|
||||||
|
// CPA-managed (set by handlers, not upstream)
|
||||||
|
"Content-Length": {},
|
||||||
|
"Content-Encoding": {},
|
||||||
|
}
|
||||||
|
|
||||||
|
// FilterUpstreamHeaders returns a copy of src with hop-by-hop and security-sensitive
|
||||||
|
// headers removed. Returns nil if src is nil or empty after filtering.
|
||||||
|
func FilterUpstreamHeaders(src http.Header) http.Header {
|
||||||
|
if src == nil {
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
connectionScoped := connectionScopedHeaders(src)
|
||||||
|
dst := make(http.Header)
|
||||||
|
for key, values := range src {
|
||||||
|
canonicalKey := http.CanonicalHeaderKey(key)
|
||||||
|
if _, blocked := hopByHopHeaders[canonicalKey]; blocked {
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
if _, scoped := connectionScoped[canonicalKey]; scoped {
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
dst[key] = values
|
||||||
|
}
|
||||||
|
if len(dst) == 0 {
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
return dst
|
||||||
|
}
|
||||||
|
|
||||||
|
func connectionScopedHeaders(src http.Header) map[string]struct{} {
|
||||||
|
scoped := make(map[string]struct{})
|
||||||
|
for _, rawValue := range src.Values("Connection") {
|
||||||
|
for _, token := range strings.Split(rawValue, ",") {
|
||||||
|
headerName := strings.TrimSpace(token)
|
||||||
|
if headerName == "" {
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
scoped[http.CanonicalHeaderKey(headerName)] = struct{}{}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return scoped
|
||||||
|
}
|
||||||
|
|
||||||
|
// WriteUpstreamHeaders writes filtered upstream headers to the gin response writer.
|
||||||
|
// Headers already set by CPA (e.g., Content-Type) are NOT overwritten.
|
||||||
|
func WriteUpstreamHeaders(dst http.Header, src http.Header) {
|
||||||
|
if src == nil {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
for key, values := range src {
|
||||||
|
// Don't overwrite headers already set by CPA handlers
|
||||||
|
if dst.Get(key) != "" {
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
for _, v := range values {
|
||||||
|
dst.Add(key, v)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
55
sdk/api/handlers/header_filter_test.go
Normal file
55
sdk/api/handlers/header_filter_test.go
Normal file
@@ -0,0 +1,55 @@
|
|||||||
|
package handlers
|
||||||
|
|
||||||
|
import (
|
||||||
|
"net/http"
|
||||||
|
"testing"
|
||||||
|
)
|
||||||
|
|
||||||
|
func TestFilterUpstreamHeaders_RemovesConnectionScopedHeaders(t *testing.T) {
|
||||||
|
src := http.Header{}
|
||||||
|
src.Add("Connection", "keep-alive, x-hop-a, x-hop-b")
|
||||||
|
src.Add("Connection", "x-hop-c")
|
||||||
|
src.Set("Keep-Alive", "timeout=5")
|
||||||
|
src.Set("X-Hop-A", "a")
|
||||||
|
src.Set("X-Hop-B", "b")
|
||||||
|
src.Set("X-Hop-C", "c")
|
||||||
|
src.Set("X-Request-Id", "req-1")
|
||||||
|
src.Set("Set-Cookie", "session=secret")
|
||||||
|
|
||||||
|
filtered := FilterUpstreamHeaders(src)
|
||||||
|
if filtered == nil {
|
||||||
|
t.Fatalf("expected filtered headers, got nil")
|
||||||
|
}
|
||||||
|
|
||||||
|
requestID := filtered.Get("X-Request-Id")
|
||||||
|
if requestID != "req-1" {
|
||||||
|
t.Fatalf("expected X-Request-Id to be preserved, got %q", requestID)
|
||||||
|
}
|
||||||
|
|
||||||
|
blockedHeaderKeys := []string{
|
||||||
|
"Connection",
|
||||||
|
"Keep-Alive",
|
||||||
|
"X-Hop-A",
|
||||||
|
"X-Hop-B",
|
||||||
|
"X-Hop-C",
|
||||||
|
"Set-Cookie",
|
||||||
|
}
|
||||||
|
for _, key := range blockedHeaderKeys {
|
||||||
|
value := filtered.Get(key)
|
||||||
|
if value != "" {
|
||||||
|
t.Fatalf("expected %s to be removed, got %q", key, value)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestFilterUpstreamHeaders_ReturnsNilWhenAllHeadersBlocked(t *testing.T) {
|
||||||
|
src := http.Header{}
|
||||||
|
src.Add("Connection", "x-hop-a")
|
||||||
|
src.Set("X-Hop-A", "a")
|
||||||
|
src.Set("Set-Cookie", "session=secret")
|
||||||
|
|
||||||
|
filtered := FilterUpstreamHeaders(src)
|
||||||
|
if filtered != nil {
|
||||||
|
t.Fatalf("expected nil when all headers are filtered, got %#v", filtered)
|
||||||
|
}
|
||||||
|
}
|
||||||
@@ -519,12 +519,13 @@ func (h *OpenAIAPIHandler) handleNonStreamingResponse(c *gin.Context, rawJSON []
|
|||||||
|
|
||||||
modelName := gjson.GetBytes(rawJSON, "model").String()
|
modelName := gjson.GetBytes(rawJSON, "model").String()
|
||||||
cliCtx, cliCancel := h.GetContextWithCancel(h, c, context.Background())
|
cliCtx, cliCancel := h.GetContextWithCancel(h, c, context.Background())
|
||||||
resp, errMsg := h.ExecuteWithAuthManager(cliCtx, h.HandlerType(), modelName, rawJSON, h.GetAlt(c))
|
resp, upstreamHeaders, errMsg := h.ExecuteWithAuthManager(cliCtx, h.HandlerType(), modelName, rawJSON, h.GetAlt(c))
|
||||||
if errMsg != nil {
|
if errMsg != nil {
|
||||||
h.WriteErrorResponse(c, errMsg)
|
h.WriteErrorResponse(c, errMsg)
|
||||||
cliCancel(errMsg.Error)
|
cliCancel(errMsg.Error)
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
handlers.WriteUpstreamHeaders(c.Writer.Header(), upstreamHeaders)
|
||||||
_, _ = c.Writer.Write(resp)
|
_, _ = c.Writer.Write(resp)
|
||||||
cliCancel()
|
cliCancel()
|
||||||
}
|
}
|
||||||
@@ -534,12 +535,13 @@ func (h *OpenAIAPIHandler) handleNonStreamingResponseViaResponses(c *gin.Context
|
|||||||
|
|
||||||
modelName := gjson.GetBytes(rawJSON, "model").String()
|
modelName := gjson.GetBytes(rawJSON, "model").String()
|
||||||
cliCtx, cliCancel := h.GetContextWithCancel(h, c, context.Background())
|
cliCtx, cliCancel := h.GetContextWithCancel(h, c, context.Background())
|
||||||
resp, errMsg := h.ExecuteWithAuthManager(cliCtx, OpenaiResponse, modelName, rawJSON, h.GetAlt(c))
|
resp, upstreamHeaders, errMsg := h.ExecuteWithAuthManager(cliCtx, OpenaiResponse, modelName, rawJSON, h.GetAlt(c))
|
||||||
if errMsg != nil {
|
if errMsg != nil {
|
||||||
h.WriteErrorResponse(c, errMsg)
|
h.WriteErrorResponse(c, errMsg)
|
||||||
cliCancel(errMsg.Error)
|
cliCancel(errMsg.Error)
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
handlers.WriteUpstreamHeaders(c.Writer.Header(), upstreamHeaders)
|
||||||
converted := convertResponsesObjectToChatCompletion(cliCtx, modelName, originalChatJSON, rawJSON, resp)
|
converted := convertResponsesObjectToChatCompletion(cliCtx, modelName, originalChatJSON, rawJSON, resp)
|
||||||
if converted == nil {
|
if converted == nil {
|
||||||
h.WriteErrorResponse(c, &interfaces.ErrorMessage{
|
h.WriteErrorResponse(c, &interfaces.ErrorMessage{
|
||||||
@@ -575,7 +577,7 @@ func (h *OpenAIAPIHandler) handleStreamingResponse(c *gin.Context, rawJSON []byt
|
|||||||
|
|
||||||
modelName := gjson.GetBytes(rawJSON, "model").String()
|
modelName := gjson.GetBytes(rawJSON, "model").String()
|
||||||
cliCtx, cliCancel := h.GetContextWithCancel(h, c, context.Background())
|
cliCtx, cliCancel := h.GetContextWithCancel(h, c, context.Background())
|
||||||
dataChan, errChan := h.ExecuteStreamWithAuthManager(cliCtx, h.HandlerType(), modelName, rawJSON, h.GetAlt(c))
|
dataChan, upstreamHeaders, errChan := h.ExecuteStreamWithAuthManager(cliCtx, h.HandlerType(), modelName, rawJSON, h.GetAlt(c))
|
||||||
|
|
||||||
setSSEHeaders := func() {
|
setSSEHeaders := func() {
|
||||||
c.Header("Content-Type", "text/event-stream")
|
c.Header("Content-Type", "text/event-stream")
|
||||||
@@ -608,6 +610,7 @@ func (h *OpenAIAPIHandler) handleStreamingResponse(c *gin.Context, rawJSON []byt
|
|||||||
if !ok {
|
if !ok {
|
||||||
// Stream closed without data? Send DONE or just headers.
|
// Stream closed without data? Send DONE or just headers.
|
||||||
setSSEHeaders()
|
setSSEHeaders()
|
||||||
|
handlers.WriteUpstreamHeaders(c.Writer.Header(), upstreamHeaders)
|
||||||
_, _ = fmt.Fprintf(c.Writer, "data: [DONE]\n\n")
|
_, _ = fmt.Fprintf(c.Writer, "data: [DONE]\n\n")
|
||||||
flusher.Flush()
|
flusher.Flush()
|
||||||
cliCancel(nil)
|
cliCancel(nil)
|
||||||
@@ -616,6 +619,7 @@ func (h *OpenAIAPIHandler) handleStreamingResponse(c *gin.Context, rawJSON []byt
|
|||||||
|
|
||||||
// Success! Commit to streaming headers.
|
// Success! Commit to streaming headers.
|
||||||
setSSEHeaders()
|
setSSEHeaders()
|
||||||
|
handlers.WriteUpstreamHeaders(c.Writer.Header(), upstreamHeaders)
|
||||||
|
|
||||||
_, _ = fmt.Fprintf(c.Writer, "data: %s\n\n", string(chunk))
|
_, _ = fmt.Fprintf(c.Writer, "data: %s\n\n", string(chunk))
|
||||||
flusher.Flush()
|
flusher.Flush()
|
||||||
@@ -641,7 +645,7 @@ func (h *OpenAIAPIHandler) handleStreamingResponseViaResponses(c *gin.Context, r
|
|||||||
|
|
||||||
modelName := gjson.GetBytes(rawJSON, "model").String()
|
modelName := gjson.GetBytes(rawJSON, "model").String()
|
||||||
cliCtx, cliCancel := h.GetContextWithCancel(h, c, context.Background())
|
cliCtx, cliCancel := h.GetContextWithCancel(h, c, context.Background())
|
||||||
dataChan, errChan := h.ExecuteStreamWithAuthManager(cliCtx, OpenaiResponse, modelName, rawJSON, h.GetAlt(c))
|
dataChan, upstreamHeaders, errChan := h.ExecuteStreamWithAuthManager(cliCtx, OpenaiResponse, modelName, rawJSON, h.GetAlt(c))
|
||||||
var param any
|
var param any
|
||||||
|
|
||||||
setSSEHeaders := func() {
|
setSSEHeaders := func() {
|
||||||
@@ -672,6 +676,7 @@ func (h *OpenAIAPIHandler) handleStreamingResponseViaResponses(c *gin.Context, r
|
|||||||
case chunk, ok := <-dataChan:
|
case chunk, ok := <-dataChan:
|
||||||
if !ok {
|
if !ok {
|
||||||
setSSEHeaders()
|
setSSEHeaders()
|
||||||
|
handlers.WriteUpstreamHeaders(c.Writer.Header(), upstreamHeaders)
|
||||||
_, _ = fmt.Fprintf(c.Writer, "data: [DONE]\n\n")
|
_, _ = fmt.Fprintf(c.Writer, "data: [DONE]\n\n")
|
||||||
flusher.Flush()
|
flusher.Flush()
|
||||||
cliCancel(nil)
|
cliCancel(nil)
|
||||||
@@ -679,6 +684,7 @@ func (h *OpenAIAPIHandler) handleStreamingResponseViaResponses(c *gin.Context, r
|
|||||||
}
|
}
|
||||||
|
|
||||||
setSSEHeaders()
|
setSSEHeaders()
|
||||||
|
handlers.WriteUpstreamHeaders(c.Writer.Header(), upstreamHeaders)
|
||||||
writeConvertedResponsesChunk(c, cliCtx, modelName, originalChatJSON, rawJSON, chunk, ¶m)
|
writeConvertedResponsesChunk(c, cliCtx, modelName, originalChatJSON, rawJSON, chunk, ¶m)
|
||||||
flusher.Flush()
|
flusher.Flush()
|
||||||
|
|
||||||
@@ -704,13 +710,14 @@ func (h *OpenAIAPIHandler) handleCompletionsNonStreamingResponse(c *gin.Context,
|
|||||||
modelName := gjson.GetBytes(chatCompletionsJSON, "model").String()
|
modelName := gjson.GetBytes(chatCompletionsJSON, "model").String()
|
||||||
cliCtx, cliCancel := h.GetContextWithCancel(h, c, context.Background())
|
cliCtx, cliCancel := h.GetContextWithCancel(h, c, context.Background())
|
||||||
stopKeepAlive := h.StartNonStreamingKeepAlive(c, cliCtx)
|
stopKeepAlive := h.StartNonStreamingKeepAlive(c, cliCtx)
|
||||||
resp, errMsg := h.ExecuteWithAuthManager(cliCtx, h.HandlerType(), modelName, chatCompletionsJSON, "")
|
resp, upstreamHeaders, errMsg := h.ExecuteWithAuthManager(cliCtx, h.HandlerType(), modelName, chatCompletionsJSON, "")
|
||||||
stopKeepAlive()
|
stopKeepAlive()
|
||||||
if errMsg != nil {
|
if errMsg != nil {
|
||||||
h.WriteErrorResponse(c, errMsg)
|
h.WriteErrorResponse(c, errMsg)
|
||||||
cliCancel(errMsg.Error)
|
cliCancel(errMsg.Error)
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
handlers.WriteUpstreamHeaders(c.Writer.Header(), upstreamHeaders)
|
||||||
completionsResp := convertChatCompletionsResponseToCompletions(resp)
|
completionsResp := convertChatCompletionsResponseToCompletions(resp)
|
||||||
_, _ = c.Writer.Write(completionsResp)
|
_, _ = c.Writer.Write(completionsResp)
|
||||||
cliCancel()
|
cliCancel()
|
||||||
@@ -741,7 +748,7 @@ func (h *OpenAIAPIHandler) handleCompletionsStreamingResponse(c *gin.Context, ra
|
|||||||
|
|
||||||
modelName := gjson.GetBytes(chatCompletionsJSON, "model").String()
|
modelName := gjson.GetBytes(chatCompletionsJSON, "model").String()
|
||||||
cliCtx, cliCancel := h.GetContextWithCancel(h, c, context.Background())
|
cliCtx, cliCancel := h.GetContextWithCancel(h, c, context.Background())
|
||||||
dataChan, errChan := h.ExecuteStreamWithAuthManager(cliCtx, h.HandlerType(), modelName, chatCompletionsJSON, "")
|
dataChan, upstreamHeaders, errChan := h.ExecuteStreamWithAuthManager(cliCtx, h.HandlerType(), modelName, chatCompletionsJSON, "")
|
||||||
|
|
||||||
setSSEHeaders := func() {
|
setSSEHeaders := func() {
|
||||||
c.Header("Content-Type", "text/event-stream")
|
c.Header("Content-Type", "text/event-stream")
|
||||||
@@ -772,6 +779,7 @@ func (h *OpenAIAPIHandler) handleCompletionsStreamingResponse(c *gin.Context, ra
|
|||||||
case chunk, ok := <-dataChan:
|
case chunk, ok := <-dataChan:
|
||||||
if !ok {
|
if !ok {
|
||||||
setSSEHeaders()
|
setSSEHeaders()
|
||||||
|
handlers.WriteUpstreamHeaders(c.Writer.Header(), upstreamHeaders)
|
||||||
_, _ = fmt.Fprintf(c.Writer, "data: [DONE]\n\n")
|
_, _ = fmt.Fprintf(c.Writer, "data: [DONE]\n\n")
|
||||||
flusher.Flush()
|
flusher.Flush()
|
||||||
cliCancel(nil)
|
cliCancel(nil)
|
||||||
@@ -780,6 +788,7 @@ func (h *OpenAIAPIHandler) handleCompletionsStreamingResponse(c *gin.Context, ra
|
|||||||
|
|
||||||
// Success! Set headers.
|
// Success! Set headers.
|
||||||
setSSEHeaders()
|
setSSEHeaders()
|
||||||
|
handlers.WriteUpstreamHeaders(c.Writer.Header(), upstreamHeaders)
|
||||||
|
|
||||||
// Write the first chunk
|
// Write the first chunk
|
||||||
converted := convertChatCompletionsStreamChunkToCompletions(chunk)
|
converted := convertChatCompletionsStreamChunkToCompletions(chunk)
|
||||||
|
|||||||
@@ -31,7 +31,7 @@ func (e *compactCaptureExecutor) Execute(ctx context.Context, auth *coreauth.Aut
|
|||||||
return coreexecutor.Response{Payload: []byte(`{"ok":true}`)}, nil
|
return coreexecutor.Response{Payload: []byte(`{"ok":true}`)}, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func (e *compactCaptureExecutor) ExecuteStream(context.Context, *coreauth.Auth, coreexecutor.Request, coreexecutor.Options) (<-chan coreexecutor.StreamChunk, error) {
|
func (e *compactCaptureExecutor) ExecuteStream(context.Context, *coreauth.Auth, coreexecutor.Request, coreexecutor.Options) (*coreexecutor.StreamResult, error) {
|
||||||
return nil, errors.New("not implemented")
|
return nil, errors.New("not implemented")
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|||||||
@@ -139,13 +139,14 @@ func (h *OpenAIResponsesAPIHandler) Compact(c *gin.Context) {
|
|||||||
modelName := gjson.GetBytes(rawJSON, "model").String()
|
modelName := gjson.GetBytes(rawJSON, "model").String()
|
||||||
cliCtx, cliCancel := h.GetContextWithCancel(h, c, context.Background())
|
cliCtx, cliCancel := h.GetContextWithCancel(h, c, context.Background())
|
||||||
stopKeepAlive := h.StartNonStreamingKeepAlive(c, cliCtx)
|
stopKeepAlive := h.StartNonStreamingKeepAlive(c, cliCtx)
|
||||||
resp, errMsg := h.ExecuteWithAuthManager(cliCtx, h.HandlerType(), modelName, rawJSON, "responses/compact")
|
resp, upstreamHeaders, errMsg := h.ExecuteWithAuthManager(cliCtx, h.HandlerType(), modelName, rawJSON, "responses/compact")
|
||||||
stopKeepAlive()
|
stopKeepAlive()
|
||||||
if errMsg != nil {
|
if errMsg != nil {
|
||||||
h.WriteErrorResponse(c, errMsg)
|
h.WriteErrorResponse(c, errMsg)
|
||||||
cliCancel(errMsg.Error)
|
cliCancel(errMsg.Error)
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
handlers.WriteUpstreamHeaders(c.Writer.Header(), upstreamHeaders)
|
||||||
_, _ = c.Writer.Write(resp)
|
_, _ = c.Writer.Write(resp)
|
||||||
cliCancel()
|
cliCancel()
|
||||||
}
|
}
|
||||||
@@ -164,13 +165,14 @@ func (h *OpenAIResponsesAPIHandler) handleNonStreamingResponse(c *gin.Context, r
|
|||||||
cliCtx, cliCancel := h.GetContextWithCancel(h, c, context.Background())
|
cliCtx, cliCancel := h.GetContextWithCancel(h, c, context.Background())
|
||||||
stopKeepAlive := h.StartNonStreamingKeepAlive(c, cliCtx)
|
stopKeepAlive := h.StartNonStreamingKeepAlive(c, cliCtx)
|
||||||
|
|
||||||
resp, errMsg := h.ExecuteWithAuthManager(cliCtx, h.HandlerType(), modelName, rawJSON, "")
|
resp, upstreamHeaders, errMsg := h.ExecuteWithAuthManager(cliCtx, h.HandlerType(), modelName, rawJSON, "")
|
||||||
stopKeepAlive()
|
stopKeepAlive()
|
||||||
if errMsg != nil {
|
if errMsg != nil {
|
||||||
h.WriteErrorResponse(c, errMsg)
|
h.WriteErrorResponse(c, errMsg)
|
||||||
cliCancel(errMsg.Error)
|
cliCancel(errMsg.Error)
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
handlers.WriteUpstreamHeaders(c.Writer.Header(), upstreamHeaders)
|
||||||
_, _ = c.Writer.Write(resp)
|
_, _ = c.Writer.Write(resp)
|
||||||
cliCancel()
|
cliCancel()
|
||||||
}
|
}
|
||||||
@@ -180,12 +182,13 @@ func (h *OpenAIResponsesAPIHandler) handleNonStreamingResponseViaChat(c *gin.Con
|
|||||||
|
|
||||||
modelName := gjson.GetBytes(chatJSON, "model").String()
|
modelName := gjson.GetBytes(chatJSON, "model").String()
|
||||||
cliCtx, cliCancel := h.GetContextWithCancel(h, c, context.Background())
|
cliCtx, cliCancel := h.GetContextWithCancel(h, c, context.Background())
|
||||||
resp, errMsg := h.ExecuteWithAuthManager(cliCtx, OpenAI, modelName, chatJSON, "")
|
resp, upstreamHeaders, errMsg := h.ExecuteWithAuthManager(cliCtx, OpenAI, modelName, chatJSON, "")
|
||||||
if errMsg != nil {
|
if errMsg != nil {
|
||||||
h.WriteErrorResponse(c, errMsg)
|
h.WriteErrorResponse(c, errMsg)
|
||||||
cliCancel(errMsg.Error)
|
cliCancel(errMsg.Error)
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
handlers.WriteUpstreamHeaders(c.Writer.Header(), upstreamHeaders)
|
||||||
var param any
|
var param any
|
||||||
converted := responsesconverter.ConvertOpenAIChatCompletionsResponseToOpenAIResponsesNonStream(cliCtx, modelName, originalResponsesJSON, originalResponsesJSON, resp, ¶m)
|
converted := responsesconverter.ConvertOpenAIChatCompletionsResponseToOpenAIResponsesNonStream(cliCtx, modelName, originalResponsesJSON, originalResponsesJSON, resp, ¶m)
|
||||||
if converted == "" {
|
if converted == "" {
|
||||||
@@ -223,7 +226,7 @@ func (h *OpenAIResponsesAPIHandler) handleStreamingResponse(c *gin.Context, rawJ
|
|||||||
// New core execution path
|
// New core execution path
|
||||||
modelName := gjson.GetBytes(rawJSON, "model").String()
|
modelName := gjson.GetBytes(rawJSON, "model").String()
|
||||||
cliCtx, cliCancel := h.GetContextWithCancel(h, c, context.Background())
|
cliCtx, cliCancel := h.GetContextWithCancel(h, c, context.Background())
|
||||||
dataChan, errChan := h.ExecuteStreamWithAuthManager(cliCtx, h.HandlerType(), modelName, rawJSON, "")
|
dataChan, upstreamHeaders, errChan := h.ExecuteStreamWithAuthManager(cliCtx, h.HandlerType(), modelName, rawJSON, "")
|
||||||
|
|
||||||
setSSEHeaders := func() {
|
setSSEHeaders := func() {
|
||||||
c.Header("Content-Type", "text/event-stream")
|
c.Header("Content-Type", "text/event-stream")
|
||||||
@@ -256,6 +259,7 @@ func (h *OpenAIResponsesAPIHandler) handleStreamingResponse(c *gin.Context, rawJ
|
|||||||
if !ok {
|
if !ok {
|
||||||
// Stream closed without data? Send headers and done.
|
// Stream closed without data? Send headers and done.
|
||||||
setSSEHeaders()
|
setSSEHeaders()
|
||||||
|
handlers.WriteUpstreamHeaders(c.Writer.Header(), upstreamHeaders)
|
||||||
_, _ = c.Writer.Write([]byte("\n"))
|
_, _ = c.Writer.Write([]byte("\n"))
|
||||||
flusher.Flush()
|
flusher.Flush()
|
||||||
cliCancel(nil)
|
cliCancel(nil)
|
||||||
@@ -264,6 +268,7 @@ func (h *OpenAIResponsesAPIHandler) handleStreamingResponse(c *gin.Context, rawJ
|
|||||||
|
|
||||||
// Success! Set headers.
|
// Success! Set headers.
|
||||||
setSSEHeaders()
|
setSSEHeaders()
|
||||||
|
handlers.WriteUpstreamHeaders(c.Writer.Header(), upstreamHeaders)
|
||||||
|
|
||||||
// Write first chunk logic (matching forwardResponsesStream)
|
// Write first chunk logic (matching forwardResponsesStream)
|
||||||
if bytes.HasPrefix(chunk, []byte("event:")) {
|
if bytes.HasPrefix(chunk, []byte("event:")) {
|
||||||
@@ -294,7 +299,7 @@ func (h *OpenAIResponsesAPIHandler) handleStreamingResponseViaChat(c *gin.Contex
|
|||||||
|
|
||||||
modelName := gjson.GetBytes(chatJSON, "model").String()
|
modelName := gjson.GetBytes(chatJSON, "model").String()
|
||||||
cliCtx, cliCancel := h.GetContextWithCancel(h, c, context.Background())
|
cliCtx, cliCancel := h.GetContextWithCancel(h, c, context.Background())
|
||||||
dataChan, errChan := h.ExecuteStreamWithAuthManager(cliCtx, OpenAI, modelName, chatJSON, "")
|
dataChan, upstreamHeaders, errChan := h.ExecuteStreamWithAuthManager(cliCtx, OpenAI, modelName, chatJSON, "")
|
||||||
var param any
|
var param any
|
||||||
|
|
||||||
setSSEHeaders := func() {
|
setSSEHeaders := func() {
|
||||||
@@ -324,6 +329,7 @@ func (h *OpenAIResponsesAPIHandler) handleStreamingResponseViaChat(c *gin.Contex
|
|||||||
case chunk, ok := <-dataChan:
|
case chunk, ok := <-dataChan:
|
||||||
if !ok {
|
if !ok {
|
||||||
setSSEHeaders()
|
setSSEHeaders()
|
||||||
|
handlers.WriteUpstreamHeaders(c.Writer.Header(), upstreamHeaders)
|
||||||
_, _ = c.Writer.Write([]byte("\n"))
|
_, _ = c.Writer.Write([]byte("\n"))
|
||||||
flusher.Flush()
|
flusher.Flush()
|
||||||
cliCancel(nil)
|
cliCancel(nil)
|
||||||
@@ -331,6 +337,7 @@ func (h *OpenAIResponsesAPIHandler) handleStreamingResponseViaChat(c *gin.Contex
|
|||||||
}
|
}
|
||||||
|
|
||||||
setSSEHeaders()
|
setSSEHeaders()
|
||||||
|
handlers.WriteUpstreamHeaders(c.Writer.Header(), upstreamHeaders)
|
||||||
writeChatAsResponsesChunk(c, cliCtx, modelName, originalResponsesJSON, chunk, ¶m)
|
writeChatAsResponsesChunk(c, cliCtx, modelName, originalResponsesJSON, chunk, ¶m)
|
||||||
flusher.Flush()
|
flusher.Flush()
|
||||||
|
|
||||||
@@ -411,8 +418,8 @@ func (h *OpenAIResponsesAPIHandler) forwardResponsesStream(c *gin.Context, flush
|
|||||||
if errMsg.Error != nil && errMsg.Error.Error() != "" {
|
if errMsg.Error != nil && errMsg.Error.Error() != "" {
|
||||||
errText = errMsg.Error.Error()
|
errText = errMsg.Error.Error()
|
||||||
}
|
}
|
||||||
body := handlers.BuildErrorResponseBody(status, errText)
|
chunk := handlers.BuildOpenAIResponsesStreamErrorChunk(status, errText, 0)
|
||||||
_, _ = fmt.Fprintf(c.Writer, "\nevent: error\ndata: %s\n\n", string(body))
|
_, _ = fmt.Fprintf(c.Writer, "\nevent: error\ndata: %s\n\n", string(chunk))
|
||||||
},
|
},
|
||||||
WriteDone: func() {
|
WriteDone: func() {
|
||||||
_, _ = c.Writer.Write([]byte("\n"))
|
_, _ = c.Writer.Write([]byte("\n"))
|
||||||
|
|||||||
@@ -0,0 +1,43 @@
|
|||||||
|
package openai
|
||||||
|
|
||||||
|
import (
|
||||||
|
"errors"
|
||||||
|
"net/http"
|
||||||
|
"net/http/httptest"
|
||||||
|
"strings"
|
||||||
|
"testing"
|
||||||
|
|
||||||
|
"github.com/gin-gonic/gin"
|
||||||
|
"github.com/router-for-me/CLIProxyAPI/v6/internal/interfaces"
|
||||||
|
"github.com/router-for-me/CLIProxyAPI/v6/sdk/api/handlers"
|
||||||
|
sdkconfig "github.com/router-for-me/CLIProxyAPI/v6/sdk/config"
|
||||||
|
)
|
||||||
|
|
||||||
|
func TestForwardResponsesStreamTerminalErrorUsesResponsesErrorChunk(t *testing.T) {
|
||||||
|
gin.SetMode(gin.TestMode)
|
||||||
|
base := handlers.NewBaseAPIHandlers(&sdkconfig.SDKConfig{}, nil)
|
||||||
|
h := NewOpenAIResponsesAPIHandler(base)
|
||||||
|
|
||||||
|
recorder := httptest.NewRecorder()
|
||||||
|
c, _ := gin.CreateTestContext(recorder)
|
||||||
|
c.Request = httptest.NewRequest(http.MethodPost, "/v1/responses", nil)
|
||||||
|
|
||||||
|
flusher, ok := c.Writer.(http.Flusher)
|
||||||
|
if !ok {
|
||||||
|
t.Fatalf("expected gin writer to implement http.Flusher")
|
||||||
|
}
|
||||||
|
|
||||||
|
data := make(chan []byte)
|
||||||
|
errs := make(chan *interfaces.ErrorMessage, 1)
|
||||||
|
errs <- &interfaces.ErrorMessage{StatusCode: http.StatusInternalServerError, Error: errors.New("unexpected EOF")}
|
||||||
|
close(errs)
|
||||||
|
|
||||||
|
h.forwardResponsesStream(c, flusher, func(error) {}, data, errs)
|
||||||
|
body := recorder.Body.String()
|
||||||
|
if !strings.Contains(body, `"type":"error"`) {
|
||||||
|
t.Fatalf("expected responses error chunk, got: %q", body)
|
||||||
|
}
|
||||||
|
if strings.Contains(body, `"error":{`) {
|
||||||
|
t.Fatalf("expected streaming error chunk (top-level type), got HTTP error body: %q", body)
|
||||||
|
}
|
||||||
|
}
|
||||||
@@ -153,7 +153,7 @@ func (h *OpenAIResponsesAPIHandler) ResponsesWebsocket(c *gin.Context) {
|
|||||||
pinnedAuthID = strings.TrimSpace(authID)
|
pinnedAuthID = strings.TrimSpace(authID)
|
||||||
})
|
})
|
||||||
}
|
}
|
||||||
dataChan, errChan := h.ExecuteStreamWithAuthManager(cliCtx, h.HandlerType(), modelName, requestJSON, "")
|
dataChan, _, errChan := h.ExecuteStreamWithAuthManager(cliCtx, h.HandlerType(), modelName, requestJSON, "")
|
||||||
|
|
||||||
completedOutput, errForward := h.forwardResponsesWebsocket(c, conn, cliCancel, dataChan, errChan, &wsBodyLog, passthroughSessionID)
|
completedOutput, errForward := h.forwardResponsesWebsocket(c, conn, cliCancel, dataChan, errChan, &wsBodyLog, passthroughSessionID)
|
||||||
if errForward != nil {
|
if errForward != nil {
|
||||||
|
|||||||
119
sdk/api/handlers/openai_responses_stream_error.go
Normal file
119
sdk/api/handlers/openai_responses_stream_error.go
Normal file
@@ -0,0 +1,119 @@
|
|||||||
|
package handlers
|
||||||
|
|
||||||
|
import (
|
||||||
|
"encoding/json"
|
||||||
|
"fmt"
|
||||||
|
"net/http"
|
||||||
|
"strings"
|
||||||
|
)
|
||||||
|
|
||||||
|
type openAIResponsesStreamErrorChunk struct {
|
||||||
|
Type string `json:"type"`
|
||||||
|
Code string `json:"code"`
|
||||||
|
Message string `json:"message"`
|
||||||
|
SequenceNumber int `json:"sequence_number"`
|
||||||
|
}
|
||||||
|
|
||||||
|
func openAIResponsesStreamErrorCode(status int) string {
|
||||||
|
switch status {
|
||||||
|
case http.StatusUnauthorized:
|
||||||
|
return "invalid_api_key"
|
||||||
|
case http.StatusForbidden:
|
||||||
|
return "insufficient_quota"
|
||||||
|
case http.StatusTooManyRequests:
|
||||||
|
return "rate_limit_exceeded"
|
||||||
|
case http.StatusNotFound:
|
||||||
|
return "model_not_found"
|
||||||
|
case http.StatusRequestTimeout:
|
||||||
|
return "request_timeout"
|
||||||
|
default:
|
||||||
|
if status >= http.StatusInternalServerError {
|
||||||
|
return "internal_server_error"
|
||||||
|
}
|
||||||
|
if status >= http.StatusBadRequest {
|
||||||
|
return "invalid_request_error"
|
||||||
|
}
|
||||||
|
return "unknown_error"
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// BuildOpenAIResponsesStreamErrorChunk builds an OpenAI Responses streaming error chunk.
|
||||||
|
//
|
||||||
|
// Important: OpenAI's HTTP error bodies are shaped like {"error":{...}}; those are valid for
|
||||||
|
// non-streaming responses, but streaming clients validate SSE `data:` payloads against a union
|
||||||
|
// of chunks that requires a top-level `type` field.
|
||||||
|
func BuildOpenAIResponsesStreamErrorChunk(status int, errText string, sequenceNumber int) []byte {
|
||||||
|
if status <= 0 {
|
||||||
|
status = http.StatusInternalServerError
|
||||||
|
}
|
||||||
|
if sequenceNumber < 0 {
|
||||||
|
sequenceNumber = 0
|
||||||
|
}
|
||||||
|
|
||||||
|
message := strings.TrimSpace(errText)
|
||||||
|
if message == "" {
|
||||||
|
message = http.StatusText(status)
|
||||||
|
}
|
||||||
|
|
||||||
|
code := openAIResponsesStreamErrorCode(status)
|
||||||
|
|
||||||
|
trimmed := strings.TrimSpace(errText)
|
||||||
|
if trimmed != "" && json.Valid([]byte(trimmed)) {
|
||||||
|
var payload map[string]any
|
||||||
|
if err := json.Unmarshal([]byte(trimmed), &payload); err == nil {
|
||||||
|
if t, ok := payload["type"].(string); ok && strings.TrimSpace(t) == "error" {
|
||||||
|
if m, ok := payload["message"].(string); ok && strings.TrimSpace(m) != "" {
|
||||||
|
message = strings.TrimSpace(m)
|
||||||
|
}
|
||||||
|
if v, ok := payload["code"]; ok && v != nil {
|
||||||
|
if c, ok := v.(string); ok && strings.TrimSpace(c) != "" {
|
||||||
|
code = strings.TrimSpace(c)
|
||||||
|
} else {
|
||||||
|
code = strings.TrimSpace(fmt.Sprint(v))
|
||||||
|
}
|
||||||
|
}
|
||||||
|
if v, ok := payload["sequence_number"].(float64); ok && sequenceNumber == 0 {
|
||||||
|
sequenceNumber = int(v)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
if e, ok := payload["error"].(map[string]any); ok {
|
||||||
|
if m, ok := e["message"].(string); ok && strings.TrimSpace(m) != "" {
|
||||||
|
message = strings.TrimSpace(m)
|
||||||
|
}
|
||||||
|
if v, ok := e["code"]; ok && v != nil {
|
||||||
|
if c, ok := v.(string); ok && strings.TrimSpace(c) != "" {
|
||||||
|
code = strings.TrimSpace(c)
|
||||||
|
} else {
|
||||||
|
code = strings.TrimSpace(fmt.Sprint(v))
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
if strings.TrimSpace(code) == "" {
|
||||||
|
code = "unknown_error"
|
||||||
|
}
|
||||||
|
|
||||||
|
data, err := json.Marshal(openAIResponsesStreamErrorChunk{
|
||||||
|
Type: "error",
|
||||||
|
Code: code,
|
||||||
|
Message: message,
|
||||||
|
SequenceNumber: sequenceNumber,
|
||||||
|
})
|
||||||
|
if err == nil {
|
||||||
|
return data
|
||||||
|
}
|
||||||
|
|
||||||
|
// Extremely defensive fallback.
|
||||||
|
data, _ = json.Marshal(openAIResponsesStreamErrorChunk{
|
||||||
|
Type: "error",
|
||||||
|
Code: "internal_server_error",
|
||||||
|
Message: message,
|
||||||
|
SequenceNumber: sequenceNumber,
|
||||||
|
})
|
||||||
|
if len(data) > 0 {
|
||||||
|
return data
|
||||||
|
}
|
||||||
|
return []byte(`{"type":"error","code":"internal_server_error","message":"internal error","sequence_number":0}`)
|
||||||
|
}
|
||||||
48
sdk/api/handlers/openai_responses_stream_error_test.go
Normal file
48
sdk/api/handlers/openai_responses_stream_error_test.go
Normal file
@@ -0,0 +1,48 @@
|
|||||||
|
package handlers
|
||||||
|
|
||||||
|
import (
|
||||||
|
"encoding/json"
|
||||||
|
"net/http"
|
||||||
|
"testing"
|
||||||
|
)
|
||||||
|
|
||||||
|
func TestBuildOpenAIResponsesStreamErrorChunk(t *testing.T) {
|
||||||
|
chunk := BuildOpenAIResponsesStreamErrorChunk(http.StatusInternalServerError, "unexpected EOF", 0)
|
||||||
|
var payload map[string]any
|
||||||
|
if err := json.Unmarshal(chunk, &payload); err != nil {
|
||||||
|
t.Fatalf("unmarshal: %v", err)
|
||||||
|
}
|
||||||
|
if payload["type"] != "error" {
|
||||||
|
t.Fatalf("type = %v, want %q", payload["type"], "error")
|
||||||
|
}
|
||||||
|
if payload["code"] != "internal_server_error" {
|
||||||
|
t.Fatalf("code = %v, want %q", payload["code"], "internal_server_error")
|
||||||
|
}
|
||||||
|
if payload["message"] != "unexpected EOF" {
|
||||||
|
t.Fatalf("message = %v, want %q", payload["message"], "unexpected EOF")
|
||||||
|
}
|
||||||
|
if payload["sequence_number"] != float64(0) {
|
||||||
|
t.Fatalf("sequence_number = %v, want %v", payload["sequence_number"], 0)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestBuildOpenAIResponsesStreamErrorChunkExtractsHTTPErrorBody(t *testing.T) {
|
||||||
|
chunk := BuildOpenAIResponsesStreamErrorChunk(
|
||||||
|
http.StatusInternalServerError,
|
||||||
|
`{"error":{"message":"oops","type":"server_error","code":"internal_server_error"}}`,
|
||||||
|
0,
|
||||||
|
)
|
||||||
|
var payload map[string]any
|
||||||
|
if err := json.Unmarshal(chunk, &payload); err != nil {
|
||||||
|
t.Fatalf("unmarshal: %v", err)
|
||||||
|
}
|
||||||
|
if payload["type"] != "error" {
|
||||||
|
t.Fatalf("type = %v, want %q", payload["type"], "error")
|
||||||
|
}
|
||||||
|
if payload["code"] != "internal_server_error" {
|
||||||
|
t.Fatalf("code = %v, want %q", payload["code"], "internal_server_error")
|
||||||
|
}
|
||||||
|
if payload["message"] != "oops" {
|
||||||
|
t.Fatalf("message = %v, want %q", payload["message"], "oops")
|
||||||
|
}
|
||||||
|
}
|
||||||
@@ -2,8 +2,6 @@ package auth
|
|||||||
|
|
||||||
import (
|
import (
|
||||||
"context"
|
"context"
|
||||||
"crypto/sha256"
|
|
||||||
"encoding/hex"
|
|
||||||
"fmt"
|
"fmt"
|
||||||
"net/http"
|
"net/http"
|
||||||
"strings"
|
"strings"
|
||||||
@@ -48,6 +46,10 @@ func (a *CodexAuthenticator) Login(ctx context.Context, cfg *config.Config, opts
|
|||||||
opts = &LoginOptions{}
|
opts = &LoginOptions{}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
if shouldUseCodexDeviceFlow(opts) {
|
||||||
|
return a.loginWithDeviceFlow(ctx, cfg, opts)
|
||||||
|
}
|
||||||
|
|
||||||
callbackPort := a.CallbackPort
|
callbackPort := a.CallbackPort
|
||||||
if opts.CallbackPort > 0 {
|
if opts.CallbackPort > 0 {
|
||||||
callbackPort = opts.CallbackPort
|
callbackPort = opts.CallbackPort
|
||||||
@@ -186,39 +188,5 @@ waitForCallback:
|
|||||||
return nil, codex.NewAuthenticationError(codex.ErrCodeExchangeFailed, err)
|
return nil, codex.NewAuthenticationError(codex.ErrCodeExchangeFailed, err)
|
||||||
}
|
}
|
||||||
|
|
||||||
tokenStorage := authSvc.CreateTokenStorage(authBundle)
|
return a.buildAuthRecord(authSvc, authBundle)
|
||||||
|
|
||||||
if tokenStorage == nil || tokenStorage.Email == "" {
|
|
||||||
return nil, fmt.Errorf("codex token storage missing account information")
|
|
||||||
}
|
|
||||||
|
|
||||||
planType := ""
|
|
||||||
hashAccountID := ""
|
|
||||||
if tokenStorage.IDToken != "" {
|
|
||||||
if claims, errParse := codex.ParseJWTToken(tokenStorage.IDToken); errParse == nil && claims != nil {
|
|
||||||
planType = strings.TrimSpace(claims.CodexAuthInfo.ChatgptPlanType)
|
|
||||||
accountID := strings.TrimSpace(claims.CodexAuthInfo.ChatgptAccountID)
|
|
||||||
if accountID != "" {
|
|
||||||
digest := sha256.Sum256([]byte(accountID))
|
|
||||||
hashAccountID = hex.EncodeToString(digest[:])[:8]
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
fileName := codex.CredentialFileName(tokenStorage.Email, planType, hashAccountID, true)
|
|
||||||
metadata := map[string]any{
|
|
||||||
"email": tokenStorage.Email,
|
|
||||||
}
|
|
||||||
|
|
||||||
fmt.Println("Codex authentication successful")
|
|
||||||
if authBundle.APIKey != "" {
|
|
||||||
fmt.Println("Codex API key obtained and stored")
|
|
||||||
}
|
|
||||||
|
|
||||||
return &coreauth.Auth{
|
|
||||||
ID: fileName,
|
|
||||||
Provider: a.Provider(),
|
|
||||||
FileName: fileName,
|
|
||||||
Storage: tokenStorage,
|
|
||||||
Metadata: metadata,
|
|
||||||
}, nil
|
|
||||||
}
|
}
|
||||||
|
|||||||
291
sdk/auth/codex_device.go
Normal file
291
sdk/auth/codex_device.go
Normal file
@@ -0,0 +1,291 @@
|
|||||||
|
package auth
|
||||||
|
|
||||||
|
import (
|
||||||
|
"bytes"
|
||||||
|
"context"
|
||||||
|
"crypto/sha256"
|
||||||
|
"encoding/hex"
|
||||||
|
"encoding/json"
|
||||||
|
"fmt"
|
||||||
|
"io"
|
||||||
|
"net/http"
|
||||||
|
"strconv"
|
||||||
|
"strings"
|
||||||
|
"time"
|
||||||
|
|
||||||
|
"github.com/router-for-me/CLIProxyAPI/v6/internal/auth/codex"
|
||||||
|
"github.com/router-for-me/CLIProxyAPI/v6/internal/browser"
|
||||||
|
"github.com/router-for-me/CLIProxyAPI/v6/internal/config"
|
||||||
|
"github.com/router-for-me/CLIProxyAPI/v6/internal/util"
|
||||||
|
coreauth "github.com/router-for-me/CLIProxyAPI/v6/sdk/cliproxy/auth"
|
||||||
|
log "github.com/sirupsen/logrus"
|
||||||
|
)
|
||||||
|
|
||||||
|
const (
|
||||||
|
codexLoginModeMetadataKey = "codex_login_mode"
|
||||||
|
codexLoginModeDevice = "device"
|
||||||
|
codexDeviceUserCodeURL = "https://auth.openai.com/api/accounts/deviceauth/usercode"
|
||||||
|
codexDeviceTokenURL = "https://auth.openai.com/api/accounts/deviceauth/token"
|
||||||
|
codexDeviceVerificationURL = "https://auth.openai.com/codex/device"
|
||||||
|
codexDeviceTokenExchangeRedirectURI = "https://auth.openai.com/deviceauth/callback"
|
||||||
|
codexDeviceTimeout = 15 * time.Minute
|
||||||
|
codexDeviceDefaultPollIntervalSeconds = 5
|
||||||
|
)
|
||||||
|
|
||||||
|
type codexDeviceUserCodeRequest struct {
|
||||||
|
ClientID string `json:"client_id"`
|
||||||
|
}
|
||||||
|
|
||||||
|
type codexDeviceUserCodeResponse struct {
|
||||||
|
DeviceAuthID string `json:"device_auth_id"`
|
||||||
|
UserCode string `json:"user_code"`
|
||||||
|
UserCodeAlt string `json:"usercode"`
|
||||||
|
Interval json.RawMessage `json:"interval"`
|
||||||
|
}
|
||||||
|
|
||||||
|
type codexDeviceTokenRequest struct {
|
||||||
|
DeviceAuthID string `json:"device_auth_id"`
|
||||||
|
UserCode string `json:"user_code"`
|
||||||
|
}
|
||||||
|
|
||||||
|
type codexDeviceTokenResponse struct {
|
||||||
|
AuthorizationCode string `json:"authorization_code"`
|
||||||
|
CodeVerifier string `json:"code_verifier"`
|
||||||
|
CodeChallenge string `json:"code_challenge"`
|
||||||
|
}
|
||||||
|
|
||||||
|
func shouldUseCodexDeviceFlow(opts *LoginOptions) bool {
|
||||||
|
if opts == nil || opts.Metadata == nil {
|
||||||
|
return false
|
||||||
|
}
|
||||||
|
return strings.EqualFold(strings.TrimSpace(opts.Metadata[codexLoginModeMetadataKey]), codexLoginModeDevice)
|
||||||
|
}
|
||||||
|
|
||||||
|
func (a *CodexAuthenticator) loginWithDeviceFlow(ctx context.Context, cfg *config.Config, opts *LoginOptions) (*coreauth.Auth, error) {
|
||||||
|
if ctx == nil {
|
||||||
|
ctx = context.Background()
|
||||||
|
}
|
||||||
|
|
||||||
|
httpClient := util.SetProxy(&cfg.SDKConfig, &http.Client{})
|
||||||
|
|
||||||
|
userCodeResp, err := requestCodexDeviceUserCode(ctx, httpClient)
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
|
||||||
|
deviceCode := strings.TrimSpace(userCodeResp.UserCode)
|
||||||
|
if deviceCode == "" {
|
||||||
|
deviceCode = strings.TrimSpace(userCodeResp.UserCodeAlt)
|
||||||
|
}
|
||||||
|
deviceAuthID := strings.TrimSpace(userCodeResp.DeviceAuthID)
|
||||||
|
if deviceCode == "" || deviceAuthID == "" {
|
||||||
|
return nil, fmt.Errorf("codex device flow did not return required fields")
|
||||||
|
}
|
||||||
|
|
||||||
|
pollInterval := parseCodexDevicePollInterval(userCodeResp.Interval)
|
||||||
|
|
||||||
|
fmt.Println("Starting Codex device authentication...")
|
||||||
|
fmt.Printf("Codex device URL: %s\n", codexDeviceVerificationURL)
|
||||||
|
fmt.Printf("Codex device code: %s\n", deviceCode)
|
||||||
|
|
||||||
|
if !opts.NoBrowser {
|
||||||
|
if !browser.IsAvailable() {
|
||||||
|
log.Warn("No browser available; please open the device URL manually")
|
||||||
|
} else if errOpen := browser.OpenURL(codexDeviceVerificationURL); errOpen != nil {
|
||||||
|
log.Warnf("Failed to open browser automatically: %v", errOpen)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
tokenResp, err := pollCodexDeviceToken(ctx, httpClient, deviceAuthID, deviceCode, pollInterval)
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
|
||||||
|
authCode := strings.TrimSpace(tokenResp.AuthorizationCode)
|
||||||
|
codeVerifier := strings.TrimSpace(tokenResp.CodeVerifier)
|
||||||
|
codeChallenge := strings.TrimSpace(tokenResp.CodeChallenge)
|
||||||
|
if authCode == "" || codeVerifier == "" || codeChallenge == "" {
|
||||||
|
return nil, fmt.Errorf("codex device flow token response missing required fields")
|
||||||
|
}
|
||||||
|
|
||||||
|
authSvc := codex.NewCodexAuth(cfg)
|
||||||
|
authBundle, err := authSvc.ExchangeCodeForTokensWithRedirect(
|
||||||
|
ctx,
|
||||||
|
authCode,
|
||||||
|
codexDeviceTokenExchangeRedirectURI,
|
||||||
|
&codex.PKCECodes{
|
||||||
|
CodeVerifier: codeVerifier,
|
||||||
|
CodeChallenge: codeChallenge,
|
||||||
|
},
|
||||||
|
)
|
||||||
|
if err != nil {
|
||||||
|
return nil, codex.NewAuthenticationError(codex.ErrCodeExchangeFailed, err)
|
||||||
|
}
|
||||||
|
|
||||||
|
return a.buildAuthRecord(authSvc, authBundle)
|
||||||
|
}
|
||||||
|
|
||||||
|
func requestCodexDeviceUserCode(ctx context.Context, client *http.Client) (*codexDeviceUserCodeResponse, error) {
|
||||||
|
body, err := json.Marshal(codexDeviceUserCodeRequest{ClientID: codex.ClientID})
|
||||||
|
if err != nil {
|
||||||
|
return nil, fmt.Errorf("failed to encode codex device request: %w", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
req, err := http.NewRequestWithContext(ctx, http.MethodPost, codexDeviceUserCodeURL, bytes.NewReader(body))
|
||||||
|
if err != nil {
|
||||||
|
return nil, fmt.Errorf("failed to create codex device request: %w", err)
|
||||||
|
}
|
||||||
|
req.Header.Set("Content-Type", "application/json")
|
||||||
|
req.Header.Set("Accept", "application/json")
|
||||||
|
|
||||||
|
resp, err := client.Do(req)
|
||||||
|
if err != nil {
|
||||||
|
return nil, fmt.Errorf("failed to request codex device code: %w", err)
|
||||||
|
}
|
||||||
|
defer func() { _ = resp.Body.Close() }()
|
||||||
|
|
||||||
|
respBody, err := io.ReadAll(resp.Body)
|
||||||
|
if err != nil {
|
||||||
|
return nil, fmt.Errorf("failed to read codex device code response: %w", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
if !codexDeviceIsSuccessStatus(resp.StatusCode) {
|
||||||
|
trimmed := strings.TrimSpace(string(respBody))
|
||||||
|
if resp.StatusCode == http.StatusNotFound {
|
||||||
|
return nil, fmt.Errorf("codex device endpoint is unavailable (status %d)", resp.StatusCode)
|
||||||
|
}
|
||||||
|
if trimmed == "" {
|
||||||
|
trimmed = "empty response body"
|
||||||
|
}
|
||||||
|
return nil, fmt.Errorf("codex device code request failed with status %d: %s", resp.StatusCode, trimmed)
|
||||||
|
}
|
||||||
|
|
||||||
|
var parsed codexDeviceUserCodeResponse
|
||||||
|
if err := json.Unmarshal(respBody, &parsed); err != nil {
|
||||||
|
return nil, fmt.Errorf("failed to decode codex device code response: %w", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
return &parsed, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func pollCodexDeviceToken(ctx context.Context, client *http.Client, deviceAuthID, userCode string, interval time.Duration) (*codexDeviceTokenResponse, error) {
|
||||||
|
deadline := time.Now().Add(codexDeviceTimeout)
|
||||||
|
|
||||||
|
for {
|
||||||
|
if time.Now().After(deadline) {
|
||||||
|
return nil, fmt.Errorf("codex device authentication timed out after 15 minutes")
|
||||||
|
}
|
||||||
|
|
||||||
|
body, err := json.Marshal(codexDeviceTokenRequest{
|
||||||
|
DeviceAuthID: deviceAuthID,
|
||||||
|
UserCode: userCode,
|
||||||
|
})
|
||||||
|
if err != nil {
|
||||||
|
return nil, fmt.Errorf("failed to encode codex device poll request: %w", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
req, err := http.NewRequestWithContext(ctx, http.MethodPost, codexDeviceTokenURL, bytes.NewReader(body))
|
||||||
|
if err != nil {
|
||||||
|
return nil, fmt.Errorf("failed to create codex device poll request: %w", err)
|
||||||
|
}
|
||||||
|
req.Header.Set("Content-Type", "application/json")
|
||||||
|
req.Header.Set("Accept", "application/json")
|
||||||
|
|
||||||
|
resp, err := client.Do(req)
|
||||||
|
if err != nil {
|
||||||
|
return nil, fmt.Errorf("failed to poll codex device token: %w", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
respBody, readErr := io.ReadAll(resp.Body)
|
||||||
|
_ = resp.Body.Close()
|
||||||
|
if readErr != nil {
|
||||||
|
return nil, fmt.Errorf("failed to read codex device poll response: %w", readErr)
|
||||||
|
}
|
||||||
|
|
||||||
|
switch {
|
||||||
|
case codexDeviceIsSuccessStatus(resp.StatusCode):
|
||||||
|
var parsed codexDeviceTokenResponse
|
||||||
|
if err := json.Unmarshal(respBody, &parsed); err != nil {
|
||||||
|
return nil, fmt.Errorf("failed to decode codex device token response: %w", err)
|
||||||
|
}
|
||||||
|
return &parsed, nil
|
||||||
|
case resp.StatusCode == http.StatusForbidden || resp.StatusCode == http.StatusNotFound:
|
||||||
|
select {
|
||||||
|
case <-ctx.Done():
|
||||||
|
return nil, ctx.Err()
|
||||||
|
case <-time.After(interval):
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
default:
|
||||||
|
trimmed := strings.TrimSpace(string(respBody))
|
||||||
|
if trimmed == "" {
|
||||||
|
trimmed = "empty response body"
|
||||||
|
}
|
||||||
|
return nil, fmt.Errorf("codex device token polling failed with status %d: %s", resp.StatusCode, trimmed)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func parseCodexDevicePollInterval(raw json.RawMessage) time.Duration {
|
||||||
|
defaultInterval := time.Duration(codexDeviceDefaultPollIntervalSeconds) * time.Second
|
||||||
|
if len(raw) == 0 {
|
||||||
|
return defaultInterval
|
||||||
|
}
|
||||||
|
|
||||||
|
var asString string
|
||||||
|
if err := json.Unmarshal(raw, &asString); err == nil {
|
||||||
|
if seconds, convErr := strconv.Atoi(strings.TrimSpace(asString)); convErr == nil && seconds > 0 {
|
||||||
|
return time.Duration(seconds) * time.Second
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
var asInt int
|
||||||
|
if err := json.Unmarshal(raw, &asInt); err == nil && asInt > 0 {
|
||||||
|
return time.Duration(asInt) * time.Second
|
||||||
|
}
|
||||||
|
|
||||||
|
return defaultInterval
|
||||||
|
}
|
||||||
|
|
||||||
|
func codexDeviceIsSuccessStatus(code int) bool {
|
||||||
|
return code >= 200 && code < 300
|
||||||
|
}
|
||||||
|
|
||||||
|
func (a *CodexAuthenticator) buildAuthRecord(authSvc *codex.CodexAuth, authBundle *codex.CodexAuthBundle) (*coreauth.Auth, error) {
|
||||||
|
tokenStorage := authSvc.CreateTokenStorage(authBundle)
|
||||||
|
|
||||||
|
if tokenStorage == nil || tokenStorage.Email == "" {
|
||||||
|
return nil, fmt.Errorf("codex token storage missing account information")
|
||||||
|
}
|
||||||
|
|
||||||
|
planType := ""
|
||||||
|
hashAccountID := ""
|
||||||
|
if tokenStorage.IDToken != "" {
|
||||||
|
if claims, errParse := codex.ParseJWTToken(tokenStorage.IDToken); errParse == nil && claims != nil {
|
||||||
|
planType = strings.TrimSpace(claims.CodexAuthInfo.ChatgptPlanType)
|
||||||
|
accountID := strings.TrimSpace(claims.CodexAuthInfo.ChatgptAccountID)
|
||||||
|
if accountID != "" {
|
||||||
|
digest := sha256.Sum256([]byte(accountID))
|
||||||
|
hashAccountID = hex.EncodeToString(digest[:])[:8]
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
fileName := codex.CredentialFileName(tokenStorage.Email, planType, hashAccountID, true)
|
||||||
|
metadata := map[string]any{
|
||||||
|
"email": tokenStorage.Email,
|
||||||
|
}
|
||||||
|
|
||||||
|
fmt.Println("Codex authentication successful")
|
||||||
|
if authBundle.APIKey != "" {
|
||||||
|
fmt.Println("Codex API key obtained and stored")
|
||||||
|
}
|
||||||
|
|
||||||
|
return &coreauth.Auth{
|
||||||
|
ID: fileName,
|
||||||
|
Provider: a.Provider(),
|
||||||
|
FileName: fileName,
|
||||||
|
Storage: tokenStorage,
|
||||||
|
Metadata: metadata,
|
||||||
|
}, nil
|
||||||
|
}
|
||||||
@@ -64,8 +64,16 @@ func (s *FileTokenStore) Save(ctx context.Context, auth *cliproxyauth.Auth) (str
|
|||||||
return "", fmt.Errorf("auth filestore: create dir failed: %w", err)
|
return "", fmt.Errorf("auth filestore: create dir failed: %w", err)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// metadataSetter is a private interface for TokenStorage implementations that support metadata injection.
|
||||||
|
type metadataSetter interface {
|
||||||
|
SetMetadata(map[string]any)
|
||||||
|
}
|
||||||
|
|
||||||
switch {
|
switch {
|
||||||
case auth.Storage != nil:
|
case auth.Storage != nil:
|
||||||
|
if setter, ok := auth.Storage.(metadataSetter); ok {
|
||||||
|
setter.SetMetadata(auth.Metadata)
|
||||||
|
}
|
||||||
if err = auth.Storage.SaveTokenToFile(path); err != nil {
|
if err = auth.Storage.SaveTokenToFile(path); err != nil {
|
||||||
return "", err
|
return "", err
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -86,6 +86,8 @@ func (a GitHubCopilotAuthenticator) Login(ctx context.Context, cfg *config.Confi
|
|||||||
metadata := map[string]any{
|
metadata := map[string]any{
|
||||||
"type": "github-copilot",
|
"type": "github-copilot",
|
||||||
"username": authBundle.Username,
|
"username": authBundle.Username,
|
||||||
|
"email": authBundle.Email,
|
||||||
|
"name": authBundle.Name,
|
||||||
"access_token": authBundle.TokenData.AccessToken,
|
"access_token": authBundle.TokenData.AccessToken,
|
||||||
"token_type": authBundle.TokenData.TokenType,
|
"token_type": authBundle.TokenData.TokenType,
|
||||||
"scope": authBundle.TokenData.Scope,
|
"scope": authBundle.TokenData.Scope,
|
||||||
@@ -98,13 +100,18 @@ func (a GitHubCopilotAuthenticator) Login(ctx context.Context, cfg *config.Confi
|
|||||||
|
|
||||||
fileName := fmt.Sprintf("github-copilot-%s.json", authBundle.Username)
|
fileName := fmt.Sprintf("github-copilot-%s.json", authBundle.Username)
|
||||||
|
|
||||||
|
label := authBundle.Email
|
||||||
|
if label == "" {
|
||||||
|
label = authBundle.Username
|
||||||
|
}
|
||||||
|
|
||||||
fmt.Printf("\nGitHub Copilot authentication successful for user: %s\n", authBundle.Username)
|
fmt.Printf("\nGitHub Copilot authentication successful for user: %s\n", authBundle.Username)
|
||||||
|
|
||||||
return &coreauth.Auth{
|
return &coreauth.Auth{
|
||||||
ID: fileName,
|
ID: fileName,
|
||||||
Provider: a.Provider(),
|
Provider: a.Provider(),
|
||||||
FileName: fileName,
|
FileName: fileName,
|
||||||
Label: authBundle.Username,
|
Label: label,
|
||||||
Storage: tokenStorage,
|
Storage: tokenStorage,
|
||||||
Metadata: metadata,
|
Metadata: metadata,
|
||||||
}, nil
|
}, nil
|
||||||
|
|||||||
@@ -30,8 +30,9 @@ type ProviderExecutor interface {
|
|||||||
Identifier() string
|
Identifier() string
|
||||||
// Execute handles non-streaming execution and returns the provider response payload.
|
// Execute handles non-streaming execution and returns the provider response payload.
|
||||||
Execute(ctx context.Context, auth *Auth, req cliproxyexecutor.Request, opts cliproxyexecutor.Options) (cliproxyexecutor.Response, error)
|
Execute(ctx context.Context, auth *Auth, req cliproxyexecutor.Request, opts cliproxyexecutor.Options) (cliproxyexecutor.Response, error)
|
||||||
// ExecuteStream handles streaming execution and returns a channel of provider chunks.
|
// ExecuteStream handles streaming execution and returns a StreamResult containing
|
||||||
ExecuteStream(ctx context.Context, auth *Auth, req cliproxyexecutor.Request, opts cliproxyexecutor.Options) (<-chan cliproxyexecutor.StreamChunk, error)
|
// upstream headers and a channel of provider chunks.
|
||||||
|
ExecuteStream(ctx context.Context, auth *Auth, req cliproxyexecutor.Request, opts cliproxyexecutor.Options) (*cliproxyexecutor.StreamResult, error)
|
||||||
// Refresh attempts to refresh provider credentials and returns the updated auth state.
|
// Refresh attempts to refresh provider credentials and returns the updated auth state.
|
||||||
Refresh(ctx context.Context, auth *Auth) (*Auth, error)
|
Refresh(ctx context.Context, auth *Auth) (*Auth, error)
|
||||||
// CountTokens returns the token count for the given request.
|
// CountTokens returns the token count for the given request.
|
||||||
@@ -558,7 +559,7 @@ func (m *Manager) ExecuteCount(ctx context.Context, providers []string, req clip
|
|||||||
|
|
||||||
// ExecuteStream performs a streaming execution using the configured selector and executor.
|
// ExecuteStream performs a streaming execution using the configured selector and executor.
|
||||||
// It supports multiple providers for the same model and round-robins the starting provider per model.
|
// It supports multiple providers for the same model and round-robins the starting provider per model.
|
||||||
func (m *Manager) ExecuteStream(ctx context.Context, providers []string, req cliproxyexecutor.Request, opts cliproxyexecutor.Options) (<-chan cliproxyexecutor.StreamChunk, error) {
|
func (m *Manager) ExecuteStream(ctx context.Context, providers []string, req cliproxyexecutor.Request, opts cliproxyexecutor.Options) (*cliproxyexecutor.StreamResult, error) {
|
||||||
normalized := m.normalizeProviders(providers)
|
normalized := m.normalizeProviders(providers)
|
||||||
if len(normalized) == 0 {
|
if len(normalized) == 0 {
|
||||||
return nil, &Error{Code: "provider_not_found", Message: "no provider supplied"}
|
return nil, &Error{Code: "provider_not_found", Message: "no provider supplied"}
|
||||||
@@ -568,9 +569,9 @@ func (m *Manager) ExecuteStream(ctx context.Context, providers []string, req cli
|
|||||||
|
|
||||||
var lastErr error
|
var lastErr error
|
||||||
for attempt := 0; ; attempt++ {
|
for attempt := 0; ; attempt++ {
|
||||||
chunks, errStream := m.executeStreamMixedOnce(ctx, normalized, req, opts)
|
result, errStream := m.executeStreamMixedOnce(ctx, normalized, req, opts)
|
||||||
if errStream == nil {
|
if errStream == nil {
|
||||||
return chunks, nil
|
return result, nil
|
||||||
}
|
}
|
||||||
lastErr = errStream
|
lastErr = errStream
|
||||||
wait, shouldRetry := m.shouldRetryAfterError(errStream, attempt, normalized, req.Model, maxWait)
|
wait, shouldRetry := m.shouldRetryAfterError(errStream, attempt, normalized, req.Model, maxWait)
|
||||||
@@ -699,7 +700,7 @@ func (m *Manager) executeCountMixedOnce(ctx context.Context, providers []string,
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
func (m *Manager) executeStreamMixedOnce(ctx context.Context, providers []string, req cliproxyexecutor.Request, opts cliproxyexecutor.Options) (<-chan cliproxyexecutor.StreamChunk, error) {
|
func (m *Manager) executeStreamMixedOnce(ctx context.Context, providers []string, req cliproxyexecutor.Request, opts cliproxyexecutor.Options) (*cliproxyexecutor.StreamResult, error) {
|
||||||
if len(providers) == 0 {
|
if len(providers) == 0 {
|
||||||
return nil, &Error{Code: "provider_not_found", Message: "no provider supplied"}
|
return nil, &Error{Code: "provider_not_found", Message: "no provider supplied"}
|
||||||
}
|
}
|
||||||
@@ -730,7 +731,7 @@ func (m *Manager) executeStreamMixedOnce(ctx context.Context, providers []string
|
|||||||
execReq.Model = rewriteModelForAuth(routeModel, auth)
|
execReq.Model = rewriteModelForAuth(routeModel, auth)
|
||||||
execReq.Model = m.applyOAuthModelAlias(auth, execReq.Model)
|
execReq.Model = m.applyOAuthModelAlias(auth, execReq.Model)
|
||||||
execReq.Model = m.applyAPIKeyModelAlias(auth, execReq.Model)
|
execReq.Model = m.applyAPIKeyModelAlias(auth, execReq.Model)
|
||||||
chunks, errStream := executor.ExecuteStream(execCtx, auth, execReq, opts)
|
streamResult, errStream := executor.ExecuteStream(execCtx, auth, execReq, opts)
|
||||||
if errStream != nil {
|
if errStream != nil {
|
||||||
if errCtx := execCtx.Err(); errCtx != nil {
|
if errCtx := execCtx.Err(); errCtx != nil {
|
||||||
return nil, errCtx
|
return nil, errCtx
|
||||||
@@ -778,8 +779,11 @@ func (m *Manager) executeStreamMixedOnce(ctx context.Context, providers []string
|
|||||||
if !failed {
|
if !failed {
|
||||||
m.MarkResult(streamCtx, Result{AuthID: streamAuth.ID, Provider: streamProvider, Model: routeModel, Success: true})
|
m.MarkResult(streamCtx, Result{AuthID: streamAuth.ID, Provider: streamProvider, Model: routeModel, Success: true})
|
||||||
}
|
}
|
||||||
}(execCtx, auth.Clone(), provider, chunks)
|
}(execCtx, auth.Clone(), provider, streamResult.Chunks)
|
||||||
return out, nil
|
return &cliproxyexecutor.StreamResult{
|
||||||
|
Headers: streamResult.Headers,
|
||||||
|
Chunks: out,
|
||||||
|
}, nil
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -1824,9 +1828,7 @@ func (m *Manager) persist(ctx context.Context, auth *Auth) error {
|
|||||||
// every few seconds and triggers refresh operations when required.
|
// every few seconds and triggers refresh operations when required.
|
||||||
// Only one loop is kept alive; starting a new one cancels the previous run.
|
// Only one loop is kept alive; starting a new one cancels the previous run.
|
||||||
func (m *Manager) StartAutoRefresh(parent context.Context, interval time.Duration) {
|
func (m *Manager) StartAutoRefresh(parent context.Context, interval time.Duration) {
|
||||||
if interval <= 0 || interval > refreshCheckInterval {
|
if interval <= 0 {
|
||||||
interval = refreshCheckInterval
|
|
||||||
} else {
|
|
||||||
interval = refreshCheckInterval
|
interval = refreshCheckInterval
|
||||||
}
|
}
|
||||||
if m.refreshCancel != nil {
|
if m.refreshCancel != nil {
|
||||||
|
|||||||
@@ -24,10 +24,10 @@ func (e *replaceAwareExecutor) Execute(context.Context, *Auth, cliproxyexecutor.
|
|||||||
return cliproxyexecutor.Response{}, nil
|
return cliproxyexecutor.Response{}, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func (e *replaceAwareExecutor) ExecuteStream(context.Context, *Auth, cliproxyexecutor.Request, cliproxyexecutor.Options) (<-chan cliproxyexecutor.StreamChunk, error) {
|
func (e *replaceAwareExecutor) ExecuteStream(context.Context, *Auth, cliproxyexecutor.Request, cliproxyexecutor.Options) (*cliproxyexecutor.StreamResult, error) {
|
||||||
ch := make(chan cliproxyexecutor.StreamChunk)
|
ch := make(chan cliproxyexecutor.StreamChunk)
|
||||||
close(ch)
|
close(ch)
|
||||||
return ch, nil
|
return &cliproxyexecutor.StreamResult{Chunks: ch}, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func (e *replaceAwareExecutor) Refresh(_ context.Context, auth *Auth) (*Auth, error) {
|
func (e *replaceAwareExecutor) Refresh(_ context.Context, auth *Auth) (*Auth, error) {
|
||||||
@@ -89,7 +89,11 @@ func TestManagerExecutorReturnsRegisteredExecutor(t *testing.T) {
|
|||||||
if !okResolved {
|
if !okResolved {
|
||||||
t.Fatal("expected registered executor to be found")
|
t.Fatal("expected registered executor to be found")
|
||||||
}
|
}
|
||||||
if resolved != current {
|
resolvedExecutor, okResolvedExecutor := resolved.(*replaceAwareExecutor)
|
||||||
|
if !okResolvedExecutor {
|
||||||
|
t.Fatalf("expected resolved executor type %T, got %T", current, resolved)
|
||||||
|
}
|
||||||
|
if resolvedExecutor != current {
|
||||||
t.Fatal("expected resolved executor to match registered executor")
|
t.Fatal("expected resolved executor to match registered executor")
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|||||||
@@ -5,6 +5,7 @@ import (
|
|||||||
"encoding/json"
|
"encoding/json"
|
||||||
"fmt"
|
"fmt"
|
||||||
"math"
|
"math"
|
||||||
|
"math/rand/v2"
|
||||||
"net/http"
|
"net/http"
|
||||||
"sort"
|
"sort"
|
||||||
"strconv"
|
"strconv"
|
||||||
@@ -248,6 +249,9 @@ func getAvailableAuths(auths []*Auth, provider, model string, now time.Time) ([]
|
|||||||
}
|
}
|
||||||
|
|
||||||
// Pick selects the next available auth for the provider in a round-robin manner.
|
// Pick selects the next available auth for the provider in a round-robin manner.
|
||||||
|
// For gemini-cli virtual auths (identified by the gemini_virtual_parent attribute),
|
||||||
|
// a two-level round-robin is used: first cycling across credential groups (parent
|
||||||
|
// accounts), then cycling within each group's project auths.
|
||||||
func (s *RoundRobinSelector) Pick(ctx context.Context, provider, model string, opts cliproxyexecutor.Options, auths []*Auth) (*Auth, error) {
|
func (s *RoundRobinSelector) Pick(ctx context.Context, provider, model string, opts cliproxyexecutor.Options, auths []*Auth) (*Auth, error) {
|
||||||
_ = opts
|
_ = opts
|
||||||
now := time.Now()
|
now := time.Now()
|
||||||
@@ -265,21 +269,87 @@ func (s *RoundRobinSelector) Pick(ctx context.Context, provider, model string, o
|
|||||||
if limit <= 0 {
|
if limit <= 0 {
|
||||||
limit = 4096
|
limit = 4096
|
||||||
}
|
}
|
||||||
if _, ok := s.cursors[key]; !ok && len(s.cursors) >= limit {
|
|
||||||
s.cursors = make(map[string]int)
|
|
||||||
}
|
|
||||||
index := s.cursors[key]
|
|
||||||
|
|
||||||
|
// Check if any available auth has gemini_virtual_parent attribute,
|
||||||
|
// indicating gemini-cli virtual auths that should use credential-level polling.
|
||||||
|
groups, parentOrder := groupByVirtualParent(available)
|
||||||
|
if len(parentOrder) > 1 {
|
||||||
|
// Two-level round-robin: first select a credential group, then pick within it.
|
||||||
|
groupKey := key + "::group"
|
||||||
|
s.ensureCursorKey(groupKey, limit)
|
||||||
|
if _, exists := s.cursors[groupKey]; !exists {
|
||||||
|
// Seed with a random initial offset so the starting credential is randomized.
|
||||||
|
s.cursors[groupKey] = rand.IntN(len(parentOrder))
|
||||||
|
}
|
||||||
|
groupIndex := s.cursors[groupKey]
|
||||||
|
if groupIndex >= 2_147_483_640 {
|
||||||
|
groupIndex = 0
|
||||||
|
}
|
||||||
|
s.cursors[groupKey] = groupIndex + 1
|
||||||
|
|
||||||
|
selectedParent := parentOrder[groupIndex%len(parentOrder)]
|
||||||
|
group := groups[selectedParent]
|
||||||
|
|
||||||
|
// Second level: round-robin within the selected credential group.
|
||||||
|
innerKey := key + "::cred:" + selectedParent
|
||||||
|
s.ensureCursorKey(innerKey, limit)
|
||||||
|
innerIndex := s.cursors[innerKey]
|
||||||
|
if innerIndex >= 2_147_483_640 {
|
||||||
|
innerIndex = 0
|
||||||
|
}
|
||||||
|
s.cursors[innerKey] = innerIndex + 1
|
||||||
|
s.mu.Unlock()
|
||||||
|
return group[innerIndex%len(group)], nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// Flat round-robin for non-grouped auths (original behavior).
|
||||||
|
s.ensureCursorKey(key, limit)
|
||||||
|
index := s.cursors[key]
|
||||||
if index >= 2_147_483_640 {
|
if index >= 2_147_483_640 {
|
||||||
index = 0
|
index = 0
|
||||||
}
|
}
|
||||||
|
|
||||||
s.cursors[key] = index + 1
|
s.cursors[key] = index + 1
|
||||||
s.mu.Unlock()
|
s.mu.Unlock()
|
||||||
// log.Debugf("available: %d, index: %d, key: %d", len(available), index, index%len(available))
|
|
||||||
return available[index%len(available)], nil
|
return available[index%len(available)], nil
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// ensureCursorKey ensures the cursor map has capacity for the given key.
|
||||||
|
// Must be called with s.mu held.
|
||||||
|
func (s *RoundRobinSelector) ensureCursorKey(key string, limit int) {
|
||||||
|
if _, ok := s.cursors[key]; !ok && len(s.cursors) >= limit {
|
||||||
|
s.cursors = make(map[string]int)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// groupByVirtualParent groups auths by their gemini_virtual_parent attribute.
|
||||||
|
// Returns a map of parentID -> auths and a sorted slice of parent IDs for stable iteration.
|
||||||
|
// Only auths with a non-empty gemini_virtual_parent are grouped; if any auth lacks
|
||||||
|
// this attribute, nil/nil is returned so the caller falls back to flat round-robin.
|
||||||
|
func groupByVirtualParent(auths []*Auth) (map[string][]*Auth, []string) {
|
||||||
|
if len(auths) == 0 {
|
||||||
|
return nil, nil
|
||||||
|
}
|
||||||
|
groups := make(map[string][]*Auth)
|
||||||
|
for _, a := range auths {
|
||||||
|
parent := ""
|
||||||
|
if a.Attributes != nil {
|
||||||
|
parent = strings.TrimSpace(a.Attributes["gemini_virtual_parent"])
|
||||||
|
}
|
||||||
|
if parent == "" {
|
||||||
|
// Non-virtual auth present; fall back to flat round-robin.
|
||||||
|
return nil, nil
|
||||||
|
}
|
||||||
|
groups[parent] = append(groups[parent], a)
|
||||||
|
}
|
||||||
|
// Collect parent IDs in sorted order for stable cursor indexing.
|
||||||
|
parentOrder := make([]string, 0, len(groups))
|
||||||
|
for p := range groups {
|
||||||
|
parentOrder = append(parentOrder, p)
|
||||||
|
}
|
||||||
|
sort.Strings(parentOrder)
|
||||||
|
return groups, parentOrder
|
||||||
|
}
|
||||||
|
|
||||||
// Pick selects the first available auth for the provider in a deterministic manner.
|
// Pick selects the first available auth for the provider in a deterministic manner.
|
||||||
func (s *FillFirstSelector) Pick(ctx context.Context, provider, model string, opts cliproxyexecutor.Options, auths []*Auth) (*Auth, error) {
|
func (s *FillFirstSelector) Pick(ctx context.Context, provider, model string, opts cliproxyexecutor.Options, auths []*Auth) (*Auth, error) {
|
||||||
_ = opts
|
_ = opts
|
||||||
|
|||||||
@@ -402,3 +402,128 @@ func TestRoundRobinSelectorPick_CursorKeyCap(t *testing.T) {
|
|||||||
t.Fatalf("selector.cursors missing key %q", "gemini:m3")
|
t.Fatalf("selector.cursors missing key %q", "gemini:m3")
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func TestRoundRobinSelectorPick_GeminiCLICredentialGrouping(t *testing.T) {
|
||||||
|
t.Parallel()
|
||||||
|
|
||||||
|
selector := &RoundRobinSelector{}
|
||||||
|
|
||||||
|
// Simulate two gemini-cli credentials, each with multiple projects:
|
||||||
|
// Credential A (parent = "cred-a.json") has 3 projects
|
||||||
|
// Credential B (parent = "cred-b.json") has 2 projects
|
||||||
|
auths := []*Auth{
|
||||||
|
{ID: "cred-a.json::proj-a1", Attributes: map[string]string{"gemini_virtual_parent": "cred-a.json"}},
|
||||||
|
{ID: "cred-a.json::proj-a2", Attributes: map[string]string{"gemini_virtual_parent": "cred-a.json"}},
|
||||||
|
{ID: "cred-a.json::proj-a3", Attributes: map[string]string{"gemini_virtual_parent": "cred-a.json"}},
|
||||||
|
{ID: "cred-b.json::proj-b1", Attributes: map[string]string{"gemini_virtual_parent": "cred-b.json"}},
|
||||||
|
{ID: "cred-b.json::proj-b2", Attributes: map[string]string{"gemini_virtual_parent": "cred-b.json"}},
|
||||||
|
}
|
||||||
|
|
||||||
|
// Two-level round-robin: consecutive picks must alternate between credentials.
|
||||||
|
// Credential group order is randomized, but within each call the group cursor
|
||||||
|
// advances by 1, so consecutive picks should cycle through different parents.
|
||||||
|
picks := make([]string, 6)
|
||||||
|
parents := make([]string, 6)
|
||||||
|
for i := 0; i < 6; i++ {
|
||||||
|
got, err := selector.Pick(context.Background(), "gemini-cli", "gemini-2.5-pro", cliproxyexecutor.Options{}, auths)
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("Pick() #%d error = %v", i, err)
|
||||||
|
}
|
||||||
|
if got == nil {
|
||||||
|
t.Fatalf("Pick() #%d auth = nil", i)
|
||||||
|
}
|
||||||
|
picks[i] = got.ID
|
||||||
|
parents[i] = got.Attributes["gemini_virtual_parent"]
|
||||||
|
}
|
||||||
|
|
||||||
|
// Verify property: consecutive picks must alternate between credential groups.
|
||||||
|
for i := 1; i < len(parents); i++ {
|
||||||
|
if parents[i] == parents[i-1] {
|
||||||
|
t.Fatalf("Pick() #%d and #%d both from same parent %q (IDs: %q, %q); expected alternating credentials",
|
||||||
|
i-1, i, parents[i], picks[i-1], picks[i])
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// Verify property: each credential's projects are picked in sequence (round-robin within group).
|
||||||
|
credPicks := map[string][]string{}
|
||||||
|
for i, id := range picks {
|
||||||
|
credPicks[parents[i]] = append(credPicks[parents[i]], id)
|
||||||
|
}
|
||||||
|
for parent, ids := range credPicks {
|
||||||
|
for i := 1; i < len(ids); i++ {
|
||||||
|
if ids[i] == ids[i-1] {
|
||||||
|
t.Fatalf("Credential %q picked same project %q twice in a row", parent, ids[i])
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestRoundRobinSelectorPick_SingleParentFallsBackToFlat(t *testing.T) {
|
||||||
|
t.Parallel()
|
||||||
|
|
||||||
|
selector := &RoundRobinSelector{}
|
||||||
|
|
||||||
|
// All auths from the same parent - should fall back to flat round-robin
|
||||||
|
// because there's only one credential group (no benefit from two-level).
|
||||||
|
auths := []*Auth{
|
||||||
|
{ID: "cred-a.json::proj-a1", Attributes: map[string]string{"gemini_virtual_parent": "cred-a.json"}},
|
||||||
|
{ID: "cred-a.json::proj-a2", Attributes: map[string]string{"gemini_virtual_parent": "cred-a.json"}},
|
||||||
|
{ID: "cred-a.json::proj-a3", Attributes: map[string]string{"gemini_virtual_parent": "cred-a.json"}},
|
||||||
|
}
|
||||||
|
|
||||||
|
// With single parent group, parentOrder has length 1, so it uses flat round-robin.
|
||||||
|
// Sorted by ID: proj-a1, proj-a2, proj-a3
|
||||||
|
want := []string{
|
||||||
|
"cred-a.json::proj-a1",
|
||||||
|
"cred-a.json::proj-a2",
|
||||||
|
"cred-a.json::proj-a3",
|
||||||
|
"cred-a.json::proj-a1",
|
||||||
|
}
|
||||||
|
|
||||||
|
for i, expectedID := range want {
|
||||||
|
got, err := selector.Pick(context.Background(), "gemini-cli", "gemini-2.5-pro", cliproxyexecutor.Options{}, auths)
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("Pick() #%d error = %v", i, err)
|
||||||
|
}
|
||||||
|
if got == nil {
|
||||||
|
t.Fatalf("Pick() #%d auth = nil", i)
|
||||||
|
}
|
||||||
|
if got.ID != expectedID {
|
||||||
|
t.Fatalf("Pick() #%d auth.ID = %q, want %q", i, got.ID, expectedID)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestRoundRobinSelectorPick_MixedVirtualAndNonVirtualFallsBackToFlat(t *testing.T) {
|
||||||
|
t.Parallel()
|
||||||
|
|
||||||
|
selector := &RoundRobinSelector{}
|
||||||
|
|
||||||
|
// Mix of virtual and non-virtual auths (e.g., a regular gemini-cli auth without projects
|
||||||
|
// alongside virtual ones). Should fall back to flat round-robin.
|
||||||
|
auths := []*Auth{
|
||||||
|
{ID: "cred-a.json::proj-a1", Attributes: map[string]string{"gemini_virtual_parent": "cred-a.json"}},
|
||||||
|
{ID: "cred-regular.json"}, // no gemini_virtual_parent
|
||||||
|
}
|
||||||
|
|
||||||
|
// groupByVirtualParent returns nil when any auth lacks the attribute,
|
||||||
|
// so flat round-robin is used. Sorted by ID: cred-a.json::proj-a1, cred-regular.json
|
||||||
|
want := []string{
|
||||||
|
"cred-a.json::proj-a1",
|
||||||
|
"cred-regular.json",
|
||||||
|
"cred-a.json::proj-a1",
|
||||||
|
}
|
||||||
|
|
||||||
|
for i, expectedID := range want {
|
||||||
|
got, err := selector.Pick(context.Background(), "gemini-cli", "", cliproxyexecutor.Options{}, auths)
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("Pick() #%d error = %v", i, err)
|
||||||
|
}
|
||||||
|
if got == nil {
|
||||||
|
t.Fatalf("Pick() #%d auth = nil", i)
|
||||||
|
}
|
||||||
|
if got.ID != expectedID {
|
||||||
|
t.Fatalf("Pick() #%d auth.ID = %q, want %q", i, got.ID, expectedID)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|||||||
@@ -1,9 +1,12 @@
|
|||||||
package auth
|
package auth
|
||||||
|
|
||||||
import (
|
import (
|
||||||
|
"context"
|
||||||
"crypto/sha256"
|
"crypto/sha256"
|
||||||
"encoding/hex"
|
"encoding/hex"
|
||||||
"encoding/json"
|
"encoding/json"
|
||||||
|
"net/http"
|
||||||
|
"net/url"
|
||||||
"strconv"
|
"strconv"
|
||||||
"strings"
|
"strings"
|
||||||
"sync"
|
"sync"
|
||||||
@@ -12,6 +15,33 @@ import (
|
|||||||
baseauth "github.com/router-for-me/CLIProxyAPI/v6/internal/auth"
|
baseauth "github.com/router-for-me/CLIProxyAPI/v6/internal/auth"
|
||||||
)
|
)
|
||||||
|
|
||||||
|
// PostAuthHook defines a function that is called after an Auth record is created
|
||||||
|
// but before it is persisted to storage. This allows for modification of the
|
||||||
|
// Auth record (e.g., injecting metadata) based on external context.
|
||||||
|
type PostAuthHook func(context.Context, *Auth) error
|
||||||
|
|
||||||
|
// RequestInfo holds information extracted from the HTTP request.
|
||||||
|
// It is injected into the context passed to PostAuthHook.
|
||||||
|
type RequestInfo struct {
|
||||||
|
Query url.Values
|
||||||
|
Headers http.Header
|
||||||
|
}
|
||||||
|
|
||||||
|
type requestInfoKey struct{}
|
||||||
|
|
||||||
|
// WithRequestInfo returns a new context with the given RequestInfo attached.
|
||||||
|
func WithRequestInfo(ctx context.Context, info *RequestInfo) context.Context {
|
||||||
|
return context.WithValue(ctx, requestInfoKey{}, info)
|
||||||
|
}
|
||||||
|
|
||||||
|
// GetRequestInfo retrieves the RequestInfo from the context, if present.
|
||||||
|
func GetRequestInfo(ctx context.Context) *RequestInfo {
|
||||||
|
if val, ok := ctx.Value(requestInfoKey{}).(*RequestInfo); ok {
|
||||||
|
return val
|
||||||
|
}
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
// Auth encapsulates the runtime state and metadata associated with a single credential.
|
// Auth encapsulates the runtime state and metadata associated with a single credential.
|
||||||
type Auth struct {
|
type Auth struct {
|
||||||
// ID uniquely identifies the auth record across restarts.
|
// ID uniquely identifies the auth record across restarts.
|
||||||
|
|||||||
@@ -153,6 +153,16 @@ func (b *Builder) WithLocalManagementPassword(password string) *Builder {
|
|||||||
return b
|
return b
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// WithPostAuthHook registers a hook to be called after an Auth record is created
|
||||||
|
// but before it is persisted to storage.
|
||||||
|
func (b *Builder) WithPostAuthHook(hook coreauth.PostAuthHook) *Builder {
|
||||||
|
if hook == nil {
|
||||||
|
return b
|
||||||
|
}
|
||||||
|
b.serverOptions = append(b.serverOptions, api.WithPostAuthHook(hook))
|
||||||
|
return b
|
||||||
|
}
|
||||||
|
|
||||||
// Build validates inputs, applies defaults, and returns a ready-to-run service.
|
// Build validates inputs, applies defaults, and returns a ready-to-run service.
|
||||||
func (b *Builder) Build() (*Service, error) {
|
func (b *Builder) Build() (*Service, error) {
|
||||||
if b.cfg == nil {
|
if b.cfg == nil {
|
||||||
|
|||||||
@@ -57,6 +57,8 @@ type Response struct {
|
|||||||
Payload []byte
|
Payload []byte
|
||||||
// Metadata exposes optional structured data for translators.
|
// Metadata exposes optional structured data for translators.
|
||||||
Metadata map[string]any
|
Metadata map[string]any
|
||||||
|
// Headers carries upstream HTTP response headers for passthrough to clients.
|
||||||
|
Headers http.Header
|
||||||
}
|
}
|
||||||
|
|
||||||
// StreamChunk represents a single streaming payload unit emitted by provider executors.
|
// StreamChunk represents a single streaming payload unit emitted by provider executors.
|
||||||
@@ -67,6 +69,15 @@ type StreamChunk struct {
|
|||||||
Err error
|
Err error
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// StreamResult wraps the streaming response, providing both the chunk channel
|
||||||
|
// and the upstream HTTP response headers captured before streaming begins.
|
||||||
|
type StreamResult struct {
|
||||||
|
// Headers carries upstream HTTP response headers from the initial connection.
|
||||||
|
Headers http.Header
|
||||||
|
// Chunks is the channel of streaming payload units.
|
||||||
|
Chunks <-chan StreamChunk
|
||||||
|
}
|
||||||
|
|
||||||
// StatusError represents an error that carries an HTTP-like status code.
|
// StatusError represents an error that carries an HTTP-like status code.
|
||||||
// Provider executors should implement this when possible to enable
|
// Provider executors should implement this when possible to enable
|
||||||
// better auth state updates on failures (e.g., 401/402/429).
|
// better auth state updates on failures (e.g., 401/402/429).
|
||||||
|
|||||||
@@ -90,3 +90,26 @@ func TestApplyOAuthModelAlias_ForkAddsMultipleAliases(t *testing.T) {
|
|||||||
t.Fatalf("expected forked model name %q, got %q", "models/g5-2", out[2].Name)
|
t.Fatalf("expected forked model name %q, got %q", "models/g5-2", out[2].Name)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func TestApplyOAuthModelAlias_DefaultGitHubCopilotAliasViaSanitize(t *testing.T) {
|
||||||
|
cfg := &config.Config{}
|
||||||
|
cfg.SanitizeOAuthModelAlias()
|
||||||
|
|
||||||
|
models := []*ModelInfo{
|
||||||
|
{ID: "claude-opus-4.6", Name: "models/claude-opus-4.6"},
|
||||||
|
}
|
||||||
|
|
||||||
|
out := applyOAuthModelAlias(cfg, "github-copilot", "oauth", models)
|
||||||
|
if len(out) != 2 {
|
||||||
|
t.Fatalf("expected 2 models (original + default alias), got %d", len(out))
|
||||||
|
}
|
||||||
|
if out[0].ID != "claude-opus-4.6" {
|
||||||
|
t.Fatalf("expected first model id %q, got %q", "claude-opus-4.6", out[0].ID)
|
||||||
|
}
|
||||||
|
if out[1].ID != "claude-opus-4-6" {
|
||||||
|
t.Fatalf("expected second model id %q, got %q", "claude-opus-4-6", out[1].ID)
|
||||||
|
}
|
||||||
|
if out[1].Name != "models/claude-opus-4-6" {
|
||||||
|
t.Fatalf("expected aliased model name %q, got %q", "models/claude-opus-4-6", out[1].Name)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|||||||
Reference in New Issue
Block a user