Compare commits

..

50 Commits

Author SHA1 Message Date
Luis Pater
a862984dca Merge pull request #58 from router-for-me/plus
v6.6.48
2025-12-23 22:34:15 +08:00
Luis Pater
f0365f0465 Merge branch 'main' into plus 2025-12-23 22:34:08 +08:00
Luis Pater
6d1e20e940 fix(claude_executor): update header logic for API key handling
Refined header assignment to use `x-api-key` for Anthropic API requests, ensuring correct authorization behavior based on request attributes and URL validation.
2025-12-23 22:30:25 +08:00
Luis Pater
e52b542e22 Merge pull request #684 from packyme/main
docs(readme): add PackyCode sponsor
2025-12-23 17:19:25 +08:00
SmallL-U
8f6abb8a86 fix(readme): correct closing tbody tag 2025-12-23 17:17:57 +08:00
SmallL-U
ed8eaae964 docs(readme): add PackyCode sponsor 2025-12-23 17:11:34 +08:00
Luis Pater
e8de87ee90 Merge branch 'router-for-me:main' into main 2025-12-23 08:48:29 +08:00
Luis Pater
4e572ec8b9 fix(translators): handle string system instructions in Claude translators
Updated Antigravity, Gemini, and Gemini-CLI translators to process `systemResult` of type `string` for system instructions. Ensures properly formatted JSON with dynamic content assignment.
2025-12-23 08:44:36 +08:00
Luis Pater
6c7f18c448 Merge branch 'router-for-me:main' into main 2025-12-23 03:50:35 +08:00
Luis Pater
24bc9cba67 Fixed: #639
fix(antigravity): validate function arguments before serialization

Ensure `function.arguments` is a valid JSON before setting raw bytes, fallback to setting as parameterized content if invalid.
2025-12-23 03:49:45 +08:00
Luis Pater
97356b1a04 Merge branch 'router-for-me:main' into main 2025-12-23 03:17:44 +08:00
Luis Pater
1084b53fba Fixed: #655
refactor(antigravity): clean up tool key filtering and improve signature caching logic
2025-12-23 03:16:51 +08:00
Luis Pater
b1aecc2bf1 Merge branch 'router-for-me:main' into main 2025-12-23 02:49:37 +08:00
Luis Pater
83b90e106f refactor(antigravity): add sandbox URL constant and update base URLs routine 2025-12-23 02:47:56 +08:00
Luis Pater
f52114dab2 Merge branch 'router-for-me:main' into main 2025-12-23 02:26:33 +08:00
Luis Pater
5106caf641 Fixed: #654
feat: handle array input for system instructions in translators

Enhanced Gemini, Gemini-CLI, and Antigravity translators to process array content for system instructions. Adds support for assigning roles and handling multiple content parts dynamically.
2025-12-23 02:24:26 +08:00
Luis Pater
12370ee84e Merge branch 'router-for-me:main' into main 2025-12-22 22:53:29 +08:00
Luis Pater
b84ccc6e7a feat: add unit tests for routing strategies and implement dynamic selector updates
Added comprehensive tests for `FillFirstSelector` and `RoundRobinSelector` to ensure proper behavior, including deterministic, cyclical, and concurrent scenarios. Introduced dynamic routing strategy updates in `service.go`, normalizing strategies and seamlessly switching between `fill-first` and `round-robin`. Updated `Manager` to support selector changes via the new `SetSelector` method.
2025-12-22 22:52:23 +08:00
Luis Pater
e19ddb53e7 Merge pull request #663 from jroth1111/feat/fill-first-selector
feat: add fill-first routing strategy
2025-12-22 22:26:32 +08:00
gwizz
2a0100b2d6 docs: add routing strategy example 2025-12-23 00:39:18 +11:00
gwizz
c020fa60d0 fix: keep round-robin as default routing 2025-12-22 23:39:41 +11:00
gwizz
b078be4613 feat: add fill-first routing strategy 2025-12-22 23:38:10 +11:00
Luis Pater
5f65dd5bb4 Merge branch 'router-for-me:main' into main 2025-12-22 16:58:26 +08:00
Luis Pater
27b43ed63f Merge pull request #658 from moxi000/fix-responses-convert
Fix responses-format handling for chat completions(Support Cursor)
2025-12-22 16:47:46 +08:00
moxi
f6a3a1d0ba Remove compat test under translator per review 2025-12-22 16:44:50 +08:00
moxi
830fd8eac2 Fix responses-format handling for chat completions 2025-12-22 13:54:02 +08:00
Luis Pater
a86d501dc2 refactor: replace json.Marshal and json.Unmarshal with sjson and gjson
Optimized the handling of JSON serialization and deserialization by replacing redundant `json.Marshal` and `json.Unmarshal` calls with `sjson` and `gjson`. Introduced a `marshalJSONValue` utility for compact JSON encoding, improving performance and code simplicity. Removed unused `encoding/json` imports.
2025-12-22 11:44:06 +08:00
Luis Pater
e755e567ea Merge branch 'router-for-me:main' into main 2025-12-21 19:54:13 +08:00
Luis Pater
dbcbe48ead Merge pull request #641 from router-for-me/url-OAuth-add-ter
OAuth and management
2025-12-21 17:25:24 +08:00
Luis Pater
63908869f6 Merge pull request #611 from soilSpoon/feature/antigravity
feat(antigravity): Improve Claude model compatibility
2025-12-21 16:27:29 +08:00
이대희
7dc40ba6d4 Improve tool-call parsing, schema sanitization, and hint injection
Improve parsing of tool call inputs and Antigravity compatibility to avoid invalid thinking/tool_use errors.

- Parse tool call inputs robustly by accepting both object and JSON-string formats and only produce a functionCall part when valid args exist, reducing spurious or malformed parts.
- Preserve the skip_thought_signature_validator approach for calls without a valid thinking signature but stop toggling/tracking a separate "disable thinking" flag; this prevents unnecessary removal of thinkingConfig.
- Sanitize tool input schemas before attaching them to the Antigravity request to improve compatibility.
- Append the interleaved-thinking hint as a new parts entry instead of overwriting/setting text directly, preserving structure.
- Remove unused tracking logic and related comments to simplify flow.

These changes reduce errors related to missing/invalid thinking signatures, improve schema compatibility, and make hint injection safer and more consistent.
2025-12-21 17:16:40 +09:00
이대희
4070c9de81 Remove interleaved-thinking header from requests
Removes the addition of the "anthropic-beta: interleaved-thinking-2025-05-14" header for Claude thinking models when building HTTP requests.

This prevents sending an experimental/feature flag header that is no longer required and avoids potential compatibility or routing issues with downstream services. Keeps request headers simpler and more standard.
2025-12-21 15:29:36 +09:00
이대희
1e9e4a86a2 Improve thinking/tool signature handling for Claude and Gemini requests
Prefer cached signatures and avoid injecting dummy thinking blocks; instead remove unsigned thinking blocks and add a skip sentinel for tool calls without a valid signature. Generate stable session IDs from the first user message, apply schema cleaning only for Claude models, and reorder thinking parts so thinking appears first. For Gemini, remove thinking blocks and attach a skip sentinel to function calls. Simplify response handling by passing raw function args through (remove special Bash conversion). Update and add tests to reflect the new behavior.

These changes prevent rejected dummy signatures, improve compatibility with Antigravity’s signature validation, provide more stable session IDs for conversation grouping, and make request/response translation more robust.
2025-12-21 15:15:50 +09:00
이대희
406a27271a Remove opencode-antigravity-auth submodule
Remove the opencode-antigravity-auth submodule reference from the repository.

Cleans up the project by eliminating an external submodule pointer that is no longer needed or maintained, reducing repository complexity and avoiding dangling submodule state.
2025-12-21 14:54:49 +09:00
이대희
9f9a4fc2af Remove unused submodules
Removes two obsolete git submodules to clean up repository state and reduce maintenance overhead.

This eliminates external references that are no longer needed, simplifying dependency management and repository maintenance going forward.
2025-12-21 14:48:50 +09:00
Supra4E8C
781bc1521b fix(oauth): prevent stale session timeouts after login
- stop callback forwarders by instance to avoid cross-session shutdowns
  - clear pending sessions for a provider after successful auth
2025-12-21 10:48:40 +08:00
Supra4E8C
05d201ece8 fix(gemini): gate callback prompt on project_id 2025-12-21 07:21:12 +08:00
Supra4E8C
cd0c94f48a fix(sdk/auth): prevent OAuth manual prompt goroutine leak,Use timer-based manual prompt per provider and remove oauth_callback helper. 2025-12-21 07:06:28 +08:00
Supra4E8C
24970baa57 management: allow prefix updates in provider PATCH handlers 2025-12-21 02:14:28 +08:00
Supra4E8C
9855615f1e fix(gemini): avoid stale manual oauth prompt and accept schemeless callbacks 2025-12-20 19:03:38 +08:00
Supra4E8C
93414f1baa feat (auth): CLI OAuth supports pasting callback URLs to complete login
- Added callback URL resolution and terminal prompt logic
  - Codex/Claude/iFlow/Antigravity/Gemini login supports callback URL or local callback completion
  - Update Gemini login option signature and manager call
  - CLI default prompt function is compatible with null input to continue waiting
2025-12-20 18:25:55 +08:00
이대희
e04b02113a refactor: Improve cache eviction ordering and clean up session ID usage
Improve the cache eviction routine to sort entries by timestamp using the standard library sort routine (stable, clearer and faster than the prior manual selection/bubble logic), and remove a redundant request-derived session ID helper in favor of the centralized session ID function. Also drop now-unused crypto/encoding imports.

This yields clearer, more maintainable eviction logic and removes duplicated/unused code and imports to reduce surface area and potential inconsistencies.
2025-12-19 13:14:51 +09:00
이대희
3275494fde refactor: Use helper to extract wrapped "thinking" text
Improve robustness when handling "thinking" content by using a dedicated helper to extract the thinking text. This ensures wrapped or nested thinking objects are handled correctly instead of relying on a direct string extraction, reducing parsing errors for complex payloads.
2025-12-19 13:09:57 +09:00
이대희
c1f8211acb fix: Normalize Bash tool args and add signature caching support
Normalize Bash tool arguments by converting a "command" key into "cmd" using JSON-aware parsing, avoiding brittle string replacements that could corrupt values. Apply this conversion in both streaming and non-streaming response paths so bash-style tool calls are emitted with the expected "cmd" field.

Add support for accumulating thinking text and carrying session identifiers to enable signature caching/restore for unsigned thinking blocks, improving handling of thinking-state continuity across requests/responses.

Also perform small cleanups: import logging, tidy comments and test descriptions. These changes make tool-argument handling more robust and enable reliable signature restoration for thinking blocks.
2025-12-19 11:12:16 +09:00
이대희
98fa2a1597 feat(translator/antigravity/claude): support interleaved thinking, signature restoration and system hint injection 2025-12-19 10:30:59 +09:00
이대희
0e7c79ba23 feat(translator/antigravity/claude): support interleaved thinking, signature restoration and system hint injection 2025-12-19 10:28:25 +09:00
이대희
b6ba15fcbd fix(runtime/executor): Antigravity executor schema handling and Claude-specific headers 2025-12-19 10:28:23 +09:00
이대희
e44167d7a4 refactor(util/schema): rename and extend Gemini schema cleaning for Antigravity and add empty-schema placeholders 2025-12-19 10:28:17 +09:00
이대희
1bfa75f780 feat(util): add helper to detect Claude thinking models 2025-12-19 10:28:15 +09:00
이대희
bbcb5552f3 feat(cache): add signature cache for Claude thinking blocks 2025-12-19 10:28:12 +09:00
64 changed files with 4089 additions and 1957 deletions

BIN
assets/packycode.png Normal file

Binary file not shown.

After

Width:  |  Height:  |  Size: 8.1 KiB

View File

@@ -71,6 +71,10 @@ quota-exceeded:
switch-project: true # Whether to automatically switch to another project when a quota is exceeded switch-project: true # Whether to automatically switch to another project when a quota is exceeded
switch-preview-model: true # Whether to automatically switch to a preview model when a quota is exceeded switch-preview-model: true # Whether to automatically switch to a preview model when a quota is exceeded
# Routing strategy for selecting credentials when multiple match.
routing:
strategy: "round-robin" # round-robin (default), fill-first
# When true, enable authentication for the WebSocket API (/v1/ws). # When true, enable authentication for the WebSocket API (/v1/ws).
ws-auth: false ws-auth: false

View File

@@ -201,6 +201,19 @@ func stopCallbackForwarder(port int) {
stopForwarderInstance(port, forwarder) stopForwarderInstance(port, forwarder)
} }
func stopCallbackForwarderInstance(port int, forwarder *callbackForwarder) {
if forwarder == nil {
return
}
callbackForwardersMu.Lock()
if current := callbackForwarders[port]; current == forwarder {
delete(callbackForwarders, port)
}
callbackForwardersMu.Unlock()
stopForwarderInstance(port, forwarder)
}
func stopForwarderInstance(port int, forwarder *callbackForwarder) { func stopForwarderInstance(port int, forwarder *callbackForwarder) {
if forwarder == nil || forwarder.server == nil { if forwarder == nil || forwarder.server == nil {
return return
@@ -789,6 +802,7 @@ func (h *Handler) RequestAnthropicToken(c *gin.Context) {
RegisterOAuthSession(state, "anthropic") RegisterOAuthSession(state, "anthropic")
isWebUI := isWebUIRequest(c) isWebUI := isWebUIRequest(c)
var forwarder *callbackForwarder
if isWebUI { if isWebUI {
targetURL, errTarget := h.managementCallbackURL("/anthropic/callback") targetURL, errTarget := h.managementCallbackURL("/anthropic/callback")
if errTarget != nil { if errTarget != nil {
@@ -796,7 +810,8 @@ func (h *Handler) RequestAnthropicToken(c *gin.Context) {
c.JSON(http.StatusInternalServerError, gin.H{"error": "callback server unavailable"}) c.JSON(http.StatusInternalServerError, gin.H{"error": "callback server unavailable"})
return return
} }
if _, errStart := startCallbackForwarder(anthropicCallbackPort, "anthropic", targetURL); errStart != nil { var errStart error
if forwarder, errStart = startCallbackForwarder(anthropicCallbackPort, "anthropic", targetURL); errStart != nil {
log.WithError(errStart).Error("failed to start anthropic callback forwarder") log.WithError(errStart).Error("failed to start anthropic callback forwarder")
c.JSON(http.StatusInternalServerError, gin.H{"error": "failed to start callback server"}) c.JSON(http.StatusInternalServerError, gin.H{"error": "failed to start callback server"})
return return
@@ -805,7 +820,7 @@ func (h *Handler) RequestAnthropicToken(c *gin.Context) {
go func() { go func() {
if isWebUI { if isWebUI {
defer stopCallbackForwarder(anthropicCallbackPort) defer stopCallbackForwarderInstance(anthropicCallbackPort, forwarder)
} }
// Helper: wait for callback file // Helper: wait for callback file
@@ -813,6 +828,9 @@ func (h *Handler) RequestAnthropicToken(c *gin.Context) {
waitForFile := func(path string, timeout time.Duration) (map[string]string, error) { waitForFile := func(path string, timeout time.Duration) (map[string]string, error) {
deadline := time.Now().Add(timeout) deadline := time.Now().Add(timeout)
for { for {
if !IsOAuthSessionPending(state, "anthropic") {
return nil, errOAuthSessionNotPending
}
if time.Now().After(deadline) { if time.Now().After(deadline) {
SetOAuthSessionError(state, "Timeout waiting for OAuth callback") SetOAuthSessionError(state, "Timeout waiting for OAuth callback")
return nil, fmt.Errorf("timeout waiting for OAuth callback") return nil, fmt.Errorf("timeout waiting for OAuth callback")
@@ -832,6 +850,9 @@ func (h *Handler) RequestAnthropicToken(c *gin.Context) {
// Wait up to 5 minutes // Wait up to 5 minutes
resultMap, errWait := waitForFile(waitFile, 5*time.Minute) resultMap, errWait := waitForFile(waitFile, 5*time.Minute)
if errWait != nil { if errWait != nil {
if errors.Is(errWait, errOAuthSessionNotPending) {
return
}
authErr := claude.NewAuthenticationError(claude.ErrCallbackTimeout, errWait) authErr := claude.NewAuthenticationError(claude.ErrCallbackTimeout, errWait)
log.Error(claude.GetUserFriendlyMessage(authErr)) log.Error(claude.GetUserFriendlyMessage(authErr))
return return
@@ -937,6 +958,7 @@ func (h *Handler) RequestAnthropicToken(c *gin.Context) {
} }
fmt.Println("You can now use Claude services through this CLI") fmt.Println("You can now use Claude services through this CLI")
CompleteOAuthSession(state) CompleteOAuthSession(state)
CompleteOAuthSessionsByProvider("anthropic")
}() }()
c.JSON(200, gin.H{"status": "ok", "url": authURL, "state": state}) c.JSON(200, gin.H{"status": "ok", "url": authURL, "state": state})
@@ -972,6 +994,7 @@ func (h *Handler) RequestGeminiCLIToken(c *gin.Context) {
RegisterOAuthSession(state, "gemini") RegisterOAuthSession(state, "gemini")
isWebUI := isWebUIRequest(c) isWebUI := isWebUIRequest(c)
var forwarder *callbackForwarder
if isWebUI { if isWebUI {
targetURL, errTarget := h.managementCallbackURL("/google/callback") targetURL, errTarget := h.managementCallbackURL("/google/callback")
if errTarget != nil { if errTarget != nil {
@@ -979,7 +1002,8 @@ func (h *Handler) RequestGeminiCLIToken(c *gin.Context) {
c.JSON(http.StatusInternalServerError, gin.H{"error": "callback server unavailable"}) c.JSON(http.StatusInternalServerError, gin.H{"error": "callback server unavailable"})
return return
} }
if _, errStart := startCallbackForwarder(geminiCallbackPort, "gemini", targetURL); errStart != nil { var errStart error
if forwarder, errStart = startCallbackForwarder(geminiCallbackPort, "gemini", targetURL); errStart != nil {
log.WithError(errStart).Error("failed to start gemini callback forwarder") log.WithError(errStart).Error("failed to start gemini callback forwarder")
c.JSON(http.StatusInternalServerError, gin.H{"error": "failed to start callback server"}) c.JSON(http.StatusInternalServerError, gin.H{"error": "failed to start callback server"})
return return
@@ -988,7 +1012,7 @@ func (h *Handler) RequestGeminiCLIToken(c *gin.Context) {
go func() { go func() {
if isWebUI { if isWebUI {
defer stopCallbackForwarder(geminiCallbackPort) defer stopCallbackForwarderInstance(geminiCallbackPort, forwarder)
} }
// Wait for callback file written by server route // Wait for callback file written by server route
@@ -997,6 +1021,9 @@ func (h *Handler) RequestGeminiCLIToken(c *gin.Context) {
deadline := time.Now().Add(5 * time.Minute) deadline := time.Now().Add(5 * time.Minute)
var authCode string var authCode string
for { for {
if !IsOAuthSessionPending(state, "gemini") {
return
}
if time.Now().After(deadline) { if time.Now().After(deadline) {
log.Error("oauth flow timed out") log.Error("oauth flow timed out")
SetOAuthSessionError(state, "OAuth flow timed out") SetOAuthSessionError(state, "OAuth flow timed out")
@@ -1097,7 +1124,9 @@ func (h *Handler) RequestGeminiCLIToken(c *gin.Context) {
// Initialize authenticated HTTP client via GeminiAuth to honor proxy settings // Initialize authenticated HTTP client via GeminiAuth to honor proxy settings
gemAuth := geminiAuth.NewGeminiAuth() gemAuth := geminiAuth.NewGeminiAuth()
gemClient, errGetClient := gemAuth.GetAuthenticatedClient(ctx, &ts, h.cfg, true) gemClient, errGetClient := gemAuth.GetAuthenticatedClient(ctx, &ts, h.cfg, &geminiAuth.WebLoginOptions{
NoBrowser: true,
})
if errGetClient != nil { if errGetClient != nil {
log.Errorf("failed to get authenticated client: %v", errGetClient) log.Errorf("failed to get authenticated client: %v", errGetClient)
SetOAuthSessionError(state, "Failed to get authenticated client") SetOAuthSessionError(state, "Failed to get authenticated client")
@@ -1170,6 +1199,7 @@ func (h *Handler) RequestGeminiCLIToken(c *gin.Context) {
} }
CompleteOAuthSession(state) CompleteOAuthSession(state)
CompleteOAuthSessionsByProvider("gemini")
fmt.Printf("You can now use Gemini CLI services through this CLI; token saved to %s\n", savedPath) fmt.Printf("You can now use Gemini CLI services through this CLI; token saved to %s\n", savedPath)
}() }()
@@ -1211,6 +1241,7 @@ func (h *Handler) RequestCodexToken(c *gin.Context) {
RegisterOAuthSession(state, "codex") RegisterOAuthSession(state, "codex")
isWebUI := isWebUIRequest(c) isWebUI := isWebUIRequest(c)
var forwarder *callbackForwarder
if isWebUI { if isWebUI {
targetURL, errTarget := h.managementCallbackURL("/codex/callback") targetURL, errTarget := h.managementCallbackURL("/codex/callback")
if errTarget != nil { if errTarget != nil {
@@ -1218,7 +1249,8 @@ func (h *Handler) RequestCodexToken(c *gin.Context) {
c.JSON(http.StatusInternalServerError, gin.H{"error": "callback server unavailable"}) c.JSON(http.StatusInternalServerError, gin.H{"error": "callback server unavailable"})
return return
} }
if _, errStart := startCallbackForwarder(codexCallbackPort, "codex", targetURL); errStart != nil { var errStart error
if forwarder, errStart = startCallbackForwarder(codexCallbackPort, "codex", targetURL); errStart != nil {
log.WithError(errStart).Error("failed to start codex callback forwarder") log.WithError(errStart).Error("failed to start codex callback forwarder")
c.JSON(http.StatusInternalServerError, gin.H{"error": "failed to start callback server"}) c.JSON(http.StatusInternalServerError, gin.H{"error": "failed to start callback server"})
return return
@@ -1227,7 +1259,7 @@ func (h *Handler) RequestCodexToken(c *gin.Context) {
go func() { go func() {
if isWebUI { if isWebUI {
defer stopCallbackForwarder(codexCallbackPort) defer stopCallbackForwarderInstance(codexCallbackPort, forwarder)
} }
// Wait for callback file // Wait for callback file
@@ -1235,6 +1267,9 @@ func (h *Handler) RequestCodexToken(c *gin.Context) {
deadline := time.Now().Add(5 * time.Minute) deadline := time.Now().Add(5 * time.Minute)
var code string var code string
for { for {
if !IsOAuthSessionPending(state, "codex") {
return
}
if time.Now().After(deadline) { if time.Now().After(deadline) {
authErr := codex.NewAuthenticationError(codex.ErrCallbackTimeout, fmt.Errorf("timeout waiting for OAuth callback")) authErr := codex.NewAuthenticationError(codex.ErrCallbackTimeout, fmt.Errorf("timeout waiting for OAuth callback"))
log.Error(codex.GetUserFriendlyMessage(authErr)) log.Error(codex.GetUserFriendlyMessage(authErr))
@@ -1350,6 +1385,7 @@ func (h *Handler) RequestCodexToken(c *gin.Context) {
} }
fmt.Println("You can now use Codex services through this CLI") fmt.Println("You can now use Codex services through this CLI")
CompleteOAuthSession(state) CompleteOAuthSession(state)
CompleteOAuthSessionsByProvider("codex")
}() }()
c.JSON(200, gin.H{"status": "ok", "url": authURL, "state": state}) c.JSON(200, gin.H{"status": "ok", "url": authURL, "state": state})
@@ -1395,6 +1431,7 @@ func (h *Handler) RequestAntigravityToken(c *gin.Context) {
RegisterOAuthSession(state, "antigravity") RegisterOAuthSession(state, "antigravity")
isWebUI := isWebUIRequest(c) isWebUI := isWebUIRequest(c)
var forwarder *callbackForwarder
if isWebUI { if isWebUI {
targetURL, errTarget := h.managementCallbackURL("/antigravity/callback") targetURL, errTarget := h.managementCallbackURL("/antigravity/callback")
if errTarget != nil { if errTarget != nil {
@@ -1402,7 +1439,8 @@ func (h *Handler) RequestAntigravityToken(c *gin.Context) {
c.JSON(http.StatusInternalServerError, gin.H{"error": "callback server unavailable"}) c.JSON(http.StatusInternalServerError, gin.H{"error": "callback server unavailable"})
return return
} }
if _, errStart := startCallbackForwarder(antigravityCallbackPort, "antigravity", targetURL); errStart != nil { var errStart error
if forwarder, errStart = startCallbackForwarder(antigravityCallbackPort, "antigravity", targetURL); errStart != nil {
log.WithError(errStart).Error("failed to start antigravity callback forwarder") log.WithError(errStart).Error("failed to start antigravity callback forwarder")
c.JSON(http.StatusInternalServerError, gin.H{"error": "failed to start callback server"}) c.JSON(http.StatusInternalServerError, gin.H{"error": "failed to start callback server"})
return return
@@ -1411,13 +1449,16 @@ func (h *Handler) RequestAntigravityToken(c *gin.Context) {
go func() { go func() {
if isWebUI { if isWebUI {
defer stopCallbackForwarder(antigravityCallbackPort) defer stopCallbackForwarderInstance(antigravityCallbackPort, forwarder)
} }
waitFile := filepath.Join(h.cfg.AuthDir, fmt.Sprintf(".oauth-antigravity-%s.oauth", state)) waitFile := filepath.Join(h.cfg.AuthDir, fmt.Sprintf(".oauth-antigravity-%s.oauth", state))
deadline := time.Now().Add(5 * time.Minute) deadline := time.Now().Add(5 * time.Minute)
var authCode string var authCode string
for { for {
if !IsOAuthSessionPending(state, "antigravity") {
return
}
if time.Now().After(deadline) { if time.Now().After(deadline) {
log.Error("oauth flow timed out") log.Error("oauth flow timed out")
SetOAuthSessionError(state, "OAuth flow timed out") SetOAuthSessionError(state, "OAuth flow timed out")
@@ -1580,6 +1621,7 @@ func (h *Handler) RequestAntigravityToken(c *gin.Context) {
} }
CompleteOAuthSession(state) CompleteOAuthSession(state)
CompleteOAuthSessionsByProvider("antigravity")
fmt.Printf("Authentication successful! Token saved to %s\n", savedPath) fmt.Printf("Authentication successful! Token saved to %s\n", savedPath)
if projectID != "" { if projectID != "" {
fmt.Printf("Using GCP project: %s\n", projectID) fmt.Printf("Using GCP project: %s\n", projectID)
@@ -1657,6 +1699,7 @@ func (h *Handler) RequestIFlowToken(c *gin.Context) {
RegisterOAuthSession(state, "iflow") RegisterOAuthSession(state, "iflow")
isWebUI := isWebUIRequest(c) isWebUI := isWebUIRequest(c)
var forwarder *callbackForwarder
if isWebUI { if isWebUI {
targetURL, errTarget := h.managementCallbackURL("/iflow/callback") targetURL, errTarget := h.managementCallbackURL("/iflow/callback")
if errTarget != nil { if errTarget != nil {
@@ -1664,7 +1707,8 @@ func (h *Handler) RequestIFlowToken(c *gin.Context) {
c.JSON(http.StatusInternalServerError, gin.H{"status": "error", "error": "callback server unavailable"}) c.JSON(http.StatusInternalServerError, gin.H{"status": "error", "error": "callback server unavailable"})
return return
} }
if _, errStart := startCallbackForwarder(iflowauth.CallbackPort, "iflow", targetURL); errStart != nil { var errStart error
if forwarder, errStart = startCallbackForwarder(iflowauth.CallbackPort, "iflow", targetURL); errStart != nil {
log.WithError(errStart).Error("failed to start iflow callback forwarder") log.WithError(errStart).Error("failed to start iflow callback forwarder")
c.JSON(http.StatusInternalServerError, gin.H{"status": "error", "error": "failed to start callback server"}) c.JSON(http.StatusInternalServerError, gin.H{"status": "error", "error": "failed to start callback server"})
return return
@@ -1673,7 +1717,7 @@ func (h *Handler) RequestIFlowToken(c *gin.Context) {
go func() { go func() {
if isWebUI { if isWebUI {
defer stopCallbackForwarder(iflowauth.CallbackPort) defer stopCallbackForwarderInstance(iflowauth.CallbackPort, forwarder)
} }
fmt.Println("Waiting for authentication...") fmt.Println("Waiting for authentication...")
@@ -1681,6 +1725,9 @@ func (h *Handler) RequestIFlowToken(c *gin.Context) {
deadline := time.Now().Add(5 * time.Minute) deadline := time.Now().Add(5 * time.Minute)
var resultMap map[string]string var resultMap map[string]string
for { for {
if !IsOAuthSessionPending(state, "iflow") {
return
}
if time.Now().After(deadline) { if time.Now().After(deadline) {
SetOAuthSessionError(state, "Authentication failed") SetOAuthSessionError(state, "Authentication failed")
fmt.Println("Authentication failed: timeout waiting for callback") fmt.Println("Authentication failed: timeout waiting for callback")
@@ -1747,6 +1794,7 @@ func (h *Handler) RequestIFlowToken(c *gin.Context) {
} }
fmt.Println("You can now use iFlow services through this CLI") fmt.Println("You can now use iFlow services through this CLI")
CompleteOAuthSession(state) CompleteOAuthSession(state)
CompleteOAuthSessionsByProvider("iflow")
}() }()
c.JSON(http.StatusOK, gin.H{"status": "ok", "url": authURL, "state": state}) c.JSON(http.StatusOK, gin.H{"status": "ok", "url": authURL, "state": state})

View File

@@ -145,71 +145,74 @@ func (h *Handler) PutGeminiKeys(c *gin.Context) {
h.persist(c) h.persist(c)
} }
func (h *Handler) PatchGeminiKey(c *gin.Context) { func (h *Handler) PatchGeminiKey(c *gin.Context) {
type geminiKeyPatch struct {
APIKey *string `json:"api-key"`
Prefix *string `json:"prefix"`
BaseURL *string `json:"base-url"`
ProxyURL *string `json:"proxy-url"`
Headers *map[string]string `json:"headers"`
ExcludedModels *[]string `json:"excluded-models"`
}
var body struct { var body struct {
Index *int `json:"index"` Index *int `json:"index"`
Match *string `json:"match"` Match *string `json:"match"`
Value *config.GeminiKey `json:"value"` Value *geminiKeyPatch `json:"value"`
} }
if err := c.ShouldBindJSON(&body); err != nil || body.Value == nil { if err := c.ShouldBindJSON(&body); err != nil || body.Value == nil {
c.JSON(400, gin.H{"error": "invalid body"}) c.JSON(400, gin.H{"error": "invalid body"})
return return
} }
value := *body.Value targetIndex := -1
value.APIKey = strings.TrimSpace(value.APIKey) if body.Index != nil && *body.Index >= 0 && *body.Index < len(h.cfg.GeminiKey) {
value.BaseURL = strings.TrimSpace(value.BaseURL) targetIndex = *body.Index
value.ProxyURL = strings.TrimSpace(value.ProxyURL) }
value.ExcludedModels = config.NormalizeExcludedModels(value.ExcludedModels) if targetIndex == -1 && body.Match != nil {
if value.APIKey == "" { match := strings.TrimSpace(*body.Match)
// Treat empty API key as delete. if match != "" {
if body.Index != nil && *body.Index >= 0 && *body.Index < len(h.cfg.GeminiKey) { for i := range h.cfg.GeminiKey {
h.cfg.GeminiKey = append(h.cfg.GeminiKey[:*body.Index], h.cfg.GeminiKey[*body.Index+1:]...) if h.cfg.GeminiKey[i].APIKey == match {
h.cfg.SanitizeGeminiKeys() targetIndex = i
h.persist(c) break
return
}
if body.Match != nil {
match := strings.TrimSpace(*body.Match)
if match != "" {
out := make([]config.GeminiKey, 0, len(h.cfg.GeminiKey))
removed := false
for i := range h.cfg.GeminiKey {
if !removed && h.cfg.GeminiKey[i].APIKey == match {
removed = true
continue
}
out = append(out, h.cfg.GeminiKey[i])
}
if removed {
h.cfg.GeminiKey = out
h.cfg.SanitizeGeminiKeys()
h.persist(c)
return
} }
} }
} }
}
if targetIndex == -1 {
c.JSON(404, gin.H{"error": "item not found"}) c.JSON(404, gin.H{"error": "item not found"})
return return
} }
if body.Index != nil && *body.Index >= 0 && *body.Index < len(h.cfg.GeminiKey) { entry := h.cfg.GeminiKey[targetIndex]
h.cfg.GeminiKey[*body.Index] = value if body.Value.APIKey != nil {
h.cfg.SanitizeGeminiKeys() trimmed := strings.TrimSpace(*body.Value.APIKey)
h.persist(c) if trimmed == "" {
return h.cfg.GeminiKey = append(h.cfg.GeminiKey[:targetIndex], h.cfg.GeminiKey[targetIndex+1:]...)
} h.cfg.SanitizeGeminiKeys()
if body.Match != nil { h.persist(c)
match := strings.TrimSpace(*body.Match) return
for i := range h.cfg.GeminiKey {
if h.cfg.GeminiKey[i].APIKey == match {
h.cfg.GeminiKey[i] = value
h.cfg.SanitizeGeminiKeys()
h.persist(c)
return
}
} }
entry.APIKey = trimmed
} }
c.JSON(404, gin.H{"error": "item not found"}) if body.Value.Prefix != nil {
entry.Prefix = strings.TrimSpace(*body.Value.Prefix)
}
if body.Value.BaseURL != nil {
entry.BaseURL = strings.TrimSpace(*body.Value.BaseURL)
}
if body.Value.ProxyURL != nil {
entry.ProxyURL = strings.TrimSpace(*body.Value.ProxyURL)
}
if body.Value.Headers != nil {
entry.Headers = config.NormalizeHeaders(*body.Value.Headers)
}
if body.Value.ExcludedModels != nil {
entry.ExcludedModels = config.NormalizeExcludedModels(*body.Value.ExcludedModels)
}
h.cfg.GeminiKey[targetIndex] = entry
h.cfg.SanitizeGeminiKeys()
h.persist(c)
} }
func (h *Handler) DeleteGeminiKey(c *gin.Context) { func (h *Handler) DeleteGeminiKey(c *gin.Context) {
if val := strings.TrimSpace(c.Query("api-key")); val != "" { if val := strings.TrimSpace(c.Query("api-key")); val != "" {
out := make([]config.GeminiKey, 0, len(h.cfg.GeminiKey)) out := make([]config.GeminiKey, 0, len(h.cfg.GeminiKey))
@@ -268,35 +271,70 @@ func (h *Handler) PutClaudeKeys(c *gin.Context) {
h.persist(c) h.persist(c)
} }
func (h *Handler) PatchClaudeKey(c *gin.Context) { func (h *Handler) PatchClaudeKey(c *gin.Context) {
type claudeKeyPatch struct {
APIKey *string `json:"api-key"`
Prefix *string `json:"prefix"`
BaseURL *string `json:"base-url"`
ProxyURL *string `json:"proxy-url"`
Models *[]config.ClaudeModel `json:"models"`
Headers *map[string]string `json:"headers"`
ExcludedModels *[]string `json:"excluded-models"`
}
var body struct { var body struct {
Index *int `json:"index"` Index *int `json:"index"`
Match *string `json:"match"` Match *string `json:"match"`
Value *config.ClaudeKey `json:"value"` Value *claudeKeyPatch `json:"value"`
} }
if err := c.ShouldBindJSON(&body); err != nil || body.Value == nil { if err := c.ShouldBindJSON(&body); err != nil || body.Value == nil {
c.JSON(400, gin.H{"error": "invalid body"}) c.JSON(400, gin.H{"error": "invalid body"})
return return
} }
value := *body.Value targetIndex := -1
normalizeClaudeKey(&value)
if body.Index != nil && *body.Index >= 0 && *body.Index < len(h.cfg.ClaudeKey) { if body.Index != nil && *body.Index >= 0 && *body.Index < len(h.cfg.ClaudeKey) {
h.cfg.ClaudeKey[*body.Index] = value targetIndex = *body.Index
h.cfg.SanitizeClaudeKeys()
h.persist(c)
return
} }
if body.Match != nil { if targetIndex == -1 && body.Match != nil {
match := strings.TrimSpace(*body.Match)
for i := range h.cfg.ClaudeKey { for i := range h.cfg.ClaudeKey {
if h.cfg.ClaudeKey[i].APIKey == *body.Match { if h.cfg.ClaudeKey[i].APIKey == match {
h.cfg.ClaudeKey[i] = value targetIndex = i
h.cfg.SanitizeClaudeKeys() break
h.persist(c)
return
} }
} }
} }
c.JSON(404, gin.H{"error": "item not found"}) if targetIndex == -1 {
c.JSON(404, gin.H{"error": "item not found"})
return
}
entry := h.cfg.ClaudeKey[targetIndex]
if body.Value.APIKey != nil {
entry.APIKey = strings.TrimSpace(*body.Value.APIKey)
}
if body.Value.Prefix != nil {
entry.Prefix = strings.TrimSpace(*body.Value.Prefix)
}
if body.Value.BaseURL != nil {
entry.BaseURL = strings.TrimSpace(*body.Value.BaseURL)
}
if body.Value.ProxyURL != nil {
entry.ProxyURL = strings.TrimSpace(*body.Value.ProxyURL)
}
if body.Value.Models != nil {
entry.Models = append([]config.ClaudeModel(nil), (*body.Value.Models)...)
}
if body.Value.Headers != nil {
entry.Headers = config.NormalizeHeaders(*body.Value.Headers)
}
if body.Value.ExcludedModels != nil {
entry.ExcludedModels = config.NormalizeExcludedModels(*body.Value.ExcludedModels)
}
normalizeClaudeKey(&entry)
h.cfg.ClaudeKey[targetIndex] = entry
h.cfg.SanitizeClaudeKeys()
h.persist(c)
} }
func (h *Handler) DeleteClaudeKey(c *gin.Context) { func (h *Handler) DeleteClaudeKey(c *gin.Context) {
if val := c.Query("api-key"); val != "" { if val := c.Query("api-key"); val != "" {
out := make([]config.ClaudeKey, 0, len(h.cfg.ClaudeKey)) out := make([]config.ClaudeKey, 0, len(h.cfg.ClaudeKey))
@@ -356,62 +394,73 @@ func (h *Handler) PutOpenAICompat(c *gin.Context) {
h.persist(c) h.persist(c)
} }
func (h *Handler) PatchOpenAICompat(c *gin.Context) { func (h *Handler) PatchOpenAICompat(c *gin.Context) {
type openAICompatPatch struct {
Name *string `json:"name"`
Prefix *string `json:"prefix"`
BaseURL *string `json:"base-url"`
APIKeyEntries *[]config.OpenAICompatibilityAPIKey `json:"api-key-entries"`
Models *[]config.OpenAICompatibilityModel `json:"models"`
Headers *map[string]string `json:"headers"`
}
var body struct { var body struct {
Name *string `json:"name"` Name *string `json:"name"`
Index *int `json:"index"` Index *int `json:"index"`
Value *config.OpenAICompatibility `json:"value"` Value *openAICompatPatch `json:"value"`
} }
if err := c.ShouldBindJSON(&body); err != nil || body.Value == nil { if err := c.ShouldBindJSON(&body); err != nil || body.Value == nil {
c.JSON(400, gin.H{"error": "invalid body"}) c.JSON(400, gin.H{"error": "invalid body"})
return return
} }
normalizeOpenAICompatibilityEntry(body.Value) targetIndex := -1
// If base-url becomes empty, delete the provider instead of updating if body.Index != nil && *body.Index >= 0 && *body.Index < len(h.cfg.OpenAICompatibility) {
if strings.TrimSpace(body.Value.BaseURL) == "" { targetIndex = *body.Index
if body.Index != nil && *body.Index >= 0 && *body.Index < len(h.cfg.OpenAICompatibility) { }
h.cfg.OpenAICompatibility = append(h.cfg.OpenAICompatibility[:*body.Index], h.cfg.OpenAICompatibility[*body.Index+1:]...) if targetIndex == -1 && body.Name != nil {
match := strings.TrimSpace(*body.Name)
for i := range h.cfg.OpenAICompatibility {
if h.cfg.OpenAICompatibility[i].Name == match {
targetIndex = i
break
}
}
}
if targetIndex == -1 {
c.JSON(404, gin.H{"error": "item not found"})
return
}
entry := h.cfg.OpenAICompatibility[targetIndex]
if body.Value.Name != nil {
entry.Name = strings.TrimSpace(*body.Value.Name)
}
if body.Value.Prefix != nil {
entry.Prefix = strings.TrimSpace(*body.Value.Prefix)
}
if body.Value.BaseURL != nil {
trimmed := strings.TrimSpace(*body.Value.BaseURL)
if trimmed == "" {
h.cfg.OpenAICompatibility = append(h.cfg.OpenAICompatibility[:targetIndex], h.cfg.OpenAICompatibility[targetIndex+1:]...)
h.cfg.SanitizeOpenAICompatibility() h.cfg.SanitizeOpenAICompatibility()
h.persist(c) h.persist(c)
return return
} }
if body.Name != nil { entry.BaseURL = trimmed
out := make([]config.OpenAICompatibility, 0, len(h.cfg.OpenAICompatibility))
removed := false
for i := range h.cfg.OpenAICompatibility {
if !removed && h.cfg.OpenAICompatibility[i].Name == *body.Name {
removed = true
continue
}
out = append(out, h.cfg.OpenAICompatibility[i])
}
if removed {
h.cfg.OpenAICompatibility = out
h.cfg.SanitizeOpenAICompatibility()
h.persist(c)
return
}
}
c.JSON(404, gin.H{"error": "item not found"})
return
} }
if body.Index != nil && *body.Index >= 0 && *body.Index < len(h.cfg.OpenAICompatibility) { if body.Value.APIKeyEntries != nil {
h.cfg.OpenAICompatibility[*body.Index] = *body.Value entry.APIKeyEntries = append([]config.OpenAICompatibilityAPIKey(nil), (*body.Value.APIKeyEntries)...)
h.cfg.SanitizeOpenAICompatibility()
h.persist(c)
return
} }
if body.Name != nil { if body.Value.Models != nil {
for i := range h.cfg.OpenAICompatibility { entry.Models = append([]config.OpenAICompatibilityModel(nil), (*body.Value.Models)...)
if h.cfg.OpenAICompatibility[i].Name == *body.Name {
h.cfg.OpenAICompatibility[i] = *body.Value
h.cfg.SanitizeOpenAICompatibility()
h.persist(c)
return
}
}
} }
c.JSON(404, gin.H{"error": "item not found"}) if body.Value.Headers != nil {
entry.Headers = config.NormalizeHeaders(*body.Value.Headers)
}
normalizeOpenAICompatibilityEntry(&entry)
h.cfg.OpenAICompatibility[targetIndex] = entry
h.cfg.SanitizeOpenAICompatibility()
h.persist(c)
} }
func (h *Handler) DeleteOpenAICompat(c *gin.Context) { func (h *Handler) DeleteOpenAICompat(c *gin.Context) {
if name := c.Query("name"); name != "" { if name := c.Query("name"); name != "" {
out := make([]config.OpenAICompatibility, 0, len(h.cfg.OpenAICompatibility)) out := make([]config.OpenAICompatibility, 0, len(h.cfg.OpenAICompatibility))
@@ -563,66 +612,72 @@ func (h *Handler) PutCodexKeys(c *gin.Context) {
h.persist(c) h.persist(c)
} }
func (h *Handler) PatchCodexKey(c *gin.Context) { func (h *Handler) PatchCodexKey(c *gin.Context) {
type codexKeyPatch struct {
APIKey *string `json:"api-key"`
Prefix *string `json:"prefix"`
BaseURL *string `json:"base-url"`
ProxyURL *string `json:"proxy-url"`
Headers *map[string]string `json:"headers"`
ExcludedModels *[]string `json:"excluded-models"`
}
var body struct { var body struct {
Index *int `json:"index"` Index *int `json:"index"`
Match *string `json:"match"` Match *string `json:"match"`
Value *config.CodexKey `json:"value"` Value *codexKeyPatch `json:"value"`
} }
if err := c.ShouldBindJSON(&body); err != nil || body.Value == nil { if err := c.ShouldBindJSON(&body); err != nil || body.Value == nil {
c.JSON(400, gin.H{"error": "invalid body"}) c.JSON(400, gin.H{"error": "invalid body"})
return return
} }
value := *body.Value targetIndex := -1
value.APIKey = strings.TrimSpace(value.APIKey) if body.Index != nil && *body.Index >= 0 && *body.Index < len(h.cfg.CodexKey) {
value.BaseURL = strings.TrimSpace(value.BaseURL) targetIndex = *body.Index
value.ProxyURL = strings.TrimSpace(value.ProxyURL) }
value.Headers = config.NormalizeHeaders(value.Headers) if targetIndex == -1 && body.Match != nil {
value.ExcludedModels = config.NormalizeExcludedModels(value.ExcludedModels) match := strings.TrimSpace(*body.Match)
// If base-url becomes empty, delete instead of update for i := range h.cfg.CodexKey {
if value.BaseURL == "" { if h.cfg.CodexKey[i].APIKey == match {
if body.Index != nil && *body.Index >= 0 && *body.Index < len(h.cfg.CodexKey) { targetIndex = i
h.cfg.CodexKey = append(h.cfg.CodexKey[:*body.Index], h.cfg.CodexKey[*body.Index+1:]...) break
h.cfg.SanitizeCodexKeys()
h.persist(c)
return
}
if body.Match != nil {
out := make([]config.CodexKey, 0, len(h.cfg.CodexKey))
removed := false
for i := range h.cfg.CodexKey {
if !removed && h.cfg.CodexKey[i].APIKey == *body.Match {
removed = true
continue
}
out = append(out, h.cfg.CodexKey[i])
}
if removed {
h.cfg.CodexKey = out
h.cfg.SanitizeCodexKeys()
h.persist(c)
return
}
}
} else {
if body.Index != nil && *body.Index >= 0 && *body.Index < len(h.cfg.CodexKey) {
h.cfg.CodexKey[*body.Index] = value
h.cfg.SanitizeCodexKeys()
h.persist(c)
return
}
if body.Match != nil {
for i := range h.cfg.CodexKey {
if h.cfg.CodexKey[i].APIKey == *body.Match {
h.cfg.CodexKey[i] = value
h.cfg.SanitizeCodexKeys()
h.persist(c)
return
}
} }
} }
} }
c.JSON(404, gin.H{"error": "item not found"}) if targetIndex == -1 {
c.JSON(404, gin.H{"error": "item not found"})
return
}
entry := h.cfg.CodexKey[targetIndex]
if body.Value.APIKey != nil {
entry.APIKey = strings.TrimSpace(*body.Value.APIKey)
}
if body.Value.Prefix != nil {
entry.Prefix = strings.TrimSpace(*body.Value.Prefix)
}
if body.Value.BaseURL != nil {
trimmed := strings.TrimSpace(*body.Value.BaseURL)
if trimmed == "" {
h.cfg.CodexKey = append(h.cfg.CodexKey[:targetIndex], h.cfg.CodexKey[targetIndex+1:]...)
h.cfg.SanitizeCodexKeys()
h.persist(c)
return
}
entry.BaseURL = trimmed
}
if body.Value.ProxyURL != nil {
entry.ProxyURL = strings.TrimSpace(*body.Value.ProxyURL)
}
if body.Value.Headers != nil {
entry.Headers = config.NormalizeHeaders(*body.Value.Headers)
}
if body.Value.ExcludedModels != nil {
entry.ExcludedModels = config.NormalizeExcludedModels(*body.Value.ExcludedModels)
}
h.cfg.CodexKey[targetIndex] = entry
h.cfg.SanitizeCodexKeys()
h.persist(c)
} }
func (h *Handler) DeleteCodexKey(c *gin.Context) { func (h *Handler) DeleteCodexKey(c *gin.Context) {
if val := c.Query("api-key"); val != "" { if val := c.Query("api-key"); val != "" {
out := make([]config.CodexKey, 0, len(h.cfg.CodexKey)) out := make([]config.CodexKey, 0, len(h.cfg.CodexKey))

View File

@@ -111,6 +111,27 @@ func (s *oauthSessionStore) Complete(state string) {
delete(s.sessions, state) delete(s.sessions, state)
} }
func (s *oauthSessionStore) CompleteProvider(provider string) int {
provider = strings.ToLower(strings.TrimSpace(provider))
if provider == "" {
return 0
}
now := time.Now()
s.mu.Lock()
defer s.mu.Unlock()
s.purgeExpiredLocked(now)
removed := 0
for state, session := range s.sessions {
if strings.EqualFold(session.Provider, provider) {
delete(s.sessions, state)
removed++
}
}
return removed
}
func (s *oauthSessionStore) Get(state string) (oauthSession, bool) { func (s *oauthSessionStore) Get(state string) (oauthSession, bool) {
state = strings.TrimSpace(state) state = strings.TrimSpace(state)
now := time.Now() now := time.Now()
@@ -158,6 +179,10 @@ func SetOAuthSessionError(state, message string) { oauthSessions.SetError(state,
func CompleteOAuthSession(state string) { oauthSessions.Complete(state) } func CompleteOAuthSession(state string) { oauthSessions.Complete(state) }
func CompleteOAuthSessionsByProvider(provider string) int {
return oauthSessions.CompleteProvider(provider)
}
func GetOAuthSession(state string) (provider string, status string, ok bool) { func GetOAuthSession(state string) (provider string, status string, ok bool) {
session, ok := oauthSessions.Get(state) session, ok := oauthSessions.Get(state)
if !ok { if !ok {

View File

@@ -18,6 +18,7 @@ import (
"github.com/router-for-me/CLIProxyAPI/v6/internal/auth/codex" "github.com/router-for-me/CLIProxyAPI/v6/internal/auth/codex"
"github.com/router-for-me/CLIProxyAPI/v6/internal/browser" "github.com/router-for-me/CLIProxyAPI/v6/internal/browser"
"github.com/router-for-me/CLIProxyAPI/v6/internal/config" "github.com/router-for-me/CLIProxyAPI/v6/internal/config"
"github.com/router-for-me/CLIProxyAPI/v6/internal/misc"
"github.com/router-for-me/CLIProxyAPI/v6/internal/util" "github.com/router-for-me/CLIProxyAPI/v6/internal/util"
log "github.com/sirupsen/logrus" log "github.com/sirupsen/logrus"
"github.com/tidwall/gjson" "github.com/tidwall/gjson"
@@ -46,6 +47,12 @@ var (
type GeminiAuth struct { type GeminiAuth struct {
} }
// WebLoginOptions customizes the interactive OAuth flow.
type WebLoginOptions struct {
NoBrowser bool
Prompt func(string) (string, error)
}
// NewGeminiAuth creates a new instance of GeminiAuth. // NewGeminiAuth creates a new instance of GeminiAuth.
func NewGeminiAuth() *GeminiAuth { func NewGeminiAuth() *GeminiAuth {
return &GeminiAuth{} return &GeminiAuth{}
@@ -59,12 +66,12 @@ func NewGeminiAuth() *GeminiAuth {
// - ctx: The context for the HTTP client // - ctx: The context for the HTTP client
// - ts: The Gemini token storage containing authentication tokens // - ts: The Gemini token storage containing authentication tokens
// - cfg: The configuration containing proxy settings // - cfg: The configuration containing proxy settings
// - noBrowser: Optional parameter to disable browser opening // - opts: Optional parameters to customize browser and prompt behavior
// //
// Returns: // Returns:
// - *http.Client: An HTTP client configured with authentication // - *http.Client: An HTTP client configured with authentication
// - error: An error if the client configuration fails, nil otherwise // - error: An error if the client configuration fails, nil otherwise
func (g *GeminiAuth) GetAuthenticatedClient(ctx context.Context, ts *GeminiTokenStorage, cfg *config.Config, noBrowser ...bool) (*http.Client, error) { func (g *GeminiAuth) GetAuthenticatedClient(ctx context.Context, ts *GeminiTokenStorage, cfg *config.Config, opts *WebLoginOptions) (*http.Client, error) {
// Configure proxy settings for the HTTP client if a proxy URL is provided. // Configure proxy settings for the HTTP client if a proxy URL is provided.
proxyURL, err := url.Parse(cfg.ProxyURL) proxyURL, err := url.Parse(cfg.ProxyURL)
if err == nil { if err == nil {
@@ -109,7 +116,7 @@ func (g *GeminiAuth) GetAuthenticatedClient(ctx context.Context, ts *GeminiToken
// If no token is found in storage, initiate the web-based OAuth flow. // If no token is found in storage, initiate the web-based OAuth flow.
if ts.Token == nil { if ts.Token == nil {
fmt.Printf("Could not load token from file, starting OAuth flow.\n") fmt.Printf("Could not load token from file, starting OAuth flow.\n")
token, err = g.getTokenFromWeb(ctx, conf, noBrowser...) token, err = g.getTokenFromWeb(ctx, conf, opts)
if err != nil { if err != nil {
return nil, fmt.Errorf("failed to get token from web: %w", err) return nil, fmt.Errorf("failed to get token from web: %w", err)
} }
@@ -205,15 +212,15 @@ func (g *GeminiAuth) createTokenStorage(ctx context.Context, config *oauth2.Conf
// Parameters: // Parameters:
// - ctx: The context for the HTTP client // - ctx: The context for the HTTP client
// - config: The OAuth2 configuration // - config: The OAuth2 configuration
// - noBrowser: Optional parameter to disable browser opening // - opts: Optional parameters to customize browser and prompt behavior
// //
// Returns: // Returns:
// - *oauth2.Token: The OAuth2 token obtained from the authorization flow // - *oauth2.Token: The OAuth2 token obtained from the authorization flow
// - error: An error if the token acquisition fails, nil otherwise // - error: An error if the token acquisition fails, nil otherwise
func (g *GeminiAuth) getTokenFromWeb(ctx context.Context, config *oauth2.Config, noBrowser ...bool) (*oauth2.Token, error) { func (g *GeminiAuth) getTokenFromWeb(ctx context.Context, config *oauth2.Config, opts *WebLoginOptions) (*oauth2.Token, error) {
// Use a channel to pass the authorization code from the HTTP handler to the main function. // Use a channel to pass the authorization code from the HTTP handler to the main function.
codeChan := make(chan string) codeChan := make(chan string, 1)
errChan := make(chan error) errChan := make(chan error, 1)
// Create a new HTTP server with its own multiplexer. // Create a new HTTP server with its own multiplexer.
mux := http.NewServeMux() mux := http.NewServeMux()
@@ -223,17 +230,26 @@ func (g *GeminiAuth) getTokenFromWeb(ctx context.Context, config *oauth2.Config,
mux.HandleFunc("/oauth2callback", func(w http.ResponseWriter, r *http.Request) { mux.HandleFunc("/oauth2callback", func(w http.ResponseWriter, r *http.Request) {
if err := r.URL.Query().Get("error"); err != "" { if err := r.URL.Query().Get("error"); err != "" {
_, _ = fmt.Fprintf(w, "Authentication failed: %s", err) _, _ = fmt.Fprintf(w, "Authentication failed: %s", err)
errChan <- fmt.Errorf("authentication failed via callback: %s", err) select {
case errChan <- fmt.Errorf("authentication failed via callback: %s", err):
default:
}
return return
} }
code := r.URL.Query().Get("code") code := r.URL.Query().Get("code")
if code == "" { if code == "" {
_, _ = fmt.Fprint(w, "Authentication failed: code not found.") _, _ = fmt.Fprint(w, "Authentication failed: code not found.")
errChan <- fmt.Errorf("code not found in callback") select {
case errChan <- fmt.Errorf("code not found in callback"):
default:
}
return return
} }
_, _ = fmt.Fprint(w, "<html><body><h1>Authentication successful!</h1><p>You can close this window.</p></body></html>") _, _ = fmt.Fprint(w, "<html><body><h1>Authentication successful!</h1><p>You can close this window.</p></body></html>")
codeChan <- code select {
case codeChan <- code:
default:
}
}) })
// Start the server in a goroutine. // Start the server in a goroutine.
@@ -250,7 +266,12 @@ func (g *GeminiAuth) getTokenFromWeb(ctx context.Context, config *oauth2.Config,
// Open the authorization URL in the user's browser. // Open the authorization URL in the user's browser.
authURL := config.AuthCodeURL("state-token", oauth2.AccessTypeOffline, oauth2.SetAuthURLParam("prompt", "consent")) authURL := config.AuthCodeURL("state-token", oauth2.AccessTypeOffline, oauth2.SetAuthURLParam("prompt", "consent"))
if len(noBrowser) == 1 && !noBrowser[0] { noBrowser := false
if opts != nil {
noBrowser = opts.NoBrowser
}
if !noBrowser {
fmt.Println("Opening browser for authentication...") fmt.Println("Opening browser for authentication...")
// Check if browser is available // Check if browser is available
@@ -281,13 +302,60 @@ func (g *GeminiAuth) getTokenFromWeb(ctx context.Context, config *oauth2.Config,
// Wait for the authorization code or an error. // Wait for the authorization code or an error.
var authCode string var authCode string
select { timeoutTimer := time.NewTimer(5 * time.Minute)
case code := <-codeChan: defer timeoutTimer.Stop()
authCode = code
case err := <-errChan: var manualPromptTimer *time.Timer
return nil, err var manualPromptC <-chan time.Time
case <-time.After(5 * time.Minute): // Timeout if opts != nil && opts.Prompt != nil {
return nil, fmt.Errorf("oauth flow timed out") manualPromptTimer = time.NewTimer(15 * time.Second)
manualPromptC = manualPromptTimer.C
defer manualPromptTimer.Stop()
}
waitForCallback:
for {
select {
case code := <-codeChan:
authCode = code
break waitForCallback
case err := <-errChan:
return nil, err
case <-manualPromptC:
manualPromptC = nil
if manualPromptTimer != nil {
manualPromptTimer.Stop()
}
select {
case code := <-codeChan:
authCode = code
break waitForCallback
case err := <-errChan:
return nil, err
default:
}
input, err := opts.Prompt("Paste the Gemini callback URL (or press Enter to keep waiting): ")
if err != nil {
return nil, err
}
parsed, err := misc.ParseOAuthCallback(input)
if err != nil {
return nil, err
}
if parsed == nil {
continue
}
if parsed.Error != "" {
return nil, fmt.Errorf("authentication failed via callback: %s", parsed.Error)
}
if parsed.Code == "" {
return nil, fmt.Errorf("code not found in callback")
}
authCode = parsed.Code
break waitForCallback
case <-timeoutTimer.C:
return nil, fmt.Errorf("oauth flow timed out")
}
} }
// Shutdown the server. // Shutdown the server.

164
internal/cache/signature_cache.go vendored Normal file
View File

@@ -0,0 +1,164 @@
package cache
import (
"crypto/sha256"
"encoding/hex"
"sort"
"sync"
"time"
)
// SignatureEntry holds a cached thinking signature with timestamp
type SignatureEntry struct {
Signature string
Timestamp time.Time
}
const (
// SignatureCacheTTL is how long signatures are valid
SignatureCacheTTL = 1 * time.Hour
// MaxEntriesPerSession limits memory usage per session
MaxEntriesPerSession = 100
// SignatureTextHashLen is the length of the hash key (16 hex chars = 64-bit key space)
SignatureTextHashLen = 16
// MinValidSignatureLen is the minimum length for a signature to be considered valid
MinValidSignatureLen = 50
)
// signatureCache stores signatures by sessionId -> textHash -> SignatureEntry
var signatureCache sync.Map
// sessionCache is the inner map type
type sessionCache struct {
mu sync.RWMutex
entries map[string]SignatureEntry
}
// hashText creates a stable, Unicode-safe key from text content
func hashText(text string) string {
h := sha256.Sum256([]byte(text))
return hex.EncodeToString(h[:])[:SignatureTextHashLen]
}
// getOrCreateSession gets or creates a session cache
func getOrCreateSession(sessionID string) *sessionCache {
if val, ok := signatureCache.Load(sessionID); ok {
return val.(*sessionCache)
}
sc := &sessionCache{entries: make(map[string]SignatureEntry)}
actual, _ := signatureCache.LoadOrStore(sessionID, sc)
return actual.(*sessionCache)
}
// CacheSignature stores a thinking signature for a given session and text.
// Used for Claude models that require signed thinking blocks in multi-turn conversations.
func CacheSignature(sessionID, text, signature string) {
if sessionID == "" || text == "" || signature == "" {
return
}
if len(signature) < MinValidSignatureLen {
return
}
sc := getOrCreateSession(sessionID)
textHash := hashText(text)
sc.mu.Lock()
defer sc.mu.Unlock()
// Evict expired entries if at capacity
if len(sc.entries) >= MaxEntriesPerSession {
now := time.Now()
for key, entry := range sc.entries {
if now.Sub(entry.Timestamp) > SignatureCacheTTL {
delete(sc.entries, key)
}
}
// If still at capacity, remove oldest entries
if len(sc.entries) >= MaxEntriesPerSession {
// Find and remove oldest quarter
oldest := make([]struct {
key string
ts time.Time
}, 0, len(sc.entries))
for key, entry := range sc.entries {
oldest = append(oldest, struct {
key string
ts time.Time
}{key, entry.Timestamp})
}
// Sort by timestamp (oldest first) using sort.Slice
sort.Slice(oldest, func(i, j int) bool {
return oldest[i].ts.Before(oldest[j].ts)
})
toRemove := len(oldest) / 4
if toRemove < 1 {
toRemove = 1
}
for i := 0; i < toRemove; i++ {
delete(sc.entries, oldest[i].key)
}
}
}
sc.entries[textHash] = SignatureEntry{
Signature: signature,
Timestamp: time.Now(),
}
}
// GetCachedSignature retrieves a cached signature for a given session and text.
// Returns empty string if not found or expired.
func GetCachedSignature(sessionID, text string) string {
if sessionID == "" || text == "" {
return ""
}
val, ok := signatureCache.Load(sessionID)
if !ok {
return ""
}
sc := val.(*sessionCache)
textHash := hashText(text)
sc.mu.RLock()
entry, exists := sc.entries[textHash]
sc.mu.RUnlock()
if !exists {
return ""
}
// Check if expired
if time.Since(entry.Timestamp) > SignatureCacheTTL {
sc.mu.Lock()
delete(sc.entries, textHash)
sc.mu.Unlock()
return ""
}
return entry.Signature
}
// ClearSignatureCache clears signature cache for a specific session or all sessions.
func ClearSignatureCache(sessionID string) {
if sessionID != "" {
signatureCache.Delete(sessionID)
} else {
signatureCache.Range(func(key, _ any) bool {
signatureCache.Delete(key)
return true
})
}
}
// HasValidSignature checks if a signature is valid (non-empty and long enough)
func HasValidSignature(signature string) bool {
return signature != "" && len(signature) >= MinValidSignatureLen
}

216
internal/cache/signature_cache_test.go vendored Normal file
View File

@@ -0,0 +1,216 @@
package cache
import (
"testing"
"time"
)
func TestCacheSignature_BasicStorageAndRetrieval(t *testing.T) {
ClearSignatureCache("")
sessionID := "test-session-1"
text := "This is some thinking text content"
signature := "abc123validSignature1234567890123456789012345678901234567890"
// Store signature
CacheSignature(sessionID, text, signature)
// Retrieve signature
retrieved := GetCachedSignature(sessionID, text)
if retrieved != signature {
t.Errorf("Expected signature '%s', got '%s'", signature, retrieved)
}
}
func TestCacheSignature_DifferentSessions(t *testing.T) {
ClearSignatureCache("")
text := "Same text in different sessions"
sig1 := "signature1_1234567890123456789012345678901234567890123456"
sig2 := "signature2_1234567890123456789012345678901234567890123456"
CacheSignature("session-a", text, sig1)
CacheSignature("session-b", text, sig2)
if GetCachedSignature("session-a", text) != sig1 {
t.Error("Session-a signature mismatch")
}
if GetCachedSignature("session-b", text) != sig2 {
t.Error("Session-b signature mismatch")
}
}
func TestCacheSignature_NotFound(t *testing.T) {
ClearSignatureCache("")
// Non-existent session
if got := GetCachedSignature("nonexistent", "some text"); got != "" {
t.Errorf("Expected empty string for nonexistent session, got '%s'", got)
}
// Existing session but different text
CacheSignature("session-x", "text-a", "sigA12345678901234567890123456789012345678901234567890")
if got := GetCachedSignature("session-x", "text-b"); got != "" {
t.Errorf("Expected empty string for different text, got '%s'", got)
}
}
func TestCacheSignature_EmptyInputs(t *testing.T) {
ClearSignatureCache("")
// All empty/invalid inputs should be no-ops
CacheSignature("", "text", "sig12345678901234567890123456789012345678901234567890")
CacheSignature("session", "", "sig12345678901234567890123456789012345678901234567890")
CacheSignature("session", "text", "")
CacheSignature("session", "text", "short") // Too short
if got := GetCachedSignature("session", "text"); got != "" {
t.Errorf("Expected empty after invalid cache attempts, got '%s'", got)
}
}
func TestCacheSignature_ShortSignatureRejected(t *testing.T) {
ClearSignatureCache("")
sessionID := "test-short-sig"
text := "Some text"
shortSig := "abc123" // Less than 50 chars
CacheSignature(sessionID, text, shortSig)
if got := GetCachedSignature(sessionID, text); got != "" {
t.Errorf("Short signature should be rejected, got '%s'", got)
}
}
func TestClearSignatureCache_SpecificSession(t *testing.T) {
ClearSignatureCache("")
sig := "validSig1234567890123456789012345678901234567890123456"
CacheSignature("session-1", "text", sig)
CacheSignature("session-2", "text", sig)
ClearSignatureCache("session-1")
if got := GetCachedSignature("session-1", "text"); got != "" {
t.Error("session-1 should be cleared")
}
if got := GetCachedSignature("session-2", "text"); got != sig {
t.Error("session-2 should still exist")
}
}
func TestClearSignatureCache_AllSessions(t *testing.T) {
ClearSignatureCache("")
sig := "validSig1234567890123456789012345678901234567890123456"
CacheSignature("session-1", "text", sig)
CacheSignature("session-2", "text", sig)
ClearSignatureCache("")
if got := GetCachedSignature("session-1", "text"); got != "" {
t.Error("session-1 should be cleared")
}
if got := GetCachedSignature("session-2", "text"); got != "" {
t.Error("session-2 should be cleared")
}
}
func TestHasValidSignature(t *testing.T) {
tests := []struct {
name string
signature string
expected bool
}{
{"valid long signature", "abc123validSignature1234567890123456789012345678901234567890", true},
{"exactly 50 chars", "12345678901234567890123456789012345678901234567890", true},
{"49 chars - invalid", "1234567890123456789012345678901234567890123456789", false},
{"empty string", "", false},
{"short signature", "abc", false},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
result := HasValidSignature(tt.signature)
if result != tt.expected {
t.Errorf("HasValidSignature(%q) = %v, expected %v", tt.signature, result, tt.expected)
}
})
}
}
func TestCacheSignature_TextHashCollisionResistance(t *testing.T) {
ClearSignatureCache("")
sessionID := "hash-test-session"
// Different texts should produce different hashes
text1 := "First thinking text"
text2 := "Second thinking text"
sig1 := "signature1_1234567890123456789012345678901234567890123456"
sig2 := "signature2_1234567890123456789012345678901234567890123456"
CacheSignature(sessionID, text1, sig1)
CacheSignature(sessionID, text2, sig2)
if GetCachedSignature(sessionID, text1) != sig1 {
t.Error("text1 signature mismatch")
}
if GetCachedSignature(sessionID, text2) != sig2 {
t.Error("text2 signature mismatch")
}
}
func TestCacheSignature_UnicodeText(t *testing.T) {
ClearSignatureCache("")
sessionID := "unicode-session"
text := "한글 텍스트와 이모지 🎉 그리고 特殊文字"
sig := "unicodeSig123456789012345678901234567890123456789012345"
CacheSignature(sessionID, text, sig)
if got := GetCachedSignature(sessionID, text); got != sig {
t.Errorf("Unicode text signature retrieval failed, got '%s'", got)
}
}
func TestCacheSignature_Overwrite(t *testing.T) {
ClearSignatureCache("")
sessionID := "overwrite-session"
text := "Same text"
sig1 := "firstSignature12345678901234567890123456789012345678901"
sig2 := "secondSignature1234567890123456789012345678901234567890"
CacheSignature(sessionID, text, sig1)
CacheSignature(sessionID, text, sig2) // Overwrite
if got := GetCachedSignature(sessionID, text); got != sig2 {
t.Errorf("Expected overwritten signature '%s', got '%s'", sig2, got)
}
}
// Note: TTL expiration test is tricky to test without mocking time
// We test the logic path exists but actual expiration would require time manipulation
func TestCacheSignature_ExpirationLogic(t *testing.T) {
ClearSignatureCache("")
// This test verifies the expiration check exists
// In a real scenario, we'd mock time.Now()
sessionID := "expiration-test"
text := "text"
sig := "validSig1234567890123456789012345678901234567890123456"
CacheSignature(sessionID, text, sig)
// Fresh entry should be retrievable
if got := GetCachedSignature(sessionID, text); got != sig {
t.Errorf("Fresh entry should be retrievable, got '%s'", got)
}
// We can't easily test actual expiration without time mocking
// but the logic is verified by the implementation
_ = time.Now() // Acknowledge we're not testing time passage
}

View File

@@ -24,12 +24,17 @@ func DoClaudeLogin(cfg *config.Config, options *LoginOptions) {
options = &LoginOptions{} options = &LoginOptions{}
} }
promptFn := options.Prompt
if promptFn == nil {
promptFn = defaultProjectPrompt()
}
manager := newAuthManager() manager := newAuthManager()
authOpts := &sdkAuth.LoginOptions{ authOpts := &sdkAuth.LoginOptions{
NoBrowser: options.NoBrowser, NoBrowser: options.NoBrowser,
Metadata: map[string]string{}, Metadata: map[string]string{},
Prompt: options.Prompt, Prompt: promptFn,
} }
_, savedPath, err := manager.Login(context.Background(), "claude", cfg, authOpts) _, savedPath, err := manager.Login(context.Background(), "claude", cfg, authOpts)

View File

@@ -15,11 +15,16 @@ func DoAntigravityLogin(cfg *config.Config, options *LoginOptions) {
options = &LoginOptions{} options = &LoginOptions{}
} }
promptFn := options.Prompt
if promptFn == nil {
promptFn = defaultProjectPrompt()
}
manager := newAuthManager() manager := newAuthManager()
authOpts := &sdkAuth.LoginOptions{ authOpts := &sdkAuth.LoginOptions{
NoBrowser: options.NoBrowser, NoBrowser: options.NoBrowser,
Metadata: map[string]string{}, Metadata: map[string]string{},
Prompt: options.Prompt, Prompt: promptFn,
} }
record, savedPath, err := manager.Login(context.Background(), "antigravity", cfg, authOpts) record, savedPath, err := manager.Login(context.Background(), "antigravity", cfg, authOpts)

View File

@@ -20,13 +20,7 @@ func DoIFlowLogin(cfg *config.Config, options *LoginOptions) {
promptFn := options.Prompt promptFn := options.Prompt
if promptFn == nil { if promptFn == nil {
promptFn = func(prompt string) (string, error) { promptFn = defaultProjectPrompt()
fmt.Println()
fmt.Println(prompt)
var value string
_, err := fmt.Scanln(&value)
return value, err
}
} }
authOpts := &sdkAuth.LoginOptions{ authOpts := &sdkAuth.LoginOptions{

View File

@@ -55,11 +55,22 @@ func DoLogin(cfg *config.Config, projectID string, options *LoginOptions) {
ctx := context.Background() ctx := context.Background()
promptFn := options.Prompt
if promptFn == nil {
promptFn = defaultProjectPrompt()
}
trimmedProjectID := strings.TrimSpace(projectID)
callbackPrompt := promptFn
if trimmedProjectID == "" {
callbackPrompt = nil
}
loginOpts := &sdkAuth.LoginOptions{ loginOpts := &sdkAuth.LoginOptions{
NoBrowser: options.NoBrowser, NoBrowser: options.NoBrowser,
ProjectID: strings.TrimSpace(projectID), ProjectID: trimmedProjectID,
Metadata: map[string]string{}, Metadata: map[string]string{},
Prompt: options.Prompt, Prompt: callbackPrompt,
} }
authenticator := sdkAuth.NewGeminiAuthenticator() authenticator := sdkAuth.NewGeminiAuthenticator()
@@ -76,7 +87,10 @@ func DoLogin(cfg *config.Config, projectID string, options *LoginOptions) {
} }
geminiAuth := gemini.NewGeminiAuth() geminiAuth := gemini.NewGeminiAuth()
httpClient, errClient := geminiAuth.GetAuthenticatedClient(ctx, storage, cfg, options.NoBrowser) httpClient, errClient := geminiAuth.GetAuthenticatedClient(ctx, storage, cfg, &gemini.WebLoginOptions{
NoBrowser: options.NoBrowser,
Prompt: callbackPrompt,
})
if errClient != nil { if errClient != nil {
log.Errorf("Gemini authentication failed: %v", errClient) log.Errorf("Gemini authentication failed: %v", errClient)
return return
@@ -90,12 +104,7 @@ func DoLogin(cfg *config.Config, projectID string, options *LoginOptions) {
return return
} }
promptFn := options.Prompt selectedProjectID := promptForProjectSelection(projects, trimmedProjectID, promptFn)
if promptFn == nil {
promptFn = defaultProjectPrompt()
}
selectedProjectID := promptForProjectSelection(projects, strings.TrimSpace(projectID), promptFn)
projectSelections, errSelection := resolveProjectSelections(selectedProjectID, projects) projectSelections, errSelection := resolveProjectSelections(selectedProjectID, projects)
if errSelection != nil { if errSelection != nil {
log.Errorf("Invalid project selection: %v", errSelection) log.Errorf("Invalid project selection: %v", errSelection)

View File

@@ -35,12 +35,17 @@ func DoCodexLogin(cfg *config.Config, options *LoginOptions) {
options = &LoginOptions{} options = &LoginOptions{}
} }
promptFn := options.Prompt
if promptFn == nil {
promptFn = defaultProjectPrompt()
}
manager := newAuthManager() manager := newAuthManager()
authOpts := &sdkAuth.LoginOptions{ authOpts := &sdkAuth.LoginOptions{
NoBrowser: options.NoBrowser, NoBrowser: options.NoBrowser,
Metadata: map[string]string{}, Metadata: map[string]string{},
Prompt: options.Prompt, Prompt: promptFn,
} }
_, savedPath, err := manager.Login(context.Background(), "codex", cfg, authOpts) _, savedPath, err := manager.Login(context.Background(), "codex", cfg, authOpts)

View File

@@ -60,6 +60,9 @@ type Config struct {
// QuotaExceeded defines the behavior when a quota is exceeded. // QuotaExceeded defines the behavior when a quota is exceeded.
QuotaExceeded QuotaExceeded `yaml:"quota-exceeded" json:"quota-exceeded"` QuotaExceeded QuotaExceeded `yaml:"quota-exceeded" json:"quota-exceeded"`
// Routing controls credential selection behavior.
Routing RoutingConfig `yaml:"routing" json:"routing"`
// WebsocketAuth enables or disables authentication for the WebSocket API. // WebsocketAuth enables or disables authentication for the WebSocket API.
WebsocketAuth bool `yaml:"ws-auth" json:"ws-auth"` WebsocketAuth bool `yaml:"ws-auth" json:"ws-auth"`
@@ -136,6 +139,13 @@ type QuotaExceeded struct {
SwitchPreviewModel bool `yaml:"switch-preview-model" json:"switch-preview-model"` SwitchPreviewModel bool `yaml:"switch-preview-model" json:"switch-preview-model"`
} }
// RoutingConfig configures how credentials are selected for requests.
type RoutingConfig struct {
// Strategy selects the credential selection strategy.
// Supported values: "round-robin" (default), "fill-first".
Strategy string `yaml:"strategy,omitempty" json:"strategy,omitempty"`
}
// AmpModelMapping defines a model name mapping for Amp CLI requests. // AmpModelMapping defines a model name mapping for Amp CLI requests.
// When Amp requests a model that isn't available locally, this mapping // When Amp requests a model that isn't available locally, this mapping
// allows routing to an alternative model that IS available. // allows routing to an alternative model that IS available.

View File

@@ -4,6 +4,8 @@ import (
"crypto/rand" "crypto/rand"
"encoding/hex" "encoding/hex"
"fmt" "fmt"
"net/url"
"strings"
) )
// GenerateRandomState generates a cryptographically secure random state parameter // GenerateRandomState generates a cryptographically secure random state parameter
@@ -19,3 +21,83 @@ func GenerateRandomState() (string, error) {
} }
return hex.EncodeToString(bytes), nil return hex.EncodeToString(bytes), nil
} }
// OAuthCallback captures the parsed OAuth callback parameters.
type OAuthCallback struct {
Code string
State string
Error string
ErrorDescription string
}
// ParseOAuthCallback extracts OAuth parameters from a callback URL.
// It returns nil when the input is empty.
func ParseOAuthCallback(input string) (*OAuthCallback, error) {
trimmed := strings.TrimSpace(input)
if trimmed == "" {
return nil, nil
}
candidate := trimmed
if !strings.Contains(candidate, "://") {
if strings.HasPrefix(candidate, "?") {
candidate = "http://localhost" + candidate
} else if strings.ContainsAny(candidate, "/?#") || strings.Contains(candidate, ":") {
candidate = "http://" + candidate
} else if strings.Contains(candidate, "=") {
candidate = "http://localhost/?" + candidate
} else {
return nil, fmt.Errorf("invalid callback URL")
}
}
parsedURL, err := url.Parse(candidate)
if err != nil {
return nil, err
}
query := parsedURL.Query()
code := strings.TrimSpace(query.Get("code"))
state := strings.TrimSpace(query.Get("state"))
errCode := strings.TrimSpace(query.Get("error"))
errDesc := strings.TrimSpace(query.Get("error_description"))
if parsedURL.Fragment != "" {
if fragQuery, errFrag := url.ParseQuery(parsedURL.Fragment); errFrag == nil {
if code == "" {
code = strings.TrimSpace(fragQuery.Get("code"))
}
if state == "" {
state = strings.TrimSpace(fragQuery.Get("state"))
}
if errCode == "" {
errCode = strings.TrimSpace(fragQuery.Get("error"))
}
if errDesc == "" {
errDesc = strings.TrimSpace(fragQuery.Get("error_description"))
}
}
}
if code != "" && state == "" && strings.Contains(code, "#") {
parts := strings.SplitN(code, "#", 2)
code = parts[0]
state = parts[1]
}
if errCode == "" && errDesc != "" {
errCode = errDesc
errDesc = ""
}
if code == "" && errCode == "" {
return nil, fmt.Errorf("callback URL missing code")
}
return &OAuthCallback{
Code: code,
State: state,
Error: errCode,
ErrorDescription: errDesc,
}, nil
}

View File

@@ -7,6 +7,8 @@ import (
"bufio" "bufio"
"bytes" "bytes"
"context" "context"
"crypto/sha256"
"encoding/binary"
"encoding/json" "encoding/json"
"fmt" "fmt"
"io" "io"
@@ -31,18 +33,18 @@ import (
) )
const ( const (
antigravityBaseURLDaily = "https://daily-cloudcode-pa.sandbox.googleapis.com" antigravityBaseURLDaily = "https://daily-cloudcode-pa.googleapis.com"
// antigravityBaseURLAutopush = "https://autopush-cloudcode-pa.sandbox.googleapis.com" antigravitySandboxBaseURLDaily = "https://daily-cloudcode-pa.sandbox.googleapis.com"
antigravityBaseURLProd = "https://cloudcode-pa.googleapis.com" antigravityBaseURLProd = "https://cloudcode-pa.googleapis.com"
antigravityCountTokensPath = "/v1internal:countTokens" antigravityCountTokensPath = "/v1internal:countTokens"
antigravityStreamPath = "/v1internal:streamGenerateContent" antigravityStreamPath = "/v1internal:streamGenerateContent"
antigravityGeneratePath = "/v1internal:generateContent" antigravityGeneratePath = "/v1internal:generateContent"
antigravityModelsPath = "/v1internal:fetchAvailableModels" antigravityModelsPath = "/v1internal:fetchAvailableModels"
antigravityClientID = "1071006060591-tmhssin2h21lcre235vtolojh4g403ep.apps.googleusercontent.com" antigravityClientID = "1071006060591-tmhssin2h21lcre235vtolojh4g403ep.apps.googleusercontent.com"
antigravityClientSecret = "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" antigravityClientSecret = "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"
defaultAntigravityAgent = "antigravity/1.11.5 windows/amd64" defaultAntigravityAgent = "antigravity/1.11.5 windows/amd64"
antigravityAuthType = "antigravity" antigravityAuthType = "antigravity"
refreshSkew = 3000 * time.Second refreshSkew = 3000 * time.Second
) )
var ( var (
@@ -1014,7 +1016,7 @@ func (e *AntigravityExecutor) buildRequest(ctx context.Context, auth *cliproxyau
// Use the centralized schema cleaner to handle unsupported keywords, // Use the centralized schema cleaner to handle unsupported keywords,
// const->enum conversion, and flattening of types/anyOf. // const->enum conversion, and flattening of types/anyOf.
strJSON = util.CleanJSONSchemaForGemini(strJSON) strJSON = util.CleanJSONSchemaForAntigravity(strJSON)
payload = []byte(strJSON) payload = []byte(strJSON)
} }
@@ -1154,7 +1156,7 @@ func antigravityBaseURLFallbackOrder(auth *cliproxyauth.Auth) []string {
} }
return []string{ return []string{
antigravityBaseURLDaily, antigravityBaseURLDaily,
// antigravityBaseURLAutopush, antigravitySandboxBaseURLDaily,
antigravityBaseURLProd, antigravityBaseURLProd,
} }
} }
@@ -1190,7 +1192,7 @@ func geminiToAntigravity(modelName string, payload []byte, projectID string) []b
template, _ = sjson.Set(template, "project", generateProjectID()) template, _ = sjson.Set(template, "project", generateProjectID())
} }
template, _ = sjson.Set(template, "requestId", generateRequestID()) template, _ = sjson.Set(template, "requestId", generateRequestID())
template, _ = sjson.Set(template, "request.sessionId", generateSessionID()) template, _ = sjson.Set(template, "request.sessionId", generateStableSessionID(payload))
template, _ = sjson.Delete(template, "request.safetySettings") template, _ = sjson.Delete(template, "request.safetySettings")
template, _ = sjson.Set(template, "request.toolConfig.functionCallingConfig.mode", "VALIDATED") template, _ = sjson.Set(template, "request.toolConfig.functionCallingConfig.mode", "VALIDATED")
@@ -1232,6 +1234,23 @@ func generateSessionID() string {
return "-" + strconv.FormatInt(n, 10) return "-" + strconv.FormatInt(n, 10)
} }
func generateStableSessionID(payload []byte) string {
contents := gjson.GetBytes(payload, "request.contents")
if contents.IsArray() {
for _, content := range contents.Array() {
if content.Get("role").String() == "user" {
text := content.Get("parts.0.text").String()
if text != "" {
h := sha256.Sum256([]byte(text))
n := int64(binary.BigEndian.Uint64(h[:8])) & 0x7FFFFFFFFFFFFFFF
return "-" + strconv.FormatInt(n, 10)
}
}
}
}
return generateSessionID()
}
func generateProjectID() string { func generateProjectID() string {
adjectives := []string{"useful", "bright", "swift", "calm", "bold"} adjectives := []string{"useful", "bright", "swift", "calm", "bold"}
nouns := []string{"fuze", "wave", "spark", "flow", "core"} nouns := []string{"fuze", "wave", "spark", "flow", "core"}

View File

@@ -662,7 +662,14 @@ func decodeResponseBody(body io.ReadCloser, contentEncoding string) (io.ReadClos
} }
func applyClaudeHeaders(r *http.Request, auth *cliproxyauth.Auth, apiKey string, stream bool, extraBetas []string) { func applyClaudeHeaders(r *http.Request, auth *cliproxyauth.Auth, apiKey string, stream bool, extraBetas []string) {
r.Header.Set("Authorization", "Bearer "+apiKey) useAPIKey := auth != nil && auth.Attributes != nil && strings.TrimSpace(auth.Attributes["api_key"]) != ""
isAnthropicBase := r.URL != nil && strings.EqualFold(r.URL.Scheme, "https") && strings.EqualFold(r.URL.Host, "api.anthropic.com")
if isAnthropicBase && useAPIKey {
r.Header.Del("Authorization")
r.Header.Set("x-api-key", apiKey)
} else {
r.Header.Set("Authorization", "Bearer "+apiKey)
}
r.Header.Set("Content-Type", "application/json") r.Header.Set("Content-Type", "application/json")
var ginHeaders http.Header var ginHeaders http.Header

View File

@@ -7,15 +7,40 @@ package claude
import ( import (
"bytes" "bytes"
"crypto/sha256"
"encoding/hex"
"strings" "strings"
"github.com/router-for-me/CLIProxyAPI/v6/internal/cache"
"github.com/router-for-me/CLIProxyAPI/v6/internal/translator/gemini/common" "github.com/router-for-me/CLIProxyAPI/v6/internal/translator/gemini/common"
"github.com/router-for-me/CLIProxyAPI/v6/internal/util" "github.com/router-for-me/CLIProxyAPI/v6/internal/util"
log "github.com/sirupsen/logrus"
"github.com/tidwall/gjson" "github.com/tidwall/gjson"
"github.com/tidwall/sjson" "github.com/tidwall/sjson"
) )
const geminiCLIClaudeThoughtSignature = "skip_thought_signature_validator" // deriveSessionID generates a stable session ID from the request.
// Uses the hash of the first user message to identify the conversation.
func deriveSessionID(rawJSON []byte) string {
messages := gjson.GetBytes(rawJSON, "messages")
if !messages.IsArray() {
return ""
}
for _, msg := range messages.Array() {
if msg.Get("role").String() == "user" {
content := msg.Get("content").String()
if content == "" {
// Try to get text from content array
content = msg.Get("content.0.text").String()
}
if content != "" {
h := sha256.Sum256([]byte(content))
return hex.EncodeToString(h[:16])
}
}
}
return ""
}
// ConvertClaudeRequestToAntigravity parses and transforms a Claude Code API request into Gemini CLI API format. // ConvertClaudeRequestToAntigravity parses and transforms a Claude Code API request into Gemini CLI API format.
// It extracts the model name, system instruction, message contents, and tool declarations // It extracts the model name, system instruction, message contents, and tool declarations
@@ -37,7 +62,9 @@ const geminiCLIClaudeThoughtSignature = "skip_thought_signature_validator"
// - []byte: The transformed request data in Gemini CLI API format // - []byte: The transformed request data in Gemini CLI API format
func ConvertClaudeRequestToAntigravity(modelName string, inputRawJSON []byte, _ bool) []byte { func ConvertClaudeRequestToAntigravity(modelName string, inputRawJSON []byte, _ bool) []byte {
rawJSON := bytes.Clone(inputRawJSON) rawJSON := bytes.Clone(inputRawJSON)
rawJSON = bytes.Replace(rawJSON, []byte(`"url":{"type":"string","format":"uri",`), []byte(`"url":{"type":"string",`), -1)
// Derive session ID for signature caching
sessionID := deriveSessionID(rawJSON)
// system instruction // system instruction
systemInstructionJSON := "" systemInstructionJSON := ""
@@ -59,21 +86,28 @@ func ConvertClaudeRequestToAntigravity(modelName string, inputRawJSON []byte, _
hasSystemInstruction = true hasSystemInstruction = true
} }
} }
} else if systemResult.Type == gjson.String {
systemInstructionJSON = `{"role":"user","parts":[{"text":""}]}`
systemInstructionJSON, _ = sjson.Set(systemInstructionJSON, "parts.0.text", systemResult.String())
hasSystemInstruction = true
} }
// contents // contents
contentsJSON := "[]" contentsJSON := "[]"
hasContents := false hasContents := false
messagesResult := gjson.GetBytes(rawJSON, "messages") messagesResult := gjson.GetBytes(rawJSON, "messages")
if messagesResult.IsArray() { if messagesResult.IsArray() {
messageResults := messagesResult.Array() messageResults := messagesResult.Array()
for i := 0; i < len(messageResults); i++ { numMessages := len(messageResults)
for i := 0; i < numMessages; i++ {
messageResult := messageResults[i] messageResult := messageResults[i]
roleResult := messageResult.Get("role") roleResult := messageResult.Get("role")
if roleResult.Type != gjson.String { if roleResult.Type != gjson.String {
continue continue
} }
role := roleResult.String() originalRole := roleResult.String()
role := originalRole
if role == "assistant" { if role == "assistant" {
role = "model" role = "model"
} }
@@ -82,20 +116,58 @@ func ConvertClaudeRequestToAntigravity(modelName string, inputRawJSON []byte, _
contentsResult := messageResult.Get("content") contentsResult := messageResult.Get("content")
if contentsResult.IsArray() { if contentsResult.IsArray() {
contentResults := contentsResult.Array() contentResults := contentsResult.Array()
for j := 0; j < len(contentResults); j++ { numContents := len(contentResults)
var currentMessageThinkingSignature string
for j := 0; j < numContents; j++ {
contentResult := contentResults[j] contentResult := contentResults[j]
contentTypeResult := contentResult.Get("type") contentTypeResult := contentResult.Get("type")
if contentTypeResult.Type == gjson.String && contentTypeResult.String() == "thinking" { if contentTypeResult.Type == gjson.String && contentTypeResult.String() == "thinking" {
prompt := contentResult.Get("thinking").String() // Use GetThinkingText to handle wrapped thinking objects
thinkingText := util.GetThinkingText(contentResult)
signatureResult := contentResult.Get("signature") signatureResult := contentResult.Get("signature")
signature := geminiCLIClaudeThoughtSignature clientSignature := ""
if signatureResult.Exists() { if signatureResult.Exists() && signatureResult.String() != "" {
signature = signatureResult.String() clientSignature = signatureResult.String()
} }
// Always try cached signature first (more reliable than client-provided)
// Client may send stale or invalid signatures from different sessions
signature := ""
if sessionID != "" && thinkingText != "" {
if cachedSig := cache.GetCachedSignature(sessionID, thinkingText); cachedSig != "" {
signature = cachedSig
log.Debugf("Using cached signature for thinking block")
}
}
// Fallback to client signature only if cache miss and client signature is valid
if signature == "" && cache.HasValidSignature(clientSignature) {
signature = clientSignature
log.Debugf("Using client-provided signature for thinking block")
}
// Store for subsequent tool_use in the same message
if cache.HasValidSignature(signature) {
currentMessageThinkingSignature = signature
}
// Skip trailing unsigned thinking blocks on last assistant message
isUnsigned := !cache.HasValidSignature(signature)
// If unsigned, skip entirely (don't convert to text)
// Claude requires assistant messages to start with thinking blocks when thinking is enabled
// Converting to text would break this requirement
if isUnsigned {
// TypeScript plugin approach: drop unsigned thinking blocks entirely
log.Debugf("Dropping unsigned thinking block (no valid signature)")
continue
}
// Valid signature, send as thought block
partJSON := `{}` partJSON := `{}`
partJSON, _ = sjson.Set(partJSON, "thought", true) partJSON, _ = sjson.Set(partJSON, "thought", true)
if prompt != "" { if thinkingText != "" {
partJSON, _ = sjson.Set(partJSON, "text", prompt) partJSON, _ = sjson.Set(partJSON, "text", thinkingText)
} }
if signature != "" { if signature != "" {
partJSON, _ = sjson.Set(partJSON, "thoughtSignature", signature) partJSON, _ = sjson.Set(partJSON, "thoughtSignature", signature)
@@ -109,24 +181,47 @@ func ConvertClaudeRequestToAntigravity(modelName string, inputRawJSON []byte, _
} }
clientContentJSON, _ = sjson.SetRaw(clientContentJSON, "parts.-1", partJSON) clientContentJSON, _ = sjson.SetRaw(clientContentJSON, "parts.-1", partJSON)
} else if contentTypeResult.Type == gjson.String && contentTypeResult.String() == "tool_use" { } else if contentTypeResult.Type == gjson.String && contentTypeResult.String() == "tool_use" {
// NOTE: Do NOT inject dummy thinking blocks here.
// Antigravity API validates signatures, so dummy values are rejected.
// The TypeScript plugin removes unsigned thinking blocks instead of injecting dummies.
functionName := contentResult.Get("name").String() functionName := contentResult.Get("name").String()
functionArgs := contentResult.Get("input").String() argsResult := contentResult.Get("input")
functionID := contentResult.Get("id").String() functionID := contentResult.Get("id").String()
if gjson.Valid(functionArgs) {
argsResult := gjson.Parse(functionArgs) // Handle both object and string input formats
if argsResult.IsObject() { var argsRaw string
partJSON := `{}` if argsResult.IsObject() {
if !strings.Contains(modelName, "claude") { argsRaw = argsResult.Raw
partJSON, _ = sjson.Set(partJSON, "thoughtSignature", geminiCLIClaudeThoughtSignature) } else if argsResult.Type == gjson.String {
} // Input is a JSON string, parse and validate it
if functionID != "" { parsed := gjson.Parse(argsResult.String())
partJSON, _ = sjson.Set(partJSON, "functionCall.id", functionID) if parsed.IsObject() {
} argsRaw = parsed.Raw
partJSON, _ = sjson.Set(partJSON, "functionCall.name", functionName)
partJSON, _ = sjson.SetRaw(partJSON, "functionCall.args", argsResult.Raw)
clientContentJSON, _ = sjson.SetRaw(clientContentJSON, "parts.-1", partJSON)
} }
} }
if argsRaw != "" {
partJSON := `{}`
// Use skip_thought_signature_validator for tool calls without valid thinking signature
// This is the approach used in opencode-google-antigravity-auth for Gemini
// and also works for Claude through Antigravity API
const skipSentinel = "skip_thought_signature_validator"
if cache.HasValidSignature(currentMessageThinkingSignature) {
partJSON, _ = sjson.Set(partJSON, "thoughtSignature", currentMessageThinkingSignature)
} else {
// No valid signature - use skip sentinel to bypass validation
partJSON, _ = sjson.Set(partJSON, "thoughtSignature", skipSentinel)
}
if functionID != "" {
partJSON, _ = sjson.Set(partJSON, "functionCall.id", functionID)
}
partJSON, _ = sjson.Set(partJSON, "functionCall.name", functionName)
partJSON, _ = sjson.SetRaw(partJSON, "functionCall.args", argsRaw)
clientContentJSON, _ = sjson.SetRaw(clientContentJSON, "parts.-1", partJSON)
}
} else if contentTypeResult.Type == gjson.String && contentTypeResult.String() == "tool_result" { } else if contentTypeResult.Type == gjson.String && contentTypeResult.String() == "tool_result" {
toolCallID := contentResult.Get("tool_use_id").String() toolCallID := contentResult.Get("tool_use_id").String()
if toolCallID != "" { if toolCallID != "" {
@@ -180,6 +275,37 @@ func ConvertClaudeRequestToAntigravity(modelName string, inputRawJSON []byte, _
} }
} }
} }
// Reorder parts for 'model' role to ensure thinking block is first
if role == "model" {
partsResult := gjson.Get(clientContentJSON, "parts")
if partsResult.IsArray() {
parts := partsResult.Array()
var thinkingParts []gjson.Result
var otherParts []gjson.Result
for _, part := range parts {
if part.Get("thought").Bool() {
thinkingParts = append(thinkingParts, part)
} else {
otherParts = append(otherParts, part)
}
}
if len(thinkingParts) > 0 {
firstPartIsThinking := parts[0].Get("thought").Bool()
if !firstPartIsThinking || len(thinkingParts) > 1 {
var newParts []interface{}
for _, p := range thinkingParts {
newParts = append(newParts, p.Value())
}
for _, p := range otherParts {
newParts = append(newParts, p.Value())
}
clientContentJSON, _ = sjson.Set(clientContentJSON, "parts", newParts)
}
}
}
}
contentsJSON, _ = sjson.SetRaw(contentsJSON, "-1", clientContentJSON) contentsJSON, _ = sjson.SetRaw(contentsJSON, "-1", clientContentJSON)
hasContents = true hasContents = true
} else if contentsResult.Type == gjson.String { } else if contentsResult.Type == gjson.String {
@@ -198,6 +324,7 @@ func ConvertClaudeRequestToAntigravity(modelName string, inputRawJSON []byte, _
// tools // tools
toolsJSON := "" toolsJSON := ""
toolDeclCount := 0 toolDeclCount := 0
allowedToolKeys := []string{"name", "description", "behavior", "parameters", "parametersJsonSchema", "response", "responseJsonSchema"}
toolsResult := gjson.GetBytes(rawJSON, "tools") toolsResult := gjson.GetBytes(rawJSON, "tools")
if toolsResult.IsArray() { if toolsResult.IsArray() {
toolsJSON = `[{"functionDeclarations":[]}]` toolsJSON = `[{"functionDeclarations":[]}]`
@@ -206,13 +333,16 @@ func ConvertClaudeRequestToAntigravity(modelName string, inputRawJSON []byte, _
toolResult := toolsResults[i] toolResult := toolsResults[i]
inputSchemaResult := toolResult.Get("input_schema") inputSchemaResult := toolResult.Get("input_schema")
if inputSchemaResult.Exists() && inputSchemaResult.IsObject() { if inputSchemaResult.Exists() && inputSchemaResult.IsObject() {
inputSchema := inputSchemaResult.Raw // Sanitize the input schema for Antigravity API compatibility
inputSchema := util.CleanJSONSchemaForAntigravity(inputSchemaResult.Raw)
tool, _ := sjson.Delete(toolResult.Raw, "input_schema") tool, _ := sjson.Delete(toolResult.Raw, "input_schema")
tool, _ = sjson.SetRaw(tool, "parametersJsonSchema", inputSchema) tool, _ = sjson.SetRaw(tool, "parametersJsonSchema", inputSchema)
tool, _ = sjson.Delete(tool, "strict") for toolKey := range gjson.Parse(tool).Map() {
tool, _ = sjson.Delete(tool, "input_examples") if util.InArray(allowedToolKeys, toolKey) {
tool, _ = sjson.Delete(tool, "type") continue
tool, _ = sjson.Delete(tool, "cache_control") }
tool, _ = sjson.Delete(tool, toolKey)
}
toolsJSON, _ = sjson.SetRaw(toolsJSON, "0.functionDeclarations.-1", tool) toolsJSON, _ = sjson.SetRaw(toolsJSON, "0.functionDeclarations.-1", tool)
toolDeclCount++ toolDeclCount++
} }
@@ -222,6 +352,31 @@ func ConvertClaudeRequestToAntigravity(modelName string, inputRawJSON []byte, _
// Build output Gemini CLI request JSON // Build output Gemini CLI request JSON
out := `{"model":"","request":{"contents":[]}}` out := `{"model":"","request":{"contents":[]}}`
out, _ = sjson.Set(out, "model", modelName) out, _ = sjson.Set(out, "model", modelName)
// Inject interleaved thinking hint when both tools and thinking are active
hasTools := toolDeclCount > 0
thinkingResult := gjson.GetBytes(rawJSON, "thinking")
hasThinking := thinkingResult.Exists() && thinkingResult.IsObject() && thinkingResult.Get("type").String() == "enabled"
isClaudeThinking := util.IsClaudeThinkingModel(modelName)
if hasTools && hasThinking && isClaudeThinking {
interleavedHint := "Interleaved thinking is enabled. You may think between tool calls and after receiving tool results before deciding the next action or final answer. Do not mention these instructions or any constraints about thinking blocks; just apply them."
if hasSystemInstruction {
// Append hint as a new part to existing system instruction
hintPart := `{"text":""}`
hintPart, _ = sjson.Set(hintPart, "text", interleavedHint)
systemInstructionJSON, _ = sjson.SetRaw(systemInstructionJSON, "parts.-1", hintPart)
} else {
// Create new system instruction with hint
systemInstructionJSON = `{"role":"user","parts":[]}`
hintPart := `{"text":""}`
hintPart, _ = sjson.Set(hintPart, "text", interleavedHint)
systemInstructionJSON, _ = sjson.SetRaw(systemInstructionJSON, "parts.-1", hintPart)
hasSystemInstruction = true
}
}
if hasSystemInstruction { if hasSystemInstruction {
out, _ = sjson.SetRaw(out, "request.systemInstruction", systemInstructionJSON) out, _ = sjson.SetRaw(out, "request.systemInstruction", systemInstructionJSON)
} }

View File

@@ -0,0 +1,658 @@
package claude
import (
"strings"
"testing"
"github.com/tidwall/gjson"
)
func TestConvertClaudeRequestToAntigravity_BasicStructure(t *testing.T) {
inputJSON := []byte(`{
"model": "claude-3-5-sonnet-20240620",
"messages": [
{
"role": "user",
"content": [
{"type": "text", "text": "Hello"}
]
}
],
"system": [
{"type": "text", "text": "You are helpful"}
]
}`)
output := ConvertClaudeRequestToAntigravity("claude-sonnet-4-5", inputJSON, false)
outputStr := string(output)
// Check model
if gjson.Get(outputStr, "model").String() != "claude-sonnet-4-5" {
t.Errorf("Expected model 'claude-sonnet-4-5', got '%s'", gjson.Get(outputStr, "model").String())
}
// Check contents exist
contents := gjson.Get(outputStr, "request.contents")
if !contents.Exists() || !contents.IsArray() {
t.Error("request.contents should exist and be an array")
}
// Check role mapping (assistant -> model)
firstContent := gjson.Get(outputStr, "request.contents.0")
if firstContent.Get("role").String() != "user" {
t.Errorf("Expected role 'user', got '%s'", firstContent.Get("role").String())
}
// Check systemInstruction
sysInstruction := gjson.Get(outputStr, "request.systemInstruction")
if !sysInstruction.Exists() {
t.Error("systemInstruction should exist")
}
if sysInstruction.Get("parts.0.text").String() != "You are helpful" {
t.Error("systemInstruction text mismatch")
}
}
func TestConvertClaudeRequestToAntigravity_RoleMapping(t *testing.T) {
inputJSON := []byte(`{
"model": "claude-3-5-sonnet-20240620",
"messages": [
{"role": "user", "content": [{"type": "text", "text": "Hi"}]},
{"role": "assistant", "content": [{"type": "text", "text": "Hello"}]}
]
}`)
output := ConvertClaudeRequestToAntigravity("claude-sonnet-4-5", inputJSON, false)
outputStr := string(output)
// assistant should be mapped to model
secondContent := gjson.Get(outputStr, "request.contents.1")
if secondContent.Get("role").String() != "model" {
t.Errorf("Expected role 'model' (mapped from 'assistant'), got '%s'", secondContent.Get("role").String())
}
}
func TestConvertClaudeRequestToAntigravity_ThinkingBlocks(t *testing.T) {
// Valid signature must be at least 50 characters
validSignature := "abc123validSignature1234567890123456789012345678901234567890"
inputJSON := []byte(`{
"model": "claude-sonnet-4-5-thinking",
"messages": [
{
"role": "assistant",
"content": [
{"type": "thinking", "thinking": "Let me think...", "signature": "` + validSignature + `"},
{"type": "text", "text": "Answer"}
]
}
]
}`)
output := ConvertClaudeRequestToAntigravity("claude-sonnet-4-5-thinking", inputJSON, false)
outputStr := string(output)
// Check thinking block conversion
firstPart := gjson.Get(outputStr, "request.contents.0.parts.0")
if !firstPart.Get("thought").Bool() {
t.Error("thinking block should have thought: true")
}
if firstPart.Get("text").String() != "Let me think..." {
t.Error("thinking text mismatch")
}
if firstPart.Get("thoughtSignature").String() != validSignature {
t.Errorf("Expected thoughtSignature '%s', got '%s'", validSignature, firstPart.Get("thoughtSignature").String())
}
}
func TestConvertClaudeRequestToAntigravity_ThinkingBlockWithoutSignature(t *testing.T) {
// Unsigned thinking blocks should be removed entirely (not converted to text)
inputJSON := []byte(`{
"model": "claude-sonnet-4-5-thinking",
"messages": [
{
"role": "assistant",
"content": [
{"type": "thinking", "thinking": "Let me think..."},
{"type": "text", "text": "Answer"}
]
}
]
}`)
output := ConvertClaudeRequestToAntigravity("claude-sonnet-4-5-thinking", inputJSON, false)
outputStr := string(output)
// Without signature, thinking block should be removed (not converted to text)
parts := gjson.Get(outputStr, "request.contents.0.parts").Array()
if len(parts) != 1 {
t.Fatalf("Expected 1 part (thinking removed), got %d", len(parts))
}
// Only text part should remain
if parts[0].Get("thought").Bool() {
t.Error("Thinking block should be removed, not preserved")
}
if parts[0].Get("text").String() != "Answer" {
t.Errorf("Expected text 'Answer', got '%s'", parts[0].Get("text").String())
}
}
func TestConvertClaudeRequestToAntigravity_ToolDeclarations(t *testing.T) {
inputJSON := []byte(`{
"model": "claude-3-5-sonnet-20240620",
"messages": [],
"tools": [
{
"name": "test_tool",
"description": "A test tool",
"input_schema": {
"type": "object",
"properties": {
"name": {"type": "string"}
},
"required": ["name"]
}
}
]
}`)
output := ConvertClaudeRequestToAntigravity("gemini-1.5-pro", inputJSON, false)
outputStr := string(output)
// Check tools structure
tools := gjson.Get(outputStr, "request.tools")
if !tools.Exists() {
t.Error("Tools should exist in output")
}
funcDecl := gjson.Get(outputStr, "request.tools.0.functionDeclarations.0")
if funcDecl.Get("name").String() != "test_tool" {
t.Errorf("Expected tool name 'test_tool', got '%s'", funcDecl.Get("name").String())
}
// Check input_schema renamed to parametersJsonSchema
if funcDecl.Get("parametersJsonSchema").Exists() {
t.Log("parametersJsonSchema exists (expected)")
}
if funcDecl.Get("input_schema").Exists() {
t.Error("input_schema should be removed")
}
}
func TestConvertClaudeRequestToAntigravity_ToolUse(t *testing.T) {
inputJSON := []byte(`{
"model": "claude-3-5-sonnet-20240620",
"messages": [
{
"role": "assistant",
"content": [
{
"type": "tool_use",
"id": "call_123",
"name": "get_weather",
"input": "{\"location\": \"Paris\"}"
}
]
}
]
}`)
output := ConvertClaudeRequestToAntigravity("claude-sonnet-4-5", inputJSON, false)
outputStr := string(output)
// Now we expect only 1 part (tool_use), no dummy thinking block injected
parts := gjson.Get(outputStr, "request.contents.0.parts").Array()
if len(parts) != 1 {
t.Fatalf("Expected 1 part (tool only, no dummy injection), got %d", len(parts))
}
// Check function call conversion at parts[0]
funcCall := parts[0].Get("functionCall")
if !funcCall.Exists() {
t.Error("functionCall should exist at parts[0]")
}
if funcCall.Get("name").String() != "get_weather" {
t.Errorf("Expected function name 'get_weather', got '%s'", funcCall.Get("name").String())
}
if funcCall.Get("id").String() != "call_123" {
t.Errorf("Expected function id 'call_123', got '%s'", funcCall.Get("id").String())
}
// Verify skip_thought_signature_validator is added (bypass for tools without valid thinking)
expectedSig := "skip_thought_signature_validator"
actualSig := parts[0].Get("thoughtSignature").String()
if actualSig != expectedSig {
t.Errorf("Expected thoughtSignature '%s', got '%s'", expectedSig, actualSig)
}
}
func TestConvertClaudeRequestToAntigravity_ToolUse_WithSignature(t *testing.T) {
validSignature := "abc123validSignature1234567890123456789012345678901234567890"
inputJSON := []byte(`{
"model": "claude-sonnet-4-5-thinking",
"messages": [
{
"role": "assistant",
"content": [
{"type": "thinking", "thinking": "Let me think...", "signature": "` + validSignature + `"},
{
"type": "tool_use",
"id": "call_123",
"name": "get_weather",
"input": "{\"location\": \"Paris\"}"
}
]
}
]
}`)
output := ConvertClaudeRequestToAntigravity("claude-sonnet-4-5-thinking", inputJSON, false)
outputStr := string(output)
// Check function call has the signature from the preceding thinking block
part := gjson.Get(outputStr, "request.contents.0.parts.1")
if part.Get("functionCall.name").String() != "get_weather" {
t.Errorf("Expected functionCall, got %s", part.Raw)
}
if part.Get("thoughtSignature").String() != validSignature {
t.Errorf("Expected thoughtSignature '%s' on tool_use, got '%s'", validSignature, part.Get("thoughtSignature").String())
}
}
func TestConvertClaudeRequestToAntigravity_ReorderThinking(t *testing.T) {
// Case: text block followed by thinking block -> should be reordered to thinking first
validSignature := "abc123validSignature1234567890123456789012345678901234567890"
inputJSON := []byte(`{
"model": "claude-sonnet-4-5-thinking",
"messages": [
{
"role": "assistant",
"content": [
{"type": "text", "text": "Here is the plan."},
{"type": "thinking", "thinking": "Planning...", "signature": "` + validSignature + `"}
]
}
]
}`)
output := ConvertClaudeRequestToAntigravity("claude-sonnet-4-5-thinking", inputJSON, false)
outputStr := string(output)
// Verify order: Thinking block MUST be first
parts := gjson.Get(outputStr, "request.contents.0.parts").Array()
if len(parts) != 2 {
t.Fatalf("Expected 2 parts, got %d", len(parts))
}
if !parts[0].Get("thought").Bool() {
t.Error("First part should be thinking block after reordering")
}
if parts[1].Get("text").String() != "Here is the plan." {
t.Error("Second part should be text block")
}
}
func TestConvertClaudeRequestToAntigravity_ToolResult(t *testing.T) {
inputJSON := []byte(`{
"model": "claude-3-5-sonnet-20240620",
"messages": [
{
"role": "user",
"content": [
{
"type": "tool_result",
"tool_use_id": "get_weather-call-123",
"content": "22C sunny"
}
]
}
]
}`)
output := ConvertClaudeRequestToAntigravity("claude-sonnet-4-5", inputJSON, false)
outputStr := string(output)
// Check function response conversion
funcResp := gjson.Get(outputStr, "request.contents.0.parts.0.functionResponse")
if !funcResp.Exists() {
t.Error("functionResponse should exist")
}
if funcResp.Get("id").String() != "get_weather-call-123" {
t.Errorf("Expected function id, got '%s'", funcResp.Get("id").String())
}
}
func TestConvertClaudeRequestToAntigravity_ThinkingConfig(t *testing.T) {
// Note: This test requires the model to be registered in the registry
// with Thinking metadata. If the registry is not populated in test environment,
// thinkingConfig won't be added. We'll test the basic structure only.
inputJSON := []byte(`{
"model": "claude-sonnet-4-5-thinking",
"messages": [],
"thinking": {
"type": "enabled",
"budget_tokens": 8000
}
}`)
output := ConvertClaudeRequestToAntigravity("claude-sonnet-4-5-thinking", inputJSON, false)
outputStr := string(output)
// Check thinking config conversion (only if model supports thinking in registry)
thinkingConfig := gjson.Get(outputStr, "request.generationConfig.thinkingConfig")
if thinkingConfig.Exists() {
if thinkingConfig.Get("thinkingBudget").Int() != 8000 {
t.Errorf("Expected thinkingBudget 8000, got %d", thinkingConfig.Get("thinkingBudget").Int())
}
if !thinkingConfig.Get("include_thoughts").Bool() {
t.Error("include_thoughts should be true")
}
} else {
t.Log("thinkingConfig not present - model may not be registered in test registry")
}
}
func TestConvertClaudeRequestToAntigravity_ImageContent(t *testing.T) {
inputJSON := []byte(`{
"model": "claude-3-5-sonnet-20240620",
"messages": [
{
"role": "user",
"content": [
{
"type": "image",
"source": {
"type": "base64",
"media_type": "image/png",
"data": "iVBORw0KGgoAAAANSUhEUg=="
}
}
]
}
]
}`)
output := ConvertClaudeRequestToAntigravity("claude-sonnet-4-5", inputJSON, false)
outputStr := string(output)
// Check inline data conversion
inlineData := gjson.Get(outputStr, "request.contents.0.parts.0.inlineData")
if !inlineData.Exists() {
t.Error("inlineData should exist")
}
if inlineData.Get("mime_type").String() != "image/png" {
t.Error("mime_type mismatch")
}
if !strings.Contains(inlineData.Get("data").String(), "iVBORw0KGgo") {
t.Error("data mismatch")
}
}
func TestConvertClaudeRequestToAntigravity_GenerationConfig(t *testing.T) {
inputJSON := []byte(`{
"model": "claude-3-5-sonnet-20240620",
"messages": [],
"temperature": 0.7,
"top_p": 0.9,
"top_k": 40,
"max_tokens": 2000
}`)
output := ConvertClaudeRequestToAntigravity("claude-sonnet-4-5", inputJSON, false)
outputStr := string(output)
genConfig := gjson.Get(outputStr, "request.generationConfig")
if genConfig.Get("temperature").Float() != 0.7 {
t.Errorf("Expected temperature 0.7, got %f", genConfig.Get("temperature").Float())
}
if genConfig.Get("topP").Float() != 0.9 {
t.Errorf("Expected topP 0.9, got %f", genConfig.Get("topP").Float())
}
if genConfig.Get("topK").Float() != 40 {
t.Errorf("Expected topK 40, got %f", genConfig.Get("topK").Float())
}
if genConfig.Get("maxOutputTokens").Float() != 2000 {
t.Errorf("Expected maxOutputTokens 2000, got %f", genConfig.Get("maxOutputTokens").Float())
}
}
// ============================================================================
// Trailing Unsigned Thinking Block Removal
// ============================================================================
func TestConvertClaudeRequestToAntigravity_TrailingUnsignedThinking_Removed(t *testing.T) {
// Last assistant message ends with unsigned thinking block - should be removed
inputJSON := []byte(`{
"model": "claude-sonnet-4-5-thinking",
"messages": [
{
"role": "user",
"content": [{"type": "text", "text": "Hello"}]
},
{
"role": "assistant",
"content": [
{"type": "text", "text": "Here is my answer"},
{"type": "thinking", "thinking": "I should think more..."}
]
}
]
}`)
output := ConvertClaudeRequestToAntigravity("claude-sonnet-4-5-thinking", inputJSON, false)
outputStr := string(output)
// The last part of the last assistant message should NOT be a thinking block
lastMessageParts := gjson.Get(outputStr, "request.contents.1.parts")
if !lastMessageParts.IsArray() {
t.Fatal("Last message should have parts array")
}
parts := lastMessageParts.Array()
if len(parts) == 0 {
t.Fatal("Last message should have at least one part")
}
// The unsigned thinking should be removed, leaving only the text
lastPart := parts[len(parts)-1]
if lastPart.Get("thought").Bool() {
t.Error("Trailing unsigned thinking block should be removed")
}
}
func TestConvertClaudeRequestToAntigravity_TrailingSignedThinking_Kept(t *testing.T) {
// Last assistant message ends with signed thinking block - should be kept
inputJSON := []byte(`{
"model": "claude-sonnet-4-5-thinking",
"messages": [
{
"role": "user",
"content": [{"type": "text", "text": "Hello"}]
},
{
"role": "assistant",
"content": [
{"type": "text", "text": "Here is my answer"},
{"type": "thinking", "thinking": "Valid thinking...", "signature": "abc123validSignature1234567890123456789012345678901234567890"}
]
}
]
}`)
output := ConvertClaudeRequestToAntigravity("claude-sonnet-4-5-thinking", inputJSON, false)
outputStr := string(output)
// The signed thinking block should be preserved
lastMessageParts := gjson.Get(outputStr, "request.contents.1.parts")
parts := lastMessageParts.Array()
if len(parts) < 2 {
t.Error("Signed thinking block should be preserved")
}
}
func TestConvertClaudeRequestToAntigravity_MiddleUnsignedThinking_Removed(t *testing.T) {
// Middle message has unsigned thinking - should be removed entirely
inputJSON := []byte(`{
"model": "claude-sonnet-4-5-thinking",
"messages": [
{
"role": "assistant",
"content": [
{"type": "thinking", "thinking": "Middle thinking..."},
{"type": "text", "text": "Answer"}
]
},
{
"role": "user",
"content": [{"type": "text", "text": "Follow up"}]
}
]
}`)
output := ConvertClaudeRequestToAntigravity("claude-sonnet-4-5-thinking", inputJSON, false)
outputStr := string(output)
// Unsigned thinking should be removed entirely
parts := gjson.Get(outputStr, "request.contents.0.parts").Array()
if len(parts) != 1 {
t.Fatalf("Expected 1 part (thinking removed), got %d", len(parts))
}
// Only text part should remain
if parts[0].Get("thought").Bool() {
t.Error("Thinking block should be removed, not preserved")
}
if parts[0].Get("text").String() != "Answer" {
t.Errorf("Expected text 'Answer', got '%s'", parts[0].Get("text").String())
}
}
// ============================================================================
// Tool + Thinking System Hint Injection
// ============================================================================
func TestConvertClaudeRequestToAntigravity_ToolAndThinking_HintInjected(t *testing.T) {
// When both tools and thinking are enabled, hint should be injected into system instruction
inputJSON := []byte(`{
"model": "claude-sonnet-4-5-thinking",
"messages": [{"role": "user", "content": [{"type": "text", "text": "Hello"}]}],
"system": [{"type": "text", "text": "You are helpful."}],
"tools": [
{
"name": "get_weather",
"description": "Get weather",
"input_schema": {"type": "object", "properties": {"location": {"type": "string"}}}
}
],
"thinking": {"type": "enabled", "budget_tokens": 8000}
}`)
output := ConvertClaudeRequestToAntigravity("claude-sonnet-4-5-thinking", inputJSON, false)
outputStr := string(output)
// System instruction should contain the interleaved thinking hint
sysInstruction := gjson.Get(outputStr, "request.systemInstruction")
if !sysInstruction.Exists() {
t.Fatal("systemInstruction should exist")
}
// Check if hint is appended
sysText := sysInstruction.Get("parts").Array()
found := false
for _, part := range sysText {
if strings.Contains(part.Get("text").String(), "Interleaved thinking is enabled") {
found = true
break
}
}
if !found {
t.Errorf("Interleaved thinking hint should be injected when tools and thinking are both active, got: %v", sysInstruction.Raw)
}
}
func TestConvertClaudeRequestToAntigravity_ToolsOnly_NoHint(t *testing.T) {
// When only tools are present (no thinking), hint should NOT be injected
inputJSON := []byte(`{
"model": "claude-sonnet-4-5",
"messages": [{"role": "user", "content": [{"type": "text", "text": "Hello"}]}],
"system": [{"type": "text", "text": "You are helpful."}],
"tools": [
{
"name": "get_weather",
"description": "Get weather",
"input_schema": {"type": "object", "properties": {"location": {"type": "string"}}}
}
]
}`)
output := ConvertClaudeRequestToAntigravity("claude-sonnet-4-5", inputJSON, false)
outputStr := string(output)
// System instruction should NOT contain the hint
sysInstruction := gjson.Get(outputStr, "request.systemInstruction")
if sysInstruction.Exists() {
for _, part := range sysInstruction.Get("parts").Array() {
if strings.Contains(part.Get("text").String(), "Interleaved thinking is enabled") {
t.Error("Hint should NOT be injected when only tools are present (no thinking)")
}
}
}
}
func TestConvertClaudeRequestToAntigravity_ThinkingOnly_NoHint(t *testing.T) {
// When only thinking is enabled (no tools), hint should NOT be injected
inputJSON := []byte(`{
"model": "claude-sonnet-4-5-thinking",
"messages": [{"role": "user", "content": [{"type": "text", "text": "Hello"}]}],
"system": [{"type": "text", "text": "You are helpful."}],
"thinking": {"type": "enabled", "budget_tokens": 8000}
}`)
output := ConvertClaudeRequestToAntigravity("claude-sonnet-4-5-thinking", inputJSON, false)
outputStr := string(output)
// System instruction should NOT contain the hint (no tools)
sysInstruction := gjson.Get(outputStr, "request.systemInstruction")
if sysInstruction.Exists() {
for _, part := range sysInstruction.Get("parts").Array() {
if strings.Contains(part.Get("text").String(), "Interleaved thinking is enabled") {
t.Error("Hint should NOT be injected when only thinking is present (no tools)")
}
}
}
}
func TestConvertClaudeRequestToAntigravity_ToolAndThinking_NoExistingSystem(t *testing.T) {
// When tools + thinking but no system instruction, should create one with hint
inputJSON := []byte(`{
"model": "claude-sonnet-4-5-thinking",
"messages": [{"role": "user", "content": [{"type": "text", "text": "Hello"}]}],
"tools": [
{
"name": "get_weather",
"description": "Get weather",
"input_schema": {"type": "object", "properties": {"location": {"type": "string"}}}
}
],
"thinking": {"type": "enabled", "budget_tokens": 8000}
}`)
output := ConvertClaudeRequestToAntigravity("claude-sonnet-4-5-thinking", inputJSON, false)
outputStr := string(output)
// System instruction should be created with hint
sysInstruction := gjson.Get(outputStr, "request.systemInstruction")
if !sysInstruction.Exists() {
t.Fatal("systemInstruction should be created when tools + thinking are active")
}
sysText := sysInstruction.Get("parts").Array()
found := false
for _, part := range sysText {
if strings.Contains(part.Get("text").String(), "Interleaved thinking is enabled") {
found = true
break
}
}
if !found {
t.Errorf("Interleaved thinking hint should be in created systemInstruction, got: %v", sysInstruction.Raw)
}
}

View File

@@ -14,6 +14,9 @@ import (
"sync/atomic" "sync/atomic"
"time" "time"
"github.com/router-for-me/CLIProxyAPI/v6/internal/cache"
log "github.com/sirupsen/logrus"
"github.com/tidwall/gjson" "github.com/tidwall/gjson"
"github.com/tidwall/sjson" "github.com/tidwall/sjson"
) )
@@ -35,6 +38,10 @@ type Params struct {
HasSentFinalEvents bool // Indicates if final content/message events have been sent HasSentFinalEvents bool // Indicates if final content/message events have been sent
HasToolUse bool // Indicates if tool use was observed in the stream HasToolUse bool // Indicates if tool use was observed in the stream
HasContent bool // Tracks whether any content (text, thinking, or tool use) has been output HasContent bool // Tracks whether any content (text, thinking, or tool use) has been output
// Signature caching support
SessionID string // Session ID derived from request for signature caching
CurrentThinkingText strings.Builder // Accumulates thinking text for signature caching
} }
// toolUseIDCounter provides a process-wide unique counter for tool use identifiers. // toolUseIDCounter provides a process-wide unique counter for tool use identifiers.
@@ -62,6 +69,7 @@ func ConvertAntigravityResponseToClaude(_ context.Context, _ string, originalReq
HasFirstResponse: false, HasFirstResponse: false,
ResponseType: 0, ResponseType: 0,
ResponseIndex: 0, ResponseIndex: 0,
SessionID: deriveSessionID(originalRequestRawJSON),
} }
} }
@@ -119,11 +127,20 @@ func ConvertAntigravityResponseToClaude(_ context.Context, _ string, originalReq
// Process thinking content (internal reasoning) // Process thinking content (internal reasoning)
if partResult.Get("thought").Bool() { if partResult.Get("thought").Bool() {
if thoughtSignature := partResult.Get("thoughtSignature"); thoughtSignature.Exists() && thoughtSignature.String() != "" { if thoughtSignature := partResult.Get("thoughtSignature"); thoughtSignature.Exists() && thoughtSignature.String() != "" {
log.Debug("Branch: signature_delta")
if params.SessionID != "" && params.CurrentThinkingText.Len() > 0 {
cache.CacheSignature(params.SessionID, params.CurrentThinkingText.String(), thoughtSignature.String())
log.Debugf("Cached signature for thinking block (sessionID=%s, textLen=%d)", params.SessionID, params.CurrentThinkingText.Len())
params.CurrentThinkingText.Reset()
}
output = output + "event: content_block_delta\n" output = output + "event: content_block_delta\n"
data, _ := sjson.Set(fmt.Sprintf(`{"type":"content_block_delta","index":%d,"delta":{"type":"signature_delta","signature":""}}`, params.ResponseIndex), "delta.signature", thoughtSignature.String()) data, _ := sjson.Set(fmt.Sprintf(`{"type":"content_block_delta","index":%d,"delta":{"type":"signature_delta","signature":""}}`, params.ResponseIndex), "delta.signature", thoughtSignature.String())
output = output + fmt.Sprintf("data: %s\n\n\n", data) output = output + fmt.Sprintf("data: %s\n\n\n", data)
params.HasContent = true params.HasContent = true
} else if params.ResponseType == 2 { // Continue existing thinking block if already in thinking state } else if params.ResponseType == 2 { // Continue existing thinking block if already in thinking state
params.CurrentThinkingText.WriteString(partTextResult.String())
output = output + "event: content_block_delta\n" output = output + "event: content_block_delta\n"
data, _ := sjson.Set(fmt.Sprintf(`{"type":"content_block_delta","index":%d,"delta":{"type":"thinking_delta","thinking":""}}`, params.ResponseIndex), "delta.thinking", partTextResult.String()) data, _ := sjson.Set(fmt.Sprintf(`{"type":"content_block_delta","index":%d,"delta":{"type":"thinking_delta","thinking":""}}`, params.ResponseIndex), "delta.thinking", partTextResult.String())
output = output + fmt.Sprintf("data: %s\n\n\n", data) output = output + fmt.Sprintf("data: %s\n\n\n", data)
@@ -152,6 +169,9 @@ func ConvertAntigravityResponseToClaude(_ context.Context, _ string, originalReq
output = output + fmt.Sprintf("data: %s\n\n\n", data) output = output + fmt.Sprintf("data: %s\n\n\n", data)
params.ResponseType = 2 // Set state to thinking params.ResponseType = 2 // Set state to thinking
params.HasContent = true params.HasContent = true
// Start accumulating thinking text for signature caching
params.CurrentThinkingText.Reset()
params.CurrentThinkingText.WriteString(partTextResult.String())
} }
} else { } else {
finishReasonResult := gjson.GetBytes(rawJSON, "response.candidates.0.finishReason") finishReasonResult := gjson.GetBytes(rawJSON, "response.candidates.0.finishReason")
@@ -432,7 +452,7 @@ func ConvertAntigravityResponseToClaudeNonStream(_ context.Context, _ string, or
toolBlock, _ = sjson.Set(toolBlock, "id", fmt.Sprintf("tool_%d", toolIDCounter)) toolBlock, _ = sjson.Set(toolBlock, "id", fmt.Sprintf("tool_%d", toolIDCounter))
toolBlock, _ = sjson.Set(toolBlock, "name", name) toolBlock, _ = sjson.Set(toolBlock, "name", name)
if args := functionCall.Get("args"); args.Exists() && args.Raw != "" && gjson.Valid(args.Raw) { if args := functionCall.Get("args"); args.Exists() && args.Raw != "" && gjson.Valid(args.Raw) && args.IsObject() {
toolBlock, _ = sjson.SetRaw(toolBlock, "input", args.Raw) toolBlock, _ = sjson.SetRaw(toolBlock, "input", args.Raw)
} }

View File

@@ -0,0 +1,316 @@
package claude
import (
"context"
"strings"
"testing"
"github.com/router-for-me/CLIProxyAPI/v6/internal/cache"
)
// ============================================================================
// Signature Caching Tests
// ============================================================================
func TestConvertAntigravityResponseToClaude_SessionIDDerived(t *testing.T) {
cache.ClearSignatureCache("")
// Request with user message - should derive session ID
requestJSON := []byte(`{
"messages": [
{"role": "user", "content": [{"type": "text", "text": "Hello world"}]}
]
}`)
// First response chunk with thinking
responseJSON := []byte(`{
"response": {
"candidates": [{
"content": {
"parts": [{"text": "Let me think...", "thought": true}]
}
}]
}
}`)
var param any
ctx := context.Background()
ConvertAntigravityResponseToClaude(ctx, "claude-sonnet-4-5-thinking", requestJSON, requestJSON, responseJSON, &param)
// Verify session ID was set
params := param.(*Params)
if params.SessionID == "" {
t.Error("SessionID should be derived from request")
}
}
func TestConvertAntigravityResponseToClaude_ThinkingTextAccumulated(t *testing.T) {
cache.ClearSignatureCache("")
requestJSON := []byte(`{
"messages": [{"role": "user", "content": [{"type": "text", "text": "Test"}]}]
}`)
// First thinking chunk
chunk1 := []byte(`{
"response": {
"candidates": [{
"content": {
"parts": [{"text": "First part of thinking...", "thought": true}]
}
}]
}
}`)
// Second thinking chunk (continuation)
chunk2 := []byte(`{
"response": {
"candidates": [{
"content": {
"parts": [{"text": " Second part of thinking...", "thought": true}]
}
}]
}
}`)
var param any
ctx := context.Background()
// Process first chunk - starts new thinking block
ConvertAntigravityResponseToClaude(ctx, "claude-sonnet-4-5-thinking", requestJSON, requestJSON, chunk1, &param)
params := param.(*Params)
if params.CurrentThinkingText.Len() == 0 {
t.Error("Thinking text should be accumulated after first chunk")
}
// Process second chunk - continues thinking block
ConvertAntigravityResponseToClaude(ctx, "claude-sonnet-4-5-thinking", requestJSON, requestJSON, chunk2, &param)
text := params.CurrentThinkingText.String()
if !strings.Contains(text, "First part") || !strings.Contains(text, "Second part") {
t.Errorf("Thinking text should accumulate both parts, got: %s", text)
}
}
func TestConvertAntigravityResponseToClaude_SignatureCached(t *testing.T) {
cache.ClearSignatureCache("")
requestJSON := []byte(`{
"messages": [{"role": "user", "content": [{"type": "text", "text": "Cache test"}]}]
}`)
// Thinking chunk
thinkingChunk := []byte(`{
"response": {
"candidates": [{
"content": {
"parts": [{"text": "My thinking process here", "thought": true}]
}
}]
}
}`)
// Signature chunk
validSignature := "abc123validSignature1234567890123456789012345678901234567890"
signatureChunk := []byte(`{
"response": {
"candidates": [{
"content": {
"parts": [{"text": "", "thought": true, "thoughtSignature": "` + validSignature + `"}]
}
}]
}
}`)
var param any
ctx := context.Background()
// Process thinking chunk
ConvertAntigravityResponseToClaude(ctx, "claude-sonnet-4-5-thinking", requestJSON, requestJSON, thinkingChunk, &param)
params := param.(*Params)
sessionID := params.SessionID
thinkingText := params.CurrentThinkingText.String()
if sessionID == "" {
t.Fatal("SessionID should be set")
}
if thinkingText == "" {
t.Fatal("Thinking text should be accumulated")
}
// Process signature chunk - should cache the signature
ConvertAntigravityResponseToClaude(ctx, "claude-sonnet-4-5-thinking", requestJSON, requestJSON, signatureChunk, &param)
// Verify signature was cached
cachedSig := cache.GetCachedSignature(sessionID, thinkingText)
if cachedSig != validSignature {
t.Errorf("Expected cached signature '%s', got '%s'", validSignature, cachedSig)
}
// Verify thinking text was reset after caching
if params.CurrentThinkingText.Len() != 0 {
t.Error("Thinking text should be reset after signature is cached")
}
}
func TestConvertAntigravityResponseToClaude_MultipleThinkingBlocks(t *testing.T) {
cache.ClearSignatureCache("")
requestJSON := []byte(`{
"messages": [{"role": "user", "content": [{"type": "text", "text": "Multi block test"}]}]
}`)
validSig1 := "signature1_12345678901234567890123456789012345678901234567"
validSig2 := "signature2_12345678901234567890123456789012345678901234567"
// First thinking block with signature
block1Thinking := []byte(`{
"response": {
"candidates": [{
"content": {
"parts": [{"text": "First thinking block", "thought": true}]
}
}]
}
}`)
block1Sig := []byte(`{
"response": {
"candidates": [{
"content": {
"parts": [{"text": "", "thought": true, "thoughtSignature": "` + validSig1 + `"}]
}
}]
}
}`)
// Text content (breaks thinking)
textBlock := []byte(`{
"response": {
"candidates": [{
"content": {
"parts": [{"text": "Regular text output"}]
}
}]
}
}`)
// Second thinking block with signature
block2Thinking := []byte(`{
"response": {
"candidates": [{
"content": {
"parts": [{"text": "Second thinking block", "thought": true}]
}
}]
}
}`)
block2Sig := []byte(`{
"response": {
"candidates": [{
"content": {
"parts": [{"text": "", "thought": true, "thoughtSignature": "` + validSig2 + `"}]
}
}]
}
}`)
var param any
ctx := context.Background()
// Process first thinking block
ConvertAntigravityResponseToClaude(ctx, "claude-sonnet-4-5-thinking", requestJSON, requestJSON, block1Thinking, &param)
params := param.(*Params)
sessionID := params.SessionID
firstThinkingText := params.CurrentThinkingText.String()
ConvertAntigravityResponseToClaude(ctx, "claude-sonnet-4-5-thinking", requestJSON, requestJSON, block1Sig, &param)
// Verify first signature cached
if cache.GetCachedSignature(sessionID, firstThinkingText) != validSig1 {
t.Error("First thinking block signature should be cached")
}
// Process text (transitions out of thinking)
ConvertAntigravityResponseToClaude(ctx, "claude-sonnet-4-5-thinking", requestJSON, requestJSON, textBlock, &param)
// Process second thinking block
ConvertAntigravityResponseToClaude(ctx, "claude-sonnet-4-5-thinking", requestJSON, requestJSON, block2Thinking, &param)
secondThinkingText := params.CurrentThinkingText.String()
ConvertAntigravityResponseToClaude(ctx, "claude-sonnet-4-5-thinking", requestJSON, requestJSON, block2Sig, &param)
// Verify second signature cached
if cache.GetCachedSignature(sessionID, secondThinkingText) != validSig2 {
t.Error("Second thinking block signature should be cached")
}
}
func TestDeriveSessionIDFromRequest(t *testing.T) {
tests := []struct {
name string
input []byte
wantEmpty bool
}{
{
name: "valid user message",
input: []byte(`{"messages": [{"role": "user", "content": "Hello"}]}`),
wantEmpty: false,
},
{
name: "user message with content array",
input: []byte(`{"messages": [{"role": "user", "content": [{"type": "text", "text": "Hello"}]}]}`),
wantEmpty: false,
},
{
name: "no user message",
input: []byte(`{"messages": [{"role": "assistant", "content": "Hi"}]}`),
wantEmpty: true,
},
{
name: "empty messages",
input: []byte(`{"messages": []}`),
wantEmpty: true,
},
{
name: "no messages field",
input: []byte(`{}`),
wantEmpty: true,
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
result := deriveSessionID(tt.input)
if tt.wantEmpty && result != "" {
t.Errorf("Expected empty session ID, got '%s'", result)
}
if !tt.wantEmpty && result == "" {
t.Error("Expected non-empty session ID")
}
})
}
}
func TestDeriveSessionIDFromRequest_Deterministic(t *testing.T) {
input := []byte(`{"messages": [{"role": "user", "content": "Same message"}]}`)
id1 := deriveSessionID(input)
id2 := deriveSessionID(input)
if id1 != id2 {
t.Errorf("Session ID should be deterministic: '%s' != '%s'", id1, id2)
}
}
func TestDeriveSessionIDFromRequest_DifferentMessages(t *testing.T) {
input1 := []byte(`{"messages": [{"role": "user", "content": "Message A"}]}`)
input2 := []byte(`{"messages": [{"role": "user", "content": "Message B"}]}`)
id1 := deriveSessionID(input1)
id2 := deriveSessionID(input2)
if id1 == id2 {
t.Error("Different messages should produce different session IDs")
}
}

View File

@@ -7,7 +7,6 @@ package gemini
import ( import (
"bytes" "bytes"
"encoding/json"
"fmt" "fmt"
"github.com/router-for-me/CLIProxyAPI/v6/internal/translator/gemini/common" "github.com/router-for-me/CLIProxyAPI/v6/internal/translator/gemini/common"
@@ -98,16 +97,34 @@ func ConvertGeminiRequestToAntigravity(_ string, inputRawJSON []byte, _ bool) []
} }
} }
gjson.GetBytes(rawJSON, "request.contents").ForEach(func(key, content gjson.Result) bool { // Gemini-specific handling: add skip_thought_signature_validator to functionCall parts
// and remove thinking blocks entirely (Gemini doesn't need to preserve them)
const skipSentinel = "skip_thought_signature_validator"
gjson.GetBytes(rawJSON, "request.contents").ForEach(func(contentIdx, content gjson.Result) bool {
if content.Get("role").String() == "model" { if content.Get("role").String() == "model" {
content.Get("parts").ForEach(func(partKey, part gjson.Result) bool { // First pass: collect indices of thinking parts to remove
var thinkingIndicesToRemove []int64
content.Get("parts").ForEach(func(partIdx, part gjson.Result) bool {
// Mark thinking blocks for removal
if part.Get("thought").Bool() {
thinkingIndicesToRemove = append(thinkingIndicesToRemove, partIdx.Int())
}
// Add skip sentinel to functionCall parts
if part.Get("functionCall").Exists() { if part.Get("functionCall").Exists() {
rawJSON, _ = sjson.SetBytes(rawJSON, fmt.Sprintf("request.contents.%d.parts.%d.thoughtSignature", key.Int(), partKey.Int()), "skip_thought_signature_validator") existingSig := part.Get("thoughtSignature").String()
} else if part.Get("thoughtSignature").Exists() { if existingSig == "" || len(existingSig) < 50 {
rawJSON, _ = sjson.SetBytes(rawJSON, fmt.Sprintf("request.contents.%d.parts.%d.thoughtSignature", key.Int(), partKey.Int()), "skip_thought_signature_validator") rawJSON, _ = sjson.SetBytes(rawJSON, fmt.Sprintf("request.contents.%d.parts.%d.thoughtSignature", contentIdx.Int(), partIdx.Int()), skipSentinel)
}
} }
return true return true
}) })
// Remove thinking blocks in reverse order to preserve indices
for i := len(thinkingIndicesToRemove) - 1; i >= 0; i-- {
idx := thinkingIndicesToRemove[i]
rawJSON, _ = sjson.DeleteBytes(rawJSON, fmt.Sprintf("request.contents.%d.parts.%d", contentIdx.Int(), idx))
}
} }
return true return true
}) })
@@ -117,41 +134,31 @@ func ConvertGeminiRequestToAntigravity(_ string, inputRawJSON []byte, _ bool) []
// FunctionCallGroup represents a group of function calls and their responses // FunctionCallGroup represents a group of function calls and their responses
type FunctionCallGroup struct { type FunctionCallGroup struct {
ModelContent map[string]interface{}
FunctionCalls []gjson.Result
ResponsesNeeded int ResponsesNeeded int
} }
// parseFunctionResponse attempts to unmarshal a function response part. // parseFunctionResponseRaw attempts to normalize a function response part into a JSON object string.
// Falls back to gjson extraction if standard json.Unmarshal fails. // Falls back to a minimal "functionResponse" object when parsing fails.
func parseFunctionResponse(response gjson.Result) map[string]interface{} { func parseFunctionResponseRaw(response gjson.Result) string {
var responseMap map[string]interface{} if response.IsObject() && gjson.Valid(response.Raw) {
err := json.Unmarshal([]byte(response.Raw), &responseMap) return response.Raw
if err == nil {
return responseMap
} }
log.Debugf("unmarshal function response failed, using fallback: %v", err) log.Debugf("parse function response failed, using fallback")
funcResp := response.Get("functionResponse") funcResp := response.Get("functionResponse")
if funcResp.Exists() { if funcResp.Exists() {
fr := map[string]interface{}{ fr := `{"functionResponse":{"name":"","response":{"result":""}}}`
"name": funcResp.Get("name").String(), fr, _ = sjson.Set(fr, "functionResponse.name", funcResp.Get("name").String())
"response": map[string]interface{}{ fr, _ = sjson.Set(fr, "functionResponse.response.result", funcResp.Get("response").String())
"result": funcResp.Get("response").String(),
},
}
if id := funcResp.Get("id").String(); id != "" { if id := funcResp.Get("id").String(); id != "" {
fr["id"] = id fr, _ = sjson.Set(fr, "functionResponse.id", id)
} }
return map[string]interface{}{"functionResponse": fr} return fr
} }
return map[string]interface{}{ fr := `{"functionResponse":{"name":"unknown","response":{"result":""}}}`
"functionResponse": map[string]interface{}{ fr, _ = sjson.Set(fr, "functionResponse.response.result", response.String())
"name": "unknown", return fr
"response": map[string]interface{}{"result": response.String()},
},
}
} }
// fixCLIToolResponse performs sophisticated tool response format conversion and grouping. // fixCLIToolResponse performs sophisticated tool response format conversion and grouping.
@@ -178,7 +185,7 @@ func fixCLIToolResponse(input string) (string, error) {
} }
// Initialize data structures for processing and grouping // Initialize data structures for processing and grouping
var newContents []interface{} // Final processed contents array contentsWrapper := `{"contents":[]}`
var pendingGroups []*FunctionCallGroup // Groups awaiting completion with responses var pendingGroups []*FunctionCallGroup // Groups awaiting completion with responses
var collectedResponses []gjson.Result // Standalone responses to be matched var collectedResponses []gjson.Result // Standalone responses to be matched
@@ -210,17 +217,16 @@ func fixCLIToolResponse(input string) (string, error) {
collectedResponses = collectedResponses[group.ResponsesNeeded:] collectedResponses = collectedResponses[group.ResponsesNeeded:]
// Create merged function response content // Create merged function response content
var responseParts []interface{} functionResponseContent := `{"parts":[],"role":"function"}`
for _, response := range groupResponses { for _, response := range groupResponses {
responseParts = append(responseParts, parseFunctionResponse(response)) partRaw := parseFunctionResponseRaw(response)
if partRaw != "" {
functionResponseContent, _ = sjson.SetRaw(functionResponseContent, "parts.-1", partRaw)
}
} }
if len(responseParts) > 0 { if gjson.Get(functionResponseContent, "parts.#").Int() > 0 {
functionResponseContent := map[string]interface{}{ contentsWrapper, _ = sjson.SetRaw(contentsWrapper, "contents.-1", functionResponseContent)
"parts": responseParts,
"role": "function",
}
newContents = append(newContents, functionResponseContent)
} }
// Remove this group as it's been satisfied // Remove this group as it's been satisfied
@@ -234,50 +240,42 @@ func fixCLIToolResponse(input string) (string, error) {
// If this is a model with function calls, create a new group // If this is a model with function calls, create a new group
if role == "model" { if role == "model" {
var functionCallsInThisModel []gjson.Result functionCallsCount := 0
parts.ForEach(func(_, part gjson.Result) bool { parts.ForEach(func(_, part gjson.Result) bool {
if part.Get("functionCall").Exists() { if part.Get("functionCall").Exists() {
functionCallsInThisModel = append(functionCallsInThisModel, part) functionCallsCount++
} }
return true return true
}) })
if len(functionCallsInThisModel) > 0 { if functionCallsCount > 0 {
// Add the model content // Add the model content
var contentMap map[string]interface{} if !value.IsObject() {
errUnmarshal := json.Unmarshal([]byte(value.Raw), &contentMap) log.Warnf("failed to parse model content")
if errUnmarshal != nil {
log.Warnf("failed to unmarshal model content: %v\n", errUnmarshal)
return true return true
} }
newContents = append(newContents, contentMap) contentsWrapper, _ = sjson.SetRaw(contentsWrapper, "contents.-1", value.Raw)
// Create a new group for tracking responses // Create a new group for tracking responses
group := &FunctionCallGroup{ group := &FunctionCallGroup{
ModelContent: contentMap, ResponsesNeeded: functionCallsCount,
FunctionCalls: functionCallsInThisModel,
ResponsesNeeded: len(functionCallsInThisModel),
} }
pendingGroups = append(pendingGroups, group) pendingGroups = append(pendingGroups, group)
} else { } else {
// Regular model content without function calls // Regular model content without function calls
var contentMap map[string]interface{} if !value.IsObject() {
errUnmarshal := json.Unmarshal([]byte(value.Raw), &contentMap) log.Warnf("failed to parse content")
if errUnmarshal != nil {
log.Warnf("failed to unmarshal content: %v\n", errUnmarshal)
return true return true
} }
newContents = append(newContents, contentMap) contentsWrapper, _ = sjson.SetRaw(contentsWrapper, "contents.-1", value.Raw)
} }
} else { } else {
// Non-model content (user, etc.) // Non-model content (user, etc.)
var contentMap map[string]interface{} if !value.IsObject() {
errUnmarshal := json.Unmarshal([]byte(value.Raw), &contentMap) log.Warnf("failed to parse content")
if errUnmarshal != nil {
log.Warnf("failed to unmarshal content: %v\n", errUnmarshal)
return true return true
} }
newContents = append(newContents, contentMap) contentsWrapper, _ = sjson.SetRaw(contentsWrapper, "contents.-1", value.Raw)
} }
return true return true
@@ -289,25 +287,23 @@ func fixCLIToolResponse(input string) (string, error) {
groupResponses := collectedResponses[:group.ResponsesNeeded] groupResponses := collectedResponses[:group.ResponsesNeeded]
collectedResponses = collectedResponses[group.ResponsesNeeded:] collectedResponses = collectedResponses[group.ResponsesNeeded:]
var responseParts []interface{} functionResponseContent := `{"parts":[],"role":"function"}`
for _, response := range groupResponses { for _, response := range groupResponses {
responseParts = append(responseParts, parseFunctionResponse(response)) partRaw := parseFunctionResponseRaw(response)
if partRaw != "" {
functionResponseContent, _ = sjson.SetRaw(functionResponseContent, "parts.-1", partRaw)
}
} }
if len(responseParts) > 0 { if gjson.Get(functionResponseContent, "parts.#").Int() > 0 {
functionResponseContent := map[string]interface{}{ contentsWrapper, _ = sjson.SetRaw(contentsWrapper, "contents.-1", functionResponseContent)
"parts": responseParts,
"role": "function",
}
newContents = append(newContents, functionResponseContent)
} }
} }
} }
// Update the original JSON with the new contents // Update the original JSON with the new contents
result := input result := input
newContentsJSON, _ := json.Marshal(newContents) result, _ = sjson.SetRaw(result, "request.contents", gjson.Get(contentsWrapper, "contents").Raw)
result, _ = sjson.Set(result, "request.contents", json.RawMessage(newContentsJSON))
return result, nil return result, nil
} }

View File

@@ -0,0 +1,129 @@
package gemini
import (
"fmt"
"testing"
"github.com/tidwall/gjson"
)
func TestConvertGeminiRequestToAntigravity_PreserveValidSignature(t *testing.T) {
// Valid signature on functionCall should be preserved
validSignature := "abc123validSignature1234567890123456789012345678901234567890"
inputJSON := []byte(fmt.Sprintf(`{
"model": "gemini-3-pro-preview",
"contents": [
{
"role": "model",
"parts": [
{"functionCall": {"name": "test_tool", "args": {}}, "thoughtSignature": "%s"}
]
}
]
}`, validSignature))
output := ConvertGeminiRequestToAntigravity("gemini-3-pro-preview", inputJSON, false)
outputStr := string(output)
// Check that valid thoughtSignature is preserved
parts := gjson.Get(outputStr, "request.contents.0.parts").Array()
if len(parts) != 1 {
t.Fatalf("Expected 1 part, got %d", len(parts))
}
sig := parts[0].Get("thoughtSignature").String()
if sig != validSignature {
t.Errorf("Expected thoughtSignature '%s', got '%s'", validSignature, sig)
}
}
func TestConvertGeminiRequestToAntigravity_AddSkipSentinelToFunctionCall(t *testing.T) {
// functionCall without signature should get skip_thought_signature_validator
inputJSON := []byte(`{
"model": "gemini-3-pro-preview",
"contents": [
{
"role": "model",
"parts": [
{"functionCall": {"name": "test_tool", "args": {}}}
]
}
]
}`)
output := ConvertGeminiRequestToAntigravity("gemini-3-pro-preview", inputJSON, false)
outputStr := string(output)
// Check that skip_thought_signature_validator is added to functionCall
sig := gjson.Get(outputStr, "request.contents.0.parts.0.thoughtSignature").String()
expectedSig := "skip_thought_signature_validator"
if sig != expectedSig {
t.Errorf("Expected skip sentinel '%s', got '%s'", expectedSig, sig)
}
}
func TestConvertGeminiRequestToAntigravity_RemoveThinkingBlocks(t *testing.T) {
// Thinking blocks should be removed entirely for Gemini
validSignature := "abc123validSignature1234567890123456789012345678901234567890"
inputJSON := []byte(fmt.Sprintf(`{
"model": "gemini-3-pro-preview",
"contents": [
{
"role": "model",
"parts": [
{"thought": true, "text": "Thinking...", "thoughtSignature": "%s"},
{"text": "Here is my response"}
]
}
]
}`, validSignature))
output := ConvertGeminiRequestToAntigravity("gemini-3-pro-preview", inputJSON, false)
outputStr := string(output)
// Check that thinking block is removed
parts := gjson.Get(outputStr, "request.contents.0.parts").Array()
if len(parts) != 1 {
t.Fatalf("Expected 1 part (thinking removed), got %d", len(parts))
}
// Only text part should remain
if parts[0].Get("thought").Bool() {
t.Error("Thinking block should be removed for Gemini")
}
if parts[0].Get("text").String() != "Here is my response" {
t.Errorf("Expected text 'Here is my response', got '%s'", parts[0].Get("text").String())
}
}
func TestConvertGeminiRequestToAntigravity_ParallelFunctionCalls(t *testing.T) {
// Multiple functionCalls should all get skip_thought_signature_validator
inputJSON := []byte(`{
"model": "gemini-3-pro-preview",
"contents": [
{
"role": "model",
"parts": [
{"functionCall": {"name": "tool_one", "args": {"a": "1"}}},
{"functionCall": {"name": "tool_two", "args": {"b": "2"}}}
]
}
]
}`)
output := ConvertGeminiRequestToAntigravity("gemini-3-pro-preview", inputJSON, false)
outputStr := string(output)
parts := gjson.Get(outputStr, "request.contents.0.parts").Array()
if len(parts) != 2 {
t.Fatalf("Expected 2 parts, got %d", len(parts))
}
expectedSig := "skip_thought_signature_validator"
for i, part := range parts {
sig := part.Get("thoughtSignature").String()
if sig != expectedSig {
t.Errorf("Part %d: Expected '%s', got '%s'", i, expectedSig, sig)
}
}
}

View File

@@ -192,6 +192,14 @@ func ConvertOpenAIRequestToAntigravity(modelName string, inputRawJSON []byte, _
} else if content.IsObject() && content.Get("type").String() == "text" { } else if content.IsObject() && content.Get("type").String() == "text" {
out, _ = sjson.SetBytes(out, "request.systemInstruction.role", "user") out, _ = sjson.SetBytes(out, "request.systemInstruction.role", "user")
out, _ = sjson.SetBytes(out, "request.systemInstruction.parts.0.text", content.Get("text").String()) out, _ = sjson.SetBytes(out, "request.systemInstruction.parts.0.text", content.Get("text").String())
} else if content.IsArray() {
contents := content.Array()
if len(contents) > 0 {
out, _ = sjson.SetBytes(out, "request.systemInstruction.role", "user")
for j := 0; j < len(contents); j++ {
out, _ = sjson.SetBytes(out, fmt.Sprintf("request.systemInstruction.parts.%d.text", j), contents[j].Get("text").String())
}
}
} }
} else if role == "user" || (role == "system" && len(arr) == 1) { } else if role == "user" || (role == "system" && len(arr) == 1) {
// Build single user content node to avoid splitting into multiple contents // Build single user content node to avoid splitting into multiple contents
@@ -258,7 +266,11 @@ func ConvertOpenAIRequestToAntigravity(modelName string, inputRawJSON []byte, _
fargs := tc.Get("function.arguments").String() fargs := tc.Get("function.arguments").String()
node, _ = sjson.SetBytes(node, "parts."+itoa(p)+".functionCall.id", fid) node, _ = sjson.SetBytes(node, "parts."+itoa(p)+".functionCall.id", fid)
node, _ = sjson.SetBytes(node, "parts."+itoa(p)+".functionCall.name", fname) node, _ = sjson.SetBytes(node, "parts."+itoa(p)+".functionCall.name", fname)
node, _ = sjson.SetRawBytes(node, "parts."+itoa(p)+".functionCall.args", []byte(fargs)) if gjson.Valid(fargs) {
node, _ = sjson.SetRawBytes(node, "parts."+itoa(p)+".functionCall.args", []byte(fargs))
} else {
node, _ = sjson.SetBytes(node, "parts."+itoa(p)+".functionCall.args.params", []byte(fargs))
}
node, _ = sjson.SetBytes(node, "parts."+itoa(p)+".thoughtSignature", geminiCLIFunctionThoughtSignature) node, _ = sjson.SetBytes(node, "parts."+itoa(p)+".thoughtSignature", geminiCLIFunctionThoughtSignature)
p++ p++
if fid != "" { if fid != "" {
@@ -319,7 +331,7 @@ func ConvertOpenAIRequestToAntigravity(modelName string, inputRawJSON []byte, _
log.Warnf("Failed to set default schema type for tool '%s': %v", fn.Get("name").String(), errSet) log.Warnf("Failed to set default schema type for tool '%s': %v", fn.Get("name").String(), errSet)
continue continue
} }
fnRaw, errSet = sjson.Set(fnRaw, "parametersJsonSchema.properties", map[string]interface{}{}) fnRaw, errSet = sjson.SetRaw(fnRaw, "parametersJsonSchema.properties", `{}`)
if errSet != nil { if errSet != nil {
log.Warnf("Failed to set default schema properties for tool '%s': %v", fn.Get("name").String(), errSet) log.Warnf("Failed to set default schema properties for tool '%s': %v", fn.Get("name").String(), errSet)
continue continue
@@ -334,7 +346,7 @@ func ConvertOpenAIRequestToAntigravity(modelName string, inputRawJSON []byte, _
log.Warnf("Failed to set default schema type for tool '%s': %v", fn.Get("name").String(), errSet) log.Warnf("Failed to set default schema type for tool '%s': %v", fn.Get("name").String(), errSet)
continue continue
} }
fnRaw, errSet = sjson.Set(fnRaw, "parametersJsonSchema.properties", map[string]interface{}{}) fnRaw, errSet = sjson.SetRaw(fnRaw, "parametersJsonSchema.properties", `{}`)
if errSet != nil { if errSet != nil {
log.Warnf("Failed to set default schema properties for tool '%s': %v", fn.Get("name").String(), errSet) log.Warnf("Failed to set default schema properties for tool '%s': %v", fn.Get("name").String(), errSet)
continue continue

View File

@@ -8,7 +8,6 @@ package chat_completions
import ( import (
"bytes" "bytes"
"context" "context"
"encoding/json"
"fmt" "fmt"
"strings" "strings"
"sync/atomic" "sync/atomic"
@@ -171,21 +170,14 @@ func ConvertAntigravityResponseToOpenAI(_ context.Context, _ string, originalReq
mimeType = "image/png" mimeType = "image/png"
} }
imageURL := fmt.Sprintf("data:%s;base64,%s", mimeType, data) imageURL := fmt.Sprintf("data:%s;base64,%s", mimeType, data)
imagePayload, err := json.Marshal(map[string]any{ imagePayload := `{"image_url":{"url":""},"type":"image_url"}`
"type": "image_url", imagePayload, _ = sjson.Set(imagePayload, "image_url.url", imageURL)
"image_url": map[string]string{
"url": imageURL,
},
})
if err != nil {
continue
}
imagesResult := gjson.Get(template, "choices.0.delta.images") imagesResult := gjson.Get(template, "choices.0.delta.images")
if !imagesResult.Exists() || !imagesResult.IsArray() { if !imagesResult.Exists() || !imagesResult.IsArray() {
template, _ = sjson.SetRaw(template, "choices.0.delta.images", `[]`) template, _ = sjson.SetRaw(template, "choices.0.delta.images", `[]`)
} }
template, _ = sjson.Set(template, "choices.0.delta.role", "assistant") template, _ = sjson.Set(template, "choices.0.delta.role", "assistant")
template, _ = sjson.SetRaw(template, "choices.0.delta.images.-1", string(imagePayload)) template, _ = sjson.SetRaw(template, "choices.0.delta.images.-1", imagePayload)
} }
} }
} }

View File

@@ -194,7 +194,7 @@ func ConvertGeminiRequestToClaude(modelName string, inputRawJSON []byte, stream
if name := fc.Get("name"); name.Exists() { if name := fc.Get("name"); name.Exists() {
toolUse, _ = sjson.Set(toolUse, "name", name.String()) toolUse, _ = sjson.Set(toolUse, "name", name.String())
} }
if args := fc.Get("args"); args.Exists() { if args := fc.Get("args"); args.Exists() && args.IsObject() {
toolUse, _ = sjson.SetRaw(toolUse, "input", args.Raw) toolUse, _ = sjson.SetRaw(toolUse, "input", args.Raw)
} }
msg, _ = sjson.SetRaw(msg, "content.-1", toolUse) msg, _ = sjson.SetRaw(msg, "content.-1", toolUse)
@@ -314,11 +314,11 @@ func ConvertGeminiRequestToClaude(modelName string, inputRawJSON []byte, stream
if mode := funcCalling.Get("mode"); mode.Exists() { if mode := funcCalling.Get("mode"); mode.Exists() {
switch mode.String() { switch mode.String() {
case "AUTO": case "AUTO":
out, _ = sjson.Set(out, "tool_choice", map[string]interface{}{"type": "auto"}) out, _ = sjson.SetRaw(out, "tool_choice", `{"type":"auto"}`)
case "NONE": case "NONE":
out, _ = sjson.Set(out, "tool_choice", map[string]interface{}{"type": "none"}) out, _ = sjson.SetRaw(out, "tool_choice", `{"type":"none"}`)
case "ANY": case "ANY":
out, _ = sjson.Set(out, "tool_choice", map[string]interface{}{"type": "any"}) out, _ = sjson.SetRaw(out, "tool_choice", `{"type":"any"}`)
} }
} }
} }

View File

@@ -263,51 +263,6 @@ func ConvertClaudeResponseToGemini(_ context.Context, modelName string, original
} }
} }
// convertArrayToJSON converts []interface{} to JSON array string
func convertArrayToJSON(arr []interface{}) string {
result := "[]"
for _, item := range arr {
switch itemData := item.(type) {
case map[string]interface{}:
itemJSON := convertMapToJSON(itemData)
result, _ = sjson.SetRaw(result, "-1", itemJSON)
case string:
result, _ = sjson.Set(result, "-1", itemData)
case bool:
result, _ = sjson.Set(result, "-1", itemData)
case float64, int, int64:
result, _ = sjson.Set(result, "-1", itemData)
default:
result, _ = sjson.Set(result, "-1", itemData)
}
}
return result
}
// convertMapToJSON converts map[string]interface{} to JSON object string
func convertMapToJSON(m map[string]interface{}) string {
result := "{}"
for key, value := range m {
switch val := value.(type) {
case map[string]interface{}:
nestedJSON := convertMapToJSON(val)
result, _ = sjson.SetRaw(result, key, nestedJSON)
case []interface{}:
arrayJSON := convertArrayToJSON(val)
result, _ = sjson.SetRaw(result, key, arrayJSON)
case string:
result, _ = sjson.Set(result, key, val)
case bool:
result, _ = sjson.Set(result, key, val)
case float64, int, int64:
result, _ = sjson.Set(result, key, val)
default:
result, _ = sjson.Set(result, key, val)
}
}
return result
}
// ConvertClaudeResponseToGeminiNonStream converts a non-streaming Claude Code response to a non-streaming Gemini response. // ConvertClaudeResponseToGeminiNonStream converts a non-streaming Claude Code response to a non-streaming Gemini response.
// This function processes the complete Claude Code response and transforms it into a single Gemini-compatible // This function processes the complete Claude Code response and transforms it into a single Gemini-compatible
// JSON response. It handles message content, tool calls, reasoning content, and usage metadata, combining all // JSON response. It handles message content, tool calls, reasoning content, and usage metadata, combining all
@@ -356,8 +311,8 @@ func ConvertClaudeResponseToGeminiNonStream(_ context.Context, modelName string,
} }
// Process each streaming event and collect parts // Process each streaming event and collect parts
var allParts []interface{} var allParts []string
var finalUsage map[string]interface{} var finalUsageJSON string
var responseID string var responseID string
var createdAt int64 var createdAt int64
@@ -407,16 +362,14 @@ func ConvertClaudeResponseToGeminiNonStream(_ context.Context, modelName string,
if text := delta.Get("text"); text.Exists() && text.String() != "" { if text := delta.Get("text"); text.Exists() && text.String() != "" {
partJSON := `{"text":""}` partJSON := `{"text":""}`
partJSON, _ = sjson.Set(partJSON, "text", text.String()) partJSON, _ = sjson.Set(partJSON, "text", text.String())
part := gjson.Parse(partJSON).Value().(map[string]interface{}) allParts = append(allParts, partJSON)
allParts = append(allParts, part)
} }
case "thinking_delta": case "thinking_delta":
// Process reasoning/thinking content // Process reasoning/thinking content
if text := delta.Get("thinking"); text.Exists() && text.String() != "" { if text := delta.Get("thinking"); text.Exists() && text.String() != "" {
partJSON := `{"thought":true,"text":""}` partJSON := `{"thought":true,"text":""}`
partJSON, _ = sjson.Set(partJSON, "text", text.String()) partJSON, _ = sjson.Set(partJSON, "text", text.String())
part := gjson.Parse(partJSON).Value().(map[string]interface{}) allParts = append(allParts, partJSON)
allParts = append(allParts, part)
} }
case "input_json_delta": case "input_json_delta":
// accumulate args partial_json for this index // accumulate args partial_json for this index
@@ -456,9 +409,7 @@ func ConvertClaudeResponseToGeminiNonStream(_ context.Context, modelName string,
if argsTrim != "" { if argsTrim != "" {
functionCallJSON, _ = sjson.SetRaw(functionCallJSON, "functionCall.args", argsTrim) functionCallJSON, _ = sjson.SetRaw(functionCallJSON, "functionCall.args", argsTrim)
} }
// Parse back to interface{} for allParts allParts = append(allParts, functionCallJSON)
functionCall := gjson.Parse(functionCallJSON).Value().(map[string]interface{})
allParts = append(allParts, functionCall)
// cleanup used state for this index // cleanup used state for this index
if newParam.ToolUseArgs != nil { if newParam.ToolUseArgs != nil {
delete(newParam.ToolUseArgs, idx) delete(newParam.ToolUseArgs, idx)
@@ -501,8 +452,7 @@ func ConvertClaudeResponseToGeminiNonStream(_ context.Context, modelName string,
// Set traffic type (required by Gemini API) // Set traffic type (required by Gemini API)
usageJSON, _ = sjson.Set(usageJSON, "trafficType", "PROVISIONED_THROUGHPUT") usageJSON, _ = sjson.Set(usageJSON, "trafficType", "PROVISIONED_THROUGHPUT")
// Convert to map[string]interface{} using gjson finalUsageJSON = usageJSON
finalUsage = gjson.Parse(usageJSON).Value().(map[string]interface{})
} }
} }
} }
@@ -520,12 +470,16 @@ func ConvertClaudeResponseToGeminiNonStream(_ context.Context, modelName string,
// Set the consolidated parts array // Set the consolidated parts array
if len(consolidatedParts) > 0 { if len(consolidatedParts) > 0 {
template, _ = sjson.SetRaw(template, "candidates.0.content.parts", convertToJSONString(consolidatedParts)) partsJSON := "[]"
for _, partJSON := range consolidatedParts {
partsJSON, _ = sjson.SetRaw(partsJSON, "-1", partJSON)
}
template, _ = sjson.SetRaw(template, "candidates.0.content.parts", partsJSON)
} }
// Set usage metadata // Set usage metadata
if finalUsage != nil { if finalUsageJSON != "" {
template, _ = sjson.SetRaw(template, "usageMetadata", convertToJSONString(finalUsage)) template, _ = sjson.SetRaw(template, "usageMetadata", finalUsageJSON)
} }
return template return template
@@ -539,12 +493,12 @@ func GeminiTokenCount(ctx context.Context, count int64) string {
// This function processes the parts array to combine adjacent text elements and thinking elements // This function processes the parts array to combine adjacent text elements and thinking elements
// into single consolidated parts, which results in a more readable and efficient response structure. // into single consolidated parts, which results in a more readable and efficient response structure.
// Tool calls and other non-text parts are preserved as separate elements. // Tool calls and other non-text parts are preserved as separate elements.
func consolidateParts(parts []interface{}) []interface{} { func consolidateParts(parts []string) []string {
if len(parts) == 0 { if len(parts) == 0 {
return parts return parts
} }
var consolidated []interface{} var consolidated []string
var currentTextPart strings.Builder var currentTextPart strings.Builder
var currentThoughtPart strings.Builder var currentThoughtPart strings.Builder
var hasText, hasThought bool var hasText, hasThought bool
@@ -554,8 +508,7 @@ func consolidateParts(parts []interface{}) []interface{} {
if hasText && currentTextPart.Len() > 0 { if hasText && currentTextPart.Len() > 0 {
textPartJSON := `{"text":""}` textPartJSON := `{"text":""}`
textPartJSON, _ = sjson.Set(textPartJSON, "text", currentTextPart.String()) textPartJSON, _ = sjson.Set(textPartJSON, "text", currentTextPart.String())
textPart := gjson.Parse(textPartJSON).Value().(map[string]interface{}) consolidated = append(consolidated, textPartJSON)
consolidated = append(consolidated, textPart)
currentTextPart.Reset() currentTextPart.Reset()
hasText = false hasText = false
} }
@@ -566,42 +519,42 @@ func consolidateParts(parts []interface{}) []interface{} {
if hasThought && currentThoughtPart.Len() > 0 { if hasThought && currentThoughtPart.Len() > 0 {
thoughtPartJSON := `{"thought":true,"text":""}` thoughtPartJSON := `{"thought":true,"text":""}`
thoughtPartJSON, _ = sjson.Set(thoughtPartJSON, "text", currentThoughtPart.String()) thoughtPartJSON, _ = sjson.Set(thoughtPartJSON, "text", currentThoughtPart.String())
thoughtPart := gjson.Parse(thoughtPartJSON).Value().(map[string]interface{}) consolidated = append(consolidated, thoughtPartJSON)
consolidated = append(consolidated, thoughtPart)
currentThoughtPart.Reset() currentThoughtPart.Reset()
hasThought = false hasThought = false
} }
} }
for _, part := range parts { for _, partJSON := range parts {
partMap, ok := part.(map[string]interface{}) part := gjson.Parse(partJSON)
if !ok { if !part.Exists() || !part.IsObject() {
// Flush any pending parts and add this non-text part // Flush any pending parts and add this non-text part
flushText() flushText()
flushThought() flushThought()
consolidated = append(consolidated, part) consolidated = append(consolidated, partJSON)
continue continue
} }
if thought, isThought := partMap["thought"]; isThought && thought == true { thought := part.Get("thought")
if thought.Exists() && thought.Type == gjson.True {
// This is a thinking part - flush any pending text first // This is a thinking part - flush any pending text first
flushText() // Flush any pending text first flushText() // Flush any pending text first
if text, hasTextContent := partMap["text"].(string); hasTextContent { if text := part.Get("text"); text.Exists() && text.Type == gjson.String {
currentThoughtPart.WriteString(text) currentThoughtPart.WriteString(text.String())
hasThought = true hasThought = true
} }
} else if text, hasTextContent := partMap["text"].(string); hasTextContent { } else if text := part.Get("text"); text.Exists() && text.Type == gjson.String {
// This is a regular text part - flush any pending thought first // This is a regular text part - flush any pending thought first
flushThought() // Flush any pending thought first flushThought() // Flush any pending thought first
currentTextPart.WriteString(text) currentTextPart.WriteString(text.String())
hasText = true hasText = true
} else { } else {
// This is some other type of part (like function call) - flush both text and thought // This is some other type of part (like function call) - flush both text and thought
flushText() flushText()
flushThought() flushThought()
consolidated = append(consolidated, part) consolidated = append(consolidated, partJSON)
} }
} }
@@ -611,20 +564,3 @@ func consolidateParts(parts []interface{}) []interface{} {
return consolidated return consolidated
} }
// convertToJSONString converts interface{} to JSON string using sjson/gjson.
// This function provides a consistent way to serialize different data types to JSON strings
// for inclusion in the Gemini API response structure.
func convertToJSONString(v interface{}) string {
switch val := v.(type) {
case []interface{}:
return convertArrayToJSON(val)
case map[string]interface{}:
return convertMapToJSON(val)
default:
// For simple types, create a temporary JSON and extract the value
temp := `{"temp":null}`
temp, _ = sjson.Set(temp, "temp", val)
return gjson.Get(temp, "temp").Raw
}
}

View File

@@ -10,7 +10,6 @@ import (
"crypto/rand" "crypto/rand"
"crypto/sha256" "crypto/sha256"
"encoding/hex" "encoding/hex"
"encoding/json"
"fmt" "fmt"
"math/big" "math/big"
"strings" "strings"
@@ -137,9 +136,6 @@ func ConvertOpenAIRequestToClaude(modelName string, inputRawJSON []byte, stream
out, _ = sjson.Set(out, "stream", stream) out, _ = sjson.Set(out, "stream", stream)
// Process messages and transform them to Claude Code format // Process messages and transform them to Claude Code format
var anthropicMessages []interface{}
var toolCallIDs []string // Track tool call IDs for matching with tool results
if messages := root.Get("messages"); messages.Exists() && messages.IsArray() { if messages := root.Get("messages"); messages.Exists() && messages.IsArray() {
messages.ForEach(func(_, message gjson.Result) bool { messages.ForEach(func(_, message gjson.Result) bool {
role := message.Get("role").String() role := message.Get("role").String()
@@ -152,33 +148,23 @@ func ConvertOpenAIRequestToClaude(modelName string, inputRawJSON []byte, stream
role = "user" role = "user"
} }
msg := map[string]interface{}{ msg := `{"role":"","content":[]}`
"role": role, msg, _ = sjson.Set(msg, "role", role)
"content": []interface{}{},
}
// Handle content based on its type (string or array) // Handle content based on its type (string or array)
if contentResult.Exists() && contentResult.Type == gjson.String && contentResult.String() != "" { if contentResult.Exists() && contentResult.Type == gjson.String && contentResult.String() != "" {
// Simple text content conversion part := `{"type":"text","text":""}`
msg["content"] = []interface{}{ part, _ = sjson.Set(part, "text", contentResult.String())
map[string]interface{}{ msg, _ = sjson.SetRaw(msg, "content.-1", part)
"type": "text",
"text": contentResult.String(),
},
}
} else if contentResult.Exists() && contentResult.IsArray() { } else if contentResult.Exists() && contentResult.IsArray() {
// Array of content parts processing
var contentParts []interface{}
contentResult.ForEach(func(_, part gjson.Result) bool { contentResult.ForEach(func(_, part gjson.Result) bool {
partType := part.Get("type").String() partType := part.Get("type").String()
switch partType { switch partType {
case "text": case "text":
// Text part conversion textPart := `{"type":"text","text":""}`
contentParts = append(contentParts, map[string]interface{}{ textPart, _ = sjson.Set(textPart, "text", part.Get("text").String())
"type": "text", msg, _ = sjson.SetRaw(msg, "content.-1", textPart)
"text": part.Get("text").String(),
})
case "image_url": case "image_url":
// Convert OpenAI image format to Claude Code format // Convert OpenAI image format to Claude Code format
@@ -191,132 +177,95 @@ func ConvertOpenAIRequestToClaude(modelName string, inputRawJSON []byte, stream
mediaType := strings.TrimPrefix(mediaTypePart, "data:") mediaType := strings.TrimPrefix(mediaTypePart, "data:")
data := parts[1] data := parts[1]
contentParts = append(contentParts, map[string]interface{}{ imagePart := `{"type":"image","source":{"type":"base64","media_type":"","data":""}}`
"type": "image", imagePart, _ = sjson.Set(imagePart, "source.media_type", mediaType)
"source": map[string]interface{}{ imagePart, _ = sjson.Set(imagePart, "source.data", data)
"type": "base64", msg, _ = sjson.SetRaw(msg, "content.-1", imagePart)
"media_type": mediaType,
"data": data,
},
})
} }
} }
} }
return true return true
}) })
if len(contentParts) > 0 {
msg["content"] = contentParts
}
} else {
// Initialize empty content array for tool calls
msg["content"] = []interface{}{}
} }
// Handle tool calls (for assistant messages) // Handle tool calls (for assistant messages)
if toolCalls := message.Get("tool_calls"); toolCalls.Exists() && toolCalls.IsArray() && role == "assistant" { if toolCalls := message.Get("tool_calls"); toolCalls.Exists() && toolCalls.IsArray() && role == "assistant" {
var contentParts []interface{}
// Add existing text content if any
if existingContent, ok := msg["content"].([]interface{}); ok {
contentParts = existingContent
}
toolCalls.ForEach(func(_, toolCall gjson.Result) bool { toolCalls.ForEach(func(_, toolCall gjson.Result) bool {
if toolCall.Get("type").String() == "function" { if toolCall.Get("type").String() == "function" {
toolCallID := toolCall.Get("id").String() toolCallID := toolCall.Get("id").String()
if toolCallID == "" { if toolCallID == "" {
toolCallID = genToolCallID() toolCallID = genToolCallID()
} }
toolCallIDs = append(toolCallIDs, toolCallID)
function := toolCall.Get("function") function := toolCall.Get("function")
toolUse := map[string]interface{}{ toolUse := `{"type":"tool_use","id":"","name":"","input":{}}`
"type": "tool_use", toolUse, _ = sjson.Set(toolUse, "id", toolCallID)
"id": toolCallID, toolUse, _ = sjson.Set(toolUse, "name", function.Get("name").String())
"name": function.Get("name").String(),
}
// Parse arguments for the tool call // Parse arguments for the tool call
if args := function.Get("arguments"); args.Exists() { if args := function.Get("arguments"); args.Exists() {
argsStr := args.String() argsStr := args.String()
if argsStr != "" { if argsStr != "" && gjson.Valid(argsStr) {
var argsMap map[string]interface{} argsJSON := gjson.Parse(argsStr)
if err := json.Unmarshal([]byte(argsStr), &argsMap); err == nil { if argsJSON.IsObject() {
toolUse["input"] = argsMap toolUse, _ = sjson.SetRaw(toolUse, "input", argsJSON.Raw)
} else { } else {
toolUse["input"] = map[string]interface{}{} toolUse, _ = sjson.SetRaw(toolUse, "input", "{}")
} }
} else { } else {
toolUse["input"] = map[string]interface{}{} toolUse, _ = sjson.SetRaw(toolUse, "input", "{}")
} }
} else { } else {
toolUse["input"] = map[string]interface{}{} toolUse, _ = sjson.SetRaw(toolUse, "input", "{}")
} }
contentParts = append(contentParts, toolUse) msg, _ = sjson.SetRaw(msg, "content.-1", toolUse)
} }
return true return true
}) })
msg["content"] = contentParts
} }
anthropicMessages = append(anthropicMessages, msg) out, _ = sjson.SetRaw(out, "messages.-1", msg)
case "tool": case "tool":
// Handle tool result messages conversion // Handle tool result messages conversion
toolCallID := message.Get("tool_call_id").String() toolCallID := message.Get("tool_call_id").String()
content := message.Get("content").String() content := message.Get("content").String()
// Create tool result message in Claude Code format msg := `{"role":"user","content":[{"type":"tool_result","tool_use_id":"","content":""}]}`
msg := map[string]interface{}{ msg, _ = sjson.Set(msg, "content.0.tool_use_id", toolCallID)
"role": "user", msg, _ = sjson.Set(msg, "content.0.content", content)
"content": []interface{}{ out, _ = sjson.SetRaw(out, "messages.-1", msg)
map[string]interface{}{
"type": "tool_result",
"tool_use_id": toolCallID,
"content": content,
},
},
}
anthropicMessages = append(anthropicMessages, msg)
} }
return true return true
}) })
} }
// Set messages in the output template
if len(anthropicMessages) > 0 {
messagesJSON, _ := json.Marshal(anthropicMessages)
out, _ = sjson.SetRaw(out, "messages", string(messagesJSON))
}
// Tools mapping: OpenAI tools -> Claude Code tools // Tools mapping: OpenAI tools -> Claude Code tools
if tools := root.Get("tools"); tools.Exists() && tools.IsArray() && len(tools.Array()) > 0 { if tools := root.Get("tools"); tools.Exists() && tools.IsArray() && len(tools.Array()) > 0 {
var anthropicTools []interface{} hasAnthropicTools := false
tools.ForEach(func(_, tool gjson.Result) bool { tools.ForEach(func(_, tool gjson.Result) bool {
if tool.Get("type").String() == "function" { if tool.Get("type").String() == "function" {
function := tool.Get("function") function := tool.Get("function")
anthropicTool := map[string]interface{}{ anthropicTool := `{"name":"","description":""}`
"name": function.Get("name").String(), anthropicTool, _ = sjson.Set(anthropicTool, "name", function.Get("name").String())
"description": function.Get("description").String(), anthropicTool, _ = sjson.Set(anthropicTool, "description", function.Get("description").String())
}
// Convert parameters schema for the tool // Convert parameters schema for the tool
if parameters := function.Get("parameters"); parameters.Exists() { if parameters := function.Get("parameters"); parameters.Exists() {
anthropicTool["input_schema"] = parameters.Value() anthropicTool, _ = sjson.SetRaw(anthropicTool, "input_schema", parameters.Raw)
} else if parameters = function.Get("parametersJsonSchema"); parameters.Exists() { } else if parameters := function.Get("parametersJsonSchema"); parameters.Exists() {
anthropicTool["input_schema"] = parameters.Value() anthropicTool, _ = sjson.SetRaw(anthropicTool, "input_schema", parameters.Raw)
} }
anthropicTools = append(anthropicTools, anthropicTool) out, _ = sjson.SetRaw(out, "tools.-1", anthropicTool)
hasAnthropicTools = true
} }
return true return true
}) })
if len(anthropicTools) > 0 { if !hasAnthropicTools {
toolsJSON, _ := json.Marshal(anthropicTools) out, _ = sjson.Delete(out, "tools")
out, _ = sjson.SetRaw(out, "tools", string(toolsJSON))
} }
} }
@@ -329,18 +278,17 @@ func ConvertOpenAIRequestToClaude(modelName string, inputRawJSON []byte, stream
case "none": case "none":
// Don't set tool_choice, Claude Code will not use tools // Don't set tool_choice, Claude Code will not use tools
case "auto": case "auto":
out, _ = sjson.Set(out, "tool_choice", map[string]interface{}{"type": "auto"}) out, _ = sjson.SetRaw(out, "tool_choice", `{"type":"auto"}`)
case "required": case "required":
out, _ = sjson.Set(out, "tool_choice", map[string]interface{}{"type": "any"}) out, _ = sjson.SetRaw(out, "tool_choice", `{"type":"any"}`)
} }
case gjson.JSON: case gjson.JSON:
// Specific tool choice mapping // Specific tool choice mapping
if toolChoice.Get("type").String() == "function" { if toolChoice.Get("type").String() == "function" {
functionName := toolChoice.Get("function.name").String() functionName := toolChoice.Get("function.name").String()
out, _ = sjson.Set(out, "tool_choice", map[string]interface{}{ toolChoiceJSON := `{"type":"tool","name":""}`
"type": "tool", toolChoiceJSON, _ = sjson.Set(toolChoiceJSON, "name", functionName)
"name": functionName, out, _ = sjson.SetRaw(out, "tool_choice", toolChoiceJSON)
})
} }
default: default:
} }

View File

@@ -8,7 +8,7 @@ package chat_completions
import ( import (
"bytes" "bytes"
"context" "context"
"encoding/json" "fmt"
"strings" "strings"
"time" "time"
@@ -182,18 +182,11 @@ func ConvertClaudeResponseToOpenAI(_ context.Context, modelName string, original
if arguments == "" { if arguments == "" {
arguments = "{}" arguments = "{}"
} }
template, _ = sjson.Set(template, "choices.0.delta.tool_calls.0.index", index)
toolCall := map[string]interface{}{ template, _ = sjson.Set(template, "choices.0.delta.tool_calls.0.id", accumulator.ID)
"index": index, template, _ = sjson.Set(template, "choices.0.delta.tool_calls.0.type", "function")
"id": accumulator.ID, template, _ = sjson.Set(template, "choices.0.delta.tool_calls.0.function.name", accumulator.Name)
"type": "function", template, _ = sjson.Set(template, "choices.0.delta.tool_calls.0.function.arguments", arguments)
"function": map[string]interface{}{
"name": accumulator.Name,
"arguments": arguments,
},
}
template, _ = sjson.Set(template, "choices.0.delta.tool_calls", []interface{}{toolCall})
// Clean up the accumulator for this index // Clean up the accumulator for this index
delete((*param).(*ConvertAnthropicResponseToOpenAIParams).ToolCallsAccumulator, index) delete((*param).(*ConvertAnthropicResponseToOpenAIParams).ToolCallsAccumulator, index)
@@ -214,12 +207,11 @@ func ConvertClaudeResponseToOpenAI(_ context.Context, modelName string, original
// Handle usage information for token counts // Handle usage information for token counts
if usage := root.Get("usage"); usage.Exists() { if usage := root.Get("usage"); usage.Exists() {
usageObj := map[string]interface{}{ inputTokens := usage.Get("input_tokens").Int()
"prompt_tokens": usage.Get("input_tokens").Int(), outputTokens := usage.Get("output_tokens").Int()
"completion_tokens": usage.Get("output_tokens").Int(), template, _ = sjson.Set(template, "usage.prompt_tokens", inputTokens)
"total_tokens": usage.Get("input_tokens").Int() + usage.Get("output_tokens").Int(), template, _ = sjson.Set(template, "usage.completion_tokens", outputTokens)
} template, _ = sjson.Set(template, "usage.total_tokens", inputTokens+outputTokens)
template, _ = sjson.Set(template, "usage", usageObj)
} }
return []string{template} return []string{template}
@@ -234,14 +226,10 @@ func ConvertClaudeResponseToOpenAI(_ context.Context, modelName string, original
case "error": case "error":
// Error event - format and return error response // Error event - format and return error response
if errorData := root.Get("error"); errorData.Exists() { if errorData := root.Get("error"); errorData.Exists() {
errorResponse := map[string]interface{}{ errorJSON := `{"error":{"message":"","type":""}}`
"error": map[string]interface{}{ errorJSON, _ = sjson.Set(errorJSON, "error.message", errorData.Get("message").String())
"message": errorData.Get("message").String(), errorJSON, _ = sjson.Set(errorJSON, "error.type", errorData.Get("type").String())
"type": errorData.Get("type").String(), return []string{errorJSON}
},
}
errorJSON, _ := json.Marshal(errorResponse)
return []string{string(errorJSON)}
} }
return []string{} return []string{}
@@ -302,10 +290,7 @@ func ConvertClaudeResponseToOpenAINonStream(_ context.Context, _ string, origina
var stopReason string var stopReason string
var contentParts []string var contentParts []string
var reasoningParts []string var reasoningParts []string
// Use map to track tool calls by index for proper merging toolCallsAccumulator := make(map[int]*ToolCallAccumulator)
toolCallsMap := make(map[int]map[string]interface{})
// Track tool call arguments accumulation
toolCallArgsMap := make(map[int]strings.Builder)
for _, chunk := range chunks { for _, chunk := range chunks {
root := gjson.ParseBytes(chunk) root := gjson.ParseBytes(chunk)
@@ -331,18 +316,12 @@ func ConvertClaudeResponseToOpenAINonStream(_ context.Context, _ string, origina
// Start of thinking/reasoning content - skip for now as it's handled in delta // Start of thinking/reasoning content - skip for now as it's handled in delta
continue continue
} else if blockType == "tool_use" { } else if blockType == "tool_use" {
// Initialize tool call tracking for this index // Initialize tool call accumulator for this index
index := int(root.Get("index").Int()) index := int(root.Get("index").Int())
toolCallsMap[index] = map[string]interface{}{ toolCallsAccumulator[index] = &ToolCallAccumulator{
"id": contentBlock.Get("id").String(), ID: contentBlock.Get("id").String(),
"type": "function", Name: contentBlock.Get("name").String(),
"function": map[string]interface{}{
"name": contentBlock.Get("name").String(),
"arguments": "",
},
} }
// Initialize arguments builder for this tool call
toolCallArgsMap[index] = strings.Builder{}
} }
} }
@@ -365,9 +344,8 @@ func ConvertClaudeResponseToOpenAINonStream(_ context.Context, _ string, origina
// Accumulate tool call arguments // Accumulate tool call arguments
if partialJSON := delta.Get("partial_json"); partialJSON.Exists() { if partialJSON := delta.Get("partial_json"); partialJSON.Exists() {
index := int(root.Get("index").Int()) index := int(root.Get("index").Int())
if builder, exists := toolCallArgsMap[index]; exists { if accumulator, exists := toolCallsAccumulator[index]; exists {
builder.WriteString(partialJSON.String()) accumulator.Arguments.WriteString(partialJSON.String())
toolCallArgsMap[index] = builder
} }
} }
} }
@@ -376,14 +354,9 @@ func ConvertClaudeResponseToOpenAINonStream(_ context.Context, _ string, origina
case "content_block_stop": case "content_block_stop":
// Finalize tool call arguments for this index when content block ends // Finalize tool call arguments for this index when content block ends
index := int(root.Get("index").Int()) index := int(root.Get("index").Int())
if toolCall, exists := toolCallsMap[index]; exists { if accumulator, exists := toolCallsAccumulator[index]; exists {
if builder, argsExists := toolCallArgsMap[index]; argsExists { if accumulator.Arguments.Len() == 0 {
// Set the accumulated arguments for the tool call accumulator.Arguments.WriteString("{}")
arguments := builder.String()
if arguments == "" {
arguments = "{}"
}
toolCall["function"].(map[string]interface{})["arguments"] = arguments
} }
} }
@@ -421,24 +394,35 @@ func ConvertClaudeResponseToOpenAINonStream(_ context.Context, _ string, origina
} }
// Set tool calls if any were accumulated during processing // Set tool calls if any were accumulated during processing
if len(toolCallsMap) > 0 { if len(toolCallsAccumulator) > 0 {
// Convert tool calls map to array, preserving order by index toolCallsCount := 0
var toolCallsArray []interface{}
// Find the maximum index to determine the range
maxIndex := -1 maxIndex := -1
for index := range toolCallsMap { for index := range toolCallsAccumulator {
if index > maxIndex { if index > maxIndex {
maxIndex = index maxIndex = index
} }
} }
// Iterate through all possible indices up to maxIndex
for i := 0; i <= maxIndex; i++ { for i := 0; i <= maxIndex; i++ {
if toolCall, exists := toolCallsMap[i]; exists { accumulator, exists := toolCallsAccumulator[i]
toolCallsArray = append(toolCallsArray, toolCall) if !exists {
continue
} }
arguments := accumulator.Arguments.String()
idPath := fmt.Sprintf("choices.0.message.tool_calls.%d.id", toolCallsCount)
typePath := fmt.Sprintf("choices.0.message.tool_calls.%d.type", toolCallsCount)
namePath := fmt.Sprintf("choices.0.message.tool_calls.%d.function.name", toolCallsCount)
argumentsPath := fmt.Sprintf("choices.0.message.tool_calls.%d.function.arguments", toolCallsCount)
out, _ = sjson.Set(out, idPath, accumulator.ID)
out, _ = sjson.Set(out, typePath, "function")
out, _ = sjson.Set(out, namePath, accumulator.Name)
out, _ = sjson.Set(out, argumentsPath, arguments)
toolCallsCount++
} }
if len(toolCallsArray) > 0 { if toolCallsCount > 0 {
out, _ = sjson.Set(out, "choices.0.message.tool_calls", toolCallsArray)
out, _ = sjson.Set(out, "choices.0.finish_reason", "tool_calls") out, _ = sjson.Set(out, "choices.0.finish_reason", "tool_calls")
} else { } else {
out, _ = sjson.Set(out, "choices.0.finish_reason", mapAnthropicStopReasonToOpenAI(stopReason)) out, _ = sjson.Set(out, "choices.0.finish_reason", mapAnthropicStopReasonToOpenAI(stopReason))

View File

@@ -254,7 +254,10 @@ func ConvertOpenAIResponsesRequestToClaude(modelName string, inputRawJSON []byte
toolUse, _ = sjson.Set(toolUse, "id", callID) toolUse, _ = sjson.Set(toolUse, "id", callID)
toolUse, _ = sjson.Set(toolUse, "name", name) toolUse, _ = sjson.Set(toolUse, "name", name)
if argsStr != "" && gjson.Valid(argsStr) { if argsStr != "" && gjson.Valid(argsStr) {
toolUse, _ = sjson.SetRaw(toolUse, "input", argsStr) argsJSON := gjson.Parse(argsStr)
if argsJSON.IsObject() {
toolUse, _ = sjson.SetRaw(toolUse, "input", argsJSON.Raw)
}
} }
asst := `{"role":"assistant","content":[]}` asst := `{"role":"assistant","content":[]}`
@@ -309,16 +312,18 @@ func ConvertOpenAIResponsesRequestToClaude(modelName string, inputRawJSON []byte
case gjson.String: case gjson.String:
switch toolChoice.String() { switch toolChoice.String() {
case "auto": case "auto":
out, _ = sjson.Set(out, "tool_choice", map[string]interface{}{"type": "auto"}) out, _ = sjson.SetRaw(out, "tool_choice", `{"type":"auto"}`)
case "none": case "none":
// Leave unset; implies no tools // Leave unset; implies no tools
case "required": case "required":
out, _ = sjson.Set(out, "tool_choice", map[string]interface{}{"type": "any"}) out, _ = sjson.SetRaw(out, "tool_choice", `{"type":"any"}`)
} }
case gjson.JSON: case gjson.JSON:
if toolChoice.Get("type").String() == "function" { if toolChoice.Get("type").String() == "function" {
fn := toolChoice.Get("function.name").String() fn := toolChoice.Get("function.name").String()
out, _ = sjson.Set(out, "tool_choice", map[string]interface{}{"type": "tool", "name": fn}) toolChoiceJSON := `{"name":"","type":"tool"}`
toolChoiceJSON, _ = sjson.Set(toolChoiceJSON, "name", fn)
out, _ = sjson.SetRaw(out, "tool_choice", toolChoiceJSON)
} }
default: default:

View File

@@ -344,31 +344,20 @@ func ConvertClaudeResponseToOpenAIResponses(ctx context.Context, modelName strin
} }
// Build response.output from aggregated state // Build response.output from aggregated state
var outputs []interface{} outputsWrapper := `{"arr":[]}`
// reasoning item (if any) // reasoning item (if any)
if st.ReasoningBuf.Len() > 0 || st.ReasoningPartAdded { if st.ReasoningBuf.Len() > 0 || st.ReasoningPartAdded {
r := map[string]interface{}{ item := `{"id":"","type":"reasoning","summary":[{"type":"summary_text","text":""}]}`
"id": st.ReasoningItemID, item, _ = sjson.Set(item, "id", st.ReasoningItemID)
"type": "reasoning", item, _ = sjson.Set(item, "summary.0.text", st.ReasoningBuf.String())
"summary": []interface{}{map[string]interface{}{"type": "summary_text", "text": st.ReasoningBuf.String()}}, outputsWrapper, _ = sjson.SetRaw(outputsWrapper, "arr.-1", item)
}
outputs = append(outputs, r)
} }
// assistant message item (if any text) // assistant message item (if any text)
if st.TextBuf.Len() > 0 || st.InTextBlock || st.CurrentMsgID != "" { if st.TextBuf.Len() > 0 || st.InTextBlock || st.CurrentMsgID != "" {
m := map[string]interface{}{ item := `{"id":"","type":"message","status":"completed","content":[{"type":"output_text","annotations":[],"logprobs":[],"text":""}],"role":"assistant"}`
"id": st.CurrentMsgID, item, _ = sjson.Set(item, "id", st.CurrentMsgID)
"type": "message", item, _ = sjson.Set(item, "content.0.text", st.TextBuf.String())
"status": "completed", outputsWrapper, _ = sjson.SetRaw(outputsWrapper, "arr.-1", item)
"content": []interface{}{map[string]interface{}{
"type": "output_text",
"annotations": []interface{}{},
"logprobs": []interface{}{},
"text": st.TextBuf.String(),
}},
"role": "assistant",
}
outputs = append(outputs, m)
} }
// function_call items (in ascending index order for determinism) // function_call items (in ascending index order for determinism)
if len(st.FuncArgsBuf) > 0 { if len(st.FuncArgsBuf) > 0 {
@@ -395,19 +384,16 @@ func ConvertClaudeResponseToOpenAIResponses(ctx context.Context, modelName strin
if callID == "" && st.CurrentFCID != "" { if callID == "" && st.CurrentFCID != "" {
callID = st.CurrentFCID callID = st.CurrentFCID
} }
item := map[string]interface{}{ item := `{"id":"","type":"function_call","status":"completed","arguments":"","call_id":"","name":""}`
"id": fmt.Sprintf("fc_%s", callID), item, _ = sjson.Set(item, "id", fmt.Sprintf("fc_%s", callID))
"type": "function_call", item, _ = sjson.Set(item, "arguments", args)
"status": "completed", item, _ = sjson.Set(item, "call_id", callID)
"arguments": args, item, _ = sjson.Set(item, "name", name)
"call_id": callID, outputsWrapper, _ = sjson.SetRaw(outputsWrapper, "arr.-1", item)
"name": name,
}
outputs = append(outputs, item)
} }
} }
if len(outputs) > 0 { if gjson.Get(outputsWrapper, "arr.#").Int() > 0 {
completed, _ = sjson.Set(completed, "response.output", outputs) completed, _ = sjson.SetRaw(completed, "response.output", gjson.Get(outputsWrapper, "arr").Raw)
} }
reasoningTokens := int64(0) reasoningTokens := int64(0)
@@ -628,27 +614,18 @@ func ConvertClaudeResponseToOpenAIResponsesNonStream(_ context.Context, _ string
} }
// Build output array // Build output array
var outputs []interface{} outputsWrapper := `{"arr":[]}`
if reasoningBuf.Len() > 0 { if reasoningBuf.Len() > 0 {
outputs = append(outputs, map[string]interface{}{ item := `{"id":"","type":"reasoning","summary":[{"type":"summary_text","text":""}]}`
"id": reasoningItemID, item, _ = sjson.Set(item, "id", reasoningItemID)
"type": "reasoning", item, _ = sjson.Set(item, "summary.0.text", reasoningBuf.String())
"summary": []interface{}{map[string]interface{}{"type": "summary_text", "text": reasoningBuf.String()}}, outputsWrapper, _ = sjson.SetRaw(outputsWrapper, "arr.-1", item)
})
} }
if currentMsgID != "" || textBuf.Len() > 0 { if currentMsgID != "" || textBuf.Len() > 0 {
outputs = append(outputs, map[string]interface{}{ item := `{"id":"","type":"message","status":"completed","content":[{"type":"output_text","annotations":[],"logprobs":[],"text":""}],"role":"assistant"}`
"id": currentMsgID, item, _ = sjson.Set(item, "id", currentMsgID)
"type": "message", item, _ = sjson.Set(item, "content.0.text", textBuf.String())
"status": "completed", outputsWrapper, _ = sjson.SetRaw(outputsWrapper, "arr.-1", item)
"content": []interface{}{map[string]interface{}{
"type": "output_text",
"annotations": []interface{}{},
"logprobs": []interface{}{},
"text": textBuf.String(),
}},
"role": "assistant",
})
} }
if len(toolCalls) > 0 { if len(toolCalls) > 0 {
// Preserve index order // Preserve index order
@@ -669,18 +646,16 @@ func ConvertClaudeResponseToOpenAIResponsesNonStream(_ context.Context, _ string
if args == "" { if args == "" {
args = "{}" args = "{}"
} }
outputs = append(outputs, map[string]interface{}{ item := `{"id":"","type":"function_call","status":"completed","arguments":"","call_id":"","name":""}`
"id": fmt.Sprintf("fc_%s", st.id), item, _ = sjson.Set(item, "id", fmt.Sprintf("fc_%s", st.id))
"type": "function_call", item, _ = sjson.Set(item, "arguments", args)
"status": "completed", item, _ = sjson.Set(item, "call_id", st.id)
"arguments": args, item, _ = sjson.Set(item, "name", st.name)
"call_id": st.id, outputsWrapper, _ = sjson.SetRaw(outputsWrapper, "arr.-1", item)
"name": st.name,
})
} }
} }
if len(outputs) > 0 { if gjson.Get(outputsWrapper, "arr.#").Int() > 0 {
out, _ = sjson.Set(out, "output", outputs) out, _ = sjson.SetRaw(out, "output", gjson.Get(outputsWrapper, "arr").Raw)
} }
// Usage // Usage

View File

@@ -9,7 +9,6 @@ package claude
import ( import (
"bytes" "bytes"
"context" "context"
"encoding/json"
"fmt" "fmt"
"strings" "strings"
@@ -191,21 +190,12 @@ func ConvertCodexResponseToClaudeNonStream(_ context.Context, _ string, original
return "" return ""
} }
response := map[string]interface{}{ out := `{"id":"","type":"message","role":"assistant","model":"","content":[],"stop_reason":null,"stop_sequence":null,"usage":{"input_tokens":0,"output_tokens":0}}`
"id": responseData.Get("id").String(), out, _ = sjson.Set(out, "id", responseData.Get("id").String())
"type": "message", out, _ = sjson.Set(out, "model", responseData.Get("model").String())
"role": "assistant", out, _ = sjson.Set(out, "usage.input_tokens", responseData.Get("usage.input_tokens").Int())
"model": responseData.Get("model").String(), out, _ = sjson.Set(out, "usage.output_tokens", responseData.Get("usage.output_tokens").Int())
"content": []interface{}{},
"stop_reason": nil,
"stop_sequence": nil,
"usage": map[string]interface{}{
"input_tokens": responseData.Get("usage.input_tokens").Int(),
"output_tokens": responseData.Get("usage.output_tokens").Int(),
},
}
var contentBlocks []interface{}
hasToolCall := false hasToolCall := false
if output := responseData.Get("output"); output.Exists() && output.IsArray() { if output := responseData.Get("output"); output.Exists() && output.IsArray() {
@@ -244,10 +234,9 @@ func ConvertCodexResponseToClaudeNonStream(_ context.Context, _ string, original
} }
} }
if thinkingBuilder.Len() > 0 { if thinkingBuilder.Len() > 0 {
contentBlocks = append(contentBlocks, map[string]interface{}{ block := `{"type":"thinking","thinking":""}`
"type": "thinking", block, _ = sjson.Set(block, "thinking", thinkingBuilder.String())
"thinking": thinkingBuilder.String(), out, _ = sjson.SetRaw(out, "content.-1", block)
})
} }
case "message": case "message":
if content := item.Get("content"); content.Exists() { if content := item.Get("content"); content.Exists() {
@@ -256,10 +245,9 @@ func ConvertCodexResponseToClaudeNonStream(_ context.Context, _ string, original
if part.Get("type").String() == "output_text" { if part.Get("type").String() == "output_text" {
text := part.Get("text").String() text := part.Get("text").String()
if text != "" { if text != "" {
contentBlocks = append(contentBlocks, map[string]interface{}{ block := `{"type":"text","text":""}`
"type": "text", block, _ = sjson.Set(block, "text", text)
"text": text, out, _ = sjson.SetRaw(out, "content.-1", block)
})
} }
} }
return true return true
@@ -267,10 +255,9 @@ func ConvertCodexResponseToClaudeNonStream(_ context.Context, _ string, original
} else { } else {
text := content.String() text := content.String()
if text != "" { if text != "" {
contentBlocks = append(contentBlocks, map[string]interface{}{ block := `{"type":"text","text":""}`
"type": "text", block, _ = sjson.Set(block, "text", text)
"text": text, out, _ = sjson.SetRaw(out, "content.-1", block)
})
} }
} }
} }
@@ -281,54 +268,41 @@ func ConvertCodexResponseToClaudeNonStream(_ context.Context, _ string, original
name = original name = original
} }
toolBlock := map[string]interface{}{ toolBlock := `{"type":"tool_use","id":"","name":"","input":{}}`
"type": "tool_use", toolBlock, _ = sjson.Set(toolBlock, "id", item.Get("call_id").String())
"id": item.Get("call_id").String(), toolBlock, _ = sjson.Set(toolBlock, "name", name)
"name": name, inputRaw := "{}"
"input": map[string]interface{}{}, if argsStr := item.Get("arguments").String(); argsStr != "" && gjson.Valid(argsStr) {
} argsJSON := gjson.Parse(argsStr)
if argsJSON.IsObject() {
if argsStr := item.Get("arguments").String(); argsStr != "" { inputRaw = argsJSON.Raw
var args interface{}
if err := json.Unmarshal([]byte(argsStr), &args); err == nil {
toolBlock["input"] = args
} }
} }
toolBlock, _ = sjson.SetRaw(toolBlock, "input", inputRaw)
contentBlocks = append(contentBlocks, toolBlock) out, _ = sjson.SetRaw(out, "content.-1", toolBlock)
} }
return true return true
}) })
} }
if len(contentBlocks) > 0 {
response["content"] = contentBlocks
}
if stopReason := responseData.Get("stop_reason"); stopReason.Exists() && stopReason.String() != "" { if stopReason := responseData.Get("stop_reason"); stopReason.Exists() && stopReason.String() != "" {
response["stop_reason"] = stopReason.String() out, _ = sjson.Set(out, "stop_reason", stopReason.String())
} else if hasToolCall { } else if hasToolCall {
response["stop_reason"] = "tool_use" out, _ = sjson.Set(out, "stop_reason", "tool_use")
} else { } else {
response["stop_reason"] = "end_turn" out, _ = sjson.Set(out, "stop_reason", "end_turn")
} }
if stopSequence := responseData.Get("stop_sequence"); stopSequence.Exists() && stopSequence.String() != "" { if stopSequence := responseData.Get("stop_sequence"); stopSequence.Exists() && stopSequence.String() != "" {
response["stop_sequence"] = stopSequence.Value() out, _ = sjson.SetRaw(out, "stop_sequence", stopSequence.Raw)
} }
if responseData.Get("usage.input_tokens").Exists() || responseData.Get("usage.output_tokens").Exists() { if responseData.Get("usage.input_tokens").Exists() || responseData.Get("usage.output_tokens").Exists() {
response["usage"] = map[string]interface{}{ out, _ = sjson.Set(out, "usage.input_tokens", responseData.Get("usage.input_tokens").Int())
"input_tokens": responseData.Get("usage.input_tokens").Int(), out, _ = sjson.Set(out, "usage.output_tokens", responseData.Get("usage.output_tokens").Int())
"output_tokens": responseData.Get("usage.output_tokens").Int(),
}
} }
responseJSON, err := json.Marshal(response) return out
if err != nil {
return ""
}
return string(responseJSON)
} }
// buildReverseMapFromClaudeOriginalShortToOriginal builds a map[short]original from original Claude request tools. // buildReverseMapFromClaudeOriginalShortToOriginal builds a map[short]original from original Claude request tools.

View File

@@ -7,7 +7,6 @@ package gemini
import ( import (
"bytes" "bytes"
"context" "context"
"encoding/json"
"fmt" "fmt"
"time" "time"
@@ -190,19 +189,19 @@ func ConvertCodexResponseToGeminiNonStream(_ context.Context, modelName string,
} }
// Process output content to build parts array // Process output content to build parts array
var parts []interface{}
hasToolCall := false hasToolCall := false
var pendingFunctionCalls []interface{} var pendingFunctionCalls []string
flushPendingFunctionCalls := func() { flushPendingFunctionCalls := func() {
if len(pendingFunctionCalls) > 0 { if len(pendingFunctionCalls) == 0 {
// Add all pending function calls as individual parts return
// This maintains the original Gemini API format while ensuring consecutive calls are grouped together
for _, fc := range pendingFunctionCalls {
parts = append(parts, fc)
}
pendingFunctionCalls = nil
} }
// Add all pending function calls as individual parts
// This maintains the original Gemini API format while ensuring consecutive calls are grouped together
for _, fc := range pendingFunctionCalls {
template, _ = sjson.SetRaw(template, "candidates.0.content.parts.-1", fc)
}
pendingFunctionCalls = nil
} }
if output := responseData.Get("output"); output.Exists() && output.IsArray() { if output := responseData.Get("output"); output.Exists() && output.IsArray() {
@@ -216,11 +215,9 @@ func ConvertCodexResponseToGeminiNonStream(_ context.Context, modelName string,
// Add thinking content // Add thinking content
if content := value.Get("content"); content.Exists() { if content := value.Get("content"); content.Exists() {
part := map[string]interface{}{ part := `{"text":"","thought":true}`
"thought": true, part, _ = sjson.Set(part, "text", content.String())
"text": content.String(), template, _ = sjson.SetRaw(template, "candidates.0.content.parts.-1", part)
}
parts = append(parts, part)
} }
case "message": case "message":
@@ -232,10 +229,9 @@ func ConvertCodexResponseToGeminiNonStream(_ context.Context, modelName string,
content.ForEach(func(_, contentItem gjson.Result) bool { content.ForEach(func(_, contentItem gjson.Result) bool {
if contentItem.Get("type").String() == "output_text" { if contentItem.Get("type").String() == "output_text" {
if text := contentItem.Get("text"); text.Exists() { if text := contentItem.Get("text"); text.Exists() {
part := map[string]interface{}{ part := `{"text":""}`
"text": text.String(), part, _ = sjson.Set(part, "text", text.String())
} template, _ = sjson.SetRaw(template, "candidates.0.content.parts.-1", part)
parts = append(parts, part)
} }
} }
return true return true
@@ -245,28 +241,21 @@ func ConvertCodexResponseToGeminiNonStream(_ context.Context, modelName string,
case "function_call": case "function_call":
// Collect function call for potential merging with consecutive ones // Collect function call for potential merging with consecutive ones
hasToolCall = true hasToolCall = true
functionCall := map[string]interface{}{ functionCall := `{"functionCall":{"args":{},"name":""}}`
"functionCall": map[string]interface{}{ {
"name": func() string { n := value.Get("name").String()
n := value.Get("name").String() rev := buildReverseMapFromGeminiOriginal(originalRequestRawJSON)
rev := buildReverseMapFromGeminiOriginal(originalRequestRawJSON) if orig, ok := rev[n]; ok {
if orig, ok := rev[n]; ok { n = orig
return orig }
} functionCall, _ = sjson.Set(functionCall, "functionCall.name", n)
return n
}(),
"args": map[string]interface{}{},
},
} }
// Parse and set arguments // Parse and set arguments
if argsStr := value.Get("arguments").String(); argsStr != "" { if argsStr := value.Get("arguments").String(); argsStr != "" {
argsResult := gjson.Parse(argsStr) argsResult := gjson.Parse(argsStr)
if argsResult.IsObject() { if argsResult.IsObject() {
var args map[string]interface{} functionCall, _ = sjson.SetRaw(functionCall, "functionCall.args", argsStr)
if err := json.Unmarshal([]byte(argsStr), &args); err == nil {
functionCall["functionCall"].(map[string]interface{})["args"] = args
}
} }
} }
@@ -279,11 +268,6 @@ func ConvertCodexResponseToGeminiNonStream(_ context.Context, modelName string,
flushPendingFunctionCalls() flushPendingFunctionCalls()
} }
// Set the parts array
if len(parts) > 0 {
template, _ = sjson.SetRaw(template, "candidates.0.content.parts", mustMarshalJSON(parts))
}
// Set finish reason based on whether there were tool calls // Set finish reason based on whether there were tool calls
if hasToolCall { if hasToolCall {
template, _ = sjson.Set(template, "candidates.0.finishReason", "STOP") template, _ = sjson.Set(template, "candidates.0.finishReason", "STOP")
@@ -323,15 +307,6 @@ func buildReverseMapFromGeminiOriginal(original []byte) map[string]string {
return rev return rev
} }
// mustMarshalJSON marshals a value to JSON, panicking on error.
func mustMarshalJSON(v interface{}) string {
data, err := json.Marshal(v)
if err != nil {
return ""
}
return string(data)
}
func GeminiTokenCount(ctx context.Context, count int64) string { func GeminiTokenCount(ctx context.Context, count int64) string {
return fmt.Sprintf(`{"totalTokens":%d,"promptTokensDetails":[{"modality":"TEXT","tokenCount":%d}]}`, count, count) return fmt.Sprintf(`{"totalTokens":%d,"promptTokensDetails":[{"modality":"TEXT","tokenCount":%d}]}`, count, count)
} }

View File

@@ -7,10 +7,8 @@ package claude
import ( import (
"bytes" "bytes"
"encoding/json"
"strings" "strings"
client "github.com/router-for-me/CLIProxyAPI/v6/internal/interfaces"
"github.com/router-for-me/CLIProxyAPI/v6/internal/translator/gemini/common" "github.com/router-for-me/CLIProxyAPI/v6/internal/translator/gemini/common"
"github.com/router-for-me/CLIProxyAPI/v6/internal/util" "github.com/router-for-me/CLIProxyAPI/v6/internal/util"
"github.com/tidwall/gjson" "github.com/tidwall/gjson"
@@ -41,92 +39,102 @@ func ConvertClaudeRequestToCLI(modelName string, inputRawJSON []byte, _ bool) []
rawJSON := bytes.Clone(inputRawJSON) rawJSON := bytes.Clone(inputRawJSON)
rawJSON = bytes.Replace(rawJSON, []byte(`"url":{"type":"string","format":"uri",`), []byte(`"url":{"type":"string",`), -1) rawJSON = bytes.Replace(rawJSON, []byte(`"url":{"type":"string","format":"uri",`), []byte(`"url":{"type":"string",`), -1)
// Build output Gemini CLI request JSON
out := `{"model":"","request":{"contents":[]}}`
out, _ = sjson.Set(out, "model", modelName)
// system instruction // system instruction
var systemInstruction *client.Content if systemResult := gjson.GetBytes(rawJSON, "system"); systemResult.IsArray() {
systemResult := gjson.GetBytes(rawJSON, "system") systemInstruction := `{"role":"user","parts":[]}`
if systemResult.IsArray() { hasSystemParts := false
systemResults := systemResult.Array() systemResult.ForEach(func(_, systemPromptResult gjson.Result) bool {
systemInstruction = &client.Content{Role: "user", Parts: []client.Part{}} if systemPromptResult.Get("type").String() == "text" {
for i := 0; i < len(systemResults); i++ { textResult := systemPromptResult.Get("text")
systemPromptResult := systemResults[i] if textResult.Type == gjson.String {
systemTypePromptResult := systemPromptResult.Get("type") part := `{"text":""}`
if systemTypePromptResult.Type == gjson.String && systemTypePromptResult.String() == "text" { part, _ = sjson.Set(part, "text", textResult.String())
systemPrompt := systemPromptResult.Get("text").String() systemInstruction, _ = sjson.SetRaw(systemInstruction, "parts.-1", part)
systemPart := client.Part{Text: systemPrompt} hasSystemParts = true
systemInstruction.Parts = append(systemInstruction.Parts, systemPart) }
} }
return true
})
if hasSystemParts {
out, _ = sjson.SetRaw(out, "request.systemInstruction", systemInstruction)
} }
if len(systemInstruction.Parts) == 0 { } else if systemResult.Type == gjson.String {
systemInstruction = nil out, _ = sjson.Set(out, "request.systemInstruction.parts.-1.text", systemResult.String())
}
} }
// contents // contents
contents := make([]client.Content, 0) if messagesResult := gjson.GetBytes(rawJSON, "messages"); messagesResult.IsArray() {
messagesResult := gjson.GetBytes(rawJSON, "messages") messagesResult.ForEach(func(_, messageResult gjson.Result) bool {
if messagesResult.IsArray() {
messageResults := messagesResult.Array()
for i := 0; i < len(messageResults); i++ {
messageResult := messageResults[i]
roleResult := messageResult.Get("role") roleResult := messageResult.Get("role")
if roleResult.Type != gjson.String { if roleResult.Type != gjson.String {
continue return true
} }
role := roleResult.String() role := roleResult.String()
if role == "assistant" { if role == "assistant" {
role = "model" role = "model"
} }
clientContent := client.Content{Role: role, Parts: []client.Part{}}
contentJSON := `{"role":"","parts":[]}`
contentJSON, _ = sjson.Set(contentJSON, "role", role)
contentsResult := messageResult.Get("content") contentsResult := messageResult.Get("content")
if contentsResult.IsArray() { if contentsResult.IsArray() {
contentResults := contentsResult.Array() contentsResult.ForEach(func(_, contentResult gjson.Result) bool {
for j := 0; j < len(contentResults); j++ { switch contentResult.Get("type").String() {
contentResult := contentResults[j] case "text":
contentTypeResult := contentResult.Get("type") part := `{"text":""}`
if contentTypeResult.Type == gjson.String && contentTypeResult.String() == "text" { part, _ = sjson.Set(part, "text", contentResult.Get("text").String())
prompt := contentResult.Get("text").String() contentJSON, _ = sjson.SetRaw(contentJSON, "parts.-1", part)
clientContent.Parts = append(clientContent.Parts, client.Part{Text: prompt})
} else if contentTypeResult.Type == gjson.String && contentTypeResult.String() == "tool_use" { case "tool_use":
functionName := contentResult.Get("name").String() functionName := contentResult.Get("name").String()
functionArgs := contentResult.Get("input").String() functionArgs := contentResult.Get("input").String()
var args map[string]any argsResult := gjson.Parse(functionArgs)
if err := json.Unmarshal([]byte(functionArgs), &args); err == nil { if argsResult.IsObject() && gjson.Valid(functionArgs) {
clientContent.Parts = append(clientContent.Parts, client.Part{ part := `{"thoughtSignature":"","functionCall":{"name":"","args":{}}}`
FunctionCall: &client.FunctionCall{Name: functionName, Args: args}, part, _ = sjson.Set(part, "thoughtSignature", geminiCLIClaudeThoughtSignature)
ThoughtSignature: geminiCLIClaudeThoughtSignature, part, _ = sjson.Set(part, "functionCall.name", functionName)
}) part, _ = sjson.SetRaw(part, "functionCall.args", functionArgs)
contentJSON, _ = sjson.SetRaw(contentJSON, "parts.-1", part)
} }
} else if contentTypeResult.Type == gjson.String && contentTypeResult.String() == "tool_result" {
case "tool_result":
toolCallID := contentResult.Get("tool_use_id").String() toolCallID := contentResult.Get("tool_use_id").String()
if toolCallID != "" { if toolCallID == "" {
funcName := toolCallID return true
toolCallIDs := strings.Split(toolCallID, "-")
if len(toolCallIDs) > 1 {
funcName = strings.Join(toolCallIDs[0:len(toolCallIDs)-1], "-")
}
responseData := contentResult.Get("content").Raw
functionResponse := client.FunctionResponse{Name: funcName, Response: map[string]interface{}{"result": responseData}}
clientContent.Parts = append(clientContent.Parts, client.Part{FunctionResponse: &functionResponse})
} }
funcName := toolCallID
toolCallIDs := strings.Split(toolCallID, "-")
if len(toolCallIDs) > 1 {
funcName = strings.Join(toolCallIDs[0:len(toolCallIDs)-1], "-")
}
responseData := contentResult.Get("content").Raw
part := `{"functionResponse":{"name":"","response":{"result":""}}}`
part, _ = sjson.Set(part, "functionResponse.name", funcName)
part, _ = sjson.Set(part, "functionResponse.response.result", responseData)
contentJSON, _ = sjson.SetRaw(contentJSON, "parts.-1", part)
} }
} return true
contents = append(contents, clientContent) })
out, _ = sjson.SetRaw(out, "request.contents.-1", contentJSON)
} else if contentsResult.Type == gjson.String { } else if contentsResult.Type == gjson.String {
prompt := contentsResult.String() part := `{"text":""}`
contents = append(contents, client.Content{Role: role, Parts: []client.Part{{Text: prompt}}}) part, _ = sjson.Set(part, "text", contentsResult.String())
contentJSON, _ = sjson.SetRaw(contentJSON, "parts.-1", part)
out, _ = sjson.SetRaw(out, "request.contents.-1", contentJSON)
} }
} return true
})
} }
// tools // tools
var tools []client.ToolDeclaration if toolsResult := gjson.GetBytes(rawJSON, "tools"); toolsResult.IsArray() {
toolsResult := gjson.GetBytes(rawJSON, "tools") hasTools := false
if toolsResult.IsArray() { toolsResult.ForEach(func(_, toolResult gjson.Result) bool {
tools = make([]client.ToolDeclaration, 1)
tools[0].FunctionDeclarations = make([]any, 0)
toolsResults := toolsResult.Array()
for i := 0; i < len(toolsResults); i++ {
toolResult := toolsResults[i]
inputSchemaResult := toolResult.Get("input_schema") inputSchemaResult := toolResult.Get("input_schema")
if inputSchemaResult.Exists() && inputSchemaResult.IsObject() { if inputSchemaResult.Exists() && inputSchemaResult.IsObject() {
inputSchema := inputSchemaResult.Raw inputSchema := inputSchemaResult.Raw
@@ -136,30 +144,19 @@ func ConvertClaudeRequestToCLI(modelName string, inputRawJSON []byte, _ bool) []
tool, _ = sjson.Delete(tool, "input_examples") tool, _ = sjson.Delete(tool, "input_examples")
tool, _ = sjson.Delete(tool, "type") tool, _ = sjson.Delete(tool, "type")
tool, _ = sjson.Delete(tool, "cache_control") tool, _ = sjson.Delete(tool, "cache_control")
var toolDeclaration any if gjson.Valid(tool) && gjson.Parse(tool).IsObject() {
if err := json.Unmarshal([]byte(tool), &toolDeclaration); err == nil { if !hasTools {
tools[0].FunctionDeclarations = append(tools[0].FunctionDeclarations, toolDeclaration) out, _ = sjson.SetRaw(out, "request.tools", `[{"functionDeclarations":[]}]`)
hasTools = true
}
out, _ = sjson.SetRaw(out, "request.tools.0.functionDeclarations.-1", tool)
} }
} }
return true
})
if !hasTools {
out, _ = sjson.Delete(out, "request.tools")
} }
} else {
tools = make([]client.ToolDeclaration, 0)
}
// Build output Gemini CLI request JSON
out := `{"model":"","request":{"contents":[]}}`
out, _ = sjson.Set(out, "model", modelName)
if systemInstruction != nil {
b, _ := json.Marshal(systemInstruction)
out, _ = sjson.SetRaw(out, "request.systemInstruction", string(b))
}
if len(contents) > 0 {
b, _ := json.Marshal(contents)
out, _ = sjson.SetRaw(out, "request.contents", string(b))
}
if len(tools) > 0 && len(tools[0].FunctionDeclarations) > 0 {
b, _ := json.Marshal(tools)
out, _ = sjson.SetRaw(out, "request.tools", string(b))
} }
// Map Anthropic thinking -> Gemini thinkingBudget/include_thoughts when type==enabled // Map Anthropic thinking -> Gemini thinkingBudget/include_thoughts when type==enabled

View File

@@ -9,7 +9,6 @@ package claude
import ( import (
"bytes" "bytes"
"context" "context"
"encoding/json"
"fmt" "fmt"
"strings" "strings"
"sync/atomic" "sync/atomic"
@@ -276,22 +275,16 @@ func ConvertGeminiCLIResponseToClaudeNonStream(_ context.Context, _ string, orig
root := gjson.ParseBytes(rawJSON) root := gjson.ParseBytes(rawJSON)
response := map[string]interface{}{ out := `{"id":"","type":"message","role":"assistant","model":"","content":[],"stop_reason":null,"stop_sequence":null,"usage":{"input_tokens":0,"output_tokens":0}}`
"id": root.Get("response.responseId").String(), out, _ = sjson.Set(out, "id", root.Get("response.responseId").String())
"type": "message", out, _ = sjson.Set(out, "model", root.Get("response.modelVersion").String())
"role": "assistant",
"model": root.Get("response.modelVersion").String(), inputTokens := root.Get("response.usageMetadata.promptTokenCount").Int()
"content": []interface{}{}, outputTokens := root.Get("response.usageMetadata.candidatesTokenCount").Int() + root.Get("response.usageMetadata.thoughtsTokenCount").Int()
"stop_reason": nil, out, _ = sjson.Set(out, "usage.input_tokens", inputTokens)
"stop_sequence": nil, out, _ = sjson.Set(out, "usage.output_tokens", outputTokens)
"usage": map[string]interface{}{
"input_tokens": root.Get("response.usageMetadata.promptTokenCount").Int(),
"output_tokens": root.Get("response.usageMetadata.candidatesTokenCount").Int() + root.Get("response.usageMetadata.thoughtsTokenCount").Int(),
},
}
parts := root.Get("response.candidates.0.content.parts") parts := root.Get("response.candidates.0.content.parts")
var contentBlocks []interface{}
textBuilder := strings.Builder{} textBuilder := strings.Builder{}
thinkingBuilder := strings.Builder{} thinkingBuilder := strings.Builder{}
toolIDCounter := 0 toolIDCounter := 0
@@ -301,10 +294,9 @@ func ConvertGeminiCLIResponseToClaudeNonStream(_ context.Context, _ string, orig
if textBuilder.Len() == 0 { if textBuilder.Len() == 0 {
return return
} }
contentBlocks = append(contentBlocks, map[string]interface{}{ block := `{"type":"text","text":""}`
"type": "text", block, _ = sjson.Set(block, "text", textBuilder.String())
"text": textBuilder.String(), out, _ = sjson.SetRaw(out, "content.-1", block)
})
textBuilder.Reset() textBuilder.Reset()
} }
@@ -312,10 +304,9 @@ func ConvertGeminiCLIResponseToClaudeNonStream(_ context.Context, _ string, orig
if thinkingBuilder.Len() == 0 { if thinkingBuilder.Len() == 0 {
return return
} }
contentBlocks = append(contentBlocks, map[string]interface{}{ block := `{"type":"thinking","thinking":""}`
"type": "thinking", block, _ = sjson.Set(block, "thinking", thinkingBuilder.String())
"thinking": thinkingBuilder.String(), out, _ = sjson.SetRaw(out, "content.-1", block)
})
thinkingBuilder.Reset() thinkingBuilder.Reset()
} }
@@ -339,21 +330,15 @@ func ConvertGeminiCLIResponseToClaudeNonStream(_ context.Context, _ string, orig
name := functionCall.Get("name").String() name := functionCall.Get("name").String()
toolIDCounter++ toolIDCounter++
toolBlock := map[string]interface{}{ toolBlock := `{"type":"tool_use","id":"","name":"","input":{}}`
"type": "tool_use", toolBlock, _ = sjson.Set(toolBlock, "id", fmt.Sprintf("tool_%d", toolIDCounter))
"id": fmt.Sprintf("tool_%d", toolIDCounter), toolBlock, _ = sjson.Set(toolBlock, "name", name)
"name": name, inputRaw := "{}"
"input": map[string]interface{}{}, if args := functionCall.Get("args"); args.Exists() && gjson.Valid(args.Raw) && args.IsObject() {
inputRaw = args.Raw
} }
toolBlock, _ = sjson.SetRaw(toolBlock, "input", inputRaw)
if args := functionCall.Get("args"); args.Exists() { out, _ = sjson.SetRaw(out, "content.-1", toolBlock)
var parsed interface{}
if err := json.Unmarshal([]byte(args.Raw), &parsed); err == nil {
toolBlock["input"] = parsed
}
}
contentBlocks = append(contentBlocks, toolBlock)
continue continue
} }
} }
@@ -362,8 +347,6 @@ func ConvertGeminiCLIResponseToClaudeNonStream(_ context.Context, _ string, orig
flushThinking() flushThinking()
flushText() flushText()
response["content"] = contentBlocks
stopReason := "end_turn" stopReason := "end_turn"
if hasToolCall { if hasToolCall {
stopReason = "tool_use" stopReason = "tool_use"
@@ -379,19 +362,13 @@ func ConvertGeminiCLIResponseToClaudeNonStream(_ context.Context, _ string, orig
} }
} }
} }
response["stop_reason"] = stopReason out, _ = sjson.Set(out, "stop_reason", stopReason)
if usage := response["usage"].(map[string]interface{}); usage["input_tokens"] == int64(0) && usage["output_tokens"] == int64(0) { if inputTokens == int64(0) && outputTokens == int64(0) && !root.Get("response.usageMetadata").Exists() {
if usageMeta := root.Get("response.usageMetadata"); !usageMeta.Exists() { out, _ = sjson.Delete(out, "usage")
delete(response, "usage")
}
} }
encoded, err := json.Marshal(response) return out
if err != nil {
return ""
}
return string(encoded)
} }
func ClaudeTokenCount(ctx context.Context, count int64) string { func ClaudeTokenCount(ctx context.Context, count int64) string {

View File

@@ -7,7 +7,6 @@ package gemini
import ( import (
"bytes" "bytes"
"encoding/json"
"fmt" "fmt"
"github.com/router-for-me/CLIProxyAPI/v6/internal/translator/gemini/common" "github.com/router-for-me/CLIProxyAPI/v6/internal/translator/gemini/common"
@@ -117,8 +116,6 @@ func ConvertGeminiRequestToGeminiCLI(_ string, inputRawJSON []byte, _ bool) []by
// FunctionCallGroup represents a group of function calls and their responses // FunctionCallGroup represents a group of function calls and their responses
type FunctionCallGroup struct { type FunctionCallGroup struct {
ModelContent map[string]interface{}
FunctionCalls []gjson.Result
ResponsesNeeded int ResponsesNeeded int
} }
@@ -146,7 +143,7 @@ func fixCLIToolResponse(input string) (string, error) {
} }
// Initialize data structures for processing and grouping // Initialize data structures for processing and grouping
var newContents []interface{} // Final processed contents array contentsWrapper := `{"contents":[]}`
var pendingGroups []*FunctionCallGroup // Groups awaiting completion with responses var pendingGroups []*FunctionCallGroup // Groups awaiting completion with responses
var collectedResponses []gjson.Result // Standalone responses to be matched var collectedResponses []gjson.Result // Standalone responses to be matched
@@ -178,23 +175,17 @@ func fixCLIToolResponse(input string) (string, error) {
collectedResponses = collectedResponses[group.ResponsesNeeded:] collectedResponses = collectedResponses[group.ResponsesNeeded:]
// Create merged function response content // Create merged function response content
var responseParts []interface{} functionResponseContent := `{"parts":[],"role":"function"}`
for _, response := range groupResponses { for _, response := range groupResponses {
var responseMap map[string]interface{} if !response.IsObject() {
errUnmarshal := json.Unmarshal([]byte(response.Raw), &responseMap) log.Warnf("failed to parse function response")
if errUnmarshal != nil {
log.Warnf("failed to unmarshal function response: %v\n", errUnmarshal)
continue continue
} }
responseParts = append(responseParts, responseMap) functionResponseContent, _ = sjson.SetRaw(functionResponseContent, "parts.-1", response.Raw)
} }
if len(responseParts) > 0 { if gjson.Get(functionResponseContent, "parts.#").Int() > 0 {
functionResponseContent := map[string]interface{}{ contentsWrapper, _ = sjson.SetRaw(contentsWrapper, "contents.-1", functionResponseContent)
"parts": responseParts,
"role": "function",
}
newContents = append(newContents, functionResponseContent)
} }
// Remove this group as it's been satisfied // Remove this group as it's been satisfied
@@ -208,50 +199,42 @@ func fixCLIToolResponse(input string) (string, error) {
// If this is a model with function calls, create a new group // If this is a model with function calls, create a new group
if role == "model" { if role == "model" {
var functionCallsInThisModel []gjson.Result functionCallsCount := 0
parts.ForEach(func(_, part gjson.Result) bool { parts.ForEach(func(_, part gjson.Result) bool {
if part.Get("functionCall").Exists() { if part.Get("functionCall").Exists() {
functionCallsInThisModel = append(functionCallsInThisModel, part) functionCallsCount++
} }
return true return true
}) })
if len(functionCallsInThisModel) > 0 { if functionCallsCount > 0 {
// Add the model content // Add the model content
var contentMap map[string]interface{} if !value.IsObject() {
errUnmarshal := json.Unmarshal([]byte(value.Raw), &contentMap) log.Warnf("failed to parse model content")
if errUnmarshal != nil {
log.Warnf("failed to unmarshal model content: %v\n", errUnmarshal)
return true return true
} }
newContents = append(newContents, contentMap) contentsWrapper, _ = sjson.SetRaw(contentsWrapper, "contents.-1", value.Raw)
// Create a new group for tracking responses // Create a new group for tracking responses
group := &FunctionCallGroup{ group := &FunctionCallGroup{
ModelContent: contentMap, ResponsesNeeded: functionCallsCount,
FunctionCalls: functionCallsInThisModel,
ResponsesNeeded: len(functionCallsInThisModel),
} }
pendingGroups = append(pendingGroups, group) pendingGroups = append(pendingGroups, group)
} else { } else {
// Regular model content without function calls // Regular model content without function calls
var contentMap map[string]interface{} if !value.IsObject() {
errUnmarshal := json.Unmarshal([]byte(value.Raw), &contentMap) log.Warnf("failed to parse content")
if errUnmarshal != nil {
log.Warnf("failed to unmarshal content: %v\n", errUnmarshal)
return true return true
} }
newContents = append(newContents, contentMap) contentsWrapper, _ = sjson.SetRaw(contentsWrapper, "contents.-1", value.Raw)
} }
} else { } else {
// Non-model content (user, etc.) // Non-model content (user, etc.)
var contentMap map[string]interface{} if !value.IsObject() {
errUnmarshal := json.Unmarshal([]byte(value.Raw), &contentMap) log.Warnf("failed to parse content")
if errUnmarshal != nil {
log.Warnf("failed to unmarshal content: %v\n", errUnmarshal)
return true return true
} }
newContents = append(newContents, contentMap) contentsWrapper, _ = sjson.SetRaw(contentsWrapper, "contents.-1", value.Raw)
} }
return true return true
@@ -263,31 +246,24 @@ func fixCLIToolResponse(input string) (string, error) {
groupResponses := collectedResponses[:group.ResponsesNeeded] groupResponses := collectedResponses[:group.ResponsesNeeded]
collectedResponses = collectedResponses[group.ResponsesNeeded:] collectedResponses = collectedResponses[group.ResponsesNeeded:]
var responseParts []interface{} functionResponseContent := `{"parts":[],"role":"function"}`
for _, response := range groupResponses { for _, response := range groupResponses {
var responseMap map[string]interface{} if !response.IsObject() {
errUnmarshal := json.Unmarshal([]byte(response.Raw), &responseMap) log.Warnf("failed to parse function response")
if errUnmarshal != nil {
log.Warnf("failed to unmarshal function response: %v\n", errUnmarshal)
continue continue
} }
responseParts = append(responseParts, responseMap) functionResponseContent, _ = sjson.SetRaw(functionResponseContent, "parts.-1", response.Raw)
} }
if len(responseParts) > 0 { if gjson.Get(functionResponseContent, "parts.#").Int() > 0 {
functionResponseContent := map[string]interface{}{ contentsWrapper, _ = sjson.SetRaw(contentsWrapper, "contents.-1", functionResponseContent)
"parts": responseParts,
"role": "function",
}
newContents = append(newContents, functionResponseContent)
} }
} }
} }
// Update the original JSON with the new contents // Update the original JSON with the new contents
result := input result := input
newContentsJSON, _ := json.Marshal(newContents) result, _ = sjson.SetRaw(result, "request.contents", gjson.Get(contentsWrapper, "contents").Raw)
result, _ = sjson.Set(result, "request.contents", json.RawMessage(newContentsJSON))
return result, nil return result, nil
} }

View File

@@ -160,6 +160,14 @@ func ConvertOpenAIRequestToGeminiCLI(modelName string, inputRawJSON []byte, _ bo
} else if content.IsObject() && content.Get("type").String() == "text" { } else if content.IsObject() && content.Get("type").String() == "text" {
out, _ = sjson.SetBytes(out, "request.systemInstruction.role", "user") out, _ = sjson.SetBytes(out, "request.systemInstruction.role", "user")
out, _ = sjson.SetBytes(out, "request.systemInstruction.parts.0.text", content.Get("text").String()) out, _ = sjson.SetBytes(out, "request.systemInstruction.parts.0.text", content.Get("text").String())
} else if content.IsArray() {
contents := content.Array()
if len(contents) > 0 {
out, _ = sjson.SetBytes(out, "request.systemInstruction.role", "user")
for j := 0; j < len(contents); j++ {
out, _ = sjson.SetBytes(out, fmt.Sprintf("request.systemInstruction.parts.%d.text", j), contents[j].Get("text").String())
}
}
} }
} else if role == "user" || (role == "system" && len(arr) == 1) { } else if role == "user" || (role == "system" && len(arr) == 1) {
// Build single user content node to avoid splitting into multiple contents // Build single user content node to avoid splitting into multiple contents
@@ -278,7 +286,7 @@ func ConvertOpenAIRequestToGeminiCLI(modelName string, inputRawJSON []byte, _ bo
log.Warnf("Failed to set default schema type for tool '%s': %v", fn.Get("name").String(), errSet) log.Warnf("Failed to set default schema type for tool '%s': %v", fn.Get("name").String(), errSet)
continue continue
} }
fnRaw, errSet = sjson.Set(fnRaw, "parametersJsonSchema.properties", map[string]interface{}{}) fnRaw, errSet = sjson.SetRaw(fnRaw, "parametersJsonSchema.properties", `{}`)
if errSet != nil { if errSet != nil {
log.Warnf("Failed to set default schema properties for tool '%s': %v", fn.Get("name").String(), errSet) log.Warnf("Failed to set default schema properties for tool '%s': %v", fn.Get("name").String(), errSet)
continue continue
@@ -293,7 +301,7 @@ func ConvertOpenAIRequestToGeminiCLI(modelName string, inputRawJSON []byte, _ bo
log.Warnf("Failed to set default schema type for tool '%s': %v", fn.Get("name").String(), errSet) log.Warnf("Failed to set default schema type for tool '%s': %v", fn.Get("name").String(), errSet)
continue continue
} }
fnRaw, errSet = sjson.Set(fnRaw, "parametersJsonSchema.properties", map[string]interface{}{}) fnRaw, errSet = sjson.SetRaw(fnRaw, "parametersJsonSchema.properties", `{}`)
if errSet != nil { if errSet != nil {
log.Warnf("Failed to set default schema properties for tool '%s': %v", fn.Get("name").String(), errSet) log.Warnf("Failed to set default schema properties for tool '%s': %v", fn.Get("name").String(), errSet)
continue continue

View File

@@ -8,7 +8,6 @@ package chat_completions
import ( import (
"bytes" "bytes"
"context" "context"
"encoding/json"
"fmt" "fmt"
"strings" "strings"
"sync/atomic" "sync/atomic"
@@ -171,21 +170,14 @@ func ConvertCliResponseToOpenAI(_ context.Context, _ string, originalRequestRawJ
mimeType = "image/png" mimeType = "image/png"
} }
imageURL := fmt.Sprintf("data:%s;base64,%s", mimeType, data) imageURL := fmt.Sprintf("data:%s;base64,%s", mimeType, data)
imagePayload, err := json.Marshal(map[string]any{ imagePayload := `{"image_url":{"url":""},"type":"image_url"}`
"type": "image_url", imagePayload, _ = sjson.Set(imagePayload, "image_url.url", imageURL)
"image_url": map[string]string{
"url": imageURL,
},
})
if err != nil {
continue
}
imagesResult := gjson.Get(template, "choices.0.delta.images") imagesResult := gjson.Get(template, "choices.0.delta.images")
if !imagesResult.Exists() || !imagesResult.IsArray() { if !imagesResult.Exists() || !imagesResult.IsArray() {
template, _ = sjson.SetRaw(template, "choices.0.delta.images", `[]`) template, _ = sjson.SetRaw(template, "choices.0.delta.images", `[]`)
} }
template, _ = sjson.Set(template, "choices.0.delta.role", "assistant") template, _ = sjson.Set(template, "choices.0.delta.role", "assistant")
template, _ = sjson.SetRaw(template, "choices.0.delta.images.-1", string(imagePayload)) template, _ = sjson.SetRaw(template, "choices.0.delta.images.-1", imagePayload)
} }
} }
} }

View File

@@ -7,10 +7,8 @@ package claude
import ( import (
"bytes" "bytes"
"encoding/json"
"strings" "strings"
client "github.com/router-for-me/CLIProxyAPI/v6/internal/interfaces"
"github.com/router-for-me/CLIProxyAPI/v6/internal/translator/gemini/common" "github.com/router-for-me/CLIProxyAPI/v6/internal/translator/gemini/common"
"github.com/router-for-me/CLIProxyAPI/v6/internal/util" "github.com/router-for-me/CLIProxyAPI/v6/internal/util"
"github.com/tidwall/gjson" "github.com/tidwall/gjson"
@@ -34,92 +32,102 @@ func ConvertClaudeRequestToGemini(modelName string, inputRawJSON []byte, _ bool)
rawJSON := bytes.Clone(inputRawJSON) rawJSON := bytes.Clone(inputRawJSON)
rawJSON = bytes.Replace(rawJSON, []byte(`"url":{"type":"string","format":"uri",`), []byte(`"url":{"type":"string",`), -1) rawJSON = bytes.Replace(rawJSON, []byte(`"url":{"type":"string","format":"uri",`), []byte(`"url":{"type":"string",`), -1)
// Build output Gemini CLI request JSON
out := `{"contents":[]}`
out, _ = sjson.Set(out, "model", modelName)
// system instruction // system instruction
var systemInstruction *client.Content if systemResult := gjson.GetBytes(rawJSON, "system"); systemResult.IsArray() {
systemResult := gjson.GetBytes(rawJSON, "system") systemInstruction := `{"role":"user","parts":[]}`
if systemResult.IsArray() { hasSystemParts := false
systemResults := systemResult.Array() systemResult.ForEach(func(_, systemPromptResult gjson.Result) bool {
systemInstruction = &client.Content{Role: "user", Parts: []client.Part{}} if systemPromptResult.Get("type").String() == "text" {
for i := 0; i < len(systemResults); i++ { textResult := systemPromptResult.Get("text")
systemPromptResult := systemResults[i] if textResult.Type == gjson.String {
systemTypePromptResult := systemPromptResult.Get("type") part := `{"text":""}`
if systemTypePromptResult.Type == gjson.String && systemTypePromptResult.String() == "text" { part, _ = sjson.Set(part, "text", textResult.String())
systemPrompt := systemPromptResult.Get("text").String() systemInstruction, _ = sjson.SetRaw(systemInstruction, "parts.-1", part)
systemPart := client.Part{Text: systemPrompt} hasSystemParts = true
systemInstruction.Parts = append(systemInstruction.Parts, systemPart) }
} }
return true
})
if hasSystemParts {
out, _ = sjson.SetRaw(out, "system_instruction", systemInstruction)
} }
if len(systemInstruction.Parts) == 0 { } else if systemResult.Type == gjson.String {
systemInstruction = nil out, _ = sjson.Set(out, "request.system_instruction.parts.-1.text", systemResult.String())
}
} }
// contents // contents
contents := make([]client.Content, 0) if messagesResult := gjson.GetBytes(rawJSON, "messages"); messagesResult.IsArray() {
messagesResult := gjson.GetBytes(rawJSON, "messages") messagesResult.ForEach(func(_, messageResult gjson.Result) bool {
if messagesResult.IsArray() {
messageResults := messagesResult.Array()
for i := 0; i < len(messageResults); i++ {
messageResult := messageResults[i]
roleResult := messageResult.Get("role") roleResult := messageResult.Get("role")
if roleResult.Type != gjson.String { if roleResult.Type != gjson.String {
continue return true
} }
role := roleResult.String() role := roleResult.String()
if role == "assistant" { if role == "assistant" {
role = "model" role = "model"
} }
clientContent := client.Content{Role: role, Parts: []client.Part{}}
contentJSON := `{"role":"","parts":[]}`
contentJSON, _ = sjson.Set(contentJSON, "role", role)
contentsResult := messageResult.Get("content") contentsResult := messageResult.Get("content")
if contentsResult.IsArray() { if contentsResult.IsArray() {
contentResults := contentsResult.Array() contentsResult.ForEach(func(_, contentResult gjson.Result) bool {
for j := 0; j < len(contentResults); j++ { switch contentResult.Get("type").String() {
contentResult := contentResults[j] case "text":
contentTypeResult := contentResult.Get("type") part := `{"text":""}`
if contentTypeResult.Type == gjson.String && contentTypeResult.String() == "text" { part, _ = sjson.Set(part, "text", contentResult.Get("text").String())
prompt := contentResult.Get("text").String() contentJSON, _ = sjson.SetRaw(contentJSON, "parts.-1", part)
clientContent.Parts = append(clientContent.Parts, client.Part{Text: prompt})
} else if contentTypeResult.Type == gjson.String && contentTypeResult.String() == "tool_use" { case "tool_use":
functionName := contentResult.Get("name").String() functionName := contentResult.Get("name").String()
functionArgs := contentResult.Get("input").String() functionArgs := contentResult.Get("input").String()
var args map[string]any argsResult := gjson.Parse(functionArgs)
if err := json.Unmarshal([]byte(functionArgs), &args); err == nil { if argsResult.IsObject() && gjson.Valid(functionArgs) {
clientContent.Parts = append(clientContent.Parts, client.Part{ part := `{"thoughtSignature":"","functionCall":{"name":"","args":{}}}`
FunctionCall: &client.FunctionCall{Name: functionName, Args: args}, part, _ = sjson.Set(part, "thoughtSignature", geminiClaudeThoughtSignature)
ThoughtSignature: geminiClaudeThoughtSignature, part, _ = sjson.Set(part, "functionCall.name", functionName)
}) part, _ = sjson.SetRaw(part, "functionCall.args", functionArgs)
contentJSON, _ = sjson.SetRaw(contentJSON, "parts.-1", part)
} }
} else if contentTypeResult.Type == gjson.String && contentTypeResult.String() == "tool_result" {
case "tool_result":
toolCallID := contentResult.Get("tool_use_id").String() toolCallID := contentResult.Get("tool_use_id").String()
if toolCallID != "" { if toolCallID == "" {
funcName := toolCallID return true
toolCallIDs := strings.Split(toolCallID, "-")
if len(toolCallIDs) > 1 {
funcName = strings.Join(toolCallIDs[0:len(toolCallIDs)-1], "-")
}
responseData := contentResult.Get("content").Raw
functionResponse := client.FunctionResponse{Name: funcName, Response: map[string]interface{}{"result": responseData}}
clientContent.Parts = append(clientContent.Parts, client.Part{FunctionResponse: &functionResponse})
} }
funcName := toolCallID
toolCallIDs := strings.Split(toolCallID, "-")
if len(toolCallIDs) > 1 {
funcName = strings.Join(toolCallIDs[0:len(toolCallIDs)-1], "-")
}
responseData := contentResult.Get("content").Raw
part := `{"functionResponse":{"name":"","response":{"result":""}}}`
part, _ = sjson.Set(part, "functionResponse.name", funcName)
part, _ = sjson.Set(part, "functionResponse.response.result", responseData)
contentJSON, _ = sjson.SetRaw(contentJSON, "parts.-1", part)
} }
} return true
contents = append(contents, clientContent) })
out, _ = sjson.SetRaw(out, "contents.-1", contentJSON)
} else if contentsResult.Type == gjson.String { } else if contentsResult.Type == gjson.String {
prompt := contentsResult.String() part := `{"text":""}`
contents = append(contents, client.Content{Role: role, Parts: []client.Part{{Text: prompt}}}) part, _ = sjson.Set(part, "text", contentsResult.String())
contentJSON, _ = sjson.SetRaw(contentJSON, "parts.-1", part)
out, _ = sjson.SetRaw(out, "contents.-1", contentJSON)
} }
} return true
})
} }
// tools // tools
var tools []client.ToolDeclaration if toolsResult := gjson.GetBytes(rawJSON, "tools"); toolsResult.IsArray() {
toolsResult := gjson.GetBytes(rawJSON, "tools") hasTools := false
if toolsResult.IsArray() { toolsResult.ForEach(func(_, toolResult gjson.Result) bool {
tools = make([]client.ToolDeclaration, 1)
tools[0].FunctionDeclarations = make([]any, 0)
toolsResults := toolsResult.Array()
for i := 0; i < len(toolsResults); i++ {
toolResult := toolsResults[i]
inputSchemaResult := toolResult.Get("input_schema") inputSchemaResult := toolResult.Get("input_schema")
if inputSchemaResult.Exists() && inputSchemaResult.IsObject() { if inputSchemaResult.Exists() && inputSchemaResult.IsObject() {
inputSchema := inputSchemaResult.Raw inputSchema := inputSchemaResult.Raw
@@ -129,30 +137,19 @@ func ConvertClaudeRequestToGemini(modelName string, inputRawJSON []byte, _ bool)
tool, _ = sjson.Delete(tool, "input_examples") tool, _ = sjson.Delete(tool, "input_examples")
tool, _ = sjson.Delete(tool, "type") tool, _ = sjson.Delete(tool, "type")
tool, _ = sjson.Delete(tool, "cache_control") tool, _ = sjson.Delete(tool, "cache_control")
var toolDeclaration any if gjson.Valid(tool) && gjson.Parse(tool).IsObject() {
if err := json.Unmarshal([]byte(tool), &toolDeclaration); err == nil { if !hasTools {
tools[0].FunctionDeclarations = append(tools[0].FunctionDeclarations, toolDeclaration) out, _ = sjson.SetRaw(out, "tools", `[{"functionDeclarations":[]}]`)
hasTools = true
}
out, _ = sjson.SetRaw(out, "tools.0.functionDeclarations.-1", tool)
} }
} }
return true
})
if !hasTools {
out, _ = sjson.Delete(out, "tools")
} }
} else {
tools = make([]client.ToolDeclaration, 0)
}
// Build output Gemini CLI request JSON
out := `{"contents":[]}`
out, _ = sjson.Set(out, "model", modelName)
if systemInstruction != nil {
b, _ := json.Marshal(systemInstruction)
out, _ = sjson.SetRaw(out, "system_instruction", string(b))
}
if len(contents) > 0 {
b, _ := json.Marshal(contents)
out, _ = sjson.SetRaw(out, "contents", string(b))
}
if len(tools) > 0 && len(tools[0].FunctionDeclarations) > 0 {
b, _ := json.Marshal(tools)
out, _ = sjson.SetRaw(out, "tools", string(b))
} }
// Map Anthropic thinking -> Gemini thinkingBudget/include_thoughts when enabled // Map Anthropic thinking -> Gemini thinkingBudget/include_thoughts when enabled

View File

@@ -9,7 +9,6 @@ package claude
import ( import (
"bytes" "bytes"
"context" "context"
"encoding/json"
"fmt" "fmt"
"strings" "strings"
"sync/atomic" "sync/atomic"
@@ -282,22 +281,16 @@ func ConvertGeminiResponseToClaudeNonStream(_ context.Context, _ string, origina
root := gjson.ParseBytes(rawJSON) root := gjson.ParseBytes(rawJSON)
response := map[string]interface{}{ out := `{"id":"","type":"message","role":"assistant","model":"","content":[],"stop_reason":null,"stop_sequence":null,"usage":{"input_tokens":0,"output_tokens":0}}`
"id": root.Get("responseId").String(), out, _ = sjson.Set(out, "id", root.Get("responseId").String())
"type": "message", out, _ = sjson.Set(out, "model", root.Get("modelVersion").String())
"role": "assistant",
"model": root.Get("modelVersion").String(), inputTokens := root.Get("usageMetadata.promptTokenCount").Int()
"content": []interface{}{}, outputTokens := root.Get("usageMetadata.candidatesTokenCount").Int() + root.Get("usageMetadata.thoughtsTokenCount").Int()
"stop_reason": nil, out, _ = sjson.Set(out, "usage.input_tokens", inputTokens)
"stop_sequence": nil, out, _ = sjson.Set(out, "usage.output_tokens", outputTokens)
"usage": map[string]interface{}{
"input_tokens": root.Get("usageMetadata.promptTokenCount").Int(),
"output_tokens": root.Get("usageMetadata.candidatesTokenCount").Int() + root.Get("usageMetadata.thoughtsTokenCount").Int(),
},
}
parts := root.Get("candidates.0.content.parts") parts := root.Get("candidates.0.content.parts")
var contentBlocks []interface{}
textBuilder := strings.Builder{} textBuilder := strings.Builder{}
thinkingBuilder := strings.Builder{} thinkingBuilder := strings.Builder{}
toolIDCounter := 0 toolIDCounter := 0
@@ -307,10 +300,9 @@ func ConvertGeminiResponseToClaudeNonStream(_ context.Context, _ string, origina
if textBuilder.Len() == 0 { if textBuilder.Len() == 0 {
return return
} }
contentBlocks = append(contentBlocks, map[string]interface{}{ block := `{"type":"text","text":""}`
"type": "text", block, _ = sjson.Set(block, "text", textBuilder.String())
"text": textBuilder.String(), out, _ = sjson.SetRaw(out, "content.-1", block)
})
textBuilder.Reset() textBuilder.Reset()
} }
@@ -318,10 +310,9 @@ func ConvertGeminiResponseToClaudeNonStream(_ context.Context, _ string, origina
if thinkingBuilder.Len() == 0 { if thinkingBuilder.Len() == 0 {
return return
} }
contentBlocks = append(contentBlocks, map[string]interface{}{ block := `{"type":"thinking","thinking":""}`
"type": "thinking", block, _ = sjson.Set(block, "thinking", thinkingBuilder.String())
"thinking": thinkingBuilder.String(), out, _ = sjson.SetRaw(out, "content.-1", block)
})
thinkingBuilder.Reset() thinkingBuilder.Reset()
} }
@@ -345,21 +336,15 @@ func ConvertGeminiResponseToClaudeNonStream(_ context.Context, _ string, origina
name := functionCall.Get("name").String() name := functionCall.Get("name").String()
toolIDCounter++ toolIDCounter++
toolBlock := map[string]interface{}{ toolBlock := `{"type":"tool_use","id":"","name":"","input":{}}`
"type": "tool_use", toolBlock, _ = sjson.Set(toolBlock, "id", fmt.Sprintf("tool_%d", toolIDCounter))
"id": fmt.Sprintf("tool_%d", toolIDCounter), toolBlock, _ = sjson.Set(toolBlock, "name", name)
"name": name, inputRaw := "{}"
"input": map[string]interface{}{}, if args := functionCall.Get("args"); args.Exists() && gjson.Valid(args.Raw) && args.IsObject() {
inputRaw = args.Raw
} }
toolBlock, _ = sjson.SetRaw(toolBlock, "input", inputRaw)
if args := functionCall.Get("args"); args.Exists() { out, _ = sjson.SetRaw(out, "content.-1", toolBlock)
var parsed interface{}
if err := json.Unmarshal([]byte(args.Raw), &parsed); err == nil {
toolBlock["input"] = parsed
}
}
contentBlocks = append(contentBlocks, toolBlock)
continue continue
} }
} }
@@ -368,8 +353,6 @@ func ConvertGeminiResponseToClaudeNonStream(_ context.Context, _ string, origina
flushThinking() flushThinking()
flushText() flushText()
response["content"] = contentBlocks
stopReason := "end_turn" stopReason := "end_turn"
if hasToolCall { if hasToolCall {
stopReason = "tool_use" stopReason = "tool_use"
@@ -385,19 +368,13 @@ func ConvertGeminiResponseToClaudeNonStream(_ context.Context, _ string, origina
} }
} }
} }
response["stop_reason"] = stopReason out, _ = sjson.Set(out, "stop_reason", stopReason)
if usage := response["usage"].(map[string]interface{}); usage["input_tokens"] == int64(0) && usage["output_tokens"] == int64(0) { if inputTokens == int64(0) && outputTokens == int64(0) && !root.Get("usageMetadata").Exists() {
if usageMeta := root.Get("usageMetadata"); !usageMeta.Exists() { out, _ = sjson.Delete(out, "usage")
delete(response, "usage")
}
} }
encoded, err := json.Marshal(response) return out
if err != nil {
return ""
}
return string(encoded)
} }
func ClaudeTokenCount(ctx context.Context, count int64) string { func ClaudeTokenCount(ctx context.Context, count int64) string {

View File

@@ -178,6 +178,14 @@ func ConvertOpenAIRequestToGemini(modelName string, inputRawJSON []byte, _ bool)
} else if content.IsObject() && content.Get("type").String() == "text" { } else if content.IsObject() && content.Get("type").String() == "text" {
out, _ = sjson.SetBytes(out, "system_instruction.role", "user") out, _ = sjson.SetBytes(out, "system_instruction.role", "user")
out, _ = sjson.SetBytes(out, "system_instruction.parts.0.text", content.Get("text").String()) out, _ = sjson.SetBytes(out, "system_instruction.parts.0.text", content.Get("text").String())
} else if content.IsArray() {
contents := content.Array()
if len(contents) > 0 {
out, _ = sjson.SetBytes(out, "request.systemInstruction.role", "user")
for j := 0; j < len(contents); j++ {
out, _ = sjson.SetBytes(out, fmt.Sprintf("request.systemInstruction.parts.%d.text", j), contents[j].Get("text").String())
}
}
} }
} else if role == "user" || (role == "system" && len(arr) == 1) { } else if role == "user" || (role == "system" && len(arr) == 1) {
// Build single user content node to avoid splitting into multiple contents // Build single user content node to avoid splitting into multiple contents
@@ -320,7 +328,7 @@ func ConvertOpenAIRequestToGemini(modelName string, inputRawJSON []byte, _ bool)
log.Warnf("Failed to set default schema type for tool '%s': %v", fn.Get("name").String(), errSet) log.Warnf("Failed to set default schema type for tool '%s': %v", fn.Get("name").String(), errSet)
continue continue
} }
fnRaw, errSet = sjson.Set(fnRaw, "parametersJsonSchema.properties", map[string]interface{}{}) fnRaw, errSet = sjson.SetRaw(fnRaw, "parametersJsonSchema.properties", `{}`)
if errSet != nil { if errSet != nil {
log.Warnf("Failed to set default schema properties for tool '%s': %v", fn.Get("name").String(), errSet) log.Warnf("Failed to set default schema properties for tool '%s': %v", fn.Get("name").String(), errSet)
continue continue
@@ -335,7 +343,7 @@ func ConvertOpenAIRequestToGemini(modelName string, inputRawJSON []byte, _ bool)
log.Warnf("Failed to set default schema type for tool '%s': %v", fn.Get("name").String(), errSet) log.Warnf("Failed to set default schema type for tool '%s': %v", fn.Get("name").String(), errSet)
continue continue
} }
fnRaw, errSet = sjson.Set(fnRaw, "parametersJsonSchema.properties", map[string]interface{}{}) fnRaw, errSet = sjson.SetRaw(fnRaw, "parametersJsonSchema.properties", `{}`)
if errSet != nil { if errSet != nil {
log.Warnf("Failed to set default schema properties for tool '%s': %v", fn.Get("name").String(), errSet) log.Warnf("Failed to set default schema properties for tool '%s': %v", fn.Get("name").String(), errSet)
continue continue

View File

@@ -8,7 +8,6 @@ package chat_completions
import ( import (
"bytes" "bytes"
"context" "context"
"encoding/json"
"fmt" "fmt"
"strings" "strings"
"sync/atomic" "sync/atomic"
@@ -173,21 +172,14 @@ func ConvertGeminiResponseToOpenAI(_ context.Context, _ string, originalRequestR
mimeType = "image/png" mimeType = "image/png"
} }
imageURL := fmt.Sprintf("data:%s;base64,%s", mimeType, data) imageURL := fmt.Sprintf("data:%s;base64,%s", mimeType, data)
imagePayload, err := json.Marshal(map[string]any{ imagePayload := `{"image_url":{"url":""},"type":"image_url"}`
"type": "image_url", imagePayload, _ = sjson.Set(imagePayload, "image_url.url", imageURL)
"image_url": map[string]string{
"url": imageURL,
},
})
if err != nil {
continue
}
imagesResult := gjson.Get(template, "choices.0.delta.images") imagesResult := gjson.Get(template, "choices.0.delta.images")
if !imagesResult.Exists() || !imagesResult.IsArray() { if !imagesResult.Exists() || !imagesResult.IsArray() {
template, _ = sjson.SetRaw(template, "choices.0.delta.images", `[]`) template, _ = sjson.SetRaw(template, "choices.0.delta.images", `[]`)
} }
template, _ = sjson.Set(template, "choices.0.delta.role", "assistant") template, _ = sjson.Set(template, "choices.0.delta.role", "assistant")
template, _ = sjson.SetRaw(template, "choices.0.delta.images.-1", string(imagePayload)) template, _ = sjson.SetRaw(template, "choices.0.delta.images.-1", imagePayload)
} }
} }
} }
@@ -305,21 +297,14 @@ func ConvertGeminiResponseToOpenAINonStream(_ context.Context, _ string, origina
mimeType = "image/png" mimeType = "image/png"
} }
imageURL := fmt.Sprintf("data:%s;base64,%s", mimeType, data) imageURL := fmt.Sprintf("data:%s;base64,%s", mimeType, data)
imagePayload, err := json.Marshal(map[string]any{ imagePayload := `{"image_url":{"url":""},"type":"image_url"}`
"type": "image_url", imagePayload, _ = sjson.Set(imagePayload, "image_url.url", imageURL)
"image_url": map[string]string{
"url": imageURL,
},
})
if err != nil {
continue
}
imagesResult := gjson.Get(template, "choices.0.message.images") imagesResult := gjson.Get(template, "choices.0.message.images")
if !imagesResult.Exists() || !imagesResult.IsArray() { if !imagesResult.Exists() || !imagesResult.IsArray() {
template, _ = sjson.SetRaw(template, "choices.0.message.images", `[]`) template, _ = sjson.SetRaw(template, "choices.0.message.images", `[]`)
} }
template, _ = sjson.Set(template, "choices.0.message.role", "assistant") template, _ = sjson.Set(template, "choices.0.message.role", "assistant")
template, _ = sjson.SetRaw(template, "choices.0.message.images.-1", string(imagePayload)) template, _ = sjson.SetRaw(template, "choices.0.message.images.-1", imagePayload)
} }
} }
} }

View File

@@ -377,27 +377,18 @@ func ConvertGeminiResponseToOpenAIResponses(_ context.Context, modelName string,
} }
// Compose outputs in encountered order: reasoning, message, function_calls // Compose outputs in encountered order: reasoning, message, function_calls
var outputs []interface{} outputsWrapper := `{"arr":[]}`
if st.ReasoningOpened { if st.ReasoningOpened {
outputs = append(outputs, map[string]interface{}{ item := `{"id":"","type":"reasoning","summary":[{"type":"summary_text","text":""}]}`
"id": st.ReasoningItemID, item, _ = sjson.Set(item, "id", st.ReasoningItemID)
"type": "reasoning", item, _ = sjson.Set(item, "summary.0.text", st.ReasoningBuf.String())
"summary": []interface{}{map[string]interface{}{"type": "summary_text", "text": st.ReasoningBuf.String()}}, outputsWrapper, _ = sjson.SetRaw(outputsWrapper, "arr.-1", item)
})
} }
if st.MsgOpened { if st.MsgOpened {
outputs = append(outputs, map[string]interface{}{ item := `{"id":"","type":"message","status":"completed","content":[{"type":"output_text","annotations":[],"logprobs":[],"text":""}],"role":"assistant"}`
"id": st.CurrentMsgID, item, _ = sjson.Set(item, "id", st.CurrentMsgID)
"type": "message", item, _ = sjson.Set(item, "content.0.text", st.TextBuf.String())
"status": "completed", outputsWrapper, _ = sjson.SetRaw(outputsWrapper, "arr.-1", item)
"content": []interface{}{map[string]interface{}{
"type": "output_text",
"annotations": []interface{}{},
"logprobs": []interface{}{},
"text": st.TextBuf.String(),
}},
"role": "assistant",
})
} }
if len(st.FuncArgsBuf) > 0 { if len(st.FuncArgsBuf) > 0 {
idxs := make([]int, 0, len(st.FuncArgsBuf)) idxs := make([]int, 0, len(st.FuncArgsBuf))
@@ -416,18 +407,16 @@ func ConvertGeminiResponseToOpenAIResponses(_ context.Context, modelName string,
if b := st.FuncArgsBuf[idx]; b != nil { if b := st.FuncArgsBuf[idx]; b != nil {
args = b.String() args = b.String()
} }
outputs = append(outputs, map[string]interface{}{ item := `{"id":"","type":"function_call","status":"completed","arguments":"","call_id":"","name":""}`
"id": fmt.Sprintf("fc_%s", st.FuncCallIDs[idx]), item, _ = sjson.Set(item, "id", fmt.Sprintf("fc_%s", st.FuncCallIDs[idx]))
"type": "function_call", item, _ = sjson.Set(item, "arguments", args)
"status": "completed", item, _ = sjson.Set(item, "call_id", st.FuncCallIDs[idx])
"arguments": args, item, _ = sjson.Set(item, "name", st.FuncNames[idx])
"call_id": st.FuncCallIDs[idx], outputsWrapper, _ = sjson.SetRaw(outputsWrapper, "arr.-1", item)
"name": st.FuncNames[idx],
})
} }
} }
if len(outputs) > 0 { if gjson.Get(outputsWrapper, "arr.#").Int() > 0 {
completed, _ = sjson.Set(completed, "response.output", outputs) completed, _ = sjson.SetRaw(completed, "response.output", gjson.Get(outputsWrapper, "arr").Raw)
} }
// usage mapping // usage mapping
@@ -558,11 +547,24 @@ func ConvertGeminiResponseToOpenAIResponsesNonStream(_ context.Context, _ string
} }
// Build outputs from candidates[0].content.parts // Build outputs from candidates[0].content.parts
var outputs []interface{}
var reasoningText strings.Builder var reasoningText strings.Builder
var reasoningEncrypted string var reasoningEncrypted string
var messageText strings.Builder var messageText strings.Builder
var haveMessage bool var haveMessage bool
haveOutput := false
ensureOutput := func() {
if haveOutput {
return
}
resp, _ = sjson.SetRaw(resp, "output", "[]")
haveOutput = true
}
appendOutput := func(itemJSON string) {
ensureOutput()
resp, _ = sjson.SetRaw(resp, "output.-1", itemJSON)
}
if parts := root.Get("candidates.0.content.parts"); parts.Exists() && parts.IsArray() { if parts := root.Get("candidates.0.content.parts"); parts.Exists() && parts.IsArray() {
parts.ForEach(func(_, p gjson.Result) bool { parts.ForEach(func(_, p gjson.Result) bool {
if p.Get("thought").Bool() { if p.Get("thought").Bool() {
@@ -583,19 +585,16 @@ func ConvertGeminiResponseToOpenAIResponsesNonStream(_ context.Context, _ string
name := fc.Get("name").String() name := fc.Get("name").String()
args := fc.Get("args") args := fc.Get("args")
callID := fmt.Sprintf("call_%x_%d", time.Now().UnixNano(), atomic.AddUint64(&funcCallIDCounter, 1)) callID := fmt.Sprintf("call_%x_%d", time.Now().UnixNano(), atomic.AddUint64(&funcCallIDCounter, 1))
outputs = append(outputs, map[string]interface{}{ itemJSON := `{"id":"","type":"function_call","status":"completed","arguments":"","call_id":"","name":""}`
"id": fmt.Sprintf("fc_%s", callID), itemJSON, _ = sjson.Set(itemJSON, "id", fmt.Sprintf("fc_%s", callID))
"type": "function_call", itemJSON, _ = sjson.Set(itemJSON, "call_id", callID)
"status": "completed", itemJSON, _ = sjson.Set(itemJSON, "name", name)
"arguments": func() string { argsStr := ""
if args.Exists() { if args.Exists() {
return args.Raw argsStr = args.Raw
} }
return "" itemJSON, _ = sjson.Set(itemJSON, "arguments", argsStr)
}(), appendOutput(itemJSON)
"call_id": callID,
"name": name,
})
return true return true
} }
return true return true
@@ -605,42 +604,24 @@ func ConvertGeminiResponseToOpenAIResponsesNonStream(_ context.Context, _ string
// Reasoning output item // Reasoning output item
if reasoningText.Len() > 0 || reasoningEncrypted != "" { if reasoningText.Len() > 0 || reasoningEncrypted != "" {
rid := strings.TrimPrefix(id, "resp_") rid := strings.TrimPrefix(id, "resp_")
item := map[string]interface{}{ itemJSON := `{"id":"","type":"reasoning","encrypted_content":""}`
"id": fmt.Sprintf("rs_%s", rid), itemJSON, _ = sjson.Set(itemJSON, "id", fmt.Sprintf("rs_%s", rid))
"type": "reasoning", itemJSON, _ = sjson.Set(itemJSON, "encrypted_content", reasoningEncrypted)
"encrypted_content": reasoningEncrypted,
}
var summaries []interface{}
if reasoningText.Len() > 0 { if reasoningText.Len() > 0 {
summaries = append(summaries, map[string]interface{}{ summaryJSON := `{"type":"summary_text","text":""}`
"type": "summary_text", summaryJSON, _ = sjson.Set(summaryJSON, "text", reasoningText.String())
"text": reasoningText.String(), itemJSON, _ = sjson.SetRaw(itemJSON, "summary", "[]")
}) itemJSON, _ = sjson.SetRaw(itemJSON, "summary.-1", summaryJSON)
} }
if summaries != nil { appendOutput(itemJSON)
item["summary"] = summaries
}
outputs = append(outputs, item)
} }
// Assistant message output item // Assistant message output item
if haveMessage { if haveMessage {
outputs = append(outputs, map[string]interface{}{ itemJSON := `{"id":"","type":"message","status":"completed","content":[{"type":"output_text","annotations":[],"logprobs":[],"text":""}],"role":"assistant"}`
"id": fmt.Sprintf("msg_%s_0", strings.TrimPrefix(id, "resp_")), itemJSON, _ = sjson.Set(itemJSON, "id", fmt.Sprintf("msg_%s_0", strings.TrimPrefix(id, "resp_")))
"type": "message", itemJSON, _ = sjson.Set(itemJSON, "content.0.text", messageText.String())
"status": "completed", appendOutput(itemJSON)
"content": []interface{}{map[string]interface{}{
"type": "output_text",
"annotations": []interface{}{},
"logprobs": []interface{}{},
"text": messageText.String(),
}},
"role": "assistant",
})
}
if len(outputs) > 0 {
resp, _ = sjson.Set(resp, "output", outputs)
} }
// usage mapping // usage mapping

View File

@@ -7,7 +7,6 @@ package claude
import ( import (
"bytes" "bytes"
"encoding/json"
"strings" "strings"
"github.com/router-for-me/CLIProxyAPI/v6/internal/util" "github.com/router-for-me/CLIProxyAPI/v6/internal/util"
@@ -138,11 +137,7 @@ func ConvertClaudeRequestToOpenAI(modelName string, inputRawJSON []byte, stream
// Convert input to arguments JSON string // Convert input to arguments JSON string
if input := part.Get("input"); input.Exists() { if input := part.Get("input"); input.Exists() {
if inputJSON, err := json.Marshal(input.Value()); err == nil { toolCallJSON, _ = sjson.Set(toolCallJSON, "function.arguments", input.Raw)
toolCallJSON, _ = sjson.Set(toolCallJSON, "function.arguments", string(inputJSON))
} else {
toolCallJSON, _ = sjson.Set(toolCallJSON, "function.arguments", "{}")
}
} else { } else {
toolCallJSON, _ = sjson.Set(toolCallJSON, "function.arguments", "{}") toolCallJSON, _ = sjson.Set(toolCallJSON, "function.arguments", "{}")
} }
@@ -191,8 +186,7 @@ func ConvertClaudeRequestToOpenAI(modelName string, inputRawJSON []byte, stream
// Emit tool calls in a separate assistant message // Emit tool calls in a separate assistant message
if role == "assistant" && len(toolCalls) > 0 { if role == "assistant" && len(toolCalls) > 0 {
toolCallMsgJSON := `{"role":"assistant","tool_calls":[]}` toolCallMsgJSON := `{"role":"assistant","tool_calls":[]}`
toolCallsJSON, _ := json.Marshal(toolCalls) toolCallMsgJSON, _ = sjson.Set(toolCallMsgJSON, "tool_calls", toolCalls)
toolCallMsgJSON, _ = sjson.SetRaw(toolCallMsgJSON, "tool_calls", string(toolCallsJSON))
messagesJSON, _ = sjson.Set(messagesJSON, "-1", gjson.Parse(toolCallMsgJSON).Value()) messagesJSON, _ = sjson.Set(messagesJSON, "-1", gjson.Parse(toolCallMsgJSON).Value())
} }

View File

@@ -8,7 +8,6 @@ package claude
import ( import (
"bytes" "bytes"
"context" "context"
"encoding/json"
"fmt" "fmt"
"strings" "strings"
@@ -133,24 +132,10 @@ func convertOpenAIStreamingChunkToAnthropic(rawJSON []byte, param *ConvertOpenAI
if delta := root.Get("choices.0.delta"); delta.Exists() { if delta := root.Get("choices.0.delta"); delta.Exists() {
if !param.MessageStarted { if !param.MessageStarted {
// Send message_start event // Send message_start event
messageStart := map[string]interface{}{ messageStartJSON := `{"type":"message_start","message":{"id":"","type":"message","role":"assistant","model":"","content":[],"stop_reason":null,"stop_sequence":null,"usage":{"input_tokens":0,"output_tokens":0}}}`
"type": "message_start", messageStartJSON, _ = sjson.Set(messageStartJSON, "message.id", param.MessageID)
"message": map[string]interface{}{ messageStartJSON, _ = sjson.Set(messageStartJSON, "message.model", param.Model)
"id": param.MessageID, results = append(results, "event: message_start\ndata: "+messageStartJSON+"\n\n")
"type": "message",
"role": "assistant",
"model": param.Model,
"content": []interface{}{},
"stop_reason": nil,
"stop_sequence": nil,
"usage": map[string]interface{}{
"input_tokens": 0,
"output_tokens": 0,
},
},
}
messageStartJSON, _ := json.Marshal(messageStart)
results = append(results, "event: message_start\ndata: "+string(messageStartJSON)+"\n\n")
param.MessageStarted = true param.MessageStarted = true
// Don't send content_block_start for text here - wait for actual content // Don't send content_block_start for text here - wait for actual content
@@ -168,29 +153,16 @@ func convertOpenAIStreamingChunkToAnthropic(rawJSON []byte, param *ConvertOpenAI
param.ThinkingContentBlockIndex = param.NextContentBlockIndex param.ThinkingContentBlockIndex = param.NextContentBlockIndex
param.NextContentBlockIndex++ param.NextContentBlockIndex++
} }
contentBlockStart := map[string]interface{}{ contentBlockStartJSON := `{"type":"content_block_start","index":0,"content_block":{"type":"thinking","thinking":""}}`
"type": "content_block_start", contentBlockStartJSON, _ = sjson.Set(contentBlockStartJSON, "index", param.ThinkingContentBlockIndex)
"index": param.ThinkingContentBlockIndex, results = append(results, "event: content_block_start\ndata: "+contentBlockStartJSON+"\n\n")
"content_block": map[string]interface{}{
"type": "thinking",
"thinking": "",
},
}
contentBlockStartJSON, _ := json.Marshal(contentBlockStart)
results = append(results, "event: content_block_start\ndata: "+string(contentBlockStartJSON)+"\n\n")
param.ThinkingContentBlockStarted = true param.ThinkingContentBlockStarted = true
} }
thinkingDelta := map[string]interface{}{ thinkingDeltaJSON := `{"type":"content_block_delta","index":0,"delta":{"type":"thinking_delta","thinking":""}}`
"type": "content_block_delta", thinkingDeltaJSON, _ = sjson.Set(thinkingDeltaJSON, "index", param.ThinkingContentBlockIndex)
"index": param.ThinkingContentBlockIndex, thinkingDeltaJSON, _ = sjson.Set(thinkingDeltaJSON, "delta.thinking", reasoningText)
"delta": map[string]interface{}{ results = append(results, "event: content_block_delta\ndata: "+thinkingDeltaJSON+"\n\n")
"type": "thinking_delta",
"thinking": reasoningText,
},
}
thinkingDeltaJSON, _ := json.Marshal(thinkingDelta)
results = append(results, "event: content_block_delta\ndata: "+string(thinkingDeltaJSON)+"\n\n")
} }
} }
@@ -203,29 +175,16 @@ func convertOpenAIStreamingChunkToAnthropic(rawJSON []byte, param *ConvertOpenAI
param.TextContentBlockIndex = param.NextContentBlockIndex param.TextContentBlockIndex = param.NextContentBlockIndex
param.NextContentBlockIndex++ param.NextContentBlockIndex++
} }
contentBlockStart := map[string]interface{}{ contentBlockStartJSON := `{"type":"content_block_start","index":0,"content_block":{"type":"text","text":""}}`
"type": "content_block_start", contentBlockStartJSON, _ = sjson.Set(contentBlockStartJSON, "index", param.TextContentBlockIndex)
"index": param.TextContentBlockIndex, results = append(results, "event: content_block_start\ndata: "+contentBlockStartJSON+"\n\n")
"content_block": map[string]interface{}{
"type": "text",
"text": "",
},
}
contentBlockStartJSON, _ := json.Marshal(contentBlockStart)
results = append(results, "event: content_block_start\ndata: "+string(contentBlockStartJSON)+"\n\n")
param.TextContentBlockStarted = true param.TextContentBlockStarted = true
} }
contentDelta := map[string]interface{}{ contentDeltaJSON := `{"type":"content_block_delta","index":0,"delta":{"type":"text_delta","text":""}}`
"type": "content_block_delta", contentDeltaJSON, _ = sjson.Set(contentDeltaJSON, "index", param.TextContentBlockIndex)
"index": param.TextContentBlockIndex, contentDeltaJSON, _ = sjson.Set(contentDeltaJSON, "delta.text", content.String())
"delta": map[string]interface{}{ results = append(results, "event: content_block_delta\ndata: "+contentDeltaJSON+"\n\n")
"type": "text_delta",
"text": content.String(),
},
}
contentDeltaJSON, _ := json.Marshal(contentDelta)
results = append(results, "event: content_block_delta\ndata: "+string(contentDeltaJSON)+"\n\n")
// Accumulate content // Accumulate content
param.ContentAccumulator.WriteString(content.String()) param.ContentAccumulator.WriteString(content.String())
@@ -263,18 +222,11 @@ func convertOpenAIStreamingChunkToAnthropic(rawJSON []byte, param *ConvertOpenAI
stopTextContentBlock(param, &results) stopTextContentBlock(param, &results)
// Send content_block_start for tool_use // Send content_block_start for tool_use
contentBlockStart := map[string]interface{}{ contentBlockStartJSON := `{"type":"content_block_start","index":0,"content_block":{"type":"tool_use","id":"","name":"","input":{}}}`
"type": "content_block_start", contentBlockStartJSON, _ = sjson.Set(contentBlockStartJSON, "index", blockIndex)
"index": blockIndex, contentBlockStartJSON, _ = sjson.Set(contentBlockStartJSON, "content_block.id", accumulator.ID)
"content_block": map[string]interface{}{ contentBlockStartJSON, _ = sjson.Set(contentBlockStartJSON, "content_block.name", accumulator.Name)
"type": "tool_use", results = append(results, "event: content_block_start\ndata: "+contentBlockStartJSON+"\n\n")
"id": accumulator.ID,
"name": accumulator.Name,
"input": map[string]interface{}{},
},
}
contentBlockStartJSON, _ := json.Marshal(contentBlockStart)
results = append(results, "event: content_block_start\ndata: "+string(contentBlockStartJSON)+"\n\n")
} }
// Handle function arguments // Handle function arguments
@@ -298,12 +250,9 @@ func convertOpenAIStreamingChunkToAnthropic(rawJSON []byte, param *ConvertOpenAI
// Send content_block_stop for thinking content if needed // Send content_block_stop for thinking content if needed
if param.ThinkingContentBlockStarted { if param.ThinkingContentBlockStarted {
contentBlockStop := map[string]interface{}{ contentBlockStopJSON := `{"type":"content_block_stop","index":0}`
"type": "content_block_stop", contentBlockStopJSON, _ = sjson.Set(contentBlockStopJSON, "index", param.ThinkingContentBlockIndex)
"index": param.ThinkingContentBlockIndex, results = append(results, "event: content_block_stop\ndata: "+contentBlockStopJSON+"\n\n")
}
contentBlockStopJSON, _ := json.Marshal(contentBlockStop)
results = append(results, "event: content_block_stop\ndata: "+string(contentBlockStopJSON)+"\n\n")
param.ThinkingContentBlockStarted = false param.ThinkingContentBlockStarted = false
param.ThinkingContentBlockIndex = -1 param.ThinkingContentBlockIndex = -1
} }
@@ -319,24 +268,15 @@ func convertOpenAIStreamingChunkToAnthropic(rawJSON []byte, param *ConvertOpenAI
// Send complete input_json_delta with all accumulated arguments // Send complete input_json_delta with all accumulated arguments
if accumulator.Arguments.Len() > 0 { if accumulator.Arguments.Len() > 0 {
inputDelta := map[string]interface{}{ inputDeltaJSON := `{"type":"content_block_delta","index":0,"delta":{"type":"input_json_delta","partial_json":""}}`
"type": "content_block_delta", inputDeltaJSON, _ = sjson.Set(inputDeltaJSON, "index", blockIndex)
"index": blockIndex, inputDeltaJSON, _ = sjson.Set(inputDeltaJSON, "delta.partial_json", util.FixJSON(accumulator.Arguments.String()))
"delta": map[string]interface{}{ results = append(results, "event: content_block_delta\ndata: "+inputDeltaJSON+"\n\n")
"type": "input_json_delta",
"partial_json": util.FixJSON(accumulator.Arguments.String()),
},
}
inputDeltaJSON, _ := json.Marshal(inputDelta)
results = append(results, "event: content_block_delta\ndata: "+string(inputDeltaJSON)+"\n\n")
} }
contentBlockStop := map[string]interface{}{ contentBlockStopJSON := `{"type":"content_block_stop","index":0}`
"type": "content_block_stop", contentBlockStopJSON, _ = sjson.Set(contentBlockStopJSON, "index", blockIndex)
"index": blockIndex, results = append(results, "event: content_block_stop\ndata: "+contentBlockStopJSON+"\n\n")
}
contentBlockStopJSON, _ := json.Marshal(contentBlockStop)
results = append(results, "event: content_block_stop\ndata: "+string(contentBlockStopJSON)+"\n\n")
delete(param.ToolCallBlockIndexes, index) delete(param.ToolCallBlockIndexes, index)
} }
param.ContentBlocksStopped = true param.ContentBlocksStopped = true
@@ -361,20 +301,11 @@ func convertOpenAIStreamingChunkToAnthropic(rawJSON []byte, param *ConvertOpenAI
} }
} }
// Send message_delta with usage // Send message_delta with usage
messageDelta := map[string]interface{}{ messageDeltaJSON := `{"type":"message_delta","delta":{"stop_reason":"","stop_sequence":null},"usage":{"input_tokens":0,"output_tokens":0}}`
"type": "message_delta", messageDeltaJSON, _ = sjson.Set(messageDeltaJSON, "delta.stop_reason", mapOpenAIFinishReasonToAnthropic(param.FinishReason))
"delta": map[string]interface{}{ messageDeltaJSON, _ = sjson.Set(messageDeltaJSON, "usage.input_tokens", inputTokens)
"stop_reason": mapOpenAIFinishReasonToAnthropic(param.FinishReason), messageDeltaJSON, _ = sjson.Set(messageDeltaJSON, "usage.output_tokens", outputTokens)
"stop_sequence": nil, results = append(results, "event: message_delta\ndata: "+messageDeltaJSON+"\n\n")
},
"usage": map[string]interface{}{
"input_tokens": inputTokens,
"output_tokens": outputTokens,
},
}
messageDeltaJSON, _ := json.Marshal(messageDelta)
results = append(results, "event: message_delta\ndata: "+string(messageDeltaJSON)+"\n\n")
param.MessageDeltaSent = true param.MessageDeltaSent = true
emitMessageStopIfNeeded(param, &results) emitMessageStopIfNeeded(param, &results)
@@ -390,12 +321,9 @@ func convertOpenAIDoneToAnthropic(param *ConvertOpenAIResponseToAnthropicParams)
// Ensure all content blocks are stopped before final events // Ensure all content blocks are stopped before final events
if param.ThinkingContentBlockStarted { if param.ThinkingContentBlockStarted {
contentBlockStop := map[string]interface{}{ contentBlockStopJSON := `{"type":"content_block_stop","index":0}`
"type": "content_block_stop", contentBlockStopJSON, _ = sjson.Set(contentBlockStopJSON, "index", param.ThinkingContentBlockIndex)
"index": param.ThinkingContentBlockIndex, results = append(results, "event: content_block_stop\ndata: "+contentBlockStopJSON+"\n\n")
}
contentBlockStopJSON, _ := json.Marshal(contentBlockStop)
results = append(results, "event: content_block_stop\ndata: "+string(contentBlockStopJSON)+"\n\n")
param.ThinkingContentBlockStarted = false param.ThinkingContentBlockStarted = false
param.ThinkingContentBlockIndex = -1 param.ThinkingContentBlockIndex = -1
} }
@@ -408,24 +336,15 @@ func convertOpenAIDoneToAnthropic(param *ConvertOpenAIResponseToAnthropicParams)
blockIndex := param.toolContentBlockIndex(index) blockIndex := param.toolContentBlockIndex(index)
if accumulator.Arguments.Len() > 0 { if accumulator.Arguments.Len() > 0 {
inputDelta := map[string]interface{}{ inputDeltaJSON := `{"type":"content_block_delta","index":0,"delta":{"type":"input_json_delta","partial_json":""}}`
"type": "content_block_delta", inputDeltaJSON, _ = sjson.Set(inputDeltaJSON, "index", blockIndex)
"index": blockIndex, inputDeltaJSON, _ = sjson.Set(inputDeltaJSON, "delta.partial_json", util.FixJSON(accumulator.Arguments.String()))
"delta": map[string]interface{}{ results = append(results, "event: content_block_delta\ndata: "+inputDeltaJSON+"\n\n")
"type": "input_json_delta",
"partial_json": util.FixJSON(accumulator.Arguments.String()),
},
}
inputDeltaJSON, _ := json.Marshal(inputDelta)
results = append(results, "event: content_block_delta\ndata: "+string(inputDeltaJSON)+"\n\n")
} }
contentBlockStop := map[string]interface{}{ contentBlockStopJSON := `{"type":"content_block_stop","index":0}`
"type": "content_block_stop", contentBlockStopJSON, _ = sjson.Set(contentBlockStopJSON, "index", blockIndex)
"index": blockIndex, results = append(results, "event: content_block_stop\ndata: "+contentBlockStopJSON+"\n\n")
}
contentBlockStopJSON, _ := json.Marshal(contentBlockStop)
results = append(results, "event: content_block_stop\ndata: "+string(contentBlockStopJSON)+"\n\n")
delete(param.ToolCallBlockIndexes, index) delete(param.ToolCallBlockIndexes, index)
} }
param.ContentBlocksStopped = true param.ContentBlocksStopped = true
@@ -433,16 +352,9 @@ func convertOpenAIDoneToAnthropic(param *ConvertOpenAIResponseToAnthropicParams)
// If we haven't sent message_delta yet (no usage info was received), send it now // If we haven't sent message_delta yet (no usage info was received), send it now
if param.FinishReason != "" && !param.MessageDeltaSent { if param.FinishReason != "" && !param.MessageDeltaSent {
messageDelta := map[string]interface{}{ messageDeltaJSON := `{"type":"message_delta","delta":{"stop_reason":"","stop_sequence":null}}`
"type": "message_delta", messageDeltaJSON, _ = sjson.Set(messageDeltaJSON, "delta.stop_reason", mapOpenAIFinishReasonToAnthropic(param.FinishReason))
"delta": map[string]interface{}{ results = append(results, "event: message_delta\ndata: "+messageDeltaJSON+"\n\n")
"stop_reason": mapOpenAIFinishReasonToAnthropic(param.FinishReason),
"stop_sequence": nil,
},
}
messageDeltaJSON, _ := json.Marshal(messageDelta)
results = append(results, "event: message_delta\ndata: "+string(messageDeltaJSON)+"\n\n")
param.MessageDeltaSent = true param.MessageDeltaSent = true
} }
@@ -455,105 +367,73 @@ func convertOpenAIDoneToAnthropic(param *ConvertOpenAIResponseToAnthropicParams)
func convertOpenAINonStreamingToAnthropic(rawJSON []byte) []string { func convertOpenAINonStreamingToAnthropic(rawJSON []byte) []string {
root := gjson.ParseBytes(rawJSON) root := gjson.ParseBytes(rawJSON)
// Build Anthropic response out := `{"id":"","type":"message","role":"assistant","model":"","content":[],"stop_reason":null,"stop_sequence":null,"usage":{"input_tokens":0,"output_tokens":0}}`
response := map[string]interface{}{ out, _ = sjson.Set(out, "id", root.Get("id").String())
"id": root.Get("id").String(), out, _ = sjson.Set(out, "model", root.Get("model").String())
"type": "message",
"role": "assistant",
"model": root.Get("model").String(),
"content": []interface{}{},
"stop_reason": nil,
"stop_sequence": nil,
"usage": map[string]interface{}{
"input_tokens": 0,
"output_tokens": 0,
},
}
// Process message content and tool calls // Process message content and tool calls
var contentBlocks []interface{} if choices := root.Get("choices"); choices.Exists() && choices.IsArray() && len(choices.Array()) > 0 {
if choices := root.Get("choices"); choices.Exists() && choices.IsArray() {
choice := choices.Array()[0] // Take first choice choice := choices.Array()[0] // Take first choice
reasoningNode := choice.Get("message.reasoning_content")
allReasoning := collectOpenAIReasoningTexts(reasoningNode)
for _, reasoningText := range allReasoning { reasoningNode := choice.Get("message.reasoning_content")
for _, reasoningText := range collectOpenAIReasoningTexts(reasoningNode) {
if reasoningText == "" { if reasoningText == "" {
continue continue
} }
contentBlocks = append(contentBlocks, map[string]interface{}{ block := `{"type":"thinking","thinking":""}`
"type": "thinking", block, _ = sjson.Set(block, "thinking", reasoningText)
"thinking": reasoningText, out, _ = sjson.SetRaw(out, "content.-1", block)
})
} }
// Handle text content // Handle text content
if content := choice.Get("message.content"); content.Exists() && content.String() != "" { if content := choice.Get("message.content"); content.Exists() && content.String() != "" {
textBlock := map[string]interface{}{ block := `{"type":"text","text":""}`
"type": "text", block, _ = sjson.Set(block, "text", content.String())
"text": content.String(), out, _ = sjson.SetRaw(out, "content.-1", block)
}
contentBlocks = append(contentBlocks, textBlock)
} }
// Handle tool calls // Handle tool calls
if toolCalls := choice.Get("message.tool_calls"); toolCalls.Exists() && toolCalls.IsArray() { if toolCalls := choice.Get("message.tool_calls"); toolCalls.Exists() && toolCalls.IsArray() {
toolCalls.ForEach(func(_, toolCall gjson.Result) bool { toolCalls.ForEach(func(_, toolCall gjson.Result) bool {
toolUseBlock := map[string]interface{}{ toolUseBlock := `{"type":"tool_use","id":"","name":"","input":{}}`
"type": "tool_use", toolUseBlock, _ = sjson.Set(toolUseBlock, "id", toolCall.Get("id").String())
"id": toolCall.Get("id").String(), toolUseBlock, _ = sjson.Set(toolUseBlock, "name", toolCall.Get("function.name").String())
"name": toolCall.Get("function.name").String(),
}
// Parse arguments argsStr := util.FixJSON(toolCall.Get("function.arguments").String())
argsStr := toolCall.Get("function.arguments").String() if argsStr != "" && gjson.Valid(argsStr) {
argsStr = util.FixJSON(argsStr) argsJSON := gjson.Parse(argsStr)
if argsStr != "" { if argsJSON.IsObject() {
var args interface{} toolUseBlock, _ = sjson.SetRaw(toolUseBlock, "input", argsJSON.Raw)
if err := json.Unmarshal([]byte(argsStr), &args); err == nil {
toolUseBlock["input"] = args
} else { } else {
toolUseBlock["input"] = map[string]interface{}{} toolUseBlock, _ = sjson.SetRaw(toolUseBlock, "input", "{}")
} }
} else { } else {
toolUseBlock["input"] = map[string]interface{}{} toolUseBlock, _ = sjson.SetRaw(toolUseBlock, "input", "{}")
} }
contentBlocks = append(contentBlocks, toolUseBlock) out, _ = sjson.SetRaw(out, "content.-1", toolUseBlock)
return true return true
}) })
} }
// Set stop reason // Set stop reason
if finishReason := choice.Get("finish_reason"); finishReason.Exists() { if finishReason := choice.Get("finish_reason"); finishReason.Exists() {
response["stop_reason"] = mapOpenAIFinishReasonToAnthropic(finishReason.String()) out, _ = sjson.Set(out, "stop_reason", mapOpenAIFinishReasonToAnthropic(finishReason.String()))
} }
} }
response["content"] = contentBlocks
// Set usage information // Set usage information
if usage := root.Get("usage"); usage.Exists() { if usage := root.Get("usage"); usage.Exists() {
response["usage"] = map[string]interface{}{ out, _ = sjson.Set(out, "usage.input_tokens", usage.Get("prompt_tokens").Int())
"input_tokens": usage.Get("prompt_tokens").Int(), out, _ = sjson.Set(out, "usage.output_tokens", usage.Get("completion_tokens").Int())
"output_tokens": usage.Get("completion_tokens").Int(), reasoningTokens := int64(0)
"reasoning_tokens": func() int64 { if v := usage.Get("completion_tokens_details.reasoning_tokens"); v.Exists() {
if v := usage.Get("completion_tokens_details.reasoning_tokens"); v.Exists() { reasoningTokens = v.Int()
return v.Int()
}
return 0
}(),
}
} else {
response["usage"] = map[string]interface{}{
"input_tokens": 0,
"output_tokens": 0,
} }
out, _ = sjson.Set(out, "usage.reasoning_tokens", reasoningTokens)
} }
responseJSON, _ := json.Marshal(response) return []string{out}
return []string{string(responseJSON)}
} }
// mapOpenAIFinishReasonToAnthropic maps OpenAI finish reasons to Anthropic equivalents // mapOpenAIFinishReasonToAnthropic maps OpenAI finish reasons to Anthropic equivalents
@@ -620,12 +500,9 @@ func stopThinkingContentBlock(param *ConvertOpenAIResponseToAnthropicParams, res
if !param.ThinkingContentBlockStarted { if !param.ThinkingContentBlockStarted {
return return
} }
contentBlockStop := map[string]interface{}{ contentBlockStopJSON := `{"type":"content_block_stop","index":0}`
"type": "content_block_stop", contentBlockStopJSON, _ = sjson.Set(contentBlockStopJSON, "index", param.ThinkingContentBlockIndex)
"index": param.ThinkingContentBlockIndex, *results = append(*results, "event: content_block_stop\ndata: "+contentBlockStopJSON+"\n\n")
}
contentBlockStopJSON, _ := json.Marshal(contentBlockStop)
*results = append(*results, "event: content_block_stop\ndata: "+string(contentBlockStopJSON)+"\n\n")
param.ThinkingContentBlockStarted = false param.ThinkingContentBlockStarted = false
param.ThinkingContentBlockIndex = -1 param.ThinkingContentBlockIndex = -1
} }
@@ -642,12 +519,9 @@ func stopTextContentBlock(param *ConvertOpenAIResponseToAnthropicParams, results
if !param.TextContentBlockStarted { if !param.TextContentBlockStarted {
return return
} }
contentBlockStop := map[string]interface{}{ contentBlockStopJSON := `{"type":"content_block_stop","index":0}`
"type": "content_block_stop", contentBlockStopJSON, _ = sjson.Set(contentBlockStopJSON, "index", param.TextContentBlockIndex)
"index": param.TextContentBlockIndex, *results = append(*results, "event: content_block_stop\ndata: "+contentBlockStopJSON+"\n\n")
}
contentBlockStopJSON, _ := json.Marshal(contentBlockStop)
*results = append(*results, "event: content_block_stop\ndata: "+string(contentBlockStopJSON)+"\n\n")
param.TextContentBlockStarted = false param.TextContentBlockStarted = false
param.TextContentBlockIndex = -1 param.TextContentBlockIndex = -1
} }
@@ -667,29 +541,19 @@ func ConvertOpenAIResponseToClaudeNonStream(_ context.Context, _ string, origina
_ = requestRawJSON _ = requestRawJSON
root := gjson.ParseBytes(rawJSON) root := gjson.ParseBytes(rawJSON)
out := `{"id":"","type":"message","role":"assistant","model":"","content":[],"stop_reason":null,"stop_sequence":null,"usage":{"input_tokens":0,"output_tokens":0}}`
out, _ = sjson.Set(out, "id", root.Get("id").String())
out, _ = sjson.Set(out, "model", root.Get("model").String())
response := map[string]interface{}{
"id": root.Get("id").String(),
"type": "message",
"role": "assistant",
"model": root.Get("model").String(),
"content": []interface{}{},
"stop_reason": nil,
"stop_sequence": nil,
"usage": map[string]interface{}{
"input_tokens": 0,
"output_tokens": 0,
},
}
contentBlocks := make([]interface{}, 0)
hasToolCall := false hasToolCall := false
stopReasonSet := false
if choices := root.Get("choices"); choices.Exists() && choices.IsArray() && len(choices.Array()) > 0 { if choices := root.Get("choices"); choices.Exists() && choices.IsArray() && len(choices.Array()) > 0 {
choice := choices.Array()[0] choice := choices.Array()[0]
if finishReason := choice.Get("finish_reason"); finishReason.Exists() { if finishReason := choice.Get("finish_reason"); finishReason.Exists() {
response["stop_reason"] = mapOpenAIFinishReasonToAnthropic(finishReason.String()) out, _ = sjson.Set(out, "stop_reason", mapOpenAIFinishReasonToAnthropic(finishReason.String()))
stopReasonSet = true
} }
if message := choice.Get("message"); message.Exists() { if message := choice.Get("message"); message.Exists() {
@@ -702,10 +566,9 @@ func ConvertOpenAIResponseToClaudeNonStream(_ context.Context, _ string, origina
if textBuilder.Len() == 0 { if textBuilder.Len() == 0 {
return return
} }
contentBlocks = append(contentBlocks, map[string]interface{}{ block := `{"type":"text","text":""}`
"type": "text", block, _ = sjson.Set(block, "text", textBuilder.String())
"text": textBuilder.String(), out, _ = sjson.SetRaw(out, "content.-1", block)
})
textBuilder.Reset() textBuilder.Reset()
} }
@@ -713,16 +576,14 @@ func ConvertOpenAIResponseToClaudeNonStream(_ context.Context, _ string, origina
if thinkingBuilder.Len() == 0 { if thinkingBuilder.Len() == 0 {
return return
} }
contentBlocks = append(contentBlocks, map[string]interface{}{ block := `{"type":"thinking","thinking":""}`
"type": "thinking", block, _ = sjson.Set(block, "thinking", thinkingBuilder.String())
"thinking": thinkingBuilder.String(), out, _ = sjson.SetRaw(out, "content.-1", block)
})
thinkingBuilder.Reset() thinkingBuilder.Reset()
} }
for _, item := range contentResult.Array() { for _, item := range contentResult.Array() {
typeStr := item.Get("type").String() switch item.Get("type").String() {
switch typeStr {
case "text": case "text":
flushThinking() flushThinking()
textBuilder.WriteString(item.Get("text").String()) textBuilder.WriteString(item.Get("text").String())
@@ -733,25 +594,23 @@ func ConvertOpenAIResponseToClaudeNonStream(_ context.Context, _ string, origina
if toolCalls.IsArray() { if toolCalls.IsArray() {
toolCalls.ForEach(func(_, tc gjson.Result) bool { toolCalls.ForEach(func(_, tc gjson.Result) bool {
hasToolCall = true hasToolCall = true
toolUse := map[string]interface{}{ toolUse := `{"type":"tool_use","id":"","name":"","input":{}}`
"type": "tool_use", toolUse, _ = sjson.Set(toolUse, "id", tc.Get("id").String())
"id": tc.Get("id").String(), toolUse, _ = sjson.Set(toolUse, "name", tc.Get("function.name").String())
"name": tc.Get("function.name").String(),
}
argsStr := util.FixJSON(tc.Get("function.arguments").String()) argsStr := util.FixJSON(tc.Get("function.arguments").String())
if argsStr != "" { if argsStr != "" && gjson.Valid(argsStr) {
var parsed interface{} argsJSON := gjson.Parse(argsStr)
if err := json.Unmarshal([]byte(argsStr), &parsed); err == nil { if argsJSON.IsObject() {
toolUse["input"] = parsed toolUse, _ = sjson.SetRaw(toolUse, "input", argsJSON.Raw)
} else { } else {
toolUse["input"] = map[string]interface{}{} toolUse, _ = sjson.SetRaw(toolUse, "input", "{}")
} }
} else { } else {
toolUse["input"] = map[string]interface{}{} toolUse, _ = sjson.SetRaw(toolUse, "input", "{}")
} }
contentBlocks = append(contentBlocks, toolUse) out, _ = sjson.SetRaw(out, "content.-1", toolUse)
return true return true
}) })
} }
@@ -771,10 +630,9 @@ func ConvertOpenAIResponseToClaudeNonStream(_ context.Context, _ string, origina
} else if contentResult.Type == gjson.String { } else if contentResult.Type == gjson.String {
textContent := contentResult.String() textContent := contentResult.String()
if textContent != "" { if textContent != "" {
contentBlocks = append(contentBlocks, map[string]interface{}{ block := `{"type":"text","text":""}`
"type": "text", block, _ = sjson.Set(block, "text", textContent)
"text": textContent, out, _ = sjson.SetRaw(out, "content.-1", block)
})
} }
} }
} }
@@ -784,81 +642,52 @@ func ConvertOpenAIResponseToClaudeNonStream(_ context.Context, _ string, origina
if reasoningText == "" { if reasoningText == "" {
continue continue
} }
contentBlocks = append(contentBlocks, map[string]interface{}{ block := `{"type":"thinking","thinking":""}`
"type": "thinking", block, _ = sjson.Set(block, "thinking", reasoningText)
"thinking": reasoningText, out, _ = sjson.SetRaw(out, "content.-1", block)
})
} }
} }
if toolCalls := message.Get("tool_calls"); toolCalls.Exists() && toolCalls.IsArray() { if toolCalls := message.Get("tool_calls"); toolCalls.Exists() && toolCalls.IsArray() {
toolCalls.ForEach(func(_, toolCall gjson.Result) bool { toolCalls.ForEach(func(_, toolCall gjson.Result) bool {
hasToolCall = true hasToolCall = true
toolUseBlock := map[string]interface{}{ toolUseBlock := `{"type":"tool_use","id":"","name":"","input":{}}`
"type": "tool_use", toolUseBlock, _ = sjson.Set(toolUseBlock, "id", toolCall.Get("id").String())
"id": toolCall.Get("id").String(), toolUseBlock, _ = sjson.Set(toolUseBlock, "name", toolCall.Get("function.name").String())
"name": toolCall.Get("function.name").String(),
}
argsStr := toolCall.Get("function.arguments").String() argsStr := util.FixJSON(toolCall.Get("function.arguments").String())
argsStr = util.FixJSON(argsStr) if argsStr != "" && gjson.Valid(argsStr) {
if argsStr != "" { argsJSON := gjson.Parse(argsStr)
var args interface{} if argsJSON.IsObject() {
if err := json.Unmarshal([]byte(argsStr), &args); err == nil { toolUseBlock, _ = sjson.SetRaw(toolUseBlock, "input", argsJSON.Raw)
toolUseBlock["input"] = args
} else { } else {
toolUseBlock["input"] = map[string]interface{}{} toolUseBlock, _ = sjson.SetRaw(toolUseBlock, "input", "{}")
} }
} else { } else {
toolUseBlock["input"] = map[string]interface{}{} toolUseBlock, _ = sjson.SetRaw(toolUseBlock, "input", "{}")
} }
contentBlocks = append(contentBlocks, toolUseBlock) out, _ = sjson.SetRaw(out, "content.-1", toolUseBlock)
return true return true
}) })
} }
} }
} }
response["content"] = contentBlocks
if respUsage := root.Get("usage"); respUsage.Exists() { if respUsage := root.Get("usage"); respUsage.Exists() {
usageJSON := `{}` out, _ = sjson.Set(out, "usage.input_tokens", respUsage.Get("prompt_tokens").Int())
usageJSON, _ = sjson.Set(usageJSON, "input_tokens", respUsage.Get("prompt_tokens").Int()) out, _ = sjson.Set(out, "usage.output_tokens", respUsage.Get("completion_tokens").Int())
usageJSON, _ = sjson.Set(usageJSON, "output_tokens", respUsage.Get("completion_tokens").Int())
parsedUsage := gjson.Parse(usageJSON).Value().(map[string]interface{})
response["usage"] = parsedUsage
} else {
response["usage"] = `{"input_tokens":0,"output_tokens":0}`
} }
if response["stop_reason"] == nil { if !stopReasonSet {
if hasToolCall { if hasToolCall {
response["stop_reason"] = "tool_use" out, _ = sjson.Set(out, "stop_reason", "tool_use")
} else { } else {
response["stop_reason"] = "end_turn" out, _ = sjson.Set(out, "stop_reason", "end_turn")
} }
} }
if !hasToolCall { return out
if toolBlocks := response["content"].([]interface{}); len(toolBlocks) > 0 {
for _, block := range toolBlocks {
if m, ok := block.(map[string]interface{}); ok && m["type"] == "tool_use" {
hasToolCall = true
break
}
}
}
if hasToolCall {
response["stop_reason"] = "tool_use"
}
}
responseJSON, err := json.Marshal(response)
if err != nil {
return ""
}
return string(responseJSON)
} }
func ClaudeTokenCount(ctx context.Context, count int64) string { func ClaudeTokenCount(ctx context.Context, count int64) string {

View File

@@ -8,7 +8,6 @@ package gemini
import ( import (
"bytes" "bytes"
"crypto/rand" "crypto/rand"
"encoding/json"
"fmt" "fmt"
"math/big" "math/big"
"strings" "strings"
@@ -94,7 +93,6 @@ func ConvertGeminiRequestToOpenAI(modelName string, inputRawJSON []byte, stream
out, _ = sjson.Set(out, "stream", stream) out, _ = sjson.Set(out, "stream", stream)
// Process contents (Gemini messages) -> OpenAI messages // Process contents (Gemini messages) -> OpenAI messages
var openAIMessages []interface{}
var toolCallIDs []string // Track tool call IDs for matching with tool results var toolCallIDs []string // Track tool call IDs for matching with tool results
// System instruction -> OpenAI system message // System instruction -> OpenAI system message
@@ -105,22 +103,17 @@ func ConvertGeminiRequestToOpenAI(modelName string, inputRawJSON []byte, stream
} }
if systemInstruction.Exists() { if systemInstruction.Exists() {
parts := systemInstruction.Get("parts") parts := systemInstruction.Get("parts")
msg := map[string]interface{}{ msg := `{"role":"system","content":[]}`
"role": "system", hasContent := false
"content": []interface{}{},
}
var aggregatedParts []interface{}
if parts.Exists() && parts.IsArray() { if parts.Exists() && parts.IsArray() {
parts.ForEach(func(_, part gjson.Result) bool { parts.ForEach(func(_, part gjson.Result) bool {
// Handle text parts // Handle text parts
if text := part.Get("text"); text.Exists() { if text := part.Get("text"); text.Exists() {
formattedText := text.String() contentPart := `{"type":"text","text":""}`
aggregatedParts = append(aggregatedParts, map[string]interface{}{ contentPart, _ = sjson.Set(contentPart, "text", text.String())
"type": "text", msg, _ = sjson.SetRaw(msg, "content.-1", contentPart)
"text": formattedText, hasContent = true
})
} }
// Handle inline data (e.g., images) // Handle inline data (e.g., images)
@@ -132,20 +125,17 @@ func ConvertGeminiRequestToOpenAI(modelName string, inputRawJSON []byte, stream
data := inlineData.Get("data").String() data := inlineData.Get("data").String()
imageURL := fmt.Sprintf("data:%s;base64,%s", mimeType, data) imageURL := fmt.Sprintf("data:%s;base64,%s", mimeType, data)
aggregatedParts = append(aggregatedParts, map[string]interface{}{ contentPart := `{"type":"image_url","image_url":{"url":""}}`
"type": "image_url", contentPart, _ = sjson.Set(contentPart, "image_url.url", imageURL)
"image_url": map[string]interface{}{ msg, _ = sjson.SetRaw(msg, "content.-1", contentPart)
"url": imageURL, hasContent = true
},
})
} }
return true return true
}) })
} }
if len(aggregatedParts) > 0 { if hasContent {
msg["content"] = aggregatedParts out, _ = sjson.SetRaw(out, "messages.-1", msg)
openAIMessages = append(openAIMessages, msg)
} }
} }
@@ -159,16 +149,15 @@ func ConvertGeminiRequestToOpenAI(modelName string, inputRawJSON []byte, stream
role = "assistant" role = "assistant"
} }
// Create OpenAI message msg := `{"role":"","content":""}`
msg := map[string]interface{}{ msg, _ = sjson.Set(msg, "role", role)
"role": role,
"content": "",
}
var textBuilder strings.Builder var textBuilder strings.Builder
var aggregatedParts []interface{} contentWrapper := `{"arr":[]}`
contentPartsCount := 0
onlyTextContent := true onlyTextContent := true
var toolCalls []interface{} toolCallsWrapper := `{"arr":[]}`
toolCallsCount := 0
if parts.Exists() && parts.IsArray() { if parts.Exists() && parts.IsArray() {
parts.ForEach(func(_, part gjson.Result) bool { parts.ForEach(func(_, part gjson.Result) bool {
@@ -176,10 +165,10 @@ func ConvertGeminiRequestToOpenAI(modelName string, inputRawJSON []byte, stream
if text := part.Get("text"); text.Exists() { if text := part.Get("text"); text.Exists() {
formattedText := text.String() formattedText := text.String()
textBuilder.WriteString(formattedText) textBuilder.WriteString(formattedText)
aggregatedParts = append(aggregatedParts, map[string]interface{}{ contentPart := `{"type":"text","text":""}`
"type": "text", contentPart, _ = sjson.Set(contentPart, "text", formattedText)
"text": formattedText, contentWrapper, _ = sjson.SetRaw(contentWrapper, "arr.-1", contentPart)
}) contentPartsCount++
} }
// Handle inline data (e.g., images) // Handle inline data (e.g., images)
@@ -193,12 +182,10 @@ func ConvertGeminiRequestToOpenAI(modelName string, inputRawJSON []byte, stream
data := inlineData.Get("data").String() data := inlineData.Get("data").String()
imageURL := fmt.Sprintf("data:%s;base64,%s", mimeType, data) imageURL := fmt.Sprintf("data:%s;base64,%s", mimeType, data)
aggregatedParts = append(aggregatedParts, map[string]interface{}{ contentPart := `{"type":"image_url","image_url":{"url":""}}`
"type": "image_url", contentPart, _ = sjson.Set(contentPart, "image_url.url", imageURL)
"image_url": map[string]interface{}{ contentWrapper, _ = sjson.SetRaw(contentWrapper, "arr.-1", contentPart)
"url": imageURL, contentPartsCount++
},
})
} }
// Handle function calls (Gemini) -> tool calls (OpenAI) // Handle function calls (Gemini) -> tool calls (OpenAI)
@@ -206,44 +193,32 @@ func ConvertGeminiRequestToOpenAI(modelName string, inputRawJSON []byte, stream
toolCallID := genToolCallID() toolCallID := genToolCallID()
toolCallIDs = append(toolCallIDs, toolCallID) toolCallIDs = append(toolCallIDs, toolCallID)
toolCall := map[string]interface{}{ toolCall := `{"id":"","type":"function","function":{"name":"","arguments":""}}`
"id": toolCallID, toolCall, _ = sjson.Set(toolCall, "id", toolCallID)
"type": "function", toolCall, _ = sjson.Set(toolCall, "function.name", functionCall.Get("name").String())
"function": map[string]interface{}{
"name": functionCall.Get("name").String(),
},
}
// Convert args to arguments JSON string // Convert args to arguments JSON string
if args := functionCall.Get("args"); args.Exists() { if args := functionCall.Get("args"); args.Exists() {
argsJSON, _ := json.Marshal(args.Value()) toolCall, _ = sjson.Set(toolCall, "function.arguments", args.Raw)
toolCall["function"].(map[string]interface{})["arguments"] = string(argsJSON)
} else { } else {
toolCall["function"].(map[string]interface{})["arguments"] = "{}" toolCall, _ = sjson.Set(toolCall, "function.arguments", "{}")
} }
toolCalls = append(toolCalls, toolCall) toolCallsWrapper, _ = sjson.SetRaw(toolCallsWrapper, "arr.-1", toolCall)
toolCallsCount++
} }
// Handle function responses (Gemini) -> tool role messages (OpenAI) // Handle function responses (Gemini) -> tool role messages (OpenAI)
if functionResponse := part.Get("functionResponse"); functionResponse.Exists() { if functionResponse := part.Get("functionResponse"); functionResponse.Exists() {
// Create tool message for function response // Create tool message for function response
toolMsg := map[string]interface{}{ toolMsg := `{"role":"tool","tool_call_id":"","content":""}`
"role": "tool",
"tool_call_id": "", // Will be set based on context
"content": "",
}
// Convert response.content to JSON string // Convert response.content to JSON string
if response := functionResponse.Get("response"); response.Exists() { if response := functionResponse.Get("response"); response.Exists() {
if content = response.Get("content"); content.Exists() { if contentField := response.Get("content"); contentField.Exists() {
// Use the content field from the response toolMsg, _ = sjson.Set(toolMsg, "content", contentField.Raw)
contentJSON, _ := json.Marshal(content.Value())
toolMsg["content"] = string(contentJSON)
} else { } else {
// Fallback to entire response toolMsg, _ = sjson.Set(toolMsg, "content", response.Raw)
responseJSON, _ := json.Marshal(response.Value())
toolMsg["content"] = string(responseJSON)
} }
} }
@@ -252,13 +227,13 @@ func ConvertGeminiRequestToOpenAI(modelName string, inputRawJSON []byte, stream
if len(toolCallIDs) > 0 { if len(toolCallIDs) > 0 {
// Use the last tool call ID (simple matching by function name) // Use the last tool call ID (simple matching by function name)
// In a real implementation, you might want more sophisticated matching // In a real implementation, you might want more sophisticated matching
toolMsg["tool_call_id"] = toolCallIDs[len(toolCallIDs)-1] toolMsg, _ = sjson.Set(toolMsg, "tool_call_id", toolCallIDs[len(toolCallIDs)-1])
} else { } else {
// Generate a tool call ID if none available // Generate a tool call ID if none available
toolMsg["tool_call_id"] = genToolCallID() toolMsg, _ = sjson.Set(toolMsg, "tool_call_id", genToolCallID())
} }
openAIMessages = append(openAIMessages, toolMsg) out, _ = sjson.SetRaw(out, "messages.-1", toolMsg)
} }
return true return true
@@ -266,170 +241,46 @@ func ConvertGeminiRequestToOpenAI(modelName string, inputRawJSON []byte, stream
} }
// Set content // Set content
if len(aggregatedParts) > 0 { if contentPartsCount > 0 {
if onlyTextContent { if onlyTextContent {
msg["content"] = textBuilder.String() msg, _ = sjson.Set(msg, "content", textBuilder.String())
} else { } else {
msg["content"] = aggregatedParts msg, _ = sjson.SetRaw(msg, "content", gjson.Get(contentWrapper, "arr").Raw)
} }
} }
// Set tool calls if any // Set tool calls if any
if len(toolCalls) > 0 { if toolCallsCount > 0 {
msg["tool_calls"] = toolCalls msg, _ = sjson.SetRaw(msg, "tool_calls", gjson.Get(toolCallsWrapper, "arr").Raw)
} }
openAIMessages = append(openAIMessages, msg) out, _ = sjson.SetRaw(out, "messages.-1", msg)
// switch role {
// case "user", "model":
// // Convert role: model -> assistant
// if role == "model" {
// role = "assistant"
// }
//
// // Create OpenAI message
// msg := map[string]interface{}{
// "role": role,
// "content": "",
// }
//
// var contentParts []string
// var toolCalls []interface{}
//
// if parts.Exists() && parts.IsArray() {
// parts.ForEach(func(_, part gjson.Result) bool {
// // Handle text parts
// if text := part.Get("text"); text.Exists() {
// contentParts = append(contentParts, text.String())
// }
//
// // Handle function calls (Gemini) -> tool calls (OpenAI)
// if functionCall := part.Get("functionCall"); functionCall.Exists() {
// toolCallID := genToolCallID()
// toolCallIDs = append(toolCallIDs, toolCallID)
//
// toolCall := map[string]interface{}{
// "id": toolCallID,
// "type": "function",
// "function": map[string]interface{}{
// "name": functionCall.Get("name").String(),
// },
// }
//
// // Convert args to arguments JSON string
// if args := functionCall.Get("args"); args.Exists() {
// argsJSON, _ := json.Marshal(args.Value())
// toolCall["function"].(map[string]interface{})["arguments"] = string(argsJSON)
// } else {
// toolCall["function"].(map[string]interface{})["arguments"] = "{}"
// }
//
// toolCalls = append(toolCalls, toolCall)
// }
//
// return true
// })
// }
//
// // Set content
// if len(contentParts) > 0 {
// msg["content"] = strings.Join(contentParts, "")
// }
//
// // Set tool calls if any
// if len(toolCalls) > 0 {
// msg["tool_calls"] = toolCalls
// }
//
// openAIMessages = append(openAIMessages, msg)
//
// case "function":
// // Handle Gemini function role -> OpenAI tool role
// if parts.Exists() && parts.IsArray() {
// parts.ForEach(func(_, part gjson.Result) bool {
// // Handle function responses (Gemini) -> tool role messages (OpenAI)
// if functionResponse := part.Get("functionResponse"); functionResponse.Exists() {
// // Create tool message for function response
// toolMsg := map[string]interface{}{
// "role": "tool",
// "tool_call_id": "", // Will be set based on context
// "content": "",
// }
//
// // Convert response.content to JSON string
// if response := functionResponse.Get("response"); response.Exists() {
// if content = response.Get("content"); content.Exists() {
// // Use the content field from the response
// contentJSON, _ := json.Marshal(content.Value())
// toolMsg["content"] = string(contentJSON)
// } else {
// // Fallback to entire response
// responseJSON, _ := json.Marshal(response.Value())
// toolMsg["content"] = string(responseJSON)
// }
// }
//
// // Try to match with previous tool call ID
// _ = functionResponse.Get("name").String() // functionName not used for now
// if len(toolCallIDs) > 0 {
// // Use the last tool call ID (simple matching by function name)
// // In a real implementation, you might want more sophisticated matching
// toolMsg["tool_call_id"] = toolCallIDs[len(toolCallIDs)-1]
// } else {
// // Generate a tool call ID if none available
// toolMsg["tool_call_id"] = genToolCallID()
// }
//
// openAIMessages = append(openAIMessages, toolMsg)
// }
//
// return true
// })
// }
// }
return true return true
}) })
} }
// Set messages
if len(openAIMessages) > 0 {
messagesJSON, _ := json.Marshal(openAIMessages)
out, _ = sjson.SetRaw(out, "messages", string(messagesJSON))
}
// Tools mapping: Gemini tools -> OpenAI tools // Tools mapping: Gemini tools -> OpenAI tools
if tools := root.Get("tools"); tools.Exists() && tools.IsArray() { if tools := root.Get("tools"); tools.Exists() && tools.IsArray() {
var openAITools []interface{}
tools.ForEach(func(_, tool gjson.Result) bool { tools.ForEach(func(_, tool gjson.Result) bool {
if functionDeclarations := tool.Get("functionDeclarations"); functionDeclarations.Exists() && functionDeclarations.IsArray() { if functionDeclarations := tool.Get("functionDeclarations"); functionDeclarations.Exists() && functionDeclarations.IsArray() {
functionDeclarations.ForEach(func(_, funcDecl gjson.Result) bool { functionDeclarations.ForEach(func(_, funcDecl gjson.Result) bool {
openAITool := map[string]interface{}{ openAITool := `{"type":"function","function":{"name":"","description":""}}`
"type": "function", openAITool, _ = sjson.Set(openAITool, "function.name", funcDecl.Get("name").String())
"function": map[string]interface{}{ openAITool, _ = sjson.Set(openAITool, "function.description", funcDecl.Get("description").String())
"name": funcDecl.Get("name").String(),
"description": funcDecl.Get("description").String(),
},
}
// Convert parameters schema // Convert parameters schema
if parameters := funcDecl.Get("parameters"); parameters.Exists() { if parameters := funcDecl.Get("parameters"); parameters.Exists() {
openAITool["function"].(map[string]interface{})["parameters"] = parameters.Value() openAITool, _ = sjson.SetRaw(openAITool, "function.parameters", parameters.Raw)
} else if parameters = funcDecl.Get("parametersJsonSchema"); parameters.Exists() { } else if parameters := funcDecl.Get("parametersJsonSchema"); parameters.Exists() {
openAITool["function"].(map[string]interface{})["parameters"] = parameters.Value() openAITool, _ = sjson.SetRaw(openAITool, "function.parameters", parameters.Raw)
} }
openAITools = append(openAITools, openAITool) out, _ = sjson.SetRaw(out, "tools.-1", openAITool)
return true return true
}) })
} }
return true return true
}) })
if len(openAITools) > 0 {
toolsJSON, _ := json.Marshal(openAITools)
out, _ = sjson.SetRaw(out, "tools", string(toolsJSON))
}
} }
// Tool choice mapping (Gemini doesn't have direct equivalent, but we can handle it) // Tool choice mapping (Gemini doesn't have direct equivalent, but we can handle it)

View File

@@ -8,7 +8,6 @@ package gemini
import ( import (
"bytes" "bytes"
"context" "context"
"encoding/json"
"fmt" "fmt"
"strconv" "strconv"
"strings" "strings"
@@ -84,15 +83,12 @@ func ConvertOpenAIResponseToGemini(_ context.Context, _ string, originalRequestR
template, _ = sjson.Set(template, "model", model.String()) template, _ = sjson.Set(template, "model", model.String())
} }
usageObj := map[string]interface{}{ template, _ = sjson.Set(template, "usageMetadata.promptTokenCount", usage.Get("prompt_tokens").Int())
"promptTokenCount": usage.Get("prompt_tokens").Int(), template, _ = sjson.Set(template, "usageMetadata.candidatesTokenCount", usage.Get("completion_tokens").Int())
"candidatesTokenCount": usage.Get("completion_tokens").Int(), template, _ = sjson.Set(template, "usageMetadata.totalTokenCount", usage.Get("total_tokens").Int())
"totalTokenCount": usage.Get("total_tokens").Int(),
}
if reasoningTokens := reasoningTokensFromUsage(usage); reasoningTokens > 0 { if reasoningTokens := reasoningTokensFromUsage(usage); reasoningTokens > 0 {
usageObj["thoughtsTokenCount"] = reasoningTokens template, _ = sjson.Set(template, "usageMetadata.thoughtsTokenCount", reasoningTokens)
} }
template, _ = sjson.Set(template, "usageMetadata", usageObj)
return []string{template} return []string{template}
} }
return []string{} return []string{}
@@ -133,13 +129,8 @@ func ConvertOpenAIResponseToGemini(_ context.Context, _ string, originalRequestR
continue continue
} }
reasoningTemplate := baseTemplate reasoningTemplate := baseTemplate
parts := []interface{}{ reasoningTemplate, _ = sjson.Set(reasoningTemplate, "candidates.0.content.parts.0.thought", true)
map[string]interface{}{ reasoningTemplate, _ = sjson.Set(reasoningTemplate, "candidates.0.content.parts.0.text", reasoningText)
"thought": true,
"text": reasoningText,
},
}
reasoningTemplate, _ = sjson.Set(reasoningTemplate, "candidates.0.content.parts", parts)
chunkOutputs = append(chunkOutputs, reasoningTemplate) chunkOutputs = append(chunkOutputs, reasoningTemplate)
} }
} }
@@ -150,13 +141,8 @@ func ConvertOpenAIResponseToGemini(_ context.Context, _ string, originalRequestR
(*param).(*ConvertOpenAIResponseToGeminiParams).ContentAccumulator.WriteString(contentText) (*param).(*ConvertOpenAIResponseToGeminiParams).ContentAccumulator.WriteString(contentText)
// Create text part for this delta // Create text part for this delta
parts := []interface{}{
map[string]interface{}{
"text": contentText,
},
}
contentTemplate := baseTemplate contentTemplate := baseTemplate
contentTemplate, _ = sjson.Set(contentTemplate, "candidates.0.content.parts", parts) contentTemplate, _ = sjson.Set(contentTemplate, "candidates.0.content.parts.0.text", contentText)
chunkOutputs = append(chunkOutputs, contentTemplate) chunkOutputs = append(chunkOutputs, contentTemplate)
} }
@@ -225,24 +211,13 @@ func ConvertOpenAIResponseToGemini(_ context.Context, _ string, originalRequestR
// If we have accumulated tool calls, output them now // If we have accumulated tool calls, output them now
if len((*param).(*ConvertOpenAIResponseToGeminiParams).ToolCallsAccumulator) > 0 { if len((*param).(*ConvertOpenAIResponseToGeminiParams).ToolCallsAccumulator) > 0 {
var parts []interface{} partIndex := 0
for _, accumulator := range (*param).(*ConvertOpenAIResponseToGeminiParams).ToolCallsAccumulator { for _, accumulator := range (*param).(*ConvertOpenAIResponseToGeminiParams).ToolCallsAccumulator {
argsStr := accumulator.Arguments.String() namePath := fmt.Sprintf("candidates.0.content.parts.%d.functionCall.name", partIndex)
var argsMap map[string]interface{} argsPath := fmt.Sprintf("candidates.0.content.parts.%d.functionCall.args", partIndex)
template, _ = sjson.Set(template, namePath, accumulator.Name)
argsMap = parseArgsToMap(argsStr) template, _ = sjson.SetRaw(template, argsPath, parseArgsToObjectRaw(accumulator.Arguments.String()))
partIndex++
functionCallPart := map[string]interface{}{
"functionCall": map[string]interface{}{
"name": accumulator.Name,
"args": argsMap,
},
}
parts = append(parts, functionCallPart)
}
if len(parts) > 0 {
template, _ = sjson.Set(template, "candidates.0.content.parts", parts)
} }
// Clear accumulators // Clear accumulators
@@ -255,15 +230,12 @@ func ConvertOpenAIResponseToGemini(_ context.Context, _ string, originalRequestR
// Handle usage information // Handle usage information
if usage := root.Get("usage"); usage.Exists() { if usage := root.Get("usage"); usage.Exists() {
usageObj := map[string]interface{}{ template, _ = sjson.Set(template, "usageMetadata.promptTokenCount", usage.Get("prompt_tokens").Int())
"promptTokenCount": usage.Get("prompt_tokens").Int(), template, _ = sjson.Set(template, "usageMetadata.candidatesTokenCount", usage.Get("completion_tokens").Int())
"candidatesTokenCount": usage.Get("completion_tokens").Int(), template, _ = sjson.Set(template, "usageMetadata.totalTokenCount", usage.Get("total_tokens").Int())
"totalTokenCount": usage.Get("total_tokens").Int(),
}
if reasoningTokens := reasoningTokensFromUsage(usage); reasoningTokens > 0 { if reasoningTokens := reasoningTokensFromUsage(usage); reasoningTokens > 0 {
usageObj["thoughtsTokenCount"] = reasoningTokens template, _ = sjson.Set(template, "usageMetadata.thoughtsTokenCount", reasoningTokens)
} }
template, _ = sjson.Set(template, "usageMetadata", usageObj)
results = append(results, template) results = append(results, template)
return true return true
} }
@@ -291,46 +263,54 @@ func mapOpenAIFinishReasonToGemini(openAIReason string) string {
} }
} }
// parseArgsToMap safely parses a JSON string of function arguments into a map. // parseArgsToObjectRaw safely parses a JSON string of function arguments into an object JSON string.
// It returns an empty map if the input is empty or cannot be parsed as a JSON object. // It returns "{}" if the input is empty or cannot be parsed as a JSON object.
func parseArgsToMap(argsStr string) map[string]interface{} { func parseArgsToObjectRaw(argsStr string) string {
trimmed := strings.TrimSpace(argsStr) trimmed := strings.TrimSpace(argsStr)
if trimmed == "" || trimmed == "{}" { if trimmed == "" || trimmed == "{}" {
return map[string]interface{}{} return "{}"
} }
// First try strict JSON // First try strict JSON
var out map[string]interface{} if gjson.Valid(trimmed) {
if errUnmarshal := json.Unmarshal([]byte(trimmed), &out); errUnmarshal == nil { strict := gjson.Parse(trimmed)
return out if strict.IsObject() {
return strict.Raw
}
} }
// Tolerant parse: handle streams where values are barewords (e.g., 北京, celsius) // Tolerant parse: handle streams where values are barewords (e.g., 北京, celsius)
tolerant := tolerantParseJSONMap(trimmed) tolerant := tolerantParseJSONObjectRaw(trimmed)
if len(tolerant) > 0 { if tolerant != "{}" {
return tolerant return tolerant
} }
// Fallback: return empty object when parsing fails // Fallback: return empty object when parsing fails
return map[string]interface{}{} return "{}"
} }
// tolerantParseJSONMap attempts to parse a JSON-like object string into a map, tolerating func escapeSjsonPathKey(key string) string {
key = strings.ReplaceAll(key, `\`, `\\`)
key = strings.ReplaceAll(key, `.`, `\.`)
return key
}
// tolerantParseJSONObjectRaw attempts to parse a JSON-like object string into a JSON object string, tolerating
// bareword values (unquoted strings) commonly seen during streamed tool calls. // bareword values (unquoted strings) commonly seen during streamed tool calls.
// Example input: {"location": 北京, "unit": celsius} // Example input: {"location": 北京, "unit": celsius}
func tolerantParseJSONMap(s string) map[string]interface{} { func tolerantParseJSONObjectRaw(s string) string {
// Ensure we operate within the outermost braces if present // Ensure we operate within the outermost braces if present
start := strings.Index(s, "{") start := strings.Index(s, "{")
end := strings.LastIndex(s, "}") end := strings.LastIndex(s, "}")
if start == -1 || end == -1 || start >= end { if start == -1 || end == -1 || start >= end {
return map[string]interface{}{} return "{}"
} }
content := s[start+1 : end] content := s[start+1 : end]
runes := []rune(content) runes := []rune(content)
n := len(runes) n := len(runes)
i := 0 i := 0
result := make(map[string]interface{}) result := "{}"
for i < n { for i < n {
// Skip whitespace and commas // Skip whitespace and commas
@@ -356,6 +336,7 @@ func tolerantParseJSONMap(s string) map[string]interface{} {
break break
} }
keyName := jsonStringTokenToRawString(keyToken) keyName := jsonStringTokenToRawString(keyToken)
sjsonKey := escapeSjsonPathKey(keyName)
i = nextIdx i = nextIdx
// Skip whitespace // Skip whitespace
@@ -375,17 +356,16 @@ func tolerantParseJSONMap(s string) map[string]interface{} {
} }
// Parse value (string, number, object/array, bareword) // Parse value (string, number, object/array, bareword)
var value interface{}
switch runes[i] { switch runes[i] {
case '"': case '"':
// JSON string // JSON string
valToken, ni := parseJSONStringRunes(runes, i) valToken, ni := parseJSONStringRunes(runes, i)
if ni == -1 { if ni == -1 {
// Malformed; treat as empty string // Malformed; treat as empty string
value = "" result, _ = sjson.Set(result, sjsonKey, "")
i = n i = n
} else { } else {
value = jsonStringTokenToRawString(valToken) result, _ = sjson.Set(result, sjsonKey, jsonStringTokenToRawString(valToken))
i = ni i = ni
} }
case '{', '[': case '{', '[':
@@ -394,11 +374,10 @@ func tolerantParseJSONMap(s string) map[string]interface{} {
if ni == -1 { if ni == -1 {
i = n i = n
} else { } else {
var anyVal interface{} if gjson.Valid(seg) {
if errUnmarshal := json.Unmarshal([]byte(seg), &anyVal); errUnmarshal == nil { result, _ = sjson.SetRaw(result, sjsonKey, seg)
value = anyVal
} else { } else {
value = seg result, _ = sjson.Set(result, sjsonKey, seg)
} }
i = ni i = ni
} }
@@ -411,21 +390,19 @@ func tolerantParseJSONMap(s string) map[string]interface{} {
token := strings.TrimSpace(string(runes[i:j])) token := strings.TrimSpace(string(runes[i:j]))
// Interpret common JSON atoms and numbers; otherwise treat as string // Interpret common JSON atoms and numbers; otherwise treat as string
if token == "true" { if token == "true" {
value = true result, _ = sjson.Set(result, sjsonKey, true)
} else if token == "false" { } else if token == "false" {
value = false result, _ = sjson.Set(result, sjsonKey, false)
} else if token == "null" { } else if token == "null" {
value = nil result, _ = sjson.Set(result, sjsonKey, nil)
} else if numVal, ok := tryParseNumber(token); ok { } else if numVal, ok := tryParseNumber(token); ok {
value = numVal result, _ = sjson.Set(result, sjsonKey, numVal)
} else { } else {
value = token result, _ = sjson.Set(result, sjsonKey, token)
} }
i = j i = j
} }
result[keyName] = value
// Skip trailing whitespace and optional comma before next pair // Skip trailing whitespace and optional comma before next pair
for i < n && (runes[i] == ' ' || runes[i] == '\n' || runes[i] == '\r' || runes[i] == '\t') { for i < n && (runes[i] == ' ' || runes[i] == '\n' || runes[i] == '\r' || runes[i] == '\t') {
i++ i++
@@ -463,9 +440,9 @@ func parseJSONStringRunes(runes []rune, start int) (string, int) {
// jsonStringTokenToRawString converts a JSON string token (including quotes) to a raw Go string value. // jsonStringTokenToRawString converts a JSON string token (including quotes) to a raw Go string value.
func jsonStringTokenToRawString(token string) string { func jsonStringTokenToRawString(token string) string {
var s string r := gjson.Parse(token)
if errUnmarshal := json.Unmarshal([]byte(token), &s); errUnmarshal == nil { if r.Type == gjson.String {
return s return r.String()
} }
// Fallback: strip surrounding quotes if present // Fallback: strip surrounding quotes if present
if len(token) >= 2 && token[0] == '"' && token[len(token)-1] == '"' { if len(token) >= 2 && token[0] == '"' && token[len(token)-1] == '"' {
@@ -579,7 +556,7 @@ func ConvertOpenAIResponseToGeminiNonStream(_ context.Context, _ string, origina
} }
} }
var parts []interface{} partIndex := 0
// Handle reasoning content before visible text // Handle reasoning content before visible text
if reasoning := message.Get("reasoning_content"); reasoning.Exists() { if reasoning := message.Get("reasoning_content"); reasoning.Exists() {
@@ -587,18 +564,16 @@ func ConvertOpenAIResponseToGeminiNonStream(_ context.Context, _ string, origina
if reasoningText == "" { if reasoningText == "" {
continue continue
} }
parts = append(parts, map[string]interface{}{ out, _ = sjson.Set(out, fmt.Sprintf("candidates.0.content.parts.%d.thought", partIndex), true)
"thought": true, out, _ = sjson.Set(out, fmt.Sprintf("candidates.0.content.parts.%d.text", partIndex), reasoningText)
"text": reasoningText, partIndex++
})
} }
} }
// Handle content first // Handle content first
if content := message.Get("content"); content.Exists() && content.String() != "" { if content := message.Get("content"); content.Exists() && content.String() != "" {
parts = append(parts, map[string]interface{}{ out, _ = sjson.Set(out, fmt.Sprintf("candidates.0.content.parts.%d.text", partIndex), content.String())
"text": content.String(), partIndex++
})
} }
// Handle tool calls // Handle tool calls
@@ -609,27 +584,16 @@ func ConvertOpenAIResponseToGeminiNonStream(_ context.Context, _ string, origina
functionName := function.Get("name").String() functionName := function.Get("name").String()
functionArgs := function.Get("arguments").String() functionArgs := function.Get("arguments").String()
// Parse arguments namePath := fmt.Sprintf("candidates.0.content.parts.%d.functionCall.name", partIndex)
var argsMap map[string]interface{} argsPath := fmt.Sprintf("candidates.0.content.parts.%d.functionCall.args", partIndex)
argsMap = parseArgsToMap(functionArgs) out, _ = sjson.Set(out, namePath, functionName)
out, _ = sjson.SetRaw(out, argsPath, parseArgsToObjectRaw(functionArgs))
functionCallPart := map[string]interface{}{ partIndex++
"functionCall": map[string]interface{}{
"name": functionName,
"args": argsMap,
},
}
parts = append(parts, functionCallPart)
} }
return true return true
}) })
} }
// Set parts
if len(parts) > 0 {
out, _ = sjson.Set(out, "candidates.0.content.parts", parts)
}
// Handle finish reason // Handle finish reason
if finishReason := choice.Get("finish_reason"); finishReason.Exists() { if finishReason := choice.Get("finish_reason"); finishReason.Exists() {
geminiFinishReason := mapOpenAIFinishReasonToGemini(finishReason.String()) geminiFinishReason := mapOpenAIFinishReasonToGemini(finishReason.String())
@@ -645,15 +609,12 @@ func ConvertOpenAIResponseToGeminiNonStream(_ context.Context, _ string, origina
// Handle usage information // Handle usage information
if usage := root.Get("usage"); usage.Exists() { if usage := root.Get("usage"); usage.Exists() {
usageObj := map[string]interface{}{ out, _ = sjson.Set(out, "usageMetadata.promptTokenCount", usage.Get("prompt_tokens").Int())
"promptTokenCount": usage.Get("prompt_tokens").Int(), out, _ = sjson.Set(out, "usageMetadata.candidatesTokenCount", usage.Get("completion_tokens").Int())
"candidatesTokenCount": usage.Get("completion_tokens").Int(), out, _ = sjson.Set(out, "usageMetadata.totalTokenCount", usage.Get("total_tokens").Int())
"totalTokenCount": usage.Get("total_tokens").Int(),
}
if reasoningTokens := reasoningTokensFromUsage(usage); reasoningTokens > 0 { if reasoningTokens := reasoningTokensFromUsage(usage); reasoningTokens > 0 {
usageObj["thoughtsTokenCount"] = reasoningTokens out, _ = sjson.Set(out, "usageMetadata.thoughtsTokenCount", reasoningTokens)
} }
out, _ = sjson.Set(out, "usageMetadata", usageObj)
} }
return out return out

View File

@@ -484,16 +484,12 @@ func ConvertOpenAIChatCompletionsResponseToOpenAIResponses(ctx context.Context,
} }
} }
// Build response.output using aggregated buffers // Build response.output using aggregated buffers
var outputs []interface{} outputsWrapper := `{"arr":[]}`
if st.ReasoningBuf.Len() > 0 { if st.ReasoningBuf.Len() > 0 {
outputs = append(outputs, map[string]interface{}{ item := `{"id":"","type":"reasoning","summary":[{"type":"summary_text","text":""}]}`
"id": st.ReasoningID, item, _ = sjson.Set(item, "id", st.ReasoningID)
"type": "reasoning", item, _ = sjson.Set(item, "summary.0.text", st.ReasoningBuf.String())
"summary": []interface{}{map[string]interface{}{ outputsWrapper, _ = sjson.SetRaw(outputsWrapper, "arr.-1", item)
"type": "summary_text",
"text": st.ReasoningBuf.String(),
}},
})
} }
// Append message items in ascending index order // Append message items in ascending index order
if len(st.MsgItemAdded) > 0 { if len(st.MsgItemAdded) > 0 {
@@ -513,18 +509,10 @@ func ConvertOpenAIChatCompletionsResponseToOpenAIResponses(ctx context.Context,
if b := st.MsgTextBuf[i]; b != nil { if b := st.MsgTextBuf[i]; b != nil {
txt = b.String() txt = b.String()
} }
outputs = append(outputs, map[string]interface{}{ item := `{"id":"","type":"message","status":"completed","content":[{"type":"output_text","annotations":[],"logprobs":[],"text":""}],"role":"assistant"}`
"id": fmt.Sprintf("msg_%s_%d", st.ResponseID, i), item, _ = sjson.Set(item, "id", fmt.Sprintf("msg_%s_%d", st.ResponseID, i))
"type": "message", item, _ = sjson.Set(item, "content.0.text", txt)
"status": "completed", outputsWrapper, _ = sjson.SetRaw(outputsWrapper, "arr.-1", item)
"content": []interface{}{map[string]interface{}{
"type": "output_text",
"annotations": []interface{}{},
"logprobs": []interface{}{},
"text": txt,
}},
"role": "assistant",
})
} }
} }
if len(st.FuncArgsBuf) > 0 { if len(st.FuncArgsBuf) > 0 {
@@ -547,18 +535,16 @@ func ConvertOpenAIChatCompletionsResponseToOpenAIResponses(ctx context.Context,
} }
callID := st.FuncCallIDs[i] callID := st.FuncCallIDs[i]
name := st.FuncNames[i] name := st.FuncNames[i]
outputs = append(outputs, map[string]interface{}{ item := `{"id":"","type":"function_call","status":"completed","arguments":"","call_id":"","name":""}`
"id": fmt.Sprintf("fc_%s", callID), item, _ = sjson.Set(item, "id", fmt.Sprintf("fc_%s", callID))
"type": "function_call", item, _ = sjson.Set(item, "arguments", args)
"status": "completed", item, _ = sjson.Set(item, "call_id", callID)
"arguments": args, item, _ = sjson.Set(item, "name", name)
"call_id": callID, outputsWrapper, _ = sjson.SetRaw(outputsWrapper, "arr.-1", item)
"name": name,
})
} }
} }
if len(outputs) > 0 { if gjson.Get(outputsWrapper, "arr.#").Int() > 0 {
completed, _ = sjson.Set(completed, "response.output", outputs) completed, _ = sjson.SetRaw(completed, "response.output", gjson.Get(outputsWrapper, "arr").Raw)
} }
if st.UsageSeen { if st.UsageSeen {
completed, _ = sjson.Set(completed, "response.usage.input_tokens", st.PromptTokens) completed, _ = sjson.Set(completed, "response.usage.input_tokens", st.PromptTokens)
@@ -681,7 +667,7 @@ func ConvertOpenAIChatCompletionsResponseToOpenAIResponsesNonStream(_ context.Co
} }
// Build output list from choices[...] // Build output list from choices[...]
var outputs []interface{} outputsWrapper := `{"arr":[]}`
// Detect and capture reasoning content if present // Detect and capture reasoning content if present
rcText := gjson.GetBytes(rawJSON, "choices.0.message.reasoning_content").String() rcText := gjson.GetBytes(rawJSON, "choices.0.message.reasoning_content").String()
includeReasoning := rcText != "" includeReasoning := rcText != ""
@@ -693,21 +679,14 @@ func ConvertOpenAIChatCompletionsResponseToOpenAIResponsesNonStream(_ context.Co
if strings.HasPrefix(rid, "resp_") { if strings.HasPrefix(rid, "resp_") {
rid = strings.TrimPrefix(rid, "resp_") rid = strings.TrimPrefix(rid, "resp_")
} }
reasoningItem := map[string]interface{}{
"id": fmt.Sprintf("rs_%s", rid),
"type": "reasoning",
"encrypted_content": "",
}
// Prefer summary_text from reasoning_content; encrypted_content is optional // Prefer summary_text from reasoning_content; encrypted_content is optional
var summaries []interface{} reasoningItem := `{"id":"","type":"reasoning","encrypted_content":"","summary":[]}`
reasoningItem, _ = sjson.Set(reasoningItem, "id", fmt.Sprintf("rs_%s", rid))
if rcText != "" { if rcText != "" {
summaries = append(summaries, map[string]interface{}{ reasoningItem, _ = sjson.Set(reasoningItem, "summary.0.type", "summary_text")
"type": "summary_text", reasoningItem, _ = sjson.Set(reasoningItem, "summary.0.text", rcText)
"text": rcText,
})
} }
reasoningItem["summary"] = summaries outputsWrapper, _ = sjson.SetRaw(outputsWrapper, "arr.-1", reasoningItem)
outputs = append(outputs, reasoningItem)
} }
if choices := root.Get("choices"); choices.Exists() && choices.IsArray() { if choices := root.Get("choices"); choices.Exists() && choices.IsArray() {
@@ -716,18 +695,10 @@ func ConvertOpenAIChatCompletionsResponseToOpenAIResponsesNonStream(_ context.Co
if msg.Exists() { if msg.Exists() {
// Text message part // Text message part
if c := msg.Get("content"); c.Exists() && c.String() != "" { if c := msg.Get("content"); c.Exists() && c.String() != "" {
outputs = append(outputs, map[string]interface{}{ item := `{"id":"","type":"message","status":"completed","content":[{"type":"output_text","annotations":[],"logprobs":[],"text":""}],"role":"assistant"}`
"id": fmt.Sprintf("msg_%s_%d", id, int(choice.Get("index").Int())), item, _ = sjson.Set(item, "id", fmt.Sprintf("msg_%s_%d", id, int(choice.Get("index").Int())))
"type": "message", item, _ = sjson.Set(item, "content.0.text", c.String())
"status": "completed", outputsWrapper, _ = sjson.SetRaw(outputsWrapper, "arr.-1", item)
"content": []interface{}{map[string]interface{}{
"type": "output_text",
"annotations": []interface{}{},
"logprobs": []interface{}{},
"text": c.String(),
}},
"role": "assistant",
})
} }
// Function/tool calls // Function/tool calls
@@ -736,14 +707,12 @@ func ConvertOpenAIChatCompletionsResponseToOpenAIResponsesNonStream(_ context.Co
callID := tc.Get("id").String() callID := tc.Get("id").String()
name := tc.Get("function.name").String() name := tc.Get("function.name").String()
args := tc.Get("function.arguments").String() args := tc.Get("function.arguments").String()
outputs = append(outputs, map[string]interface{}{ item := `{"id":"","type":"function_call","status":"completed","arguments":"","call_id":"","name":""}`
"id": fmt.Sprintf("fc_%s", callID), item, _ = sjson.Set(item, "id", fmt.Sprintf("fc_%s", callID))
"type": "function_call", item, _ = sjson.Set(item, "arguments", args)
"status": "completed", item, _ = sjson.Set(item, "call_id", callID)
"arguments": args, item, _ = sjson.Set(item, "name", name)
"call_id": callID, outputsWrapper, _ = sjson.SetRaw(outputsWrapper, "arr.-1", item)
"name": name,
})
return true return true
}) })
} }
@@ -751,8 +720,8 @@ func ConvertOpenAIChatCompletionsResponseToOpenAIResponsesNonStream(_ context.Co
return true return true
}) })
} }
if len(outputs) > 0 { if gjson.Get(outputsWrapper, "arr.#").Int() > 0 {
resp, _ = sjson.Set(resp, "output", outputs) resp, _ = sjson.SetRaw(resp, "output", gjson.Get(outputsWrapper, "arr").Raw)
} }
// usage mapping // usage mapping

View File

@@ -0,0 +1,10 @@
package util
import "strings"
// IsClaudeThinkingModel checks if the model is a Claude thinking model
// that requires the interleaved-thinking beta header.
func IsClaudeThinkingModel(model string) bool {
lower := strings.ToLower(model)
return strings.Contains(lower, "claude") && strings.Contains(lower, "thinking")
}

View File

@@ -0,0 +1,41 @@
package util
import "testing"
func TestIsClaudeThinkingModel(t *testing.T) {
tests := []struct {
name string
model string
expected bool
}{
// Claude thinking models - should return true
{"claude-sonnet-4-5-thinking", "claude-sonnet-4-5-thinking", true},
{"claude-opus-4-5-thinking", "claude-opus-4-5-thinking", true},
{"Claude-Sonnet-Thinking uppercase", "Claude-Sonnet-4-5-Thinking", true},
{"claude thinking mixed case", "Claude-THINKING-Model", true},
// Non-thinking Claude models - should return false
{"claude-sonnet-4-5 (no thinking)", "claude-sonnet-4-5", false},
{"claude-opus-4-5 (no thinking)", "claude-opus-4-5", false},
{"claude-3-5-sonnet", "claude-3-5-sonnet-20240620", false},
// Non-Claude models - should return false
{"gemini-3-pro-preview", "gemini-3-pro-preview", false},
{"gemini-thinking model", "gemini-3-pro-thinking", false}, // not Claude
{"gpt-4o", "gpt-4o", false},
{"empty string", "", false},
// Edge cases
{"thinking without claude", "thinking-model", false},
{"claude without thinking", "claude-model", false},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
result := IsClaudeThinkingModel(tt.model)
if result != tt.expected {
t.Errorf("IsClaudeThinkingModel(%q) = %v, expected %v", tt.model, result, tt.expected)
}
})
}
}

View File

@@ -12,10 +12,10 @@ import (
var gjsonPathKeyReplacer = strings.NewReplacer(".", "\\.", "*", "\\*", "?", "\\?") var gjsonPathKeyReplacer = strings.NewReplacer(".", "\\.", "*", "\\*", "?", "\\?")
// CleanJSONSchemaForGemini transforms a JSON schema to be compatible with Gemini/Antigravity API. // CleanJSONSchemaForAntigravity transforms a JSON schema to be compatible with Antigravity API.
// It handles unsupported keywords, type flattening, and schema simplification while preserving // It handles unsupported keywords, type flattening, and schema simplification while preserving
// semantic information as description hints. // semantic information as description hints.
func CleanJSONSchemaForGemini(jsonStr string) string { func CleanJSONSchemaForAntigravity(jsonStr string) string {
// Phase 1: Convert and add hints // Phase 1: Convert and add hints
jsonStr = convertRefsToHints(jsonStr) jsonStr = convertRefsToHints(jsonStr)
jsonStr = convertConstToEnum(jsonStr) jsonStr = convertConstToEnum(jsonStr)
@@ -32,6 +32,9 @@ func CleanJSONSchemaForGemini(jsonStr string) string {
jsonStr = removeUnsupportedKeywords(jsonStr) jsonStr = removeUnsupportedKeywords(jsonStr)
jsonStr = cleanupRequiredFields(jsonStr) jsonStr = cleanupRequiredFields(jsonStr)
// Phase 4: Add placeholder for empty object schemas (Claude VALIDATED mode requirement)
jsonStr = addEmptySchemaPlaceholder(jsonStr)
return jsonStr return jsonStr
} }
@@ -105,7 +108,8 @@ func addAdditionalPropertiesHints(jsonStr string) string {
var unsupportedConstraints = []string{ var unsupportedConstraints = []string{
"minLength", "maxLength", "exclusiveMinimum", "exclusiveMaximum", "minLength", "maxLength", "exclusiveMinimum", "exclusiveMaximum",
"pattern", "minItems", "maxItems", "pattern", "minItems", "maxItems", "format",
"default", "examples", // Claude rejects these in VALIDATED mode
} }
func moveConstraintsToDescription(jsonStr string) string { func moveConstraintsToDescription(jsonStr string) string {
@@ -339,6 +343,52 @@ func cleanupRequiredFields(jsonStr string) string {
return jsonStr return jsonStr
} }
// addEmptySchemaPlaceholder adds a placeholder "reason" property to empty object schemas.
// Claude VALIDATED mode requires at least one property in tool schemas.
func addEmptySchemaPlaceholder(jsonStr string) string {
// Find all "type" fields
paths := findPaths(jsonStr, "type")
// Process from deepest to shallowest (to handle nested objects properly)
sortByDepth(paths)
for _, p := range paths {
typeVal := gjson.Get(jsonStr, p)
if typeVal.String() != "object" {
continue
}
// Get the parent path (the object containing "type")
parentPath := trimSuffix(p, ".type")
// Check if properties exists and is empty or missing
propsPath := joinPath(parentPath, "properties")
propsVal := gjson.Get(jsonStr, propsPath)
needsPlaceholder := false
if !propsVal.Exists() {
// No properties field at all
needsPlaceholder = true
} else if propsVal.IsObject() && len(propsVal.Map()) == 0 {
// Empty properties object
needsPlaceholder = true
}
if needsPlaceholder {
// Add placeholder "reason" property
reasonPath := joinPath(propsPath, "reason")
jsonStr, _ = sjson.Set(jsonStr, reasonPath+".type", "string")
jsonStr, _ = sjson.Set(jsonStr, reasonPath+".description", "Brief explanation of why you are calling this tool")
// Add to required array
reqPath := joinPath(parentPath, "required")
jsonStr, _ = sjson.Set(jsonStr, reqPath, []string{"reason"})
}
}
return jsonStr
}
// --- Helpers --- // --- Helpers ---
func findPaths(jsonStr, field string) []string { func findPaths(jsonStr, field string) []string {

View File

@@ -5,9 +5,11 @@ import (
"reflect" "reflect"
"strings" "strings"
"testing" "testing"
"github.com/tidwall/gjson"
) )
func TestCleanJSONSchemaForGemini_ConstToEnum(t *testing.T) { func TestCleanJSONSchemaForAntigravity_ConstToEnum(t *testing.T) {
input := `{ input := `{
"type": "object", "type": "object",
"properties": { "properties": {
@@ -28,11 +30,11 @@ func TestCleanJSONSchemaForGemini_ConstToEnum(t *testing.T) {
} }
}` }`
result := CleanJSONSchemaForGemini(input) result := CleanJSONSchemaForAntigravity(input)
compareJSON(t, expected, result) compareJSON(t, expected, result)
} }
func TestCleanJSONSchemaForGemini_TypeFlattening_Nullable(t *testing.T) { func TestCleanJSONSchemaForAntigravity_TypeFlattening_Nullable(t *testing.T) {
input := `{ input := `{
"type": "object", "type": "object",
"properties": { "properties": {
@@ -60,11 +62,11 @@ func TestCleanJSONSchemaForGemini_TypeFlattening_Nullable(t *testing.T) {
"required": ["other"] "required": ["other"]
}` }`
result := CleanJSONSchemaForGemini(input) result := CleanJSONSchemaForAntigravity(input)
compareJSON(t, expected, result) compareJSON(t, expected, result)
} }
func TestCleanJSONSchemaForGemini_ConstraintsToDescription(t *testing.T) { func TestCleanJSONSchemaForAntigravity_ConstraintsToDescription(t *testing.T) {
input := `{ input := `{
"type": "object", "type": "object",
"properties": { "properties": {
@@ -81,7 +83,7 @@ func TestCleanJSONSchemaForGemini_ConstraintsToDescription(t *testing.T) {
} }
}` }`
result := CleanJSONSchemaForGemini(input) result := CleanJSONSchemaForAntigravity(input)
// minItems should be REMOVED and moved to description // minItems should be REMOVED and moved to description
if strings.Contains(result, `"minItems"`) { if strings.Contains(result, `"minItems"`) {
@@ -100,7 +102,7 @@ func TestCleanJSONSchemaForGemini_ConstraintsToDescription(t *testing.T) {
} }
} }
func TestCleanJSONSchemaForGemini_AnyOfFlattening_SmartSelection(t *testing.T) { func TestCleanJSONSchemaForAntigravity_AnyOfFlattening_SmartSelection(t *testing.T) {
input := `{ input := `{
"type": "object", "type": "object",
"properties": { "properties": {
@@ -131,11 +133,11 @@ func TestCleanJSONSchemaForGemini_AnyOfFlattening_SmartSelection(t *testing.T) {
} }
}` }`
result := CleanJSONSchemaForGemini(input) result := CleanJSONSchemaForAntigravity(input)
compareJSON(t, expected, result) compareJSON(t, expected, result)
} }
func TestCleanJSONSchemaForGemini_OneOfFlattening(t *testing.T) { func TestCleanJSONSchemaForAntigravity_OneOfFlattening(t *testing.T) {
input := `{ input := `{
"type": "object", "type": "object",
"properties": { "properties": {
@@ -158,11 +160,11 @@ func TestCleanJSONSchemaForGemini_OneOfFlattening(t *testing.T) {
} }
}` }`
result := CleanJSONSchemaForGemini(input) result := CleanJSONSchemaForAntigravity(input)
compareJSON(t, expected, result) compareJSON(t, expected, result)
} }
func TestCleanJSONSchemaForGemini_AllOfMerging(t *testing.T) { func TestCleanJSONSchemaForAntigravity_AllOfMerging(t *testing.T) {
input := `{ input := `{
"type": "object", "type": "object",
"allOf": [ "allOf": [
@@ -190,11 +192,11 @@ func TestCleanJSONSchemaForGemini_AllOfMerging(t *testing.T) {
"required": ["a", "b"] "required": ["a", "b"]
}` }`
result := CleanJSONSchemaForGemini(input) result := CleanJSONSchemaForAntigravity(input)
compareJSON(t, expected, result) compareJSON(t, expected, result)
} }
func TestCleanJSONSchemaForGemini_RefHandling(t *testing.T) { func TestCleanJSONSchemaForAntigravity_RefHandling(t *testing.T) {
input := `{ input := `{
"definitions": { "definitions": {
"User": { "User": {
@@ -210,21 +212,29 @@ func TestCleanJSONSchemaForGemini_RefHandling(t *testing.T) {
} }
}` }`
// After $ref is converted to placeholder object, empty schema placeholder is also added
expected := `{ expected := `{
"type": "object", "type": "object",
"properties": { "properties": {
"customer": { "customer": {
"type": "object", "type": "object",
"description": "See: User" "description": "See: User",
"properties": {
"reason": {
"type": "string",
"description": "Brief explanation of why you are calling this tool"
}
},
"required": ["reason"]
} }
} }
}` }`
result := CleanJSONSchemaForGemini(input) result := CleanJSONSchemaForAntigravity(input)
compareJSON(t, expected, result) compareJSON(t, expected, result)
} }
func TestCleanJSONSchemaForGemini_RefHandling_DescriptionEscaping(t *testing.T) { func TestCleanJSONSchemaForAntigravity_RefHandling_DescriptionEscaping(t *testing.T) {
input := `{ input := `{
"definitions": { "definitions": {
"User": { "User": {
@@ -243,21 +253,29 @@ func TestCleanJSONSchemaForGemini_RefHandling_DescriptionEscaping(t *testing.T)
} }
}` }`
// After $ref is converted, empty schema placeholder is also added
expected := `{ expected := `{
"type": "object", "type": "object",
"properties": { "properties": {
"customer": { "customer": {
"type": "object", "type": "object",
"description": "He said \"hi\"\\nsecond line (See: User)" "description": "He said \"hi\"\\nsecond line (See: User)",
"properties": {
"reason": {
"type": "string",
"description": "Brief explanation of why you are calling this tool"
}
},
"required": ["reason"]
} }
} }
}` }`
result := CleanJSONSchemaForGemini(input) result := CleanJSONSchemaForAntigravity(input)
compareJSON(t, expected, result) compareJSON(t, expected, result)
} }
func TestCleanJSONSchemaForGemini_CyclicRefDefaults(t *testing.T) { func TestCleanJSONSchemaForAntigravity_CyclicRefDefaults(t *testing.T) {
input := `{ input := `{
"definitions": { "definitions": {
"Node": { "Node": {
@@ -270,7 +288,7 @@ func TestCleanJSONSchemaForGemini_CyclicRefDefaults(t *testing.T) {
"$ref": "#/definitions/Node" "$ref": "#/definitions/Node"
}` }`
result := CleanJSONSchemaForGemini(input) result := CleanJSONSchemaForAntigravity(input)
var resMap map[string]interface{} var resMap map[string]interface{}
json.Unmarshal([]byte(result), &resMap) json.Unmarshal([]byte(result), &resMap)
@@ -285,7 +303,7 @@ func TestCleanJSONSchemaForGemini_CyclicRefDefaults(t *testing.T) {
} }
} }
func TestCleanJSONSchemaForGemini_RequiredCleanup(t *testing.T) { func TestCleanJSONSchemaForAntigravity_RequiredCleanup(t *testing.T) {
input := `{ input := `{
"type": "object", "type": "object",
"properties": { "properties": {
@@ -304,11 +322,11 @@ func TestCleanJSONSchemaForGemini_RequiredCleanup(t *testing.T) {
"required": ["a", "b"] "required": ["a", "b"]
}` }`
result := CleanJSONSchemaForGemini(input) result := CleanJSONSchemaForAntigravity(input)
compareJSON(t, expected, result) compareJSON(t, expected, result)
} }
func TestCleanJSONSchemaForGemini_AllOfMerging_DotKeys(t *testing.T) { func TestCleanJSONSchemaForAntigravity_AllOfMerging_DotKeys(t *testing.T) {
input := `{ input := `{
"type": "object", "type": "object",
"allOf": [ "allOf": [
@@ -336,11 +354,11 @@ func TestCleanJSONSchemaForGemini_AllOfMerging_DotKeys(t *testing.T) {
"required": ["my.param", "b"] "required": ["my.param", "b"]
}` }`
result := CleanJSONSchemaForGemini(input) result := CleanJSONSchemaForAntigravity(input)
compareJSON(t, expected, result) compareJSON(t, expected, result)
} }
func TestCleanJSONSchemaForGemini_PropertyNameCollision(t *testing.T) { func TestCleanJSONSchemaForAntigravity_PropertyNameCollision(t *testing.T) {
// A tool has an argument named "pattern" - should NOT be treated as a constraint // A tool has an argument named "pattern" - should NOT be treated as a constraint
input := `{ input := `{
"type": "object", "type": "object",
@@ -364,7 +382,7 @@ func TestCleanJSONSchemaForGemini_PropertyNameCollision(t *testing.T) {
"required": ["pattern"] "required": ["pattern"]
}` }`
result := CleanJSONSchemaForGemini(input) result := CleanJSONSchemaForAntigravity(input)
compareJSON(t, expected, result) compareJSON(t, expected, result)
var resMap map[string]interface{} var resMap map[string]interface{}
@@ -375,7 +393,7 @@ func TestCleanJSONSchemaForGemini_PropertyNameCollision(t *testing.T) {
} }
} }
func TestCleanJSONSchemaForGemini_DotKeys(t *testing.T) { func TestCleanJSONSchemaForAntigravity_DotKeys(t *testing.T) {
input := `{ input := `{
"type": "object", "type": "object",
"properties": { "properties": {
@@ -389,7 +407,7 @@ func TestCleanJSONSchemaForGemini_DotKeys(t *testing.T) {
} }
}` }`
result := CleanJSONSchemaForGemini(input) result := CleanJSONSchemaForAntigravity(input)
var resMap map[string]interface{} var resMap map[string]interface{}
if err := json.Unmarshal([]byte(result), &resMap); err != nil { if err := json.Unmarshal([]byte(result), &resMap); err != nil {
@@ -414,7 +432,7 @@ func TestCleanJSONSchemaForGemini_DotKeys(t *testing.T) {
} }
} }
func TestCleanJSONSchemaForGemini_AnyOfAlternativeHints(t *testing.T) { func TestCleanJSONSchemaForAntigravity_AnyOfAlternativeHints(t *testing.T) {
input := `{ input := `{
"type": "object", "type": "object",
"properties": { "properties": {
@@ -428,7 +446,7 @@ func TestCleanJSONSchemaForGemini_AnyOfAlternativeHints(t *testing.T) {
} }
}` }`
result := CleanJSONSchemaForGemini(input) result := CleanJSONSchemaForAntigravity(input)
if !strings.Contains(result, "Accepts:") { if !strings.Contains(result, "Accepts:") {
t.Errorf("Expected alternative types hint, got: %s", result) t.Errorf("Expected alternative types hint, got: %s", result)
@@ -438,7 +456,7 @@ func TestCleanJSONSchemaForGemini_AnyOfAlternativeHints(t *testing.T) {
} }
} }
func TestCleanJSONSchemaForGemini_NullableHint(t *testing.T) { func TestCleanJSONSchemaForAntigravity_NullableHint(t *testing.T) {
input := `{ input := `{
"type": "object", "type": "object",
"properties": { "properties": {
@@ -450,7 +468,7 @@ func TestCleanJSONSchemaForGemini_NullableHint(t *testing.T) {
"required": ["name"] "required": ["name"]
}` }`
result := CleanJSONSchemaForGemini(input) result := CleanJSONSchemaForAntigravity(input)
if !strings.Contains(result, "(nullable)") { if !strings.Contains(result, "(nullable)") {
t.Errorf("Expected nullable hint, got: %s", result) t.Errorf("Expected nullable hint, got: %s", result)
@@ -460,7 +478,7 @@ func TestCleanJSONSchemaForGemini_NullableHint(t *testing.T) {
} }
} }
func TestCleanJSONSchemaForGemini_TypeFlattening_Nullable_DotKey(t *testing.T) { func TestCleanJSONSchemaForAntigravity_TypeFlattening_Nullable_DotKey(t *testing.T) {
input := `{ input := `{
"type": "object", "type": "object",
"properties": { "properties": {
@@ -488,11 +506,11 @@ func TestCleanJSONSchemaForGemini_TypeFlattening_Nullable_DotKey(t *testing.T) {
"required": ["other"] "required": ["other"]
}` }`
result := CleanJSONSchemaForGemini(input) result := CleanJSONSchemaForAntigravity(input)
compareJSON(t, expected, result) compareJSON(t, expected, result)
} }
func TestCleanJSONSchemaForGemini_EnumHint(t *testing.T) { func TestCleanJSONSchemaForAntigravity_EnumHint(t *testing.T) {
input := `{ input := `{
"type": "object", "type": "object",
"properties": { "properties": {
@@ -504,7 +522,7 @@ func TestCleanJSONSchemaForGemini_EnumHint(t *testing.T) {
} }
}` }`
result := CleanJSONSchemaForGemini(input) result := CleanJSONSchemaForAntigravity(input)
if !strings.Contains(result, "Allowed:") { if !strings.Contains(result, "Allowed:") {
t.Errorf("Expected enum values hint, got: %s", result) t.Errorf("Expected enum values hint, got: %s", result)
@@ -514,7 +532,7 @@ func TestCleanJSONSchemaForGemini_EnumHint(t *testing.T) {
} }
} }
func TestCleanJSONSchemaForGemini_AdditionalPropertiesHint(t *testing.T) { func TestCleanJSONSchemaForAntigravity_AdditionalPropertiesHint(t *testing.T) {
input := `{ input := `{
"type": "object", "type": "object",
"properties": { "properties": {
@@ -523,14 +541,14 @@ func TestCleanJSONSchemaForGemini_AdditionalPropertiesHint(t *testing.T) {
"additionalProperties": false "additionalProperties": false
}` }`
result := CleanJSONSchemaForGemini(input) result := CleanJSONSchemaForAntigravity(input)
if !strings.Contains(result, "No extra properties allowed") { if !strings.Contains(result, "No extra properties allowed") {
t.Errorf("Expected additionalProperties hint, got: %s", result) t.Errorf("Expected additionalProperties hint, got: %s", result)
} }
} }
func TestCleanJSONSchemaForGemini_AnyOfFlattening_PreservesDescription(t *testing.T) { func TestCleanJSONSchemaForAntigravity_AnyOfFlattening_PreservesDescription(t *testing.T) {
input := `{ input := `{
"type": "object", "type": "object",
"properties": { "properties": {
@@ -554,11 +572,11 @@ func TestCleanJSONSchemaForGemini_AnyOfFlattening_PreservesDescription(t *testin
} }
}` }`
result := CleanJSONSchemaForGemini(input) result := CleanJSONSchemaForAntigravity(input)
compareJSON(t, expected, result) compareJSON(t, expected, result)
} }
func TestCleanJSONSchemaForGemini_SingleEnumNoHint(t *testing.T) { func TestCleanJSONSchemaForAntigravity_SingleEnumNoHint(t *testing.T) {
input := `{ input := `{
"type": "object", "type": "object",
"properties": { "properties": {
@@ -569,14 +587,14 @@ func TestCleanJSONSchemaForGemini_SingleEnumNoHint(t *testing.T) {
} }
}` }`
result := CleanJSONSchemaForGemini(input) result := CleanJSONSchemaForAntigravity(input)
if strings.Contains(result, "Allowed:") { if strings.Contains(result, "Allowed:") {
t.Errorf("Single value enum should not add Allowed hint, got: %s", result) t.Errorf("Single value enum should not add Allowed hint, got: %s", result)
} }
} }
func TestCleanJSONSchemaForGemini_MultipleNonNullTypes(t *testing.T) { func TestCleanJSONSchemaForAntigravity_MultipleNonNullTypes(t *testing.T) {
input := `{ input := `{
"type": "object", "type": "object",
"properties": { "properties": {
@@ -586,7 +604,7 @@ func TestCleanJSONSchemaForGemini_MultipleNonNullTypes(t *testing.T) {
} }
}` }`
result := CleanJSONSchemaForGemini(input) result := CleanJSONSchemaForAntigravity(input)
if !strings.Contains(result, "Accepts:") { if !strings.Contains(result, "Accepts:") {
t.Errorf("Expected multiple types hint, got: %s", result) t.Errorf("Expected multiple types hint, got: %s", result)
@@ -676,3 +694,190 @@ func compareJSON(t *testing.T, expectedJSON, actualJSON string) {
t.Errorf("JSON mismatch:\nExpected:\n%s\n\nActual:\n%s", string(expBytes), string(actBytes)) t.Errorf("JSON mismatch:\nExpected:\n%s\n\nActual:\n%s", string(expBytes), string(actBytes))
} }
} }
// ============================================================================
// Empty Schema Placeholder Tests
// ============================================================================
func TestCleanJSONSchemaForAntigravity_EmptySchemaPlaceholder(t *testing.T) {
// Empty object schema with no properties should get a placeholder
input := `{
"type": "object"
}`
result := CleanJSONSchemaForAntigravity(input)
// Should have placeholder property added
if !strings.Contains(result, `"reason"`) {
t.Errorf("Empty schema should have 'reason' placeholder property, got: %s", result)
}
if !strings.Contains(result, `"required"`) {
t.Errorf("Empty schema should have 'required' with 'reason', got: %s", result)
}
}
func TestCleanJSONSchemaForAntigravity_EmptyPropertiesPlaceholder(t *testing.T) {
// Object with empty properties object
input := `{
"type": "object",
"properties": {}
}`
result := CleanJSONSchemaForAntigravity(input)
// Should have placeholder property added
if !strings.Contains(result, `"reason"`) {
t.Errorf("Empty properties should have 'reason' placeholder, got: %s", result)
}
}
func TestCleanJSONSchemaForAntigravity_NonEmptySchemaUnchanged(t *testing.T) {
// Schema with properties should NOT get placeholder
input := `{
"type": "object",
"properties": {
"name": {"type": "string"}
},
"required": ["name"]
}`
result := CleanJSONSchemaForAntigravity(input)
// Should NOT have placeholder property
if strings.Contains(result, `"reason"`) {
t.Errorf("Non-empty schema should NOT have 'reason' placeholder, got: %s", result)
}
// Original properties should be preserved
if !strings.Contains(result, `"name"`) {
t.Errorf("Original property 'name' should be preserved, got: %s", result)
}
}
func TestCleanJSONSchemaForAntigravity_NestedEmptySchema(t *testing.T) {
// Nested empty object in items should also get placeholder
input := `{
"type": "object",
"properties": {
"items": {
"type": "array",
"items": {
"type": "object"
}
}
}
}`
result := CleanJSONSchemaForAntigravity(input)
// Nested empty object should also get placeholder
// Check that the nested object has a reason property
parsed := gjson.Parse(result)
nestedProps := parsed.Get("properties.items.items.properties")
if !nestedProps.Exists() || !nestedProps.Get("reason").Exists() {
t.Errorf("Nested empty object should have 'reason' placeholder, got: %s", result)
}
}
func TestCleanJSONSchemaForAntigravity_EmptySchemaWithDescription(t *testing.T) {
// Empty schema with description should preserve description and add placeholder
input := `{
"type": "object",
"description": "An empty object"
}`
result := CleanJSONSchemaForAntigravity(input)
// Should have both description and placeholder
if !strings.Contains(result, `"An empty object"`) {
t.Errorf("Description should be preserved, got: %s", result)
}
if !strings.Contains(result, `"reason"`) {
t.Errorf("Empty schema should have 'reason' placeholder, got: %s", result)
}
}
// ============================================================================
// Format field handling (ad-hoc patch removal)
// ============================================================================
func TestCleanJSONSchemaForAntigravity_FormatFieldRemoval(t *testing.T) {
// format:"uri" should be removed and added as hint
input := `{
"type": "object",
"properties": {
"url": {
"type": "string",
"format": "uri",
"description": "A URL"
}
}
}`
result := CleanJSONSchemaForAntigravity(input)
// format should be removed
if strings.Contains(result, `"format"`) {
t.Errorf("format field should be removed, got: %s", result)
}
// hint should be added to description
if !strings.Contains(result, "format: uri") {
t.Errorf("format hint should be added to description, got: %s", result)
}
// original description should be preserved
if !strings.Contains(result, "A URL") {
t.Errorf("Original description should be preserved, got: %s", result)
}
}
func TestCleanJSONSchemaForAntigravity_FormatFieldNoDescription(t *testing.T) {
// format without description should create description with hint
input := `{
"type": "object",
"properties": {
"email": {
"type": "string",
"format": "email"
}
}
}`
result := CleanJSONSchemaForAntigravity(input)
// format should be removed
if strings.Contains(result, `"format"`) {
t.Errorf("format field should be removed, got: %s", result)
}
// hint should be added
if !strings.Contains(result, "format: email") {
t.Errorf("format hint should be added, got: %s", result)
}
}
func TestCleanJSONSchemaForAntigravity_MultipleFormats(t *testing.T) {
// Multiple format fields should all be handled
input := `{
"type": "object",
"properties": {
"url": {"type": "string", "format": "uri"},
"email": {"type": "string", "format": "email"},
"date": {"type": "string", "format": "date-time"}
}
}`
result := CleanJSONSchemaForAntigravity(input)
// All format fields should be removed
if strings.Contains(result, `"format"`) {
t.Errorf("All format fields should be removed, got: %s", result)
}
// All hints should be added
if !strings.Contains(result, "format: uri") {
t.Errorf("uri format hint should be added, got: %s", result)
}
if !strings.Contains(result, "format: email") {
t.Errorf("email format hint should be added, got: %s", result)
}
if !strings.Contains(result, "format: date-time") {
t.Errorf("date-time format hint should be added, got: %s", result)
}
}

View File

@@ -0,0 +1,87 @@
package util
import (
"github.com/tidwall/gjson"
"github.com/tidwall/sjson"
)
// GetThinkingText extracts the thinking text from a content part.
// Handles various formats:
// - Simple string: { "thinking": "text" } or { "text": "text" }
// - Wrapped object: { "thinking": { "text": "text", "cache_control": {...} } }
// - Gemini-style: { "thought": true, "text": "text" }
// Returns the extracted text string.
func GetThinkingText(part gjson.Result) string {
// Try direct text field first (Gemini-style)
if text := part.Get("text"); text.Exists() && text.Type == gjson.String {
return text.String()
}
// Try thinking field
thinkingField := part.Get("thinking")
if !thinkingField.Exists() {
return ""
}
// thinking is a string
if thinkingField.Type == gjson.String {
return thinkingField.String()
}
// thinking is an object with inner text/thinking
if thinkingField.IsObject() {
if inner := thinkingField.Get("text"); inner.Exists() && inner.Type == gjson.String {
return inner.String()
}
if inner := thinkingField.Get("thinking"); inner.Exists() && inner.Type == gjson.String {
return inner.String()
}
}
return ""
}
// GetThinkingTextFromJSON extracts thinking text from a raw JSON string.
func GetThinkingTextFromJSON(jsonStr string) string {
return GetThinkingText(gjson.Parse(jsonStr))
}
// SanitizeThinkingPart normalizes a thinking part to a canonical form.
// Strips cache_control and other non-essential fields.
// Returns the sanitized part as JSON string.
func SanitizeThinkingPart(part gjson.Result) string {
// Gemini-style: { thought: true, text, thoughtSignature }
if part.Get("thought").Bool() {
result := `{"thought":true}`
if text := GetThinkingText(part); text != "" {
result, _ = sjson.Set(result, "text", text)
}
if sig := part.Get("thoughtSignature"); sig.Exists() && sig.Type == gjson.String {
result, _ = sjson.Set(result, "thoughtSignature", sig.String())
}
return result
}
// Anthropic-style: { type: "thinking", thinking, signature }
if part.Get("type").String() == "thinking" || part.Get("thinking").Exists() {
result := `{"type":"thinking"}`
if text := GetThinkingText(part); text != "" {
result, _ = sjson.Set(result, "thinking", text)
}
if sig := part.Get("signature"); sig.Exists() && sig.Type == gjson.String {
result, _ = sjson.Set(result, "signature", sig.String())
}
return result
}
// Not a thinking part, return as-is but strip cache_control
return StripCacheControl(part.Raw)
}
// StripCacheControl removes cache_control and providerOptions from a JSON object.
func StripCacheControl(jsonStr string) string {
result := jsonStr
result, _ = sjson.Delete(result, "cache_control")
result, _ = sjson.Delete(result, "providerOptions")
return result
}

View File

@@ -17,6 +17,7 @@ import (
. "github.com/router-for-me/CLIProxyAPI/v6/internal/constant" . "github.com/router-for-me/CLIProxyAPI/v6/internal/constant"
"github.com/router-for-me/CLIProxyAPI/v6/internal/interfaces" "github.com/router-for-me/CLIProxyAPI/v6/internal/interfaces"
"github.com/router-for-me/CLIProxyAPI/v6/internal/registry" "github.com/router-for-me/CLIProxyAPI/v6/internal/registry"
responsesconverter "github.com/router-for-me/CLIProxyAPI/v6/internal/translator/openai/openai/responses"
"github.com/router-for-me/CLIProxyAPI/v6/sdk/api/handlers" "github.com/router-for-me/CLIProxyAPI/v6/sdk/api/handlers"
"github.com/tidwall/gjson" "github.com/tidwall/gjson"
"github.com/tidwall/sjson" "github.com/tidwall/sjson"
@@ -109,7 +110,17 @@ func (h *OpenAIAPIHandler) ChatCompletions(c *gin.Context) {
// Check if the client requested a streaming response. // Check if the client requested a streaming response.
streamResult := gjson.GetBytes(rawJSON, "stream") streamResult := gjson.GetBytes(rawJSON, "stream")
if streamResult.Type == gjson.True { stream := streamResult.Type == gjson.True
// Some clients send OpenAI Responses-format payloads to /v1/chat/completions.
// Convert them to Chat Completions so downstream translators preserve tool metadata.
if shouldTreatAsResponsesFormat(rawJSON) {
modelName := gjson.GetBytes(rawJSON, "model").String()
rawJSON = responsesconverter.ConvertOpenAIResponsesRequestToOpenAIChatCompletions(modelName, rawJSON, stream)
stream = gjson.GetBytes(rawJSON, "stream").Bool()
}
if stream {
h.handleStreamingResponse(c, rawJSON) h.handleStreamingResponse(c, rawJSON)
} else { } else {
h.handleNonStreamingResponse(c, rawJSON) h.handleNonStreamingResponse(c, rawJSON)
@@ -117,6 +128,21 @@ func (h *OpenAIAPIHandler) ChatCompletions(c *gin.Context) {
} }
// shouldTreatAsResponsesFormat detects OpenAI Responses-style payloads that are
// accidentally sent to the Chat Completions endpoint.
func shouldTreatAsResponsesFormat(rawJSON []byte) bool {
if gjson.GetBytes(rawJSON, "messages").Exists() {
return false
}
if gjson.GetBytes(rawJSON, "input").Exists() {
return true
}
if gjson.GetBytes(rawJSON, "instructions").Exists() {
return true
}
return false
}
// Completions handles the /v1/completions endpoint. // Completions handles the /v1/completions endpoint.
// It determines whether the request is for a streaming or non-streaming response // It determines whether the request is for a streaming or non-streaming response
// and calls the appropriate handler based on the model provider. // and calls the appropriate handler based on the model provider.

View File

@@ -99,11 +99,54 @@ func (AntigravityAuthenticator) Login(ctx context.Context, cfg *config.Config, o
fmt.Println("Waiting for antigravity authentication callback...") fmt.Println("Waiting for antigravity authentication callback...")
var cbRes callbackResult var cbRes callbackResult
select { timeoutTimer := time.NewTimer(5 * time.Minute)
case res := <-cbChan: defer timeoutTimer.Stop()
cbRes = res
case <-time.After(5 * time.Minute): var manualPromptTimer *time.Timer
return nil, fmt.Errorf("antigravity: authentication timed out") var manualPromptC <-chan time.Time
if opts.Prompt != nil {
manualPromptTimer = time.NewTimer(15 * time.Second)
manualPromptC = manualPromptTimer.C
defer manualPromptTimer.Stop()
}
waitForCallback:
for {
select {
case res := <-cbChan:
cbRes = res
break waitForCallback
case <-manualPromptC:
manualPromptC = nil
if manualPromptTimer != nil {
manualPromptTimer.Stop()
}
select {
case res := <-cbChan:
cbRes = res
break waitForCallback
default:
}
input, errPrompt := opts.Prompt("Paste the antigravity callback URL (or press Enter to keep waiting): ")
if errPrompt != nil {
return nil, errPrompt
}
parsed, errParse := misc.ParseOAuthCallback(input)
if errParse != nil {
return nil, errParse
}
if parsed == nil {
continue
}
cbRes = callbackResult{
Code: parsed.Code,
State: parsed.State,
Error: parsed.Error,
}
break waitForCallback
case <-timeoutTimer.C:
return nil, fmt.Errorf("antigravity: authentication timed out")
}
} }
if cbRes.Error != "" { if cbRes.Error != "" {

View File

@@ -98,16 +98,76 @@ func (a *ClaudeAuthenticator) Login(ctx context.Context, cfg *config.Config, opt
fmt.Println("Waiting for Claude authentication callback...") fmt.Println("Waiting for Claude authentication callback...")
result, err := oauthServer.WaitForCallback(5 * time.Minute) callbackCh := make(chan *claude.OAuthResult, 1)
if err != nil { callbackErrCh := make(chan error, 1)
if strings.Contains(err.Error(), "timeout") { manualDescription := ""
return nil, claude.NewAuthenticationError(claude.ErrCallbackTimeout, err)
go func() {
result, errWait := oauthServer.WaitForCallback(5 * time.Minute)
if errWait != nil {
callbackErrCh <- errWait
return
}
callbackCh <- result
}()
var result *claude.OAuthResult
var manualPromptTimer *time.Timer
var manualPromptC <-chan time.Time
if opts.Prompt != nil {
manualPromptTimer = time.NewTimer(15 * time.Second)
manualPromptC = manualPromptTimer.C
defer manualPromptTimer.Stop()
}
waitForCallback:
for {
select {
case result = <-callbackCh:
break waitForCallback
case err = <-callbackErrCh:
if strings.Contains(err.Error(), "timeout") {
return nil, claude.NewAuthenticationError(claude.ErrCallbackTimeout, err)
}
return nil, err
case <-manualPromptC:
manualPromptC = nil
if manualPromptTimer != nil {
manualPromptTimer.Stop()
}
select {
case result = <-callbackCh:
break waitForCallback
case err = <-callbackErrCh:
if strings.Contains(err.Error(), "timeout") {
return nil, claude.NewAuthenticationError(claude.ErrCallbackTimeout, err)
}
return nil, err
default:
}
input, errPrompt := opts.Prompt("Paste the Claude callback URL (or press Enter to keep waiting): ")
if errPrompt != nil {
return nil, errPrompt
}
parsed, errParse := misc.ParseOAuthCallback(input)
if errParse != nil {
return nil, errParse
}
if parsed == nil {
continue
}
manualDescription = parsed.ErrorDescription
result = &claude.OAuthResult{
Code: parsed.Code,
State: parsed.State,
Error: parsed.Error,
}
break waitForCallback
} }
return nil, err
} }
if result.Error != "" { if result.Error != "" {
return nil, claude.NewOAuthError(result.Error, "", http.StatusBadRequest) return nil, claude.NewOAuthError(result.Error, manualDescription, http.StatusBadRequest)
} }
if result.State != state { if result.State != state {

View File

@@ -97,16 +97,76 @@ func (a *CodexAuthenticator) Login(ctx context.Context, cfg *config.Config, opts
fmt.Println("Waiting for Codex authentication callback...") fmt.Println("Waiting for Codex authentication callback...")
result, err := oauthServer.WaitForCallback(5 * time.Minute) callbackCh := make(chan *codex.OAuthResult, 1)
if err != nil { callbackErrCh := make(chan error, 1)
if strings.Contains(err.Error(), "timeout") { manualDescription := ""
return nil, codex.NewAuthenticationError(codex.ErrCallbackTimeout, err)
go func() {
result, errWait := oauthServer.WaitForCallback(5 * time.Minute)
if errWait != nil {
callbackErrCh <- errWait
return
}
callbackCh <- result
}()
var result *codex.OAuthResult
var manualPromptTimer *time.Timer
var manualPromptC <-chan time.Time
if opts.Prompt != nil {
manualPromptTimer = time.NewTimer(15 * time.Second)
manualPromptC = manualPromptTimer.C
defer manualPromptTimer.Stop()
}
waitForCallback:
for {
select {
case result = <-callbackCh:
break waitForCallback
case err = <-callbackErrCh:
if strings.Contains(err.Error(), "timeout") {
return nil, codex.NewAuthenticationError(codex.ErrCallbackTimeout, err)
}
return nil, err
case <-manualPromptC:
manualPromptC = nil
if manualPromptTimer != nil {
manualPromptTimer.Stop()
}
select {
case result = <-callbackCh:
break waitForCallback
case err = <-callbackErrCh:
if strings.Contains(err.Error(), "timeout") {
return nil, codex.NewAuthenticationError(codex.ErrCallbackTimeout, err)
}
return nil, err
default:
}
input, errPrompt := opts.Prompt("Paste the Codex callback URL (or press Enter to keep waiting): ")
if errPrompt != nil {
return nil, errPrompt
}
parsed, errParse := misc.ParseOAuthCallback(input)
if errParse != nil {
return nil, errParse
}
if parsed == nil {
continue
}
manualDescription = parsed.ErrorDescription
result = &codex.OAuthResult{
Code: parsed.Code,
State: parsed.State,
Error: parsed.Error,
}
break waitForCallback
} }
return nil, err
} }
if result.Error != "" { if result.Error != "" {
return nil, codex.NewOAuthError(result.Error, "", http.StatusBadRequest) return nil, codex.NewOAuthError(result.Error, manualDescription, http.StatusBadRequest)
} }
if result.State != state { if result.State != state {

View File

@@ -44,7 +44,10 @@ func (a *GeminiAuthenticator) Login(ctx context.Context, cfg *config.Config, opt
} }
geminiAuth := gemini.NewGeminiAuth() geminiAuth := gemini.NewGeminiAuth()
_, err := geminiAuth.GetAuthenticatedClient(ctx, &ts, cfg, opts.NoBrowser) _, err := geminiAuth.GetAuthenticatedClient(ctx, &ts, cfg, &gemini.WebLoginOptions{
NoBrowser: opts.NoBrowser,
Prompt: opts.Prompt,
})
if err != nil { if err != nil {
return nil, fmt.Errorf("gemini authentication failed: %w", err) return nil, fmt.Errorf("gemini authentication failed: %w", err)
} }

View File

@@ -84,9 +84,64 @@ func (a *IFlowAuthenticator) Login(ctx context.Context, cfg *config.Config, opts
fmt.Println("Waiting for iFlow authentication callback...") fmt.Println("Waiting for iFlow authentication callback...")
result, err := oauthServer.WaitForCallback(5 * time.Minute) callbackCh := make(chan *iflow.OAuthResult, 1)
if err != nil { callbackErrCh := make(chan error, 1)
return nil, fmt.Errorf("iflow auth: callback wait failed: %w", err)
go func() {
result, errWait := oauthServer.WaitForCallback(5 * time.Minute)
if errWait != nil {
callbackErrCh <- errWait
return
}
callbackCh <- result
}()
var result *iflow.OAuthResult
var manualPromptTimer *time.Timer
var manualPromptC <-chan time.Time
if opts.Prompt != nil {
manualPromptTimer = time.NewTimer(15 * time.Second)
manualPromptC = manualPromptTimer.C
defer manualPromptTimer.Stop()
}
waitForCallback:
for {
select {
case result = <-callbackCh:
break waitForCallback
case err = <-callbackErrCh:
return nil, fmt.Errorf("iflow auth: callback wait failed: %w", err)
case <-manualPromptC:
manualPromptC = nil
if manualPromptTimer != nil {
manualPromptTimer.Stop()
}
select {
case result = <-callbackCh:
break waitForCallback
case err = <-callbackErrCh:
return nil, fmt.Errorf("iflow auth: callback wait failed: %w", err)
default:
}
input, errPrompt := opts.Prompt("Paste the iFlow callback URL (or press Enter to keep waiting): ")
if errPrompt != nil {
return nil, errPrompt
}
parsed, errParse := misc.ParseOAuthCallback(input)
if errParse != nil {
return nil, errParse
}
if parsed == nil {
continue
}
result = &iflow.OAuthResult{
Code: parsed.Code,
State: parsed.State,
Error: parsed.Error,
}
break waitForCallback
}
} }
if result.Error != "" { if result.Error != "" {
return nil, fmt.Errorf("iflow auth: provider returned error %s", result.Error) return nil, fmt.Errorf("iflow auth: provider returned error %s", result.Error)

View File

@@ -135,6 +135,18 @@ func NewManager(store Store, selector Selector, hook Hook) *Manager {
} }
} }
func (m *Manager) SetSelector(selector Selector) {
if m == nil {
return
}
if selector == nil {
selector = &RoundRobinSelector{}
}
m.mu.Lock()
m.selector = selector
m.mu.Unlock()
}
// SetStore swaps the underlying persistence store. // SetStore swaps the underlying persistence store.
func (m *Manager) SetStore(store Store) { func (m *Manager) SetStore(store Store) {
m.mu.Lock() m.mu.Lock()

View File

@@ -20,6 +20,11 @@ type RoundRobinSelector struct {
cursors map[string]int cursors map[string]int
} }
// FillFirstSelector selects the first available credential (deterministic ordering).
// This "burns" one account before moving to the next, which can help stagger
// rolling-window subscription caps (e.g. chat message limits).
type FillFirstSelector struct{}
type blockReason int type blockReason int
const ( const (
@@ -98,20 +103,8 @@ func (e *modelCooldownError) Headers() http.Header {
return headers return headers
} }
// Pick selects the next available auth for the provider in a round-robin manner. func collectAvailable(auths []*Auth, model string, now time.Time) (available []*Auth, cooldownCount int, earliest time.Time) {
func (s *RoundRobinSelector) Pick(ctx context.Context, provider, model string, opts cliproxyexecutor.Options, auths []*Auth) (*Auth, error) { available = make([]*Auth, 0, len(auths))
_ = ctx
_ = opts
if len(auths) == 0 {
return nil, &Error{Code: "auth_not_found", Message: "no auth candidates"}
}
if s.cursors == nil {
s.cursors = make(map[string]int)
}
available := make([]*Auth, 0, len(auths))
now := time.Now()
cooldownCount := 0
var earliest time.Time
for i := 0; i < len(auths); i++ { for i := 0; i < len(auths); i++ {
candidate := auths[i] candidate := auths[i]
blocked, reason, next := isAuthBlockedForModel(candidate, model, now) blocked, reason, next := isAuthBlockedForModel(candidate, model, now)
@@ -126,6 +119,18 @@ func (s *RoundRobinSelector) Pick(ctx context.Context, provider, model string, o
} }
} }
} }
if len(available) > 1 {
sort.Slice(available, func(i, j int) bool { return available[i].ID < available[j].ID })
}
return available, cooldownCount, earliest
}
func getAvailableAuths(auths []*Auth, provider, model string, now time.Time) ([]*Auth, error) {
if len(auths) == 0 {
return nil, &Error{Code: "auth_not_found", Message: "no auth candidates"}
}
available, cooldownCount, earliest := collectAvailable(auths, model, now)
if len(available) == 0 { if len(available) == 0 {
if cooldownCount == len(auths) && !earliest.IsZero() { if cooldownCount == len(auths) && !earliest.IsZero() {
resetIn := earliest.Sub(now) resetIn := earliest.Sub(now)
@@ -136,12 +141,24 @@ func (s *RoundRobinSelector) Pick(ctx context.Context, provider, model string, o
} }
return nil, &Error{Code: "auth_unavailable", Message: "no auth available"} return nil, &Error{Code: "auth_unavailable", Message: "no auth available"}
} }
// Make round-robin deterministic even if caller's candidate order is unstable.
if len(available) > 1 { return available, nil
sort.Slice(available, func(i, j int) bool { return available[i].ID < available[j].ID }) }
// Pick selects the next available auth for the provider in a round-robin manner.
func (s *RoundRobinSelector) Pick(ctx context.Context, provider, model string, opts cliproxyexecutor.Options, auths []*Auth) (*Auth, error) {
_ = ctx
_ = opts
now := time.Now()
available, err := getAvailableAuths(auths, provider, model, now)
if err != nil {
return nil, err
} }
key := provider + ":" + model key := provider + ":" + model
s.mu.Lock() s.mu.Lock()
if s.cursors == nil {
s.cursors = make(map[string]int)
}
index := s.cursors[key] index := s.cursors[key]
if index >= 2_147_483_640 { if index >= 2_147_483_640 {
@@ -154,6 +171,18 @@ func (s *RoundRobinSelector) Pick(ctx context.Context, provider, model string, o
return available[index%len(available)], nil return available[index%len(available)], nil
} }
// Pick selects the first available auth for the provider in a deterministic manner.
func (s *FillFirstSelector) Pick(ctx context.Context, provider, model string, opts cliproxyexecutor.Options, auths []*Auth) (*Auth, error) {
_ = ctx
_ = opts
now := time.Now()
available, err := getAvailableAuths(auths, provider, model, now)
if err != nil {
return nil, err
}
return available[0], nil
}
func isAuthBlockedForModel(auth *Auth, model string, now time.Time) (bool, blockReason, time.Time) { func isAuthBlockedForModel(auth *Auth, model string, now time.Time) (bool, blockReason, time.Time) {
if auth == nil { if auth == nil {
return true, blockReasonOther, time.Time{} return true, blockReasonOther, time.Time{}

View File

@@ -0,0 +1,113 @@
package auth
import (
"context"
"errors"
"sync"
"testing"
cliproxyexecutor "github.com/router-for-me/CLIProxyAPI/v6/sdk/cliproxy/executor"
)
func TestFillFirstSelectorPick_Deterministic(t *testing.T) {
t.Parallel()
selector := &FillFirstSelector{}
auths := []*Auth{
{ID: "b"},
{ID: "a"},
{ID: "c"},
}
got, err := selector.Pick(context.Background(), "gemini", "", cliproxyexecutor.Options{}, auths)
if err != nil {
t.Fatalf("Pick() error = %v", err)
}
if got == nil {
t.Fatalf("Pick() auth = nil")
}
if got.ID != "a" {
t.Fatalf("Pick() auth.ID = %q, want %q", got.ID, "a")
}
}
func TestRoundRobinSelectorPick_CyclesDeterministic(t *testing.T) {
t.Parallel()
selector := &RoundRobinSelector{}
auths := []*Auth{
{ID: "b"},
{ID: "a"},
{ID: "c"},
}
want := []string{"a", "b", "c", "a", "b"}
for i, id := range want {
got, err := selector.Pick(context.Background(), "gemini", "", cliproxyexecutor.Options{}, auths)
if err != nil {
t.Fatalf("Pick() #%d error = %v", i, err)
}
if got == nil {
t.Fatalf("Pick() #%d auth = nil", i)
}
if got.ID != id {
t.Fatalf("Pick() #%d auth.ID = %q, want %q", i, got.ID, id)
}
}
}
func TestRoundRobinSelectorPick_Concurrent(t *testing.T) {
selector := &RoundRobinSelector{}
auths := []*Auth{
{ID: "b"},
{ID: "a"},
{ID: "c"},
}
start := make(chan struct{})
var wg sync.WaitGroup
errCh := make(chan error, 1)
goroutines := 32
iterations := 100
for i := 0; i < goroutines; i++ {
wg.Add(1)
go func() {
defer wg.Done()
<-start
for j := 0; j < iterations; j++ {
got, err := selector.Pick(context.Background(), "gemini", "", cliproxyexecutor.Options{}, auths)
if err != nil {
select {
case errCh <- err:
default:
}
return
}
if got == nil {
select {
case errCh <- errors.New("Pick() returned nil auth"):
default:
}
return
}
if got.ID == "" {
select {
case errCh <- errors.New("Pick() returned auth with empty ID"):
default:
}
return
}
}
}()
}
close(start)
wg.Wait()
select {
case err := <-errCh:
t.Fatalf("concurrent Pick() error = %v", err)
default:
}
}

View File

@@ -5,6 +5,7 @@ package cliproxy
import ( import (
"fmt" "fmt"
"strings"
"github.com/router-for-me/CLIProxyAPI/v6/internal/api" "github.com/router-for-me/CLIProxyAPI/v6/internal/api"
sdkaccess "github.com/router-for-me/CLIProxyAPI/v6/sdk/access" sdkaccess "github.com/router-for-me/CLIProxyAPI/v6/sdk/access"
@@ -197,7 +198,20 @@ func (b *Builder) Build() (*Service, error) {
if dirSetter, ok := tokenStore.(interface{ SetBaseDir(string) }); ok && b.cfg != nil { if dirSetter, ok := tokenStore.(interface{ SetBaseDir(string) }); ok && b.cfg != nil {
dirSetter.SetBaseDir(b.cfg.AuthDir) dirSetter.SetBaseDir(b.cfg.AuthDir)
} }
coreManager = coreauth.NewManager(tokenStore, nil, nil)
strategy := ""
if b.cfg != nil {
strategy = strings.ToLower(strings.TrimSpace(b.cfg.Routing.Strategy))
}
var selector coreauth.Selector
switch strategy {
case "fill-first", "fillfirst", "ff":
selector = &coreauth.FillFirstSelector{}
default:
selector = &coreauth.RoundRobinSelector{}
}
coreManager = coreauth.NewManager(tokenStore, selector, nil)
} }
// Attach a default RoundTripper provider so providers can opt-in per-auth transports. // Attach a default RoundTripper provider so providers can opt-in per-auth transports.
coreManager.SetRoundTripperProvider(newDefaultRoundTripperProvider()) coreManager.SetRoundTripperProvider(newDefaultRoundTripperProvider())

View File

@@ -510,6 +510,13 @@ func (s *Service) Run(ctx context.Context) error {
var watcherWrapper *WatcherWrapper var watcherWrapper *WatcherWrapper
reloadCallback := func(newCfg *config.Config) { reloadCallback := func(newCfg *config.Config) {
previousStrategy := ""
s.cfgMu.RLock()
if s.cfg != nil {
previousStrategy = strings.ToLower(strings.TrimSpace(s.cfg.Routing.Strategy))
}
s.cfgMu.RUnlock()
if newCfg == nil { if newCfg == nil {
s.cfgMu.RLock() s.cfgMu.RLock()
newCfg = s.cfg newCfg = s.cfg
@@ -518,6 +525,30 @@ func (s *Service) Run(ctx context.Context) error {
if newCfg == nil { if newCfg == nil {
return return
} }
nextStrategy := strings.ToLower(strings.TrimSpace(newCfg.Routing.Strategy))
normalizeStrategy := func(strategy string) string {
switch strategy {
case "fill-first", "fillfirst", "ff":
return "fill-first"
default:
return "round-robin"
}
}
previousStrategy = normalizeStrategy(previousStrategy)
nextStrategy = normalizeStrategy(nextStrategy)
if s.coreManager != nil && previousStrategy != nextStrategy {
var selector coreauth.Selector
switch nextStrategy {
case "fill-first":
selector = &coreauth.FillFirstSelector{}
default:
selector = &coreauth.RoundRobinSelector{}
}
s.coreManager.SetSelector(selector)
log.Infof("routing strategy updated to %s", nextStrategy)
}
s.applyRetryConfig(newCfg) s.applyRetryConfig(newCfg)
if s.server != nil { if s.server != nil {
s.server.UpdateClients(newCfg) s.server.UpdateClients(newCfg)