mirror of
https://github.com/router-for-me/CLIProxyAPIPlus.git
synced 2026-03-07 22:33:30 +00:00
Merge branch 'router-for-me:main' into main
This commit is contained in:
@@ -839,6 +839,20 @@ func GetOpenAIModels() []*ModelInfo {
|
||||
SupportedParameters: []string{"tools"},
|
||||
Thinking: &ThinkingSupport{Levels: []string{"low", "medium", "high", "xhigh"}},
|
||||
},
|
||||
{
|
||||
ID: "gpt-5.4",
|
||||
Object: "model",
|
||||
Created: 1772668800,
|
||||
OwnedBy: "openai",
|
||||
Type: "openai",
|
||||
Version: "gpt-5.4",
|
||||
DisplayName: "GPT 5.4",
|
||||
Description: "Stable version of GPT 5.4 Codex, The best model for coding and agentic tasks across domains.",
|
||||
ContextLength: 1_050_000,
|
||||
MaxCompletionTokens: 128000,
|
||||
SupportedParameters: []string{"tools"},
|
||||
Thinking: &ThinkingSupport{Levels: []string{"low", "medium", "high", "xhigh"}},
|
||||
},
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
@@ -187,17 +187,15 @@ func (e *ClaudeExecutor) Execute(ctx context.Context, auth *cliproxyauth.Auth, r
|
||||
}
|
||||
recordAPIResponseMetadata(ctx, e.cfg, httpResp.StatusCode, httpResp.Header.Clone())
|
||||
if httpResp.StatusCode < 200 || httpResp.StatusCode >= 300 {
|
||||
// Decompress error responses (e.g. gzip-compressed 400 errors from Anthropic API).
|
||||
errBody := httpResp.Body
|
||||
if ce := httpResp.Header.Get("Content-Encoding"); ce != "" {
|
||||
var decErr error
|
||||
errBody, decErr = decodeResponseBody(httpResp.Body, ce)
|
||||
if decErr != nil {
|
||||
recordAPIResponseError(ctx, e.cfg, decErr)
|
||||
msg := fmt.Sprintf("failed to decode error response body (encoding=%s): %v", ce, decErr)
|
||||
logWithRequestID(ctx).Warn(msg)
|
||||
return resp, statusErr{code: httpResp.StatusCode, msg: msg}
|
||||
}
|
||||
// Decompress error responses — pass the Content-Encoding value (may be empty)
|
||||
// and let decodeResponseBody handle both header-declared and magic-byte-detected
|
||||
// compression. This keeps error-path behaviour consistent with the success path.
|
||||
errBody, decErr := decodeResponseBody(httpResp.Body, httpResp.Header.Get("Content-Encoding"))
|
||||
if decErr != nil {
|
||||
recordAPIResponseError(ctx, e.cfg, decErr)
|
||||
msg := fmt.Sprintf("failed to decode error response body: %v", decErr)
|
||||
logWithRequestID(ctx).Warn(msg)
|
||||
return resp, statusErr{code: httpResp.StatusCode, msg: msg}
|
||||
}
|
||||
b, readErr := io.ReadAll(errBody)
|
||||
if readErr != nil {
|
||||
@@ -352,17 +350,15 @@ func (e *ClaudeExecutor) ExecuteStream(ctx context.Context, auth *cliproxyauth.A
|
||||
}
|
||||
recordAPIResponseMetadata(ctx, e.cfg, httpResp.StatusCode, httpResp.Header.Clone())
|
||||
if httpResp.StatusCode < 200 || httpResp.StatusCode >= 300 {
|
||||
// Decompress error responses (e.g. gzip-compressed 400 errors from Anthropic API).
|
||||
errBody := httpResp.Body
|
||||
if ce := httpResp.Header.Get("Content-Encoding"); ce != "" {
|
||||
var decErr error
|
||||
errBody, decErr = decodeResponseBody(httpResp.Body, ce)
|
||||
if decErr != nil {
|
||||
recordAPIResponseError(ctx, e.cfg, decErr)
|
||||
msg := fmt.Sprintf("failed to decode error response body (encoding=%s): %v", ce, decErr)
|
||||
logWithRequestID(ctx).Warn(msg)
|
||||
return nil, statusErr{code: httpResp.StatusCode, msg: msg}
|
||||
}
|
||||
// Decompress error responses — pass the Content-Encoding value (may be empty)
|
||||
// and let decodeResponseBody handle both header-declared and magic-byte-detected
|
||||
// compression. This keeps error-path behaviour consistent with the success path.
|
||||
errBody, decErr := decodeResponseBody(httpResp.Body, httpResp.Header.Get("Content-Encoding"))
|
||||
if decErr != nil {
|
||||
recordAPIResponseError(ctx, e.cfg, decErr)
|
||||
msg := fmt.Sprintf("failed to decode error response body: %v", decErr)
|
||||
logWithRequestID(ctx).Warn(msg)
|
||||
return nil, statusErr{code: httpResp.StatusCode, msg: msg}
|
||||
}
|
||||
b, readErr := io.ReadAll(errBody)
|
||||
if readErr != nil {
|
||||
@@ -521,17 +517,15 @@ func (e *ClaudeExecutor) CountTokens(ctx context.Context, auth *cliproxyauth.Aut
|
||||
}
|
||||
recordAPIResponseMetadata(ctx, e.cfg, resp.StatusCode, resp.Header.Clone())
|
||||
if resp.StatusCode < 200 || resp.StatusCode >= 300 {
|
||||
// Decompress error responses (e.g. gzip-compressed 400 errors from Anthropic API).
|
||||
errBody := resp.Body
|
||||
if ce := resp.Header.Get("Content-Encoding"); ce != "" {
|
||||
var decErr error
|
||||
errBody, decErr = decodeResponseBody(resp.Body, ce)
|
||||
if decErr != nil {
|
||||
recordAPIResponseError(ctx, e.cfg, decErr)
|
||||
msg := fmt.Sprintf("failed to decode error response body (encoding=%s): %v", ce, decErr)
|
||||
logWithRequestID(ctx).Warn(msg)
|
||||
return cliproxyexecutor.Response{}, statusErr{code: resp.StatusCode, msg: msg}
|
||||
}
|
||||
// Decompress error responses — pass the Content-Encoding value (may be empty)
|
||||
// and let decodeResponseBody handle both header-declared and magic-byte-detected
|
||||
// compression. This keeps error-path behaviour consistent with the success path.
|
||||
errBody, decErr := decodeResponseBody(resp.Body, resp.Header.Get("Content-Encoding"))
|
||||
if decErr != nil {
|
||||
recordAPIResponseError(ctx, e.cfg, decErr)
|
||||
msg := fmt.Sprintf("failed to decode error response body: %v", decErr)
|
||||
logWithRequestID(ctx).Warn(msg)
|
||||
return cliproxyexecutor.Response{}, statusErr{code: resp.StatusCode, msg: msg}
|
||||
}
|
||||
b, readErr := io.ReadAll(errBody)
|
||||
if readErr != nil {
|
||||
@@ -662,12 +656,61 @@ func (c *compositeReadCloser) Close() error {
|
||||
return firstErr
|
||||
}
|
||||
|
||||
// peekableBody wraps a bufio.Reader around the original ReadCloser so that
|
||||
// magic bytes can be inspected without consuming them from the stream.
|
||||
type peekableBody struct {
|
||||
*bufio.Reader
|
||||
closer io.Closer
|
||||
}
|
||||
|
||||
func (p *peekableBody) Close() error {
|
||||
return p.closer.Close()
|
||||
}
|
||||
|
||||
func decodeResponseBody(body io.ReadCloser, contentEncoding string) (io.ReadCloser, error) {
|
||||
if body == nil {
|
||||
return nil, fmt.Errorf("response body is nil")
|
||||
}
|
||||
if contentEncoding == "" {
|
||||
return body, nil
|
||||
// No Content-Encoding header. Attempt best-effort magic-byte detection to
|
||||
// handle misbehaving upstreams that compress without setting the header.
|
||||
// Only gzip (1f 8b) and zstd (28 b5 2f fd) have reliable magic sequences;
|
||||
// br and deflate have none and are left as-is.
|
||||
// The bufio wrapper preserves unread bytes so callers always see the full
|
||||
// stream regardless of whether decompression was applied.
|
||||
pb := &peekableBody{Reader: bufio.NewReader(body), closer: body}
|
||||
magic, peekErr := pb.Peek(4)
|
||||
if peekErr == nil || (peekErr == io.EOF && len(magic) >= 2) {
|
||||
switch {
|
||||
case len(magic) >= 2 && magic[0] == 0x1f && magic[1] == 0x8b:
|
||||
gzipReader, gzErr := gzip.NewReader(pb)
|
||||
if gzErr != nil {
|
||||
_ = pb.Close()
|
||||
return nil, fmt.Errorf("magic-byte gzip: failed to create reader: %w", gzErr)
|
||||
}
|
||||
return &compositeReadCloser{
|
||||
Reader: gzipReader,
|
||||
closers: []func() error{
|
||||
gzipReader.Close,
|
||||
pb.Close,
|
||||
},
|
||||
}, nil
|
||||
case len(magic) >= 4 && magic[0] == 0x28 && magic[1] == 0xb5 && magic[2] == 0x2f && magic[3] == 0xfd:
|
||||
decoder, zdErr := zstd.NewReader(pb)
|
||||
if zdErr != nil {
|
||||
_ = pb.Close()
|
||||
return nil, fmt.Errorf("magic-byte zstd: failed to create reader: %w", zdErr)
|
||||
}
|
||||
return &compositeReadCloser{
|
||||
Reader: decoder,
|
||||
closers: []func() error{
|
||||
func() error { decoder.Close(); return nil },
|
||||
pb.Close,
|
||||
},
|
||||
}, nil
|
||||
}
|
||||
}
|
||||
return pb, nil
|
||||
}
|
||||
encodings := strings.Split(contentEncoding, ",")
|
||||
for _, raw := range encodings {
|
||||
@@ -844,11 +887,15 @@ func applyClaudeHeaders(r *http.Request, auth *cliproxyauth.Auth, apiKey string,
|
||||
r.Header.Set("User-Agent", hdrDefault(hd.UserAgent, "claude-cli/2.1.63 (external, cli)"))
|
||||
}
|
||||
r.Header.Set("Connection", "keep-alive")
|
||||
r.Header.Set("Accept-Encoding", "gzip, deflate, br, zstd")
|
||||
if stream {
|
||||
r.Header.Set("Accept", "text/event-stream")
|
||||
// SSE streams must not be compressed: the downstream scanner reads
|
||||
// line-delimited text and cannot parse compressed bytes. Using
|
||||
// "identity" tells the upstream to send an uncompressed stream.
|
||||
r.Header.Set("Accept-Encoding", "identity")
|
||||
} else {
|
||||
r.Header.Set("Accept", "application/json")
|
||||
r.Header.Set("Accept-Encoding", "gzip, deflate, br, zstd")
|
||||
}
|
||||
// Keep OS/Arch mapping dynamic (not configurable).
|
||||
// They intentionally continue to derive from runtime.GOOS/runtime.GOARCH.
|
||||
@@ -857,6 +904,12 @@ func applyClaudeHeaders(r *http.Request, auth *cliproxyauth.Auth, apiKey string,
|
||||
attrs = auth.Attributes
|
||||
}
|
||||
util.ApplyCustomHeadersFromAttrs(r, attrs)
|
||||
// Re-enforce Accept-Encoding: identity after ApplyCustomHeadersFromAttrs, which
|
||||
// may override it with a user-configured value. Compressed SSE breaks the line
|
||||
// scanner regardless of user preference, so this is non-negotiable for streams.
|
||||
if stream {
|
||||
r.Header.Set("Accept-Encoding", "identity")
|
||||
}
|
||||
}
|
||||
|
||||
func claudeCreds(a *cliproxyauth.Auth) (apiKey, baseURL string) {
|
||||
|
||||
@@ -2,6 +2,7 @@ package executor
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"compress/gzip"
|
||||
"context"
|
||||
"io"
|
||||
"net/http"
|
||||
@@ -9,6 +10,7 @@ import (
|
||||
"strings"
|
||||
"testing"
|
||||
|
||||
"github.com/klauspost/compress/zstd"
|
||||
"github.com/router-for-me/CLIProxyAPI/v6/internal/config"
|
||||
cliproxyauth "github.com/router-for-me/CLIProxyAPI/v6/sdk/cliproxy/auth"
|
||||
cliproxyexecutor "github.com/router-for-me/CLIProxyAPI/v6/sdk/cliproxy/executor"
|
||||
@@ -583,3 +585,385 @@ func testClaudeExecutorInvalidCompressedErrorBody(
|
||||
t.Fatalf("expected status code 400, got: %v", err)
|
||||
}
|
||||
}
|
||||
|
||||
// TestClaudeExecutor_ExecuteStream_SetsIdentityAcceptEncoding verifies that streaming
|
||||
// requests use Accept-Encoding: identity so the upstream cannot respond with a
|
||||
// compressed SSE body that would silently break the line scanner.
|
||||
func TestClaudeExecutor_ExecuteStream_SetsIdentityAcceptEncoding(t *testing.T) {
|
||||
var gotEncoding, gotAccept string
|
||||
server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
gotEncoding = r.Header.Get("Accept-Encoding")
|
||||
gotAccept = r.Header.Get("Accept")
|
||||
w.Header().Set("Content-Type", "text/event-stream")
|
||||
_, _ = w.Write([]byte("data: {\"type\":\"message_stop\"}\n\n"))
|
||||
}))
|
||||
defer server.Close()
|
||||
|
||||
executor := NewClaudeExecutor(&config.Config{})
|
||||
auth := &cliproxyauth.Auth{Attributes: map[string]string{
|
||||
"api_key": "key-123",
|
||||
"base_url": server.URL,
|
||||
}}
|
||||
payload := []byte(`{"messages":[{"role":"user","content":[{"type":"text","text":"hi"}]}]}`)
|
||||
|
||||
result, err := executor.ExecuteStream(context.Background(), auth, cliproxyexecutor.Request{
|
||||
Model: "claude-3-5-sonnet-20241022",
|
||||
Payload: payload,
|
||||
}, cliproxyexecutor.Options{
|
||||
SourceFormat: sdktranslator.FromString("claude"),
|
||||
})
|
||||
if err != nil {
|
||||
t.Fatalf("ExecuteStream error: %v", err)
|
||||
}
|
||||
for chunk := range result.Chunks {
|
||||
if chunk.Err != nil {
|
||||
t.Fatalf("unexpected chunk error: %v", chunk.Err)
|
||||
}
|
||||
}
|
||||
|
||||
if gotEncoding != "identity" {
|
||||
t.Errorf("Accept-Encoding = %q, want %q", gotEncoding, "identity")
|
||||
}
|
||||
if gotAccept != "text/event-stream" {
|
||||
t.Errorf("Accept = %q, want %q", gotAccept, "text/event-stream")
|
||||
}
|
||||
}
|
||||
|
||||
// TestClaudeExecutor_Execute_SetsCompressedAcceptEncoding verifies that non-streaming
|
||||
// requests keep the full accept-encoding to allow response compression (which
|
||||
// decodeResponseBody handles correctly).
|
||||
func TestClaudeExecutor_Execute_SetsCompressedAcceptEncoding(t *testing.T) {
|
||||
var gotEncoding, gotAccept string
|
||||
server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
gotEncoding = r.Header.Get("Accept-Encoding")
|
||||
gotAccept = r.Header.Get("Accept")
|
||||
w.Header().Set("Content-Type", "application/json")
|
||||
_, _ = w.Write([]byte(`{"id":"msg_1","type":"message","model":"claude-3-5-sonnet-20241022","role":"assistant","content":[{"type":"text","text":"hi"}],"usage":{"input_tokens":1,"output_tokens":1}}`))
|
||||
}))
|
||||
defer server.Close()
|
||||
|
||||
executor := NewClaudeExecutor(&config.Config{})
|
||||
auth := &cliproxyauth.Auth{Attributes: map[string]string{
|
||||
"api_key": "key-123",
|
||||
"base_url": server.URL,
|
||||
}}
|
||||
payload := []byte(`{"messages":[{"role":"user","content":[{"type":"text","text":"hi"}]}]}`)
|
||||
|
||||
_, err := executor.Execute(context.Background(), auth, cliproxyexecutor.Request{
|
||||
Model: "claude-3-5-sonnet-20241022",
|
||||
Payload: payload,
|
||||
}, cliproxyexecutor.Options{
|
||||
SourceFormat: sdktranslator.FromString("claude"),
|
||||
})
|
||||
if err != nil {
|
||||
t.Fatalf("Execute error: %v", err)
|
||||
}
|
||||
|
||||
if gotEncoding != "gzip, deflate, br, zstd" {
|
||||
t.Errorf("Accept-Encoding = %q, want %q", gotEncoding, "gzip, deflate, br, zstd")
|
||||
}
|
||||
if gotAccept != "application/json" {
|
||||
t.Errorf("Accept = %q, want %q", gotAccept, "application/json")
|
||||
}
|
||||
}
|
||||
|
||||
// TestClaudeExecutor_ExecuteStream_GzipSuccessBodyDecoded verifies that a streaming
|
||||
// HTTP 200 response with Content-Encoding: gzip is correctly decompressed before
|
||||
// the line scanner runs, so SSE chunks are not silently dropped.
|
||||
func TestClaudeExecutor_ExecuteStream_GzipSuccessBodyDecoded(t *testing.T) {
|
||||
var buf bytes.Buffer
|
||||
gz := gzip.NewWriter(&buf)
|
||||
_, _ = gz.Write([]byte("data: {\"type\":\"message_stop\"}\n"))
|
||||
_ = gz.Close()
|
||||
compressedBody := buf.Bytes()
|
||||
|
||||
server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
w.Header().Set("Content-Type", "text/event-stream")
|
||||
w.Header().Set("Content-Encoding", "gzip")
|
||||
_, _ = w.Write(compressedBody)
|
||||
}))
|
||||
defer server.Close()
|
||||
|
||||
executor := NewClaudeExecutor(&config.Config{})
|
||||
auth := &cliproxyauth.Auth{Attributes: map[string]string{
|
||||
"api_key": "key-123",
|
||||
"base_url": server.URL,
|
||||
}}
|
||||
payload := []byte(`{"messages":[{"role":"user","content":[{"type":"text","text":"hi"}]}]}`)
|
||||
|
||||
result, err := executor.ExecuteStream(context.Background(), auth, cliproxyexecutor.Request{
|
||||
Model: "claude-3-5-sonnet-20241022",
|
||||
Payload: payload,
|
||||
}, cliproxyexecutor.Options{
|
||||
SourceFormat: sdktranslator.FromString("claude"),
|
||||
})
|
||||
if err != nil {
|
||||
t.Fatalf("ExecuteStream error: %v", err)
|
||||
}
|
||||
|
||||
var combined strings.Builder
|
||||
for chunk := range result.Chunks {
|
||||
if chunk.Err != nil {
|
||||
t.Fatalf("chunk error: %v", chunk.Err)
|
||||
}
|
||||
combined.Write(chunk.Payload)
|
||||
}
|
||||
|
||||
if combined.Len() == 0 {
|
||||
t.Fatal("expected at least one chunk from gzip-encoded SSE body, got none (body was not decompressed)")
|
||||
}
|
||||
if !strings.Contains(combined.String(), "message_stop") {
|
||||
t.Errorf("expected SSE content in chunks, got: %q", combined.String())
|
||||
}
|
||||
}
|
||||
|
||||
// TestDecodeResponseBody_MagicByteGzipNoHeader verifies that decodeResponseBody
|
||||
// detects gzip-compressed content via magic bytes even when Content-Encoding is absent.
|
||||
func TestDecodeResponseBody_MagicByteGzipNoHeader(t *testing.T) {
|
||||
const plaintext = "data: {\"type\":\"message_stop\"}\n"
|
||||
|
||||
var buf bytes.Buffer
|
||||
gz := gzip.NewWriter(&buf)
|
||||
_, _ = gz.Write([]byte(plaintext))
|
||||
_ = gz.Close()
|
||||
|
||||
rc := io.NopCloser(&buf)
|
||||
decoded, err := decodeResponseBody(rc, "")
|
||||
if err != nil {
|
||||
t.Fatalf("decodeResponseBody error: %v", err)
|
||||
}
|
||||
defer decoded.Close()
|
||||
|
||||
got, err := io.ReadAll(decoded)
|
||||
if err != nil {
|
||||
t.Fatalf("ReadAll error: %v", err)
|
||||
}
|
||||
if string(got) != plaintext {
|
||||
t.Errorf("decoded = %q, want %q", got, plaintext)
|
||||
}
|
||||
}
|
||||
|
||||
// TestDecodeResponseBody_PlainTextNoHeader verifies that decodeResponseBody returns
|
||||
// plain text untouched when Content-Encoding is absent and no magic bytes match.
|
||||
func TestDecodeResponseBody_PlainTextNoHeader(t *testing.T) {
|
||||
const plaintext = "data: {\"type\":\"message_stop\"}\n"
|
||||
rc := io.NopCloser(strings.NewReader(plaintext))
|
||||
decoded, err := decodeResponseBody(rc, "")
|
||||
if err != nil {
|
||||
t.Fatalf("decodeResponseBody error: %v", err)
|
||||
}
|
||||
defer decoded.Close()
|
||||
|
||||
got, err := io.ReadAll(decoded)
|
||||
if err != nil {
|
||||
t.Fatalf("ReadAll error: %v", err)
|
||||
}
|
||||
if string(got) != plaintext {
|
||||
t.Errorf("decoded = %q, want %q", got, plaintext)
|
||||
}
|
||||
}
|
||||
|
||||
// TestClaudeExecutor_ExecuteStream_GzipNoContentEncodingHeader verifies the full
|
||||
// pipeline: when the upstream returns a gzip-compressed SSE body WITHOUT setting
|
||||
// Content-Encoding (a misbehaving upstream), the magic-byte sniff in
|
||||
// decodeResponseBody still decompresses it, so chunks reach the caller.
|
||||
func TestClaudeExecutor_ExecuteStream_GzipNoContentEncodingHeader(t *testing.T) {
|
||||
var buf bytes.Buffer
|
||||
gz := gzip.NewWriter(&buf)
|
||||
_, _ = gz.Write([]byte("data: {\"type\":\"message_stop\"}\n"))
|
||||
_ = gz.Close()
|
||||
compressedBody := buf.Bytes()
|
||||
|
||||
server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
w.Header().Set("Content-Type", "text/event-stream")
|
||||
// Intentionally omit Content-Encoding to simulate misbehaving upstream.
|
||||
_, _ = w.Write(compressedBody)
|
||||
}))
|
||||
defer server.Close()
|
||||
|
||||
executor := NewClaudeExecutor(&config.Config{})
|
||||
auth := &cliproxyauth.Auth{Attributes: map[string]string{
|
||||
"api_key": "key-123",
|
||||
"base_url": server.URL,
|
||||
}}
|
||||
payload := []byte(`{"messages":[{"role":"user","content":[{"type":"text","text":"hi"}]}]}`)
|
||||
|
||||
result, err := executor.ExecuteStream(context.Background(), auth, cliproxyexecutor.Request{
|
||||
Model: "claude-3-5-sonnet-20241022",
|
||||
Payload: payload,
|
||||
}, cliproxyexecutor.Options{
|
||||
SourceFormat: sdktranslator.FromString("claude"),
|
||||
})
|
||||
if err != nil {
|
||||
t.Fatalf("ExecuteStream error: %v", err)
|
||||
}
|
||||
|
||||
var combined strings.Builder
|
||||
for chunk := range result.Chunks {
|
||||
if chunk.Err != nil {
|
||||
t.Fatalf("chunk error: %v", chunk.Err)
|
||||
}
|
||||
combined.Write(chunk.Payload)
|
||||
}
|
||||
|
||||
if combined.Len() == 0 {
|
||||
t.Fatal("expected chunks from gzip body without Content-Encoding header, got none (magic-byte sniff failed)")
|
||||
}
|
||||
if !strings.Contains(combined.String(), "message_stop") {
|
||||
t.Errorf("unexpected chunk content: %q", combined.String())
|
||||
}
|
||||
}
|
||||
|
||||
// TestClaudeExecutor_ExecuteStream_AcceptEncodingOverrideCannotBypassIdentity verifies
|
||||
// that injecting Accept-Encoding via auth.Attributes cannot override the stream
|
||||
// path's enforced identity encoding.
|
||||
func TestClaudeExecutor_ExecuteStream_AcceptEncodingOverrideCannotBypassIdentity(t *testing.T) {
|
||||
var gotEncoding string
|
||||
server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
gotEncoding = r.Header.Get("Accept-Encoding")
|
||||
w.Header().Set("Content-Type", "text/event-stream")
|
||||
_, _ = w.Write([]byte("data: {\"type\":\"message_stop\"}\n\n"))
|
||||
}))
|
||||
defer server.Close()
|
||||
|
||||
executor := NewClaudeExecutor(&config.Config{})
|
||||
// Inject Accept-Encoding via the custom header attribute mechanism.
|
||||
auth := &cliproxyauth.Auth{Attributes: map[string]string{
|
||||
"api_key": "key-123",
|
||||
"base_url": server.URL,
|
||||
"header:Accept-Encoding": "gzip, deflate, br, zstd",
|
||||
}}
|
||||
payload := []byte(`{"messages":[{"role":"user","content":[{"type":"text","text":"hi"}]}]}`)
|
||||
|
||||
result, err := executor.ExecuteStream(context.Background(), auth, cliproxyexecutor.Request{
|
||||
Model: "claude-3-5-sonnet-20241022",
|
||||
Payload: payload,
|
||||
}, cliproxyexecutor.Options{
|
||||
SourceFormat: sdktranslator.FromString("claude"),
|
||||
})
|
||||
if err != nil {
|
||||
t.Fatalf("ExecuteStream error: %v", err)
|
||||
}
|
||||
for chunk := range result.Chunks {
|
||||
if chunk.Err != nil {
|
||||
t.Fatalf("unexpected chunk error: %v", chunk.Err)
|
||||
}
|
||||
}
|
||||
|
||||
if gotEncoding != "identity" {
|
||||
t.Errorf("Accept-Encoding = %q; stream path must enforce identity regardless of auth.Attributes override", gotEncoding)
|
||||
}
|
||||
}
|
||||
|
||||
// TestDecodeResponseBody_MagicByteZstdNoHeader verifies that decodeResponseBody
|
||||
// detects zstd-compressed content via magic bytes (28 b5 2f fd) even when
|
||||
// Content-Encoding is absent.
|
||||
func TestDecodeResponseBody_MagicByteZstdNoHeader(t *testing.T) {
|
||||
const plaintext = "data: {\"type\":\"message_stop\"}\n"
|
||||
|
||||
var buf bytes.Buffer
|
||||
enc, err := zstd.NewWriter(&buf)
|
||||
if err != nil {
|
||||
t.Fatalf("zstd.NewWriter: %v", err)
|
||||
}
|
||||
_, _ = enc.Write([]byte(plaintext))
|
||||
_ = enc.Close()
|
||||
|
||||
rc := io.NopCloser(&buf)
|
||||
decoded, err := decodeResponseBody(rc, "")
|
||||
if err != nil {
|
||||
t.Fatalf("decodeResponseBody error: %v", err)
|
||||
}
|
||||
defer decoded.Close()
|
||||
|
||||
got, err := io.ReadAll(decoded)
|
||||
if err != nil {
|
||||
t.Fatalf("ReadAll error: %v", err)
|
||||
}
|
||||
if string(got) != plaintext {
|
||||
t.Errorf("decoded = %q, want %q", got, plaintext)
|
||||
}
|
||||
}
|
||||
|
||||
// TestClaudeExecutor_Execute_GzipErrorBodyNoContentEncodingHeader verifies that the
|
||||
// error path (4xx) correctly decompresses a gzip body even when the upstream omits
|
||||
// the Content-Encoding header. This closes the gap left by PR #1771, which only
|
||||
// fixed header-declared compression on the error path.
|
||||
func TestClaudeExecutor_Execute_GzipErrorBodyNoContentEncodingHeader(t *testing.T) {
|
||||
const errJSON = `{"type":"error","error":{"type":"invalid_request_error","message":"test error"}}`
|
||||
|
||||
var buf bytes.Buffer
|
||||
gz := gzip.NewWriter(&buf)
|
||||
_, _ = gz.Write([]byte(errJSON))
|
||||
_ = gz.Close()
|
||||
compressedBody := buf.Bytes()
|
||||
|
||||
server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
w.Header().Set("Content-Type", "application/json")
|
||||
// Intentionally omit Content-Encoding to simulate misbehaving upstream.
|
||||
w.WriteHeader(http.StatusBadRequest)
|
||||
_, _ = w.Write(compressedBody)
|
||||
}))
|
||||
defer server.Close()
|
||||
|
||||
executor := NewClaudeExecutor(&config.Config{})
|
||||
auth := &cliproxyauth.Auth{Attributes: map[string]string{
|
||||
"api_key": "key-123",
|
||||
"base_url": server.URL,
|
||||
}}
|
||||
payload := []byte(`{"messages":[{"role":"user","content":[{"type":"text","text":"hi"}]}]}`)
|
||||
|
||||
_, err := executor.Execute(context.Background(), auth, cliproxyexecutor.Request{
|
||||
Model: "claude-3-5-sonnet-20241022",
|
||||
Payload: payload,
|
||||
}, cliproxyexecutor.Options{
|
||||
SourceFormat: sdktranslator.FromString("claude"),
|
||||
})
|
||||
if err == nil {
|
||||
t.Fatal("expected an error for 400 response, got nil")
|
||||
}
|
||||
if !strings.Contains(err.Error(), "test error") {
|
||||
t.Errorf("error message should contain decompressed JSON, got: %q", err.Error())
|
||||
}
|
||||
}
|
||||
|
||||
// TestClaudeExecutor_ExecuteStream_GzipErrorBodyNoContentEncodingHeader verifies
|
||||
// the same for the streaming executor: 4xx gzip body without Content-Encoding is
|
||||
// decoded and the error message is readable.
|
||||
func TestClaudeExecutor_ExecuteStream_GzipErrorBodyNoContentEncodingHeader(t *testing.T) {
|
||||
const errJSON = `{"type":"error","error":{"type":"invalid_request_error","message":"stream test error"}}`
|
||||
|
||||
var buf bytes.Buffer
|
||||
gz := gzip.NewWriter(&buf)
|
||||
_, _ = gz.Write([]byte(errJSON))
|
||||
_ = gz.Close()
|
||||
compressedBody := buf.Bytes()
|
||||
|
||||
server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
w.Header().Set("Content-Type", "application/json")
|
||||
// Intentionally omit Content-Encoding to simulate misbehaving upstream.
|
||||
w.WriteHeader(http.StatusBadRequest)
|
||||
_, _ = w.Write(compressedBody)
|
||||
}))
|
||||
defer server.Close()
|
||||
|
||||
executor := NewClaudeExecutor(&config.Config{})
|
||||
auth := &cliproxyauth.Auth{Attributes: map[string]string{
|
||||
"api_key": "key-123",
|
||||
"base_url": server.URL,
|
||||
}}
|
||||
payload := []byte(`{"messages":[{"role":"user","content":[{"type":"text","text":"hi"}]}]}`)
|
||||
|
||||
_, err := executor.ExecuteStream(context.Background(), auth, cliproxyexecutor.Request{
|
||||
Model: "claude-3-5-sonnet-20241022",
|
||||
Payload: payload,
|
||||
}, cliproxyexecutor.Options{
|
||||
SourceFormat: sdktranslator.FromString("claude"),
|
||||
})
|
||||
if err == nil {
|
||||
t.Fatal("expected an error for 400 response, got nil")
|
||||
}
|
||||
if !strings.Contains(err.Error(), "stream test error") {
|
||||
t.Errorf("error message should contain decompressed JSON, got: %q", err.Error())
|
||||
}
|
||||
}
|
||||
|
||||
@@ -616,6 +616,10 @@ func (e *CodexExecutor) cacheHelper(ctx context.Context, from sdktranslator.Form
|
||||
if promptCacheKey.Exists() {
|
||||
cache.ID = promptCacheKey.String()
|
||||
}
|
||||
} else if from == "openai" {
|
||||
if apiKey := strings.TrimSpace(apiKeyFromContext(ctx)); apiKey != "" {
|
||||
cache.ID = uuid.NewSHA1(uuid.NameSpaceOID, []byte("cli-proxy-api:codex:prompt-cache:"+apiKey)).String()
|
||||
}
|
||||
}
|
||||
|
||||
if cache.ID != "" {
|
||||
|
||||
64
internal/runtime/executor/codex_executor_cache_test.go
Normal file
64
internal/runtime/executor/codex_executor_cache_test.go
Normal file
@@ -0,0 +1,64 @@
|
||||
package executor
|
||||
|
||||
import (
|
||||
"context"
|
||||
"io"
|
||||
"net/http/httptest"
|
||||
"testing"
|
||||
|
||||
"github.com/gin-gonic/gin"
|
||||
"github.com/google/uuid"
|
||||
cliproxyexecutor "github.com/router-for-me/CLIProxyAPI/v6/sdk/cliproxy/executor"
|
||||
sdktranslator "github.com/router-for-me/CLIProxyAPI/v6/sdk/translator"
|
||||
"github.com/tidwall/gjson"
|
||||
)
|
||||
|
||||
func TestCodexExecutorCacheHelper_OpenAIChatCompletions_StablePromptCacheKeyFromAPIKey(t *testing.T) {
|
||||
recorder := httptest.NewRecorder()
|
||||
ginCtx, _ := gin.CreateTestContext(recorder)
|
||||
ginCtx.Set("apiKey", "test-api-key")
|
||||
|
||||
ctx := context.WithValue(context.Background(), "gin", ginCtx)
|
||||
executor := &CodexExecutor{}
|
||||
rawJSON := []byte(`{"model":"gpt-5.3-codex","stream":true}`)
|
||||
req := cliproxyexecutor.Request{
|
||||
Model: "gpt-5.3-codex",
|
||||
Payload: []byte(`{"model":"gpt-5.3-codex"}`),
|
||||
}
|
||||
url := "https://example.com/responses"
|
||||
|
||||
httpReq, err := executor.cacheHelper(ctx, sdktranslator.FromString("openai"), url, req, rawJSON)
|
||||
if err != nil {
|
||||
t.Fatalf("cacheHelper error: %v", err)
|
||||
}
|
||||
|
||||
body, errRead := io.ReadAll(httpReq.Body)
|
||||
if errRead != nil {
|
||||
t.Fatalf("read request body: %v", errRead)
|
||||
}
|
||||
|
||||
expectedKey := uuid.NewSHA1(uuid.NameSpaceOID, []byte("cli-proxy-api:codex:prompt-cache:test-api-key")).String()
|
||||
gotKey := gjson.GetBytes(body, "prompt_cache_key").String()
|
||||
if gotKey != expectedKey {
|
||||
t.Fatalf("prompt_cache_key = %q, want %q", gotKey, expectedKey)
|
||||
}
|
||||
if gotConversation := httpReq.Header.Get("Conversation_id"); gotConversation != expectedKey {
|
||||
t.Fatalf("Conversation_id = %q, want %q", gotConversation, expectedKey)
|
||||
}
|
||||
if gotSession := httpReq.Header.Get("Session_id"); gotSession != expectedKey {
|
||||
t.Fatalf("Session_id = %q, want %q", gotSession, expectedKey)
|
||||
}
|
||||
|
||||
httpReq2, err := executor.cacheHelper(ctx, sdktranslator.FromString("openai"), url, req, rawJSON)
|
||||
if err != nil {
|
||||
t.Fatalf("cacheHelper error (second call): %v", err)
|
||||
}
|
||||
body2, errRead2 := io.ReadAll(httpReq2.Body)
|
||||
if errRead2 != nil {
|
||||
t.Fatalf("read request body (second call): %v", errRead2)
|
||||
}
|
||||
gotKey2 := gjson.GetBytes(body2, "prompt_cache_key").String()
|
||||
if gotKey2 != expectedKey {
|
||||
t.Fatalf("prompt_cache_key (second call) = %q, want %q", gotKey2, expectedKey)
|
||||
}
|
||||
}
|
||||
@@ -25,7 +25,7 @@ func ConvertOpenAIResponsesRequestToCodex(modelName string, inputRawJSON []byte,
|
||||
rawJSON, _ = sjson.DeleteBytes(rawJSON, "max_completion_tokens")
|
||||
rawJSON, _ = sjson.DeleteBytes(rawJSON, "temperature")
|
||||
rawJSON, _ = sjson.DeleteBytes(rawJSON, "top_p")
|
||||
rawJSON, _ = sjson.DeleteBytes(rawJSON, "service_tier")
|
||||
// rawJSON, _ = sjson.DeleteBytes(rawJSON, "service_tier")
|
||||
rawJSON, _ = sjson.DeleteBytes(rawJSON, "truncation")
|
||||
rawJSON = applyResponsesCompactionCompatibility(rawJSON)
|
||||
|
||||
|
||||
@@ -6,24 +6,14 @@ import (
|
||||
"fmt"
|
||||
|
||||
"github.com/tidwall/gjson"
|
||||
"github.com/tidwall/sjson"
|
||||
)
|
||||
|
||||
// ConvertCodexResponseToOpenAIResponses converts OpenAI Chat Completions streaming chunks
|
||||
// to OpenAI Responses SSE events (response.*).
|
||||
|
||||
func ConvertCodexResponseToOpenAIResponses(ctx context.Context, modelName string, originalRequestRawJSON, requestRawJSON, rawJSON []byte, param *any) []string {
|
||||
func ConvertCodexResponseToOpenAIResponses(_ context.Context, _ string, _, _, rawJSON []byte, _ *any) []string {
|
||||
if bytes.HasPrefix(rawJSON, []byte("data:")) {
|
||||
rawJSON = bytes.TrimSpace(rawJSON[5:])
|
||||
if typeResult := gjson.GetBytes(rawJSON, "type"); typeResult.Exists() {
|
||||
typeStr := typeResult.String()
|
||||
if typeStr == "response.created" || typeStr == "response.in_progress" || typeStr == "response.completed" {
|
||||
if gjson.GetBytes(rawJSON, "response.instructions").Exists() {
|
||||
instructions := gjson.GetBytes(originalRequestRawJSON, "instructions").String()
|
||||
rawJSON, _ = sjson.SetBytes(rawJSON, "response.instructions", instructions)
|
||||
}
|
||||
}
|
||||
}
|
||||
out := fmt.Sprintf("data: %s", string(rawJSON))
|
||||
return []string{out}
|
||||
}
|
||||
@@ -32,17 +22,12 @@ func ConvertCodexResponseToOpenAIResponses(ctx context.Context, modelName string
|
||||
|
||||
// ConvertCodexResponseToOpenAIResponsesNonStream builds a single Responses JSON
|
||||
// from a non-streaming OpenAI Chat Completions response.
|
||||
func ConvertCodexResponseToOpenAIResponsesNonStream(_ context.Context, modelName string, originalRequestRawJSON, requestRawJSON, rawJSON []byte, _ *any) string {
|
||||
func ConvertCodexResponseToOpenAIResponsesNonStream(_ context.Context, _ string, _, _, rawJSON []byte, _ *any) string {
|
||||
rootResult := gjson.ParseBytes(rawJSON)
|
||||
// Verify this is a response.completed event
|
||||
if rootResult.Get("type").String() != "response.completed" {
|
||||
return ""
|
||||
}
|
||||
responseResult := rootResult.Get("response")
|
||||
template := responseResult.Raw
|
||||
if responseResult.Get("instructions").Exists() {
|
||||
instructions := gjson.GetBytes(originalRequestRawJSON, "instructions").String()
|
||||
template, _ = sjson.Set(template, "instructions", instructions)
|
||||
}
|
||||
return template
|
||||
return responseResult.Raw
|
||||
}
|
||||
|
||||
@@ -85,6 +85,11 @@ func ConvertClaudeRequestToGemini(modelName string, inputRawJSON []byte, _ bool)
|
||||
|
||||
case "tool_use":
|
||||
functionName := contentResult.Get("name").String()
|
||||
if toolUseID := contentResult.Get("id").String(); toolUseID != "" {
|
||||
if derived := toolNameFromClaudeToolUseID(toolUseID); derived != "" {
|
||||
functionName = derived
|
||||
}
|
||||
}
|
||||
functionArgs := contentResult.Get("input").String()
|
||||
argsResult := gjson.Parse(functionArgs)
|
||||
if argsResult.IsObject() && gjson.Valid(functionArgs) {
|
||||
@@ -100,10 +105,9 @@ func ConvertClaudeRequestToGemini(modelName string, inputRawJSON []byte, _ bool)
|
||||
if toolCallID == "" {
|
||||
return true
|
||||
}
|
||||
funcName := toolCallID
|
||||
toolCallIDs := strings.Split(toolCallID, "-")
|
||||
if len(toolCallIDs) > 1 {
|
||||
funcName = strings.Join(toolCallIDs[0:len(toolCallIDs)-1], "-")
|
||||
funcName := toolNameFromClaudeToolUseID(toolCallID)
|
||||
if funcName == "" {
|
||||
funcName = toolCallID
|
||||
}
|
||||
responseData := contentResult.Get("content").Raw
|
||||
part := `{"functionResponse":{"name":"","response":{"result":""}}}`
|
||||
@@ -230,3 +234,11 @@ func ConvertClaudeRequestToGemini(modelName string, inputRawJSON []byte, _ bool)
|
||||
|
||||
return result
|
||||
}
|
||||
|
||||
func toolNameFromClaudeToolUseID(toolUseID string) string {
|
||||
parts := strings.Split(toolUseID, "-")
|
||||
if len(parts) <= 1 {
|
||||
return ""
|
||||
}
|
||||
return strings.Join(parts[0:len(parts)-1], "-")
|
||||
}
|
||||
|
||||
@@ -12,8 +12,8 @@ import (
|
||||
"fmt"
|
||||
"strings"
|
||||
"sync/atomic"
|
||||
"time"
|
||||
|
||||
"github.com/router-for-me/CLIProxyAPI/v6/internal/util"
|
||||
"github.com/tidwall/gjson"
|
||||
"github.com/tidwall/sjson"
|
||||
)
|
||||
@@ -25,6 +25,8 @@ type Params struct {
|
||||
ResponseType int
|
||||
ResponseIndex int
|
||||
HasContent bool // Tracks whether any content (text, thinking, or tool use) has been output
|
||||
ToolNameMap map[string]string
|
||||
SawToolCall bool
|
||||
}
|
||||
|
||||
// toolUseIDCounter provides a process-wide unique counter for tool use identifiers.
|
||||
@@ -53,6 +55,8 @@ func ConvertGeminiResponseToClaude(_ context.Context, _ string, originalRequestR
|
||||
HasFirstResponse: false,
|
||||
ResponseType: 0,
|
||||
ResponseIndex: 0,
|
||||
ToolNameMap: util.ToolNameMapFromClaudeRequest(originalRequestRawJSON),
|
||||
SawToolCall: false,
|
||||
}
|
||||
}
|
||||
|
||||
@@ -66,8 +70,6 @@ func ConvertGeminiResponseToClaude(_ context.Context, _ string, originalRequestR
|
||||
return []string{}
|
||||
}
|
||||
|
||||
// Track whether tools are being used in this response chunk
|
||||
usedTool := false
|
||||
output := ""
|
||||
|
||||
// Initialize the streaming session with a message_start event
|
||||
@@ -175,12 +177,13 @@ func ConvertGeminiResponseToClaude(_ context.Context, _ string, originalRequestR
|
||||
} else if functionCallResult.Exists() {
|
||||
// Handle function/tool calls from the AI model
|
||||
// This processes tool usage requests and formats them for Claude API compatibility
|
||||
usedTool = true
|
||||
fcName := functionCallResult.Get("name").String()
|
||||
(*param).(*Params).SawToolCall = true
|
||||
upstreamToolName := functionCallResult.Get("name").String()
|
||||
clientToolName := util.MapToolName((*param).(*Params).ToolNameMap, upstreamToolName)
|
||||
|
||||
// FIX: Handle streaming split/delta where name might be empty in subsequent chunks.
|
||||
// If we are already in tool use mode and name is empty, treat as continuation (delta).
|
||||
if (*param).(*Params).ResponseType == 3 && fcName == "" {
|
||||
if (*param).(*Params).ResponseType == 3 && upstreamToolName == "" {
|
||||
if fcArgsResult := functionCallResult.Get("args"); fcArgsResult.Exists() {
|
||||
output = output + "event: content_block_delta\n"
|
||||
data, _ := sjson.Set(fmt.Sprintf(`{"type":"content_block_delta","index":%d,"delta":{"type":"input_json_delta","partial_json":""}}`, (*param).(*Params).ResponseIndex), "delta.partial_json", fcArgsResult.Raw)
|
||||
@@ -221,8 +224,8 @@ func ConvertGeminiResponseToClaude(_ context.Context, _ string, originalRequestR
|
||||
|
||||
// Create the tool use block with unique ID and function details
|
||||
data := fmt.Sprintf(`{"type":"content_block_start","index":%d,"content_block":{"type":"tool_use","id":"","name":"","input":{}}}`, (*param).(*Params).ResponseIndex)
|
||||
data, _ = sjson.Set(data, "content_block.id", fmt.Sprintf("%s-%d-%d", fcName, time.Now().UnixNano(), atomic.AddUint64(&toolUseIDCounter, 1)))
|
||||
data, _ = sjson.Set(data, "content_block.name", fcName)
|
||||
data, _ = sjson.Set(data, "content_block.id", fmt.Sprintf("%s-%d", upstreamToolName, atomic.AddUint64(&toolUseIDCounter, 1)))
|
||||
data, _ = sjson.Set(data, "content_block.name", clientToolName)
|
||||
output = output + fmt.Sprintf("data: %s\n\n\n", data)
|
||||
|
||||
if fcArgsResult := functionCallResult.Get("args"); fcArgsResult.Exists() {
|
||||
@@ -249,7 +252,7 @@ func ConvertGeminiResponseToClaude(_ context.Context, _ string, originalRequestR
|
||||
output = output + `data: `
|
||||
|
||||
template := `{"type":"message_delta","delta":{"stop_reason":"end_turn","stop_sequence":null},"usage":{"input_tokens":0,"output_tokens":0}}`
|
||||
if usedTool {
|
||||
if (*param).(*Params).SawToolCall {
|
||||
template = `{"type":"message_delta","delta":{"stop_reason":"tool_use","stop_sequence":null},"usage":{"input_tokens":0,"output_tokens":0}}`
|
||||
} else if finish := gjson.GetBytes(rawJSON, "candidates.0.finishReason"); finish.Exists() && finish.String() == "MAX_TOKENS" {
|
||||
template = `{"type":"message_delta","delta":{"stop_reason":"max_tokens","stop_sequence":null},"usage":{"input_tokens":0,"output_tokens":0}}`
|
||||
@@ -278,10 +281,10 @@ func ConvertGeminiResponseToClaude(_ context.Context, _ string, originalRequestR
|
||||
// Returns:
|
||||
// - string: A Claude-compatible JSON response.
|
||||
func ConvertGeminiResponseToClaudeNonStream(_ context.Context, _ string, originalRequestRawJSON, requestRawJSON, rawJSON []byte, _ *any) string {
|
||||
_ = originalRequestRawJSON
|
||||
_ = requestRawJSON
|
||||
|
||||
root := gjson.ParseBytes(rawJSON)
|
||||
toolNameMap := util.ToolNameMapFromClaudeRequest(originalRequestRawJSON)
|
||||
|
||||
out := `{"id":"","type":"message","role":"assistant","model":"","content":[],"stop_reason":null,"stop_sequence":null,"usage":{"input_tokens":0,"output_tokens":0}}`
|
||||
out, _ = sjson.Set(out, "id", root.Get("responseId").String())
|
||||
@@ -336,11 +339,12 @@ func ConvertGeminiResponseToClaudeNonStream(_ context.Context, _ string, origina
|
||||
flushText()
|
||||
hasToolCall = true
|
||||
|
||||
name := functionCall.Get("name").String()
|
||||
upstreamToolName := functionCall.Get("name").String()
|
||||
clientToolName := util.MapToolName(toolNameMap, upstreamToolName)
|
||||
toolIDCounter++
|
||||
toolBlock := `{"type":"tool_use","id":"","name":"","input":{}}`
|
||||
toolBlock, _ = sjson.Set(toolBlock, "id", fmt.Sprintf("tool_%d", toolIDCounter))
|
||||
toolBlock, _ = sjson.Set(toolBlock, "name", name)
|
||||
toolBlock, _ = sjson.Set(toolBlock, "id", fmt.Sprintf("%s-%d", upstreamToolName, toolIDCounter))
|
||||
toolBlock, _ = sjson.Set(toolBlock, "name", clientToolName)
|
||||
inputRaw := "{}"
|
||||
if args := functionCall.Get("args"); args.Exists() && gjson.Valid(args.Raw) && args.IsObject() {
|
||||
inputRaw = args.Raw
|
||||
|
||||
@@ -22,9 +22,11 @@ var (
|
||||
|
||||
// ConvertOpenAIResponseToAnthropicParams holds parameters for response conversion
|
||||
type ConvertOpenAIResponseToAnthropicParams struct {
|
||||
MessageID string
|
||||
Model string
|
||||
CreatedAt int64
|
||||
MessageID string
|
||||
Model string
|
||||
CreatedAt int64
|
||||
ToolNameMap map[string]string
|
||||
SawToolCall bool
|
||||
// Content accumulator for streaming
|
||||
ContentAccumulator strings.Builder
|
||||
// Tool calls accumulator for streaming
|
||||
@@ -78,6 +80,8 @@ func ConvertOpenAIResponseToClaude(_ context.Context, _ string, originalRequestR
|
||||
MessageID: "",
|
||||
Model: "",
|
||||
CreatedAt: 0,
|
||||
ToolNameMap: nil,
|
||||
SawToolCall: false,
|
||||
ContentAccumulator: strings.Builder{},
|
||||
ToolCallsAccumulator: nil,
|
||||
TextContentBlockStarted: false,
|
||||
@@ -97,6 +101,10 @@ func ConvertOpenAIResponseToClaude(_ context.Context, _ string, originalRequestR
|
||||
}
|
||||
rawJSON = bytes.TrimSpace(rawJSON[5:])
|
||||
|
||||
if (*param).(*ConvertOpenAIResponseToAnthropicParams).ToolNameMap == nil {
|
||||
(*param).(*ConvertOpenAIResponseToAnthropicParams).ToolNameMap = util.ToolNameMapFromClaudeRequest(originalRequestRawJSON)
|
||||
}
|
||||
|
||||
// Check if this is the [DONE] marker
|
||||
rawStr := strings.TrimSpace(string(rawJSON))
|
||||
if rawStr == "[DONE]" {
|
||||
@@ -111,6 +119,16 @@ func ConvertOpenAIResponseToClaude(_ context.Context, _ string, originalRequestR
|
||||
}
|
||||
}
|
||||
|
||||
func effectiveOpenAIFinishReason(param *ConvertOpenAIResponseToAnthropicParams) string {
|
||||
if param == nil {
|
||||
return ""
|
||||
}
|
||||
if param.SawToolCall {
|
||||
return "tool_calls"
|
||||
}
|
||||
return param.FinishReason
|
||||
}
|
||||
|
||||
// convertOpenAIStreamingChunkToAnthropic converts OpenAI streaming chunk to Anthropic streaming events
|
||||
func convertOpenAIStreamingChunkToAnthropic(rawJSON []byte, param *ConvertOpenAIResponseToAnthropicParams) []string {
|
||||
root := gjson.ParseBytes(rawJSON)
|
||||
@@ -197,6 +215,7 @@ func convertOpenAIStreamingChunkToAnthropic(rawJSON []byte, param *ConvertOpenAI
|
||||
}
|
||||
|
||||
toolCalls.ForEach(func(_, toolCall gjson.Result) bool {
|
||||
param.SawToolCall = true
|
||||
index := int(toolCall.Get("index").Int())
|
||||
blockIndex := param.toolContentBlockIndex(index)
|
||||
|
||||
@@ -215,7 +234,7 @@ func convertOpenAIStreamingChunkToAnthropic(rawJSON []byte, param *ConvertOpenAI
|
||||
// Handle function name
|
||||
if function := toolCall.Get("function"); function.Exists() {
|
||||
if name := function.Get("name"); name.Exists() {
|
||||
accumulator.Name = name.String()
|
||||
accumulator.Name = util.MapToolName(param.ToolNameMap, name.String())
|
||||
|
||||
stopThinkingContentBlock(param, &results)
|
||||
|
||||
@@ -246,7 +265,11 @@ func convertOpenAIStreamingChunkToAnthropic(rawJSON []byte, param *ConvertOpenAI
|
||||
// Handle finish_reason (but don't send message_delta/message_stop yet)
|
||||
if finishReason := root.Get("choices.0.finish_reason"); finishReason.Exists() && finishReason.String() != "" {
|
||||
reason := finishReason.String()
|
||||
param.FinishReason = reason
|
||||
if param.SawToolCall {
|
||||
param.FinishReason = "tool_calls"
|
||||
} else {
|
||||
param.FinishReason = reason
|
||||
}
|
||||
|
||||
// Send content_block_stop for thinking content if needed
|
||||
if param.ThinkingContentBlockStarted {
|
||||
@@ -294,7 +317,7 @@ func convertOpenAIStreamingChunkToAnthropic(rawJSON []byte, param *ConvertOpenAI
|
||||
inputTokens, outputTokens, cachedTokens = extractOpenAIUsage(usage)
|
||||
// Send message_delta with usage
|
||||
messageDeltaJSON := `{"type":"message_delta","delta":{"stop_reason":"","stop_sequence":null},"usage":{"input_tokens":0,"output_tokens":0}}`
|
||||
messageDeltaJSON, _ = sjson.Set(messageDeltaJSON, "delta.stop_reason", mapOpenAIFinishReasonToAnthropic(param.FinishReason))
|
||||
messageDeltaJSON, _ = sjson.Set(messageDeltaJSON, "delta.stop_reason", mapOpenAIFinishReasonToAnthropic(effectiveOpenAIFinishReason(param)))
|
||||
messageDeltaJSON, _ = sjson.Set(messageDeltaJSON, "usage.input_tokens", inputTokens)
|
||||
messageDeltaJSON, _ = sjson.Set(messageDeltaJSON, "usage.output_tokens", outputTokens)
|
||||
if cachedTokens > 0 {
|
||||
@@ -348,7 +371,7 @@ func convertOpenAIDoneToAnthropic(param *ConvertOpenAIResponseToAnthropicParams)
|
||||
// If we haven't sent message_delta yet (no usage info was received), send it now
|
||||
if param.FinishReason != "" && !param.MessageDeltaSent {
|
||||
messageDeltaJSON := `{"type":"message_delta","delta":{"stop_reason":"","stop_sequence":null},"usage":{"input_tokens":0,"output_tokens":0}}`
|
||||
messageDeltaJSON, _ = sjson.Set(messageDeltaJSON, "delta.stop_reason", mapOpenAIFinishReasonToAnthropic(param.FinishReason))
|
||||
messageDeltaJSON, _ = sjson.Set(messageDeltaJSON, "delta.stop_reason", mapOpenAIFinishReasonToAnthropic(effectiveOpenAIFinishReason(param)))
|
||||
results = append(results, "event: message_delta\ndata: "+messageDeltaJSON+"\n\n")
|
||||
param.MessageDeltaSent = true
|
||||
}
|
||||
@@ -531,10 +554,10 @@ func stopTextContentBlock(param *ConvertOpenAIResponseToAnthropicParams, results
|
||||
// Returns:
|
||||
// - string: An Anthropic-compatible JSON response.
|
||||
func ConvertOpenAIResponseToClaudeNonStream(_ context.Context, _ string, originalRequestRawJSON, requestRawJSON, rawJSON []byte, _ *any) string {
|
||||
_ = originalRequestRawJSON
|
||||
_ = requestRawJSON
|
||||
|
||||
root := gjson.ParseBytes(rawJSON)
|
||||
toolNameMap := util.ToolNameMapFromClaudeRequest(originalRequestRawJSON)
|
||||
out := `{"id":"","type":"message","role":"assistant","model":"","content":[],"stop_reason":null,"stop_sequence":null,"usage":{"input_tokens":0,"output_tokens":0}}`
|
||||
out, _ = sjson.Set(out, "id", root.Get("id").String())
|
||||
out, _ = sjson.Set(out, "model", root.Get("model").String())
|
||||
@@ -590,7 +613,7 @@ func ConvertOpenAIResponseToClaudeNonStream(_ context.Context, _ string, origina
|
||||
hasToolCall = true
|
||||
toolUse := `{"type":"tool_use","id":"","name":"","input":{}}`
|
||||
toolUse, _ = sjson.Set(toolUse, "id", tc.Get("id").String())
|
||||
toolUse, _ = sjson.Set(toolUse, "name", tc.Get("function.name").String())
|
||||
toolUse, _ = sjson.Set(toolUse, "name", util.MapToolName(toolNameMap, tc.Get("function.name").String()))
|
||||
|
||||
argsStr := util.FixJSON(tc.Get("function.arguments").String())
|
||||
if argsStr != "" && gjson.Valid(argsStr) {
|
||||
@@ -647,7 +670,7 @@ func ConvertOpenAIResponseToClaudeNonStream(_ context.Context, _ string, origina
|
||||
hasToolCall = true
|
||||
toolUseBlock := `{"type":"tool_use","id":"","name":"","input":{}}`
|
||||
toolUseBlock, _ = sjson.Set(toolUseBlock, "id", toolCall.Get("id").String())
|
||||
toolUseBlock, _ = sjson.Set(toolUseBlock, "name", toolCall.Get("function.name").String())
|
||||
toolUseBlock, _ = sjson.Set(toolUseBlock, "name", util.MapToolName(toolNameMap, toolCall.Get("function.name").String()))
|
||||
|
||||
argsStr := util.FixJSON(toolCall.Get("function.arguments").String())
|
||||
if argsStr != "" && gjson.Valid(argsStr) {
|
||||
|
||||
@@ -6,6 +6,7 @@ package util
|
||||
import (
|
||||
"bytes"
|
||||
"fmt"
|
||||
"strings"
|
||||
|
||||
"github.com/tidwall/gjson"
|
||||
"github.com/tidwall/sjson"
|
||||
@@ -219,3 +220,54 @@ func FixJSON(input string) string {
|
||||
|
||||
return out.String()
|
||||
}
|
||||
|
||||
func CanonicalToolName(name string) string {
|
||||
canonical := strings.TrimSpace(name)
|
||||
canonical = strings.TrimLeft(canonical, "_")
|
||||
return strings.ToLower(canonical)
|
||||
}
|
||||
|
||||
// ToolNameMapFromClaudeRequest returns a canonical-name -> original-name map extracted from a Claude request.
|
||||
// It is used to restore exact tool name casing for clients that require strict tool name matching (e.g. Claude Code).
|
||||
func ToolNameMapFromClaudeRequest(rawJSON []byte) map[string]string {
|
||||
if len(rawJSON) == 0 || !gjson.ValidBytes(rawJSON) {
|
||||
return nil
|
||||
}
|
||||
|
||||
tools := gjson.GetBytes(rawJSON, "tools")
|
||||
if !tools.Exists() || !tools.IsArray() {
|
||||
return nil
|
||||
}
|
||||
|
||||
toolResults := tools.Array()
|
||||
out := make(map[string]string, len(toolResults))
|
||||
tools.ForEach(func(_, tool gjson.Result) bool {
|
||||
name := strings.TrimSpace(tool.Get("name").String())
|
||||
if name == "" {
|
||||
return true
|
||||
}
|
||||
key := CanonicalToolName(name)
|
||||
if key == "" {
|
||||
return true
|
||||
}
|
||||
if _, exists := out[key]; !exists {
|
||||
out[key] = name
|
||||
}
|
||||
return true
|
||||
})
|
||||
|
||||
if len(out) == 0 {
|
||||
return nil
|
||||
}
|
||||
return out
|
||||
}
|
||||
|
||||
func MapToolName(toolNameMap map[string]string, name string) string {
|
||||
if name == "" || toolNameMap == nil {
|
||||
return name
|
||||
}
|
||||
if mapped, ok := toolNameMap[CanonicalToolName(name)]; ok && mapped != "" {
|
||||
return mapped
|
||||
}
|
||||
return name
|
||||
}
|
||||
|
||||
@@ -183,7 +183,7 @@ func (w *Watcher) addOrUpdateClient(path string) {
|
||||
|
||||
if w.reloadCallback != nil {
|
||||
log.Debugf("triggering server update callback after add/update")
|
||||
w.reloadCallback(cfg)
|
||||
w.triggerServerUpdate(cfg)
|
||||
}
|
||||
w.persistAuthAsync(fmt.Sprintf("Sync auth %s", filepath.Base(path)), path)
|
||||
}
|
||||
@@ -202,7 +202,7 @@ func (w *Watcher) removeClient(path string) {
|
||||
|
||||
if w.reloadCallback != nil {
|
||||
log.Debugf("triggering server update callback after removal")
|
||||
w.reloadCallback(cfg)
|
||||
w.triggerServerUpdate(cfg)
|
||||
}
|
||||
w.persistAuthAsync(fmt.Sprintf("Remove auth %s", filepath.Base(path)), path)
|
||||
}
|
||||
@@ -303,3 +303,79 @@ func (w *Watcher) persistAuthAsync(message string, paths ...string) {
|
||||
}
|
||||
}()
|
||||
}
|
||||
|
||||
func (w *Watcher) stopServerUpdateTimer() {
|
||||
w.serverUpdateMu.Lock()
|
||||
defer w.serverUpdateMu.Unlock()
|
||||
if w.serverUpdateTimer != nil {
|
||||
w.serverUpdateTimer.Stop()
|
||||
w.serverUpdateTimer = nil
|
||||
}
|
||||
w.serverUpdatePend = false
|
||||
}
|
||||
|
||||
func (w *Watcher) triggerServerUpdate(cfg *config.Config) {
|
||||
if w == nil || w.reloadCallback == nil || cfg == nil {
|
||||
return
|
||||
}
|
||||
if w.stopped.Load() {
|
||||
return
|
||||
}
|
||||
|
||||
now := time.Now()
|
||||
|
||||
w.serverUpdateMu.Lock()
|
||||
if w.serverUpdateLast.IsZero() || now.Sub(w.serverUpdateLast) >= serverUpdateDebounce {
|
||||
w.serverUpdateLast = now
|
||||
if w.serverUpdateTimer != nil {
|
||||
w.serverUpdateTimer.Stop()
|
||||
w.serverUpdateTimer = nil
|
||||
}
|
||||
w.serverUpdatePend = false
|
||||
w.serverUpdateMu.Unlock()
|
||||
w.reloadCallback(cfg)
|
||||
return
|
||||
}
|
||||
|
||||
if w.serverUpdatePend {
|
||||
w.serverUpdateMu.Unlock()
|
||||
return
|
||||
}
|
||||
|
||||
delay := serverUpdateDebounce - now.Sub(w.serverUpdateLast)
|
||||
if delay < 10*time.Millisecond {
|
||||
delay = 10 * time.Millisecond
|
||||
}
|
||||
w.serverUpdatePend = true
|
||||
if w.serverUpdateTimer != nil {
|
||||
w.serverUpdateTimer.Stop()
|
||||
w.serverUpdateTimer = nil
|
||||
}
|
||||
var timer *time.Timer
|
||||
timer = time.AfterFunc(delay, func() {
|
||||
if w.stopped.Load() {
|
||||
return
|
||||
}
|
||||
w.clientsMutex.RLock()
|
||||
latestCfg := w.config
|
||||
w.clientsMutex.RUnlock()
|
||||
|
||||
w.serverUpdateMu.Lock()
|
||||
if w.serverUpdateTimer != timer || !w.serverUpdatePend {
|
||||
w.serverUpdateMu.Unlock()
|
||||
return
|
||||
}
|
||||
w.serverUpdateTimer = nil
|
||||
w.serverUpdatePend = false
|
||||
if latestCfg == nil || w.reloadCallback == nil || w.stopped.Load() {
|
||||
w.serverUpdateMu.Unlock()
|
||||
return
|
||||
}
|
||||
|
||||
w.serverUpdateLast = time.Now()
|
||||
w.serverUpdateMu.Unlock()
|
||||
w.reloadCallback(latestCfg)
|
||||
})
|
||||
w.serverUpdateTimer = timer
|
||||
w.serverUpdateMu.Unlock()
|
||||
}
|
||||
|
||||
@@ -6,6 +6,7 @@ import (
|
||||
"context"
|
||||
"strings"
|
||||
"sync"
|
||||
"sync/atomic"
|
||||
"time"
|
||||
|
||||
"github.com/fsnotify/fsnotify"
|
||||
@@ -35,6 +36,11 @@ type Watcher struct {
|
||||
clientsMutex sync.RWMutex
|
||||
configReloadMu sync.Mutex
|
||||
configReloadTimer *time.Timer
|
||||
serverUpdateMu sync.Mutex
|
||||
serverUpdateTimer *time.Timer
|
||||
serverUpdateLast time.Time
|
||||
serverUpdatePend bool
|
||||
stopped atomic.Bool
|
||||
reloadCallback func(*config.Config)
|
||||
watcher *fsnotify.Watcher
|
||||
lastAuthHashes map[string]string
|
||||
@@ -76,6 +82,7 @@ const (
|
||||
replaceCheckDelay = 50 * time.Millisecond
|
||||
configReloadDebounce = 150 * time.Millisecond
|
||||
authRemoveDebounceWindow = 1 * time.Second
|
||||
serverUpdateDebounce = 1 * time.Second
|
||||
)
|
||||
|
||||
// NewWatcher creates a new file watcher instance
|
||||
@@ -114,8 +121,10 @@ func (w *Watcher) Start(ctx context.Context) error {
|
||||
|
||||
// Stop stops the file watcher
|
||||
func (w *Watcher) Stop() error {
|
||||
w.stopped.Store(true)
|
||||
w.stopDispatch()
|
||||
w.stopConfigReloadTimer()
|
||||
w.stopServerUpdateTimer()
|
||||
return w.watcher.Close()
|
||||
}
|
||||
|
||||
|
||||
@@ -441,6 +441,46 @@ func TestRemoveClientRemovesHash(t *testing.T) {
|
||||
}
|
||||
}
|
||||
|
||||
func TestTriggerServerUpdateCancelsPendingTimerOnImmediate(t *testing.T) {
|
||||
tmpDir := t.TempDir()
|
||||
cfg := &config.Config{AuthDir: tmpDir}
|
||||
|
||||
var reloads int32
|
||||
w := &Watcher{
|
||||
reloadCallback: func(*config.Config) {
|
||||
atomic.AddInt32(&reloads, 1)
|
||||
},
|
||||
}
|
||||
w.SetConfig(cfg)
|
||||
|
||||
w.serverUpdateMu.Lock()
|
||||
w.serverUpdateLast = time.Now().Add(-(serverUpdateDebounce - 100*time.Millisecond))
|
||||
w.serverUpdateMu.Unlock()
|
||||
w.triggerServerUpdate(cfg)
|
||||
|
||||
if got := atomic.LoadInt32(&reloads); got != 0 {
|
||||
t.Fatalf("expected no immediate reload, got %d", got)
|
||||
}
|
||||
|
||||
w.serverUpdateMu.Lock()
|
||||
if !w.serverUpdatePend || w.serverUpdateTimer == nil {
|
||||
w.serverUpdateMu.Unlock()
|
||||
t.Fatal("expected a pending server update timer")
|
||||
}
|
||||
w.serverUpdateLast = time.Now().Add(-(serverUpdateDebounce + 10*time.Millisecond))
|
||||
w.serverUpdateMu.Unlock()
|
||||
|
||||
w.triggerServerUpdate(cfg)
|
||||
if got := atomic.LoadInt32(&reloads); got != 1 {
|
||||
t.Fatalf("expected immediate reload once, got %d", got)
|
||||
}
|
||||
|
||||
time.Sleep(250 * time.Millisecond)
|
||||
if got := atomic.LoadInt32(&reloads); got != 1 {
|
||||
t.Fatalf("expected pending timer to be cancelled, got %d reloads", got)
|
||||
}
|
||||
}
|
||||
|
||||
func TestShouldDebounceRemove(t *testing.T) {
|
||||
w := &Watcher{}
|
||||
path := filepath.Clean("test.json")
|
||||
|
||||
Reference in New Issue
Block a user