Compare commits

...

36 Commits

Author SHA1 Message Date
Luis Pater
ee0c24628f Merge branch 'router-for-me:main' into main 2026-03-07 20:42:22 +08:00
Luis Pater
ddcf1f279d Fixed: #1901
test(websocket): add tests for incremental input and prewarm handling logic

- Added test cases for incremental input support based on upstream capabilities.
- Introduced validation for prewarm handling of `response.create` messages locally.
- Enhanced test coverage for websocket executor behavior, including payload forwarding checks.
- Updated websocket implementation with prewarm and incremental input logic for better testability.
2026-03-07 13:11:28 +08:00
Luis Pater
7e6bb8fdc5 Merge origin/dev into pr-1774-review and resolve watcher conflicts 2026-03-07 11:12:42 +08:00
Luis Pater
9cee8ef87b Merge pull request #1684 from alexey-yanchenko/fix/input-audio-from-openai-to-antigravity
fix: preserve input_audio content parts when proxying to Antigravity
2026-03-07 10:12:28 +08:00
Luis Pater
93fb841bcb Fixed: #1670
test(translator): add unit tests for OpenAI to Claude requests and tool result handling

- Introduced tests for converting OpenAI requests to Claude with text, base64 images, and URL images in tool results.
- Refactored `convertClaudeToolResultContent` and related functionality to properly handle raw content with images and text.
- Updated conversion logic to streamline image handling for both base64 and URL formats.
2026-03-07 09:25:22 +08:00
Luis Pater
0c05131aeb Merge branch 'router-for-me:main' into main 2026-03-07 09:08:28 +08:00
Luis Pater
5ebc58fab4 refactor(executor): remove legacy connCreateSent logic and standardize response.create usage for all websocket events
- Simplified connection logic by removing `connCreateSent` and related state handling.
- Updated `buildCodexWebsocketRequestBody` to always use `response.create`.
- Added unit tests to validate `response.create` behavior and beta header preservation.
- Dropped unsupported `response.append` and outdated `response.done` event types.
2026-03-07 09:07:23 +08:00
Luis Pater
2b609dd891 Merge pull request #1912 from FradSer/main
feat(registry): add gemini 3.1 flash lite preview
2026-03-07 05:41:31 +08:00
Frad LEE
a8cbc68c3e feat(registry): add gemini 3.1 flash lite preview
- Add model to GetGeminiModels()
- Add model to GetGeminiVertexModels()
- Add model to GetGeminiCLIModels()
- Add model to GetAIStudioModels()
- Add to AntigravityModelConfig with thinking levels
- Update gemini-3-flash-preview description

Registers the new lightweight Gemini model across all provider
endpoints for cost-effective high-volume usage scenarios.

Co-Authored-By: Claude Sonnet 4.6 <noreply@anthropic.com>
2026-03-06 20:52:28 +08:00
Luis Pater
89c428216e Merge branch 'router-for-me:main' into main 2026-03-06 11:09:31 +08:00
Luis Pater
2695a99623 fix(translator): conditionally remove service_tier from OpenAI response processing 2026-03-06 11:07:22 +08:00
Luis Pater
ad5253bd2b Merge branch 'router-for-me:main' into main 2026-03-06 04:15:55 +08:00
Luis Pater
9397f7049f fix(registry): simplify GPT 5.4 model description in static data 2026-03-06 02:32:56 +08:00
Luis Pater
a14d19b92c Merge branch 'router-for-me:main' into main 2026-03-06 02:25:19 +08:00
Luis Pater
8ae0c05ea6 Merge pull request #416 from ladeng07/main
feat(github-copilot): add /responses support for gpt-4o and gpt-4.1
2026-03-06 02:25:10 +08:00
Luis Pater
8822f20d17 feat(registry): add GPT 5.4 model definition to static data 2026-03-06 02:23:53 +08:00
Luis Pater
f0e5a5a367 test(watcher): add unit test for server update timer cancellation and immediate reload logic
- Add `TestTriggerServerUpdateCancelsPendingTimerOnImmediate` to verify proper handling of server update debounce and timer cancellation.
- Fix logic in `triggerServerUpdate` to prevent duplicate timers and ensure proper cleanup of pending state.
2026-03-05 23:48:50 +08:00
Luis Pater
f6dfea9357 Merge pull request #1874 from constansino/fix/watcher-auth-event-storm-debounce
fix(watcher): 合并 auth 事件风暴下的回调触发,降低高 CPU
2026-03-05 23:29:56 +08:00
Luis Pater
cc8dc7f62c Merge branch 'main' into dev 2026-03-05 23:13:21 +08:00
Luis Pater
a3846ea513 Merge pull request #1870 from sususu98/fix/remove-instructions-restore
cleanup(translator): remove leftover instructions restore in codex responses
2026-03-05 23:12:31 +08:00
Luis Pater
8d44be858e Merge pull request #1834 from DragonFSKY/fix/sse-streaming-accept-encoding
fix(claude): extend gzip fix to SSE success path and header-absent compression (#1763)
2026-03-05 22:57:27 +08:00
Luis Pater
0e6bb076e9 fix(translator): comment out service_tier removal from OpenAI response processing 2026-03-05 22:49:38 +08:00
Luis Pater
ac135fc7cb Fixed: #1815
**test(executor): add unit tests for prompt cache key generation in OpenAI `cacheHelper`**
2026-03-05 22:49:23 +08:00
Luis Pater
4e1d09809d Fixed: #1741
fix(translator): handle tool name mappings and improve tool call handling in OpenAI and Claude integrations
2026-03-05 22:24:50 +08:00
LMark
9e855f8100 feat(github-copilot): add /responses support for gpt-4o and gpt-4.1 2026-03-05 21:20:21 +08:00
constansino
ac95e92829 fix(watcher): guard debounced callback after Stop 2026-03-05 19:25:57 +08:00
constansino
8526c2da25 fix(watcher): debounce auth event callback storms 2026-03-05 19:12:57 +08:00
sususu98
68a6cabf8b style: blank unused params in codex responses translator 2026-03-05 16:42:48 +08:00
sususu98
ac0e387da1 cleanup(translator): remove leftover instructions restore in codex responses
The instructions restore logic was originally needed when the proxy
injected custom instructions (per-model system prompts) into requests.
Since ac802a46 removed the injection system, the proxy no longer
modifies instructions before forwarding. The upstream response's
instructions field now matches the client's original value, making
the restore a no-op.

Also removes unused sjson import.

Closes router-for-me/CLIProxyAPI#1868
2026-03-05 16:34:55 +08:00
DragonFSKY
419bf784ab fix(claude): prevent compressed SSE streams and add magic-byte decompression fallback
- Set Accept-Encoding: identity for SSE streams; upstream must not compress
  line-delimited SSE bodies that bufio.Scanner reads directly
- Re-enforce identity after ApplyCustomHeadersFromAttrs to prevent auth
  attribute injection from re-enabling compression on the stream path
- Add peekableBody type wrapping bufio.Reader for non-consuming magic-byte
  inspection of the first 4 bytes without affecting downstream readers
- Detect gzip (0x1f 0x8b) and zstd (0x28 0xb5 0x2f 0xfd) by magic bytes
  when Content-Encoding header is absent, covering misbehaving upstreams
- Remove if-Content-Encoding guard on all three error paths (Execute,
  ExecuteStream, CountTokens); unconditionally delegate to decodeResponseBody
  so magic-byte detection applies consistently to all response paths
- Add 10 tests covering stream identity enforcement, compressed success bodies,
  magic-byte detection without headers, error path decoding, and
  auth attribute override prevention
2026-03-05 06:38:38 +08:00
lyd123qw2008
dd44413ba5 refactor(watcher): make authSliceToMap always return map 2026-03-02 10:09:56 +08:00
lyd123qw2008
10fa0f2062 refactor(watcher): dedupe auth map conversion in incremental flow 2026-03-02 10:03:42 +08:00
lyd123qw2008
30338ecec4 perf(watcher): remove redundant auth clones in incremental path 2026-03-01 14:05:11 +08:00
lyd123qw2008
9a37defed3 test(watcher): restore main test names and max-retry callback coverage 2026-03-01 13:54:03 +08:00
lyd123qw2008
c83a057996 refactor(watcher): make auth file events fully incremental 2026-03-01 13:42:42 +08:00
Alexey Yanchenko
b7588428c5 fix: preserve input_audio content parts when proxying to Antigravity
- Add input_audio handling in chat/completions translator (antigravity_openai_request.go)
- Add input_audio handling in responses translator (gemini_openai-responses_request.go)
- Map OpenAI audio formats (mp3, wav, ogg, flac, aac, webm, pcm16, g711_ulaw, g711_alaw) to correct MIME types for Gemini inlineData
2026-02-23 20:50:28 +07:00
27 changed files with 2204 additions and 396 deletions

View File

@@ -151,6 +151,7 @@ func GetGitHubCopilotModels() []*ModelInfo {
Description: "OpenAI GPT-4.1 via GitHub Copilot",
ContextLength: 128000,
MaxCompletionTokens: 16384,
SupportedEndpoints: []string{"/chat/completions", "/responses"},
},
}
@@ -165,6 +166,7 @@ func GetGitHubCopilotModels() []*ModelInfo {
Description: entry.Description,
ContextLength: 128000,
MaxCompletionTokens: 16384,
SupportedEndpoints: []string{"/chat/completions", "/responses"},
})
}

View File

