mirror of
https://github.com/router-for-me/CLIProxyAPIPlus.git
synced 2026-04-24 03:50:33 +00:00
Merge branch 'main' into plus
This commit is contained in:
@@ -1007,7 +1007,12 @@ func (e *AntigravityExecutor) CountTokens(ctx context.Context, auth *cliproxyaut
|
||||
func FetchAntigravityModels(ctx context.Context, auth *cliproxyauth.Auth, cfg *config.Config) []*registry.ModelInfo {
|
||||
exec := &AntigravityExecutor{cfg: cfg}
|
||||
token, updatedAuth, errToken := exec.ensureAccessToken(ctx, auth)
|
||||
if errToken != nil || token == "" {
|
||||
if errToken != nil {
|
||||
log.Warnf("antigravity executor: fetch models failed for %s: token error: %v", auth.ID, errToken)
|
||||
return nil
|
||||
}
|
||||
if token == "" {
|
||||
log.Warnf("antigravity executor: fetch models failed for %s: got empty token", auth.ID)
|
||||
return nil
|
||||
}
|
||||
if updatedAuth != nil {
|
||||
@@ -1021,6 +1026,7 @@ func FetchAntigravityModels(ctx context.Context, auth *cliproxyauth.Auth, cfg *c
|
||||
modelsURL := baseURL + antigravityModelsPath
|
||||
httpReq, errReq := http.NewRequestWithContext(ctx, http.MethodPost, modelsURL, bytes.NewReader([]byte(`{}`)))
|
||||
if errReq != nil {
|
||||
log.Warnf("antigravity executor: fetch models failed for %s: create request error: %v", auth.ID, errReq)
|
||||
return nil
|
||||
}
|
||||
httpReq.Header.Set("Content-Type", "application/json")
|
||||
@@ -1033,12 +1039,14 @@ func FetchAntigravityModels(ctx context.Context, auth *cliproxyauth.Auth, cfg *c
|
||||
httpResp, errDo := httpClient.Do(httpReq)
|
||||
if errDo != nil {
|
||||
if errors.Is(errDo, context.Canceled) || errors.Is(errDo, context.DeadlineExceeded) {
|
||||
log.Warnf("antigravity executor: fetch models failed for %s: context canceled: %v", auth.ID, errDo)
|
||||
return nil
|
||||
}
|
||||
if idx+1 < len(baseURLs) {
|
||||
log.Debugf("antigravity executor: models request error on base url %s, retrying with fallback base url: %s", baseURL, baseURLs[idx+1])
|
||||
continue
|
||||
}
|
||||
log.Warnf("antigravity executor: fetch models failed for %s: request error: %v", auth.ID, errDo)
|
||||
return nil
|
||||
}
|
||||
|
||||
@@ -1051,6 +1059,7 @@ func FetchAntigravityModels(ctx context.Context, auth *cliproxyauth.Auth, cfg *c
|
||||
log.Debugf("antigravity executor: models read error on base url %s, retrying with fallback base url: %s", baseURL, baseURLs[idx+1])
|
||||
continue
|
||||
}
|
||||
log.Warnf("antigravity executor: fetch models failed for %s: read body error: %v", auth.ID, errRead)
|
||||
return nil
|
||||
}
|
||||
if httpResp.StatusCode < http.StatusOK || httpResp.StatusCode >= http.StatusMultipleChoices {
|
||||
@@ -1058,11 +1067,13 @@ func FetchAntigravityModels(ctx context.Context, auth *cliproxyauth.Auth, cfg *c
|
||||
log.Debugf("antigravity executor: models request rate limited on base url %s, retrying with fallback base url: %s", baseURL, baseURLs[idx+1])
|
||||
continue
|
||||
}
|
||||
log.Warnf("antigravity executor: fetch models failed for %s: unexpected status %d, body: %s", auth.ID, httpResp.StatusCode, string(bodyBytes))
|
||||
return nil
|
||||
}
|
||||
|
||||
result := gjson.GetBytes(bodyBytes, "models")
|
||||
if !result.Exists() {
|
||||
log.Warnf("antigravity executor: fetch models failed for %s: no models field in response, body: %s", auth.ID, string(bodyBytes))
|
||||
return nil
|
||||
}
|
||||
|
||||
|
||||
@@ -29,6 +29,7 @@ func startCodexCacheCleanup() {
|
||||
go func() {
|
||||
ticker := time.NewTicker(codexCacheCleanupInterval)
|
||||
defer ticker.Stop()
|
||||
|
||||
for range ticker.C {
|
||||
purgeExpiredCodexCache()
|
||||
}
|
||||
@@ -38,8 +39,10 @@ func startCodexCacheCleanup() {
|
||||
// purgeExpiredCodexCache removes entries that have expired.
|
||||
func purgeExpiredCodexCache() {
|
||||
now := time.Now()
|
||||
|
||||
codexCacheMu.Lock()
|
||||
defer codexCacheMu.Unlock()
|
||||
|
||||
for key, cache := range codexCacheMap {
|
||||
if cache.Expire.Before(now) {
|
||||
delete(codexCacheMap, key)
|
||||
@@ -66,3 +69,10 @@ func setCodexCache(key string, cache codexCache) {
|
||||
codexCacheMap[key] = cache
|
||||
codexCacheMu.Unlock()
|
||||
}
|
||||
|
||||
// deleteCodexCache deletes a cache entry.
|
||||
func deleteCodexCache(key string) {
|
||||
codexCacheMu.Lock()
|
||||
delete(codexCacheMap, key)
|
||||
codexCacheMu.Unlock()
|
||||
}
|
||||
|
||||
976
internal/runtime/executor/github_copilot_executor.go
Normal file
976
internal/runtime/executor/github_copilot_executor.go
Normal file
@@ -0,0 +1,976 @@
|
||||
package executor
|
||||
|
||||
import (
|
||||
"bufio"
|
||||
"bytes"
|
||||
"context"
|
||||
"fmt"
|
||||
"io"
|
||||
"net/http"
|
||||
"strings"
|
||||
"sync"
|
||||
"time"
|
||||
|
||||
"github.com/google/uuid"
|
||||
copilotauth "github.com/router-for-me/CLIProxyAPI/v6/internal/auth/copilot"
|
||||
"github.com/router-for-me/CLIProxyAPI/v6/internal/config"
|
||||
"github.com/router-for-me/CLIProxyAPI/v6/internal/thinking"
|
||||
cliproxyauth "github.com/router-for-me/CLIProxyAPI/v6/sdk/cliproxy/auth"
|
||||
cliproxyexecutor "github.com/router-for-me/CLIProxyAPI/v6/sdk/cliproxy/executor"
|
||||
sdktranslator "github.com/router-for-me/CLIProxyAPI/v6/sdk/translator"
|
||||
log "github.com/sirupsen/logrus"
|
||||
"github.com/tidwall/gjson"
|
||||
"github.com/tidwall/sjson"
|
||||
)
|
||||
|
||||
const (
|
||||
githubCopilotBaseURL = "https://api.githubcopilot.com"
|
||||
githubCopilotChatPath = "/chat/completions"
|
||||
githubCopilotResponsesPath = "/responses"
|
||||
githubCopilotAuthType = "github-copilot"
|
||||
githubCopilotTokenCacheTTL = 25 * time.Minute
|
||||
// tokenExpiryBuffer is the time before expiry when we should refresh the token.
|
||||
tokenExpiryBuffer = 5 * time.Minute
|
||||
// maxScannerBufferSize is the maximum buffer size for SSE scanning (20MB).
|
||||
maxScannerBufferSize = 20_971_520
|
||||
|
||||
// Copilot API header values.
|
||||
copilotUserAgent = "GitHubCopilotChat/0.35.0"
|
||||
copilotEditorVersion = "vscode/1.107.0"
|
||||
copilotPluginVersion = "copilot-chat/0.35.0"
|
||||
copilotIntegrationID = "vscode-chat"
|
||||
copilotOpenAIIntent = "conversation-edits"
|
||||
)
|
||||
|
||||
// GitHubCopilotExecutor handles requests to the GitHub Copilot API.
|
||||
type GitHubCopilotExecutor struct {
|
||||
cfg *config.Config
|
||||
mu sync.RWMutex
|
||||
cache map[string]*cachedAPIToken
|
||||
}
|
||||
|
||||
// cachedAPIToken stores a cached Copilot API token with its expiry.
|
||||
type cachedAPIToken struct {
|
||||
token string
|
||||
expiresAt time.Time
|
||||
}
|
||||
|
||||
// NewGitHubCopilotExecutor constructs a new executor instance.
|
||||
func NewGitHubCopilotExecutor(cfg *config.Config) *GitHubCopilotExecutor {
|
||||
return &GitHubCopilotExecutor{
|
||||
cfg: cfg,
|
||||
cache: make(map[string]*cachedAPIToken),
|
||||
}
|
||||
}
|
||||
|
||||
// Identifier implements ProviderExecutor.
|
||||
func (e *GitHubCopilotExecutor) Identifier() string { return githubCopilotAuthType }
|
||||
|
||||
// PrepareRequest implements ProviderExecutor.
|
||||
func (e *GitHubCopilotExecutor) PrepareRequest(req *http.Request, auth *cliproxyauth.Auth) error {
|
||||
if req == nil {
|
||||
return nil
|
||||
}
|
||||
ctx := req.Context()
|
||||
if ctx == nil {
|
||||
ctx = context.Background()
|
||||
}
|
||||
apiToken, errToken := e.ensureAPIToken(ctx, auth)
|
||||
if errToken != nil {
|
||||
return errToken
|
||||
}
|
||||
e.applyHeaders(req, apiToken, nil)
|
||||
return nil
|
||||
}
|
||||
|
||||
// HttpRequest injects GitHub Copilot credentials into the request and executes it.
|
||||
func (e *GitHubCopilotExecutor) HttpRequest(ctx context.Context, auth *cliproxyauth.Auth, req *http.Request) (*http.Response, error) {
|
||||
if req == nil {
|
||||
return nil, fmt.Errorf("github-copilot executor: request is nil")
|
||||
}
|
||||
if ctx == nil {
|
||||
ctx = req.Context()
|
||||
}
|
||||
httpReq := req.WithContext(ctx)
|
||||
if errPrepare := e.PrepareRequest(httpReq, auth); errPrepare != nil {
|
||||
return nil, errPrepare
|
||||
}
|
||||
httpClient := newProxyAwareHTTPClient(ctx, e.cfg, auth, 0)
|
||||
return httpClient.Do(httpReq)
|
||||
}
|
||||
|
||||
// Execute handles non-streaming requests to GitHub Copilot.
|
||||
func (e *GitHubCopilotExecutor) Execute(ctx context.Context, auth *cliproxyauth.Auth, req cliproxyexecutor.Request, opts cliproxyexecutor.Options) (resp cliproxyexecutor.Response, err error) {
|
||||
apiToken, errToken := e.ensureAPIToken(ctx, auth)
|
||||
if errToken != nil {
|
||||
return resp, errToken
|
||||
}
|
||||
|
||||
reporter := newUsageReporter(ctx, e.Identifier(), req.Model, auth)
|
||||
defer reporter.trackFailure(ctx, &err)
|
||||
|
||||
from := opts.SourceFormat
|
||||
useResponses := useGitHubCopilotResponsesEndpoint(from, req.Model)
|
||||
to := sdktranslator.FromString("openai")
|
||||
if useResponses {
|
||||
to = sdktranslator.FromString("openai-response")
|
||||
}
|
||||
originalPayload := bytes.Clone(req.Payload)
|
||||
if len(opts.OriginalRequest) > 0 {
|
||||
originalPayload = bytes.Clone(opts.OriginalRequest)
|
||||
}
|
||||
originalTranslated := sdktranslator.TranslateRequest(from, to, req.Model, originalPayload, false)
|
||||
body := sdktranslator.TranslateRequest(from, to, req.Model, bytes.Clone(req.Payload), false)
|
||||
body = e.normalizeModel(req.Model, body)
|
||||
body = flattenAssistantContent(body)
|
||||
|
||||
thinkingProvider := "openai"
|
||||
if useResponses {
|
||||
thinkingProvider = "codex"
|
||||
}
|
||||
body, err = thinking.ApplyThinking(body, req.Model, from.String(), thinkingProvider, e.Identifier())
|
||||
if err != nil {
|
||||
return resp, err
|
||||
}
|
||||
|
||||
if useResponses {
|
||||
body = normalizeGitHubCopilotResponsesInput(body)
|
||||
body = normalizeGitHubCopilotResponsesTools(body)
|
||||
} else {
|
||||
body = normalizeGitHubCopilotChatTools(body)
|
||||
}
|
||||
requestedModel := payloadRequestedModel(opts, req.Model)
|
||||
body = applyPayloadConfigWithRoot(e.cfg, req.Model, to.String(), "", body, originalTranslated, requestedModel)
|
||||
body, _ = sjson.SetBytes(body, "stream", false)
|
||||
|
||||
path := githubCopilotChatPath
|
||||
if useResponses {
|
||||
path = githubCopilotResponsesPath
|
||||
}
|
||||
url := githubCopilotBaseURL + path
|
||||
httpReq, err := http.NewRequestWithContext(ctx, http.MethodPost, url, bytes.NewReader(body))
|
||||
if err != nil {
|
||||
return resp, err
|
||||
}
|
||||
e.applyHeaders(httpReq, apiToken, body)
|
||||
|
||||
// Add Copilot-Vision-Request header if the request contains vision content
|
||||
if detectVisionContent(body) {
|
||||
httpReq.Header.Set("Copilot-Vision-Request", "true")
|
||||
}
|
||||
|
||||
var authID, authLabel, authType, authValue string
|
||||
if auth != nil {
|
||||
authID = auth.ID
|
||||
authLabel = auth.Label
|
||||
authType, authValue = auth.AccountInfo()
|
||||
}
|
||||
recordAPIRequest(ctx, e.cfg, upstreamRequestLog{
|
||||
URL: url,
|
||||
Method: http.MethodPost,
|
||||
Headers: httpReq.Header.Clone(),
|
||||
Body: body,
|
||||
Provider: e.Identifier(),
|
||||
AuthID: authID,
|
||||
AuthLabel: authLabel,
|
||||
AuthType: authType,
|
||||
AuthValue: authValue,
|
||||
})
|
||||
|
||||
httpClient := newProxyAwareHTTPClient(ctx, e.cfg, auth, 0)
|
||||
httpResp, err := httpClient.Do(httpReq)
|
||||
if err != nil {
|
||||
recordAPIResponseError(ctx, e.cfg, err)
|
||||
return resp, err
|
||||
}
|
||||
defer func() {
|
||||
if errClose := httpResp.Body.Close(); errClose != nil {
|
||||
log.Errorf("github-copilot executor: close response body error: %v", errClose)
|
||||
}
|
||||
}()
|
||||
|
||||
recordAPIResponseMetadata(ctx, e.cfg, httpResp.StatusCode, httpResp.Header.Clone())
|
||||
|
||||
if !isHTTPSuccess(httpResp.StatusCode) {
|
||||
data, _ := io.ReadAll(httpResp.Body)
|
||||
appendAPIResponseChunk(ctx, e.cfg, data)
|
||||
log.Debugf("github-copilot executor: upstream error status: %d, body: %s", httpResp.StatusCode, summarizeErrorBody(httpResp.Header.Get("Content-Type"), data))
|
||||
err = statusErr{code: httpResp.StatusCode, msg: string(data)}
|
||||
return resp, err
|
||||
}
|
||||
|
||||
data, err := io.ReadAll(httpResp.Body)
|
||||
if err != nil {
|
||||
recordAPIResponseError(ctx, e.cfg, err)
|
||||
return resp, err
|
||||
}
|
||||
appendAPIResponseChunk(ctx, e.cfg, data)
|
||||
|
||||
detail := parseOpenAIUsage(data)
|
||||
if useResponses && detail.TotalTokens == 0 {
|
||||
detail = parseOpenAIResponsesUsage(data)
|
||||
}
|
||||
if detail.TotalTokens > 0 {
|
||||
reporter.publish(ctx, detail)
|
||||
}
|
||||
|
||||
var param any
|
||||
converted := ""
|
||||
if useResponses && from.String() == "claude" {
|
||||
converted = translateGitHubCopilotResponsesNonStreamToClaude(data)
|
||||
} else {
|
||||
converted = sdktranslator.TranslateNonStream(ctx, to, from, req.Model, bytes.Clone(opts.OriginalRequest), body, data, ¶m)
|
||||
}
|
||||
resp = cliproxyexecutor.Response{Payload: []byte(converted)}
|
||||
reporter.ensurePublished(ctx)
|
||||
return resp, nil
|
||||
}
|
||||
|
||||
// ExecuteStream handles streaming requests to GitHub Copilot.
|
||||
func (e *GitHubCopilotExecutor) ExecuteStream(ctx context.Context, auth *cliproxyauth.Auth, req cliproxyexecutor.Request, opts cliproxyexecutor.Options) (stream <-chan cliproxyexecutor.StreamChunk, err error) {
|
||||
apiToken, errToken := e.ensureAPIToken(ctx, auth)
|
||||
if errToken != nil {
|
||||
return nil, errToken
|
||||
}
|
||||
|
||||
reporter := newUsageReporter(ctx, e.Identifier(), req.Model, auth)
|
||||
defer reporter.trackFailure(ctx, &err)
|
||||
|
||||
from := opts.SourceFormat
|
||||
useResponses := useGitHubCopilotResponsesEndpoint(from, req.Model)
|
||||
to := sdktranslator.FromString("openai")
|
||||
if useResponses {
|
||||
to = sdktranslator.FromString("openai-response")
|
||||
}
|
||||
originalPayload := bytes.Clone(req.Payload)
|
||||
if len(opts.OriginalRequest) > 0 {
|
||||
originalPayload = bytes.Clone(opts.OriginalRequest)
|
||||
}
|
||||
originalTranslated := sdktranslator.TranslateRequest(from, to, req.Model, originalPayload, false)
|
||||
body := sdktranslator.TranslateRequest(from, to, req.Model, bytes.Clone(req.Payload), true)
|
||||
body = e.normalizeModel(req.Model, body)
|
||||
body = flattenAssistantContent(body)
|
||||
|
||||
thinkingProvider := "openai"
|
||||
if useResponses {
|
||||
thinkingProvider = "codex"
|
||||
}
|
||||
body, err = thinking.ApplyThinking(body, req.Model, from.String(), thinkingProvider, e.Identifier())
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
if useResponses {
|
||||
body = normalizeGitHubCopilotResponsesInput(body)
|
||||
body = normalizeGitHubCopilotResponsesTools(body)
|
||||
} else {
|
||||
body = normalizeGitHubCopilotChatTools(body)
|
||||
}
|
||||
requestedModel := payloadRequestedModel(opts, req.Model)
|
||||
body = applyPayloadConfigWithRoot(e.cfg, req.Model, to.String(), "", body, originalTranslated, requestedModel)
|
||||
body, _ = sjson.SetBytes(body, "stream", true)
|
||||
// Enable stream options for usage stats in stream
|
||||
if !useResponses {
|
||||
body, _ = sjson.SetBytes(body, "stream_options.include_usage", true)
|
||||
}
|
||||
|
||||
path := githubCopilotChatPath
|
||||
if useResponses {
|
||||
path = githubCopilotResponsesPath
|
||||
}
|
||||
url := githubCopilotBaseURL + path
|
||||
httpReq, err := http.NewRequestWithContext(ctx, http.MethodPost, url, bytes.NewReader(body))
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
e.applyHeaders(httpReq, apiToken, body)
|
||||
|
||||
// Add Copilot-Vision-Request header if the request contains vision content
|
||||
if detectVisionContent(body) {
|
||||
httpReq.Header.Set("Copilot-Vision-Request", "true")
|
||||
}
|
||||
|
||||
var authID, authLabel, authType, authValue string
|
||||
if auth != nil {
|
||||
authID = auth.ID
|
||||
authLabel = auth.Label
|
||||
authType, authValue = auth.AccountInfo()
|
||||
}
|
||||
recordAPIRequest(ctx, e.cfg, upstreamRequestLog{
|
||||
URL: url,
|
||||
Method: http.MethodPost,
|
||||
Headers: httpReq.Header.Clone(),
|
||||
Body: body,
|
||||
Provider: e.Identifier(),
|
||||
AuthID: authID,
|
||||
AuthLabel: authLabel,
|
||||
AuthType: authType,
|
||||
AuthValue: authValue,
|
||||
})
|
||||
|
||||
httpClient := newProxyAwareHTTPClient(ctx, e.cfg, auth, 0)
|
||||
httpResp, err := httpClient.Do(httpReq)
|
||||
if err != nil {
|
||||
recordAPIResponseError(ctx, e.cfg, err)
|
||||
return nil, err
|
||||
}
|
||||
|
||||
recordAPIResponseMetadata(ctx, e.cfg, httpResp.StatusCode, httpResp.Header.Clone())
|
||||
|
||||
if !isHTTPSuccess(httpResp.StatusCode) {
|
||||
data, readErr := io.ReadAll(httpResp.Body)
|
||||
if errClose := httpResp.Body.Close(); errClose != nil {
|
||||
log.Errorf("github-copilot executor: close response body error: %v", errClose)
|
||||
}
|
||||
if readErr != nil {
|
||||
recordAPIResponseError(ctx, e.cfg, readErr)
|
||||
return nil, readErr
|
||||
}
|
||||
appendAPIResponseChunk(ctx, e.cfg, data)
|
||||
log.Debugf("github-copilot executor: upstream error status: %d, body: %s", httpResp.StatusCode, summarizeErrorBody(httpResp.Header.Get("Content-Type"), data))
|
||||
err = statusErr{code: httpResp.StatusCode, msg: string(data)}
|
||||
return nil, err
|
||||
}
|
||||
|
||||
out := make(chan cliproxyexecutor.StreamChunk)
|
||||
stream = out
|
||||
|
||||
go func() {
|
||||
defer close(out)
|
||||
defer func() {
|
||||
if errClose := httpResp.Body.Close(); errClose != nil {
|
||||
log.Errorf("github-copilot executor: close response body error: %v", errClose)
|
||||
}
|
||||
}()
|
||||
|
||||
scanner := bufio.NewScanner(httpResp.Body)
|
||||
scanner.Buffer(nil, maxScannerBufferSize)
|
||||
var param any
|
||||
|
||||
for scanner.Scan() {
|
||||
line := scanner.Bytes()
|
||||
appendAPIResponseChunk(ctx, e.cfg, line)
|
||||
|
||||
// Parse SSE data
|
||||
if bytes.HasPrefix(line, dataTag) {
|
||||
data := bytes.TrimSpace(line[5:])
|
||||
if bytes.Equal(data, []byte("[DONE]")) {
|
||||
continue
|
||||
}
|
||||
if detail, ok := parseOpenAIStreamUsage(line); ok {
|
||||
reporter.publish(ctx, detail)
|
||||
} else if useResponses {
|
||||
if detail, ok := parseOpenAIResponsesStreamUsage(line); ok {
|
||||
reporter.publish(ctx, detail)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
var chunks []string
|
||||
if useResponses && from.String() == "claude" {
|
||||
chunks = translateGitHubCopilotResponsesStreamToClaude(bytes.Clone(line), ¶m)
|
||||
} else {
|
||||
chunks = sdktranslator.TranslateStream(ctx, to, from, req.Model, bytes.Clone(opts.OriginalRequest), body, bytes.Clone(line), ¶m)
|
||||
}
|
||||
for i := range chunks {
|
||||
out <- cliproxyexecutor.StreamChunk{Payload: []byte(chunks[i])}
|
||||
}
|
||||
}
|
||||
|
||||
if errScan := scanner.Err(); errScan != nil {
|
||||
recordAPIResponseError(ctx, e.cfg, errScan)
|
||||
reporter.publishFailure(ctx)
|
||||
out <- cliproxyexecutor.StreamChunk{Err: errScan}
|
||||
} else {
|
||||
reporter.ensurePublished(ctx)
|
||||
}
|
||||
}()
|
||||
|
||||
return stream, nil
|
||||
}
|
||||
|
||||
// CountTokens is not supported for GitHub Copilot.
|
||||
func (e *GitHubCopilotExecutor) CountTokens(_ context.Context, _ *cliproxyauth.Auth, _ cliproxyexecutor.Request, _ cliproxyexecutor.Options) (cliproxyexecutor.Response, error) {
|
||||
return cliproxyexecutor.Response{}, statusErr{code: http.StatusNotImplemented, msg: "count tokens not supported for github-copilot"}
|
||||
}
|
||||
|
||||
// Refresh validates the GitHub token is still working.
|
||||
// GitHub OAuth tokens don't expire traditionally, so we just validate.
|
||||
func (e *GitHubCopilotExecutor) Refresh(ctx context.Context, auth *cliproxyauth.Auth) (*cliproxyauth.Auth, error) {
|
||||
if auth == nil {
|
||||
return nil, statusErr{code: http.StatusUnauthorized, msg: "missing auth"}
|
||||
}
|
||||
|
||||
// Get the GitHub access token
|
||||
accessToken := metaStringValue(auth.Metadata, "access_token")
|
||||
if accessToken == "" {
|
||||
return auth, nil
|
||||
}
|
||||
|
||||
// Validate the token can still get a Copilot API token
|
||||
copilotAuth := copilotauth.NewCopilotAuth(e.cfg)
|
||||
_, err := copilotAuth.GetCopilotAPIToken(ctx, accessToken)
|
||||
if err != nil {
|
||||
return nil, statusErr{code: http.StatusUnauthorized, msg: fmt.Sprintf("github-copilot token validation failed: %v", err)}
|
||||
}
|
||||
|
||||
return auth, nil
|
||||
}
|
||||
|
||||
// ensureAPIToken gets or refreshes the Copilot API token.
|
||||
func (e *GitHubCopilotExecutor) ensureAPIToken(ctx context.Context, auth *cliproxyauth.Auth) (string, error) {
|
||||
if auth == nil {
|
||||
return "", statusErr{code: http.StatusUnauthorized, msg: "missing auth"}
|
||||
}
|
||||
|
||||
// Get the GitHub access token
|
||||
accessToken := metaStringValue(auth.Metadata, "access_token")
|
||||
if accessToken == "" {
|
||||
return "", statusErr{code: http.StatusUnauthorized, msg: "missing github access token"}
|
||||
}
|
||||
|
||||
// Check for cached API token using thread-safe access
|
||||
e.mu.RLock()
|
||||
if cached, ok := e.cache[accessToken]; ok && cached.expiresAt.After(time.Now().Add(tokenExpiryBuffer)) {
|
||||
e.mu.RUnlock()
|
||||
return cached.token, nil
|
||||
}
|
||||
e.mu.RUnlock()
|
||||
|
||||
// Get a new Copilot API token
|
||||
copilotAuth := copilotauth.NewCopilotAuth(e.cfg)
|
||||
apiToken, err := copilotAuth.GetCopilotAPIToken(ctx, accessToken)
|
||||
if err != nil {
|
||||
return "", statusErr{code: http.StatusUnauthorized, msg: fmt.Sprintf("failed to get copilot api token: %v", err)}
|
||||
}
|
||||
|
||||
// Cache the token with thread-safe access
|
||||
expiresAt := time.Now().Add(githubCopilotTokenCacheTTL)
|
||||
if apiToken.ExpiresAt > 0 {
|
||||
expiresAt = time.Unix(apiToken.ExpiresAt, 0)
|
||||
}
|
||||
e.mu.Lock()
|
||||
e.cache[accessToken] = &cachedAPIToken{
|
||||
token: apiToken.Token,
|
||||
expiresAt: expiresAt,
|
||||
}
|
||||
e.mu.Unlock()
|
||||
|
||||
return apiToken.Token, nil
|
||||
}
|
||||
|
||||
// applyHeaders sets the required headers for GitHub Copilot API requests.
|
||||
func (e *GitHubCopilotExecutor) applyHeaders(r *http.Request, apiToken string, body []byte) {
|
||||
r.Header.Set("Content-Type", "application/json")
|
||||
r.Header.Set("Authorization", "Bearer "+apiToken)
|
||||
r.Header.Set("Accept", "application/json")
|
||||
r.Header.Set("User-Agent", copilotUserAgent)
|
||||
r.Header.Set("Editor-Version", copilotEditorVersion)
|
||||
r.Header.Set("Editor-Plugin-Version", copilotPluginVersion)
|
||||
r.Header.Set("Openai-Intent", copilotOpenAIIntent)
|
||||
r.Header.Set("Copilot-Integration-Id", copilotIntegrationID)
|
||||
r.Header.Set("X-Request-Id", uuid.NewString())
|
||||
|
||||
initiator := "user"
|
||||
if len(body) > 0 {
|
||||
if messages := gjson.GetBytes(body, "messages"); messages.Exists() && messages.IsArray() {
|
||||
arr := messages.Array()
|
||||
if len(arr) > 0 {
|
||||
lastRole := arr[len(arr)-1].Get("role").String()
|
||||
if lastRole != "" && lastRole != "user" {
|
||||
initiator = "agent"
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
r.Header.Set("X-Initiator", initiator)
|
||||
}
|
||||
|
||||
// detectVisionContent checks if the request body contains vision/image content.
|
||||
// Returns true if the request includes image_url or image type content blocks.
|
||||
func detectVisionContent(body []byte) bool {
|
||||
// Parse messages array
|
||||
messagesResult := gjson.GetBytes(body, "messages")
|
||||
if !messagesResult.Exists() || !messagesResult.IsArray() {
|
||||
return false
|
||||
}
|
||||
|
||||
// Check each message for vision content
|
||||
for _, message := range messagesResult.Array() {
|
||||
content := message.Get("content")
|
||||
|
||||
// If content is an array, check each content block
|
||||
if content.IsArray() {
|
||||
for _, block := range content.Array() {
|
||||
blockType := block.Get("type").String()
|
||||
// Check for image_url or image type
|
||||
if blockType == "image_url" || blockType == "image" {
|
||||
return true
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
return false
|
||||
}
|
||||
|
||||
// normalizeModel strips the suffix (e.g. "(medium)") from the model name
|
||||
// before sending to GitHub Copilot, as the upstream API does not accept
|
||||
// suffixed model identifiers.
|
||||
func (e *GitHubCopilotExecutor) normalizeModel(model string, body []byte) []byte {
|
||||
baseModel := thinking.ParseSuffix(model).ModelName
|
||||
if baseModel != model {
|
||||
body, _ = sjson.SetBytes(body, "model", baseModel)
|
||||
}
|
||||
return body
|
||||
}
|
||||
|
||||
func useGitHubCopilotResponsesEndpoint(sourceFormat sdktranslator.Format, model string) bool {
|
||||
if sourceFormat.String() == "openai-response" {
|
||||
return true
|
||||
}
|
||||
baseModel := strings.ToLower(thinking.ParseSuffix(model).ModelName)
|
||||
return strings.Contains(baseModel, "codex")
|
||||
}
|
||||
|
||||
// flattenAssistantContent converts assistant message content from array format
|
||||
// to a joined string. GitHub Copilot requires assistant content as a string;
|
||||
// sending it as an array causes Claude models to re-answer all previous prompts.
|
||||
func flattenAssistantContent(body []byte) []byte {
|
||||
messages := gjson.GetBytes(body, "messages")
|
||||
if !messages.Exists() || !messages.IsArray() {
|
||||
return body
|
||||
}
|
||||
result := body
|
||||
for i, msg := range messages.Array() {
|
||||
if msg.Get("role").String() != "assistant" {
|
||||
continue
|
||||
}
|
||||
content := msg.Get("content")
|
||||
if !content.Exists() || !content.IsArray() {
|
||||
continue
|
||||
}
|
||||
var textParts []string
|
||||
for _, part := range content.Array() {
|
||||
if part.Get("type").String() == "text" {
|
||||
if t := part.Get("text").String(); t != "" {
|
||||
textParts = append(textParts, t)
|
||||
}
|
||||
}
|
||||
}
|
||||
joined := strings.Join(textParts, "")
|
||||
path := fmt.Sprintf("messages.%d.content", i)
|
||||
result, _ = sjson.SetBytes(result, path, joined)
|
||||
}
|
||||
return result
|
||||
}
|
||||
|
||||
func normalizeGitHubCopilotChatTools(body []byte) []byte {
|
||||
tools := gjson.GetBytes(body, "tools")
|
||||
if tools.Exists() {
|
||||
filtered := "[]"
|
||||
if tools.IsArray() {
|
||||
for _, tool := range tools.Array() {
|
||||
if tool.Get("type").String() != "function" {
|
||||
continue
|
||||
}
|
||||
filtered, _ = sjson.SetRaw(filtered, "-1", tool.Raw)
|
||||
}
|
||||
}
|
||||
body, _ = sjson.SetRawBytes(body, "tools", []byte(filtered))
|
||||
}
|
||||
|
||||
toolChoice := gjson.GetBytes(body, "tool_choice")
|
||||
if !toolChoice.Exists() {
|
||||
return body
|
||||
}
|
||||
if toolChoice.Type == gjson.String {
|
||||
switch toolChoice.String() {
|
||||
case "auto", "none", "required":
|
||||
return body
|
||||
}
|
||||
}
|
||||
body, _ = sjson.SetBytes(body, "tool_choice", "auto")
|
||||
return body
|
||||
}
|
||||
|
||||
func normalizeGitHubCopilotResponsesInput(body []byte) []byte {
|
||||
input := gjson.GetBytes(body, "input")
|
||||
if input.Exists() {
|
||||
if input.Type == gjson.String {
|
||||
return body
|
||||
}
|
||||
inputString := input.Raw
|
||||
if input.Type != gjson.JSON {
|
||||
inputString = input.String()
|
||||
}
|
||||
body, _ = sjson.SetBytes(body, "input", inputString)
|
||||
return body
|
||||
}
|
||||
|
||||
var parts []string
|
||||
if system := gjson.GetBytes(body, "system"); system.Exists() {
|
||||
if text := strings.TrimSpace(collectTextFromNode(system)); text != "" {
|
||||
parts = append(parts, text)
|
||||
}
|
||||
}
|
||||
if messages := gjson.GetBytes(body, "messages"); messages.Exists() && messages.IsArray() {
|
||||
for _, msg := range messages.Array() {
|
||||
if text := strings.TrimSpace(collectTextFromNode(msg.Get("content"))); text != "" {
|
||||
parts = append(parts, text)
|
||||
}
|
||||
}
|
||||
}
|
||||
body, _ = sjson.SetBytes(body, "input", strings.Join(parts, "\n"))
|
||||
return body
|
||||
}
|
||||
|
||||
func normalizeGitHubCopilotResponsesTools(body []byte) []byte {
|
||||
tools := gjson.GetBytes(body, "tools")
|
||||
if tools.Exists() {
|
||||
filtered := "[]"
|
||||
if tools.IsArray() {
|
||||
for _, tool := range tools.Array() {
|
||||
toolType := tool.Get("type").String()
|
||||
// Accept OpenAI format (type="function") and Claude format
|
||||
// (no type field, but has top-level name + input_schema).
|
||||
if toolType != "" && toolType != "function" {
|
||||
continue
|
||||
}
|
||||
name := tool.Get("name").String()
|
||||
if name == "" {
|
||||
name = tool.Get("function.name").String()
|
||||
}
|
||||
if name == "" {
|
||||
continue
|
||||
}
|
||||
normalized := `{"type":"function","name":""}`
|
||||
normalized, _ = sjson.Set(normalized, "name", name)
|
||||
if desc := tool.Get("description").String(); desc != "" {
|
||||
normalized, _ = sjson.Set(normalized, "description", desc)
|
||||
} else if desc = tool.Get("function.description").String(); desc != "" {
|
||||
normalized, _ = sjson.Set(normalized, "description", desc)
|
||||
}
|
||||
if params := tool.Get("parameters"); params.Exists() {
|
||||
normalized, _ = sjson.SetRaw(normalized, "parameters", params.Raw)
|
||||
} else if params = tool.Get("function.parameters"); params.Exists() {
|
||||
normalized, _ = sjson.SetRaw(normalized, "parameters", params.Raw)
|
||||
} else if params = tool.Get("input_schema"); params.Exists() {
|
||||
normalized, _ = sjson.SetRaw(normalized, "parameters", params.Raw)
|
||||
}
|
||||
filtered, _ = sjson.SetRaw(filtered, "-1", normalized)
|
||||
}
|
||||
}
|
||||
body, _ = sjson.SetRawBytes(body, "tools", []byte(filtered))
|
||||
}
|
||||
|
||||
toolChoice := gjson.GetBytes(body, "tool_choice")
|
||||
if !toolChoice.Exists() {
|
||||
return body
|
||||
}
|
||||
if toolChoice.Type == gjson.String {
|
||||
switch toolChoice.String() {
|
||||
case "auto", "none", "required":
|
||||
return body
|
||||
default:
|
||||
body, _ = sjson.SetBytes(body, "tool_choice", "auto")
|
||||
return body
|
||||
}
|
||||
}
|
||||
if toolChoice.Type == gjson.JSON {
|
||||
choiceType := toolChoice.Get("type").String()
|
||||
if choiceType == "function" {
|
||||
name := toolChoice.Get("name").String()
|
||||
if name == "" {
|
||||
name = toolChoice.Get("function.name").String()
|
||||
}
|
||||
if name != "" {
|
||||
normalized := `{"type":"function","name":""}`
|
||||
normalized, _ = sjson.Set(normalized, "name", name)
|
||||
body, _ = sjson.SetRawBytes(body, "tool_choice", []byte(normalized))
|
||||
return body
|
||||
}
|
||||
}
|
||||
}
|
||||
body, _ = sjson.SetBytes(body, "tool_choice", "auto")
|
||||
return body
|
||||
}
|
||||
|
||||
func collectTextFromNode(node gjson.Result) string {
|
||||
if !node.Exists() {
|
||||
return ""
|
||||
}
|
||||
if node.Type == gjson.String {
|
||||
return node.String()
|
||||
}
|
||||
if node.IsArray() {
|
||||
var parts []string
|
||||
for _, item := range node.Array() {
|
||||
if item.Type == gjson.String {
|
||||
if text := item.String(); text != "" {
|
||||
parts = append(parts, text)
|
||||
}
|
||||
continue
|
||||
}
|
||||
if text := item.Get("text").String(); text != "" {
|
||||
parts = append(parts, text)
|
||||
continue
|
||||
}
|
||||
if nested := collectTextFromNode(item.Get("content")); nested != "" {
|
||||
parts = append(parts, nested)
|
||||
}
|
||||
}
|
||||
return strings.Join(parts, "\n")
|
||||
}
|
||||
if node.Type == gjson.JSON {
|
||||
if text := node.Get("text").String(); text != "" {
|
||||
return text
|
||||
}
|
||||
if nested := collectTextFromNode(node.Get("content")); nested != "" {
|
||||
return nested
|
||||
}
|
||||
return node.Raw
|
||||
}
|
||||
return node.String()
|
||||
}
|
||||
|
||||
type githubCopilotResponsesStreamToolState struct {
|
||||
Index int
|
||||
ID string
|
||||
Name string
|
||||
}
|
||||
|
||||
type githubCopilotResponsesStreamState struct {
|
||||
MessageStarted bool
|
||||
MessageStopSent bool
|
||||
TextBlockStarted bool
|
||||
TextBlockIndex int
|
||||
NextContentIndex int
|
||||
HasToolUse bool
|
||||
OutputIndexToTool map[int]*githubCopilotResponsesStreamToolState
|
||||
ItemIDToTool map[string]*githubCopilotResponsesStreamToolState
|
||||
}
|
||||
|
||||
func translateGitHubCopilotResponsesNonStreamToClaude(data []byte) string {
|
||||
root := gjson.ParseBytes(data)
|
||||
out := `{"id":"","type":"message","role":"assistant","model":"","content":[],"stop_reason":null,"stop_sequence":null,"usage":{"input_tokens":0,"output_tokens":0}}`
|
||||
out, _ = sjson.Set(out, "id", root.Get("id").String())
|
||||
out, _ = sjson.Set(out, "model", root.Get("model").String())
|
||||
|
||||
hasToolUse := false
|
||||
if output := root.Get("output"); output.Exists() && output.IsArray() {
|
||||
for _, item := range output.Array() {
|
||||
switch item.Get("type").String() {
|
||||
case "message":
|
||||
if content := item.Get("content"); content.Exists() && content.IsArray() {
|
||||
for _, part := range content.Array() {
|
||||
if part.Get("type").String() != "output_text" {
|
||||
continue
|
||||
}
|
||||
text := part.Get("text").String()
|
||||
if text == "" {
|
||||
continue
|
||||
}
|
||||
block := `{"type":"text","text":""}`
|
||||
block, _ = sjson.Set(block, "text", text)
|
||||
out, _ = sjson.SetRaw(out, "content.-1", block)
|
||||
}
|
||||
}
|
||||
case "function_call":
|
||||
hasToolUse = true
|
||||
toolUse := `{"type":"tool_use","id":"","name":"","input":{}}`
|
||||
toolID := item.Get("call_id").String()
|
||||
if toolID == "" {
|
||||
toolID = item.Get("id").String()
|
||||
}
|
||||
toolUse, _ = sjson.Set(toolUse, "id", toolID)
|
||||
toolUse, _ = sjson.Set(toolUse, "name", item.Get("name").String())
|
||||
if args := item.Get("arguments").String(); args != "" && gjson.Valid(args) {
|
||||
argObj := gjson.Parse(args)
|
||||
if argObj.IsObject() {
|
||||
toolUse, _ = sjson.SetRaw(toolUse, "input", argObj.Raw)
|
||||
}
|
||||
}
|
||||
out, _ = sjson.SetRaw(out, "content.-1", toolUse)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
inputTokens := root.Get("usage.input_tokens").Int()
|
||||
outputTokens := root.Get("usage.output_tokens").Int()
|
||||
out, _ = sjson.Set(out, "usage.input_tokens", inputTokens)
|
||||
out, _ = sjson.Set(out, "usage.output_tokens", outputTokens)
|
||||
if hasToolUse {
|
||||
out, _ = sjson.Set(out, "stop_reason", "tool_use")
|
||||
} else {
|
||||
out, _ = sjson.Set(out, "stop_reason", "end_turn")
|
||||
}
|
||||
return out
|
||||
}
|
||||
|
||||
func translateGitHubCopilotResponsesStreamToClaude(line []byte, param *any) []string {
|
||||
if *param == nil {
|
||||
*param = &githubCopilotResponsesStreamState{
|
||||
TextBlockIndex: -1,
|
||||
OutputIndexToTool: make(map[int]*githubCopilotResponsesStreamToolState),
|
||||
ItemIDToTool: make(map[string]*githubCopilotResponsesStreamToolState),
|
||||
}
|
||||
}
|
||||
state := (*param).(*githubCopilotResponsesStreamState)
|
||||
|
||||
if !bytes.HasPrefix(line, dataTag) {
|
||||
return nil
|
||||
}
|
||||
payload := bytes.TrimSpace(line[5:])
|
||||
if bytes.Equal(payload, []byte("[DONE]")) {
|
||||
return nil
|
||||
}
|
||||
if !gjson.ValidBytes(payload) {
|
||||
return nil
|
||||
}
|
||||
|
||||
event := gjson.GetBytes(payload, "type").String()
|
||||
results := make([]string, 0, 4)
|
||||
ensureMessageStart := func() {
|
||||
if state.MessageStarted {
|
||||
return
|
||||
}
|
||||
messageStart := `{"type":"message_start","message":{"id":"","type":"message","role":"assistant","model":"","content":[],"stop_reason":null,"stop_sequence":null,"usage":{"input_tokens":0,"output_tokens":0}}}`
|
||||
messageStart, _ = sjson.Set(messageStart, "message.id", gjson.GetBytes(payload, "response.id").String())
|
||||
messageStart, _ = sjson.Set(messageStart, "message.model", gjson.GetBytes(payload, "response.model").String())
|
||||
results = append(results, "event: message_start\ndata: "+messageStart+"\n\n")
|
||||
state.MessageStarted = true
|
||||
}
|
||||
startTextBlockIfNeeded := func() {
|
||||
if state.TextBlockStarted {
|
||||
return
|
||||
}
|
||||
if state.TextBlockIndex < 0 {
|
||||
state.TextBlockIndex = state.NextContentIndex
|
||||
state.NextContentIndex++
|
||||
}
|
||||
contentBlockStart := `{"type":"content_block_start","index":0,"content_block":{"type":"text","text":""}}`
|
||||
contentBlockStart, _ = sjson.Set(contentBlockStart, "index", state.TextBlockIndex)
|
||||
results = append(results, "event: content_block_start\ndata: "+contentBlockStart+"\n\n")
|
||||
state.TextBlockStarted = true
|
||||
}
|
||||
stopTextBlockIfNeeded := func() {
|
||||
if !state.TextBlockStarted {
|
||||
return
|
||||
}
|
||||
contentBlockStop := `{"type":"content_block_stop","index":0}`
|
||||
contentBlockStop, _ = sjson.Set(contentBlockStop, "index", state.TextBlockIndex)
|
||||
results = append(results, "event: content_block_stop\ndata: "+contentBlockStop+"\n\n")
|
||||
state.TextBlockStarted = false
|
||||
state.TextBlockIndex = -1
|
||||
}
|
||||
resolveTool := func(itemID string, outputIndex int) *githubCopilotResponsesStreamToolState {
|
||||
if itemID != "" {
|
||||
if tool, ok := state.ItemIDToTool[itemID]; ok {
|
||||
return tool
|
||||
}
|
||||
}
|
||||
if tool, ok := state.OutputIndexToTool[outputIndex]; ok {
|
||||
if itemID != "" {
|
||||
state.ItemIDToTool[itemID] = tool
|
||||
}
|
||||
return tool
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
switch event {
|
||||
case "response.created":
|
||||
ensureMessageStart()
|
||||
case "response.output_text.delta":
|
||||
ensureMessageStart()
|
||||
startTextBlockIfNeeded()
|
||||
delta := gjson.GetBytes(payload, "delta").String()
|
||||
if delta != "" {
|
||||
contentDelta := `{"type":"content_block_delta","index":0,"delta":{"type":"text_delta","text":""}}`
|
||||
contentDelta, _ = sjson.Set(contentDelta, "index", state.TextBlockIndex)
|
||||
contentDelta, _ = sjson.Set(contentDelta, "delta.text", delta)
|
||||
results = append(results, "event: content_block_delta\ndata: "+contentDelta+"\n\n")
|
||||
}
|
||||
case "response.output_item.added":
|
||||
if gjson.GetBytes(payload, "item.type").String() != "function_call" {
|
||||
break
|
||||
}
|
||||
ensureMessageStart()
|
||||
stopTextBlockIfNeeded()
|
||||
state.HasToolUse = true
|
||||
tool := &githubCopilotResponsesStreamToolState{
|
||||
Index: state.NextContentIndex,
|
||||
ID: gjson.GetBytes(payload, "item.call_id").String(),
|
||||
Name: gjson.GetBytes(payload, "item.name").String(),
|
||||
}
|
||||
if tool.ID == "" {
|
||||
tool.ID = gjson.GetBytes(payload, "item.id").String()
|
||||
}
|
||||
state.NextContentIndex++
|
||||
outputIndex := int(gjson.GetBytes(payload, "output_index").Int())
|
||||
state.OutputIndexToTool[outputIndex] = tool
|
||||
if itemID := gjson.GetBytes(payload, "item.id").String(); itemID != "" {
|
||||
state.ItemIDToTool[itemID] = tool
|
||||
}
|
||||
contentBlockStart := `{"type":"content_block_start","index":0,"content_block":{"type":"tool_use","id":"","name":"","input":{}}}`
|
||||
contentBlockStart, _ = sjson.Set(contentBlockStart, "index", tool.Index)
|
||||
contentBlockStart, _ = sjson.Set(contentBlockStart, "content_block.id", tool.ID)
|
||||
contentBlockStart, _ = sjson.Set(contentBlockStart, "content_block.name", tool.Name)
|
||||
results = append(results, "event: content_block_start\ndata: "+contentBlockStart+"\n\n")
|
||||
case "response.output_item.delta":
|
||||
item := gjson.GetBytes(payload, "item")
|
||||
if item.Get("type").String() != "function_call" {
|
||||
break
|
||||
}
|
||||
tool := resolveTool(item.Get("id").String(), int(gjson.GetBytes(payload, "output_index").Int()))
|
||||
if tool == nil {
|
||||
break
|
||||
}
|
||||
partial := gjson.GetBytes(payload, "delta").String()
|
||||
if partial == "" {
|
||||
partial = item.Get("arguments").String()
|
||||
}
|
||||
if partial == "" {
|
||||
break
|
||||
}
|
||||
inputDelta := `{"type":"content_block_delta","index":0,"delta":{"type":"input_json_delta","partial_json":""}}`
|
||||
inputDelta, _ = sjson.Set(inputDelta, "index", tool.Index)
|
||||
inputDelta, _ = sjson.Set(inputDelta, "delta.partial_json", partial)
|
||||
results = append(results, "event: content_block_delta\ndata: "+inputDelta+"\n\n")
|
||||
case "response.output_item.done":
|
||||
if gjson.GetBytes(payload, "item.type").String() != "function_call" {
|
||||
break
|
||||
}
|
||||
tool := resolveTool(gjson.GetBytes(payload, "item.id").String(), int(gjson.GetBytes(payload, "output_index").Int()))
|
||||
if tool == nil {
|
||||
break
|
||||
}
|
||||
contentBlockStop := `{"type":"content_block_stop","index":0}`
|
||||
contentBlockStop, _ = sjson.Set(contentBlockStop, "index", tool.Index)
|
||||
results = append(results, "event: content_block_stop\ndata: "+contentBlockStop+"\n\n")
|
||||
case "response.completed":
|
||||
ensureMessageStart()
|
||||
stopTextBlockIfNeeded()
|
||||
if !state.MessageStopSent {
|
||||
stopReason := "end_turn"
|
||||
if state.HasToolUse {
|
||||
stopReason = "tool_use"
|
||||
}
|
||||
messageDelta := `{"type":"message_delta","delta":{"stop_reason":"","stop_sequence":null},"usage":{"input_tokens":0,"output_tokens":0}}`
|
||||
messageDelta, _ = sjson.Set(messageDelta, "delta.stop_reason", stopReason)
|
||||
messageDelta, _ = sjson.Set(messageDelta, "usage.input_tokens", gjson.GetBytes(payload, "response.usage.input_tokens").Int())
|
||||
messageDelta, _ = sjson.Set(messageDelta, "usage.output_tokens", gjson.GetBytes(payload, "response.usage.output_tokens").Int())
|
||||
results = append(results, "event: message_delta\ndata: "+messageDelta+"\n\n")
|
||||
results = append(results, "event: message_stop\ndata: {\"type\":\"message_stop\"}\n\n")
|
||||
state.MessageStopSent = true
|
||||
}
|
||||
}
|
||||
|
||||
return results
|
||||
}
|
||||
|
||||
// isHTTPSuccess checks if the status code indicates success (2xx).
|
||||
func isHTTPSuccess(statusCode int) bool {
|
||||
return statusCode >= 200 && statusCode < 300
|
||||
}
|
||||
242
internal/runtime/executor/github_copilot_executor_test.go
Normal file
242
internal/runtime/executor/github_copilot_executor_test.go
Normal file
@@ -0,0 +1,242 @@
|
||||
package executor
|
||||
|
||||
import (
|
||||
"strings"
|
||||
"testing"
|
||||
|
||||
sdktranslator "github.com/router-for-me/CLIProxyAPI/v6/sdk/translator"
|
||||
"github.com/tidwall/gjson"
|
||||
)
|
||||
|
||||
func TestGitHubCopilotNormalizeModel_StripsSuffix(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
tests := []struct {
|
||||
name string
|
||||
model string
|
||||
wantModel string
|
||||
}{
|
||||
{
|
||||
name: "suffix stripped",
|
||||
model: "claude-opus-4.6(medium)",
|
||||
wantModel: "claude-opus-4.6",
|
||||
},
|
||||
{
|
||||
name: "no suffix unchanged",
|
||||
model: "claude-opus-4.6",
|
||||
wantModel: "claude-opus-4.6",
|
||||
},
|
||||
{
|
||||
name: "different suffix stripped",
|
||||
model: "gpt-4o(high)",
|
||||
wantModel: "gpt-4o",
|
||||
},
|
||||
{
|
||||
name: "numeric suffix stripped",
|
||||
model: "gemini-2.5-pro(8192)",
|
||||
wantModel: "gemini-2.5-pro",
|
||||
},
|
||||
}
|
||||
|
||||
e := &GitHubCopilotExecutor{}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
body := []byte(`{"model":"` + tt.model + `","messages":[]}`)
|
||||
got := e.normalizeModel(tt.model, body)
|
||||
|
||||
gotModel := gjson.GetBytes(got, "model").String()
|
||||
if gotModel != tt.wantModel {
|
||||
t.Fatalf("normalizeModel() model = %q, want %q", gotModel, tt.wantModel)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestUseGitHubCopilotResponsesEndpoint_OpenAIResponseSource(t *testing.T) {
|
||||
t.Parallel()
|
||||
if !useGitHubCopilotResponsesEndpoint(sdktranslator.FromString("openai-response"), "claude-3-5-sonnet") {
|
||||
t.Fatal("expected openai-response source to use /responses")
|
||||
}
|
||||
}
|
||||
|
||||
func TestUseGitHubCopilotResponsesEndpoint_CodexModel(t *testing.T) {
|
||||
t.Parallel()
|
||||
if !useGitHubCopilotResponsesEndpoint(sdktranslator.FromString("openai"), "gpt-5-codex") {
|
||||
t.Fatal("expected codex model to use /responses")
|
||||
}
|
||||
}
|
||||
|
||||
func TestUseGitHubCopilotResponsesEndpoint_DefaultChat(t *testing.T) {
|
||||
t.Parallel()
|
||||
if useGitHubCopilotResponsesEndpoint(sdktranslator.FromString("openai"), "claude-3-5-sonnet") {
|
||||
t.Fatal("expected default openai source with non-codex model to use /chat/completions")
|
||||
}
|
||||
}
|
||||
|
||||
func TestNormalizeGitHubCopilotChatTools_KeepFunctionOnly(t *testing.T) {
|
||||
t.Parallel()
|
||||
body := []byte(`{"tools":[{"type":"function","function":{"name":"ok"}},{"type":"code_interpreter"}],"tool_choice":"auto"}`)
|
||||
got := normalizeGitHubCopilotChatTools(body)
|
||||
tools := gjson.GetBytes(got, "tools").Array()
|
||||
if len(tools) != 1 {
|
||||
t.Fatalf("tools len = %d, want 1", len(tools))
|
||||
}
|
||||
if tools[0].Get("type").String() != "function" {
|
||||
t.Fatalf("tool type = %q, want function", tools[0].Get("type").String())
|
||||
}
|
||||
}
|
||||
|
||||
func TestNormalizeGitHubCopilotChatTools_InvalidToolChoiceDowngradeToAuto(t *testing.T) {
|
||||
t.Parallel()
|
||||
body := []byte(`{"tools":[],"tool_choice":{"type":"function","function":{"name":"x"}}}`)
|
||||
got := normalizeGitHubCopilotChatTools(body)
|
||||
if gjson.GetBytes(got, "tool_choice").String() != "auto" {
|
||||
t.Fatalf("tool_choice = %s, want auto", gjson.GetBytes(got, "tool_choice").Raw)
|
||||
}
|
||||
}
|
||||
|
||||
func TestNormalizeGitHubCopilotResponsesInput_MissingInputExtractedFromSystemAndMessages(t *testing.T) {
|
||||
t.Parallel()
|
||||
body := []byte(`{"system":"sys text","messages":[{"role":"user","content":"user text"},{"role":"assistant","content":[{"type":"text","text":"assistant text"}]}]}`)
|
||||
got := normalizeGitHubCopilotResponsesInput(body)
|
||||
in := gjson.GetBytes(got, "input")
|
||||
if in.Type != gjson.String {
|
||||
t.Fatalf("input type = %v, want string", in.Type)
|
||||
}
|
||||
if !strings.Contains(in.String(), "sys text") || !strings.Contains(in.String(), "user text") || !strings.Contains(in.String(), "assistant text") {
|
||||
t.Fatalf("input = %q, want merged text", in.String())
|
||||
}
|
||||
}
|
||||
|
||||
func TestNormalizeGitHubCopilotResponsesInput_NonStringInputStringified(t *testing.T) {
|
||||
t.Parallel()
|
||||
body := []byte(`{"input":{"foo":"bar"}}`)
|
||||
got := normalizeGitHubCopilotResponsesInput(body)
|
||||
in := gjson.GetBytes(got, "input")
|
||||
if in.Type != gjson.String {
|
||||
t.Fatalf("input type = %v, want string", in.Type)
|
||||
}
|
||||
if !strings.Contains(in.String(), "foo") {
|
||||
t.Fatalf("input = %q, want stringified object", in.String())
|
||||
}
|
||||
}
|
||||
|
||||
func TestNormalizeGitHubCopilotResponsesTools_FlattenFunctionTools(t *testing.T) {
|
||||
t.Parallel()
|
||||
body := []byte(`{"tools":[{"type":"function","function":{"name":"sum","description":"d","parameters":{"type":"object"}}},{"type":"web_search"}]}`)
|
||||
got := normalizeGitHubCopilotResponsesTools(body)
|
||||
tools := gjson.GetBytes(got, "tools").Array()
|
||||
if len(tools) != 1 {
|
||||
t.Fatalf("tools len = %d, want 1", len(tools))
|
||||
}
|
||||
if tools[0].Get("name").String() != "sum" {
|
||||
t.Fatalf("tools[0].name = %q, want sum", tools[0].Get("name").String())
|
||||
}
|
||||
if !tools[0].Get("parameters").Exists() {
|
||||
t.Fatal("expected parameters to be preserved")
|
||||
}
|
||||
}
|
||||
|
||||
func TestNormalizeGitHubCopilotResponsesTools_ClaudeFormatTools(t *testing.T) {
|
||||
t.Parallel()
|
||||
body := []byte(`{"tools":[{"name":"Bash","description":"Run commands","input_schema":{"type":"object","properties":{"command":{"type":"string"}},"required":["command"]}},{"name":"Read","description":"Read files","input_schema":{"type":"object","properties":{"path":{"type":"string"}}}}]}`)
|
||||
got := normalizeGitHubCopilotResponsesTools(body)
|
||||
tools := gjson.GetBytes(got, "tools").Array()
|
||||
if len(tools) != 2 {
|
||||
t.Fatalf("tools len = %d, want 2", len(tools))
|
||||
}
|
||||
if tools[0].Get("type").String() != "function" {
|
||||
t.Fatalf("tools[0].type = %q, want function", tools[0].Get("type").String())
|
||||
}
|
||||
if tools[0].Get("name").String() != "Bash" {
|
||||
t.Fatalf("tools[0].name = %q, want Bash", tools[0].Get("name").String())
|
||||
}
|
||||
if tools[0].Get("description").String() != "Run commands" {
|
||||
t.Fatalf("tools[0].description = %q, want 'Run commands'", tools[0].Get("description").String())
|
||||
}
|
||||
if !tools[0].Get("parameters").Exists() {
|
||||
t.Fatal("expected parameters to be set from input_schema")
|
||||
}
|
||||
if tools[0].Get("parameters.properties.command").Exists() != true {
|
||||
t.Fatal("expected parameters.properties.command to exist")
|
||||
}
|
||||
if tools[1].Get("name").String() != "Read" {
|
||||
t.Fatalf("tools[1].name = %q, want Read", tools[1].Get("name").String())
|
||||
}
|
||||
}
|
||||
|
||||
func TestNormalizeGitHubCopilotResponsesTools_FlattenToolChoiceFunctionObject(t *testing.T) {
|
||||
t.Parallel()
|
||||
body := []byte(`{"tool_choice":{"type":"function","function":{"name":"sum"}}}`)
|
||||
got := normalizeGitHubCopilotResponsesTools(body)
|
||||
if gjson.GetBytes(got, "tool_choice.type").String() != "function" {
|
||||
t.Fatalf("tool_choice.type = %q, want function", gjson.GetBytes(got, "tool_choice.type").String())
|
||||
}
|
||||
if gjson.GetBytes(got, "tool_choice.name").String() != "sum" {
|
||||
t.Fatalf("tool_choice.name = %q, want sum", gjson.GetBytes(got, "tool_choice.name").String())
|
||||
}
|
||||
}
|
||||
|
||||
func TestNormalizeGitHubCopilotResponsesTools_InvalidToolChoiceDowngradeToAuto(t *testing.T) {
|
||||
t.Parallel()
|
||||
body := []byte(`{"tool_choice":{"type":"function"}}`)
|
||||
got := normalizeGitHubCopilotResponsesTools(body)
|
||||
if gjson.GetBytes(got, "tool_choice").String() != "auto" {
|
||||
t.Fatalf("tool_choice = %s, want auto", gjson.GetBytes(got, "tool_choice").Raw)
|
||||
}
|
||||
}
|
||||
|
||||
func TestTranslateGitHubCopilotResponsesNonStreamToClaude_TextMapping(t *testing.T) {
|
||||
t.Parallel()
|
||||
resp := []byte(`{"id":"resp_1","model":"gpt-5-codex","output":[{"type":"message","content":[{"type":"output_text","text":"hello"}]}],"usage":{"input_tokens":3,"output_tokens":5}}`)
|
||||
out := translateGitHubCopilotResponsesNonStreamToClaude(resp)
|
||||
if gjson.Get(out, "type").String() != "message" {
|
||||
t.Fatalf("type = %q, want message", gjson.Get(out, "type").String())
|
||||
}
|
||||
if gjson.Get(out, "content.0.type").String() != "text" {
|
||||
t.Fatalf("content.0.type = %q, want text", gjson.Get(out, "content.0.type").String())
|
||||
}
|
||||
if gjson.Get(out, "content.0.text").String() != "hello" {
|
||||
t.Fatalf("content.0.text = %q, want hello", gjson.Get(out, "content.0.text").String())
|
||||
}
|
||||
}
|
||||
|
||||
func TestTranslateGitHubCopilotResponsesNonStreamToClaude_ToolUseMapping(t *testing.T) {
|
||||
t.Parallel()
|
||||
resp := []byte(`{"id":"resp_2","model":"gpt-5-codex","output":[{"type":"function_call","id":"fc_1","call_id":"call_1","name":"sum","arguments":"{\"a\":1}"}],"usage":{"input_tokens":1,"output_tokens":2}}`)
|
||||
out := translateGitHubCopilotResponsesNonStreamToClaude(resp)
|
||||
if gjson.Get(out, "content.0.type").String() != "tool_use" {
|
||||
t.Fatalf("content.0.type = %q, want tool_use", gjson.Get(out, "content.0.type").String())
|
||||
}
|
||||
if gjson.Get(out, "content.0.name").String() != "sum" {
|
||||
t.Fatalf("content.0.name = %q, want sum", gjson.Get(out, "content.0.name").String())
|
||||
}
|
||||
if gjson.Get(out, "stop_reason").String() != "tool_use" {
|
||||
t.Fatalf("stop_reason = %q, want tool_use", gjson.Get(out, "stop_reason").String())
|
||||
}
|
||||
}
|
||||
|
||||
func TestTranslateGitHubCopilotResponsesStreamToClaude_TextLifecycle(t *testing.T) {
|
||||
t.Parallel()
|
||||
var param any
|
||||
|
||||
created := translateGitHubCopilotResponsesStreamToClaude([]byte(`data: {"type":"response.created","response":{"id":"resp_1","model":"gpt-5-codex"}}`), ¶m)
|
||||
if len(created) == 0 || !strings.Contains(created[0], "message_start") {
|
||||
t.Fatalf("created events = %#v, want message_start", created)
|
||||
}
|
||||
|
||||
delta := translateGitHubCopilotResponsesStreamToClaude([]byte(`data: {"type":"response.output_text.delta","delta":"he"}`), ¶m)
|
||||
joinedDelta := strings.Join(delta, "")
|
||||
if !strings.Contains(joinedDelta, "content_block_start") || !strings.Contains(joinedDelta, "text_delta") {
|
||||
t.Fatalf("delta events = %#v, want content_block_start + text_delta", delta)
|
||||
}
|
||||
|
||||
completed := translateGitHubCopilotResponsesStreamToClaude([]byte(`data: {"type":"response.completed","response":{"usage":{"input_tokens":7,"output_tokens":9}}}`), ¶m)
|
||||
joinedCompleted := strings.Join(completed, "")
|
||||
if !strings.Contains(joinedCompleted, "message_delta") || !strings.Contains(joinedCompleted, "message_stop") {
|
||||
t.Fatalf("completed events = %#v, want message_delta + message_stop", completed)
|
||||
}
|
||||
}
|
||||
4781
internal/runtime/executor/kiro_executor.go
Normal file
4781
internal/runtime/executor/kiro_executor.go
Normal file
File diff suppressed because it is too large
Load Diff
@@ -6,6 +6,7 @@ import (
|
||||
"net/http"
|
||||
"net/url"
|
||||
"strings"
|
||||
"sync"
|
||||
"time"
|
||||
|
||||
"github.com/router-for-me/CLIProxyAPI/v6/internal/config"
|
||||
@@ -14,11 +15,19 @@ import (
|
||||
"golang.org/x/net/proxy"
|
||||
)
|
||||
|
||||
// httpClientCache caches HTTP clients by proxy URL to enable connection reuse
|
||||
var (
|
||||
httpClientCache = make(map[string]*http.Client)
|
||||
httpClientCacheMutex sync.RWMutex
|
||||
)
|
||||
|
||||
// newProxyAwareHTTPClient creates an HTTP client with proper proxy configuration priority:
|
||||
// 1. Use auth.ProxyURL if configured (highest priority)
|
||||
// 2. Use cfg.ProxyURL if auth proxy is not configured
|
||||
// 3. Use RoundTripper from context if neither are configured
|
||||
//
|
||||
// This function caches HTTP clients by proxy URL to enable TCP/TLS connection reuse.
|
||||
//
|
||||
// Parameters:
|
||||
// - ctx: The context containing optional RoundTripper
|
||||
// - cfg: The application configuration
|
||||
@@ -28,11 +37,6 @@ import (
|
||||
// Returns:
|
||||
// - *http.Client: An HTTP client with configured proxy or transport
|
||||
func newProxyAwareHTTPClient(ctx context.Context, cfg *config.Config, auth *cliproxyauth.Auth, timeout time.Duration) *http.Client {
|
||||
httpClient := &http.Client{}
|
||||
if timeout > 0 {
|
||||
httpClient.Timeout = timeout
|
||||
}
|
||||
|
||||
// Priority 1: Use auth.ProxyURL if configured
|
||||
var proxyURL string
|
||||
if auth != nil {
|
||||
@@ -44,11 +48,39 @@ func newProxyAwareHTTPClient(ctx context.Context, cfg *config.Config, auth *clip
|
||||
proxyURL = strings.TrimSpace(cfg.ProxyURL)
|
||||
}
|
||||
|
||||
// Build cache key from proxy URL (empty string for no proxy)
|
||||
cacheKey := proxyURL
|
||||
|
||||
// Check cache first
|
||||
httpClientCacheMutex.RLock()
|
||||
if cachedClient, ok := httpClientCache[cacheKey]; ok {
|
||||
httpClientCacheMutex.RUnlock()
|
||||
// Return a wrapper with the requested timeout but shared transport
|
||||
if timeout > 0 {
|
||||
return &http.Client{
|
||||
Transport: cachedClient.Transport,
|
||||
Timeout: timeout,
|
||||
}
|
||||
}
|
||||
return cachedClient
|
||||
}
|
||||
httpClientCacheMutex.RUnlock()
|
||||
|
||||
// Create new client
|
||||
httpClient := &http.Client{}
|
||||
if timeout > 0 {
|
||||
httpClient.Timeout = timeout
|
||||
}
|
||||
|
||||
// If we have a proxy URL configured, set up the transport
|
||||
if proxyURL != "" {
|
||||
transport := buildProxyTransport(proxyURL)
|
||||
if transport != nil {
|
||||
httpClient.Transport = transport
|
||||
// Cache the client
|
||||
httpClientCacheMutex.Lock()
|
||||
httpClientCache[cacheKey] = httpClient
|
||||
httpClientCacheMutex.Unlock()
|
||||
return httpClient
|
||||
}
|
||||
// If proxy setup failed, log and fall through to context RoundTripper
|
||||
@@ -60,6 +92,13 @@ func newProxyAwareHTTPClient(ctx context.Context, cfg *config.Config, auth *clip
|
||||
httpClient.Transport = rt
|
||||
}
|
||||
|
||||
// Cache the client for no-proxy case
|
||||
if proxyURL == "" {
|
||||
httpClientCacheMutex.Lock()
|
||||
httpClientCache[cacheKey] = httpClient
|
||||
httpClientCacheMutex.Unlock()
|
||||
}
|
||||
|
||||
return httpClient
|
||||
}
|
||||
|
||||
|
||||
@@ -2,43 +2,109 @@ package executor
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"regexp"
|
||||
"strconv"
|
||||
"strings"
|
||||
"sync"
|
||||
|
||||
"github.com/tidwall/gjson"
|
||||
"github.com/tiktoken-go/tokenizer"
|
||||
)
|
||||
|
||||
// tokenizerCache stores tokenizer instances to avoid repeated creation
|
||||
var tokenizerCache sync.Map
|
||||
|
||||
// TokenizerWrapper wraps a tokenizer codec with an adjustment factor for models
|
||||
// where tiktoken may not accurately estimate token counts (e.g., Claude models)
|
||||
type TokenizerWrapper struct {
|
||||
Codec tokenizer.Codec
|
||||
AdjustmentFactor float64 // 1.0 means no adjustment, >1.0 means tiktoken underestimates
|
||||
}
|
||||
|
||||
// Count returns the token count with adjustment factor applied
|
||||
func (tw *TokenizerWrapper) Count(text string) (int, error) {
|
||||
count, err := tw.Codec.Count(text)
|
||||
if err != nil {
|
||||
return 0, err
|
||||
}
|
||||
if tw.AdjustmentFactor != 1.0 && tw.AdjustmentFactor > 0 {
|
||||
return int(float64(count) * tw.AdjustmentFactor), nil
|
||||
}
|
||||
return count, nil
|
||||
}
|
||||
|
||||
// getTokenizer returns a cached tokenizer for the given model.
|
||||
// This improves performance by avoiding repeated tokenizer creation.
|
||||
func getTokenizer(model string) (*TokenizerWrapper, error) {
|
||||
// Check cache first
|
||||
if cached, ok := tokenizerCache.Load(model); ok {
|
||||
return cached.(*TokenizerWrapper), nil
|
||||
}
|
||||
|
||||
// Cache miss, create new tokenizer
|
||||
wrapper, err := tokenizerForModel(model)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
// Store in cache (use LoadOrStore to handle race conditions)
|
||||
actual, _ := tokenizerCache.LoadOrStore(model, wrapper)
|
||||
return actual.(*TokenizerWrapper), nil
|
||||
}
|
||||
|
||||
// tokenizerForModel returns a tokenizer codec suitable for an OpenAI-style model id.
|
||||
func tokenizerForModel(model string) (tokenizer.Codec, error) {
|
||||
// For Claude models, applies a 1.1 adjustment factor since tiktoken may underestimate.
|
||||
func tokenizerForModel(model string) (*TokenizerWrapper, error) {
|
||||
sanitized := strings.ToLower(strings.TrimSpace(model))
|
||||
|
||||
// Claude models use cl100k_base with 1.1 adjustment factor
|
||||
// because tiktoken may underestimate Claude's actual token count
|
||||
if strings.Contains(sanitized, "claude") || strings.HasPrefix(sanitized, "kiro-") || strings.HasPrefix(sanitized, "amazonq-") {
|
||||
enc, err := tokenizer.Get(tokenizer.Cl100kBase)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return &TokenizerWrapper{Codec: enc, AdjustmentFactor: 1.1}, nil
|
||||
}
|
||||
|
||||
var enc tokenizer.Codec
|
||||
var err error
|
||||
|
||||
switch {
|
||||
case sanitized == "":
|
||||
return tokenizer.Get(tokenizer.Cl100kBase)
|
||||
case strings.HasPrefix(sanitized, "gpt-5"):
|
||||
return tokenizer.ForModel(tokenizer.GPT5)
|
||||
enc, err = tokenizer.Get(tokenizer.Cl100kBase)
|
||||
case strings.HasPrefix(sanitized, "gpt-5.2"):
|
||||
enc, err = tokenizer.ForModel(tokenizer.GPT5)
|
||||
case strings.HasPrefix(sanitized, "gpt-5.1"):
|
||||
return tokenizer.ForModel(tokenizer.GPT5)
|
||||
enc, err = tokenizer.ForModel(tokenizer.GPT5)
|
||||
case strings.HasPrefix(sanitized, "gpt-5"):
|
||||
enc, err = tokenizer.ForModel(tokenizer.GPT5)
|
||||
case strings.HasPrefix(sanitized, "gpt-4.1"):
|
||||
return tokenizer.ForModel(tokenizer.GPT41)
|
||||
enc, err = tokenizer.ForModel(tokenizer.GPT41)
|
||||
case strings.HasPrefix(sanitized, "gpt-4o"):
|
||||
return tokenizer.ForModel(tokenizer.GPT4o)
|
||||
enc, err = tokenizer.ForModel(tokenizer.GPT4o)
|
||||
case strings.HasPrefix(sanitized, "gpt-4"):
|
||||
return tokenizer.ForModel(tokenizer.GPT4)
|
||||
enc, err = tokenizer.ForModel(tokenizer.GPT4)
|
||||
case strings.HasPrefix(sanitized, "gpt-3.5"), strings.HasPrefix(sanitized, "gpt-3"):
|
||||
return tokenizer.ForModel(tokenizer.GPT35Turbo)
|
||||
enc, err = tokenizer.ForModel(tokenizer.GPT35Turbo)
|
||||
case strings.HasPrefix(sanitized, "o1"):
|
||||
return tokenizer.ForModel(tokenizer.O1)
|
||||
enc, err = tokenizer.ForModel(tokenizer.O1)
|
||||
case strings.HasPrefix(sanitized, "o3"):
|
||||
return tokenizer.ForModel(tokenizer.O3)
|
||||
enc, err = tokenizer.ForModel(tokenizer.O3)
|
||||
case strings.HasPrefix(sanitized, "o4"):
|
||||
return tokenizer.ForModel(tokenizer.O4Mini)
|
||||
enc, err = tokenizer.ForModel(tokenizer.O4Mini)
|
||||
default:
|
||||
return tokenizer.Get(tokenizer.O200kBase)
|
||||
enc, err = tokenizer.Get(tokenizer.O200kBase)
|
||||
}
|
||||
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return &TokenizerWrapper{Codec: enc, AdjustmentFactor: 1.0}, nil
|
||||
}
|
||||
|
||||
// countOpenAIChatTokens approximates prompt tokens for OpenAI chat completions payloads.
|
||||
func countOpenAIChatTokens(enc tokenizer.Codec, payload []byte) (int64, error) {
|
||||
func countOpenAIChatTokens(enc *TokenizerWrapper, payload []byte) (int64, error) {
|
||||
if enc == nil {
|
||||
return 0, fmt.Errorf("encoder is nil")
|
||||
}
|
||||
@@ -62,11 +128,206 @@ func countOpenAIChatTokens(enc tokenizer.Codec, payload []byte) (int64, error) {
|
||||
return 0, nil
|
||||
}
|
||||
|
||||
// Count text tokens
|
||||
count, err := enc.Count(joined)
|
||||
if err != nil {
|
||||
return 0, err
|
||||
}
|
||||
return int64(count), nil
|
||||
|
||||
// Extract and add image tokens from placeholders
|
||||
imageTokens := extractImageTokens(joined)
|
||||
|
||||
return int64(count) + int64(imageTokens), nil
|
||||
}
|
||||
|
||||
// countClaudeChatTokens approximates prompt tokens for Claude API chat completions payloads.
|
||||
// This handles Claude's message format with system, messages, and tools.
|
||||
// Image tokens are estimated based on image dimensions when available.
|
||||
func countClaudeChatTokens(enc *TokenizerWrapper, payload []byte) (int64, error) {
|
||||
if enc == nil {
|
||||
return 0, fmt.Errorf("encoder is nil")
|
||||
}
|
||||
if len(payload) == 0 {
|
||||
return 0, nil
|
||||
}
|
||||
|
||||
root := gjson.ParseBytes(payload)
|
||||
segments := make([]string, 0, 32)
|
||||
|
||||
// Collect system prompt (can be string or array of content blocks)
|
||||
collectClaudeSystem(root.Get("system"), &segments)
|
||||
|
||||
// Collect messages
|
||||
collectClaudeMessages(root.Get("messages"), &segments)
|
||||
|
||||
// Collect tools
|
||||
collectClaudeTools(root.Get("tools"), &segments)
|
||||
|
||||
joined := strings.TrimSpace(strings.Join(segments, "\n"))
|
||||
if joined == "" {
|
||||
return 0, nil
|
||||
}
|
||||
|
||||
// Count text tokens
|
||||
count, err := enc.Count(joined)
|
||||
if err != nil {
|
||||
return 0, err
|
||||
}
|
||||
|
||||
// Extract and add image tokens from placeholders
|
||||
imageTokens := extractImageTokens(joined)
|
||||
|
||||
return int64(count) + int64(imageTokens), nil
|
||||
}
|
||||
|
||||
// imageTokenPattern matches [IMAGE:xxx tokens] format for extracting estimated image tokens
|
||||
var imageTokenPattern = regexp.MustCompile(`\[IMAGE:(\d+) tokens\]`)
|
||||
|
||||
// extractImageTokens extracts image token estimates from placeholder text.
|
||||
// Placeholders are in the format [IMAGE:xxx tokens] where xxx is the estimated token count.
|
||||
func extractImageTokens(text string) int {
|
||||
matches := imageTokenPattern.FindAllStringSubmatch(text, -1)
|
||||
total := 0
|
||||
for _, match := range matches {
|
||||
if len(match) > 1 {
|
||||
if tokens, err := strconv.Atoi(match[1]); err == nil {
|
||||
total += tokens
|
||||
}
|
||||
}
|
||||
}
|
||||
return total
|
||||
}
|
||||
|
||||
// estimateImageTokens calculates estimated tokens for an image based on dimensions.
|
||||
// Based on Claude's image token calculation: tokens ≈ (width * height) / 750
|
||||
// Minimum 85 tokens, maximum 1590 tokens (for 1568x1568 images).
|
||||
func estimateImageTokens(width, height float64) int {
|
||||
if width <= 0 || height <= 0 {
|
||||
// No valid dimensions, use default estimate (medium-sized image)
|
||||
return 1000
|
||||
}
|
||||
|
||||
tokens := int(width * height / 750)
|
||||
|
||||
// Apply bounds
|
||||
if tokens < 85 {
|
||||
tokens = 85
|
||||
}
|
||||
if tokens > 1590 {
|
||||
tokens = 1590
|
||||
}
|
||||
|
||||
return tokens
|
||||
}
|
||||
|
||||
// collectClaudeSystem extracts text from Claude's system field.
|
||||
// System can be a string or an array of content blocks.
|
||||
func collectClaudeSystem(system gjson.Result, segments *[]string) {
|
||||
if !system.Exists() {
|
||||
return
|
||||
}
|
||||
if system.Type == gjson.String {
|
||||
addIfNotEmpty(segments, system.String())
|
||||
return
|
||||
}
|
||||
if system.IsArray() {
|
||||
system.ForEach(func(_, block gjson.Result) bool {
|
||||
blockType := block.Get("type").String()
|
||||
if blockType == "text" || blockType == "" {
|
||||
addIfNotEmpty(segments, block.Get("text").String())
|
||||
}
|
||||
// Also handle plain string blocks
|
||||
if block.Type == gjson.String {
|
||||
addIfNotEmpty(segments, block.String())
|
||||
}
|
||||
return true
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
// collectClaudeMessages extracts text from Claude's messages array.
|
||||
func collectClaudeMessages(messages gjson.Result, segments *[]string) {
|
||||
if !messages.Exists() || !messages.IsArray() {
|
||||
return
|
||||
}
|
||||
messages.ForEach(func(_, message gjson.Result) bool {
|
||||
addIfNotEmpty(segments, message.Get("role").String())
|
||||
collectClaudeContent(message.Get("content"), segments)
|
||||
return true
|
||||
})
|
||||
}
|
||||
|
||||
// collectClaudeContent extracts text from Claude's content field.
|
||||
// Content can be a string or an array of content blocks.
|
||||
// For images, estimates token count based on dimensions when available.
|
||||
func collectClaudeContent(content gjson.Result, segments *[]string) {
|
||||
if !content.Exists() {
|
||||
return
|
||||
}
|
||||
if content.Type == gjson.String {
|
||||
addIfNotEmpty(segments, content.String())
|
||||
return
|
||||
}
|
||||
if content.IsArray() {
|
||||
content.ForEach(func(_, part gjson.Result) bool {
|
||||
partType := part.Get("type").String()
|
||||
switch partType {
|
||||
case "text":
|
||||
addIfNotEmpty(segments, part.Get("text").String())
|
||||
case "image":
|
||||
// Estimate image tokens based on dimensions if available
|
||||
source := part.Get("source")
|
||||
if source.Exists() {
|
||||
width := source.Get("width").Float()
|
||||
height := source.Get("height").Float()
|
||||
if width > 0 && height > 0 {
|
||||
tokens := estimateImageTokens(width, height)
|
||||
addIfNotEmpty(segments, fmt.Sprintf("[IMAGE:%d tokens]", tokens))
|
||||
} else {
|
||||
// No dimensions available, use default estimate
|
||||
addIfNotEmpty(segments, "[IMAGE:1000 tokens]")
|
||||
}
|
||||
} else {
|
||||
// No source info, use default estimate
|
||||
addIfNotEmpty(segments, "[IMAGE:1000 tokens]")
|
||||
}
|
||||
case "tool_use":
|
||||
addIfNotEmpty(segments, part.Get("id").String())
|
||||
addIfNotEmpty(segments, part.Get("name").String())
|
||||
if input := part.Get("input"); input.Exists() {
|
||||
addIfNotEmpty(segments, input.Raw)
|
||||
}
|
||||
case "tool_result":
|
||||
addIfNotEmpty(segments, part.Get("tool_use_id").String())
|
||||
collectClaudeContent(part.Get("content"), segments)
|
||||
case "thinking":
|
||||
addIfNotEmpty(segments, part.Get("thinking").String())
|
||||
default:
|
||||
// For unknown types, try to extract any text content
|
||||
if part.Type == gjson.String {
|
||||
addIfNotEmpty(segments, part.String())
|
||||
} else if part.Type == gjson.JSON {
|
||||
addIfNotEmpty(segments, part.Raw)
|
||||
}
|
||||
}
|
||||
return true
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
// collectClaudeTools extracts text from Claude's tools array.
|
||||
func collectClaudeTools(tools gjson.Result, segments *[]string) {
|
||||
if !tools.Exists() || !tools.IsArray() {
|
||||
return
|
||||
}
|
||||
tools.ForEach(func(_, tool gjson.Result) bool {
|
||||
addIfNotEmpty(segments, tool.Get("name").String())
|
||||
addIfNotEmpty(segments, tool.Get("description").String())
|
||||
if inputSchema := tool.Get("input_schema"); inputSchema.Exists() {
|
||||
addIfNotEmpty(segments, inputSchema.Raw)
|
||||
}
|
||||
return true
|
||||
})
|
||||
}
|
||||
|
||||
// buildOpenAIUsageJSON returns a minimal usage structure understood by downstream translators.
|
||||
|
||||
@@ -252,6 +252,44 @@ func parseOpenAIStreamUsage(line []byte) (usage.Detail, bool) {
|
||||
return detail, true
|
||||
}
|
||||
|
||||
func parseOpenAIResponsesUsageDetail(usageNode gjson.Result) usage.Detail {
|
||||
detail := usage.Detail{
|
||||
InputTokens: usageNode.Get("input_tokens").Int(),
|
||||
OutputTokens: usageNode.Get("output_tokens").Int(),
|
||||
TotalTokens: usageNode.Get("total_tokens").Int(),
|
||||
}
|
||||
if detail.TotalTokens == 0 {
|
||||
detail.TotalTokens = detail.InputTokens + detail.OutputTokens
|
||||
}
|
||||
if cached := usageNode.Get("input_tokens_details.cached_tokens"); cached.Exists() {
|
||||
detail.CachedTokens = cached.Int()
|
||||
}
|
||||
if reasoning := usageNode.Get("output_tokens_details.reasoning_tokens"); reasoning.Exists() {
|
||||
detail.ReasoningTokens = reasoning.Int()
|
||||
}
|
||||
return detail
|
||||
}
|
||||
|
||||
func parseOpenAIResponsesUsage(data []byte) usage.Detail {
|
||||
usageNode := gjson.ParseBytes(data).Get("usage")
|
||||
if !usageNode.Exists() {
|
||||
return usage.Detail{}
|
||||
}
|
||||
return parseOpenAIResponsesUsageDetail(usageNode)
|
||||
}
|
||||
|
||||
func parseOpenAIResponsesStreamUsage(line []byte) (usage.Detail, bool) {
|
||||
payload := jsonPayload(line)
|
||||
if len(payload) == 0 || !gjson.ValidBytes(payload) {
|
||||
return usage.Detail{}, false
|
||||
}
|
||||
usageNode := gjson.GetBytes(payload, "usage")
|
||||
if !usageNode.Exists() {
|
||||
return usage.Detail{}, false
|
||||
}
|
||||
return parseOpenAIResponsesUsageDetail(usageNode), true
|
||||
}
|
||||
|
||||
func parseClaudeUsage(data []byte) usage.Detail {
|
||||
usageNode := gjson.ParseBytes(data).Get("usage")
|
||||
if !usageNode.Exists() {
|
||||
|
||||
Reference in New Issue
Block a user