From 3dd50957923703eb7ccebcc8d921954046b91e04 Mon Sep 17 00:00:00 2001 From: Luis Pater Date: Wed, 24 Sep 2025 11:59:38 +0800 Subject: [PATCH] feat(translators): add token counting support for Claude and Gemini responses - Implemented `TokenCount` transform method across translators to calculate token usage. - Integrated token counting logic into executor pipelines for Claude, Gemini, and CLI translators. - Added corresponding API endpoints and handlers (`/messages/count_tokens`) for token usage retrieval. - Enhanced translation registry to support `TokenCount` functionality alongside existing response types. --- internal/api/handlers/claude/code_handlers.go | 37 +++++++++++ internal/api/server.go | 1 + internal/runtime/executor/claude_executor.go | 64 ++++++++++++++++++- internal/runtime/executor/codex_executor.go | 8 ++- .../runtime/executor/gemini_cli_executor.go | 6 +- internal/runtime/executor/gemini_executor.go | 7 +- .../gemini-cli/claude_gemini-cli_response.go | 5 +- internal/translator/claude/gemini-cli/init.go | 5 +- .../claude/gemini/claude_gemini_response.go | 5 ++ internal/translator/claude/gemini/init.go | 5 +- .../claude/gemini-cli_claude_response.go | 4 ++ internal/translator/gemini-cli/claude/init.go | 5 +- .../gemini/gemini_gemini-cli_request.go | 5 ++ internal/translator/gemini-cli/gemini/init.go | 5 +- .../gemini/claude/gemini_claude_response.go | 4 ++ internal/translator/gemini/claude/init.go | 5 +- .../gemini-cli/gemini_gemini-cli_response.go | 6 ++ internal/translator/gemini/gemini-cli/init.go | 5 +- .../gemini/gemini/gemini_gemini_response.go | 5 ++ internal/translator/gemini/gemini/init.go | 5 +- sdk/translator/registry.go | 18 ++++++ sdk/translator/types.go | 7 +- 22 files changed, 192 insertions(+), 25 deletions(-) diff --git a/internal/api/handlers/claude/code_handlers.go b/internal/api/handlers/claude/code_handlers.go index 4b848ae3..1de542dc 100644 --- a/internal/api/handlers/claude/code_handlers.go +++ b/internal/api/handlers/claude/code_handlers.go @@ -82,6 +82,43 @@ func (h *ClaudeCodeAPIHandler) ClaudeMessages(c *gin.Context) { } } +// ClaudeMessages handles Claude-compatible streaming chat completions. +// This function implements a sophisticated client rotation and quota management system +// to ensure high availability and optimal resource utilization across multiple backend clients. +// +// Parameters: +// - c: The Gin context for the request. +func (h *ClaudeCodeAPIHandler) ClaudeCountTokens(c *gin.Context) { + // Extract raw JSON data from the incoming request + rawJSON, err := c.GetRawData() + // If data retrieval fails, return a 400 Bad Request error. + if err != nil { + c.JSON(http.StatusBadRequest, handlers.ErrorResponse{ + Error: handlers.ErrorDetail{ + Message: fmt.Sprintf("Invalid request: %v", err), + Type: "invalid_request_error", + }, + }) + return + } + + c.Header("Content-Type", "application/json") + + alt := h.GetAlt(c) + cliCtx, cliCancel := h.GetContextWithCancel(h, c, context.Background()) + + modelName := gjson.GetBytes(rawJSON, "model").String() + + resp, errMsg := h.ExecuteCountWithAuthManager(cliCtx, h.HandlerType(), modelName, rawJSON, alt) + if errMsg != nil { + h.WriteErrorResponse(c, errMsg) + cliCancel(errMsg.Error) + return + } + _, _ = c.Writer.Write(resp) + cliCancel() +} + // ClaudeModels handles the Claude models listing endpoint. // It returns a JSON response containing available Claude models and their specifications. // diff --git a/internal/api/server.go b/internal/api/server.go index 3067ecad..bea114a9 100644 --- a/internal/api/server.go +++ b/internal/api/server.go @@ -198,6 +198,7 @@ func (s *Server) setupRoutes() { v1.POST("/chat/completions", openaiHandlers.ChatCompletions) v1.POST("/completions", openaiHandlers.Completions) v1.POST("/messages", claudeCodeHandlers.ClaudeMessages) + v1.POST("/messages/count_tokens", claudeCodeHandlers.ClaudeCountTokens) v1.POST("/responses", openaiResponsesHandlers.Responses) } diff --git a/internal/runtime/executor/claude_executor.go b/internal/runtime/executor/claude_executor.go index fdd7571a..2bea5d5e 100644 --- a/internal/runtime/executor/claude_executor.go +++ b/internal/runtime/executor/claude_executor.go @@ -18,6 +18,7 @@ import ( cliproxyexecutor "github.com/router-for-me/CLIProxyAPI/v6/sdk/cliproxy/executor" sdktranslator "github.com/router-for-me/CLIProxyAPI/v6/sdk/translator" log "github.com/sirupsen/logrus" + "github.com/tidwall/gjson" "github.com/tidwall/sjson" "github.com/gin-gonic/gin" @@ -175,7 +176,68 @@ func (e *ClaudeExecutor) ExecuteStream(ctx context.Context, auth *cliproxyauth.A } func (e *ClaudeExecutor) CountTokens(ctx context.Context, auth *cliproxyauth.Auth, req cliproxyexecutor.Request, opts cliproxyexecutor.Options) (cliproxyexecutor.Response, error) { - return cliproxyexecutor.Response{Payload: []byte{}}, fmt.Errorf("not implemented") + apiKey, baseURL := claudeCreds(auth) + if apiKey == "" { + return NewClientAdapter("claude").Execute(ctx, auth, req, opts) + } + if baseURL == "" { + baseURL = "https://api.anthropic.com" + } + + from := opts.SourceFormat + to := sdktranslator.FromString("claude") + // Use streaming translation to preserve function calling, except for claude. + stream := from != to + body := sdktranslator.TranslateRequest(from, to, req.Model, bytes.Clone(req.Payload), stream) + + if !strings.HasPrefix(req.Model, "claude-3-5-haiku") { + body, _ = sjson.SetRawBytes(body, "system", []byte(misc.ClaudeCodeInstructions)) + } + + url := fmt.Sprintf("%s/v1/messages/count_tokens?beta=true", baseURL) + recordAPIRequest(ctx, e.cfg, body) + httpReq, err := http.NewRequestWithContext(ctx, http.MethodPost, url, bytes.NewReader(body)) + if err != nil { + return cliproxyexecutor.Response{}, err + } + applyClaudeHeaders(httpReq, apiKey, false) + + httpClient := &http.Client{} + if rt, ok := ctx.Value("cliproxy.roundtripper").(http.RoundTripper); ok && rt != nil { + httpClient.Transport = rt + } + resp, err := httpClient.Do(httpReq) + if err != nil { + return cliproxyexecutor.Response{}, err + } + defer func() { + if errClose := resp.Body.Close(); errClose != nil { + log.Errorf("response body close error: %v", errClose) + } + }() + if resp.StatusCode < 200 || resp.StatusCode >= 300 { + b, _ := io.ReadAll(resp.Body) + appendAPIResponseChunk(ctx, e.cfg, b) + return cliproxyexecutor.Response{}, statusErr{code: resp.StatusCode, msg: string(b)} + } + reader := io.Reader(resp.Body) + var decoder *zstd.Decoder + if hasZSTDEcoding(resp.Header.Get("Content-Encoding")) { + decoder, err = zstd.NewReader(resp.Body) + if err != nil { + return cliproxyexecutor.Response{}, fmt.Errorf("failed to initialize zstd decoder: %w", err) + } + reader = decoder + defer decoder.Close() + } + data, err := io.ReadAll(reader) + if err != nil { + return cliproxyexecutor.Response{}, err + } + appendAPIResponseChunk(ctx, e.cfg, data) + count := gjson.GetBytes(data, "input_tokens").Int() + out := sdktranslator.TranslateTokenCount(ctx, to, from, count, data) + return cliproxyexecutor.Response{Payload: []byte(out)}, nil } func (e *ClaudeExecutor) Refresh(ctx context.Context, auth *cliproxyauth.Auth) (*cliproxyauth.Auth, error) { diff --git a/internal/runtime/executor/codex_executor.go b/internal/runtime/executor/codex_executor.go index 441b0f38..5f4779e3 100644 --- a/internal/runtime/executor/codex_executor.go +++ b/internal/runtime/executor/codex_executor.go @@ -53,9 +53,11 @@ func (e *CodexExecutor) Execute(ctx context.Context, auth *cliproxyauth.Auth, re to := sdktranslator.FromString("codex") body := sdktranslator.TranslateRequest(from, to, req.Model, bytes.Clone(req.Payload), false) - if util.InArray([]string{"gpt-5-minimal", "gpt-5-low", "gpt-5-medium", "gpt-5-high"}, req.Model) { + if util.InArray([]string{"gpt-5", "gpt-5-minimal", "gpt-5-low", "gpt-5-medium", "gpt-5-high"}, req.Model) { body, _ = sjson.SetBytes(body, "model", "gpt-5") switch req.Model { + case "gpt-5": + body, _ = sjson.DeleteBytes(body, "reasoning.effort") case "gpt-5-minimal": body, _ = sjson.SetBytes(body, "reasoning.effort", "minimal") case "gpt-5-low": @@ -146,9 +148,11 @@ func (e *CodexExecutor) ExecuteStream(ctx context.Context, auth *cliproxyauth.Au to := sdktranslator.FromString("codex") body := sdktranslator.TranslateRequest(from, to, req.Model, bytes.Clone(req.Payload), true) - if util.InArray([]string{"gpt-5-minimal", "gpt-5-low", "gpt-5-medium", "gpt-5-high"}, req.Model) { + if util.InArray([]string{"gpt-5", "gpt-5-minimal", "gpt-5-low", "gpt-5-medium", "gpt-5-high"}, req.Model) { body, _ = sjson.SetBytes(body, "model", "gpt-5") switch req.Model { + case "gpt-5": + body, _ = sjson.DeleteBytes(body, "reasoning.effort") case "gpt-5-minimal": body, _ = sjson.SetBytes(body, "reasoning.effort", "minimal") case "gpt-5-low": diff --git a/internal/runtime/executor/gemini_cli_executor.go b/internal/runtime/executor/gemini_cli_executor.go index 7284a570..876eafd4 100644 --- a/internal/runtime/executor/gemini_cli_executor.go +++ b/internal/runtime/executor/gemini_cli_executor.go @@ -18,6 +18,7 @@ import ( cliproxyexecutor "github.com/router-for-me/CLIProxyAPI/v6/sdk/cliproxy/executor" sdktranslator "github.com/router-for-me/CLIProxyAPI/v6/sdk/translator" log "github.com/sirupsen/logrus" + "github.com/tidwall/gjson" "github.com/tidwall/sjson" "golang.org/x/oauth2" "golang.org/x/oauth2/google" @@ -156,7 +157,6 @@ func (e *GeminiCLIExecutor) ExecuteStream(ctx context.Context, auth *cliproxyaut httpClient := newHTTPClient(ctx, 0) respCtx := context.WithValue(ctx, "alt", opts.Alt) - dataTag := []byte("data:") var lastStatus int var lastBody []byte @@ -321,8 +321,8 @@ func (e *GeminiCLIExecutor) CountTokens(ctx context.Context, auth *cliproxyauth. _ = resp.Body.Close() appendAPIResponseChunk(ctx, e.cfg, data) if resp.StatusCode >= 200 && resp.StatusCode < 300 { - var param any - translated := sdktranslator.TranslateNonStream(respCtx, to, from, attemptModel, bytes.Clone(opts.OriginalRequest), payload, data, ¶m) + count := gjson.GetBytes(data, "totalTokens").Int() + translated := sdktranslator.TranslateTokenCount(respCtx, to, from, count, data) return cliproxyexecutor.Response{Payload: []byte(translated)}, nil } lastStatus = resp.StatusCode diff --git a/internal/runtime/executor/gemini_executor.go b/internal/runtime/executor/gemini_executor.go index f652f952..2c46f5b9 100644 --- a/internal/runtime/executor/gemini_executor.go +++ b/internal/runtime/executor/gemini_executor.go @@ -15,6 +15,7 @@ import ( cliproxyexecutor "github.com/router-for-me/CLIProxyAPI/v6/sdk/cliproxy/executor" sdktranslator "github.com/router-for-me/CLIProxyAPI/v6/sdk/translator" log "github.com/sirupsen/logrus" + "github.com/tidwall/gjson" "github.com/tidwall/sjson" "golang.org/x/oauth2" "golang.org/x/oauth2/google" @@ -182,9 +183,11 @@ func (e *GeminiExecutor) CountTokens(ctx context.Context, auth *cliproxyauth.Aut translatedReq := sdktranslator.TranslateRequest(from, to, req.Model, bytes.Clone(req.Payload), false) respCtx := context.WithValue(ctx, "alt", opts.Alt) translatedReq, _ = sjson.DeleteBytes(translatedReq, "tools") + translatedReq, _ = sjson.DeleteBytes(translatedReq, "generationConfig") url := fmt.Sprintf("%s/%s/models/%s:%s", glEndpoint, glAPIVersion, req.Model, "countTokens") recordAPIRequest(ctx, e.cfg, translatedReq) + requestBody := bytes.NewReader(translatedReq) httpReq, err := http.NewRequestWithContext(ctx, http.MethodPost, url, requestBody) @@ -218,8 +221,8 @@ func (e *GeminiExecutor) CountTokens(ctx context.Context, auth *cliproxyauth.Aut return cliproxyexecutor.Response{}, statusErr{code: resp.StatusCode, msg: string(data)} } - var param any - translated := sdktranslator.TranslateNonStream(respCtx, to, from, req.Model, bytes.Clone(opts.OriginalRequest), translatedReq, data, ¶m) + count := gjson.GetBytes(data, "totalTokens").Int() + translated := sdktranslator.TranslateTokenCount(respCtx, to, from, count, data) return cliproxyexecutor.Response{Payload: []byte(translated)}, nil } diff --git a/internal/translator/claude/gemini-cli/claude_gemini-cli_response.go b/internal/translator/claude/gemini-cli/claude_gemini-cli_response.go index 11521a9a..bc072b30 100644 --- a/internal/translator/claude/gemini-cli/claude_gemini-cli_response.go +++ b/internal/translator/claude/gemini-cli/claude_gemini-cli_response.go @@ -54,5 +54,8 @@ func ConvertClaudeResponseToGeminiCLINonStream(ctx context.Context, modelName st json := `{"response": {}}` strJSON, _ = sjson.SetRaw(json, "response", strJSON) return strJSON - +} + +func GeminiCLITokenCount(ctx context.Context, count int64) string { + return GeminiTokenCount(ctx, count) } diff --git a/internal/translator/claude/gemini-cli/init.go b/internal/translator/claude/gemini-cli/init.go index 8a7b822c..ca364a6e 100644 --- a/internal/translator/claude/gemini-cli/init.go +++ b/internal/translator/claude/gemini-cli/init.go @@ -12,8 +12,9 @@ func init() { Claude, ConvertGeminiCLIRequestToClaude, interfaces.TranslateResponse{ - Stream: ConvertClaudeResponseToGeminiCLI, - NonStream: ConvertClaudeResponseToGeminiCLINonStream, + Stream: ConvertClaudeResponseToGeminiCLI, + NonStream: ConvertClaudeResponseToGeminiCLINonStream, + TokenCount: GeminiCLITokenCount, }, ) } diff --git a/internal/translator/claude/gemini/claude_gemini_response.go b/internal/translator/claude/gemini/claude_gemini_response.go index 74de0c0b..23950fdb 100644 --- a/internal/translator/claude/gemini/claude_gemini_response.go +++ b/internal/translator/claude/gemini/claude_gemini_response.go @@ -9,6 +9,7 @@ import ( "bufio" "bytes" "context" + "fmt" "strings" "time" @@ -530,6 +531,10 @@ func ConvertClaudeResponseToGeminiNonStream(_ context.Context, modelName string, return template } +func GeminiTokenCount(ctx context.Context, count int64) string { + return fmt.Sprintf(`{"totalTokens":%d,"promptTokensDetails":[{"modality":"TEXT","tokenCount":%d}]}`, count, count) +} + // consolidateParts merges consecutive text parts and thinking parts to create a cleaner response. // This function processes the parts array to combine adjacent text elements and thinking elements // into single consolidated parts, which results in a more readable and efficient response structure. diff --git a/internal/translator/claude/gemini/init.go b/internal/translator/claude/gemini/init.go index 00d75ac9..8924f62c 100644 --- a/internal/translator/claude/gemini/init.go +++ b/internal/translator/claude/gemini/init.go @@ -12,8 +12,9 @@ func init() { Claude, ConvertGeminiRequestToClaude, interfaces.TranslateResponse{ - Stream: ConvertClaudeResponseToGemini, - NonStream: ConvertClaudeResponseToGeminiNonStream, + Stream: ConvertClaudeResponseToGemini, + NonStream: ConvertClaudeResponseToGeminiNonStream, + TokenCount: GeminiTokenCount, }, ) } diff --git a/internal/translator/gemini-cli/claude/gemini-cli_claude_response.go b/internal/translator/gemini-cli/claude/gemini-cli_claude_response.go index 8f0b3829..733668f3 100644 --- a/internal/translator/gemini-cli/claude/gemini-cli_claude_response.go +++ b/internal/translator/gemini-cli/claude/gemini-cli_claude_response.go @@ -376,3 +376,7 @@ func ConvertGeminiCLIResponseToClaudeNonStream(_ context.Context, _ string, orig } return string(encoded) } + +func ClaudeTokenCount(ctx context.Context, count int64) string { + return fmt.Sprintf(`{"input_tokens":%d}`, count) +} diff --git a/internal/translator/gemini-cli/claude/init.go b/internal/translator/gemini-cli/claude/init.go index 7899d710..79ed03c6 100644 --- a/internal/translator/gemini-cli/claude/init.go +++ b/internal/translator/gemini-cli/claude/init.go @@ -12,8 +12,9 @@ func init() { GeminiCLI, ConvertClaudeRequestToCLI, interfaces.TranslateResponse{ - Stream: ConvertGeminiCLIResponseToClaude, - NonStream: ConvertGeminiCLIResponseToClaudeNonStream, + Stream: ConvertGeminiCLIResponseToClaude, + NonStream: ConvertGeminiCLIResponseToClaudeNonStream, + TokenCount: ClaudeTokenCount, }, ) } diff --git a/internal/translator/gemini-cli/gemini/gemini_gemini-cli_request.go b/internal/translator/gemini-cli/gemini/gemini_gemini-cli_request.go index 8e765648..fc90105b 100644 --- a/internal/translator/gemini-cli/gemini/gemini_gemini-cli_request.go +++ b/internal/translator/gemini-cli/gemini/gemini_gemini-cli_request.go @@ -7,6 +7,7 @@ package gemini import ( "context" + "fmt" "github.com/tidwall/gjson" "github.com/tidwall/sjson" @@ -74,3 +75,7 @@ func ConvertGeminiCliRequestToGeminiNonStream(_ context.Context, _ string, origi } return string(rawJSON) } + +func GeminiTokenCount(ctx context.Context, count int64) string { + return fmt.Sprintf(`{"totalTokens":%d,"promptTokensDetails":[{"modality":"TEXT","tokenCount":%d}]}`, count, count) +} diff --git a/internal/translator/gemini-cli/gemini/init.go b/internal/translator/gemini-cli/gemini/init.go index 2a372ea6..934edddb 100644 --- a/internal/translator/gemini-cli/gemini/init.go +++ b/internal/translator/gemini-cli/gemini/init.go @@ -12,8 +12,9 @@ func init() { GeminiCLI, ConvertGeminiRequestToGeminiCLI, interfaces.TranslateResponse{ - Stream: ConvertGeminiCliRequestToGemini, - NonStream: ConvertGeminiCliRequestToGeminiNonStream, + Stream: ConvertGeminiCliRequestToGemini, + NonStream: ConvertGeminiCliRequestToGeminiNonStream, + TokenCount: GeminiTokenCount, }, ) } diff --git a/internal/translator/gemini/claude/gemini_claude_response.go b/internal/translator/gemini/claude/gemini_claude_response.go index 824e3519..a80171a9 100644 --- a/internal/translator/gemini/claude/gemini_claude_response.go +++ b/internal/translator/gemini/claude/gemini_claude_response.go @@ -370,3 +370,7 @@ func ConvertGeminiResponseToClaudeNonStream(_ context.Context, _ string, origina } return string(encoded) } + +func ClaudeTokenCount(ctx context.Context, count int64) string { + return fmt.Sprintf(`{"input_tokens":%d}`, count) +} diff --git a/internal/translator/gemini/claude/init.go b/internal/translator/gemini/claude/init.go index 89b663b9..66fe51e7 100644 --- a/internal/translator/gemini/claude/init.go +++ b/internal/translator/gemini/claude/init.go @@ -12,8 +12,9 @@ func init() { Gemini, ConvertClaudeRequestToGemini, interfaces.TranslateResponse{ - Stream: ConvertGeminiResponseToClaude, - NonStream: ConvertGeminiResponseToClaudeNonStream, + Stream: ConvertGeminiResponseToClaude, + NonStream: ConvertGeminiResponseToClaudeNonStream, + TokenCount: ClaudeTokenCount, }, ) } diff --git a/internal/translator/gemini/gemini-cli/gemini_gemini-cli_response.go b/internal/translator/gemini/gemini-cli/gemini_gemini-cli_response.go index e1a10fc1..6bc038e2 100644 --- a/internal/translator/gemini/gemini-cli/gemini_gemini-cli_response.go +++ b/internal/translator/gemini/gemini-cli/gemini_gemini-cli_response.go @@ -7,6 +7,8 @@ package geminiCLI import ( "bytes" "context" + "fmt" + "github.com/tidwall/sjson" ) @@ -47,3 +49,7 @@ func ConvertGeminiResponseToGeminiCLINonStream(_ context.Context, _ string, orig rawJSON, _ = sjson.SetRawBytes([]byte(json), "response", rawJSON) return string(rawJSON) } + +func GeminiCLITokenCount(ctx context.Context, count int64) string { + return fmt.Sprintf(`{"totalTokens":%d,"promptTokensDetails":[{"modality":"TEXT","tokenCount":%d}]}`, count, count) +} diff --git a/internal/translator/gemini/gemini-cli/init.go b/internal/translator/gemini/gemini-cli/init.go index d30713cc..2c2224f7 100644 --- a/internal/translator/gemini/gemini-cli/init.go +++ b/internal/translator/gemini/gemini-cli/init.go @@ -12,8 +12,9 @@ func init() { Gemini, ConvertGeminiCLIRequestToGemini, interfaces.TranslateResponse{ - Stream: ConvertGeminiResponseToGeminiCLI, - NonStream: ConvertGeminiResponseToGeminiCLINonStream, + Stream: ConvertGeminiResponseToGeminiCLI, + NonStream: ConvertGeminiResponseToGeminiCLINonStream, + TokenCount: GeminiCLITokenCount, }, ) } diff --git a/internal/translator/gemini/gemini/gemini_gemini_response.go b/internal/translator/gemini/gemini/gemini_gemini_response.go index df9deb67..05fb6ab9 100644 --- a/internal/translator/gemini/gemini/gemini_gemini_response.go +++ b/internal/translator/gemini/gemini/gemini_gemini_response.go @@ -3,6 +3,7 @@ package gemini import ( "bytes" "context" + "fmt" ) // PassthroughGeminiResponseStream forwards Gemini responses unchanged. @@ -22,3 +23,7 @@ func PassthroughGeminiResponseStream(_ context.Context, _ string, originalReques func PassthroughGeminiResponseNonStream(_ context.Context, _ string, originalRequestRawJSON, requestRawJSON, rawJSON []byte, _ *any) string { return string(rawJSON) } + +func GeminiTokenCount(ctx context.Context, count int64) string { + return fmt.Sprintf(`{"totalTokens":%d,"promptTokensDetails":[{"modality":"TEXT","tokenCount":%d}]}`, count, count) +} diff --git a/internal/translator/gemini/gemini/init.go b/internal/translator/gemini/gemini/init.go index 6a95fef3..28c97083 100644 --- a/internal/translator/gemini/gemini/init.go +++ b/internal/translator/gemini/gemini/init.go @@ -14,8 +14,9 @@ func init() { Gemini, ConvertGeminiRequestToGemini, interfaces.TranslateResponse{ - Stream: PassthroughGeminiResponseStream, - NonStream: PassthroughGeminiResponseNonStream, + Stream: PassthroughGeminiResponseStream, + NonStream: PassthroughGeminiResponseNonStream, + TokenCount: GeminiTokenCount, }, ) } diff --git a/sdk/translator/registry.go b/sdk/translator/registry.go index 2ef333ec..ace97137 100644 --- a/sdk/translator/registry.go +++ b/sdk/translator/registry.go @@ -91,6 +91,19 @@ func (r *Registry) TranslateNonStream(ctx context.Context, from, to Format, mode return string(rawJSON) } +// TranslateNonStream applies the registered non-stream response translator. +func (r *Registry) TranslateTokenCount(ctx context.Context, from, to Format, count int64, rawJSON []byte) string { + r.mu.RLock() + defer r.mu.RUnlock() + + if byTarget, ok := r.responses[to]; ok { + if fn, isOk := byTarget[from]; isOk && fn.TokenCount != nil { + return fn.TokenCount(ctx, count) + } + } + return string(rawJSON) +} + var defaultRegistry = NewRegistry() // Default exposes the package-level registry for shared use. @@ -122,3 +135,8 @@ func TranslateStream(ctx context.Context, from, to Format, model string, origina func TranslateNonStream(ctx context.Context, from, to Format, model string, originalRequestRawJSON, requestRawJSON, rawJSON []byte, param *any) string { return defaultRegistry.TranslateNonStream(ctx, from, to, model, originalRequestRawJSON, requestRawJSON, rawJSON, param) } + +// TranslateTokenCount is a helper on the default registry. +func TranslateTokenCount(ctx context.Context, from, to Format, count int64, rawJSON []byte) string { + return defaultRegistry.TranslateTokenCount(ctx, from, to, count, rawJSON) +} diff --git a/sdk/translator/types.go b/sdk/translator/types.go index 408281c3..9655ba23 100644 --- a/sdk/translator/types.go +++ b/sdk/translator/types.go @@ -11,8 +11,11 @@ type ResponseStreamTransform func(ctx context.Context, model string, originalReq // ResponseNonStreamTransform converts non-stream responses between schemas. type ResponseNonStreamTransform func(ctx context.Context, model string, originalRequestRawJSON, requestRawJSON, rawJSON []byte, param *any) string +type ResponseTokenCountTransform func(ctx context.Context, count int64) string + // ResponseTransform groups streaming and non-streaming transforms. type ResponseTransform struct { - Stream ResponseStreamTransform - NonStream ResponseNonStreamTransform + Stream ResponseStreamTransform + NonStream ResponseNonStreamTransform + TokenCount ResponseTokenCountTransform }