Merge branch 'main' into plus

This commit is contained in:
Luis Pater
2026-02-16 23:49:50 +08:00
committed by GitHub
116 changed files with 27820 additions and 391 deletions

View File

@@ -1007,7 +1007,12 @@ func (e *AntigravityExecutor) CountTokens(ctx context.Context, auth *cliproxyaut
func FetchAntigravityModels(ctx context.Context, auth *cliproxyauth.Auth, cfg *config.Config) []*registry.ModelInfo {
exec := &AntigravityExecutor{cfg: cfg}
token, updatedAuth, errToken := exec.ensureAccessToken(ctx, auth)
if errToken != nil || token == "" {
if errToken != nil {
log.Warnf("antigravity executor: fetch models failed for %s: token error: %v", auth.ID, errToken)
return nil
}
if token == "" {
log.Warnf("antigravity executor: fetch models failed for %s: got empty token", auth.ID)
return nil
}
if updatedAuth != nil {
@@ -1021,6 +1026,7 @@ func FetchAntigravityModels(ctx context.Context, auth *cliproxyauth.Auth, cfg *c
modelsURL := baseURL + antigravityModelsPath
httpReq, errReq := http.NewRequestWithContext(ctx, http.MethodPost, modelsURL, bytes.NewReader([]byte(`{}`)))
if errReq != nil {
log.Warnf("antigravity executor: fetch models failed for %s: create request error: %v", auth.ID, errReq)
return nil
}
httpReq.Header.Set("Content-Type", "application/json")
@@ -1033,12 +1039,14 @@ func FetchAntigravityModels(ctx context.Context, auth *cliproxyauth.Auth, cfg *c
httpResp, errDo := httpClient.Do(httpReq)
if errDo != nil {
if errors.Is(errDo, context.Canceled) || errors.Is(errDo, context.DeadlineExceeded) {
log.Warnf("antigravity executor: fetch models failed for %s: context canceled: %v", auth.ID, errDo)
return nil
}
if idx+1 < len(baseURLs) {
log.Debugf("antigravity executor: models request error on base url %s, retrying with fallback base url: %s", baseURL, baseURLs[idx+1])
continue
}
log.Warnf("antigravity executor: fetch models failed for %s: request error: %v", auth.ID, errDo)
return nil
}
@@ -1051,6 +1059,7 @@ func FetchAntigravityModels(ctx context.Context, auth *cliproxyauth.Auth, cfg *c
log.Debugf("antigravity executor: models read error on base url %s, retrying with fallback base url: %s", baseURL, baseURLs[idx+1])
continue
}
log.Warnf("antigravity executor: fetch models failed for %s: read body error: %v", auth.ID, errRead)
return nil
}
if httpResp.StatusCode < http.StatusOK || httpResp.StatusCode >= http.StatusMultipleChoices {
@@ -1058,11 +1067,13 @@ func FetchAntigravityModels(ctx context.Context, auth *cliproxyauth.Auth, cfg *c
log.Debugf("antigravity executor: models request rate limited on base url %s, retrying with fallback base url: %s", baseURL, baseURLs[idx+1])
continue
}
log.Warnf("antigravity executor: fetch models failed for %s: unexpected status %d, body: %s", auth.ID, httpResp.StatusCode, string(bodyBytes))
return nil
}
result := gjson.GetBytes(bodyBytes, "models")
if !result.Exists() {
log.Warnf("antigravity executor: fetch models failed for %s: no models field in response, body: %s", auth.ID, string(bodyBytes))
return nil
}

View File

@@ -29,6 +29,7 @@ func startCodexCacheCleanup() {
go func() {
ticker := time.NewTicker(codexCacheCleanupInterval)
defer ticker.Stop()
for range ticker.C {
purgeExpiredCodexCache()
}
@@ -38,8 +39,10 @@ func startCodexCacheCleanup() {
// purgeExpiredCodexCache removes entries that have expired.
func purgeExpiredCodexCache() {
now := time.Now()
codexCacheMu.Lock()
defer codexCacheMu.Unlock()
for key, cache := range codexCacheMap {
if cache.Expire.Before(now) {
delete(codexCacheMap, key)
@@ -66,3 +69,10 @@ func setCodexCache(key string, cache codexCache) {
codexCacheMap[key] = cache
codexCacheMu.Unlock()
}
// deleteCodexCache deletes a cache entry.
func deleteCodexCache(key string) {
codexCacheMu.Lock()
delete(codexCacheMap, key)
codexCacheMu.Unlock()
}

View File

@@ -0,0 +1,976 @@
package executor
import (
"bufio"
"bytes"
"context"
"fmt"
"io"
"net/http"
"strings"
"sync"
"time"
"github.com/google/uuid"
copilotauth "github.com/router-for-me/CLIProxyAPI/v6/internal/auth/copilot"
"github.com/router-for-me/CLIProxyAPI/v6/internal/config"
"github.com/router-for-me/CLIProxyAPI/v6/internal/thinking"
cliproxyauth "github.com/router-for-me/CLIProxyAPI/v6/sdk/cliproxy/auth"
cliproxyexecutor "github.com/router-for-me/CLIProxyAPI/v6/sdk/cliproxy/executor"
sdktranslator "github.com/router-for-me/CLIProxyAPI/v6/sdk/translator"
log "github.com/sirupsen/logrus"
"github.com/tidwall/gjson"
"github.com/tidwall/sjson"
)
const (
githubCopilotBaseURL = "https://api.githubcopilot.com"
githubCopilotChatPath = "/chat/completions"
githubCopilotResponsesPath = "/responses"
githubCopilotAuthType = "github-copilot"
githubCopilotTokenCacheTTL = 25 * time.Minute
// tokenExpiryBuffer is the time before expiry when we should refresh the token.
tokenExpiryBuffer = 5 * time.Minute
// maxScannerBufferSize is the maximum buffer size for SSE scanning (20MB).
maxScannerBufferSize = 20_971_520
// Copilot API header values.
copilotUserAgent = "GitHubCopilotChat/0.35.0"
copilotEditorVersion = "vscode/1.107.0"
copilotPluginVersion = "copilot-chat/0.35.0"
copilotIntegrationID = "vscode-chat"
copilotOpenAIIntent = "conversation-edits"
)
// GitHubCopilotExecutor handles requests to the GitHub Copilot API.
type GitHubCopilotExecutor struct {
cfg *config.Config
mu sync.RWMutex
cache map[string]*cachedAPIToken
}
// cachedAPIToken stores a cached Copilot API token with its expiry.
type cachedAPIToken struct {
token string
expiresAt time.Time
}
// NewGitHubCopilotExecutor constructs a new executor instance.
func NewGitHubCopilotExecutor(cfg *config.Config) *GitHubCopilotExecutor {
return &GitHubCopilotExecutor{
cfg: cfg,
cache: make(map[string]*cachedAPIToken),
}
}
// Identifier implements ProviderExecutor.
func (e *GitHubCopilotExecutor) Identifier() string { return githubCopilotAuthType }
// PrepareRequest implements ProviderExecutor.
func (e *GitHubCopilotExecutor) PrepareRequest(req *http.Request, auth *cliproxyauth.Auth) error {
if req == nil {
return nil
}
ctx := req.Context()
if ctx == nil {
ctx = context.Background()
}
apiToken, errToken := e.ensureAPIToken(ctx, auth)
if errToken != nil {
return errToken
}
e.applyHeaders(req, apiToken, nil)
return nil
}
// HttpRequest injects GitHub Copilot credentials into the request and executes it.
func (e *GitHubCopilotExecutor) HttpRequest(ctx context.Context, auth *cliproxyauth.Auth, req *http.Request) (*http.Response, error) {
if req == nil {
return nil, fmt.Errorf("github-copilot executor: request is nil")
}
if ctx == nil {
ctx = req.Context()
}
httpReq := req.WithContext(ctx)
if errPrepare := e.PrepareRequest(httpReq, auth); errPrepare != nil {
return nil, errPrepare
}
httpClient := newProxyAwareHTTPClient(ctx, e.cfg, auth, 0)
return httpClient.Do(httpReq)
}
// Execute handles non-streaming requests to GitHub Copilot.
func (e *GitHubCopilotExecutor) Execute(ctx context.Context, auth *cliproxyauth.Auth, req cliproxyexecutor.Request, opts cliproxyexecutor.Options) (resp cliproxyexecutor.Response, err error) {
apiToken, errToken := e.ensureAPIToken(ctx, auth)
if errToken != nil {
return resp, errToken
}
reporter := newUsageReporter(ctx, e.Identifier(), req.Model, auth)
defer reporter.trackFailure(ctx, &err)
from := opts.SourceFormat
useResponses := useGitHubCopilotResponsesEndpoint(from, req.Model)
to := sdktranslator.FromString("openai")
if useResponses {
to = sdktranslator.FromString("openai-response")
}
originalPayload := bytes.Clone(req.Payload)
if len(opts.OriginalRequest) > 0 {
originalPayload = bytes.Clone(opts.OriginalRequest)
}
originalTranslated := sdktranslator.TranslateRequest(from, to, req.Model, originalPayload, false)
body := sdktranslator.TranslateRequest(from, to, req.Model, bytes.Clone(req.Payload), false)
body = e.normalizeModel(req.Model, body)
body = flattenAssistantContent(body)
thinkingProvider := "openai"
if useResponses {
thinkingProvider = "codex"
}
body, err = thinking.ApplyThinking(body, req.Model, from.String(), thinkingProvider, e.Identifier())
if err != nil {
return resp, err
}
if useResponses {
body = normalizeGitHubCopilotResponsesInput(body)
body = normalizeGitHubCopilotResponsesTools(body)
} else {
body = normalizeGitHubCopilotChatTools(body)
}
requestedModel := payloadRequestedModel(opts, req.Model)
body = applyPayloadConfigWithRoot(e.cfg, req.Model, to.String(), "", body, originalTranslated, requestedModel)
body, _ = sjson.SetBytes(body, "stream", false)
path := githubCopilotChatPath
if useResponses {
path = githubCopilotResponsesPath
}
url := githubCopilotBaseURL + path
httpReq, err := http.NewRequestWithContext(ctx, http.MethodPost, url, bytes.NewReader(body))
if err != nil {
return resp, err
}
e.applyHeaders(httpReq, apiToken, body)
// Add Copilot-Vision-Request header if the request contains vision content
if detectVisionContent(body) {
httpReq.Header.Set("Copilot-Vision-Request", "true")
}
var authID, authLabel, authType, authValue string
if auth != nil {
authID = auth.ID
authLabel = auth.Label
authType, authValue = auth.AccountInfo()
}
recordAPIRequest(ctx, e.cfg, upstreamRequestLog{
URL: url,
Method: http.MethodPost,
Headers: httpReq.Header.Clone(),
Body: body,
Provider: e.Identifier(),
AuthID: authID,
AuthLabel: authLabel,
AuthType: authType,
AuthValue: authValue,
})
httpClient := newProxyAwareHTTPClient(ctx, e.cfg, auth, 0)
httpResp, err := httpClient.Do(httpReq)
if err != nil {
recordAPIResponseError(ctx, e.cfg, err)
return resp, err
}
defer func() {
if errClose := httpResp.Body.Close(); errClose != nil {
log.Errorf("github-copilot executor: close response body error: %v", errClose)
}
}()
recordAPIResponseMetadata(ctx, e.cfg, httpResp.StatusCode, httpResp.Header.Clone())
if !isHTTPSuccess(httpResp.StatusCode) {
data, _ := io.ReadAll(httpResp.Body)
appendAPIResponseChunk(ctx, e.cfg, data)
log.Debugf("github-copilot executor: upstream error status: %d, body: %s", httpResp.StatusCode, summarizeErrorBody(httpResp.Header.Get("Content-Type"), data))
err = statusErr{code: httpResp.StatusCode, msg: string(data)}
return resp, err
}
data, err := io.ReadAll(httpResp.Body)
if err != nil {
recordAPIResponseError(ctx, e.cfg, err)
return resp, err
}
appendAPIResponseChunk(ctx, e.cfg, data)
detail := parseOpenAIUsage(data)
if useResponses && detail.TotalTokens == 0 {
detail = parseOpenAIResponsesUsage(data)
}
if detail.TotalTokens > 0 {
reporter.publish(ctx, detail)
}
var param any
converted := ""
if useResponses && from.String() == "claude" {
converted = translateGitHubCopilotResponsesNonStreamToClaude(data)
} else {
converted = sdktranslator.TranslateNonStream(ctx, to, from, req.Model, bytes.Clone(opts.OriginalRequest), body, data, &param)
}
resp = cliproxyexecutor.Response{Payload: []byte(converted)}
reporter.ensurePublished(ctx)
return resp, nil
}
// ExecuteStream handles streaming requests to GitHub Copilot.
func (e *GitHubCopilotExecutor) ExecuteStream(ctx context.Context, auth *cliproxyauth.Auth, req cliproxyexecutor.Request, opts cliproxyexecutor.Options) (stream <-chan cliproxyexecutor.StreamChunk, err error) {
apiToken, errToken := e.ensureAPIToken(ctx, auth)
if errToken != nil {
return nil, errToken
}
reporter := newUsageReporter(ctx, e.Identifier(), req.Model, auth)
defer reporter.trackFailure(ctx, &err)
from := opts.SourceFormat
useResponses := useGitHubCopilotResponsesEndpoint(from, req.Model)
to := sdktranslator.FromString("openai")
if useResponses {
to = sdktranslator.FromString("openai-response")
}
originalPayload := bytes.Clone(req.Payload)
if len(opts.OriginalRequest) > 0 {
originalPayload = bytes.Clone(opts.OriginalRequest)
}
originalTranslated := sdktranslator.TranslateRequest(from, to, req.Model, originalPayload, false)
body := sdktranslator.TranslateRequest(from, to, req.Model, bytes.Clone(req.Payload), true)
body = e.normalizeModel(req.Model, body)
body = flattenAssistantContent(body)
thinkingProvider := "openai"
if useResponses {
thinkingProvider = "codex"
}
body, err = thinking.ApplyThinking(body, req.Model, from.String(), thinkingProvider, e.Identifier())
if err != nil {
return nil, err
}
if useResponses {
body = normalizeGitHubCopilotResponsesInput(body)
body = normalizeGitHubCopilotResponsesTools(body)
} else {
body = normalizeGitHubCopilotChatTools(body)
}
requestedModel := payloadRequestedModel(opts, req.Model)
body = applyPayloadConfigWithRoot(e.cfg, req.Model, to.String(), "", body, originalTranslated, requestedModel)
body, _ = sjson.SetBytes(body, "stream", true)
// Enable stream options for usage stats in stream
if !useResponses {
body, _ = sjson.SetBytes(body, "stream_options.include_usage", true)
}
path := githubCopilotChatPath
if useResponses {
path = githubCopilotResponsesPath
}
url := githubCopilotBaseURL + path
httpReq, err := http.NewRequestWithContext(ctx, http.MethodPost, url, bytes.NewReader(body))
if err != nil {
return nil, err
}
e.applyHeaders(httpReq, apiToken, body)
// Add Copilot-Vision-Request header if the request contains vision content
if detectVisionContent(body) {
httpReq.Header.Set("Copilot-Vision-Request", "true")
}
var authID, authLabel, authType, authValue string
if auth != nil {
authID = auth.ID
authLabel = auth.Label
authType, authValue = auth.AccountInfo()
}
recordAPIRequest(ctx, e.cfg, upstreamRequestLog{
URL: url,
Method: http.MethodPost,
Headers: httpReq.Header.Clone(),
Body: body,
Provider: e.Identifier(),
AuthID: authID,
AuthLabel: authLabel,
AuthType: authType,
AuthValue: authValue,
})
httpClient := newProxyAwareHTTPClient(ctx, e.cfg, auth, 0)
httpResp, err := httpClient.Do(httpReq)
if err != nil {
recordAPIResponseError(ctx, e.cfg, err)
return nil, err
}
recordAPIResponseMetadata(ctx, e.cfg, httpResp.StatusCode, httpResp.Header.Clone())
if !isHTTPSuccess(httpResp.StatusCode) {
data, readErr := io.ReadAll(httpResp.Body)
if errClose := httpResp.Body.Close(); errClose != nil {
log.Errorf("github-copilot executor: close response body error: %v", errClose)
}
if readErr != nil {
recordAPIResponseError(ctx, e.cfg, readErr)
return nil, readErr
}
appendAPIResponseChunk(ctx, e.cfg, data)
log.Debugf("github-copilot executor: upstream error status: %d, body: %s", httpResp.StatusCode, summarizeErrorBody(httpResp.Header.Get("Content-Type"), data))
err = statusErr{code: httpResp.StatusCode, msg: string(data)}
return nil, err
}
out := make(chan cliproxyexecutor.StreamChunk)
stream = out
go func() {
defer close(out)
defer func() {
if errClose := httpResp.Body.Close(); errClose != nil {
log.Errorf("github-copilot executor: close response body error: %v", errClose)
}
}()
scanner := bufio.NewScanner(httpResp.Body)
scanner.Buffer(nil, maxScannerBufferSize)
var param any
for scanner.Scan() {
line := scanner.Bytes()
appendAPIResponseChunk(ctx, e.cfg, line)
// Parse SSE data
if bytes.HasPrefix(line, dataTag) {
data := bytes.TrimSpace(line[5:])
if bytes.Equal(data, []byte("[DONE]")) {
continue
}
if detail, ok := parseOpenAIStreamUsage(line); ok {
reporter.publish(ctx, detail)
} else if useResponses {
if detail, ok := parseOpenAIResponsesStreamUsage(line); ok {
reporter.publish(ctx, detail)
}
}
}
var chunks []string
if useResponses && from.String() == "claude" {
chunks = translateGitHubCopilotResponsesStreamToClaude(bytes.Clone(line), &param)
} else {
chunks = sdktranslator.TranslateStream(ctx, to, from, req.Model, bytes.Clone(opts.OriginalRequest), body, bytes.Clone(line), &param)
}
for i := range chunks {
out <- cliproxyexecutor.StreamChunk{Payload: []byte(chunks[i])}
}
}
if errScan := scanner.Err(); errScan != nil {
recordAPIResponseError(ctx, e.cfg, errScan)
reporter.publishFailure(ctx)
out <- cliproxyexecutor.StreamChunk{Err: errScan}
} else {
reporter.ensurePublished(ctx)
}
}()
return stream, nil
}
// CountTokens is not supported for GitHub Copilot.
func (e *GitHubCopilotExecutor) CountTokens(_ context.Context, _ *cliproxyauth.Auth, _ cliproxyexecutor.Request, _ cliproxyexecutor.Options) (cliproxyexecutor.Response, error) {
return cliproxyexecutor.Response{}, statusErr{code: http.StatusNotImplemented, msg: "count tokens not supported for github-copilot"}
}
// Refresh validates the GitHub token is still working.
// GitHub OAuth tokens don't expire traditionally, so we just validate.
func (e *GitHubCopilotExecutor) Refresh(ctx context.Context, auth *cliproxyauth.Auth) (*cliproxyauth.Auth, error) {
if auth == nil {
return nil, statusErr{code: http.StatusUnauthorized, msg: "missing auth"}
}
// Get the GitHub access token
accessToken := metaStringValue(auth.Metadata, "access_token")
if accessToken == "" {
return auth, nil
}
// Validate the token can still get a Copilot API token
copilotAuth := copilotauth.NewCopilotAuth(e.cfg)
_, err := copilotAuth.GetCopilotAPIToken(ctx, accessToken)
if err != nil {
return nil, statusErr{code: http.StatusUnauthorized, msg: fmt.Sprintf("github-copilot token validation failed: %v", err)}
}
return auth, nil
}
// ensureAPIToken gets or refreshes the Copilot API token.
func (e *GitHubCopilotExecutor) ensureAPIToken(ctx context.Context, auth *cliproxyauth.Auth) (string, error) {
if auth == nil {
return "", statusErr{code: http.StatusUnauthorized, msg: "missing auth"}
}
// Get the GitHub access token
accessToken := metaStringValue(auth.Metadata, "access_token")
if accessToken == "" {
return "", statusErr{code: http.StatusUnauthorized, msg: "missing github access token"}
}
// Check for cached API token using thread-safe access
e.mu.RLock()
if cached, ok := e.cache[accessToken]; ok && cached.expiresAt.After(time.Now().Add(tokenExpiryBuffer)) {
e.mu.RUnlock()
return cached.token, nil
}
e.mu.RUnlock()
// Get a new Copilot API token
copilotAuth := copilotauth.NewCopilotAuth(e.cfg)
apiToken, err := copilotAuth.GetCopilotAPIToken(ctx, accessToken)
if err != nil {
return "", statusErr{code: http.StatusUnauthorized, msg: fmt.Sprintf("failed to get copilot api token: %v", err)}
}
// Cache the token with thread-safe access
expiresAt := time.Now().Add(githubCopilotTokenCacheTTL)
if apiToken.ExpiresAt > 0 {
expiresAt = time.Unix(apiToken.ExpiresAt, 0)
}
e.mu.Lock()
e.cache[accessToken] = &cachedAPIToken{
token: apiToken.Token,
expiresAt: expiresAt,
}
e.mu.Unlock()
return apiToken.Token, nil
}
// applyHeaders sets the required headers for GitHub Copilot API requests.
func (e *GitHubCopilotExecutor) applyHeaders(r *http.Request, apiToken string, body []byte) {
r.Header.Set("Content-Type", "application/json")
r.Header.Set("Authorization", "Bearer "+apiToken)
r.Header.Set("Accept", "application/json")
r.Header.Set("User-Agent", copilotUserAgent)
r.Header.Set("Editor-Version", copilotEditorVersion)
r.Header.Set("Editor-Plugin-Version", copilotPluginVersion)
r.Header.Set("Openai-Intent", copilotOpenAIIntent)
r.Header.Set("Copilot-Integration-Id", copilotIntegrationID)
r.Header.Set("X-Request-Id", uuid.NewString())
initiator := "user"
if len(body) > 0 {
if messages := gjson.GetBytes(body, "messages"); messages.Exists() && messages.IsArray() {
arr := messages.Array()
if len(arr) > 0 {
lastRole := arr[len(arr)-1].Get("role").String()
if lastRole != "" && lastRole != "user" {
initiator = "agent"
}
}
}
}
r.Header.Set("X-Initiator", initiator)
}
// detectVisionContent checks if the request body contains vision/image content.
// Returns true if the request includes image_url or image type content blocks.
func detectVisionContent(body []byte) bool {
// Parse messages array
messagesResult := gjson.GetBytes(body, "messages")
if !messagesResult.Exists() || !messagesResult.IsArray() {
return false
}
// Check each message for vision content
for _, message := range messagesResult.Array() {
content := message.Get("content")
// If content is an array, check each content block
if content.IsArray() {
for _, block := range content.Array() {
blockType := block.Get("type").String()
// Check for image_url or image type
if blockType == "image_url" || blockType == "image" {
return true
}
}
}
}
return false
}
// normalizeModel strips the suffix (e.g. "(medium)") from the model name
// before sending to GitHub Copilot, as the upstream API does not accept
// suffixed model identifiers.
func (e *GitHubCopilotExecutor) normalizeModel(model string, body []byte) []byte {
baseModel := thinking.ParseSuffix(model).ModelName
if baseModel != model {
body, _ = sjson.SetBytes(body, "model", baseModel)
}
return body
}
func useGitHubCopilotResponsesEndpoint(sourceFormat sdktranslator.Format, model string) bool {
if sourceFormat.String() == "openai-response" {
return true
}
baseModel := strings.ToLower(thinking.ParseSuffix(model).ModelName)
return strings.Contains(baseModel, "codex")
}
// flattenAssistantContent converts assistant message content from array format
// to a joined string. GitHub Copilot requires assistant content as a string;
// sending it as an array causes Claude models to re-answer all previous prompts.
func flattenAssistantContent(body []byte) []byte {
messages := gjson.GetBytes(body, "messages")
if !messages.Exists() || !messages.IsArray() {
return body
}
result := body
for i, msg := range messages.Array() {
if msg.Get("role").String() != "assistant" {
continue
}
content := msg.Get("content")
if !content.Exists() || !content.IsArray() {
continue
}
var textParts []string
for _, part := range content.Array() {
if part.Get("type").String() == "text" {
if t := part.Get("text").String(); t != "" {
textParts = append(textParts, t)
}
}
}
joined := strings.Join(textParts, "")
path := fmt.Sprintf("messages.%d.content", i)
result, _ = sjson.SetBytes(result, path, joined)
}
return result
}
func normalizeGitHubCopilotChatTools(body []byte) []byte {
tools := gjson.GetBytes(body, "tools")
if tools.Exists() {
filtered := "[]"
if tools.IsArray() {
for _, tool := range tools.Array() {
if tool.Get("type").String() != "function" {
continue
}
filtered, _ = sjson.SetRaw(filtered, "-1", tool.Raw)
}
}
body, _ = sjson.SetRawBytes(body, "tools", []byte(filtered))
}
toolChoice := gjson.GetBytes(body, "tool_choice")
if !toolChoice.Exists() {
return body
}
if toolChoice.Type == gjson.String {
switch toolChoice.String() {
case "auto", "none", "required":
return body
}
}
body, _ = sjson.SetBytes(body, "tool_choice", "auto")
return body
}
func normalizeGitHubCopilotResponsesInput(body []byte) []byte {
input := gjson.GetBytes(body, "input")
if input.Exists() {
if input.Type == gjson.String {
return body
}
inputString := input.Raw
if input.Type != gjson.JSON {
inputString = input.String()
}
body, _ = sjson.SetBytes(body, "input", inputString)
return body
}
var parts []string
if system := gjson.GetBytes(body, "system"); system.Exists() {
if text := strings.TrimSpace(collectTextFromNode(system)); text != "" {
parts = append(parts, text)
}
}
if messages := gjson.GetBytes(body, "messages"); messages.Exists() && messages.IsArray() {
for _, msg := range messages.Array() {
if text := strings.TrimSpace(collectTextFromNode(msg.Get("content"))); text != "" {
parts = append(parts, text)
}
}
}
body, _ = sjson.SetBytes(body, "input", strings.Join(parts, "\n"))
return body
}
func normalizeGitHubCopilotResponsesTools(body []byte) []byte {
tools := gjson.GetBytes(body, "tools")
if tools.Exists() {
filtered := "[]"
if tools.IsArray() {
for _, tool := range tools.Array() {
toolType := tool.Get("type").String()
// Accept OpenAI format (type="function") and Claude format
// (no type field, but has top-level name + input_schema).
if toolType != "" && toolType != "function" {
continue
}
name := tool.Get("name").String()
if name == "" {
name = tool.Get("function.name").String()
}
if name == "" {
continue
}
normalized := `{"type":"function","name":""}`
normalized, _ = sjson.Set(normalized, "name", name)
if desc := tool.Get("description").String(); desc != "" {
normalized, _ = sjson.Set(normalized, "description", desc)
} else if desc = tool.Get("function.description").String(); desc != "" {
normalized, _ = sjson.Set(normalized, "description", desc)
}
if params := tool.Get("parameters"); params.Exists() {
normalized, _ = sjson.SetRaw(normalized, "parameters", params.Raw)
} else if params = tool.Get("function.parameters"); params.Exists() {
normalized, _ = sjson.SetRaw(normalized, "parameters", params.Raw)
} else if params = tool.Get("input_schema"); params.Exists() {
normalized, _ = sjson.SetRaw(normalized, "parameters", params.Raw)
}
filtered, _ = sjson.SetRaw(filtered, "-1", normalized)
}
}
body, _ = sjson.SetRawBytes(body, "tools", []byte(filtered))
}
toolChoice := gjson.GetBytes(body, "tool_choice")
if !toolChoice.Exists() {
return body
}
if toolChoice.Type == gjson.String {
switch toolChoice.String() {
case "auto", "none", "required":
return body
default:
body, _ = sjson.SetBytes(body, "tool_choice", "auto")
return body
}
}
if toolChoice.Type == gjson.JSON {
choiceType := toolChoice.Get("type").String()
if choiceType == "function" {
name := toolChoice.Get("name").String()
if name == "" {
name = toolChoice.Get("function.name").String()
}
if name != "" {
normalized := `{"type":"function","name":""}`
normalized, _ = sjson.Set(normalized, "name", name)
body, _ = sjson.SetRawBytes(body, "tool_choice", []byte(normalized))
return body
}
}
}
body, _ = sjson.SetBytes(body, "tool_choice", "auto")
return body
}
func collectTextFromNode(node gjson.Result) string {
if !node.Exists() {
return ""
}
if node.Type == gjson.String {
return node.String()
}
if node.IsArray() {
var parts []string
for _, item := range node.Array() {
if item.Type == gjson.String {
if text := item.String(); text != "" {
parts = append(parts, text)
}
continue
}
if text := item.Get("text").String(); text != "" {
parts = append(parts, text)
continue
}
if nested := collectTextFromNode(item.Get("content")); nested != "" {
parts = append(parts, nested)
}
}
return strings.Join(parts, "\n")
}
if node.Type == gjson.JSON {
if text := node.Get("text").String(); text != "" {
return text
}
if nested := collectTextFromNode(node.Get("content")); nested != "" {
return nested
}
return node.Raw
}
return node.String()
}
type githubCopilotResponsesStreamToolState struct {
Index int
ID string
Name string
}
type githubCopilotResponsesStreamState struct {
MessageStarted bool
MessageStopSent bool
TextBlockStarted bool
TextBlockIndex int
NextContentIndex int
HasToolUse bool
OutputIndexToTool map[int]*githubCopilotResponsesStreamToolState
ItemIDToTool map[string]*githubCopilotResponsesStreamToolState
}
func translateGitHubCopilotResponsesNonStreamToClaude(data []byte) string {
root := gjson.ParseBytes(data)
out := `{"id":"","type":"message","role":"assistant","model":"","content":[],"stop_reason":null,"stop_sequence":null,"usage":{"input_tokens":0,"output_tokens":0}}`
out, _ = sjson.Set(out, "id", root.Get("id").String())
out, _ = sjson.Set(out, "model", root.Get("model").String())
hasToolUse := false
if output := root.Get("output"); output.Exists() && output.IsArray() {
for _, item := range output.Array() {
switch item.Get("type").String() {
case "message":
if content := item.Get("content"); content.Exists() && content.IsArray() {
for _, part := range content.Array() {
if part.Get("type").String() != "output_text" {
continue
}
text := part.Get("text").String()
if text == "" {
continue
}
block := `{"type":"text","text":""}`
block, _ = sjson.Set(block, "text", text)
out, _ = sjson.SetRaw(out, "content.-1", block)
}
}
case "function_call":
hasToolUse = true
toolUse := `{"type":"tool_use","id":"","name":"","input":{}}`
toolID := item.Get("call_id").String()
if toolID == "" {
toolID = item.Get("id").String()
}
toolUse, _ = sjson.Set(toolUse, "id", toolID)
toolUse, _ = sjson.Set(toolUse, "name", item.Get("name").String())
if args := item.Get("arguments").String(); args != "" && gjson.Valid(args) {
argObj := gjson.Parse(args)
if argObj.IsObject() {
toolUse, _ = sjson.SetRaw(toolUse, "input", argObj.Raw)
}
}
out, _ = sjson.SetRaw(out, "content.-1", toolUse)
}
}
}
inputTokens := root.Get("usage.input_tokens").Int()
outputTokens := root.Get("usage.output_tokens").Int()
out, _ = sjson.Set(out, "usage.input_tokens", inputTokens)
out, _ = sjson.Set(out, "usage.output_tokens", outputTokens)
if hasToolUse {
out, _ = sjson.Set(out, "stop_reason", "tool_use")
} else {
out, _ = sjson.Set(out, "stop_reason", "end_turn")
}
return out
}
func translateGitHubCopilotResponsesStreamToClaude(line []byte, param *any) []string {
if *param == nil {
*param = &githubCopilotResponsesStreamState{
TextBlockIndex: -1,
OutputIndexToTool: make(map[int]*githubCopilotResponsesStreamToolState),
ItemIDToTool: make(map[string]*githubCopilotResponsesStreamToolState),
}
}
state := (*param).(*githubCopilotResponsesStreamState)
if !bytes.HasPrefix(line, dataTag) {
return nil
}
payload := bytes.TrimSpace(line[5:])
if bytes.Equal(payload, []byte("[DONE]")) {
return nil
}
if !gjson.ValidBytes(payload) {
return nil
}
event := gjson.GetBytes(payload, "type").String()
results := make([]string, 0, 4)
ensureMessageStart := func() {
if state.MessageStarted {
return
}
messageStart := `{"type":"message_start","message":{"id":"","type":"message","role":"assistant","model":"","content":[],"stop_reason":null,"stop_sequence":null,"usage":{"input_tokens":0,"output_tokens":0}}}`
messageStart, _ = sjson.Set(messageStart, "message.id", gjson.GetBytes(payload, "response.id").String())
messageStart, _ = sjson.Set(messageStart, "message.model", gjson.GetBytes(payload, "response.model").String())
results = append(results, "event: message_start\ndata: "+messageStart+"\n\n")
state.MessageStarted = true
}
startTextBlockIfNeeded := func() {
if state.TextBlockStarted {
return
}
if state.TextBlockIndex < 0 {
state.TextBlockIndex = state.NextContentIndex
state.NextContentIndex++
}
contentBlockStart := `{"type":"content_block_start","index":0,"content_block":{"type":"text","text":""}}`
contentBlockStart, _ = sjson.Set(contentBlockStart, "index", state.TextBlockIndex)
results = append(results, "event: content_block_start\ndata: "+contentBlockStart+"\n\n")
state.TextBlockStarted = true
}
stopTextBlockIfNeeded := func() {
if !state.TextBlockStarted {
return
}
contentBlockStop := `{"type":"content_block_stop","index":0}`
contentBlockStop, _ = sjson.Set(contentBlockStop, "index", state.TextBlockIndex)
results = append(results, "event: content_block_stop\ndata: "+contentBlockStop+"\n\n")
state.TextBlockStarted = false
state.TextBlockIndex = -1
}
resolveTool := func(itemID string, outputIndex int) *githubCopilotResponsesStreamToolState {
if itemID != "" {
if tool, ok := state.ItemIDToTool[itemID]; ok {
return tool
}
}
if tool, ok := state.OutputIndexToTool[outputIndex]; ok {
if itemID != "" {
state.ItemIDToTool[itemID] = tool
}
return tool
}
return nil
}
switch event {
case "response.created":
ensureMessageStart()
case "response.output_text.delta":
ensureMessageStart()
startTextBlockIfNeeded()
delta := gjson.GetBytes(payload, "delta").String()
if delta != "" {
contentDelta := `{"type":"content_block_delta","index":0,"delta":{"type":"text_delta","text":""}}`
contentDelta, _ = sjson.Set(contentDelta, "index", state.TextBlockIndex)
contentDelta, _ = sjson.Set(contentDelta, "delta.text", delta)
results = append(results, "event: content_block_delta\ndata: "+contentDelta+"\n\n")
}
case "response.output_item.added":
if gjson.GetBytes(payload, "item.type").String() != "function_call" {
break
}
ensureMessageStart()
stopTextBlockIfNeeded()
state.HasToolUse = true
tool := &githubCopilotResponsesStreamToolState{
Index: state.NextContentIndex,
ID: gjson.GetBytes(payload, "item.call_id").String(),
Name: gjson.GetBytes(payload, "item.name").String(),
}
if tool.ID == "" {
tool.ID = gjson.GetBytes(payload, "item.id").String()
}
state.NextContentIndex++
outputIndex := int(gjson.GetBytes(payload, "output_index").Int())
state.OutputIndexToTool[outputIndex] = tool
if itemID := gjson.GetBytes(payload, "item.id").String(); itemID != "" {
state.ItemIDToTool[itemID] = tool
}
contentBlockStart := `{"type":"content_block_start","index":0,"content_block":{"type":"tool_use","id":"","name":"","input":{}}}`
contentBlockStart, _ = sjson.Set(contentBlockStart, "index", tool.Index)
contentBlockStart, _ = sjson.Set(contentBlockStart, "content_block.id", tool.ID)
contentBlockStart, _ = sjson.Set(contentBlockStart, "content_block.name", tool.Name)
results = append(results, "event: content_block_start\ndata: "+contentBlockStart+"\n\n")
case "response.output_item.delta":
item := gjson.GetBytes(payload, "item")
if item.Get("type").String() != "function_call" {
break
}
tool := resolveTool(item.Get("id").String(), int(gjson.GetBytes(payload, "output_index").Int()))
if tool == nil {
break
}
partial := gjson.GetBytes(payload, "delta").String()
if partial == "" {
partial = item.Get("arguments").String()
}
if partial == "" {
break
}
inputDelta := `{"type":"content_block_delta","index":0,"delta":{"type":"input_json_delta","partial_json":""}}`
inputDelta, _ = sjson.Set(inputDelta, "index", tool.Index)
inputDelta, _ = sjson.Set(inputDelta, "delta.partial_json", partial)
results = append(results, "event: content_block_delta\ndata: "+inputDelta+"\n\n")
case "response.output_item.done":
if gjson.GetBytes(payload, "item.type").String() != "function_call" {
break
}
tool := resolveTool(gjson.GetBytes(payload, "item.id").String(), int(gjson.GetBytes(payload, "output_index").Int()))
if tool == nil {
break
}
contentBlockStop := `{"type":"content_block_stop","index":0}`
contentBlockStop, _ = sjson.Set(contentBlockStop, "index", tool.Index)
results = append(results, "event: content_block_stop\ndata: "+contentBlockStop+"\n\n")
case "response.completed":
ensureMessageStart()
stopTextBlockIfNeeded()
if !state.MessageStopSent {
stopReason := "end_turn"
if state.HasToolUse {
stopReason = "tool_use"
}
messageDelta := `{"type":"message_delta","delta":{"stop_reason":"","stop_sequence":null},"usage":{"input_tokens":0,"output_tokens":0}}`
messageDelta, _ = sjson.Set(messageDelta, "delta.stop_reason", stopReason)
messageDelta, _ = sjson.Set(messageDelta, "usage.input_tokens", gjson.GetBytes(payload, "response.usage.input_tokens").Int())
messageDelta, _ = sjson.Set(messageDelta, "usage.output_tokens", gjson.GetBytes(payload, "response.usage.output_tokens").Int())
results = append(results, "event: message_delta\ndata: "+messageDelta+"\n\n")
results = append(results, "event: message_stop\ndata: {\"type\":\"message_stop\"}\n\n")
state.MessageStopSent = true
}
}
return results
}
// isHTTPSuccess checks if the status code indicates success (2xx).
func isHTTPSuccess(statusCode int) bool {
return statusCode >= 200 && statusCode < 300
}

View File

@@ -0,0 +1,242 @@
package executor
import (
"strings"
"testing"
sdktranslator "github.com/router-for-me/CLIProxyAPI/v6/sdk/translator"
"github.com/tidwall/gjson"
)
func TestGitHubCopilotNormalizeModel_StripsSuffix(t *testing.T) {
t.Parallel()
tests := []struct {
name string
model string
wantModel string
}{
{
name: "suffix stripped",
model: "claude-opus-4.6(medium)",
wantModel: "claude-opus-4.6",
},
{
name: "no suffix unchanged",
model: "claude-opus-4.6",
wantModel: "claude-opus-4.6",
},
{
name: "different suffix stripped",
model: "gpt-4o(high)",
wantModel: "gpt-4o",
},
{
name: "numeric suffix stripped",
model: "gemini-2.5-pro(8192)",
wantModel: "gemini-2.5-pro",
},
}
e := &GitHubCopilotExecutor{}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
t.Parallel()
body := []byte(`{"model":"` + tt.model + `","messages":[]}`)
got := e.normalizeModel(tt.model, body)
gotModel := gjson.GetBytes(got, "model").String()
if gotModel != tt.wantModel {
t.Fatalf("normalizeModel() model = %q, want %q", gotModel, tt.wantModel)
}
})
}
}
func TestUseGitHubCopilotResponsesEndpoint_OpenAIResponseSource(t *testing.T) {
t.Parallel()
if !useGitHubCopilotResponsesEndpoint(sdktranslator.FromString("openai-response"), "claude-3-5-sonnet") {
t.Fatal("expected openai-response source to use /responses")
}
}
func TestUseGitHubCopilotResponsesEndpoint_CodexModel(t *testing.T) {
t.Parallel()
if !useGitHubCopilotResponsesEndpoint(sdktranslator.FromString("openai"), "gpt-5-codex") {
t.Fatal("expected codex model to use /responses")
}
}
func TestUseGitHubCopilotResponsesEndpoint_DefaultChat(t *testing.T) {
t.Parallel()
if useGitHubCopilotResponsesEndpoint(sdktranslator.FromString("openai"), "claude-3-5-sonnet") {
t.Fatal("expected default openai source with non-codex model to use /chat/completions")
}
}
func TestNormalizeGitHubCopilotChatTools_KeepFunctionOnly(t *testing.T) {
t.Parallel()
body := []byte(`{"tools":[{"type":"function","function":{"name":"ok"}},{"type":"code_interpreter"}],"tool_choice":"auto"}`)
got := normalizeGitHubCopilotChatTools(body)
tools := gjson.GetBytes(got, "tools").Array()
if len(tools) != 1 {
t.Fatalf("tools len = %d, want 1", len(tools))
}
if tools[0].Get("type").String() != "function" {
t.Fatalf("tool type = %q, want function", tools[0].Get("type").String())
}
}
func TestNormalizeGitHubCopilotChatTools_InvalidToolChoiceDowngradeToAuto(t *testing.T) {
t.Parallel()
body := []byte(`{"tools":[],"tool_choice":{"type":"function","function":{"name":"x"}}}`)
got := normalizeGitHubCopilotChatTools(body)
if gjson.GetBytes(got, "tool_choice").String() != "auto" {
t.Fatalf("tool_choice = %s, want auto", gjson.GetBytes(got, "tool_choice").Raw)
}
}
func TestNormalizeGitHubCopilotResponsesInput_MissingInputExtractedFromSystemAndMessages(t *testing.T) {
t.Parallel()
body := []byte(`{"system":"sys text","messages":[{"role":"user","content":"user text"},{"role":"assistant","content":[{"type":"text","text":"assistant text"}]}]}`)
got := normalizeGitHubCopilotResponsesInput(body)
in := gjson.GetBytes(got, "input")
if in.Type != gjson.String {
t.Fatalf("input type = %v, want string", in.Type)
}
if !strings.Contains(in.String(), "sys text") || !strings.Contains(in.String(), "user text") || !strings.Contains(in.String(), "assistant text") {
t.Fatalf("input = %q, want merged text", in.String())
}
}
func TestNormalizeGitHubCopilotResponsesInput_NonStringInputStringified(t *testing.T) {
t.Parallel()
body := []byte(`{"input":{"foo":"bar"}}`)
got := normalizeGitHubCopilotResponsesInput(body)
in := gjson.GetBytes(got, "input")
if in.Type != gjson.String {
t.Fatalf("input type = %v, want string", in.Type)
}
if !strings.Contains(in.String(), "foo") {
t.Fatalf("input = %q, want stringified object", in.String())
}
}
func TestNormalizeGitHubCopilotResponsesTools_FlattenFunctionTools(t *testing.T) {
t.Parallel()
body := []byte(`{"tools":[{"type":"function","function":{"name":"sum","description":"d","parameters":{"type":"object"}}},{"type":"web_search"}]}`)
got := normalizeGitHubCopilotResponsesTools(body)
tools := gjson.GetBytes(got, "tools").Array()
if len(tools) != 1 {
t.Fatalf("tools len = %d, want 1", len(tools))
}
if tools[0].Get("name").String() != "sum" {
t.Fatalf("tools[0].name = %q, want sum", tools[0].Get("name").String())
}
if !tools[0].Get("parameters").Exists() {
t.Fatal("expected parameters to be preserved")
}
}
func TestNormalizeGitHubCopilotResponsesTools_ClaudeFormatTools(t *testing.T) {
t.Parallel()
body := []byte(`{"tools":[{"name":"Bash","description":"Run commands","input_schema":{"type":"object","properties":{"command":{"type":"string"}},"required":["command"]}},{"name":"Read","description":"Read files","input_schema":{"type":"object","properties":{"path":{"type":"string"}}}}]}`)
got := normalizeGitHubCopilotResponsesTools(body)
tools := gjson.GetBytes(got, "tools").Array()
if len(tools) != 2 {
t.Fatalf("tools len = %d, want 2", len(tools))
}
if tools[0].Get("type").String() != "function" {
t.Fatalf("tools[0].type = %q, want function", tools[0].Get("type").String())
}
if tools[0].Get("name").String() != "Bash" {
t.Fatalf("tools[0].name = %q, want Bash", tools[0].Get("name").String())
}
if tools[0].Get("description").String() != "Run commands" {
t.Fatalf("tools[0].description = %q, want 'Run commands'", tools[0].Get("description").String())
}
if !tools[0].Get("parameters").Exists() {
t.Fatal("expected parameters to be set from input_schema")
}
if tools[0].Get("parameters.properties.command").Exists() != true {
t.Fatal("expected parameters.properties.command to exist")
}
if tools[1].Get("name").String() != "Read" {
t.Fatalf("tools[1].name = %q, want Read", tools[1].Get("name").String())
}
}
func TestNormalizeGitHubCopilotResponsesTools_FlattenToolChoiceFunctionObject(t *testing.T) {
t.Parallel()
body := []byte(`{"tool_choice":{"type":"function","function":{"name":"sum"}}}`)
got := normalizeGitHubCopilotResponsesTools(body)
if gjson.GetBytes(got, "tool_choice.type").String() != "function" {
t.Fatalf("tool_choice.type = %q, want function", gjson.GetBytes(got, "tool_choice.type").String())
}
if gjson.GetBytes(got, "tool_choice.name").String() != "sum" {
t.Fatalf("tool_choice.name = %q, want sum", gjson.GetBytes(got, "tool_choice.name").String())
}
}
func TestNormalizeGitHubCopilotResponsesTools_InvalidToolChoiceDowngradeToAuto(t *testing.T) {
t.Parallel()
body := []byte(`{"tool_choice":{"type":"function"}}`)
got := normalizeGitHubCopilotResponsesTools(body)
if gjson.GetBytes(got, "tool_choice").String() != "auto" {
t.Fatalf("tool_choice = %s, want auto", gjson.GetBytes(got, "tool_choice").Raw)
}
}
func TestTranslateGitHubCopilotResponsesNonStreamToClaude_TextMapping(t *testing.T) {
t.Parallel()
resp := []byte(`{"id":"resp_1","model":"gpt-5-codex","output":[{"type":"message","content":[{"type":"output_text","text":"hello"}]}],"usage":{"input_tokens":3,"output_tokens":5}}`)
out := translateGitHubCopilotResponsesNonStreamToClaude(resp)
if gjson.Get(out, "type").String() != "message" {
t.Fatalf("type = %q, want message", gjson.Get(out, "type").String())
}
if gjson.Get(out, "content.0.type").String() != "text" {
t.Fatalf("content.0.type = %q, want text", gjson.Get(out, "content.0.type").String())
}
if gjson.Get(out, "content.0.text").String() != "hello" {
t.Fatalf("content.0.text = %q, want hello", gjson.Get(out, "content.0.text").String())
}
}
func TestTranslateGitHubCopilotResponsesNonStreamToClaude_ToolUseMapping(t *testing.T) {
t.Parallel()
resp := []byte(`{"id":"resp_2","model":"gpt-5-codex","output":[{"type":"function_call","id":"fc_1","call_id":"call_1","name":"sum","arguments":"{\"a\":1}"}],"usage":{"input_tokens":1,"output_tokens":2}}`)
out := translateGitHubCopilotResponsesNonStreamToClaude(resp)
if gjson.Get(out, "content.0.type").String() != "tool_use" {
t.Fatalf("content.0.type = %q, want tool_use", gjson.Get(out, "content.0.type").String())
}
if gjson.Get(out, "content.0.name").String() != "sum" {
t.Fatalf("content.0.name = %q, want sum", gjson.Get(out, "content.0.name").String())
}
if gjson.Get(out, "stop_reason").String() != "tool_use" {
t.Fatalf("stop_reason = %q, want tool_use", gjson.Get(out, "stop_reason").String())
}
}
func TestTranslateGitHubCopilotResponsesStreamToClaude_TextLifecycle(t *testing.T) {
t.Parallel()
var param any
created := translateGitHubCopilotResponsesStreamToClaude([]byte(`data: {"type":"response.created","response":{"id":"resp_1","model":"gpt-5-codex"}}`), &param)
if len(created) == 0 || !strings.Contains(created[0], "message_start") {
t.Fatalf("created events = %#v, want message_start", created)
}
delta := translateGitHubCopilotResponsesStreamToClaude([]byte(`data: {"type":"response.output_text.delta","delta":"he"}`), &param)
joinedDelta := strings.Join(delta, "")
if !strings.Contains(joinedDelta, "content_block_start") || !strings.Contains(joinedDelta, "text_delta") {
t.Fatalf("delta events = %#v, want content_block_start + text_delta", delta)
}
completed := translateGitHubCopilotResponsesStreamToClaude([]byte(`data: {"type":"response.completed","response":{"usage":{"input_tokens":7,"output_tokens":9}}}`), &param)
joinedCompleted := strings.Join(completed, "")
if !strings.Contains(joinedCompleted, "message_delta") || !strings.Contains(joinedCompleted, "message_stop") {
t.Fatalf("completed events = %#v, want message_delta + message_stop", completed)
}
}

File diff suppressed because it is too large Load Diff

View File

@@ -6,6 +6,7 @@ import (
"net/http"
"net/url"
"strings"
"sync"
"time"
"github.com/router-for-me/CLIProxyAPI/v6/internal/config"
@@ -14,11 +15,19 @@ import (
"golang.org/x/net/proxy"
)
// httpClientCache caches HTTP clients by proxy URL to enable connection reuse
var (
httpClientCache = make(map[string]*http.Client)
httpClientCacheMutex sync.RWMutex
)
// newProxyAwareHTTPClient creates an HTTP client with proper proxy configuration priority:
// 1. Use auth.ProxyURL if configured (highest priority)
// 2. Use cfg.ProxyURL if auth proxy is not configured
// 3. Use RoundTripper from context if neither are configured
//
// This function caches HTTP clients by proxy URL to enable TCP/TLS connection reuse.
//
// Parameters:
// - ctx: The context containing optional RoundTripper
// - cfg: The application configuration
@@ -28,11 +37,6 @@ import (
// Returns:
// - *http.Client: An HTTP client with configured proxy or transport
func newProxyAwareHTTPClient(ctx context.Context, cfg *config.Config, auth *cliproxyauth.Auth, timeout time.Duration) *http.Client {
httpClient := &http.Client{}
if timeout > 0 {
httpClient.Timeout = timeout
}
// Priority 1: Use auth.ProxyURL if configured
var proxyURL string
if auth != nil {
@@ -44,11 +48,39 @@ func newProxyAwareHTTPClient(ctx context.Context, cfg *config.Config, auth *clip
proxyURL = strings.TrimSpace(cfg.ProxyURL)
}
// Build cache key from proxy URL (empty string for no proxy)
cacheKey := proxyURL
// Check cache first
httpClientCacheMutex.RLock()
if cachedClient, ok := httpClientCache[cacheKey]; ok {
httpClientCacheMutex.RUnlock()
// Return a wrapper with the requested timeout but shared transport
if timeout > 0 {
return &http.Client{
Transport: cachedClient.Transport,
Timeout: timeout,
}
}
return cachedClient
}
httpClientCacheMutex.RUnlock()
// Create new client
httpClient := &http.Client{}
if timeout > 0 {
httpClient.Timeout = timeout
}
// If we have a proxy URL configured, set up the transport
if proxyURL != "" {
transport := buildProxyTransport(proxyURL)
if transport != nil {
httpClient.Transport = transport
// Cache the client
httpClientCacheMutex.Lock()
httpClientCache[cacheKey] = httpClient
httpClientCacheMutex.Unlock()
return httpClient
}
// If proxy setup failed, log and fall through to context RoundTripper
@@ -60,6 +92,13 @@ func newProxyAwareHTTPClient(ctx context.Context, cfg *config.Config, auth *clip
httpClient.Transport = rt
}
// Cache the client for no-proxy case
if proxyURL == "" {
httpClientCacheMutex.Lock()
httpClientCache[cacheKey] = httpClient
httpClientCacheMutex.Unlock()
}
return httpClient
}

View File

@@ -2,43 +2,109 @@ package executor
import (
"fmt"
"regexp"
"strconv"
"strings"
"sync"
"github.com/tidwall/gjson"
"github.com/tiktoken-go/tokenizer"
)
// tokenizerCache stores tokenizer instances to avoid repeated creation
var tokenizerCache sync.Map
// TokenizerWrapper wraps a tokenizer codec with an adjustment factor for models
// where tiktoken may not accurately estimate token counts (e.g., Claude models)
type TokenizerWrapper struct {
Codec tokenizer.Codec
AdjustmentFactor float64 // 1.0 means no adjustment, >1.0 means tiktoken underestimates
}
// Count returns the token count with adjustment factor applied
func (tw *TokenizerWrapper) Count(text string) (int, error) {
count, err := tw.Codec.Count(text)
if err != nil {
return 0, err
}
if tw.AdjustmentFactor != 1.0 && tw.AdjustmentFactor > 0 {
return int(float64(count) * tw.AdjustmentFactor), nil
}
return count, nil
}
// getTokenizer returns a cached tokenizer for the given model.
// This improves performance by avoiding repeated tokenizer creation.
func getTokenizer(model string) (*TokenizerWrapper, error) {
// Check cache first
if cached, ok := tokenizerCache.Load(model); ok {
return cached.(*TokenizerWrapper), nil
}
// Cache miss, create new tokenizer
wrapper, err := tokenizerForModel(model)
if err != nil {
return nil, err
}
// Store in cache (use LoadOrStore to handle race conditions)
actual, _ := tokenizerCache.LoadOrStore(model, wrapper)
return actual.(*TokenizerWrapper), nil
}
// tokenizerForModel returns a tokenizer codec suitable for an OpenAI-style model id.
func tokenizerForModel(model string) (tokenizer.Codec, error) {
// For Claude models, applies a 1.1 adjustment factor since tiktoken may underestimate.
func tokenizerForModel(model string) (*TokenizerWrapper, error) {
sanitized := strings.ToLower(strings.TrimSpace(model))
// Claude models use cl100k_base with 1.1 adjustment factor
// because tiktoken may underestimate Claude's actual token count
if strings.Contains(sanitized, "claude") || strings.HasPrefix(sanitized, "kiro-") || strings.HasPrefix(sanitized, "amazonq-") {
enc, err := tokenizer.Get(tokenizer.Cl100kBase)
if err != nil {
return nil, err
}
return &TokenizerWrapper{Codec: enc, AdjustmentFactor: 1.1}, nil
}
var enc tokenizer.Codec
var err error
switch {
case sanitized == "":
return tokenizer.Get(tokenizer.Cl100kBase)
case strings.HasPrefix(sanitized, "gpt-5"):
return tokenizer.ForModel(tokenizer.GPT5)
enc, err = tokenizer.Get(tokenizer.Cl100kBase)
case strings.HasPrefix(sanitized, "gpt-5.2"):
enc, err = tokenizer.ForModel(tokenizer.GPT5)
case strings.HasPrefix(sanitized, "gpt-5.1"):
return tokenizer.ForModel(tokenizer.GPT5)
enc, err = tokenizer.ForModel(tokenizer.GPT5)
case strings.HasPrefix(sanitized, "gpt-5"):
enc, err = tokenizer.ForModel(tokenizer.GPT5)
case strings.HasPrefix(sanitized, "gpt-4.1"):
return tokenizer.ForModel(tokenizer.GPT41)
enc, err = tokenizer.ForModel(tokenizer.GPT41)
case strings.HasPrefix(sanitized, "gpt-4o"):
return tokenizer.ForModel(tokenizer.GPT4o)
enc, err = tokenizer.ForModel(tokenizer.GPT4o)
case strings.HasPrefix(sanitized, "gpt-4"):
return tokenizer.ForModel(tokenizer.GPT4)
enc, err = tokenizer.ForModel(tokenizer.GPT4)
case strings.HasPrefix(sanitized, "gpt-3.5"), strings.HasPrefix(sanitized, "gpt-3"):
return tokenizer.ForModel(tokenizer.GPT35Turbo)
enc, err = tokenizer.ForModel(tokenizer.GPT35Turbo)
case strings.HasPrefix(sanitized, "o1"):
return tokenizer.ForModel(tokenizer.O1)
enc, err = tokenizer.ForModel(tokenizer.O1)
case strings.HasPrefix(sanitized, "o3"):
return tokenizer.ForModel(tokenizer.O3)
enc, err = tokenizer.ForModel(tokenizer.O3)
case strings.HasPrefix(sanitized, "o4"):
return tokenizer.ForModel(tokenizer.O4Mini)
enc, err = tokenizer.ForModel(tokenizer.O4Mini)
default:
return tokenizer.Get(tokenizer.O200kBase)
enc, err = tokenizer.Get(tokenizer.O200kBase)
}
if err != nil {
return nil, err
}
return &TokenizerWrapper{Codec: enc, AdjustmentFactor: 1.0}, nil
}
// countOpenAIChatTokens approximates prompt tokens for OpenAI chat completions payloads.
func countOpenAIChatTokens(enc tokenizer.Codec, payload []byte) (int64, error) {
func countOpenAIChatTokens(enc *TokenizerWrapper, payload []byte) (int64, error) {
if enc == nil {
return 0, fmt.Errorf("encoder is nil")
}
@@ -62,11 +128,206 @@ func countOpenAIChatTokens(enc tokenizer.Codec, payload []byte) (int64, error) {
return 0, nil
}
// Count text tokens
count, err := enc.Count(joined)
if err != nil {
return 0, err
}
return int64(count), nil
// Extract and add image tokens from placeholders
imageTokens := extractImageTokens(joined)
return int64(count) + int64(imageTokens), nil
}
// countClaudeChatTokens approximates prompt tokens for Claude API chat completions payloads.
// This handles Claude's message format with system, messages, and tools.
// Image tokens are estimated based on image dimensions when available.
func countClaudeChatTokens(enc *TokenizerWrapper, payload []byte) (int64, error) {
if enc == nil {
return 0, fmt.Errorf("encoder is nil")
}
if len(payload) == 0 {
return 0, nil
}
root := gjson.ParseBytes(payload)
segments := make([]string, 0, 32)
// Collect system prompt (can be string or array of content blocks)
collectClaudeSystem(root.Get("system"), &segments)
// Collect messages
collectClaudeMessages(root.Get("messages"), &segments)
// Collect tools
collectClaudeTools(root.Get("tools"), &segments)
joined := strings.TrimSpace(strings.Join(segments, "\n"))
if joined == "" {
return 0, nil
}
// Count text tokens
count, err := enc.Count(joined)
if err != nil {
return 0, err
}
// Extract and add image tokens from placeholders
imageTokens := extractImageTokens(joined)
return int64(count) + int64(imageTokens), nil
}
// imageTokenPattern matches [IMAGE:xxx tokens] format for extracting estimated image tokens
var imageTokenPattern = regexp.MustCompile(`\[IMAGE:(\d+) tokens\]`)
// extractImageTokens extracts image token estimates from placeholder text.
// Placeholders are in the format [IMAGE:xxx tokens] where xxx is the estimated token count.
func extractImageTokens(text string) int {
matches := imageTokenPattern.FindAllStringSubmatch(text, -1)
total := 0
for _, match := range matches {
if len(match) > 1 {
if tokens, err := strconv.Atoi(match[1]); err == nil {
total += tokens
}
}
}
return total
}
// estimateImageTokens calculates estimated tokens for an image based on dimensions.
// Based on Claude's image token calculation: tokens ≈ (width * height) / 750
// Minimum 85 tokens, maximum 1590 tokens (for 1568x1568 images).
func estimateImageTokens(width, height float64) int {
if width <= 0 || height <= 0 {
// No valid dimensions, use default estimate (medium-sized image)
return 1000
}
tokens := int(width * height / 750)
// Apply bounds
if tokens < 85 {
tokens = 85
}
if tokens > 1590 {
tokens = 1590
}
return tokens
}
// collectClaudeSystem extracts text from Claude's system field.
// System can be a string or an array of content blocks.
func collectClaudeSystem(system gjson.Result, segments *[]string) {
if !system.Exists() {
return
}
if system.Type == gjson.String {
addIfNotEmpty(segments, system.String())
return
}
if system.IsArray() {
system.ForEach(func(_, block gjson.Result) bool {
blockType := block.Get("type").String()
if blockType == "text" || blockType == "" {
addIfNotEmpty(segments, block.Get("text").String())
}
// Also handle plain string blocks
if block.Type == gjson.String {
addIfNotEmpty(segments, block.String())
}
return true
})
}
}
// collectClaudeMessages extracts text from Claude's messages array.
func collectClaudeMessages(messages gjson.Result, segments *[]string) {
if !messages.Exists() || !messages.IsArray() {
return
}
messages.ForEach(func(_, message gjson.Result) bool {
addIfNotEmpty(segments, message.Get("role").String())
collectClaudeContent(message.Get("content"), segments)
return true
})
}
// collectClaudeContent extracts text from Claude's content field.
// Content can be a string or an array of content blocks.
// For images, estimates token count based on dimensions when available.
func collectClaudeContent(content gjson.Result, segments *[]string) {
if !content.Exists() {
return
}
if content.Type == gjson.String {
addIfNotEmpty(segments, content.String())
return
}
if content.IsArray() {
content.ForEach(func(_, part gjson.Result) bool {
partType := part.Get("type").String()
switch partType {
case "text":
addIfNotEmpty(segments, part.Get("text").String())
case "image":
// Estimate image tokens based on dimensions if available
source := part.Get("source")
if source.Exists() {
width := source.Get("width").Float()
height := source.Get("height").Float()
if width > 0 && height > 0 {
tokens := estimateImageTokens(width, height)
addIfNotEmpty(segments, fmt.Sprintf("[IMAGE:%d tokens]", tokens))
} else {
// No dimensions available, use default estimate
addIfNotEmpty(segments, "[IMAGE:1000 tokens]")
}
} else {
// No source info, use default estimate
addIfNotEmpty(segments, "[IMAGE:1000 tokens]")
}
case "tool_use":
addIfNotEmpty(segments, part.Get("id").String())
addIfNotEmpty(segments, part.Get("name").String())
if input := part.Get("input"); input.Exists() {
addIfNotEmpty(segments, input.Raw)
}
case "tool_result":
addIfNotEmpty(segments, part.Get("tool_use_id").String())
collectClaudeContent(part.Get("content"), segments)
case "thinking":
addIfNotEmpty(segments, part.Get("thinking").String())
default:
// For unknown types, try to extract any text content
if part.Type == gjson.String {
addIfNotEmpty(segments, part.String())
} else if part.Type == gjson.JSON {
addIfNotEmpty(segments, part.Raw)
}
}
return true
})
}
}
// collectClaudeTools extracts text from Claude's tools array.
func collectClaudeTools(tools gjson.Result, segments *[]string) {
if !tools.Exists() || !tools.IsArray() {
return
}
tools.ForEach(func(_, tool gjson.Result) bool {
addIfNotEmpty(segments, tool.Get("name").String())
addIfNotEmpty(segments, tool.Get("description").String())
if inputSchema := tool.Get("input_schema"); inputSchema.Exists() {
addIfNotEmpty(segments, inputSchema.Raw)
}
return true
})
}
// buildOpenAIUsageJSON returns a minimal usage structure understood by downstream translators.

View File

@@ -252,6 +252,44 @@ func parseOpenAIStreamUsage(line []byte) (usage.Detail, bool) {
return detail, true
}
func parseOpenAIResponsesUsageDetail(usageNode gjson.Result) usage.Detail {
detail := usage.Detail{
InputTokens: usageNode.Get("input_tokens").Int(),
OutputTokens: usageNode.Get("output_tokens").Int(),
TotalTokens: usageNode.Get("total_tokens").Int(),
}
if detail.TotalTokens == 0 {
detail.TotalTokens = detail.InputTokens + detail.OutputTokens
}
if cached := usageNode.Get("input_tokens_details.cached_tokens"); cached.Exists() {
detail.CachedTokens = cached.Int()
}
if reasoning := usageNode.Get("output_tokens_details.reasoning_tokens"); reasoning.Exists() {
detail.ReasoningTokens = reasoning.Int()
}
return detail
}
func parseOpenAIResponsesUsage(data []byte) usage.Detail {
usageNode := gjson.ParseBytes(data).Get("usage")
if !usageNode.Exists() {
return usage.Detail{}
}
return parseOpenAIResponsesUsageDetail(usageNode)
}
func parseOpenAIResponsesStreamUsage(line []byte) (usage.Detail, bool) {
payload := jsonPayload(line)
if len(payload) == 0 || !gjson.ValidBytes(payload) {
return usage.Detail{}, false
}
usageNode := gjson.GetBytes(payload, "usage")
if !usageNode.Exists() {
return usage.Detail{}, false
}
return parseOpenAIResponsesUsageDetail(usageNode), true
}
func parseClaudeUsage(data []byte) usage.Detail {
usageNode := gjson.ParseBytes(data).Get("usage")
if !usageNode.Exists() {