@@ -220,12 +220,27 @@ func GetGeminiModels() []*ModelInfo {
Name: "models/gemini-3-flash-preview",
Version: "3.0",
DisplayName: "Gemini 3 Flash Preview",
Description: "Gemini 3 Flash Preview",
Description: "Our most intelligent model built for speed, combining frontier intelligence with superior search and grounding.",
InputTokenLimit: 1048576,
OutputTokenLimit: 65536,
SupportedGenerationMethods: []string{"generateContent", "countTokens", "createCachedContent", "batchGenerateContent"},
Thinking: &ThinkingSupport{Min: 128, Max: 32768, ZeroAllowed: false, DynamicAllowed: true, Levels: []string{"minimal", "low", "medium", "high"}},
},
{
ID: "gemini-3.1-flash-lite-preview",
Object: "model",
Created: 1776288000,
OwnedBy: "google",
Type: "gemini",
Name: "models/gemini-3.1-flash-lite-preview",
Version: "3.1",
DisplayName: "Gemini 3.1 Flash Lite Preview",
Description: "Our smallest and most cost effective model, built for at scale usage.",
InputTokenLimit: 1048576,
OutputTokenLimit: 65536,
SupportedGenerationMethods: []string{"generateContent", "countTokens", "createCachedContent", "batchGenerateContent"},
Thinking: &ThinkingSupport{Min: 128, Max: 32768, ZeroAllowed: false, DynamicAllowed: true, Levels: []string{"minimal", "high"}},
},
{
ID: "gemini-3-pro-image-preview",
Object: "model",
@@ -336,6 +351,21 @@ func GetGeminiVertexModels() []*ModelInfo {
SupportedGenerationMethods: []string{"generateContent", "countTokens", "createCachedContent", "batchGenerateContent"},
Thinking: &ThinkingSupport{Min: 128, Max: 32768, ZeroAllowed: false, DynamicAllowed: true, Levels: []string{"low", "high"}},
},
{
ID: "gemini-3.1-flash-lite-preview",
Object: "model",
Created: 1776288000,
OwnedBy: "google",
Type: "gemini",
Name: "models/gemini-3.1-flash-lite-preview",
Version: "3.1",
DisplayName: "Gemini 3.1 Flash Lite Preview",
Description: "Our smallest and most cost effective model, built for at scale usage.",
InputTokenLimit: 1048576,
OutputTokenLimit: 65536,
SupportedGenerationMethods: []string{"generateContent", "countTokens", "createCachedContent", "batchGenerateContent"},
Thinking: &ThinkingSupport{Min: 128, Max: 32768, ZeroAllowed: false, DynamicAllowed: true, Levels: []string{"minimal", "high"}},
},
{
ID: "gemini-3-pro-image-preview",
Object: "model",
@@ -508,6 +538,21 @@ func GetGeminiCLIModels() []*ModelInfo {
SupportedGenerationMethods: []string{"generateContent", "countTokens", "createCachedContent", "batchGenerateContent"},
Thinking: &ThinkingSupport{Min: 128, Max: 32768, ZeroAllowed: false, DynamicAllowed: true, Levels: []string{"minimal", "low", "medium", "high"}},
},
{
ID: "gemini-3.1-flash-lite-preview",
Object: "model",
Created: 1776288000,
OwnedBy: "google",
Type: "gemini",
Name: "models/gemini-3.1-flash-lite-preview",
Version: "3.1",
DisplayName: "Gemini 3.1 Flash Lite Preview",
Description: "Our smallest and most cost effective model, built for at scale usage.",
InputTokenLimit: 1048576,
OutputTokenLimit: 65536,
SupportedGenerationMethods: []string{"generateContent", "countTokens", "createCachedContent", "batchGenerateContent"},
Thinking: &ThinkingSupport{Min: 128, Max: 32768, ZeroAllowed: false, DynamicAllowed: true, Levels: []string{"minimal", "high"}},
},
}
}
@@ -604,6 +649,21 @@ func GetAIStudioModels() []*ModelInfo {
SupportedGenerationMethods: []string{"generateContent", "countTokens", "createCachedContent", "batchGenerateContent"},
Thinking: &ThinkingSupport{Min: 128, Max: 32768, ZeroAllowed: false, DynamicAllowed: true},
},
{
ID: "gemini-3.1-flash-lite-preview",
Object: "model",
Created: 1776288000,
OwnedBy: "google",
Type: "gemini",
Name: "models/gemini-3.1-flash-lite-preview",
Version: "3.1",
DisplayName: "Gemini 3.1 Flash Lite Preview",
Description: "Our smallest and most cost effective model, built for at scale usage.",
InputTokenLimit: 1048576,
OutputTokenLimit: 65536,
SupportedGenerationMethods: []string{"generateContent", "countTokens", "createCachedContent", "batchGenerateContent"},
Thinking: &ThinkingSupport{Min: 128, Max: 32768, ZeroAllowed: false, DynamicAllowed: true, Levels: []string{"minimal", "high"}},
},
{
ID: "gemini-pro-latest",
Object: "model",
@@ -839,6 +899,20 @@ func GetOpenAIModels() []*ModelInfo {
SupportedParameters: []string{"tools"},
Thinking: &ThinkingSupport{Levels: []string{"low", "medium", "high", "xhigh"}},
},
{
ID: "gpt-5.4",
Object: "model",
Created: 1772668800,
OwnedBy: "openai",
Type: "openai",
Version: "gpt-5.4",
DisplayName: "GPT 5.4",
Description: "Stable version of GPT 5.4",
ContextLength: 1_050_000,
MaxCompletionTokens: 128000,
SupportedParameters: []string{"tools"},
Thinking: &ThinkingSupport{Levels: []string{"low", "medium", "high", "xhigh"}},
},
}
}
@@ -966,6 +1040,7 @@ func GetAntigravityModelConfig() map[string]*AntigravityModelConfig {
"gemini-3.1-pro-high": {Thinking: &ThinkingSupport{Min: 128, Max: 32768, ZeroAllowed: false, DynamicAllowed: true, Levels: []string{"low", "high"}}},
"gemini-3.1-pro-low": {Thinking: &ThinkingSupport{Min: 128, Max: 32768, ZeroAllowed: false, DynamicAllowed: true, Levels: []string{"low", "high"}}},
"gemini-3.1-flash-image": {Thinking: &ThinkingSupport{Min: 128, Max: 32768, ZeroAllowed: false, DynamicAllowed: true, Levels: []string{"minimal", "high"}}},
"gemini-3.1-flash-lite-preview": {Thinking: &ThinkingSupport{Min: 128, Max: 32768, ZeroAllowed: false, DynamicAllowed: true, Levels: []string{"minimal", "high"}}},
"gemini-3-flash": {Thinking: &ThinkingSupport{Min: 128, Max: 32768, ZeroAllowed: false, DynamicAllowed: true, Levels: []string{"minimal", "low", "medium", "high"}}},
"claude-opus-4-6-thinking": {Thinking: &ThinkingSupport{Min: 1024, Max: 64000, ZeroAllowed: true, DynamicAllowed: true}, MaxCompletionTokens: 64000},
"claude-sonnet-4-6": {Thinking: &ThinkingSupport{Min: 1024, Max: 64000, ZeroAllowed: true, DynamicAllowed: true}, MaxCompletionTokens: 64000},

View File

@@ -187,17 +187,15 @@ func (e *ClaudeExecutor) Execute(ctx context.Context, auth *cliproxyauth.Auth, r
}
recordAPIResponseMetadata(ctx, e.cfg, httpResp.StatusCode, httpResp.Header.Clone())
if httpResp.StatusCode < 200 || httpResp.StatusCode >= 300 {
// Decompress error responses (e.g. gzip-compressed 400 errors from Anthropic API).
errBody := httpResp.Body
if ce := httpResp.Header.Get("Content-Encoding"); ce != "" {
var decErr error
errBody, decErr = decodeResponseBody(httpResp.Body, ce)
if decErr != nil {
recordAPIResponseError(ctx, e.cfg, decErr)
msg := fmt.Sprintf("failed to decode error response body (encoding=%s): %v", ce, decErr)
logWithRequestID(ctx).Warn(msg)
return resp, statusErr{code: httpResp.StatusCode, msg: msg}
}
// Decompress error responses — pass the Content-Encoding value (may be empty)
// and let decodeResponseBody handle both header-declared and magic-byte-detected
// compression. This keeps error-path behaviour consistent with the success path.
errBody, decErr := decodeResponseBody(httpResp.Body, httpResp.Header.Get("Content-Encoding"))
if decErr != nil {
recordAPIResponseError(ctx, e.cfg, decErr)
msg := fmt.Sprintf("failed to decode error response body: %v", decErr)
logWithRequestID(ctx).Warn(msg)
return resp, statusErr{code: httpResp.StatusCode, msg: msg}
}
b, readErr := io.ReadAll(errBody)
if readErr != nil {
@@ -352,17 +350,15 @@ func (e *ClaudeExecutor) ExecuteStream(ctx context.Context, auth *cliproxyauth.A
}
recordAPIResponseMetadata(ctx, e.cfg, httpResp.StatusCode, httpResp.Header.Clone())
if httpResp.StatusCode < 200 || httpResp.StatusCode >= 300 {
// Decompress error responses (e.g. gzip-compressed 400 errors from Anthropic API).
errBody := httpResp.Body
if ce := httpResp.Header.Get("Content-Encoding"); ce != "" {
var decErr error
errBody, decErr = decodeResponseBody(httpResp.Body, ce)
if decErr != nil {
recordAPIResponseError(ctx, e.cfg, decErr)
msg := fmt.Sprintf("failed to decode error response body (encoding=%s): %v", ce, decErr)
logWithRequestID(ctx).Warn(msg)
return nil, statusErr{code: httpResp.StatusCode, msg: msg}
}
// Decompress error responses — pass the Content-Encoding value (may be empty)
// and let decodeResponseBody handle both header-declared and magic-byte-detected
// compression. This keeps error-path behaviour consistent with the success path.
errBody, decErr := decodeResponseBody(httpResp.Body, httpResp.Header.Get("Content-Encoding"))
if decErr != nil {
recordAPIResponseError(ctx, e.cfg, decErr)
msg := fmt.Sprintf("failed to decode error response body: %v", decErr)
logWithRequestID(ctx).Warn(msg)
return nil, statusErr{code: httpResp.StatusCode, msg: msg}
}
b, readErr := io.ReadAll(errBody)
if readErr != nil {
@@ -521,17 +517,15 @@ func (e *ClaudeExecutor) CountTokens(ctx context.Context, auth *cliproxyauth.Aut
}
recordAPIResponseMetadata(ctx, e.cfg, resp.StatusCode, resp.Header.Clone())
if resp.StatusCode < 200 || resp.StatusCode >= 300 {
// Decompress error responses (e.g. gzip-compressed 400 errors from Anthropic API).
errBody := resp.Body
if ce := resp.Header.Get("Content-Encoding"); ce != "" {
var decErr error
errBody, decErr = decodeResponseBody(resp.Body, ce)
if decErr != nil {
recordAPIResponseError(ctx, e.cfg, decErr)
msg := fmt.Sprintf("failed to decode error response body (encoding=%s): %v", ce, decErr)
logWithRequestID(ctx).Warn(msg)
return cliproxyexecutor.Response{}, statusErr{code: resp.StatusCode, msg: msg}
}
// Decompress error responses — pass the Content-Encoding value (may be empty)
// and let decodeResponseBody handle both header-declared and magic-byte-detected
// compression. This keeps error-path behaviour consistent with the success path.
errBody, decErr := decodeResponseBody(resp.Body, resp.Header.Get("Content-Encoding"))
if decErr != nil {
recordAPIResponseError(ctx, e.cfg, decErr)
msg := fmt.Sprintf("failed to decode error response body: %v", decErr)
logWithRequestID(ctx).Warn(msg)
return cliproxyexecutor.Response{}, statusErr{code: resp.StatusCode, msg: msg}
}
b, readErr := io.ReadAll(errBody)
if readErr != nil {
@@ -662,12 +656,61 @@ func (c *compositeReadCloser) Close() error {
return firstErr
}
// peekableBody wraps a bufio.Reader around the original ReadCloser so that
// magic bytes can be inspected without consuming them from the stream.
type peekableBody struct {
*bufio.Reader
closer io.Closer
}
func (p *peekableBody) Close() error {
return p.closer.Close()
}
func decodeResponseBody(body io.ReadCloser, contentEncoding string) (io.ReadCloser, error) {
if body == nil {
return nil, fmt.Errorf("response body is nil")
}
if contentEncoding == "" {
return body, nil
// No Content-Encoding header. Attempt best-effort magic-byte detection to
// handle misbehaving upstreams that compress without setting the header.
// Only gzip (1f 8b) and zstd (28 b5 2f fd) have reliable magic sequences;
// br and deflate have none and are left as-is.
// The bufio wrapper preserves unread bytes so callers always see the full
// stream regardless of whether decompression was applied.
pb := &peekableBody{Reader: bufio.NewReader(body), closer: body}
magic, peekErr := pb.Peek(4)
if peekErr == nil || (peekErr == io.EOF && len(magic) >= 2) {
switch {
case len(magic) >= 2 && magic[0] == 0x1f && magic[1] == 0x8b:
gzipReader, gzErr := gzip.NewReader(pb)
if gzErr != nil {
_ = pb.Close()
return nil, fmt.Errorf("magic-byte gzip: failed to create reader: %w", gzErr)
}
return &compositeReadCloser{
Reader: gzipReader,
closers: []func() error{
gzipReader.Close,
pb.Close,
},
}, nil
case len(magic) >= 4 && magic[0] == 0x28 && magic[1] == 0xb5 && magic[2] == 0x2f && magic[3] == 0xfd:
decoder, zdErr := zstd.NewReader(pb)
if zdErr != nil {
_ = pb.Close()
return nil, fmt.Errorf("magic-byte zstd: failed to create reader: %w", zdErr)
}
return &compositeReadCloser{
Reader: decoder,
closers: []func() error{
func() error { decoder.Close(); return nil },
pb.Close,
},
}, nil
}
}
return pb, nil
}
encodings := strings.Split(contentEncoding, ",")
for _, raw := range encodings {
@@ -844,11 +887,15 @@ func applyClaudeHeaders(r *http.Request, auth *cliproxyauth.Auth, apiKey string,
r.Header.Set("User-Agent", hdrDefault(hd.UserAgent, "claude-cli/2.1.63 (external, cli)"))
}
r.Header.Set("Connection", "keep-alive")
r.Header.Set("Accept-Encoding", "gzip, deflate, br, zstd")
if stream {
r.Header.Set("Accept", "text/event-stream")
// SSE streams must not be compressed: the downstream scanner reads
// line-delimited text and cannot parse compressed bytes. Using
// "identity" tells the upstream to send an uncompressed stream.
r.Header.Set("Accept-Encoding", "identity")
} else {
r.Header.Set("Accept", "application/json")
r.Header.Set("Accept-Encoding", "gzip, deflate, br, zstd")
}
// Keep OS/Arch mapping dynamic (not configurable).
// They intentionally continue to derive from runtime.GOOS/runtime.GOARCH.
@@ -857,6 +904,12 @@ func applyClaudeHeaders(r *http.Request, auth *cliproxyauth.Auth, apiKey string,
attrs = auth.Attributes
}
util.ApplyCustomHeadersFromAttrs(r, attrs)
// Re-enforce Accept-Encoding: identity after ApplyCustomHeadersFromAttrs, which
// may override it with a user-configured value. Compressed SSE breaks the line
// scanner regardless of user preference, so this is non-negotiable for streams.
if stream {
r.Header.Set("Accept-Encoding", "identity")
}
}
func claudeCreds(a *cliproxyauth.Auth) (apiKey, baseURL string) {

View File

@@ -2,6 +2,7 @@ package executor
import (
"bytes"
"compress/gzip"
"context"
"io"
"net/http"
@@ -9,6 +10,7 @@ import (
"strings"
"testing"
"github.com/klauspost/compress/zstd"
"github.com/router-for-me/CLIProxyAPI/v6/internal/config"
cliproxyauth "github.com/router-for-me/CLIProxyAPI/v6/sdk/cliproxy/auth"
cliproxyexecutor "github.com/router-for-me/CLIProxyAPI/v6/sdk/cliproxy/executor"
@@ -583,3 +585,385 @@ func testClaudeExecutorInvalidCompressedErrorBody(
t.Fatalf("expected status code 400, got: %v", err)
}
}
// TestClaudeExecutor_ExecuteStream_SetsIdentityAcceptEncoding verifies that streaming
// requests use Accept-Encoding: identity so the upstream cannot respond with a
// compressed SSE body that would silently break the line scanner.
func TestClaudeExecutor_ExecuteStream_SetsIdentityAcceptEncoding(t *testing.T) {
var gotEncoding, gotAccept string
server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
gotEncoding = r.Header.Get("Accept-Encoding")
gotAccept = r.Header.Get("Accept")
w.Header().Set("Content-Type", "text/event-stream")
_, _ = w.Write([]byte("data: {\"type\":\"message_stop\"}\n\n"))
}))
defer server.Close()
executor := NewClaudeExecutor(&config.Config{})
auth := &cliproxyauth.Auth{Attributes: map[string]string{
"api_key": "key-123",
"base_url": server.URL,
}}
payload := []byte(`{"messages":[{"role":"user","content":[{"type":"text","text":"hi"}]}]}`)
result, err := executor.ExecuteStream(context.Background(), auth, cliproxyexecutor.Request{
Model: "claude-3-5-sonnet-20241022",
Payload: payload,
}, cliproxyexecutor.Options{
SourceFormat: sdktranslator.FromString("claude"),
})
if err != nil {
t.Fatalf("ExecuteStream error: %v", err)
}
for chunk := range result.Chunks {
if chunk.Err != nil {
t.Fatalf("unexpected chunk error: %v", chunk.Err)
}
}
if gotEncoding != "identity" {
t.Errorf("Accept-Encoding = %q, want %q", gotEncoding, "identity")
}
if gotAccept != "text/event-stream" {
t.Errorf("Accept = %q, want %q", gotAccept, "text/event-stream")
}
}
// TestClaudeExecutor_Execute_SetsCompressedAcceptEncoding verifies that non-streaming
// requests keep the full accept-encoding to allow response compression (which
// decodeResponseBody handles correctly).
func TestClaudeExecutor_Execute_SetsCompressedAcceptEncoding(t *testing.T) {
var gotEncoding, gotAccept string
server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
gotEncoding = r.Header.Get("Accept-Encoding")
gotAccept = r.Header.Get("Accept")
w.Header().Set("Content-Type", "application/json")
_, _ = w.Write([]byte(`{"id":"msg_1","type":"message","model":"claude-3-5-sonnet-20241022","role":"assistant","content":[{"type":"text","text":"hi"}],"usage":{"input_tokens":1,"output_tokens":1}}`))
}))
defer server.Close()
executor := NewClaudeExecutor(&config.Config{})
auth := &cliproxyauth.Auth{Attributes: map[string]string{
"api_key": "key-123",
"base_url": server.URL,
}}
payload := []byte(`{"messages":[{"role":"user","content":[{"type":"text","text":"hi"}]}]}`)
_, err := executor.Execute(context.Background(), auth, cliproxyexecutor.Request{
Model: "claude-3-5-sonnet-20241022",
Payload: payload,
}, cliproxyexecutor.Options{
SourceFormat: sdktranslator.FromString("claude"),
})
if err != nil {
t.Fatalf("Execute error: %v", err)
}
if gotEncoding != "gzip, deflate, br, zstd" {
t.Errorf("Accept-Encoding = %q, want %q", gotEncoding, "gzip, deflate, br, zstd")
}
if gotAccept != "application/json" {
t.Errorf("Accept = %q, want %q", gotAccept, "application/json")
}
}
// TestClaudeExecutor_ExecuteStream_GzipSuccessBodyDecoded verifies that a streaming
// HTTP 200 response with Content-Encoding: gzip is correctly decompressed before
// the line scanner runs, so SSE chunks are not silently dropped.
func TestClaudeExecutor_ExecuteStream_GzipSuccessBodyDecoded(t *testing.T) {
var buf bytes.Buffer
gz := gzip.NewWriter(&buf)
_, _ = gz.Write([]byte("data: {\"type\":\"message_stop\"}\n"))
_ = gz.Close()
compressedBody := buf.Bytes()
server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
w.Header().Set("Content-Type", "text/event-stream")
w.Header().Set("Content-Encoding", "gzip")
_, _ = w.Write(compressedBody)
}))
defer server.Close()
executor := NewClaudeExecutor(&config.Config{})
auth := &cliproxyauth.Auth{Attributes: map[string]string{
"api_key": "key-123",
"base_url": server.URL,
}}
payload := []byte(`{"messages":[{"role":"user","content":[{"type":"text","text":"hi"}]}]}`)
result, err := executor.ExecuteStream(context.Background(), auth, cliproxyexecutor.Request{
Model: "claude-3-5-sonnet-20241022",
Payload: payload,
}, cliproxyexecutor.Options{
SourceFormat: sdktranslator.FromString("claude"),
})
if err != nil {
t.Fatalf("ExecuteStream error: %v", err)
}
var combined strings.Builder
for chunk := range result.Chunks {
if chunk.Err != nil {
t.Fatalf("chunk error: %v", chunk.Err)
}
combined.Write(chunk.Payload)
}
if combined.Len() == 0 {
t.Fatal("expected at least one chunk from gzip-encoded SSE body, got none (body was not decompressed)")
}
if !strings.Contains(combined.String(), "message_stop") {
t.Errorf("expected SSE content in chunks, got: %q", combined.String())
}
}
// TestDecodeResponseBody_MagicByteGzipNoHeader verifies that decodeResponseBody
// detects gzip-compressed content via magic bytes even when Content-Encoding is absent.
func TestDecodeResponseBody_MagicByteGzipNoHeader(t *testing.T) {
const plaintext = "data: {\"type\":\"message_stop\"}\n"
var buf bytes.Buffer
gz := gzip.NewWriter(&buf)
_, _ = gz.Write([]byte(plaintext))
_ = gz.Close()
rc := io.NopCloser(&buf)
decoded, err := decodeResponseBody(rc, "")
if err != nil {
t.Fatalf("decodeResponseBody error: %v", err)
}
defer decoded.Close()
got, err := io.ReadAll(decoded)
if err != nil {
t.Fatalf("ReadAll error: %v", err)
}
if string(got) != plaintext {
t.Errorf("decoded = %q, want %q", got, plaintext)
}
}
// TestDecodeResponseBody_PlainTextNoHeader verifies that decodeResponseBody returns
// plain text untouched when Content-Encoding is absent and no magic bytes match.
func TestDecodeResponseBody_PlainTextNoHeader(t *testing.T) {
const plaintext = "data: {\"type\":\"message_stop\"}\n"
rc := io.NopCloser(strings.NewReader(plaintext))
decoded, err := decodeResponseBody(rc, "")
if err != nil {
t.Fatalf("decodeResponseBody error: %v", err)
}
defer decoded.Close()
got, err := io.ReadAll(decoded)
if err != nil {
t.Fatalf("ReadAll error: %v", err)
}
if string(got) != plaintext {
t.Errorf("decoded = %q, want %q", got, plaintext)
}
}
// TestClaudeExecutor_ExecuteStream_GzipNoContentEncodingHeader verifies the full
// pipeline: when the upstream returns a gzip-compressed SSE body WITHOUT setting
// Content-Encoding (a misbehaving upstream), the magic-byte sniff in
// decodeResponseBody still decompresses it, so chunks reach the caller.
func TestClaudeExecutor_ExecuteStream_GzipNoContentEncodingHeader(t *testing.T) {
var buf bytes.Buffer
gz := gzip.NewWriter(&buf)
_, _ = gz.Write([]byte("data: {\"type\":\"message_stop\"}\n"))
_ = gz.Close()
compressedBody := buf.Bytes()
server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
w.Header().Set("Content-Type", "text/event-stream")
// Intentionally omit Content-Encoding to simulate misbehaving upstream.
_, _ = w.Write(compressedBody)
}))
defer server.Close()
executor := NewClaudeExecutor(&config.Config{})
auth := &cliproxyauth.Auth{Attributes: map[string]string{
"api_key": "key-123",
"base_url": server.URL,
}}
payload := []byte(`{"messages":[{"role":"user","content":[{"type":"text","text":"hi"}]}]}`)
result, err := executor.ExecuteStream(context.Background(), auth, cliproxyexecutor.Request{
Model: "claude-3-5-sonnet-20241022",
Payload: payload,
}, cliproxyexecutor.Options{
SourceFormat: sdktranslator.FromString("claude"),
})
if err != nil {
t.Fatalf("ExecuteStream error: %v", err)
}
var combined strings.Builder
for chunk := range result.Chunks {
if chunk.Err != nil {
t.Fatalf("chunk error: %v", chunk.Err)
}
combined.Write(chunk.Payload)
}
if combined.Len() == 0 {
t.Fatal("expected chunks from gzip body without Content-Encoding header, got none (magic-byte sniff failed)")
}
if !strings.Contains(combined.String(), "message_stop") {
t.Errorf("unexpected chunk content: %q", combined.String())
}
}
// TestClaudeExecutor_ExecuteStream_AcceptEncodingOverrideCannotBypassIdentity verifies
// that injecting Accept-Encoding via auth.Attributes cannot override the stream
// path's enforced identity encoding.
func TestClaudeExecutor_ExecuteStream_AcceptEncodingOverrideCannotBypassIdentity(t *testing.T) {
var gotEncoding string
server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
gotEncoding = r.Header.Get("Accept-Encoding")
w.Header().Set("Content-Type", "text/event-stream")
_, _ = w.Write([]byte("data: {\"type\":\"message_stop\"}\n\n"))
}))
defer server.Close()
executor := NewClaudeExecutor(&config.Config{})
// Inject Accept-Encoding via the custom header attribute mechanism.
auth := &cliproxyauth.Auth{Attributes: map[string]string{
"api_key": "key-123",
"base_url": server.URL,
"header:Accept-Encoding": "gzip, deflate, br, zstd",
}}
payload := []byte(`{"messages":[{"role":"user","content":[{"type":"text","text":"hi"}]}]}`)
result, err := executor.ExecuteStream(context.Background(), auth, cliproxyexecutor.Request{
Model: "claude-3-5-sonnet-20241022",
Payload: payload,
}, cliproxyexecutor.Options{
SourceFormat: sdktranslator.FromString("claude"),
})
if err != nil {
t.Fatalf("ExecuteStream error: %v", err)
}
for chunk := range result.Chunks {
if chunk.Err != nil {
t.Fatalf("unexpected chunk error: %v", chunk.Err)
}
}
if gotEncoding != "identity" {
t.Errorf("Accept-Encoding = %q; stream path must enforce identity regardless of auth.Attributes override", gotEncoding)
}
}
// TestDecodeResponseBody_MagicByteZstdNoHeader verifies that decodeResponseBody
// detects zstd-compressed content via magic bytes (28 b5 2f fd) even when
// Content-Encoding is absent.
func TestDecodeResponseBody_MagicByteZstdNoHeader(t *testing.T) {
const plaintext = "data: {\"type\":\"message_stop\"}\n"
var buf bytes.Buffer
enc, err := zstd.NewWriter(&buf)
if err != nil {
t.Fatalf("zstd.NewWriter: %v", err)
}
_, _ = enc.Write([]byte(plaintext))
_ = enc.Close()
rc := io.NopCloser(&buf)
decoded, err := decodeResponseBody(rc, "")
if err != nil {
t.Fatalf("decodeResponseBody error: %v", err)
}
defer decoded.Close()
got, err := io.ReadAll(decoded)
if err != nil {
t.Fatalf("ReadAll error: %v", err)
}
if string(got) != plaintext {
t.Errorf("decoded = %q, want %q", got, plaintext)
}
}
// TestClaudeExecutor_Execute_GzipErrorBodyNoContentEncodingHeader verifies that the
// error path (4xx) correctly decompresses a gzip body even when the upstream omits
// the Content-Encoding header. This closes the gap left by PR #1771, which only
// fixed header-declared compression on the error path.
func TestClaudeExecutor_Execute_GzipErrorBodyNoContentEncodingHeader(t *testing.T) {
const errJSON = `{"type":"error","error":{"type":"invalid_request_error","message":"test error"}}`
var buf bytes.Buffer
gz := gzip.NewWriter(&buf)
_, _ = gz.Write([]byte(errJSON))
_ = gz.Close()
compressedBody := buf.Bytes()
server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
w.Header().Set("Content-Type", "application/json")
// Intentionally omit Content-Encoding to simulate misbehaving upstream.
w.WriteHeader(http.StatusBadRequest)
_, _ = w.Write(compressedBody)
}))
defer server.Close()
executor := NewClaudeExecutor(&config.Config{})
auth := &cliproxyauth.Auth{Attributes: map[string]string{
"api_key": "key-123",
"base_url": server.URL,
}}
payload := []byte(`{"messages":[{"role":"user","content":[{"type":"text","text":"hi"}]}]}`)
_, err := executor.Execute(context.Background(), auth, cliproxyexecutor.Request{
Model: "claude-3-5-sonnet-20241022",
Payload: payload,
}, cliproxyexecutor.Options{
SourceFormat: sdktranslator.FromString("claude"),
})
if err == nil {
t.Fatal("expected an error for 400 response, got nil")
}
if !strings.Contains(err.Error(), "test error") {
t.Errorf("error message should contain decompressed JSON, got: %q", err.Error())
}
}
// TestClaudeExecutor_ExecuteStream_GzipErrorBodyNoContentEncodingHeader verifies
// the same for the streaming executor: 4xx gzip body without Content-Encoding is
// decoded and the error message is readable.
func TestClaudeExecutor_ExecuteStream_GzipErrorBodyNoContentEncodingHeader(t *testing.T) {
const errJSON = `{"type":"error","error":{"type":"invalid_request_error","message":"stream test error"}}`
var buf bytes.Buffer
gz := gzip.NewWriter(&buf)
_, _ = gz.Write([]byte(errJSON))
_ = gz.Close()
compressedBody := buf.Bytes()
server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
w.Header().Set("Content-Type", "application/json")
// Intentionally omit Content-Encoding to simulate misbehaving upstream.
w.WriteHeader(http.StatusBadRequest)
_, _ = w.Write(compressedBody)
}))
defer server.Close()
executor := NewClaudeExecutor(&config.Config{})
auth := &cliproxyauth.Auth{Attributes: map[string]string{
"api_key": "key-123",
"base_url": server.URL,
}}
payload := []byte(`{"messages":[{"role":"user","content":[{"type":"text","text":"hi"}]}]}`)
_, err := executor.ExecuteStream(context.Background(), auth, cliproxyexecutor.Request{
Model: "claude-3-5-sonnet-20241022",
Payload: payload,
}, cliproxyexecutor.Options{
SourceFormat: sdktranslator.FromString("claude"),
})
if err == nil {
t.Fatal("expected an error for 400 response, got nil")
}
if !strings.Contains(err.Error(), "stream test error") {
t.Errorf("error message should contain decompressed JSON, got: %q", err.Error())
}
}

