diff --git a/internal/registry/model_definitions_static_data.go b/internal/registry/model_definitions_static_data.go index 72f2e35f..fb019730 100644 --- a/internal/registry/model_definitions_static_data.go +++ b/internal/registry/model_definitions_static_data.go @@ -839,6 +839,20 @@ func GetOpenAIModels() []*ModelInfo { SupportedParameters: []string{"tools"}, Thinking: &ThinkingSupport{Levels: []string{"low", "medium", "high", "xhigh"}}, }, + { + ID: "gpt-5.4", + Object: "model", + Created: 1772668800, + OwnedBy: "openai", + Type: "openai", + Version: "gpt-5.4", + DisplayName: "GPT 5.4", + Description: "Stable version of GPT 5.4 Codex, The best model for coding and agentic tasks across domains.", + ContextLength: 1_050_000, + MaxCompletionTokens: 128000, + SupportedParameters: []string{"tools"}, + Thinking: &ThinkingSupport{Levels: []string{"low", "medium", "high", "xhigh"}}, + }, } } diff --git a/internal/runtime/executor/claude_executor.go b/internal/runtime/executor/claude_executor.go index 805d31dd..7d0ddcf2 100644 --- a/internal/runtime/executor/claude_executor.go +++ b/internal/runtime/executor/claude_executor.go @@ -187,17 +187,15 @@ func (e *ClaudeExecutor) Execute(ctx context.Context, auth *cliproxyauth.Auth, r } recordAPIResponseMetadata(ctx, e.cfg, httpResp.StatusCode, httpResp.Header.Clone()) if httpResp.StatusCode < 200 || httpResp.StatusCode >= 300 { - // Decompress error responses (e.g. gzip-compressed 400 errors from Anthropic API). - errBody := httpResp.Body - if ce := httpResp.Header.Get("Content-Encoding"); ce != "" { - var decErr error - errBody, decErr = decodeResponseBody(httpResp.Body, ce) - if decErr != nil { - recordAPIResponseError(ctx, e.cfg, decErr) - msg := fmt.Sprintf("failed to decode error response body (encoding=%s): %v", ce, decErr) - logWithRequestID(ctx).Warn(msg) - return resp, statusErr{code: httpResp.StatusCode, msg: msg} - } + // Decompress error responses — pass the Content-Encoding value (may be empty) + // and let decodeResponseBody handle both header-declared and magic-byte-detected + // compression. This keeps error-path behaviour consistent with the success path. + errBody, decErr := decodeResponseBody(httpResp.Body, httpResp.Header.Get("Content-Encoding")) + if decErr != nil { + recordAPIResponseError(ctx, e.cfg, decErr) + msg := fmt.Sprintf("failed to decode error response body: %v", decErr) + logWithRequestID(ctx).Warn(msg) + return resp, statusErr{code: httpResp.StatusCode, msg: msg} } b, readErr := io.ReadAll(errBody) if readErr != nil { @@ -352,17 +350,15 @@ func (e *ClaudeExecutor) ExecuteStream(ctx context.Context, auth *cliproxyauth.A } recordAPIResponseMetadata(ctx, e.cfg, httpResp.StatusCode, httpResp.Header.Clone()) if httpResp.StatusCode < 200 || httpResp.StatusCode >= 300 { - // Decompress error responses (e.g. gzip-compressed 400 errors from Anthropic API). - errBody := httpResp.Body - if ce := httpResp.Header.Get("Content-Encoding"); ce != "" { - var decErr error - errBody, decErr = decodeResponseBody(httpResp.Body, ce) - if decErr != nil { - recordAPIResponseError(ctx, e.cfg, decErr) - msg := fmt.Sprintf("failed to decode error response body (encoding=%s): %v", ce, decErr) - logWithRequestID(ctx).Warn(msg) - return nil, statusErr{code: httpResp.StatusCode, msg: msg} - } + // Decompress error responses — pass the Content-Encoding value (may be empty) + // and let decodeResponseBody handle both header-declared and magic-byte-detected + // compression. This keeps error-path behaviour consistent with the success path. + errBody, decErr := decodeResponseBody(httpResp.Body, httpResp.Header.Get("Content-Encoding")) + if decErr != nil { + recordAPIResponseError(ctx, e.cfg, decErr) + msg := fmt.Sprintf("failed to decode error response body: %v", decErr) + logWithRequestID(ctx).Warn(msg) + return nil, statusErr{code: httpResp.StatusCode, msg: msg} } b, readErr := io.ReadAll(errBody) if readErr != nil { @@ -521,17 +517,15 @@ func (e *ClaudeExecutor) CountTokens(ctx context.Context, auth *cliproxyauth.Aut } recordAPIResponseMetadata(ctx, e.cfg, resp.StatusCode, resp.Header.Clone()) if resp.StatusCode < 200 || resp.StatusCode >= 300 { - // Decompress error responses (e.g. gzip-compressed 400 errors from Anthropic API). - errBody := resp.Body - if ce := resp.Header.Get("Content-Encoding"); ce != "" { - var decErr error - errBody, decErr = decodeResponseBody(resp.Body, ce) - if decErr != nil { - recordAPIResponseError(ctx, e.cfg, decErr) - msg := fmt.Sprintf("failed to decode error response body (encoding=%s): %v", ce, decErr) - logWithRequestID(ctx).Warn(msg) - return cliproxyexecutor.Response{}, statusErr{code: resp.StatusCode, msg: msg} - } + // Decompress error responses — pass the Content-Encoding value (may be empty) + // and let decodeResponseBody handle both header-declared and magic-byte-detected + // compression. This keeps error-path behaviour consistent with the success path. + errBody, decErr := decodeResponseBody(resp.Body, resp.Header.Get("Content-Encoding")) + if decErr != nil { + recordAPIResponseError(ctx, e.cfg, decErr) + msg := fmt.Sprintf("failed to decode error response body: %v", decErr) + logWithRequestID(ctx).Warn(msg) + return cliproxyexecutor.Response{}, statusErr{code: resp.StatusCode, msg: msg} } b, readErr := io.ReadAll(errBody) if readErr != nil { @@ -662,12 +656,61 @@ func (c *compositeReadCloser) Close() error { return firstErr } +// peekableBody wraps a bufio.Reader around the original ReadCloser so that +// magic bytes can be inspected without consuming them from the stream. +type peekableBody struct { + *bufio.Reader + closer io.Closer +} + +func (p *peekableBody) Close() error { + return p.closer.Close() +} + func decodeResponseBody(body io.ReadCloser, contentEncoding string) (io.ReadCloser, error) { if body == nil { return nil, fmt.Errorf("response body is nil") } if contentEncoding == "" { - return body, nil + // No Content-Encoding header. Attempt best-effort magic-byte detection to + // handle misbehaving upstreams that compress without setting the header. + // Only gzip (1f 8b) and zstd (28 b5 2f fd) have reliable magic sequences; + // br and deflate have none and are left as-is. + // The bufio wrapper preserves unread bytes so callers always see the full + // stream regardless of whether decompression was applied. + pb := &peekableBody{Reader: bufio.NewReader(body), closer: body} + magic, peekErr := pb.Peek(4) + if peekErr == nil || (peekErr == io.EOF && len(magic) >= 2) { + switch { + case len(magic) >= 2 && magic[0] == 0x1f && magic[1] == 0x8b: + gzipReader, gzErr := gzip.NewReader(pb) + if gzErr != nil { + _ = pb.Close() + return nil, fmt.Errorf("magic-byte gzip: failed to create reader: %w", gzErr) + } + return &compositeReadCloser{ + Reader: gzipReader, + closers: []func() error{ + gzipReader.Close, + pb.Close, + }, + }, nil + case len(magic) >= 4 && magic[0] == 0x28 && magic[1] == 0xb5 && magic[2] == 0x2f && magic[3] == 0xfd: + decoder, zdErr := zstd.NewReader(pb) + if zdErr != nil { + _ = pb.Close() + return nil, fmt.Errorf("magic-byte zstd: failed to create reader: %w", zdErr) + } + return &compositeReadCloser{ + Reader: decoder, + closers: []func() error{ + func() error { decoder.Close(); return nil }, + pb.Close, + }, + }, nil + } + } + return pb, nil } encodings := strings.Split(contentEncoding, ",") for _, raw := range encodings { @@ -844,11 +887,15 @@ func applyClaudeHeaders(r *http.Request, auth *cliproxyauth.Auth, apiKey string, r.Header.Set("User-Agent", hdrDefault(hd.UserAgent, "claude-cli/2.1.63 (external, cli)")) } r.Header.Set("Connection", "keep-alive") - r.Header.Set("Accept-Encoding", "gzip, deflate, br, zstd") if stream { r.Header.Set("Accept", "text/event-stream") + // SSE streams must not be compressed: the downstream scanner reads + // line-delimited text and cannot parse compressed bytes. Using + // "identity" tells the upstream to send an uncompressed stream. + r.Header.Set("Accept-Encoding", "identity") } else { r.Header.Set("Accept", "application/json") + r.Header.Set("Accept-Encoding", "gzip, deflate, br, zstd") } // Keep OS/Arch mapping dynamic (not configurable). // They intentionally continue to derive from runtime.GOOS/runtime.GOARCH. @@ -857,6 +904,12 @@ func applyClaudeHeaders(r *http.Request, auth *cliproxyauth.Auth, apiKey string, attrs = auth.Attributes } util.ApplyCustomHeadersFromAttrs(r, attrs) + // Re-enforce Accept-Encoding: identity after ApplyCustomHeadersFromAttrs, which + // may override it with a user-configured value. Compressed SSE breaks the line + // scanner regardless of user preference, so this is non-negotiable for streams. + if stream { + r.Header.Set("Accept-Encoding", "identity") + } } func claudeCreds(a *cliproxyauth.Auth) (apiKey, baseURL string) { diff --git a/internal/runtime/executor/claude_executor_test.go b/internal/runtime/executor/claude_executor_test.go index f9553f9a..c4a4d644 100644 --- a/internal/runtime/executor/claude_executor_test.go +++ b/internal/runtime/executor/claude_executor_test.go @@ -2,6 +2,7 @@ package executor import ( "bytes" + "compress/gzip" "context" "io" "net/http" @@ -9,6 +10,7 @@ import ( "strings" "testing" + "github.com/klauspost/compress/zstd" "github.com/router-for-me/CLIProxyAPI/v6/internal/config" cliproxyauth "github.com/router-for-me/CLIProxyAPI/v6/sdk/cliproxy/auth" cliproxyexecutor "github.com/router-for-me/CLIProxyAPI/v6/sdk/cliproxy/executor" @@ -583,3 +585,385 @@ func testClaudeExecutorInvalidCompressedErrorBody( t.Fatalf("expected status code 400, got: %v", err) } } + +// TestClaudeExecutor_ExecuteStream_SetsIdentityAcceptEncoding verifies that streaming +// requests use Accept-Encoding: identity so the upstream cannot respond with a +// compressed SSE body that would silently break the line scanner. +func TestClaudeExecutor_ExecuteStream_SetsIdentityAcceptEncoding(t *testing.T) { + var gotEncoding, gotAccept string + server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + gotEncoding = r.Header.Get("Accept-Encoding") + gotAccept = r.Header.Get("Accept") + w.Header().Set("Content-Type", "text/event-stream") + _, _ = w.Write([]byte("data: {\"type\":\"message_stop\"}\n\n")) + })) + defer server.Close() + + executor := NewClaudeExecutor(&config.Config{}) + auth := &cliproxyauth.Auth{Attributes: map[string]string{ + "api_key": "key-123", + "base_url": server.URL, + }} + payload := []byte(`{"messages":[{"role":"user","content":[{"type":"text","text":"hi"}]}]}`) + + result, err := executor.ExecuteStream(context.Background(), auth, cliproxyexecutor.Request{ + Model: "claude-3-5-sonnet-20241022", + Payload: payload, + }, cliproxyexecutor.Options{ + SourceFormat: sdktranslator.FromString("claude"), + }) + if err != nil { + t.Fatalf("ExecuteStream error: %v", err) + } + for chunk := range result.Chunks { + if chunk.Err != nil { + t.Fatalf("unexpected chunk error: %v", chunk.Err) + } + } + + if gotEncoding != "identity" { + t.Errorf("Accept-Encoding = %q, want %q", gotEncoding, "identity") + } + if gotAccept != "text/event-stream" { + t.Errorf("Accept = %q, want %q", gotAccept, "text/event-stream") + } +} + +// TestClaudeExecutor_Execute_SetsCompressedAcceptEncoding verifies that non-streaming +// requests keep the full accept-encoding to allow response compression (which +// decodeResponseBody handles correctly). +func TestClaudeExecutor_Execute_SetsCompressedAcceptEncoding(t *testing.T) { + var gotEncoding, gotAccept string + server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + gotEncoding = r.Header.Get("Accept-Encoding") + gotAccept = r.Header.Get("Accept") + w.Header().Set("Content-Type", "application/json") + _, _ = w.Write([]byte(`{"id":"msg_1","type":"message","model":"claude-3-5-sonnet-20241022","role":"assistant","content":[{"type":"text","text":"hi"}],"usage":{"input_tokens":1,"output_tokens":1}}`)) + })) + defer server.Close() + + executor := NewClaudeExecutor(&config.Config{}) + auth := &cliproxyauth.Auth{Attributes: map[string]string{ + "api_key": "key-123", + "base_url": server.URL, + }} + payload := []byte(`{"messages":[{"role":"user","content":[{"type":"text","text":"hi"}]}]}`) + + _, err := executor.Execute(context.Background(), auth, cliproxyexecutor.Request{ + Model: "claude-3-5-sonnet-20241022", + Payload: payload, + }, cliproxyexecutor.Options{ + SourceFormat: sdktranslator.FromString("claude"), + }) + if err != nil { + t.Fatalf("Execute error: %v", err) + } + + if gotEncoding != "gzip, deflate, br, zstd" { + t.Errorf("Accept-Encoding = %q, want %q", gotEncoding, "gzip, deflate, br, zstd") + } + if gotAccept != "application/json" { + t.Errorf("Accept = %q, want %q", gotAccept, "application/json") + } +} + +// TestClaudeExecutor_ExecuteStream_GzipSuccessBodyDecoded verifies that a streaming +// HTTP 200 response with Content-Encoding: gzip is correctly decompressed before +// the line scanner runs, so SSE chunks are not silently dropped. +func TestClaudeExecutor_ExecuteStream_GzipSuccessBodyDecoded(t *testing.T) { + var buf bytes.Buffer + gz := gzip.NewWriter(&buf) + _, _ = gz.Write([]byte("data: {\"type\":\"message_stop\"}\n")) + _ = gz.Close() + compressedBody := buf.Bytes() + + server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + w.Header().Set("Content-Type", "text/event-stream") + w.Header().Set("Content-Encoding", "gzip") + _, _ = w.Write(compressedBody) + })) + defer server.Close() + + executor := NewClaudeExecutor(&config.Config{}) + auth := &cliproxyauth.Auth{Attributes: map[string]string{ + "api_key": "key-123", + "base_url": server.URL, + }} + payload := []byte(`{"messages":[{"role":"user","content":[{"type":"text","text":"hi"}]}]}`) + + result, err := executor.ExecuteStream(context.Background(), auth, cliproxyexecutor.Request{ + Model: "claude-3-5-sonnet-20241022", + Payload: payload, + }, cliproxyexecutor.Options{ + SourceFormat: sdktranslator.FromString("claude"), + }) + if err != nil { + t.Fatalf("ExecuteStream error: %v", err) + } + + var combined strings.Builder + for chunk := range result.Chunks { + if chunk.Err != nil { + t.Fatalf("chunk error: %v", chunk.Err) + } + combined.Write(chunk.Payload) + } + + if combined.Len() == 0 { + t.Fatal("expected at least one chunk from gzip-encoded SSE body, got none (body was not decompressed)") + } + if !strings.Contains(combined.String(), "message_stop") { + t.Errorf("expected SSE content in chunks, got: %q", combined.String()) + } +} + +// TestDecodeResponseBody_MagicByteGzipNoHeader verifies that decodeResponseBody +// detects gzip-compressed content via magic bytes even when Content-Encoding is absent. +func TestDecodeResponseBody_MagicByteGzipNoHeader(t *testing.T) { + const plaintext = "data: {\"type\":\"message_stop\"}\n" + + var buf bytes.Buffer + gz := gzip.NewWriter(&buf) + _, _ = gz.Write([]byte(plaintext)) + _ = gz.Close() + + rc := io.NopCloser(&buf) + decoded, err := decodeResponseBody(rc, "") + if err != nil { + t.Fatalf("decodeResponseBody error: %v", err) + } + defer decoded.Close() + + got, err := io.ReadAll(decoded) + if err != nil { + t.Fatalf("ReadAll error: %v", err) + } + if string(got) != plaintext { + t.Errorf("decoded = %q, want %q", got, plaintext) + } +} + +// TestDecodeResponseBody_PlainTextNoHeader verifies that decodeResponseBody returns +// plain text untouched when Content-Encoding is absent and no magic bytes match. +func TestDecodeResponseBody_PlainTextNoHeader(t *testing.T) { + const plaintext = "data: {\"type\":\"message_stop\"}\n" + rc := io.NopCloser(strings.NewReader(plaintext)) + decoded, err := decodeResponseBody(rc, "") + if err != nil { + t.Fatalf("decodeResponseBody error: %v", err) + } + defer decoded.Close() + + got, err := io.ReadAll(decoded) + if err != nil { + t.Fatalf("ReadAll error: %v", err) + } + if string(got) != plaintext { + t.Errorf("decoded = %q, want %q", got, plaintext) + } +} + +// TestClaudeExecutor_ExecuteStream_GzipNoContentEncodingHeader verifies the full +// pipeline: when the upstream returns a gzip-compressed SSE body WITHOUT setting +// Content-Encoding (a misbehaving upstream), the magic-byte sniff in +// decodeResponseBody still decompresses it, so chunks reach the caller. +func TestClaudeExecutor_ExecuteStream_GzipNoContentEncodingHeader(t *testing.T) { + var buf bytes.Buffer + gz := gzip.NewWriter(&buf) + _, _ = gz.Write([]byte("data: {\"type\":\"message_stop\"}\n")) + _ = gz.Close() + compressedBody := buf.Bytes() + + server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + w.Header().Set("Content-Type", "text/event-stream") + // Intentionally omit Content-Encoding to simulate misbehaving upstream. + _, _ = w.Write(compressedBody) + })) + defer server.Close() + + executor := NewClaudeExecutor(&config.Config{}) + auth := &cliproxyauth.Auth{Attributes: map[string]string{ + "api_key": "key-123", + "base_url": server.URL, + }} + payload := []byte(`{"messages":[{"role":"user","content":[{"type":"text","text":"hi"}]}]}`) + + result, err := executor.ExecuteStream(context.Background(), auth, cliproxyexecutor.Request{ + Model: "claude-3-5-sonnet-20241022", + Payload: payload, + }, cliproxyexecutor.Options{ + SourceFormat: sdktranslator.FromString("claude"), + }) + if err != nil { + t.Fatalf("ExecuteStream error: %v", err) + } + + var combined strings.Builder + for chunk := range result.Chunks { + if chunk.Err != nil { + t.Fatalf("chunk error: %v", chunk.Err) + } + combined.Write(chunk.Payload) + } + + if combined.Len() == 0 { + t.Fatal("expected chunks from gzip body without Content-Encoding header, got none (magic-byte sniff failed)") + } + if !strings.Contains(combined.String(), "message_stop") { + t.Errorf("unexpected chunk content: %q", combined.String()) + } +} + +// TestClaudeExecutor_ExecuteStream_AcceptEncodingOverrideCannotBypassIdentity verifies +// that injecting Accept-Encoding via auth.Attributes cannot override the stream +// path's enforced identity encoding. +func TestClaudeExecutor_ExecuteStream_AcceptEncodingOverrideCannotBypassIdentity(t *testing.T) { + var gotEncoding string + server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + gotEncoding = r.Header.Get("Accept-Encoding") + w.Header().Set("Content-Type", "text/event-stream") + _, _ = w.Write([]byte("data: {\"type\":\"message_stop\"}\n\n")) + })) + defer server.Close() + + executor := NewClaudeExecutor(&config.Config{}) + // Inject Accept-Encoding via the custom header attribute mechanism. + auth := &cliproxyauth.Auth{Attributes: map[string]string{ + "api_key": "key-123", + "base_url": server.URL, + "header:Accept-Encoding": "gzip, deflate, br, zstd", + }} + payload := []byte(`{"messages":[{"role":"user","content":[{"type":"text","text":"hi"}]}]}`) + + result, err := executor.ExecuteStream(context.Background(), auth, cliproxyexecutor.Request{ + Model: "claude-3-5-sonnet-20241022", + Payload: payload, + }, cliproxyexecutor.Options{ + SourceFormat: sdktranslator.FromString("claude"), + }) + if err != nil { + t.Fatalf("ExecuteStream error: %v", err) + } + for chunk := range result.Chunks { + if chunk.Err != nil { + t.Fatalf("unexpected chunk error: %v", chunk.Err) + } + } + + if gotEncoding != "identity" { + t.Errorf("Accept-Encoding = %q; stream path must enforce identity regardless of auth.Attributes override", gotEncoding) + } +} + +// TestDecodeResponseBody_MagicByteZstdNoHeader verifies that decodeResponseBody +// detects zstd-compressed content via magic bytes (28 b5 2f fd) even when +// Content-Encoding is absent. +func TestDecodeResponseBody_MagicByteZstdNoHeader(t *testing.T) { + const plaintext = "data: {\"type\":\"message_stop\"}\n" + + var buf bytes.Buffer + enc, err := zstd.NewWriter(&buf) + if err != nil { + t.Fatalf("zstd.NewWriter: %v", err) + } + _, _ = enc.Write([]byte(plaintext)) + _ = enc.Close() + + rc := io.NopCloser(&buf) + decoded, err := decodeResponseBody(rc, "") + if err != nil { + t.Fatalf("decodeResponseBody error: %v", err) + } + defer decoded.Close() + + got, err := io.ReadAll(decoded) + if err != nil { + t.Fatalf("ReadAll error: %v", err) + } + if string(got) != plaintext { + t.Errorf("decoded = %q, want %q", got, plaintext) + } +} + +// TestClaudeExecutor_Execute_GzipErrorBodyNoContentEncodingHeader verifies that the +// error path (4xx) correctly decompresses a gzip body even when the upstream omits +// the Content-Encoding header. This closes the gap left by PR #1771, which only +// fixed header-declared compression on the error path. +func TestClaudeExecutor_Execute_GzipErrorBodyNoContentEncodingHeader(t *testing.T) { + const errJSON = `{"type":"error","error":{"type":"invalid_request_error","message":"test error"}}` + + var buf bytes.Buffer + gz := gzip.NewWriter(&buf) + _, _ = gz.Write([]byte(errJSON)) + _ = gz.Close() + compressedBody := buf.Bytes() + + server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + w.Header().Set("Content-Type", "application/json") + // Intentionally omit Content-Encoding to simulate misbehaving upstream. + w.WriteHeader(http.StatusBadRequest) + _, _ = w.Write(compressedBody) + })) + defer server.Close() + + executor := NewClaudeExecutor(&config.Config{}) + auth := &cliproxyauth.Auth{Attributes: map[string]string{ + "api_key": "key-123", + "base_url": server.URL, + }} + payload := []byte(`{"messages":[{"role":"user","content":[{"type":"text","text":"hi"}]}]}`) + + _, err := executor.Execute(context.Background(), auth, cliproxyexecutor.Request{ + Model: "claude-3-5-sonnet-20241022", + Payload: payload, + }, cliproxyexecutor.Options{ + SourceFormat: sdktranslator.FromString("claude"), + }) + if err == nil { + t.Fatal("expected an error for 400 response, got nil") + } + if !strings.Contains(err.Error(), "test error") { + t.Errorf("error message should contain decompressed JSON, got: %q", err.Error()) + } +} + +// TestClaudeExecutor_ExecuteStream_GzipErrorBodyNoContentEncodingHeader verifies +// the same for the streaming executor: 4xx gzip body without Content-Encoding is +// decoded and the error message is readable. +func TestClaudeExecutor_ExecuteStream_GzipErrorBodyNoContentEncodingHeader(t *testing.T) { + const errJSON = `{"type":"error","error":{"type":"invalid_request_error","message":"stream test error"}}` + + var buf bytes.Buffer + gz := gzip.NewWriter(&buf) + _, _ = gz.Write([]byte(errJSON)) + _ = gz.Close() + compressedBody := buf.Bytes() + + server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + w.Header().Set("Content-Type", "application/json") + // Intentionally omit Content-Encoding to simulate misbehaving upstream. + w.WriteHeader(http.StatusBadRequest) + _, _ = w.Write(compressedBody) + })) + defer server.Close() + + executor := NewClaudeExecutor(&config.Config{}) + auth := &cliproxyauth.Auth{Attributes: map[string]string{ + "api_key": "key-123", + "base_url": server.URL, + }} + payload := []byte(`{"messages":[{"role":"user","content":[{"type":"text","text":"hi"}]}]}`) + + _, err := executor.ExecuteStream(context.Background(), auth, cliproxyexecutor.Request{ + Model: "claude-3-5-sonnet-20241022", + Payload: payload, + }, cliproxyexecutor.Options{ + SourceFormat: sdktranslator.FromString("claude"), + }) + if err == nil { + t.Fatal("expected an error for 400 response, got nil") + } + if !strings.Contains(err.Error(), "stream test error") { + t.Errorf("error message should contain decompressed JSON, got: %q", err.Error()) + } +} diff --git a/internal/runtime/executor/codex_executor.go b/internal/runtime/executor/codex_executor.go index a0cbc0d5..30092ec7 100644 --- a/internal/runtime/executor/codex_executor.go +++ b/internal/runtime/executor/codex_executor.go @@ -616,6 +616,10 @@ func (e *CodexExecutor) cacheHelper(ctx context.Context, from sdktranslator.Form if promptCacheKey.Exists() { cache.ID = promptCacheKey.String() } + } else if from == "openai" { + if apiKey := strings.TrimSpace(apiKeyFromContext(ctx)); apiKey != "" { + cache.ID = uuid.NewSHA1(uuid.NameSpaceOID, []byte("cli-proxy-api:codex:prompt-cache:"+apiKey)).String() + } } if cache.ID != "" { diff --git a/internal/runtime/executor/codex_executor_cache_test.go b/internal/runtime/executor/codex_executor_cache_test.go new file mode 100644 index 00000000..d6dca031 --- /dev/null +++ b/internal/runtime/executor/codex_executor_cache_test.go @@ -0,0 +1,64 @@ +package executor + +import ( + "context" + "io" + "net/http/httptest" + "testing" + + "github.com/gin-gonic/gin" + "github.com/google/uuid" + cliproxyexecutor "github.com/router-for-me/CLIProxyAPI/v6/sdk/cliproxy/executor" + sdktranslator "github.com/router-for-me/CLIProxyAPI/v6/sdk/translator" + "github.com/tidwall/gjson" +) + +func TestCodexExecutorCacheHelper_OpenAIChatCompletions_StablePromptCacheKeyFromAPIKey(t *testing.T) { + recorder := httptest.NewRecorder() + ginCtx, _ := gin.CreateTestContext(recorder) + ginCtx.Set("apiKey", "test-api-key") + + ctx := context.WithValue(context.Background(), "gin", ginCtx) + executor := &CodexExecutor{} + rawJSON := []byte(`{"model":"gpt-5.3-codex","stream":true}`) + req := cliproxyexecutor.Request{ + Model: "gpt-5.3-codex", + Payload: []byte(`{"model":"gpt-5.3-codex"}`), + } + url := "https://example.com/responses" + + httpReq, err := executor.cacheHelper(ctx, sdktranslator.FromString("openai"), url, req, rawJSON) + if err != nil { + t.Fatalf("cacheHelper error: %v", err) + } + + body, errRead := io.ReadAll(httpReq.Body) + if errRead != nil { + t.Fatalf("read request body: %v", errRead) + } + + expectedKey := uuid.NewSHA1(uuid.NameSpaceOID, []byte("cli-proxy-api:codex:prompt-cache:test-api-key")).String() + gotKey := gjson.GetBytes(body, "prompt_cache_key").String() + if gotKey != expectedKey { + t.Fatalf("prompt_cache_key = %q, want %q", gotKey, expectedKey) + } + if gotConversation := httpReq.Header.Get("Conversation_id"); gotConversation != expectedKey { + t.Fatalf("Conversation_id = %q, want %q", gotConversation, expectedKey) + } + if gotSession := httpReq.Header.Get("Session_id"); gotSession != expectedKey { + t.Fatalf("Session_id = %q, want %q", gotSession, expectedKey) + } + + httpReq2, err := executor.cacheHelper(ctx, sdktranslator.FromString("openai"), url, req, rawJSON) + if err != nil { + t.Fatalf("cacheHelper error (second call): %v", err) + } + body2, errRead2 := io.ReadAll(httpReq2.Body) + if errRead2 != nil { + t.Fatalf("read request body (second call): %v", errRead2) + } + gotKey2 := gjson.GetBytes(body2, "prompt_cache_key").String() + if gotKey2 != expectedKey { + t.Fatalf("prompt_cache_key (second call) = %q, want %q", gotKey2, expectedKey) + } +} diff --git a/internal/translator/codex/openai/responses/codex_openai-responses_request.go b/internal/translator/codex/openai/responses/codex_openai-responses_request.go index 1161c515..87566e79 100644 --- a/internal/translator/codex/openai/responses/codex_openai-responses_request.go +++ b/internal/translator/codex/openai/responses/codex_openai-responses_request.go @@ -25,7 +25,7 @@ func ConvertOpenAIResponsesRequestToCodex(modelName string, inputRawJSON []byte, rawJSON, _ = sjson.DeleteBytes(rawJSON, "max_completion_tokens") rawJSON, _ = sjson.DeleteBytes(rawJSON, "temperature") rawJSON, _ = sjson.DeleteBytes(rawJSON, "top_p") - rawJSON, _ = sjson.DeleteBytes(rawJSON, "service_tier") + // rawJSON, _ = sjson.DeleteBytes(rawJSON, "service_tier") rawJSON, _ = sjson.DeleteBytes(rawJSON, "truncation") rawJSON = applyResponsesCompactionCompatibility(rawJSON) diff --git a/internal/translator/codex/openai/responses/codex_openai-responses_response.go b/internal/translator/codex/openai/responses/codex_openai-responses_response.go index 4287206a..e84b817b 100644 --- a/internal/translator/codex/openai/responses/codex_openai-responses_response.go +++ b/internal/translator/codex/openai/responses/codex_openai-responses_response.go @@ -6,24 +6,14 @@ import ( "fmt" "github.com/tidwall/gjson" - "github.com/tidwall/sjson" ) // ConvertCodexResponseToOpenAIResponses converts OpenAI Chat Completions streaming chunks // to OpenAI Responses SSE events (response.*). -func ConvertCodexResponseToOpenAIResponses(ctx context.Context, modelName string, originalRequestRawJSON, requestRawJSON, rawJSON []byte, param *any) []string { +func ConvertCodexResponseToOpenAIResponses(_ context.Context, _ string, _, _, rawJSON []byte, _ *any) []string { if bytes.HasPrefix(rawJSON, []byte("data:")) { rawJSON = bytes.TrimSpace(rawJSON[5:]) - if typeResult := gjson.GetBytes(rawJSON, "type"); typeResult.Exists() { - typeStr := typeResult.String() - if typeStr == "response.created" || typeStr == "response.in_progress" || typeStr == "response.completed" { - if gjson.GetBytes(rawJSON, "response.instructions").Exists() { - instructions := gjson.GetBytes(originalRequestRawJSON, "instructions").String() - rawJSON, _ = sjson.SetBytes(rawJSON, "response.instructions", instructions) - } - } - } out := fmt.Sprintf("data: %s", string(rawJSON)) return []string{out} } @@ -32,17 +22,12 @@ func ConvertCodexResponseToOpenAIResponses(ctx context.Context, modelName string // ConvertCodexResponseToOpenAIResponsesNonStream builds a single Responses JSON // from a non-streaming OpenAI Chat Completions response. -func ConvertCodexResponseToOpenAIResponsesNonStream(_ context.Context, modelName string, originalRequestRawJSON, requestRawJSON, rawJSON []byte, _ *any) string { +func ConvertCodexResponseToOpenAIResponsesNonStream(_ context.Context, _ string, _, _, rawJSON []byte, _ *any) string { rootResult := gjson.ParseBytes(rawJSON) // Verify this is a response.completed event if rootResult.Get("type").String() != "response.completed" { return "" } responseResult := rootResult.Get("response") - template := responseResult.Raw - if responseResult.Get("instructions").Exists() { - instructions := gjson.GetBytes(originalRequestRawJSON, "instructions").String() - template, _ = sjson.Set(template, "instructions", instructions) - } - return template + return responseResult.Raw } diff --git a/internal/translator/gemini/claude/gemini_claude_request.go b/internal/translator/gemini/claude/gemini_claude_request.go index ff276ce3..b13955bb 100644 --- a/internal/translator/gemini/claude/gemini_claude_request.go +++ b/internal/translator/gemini/claude/gemini_claude_request.go @@ -85,6 +85,11 @@ func ConvertClaudeRequestToGemini(modelName string, inputRawJSON []byte, _ bool) case "tool_use": functionName := contentResult.Get("name").String() + if toolUseID := contentResult.Get("id").String(); toolUseID != "" { + if derived := toolNameFromClaudeToolUseID(toolUseID); derived != "" { + functionName = derived + } + } functionArgs := contentResult.Get("input").String() argsResult := gjson.Parse(functionArgs) if argsResult.IsObject() && gjson.Valid(functionArgs) { @@ -100,10 +105,9 @@ func ConvertClaudeRequestToGemini(modelName string, inputRawJSON []byte, _ bool) if toolCallID == "" { return true } - funcName := toolCallID - toolCallIDs := strings.Split(toolCallID, "-") - if len(toolCallIDs) > 1 { - funcName = strings.Join(toolCallIDs[0:len(toolCallIDs)-1], "-") + funcName := toolNameFromClaudeToolUseID(toolCallID) + if funcName == "" { + funcName = toolCallID } responseData := contentResult.Get("content").Raw part := `{"functionResponse":{"name":"","response":{"result":""}}}` @@ -230,3 +234,11 @@ func ConvertClaudeRequestToGemini(modelName string, inputRawJSON []byte, _ bool) return result } + +func toolNameFromClaudeToolUseID(toolUseID string) string { + parts := strings.Split(toolUseID, "-") + if len(parts) <= 1 { + return "" + } + return strings.Join(parts[0:len(parts)-1], "-") +} diff --git a/internal/translator/gemini/claude/gemini_claude_response.go b/internal/translator/gemini/claude/gemini_claude_response.go index cfc06921..e5adcb5e 100644 --- a/internal/translator/gemini/claude/gemini_claude_response.go +++ b/internal/translator/gemini/claude/gemini_claude_response.go @@ -12,8 +12,8 @@ import ( "fmt" "strings" "sync/atomic" - "time" + "github.com/router-for-me/CLIProxyAPI/v6/internal/util" "github.com/tidwall/gjson" "github.com/tidwall/sjson" ) @@ -25,6 +25,8 @@ type Params struct { ResponseType int ResponseIndex int HasContent bool // Tracks whether any content (text, thinking, or tool use) has been output + ToolNameMap map[string]string + SawToolCall bool } // toolUseIDCounter provides a process-wide unique counter for tool use identifiers. @@ -53,6 +55,8 @@ func ConvertGeminiResponseToClaude(_ context.Context, _ string, originalRequestR HasFirstResponse: false, ResponseType: 0, ResponseIndex: 0, + ToolNameMap: util.ToolNameMapFromClaudeRequest(originalRequestRawJSON), + SawToolCall: false, } } @@ -66,8 +70,6 @@ func ConvertGeminiResponseToClaude(_ context.Context, _ string, originalRequestR return []string{} } - // Track whether tools are being used in this response chunk - usedTool := false output := "" // Initialize the streaming session with a message_start event @@ -175,12 +177,13 @@ func ConvertGeminiResponseToClaude(_ context.Context, _ string, originalRequestR } else if functionCallResult.Exists() { // Handle function/tool calls from the AI model // This processes tool usage requests and formats them for Claude API compatibility - usedTool = true - fcName := functionCallResult.Get("name").String() + (*param).(*Params).SawToolCall = true + upstreamToolName := functionCallResult.Get("name").String() + clientToolName := util.MapToolName((*param).(*Params).ToolNameMap, upstreamToolName) // FIX: Handle streaming split/delta where name might be empty in subsequent chunks. // If we are already in tool use mode and name is empty, treat as continuation (delta). - if (*param).(*Params).ResponseType == 3 && fcName == "" { + if (*param).(*Params).ResponseType == 3 && upstreamToolName == "" { if fcArgsResult := functionCallResult.Get("args"); fcArgsResult.Exists() { output = output + "event: content_block_delta\n" data, _ := sjson.Set(fmt.Sprintf(`{"type":"content_block_delta","index":%d,"delta":{"type":"input_json_delta","partial_json":""}}`, (*param).(*Params).ResponseIndex), "delta.partial_json", fcArgsResult.Raw) @@ -221,8 +224,8 @@ func ConvertGeminiResponseToClaude(_ context.Context, _ string, originalRequestR // Create the tool use block with unique ID and function details data := fmt.Sprintf(`{"type":"content_block_start","index":%d,"content_block":{"type":"tool_use","id":"","name":"","input":{}}}`, (*param).(*Params).ResponseIndex) - data, _ = sjson.Set(data, "content_block.id", fmt.Sprintf("%s-%d-%d", fcName, time.Now().UnixNano(), atomic.AddUint64(&toolUseIDCounter, 1))) - data, _ = sjson.Set(data, "content_block.name", fcName) + data, _ = sjson.Set(data, "content_block.id", fmt.Sprintf("%s-%d", upstreamToolName, atomic.AddUint64(&toolUseIDCounter, 1))) + data, _ = sjson.Set(data, "content_block.name", clientToolName) output = output + fmt.Sprintf("data: %s\n\n\n", data) if fcArgsResult := functionCallResult.Get("args"); fcArgsResult.Exists() { @@ -249,7 +252,7 @@ func ConvertGeminiResponseToClaude(_ context.Context, _ string, originalRequestR output = output + `data: ` template := `{"type":"message_delta","delta":{"stop_reason":"end_turn","stop_sequence":null},"usage":{"input_tokens":0,"output_tokens":0}}` - if usedTool { + if (*param).(*Params).SawToolCall { template = `{"type":"message_delta","delta":{"stop_reason":"tool_use","stop_sequence":null},"usage":{"input_tokens":0,"output_tokens":0}}` } else if finish := gjson.GetBytes(rawJSON, "candidates.0.finishReason"); finish.Exists() && finish.String() == "MAX_TOKENS" { template = `{"type":"message_delta","delta":{"stop_reason":"max_tokens","stop_sequence":null},"usage":{"input_tokens":0,"output_tokens":0}}` @@ -278,10 +281,10 @@ func ConvertGeminiResponseToClaude(_ context.Context, _ string, originalRequestR // Returns: // - string: A Claude-compatible JSON response. func ConvertGeminiResponseToClaudeNonStream(_ context.Context, _ string, originalRequestRawJSON, requestRawJSON, rawJSON []byte, _ *any) string { - _ = originalRequestRawJSON _ = requestRawJSON root := gjson.ParseBytes(rawJSON) + toolNameMap := util.ToolNameMapFromClaudeRequest(originalRequestRawJSON) out := `{"id":"","type":"message","role":"assistant","model":"","content":[],"stop_reason":null,"stop_sequence":null,"usage":{"input_tokens":0,"output_tokens":0}}` out, _ = sjson.Set(out, "id", root.Get("responseId").String()) @@ -336,11 +339,12 @@ func ConvertGeminiResponseToClaudeNonStream(_ context.Context, _ string, origina flushText() hasToolCall = true - name := functionCall.Get("name").String() + upstreamToolName := functionCall.Get("name").String() + clientToolName := util.MapToolName(toolNameMap, upstreamToolName) toolIDCounter++ toolBlock := `{"type":"tool_use","id":"","name":"","input":{}}` - toolBlock, _ = sjson.Set(toolBlock, "id", fmt.Sprintf("tool_%d", toolIDCounter)) - toolBlock, _ = sjson.Set(toolBlock, "name", name) + toolBlock, _ = sjson.Set(toolBlock, "id", fmt.Sprintf("%s-%d", upstreamToolName, toolIDCounter)) + toolBlock, _ = sjson.Set(toolBlock, "name", clientToolName) inputRaw := "{}" if args := functionCall.Get("args"); args.Exists() && gjson.Valid(args.Raw) && args.IsObject() { inputRaw = args.Raw diff --git a/internal/translator/openai/claude/openai_claude_response.go b/internal/translator/openai/claude/openai_claude_response.go index ca20c848..7bb496a2 100644 --- a/internal/translator/openai/claude/openai_claude_response.go +++ b/internal/translator/openai/claude/openai_claude_response.go @@ -22,9 +22,11 @@ var ( // ConvertOpenAIResponseToAnthropicParams holds parameters for response conversion type ConvertOpenAIResponseToAnthropicParams struct { - MessageID string - Model string - CreatedAt int64 + MessageID string + Model string + CreatedAt int64 + ToolNameMap map[string]string + SawToolCall bool // Content accumulator for streaming ContentAccumulator strings.Builder // Tool calls accumulator for streaming @@ -78,6 +80,8 @@ func ConvertOpenAIResponseToClaude(_ context.Context, _ string, originalRequestR MessageID: "", Model: "", CreatedAt: 0, + ToolNameMap: nil, + SawToolCall: false, ContentAccumulator: strings.Builder{}, ToolCallsAccumulator: nil, TextContentBlockStarted: false, @@ -97,6 +101,10 @@ func ConvertOpenAIResponseToClaude(_ context.Context, _ string, originalRequestR } rawJSON = bytes.TrimSpace(rawJSON[5:]) + if (*param).(*ConvertOpenAIResponseToAnthropicParams).ToolNameMap == nil { + (*param).(*ConvertOpenAIResponseToAnthropicParams).ToolNameMap = util.ToolNameMapFromClaudeRequest(originalRequestRawJSON) + } + // Check if this is the [DONE] marker rawStr := strings.TrimSpace(string(rawJSON)) if rawStr == "[DONE]" { @@ -111,6 +119,16 @@ func ConvertOpenAIResponseToClaude(_ context.Context, _ string, originalRequestR } } +func effectiveOpenAIFinishReason(param *ConvertOpenAIResponseToAnthropicParams) string { + if param == nil { + return "" + } + if param.SawToolCall { + return "tool_calls" + } + return param.FinishReason +} + // convertOpenAIStreamingChunkToAnthropic converts OpenAI streaming chunk to Anthropic streaming events func convertOpenAIStreamingChunkToAnthropic(rawJSON []byte, param *ConvertOpenAIResponseToAnthropicParams) []string { root := gjson.ParseBytes(rawJSON) @@ -197,6 +215,7 @@ func convertOpenAIStreamingChunkToAnthropic(rawJSON []byte, param *ConvertOpenAI } toolCalls.ForEach(func(_, toolCall gjson.Result) bool { + param.SawToolCall = true index := int(toolCall.Get("index").Int()) blockIndex := param.toolContentBlockIndex(index) @@ -215,7 +234,7 @@ func convertOpenAIStreamingChunkToAnthropic(rawJSON []byte, param *ConvertOpenAI // Handle function name if function := toolCall.Get("function"); function.Exists() { if name := function.Get("name"); name.Exists() { - accumulator.Name = name.String() + accumulator.Name = util.MapToolName(param.ToolNameMap, name.String()) stopThinkingContentBlock(param, &results) @@ -246,7 +265,11 @@ func convertOpenAIStreamingChunkToAnthropic(rawJSON []byte, param *ConvertOpenAI // Handle finish_reason (but don't send message_delta/message_stop yet) if finishReason := root.Get("choices.0.finish_reason"); finishReason.Exists() && finishReason.String() != "" { reason := finishReason.String() - param.FinishReason = reason + if param.SawToolCall { + param.FinishReason = "tool_calls" + } else { + param.FinishReason = reason + } // Send content_block_stop for thinking content if needed if param.ThinkingContentBlockStarted { @@ -294,7 +317,7 @@ func convertOpenAIStreamingChunkToAnthropic(rawJSON []byte, param *ConvertOpenAI inputTokens, outputTokens, cachedTokens = extractOpenAIUsage(usage) // Send message_delta with usage messageDeltaJSON := `{"type":"message_delta","delta":{"stop_reason":"","stop_sequence":null},"usage":{"input_tokens":0,"output_tokens":0}}` - messageDeltaJSON, _ = sjson.Set(messageDeltaJSON, "delta.stop_reason", mapOpenAIFinishReasonToAnthropic(param.FinishReason)) + messageDeltaJSON, _ = sjson.Set(messageDeltaJSON, "delta.stop_reason", mapOpenAIFinishReasonToAnthropic(effectiveOpenAIFinishReason(param))) messageDeltaJSON, _ = sjson.Set(messageDeltaJSON, "usage.input_tokens", inputTokens) messageDeltaJSON, _ = sjson.Set(messageDeltaJSON, "usage.output_tokens", outputTokens) if cachedTokens > 0 { @@ -348,7 +371,7 @@ func convertOpenAIDoneToAnthropic(param *ConvertOpenAIResponseToAnthropicParams) // If we haven't sent message_delta yet (no usage info was received), send it now if param.FinishReason != "" && !param.MessageDeltaSent { messageDeltaJSON := `{"type":"message_delta","delta":{"stop_reason":"","stop_sequence":null},"usage":{"input_tokens":0,"output_tokens":0}}` - messageDeltaJSON, _ = sjson.Set(messageDeltaJSON, "delta.stop_reason", mapOpenAIFinishReasonToAnthropic(param.FinishReason)) + messageDeltaJSON, _ = sjson.Set(messageDeltaJSON, "delta.stop_reason", mapOpenAIFinishReasonToAnthropic(effectiveOpenAIFinishReason(param))) results = append(results, "event: message_delta\ndata: "+messageDeltaJSON+"\n\n") param.MessageDeltaSent = true } @@ -531,10 +554,10 @@ func stopTextContentBlock(param *ConvertOpenAIResponseToAnthropicParams, results // Returns: // - string: An Anthropic-compatible JSON response. func ConvertOpenAIResponseToClaudeNonStream(_ context.Context, _ string, originalRequestRawJSON, requestRawJSON, rawJSON []byte, _ *any) string { - _ = originalRequestRawJSON _ = requestRawJSON root := gjson.ParseBytes(rawJSON) + toolNameMap := util.ToolNameMapFromClaudeRequest(originalRequestRawJSON) out := `{"id":"","type":"message","role":"assistant","model":"","content":[],"stop_reason":null,"stop_sequence":null,"usage":{"input_tokens":0,"output_tokens":0}}` out, _ = sjson.Set(out, "id", root.Get("id").String()) out, _ = sjson.Set(out, "model", root.Get("model").String()) @@ -590,7 +613,7 @@ func ConvertOpenAIResponseToClaudeNonStream(_ context.Context, _ string, origina hasToolCall = true toolUse := `{"type":"tool_use","id":"","name":"","input":{}}` toolUse, _ = sjson.Set(toolUse, "id", tc.Get("id").String()) - toolUse, _ = sjson.Set(toolUse, "name", tc.Get("function.name").String()) + toolUse, _ = sjson.Set(toolUse, "name", util.MapToolName(toolNameMap, tc.Get("function.name").String())) argsStr := util.FixJSON(tc.Get("function.arguments").String()) if argsStr != "" && gjson.Valid(argsStr) { @@ -647,7 +670,7 @@ func ConvertOpenAIResponseToClaudeNonStream(_ context.Context, _ string, origina hasToolCall = true toolUseBlock := `{"type":"tool_use","id":"","name":"","input":{}}` toolUseBlock, _ = sjson.Set(toolUseBlock, "id", toolCall.Get("id").String()) - toolUseBlock, _ = sjson.Set(toolUseBlock, "name", toolCall.Get("function.name").String()) + toolUseBlock, _ = sjson.Set(toolUseBlock, "name", util.MapToolName(toolNameMap, toolCall.Get("function.name").String())) argsStr := util.FixJSON(toolCall.Get("function.arguments").String()) if argsStr != "" && gjson.Valid(argsStr) { diff --git a/internal/util/translator.go b/internal/util/translator.go index 51ecb748..669ba745 100644 --- a/internal/util/translator.go +++ b/internal/util/translator.go @@ -6,6 +6,7 @@ package util import ( "bytes" "fmt" + "strings" "github.com/tidwall/gjson" "github.com/tidwall/sjson" @@ -219,3 +220,54 @@ func FixJSON(input string) string { return out.String() } + +func CanonicalToolName(name string) string { + canonical := strings.TrimSpace(name) + canonical = strings.TrimLeft(canonical, "_") + return strings.ToLower(canonical) +} + +// ToolNameMapFromClaudeRequest returns a canonical-name -> original-name map extracted from a Claude request. +// It is used to restore exact tool name casing for clients that require strict tool name matching (e.g. Claude Code). +func ToolNameMapFromClaudeRequest(rawJSON []byte) map[string]string { + if len(rawJSON) == 0 || !gjson.ValidBytes(rawJSON) { + return nil + } + + tools := gjson.GetBytes(rawJSON, "tools") + if !tools.Exists() || !tools.IsArray() { + return nil + } + + toolResults := tools.Array() + out := make(map[string]string, len(toolResults)) + tools.ForEach(func(_, tool gjson.Result) bool { + name := strings.TrimSpace(tool.Get("name").String()) + if name == "" { + return true + } + key := CanonicalToolName(name) + if key == "" { + return true + } + if _, exists := out[key]; !exists { + out[key] = name + } + return true + }) + + if len(out) == 0 { + return nil + } + return out +} + +func MapToolName(toolNameMap map[string]string, name string) string { + if name == "" || toolNameMap == nil { + return name + } + if mapped, ok := toolNameMap[CanonicalToolName(name)]; ok && mapped != "" { + return mapped + } + return name +} diff --git a/internal/watcher/clients.go b/internal/watcher/clients.go index cf0ed076..2697fa05 100644 --- a/internal/watcher/clients.go +++ b/internal/watcher/clients.go @@ -183,7 +183,7 @@ func (w *Watcher) addOrUpdateClient(path string) { if w.reloadCallback != nil { log.Debugf("triggering server update callback after add/update") - w.reloadCallback(cfg) + w.triggerServerUpdate(cfg) } w.persistAuthAsync(fmt.Sprintf("Sync auth %s", filepath.Base(path)), path) } @@ -202,7 +202,7 @@ func (w *Watcher) removeClient(path string) { if w.reloadCallback != nil { log.Debugf("triggering server update callback after removal") - w.reloadCallback(cfg) + w.triggerServerUpdate(cfg) } w.persistAuthAsync(fmt.Sprintf("Remove auth %s", filepath.Base(path)), path) } @@ -303,3 +303,79 @@ func (w *Watcher) persistAuthAsync(message string, paths ...string) { } }() } + +func (w *Watcher) stopServerUpdateTimer() { + w.serverUpdateMu.Lock() + defer w.serverUpdateMu.Unlock() + if w.serverUpdateTimer != nil { + w.serverUpdateTimer.Stop() + w.serverUpdateTimer = nil + } + w.serverUpdatePend = false +} + +func (w *Watcher) triggerServerUpdate(cfg *config.Config) { + if w == nil || w.reloadCallback == nil || cfg == nil { + return + } + if w.stopped.Load() { + return + } + + now := time.Now() + + w.serverUpdateMu.Lock() + if w.serverUpdateLast.IsZero() || now.Sub(w.serverUpdateLast) >= serverUpdateDebounce { + w.serverUpdateLast = now + if w.serverUpdateTimer != nil { + w.serverUpdateTimer.Stop() + w.serverUpdateTimer = nil + } + w.serverUpdatePend = false + w.serverUpdateMu.Unlock() + w.reloadCallback(cfg) + return + } + + if w.serverUpdatePend { + w.serverUpdateMu.Unlock() + return + } + + delay := serverUpdateDebounce - now.Sub(w.serverUpdateLast) + if delay < 10*time.Millisecond { + delay = 10 * time.Millisecond + } + w.serverUpdatePend = true + if w.serverUpdateTimer != nil { + w.serverUpdateTimer.Stop() + w.serverUpdateTimer = nil + } + var timer *time.Timer + timer = time.AfterFunc(delay, func() { + if w.stopped.Load() { + return + } + w.clientsMutex.RLock() + latestCfg := w.config + w.clientsMutex.RUnlock() + + w.serverUpdateMu.Lock() + if w.serverUpdateTimer != timer || !w.serverUpdatePend { + w.serverUpdateMu.Unlock() + return + } + w.serverUpdateTimer = nil + w.serverUpdatePend = false + if latestCfg == nil || w.reloadCallback == nil || w.stopped.Load() { + w.serverUpdateMu.Unlock() + return + } + + w.serverUpdateLast = time.Now() + w.serverUpdateMu.Unlock() + w.reloadCallback(latestCfg) + }) + w.serverUpdateTimer = timer + w.serverUpdateMu.Unlock() +} diff --git a/internal/watcher/watcher.go b/internal/watcher/watcher.go index a451ef6e..3dbfcc81 100644 --- a/internal/watcher/watcher.go +++ b/internal/watcher/watcher.go @@ -6,6 +6,7 @@ import ( "context" "strings" "sync" + "sync/atomic" "time" "github.com/fsnotify/fsnotify" @@ -35,6 +36,11 @@ type Watcher struct { clientsMutex sync.RWMutex configReloadMu sync.Mutex configReloadTimer *time.Timer + serverUpdateMu sync.Mutex + serverUpdateTimer *time.Timer + serverUpdateLast time.Time + serverUpdatePend bool + stopped atomic.Bool reloadCallback func(*config.Config) watcher *fsnotify.Watcher lastAuthHashes map[string]string @@ -76,6 +82,7 @@ const ( replaceCheckDelay = 50 * time.Millisecond configReloadDebounce = 150 * time.Millisecond authRemoveDebounceWindow = 1 * time.Second + serverUpdateDebounce = 1 * time.Second ) // NewWatcher creates a new file watcher instance @@ -114,8 +121,10 @@ func (w *Watcher) Start(ctx context.Context) error { // Stop stops the file watcher func (w *Watcher) Stop() error { + w.stopped.Store(true) w.stopDispatch() w.stopConfigReloadTimer() + w.stopServerUpdateTimer() return w.watcher.Close() } diff --git a/internal/watcher/watcher_test.go b/internal/watcher/watcher_test.go index a3be5877..0f9cd019 100644 --- a/internal/watcher/watcher_test.go +++ b/internal/watcher/watcher_test.go @@ -441,6 +441,46 @@ func TestRemoveClientRemovesHash(t *testing.T) { } } +func TestTriggerServerUpdateCancelsPendingTimerOnImmediate(t *testing.T) { + tmpDir := t.TempDir() + cfg := &config.Config{AuthDir: tmpDir} + + var reloads int32 + w := &Watcher{ + reloadCallback: func(*config.Config) { + atomic.AddInt32(&reloads, 1) + }, + } + w.SetConfig(cfg) + + w.serverUpdateMu.Lock() + w.serverUpdateLast = time.Now().Add(-(serverUpdateDebounce - 100*time.Millisecond)) + w.serverUpdateMu.Unlock() + w.triggerServerUpdate(cfg) + + if got := atomic.LoadInt32(&reloads); got != 0 { + t.Fatalf("expected no immediate reload, got %d", got) + } + + w.serverUpdateMu.Lock() + if !w.serverUpdatePend || w.serverUpdateTimer == nil { + w.serverUpdateMu.Unlock() + t.Fatal("expected a pending server update timer") + } + w.serverUpdateLast = time.Now().Add(-(serverUpdateDebounce + 10*time.Millisecond)) + w.serverUpdateMu.Unlock() + + w.triggerServerUpdate(cfg) + if got := atomic.LoadInt32(&reloads); got != 1 { + t.Fatalf("expected immediate reload once, got %d", got) + } + + time.Sleep(250 * time.Millisecond) + if got := atomic.LoadInt32(&reloads); got != 1 { + t.Fatalf("expected pending timer to be cancelled, got %d reloads", got) + } +} + func TestShouldDebounceRemove(t *testing.T) { w := &Watcher{} path := filepath.Clean("test.json")