View File

@@ -616,6 +616,10 @@ func (e *CodexExecutor) cacheHelper(ctx context.Context, from sdktranslator.Form
if promptCacheKey.Exists() {
cache.ID = promptCacheKey.String()
}
} else if from == "openai" {
if apiKey := strings.TrimSpace(apiKeyFromContext(ctx)); apiKey != "" {
cache.ID = uuid.NewSHA1(uuid.NameSpaceOID, []byte("cli-proxy-api:codex:prompt-cache:"+apiKey)).String()
}
}
if cache.ID != "" {

View File

@@ -0,0 +1,64 @@
package executor
import (
"context"
"io"
"net/http/httptest"
"testing"
"github.com/gin-gonic/gin"
"github.com/google/uuid"
cliproxyexecutor "github.com/router-for-me/CLIProxyAPI/v6/sdk/cliproxy/executor"
sdktranslator "github.com/router-for-me/CLIProxyAPI/v6/sdk/translator"
"github.com/tidwall/gjson"
)
func TestCodexExecutorCacheHelper_OpenAIChatCompletions_StablePromptCacheKeyFromAPIKey(t *testing.T) {
recorder := httptest.NewRecorder()
ginCtx, _ := gin.CreateTestContext(recorder)
ginCtx.Set("apiKey", "test-api-key")
ctx := context.WithValue(context.Background(), "gin", ginCtx)
executor := &CodexExecutor{}
rawJSON := []byte(`{"model":"gpt-5.3-codex","stream":true}`)
req := cliproxyexecutor.Request{
Model: "gpt-5.3-codex",
Payload: []byte(`{"model":"gpt-5.3-codex"}`),
}
url := "https://example.com/responses"
httpReq, err := executor.cacheHelper(ctx, sdktranslator.FromString("openai"), url, req, rawJSON)
if err != nil {
t.Fatalf("cacheHelper error: %v", err)
}
body, errRead := io.ReadAll(httpReq.Body)
if errRead != nil {
t.Fatalf("read request body: %v", errRead)
}
expectedKey := uuid.NewSHA1(uuid.NameSpaceOID, []byte("cli-proxy-api:codex:prompt-cache:test-api-key")).String()
gotKey := gjson.GetBytes(body, "prompt_cache_key").String()
if gotKey != expectedKey {
t.Fatalf("prompt_cache_key = %q, want %q", gotKey, expectedKey)
}
if gotConversation := httpReq.Header.Get("Conversation_id"); gotConversation != expectedKey {
t.Fatalf("Conversation_id = %q, want %q", gotConversation, expectedKey)
}
if gotSession := httpReq.Header.Get("Session_id"); gotSession != expectedKey {
t.Fatalf("Session_id = %q, want %q", gotSession, expectedKey)
}
httpReq2, err := executor.cacheHelper(ctx, sdktranslator.FromString("openai"), url, req, rawJSON)
if err != nil {
t.Fatalf("cacheHelper error (second call): %v", err)
}
body2, errRead2 := io.ReadAll(httpReq2.Body)
if errRead2 != nil {
t.Fatalf("read request body (second call): %v", errRead2)
}
gotKey2 := gjson.GetBytes(body2, "prompt_cache_key").String()
if gotKey2 != expectedKey {
t.Fatalf("prompt_cache_key (second call) = %q, want %q", gotKey2, expectedKey)
}
}

View File

@@ -31,7 +31,7 @@ import (
)
const (
codexResponsesWebsocketBetaHeaderValue = "responses_websockets=2026-02-04"
codexResponsesWebsocketBetaHeaderValue = "responses_websockets=2026-02-06"
codexResponsesWebsocketIdleTimeout = 5 * time.Minute
codexResponsesWebsocketHandshakeTO = 30 * time.Second
)
@@ -57,11 +57,6 @@ type codexWebsocketSession struct {
wsURL string
authID string
// connCreateSent tracks whether a `response.create` message has been successfully sent
// on the current websocket connection. The upstream expects the first message on each
// connection to be `response.create`.
connCreateSent bool
writeMu sync.Mutex
activeMu sync.Mutex
@@ -212,13 +207,7 @@ func (e *CodexWebsocketsExecutor) Execute(ctx context.Context, auth *cliproxyaut
defer sess.reqMu.Unlock()
}
allowAppend := true
if sess != nil {
sess.connMu.Lock()
allowAppend = sess.connCreateSent
sess.connMu.Unlock()
}
wsReqBody := buildCodexWebsocketRequestBody(body, allowAppend)
wsReqBody := buildCodexWebsocketRequestBody(body)
recordAPIRequest(ctx, e.cfg, upstreamRequestLog{
URL: wsURL,
Method: "WEBSOCKET",
@@ -280,10 +269,7 @@ func (e *CodexWebsocketsExecutor) Execute(ctx context.Context, auth *cliproxyaut
// execution session.
connRetry, _, errDialRetry := e.ensureUpstreamConn(ctx, auth, sess, authID, wsURL, wsHeaders)
if errDialRetry == nil && connRetry != nil {
sess.connMu.Lock()
allowAppend = sess.connCreateSent
sess.connMu.Unlock()
wsReqBodyRetry := buildCodexWebsocketRequestBody(body, allowAppend)
wsReqBodyRetry := buildCodexWebsocketRequestBody(body)
recordAPIRequest(ctx, e.cfg, upstreamRequestLog{
URL: wsURL,
Method: "WEBSOCKET",
@@ -312,7 +298,6 @@ func (e *CodexWebsocketsExecutor) Execute(ctx context.Context, auth *cliproxyaut
return resp, errSend
}
}
markCodexWebsocketCreateSent(sess, conn, wsReqBody)
for {
if ctx != nil && ctx.Err() != nil {
@@ -403,26 +388,20 @@ func (e *CodexWebsocketsExecutor) ExecuteStream(ctx context.Context, auth *clipr
wsHeaders = applyCodexWebsocketHeaders(ctx, wsHeaders, auth, apiKey)
var authID, authLabel, authType, authValue string
if auth != nil {
authID = auth.ID
authLabel = auth.Label
authType, authValue = auth.AccountInfo()
}
authID = auth.ID
authLabel = auth.Label
authType, authValue = auth.AccountInfo()
executionSessionID := executionSessionIDFromOptions(opts)
var sess *codexWebsocketSession
if executionSessionID != "" {
sess = e.getOrCreateSession(executionSessionID)
sess.reqMu.Lock()
if sess != nil {
sess.reqMu.Lock()
}
}
allowAppend := true
if sess != nil {
sess.connMu.Lock()
allowAppend = sess.connCreateSent
sess.connMu.Unlock()
}
wsReqBody := buildCodexWebsocketRequestBody(body, allowAppend)
wsReqBody := buildCodexWebsocketRequestBody(body)
recordAPIRequest(ctx, e.cfg, upstreamRequestLog{
URL: wsURL,
Method: "WEBSOCKET",
@@ -483,10 +462,7 @@ func (e *CodexWebsocketsExecutor) ExecuteStream(ctx context.Context, auth *clipr
sess.reqMu.Unlock()
return nil, errDialRetry
}
sess.connMu.Lock()
allowAppend = sess.connCreateSent
sess.connMu.Unlock()
wsReqBodyRetry := buildCodexWebsocketRequestBody(body, allowAppend)
wsReqBodyRetry := buildCodexWebsocketRequestBody(body)
recordAPIRequest(ctx, e.cfg, upstreamRequestLog{
URL: wsURL,
Method: "WEBSOCKET",
@@ -515,7 +491,6 @@ func (e *CodexWebsocketsExecutor) ExecuteStream(ctx context.Context, auth *clipr
return nil, errSend
}
}
markCodexWebsocketCreateSent(sess, conn, wsReqBody)
out := make(chan cliproxyexecutor.StreamChunk)
go func() {
@@ -657,31 +632,14 @@ func writeCodexWebsocketMessage(sess *codexWebsocketSession, conn *websocket.Con
return conn.WriteMessage(websocket.TextMessage, payload)
}
func buildCodexWebsocketRequestBody(body []byte, allowAppend bool) []byte {
func buildCodexWebsocketRequestBody(body []byte) []byte {
if len(body) == 0 {
return nil
}
// Codex CLI websocket v2 uses `response.create` with `previous_response_id` for incremental turns.
// The upstream ChatGPT Codex websocket currently rejects that with close 1008 (policy violation).
// Fall back to v1 `response.append` semantics on the same websocket connection to keep the session alive.
//
// NOTE: The upstream expects the first websocket event on each connection to be `response.create`,
// so we only use `response.append` after we have initialized the current connection.
if allowAppend {
if prev := strings.TrimSpace(gjson.GetBytes(body, "previous_response_id").String()); prev != "" {
inputNode := gjson.GetBytes(body, "input")
wsReqBody := []byte(`{}`)
wsReqBody, _ = sjson.SetBytes(wsReqBody, "type", "response.append")
if inputNode.Exists() && inputNode.IsArray() && strings.TrimSpace(inputNode.Raw) != "" {
wsReqBody, _ = sjson.SetRawBytes(wsReqBody, "input", []byte(inputNode.Raw))
return wsReqBody
}
wsReqBody, _ = sjson.SetRawBytes(wsReqBody, "input", []byte("[]"))
return wsReqBody
}
}
// Match codex-rs websocket v2 semantics: every request is `response.create`.
// Incremental follow-up turns continue on the same websocket using
// `previous_response_id` + incremental `input`, not `response.append`.
wsReqBody, errSet := sjson.SetBytes(bytes.Clone(body), "type", "response.create")
if errSet == nil && len(wsReqBody) > 0 {
return wsReqBody
@@ -725,21 +683,6 @@ func readCodexWebsocketMessage(ctx context.Context, sess *codexWebsocketSession,
}
}
func markCodexWebsocketCreateSent(sess *codexWebsocketSession, conn *websocket.Conn, payload []byte) {
if sess == nil || conn == nil || len(payload) == 0 {
return
}
if strings.TrimSpace(gjson.GetBytes(payload, "type").String()) != "response.create" {
return
}
sess.connMu.Lock()
if sess.conn == conn {
sess.connCreateSent = true
}
sess.connMu.Unlock()
}
func newProxyAwareWebsocketDialer(cfg *config.Config, auth *cliproxyauth.Auth) *websocket.Dialer {
dialer := &websocket.Dialer{
Proxy: http.ProxyFromEnvironment,
@@ -1017,36 +960,6 @@ func closeHTTPResponseBody(resp *http.Response, logPrefix string) {
}
}
func closeOnContextDone(ctx context.Context, conn *websocket.Conn) chan struct{} {
done := make(chan struct{})
if ctx == nil || conn == nil {
return done
}
go func() {
select {
case <-done:
case <-ctx.Done():
_ = conn.Close()
}
}()
return done
}
func cancelReadOnContextDone(ctx context.Context, conn *websocket.Conn) chan struct{} {
done := make(chan struct{})
if ctx == nil || conn == nil {
return done
}
go func() {
select {
case <-done:
case <-ctx.Done():
_ = conn.SetReadDeadline(time.Now())
}
}()
return done
}
func executionSessionIDFromOptions(opts cliproxyexecutor.Options) string {
if len(opts.Metadata) == 0 {
return ""
@@ -1120,7 +1033,6 @@ func (e *CodexWebsocketsExecutor) ensureUpstreamConn(ctx context.Context, auth *
sess.conn = conn
sess.wsURL = wsURL
sess.authID = authID
sess.connCreateSent = false
sess.readerConn = conn
sess.connMu.Unlock()
@@ -1206,7 +1118,6 @@ func (e *CodexWebsocketsExecutor) invalidateUpstreamConn(sess *codexWebsocketSes
return
}
sess.conn = nil
sess.connCreateSent = false
if sess.readerConn == conn {
sess.readerConn = nil
}
@@ -1273,7 +1184,6 @@ func (e *CodexWebsocketsExecutor) closeExecutionSession(sess *codexWebsocketSess
authID := sess.authID
wsURL := sess.wsURL
sess.conn = nil
sess.connCreateSent = false
if sess.readerConn == conn {
sess.readerConn = nil
}

View File

@@ -0,0 +1,36 @@
package executor
import (
"context"
"net/http"
"testing"
"github.com/tidwall/gjson"
)
func TestBuildCodexWebsocketRequestBodyPreservesPreviousResponseID(t *testing.T) {
body := []byte(`{"model":"gpt-5-codex","previous_response_id":"resp-1","input":[{"type":"message","id":"msg-1"}]}`)
wsReqBody := buildCodexWebsocketRequestBody(body)
if got := gjson.GetBytes(wsReqBody, "type").String(); got != "response.create" {
t.Fatalf("type = %s, want response.create", got)
}
if got := gjson.GetBytes(wsReqBody, "previous_response_id").String(); got != "resp-1" {
t.Fatalf("previous_response_id = %s, want resp-1", got)
}
if gjson.GetBytes(wsReqBody, "input.0.id").String() != "msg-1" {
t.Fatalf("input item id mismatch")
}
if got := gjson.GetBytes(wsReqBody, "type").String(); got == "response.append" {
t.Fatalf("unexpected websocket request type: %s", got)
}
}
func TestApplyCodexWebsocketHeadersDefaultsToCurrentResponsesBeta(t *testing.T) {
headers := applyCodexWebsocketHeaders(context.Background(), http.Header{}, nil, "")
if got := headers.Get("OpenAI-Beta"); got != codexResponsesWebsocketBetaHeaderValue {
t.Fatalf("OpenAI-Beta = %s, want %s", got, codexResponsesWebsocketBetaHeaderValue)
}
}

View File

@@ -212,6 +212,33 @@ func ConvertOpenAIRequestToAntigravity(modelName string, inputRawJSON []byte, _
} else {
log.Warnf("Unknown file name extension '%s' in user message, skip", ext)
}
case "input_audio":
audioData := item.Get("input_audio.data").String()
audioFormat := item.Get("input_audio.format").String()
if audioData != "" {
audioMimeMap := map[string]string{
"mp3": "audio/mpeg",
"wav": "audio/wav",
"ogg": "audio/ogg",
"flac": "audio/flac",
"aac": "audio/aac",
"webm": "audio/webm",
"pcm16": "audio/pcm",
"g711_ulaw": "audio/basic",
"g711_alaw": "audio/basic",
}
mimeType := "audio/wav"
if audioFormat != "" {
if mapped, ok := audioMimeMap[audioFormat]; ok {
mimeType = mapped
} else {
mimeType = "audio/" + audioFormat
}
}
node, _ = sjson.SetBytes(node, "parts."+itoa(p)+".inlineData.mime_type", mimeType)
node, _ = sjson.SetBytes(node, "parts."+itoa(p)+".inlineData.data", audioData)
p++
}
}
}
}

View File

@@ -203,46 +203,9 @@ func ConvertOpenAIRequestToClaude(modelName string, inputRawJSON []byte, stream
msg, _ = sjson.SetRaw(msg, "content.-1", part)
} else if contentResult.Exists() && contentResult.IsArray() {
contentResult.ForEach(func(_, part gjson.Result) bool {
partType := part.Get("type").String()
switch partType {
case "text":
textPart := `{"type":"text","text":""}`
textPart, _ = sjson.Set(textPart, "text", part.Get("text").String())
msg, _ = sjson.SetRaw(msg, "content.-1", textPart)
case "image_url":
// Convert OpenAI image format to Claude Code format
imageURL := part.Get("image_url.url").String()
if strings.HasPrefix(imageURL, "data:") {
// Extract base64 data and media type from data URL
parts := strings.Split(imageURL, ",")
if len(parts) == 2 {
mediaTypePart := strings.Split(parts[0], ";")[0]
mediaType := strings.TrimPrefix(mediaTypePart, "data:")
data := parts[1]
imagePart := `{"type":"image","source":{"type":"base64","media_type":"","data":""}}`
imagePart, _ = sjson.Set(imagePart, "source.media_type", mediaType)
imagePart, _ = sjson.Set(imagePart, "source.data", data)
msg, _ = sjson.SetRaw(msg, "content.-1", imagePart)
}
}
case "file":
fileData := part.Get("file.file_data").String()
if strings.HasPrefix(fileData, "data:") {
semicolonIdx := strings.Index(fileData, ";")
commaIdx := strings.Index(fileData, ",")
if semicolonIdx != -1 && commaIdx != -1 && commaIdx > semicolonIdx {
mediaType := strings.TrimPrefix(fileData[:semicolonIdx], "data:")
data := fileData[commaIdx+1:]
docPart := `{"type":"document","source":{"type":"base64","media_type":"","data":""}}`
docPart, _ = sjson.Set(docPart, "source.media_type", mediaType)
docPart, _ = sjson.Set(docPart, "source.data", data)
msg, _ = sjson.SetRaw(msg, "content.-1", docPart)
}
}
claudePart := convertOpenAIContentPartToClaudePart(part)
if claudePart != "" {
msg, _ = sjson.SetRaw(msg, "content.-1", claudePart)
}
return true
})
@@ -291,11 +254,16 @@ func ConvertOpenAIRequestToClaude(modelName string, inputRawJSON []byte, stream
case "tool":
// Handle tool result messages conversion
toolCallID := message.Get("tool_call_id").String()
content := message.Get("content").String()
toolContentResult := message.Get("content")
msg := `{"role":"user","content":[{"type":"tool_result","tool_use_id":"","content":""}]}`
msg, _ = sjson.Set(msg, "content.0.tool_use_id", toolCallID)
msg, _ = sjson.Set(msg, "content.0.content", content)
toolResultContent, toolResultContentRaw := convertOpenAIToolResultContent(toolContentResult)
if toolResultContentRaw {
msg, _ = sjson.SetRaw(msg, "content.0.content", toolResultContent)
} else {
msg, _ = sjson.Set(msg, "content.0.content", toolResultContent)
}
out, _ = sjson.SetRaw(out, "messages.-1", msg)
messageIndex++
}
@@ -358,3 +326,110 @@ func ConvertOpenAIRequestToClaude(modelName string, inputRawJSON []byte, stream
return []byte(out)
}
func convertOpenAIContentPartToClaudePart(part gjson.Result) string {
switch part.Get("type").String() {
case "text":
textPart := `{"type":"text","text":""}`
textPart, _ = sjson.Set(textPart, "text", part.Get("text").String())
return textPart
case "image_url":
return convertOpenAIImageURLToClaudePart(part.Get("image_url.url").String())
case "file":
fileData := part.Get("file.file_data").String()
if strings.HasPrefix(fileData, "data:") {
semicolonIdx := strings.Index(fileData, ";")
commaIdx := strings.Index(fileData, ",")
if semicolonIdx != -1 && commaIdx != -1 && commaIdx > semicolonIdx {
mediaType := strings.TrimPrefix(fileData[:semicolonIdx], "data:")
data := fileData[commaIdx+1:]
docPart := `{"type":"document","source":{"type":"base64","media_type":"","data":""}}`
docPart, _ = sjson.Set(docPart, "source.media_type", mediaType)
docPart, _ = sjson.Set(docPart, "source.data", data)
return docPart
}
}
}
return ""
}
func convertOpenAIImageURLToClaudePart(imageURL string) string {
if imageURL == "" {
return ""
}
if strings.HasPrefix(imageURL, "data:") {
parts := strings.SplitN(imageURL, ",", 2)
if len(parts) != 2 {
return ""
}
mediaTypePart := strings.SplitN(parts[0], ";", 2)[0]
mediaType := strings.TrimPrefix(mediaTypePart, "data:")
if mediaType == "" {
mediaType = "application/octet-stream"
}
imagePart := `{"type":"image","source":{"type":"base64","media_type":"","data":""}}`
imagePart, _ = sjson.Set(imagePart, "source.media_type", mediaType)
imagePart, _ = sjson.Set(imagePart, "source.data", parts[1])
return imagePart
}
imagePart := `{"type":"image","source":{"type":"url","url":""}}`
imagePart, _ = sjson.Set(imagePart, "source.url", imageURL)
return imagePart
}
func convertOpenAIToolResultContent(content gjson.Result) (string, bool) {
if !content.Exists() {
return "", false
}
if content.Type == gjson.String {
return content.String(), false
}
if content.IsArray() {
claudeContent := "[]"
partCount := 0
content.ForEach(func(_, part gjson.Result) bool {
if part.Type == gjson.String {
textPart := `{"type":"text","text":""}`
textPart, _ = sjson.Set(textPart, "text", part.String())
claudeContent, _ = sjson.SetRaw(claudeContent, "-1", textPart)
partCount++
return true
}
claudePart := convertOpenAIContentPartToClaudePart(part)
if claudePart != "" {
claudeContent, _ = sjson.SetRaw(claudeContent, "-1", claudePart)
partCount++
}
return true
})
if partCount > 0 || len(content.Array()) == 0 {
return claudeContent, true
}
return content.Raw, false
}
if content.IsObject() {
claudePart := convertOpenAIContentPartToClaudePart(content)
if claudePart != "" {
claudeContent := "[]"
claudeContent, _ = sjson.SetRaw(claudeContent, "-1", claudePart)
return claudeContent, true
}
return content.Raw, false
}
return content.Raw, false
}

View File

@@ -0,0 +1,137 @@
package chat_completions
import (
"testing"
"github.com/tidwall/gjson"
)
func TestConvertOpenAIRequestToClaude_ToolResultTextAndBase64Image(t *testing.T) {
inputJSON := `{
"model": "gpt-4.1",
"messages": [
{
"role": "assistant",
"content": "",
"tool_calls": [
{
"id": "call_1",
"type": "function",
"function": {
"name": "do_work",
"arguments": "{\"a\":1}"
}
}
]
},
{
"role": "tool",
"tool_call_id": "call_1",
"content": [
{"type": "text", "text": "tool ok"},
{
"type": "image_url",
"image_url": {
"url": "data:image/png;base64,iVBORw0KGgoAAAANSUhEUg=="
}
}
]
}
]
}`
result := ConvertOpenAIRequestToClaude("claude-sonnet-4-5", []byte(inputJSON), false)
resultJSON := gjson.ParseBytes(result)
messages := resultJSON.Get("messages").Array()
if len(messages) != 2 {
t.Fatalf("Expected 2 messages, got %d. Messages: %s", len(messages), resultJSON.Get("messages").Raw)
}
toolResult := messages[1].Get("content.0")
if got := toolResult.Get("type").String(); got != "tool_result" {
t.Fatalf("Expected content[0].type %q, got %q", "tool_result", got)
}
if got := toolResult.Get("tool_use_id").String(); got != "call_1" {
t.Fatalf("Expected tool_use_id %q, got %q", "call_1", got)
}
toolContent := toolResult.Get("content")
if !toolContent.IsArray() {
t.Fatalf("Expected tool_result content array, got %s", toolContent.Raw)
}
if got := toolContent.Get("0.type").String(); got != "text" {
t.Fatalf("Expected first tool_result part type %q, got %q", "text", got)
}
if got := toolContent.Get("0.text").String(); got != "tool ok" {
t.Fatalf("Expected first tool_result part text %q, got %q", "tool ok", got)
}
if got := toolContent.Get("1.type").String(); got != "image" {
t.Fatalf("Expected second tool_result part type %q, got %q", "image", got)
}
if got := toolContent.Get("1.source.type").String(); got != "base64" {
t.Fatalf("Expected image source type %q, got %q", "base64", got)
}
if got := toolContent.Get("1.source.media_type").String(); got != "image/png" {
t.Fatalf("Expected image media type %q, got %q", "image/png", got)
}
if got := toolContent.Get("1.source.data").String(); got != "iVBORw0KGgoAAAANSUhEUg==" {
t.Fatalf("Unexpected base64 image data: %q", got)
}
}
func TestConvertOpenAIRequestToClaude_ToolResultURLImageOnly(t *testing.T) {
inputJSON := `{
"model": "gpt-4.1",
"messages": [
{
"role": "assistant",
"content": "",
"tool_calls": [
{
"id": "call_1",
"type": "function",
"function": {
"name": "do_work",
"arguments": "{\"a\":1}"
}
}
]
},
{
"role": "tool",
"tool_call_id": "call_1",
"content": [
{
"type": "image_url",
"image_url": {
"url": "https://example.com/tool.png"
}
}
]
}
]
}`
result := ConvertOpenAIRequestToClaude("claude-sonnet-4-5", []byte(inputJSON), false)
resultJSON := gjson.ParseBytes(result)
messages := resultJSON.Get("messages").Array()
if len(messages) != 2 {
t.Fatalf("Expected 2 messages, got %d. Messages: %s", len(messages), resultJSON.Get("messages").Raw)
}
toolContent := messages[1].Get("content.0.content")
if !toolContent.IsArray() {
t.Fatalf("Expected tool_result content array, got %s", toolContent.Raw)
}
if got := toolContent.Get("0.type").String(); got != "image" {
t.Fatalf("Expected tool_result part type %q, got %q", "image", got)
}
if got := toolContent.Get("0.source.type").String(); got != "url" {
t.Fatalf("Expected image source type %q, got %q", "url", got)
}
if got := toolContent.Get("0.source.url").String(); got != "https://example.com/tool.png" {
t.Fatalf("Unexpected image URL: %q", got)
}
}

View File

@@ -25,7 +25,12 @@ func ConvertOpenAIResponsesRequestToCodex(modelName string, inputRawJSON []byte,
rawJSON, _ = sjson.DeleteBytes(rawJSON, "max_completion_tokens")
rawJSON, _ = sjson.DeleteBytes(rawJSON, "temperature")
rawJSON, _ = sjson.DeleteBytes(rawJSON, "top_p")
rawJSON, _ = sjson.DeleteBytes(rawJSON, "service_tier")
if v := gjson.GetBytes(rawJSON, "service_tier"); v.Exists() {
if v.String() != "priority" {
rawJSON, _ = sjson.DeleteBytes(rawJSON, "service_tier")
}
}
rawJSON, _ = sjson.DeleteBytes(rawJSON, "truncation")
rawJSON = applyResponsesCompactionCompatibility(rawJSON)

View File

@@ -6,24 +6,14 @@ import (
"fmt"
"github.com/tidwall/gjson"
"github.com/tidwall/sjson"
)
// ConvertCodexResponseToOpenAIResponses converts OpenAI Chat Completions streaming chunks
// to OpenAI Responses SSE events (response.*).
func ConvertCodexResponseToOpenAIResponses(ctx context.Context, modelName string, originalRequestRawJSON, requestRawJSON, rawJSON []byte, param *any) []string {
func ConvertCodexResponseToOpenAIResponses(_ context.Context, _ string, _, _, rawJSON []byte, _ *any) []string {
if bytes.HasPrefix(rawJSON, []byte("data:")) {
rawJSON = bytes.TrimSpace(rawJSON[5:])
if typeResult := gjson.GetBytes(rawJSON, "type"); typeResult.Exists() {
typeStr := typeResult.String()
if typeStr == "response.created" || typeStr == "response.in_progress" || typeStr == "response.completed" {
if gjson.GetBytes(rawJSON, "response.instructions").Exists() {
instructions := gjson.GetBytes(originalRequestRawJSON, "instructions").String()
rawJSON, _ = sjson.SetBytes(rawJSON, "response.instructions", instructions)
}
}
}
out := fmt.Sprintf("data: %s", string(rawJSON))
return []string{out}
}
@@ -32,17 +22,12 @@ func ConvertCodexResponseToOpenAIResponses(ctx context.Context, modelName string
// ConvertCodexResponseToOpenAIResponsesNonStream builds a single Responses JSON
// from a non-streaming OpenAI Chat Completions response.
func ConvertCodexResponseToOpenAIResponsesNonStream(_ context.Context, modelName string, originalRequestRawJSON, requestRawJSON, rawJSON []byte, _ *any) string {
func ConvertCodexResponseToOpenAIResponsesNonStream(_ context.Context, _ string, _, _, rawJSON []byte, _ *any) string {
rootResult := gjson.ParseBytes(rawJSON)
// Verify this is a response.completed event
if rootResult.Get("type").String() != "response.completed" {
return ""
}
responseResult := rootResult.Get("response")
template := responseResult.Raw
if responseResult.Get("instructions").Exists() {
instructions := gjson.GetBytes(originalRequestRawJSON, "instructions").String()
template, _ = sjson.Set(template, "instructions", instructions)
}
return template
return responseResult.Raw
}

View File

@@ -85,6 +85,11 @@ func ConvertClaudeRequestToGemini(modelName string, inputRawJSON []byte, _ bool)
case "tool_use":
functionName := contentResult.Get("name").String()
if toolUseID := contentResult.Get("id").String(); toolUseID != "" {
if derived := toolNameFromClaudeToolUseID(toolUseID); derived != "" {
functionName = derived
}
}
functionArgs := contentResult.Get("input").String()
argsResult := gjson.Parse(functionArgs)
if argsResult.IsObject() && gjson.Valid(functionArgs) {
@@ -100,10 +105,9 @@ func ConvertClaudeRequestToGemini(modelName string, inputRawJSON []byte, _ bool)
if toolCallID == "" {
return true
}
funcName := toolCallID
toolCallIDs := strings.Split(toolCallID, "-")
if len(toolCallIDs) > 1 {
funcName = strings.Join(toolCallIDs[0:len(toolCallIDs)-1], "-")
funcName := toolNameFromClaudeToolUseID(toolCallID)
if funcName == "" {
funcName = toolCallID
}
responseData := contentResult.Get("content").Raw
part := `{"functionResponse":{"name":"","response":{"result":""}}}`
@@ -230,3 +234,11 @@ func ConvertClaudeRequestToGemini(modelName string, inputRawJSON []byte, _ bool)
return result
}
func toolNameFromClaudeToolUseID(toolUseID string) string {
parts := strings.Split(toolUseID, "-")
if len(parts) <= 1 {
return ""
}
return strings.Join(parts[0:len(parts)-1], "-")
}

View File

@@ -12,8 +12,8 @@ import (
"fmt"
"strings"
"sync/atomic"
"time"
"github.com/router-for-me/CLIProxyAPI/v6/internal/util"
"github.com/tidwall/gjson"
"github.com/tidwall/sjson"
)
@@ -25,6 +25,8 @@ type Params struct {
ResponseType int
ResponseIndex int
HasContent bool // Tracks whether any content (text, thinking, or tool use) has been output
ToolNameMap map[string]string
SawToolCall bool
}
// toolUseIDCounter provides a process-wide unique counter for tool use identifiers.
@@ -53,6 +55,8 @@ func ConvertGeminiResponseToClaude(_ context.Context, _ string, originalRequestR
HasFirstResponse: false,
ResponseType: 0,
ResponseIndex: 0,
ToolNameMap: util.ToolNameMapFromClaudeRequest(originalRequestRawJSON),
SawToolCall: false,
}
}
@@ -66,8 +70,6 @@ func ConvertGeminiResponseToClaude(_ context.Context, _ string, originalRequestR
return []string{}
}
// Track whether tools are being used in this response chunk
usedTool := false
output := ""
// Initialize the streaming session with a message_start event
@@ -175,12 +177,13 @@ func ConvertGeminiResponseToClaude(_ context.Context, _ string, originalRequestR
} else if functionCallResult.Exists() {
// Handle function/tool calls from the AI model
// This processes tool usage requests and formats them for Claude API compatibility
usedTool = true
fcName := functionCallResult.Get("name").String()
(*param).(*Params).SawToolCall = true
upstreamToolName := functionCallResult.Get("name").String()
clientToolName := util.MapToolName((*param).(*Params).ToolNameMap, upstreamToolName)
// FIX: Handle streaming split/delta where name might be empty in subsequent chunks.
// If we are already in tool use mode and name is empty, treat as continuation (delta).
if (*param).(*Params).ResponseType == 3 && fcName == "" {
if (*param).(*Params).ResponseType == 3 && upstreamToolName == "" {
if fcArgsResult := functionCallResult.Get("args"); fcArgsResult.Exists() {
output = output + "event: content_block_delta\n"
data, _ := sjson.Set(fmt.Sprintf(`{"type":"content_block_delta","index":%d,"delta":{"type":"input_json_delta","partial_json":""}}`, (*param).(*Params).ResponseIndex), "delta.partial_json", fcArgsResult.Raw)
@@ -221,8 +224,8 @@ func ConvertGeminiResponseToClaude(_ context.Context, _ string, originalRequestR
// Create the tool use block with unique ID and function details
data := fmt.Sprintf(`{"type":"content_block_start","index":%d,"content_block":{"type":"tool_use","id":"","name":"","input":{}}}`, (*param).(*Params).ResponseIndex)
data, _ = sjson.Set(data, "content_block.id", fmt.Sprintf("%s-%d-%d", fcName, time.Now().UnixNano(), atomic.AddUint64(&toolUseIDCounter, 1)))
data, _ = sjson.Set(data, "content_block.name", fcName)
data, _ = sjson.Set(data, "content_block.id", fmt.Sprintf("%s-%d", upstreamToolName, atomic.AddUint64(&toolUseIDCounter, 1)))
data, _ = sjson.Set(data, "content_block.name", clientToolName)
output = output + fmt.Sprintf("data: %s\n\n\n", data)
if fcArgsResult := functionCallResult.Get("args"); fcArgsResult.Exists() {
@@ -249,7 +252,7 @@ func ConvertGeminiResponseToClaude(_ context.Context, _ string, originalRequestR
output = output + `data: `
template := `{"type":"message_delta","delta":{"stop_reason":"end_turn","stop_sequence":null},"usage":{"input_tokens":0,"output_tokens":0}}`
if usedTool {
if (*param).(*Params).SawToolCall {
template = `{"type":"message_delta","delta":{"stop_reason":"tool_use","stop_sequence":null},"usage":{"input_tokens":0,"output_tokens":0}}`
} else if finish := gjson.GetBytes(rawJSON, "candidates.0.finishReason"); finish.Exists() && finish.String() == "MAX_TOKENS" {
template = `{"type":"message_delta","delta":{"stop_reason":"max_tokens","stop_sequence":null},"usage":{"input_tokens":0,"output_tokens":0}}`
@@ -278,10 +281,10 @@ func ConvertGeminiResponseToClaude(_ context.Context, _ string, originalRequestR
// Returns:
// - string: A Claude-compatible JSON response.
func ConvertGeminiResponseToClaudeNonStream(_ context.Context, _ string, originalRequestRawJSON, requestRawJSON, rawJSON []byte, _ *any) string {
_ = originalRequestRawJSON
_ = requestRawJSON
root := gjson.ParseBytes(rawJSON)
toolNameMap := util.ToolNameMapFromClaudeRequest(originalRequestRawJSON)
out := `{"id":"","type":"message","role":"assistant","model":"","content":[],"stop_reason":null,"stop_sequence":null,"usage":{"input_tokens":0,"output_tokens":0}}`
out, _ = sjson.Set(out, "id", root.Get("responseId").String())
@@ -336,11 +339,12 @@ func ConvertGeminiResponseToClaudeNonStream(_ context.Context, _ string, origina
flushText()
hasToolCall = true
name := functionCall.Get("name").String()
upstreamToolName := functionCall.Get("name").String()
clientToolName := util.MapToolName(toolNameMap, upstreamToolName)
toolIDCounter++
toolBlock := `{"type":"tool_use","id":"","name":"","input":{}}`
toolBlock, _ = sjson.Set(toolBlock, "id", fmt.Sprintf("tool_%d", toolIDCounter))
toolBlock, _ = sjson.Set(toolBlock, "name", name)
toolBlock, _ = sjson.Set(toolBlock, "id", fmt.Sprintf("%s-%d", upstreamToolName, toolIDCounter))
toolBlock, _ = sjson.Set(toolBlock, "name", clientToolName)
inputRaw := "{}"
if args := functionCall.Get("args"); args.Exists() && gjson.Valid(args.Raw) && args.IsObject() {
inputRaw = args.Raw

View File

@@ -237,6 +237,33 @@ func ConvertOpenAIResponsesRequestToGemini(modelName string, inputRawJSON []byte
partJSON, _ = sjson.Set(partJSON, "inline_data.data", data)
}
}
case "input_audio":
audioData := contentItem.Get("data").String()
audioFormat := contentItem.Get("format").String()
if audioData != "" {
audioMimeMap := map[string]string{
"mp3": "audio/mpeg",
"wav": "audio/wav",
"ogg": "audio/ogg",
"flac": "audio/flac",
"aac": "audio/aac",
"webm": "audio/webm",
"pcm16": "audio/pcm",
"g711_ulaw": "audio/basic",
"g711_alaw": "audio/basic",
}
mimeType := "audio/wav"
if audioFormat != "" {
if mapped, ok := audioMimeMap[audioFormat]; ok {
mimeType = mapped
} else {
mimeType = "audio/" + audioFormat
}
}
partJSON = `{"inline_data":{"mime_type":"","data":""}}`
partJSON, _ = sjson.Set(partJSON, "inline_data.mime_type", mimeType)
partJSON, _ = sjson.Set(partJSON, "inline_data.data", audioData)
}
}
if partJSON != "" {

View File

@@ -183,7 +183,12 @@ func ConvertClaudeRequestToOpenAI(modelName string, inputRawJSON []byte, stream
// Collect tool_result to emit after the main message (ensures tool results follow tool_calls)
toolResultJSON := `{"role":"tool","tool_call_id":"","content":""}`
toolResultJSON, _ = sjson.Set(toolResultJSON, "tool_call_id", part.Get("tool_use_id").String())
toolResultJSON, _ = sjson.Set(toolResultJSON, "content", convertClaudeToolResultContentToString(part.Get("content")))
toolResultContent, toolResultContentRaw := convertClaudeToolResultContent(part.Get("content"))
if toolResultContentRaw {
toolResultJSON, _ = sjson.SetRaw(toolResultJSON, "content", toolResultContent)
} else {
toolResultJSON, _ = sjson.Set(toolResultJSON, "content", toolResultContent)
}
toolResults = append(toolResults, toolResultJSON)
}
return true
@@ -374,21 +379,41 @@ func convertClaudeContentPart(part gjson.Result) (string, bool) {
}
}
func convertClaudeToolResultContentToString(content gjson.Result) string {
func convertClaudeToolResultContent(content gjson.Result) (string, bool) {
if !content.Exists() {
return ""
return "", false
}
if content.Type == gjson.String {
return content.String()
return content.String(), false
}
if content.IsArray() {
var parts []string
contentJSON := "[]"
hasImagePart := false
content.ForEach(func(_, item gjson.Result) bool {
switch {
case item.Type == gjson.String:
parts = append(parts, item.String())
text := item.String()
parts = append(parts, text)
textContent := `{"type":"text","text":""}`
textContent, _ = sjson.Set(textContent, "text", text)
contentJSON, _ = sjson.SetRaw(contentJSON, "-1", textContent)
case item.IsObject() && item.Get("type").String() == "text":
text := item.Get("text").String()
parts = append(parts, text)
textContent := `{"type":"text","text":""}`
textContent, _ = sjson.Set(textContent, "text", text)
contentJSON, _ = sjson.SetRaw(contentJSON, "-1", textContent)
case item.IsObject() && item.Get("type").String() == "image":
contentItem, ok := convertClaudeContentPart(item)
if ok {
contentJSON, _ = sjson.SetRaw(contentJSON, "-1", contentItem)
hasImagePart = true
} else {
parts = append(parts, item.Raw)
}
case item.IsObject() && item.Get("text").Exists() && item.Get("text").Type == gjson.String:
parts = append(parts, item.Get("text").String())
default:
@@ -397,19 +422,31 @@ func convertClaudeToolResultContentToString(content gjson.Result) string {
return true
})
if hasImagePart {
return contentJSON, true
}
joined := strings.Join(parts, "\n\n")
if strings.TrimSpace(joined) != "" {
return joined
return joined, false
}
return content.Raw
return content.Raw, false
}
if content.IsObject() {
if text := content.Get("text"); text.Exists() && text.Type == gjson.String {
return text.String()
if content.Get("type").String() == "image" {
contentItem, ok := convertClaudeContentPart(content)
if ok {
contentJSON := "[]"
contentJSON, _ = sjson.SetRaw(contentJSON, "-1", contentItem)
return contentJSON, true
}
}
return content.Raw
if text := content.Get("text"); text.Exists() && text.Type == gjson.String {
return text.String(), false
}
return content.Raw, false
}
return content.Raw
return content.Raw, false
}

View File

@@ -488,6 +488,114 @@ func TestConvertClaudeRequestToOpenAI_ToolResultObjectContent(t *testing.T) {
}
}
func TestConvertClaudeRequestToOpenAI_ToolResultTextAndImageContent(t *testing.T) {
inputJSON := `{
"model": "claude-3-opus",
"messages": [
{
"role": "assistant",
"content": [
{"type": "tool_use", "id": "call_1", "name": "do_work", "input": {"a": 1}}
]
},
{
"role": "user",
"content": [
{
"type": "tool_result",
"tool_use_id": "call_1",
"content": [
{"type": "text", "text": "tool ok"},
{
"type": "image",
"source": {
"type": "base64",
"media_type": "image/png",
"data": "iVBORw0KGgoAAAANSUhEUg=="
}
}
]
}
]
}
]
}`
result := ConvertClaudeRequestToOpenAI("test-model", []byte(inputJSON), false)
resultJSON := gjson.ParseBytes(result)
messages := resultJSON.Get("messages").Array()
if len(messages) != 2 {
t.Fatalf("Expected 2 messages, got %d. Messages: %s", len(messages), resultJSON.Get("messages").Raw)
}
toolContent := messages[1].Get("content")
if !toolContent.IsArray() {
t.Fatalf("Expected tool content array, got %s", toolContent.Raw)
}
if got := toolContent.Get("0.type").String(); got != "text" {
t.Fatalf("Expected first tool content type %q, got %q", "text", got)
}
if got := toolContent.Get("0.text").String(); got != "tool ok" {
t.Fatalf("Expected first tool content text %q, got %q", "tool ok", got)
}
if got := toolContent.Get("1.type").String(); got != "image_url" {
t.Fatalf("Expected second tool content type %q, got %q", "image_url", got)
}
if got := toolContent.Get("1.image_url.url").String(); got != "data:image/png;base64,iVBORw0KGgoAAAANSUhEUg==" {
t.Fatalf("Unexpected image_url: %q", got)
}
}
func TestConvertClaudeRequestToOpenAI_ToolResultURLImageOnly(t *testing.T) {
inputJSON := `{
"model": "claude-3-opus",
"messages": [
{
"role": "assistant",
"content": [
{"type": "tool_use", "id": "call_1", "name": "do_work", "input": {"a": 1}}
]
},
{
"role": "user",
"content": [
{
"type": "tool_result",
"tool_use_id": "call_1",
"content": {
"type": "image",
"source": {
"type": "url",
"url": "https://example.com/tool.png"
}
}
}
]
}
]
}`
result := ConvertClaudeRequestToOpenAI("test-model", []byte(inputJSON), false)
resultJSON := gjson.ParseBytes(result)
messages := resultJSON.Get("messages").Array()
if len(messages) != 2 {
t.Fatalf("Expected 2 messages, got %d. Messages: %s", len(messages), resultJSON.Get("messages").Raw)
}
toolContent := messages[1].Get("content")
if !toolContent.IsArray() {
t.Fatalf("Expected tool content array, got %s", toolContent.Raw)
}
if got := toolContent.Get("0.type").String(); got != "image_url" {
t.Fatalf("Expected tool content type %q, got %q", "image_url", got)
}
if got := toolContent.Get("0.image_url.url").String(); got != "https://example.com/tool.png" {
t.Fatalf("Unexpected image_url: %q", got)
}
}
func TestConvertClaudeRequestToOpenAI_AssistantTextToolUseTextOrder(t *testing.T) {
inputJSON := `{
"model": "claude-3-opus",

View File

@@ -22,9 +22,11 @@ var (
// ConvertOpenAIResponseToAnthropicParams holds parameters for response conversion
type ConvertOpenAIResponseToAnthropicParams struct {
MessageID string
Model string
CreatedAt int64
MessageID string
Model string
CreatedAt int64
ToolNameMap map[string]string
SawToolCall bool
// Content accumulator for streaming
ContentAccumulator strings.Builder
// Tool calls accumulator for streaming
@@ -78,6 +80,8 @@ func ConvertOpenAIResponseToClaude(_ context.Context, _ string, originalRequestR
MessageID: "",
Model: "",
CreatedAt: 0,
ToolNameMap: nil,
SawToolCall: false,
ContentAccumulator: strings.Builder{},
ToolCallsAccumulator: nil,
TextContentBlockStarted: false,
@@ -97,6 +101,10 @@ func ConvertOpenAIResponseToClaude(_ context.Context, _ string, originalRequestR
}
rawJSON = bytes.TrimSpace(rawJSON[5:])
if (*param).(*ConvertOpenAIResponseToAnthropicParams).ToolNameMap == nil {
(*param).(*ConvertOpenAIResponseToAnthropicParams).ToolNameMap = util.ToolNameMapFromClaudeRequest(originalRequestRawJSON)
}
// Check if this is the [DONE] marker
rawStr := strings.TrimSpace(string(rawJSON))
if rawStr == "[DONE]" {
@@ -111,6 +119,16 @@ func ConvertOpenAIResponseToClaude(_ context.Context, _ string, originalRequestR
}
}
func effectiveOpenAIFinishReason(param *ConvertOpenAIResponseToAnthropicParams) string {
if param == nil {
return ""
}
if param.SawToolCall {
return "tool_calls"
}
return param.FinishReason
}
// convertOpenAIStreamingChunkToAnthropic converts OpenAI streaming chunk to Anthropic streaming events
func convertOpenAIStreamingChunkToAnthropic(rawJSON []byte, param *ConvertOpenAIResponseToAnthropicParams) []string {
root := gjson.ParseBytes(rawJSON)
@@ -197,6 +215,7 @@ func convertOpenAIStreamingChunkToAnthropic(rawJSON []byte, param *ConvertOpenAI
}
toolCalls.ForEach(func(_, toolCall gjson.Result) bool {
param.SawToolCall = true
index := int(toolCall.Get("index").Int())
blockIndex := param.toolContentBlockIndex(index)
@@ -215,7 +234,7 @@ func convertOpenAIStreamingChunkToAnthropic(rawJSON []byte, param *ConvertOpenAI
// Handle function name
if function := toolCall.Get("function"); function.Exists() {
if name := function.Get("name"); name.Exists() {
accumulator.Name = name.String()
accumulator.Name = util.MapToolName(param.ToolNameMap, name.String())
stopThinkingContentBlock(param, &results)
@@ -246,7 +265,11 @@ func convertOpenAIStreamingChunkToAnthropic(rawJSON []byte, param *ConvertOpenAI
// Handle finish_reason (but don't send message_delta/message_stop yet)
if finishReason := root.Get("choices.0.finish_reason"); finishReason.Exists() && finishReason.String() != "" {
reason := finishReason.String()
param.FinishReason = reason
if param.SawToolCall {
param.FinishReason = "tool_calls"
} else {
param.FinishReason = reason
}
// Send content_block_stop for thinking content if needed
if param.ThinkingContentBlockStarted {
@@ -294,7 +317,7 @@ func convertOpenAIStreamingChunkToAnthropic(rawJSON []byte, param *ConvertOpenAI
inputTokens, outputTokens, cachedTokens = extractOpenAIUsage(usage)
// Send message_delta with usage
messageDeltaJSON := `{"type":"message_delta","delta":{"stop_reason":"","stop_sequence":null},"usage":{"input_tokens":0,"output_tokens":0}}`
messageDeltaJSON, _ = sjson.Set(messageDeltaJSON, "delta.stop_reason", mapOpenAIFinishReasonToAnthropic(param.FinishReason))
messageDeltaJSON, _ = sjson.Set(messageDeltaJSON, "delta.stop_reason", mapOpenAIFinishReasonToAnthropic(effectiveOpenAIFinishReason(param)))
messageDeltaJSON, _ = sjson.Set(messageDeltaJSON, "usage.input_tokens", inputTokens)
messageDeltaJSON, _ = sjson.Set(messageDeltaJSON, "usage.output_tokens", outputTokens)
if cachedTokens > 0 {
@@ -348,7 +371,7 @@ func convertOpenAIDoneToAnthropic(param *ConvertOpenAIResponseToAnthropicParams)
// If we haven't sent message_delta yet (no usage info was received), send it now
if param.FinishReason != "" && !param.MessageDeltaSent {
messageDeltaJSON := `{"type":"message_delta","delta":{"stop_reason":"","stop_sequence":null},"usage":{"input_tokens":0,"output_tokens":0}}`
messageDeltaJSON, _ = sjson.Set(messageDeltaJSON, "delta.stop_reason", mapOpenAIFinishReasonToAnthropic(param.FinishReason))
messageDeltaJSON, _ = sjson.Set(messageDeltaJSON, "delta.stop_reason", mapOpenAIFinishReasonToAnthropic(effectiveOpenAIFinishReason(param)))
results = append(results, "event: message_delta\ndata: "+messageDeltaJSON+"\n\n")
param.MessageDeltaSent = true
}
@@ -531,10 +554,10 @@ func stopTextContentBlock(param *ConvertOpenAIResponseToAnthropicParams, results
// Returns:
// - string: An Anthropic-compatible JSON response.
func ConvertOpenAIResponseToClaudeNonStream(_ context.Context, _ string, originalRequestRawJSON, requestRawJSON, rawJSON []byte, _ *any) string {
_ = originalRequestRawJSON
_ = requestRawJSON
root := gjson.ParseBytes(rawJSON)
toolNameMap := util.ToolNameMapFromClaudeRequest(originalRequestRawJSON)
out := `{"id":"","type":"message","role":"assistant","model":"","content":[],"stop_reason":null,"stop_sequence":null,"usage":{"input_tokens":0,"output_tokens":0}}`
out, _ = sjson.Set(out, "id", root.Get("id").String())
out, _ = sjson.Set(out, "model", root.Get("model").String())
@@ -590,7 +613,7 @@ func ConvertOpenAIResponseToClaudeNonStream(_ context.Context, _ string, origina
hasToolCall = true
toolUse := `{"type":"tool_use","id":"","name":"","input":{}}`
toolUse, _ = sjson.Set(toolUse, "id", tc.Get("id").String())
toolUse, _ = sjson.Set(toolUse, "name", tc.Get("function.name").String())
toolUse, _ = sjson.Set(toolUse, "name", util.MapToolName(toolNameMap, tc.Get("function.name").String()))
argsStr := util.FixJSON(tc.Get("function.arguments").String())
if argsStr != "" && gjson.Valid(argsStr) {
@@ -647,7 +670,7 @@ func ConvertOpenAIResponseToClaudeNonStream(_ context.Context, _ string, origina
hasToolCall = true
toolUseBlock := `{"type":"tool_use","id":"","name":"","input":{}}`
toolUseBlock, _ = sjson.Set(toolUseBlock, "id", toolCall.Get("id").String())
toolUseBlock, _ = sjson.Set(toolUseBlock, "name", toolCall.Get("function.name").String())
toolUseBlock, _ = sjson.Set(toolUseBlock, "name", util.MapToolName(toolNameMap, toolCall.Get("function.name").String()))
argsStr := util.FixJSON(toolCall.Get("function.arguments").String())
if argsStr != "" && gjson.Valid(argsStr) {

View File

@@ -6,6 +6,7 @@ package util
import (
"bytes"
"fmt"
"strings"
"github.com/tidwall/gjson"
"github.com/tidwall/sjson"
@@ -219,3 +220,54 @@ func FixJSON(input string) string {
return out.String()
}
func CanonicalToolName(name string) string {
canonical := strings.TrimSpace(name)
canonical = strings.TrimLeft(canonical, "_")
return strings.ToLower(canonical)
}
// ToolNameMapFromClaudeRequest returns a canonical-name -> original-name map extracted from a Claude request.
// It is used to restore exact tool name casing for clients that require strict tool name matching (e.g. Claude Code).
func ToolNameMapFromClaudeRequest(rawJSON []byte) map[string]string {
if len(rawJSON) == 0 || !gjson.ValidBytes(rawJSON) {
return nil
}
tools := gjson.GetBytes(rawJSON, "tools")
if !tools.Exists() || !tools.IsArray() {
return nil
}
toolResults := tools.Array()
out := make(map[string]string, len(toolResults))
tools.ForEach(func(_, tool gjson.Result) bool {
name := strings.TrimSpace(tool.Get("name").String())
if name == "" {
return true
}
key := CanonicalToolName(name)
if key == "" {
return true
}
if _, exists := out[key]; !exists {
out[key] = name
}
return true
})
if len(out) == 0 {
return nil
}
return out
}
func MapToolName(toolNameMap map[string]string, name string) string {
if name == "" || toolNameMap == nil {
return name
}
if mapped, ok := toolNameMap[CanonicalToolName(name)]; ok && mapped != "" {
return mapped
}
return name
}

View File

@@ -17,6 +17,7 @@ import (
"github.com/router-for-me/CLIProxyAPI/v6/internal/config"
"github.com/router-for-me/CLIProxyAPI/v6/internal/util"
"github.com/router-for-me/CLIProxyAPI/v6/internal/watcher/diff"
"github.com/router-for-me/CLIProxyAPI/v6/internal/watcher/synthesizer"
coreauth "github.com/router-for-me/CLIProxyAPI/v6/sdk/cliproxy/auth"
log "github.com/sirupsen/logrus"
)
@@ -75,6 +76,7 @@ func (w *Watcher) reloadClients(rescanAuth bool, affectedOAuthProviders []string
w.lastAuthHashes = make(map[string]string)
w.lastAuthContents = make(map[string]*coreauth.Auth)
w.fileAuthsByPath = make(map[string]map[string]*coreauth.Auth)
if resolvedAuthDir, errResolveAuthDir := util.ResolveAuthDir(cfg.AuthDir); errResolveAuthDir != nil {
log.Errorf("failed to resolve auth directory for hash cache: %v", errResolveAuthDir)
} else if resolvedAuthDir != "" {
@@ -92,6 +94,17 @@ func (w *Watcher) reloadClients(rescanAuth bool, affectedOAuthProviders []string
if errParse := json.Unmarshal(data, &auth); errParse == nil {
w.lastAuthContents[normalizedPath] = &auth
}
ctx := &synthesizer.SynthesisContext{
Config: cfg,
AuthDir: resolvedAuthDir,
Now: time.Now(),
IDGenerator: synthesizer.NewStableIDGenerator(),
}
if generated := synthesizer.SynthesizeAuthFile(ctx, path, data); len(generated) > 0 {
if pathAuths := authSliceToMap(generated); len(pathAuths) > 0 {
w.fileAuthsByPath[normalizedPath] = pathAuths
}
}
}
}
return nil
@@ -143,13 +156,14 @@ func (w *Watcher) addOrUpdateClient(path string) {
}
w.clientsMutex.Lock()
cfg := w.config
if cfg == nil {
if w.config == nil {
log.Error("config is nil, cannot add or update client")
w.clientsMutex.Unlock()
return
}
if w.fileAuthsByPath == nil {
w.fileAuthsByPath = make(map[string]map[string]*coreauth.Auth)
}
if prev, ok := w.lastAuthHashes[normalized]; ok && prev == curHash {
log.Debugf("auth file unchanged (hash match), skipping reload: %s", filepath.Base(path))
w.clientsMutex.Unlock()
@@ -177,34 +191,86 @@ func (w *Watcher) addOrUpdateClient(path string) {
}
w.lastAuthContents[normalized] = &newAuth
w.clientsMutex.Unlock() // Unlock before the callback
w.refreshAuthState(false)
if w.reloadCallback != nil {
log.Debugf("triggering server update callback after add/update")
w.reloadCallback(cfg)
oldByID := make(map[string]*coreauth.Auth, len(w.fileAuthsByPath[normalized]))
for id, a := range w.fileAuthsByPath[normalized] {
oldByID[id] = a
}
// Build synthesized auth entries for this single file only.
sctx := &synthesizer.SynthesisContext{
Config: w.config,
AuthDir: w.authDir,
Now: time.Now(),
IDGenerator: synthesizer.NewStableIDGenerator(),
}
generated := synthesizer.SynthesizeAuthFile(sctx, path, data)
newByID := authSliceToMap(generated)
if len(newByID) > 0 {
w.fileAuthsByPath[normalized] = newByID
} else {
delete(w.fileAuthsByPath, normalized)
}
updates := w.computePerPathUpdatesLocked(oldByID, newByID)
w.clientsMutex.Unlock()
w.persistAuthAsync(fmt.Sprintf("Sync auth %s", filepath.Base(path)), path)
w.dispatchAuthUpdates(updates)
}
func (w *Watcher) removeClient(path string) {
normalized := w.normalizeAuthPath(path)
w.clientsMutex.Lock()
cfg := w.config
oldByID := make(map[string]*coreauth.Auth, len(w.fileAuthsByPath[normalized]))
for id, a := range w.fileAuthsByPath[normalized] {
oldByID[id] = a
}
delete(w.lastAuthHashes, normalized)
delete(w.lastAuthContents, normalized)
delete(w.fileAuthsByPath, normalized)
w.clientsMutex.Unlock() // Release the lock before the callback
updates := w.computePerPathUpdatesLocked(oldByID, map[string]*coreauth.Auth{})
w.clientsMutex.Unlock()
w.refreshAuthState(false)
if w.reloadCallback != nil {
log.Debugf("triggering server update callback after removal")
w.reloadCallback(cfg)
}
w.persistAuthAsync(fmt.Sprintf("Remove auth %s", filepath.Base(path)), path)
w.dispatchAuthUpdates(updates)
}
func (w *Watcher) computePerPathUpdatesLocked(oldByID, newByID map[string]*coreauth.Auth) []AuthUpdate {
if w.currentAuths == nil {
w.currentAuths = make(map[string]*coreauth.Auth)
}
updates := make([]AuthUpdate, 0, len(oldByID)+len(newByID))
for id, newAuth := range newByID {
existing, ok := w.currentAuths[id]
if !ok {
w.currentAuths[id] = newAuth.Clone()
updates = append(updates, AuthUpdate{Action: AuthUpdateActionAdd, ID: id, Auth: newAuth.Clone()})
continue
}
if !authEqual(existing, newAuth) {
w.currentAuths[id] = newAuth.Clone()
updates = append(updates, AuthUpdate{Action: AuthUpdateActionModify, ID: id, Auth: newAuth.Clone()})
}
}
for id := range oldByID {
if _, stillExists := newByID[id]; stillExists {
continue
}
delete(w.currentAuths, id)
updates = append(updates, AuthUpdate{Action: AuthUpdateActionDelete, ID: id})
}
return updates
}
func authSliceToMap(auths []*coreauth.Auth) map[string]*coreauth.Auth {
byID := make(map[string]*coreauth.Auth, len(auths))
for _, a := range auths {
if a == nil || strings.TrimSpace(a.ID) == "" {
continue
}
byID[a.ID] = a
}
return byID
}
func (w *Watcher) loadFileClients(cfg *config.Config) int {
@@ -303,3 +369,79 @@ func (w *Watcher) persistAuthAsync(message string, paths ...string) {
}
}()
}
func (w *Watcher) stopServerUpdateTimer() {
w.serverUpdateMu.Lock()
defer w.serverUpdateMu.Unlock()
if w.serverUpdateTimer != nil {
w.serverUpdateTimer.Stop()
w.serverUpdateTimer = nil
}
w.serverUpdatePend = false
}
func (w *Watcher) triggerServerUpdate(cfg *config.Config) {
if w == nil || w.reloadCallback == nil || cfg == nil {
return
}
if w.stopped.Load() {
return
}
now := time.Now()
w.serverUpdateMu.Lock()
if w.serverUpdateLast.IsZero() || now.Sub(w.serverUpdateLast) >= serverUpdateDebounce {
w.serverUpdateLast = now
if w.serverUpdateTimer != nil {
w.serverUpdateTimer.Stop()
w.serverUpdateTimer = nil
}
w.serverUpdatePend = false
w.serverUpdateMu.Unlock()
w.reloadCallback(cfg)
return
}
if w.serverUpdatePend {
w.serverUpdateMu.Unlock()
return
}
delay := serverUpdateDebounce - now.Sub(w.serverUpdateLast)
if delay < 10*time.Millisecond {
delay = 10 * time.Millisecond
}
w.serverUpdatePend = true
if w.serverUpdateTimer != nil {
w.serverUpdateTimer.Stop()
w.serverUpdateTimer = nil
}
var timer *time.Timer
timer = time.AfterFunc(delay, func() {
if w.stopped.Load() {
return
}
w.clientsMutex.RLock()
latestCfg := w.config
w.clientsMutex.RUnlock()
w.serverUpdateMu.Lock()
if w.serverUpdateTimer != timer || !w.serverUpdatePend {
w.serverUpdateMu.Unlock()
return
}
w.serverUpdateTimer = nil
w.serverUpdatePend = false
if latestCfg == nil || w.reloadCallback == nil || w.stopped.Load() {
w.serverUpdateMu.Unlock()
return
}
w.serverUpdateLast = time.Now()
w.serverUpdateMu.Unlock()
w.reloadCallback(latestCfg)
})
w.serverUpdateTimer = timer
w.serverUpdateMu.Unlock()
}

View File

@@ -14,6 +14,8 @@ import (
coreauth "github.com/router-for-me/CLIProxyAPI/v6/sdk/cliproxy/auth"
)
var snapshotCoreAuthsFunc = snapshotCoreAuths
func (w *Watcher) setAuthUpdateQueue(queue chan<- AuthUpdate) {
w.clientsMutex.Lock()
defer w.clientsMutex.Unlock()
@@ -76,7 +78,11 @@ func (w *Watcher) dispatchRuntimeAuthUpdate(update AuthUpdate) bool {
}
func (w *Watcher) refreshAuthState(force bool) {
auths := w.SnapshotCoreAuths()
w.clientsMutex.RLock()
cfg := w.config
authDir := w.authDir
w.clientsMutex.RUnlock()
auths := snapshotCoreAuthsFunc(cfg, authDir)
w.clientsMutex.Lock()
if len(w.runtimeAuths) > 0 {
for _, a := range w.runtimeAuths {

View File

@@ -36,9 +36,6 @@ func (s *FileSynthesizer) Synthesize(ctx *SynthesisContext) ([]*coreauth.Auth, e
return out, nil
}
now := ctx.Now
cfg := ctx.Config
for _, e := range entries {
if e.IsDir() {
continue
@@ -52,99 +49,120 @@ func (s *FileSynthesizer) Synthesize(ctx *SynthesisContext) ([]*coreauth.Auth, e
if errRead != nil || len(data) == 0 {
continue
}
var metadata map[string]any
if errUnmarshal := json.Unmarshal(data, &metadata); errUnmarshal != nil {
auths := synthesizeFileAuths(ctx, full, data)
if len(auths) == 0 {
continue
}
t, _ := metadata["type"].(string)
if t == "" {
continue
}
provider := strings.ToLower(t)
if provider == "gemini" {
provider = "gemini-cli"
}
label := provider
if email, _ := metadata["email"].(string); email != "" {
label = email
}
// Use relative path under authDir as ID to stay consistent with the file-based token store
id := full
if rel, errRel := filepath.Rel(ctx.AuthDir, full); errRel == nil && rel != "" {
id = rel
}
// On Windows, normalize ID casing to avoid duplicate auth entries caused by case-insensitive paths.
if runtime.GOOS == "windows" {
id = strings.ToLower(id)
}
proxyURL := ""
if p, ok := metadata["proxy_url"].(string); ok {
proxyURL = p
}
prefix := ""
if rawPrefix, ok := metadata["prefix"].(string); ok {
trimmed := strings.TrimSpace(rawPrefix)
trimmed = strings.Trim(trimmed, "/")
if trimmed != "" && !strings.Contains(trimmed, "/") {
prefix = trimmed
}
}
disabled, _ := metadata["disabled"].(bool)
status := coreauth.StatusActive
if disabled {
status = coreauth.StatusDisabled
}
// Read per-account excluded models from the OAuth JSON file
perAccountExcluded := extractExcludedModelsFromMetadata(metadata)
a := &coreauth.Auth{
ID: id,
Provider: provider,
Label: label,
Prefix: prefix,
Status: status,
Disabled: disabled,
Attributes: map[string]string{
"source": full,
"path": full,
},
ProxyURL: proxyURL,
Metadata: metadata,
CreatedAt: now,
UpdatedAt: now,
}
// Read priority from auth file
if rawPriority, ok := metadata["priority"]; ok {
switch v := rawPriority.(type) {
case float64:
a.Attributes["priority"] = strconv.Itoa(int(v))
case string:
priority := strings.TrimSpace(v)
if _, errAtoi := strconv.Atoi(priority); errAtoi == nil {
a.Attributes["priority"] = priority
}
}
}
ApplyAuthExcludedModelsMeta(a, cfg, perAccountExcluded, "oauth")
if provider == "gemini-cli" {
if virtuals := SynthesizeGeminiVirtualAuths(a, metadata, now); len(virtuals) > 0 {
for _, v := range virtuals {
ApplyAuthExcludedModelsMeta(v, cfg, perAccountExcluded, "oauth")
}
out = append(out, a)
out = append(out, virtuals...)
continue
}
}
out = append(out, a)
out = append(out, auths...)
}
return out, nil
}
// SynthesizeAuthFile generates Auth entries for one auth JSON file payload.
// It shares exactly the same mapping behavior as FileSynthesizer.Synthesize.
func SynthesizeAuthFile(ctx *SynthesisContext, fullPath string, data []byte) []*coreauth.Auth {
return synthesizeFileAuths(ctx, fullPath, data)
}
func synthesizeFileAuths(ctx *SynthesisContext, fullPath string, data []byte) []*coreauth.Auth {
if ctx == nil || len(data) == 0 {
return nil
}
now := ctx.Now
cfg := ctx.Config
var metadata map[string]any
if errUnmarshal := json.Unmarshal(data, &metadata); errUnmarshal != nil {
return nil
}
t, _ := metadata["type"].(string)
if t == "" {
return nil
}
provider := strings.ToLower(t)
if provider == "gemini" {
provider = "gemini-cli"
}
label := provider
if email, _ := metadata["email"].(string); email != "" {
label = email
}
// Use relative path under authDir as ID to stay consistent with the file-based token store.
id := fullPath
if strings.TrimSpace(ctx.AuthDir) != "" {
if rel, errRel := filepath.Rel(ctx.AuthDir, fullPath); errRel == nil && rel != "" {
id = rel
}
}
if runtime.GOOS == "windows" {
id = strings.ToLower(id)
}
proxyURL := ""
if p, ok := metadata["proxy_url"].(string); ok {
proxyURL = p
}
prefix := ""
if rawPrefix, ok := metadata["prefix"].(string); ok {
trimmed := strings.TrimSpace(rawPrefix)
trimmed = strings.Trim(trimmed, "/")
if trimmed != "" && !strings.Contains(trimmed, "/") {
prefix = trimmed
}
}
disabled, _ := metadata["disabled"].(bool)
status := coreauth.StatusActive
if disabled {
status = coreauth.StatusDisabled
}
// Read per-account excluded models from the OAuth JSON file.
perAccountExcluded := extractExcludedModelsFromMetadata(metadata)
a := &coreauth.Auth{
ID: id,
Provider: provider,
Label: label,
Prefix: prefix,
Status: status,
Disabled: disabled,
Attributes: map[string]string{
"source": fullPath,
"path": fullPath,
},
ProxyURL: proxyURL,
Metadata: metadata,
CreatedAt: now,
UpdatedAt: now,
}
// Read priority from auth file.
if rawPriority, ok := metadata["priority"]; ok {
switch v := rawPriority.(type) {
case float64:
a.Attributes["priority"] = strconv.Itoa(int(v))
case string:
priority := strings.TrimSpace(v)
if _, errAtoi := strconv.Atoi(priority); errAtoi == nil {
a.Attributes["priority"] = priority
}
}
}
ApplyAuthExcludedModelsMeta(a, cfg, perAccountExcluded, "oauth")
if provider == "gemini-cli" {
if virtuals := SynthesizeGeminiVirtualAuths(a, metadata, now); len(virtuals) > 0 {
for _, v := range virtuals {
ApplyAuthExcludedModelsMeta(v, cfg, perAccountExcluded, "oauth")
}
out := make([]*coreauth.Auth, 0, 1+len(virtuals))
out = append(out, a)
out = append(out, virtuals...)
return out
}
}
return []*coreauth.Auth{a}
}
// SynthesizeGeminiVirtualAuths creates virtual Auth entries for multi-project Gemini credentials.
// It disables the primary auth and creates one virtual auth per project.
func SynthesizeGeminiVirtualAuths(primary *coreauth.Auth, metadata map[string]any, now time.Time) []*coreauth.Auth {

View File

@@ -6,6 +6,7 @@ import (
"context"
"strings"
"sync"
"sync/atomic"
"time"
"github.com/fsnotify/fsnotify"
@@ -35,10 +36,16 @@ type Watcher struct {
clientsMutex sync.RWMutex
configReloadMu sync.Mutex
configReloadTimer *time.Timer
serverUpdateMu sync.Mutex
serverUpdateTimer *time.Timer
serverUpdateLast time.Time
serverUpdatePend bool
stopped atomic.Bool
reloadCallback func(*config.Config)
watcher *fsnotify.Watcher
lastAuthHashes map[string]string
lastAuthContents map[string]*coreauth.Auth
fileAuthsByPath map[string]map[string]*coreauth.Auth
lastRemoveTimes map[string]time.Time
lastConfigHash string
authQueue chan<- AuthUpdate
@@ -76,6 +83,7 @@ const (
replaceCheckDelay = 50 * time.Millisecond
configReloadDebounce = 150 * time.Millisecond
authRemoveDebounceWindow = 1 * time.Second
serverUpdateDebounce = 1 * time.Second
)
// NewWatcher creates a new file watcher instance
@@ -85,11 +93,12 @@ func NewWatcher(configPath, authDir string, reloadCallback func(*config.Config))
return nil, errNewWatcher
}
w := &Watcher{
configPath: configPath,
authDir: authDir,
reloadCallback: reloadCallback,
watcher: watcher,
lastAuthHashes: make(map[string]string),
configPath: configPath,
authDir: authDir,
reloadCallback: reloadCallback,
watcher: watcher,
lastAuthHashes: make(map[string]string),
fileAuthsByPath: make(map[string]map[string]*coreauth.Auth),
}
w.dispatchCond = sync.NewCond(&w.dispatchMu)
if store := sdkAuth.GetTokenStore(); store != nil {
@@ -114,8 +123,10 @@ func (w *Watcher) Start(ctx context.Context) error {
// Stop stops the file watcher
func (w *Watcher) Stop() error {
w.stopped.Store(true)
w.stopDispatch()
w.stopConfigReloadTimer()
w.stopServerUpdateTimer()
return w.watcher.Close()
}

View File

@@ -406,8 +406,8 @@ func TestAddOrUpdateClientTriggersReloadAndHash(t *testing.T) {
w.addOrUpdateClient(authFile)
if got := atomic.LoadInt32(&reloads); got != 1 {
t.Fatalf("expected reload callback once, got %d", got)
if got := atomic.LoadInt32(&reloads); got != 0 {
t.Fatalf("expected no reload callback for auth update, got %d", got)
}
// Use normalizeAuthPath to match how addOrUpdateClient stores the key
normalized := w.normalizeAuthPath(authFile)
@@ -436,8 +436,150 @@ func TestRemoveClientRemovesHash(t *testing.T) {
if _, ok := w.lastAuthHashes[w.normalizeAuthPath(authFile)]; ok {
t.Fatal("expected hash to be removed after deletion")
}
if got := atomic.LoadInt32(&reloads); got != 0 {
t.Fatalf("expected no reload callback for auth removal, got %d", got)
}
}
func TestAuthFileEventsDoNotInvokeSnapshotCoreAuths(t *testing.T) {
tmpDir := t.TempDir()
authFile := filepath.Join(tmpDir, "sample.json")
if err := os.WriteFile(authFile, []byte(`{"type":"codex","email":"u@example.com"}`), 0o644); err != nil {
t.Fatalf("failed to create auth file: %v", err)
}
origSnapshot := snapshotCoreAuthsFunc
var snapshotCalls int32
snapshotCoreAuthsFunc = func(cfg *config.Config, authDir string) []*coreauth.Auth {
atomic.AddInt32(&snapshotCalls, 1)
return origSnapshot(cfg, authDir)
}
defer func() { snapshotCoreAuthsFunc = origSnapshot }()
w := &Watcher{
authDir: tmpDir,
lastAuthHashes: make(map[string]string),
lastAuthContents: make(map[string]*coreauth.Auth),
fileAuthsByPath: make(map[string]map[string]*coreauth.Auth),
}
w.SetConfig(&config.Config{AuthDir: tmpDir})
w.addOrUpdateClient(authFile)
w.removeClient(authFile)
if got := atomic.LoadInt32(&snapshotCalls); got != 0 {
t.Fatalf("expected auth file events to avoid full snapshot, got %d calls", got)
}
}
func TestAuthSliceToMap(t *testing.T) {
t.Parallel()
valid1 := &coreauth.Auth{ID: "a"}
valid2 := &coreauth.Auth{ID: "b"}
dupOld := &coreauth.Auth{ID: "dup", Label: "old"}
dupNew := &coreauth.Auth{ID: "dup", Label: "new"}
empty := &coreauth.Auth{ID: " "}
tests := []struct {
name string
in []*coreauth.Auth
want map[string]*coreauth.Auth
}{
{
name: "nil input",
in: nil,
want: map[string]*coreauth.Auth{},
},
{
name: "empty input",
in: []*coreauth.Auth{},
want: map[string]*coreauth.Auth{},
},
{
name: "filters invalid auths",
in: []*coreauth.Auth{nil, empty},
want: map[string]*coreauth.Auth{},
},
{
name: "keeps valid auths",
in: []*coreauth.Auth{valid1, nil, valid2},
want: map[string]*coreauth.Auth{"a": valid1, "b": valid2},
},
{
name: "last duplicate wins",
in: []*coreauth.Auth{dupOld, dupNew},
want: map[string]*coreauth.Auth{"dup": dupNew},
},
}
for _, tc := range tests {
tc := tc
t.Run(tc.name, func(t *testing.T) {
t.Parallel()
got := authSliceToMap(tc.in)
if len(tc.want) == 0 {
if got == nil {
t.Fatal("expected empty map, got nil")
}
if len(got) != 0 {
t.Fatalf("expected empty map, got %#v", got)
}
return
}
if len(got) != len(tc.want) {
t.Fatalf("unexpected map length: got %d, want %d", len(got), len(tc.want))
}
for id, wantAuth := range tc.want {
gotAuth, ok := got[id]
if !ok {
t.Fatalf("missing id %q in result map", id)
}
if !authEqual(gotAuth, wantAuth) {
t.Fatalf("unexpected auth for id %q: got %#v, want %#v", id, gotAuth, wantAuth)
}
}
})
}
}
func TestTriggerServerUpdateCancelsPendingTimerOnImmediate(t *testing.T) {
tmpDir := t.TempDir()
cfg := &config.Config{AuthDir: tmpDir}
var reloads int32
w := &Watcher{
reloadCallback: func(*config.Config) {
atomic.AddInt32(&reloads, 1)
},
}
w.SetConfig(cfg)
w.serverUpdateMu.Lock()
w.serverUpdateLast = time.Now().Add(-(serverUpdateDebounce - 100*time.Millisecond))
w.serverUpdateMu.Unlock()
w.triggerServerUpdate(cfg)
if got := atomic.LoadInt32(&reloads); got != 0 {
t.Fatalf("expected no immediate reload, got %d", got)
}
w.serverUpdateMu.Lock()
if !w.serverUpdatePend || w.serverUpdateTimer == nil {
w.serverUpdateMu.Unlock()
t.Fatal("expected a pending server update timer")
}
w.serverUpdateLast = time.Now().Add(-(serverUpdateDebounce + 10*time.Millisecond))
w.serverUpdateMu.Unlock()
w.triggerServerUpdate(cfg)
if got := atomic.LoadInt32(&reloads); got != 1 {
t.Fatalf("expected reload callback once, got %d", got)
t.Fatalf("expected immediate reload once, got %d", got)
}
time.Sleep(250 * time.Millisecond)
if got := atomic.LoadInt32(&reloads); got != 1 {
t.Fatalf("expected pending timer to be cancelled, got %d reloads", got)
}
}
@@ -655,8 +797,8 @@ func TestHandleEventRemovesAuthFile(t *testing.T) {
w.handleEvent(fsnotify.Event{Name: authFile, Op: fsnotify.Remove})
if atomic.LoadInt32(&reloads) != 1 {
t.Fatalf("expected reload callback once, got %d", reloads)
if atomic.LoadInt32(&reloads) != 0 {
t.Fatalf("expected no reload callback for auth removal, got %d", reloads)
}
if _, ok := w.lastAuthHashes[w.normalizeAuthPath(authFile)]; ok {
t.Fatal("expected hash entry to be removed")
@@ -853,8 +995,8 @@ func TestHandleEventAuthWriteTriggersUpdate(t *testing.T) {
w.SetConfig(&config.Config{AuthDir: authDir})
w.handleEvent(fsnotify.Event{Name: authFile, Op: fsnotify.Write})
if atomic.LoadInt32(&reloads) != 1 {
t.Fatalf("expected auth write to trigger reload callback, got %d", reloads)
if atomic.LoadInt32(&reloads) != 0 {
t.Fatalf("expected auth write to avoid global reload callback, got %d", reloads)
}
}
@@ -950,8 +1092,8 @@ func TestHandleEventAtomicReplaceChangedTriggersUpdate(t *testing.T) {
w.lastAuthHashes[w.normalizeAuthPath(authFile)] = hexString(oldSum[:])
w.handleEvent(fsnotify.Event{Name: authFile, Op: fsnotify.Rename})
if atomic.LoadInt32(&reloads) != 1 {
t.Fatalf("expected changed atomic replace to trigger update, got %d", reloads)
if atomic.LoadInt32(&reloads) != 0 {
t.Fatalf("expected changed atomic replace to avoid global reload, got %d", reloads)
}
}
@@ -1005,8 +1147,8 @@ func TestHandleEventRemoveKnownFileDeletes(t *testing.T) {
w.lastAuthHashes[w.normalizeAuthPath(authFile)] = "hash"
w.handleEvent(fsnotify.Event{Name: authFile, Op: fsnotify.Remove})
if atomic.LoadInt32(&reloads) != 1 {
t.Fatalf("expected known remove to trigger reload, got %d", reloads)
if atomic.LoadInt32(&reloads) != 0 {
t.Fatalf("expected known remove to avoid global reload, got %d", reloads)
}
if _, ok := w.lastAuthHashes[w.normalizeAuthPath(authFile)]; ok {
t.Fatal("expected known auth hash to be deleted")

View File

@@ -14,7 +14,11 @@ import (
"github.com/google/uuid"
"github.com/gorilla/websocket"
"github.com/router-for-me/CLIProxyAPI/v6/internal/interfaces"
"github.com/router-for-me/CLIProxyAPI/v6/internal/registry"
"github.com/router-for-me/CLIProxyAPI/v6/internal/thinking"
"github.com/router-for-me/CLIProxyAPI/v6/internal/util"
"github.com/router-for-me/CLIProxyAPI/v6/sdk/api/handlers"
coreauth "github.com/router-for-me/CLIProxyAPI/v6/sdk/cliproxy/auth"
cliproxyexecutor "github.com/router-for-me/CLIProxyAPI/v6/sdk/cliproxy/executor"
log "github.com/sirupsen/logrus"
"github.com/tidwall/gjson"
@@ -26,7 +30,6 @@ const (
wsRequestTypeAppend = "response.append"
wsEventTypeError = "error"
wsEventTypeCompleted = "response.completed"
wsEventTypeDone = "response.done"
wsDoneMarker = "[DONE]"
wsTurnStateHeader = "x-codex-turn-state"
wsRequestBodyKey = "REQUEST_BODY_OVERRIDE"
@@ -101,11 +104,17 @@ func (h *OpenAIResponsesAPIHandler) ResponsesWebsocket(c *gin.Context) {
// )
appendWebsocketEvent(&wsBodyLog, "request", payload)
allowIncrementalInputWithPreviousResponseID := websocketUpstreamSupportsIncrementalInput(nil, nil)
allowIncrementalInputWithPreviousResponseID := false
if pinnedAuthID != "" && h != nil && h.AuthManager != nil {
if pinnedAuth, ok := h.AuthManager.GetByID(pinnedAuthID); ok && pinnedAuth != nil {
allowIncrementalInputWithPreviousResponseID = websocketUpstreamSupportsIncrementalInput(pinnedAuth.Attributes, pinnedAuth.Metadata)
}
} else {
requestModelName := strings.TrimSpace(gjson.GetBytes(payload, "model").String())
if requestModelName == "" {
requestModelName = strings.TrimSpace(gjson.GetBytes(lastRequest, "model").String())
}
allowIncrementalInputWithPreviousResponseID = h.websocketUpstreamSupportsIncrementalInputForModel(requestModelName)
}
var requestJSON []byte
@@ -140,6 +149,22 @@ func (h *OpenAIResponsesAPIHandler) ResponsesWebsocket(c *gin.Context) {
}
continue
}
if shouldHandleResponsesWebsocketPrewarmLocally(payload, lastRequest, allowIncrementalInputWithPreviousResponseID) {
if updated, errDelete := sjson.DeleteBytes(requestJSON, "generate"); errDelete == nil {
requestJSON = updated
}
if updated, errDelete := sjson.DeleteBytes(updatedLastRequest, "generate"); errDelete == nil {
updatedLastRequest = updated
}
lastRequest = updatedLastRequest
lastResponseOutput = []byte("[]")
if errWrite := writeResponsesWebsocketSyntheticPrewarm(c, conn, requestJSON, &wsBodyLog, passthroughSessionID); errWrite != nil {
wsTerminateErr = errWrite
appendWebsocketEvent(&wsBodyLog, "disconnect", []byte(errWrite.Error()))
return
}
continue
}
lastRequest = updatedLastRequest
modelName := gjson.GetBytes(requestJSON, "model").String()
@@ -340,6 +365,192 @@ func websocketUpstreamSupportsIncrementalInput(attributes map[string]string, met
return false
}
func (h *OpenAIResponsesAPIHandler) websocketUpstreamSupportsIncrementalInputForModel(modelName string) bool {
if h == nil || h.AuthManager == nil {
return false
}
resolvedModelName := modelName
initialSuffix := thinking.ParseSuffix(modelName)
if initialSuffix.ModelName == "auto" {
resolvedBase := util.ResolveAutoModel(initialSuffix.ModelName)
if initialSuffix.HasSuffix {
resolvedModelName = fmt.Sprintf("%s(%s)", resolvedBase, initialSuffix.RawSuffix)
} else {
resolvedModelName = resolvedBase
}
} else {
resolvedModelName = util.ResolveAutoModel(modelName)
}
parsed := thinking.ParseSuffix(resolvedModelName)
baseModel := strings.TrimSpace(parsed.ModelName)
providers := util.GetProviderName(baseModel)
if len(providers) == 0 && baseModel != resolvedModelName {
providers = util.GetProviderName(resolvedModelName)
}
if len(providers) == 0 {
return false
}
providerSet := make(map[string]struct{}, len(providers))
for i := 0; i < len(providers); i++ {
providerKey := strings.TrimSpace(strings.ToLower(providers[i]))
if providerKey == "" {
continue
}
providerSet[providerKey] = struct{}{}
}
if len(providerSet) == 0 {
return false
}
modelKey := baseModel
if modelKey == "" {
modelKey = strings.TrimSpace(resolvedModelName)
}
registryRef := registry.GetGlobalRegistry()
now := time.Now()
auths := h.AuthManager.List()
for i := 0; i < len(auths); i++ {
auth := auths[i]
if auth == nil {
continue
}
providerKey := strings.TrimSpace(strings.ToLower(auth.Provider))
if _, ok := providerSet[providerKey]; !ok {
continue
}
if modelKey != "" && registryRef != nil && !registryRef.ClientSupportsModel(auth.ID, modelKey) {
continue
}
if !responsesWebsocketAuthAvailableForModel(auth, modelKey, now) {
continue
}
if websocketUpstreamSupportsIncrementalInput(auth.Attributes, auth.Metadata) {
return true
}
}
return false
}
func responsesWebsocketAuthAvailableForModel(auth *coreauth.Auth, modelName string, now time.Time) bool {
if auth == nil {
return false
}
if auth.Disabled || auth.Status == coreauth.StatusDisabled {
return false
}
if modelName != "" && len(auth.ModelStates) > 0 {
state, ok := auth.ModelStates[modelName]
if (!ok || state == nil) && modelName != "" {
baseModel := strings.TrimSpace(thinking.ParseSuffix(modelName).ModelName)
if baseModel != "" && baseModel != modelName {
state, ok = auth.ModelStates[baseModel]
}
}
if ok && state != nil {
if state.Status == coreauth.StatusDisabled {
return false
}
if state.Unavailable && !state.NextRetryAfter.IsZero() && state.NextRetryAfter.After(now) {
return false
}
return true
}
}
if auth.Unavailable && !auth.NextRetryAfter.IsZero() && auth.NextRetryAfter.After(now) {
return false
}
return true
}
func shouldHandleResponsesWebsocketPrewarmLocally(rawJSON []byte, lastRequest []byte, allowIncrementalInputWithPreviousResponseID bool) bool {
if allowIncrementalInputWithPreviousResponseID || len(lastRequest) != 0 {
return false
}
if strings.TrimSpace(gjson.GetBytes(rawJSON, "type").String()) != wsRequestTypeCreate {
return false
}
generateResult := gjson.GetBytes(rawJSON, "generate")
return generateResult.Exists() && !generateResult.Bool()
}
func writeResponsesWebsocketSyntheticPrewarm(
c *gin.Context,
conn *websocket.Conn,
requestJSON []byte,
wsBodyLog *strings.Builder,
sessionID string,
) error {
payloads, errPayloads := syntheticResponsesWebsocketPrewarmPayloads(requestJSON)
if errPayloads != nil {
return errPayloads
}
for i := 0; i < len(payloads); i++ {
markAPIResponseTimestamp(c)
appendWebsocketEvent(wsBodyLog, "response", payloads[i])
// log.Infof(
// "responses websocket: downstream_out id=%s type=%d event=%s payload=%s",
// sessionID,
// websocket.TextMessage,
// websocketPayloadEventType(payloads[i]),
// websocketPayloadPreview(payloads[i]),
// )
if errWrite := conn.WriteMessage(websocket.TextMessage, payloads[i]); errWrite != nil {
log.Warnf(
"responses websocket: downstream_out write failed id=%s event=%s error=%v",
sessionID,
websocketPayloadEventType(payloads[i]),
errWrite,
)
return errWrite
}
}
return nil
}
func syntheticResponsesWebsocketPrewarmPayloads(requestJSON []byte) ([][]byte, error) {
responseID := "resp_prewarm_" + uuid.NewString()
createdAt := time.Now().Unix()
modelName := strings.TrimSpace(gjson.GetBytes(requestJSON, "model").String())
createdPayload := []byte(`{"type":"response.created","sequence_number":0,"response":{"id":"","object":"response","created_at":0,"status":"in_progress","background":false,"error":null,"output":[]}}`)
var errSet error
createdPayload, errSet = sjson.SetBytes(createdPayload, "response.id", responseID)
if errSet != nil {
return nil, errSet
}
createdPayload, errSet = sjson.SetBytes(createdPayload, "response.created_at", createdAt)
if errSet != nil {
return nil, errSet
}
if modelName != "" {
createdPayload, errSet = sjson.SetBytes(createdPayload, "response.model", modelName)
if errSet != nil {
return nil, errSet
}
}
completedPayload := []byte(`{"type":"response.completed","sequence_number":1,"response":{"id":"","object":"response","created_at":0,"status":"completed","background":false,"error":null,"output":[],"usage":{"input_tokens":0,"output_tokens":0,"total_tokens":0}}}`)
completedPayload, errSet = sjson.SetBytes(completedPayload, "response.id", responseID)
if errSet != nil {
return nil, errSet
}
completedPayload, errSet = sjson.SetBytes(completedPayload, "response.created_at", createdAt)
if errSet != nil {
return nil, errSet
}
if modelName != "" {
completedPayload, errSet = sjson.SetBytes(completedPayload, "response.model", modelName)
if errSet != nil {
return nil, errSet
}
}
return [][]byte{createdPayload, completedPayload}, nil
}
func mergeJSONArrayRaw(existingRaw, appendRaw string) (string, error) {
existingRaw = strings.TrimSpace(existingRaw)
appendRaw = strings.TrimSpace(appendRaw)
@@ -469,9 +680,6 @@ func (h *OpenAIResponsesAPIHandler) forwardResponsesWebsocket(
for i := range payloads {
eventType := gjson.GetBytes(payloads[i], "type").String()
if eventType == wsEventTypeCompleted {
// log.Infof("replace %s with %s", wsEventTypeCompleted, wsEventTypeDone)
payloads[i], _ = sjson.SetBytes(payloads[i], "type", wsEventTypeDone)
completed = true
completedOutput = responseCompletedOutputFromPayload(payloads[i])
}
@@ -554,47 +762,63 @@ func writeResponsesWebsocketError(conn *websocket.Conn, errMsg *interfaces.Error
}
body := handlers.BuildErrorResponseBody(status, errText)
payload := map[string]any{
"type": wsEventTypeError,
"status": status,
payload := []byte(`{}`)
var errSet error
payload, errSet = sjson.SetBytes(payload, "type", wsEventTypeError)
if errSet != nil {
return nil, errSet
}
payload, errSet = sjson.SetBytes(payload, "status", status)
if errSet != nil {
return nil, errSet
}
if errMsg != nil && errMsg.Addon != nil {
headers := map[string]any{}
headers := []byte(`{}`)
hasHeaders := false
for key, values := range errMsg.Addon {
if len(values) == 0 {
continue
}
headers[key] = values[0]
headerPath := strings.ReplaceAll(strings.ReplaceAll(key, `\\`, `\\\\`), ".", `\\.`)
headers, errSet = sjson.SetBytes(headers, headerPath, values[0])
if errSet != nil {
return nil, errSet
}
hasHeaders = true
}
if len(headers) > 0 {
payload["headers"] = headers
}
}
if len(body) > 0 && json.Valid(body) {
var decoded map[string]any
if errDecode := json.Unmarshal(body, &decoded); errDecode == nil {
if inner, ok := decoded["error"]; ok {
payload["error"] = inner
} else {
payload["error"] = decoded
if hasHeaders {
payload, errSet = sjson.SetRawBytes(payload, "headers", headers)
if errSet != nil {
return nil, errSet
}
}
}
if _, ok := payload["error"]; !ok {
payload["error"] = map[string]any{
"type": "server_error",
"message": errText,
if len(body) > 0 && json.Valid(body) {
errorNode := gjson.GetBytes(body, "error")
if errorNode.Exists() {
payload, errSet = sjson.SetRawBytes(payload, "error", []byte(errorNode.Raw))
} else {
payload, errSet = sjson.SetRawBytes(payload, "error", body)
}
if errSet != nil {
return nil, errSet
}
}
data, err := json.Marshal(payload)
if err != nil {
return nil, err
if !gjson.GetBytes(payload, "error").Exists() {
payload, errSet = sjson.SetBytes(payload, "error.type", "server_error")
if errSet != nil {
return nil, errSet
}
payload, errSet = sjson.SetBytes(payload, "error.message", errText)
if errSet != nil {
return nil, errSet
}
}
return data, conn.WriteMessage(websocket.TextMessage, data)
return payload, conn.WriteMessage(websocket.TextMessage, payload)
}
func appendWebsocketEvent(builder *strings.Builder, eventType string, payload []byte) {

View File

@@ -2,15 +2,57 @@ package openai
import (
"bytes"
"context"
"errors"
"fmt"
"net/http"
"net/http/httptest"
"strings"
"testing"
"github.com/gin-gonic/gin"
"github.com/gorilla/websocket"
"github.com/router-for-me/CLIProxyAPI/v6/internal/interfaces"
"github.com/router-for-me/CLIProxyAPI/v6/internal/registry"
"github.com/router-for-me/CLIProxyAPI/v6/sdk/api/handlers"
coreauth "github.com/router-for-me/CLIProxyAPI/v6/sdk/cliproxy/auth"
coreexecutor "github.com/router-for-me/CLIProxyAPI/v6/sdk/cliproxy/executor"
sdkconfig "github.com/router-for-me/CLIProxyAPI/v6/sdk/config"
"github.com/tidwall/gjson"
)
type websocketCaptureExecutor struct {
streamCalls int
payloads [][]byte
}
func (e *websocketCaptureExecutor) Identifier() string { return "test-provider" }
func (e *websocketCaptureExecutor) Execute(context.Context, *coreauth.Auth, coreexecutor.Request, coreexecutor.Options) (coreexecutor.Response, error) {
return coreexecutor.Response{}, errors.New("not implemented")
}
func (e *websocketCaptureExecutor) ExecuteStream(_ context.Context, _ *coreauth.Auth, req coreexecutor.Request, _ coreexecutor.Options) (*coreexecutor.StreamResult, error) {
e.streamCalls++
e.payloads = append(e.payloads, bytes.Clone(req.Payload))
chunks := make(chan coreexecutor.StreamChunk, 1)
chunks <- coreexecutor.StreamChunk{Payload: []byte(`{"type":"response.completed","response":{"id":"resp-upstream","output":[{"type":"message","id":"out-1"}]}}`)}
close(chunks)
return &coreexecutor.StreamResult{Chunks: chunks}, nil
}
func (e *websocketCaptureExecutor) Refresh(_ context.Context, auth *coreauth.Auth) (*coreauth.Auth, error) {
return auth, nil
}
func (e *websocketCaptureExecutor) CountTokens(context.Context, *coreauth.Auth, coreexecutor.Request, coreexecutor.Options) (coreexecutor.Response, error) {
return coreexecutor.Response{}, errors.New("not implemented")
}
func (e *websocketCaptureExecutor) HttpRequest(context.Context, *coreauth.Auth, *http.Request) (*http.Response, error) {
return nil, errors.New("not implemented")
}
func TestNormalizeResponsesWebsocketRequestCreate(t *testing.T) {
raw := []byte(`{"type":"response.create","model":"test-model","stream":false,"input":[{"type":"message","id":"msg-1"}]}`)
@@ -247,3 +289,206 @@ func TestSetWebsocketRequestBody(t *testing.T) {
t.Fatalf("request body = %q, want %q", string(bodyBytes), "event body")
}
}
func TestForwardResponsesWebsocketPreservesCompletedEvent(t *testing.T) {
gin.SetMode(gin.TestMode)
serverErrCh := make(chan error, 1)
server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
conn, err := responsesWebsocketUpgrader.Upgrade(w, r, nil)
if err != nil {
serverErrCh <- err
return
}
defer func() {
errClose := conn.Close()
if errClose != nil {
serverErrCh <- errClose
}
}()
ctx, _ := gin.CreateTestContext(httptest.NewRecorder())
ctx.Request = r
data := make(chan []byte, 1)
errCh := make(chan *interfaces.ErrorMessage)
data <- []byte("data: {\"type\":\"response.completed\",\"response\":{\"id\":\"resp-1\",\"output\":[{\"type\":\"message\",\"id\":\"out-1\"}]}}\n\n")
close(data)
close(errCh)
var bodyLog strings.Builder
completedOutput, err := (*OpenAIResponsesAPIHandler)(nil).forwardResponsesWebsocket(
ctx,
conn,
func(...interface{}) {},
data,
errCh,
&bodyLog,
"session-1",
)
if err != nil {
serverErrCh <- err
return
}
if gjson.GetBytes(completedOutput, "0.id").String() != "out-1" {
serverErrCh <- errors.New("completed output not captured")
return
}
serverErrCh <- nil
}))
defer server.Close()
wsURL := "ws" + strings.TrimPrefix(server.URL, "http")
conn, _, err := websocket.DefaultDialer.Dial(wsURL, nil)
if err != nil {
t.Fatalf("dial websocket: %v", err)
}
defer func() {
errClose := conn.Close()
if errClose != nil {
t.Fatalf("close websocket: %v", errClose)
}
}()
_, payload, errReadMessage := conn.ReadMessage()
if errReadMessage != nil {
t.Fatalf("read websocket message: %v", errReadMessage)
}
if gjson.GetBytes(payload, "type").String() != wsEventTypeCompleted {
t.Fatalf("payload type = %s, want %s", gjson.GetBytes(payload, "type").String(), wsEventTypeCompleted)
}
if strings.Contains(string(payload), "response.done") {
t.Fatalf("payload unexpectedly rewrote completed event: %s", payload)
}
if errServer := <-serverErrCh; errServer != nil {
t.Fatalf("server error: %v", errServer)
}
}
func TestWebsocketUpstreamSupportsIncrementalInputForModel(t *testing.T) {
manager := coreauth.NewManager(nil, nil, nil)
auth := &coreauth.Auth{
ID: "auth-ws",
Provider: "test-provider",
Status: coreauth.StatusActive,
Attributes: map[string]string{"websockets": "true"},
}
if _, err := manager.Register(context.Background(), auth); err != nil {
t.Fatalf("Register auth: %v", err)
}
registry.GetGlobalRegistry().RegisterClient(auth.ID, auth.Provider, []*registry.ModelInfo{{ID: "test-model"}})
t.Cleanup(func() {
registry.GetGlobalRegistry().UnregisterClient(auth.ID)
})
base := handlers.NewBaseAPIHandlers(&sdkconfig.SDKConfig{}, manager)
h := NewOpenAIResponsesAPIHandler(base)
if !h.websocketUpstreamSupportsIncrementalInputForModel("test-model") {
t.Fatalf("expected websocket-capable upstream for test-model")
}
}
func TestResponsesWebsocketPrewarmHandledLocallyForSSEUpstream(t *testing.T) {
gin.SetMode(gin.TestMode)
executor := &websocketCaptureExecutor{}
manager := coreauth.NewManager(nil, nil, nil)
manager.RegisterExecutor(executor)
auth := &coreauth.Auth{ID: "auth-sse", Provider: executor.Identifier(), Status: coreauth.StatusActive}
if _, err := manager.Register(context.Background(), auth); err != nil {
t.Fatalf("Register auth: %v", err)
}
registry.GetGlobalRegistry().RegisterClient(auth.ID, auth.Provider, []*registry.ModelInfo{{ID: "test-model"}})
t.Cleanup(func() {
registry.GetGlobalRegistry().UnregisterClient(auth.ID)
})
base := handlers.NewBaseAPIHandlers(&sdkconfig.SDKConfig{}, manager)
h := NewOpenAIResponsesAPIHandler(base)
router := gin.New()
router.GET("/v1/responses/ws", h.ResponsesWebsocket)
server := httptest.NewServer(router)
defer server.Close()
wsURL := "ws" + strings.TrimPrefix(server.URL, "http") + "/v1/responses/ws"
conn, _, err := websocket.DefaultDialer.Dial(wsURL, nil)
if err != nil {
t.Fatalf("dial websocket: %v", err)
}
defer func() {
errClose := conn.Close()
if errClose != nil {
t.Fatalf("close websocket: %v", errClose)
}
}()
errWrite := conn.WriteMessage(websocket.TextMessage, []byte(`{"type":"response.create","model":"test-model","generate":false}`))
if errWrite != nil {
t.Fatalf("write prewarm websocket message: %v", errWrite)
}
_, createdPayload, errReadMessage := conn.ReadMessage()
if errReadMessage != nil {
t.Fatalf("read prewarm created message: %v", errReadMessage)
}
if gjson.GetBytes(createdPayload, "type").String() != "response.created" {
t.Fatalf("created payload type = %s, want response.created", gjson.GetBytes(createdPayload, "type").String())
}
prewarmResponseID := gjson.GetBytes(createdPayload, "response.id").String()
if prewarmResponseID == "" {
t.Fatalf("prewarm response id is empty")
}
if executor.streamCalls != 0 {
t.Fatalf("stream calls after prewarm = %d, want 0", executor.streamCalls)
}
_, completedPayload, errReadMessage := conn.ReadMessage()
if errReadMessage != nil {
t.Fatalf("read prewarm completed message: %v", errReadMessage)
}
if gjson.GetBytes(completedPayload, "type").String() != wsEventTypeCompleted {
t.Fatalf("completed payload type = %s, want %s", gjson.GetBytes(completedPayload, "type").String(), wsEventTypeCompleted)
}
if gjson.GetBytes(completedPayload, "response.id").String() != prewarmResponseID {
t.Fatalf("completed response id = %s, want %s", gjson.GetBytes(completedPayload, "response.id").String(), prewarmResponseID)
}
if gjson.GetBytes(completedPayload, "response.usage.total_tokens").Int() != 0 {
t.Fatalf("prewarm total tokens = %d, want 0", gjson.GetBytes(completedPayload, "response.usage.total_tokens").Int())
}
secondRequest := fmt.Sprintf(`{"type":"response.create","previous_response_id":%q,"input":[{"type":"message","id":"msg-1"}]}`, prewarmResponseID)
errWrite = conn.WriteMessage(websocket.TextMessage, []byte(secondRequest))
if errWrite != nil {
t.Fatalf("write follow-up websocket message: %v", errWrite)
}
_, upstreamPayload, errReadMessage := conn.ReadMessage()
if errReadMessage != nil {
t.Fatalf("read upstream completed message: %v", errReadMessage)
}
if gjson.GetBytes(upstreamPayload, "type").String() != wsEventTypeCompleted {
t.Fatalf("upstream payload type = %s, want %s", gjson.GetBytes(upstreamPayload, "type").String(), wsEventTypeCompleted)
}
if executor.streamCalls != 1 {
t.Fatalf("stream calls after follow-up = %d, want 1", executor.streamCalls)
}
if len(executor.payloads) != 1 {
t.Fatalf("captured upstream payloads = %d, want 1", len(executor.payloads))
}
forwarded := executor.payloads[0]
if gjson.GetBytes(forwarded, "previous_response_id").Exists() {
t.Fatalf("previous_response_id leaked upstream: %s", forwarded)
}
if gjson.GetBytes(forwarded, "generate").Exists() {
t.Fatalf("generate leaked upstream: %s", forwarded)
}
if gjson.GetBytes(forwarded, "model").String() != "test-model" {
t.Fatalf("forwarded model = %s, want test-model", gjson.GetBytes(forwarded, "model").String())
}
input := gjson.GetBytes(forwarded, "input").Array()
if len(input) != 1 || input[0].Get("id").String() != "msg-1" {
t.Fatalf("unexpected forwarded input: %s", forwarded)
}
}