mirror of
https://github.com/router-for-me/CLIProxyAPIPlus.git
synced 2026-04-17 20:03:42 +00:00
Merge branch 'main' into plus
This commit is contained in:
@@ -848,6 +848,14 @@ func applyClaudeHeaders(r *http.Request, auth *cliproxyauth.Auth, apiKey string,
|
||||
hasClaude1MHeader = true
|
||||
}
|
||||
}
|
||||
// Also check auth attributes — GitLab Duo sets gitlab_duo_force_context_1m
|
||||
// when routing through the Anthropic gateway, but the gin headers won't have
|
||||
// X-CPA-CLAUDE-1M because the request is internally constructed.
|
||||
if !hasClaude1MHeader && auth != nil && auth.Attributes != nil {
|
||||
if auth.Attributes["gitlab_duo_force_context_1m"] == "true" {
|
||||
hasClaude1MHeader = true
|
||||
}
|
||||
}
|
||||
|
||||
// Merge extra betas from request body and request flags.
|
||||
if len(extraBetas) > 0 || hasClaude1MHeader {
|
||||
|
||||
343
internal/runtime/executor/codebuddy_executor.go
Normal file
343
internal/runtime/executor/codebuddy_executor.go
Normal file
@@ -0,0 +1,343 @@
|
||||
package executor
|
||||
|
||||
import (
|
||||
"bufio"
|
||||
"bytes"
|
||||
"context"
|
||||
"fmt"
|
||||
"io"
|
||||
"net/http"
|
||||
"time"
|
||||
|
||||
"github.com/router-for-me/CLIProxyAPI/v6/internal/auth/codebuddy"
|
||||
"github.com/router-for-me/CLIProxyAPI/v6/internal/config"
|
||||
"github.com/router-for-me/CLIProxyAPI/v6/internal/thinking"
|
||||
cliproxyauth "github.com/router-for-me/CLIProxyAPI/v6/sdk/cliproxy/auth"
|
||||
cliproxyexecutor "github.com/router-for-me/CLIProxyAPI/v6/sdk/cliproxy/executor"
|
||||
sdktranslator "github.com/router-for-me/CLIProxyAPI/v6/sdk/translator"
|
||||
log "github.com/sirupsen/logrus"
|
||||
)
|
||||
|
||||
const (
|
||||
codeBuddyChatPath = "/v2/chat/completions"
|
||||
codeBuddyAuthType = "codebuddy"
|
||||
)
|
||||
|
||||
// CodeBuddyExecutor handles requests to the CodeBuddy API.
|
||||
type CodeBuddyExecutor struct {
|
||||
cfg *config.Config
|
||||
}
|
||||
|
||||
// NewCodeBuddyExecutor creates a new CodeBuddy executor instance.
|
||||
func NewCodeBuddyExecutor(cfg *config.Config) *CodeBuddyExecutor {
|
||||
return &CodeBuddyExecutor{cfg: cfg}
|
||||
}
|
||||
|
||||
// Identifier returns the unique identifier for this executor.
|
||||
func (e *CodeBuddyExecutor) Identifier() string { return codeBuddyAuthType }
|
||||
|
||||
// codeBuddyCredentials extracts the access token and domain from auth metadata.
|
||||
func codeBuddyCredentials(auth *cliproxyauth.Auth) (accessToken, userID, domain string) {
|
||||
if auth == nil {
|
||||
return "", "", ""
|
||||
}
|
||||
accessToken = metaStringValue(auth.Metadata, "access_token")
|
||||
userID = metaStringValue(auth.Metadata, "user_id")
|
||||
domain = metaStringValue(auth.Metadata, "domain")
|
||||
if domain == "" {
|
||||
domain = codebuddy.DefaultDomain
|
||||
}
|
||||
return
|
||||
}
|
||||
|
||||
// PrepareRequest prepares the HTTP request before execution.
|
||||
func (e *CodeBuddyExecutor) PrepareRequest(req *http.Request, auth *cliproxyauth.Auth) error {
|
||||
if req == nil {
|
||||
return nil
|
||||
}
|
||||
accessToken, userID, domain := codeBuddyCredentials(auth)
|
||||
if accessToken == "" {
|
||||
return fmt.Errorf("codebuddy: missing access token")
|
||||
}
|
||||
e.applyHeaders(req, accessToken, userID, domain)
|
||||
return nil
|
||||
}
|
||||
|
||||
// HttpRequest executes a raw HTTP request.
|
||||
func (e *CodeBuddyExecutor) HttpRequest(ctx context.Context, auth *cliproxyauth.Auth, req *http.Request) (*http.Response, error) {
|
||||
if req == nil {
|
||||
return nil, fmt.Errorf("codebuddy executor: request is nil")
|
||||
}
|
||||
if ctx == nil {
|
||||
ctx = req.Context()
|
||||
}
|
||||
httpReq := req.WithContext(ctx)
|
||||
if err := e.PrepareRequest(httpReq, auth); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
httpClient := newProxyAwareHTTPClient(ctx, e.cfg, auth, 0)
|
||||
return httpClient.Do(httpReq)
|
||||
}
|
||||
|
||||
// Execute performs a non-streaming request.
|
||||
func (e *CodeBuddyExecutor) Execute(ctx context.Context, auth *cliproxyauth.Auth, req cliproxyexecutor.Request, opts cliproxyexecutor.Options) (resp cliproxyexecutor.Response, err error) {
|
||||
baseModel := thinking.ParseSuffix(req.Model).ModelName
|
||||
|
||||
reporter := newUsageReporter(ctx, e.Identifier(), baseModel, auth)
|
||||
defer reporter.trackFailure(ctx, &err)
|
||||
|
||||
accessToken, userID, domain := codeBuddyCredentials(auth)
|
||||
if accessToken == "" {
|
||||
return resp, fmt.Errorf("codebuddy: missing access token")
|
||||
}
|
||||
|
||||
from := opts.SourceFormat
|
||||
to := sdktranslator.FromString("openai")
|
||||
|
||||
originalPayloadSource := req.Payload
|
||||
if len(opts.OriginalRequest) > 0 {
|
||||
originalPayloadSource = opts.OriginalRequest
|
||||
}
|
||||
originalTranslated := sdktranslator.TranslateRequest(from, to, baseModel, originalPayloadSource, false)
|
||||
translated := sdktranslator.TranslateRequest(from, to, baseModel, req.Payload, false)
|
||||
requestedModel := payloadRequestedModel(opts, req.Model)
|
||||
translated = applyPayloadConfigWithRoot(e.cfg, baseModel, to.String(), "", translated, originalTranslated, requestedModel)
|
||||
|
||||
translated, err = thinking.ApplyThinking(translated, req.Model, from.String(), to.String(), e.Identifier())
|
||||
if err != nil {
|
||||
return resp, err
|
||||
}
|
||||
|
||||
url := codebuddy.BaseURL + codeBuddyChatPath
|
||||
httpReq, err := http.NewRequestWithContext(ctx, http.MethodPost, url, bytes.NewReader(translated))
|
||||
if err != nil {
|
||||
return resp, err
|
||||
}
|
||||
e.applyHeaders(httpReq, accessToken, userID, domain)
|
||||
|
||||
var authID, authLabel, authType, authValue string
|
||||
if auth != nil {
|
||||
authID = auth.ID
|
||||
authLabel = auth.Label
|
||||
authType, authValue = auth.AccountInfo()
|
||||
}
|
||||
recordAPIRequest(ctx, e.cfg, upstreamRequestLog{
|
||||
URL: url,
|
||||
Method: http.MethodPost,
|
||||
Headers: httpReq.Header.Clone(),
|
||||
Body: translated,
|
||||
Provider: e.Identifier(),
|
||||
AuthID: authID,
|
||||
AuthLabel: authLabel,
|
||||
AuthType: authType,
|
||||
AuthValue: authValue,
|
||||
})
|
||||
|
||||
httpClient := newProxyAwareHTTPClient(ctx, e.cfg, auth, 0)
|
||||
httpResp, err := httpClient.Do(httpReq)
|
||||
if err != nil {
|
||||
recordAPIResponseError(ctx, e.cfg, err)
|
||||
return resp, err
|
||||
}
|
||||
defer func() {
|
||||
if errClose := httpResp.Body.Close(); errClose != nil {
|
||||
log.Errorf("codebuddy executor: close response body error: %v", errClose)
|
||||
}
|
||||
}()
|
||||
|
||||
recordAPIResponseMetadata(ctx, e.cfg, httpResp.StatusCode, httpResp.Header.Clone())
|
||||
if !isHTTPSuccess(httpResp.StatusCode) {
|
||||
b, _ := io.ReadAll(httpResp.Body)
|
||||
appendAPIResponseChunk(ctx, e.cfg, b)
|
||||
log.Debugf("codebuddy executor: upstream error status: %d, body: %s", httpResp.StatusCode, summarizeErrorBody(httpResp.Header.Get("Content-Type"), b))
|
||||
err = statusErr{code: httpResp.StatusCode, msg: string(b)}
|
||||
return resp, err
|
||||
}
|
||||
|
||||
body, err := io.ReadAll(httpResp.Body)
|
||||
if err != nil {
|
||||
recordAPIResponseError(ctx, e.cfg, err)
|
||||
return resp, err
|
||||
}
|
||||
appendAPIResponseChunk(ctx, e.cfg, body)
|
||||
reporter.publish(ctx, parseOpenAIUsage(body))
|
||||
reporter.ensurePublished(ctx)
|
||||
|
||||
var param any
|
||||
out := sdktranslator.TranslateNonStream(ctx, to, from, req.Model, opts.OriginalRequest, translated, body, ¶m)
|
||||
resp = cliproxyexecutor.Response{Payload: []byte(out), Headers: httpResp.Header.Clone()}
|
||||
return resp, nil
|
||||
}
|
||||
|
||||
// ExecuteStream performs a streaming request.
|
||||
func (e *CodeBuddyExecutor) ExecuteStream(ctx context.Context, auth *cliproxyauth.Auth, req cliproxyexecutor.Request, opts cliproxyexecutor.Options) (_ *cliproxyexecutor.StreamResult, err error) {
|
||||
baseModel := thinking.ParseSuffix(req.Model).ModelName
|
||||
|
||||
reporter := newUsageReporter(ctx, e.Identifier(), baseModel, auth)
|
||||
defer reporter.trackFailure(ctx, &err)
|
||||
|
||||
accessToken, userID, domain := codeBuddyCredentials(auth)
|
||||
if accessToken == "" {
|
||||
return nil, fmt.Errorf("codebuddy: missing access token")
|
||||
}
|
||||
|
||||
from := opts.SourceFormat
|
||||
to := sdktranslator.FromString("openai")
|
||||
|
||||
originalPayloadSource := req.Payload
|
||||
if len(opts.OriginalRequest) > 0 {
|
||||
originalPayloadSource = opts.OriginalRequest
|
||||
}
|
||||
originalTranslated := sdktranslator.TranslateRequest(from, to, baseModel, originalPayloadSource, true)
|
||||
translated := sdktranslator.TranslateRequest(from, to, baseModel, req.Payload, true)
|
||||
requestedModel := payloadRequestedModel(opts, req.Model)
|
||||
translated = applyPayloadConfigWithRoot(e.cfg, baseModel, to.String(), "", translated, originalTranslated, requestedModel)
|
||||
|
||||
translated, err = thinking.ApplyThinking(translated, req.Model, from.String(), to.String(), e.Identifier())
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
url := codebuddy.BaseURL + codeBuddyChatPath
|
||||
httpReq, err := http.NewRequestWithContext(ctx, http.MethodPost, url, bytes.NewReader(translated))
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
e.applyHeaders(httpReq, accessToken, userID, domain)
|
||||
httpReq.Header.Set("Accept", "text/event-stream")
|
||||
httpReq.Header.Set("Cache-Control", "no-cache")
|
||||
|
||||
var authID, authLabel, authType, authValue string
|
||||
if auth != nil {
|
||||
authID = auth.ID
|
||||
authLabel = auth.Label
|
||||
authType, authValue = auth.AccountInfo()
|
||||
}
|
||||
recordAPIRequest(ctx, e.cfg, upstreamRequestLog{
|
||||
URL: url,
|
||||
Method: http.MethodPost,
|
||||
Headers: httpReq.Header.Clone(),
|
||||
Body: translated,
|
||||
Provider: e.Identifier(),
|
||||
AuthID: authID,
|
||||
AuthLabel: authLabel,
|
||||
AuthType: authType,
|
||||
AuthValue: authValue,
|
||||
})
|
||||
|
||||
httpClient := newProxyAwareHTTPClient(ctx, e.cfg, auth, 0)
|
||||
httpResp, err := httpClient.Do(httpReq)
|
||||
if err != nil {
|
||||
recordAPIResponseError(ctx, e.cfg, err)
|
||||
return nil, err
|
||||
}
|
||||
|
||||
recordAPIResponseMetadata(ctx, e.cfg, httpResp.StatusCode, httpResp.Header.Clone())
|
||||
if !isHTTPSuccess(httpResp.StatusCode) {
|
||||
b, _ := io.ReadAll(httpResp.Body)
|
||||
appendAPIResponseChunk(ctx, e.cfg, b)
|
||||
httpResp.Body.Close()
|
||||
log.Debugf("codebuddy executor: upstream error status: %d, body: %s", httpResp.StatusCode, summarizeErrorBody(httpResp.Header.Get("Content-Type"), b))
|
||||
err = statusErr{code: httpResp.StatusCode, msg: string(b)}
|
||||
return nil, err
|
||||
}
|
||||
|
||||
out := make(chan cliproxyexecutor.StreamChunk)
|
||||
go func() {
|
||||
defer close(out)
|
||||
defer func() {
|
||||
if errClose := httpResp.Body.Close(); errClose != nil {
|
||||
log.Errorf("codebuddy executor: close stream body error: %v", errClose)
|
||||
}
|
||||
}()
|
||||
|
||||
scanner := bufio.NewScanner(httpResp.Body)
|
||||
scanner.Buffer(nil, maxScannerBufferSize)
|
||||
var param any
|
||||
for scanner.Scan() {
|
||||
line := scanner.Bytes()
|
||||
appendAPIResponseChunk(ctx, e.cfg, line)
|
||||
if detail, ok := parseOpenAIStreamUsage(line); ok {
|
||||
reporter.publish(ctx, detail)
|
||||
}
|
||||
if len(line) == 0 {
|
||||
continue
|
||||
}
|
||||
if !bytes.HasPrefix(line, []byte("data:")) {
|
||||
continue
|
||||
}
|
||||
chunks := sdktranslator.TranslateStream(ctx, to, from, req.Model, opts.OriginalRequest, translated, bytes.Clone(line), ¶m)
|
||||
for i := range chunks {
|
||||
out <- cliproxyexecutor.StreamChunk{Payload: []byte(chunks[i])}
|
||||
}
|
||||
}
|
||||
if errScan := scanner.Err(); errScan != nil {
|
||||
recordAPIResponseError(ctx, e.cfg, errScan)
|
||||
reporter.publishFailure(ctx)
|
||||
out <- cliproxyexecutor.StreamChunk{Err: errScan}
|
||||
}
|
||||
reporter.ensurePublished(ctx)
|
||||
}()
|
||||
|
||||
return &cliproxyexecutor.StreamResult{
|
||||
Headers: httpResp.Header.Clone(),
|
||||
Chunks: out,
|
||||
}, nil
|
||||
}
|
||||
|
||||
// Refresh exchanges the CodeBuddy refresh token for a new access token.
|
||||
func (e *CodeBuddyExecutor) Refresh(ctx context.Context, auth *cliproxyauth.Auth) (*cliproxyauth.Auth, error) {
|
||||
if auth == nil {
|
||||
return nil, fmt.Errorf("codebuddy: missing auth")
|
||||
}
|
||||
|
||||
refreshToken := metaStringValue(auth.Metadata, "refresh_token")
|
||||
if refreshToken == "" {
|
||||
log.Debugf("codebuddy executor: no refresh token available, skipping refresh")
|
||||
return auth, nil
|
||||
}
|
||||
|
||||
accessToken, userID, domain := codeBuddyCredentials(auth)
|
||||
|
||||
authSvc := codebuddy.NewCodeBuddyAuth(e.cfg)
|
||||
storage, err := authSvc.RefreshToken(ctx, accessToken, refreshToken, userID, domain)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("codebuddy: token refresh failed: %w", err)
|
||||
}
|
||||
|
||||
updated := auth.Clone()
|
||||
updated.Metadata["access_token"] = storage.AccessToken
|
||||
if storage.RefreshToken != "" {
|
||||
updated.Metadata["refresh_token"] = storage.RefreshToken
|
||||
}
|
||||
updated.Metadata["expires_in"] = storage.ExpiresIn
|
||||
updated.Metadata["domain"] = storage.Domain
|
||||
if storage.UserID != "" {
|
||||
updated.Metadata["user_id"] = storage.UserID
|
||||
}
|
||||
now := time.Now()
|
||||
updated.UpdatedAt = now
|
||||
updated.LastRefreshedAt = now
|
||||
|
||||
return updated, nil
|
||||
}
|
||||
|
||||
// CountTokens is not supported for CodeBuddy.
|
||||
func (e *CodeBuddyExecutor) CountTokens(_ context.Context, _ *cliproxyauth.Auth, _ cliproxyexecutor.Request, _ cliproxyexecutor.Options) (cliproxyexecutor.Response, error) {
|
||||
return cliproxyexecutor.Response{}, fmt.Errorf("codebuddy: count tokens not supported")
|
||||
}
|
||||
|
||||
// applyHeaders sets required headers for CodeBuddy API requests.
|
||||
func (e *CodeBuddyExecutor) applyHeaders(req *http.Request, accessToken, userID, domain string) {
|
||||
req.Header.Set("Authorization", "Bearer "+accessToken)
|
||||
req.Header.Set("Content-Type", "application/json")
|
||||
req.Header.Set("Accept", "application/json")
|
||||
req.Header.Set("User-Agent", codebuddy.UserAgent)
|
||||
req.Header.Set("X-User-Id", userID)
|
||||
req.Header.Set("X-Domain", domain)
|
||||
req.Header.Set("X-Product", "SaaS")
|
||||
req.Header.Set("X-IDE-Type", "CLI")
|
||||
req.Header.Set("X-IDE-Name", "CLI")
|
||||
req.Header.Set("X-IDE-Version", "2.63.2")
|
||||
req.Header.Set("X-Requested-With", "XMLHttpRequest")
|
||||
}
|
||||
129
internal/runtime/executor/compat_helpers.go
Normal file
129
internal/runtime/executor/compat_helpers.go
Normal file
@@ -0,0 +1,129 @@
|
||||
package executor
|
||||
|
||||
import (
|
||||
"context"
|
||||
"net/http"
|
||||
"time"
|
||||
|
||||
"github.com/router-for-me/CLIProxyAPI/v6/internal/config"
|
||||
"github.com/router-for-me/CLIProxyAPI/v6/internal/runtime/executor/helps"
|
||||
cliproxyauth "github.com/router-for-me/CLIProxyAPI/v6/sdk/cliproxy/auth"
|
||||
cliproxyexecutor "github.com/router-for-me/CLIProxyAPI/v6/sdk/cliproxy/executor"
|
||||
"github.com/router-for-me/CLIProxyAPI/v6/sdk/cliproxy/usage"
|
||||
"github.com/tidwall/gjson"
|
||||
"github.com/tiktoken-go/tokenizer"
|
||||
)
|
||||
|
||||
func newProxyAwareHTTPClient(ctx context.Context, cfg *config.Config, auth *cliproxyauth.Auth, timeout time.Duration) *http.Client {
|
||||
return helps.NewProxyAwareHTTPClient(ctx, cfg, auth, timeout)
|
||||
}
|
||||
|
||||
func parseOpenAIUsage(data []byte) usage.Detail {
|
||||
return helps.ParseOpenAIUsage(data)
|
||||
}
|
||||
|
||||
func parseOpenAIStreamUsage(line []byte) (usage.Detail, bool) {
|
||||
return helps.ParseOpenAIStreamUsage(line)
|
||||
}
|
||||
|
||||
func parseOpenAIResponsesUsage(data []byte) usage.Detail {
|
||||
return helps.ParseOpenAIUsage(data)
|
||||
}
|
||||
|
||||
func parseOpenAIResponsesStreamUsage(line []byte) (usage.Detail, bool) {
|
||||
return helps.ParseOpenAIStreamUsage(line)
|
||||
}
|
||||
|
||||
func getTokenizer(model string) (tokenizer.Codec, error) {
|
||||
return helps.TokenizerForModel(model)
|
||||
}
|
||||
|
||||
func countOpenAIChatTokens(enc tokenizer.Codec, payload []byte) (int64, error) {
|
||||
return helps.CountOpenAIChatTokens(enc, payload)
|
||||
}
|
||||
|
||||
func countClaudeChatTokens(enc tokenizer.Codec, payload []byte) (int64, error) {
|
||||
return helps.CountClaudeChatTokens(enc, payload)
|
||||
}
|
||||
|
||||
func buildOpenAIUsageJSON(count int64) []byte {
|
||||
return helps.BuildOpenAIUsageJSON(count)
|
||||
}
|
||||
|
||||
type upstreamRequestLog = helps.UpstreamRequestLog
|
||||
|
||||
func recordAPIRequest(ctx context.Context, cfg *config.Config, info upstreamRequestLog) {
|
||||
helps.RecordAPIRequest(ctx, cfg, info)
|
||||
}
|
||||
|
||||
func recordAPIResponseMetadata(ctx context.Context, cfg *config.Config, status int, headers http.Header) {
|
||||
helps.RecordAPIResponseMetadata(ctx, cfg, status, headers)
|
||||
}
|
||||
|
||||
func recordAPIResponseError(ctx context.Context, cfg *config.Config, err error) {
|
||||
helps.RecordAPIResponseError(ctx, cfg, err)
|
||||
}
|
||||
|
||||
func appendAPIResponseChunk(ctx context.Context, cfg *config.Config, chunk []byte) {
|
||||
helps.AppendAPIResponseChunk(ctx, cfg, chunk)
|
||||
}
|
||||
|
||||
func payloadRequestedModel(opts cliproxyexecutor.Options, fallback string) string {
|
||||
return helps.PayloadRequestedModel(opts, fallback)
|
||||
}
|
||||
|
||||
func applyPayloadConfigWithRoot(cfg *config.Config, model, protocol, root string, payload, original []byte, requestedModel string) []byte {
|
||||
return helps.ApplyPayloadConfigWithRoot(cfg, model, protocol, root, payload, original, requestedModel)
|
||||
}
|
||||
|
||||
func summarizeErrorBody(contentType string, body []byte) string {
|
||||
return helps.SummarizeErrorBody(contentType, body)
|
||||
}
|
||||
|
||||
func apiKeyFromContext(ctx context.Context) string {
|
||||
return helps.APIKeyFromContext(ctx)
|
||||
}
|
||||
|
||||
func tokenizerForModel(model string) (tokenizer.Codec, error) {
|
||||
return helps.TokenizerForModel(model)
|
||||
}
|
||||
|
||||
func collectOpenAIContent(content gjson.Result, segments *[]string) {
|
||||
helps.CollectOpenAIContent(content, segments)
|
||||
}
|
||||
|
||||
type usageReporter struct {
|
||||
reporter *helps.UsageReporter
|
||||
}
|
||||
|
||||
func newUsageReporter(ctx context.Context, provider, model string, auth *cliproxyauth.Auth) *usageReporter {
|
||||
return &usageReporter{reporter: helps.NewUsageReporter(ctx, provider, model, auth)}
|
||||
}
|
||||
|
||||
func (r *usageReporter) publish(ctx context.Context, detail usage.Detail) {
|
||||
if r == nil || r.reporter == nil {
|
||||
return
|
||||
}
|
||||
r.reporter.Publish(ctx, detail)
|
||||
}
|
||||
|
||||
func (r *usageReporter) publishFailure(ctx context.Context) {
|
||||
if r == nil || r.reporter == nil {
|
||||
return
|
||||
}
|
||||
r.reporter.PublishFailure(ctx)
|
||||
}
|
||||
|
||||
func (r *usageReporter) trackFailure(ctx context.Context, errPtr *error) {
|
||||
if r == nil || r.reporter == nil {
|
||||
return
|
||||
}
|
||||
r.reporter.TrackFailure(ctx, errPtr)
|
||||
}
|
||||
|
||||
func (r *usageReporter) ensurePublished(ctx context.Context) {
|
||||
if r == nil || r.reporter == nil {
|
||||
return
|
||||
}
|
||||
r.reporter.EnsurePublished(ctx)
|
||||
}
|
||||
1719
internal/runtime/executor/cursor_executor.go
Normal file
1719
internal/runtime/executor/cursor_executor.go
Normal file
File diff suppressed because it is too large
Load Diff
1649
internal/runtime/executor/github_copilot_executor.go
Normal file
1649
internal/runtime/executor/github_copilot_executor.go
Normal file
File diff suppressed because it is too large
Load Diff
817
internal/runtime/executor/github_copilot_executor_test.go
Normal file
817
internal/runtime/executor/github_copilot_executor_test.go
Normal file
@@ -0,0 +1,817 @@
|
||||
package executor
|
||||
|
||||
import (
|
||||
"context"
|
||||
"net/http"
|
||||
"strings"
|
||||
"testing"
|
||||
|
||||
copilotauth "github.com/router-for-me/CLIProxyAPI/v6/internal/auth/copilot"
|
||||
"github.com/router-for-me/CLIProxyAPI/v6/internal/registry"
|
||||
cliproxyexecutor "github.com/router-for-me/CLIProxyAPI/v6/sdk/cliproxy/executor"
|
||||
sdktranslator "github.com/router-for-me/CLIProxyAPI/v6/sdk/translator"
|
||||
"github.com/tidwall/gjson"
|
||||
)
|
||||
|
||||
func TestGitHubCopilotNormalizeModel_StripsSuffix(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
tests := []struct {
|
||||
name string
|
||||
model string
|
||||
wantModel string
|
||||
}{
|
||||
{
|
||||
name: "suffix stripped",
|
||||
model: "claude-opus-4.6(medium)",
|
||||
wantModel: "claude-opus-4.6",
|
||||
},
|
||||
{
|
||||
name: "no suffix unchanged",
|
||||
model: "claude-opus-4.6",
|
||||
wantModel: "claude-opus-4.6",
|
||||
},
|
||||
{
|
||||
name: "different suffix stripped",
|
||||
model: "gpt-4o(high)",
|
||||
wantModel: "gpt-4o",
|
||||
},
|
||||
{
|
||||
name: "numeric suffix stripped",
|
||||
model: "gemini-2.5-pro(8192)",
|
||||
wantModel: "gemini-2.5-pro",
|
||||
},
|
||||
}
|
||||
|
||||
e := &GitHubCopilotExecutor{}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
body := []byte(`{"model":"` + tt.model + `","messages":[]}`)
|
||||
got := e.normalizeModel(tt.model, body)
|
||||
|
||||
gotModel := gjson.GetBytes(got, "model").String()
|
||||
if gotModel != tt.wantModel {
|
||||
t.Fatalf("normalizeModel() model = %q, want %q", gotModel, tt.wantModel)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestUseGitHubCopilotResponsesEndpoint_OpenAIResponseSource(t *testing.T) {
|
||||
t.Parallel()
|
||||
if !useGitHubCopilotResponsesEndpoint(sdktranslator.FromString("openai-response"), "claude-3-5-sonnet") {
|
||||
t.Fatal("expected openai-response source to use /responses")
|
||||
}
|
||||
}
|
||||
|
||||
func TestUseGitHubCopilotResponsesEndpoint_CodexModel(t *testing.T) {
|
||||
t.Parallel()
|
||||
if !useGitHubCopilotResponsesEndpoint(sdktranslator.FromString("openai"), "gpt-5-codex") {
|
||||
t.Fatal("expected codex model to use /responses")
|
||||
}
|
||||
}
|
||||
|
||||
func TestUseGitHubCopilotResponsesEndpoint_RegistryResponsesOnlyModel(t *testing.T) {
|
||||
// Not parallel: shares global model registry with DynamicRegistryWinsOverStatic.
|
||||
if !useGitHubCopilotResponsesEndpoint(sdktranslator.FromString("openai"), "gpt-5.4") {
|
||||
t.Fatal("expected responses-only registry model to use /responses")
|
||||
}
|
||||
if !useGitHubCopilotResponsesEndpoint(sdktranslator.FromString("openai"), "gpt-5.4-mini") {
|
||||
t.Fatal("expected responses-only registry model to use /responses")
|
||||
}
|
||||
}
|
||||
|
||||
func TestUseGitHubCopilotResponsesEndpoint_DynamicRegistryWinsOverStatic(t *testing.T) {
|
||||
// Not parallel: mutates global model registry, conflicts with RegistryResponsesOnlyModel.
|
||||
|
||||
reg := registry.GetGlobalRegistry()
|
||||
clientID := "github-copilot-test-client"
|
||||
reg.RegisterClient(clientID, "github-copilot", []*registry.ModelInfo{
|
||||
{
|
||||
ID: "gpt-5.4",
|
||||
SupportedEndpoints: []string{"/chat/completions", "/responses"},
|
||||
},
|
||||
{
|
||||
ID: "gpt-5.4-mini",
|
||||
SupportedEndpoints: []string{"/chat/completions", "/responses"},
|
||||
},
|
||||
})
|
||||
defer reg.UnregisterClient(clientID)
|
||||
|
||||
if useGitHubCopilotResponsesEndpoint(sdktranslator.FromString("openai"), "gpt-5.4") {
|
||||
t.Fatal("expected dynamic registry definition to take precedence over static fallback")
|
||||
}
|
||||
|
||||
if useGitHubCopilotResponsesEndpoint(sdktranslator.FromString("openai"), "gpt-5.4-mini") {
|
||||
t.Fatal("expected dynamic registry definition to take precedence over static fallback")
|
||||
}
|
||||
}
|
||||
|
||||
func TestUseGitHubCopilotResponsesEndpoint_DefaultChat(t *testing.T) {
|
||||
t.Parallel()
|
||||
if useGitHubCopilotResponsesEndpoint(sdktranslator.FromString("openai"), "claude-3-5-sonnet") {
|
||||
t.Fatal("expected default openai source with non-codex model to use /chat/completions")
|
||||
}
|
||||
}
|
||||
|
||||
func TestNormalizeGitHubCopilotChatTools_KeepFunctionOnly(t *testing.T) {
|
||||
t.Parallel()
|
||||
body := []byte(`{"tools":[{"type":"function","function":{"name":"ok"}},{"type":"code_interpreter"}],"tool_choice":"auto"}`)
|
||||
got := normalizeGitHubCopilotChatTools(body)
|
||||
tools := gjson.GetBytes(got, "tools").Array()
|
||||
if len(tools) != 1 {
|
||||
t.Fatalf("tools len = %d, want 1", len(tools))
|
||||
}
|
||||
if tools[0].Get("type").String() != "function" {
|
||||
t.Fatalf("tool type = %q, want function", tools[0].Get("type").String())
|
||||
}
|
||||
}
|
||||
|
||||
func TestNormalizeGitHubCopilotChatTools_InvalidToolChoiceDowngradeToAuto(t *testing.T) {
|
||||
t.Parallel()
|
||||
body := []byte(`{"tools":[],"tool_choice":{"type":"function","function":{"name":"x"}}}`)
|
||||
got := normalizeGitHubCopilotChatTools(body)
|
||||
if gjson.GetBytes(got, "tool_choice").String() != "auto" {
|
||||
t.Fatalf("tool_choice = %s, want auto", gjson.GetBytes(got, "tool_choice").Raw)
|
||||
}
|
||||
}
|
||||
|
||||
func TestNormalizeGitHubCopilotResponsesInput_MissingInputExtractedFromSystemAndMessages(t *testing.T) {
|
||||
t.Parallel()
|
||||
body := []byte(`{"system":"sys text","messages":[{"role":"user","content":"user text"},{"role":"assistant","content":[{"type":"text","text":"assistant text"}]}]}`)
|
||||
got := normalizeGitHubCopilotResponsesInput(body)
|
||||
in := gjson.GetBytes(got, "input")
|
||||
if !in.IsArray() {
|
||||
t.Fatalf("input type = %v, want array", in.Type)
|
||||
}
|
||||
raw := in.Raw
|
||||
if !strings.Contains(raw, "sys text") || !strings.Contains(raw, "user text") || !strings.Contains(raw, "assistant text") {
|
||||
t.Fatalf("input = %s, want structured array with all texts", raw)
|
||||
}
|
||||
if gjson.GetBytes(got, "messages").Exists() {
|
||||
t.Fatal("messages should be removed after conversion")
|
||||
}
|
||||
if gjson.GetBytes(got, "system").Exists() {
|
||||
t.Fatal("system should be removed after conversion")
|
||||
}
|
||||
}
|
||||
|
||||
func TestNormalizeGitHubCopilotResponsesInput_NonStringInputStringified(t *testing.T) {
|
||||
t.Parallel()
|
||||
body := []byte(`{"input":{"foo":"bar"}}`)
|
||||
got := normalizeGitHubCopilotResponsesInput(body)
|
||||
in := gjson.GetBytes(got, "input")
|
||||
if in.Type != gjson.String {
|
||||
t.Fatalf("input type = %v, want string", in.Type)
|
||||
}
|
||||
if !strings.Contains(in.String(), "foo") {
|
||||
t.Fatalf("input = %q, want stringified object", in.String())
|
||||
}
|
||||
}
|
||||
|
||||
func TestNormalizeGitHubCopilotResponsesInput_StripsServiceTier(t *testing.T) {
|
||||
t.Parallel()
|
||||
body := []byte(`{"input":"user text","service_tier":"default"}`)
|
||||
got := normalizeGitHubCopilotResponsesInput(body)
|
||||
|
||||
if gjson.GetBytes(got, "service_tier").Exists() {
|
||||
t.Fatalf("service_tier should be removed, got %s", gjson.GetBytes(got, "service_tier").Raw)
|
||||
}
|
||||
if gjson.GetBytes(got, "input").String() != "user text" {
|
||||
t.Fatalf("input = %q, want %q", gjson.GetBytes(got, "input").String(), "user text")
|
||||
}
|
||||
}
|
||||
|
||||
func TestNormalizeGitHubCopilotResponsesTools_FlattenFunctionTools(t *testing.T) {
|
||||
t.Parallel()
|
||||
body := []byte(`{"tools":[{"type":"function","function":{"name":"sum","description":"d","parameters":{"type":"object"}}},{"type":"web_search"}]}`)
|
||||
got := normalizeGitHubCopilotResponsesTools(body)
|
||||
tools := gjson.GetBytes(got, "tools").Array()
|
||||
if len(tools) != 1 {
|
||||
t.Fatalf("tools len = %d, want 1", len(tools))
|
||||
}
|
||||
if tools[0].Get("name").String() != "sum" {
|
||||
t.Fatalf("tools[0].name = %q, want sum", tools[0].Get("name").String())
|
||||
}
|
||||
if !tools[0].Get("parameters").Exists() {
|
||||
t.Fatal("expected parameters to be preserved")
|
||||
}
|
||||
}
|
||||
|
||||
func TestNormalizeGitHubCopilotResponsesTools_ClaudeFormatTools(t *testing.T) {
|
||||
t.Parallel()
|
||||
body := []byte(`{"tools":[{"name":"Bash","description":"Run commands","input_schema":{"type":"object","properties":{"command":{"type":"string"}},"required":["command"]}},{"name":"Read","description":"Read files","input_schema":{"type":"object","properties":{"path":{"type":"string"}}}}]}`)
|
||||
got := normalizeGitHubCopilotResponsesTools(body)
|
||||
tools := gjson.GetBytes(got, "tools").Array()
|
||||
if len(tools) != 2 {
|
||||
t.Fatalf("tools len = %d, want 2", len(tools))
|
||||
}
|
||||
if tools[0].Get("type").String() != "function" {
|
||||
t.Fatalf("tools[0].type = %q, want function", tools[0].Get("type").String())
|
||||
}
|
||||
if tools[0].Get("name").String() != "Bash" {
|
||||
t.Fatalf("tools[0].name = %q, want Bash", tools[0].Get("name").String())
|
||||
}
|
||||
if tools[0].Get("description").String() != "Run commands" {
|
||||
t.Fatalf("tools[0].description = %q, want 'Run commands'", tools[0].Get("description").String())
|
||||
}
|
||||
if !tools[0].Get("parameters").Exists() {
|
||||
t.Fatal("expected parameters to be set from input_schema")
|
||||
}
|
||||
if tools[0].Get("parameters.properties.command").Exists() != true {
|
||||
t.Fatal("expected parameters.properties.command to exist")
|
||||
}
|
||||
if tools[1].Get("name").String() != "Read" {
|
||||
t.Fatalf("tools[1].name = %q, want Read", tools[1].Get("name").String())
|
||||
}
|
||||
}
|
||||
|
||||
func TestNormalizeGitHubCopilotResponsesTools_FlattenToolChoiceFunctionObject(t *testing.T) {
|
||||
t.Parallel()
|
||||
body := []byte(`{"tool_choice":{"type":"function","function":{"name":"sum"}}}`)
|
||||
got := normalizeGitHubCopilotResponsesTools(body)
|
||||
if gjson.GetBytes(got, "tool_choice.type").String() != "function" {
|
||||
t.Fatalf("tool_choice.type = %q, want function", gjson.GetBytes(got, "tool_choice.type").String())
|
||||
}
|
||||
if gjson.GetBytes(got, "tool_choice.name").String() != "sum" {
|
||||
t.Fatalf("tool_choice.name = %q, want sum", gjson.GetBytes(got, "tool_choice.name").String())
|
||||
}
|
||||
}
|
||||
|
||||
func TestNormalizeGitHubCopilotResponsesTools_InvalidToolChoiceDowngradeToAuto(t *testing.T) {
|
||||
t.Parallel()
|
||||
body := []byte(`{"tool_choice":{"type":"function"}}`)
|
||||
got := normalizeGitHubCopilotResponsesTools(body)
|
||||
if gjson.GetBytes(got, "tool_choice").String() != "auto" {
|
||||
t.Fatalf("tool_choice = %s, want auto", gjson.GetBytes(got, "tool_choice").Raw)
|
||||
}
|
||||
}
|
||||
|
||||
func TestTranslateGitHubCopilotResponsesNonStreamToClaude_TextMapping(t *testing.T) {
|
||||
t.Parallel()
|
||||
resp := []byte(`{"id":"resp_1","model":"gpt-5-codex","output":[{"type":"message","content":[{"type":"output_text","text":"hello"}]}],"usage":{"input_tokens":3,"output_tokens":5}}`)
|
||||
out := translateGitHubCopilotResponsesNonStreamToClaude(resp)
|
||||
if gjson.GetBytes(out, "type").String() != "message" {
|
||||
t.Fatalf("type = %q, want message", gjson.GetBytes(out, "type").String())
|
||||
}
|
||||
if gjson.GetBytes(out, "content.0.type").String() != "text" {
|
||||
t.Fatalf("content.0.type = %q, want text", gjson.GetBytes(out, "content.0.type").String())
|
||||
}
|
||||
if gjson.GetBytes(out, "content.0.text").String() != "hello" {
|
||||
t.Fatalf("content.0.text = %q, want hello", gjson.GetBytes(out, "content.0.text").String())
|
||||
}
|
||||
}
|
||||
|
||||
func TestTranslateGitHubCopilotResponsesNonStreamToClaude_ToolUseMapping(t *testing.T) {
|
||||
t.Parallel()
|
||||
resp := []byte(`{"id":"resp_2","model":"gpt-5-codex","output":[{"type":"function_call","id":"fc_1","call_id":"call_1","name":"sum","arguments":"{\"a\":1}"}],"usage":{"input_tokens":1,"output_tokens":2}}`)
|
||||
out := translateGitHubCopilotResponsesNonStreamToClaude(resp)
|
||||
if gjson.GetBytes(out, "content.0.type").String() != "tool_use" {
|
||||
t.Fatalf("content.0.type = %q, want tool_use", gjson.GetBytes(out, "content.0.type").String())
|
||||
}
|
||||
if gjson.GetBytes(out, "content.0.name").String() != "sum" {
|
||||
t.Fatalf("content.0.name = %q, want sum", gjson.GetBytes(out, "content.0.name").String())
|
||||
}
|
||||
if gjson.GetBytes(out, "stop_reason").String() != "tool_use" {
|
||||
t.Fatalf("stop_reason = %q, want tool_use", gjson.GetBytes(out, "stop_reason").String())
|
||||
}
|
||||
}
|
||||
|
||||
func TestTranslateGitHubCopilotResponsesStreamToClaude_TextLifecycle(t *testing.T) {
|
||||
t.Parallel()
|
||||
var param any
|
||||
|
||||
created := translateGitHubCopilotResponsesStreamToClaude([]byte(`data: {"type":"response.created","response":{"id":"resp_1","model":"gpt-5-codex"}}`), ¶m)
|
||||
if len(created) == 0 || !strings.Contains(string(created[0]), "message_start") {
|
||||
t.Fatalf("created events = %#v, want message_start", created)
|
||||
}
|
||||
|
||||
delta := translateGitHubCopilotResponsesStreamToClaude([]byte(`data: {"type":"response.output_text.delta","delta":"he"}`), ¶m)
|
||||
var joinedDelta string
|
||||
for _, d := range delta {
|
||||
joinedDelta += string(d)
|
||||
}
|
||||
if !strings.Contains(joinedDelta, "content_block_start") || !strings.Contains(joinedDelta, "text_delta") {
|
||||
t.Fatalf("delta events = %#v, want content_block_start + text_delta", delta)
|
||||
}
|
||||
|
||||
completed := translateGitHubCopilotResponsesStreamToClaude([]byte(`data: {"type":"response.completed","response":{"usage":{"input_tokens":7,"output_tokens":9}}}`), ¶m)
|
||||
var joinedCompleted string
|
||||
for _, c := range completed {
|
||||
joinedCompleted += string(c)
|
||||
}
|
||||
if !strings.Contains(joinedCompleted, "message_delta") || !strings.Contains(joinedCompleted, "message_stop") {
|
||||
t.Fatalf("completed events = %#v, want message_delta + message_stop", completed)
|
||||
}
|
||||
}
|
||||
|
||||
// --- Tests for X-Initiator detection logic (Problem L) ---
|
||||
|
||||
func TestApplyHeaders_XInitiator_UserOnly(t *testing.T) {
|
||||
t.Parallel()
|
||||
e := &GitHubCopilotExecutor{}
|
||||
req, _ := http.NewRequest(http.MethodPost, "https://example.com", nil)
|
||||
body := []byte(`{"messages":[{"role":"system","content":"sys"},{"role":"user","content":"hello"}]}`)
|
||||
e.applyHeaders(req, "token", body)
|
||||
if got := req.Header.Get("X-Initiator"); got != "user" {
|
||||
t.Fatalf("X-Initiator = %q, want user", got)
|
||||
}
|
||||
}
|
||||
|
||||
func TestApplyHeaders_XInitiator_AgentWhenLastUserButHistoryHasAssistant(t *testing.T) {
|
||||
t.Parallel()
|
||||
e := &GitHubCopilotExecutor{}
|
||||
req, _ := http.NewRequest(http.MethodPost, "https://example.com", nil)
|
||||
// When the last role is "user" and the message contains tool_result content,
|
||||
// the request is a continuation (e.g. Claude tool result translated to a
|
||||
// synthetic user message). Should be "agent".
|
||||
body := []byte(`{"messages":[{"role":"user","content":"hello"},{"role":"assistant","content":"I will read the file"},{"role":"user","content":[{"type":"tool_result","tool_use_id":"tu1","content":"file contents..."}]}]}`)
|
||||
e.applyHeaders(req, "token", body)
|
||||
if got := req.Header.Get("X-Initiator"); got != "agent" {
|
||||
t.Fatalf("X-Initiator = %q, want agent (last user contains tool_result)", got)
|
||||
}
|
||||
}
|
||||
|
||||
func TestApplyHeaders_XInitiator_AgentWithToolRole(t *testing.T) {
|
||||
t.Parallel()
|
||||
e := &GitHubCopilotExecutor{}
|
||||
req, _ := http.NewRequest(http.MethodPost, "https://example.com", nil)
|
||||
// When the last message has role "tool", it's clearly agent-initiated.
|
||||
body := []byte(`{"messages":[{"role":"user","content":"hello"},{"role":"tool","content":"result"}]}`)
|
||||
e.applyHeaders(req, "token", body)
|
||||
if got := req.Header.Get("X-Initiator"); got != "agent" {
|
||||
t.Fatalf("X-Initiator = %q, want agent (last role is tool)", got)
|
||||
}
|
||||
}
|
||||
|
||||
func TestApplyHeaders_XInitiator_InputArrayLastAssistantMessage(t *testing.T) {
|
||||
t.Parallel()
|
||||
e := &GitHubCopilotExecutor{}
|
||||
req, _ := http.NewRequest(http.MethodPost, "https://example.com", nil)
|
||||
body := []byte(`{"input":[{"type":"message","role":"user","content":[{"type":"input_text","text":"Hi"}]},{"type":"message","role":"assistant","content":[{"type":"output_text","text":"Hello"}]}]}`)
|
||||
e.applyHeaders(req, "token", body)
|
||||
if got := req.Header.Get("X-Initiator"); got != "agent" {
|
||||
t.Fatalf("X-Initiator = %q, want agent (last role is assistant)", got)
|
||||
}
|
||||
}
|
||||
|
||||
func TestApplyHeaders_XInitiator_InputArrayAgentWhenLastUserButHistoryHasAssistant(t *testing.T) {
|
||||
t.Parallel()
|
||||
e := &GitHubCopilotExecutor{}
|
||||
req, _ := http.NewRequest(http.MethodPost, "https://example.com", nil)
|
||||
// Responses API: last item is user-role but history contains assistant → agent.
|
||||
body := []byte(`{"input":[{"type":"message","role":"assistant","content":[{"type":"output_text","text":"I can help"}]},{"type":"message","role":"user","content":[{"type":"input_text","text":"Do X"}]}]}`)
|
||||
e.applyHeaders(req, "token", body)
|
||||
if got := req.Header.Get("X-Initiator"); got != "agent" {
|
||||
t.Fatalf("X-Initiator = %q, want agent (history has assistant)", got)
|
||||
}
|
||||
}
|
||||
|
||||
func TestApplyHeaders_XInitiator_InputArrayLastFunctionCallOutput(t *testing.T) {
|
||||
t.Parallel()
|
||||
e := &GitHubCopilotExecutor{}
|
||||
req, _ := http.NewRequest(http.MethodPost, "https://example.com", nil)
|
||||
body := []byte(`{"input":[{"type":"message","role":"user","content":[{"type":"input_text","text":"Use tool"}]},{"type":"function_call","call_id":"c1","name":"Read","arguments":"{}"},{"type":"function_call_output","call_id":"c1","output":"ok"}]}`)
|
||||
e.applyHeaders(req, "token", body)
|
||||
if got := req.Header.Get("X-Initiator"); got != "agent" {
|
||||
t.Fatalf("X-Initiator = %q, want agent (last item maps to tool role)", got)
|
||||
}
|
||||
}
|
||||
|
||||
func TestApplyHeaders_XInitiator_UserInMultiTurnNoTools(t *testing.T) {
|
||||
t.Parallel()
|
||||
e := &GitHubCopilotExecutor{}
|
||||
req, _ := http.NewRequest(http.MethodPost, "https://example.com", nil)
|
||||
// Genuine multi-turn: user → assistant (plain text) → user follow-up.
|
||||
// No tool messages → should be "user" (not a false-positive).
|
||||
body := []byte(`{"messages":[{"role":"user","content":"hello"},{"role":"assistant","content":"Hi there!"},{"role":"user","content":"what is 2+2?"}]}`)
|
||||
e.applyHeaders(req, "token", body)
|
||||
if got := req.Header.Get("X-Initiator"); got != "user" {
|
||||
t.Fatalf("X-Initiator = %q, want user (genuine multi-turn, no tools)", got)
|
||||
}
|
||||
}
|
||||
|
||||
func TestApplyHeaders_XInitiator_UserFollowUpAfterToolHistory(t *testing.T) {
|
||||
t.Parallel()
|
||||
e := &GitHubCopilotExecutor{}
|
||||
req, _ := http.NewRequest(http.MethodPost, "https://example.com", nil)
|
||||
// User follow-up after a completed tool-use conversation.
|
||||
// The last message is a genuine user question — should be "user", not "agent".
|
||||
// This aligns with opencode's behavior: only active tool loops are agent-initiated.
|
||||
body := []byte(`{"messages":[{"role":"user","content":"hello"},{"role":"assistant","content":[{"type":"tool_use","id":"tu1","name":"Read","input":{}}]},{"role":"tool","tool_call_id":"tu1","content":"file data"},{"role":"assistant","content":"I read the file."},{"role":"user","content":"What did we do so far?"}]}`)
|
||||
e.applyHeaders(req, "token", body)
|
||||
if got := req.Header.Get("X-Initiator"); got != "user" {
|
||||
t.Fatalf("X-Initiator = %q, want user (genuine follow-up after tool history)", got)
|
||||
}
|
||||
}
|
||||
|
||||
// --- Tests for x-github-api-version header (Problem M) ---
|
||||
|
||||
func TestApplyHeaders_GitHubAPIVersion(t *testing.T) {
|
||||
t.Parallel()
|
||||
e := &GitHubCopilotExecutor{}
|
||||
req, _ := http.NewRequest(http.MethodPost, "https://example.com", nil)
|
||||
e.applyHeaders(req, "token", nil)
|
||||
if got := req.Header.Get("X-Github-Api-Version"); got != "2025-04-01" {
|
||||
t.Fatalf("X-Github-Api-Version = %q, want 2025-04-01", got)
|
||||
}
|
||||
}
|
||||
|
||||
// --- Tests for vision detection (Problem P) ---
|
||||
|
||||
func TestDetectVisionContent_WithImageURL(t *testing.T) {
|
||||
t.Parallel()
|
||||
body := []byte(`{"messages":[{"role":"user","content":[{"type":"text","text":"describe"},{"type":"image_url","image_url":{"url":"data:image/png;base64,abc"}}]}]}`)
|
||||
if !detectVisionContent(body) {
|
||||
t.Fatal("expected vision content to be detected")
|
||||
}
|
||||
}
|
||||
|
||||
func TestDetectVisionContent_WithImageType(t *testing.T) {
|
||||
t.Parallel()
|
||||
body := []byte(`{"messages":[{"role":"user","content":[{"type":"image","source":{"data":"abc","media_type":"image/png"}}]}]}`)
|
||||
if !detectVisionContent(body) {
|
||||
t.Fatal("expected image type to be detected")
|
||||
}
|
||||
}
|
||||
|
||||
func TestDetectVisionContent_NoVision(t *testing.T) {
|
||||
t.Parallel()
|
||||
body := []byte(`{"messages":[{"role":"user","content":[{"type":"text","text":"hello"}]}]}`)
|
||||
if detectVisionContent(body) {
|
||||
t.Fatal("expected no vision content")
|
||||
}
|
||||
}
|
||||
|
||||
func TestDetectVisionContent_NoMessages(t *testing.T) {
|
||||
t.Parallel()
|
||||
// After Responses API normalization, messages is removed — detection should return false
|
||||
body := []byte(`{"input":[{"type":"message","role":"user","content":[{"type":"input_text","text":"hello"}]}]}`)
|
||||
if detectVisionContent(body) {
|
||||
t.Fatal("expected no vision content when messages field is absent")
|
||||
}
|
||||
}
|
||||
|
||||
// --- Tests for applyGitHubCopilotResponsesDefaults ---
|
||||
|
||||
func TestApplyGitHubCopilotResponsesDefaults_SetsAllDefaults(t *testing.T) {
|
||||
t.Parallel()
|
||||
body := []byte(`{"input":"hello","reasoning":{"effort":"medium"}}`)
|
||||
got := applyGitHubCopilotResponsesDefaults(body)
|
||||
|
||||
if gjson.GetBytes(got, "store").Bool() != false {
|
||||
t.Fatalf("store = %v, want false", gjson.GetBytes(got, "store").Raw)
|
||||
}
|
||||
inc := gjson.GetBytes(got, "include")
|
||||
if !inc.IsArray() || inc.Array()[0].String() != "reasoning.encrypted_content" {
|
||||
t.Fatalf("include = %s, want [\"reasoning.encrypted_content\"]", inc.Raw)
|
||||
}
|
||||
if gjson.GetBytes(got, "reasoning.summary").String() != "auto" {
|
||||
t.Fatalf("reasoning.summary = %q, want auto", gjson.GetBytes(got, "reasoning.summary").String())
|
||||
}
|
||||
}
|
||||
|
||||
func TestApplyGitHubCopilotResponsesDefaults_DoesNotOverrideExisting(t *testing.T) {
|
||||
t.Parallel()
|
||||
body := []byte(`{"input":"hello","store":true,"include":["other"],"reasoning":{"effort":"high","summary":"concise"}}`)
|
||||
got := applyGitHubCopilotResponsesDefaults(body)
|
||||
|
||||
if gjson.GetBytes(got, "store").Bool() != true {
|
||||
t.Fatalf("store should not be overridden, got %s", gjson.GetBytes(got, "store").Raw)
|
||||
}
|
||||
if gjson.GetBytes(got, "include").Array()[0].String() != "other" {
|
||||
t.Fatalf("include should not be overridden, got %s", gjson.GetBytes(got, "include").Raw)
|
||||
}
|
||||
if gjson.GetBytes(got, "reasoning.summary").String() != "concise" {
|
||||
t.Fatalf("reasoning.summary should not be overridden, got %q", gjson.GetBytes(got, "reasoning.summary").String())
|
||||
}
|
||||
}
|
||||
|
||||
func TestApplyGitHubCopilotResponsesDefaults_NoReasoningEffort(t *testing.T) {
|
||||
t.Parallel()
|
||||
body := []byte(`{"input":"hello"}`)
|
||||
got := applyGitHubCopilotResponsesDefaults(body)
|
||||
|
||||
if gjson.GetBytes(got, "store").Bool() != false {
|
||||
t.Fatalf("store = %v, want false", gjson.GetBytes(got, "store").Raw)
|
||||
}
|
||||
// reasoning.summary should NOT be set when reasoning.effort is absent
|
||||
if gjson.GetBytes(got, "reasoning.summary").Exists() {
|
||||
t.Fatalf("reasoning.summary should not be set when reasoning.effort is absent, got %q", gjson.GetBytes(got, "reasoning.summary").String())
|
||||
}
|
||||
}
|
||||
|
||||
// --- Tests for normalizeGitHubCopilotReasoningField ---
|
||||
|
||||
func TestNormalizeReasoningField_NonStreaming(t *testing.T) {
|
||||
t.Parallel()
|
||||
data := []byte(`{"choices":[{"message":{"content":"hello","reasoning_text":"I think..."}}]}`)
|
||||
got := normalizeGitHubCopilotReasoningField(data)
|
||||
rc := gjson.GetBytes(got, "choices.0.message.reasoning_content").String()
|
||||
if rc != "I think..." {
|
||||
t.Fatalf("reasoning_content = %q, want %q", rc, "I think...")
|
||||
}
|
||||
}
|
||||
|
||||
func TestNormalizeReasoningField_Streaming(t *testing.T) {
|
||||
t.Parallel()
|
||||
data := []byte(`{"choices":[{"delta":{"reasoning_text":"thinking delta"}}]}`)
|
||||
got := normalizeGitHubCopilotReasoningField(data)
|
||||
rc := gjson.GetBytes(got, "choices.0.delta.reasoning_content").String()
|
||||
if rc != "thinking delta" {
|
||||
t.Fatalf("reasoning_content = %q, want %q", rc, "thinking delta")
|
||||
}
|
||||
}
|
||||
|
||||
func TestNormalizeReasoningField_PreservesExistingReasoningContent(t *testing.T) {
|
||||
t.Parallel()
|
||||
data := []byte(`{"choices":[{"message":{"reasoning_text":"old","reasoning_content":"existing"}}]}`)
|
||||
got := normalizeGitHubCopilotReasoningField(data)
|
||||
rc := gjson.GetBytes(got, "choices.0.message.reasoning_content").String()
|
||||
if rc != "existing" {
|
||||
t.Fatalf("reasoning_content = %q, want %q (should not overwrite)", rc, "existing")
|
||||
}
|
||||
}
|
||||
|
||||
func TestNormalizeReasoningField_MultiChoice(t *testing.T) {
|
||||
t.Parallel()
|
||||
data := []byte(`{"choices":[{"message":{"reasoning_text":"thought-0"}},{"message":{"reasoning_text":"thought-1"}}]}`)
|
||||
got := normalizeGitHubCopilotReasoningField(data)
|
||||
rc0 := gjson.GetBytes(got, "choices.0.message.reasoning_content").String()
|
||||
rc1 := gjson.GetBytes(got, "choices.1.message.reasoning_content").String()
|
||||
if rc0 != "thought-0" {
|
||||
t.Fatalf("choices[0].reasoning_content = %q, want %q", rc0, "thought-0")
|
||||
}
|
||||
if rc1 != "thought-1" {
|
||||
t.Fatalf("choices[1].reasoning_content = %q, want %q", rc1, "thought-1")
|
||||
}
|
||||
}
|
||||
|
||||
func TestNormalizeReasoningField_NoChoices(t *testing.T) {
|
||||
t.Parallel()
|
||||
data := []byte(`{"id":"chatcmpl-123"}`)
|
||||
got := normalizeGitHubCopilotReasoningField(data)
|
||||
if string(got) != string(data) {
|
||||
t.Fatalf("expected no change, got %s", string(got))
|
||||
}
|
||||
}
|
||||
|
||||
func TestApplyHeaders_OpenAIIntentValue(t *testing.T) {
|
||||
t.Parallel()
|
||||
e := &GitHubCopilotExecutor{}
|
||||
req, _ := http.NewRequest(http.MethodPost, "https://example.com", nil)
|
||||
e.applyHeaders(req, "token", nil)
|
||||
if got := req.Header.Get("Openai-Intent"); got != "conversation-edits" {
|
||||
t.Fatalf("Openai-Intent = %q, want conversation-edits", got)
|
||||
}
|
||||
}
|
||||
|
||||
// --- Tests for CountTokens (local tiktoken estimation) ---
|
||||
|
||||
func TestCountTokens_ReturnsPositiveCount(t *testing.T) {
|
||||
t.Parallel()
|
||||
e := &GitHubCopilotExecutor{}
|
||||
body := []byte(`{"model":"gpt-4o","messages":[{"role":"user","content":"Hello, world!"}]}`)
|
||||
resp, err := e.CountTokens(context.Background(), nil, cliproxyexecutor.Request{
|
||||
Model: "gpt-4o",
|
||||
Payload: body,
|
||||
}, cliproxyexecutor.Options{
|
||||
SourceFormat: sdktranslator.FromString("openai"),
|
||||
})
|
||||
if err != nil {
|
||||
t.Fatalf("CountTokens() error: %v", err)
|
||||
}
|
||||
if len(resp.Payload) == 0 {
|
||||
t.Fatal("CountTokens() returned empty payload")
|
||||
}
|
||||
// The response should contain a positive token count.
|
||||
tokens := gjson.GetBytes(resp.Payload, "usage.prompt_tokens").Int()
|
||||
if tokens <= 0 {
|
||||
t.Fatalf("expected positive token count, got %d", tokens)
|
||||
}
|
||||
}
|
||||
|
||||
func TestCountTokens_ClaudeSourceFormatTranslates(t *testing.T) {
|
||||
t.Parallel()
|
||||
e := &GitHubCopilotExecutor{}
|
||||
body := []byte(`{"model":"claude-sonnet-4","messages":[{"role":"user","content":"Tell me a joke"}],"max_tokens":1024}`)
|
||||
resp, err := e.CountTokens(context.Background(), nil, cliproxyexecutor.Request{
|
||||
Model: "claude-sonnet-4",
|
||||
Payload: body,
|
||||
}, cliproxyexecutor.Options{
|
||||
SourceFormat: sdktranslator.FromString("claude"),
|
||||
})
|
||||
if err != nil {
|
||||
t.Fatalf("CountTokens() error: %v", err)
|
||||
}
|
||||
// Claude source format → should get input_tokens in response
|
||||
inputTokens := gjson.GetBytes(resp.Payload, "input_tokens").Int()
|
||||
if inputTokens <= 0 {
|
||||
// Fallback: check usage.prompt_tokens (depends on translator registration)
|
||||
promptTokens := gjson.GetBytes(resp.Payload, "usage.prompt_tokens").Int()
|
||||
if promptTokens <= 0 {
|
||||
t.Fatalf("expected positive token count, got payload: %s", resp.Payload)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func TestCountTokens_EmptyPayload(t *testing.T) {
|
||||
t.Parallel()
|
||||
e := &GitHubCopilotExecutor{}
|
||||
resp, err := e.CountTokens(context.Background(), nil, cliproxyexecutor.Request{
|
||||
Model: "gpt-4o",
|
||||
Payload: []byte(`{"model":"gpt-4o","messages":[]}`),
|
||||
}, cliproxyexecutor.Options{
|
||||
SourceFormat: sdktranslator.FromString("openai"),
|
||||
})
|
||||
if err != nil {
|
||||
t.Fatalf("CountTokens() error: %v", err)
|
||||
}
|
||||
tokens := gjson.GetBytes(resp.Payload, "usage.prompt_tokens").Int()
|
||||
// Empty messages should return 0 tokens.
|
||||
if tokens != 0 {
|
||||
t.Fatalf("expected 0 tokens for empty messages, got %d", tokens)
|
||||
}
|
||||
}
|
||||
|
||||
func TestStripUnsupportedBetas_RemovesContext1M(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
body := []byte(`{"model":"claude-opus-4.6","betas":["interleaved-thinking-2025-05-14","context-1m-2025-08-07","claude-code-20250219"],"messages":[]}`)
|
||||
result := stripUnsupportedBetas(body)
|
||||
|
||||
betas := gjson.GetBytes(result, "betas")
|
||||
if !betas.Exists() {
|
||||
t.Fatal("betas field should still exist after stripping")
|
||||
}
|
||||
for _, item := range betas.Array() {
|
||||
if item.String() == "context-1m-2025-08-07" {
|
||||
t.Fatal("context-1m-2025-08-07 should have been stripped")
|
||||
}
|
||||
}
|
||||
// Other betas should be preserved
|
||||
found := false
|
||||
for _, item := range betas.Array() {
|
||||
if item.String() == "interleaved-thinking-2025-05-14" {
|
||||
found = true
|
||||
}
|
||||
}
|
||||
if !found {
|
||||
t.Fatal("other betas should be preserved")
|
||||
}
|
||||
}
|
||||
|
||||
func TestStripUnsupportedBetas_NoBetasField(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
body := []byte(`{"model":"gpt-4o","messages":[]}`)
|
||||
result := stripUnsupportedBetas(body)
|
||||
|
||||
// Should be unchanged
|
||||
if string(result) != string(body) {
|
||||
t.Fatalf("body should be unchanged when no betas field exists, got %s", string(result))
|
||||
}
|
||||
}
|
||||
|
||||
func TestStripUnsupportedBetas_MetadataBetas(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
body := []byte(`{"model":"claude-opus-4.6","metadata":{"betas":["context-1m-2025-08-07","other-beta"]},"messages":[]}`)
|
||||
result := stripUnsupportedBetas(body)
|
||||
|
||||
betas := gjson.GetBytes(result, "metadata.betas")
|
||||
if !betas.Exists() {
|
||||
t.Fatal("metadata.betas field should still exist after stripping")
|
||||
}
|
||||
for _, item := range betas.Array() {
|
||||
if item.String() == "context-1m-2025-08-07" {
|
||||
t.Fatal("context-1m-2025-08-07 should have been stripped from metadata.betas")
|
||||
}
|
||||
}
|
||||
if betas.Array()[0].String() != "other-beta" {
|
||||
t.Fatal("other betas in metadata.betas should be preserved")
|
||||
}
|
||||
}
|
||||
|
||||
func TestStripUnsupportedBetas_AllBetasStripped(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
body := []byte(`{"model":"claude-opus-4.6","betas":["context-1m-2025-08-07"],"messages":[]}`)
|
||||
result := stripUnsupportedBetas(body)
|
||||
|
||||
betas := gjson.GetBytes(result, "betas")
|
||||
if betas.Exists() {
|
||||
t.Fatal("betas field should be deleted when all betas are stripped")
|
||||
}
|
||||
}
|
||||
|
||||
func TestCopilotModelEntry_Limits(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
tests := []struct {
|
||||
name string
|
||||
capabilities map[string]any
|
||||
wantNil bool
|
||||
wantPrompt int
|
||||
wantOutput int
|
||||
wantContext int
|
||||
}{
|
||||
{
|
||||
name: "nil capabilities",
|
||||
capabilities: nil,
|
||||
wantNil: true,
|
||||
},
|
||||
{
|
||||
name: "no limits key",
|
||||
capabilities: map[string]any{"family": "claude-opus-4.6"},
|
||||
wantNil: true,
|
||||
},
|
||||
{
|
||||
name: "limits is not a map",
|
||||
capabilities: map[string]any{"limits": "invalid"},
|
||||
wantNil: true,
|
||||
},
|
||||
{
|
||||
name: "all zero values",
|
||||
capabilities: map[string]any{
|
||||
"limits": map[string]any{
|
||||
"max_context_window_tokens": float64(0),
|
||||
"max_prompt_tokens": float64(0),
|
||||
"max_output_tokens": float64(0),
|
||||
},
|
||||
},
|
||||
wantNil: true,
|
||||
},
|
||||
{
|
||||
name: "individual account limits (128K prompt)",
|
||||
capabilities: map[string]any{
|
||||
"limits": map[string]any{
|
||||
"max_context_window_tokens": float64(144000),
|
||||
"max_prompt_tokens": float64(128000),
|
||||
"max_output_tokens": float64(64000),
|
||||
},
|
||||
},
|
||||
wantNil: false,
|
||||
wantPrompt: 128000,
|
||||
wantOutput: 64000,
|
||||
wantContext: 144000,
|
||||
},
|
||||
{
|
||||
name: "business account limits (168K prompt)",
|
||||
capabilities: map[string]any{
|
||||
"limits": map[string]any{
|
||||
"max_context_window_tokens": float64(200000),
|
||||
"max_prompt_tokens": float64(168000),
|
||||
"max_output_tokens": float64(32000),
|
||||
},
|
||||
},
|
||||
wantNil: false,
|
||||
wantPrompt: 168000,
|
||||
wantOutput: 32000,
|
||||
wantContext: 200000,
|
||||
},
|
||||
{
|
||||
name: "partial limits (only prompt)",
|
||||
capabilities: map[string]any{
|
||||
"limits": map[string]any{
|
||||
"max_prompt_tokens": float64(128000),
|
||||
},
|
||||
},
|
||||
wantNil: false,
|
||||
wantPrompt: 128000,
|
||||
wantOutput: 0,
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
t.Parallel()
|
||||
entry := copilotauth.CopilotModelEntry{
|
||||
ID: "claude-opus-4.6",
|
||||
Capabilities: tt.capabilities,
|
||||
}
|
||||
limits := entry.Limits()
|
||||
if tt.wantNil {
|
||||
if limits != nil {
|
||||
t.Fatalf("expected nil limits, got %+v", limits)
|
||||
}
|
||||
return
|
||||
}
|
||||
if limits == nil {
|
||||
t.Fatal("expected non-nil limits, got nil")
|
||||
}
|
||||
if limits.MaxPromptTokens != tt.wantPrompt {
|
||||
t.Errorf("MaxPromptTokens = %d, want %d", limits.MaxPromptTokens, tt.wantPrompt)
|
||||
}
|
||||
if limits.MaxOutputTokens != tt.wantOutput {
|
||||
t.Errorf("MaxOutputTokens = %d, want %d", limits.MaxOutputTokens, tt.wantOutput)
|
||||
}
|
||||
if tt.wantContext > 0 && limits.MaxContextWindowTokens != tt.wantContext {
|
||||
t.Errorf("MaxContextWindowTokens = %d, want %d", limits.MaxContextWindowTokens, tt.wantContext)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
1374
internal/runtime/executor/gitlab_executor.go
Normal file
1374
internal/runtime/executor/gitlab_executor.go
Normal file
File diff suppressed because it is too large
Load Diff
539
internal/runtime/executor/gitlab_executor_test.go
Normal file
539
internal/runtime/executor/gitlab_executor_test.go
Normal file
@@ -0,0 +1,539 @@
|
||||
package executor
|
||||
|
||||
import (
|
||||
"context"
|
||||
"encoding/json"
|
||||
"io"
|
||||
"net/http"
|
||||
"net/http/httptest"
|
||||
"strings"
|
||||
"testing"
|
||||
|
||||
"github.com/router-for-me/CLIProxyAPI/v6/internal/config"
|
||||
_ "github.com/router-for-me/CLIProxyAPI/v6/internal/translator"
|
||||
cliproxyauth "github.com/router-for-me/CLIProxyAPI/v6/sdk/cliproxy/auth"
|
||||
cliproxyexecutor "github.com/router-for-me/CLIProxyAPI/v6/sdk/cliproxy/executor"
|
||||
sdktranslator "github.com/router-for-me/CLIProxyAPI/v6/sdk/translator"
|
||||
"github.com/tidwall/gjson"
|
||||
)
|
||||
|
||||
func TestGitLabExecutorExecuteUsesChatEndpoint(t *testing.T) {
|
||||
srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
if r.URL.Path != gitLabChatEndpoint {
|
||||
t.Fatalf("unexpected path %q", r.URL.Path)
|
||||
}
|
||||
_, _ = w.Write([]byte(`"chat response"`))
|
||||
}))
|
||||
defer srv.Close()
|
||||
|
||||
exec := NewGitLabExecutor(&config.Config{})
|
||||
auth := &cliproxyauth.Auth{
|
||||
Provider: "gitlab",
|
||||
Metadata: map[string]any{
|
||||
"base_url": srv.URL,
|
||||
"access_token": "oauth-access",
|
||||
"model_name": "claude-sonnet-4-5",
|
||||
},
|
||||
}
|
||||
req := cliproxyexecutor.Request{
|
||||
Model: "gitlab-duo",
|
||||
Payload: []byte(`{"model":"gitlab-duo","messages":[{"role":"user","content":"hello"}]}`),
|
||||
}
|
||||
|
||||
resp, err := exec.Execute(context.Background(), auth, req, cliproxyexecutor.Options{
|
||||
SourceFormat: sdktranslator.FromString("openai"),
|
||||
})
|
||||
if err != nil {
|
||||
t.Fatalf("Execute() error = %v", err)
|
||||
}
|
||||
if got := gjson.GetBytes(resp.Payload, "choices.0.message.content").String(); got != "chat response" {
|
||||
t.Fatalf("expected chat response, got %q", got)
|
||||
}
|
||||
if got := gjson.GetBytes(resp.Payload, "model").String(); got != "claude-sonnet-4-5" {
|
||||
t.Fatalf("expected resolved model, got %q", got)
|
||||
}
|
||||
}
|
||||
|
||||
func TestGitLabExecutorExecuteFallsBackToCodeSuggestions(t *testing.T) {
|
||||
chatCalls := 0
|
||||
srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
switch r.URL.Path {
|
||||
case gitLabChatEndpoint:
|
||||
chatCalls++
|
||||
http.Error(w, "feature unavailable", http.StatusForbidden)
|
||||
case gitLabCodeSuggestionsEndpoint:
|
||||
_ = json.NewEncoder(w).Encode(map[string]any{
|
||||
"choices": []map[string]any{{
|
||||
"text": "fallback response",
|
||||
}},
|
||||
})
|
||||
default:
|
||||
t.Fatalf("unexpected path %q", r.URL.Path)
|
||||
}
|
||||
}))
|
||||
defer srv.Close()
|
||||
|
||||
exec := NewGitLabExecutor(&config.Config{})
|
||||
auth := &cliproxyauth.Auth{
|
||||
Provider: "gitlab",
|
||||
Metadata: map[string]any{
|
||||
"base_url": srv.URL,
|
||||
"personal_access_token": "glpat-token",
|
||||
"auth_method": "pat",
|
||||
},
|
||||
}
|
||||
req := cliproxyexecutor.Request{
|
||||
Model: "gitlab-duo",
|
||||
Payload: []byte(`{"model":"gitlab-duo","messages":[{"role":"user","content":"write code"}]}`),
|
||||
}
|
||||
|
||||
resp, err := exec.Execute(context.Background(), auth, req, cliproxyexecutor.Options{
|
||||
SourceFormat: sdktranslator.FromString("openai"),
|
||||
})
|
||||
if err != nil {
|
||||
t.Fatalf("Execute() error = %v", err)
|
||||
}
|
||||
if chatCalls != 1 {
|
||||
t.Fatalf("expected chat endpoint to be tried once, got %d", chatCalls)
|
||||
}
|
||||
if got := gjson.GetBytes(resp.Payload, "choices.0.message.content").String(); got != "fallback response" {
|
||||
t.Fatalf("expected fallback response, got %q", got)
|
||||
}
|
||||
}
|
||||
|
||||
func TestGitLabExecutorExecuteUsesAnthropicGateway(t *testing.T) {
|
||||
var gotAuthHeader, gotRealmHeader string
|
||||
var gotPath string
|
||||
var gotModel string
|
||||
srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
gotPath = r.URL.Path
|
||||
gotAuthHeader = r.Header.Get("Authorization")
|
||||
gotRealmHeader = r.Header.Get("X-Gitlab-Realm")
|
||||
gotModel = gjson.GetBytes(readBody(t, r), "model").String()
|
||||
w.Header().Set("Content-Type", "application/json")
|
||||
_, _ = w.Write([]byte(`{"id":"msg_1","type":"message","role":"assistant","model":"claude-sonnet-4-5","content":[{"type":"tool_use","id":"toolu_1","name":"Bash","input":{"cmd":"ls"}}],"stop_reason":"tool_use","stop_sequence":null,"usage":{"input_tokens":11,"output_tokens":4}}`))
|
||||
}))
|
||||
defer srv.Close()
|
||||
|
||||
exec := NewGitLabExecutor(&config.Config{})
|
||||
auth := &cliproxyauth.Auth{
|
||||
Provider: "gitlab",
|
||||
Metadata: map[string]any{
|
||||
"duo_gateway_base_url": srv.URL,
|
||||
"duo_gateway_token": "gateway-token",
|
||||
"duo_gateway_headers": map[string]string{"X-Gitlab-Realm": "saas"},
|
||||
"model_provider": "anthropic",
|
||||
"model_name": "claude-sonnet-4-5",
|
||||
},
|
||||
}
|
||||
req := cliproxyexecutor.Request{
|
||||
Model: "gitlab-duo",
|
||||
Payload: []byte(`{
|
||||
"model":"gitlab-duo",
|
||||
"messages":[{"role":"user","content":[{"type":"text","text":"list files"}]}],
|
||||
"tools":[{"name":"Bash","description":"run bash","input_schema":{"type":"object","properties":{"cmd":{"type":"string"}},"required":["cmd"]}}],
|
||||
"max_tokens":128
|
||||
}`),
|
||||
}
|
||||
|
||||
resp, err := exec.Execute(context.Background(), auth, req, cliproxyexecutor.Options{
|
||||
SourceFormat: sdktranslator.FromString("claude"),
|
||||
})
|
||||
if err != nil {
|
||||
t.Fatalf("Execute() error = %v", err)
|
||||
}
|
||||
if gotPath != "/v1/proxy/anthropic/v1/messages" {
|
||||
t.Fatalf("Path = %q, want %q", gotPath, "/v1/proxy/anthropic/v1/messages")
|
||||
}
|
||||
if gotAuthHeader != "Bearer gateway-token" {
|
||||
t.Fatalf("Authorization = %q, want Bearer gateway-token", gotAuthHeader)
|
||||
}
|
||||
if gotRealmHeader != "saas" {
|
||||
t.Fatalf("X-Gitlab-Realm = %q, want saas", gotRealmHeader)
|
||||
}
|
||||
if gotModel != "claude-sonnet-4-5" {
|
||||
t.Fatalf("model = %q, want claude-sonnet-4-5", gotModel)
|
||||
}
|
||||
if got := gjson.GetBytes(resp.Payload, "content.0.type").String(); got != "tool_use" {
|
||||
t.Fatalf("expected tool_use response, got %q", got)
|
||||
}
|
||||
if got := gjson.GetBytes(resp.Payload, "content.0.name").String(); got != "Bash" {
|
||||
t.Fatalf("expected tool name Bash, got %q", got)
|
||||
}
|
||||
}
|
||||
|
||||
func TestGitLabExecutorExecuteUsesOpenAIGateway(t *testing.T) {
|
||||
var gotAuthHeader, gotRealmHeader string
|
||||
var gotPath string
|
||||
var gotModel string
|
||||
srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
gotPath = r.URL.Path
|
||||
gotAuthHeader = r.Header.Get("Authorization")
|
||||
gotRealmHeader = r.Header.Get("X-Gitlab-Realm")
|
||||
gotModel = gjson.GetBytes(readBody(t, r), "model").String()
|
||||
w.Header().Set("Content-Type", "text/event-stream")
|
||||
_, _ = w.Write([]byte("data: {\"type\":\"response.created\",\"response\":{\"id\":\"resp_1\",\"created_at\":1710000000,\"model\":\"gpt-5-codex\"}}\n\n"))
|
||||
_, _ = w.Write([]byte("data: {\"type\":\"response.output_text.delta\",\"delta\":\"hello from openai gateway\"}\n\n"))
|
||||
_, _ = w.Write([]byte("data: {\"type\":\"response.completed\",\"response\":{\"id\":\"resp_1\",\"created_at\":1710000000,\"model\":\"gpt-5-codex\",\"output\":[{\"type\":\"message\",\"id\":\"msg_1\",\"role\":\"assistant\",\"content\":[{\"type\":\"output_text\",\"text\":\"hello from openai gateway\"}]}],\"usage\":{\"input_tokens\":11,\"output_tokens\":4,\"total_tokens\":15}}}\n\n"))
|
||||
}))
|
||||
defer srv.Close()
|
||||
|
||||
exec := NewGitLabExecutor(&config.Config{})
|
||||
auth := &cliproxyauth.Auth{
|
||||
Provider: "gitlab",
|
||||
Metadata: map[string]any{
|
||||
"duo_gateway_base_url": srv.URL,
|
||||
"duo_gateway_token": "gateway-token",
|
||||
"duo_gateway_headers": map[string]string{"X-Gitlab-Realm": "saas"},
|
||||
"model_provider": "openai",
|
||||
"model_name": "gpt-5-codex",
|
||||
},
|
||||
}
|
||||
req := cliproxyexecutor.Request{
|
||||
Model: "gitlab-duo",
|
||||
Payload: []byte(`{"model":"gitlab-duo","messages":[{"role":"user","content":"hello"}]}`),
|
||||
}
|
||||
|
||||
resp, err := exec.Execute(context.Background(), auth, req, cliproxyexecutor.Options{
|
||||
SourceFormat: sdktranslator.FromString("openai"),
|
||||
})
|
||||
if err != nil {
|
||||
t.Fatalf("Execute() error = %v", err)
|
||||
}
|
||||
if gotPath != "/v1/proxy/openai/v1/responses" {
|
||||
t.Fatalf("Path = %q, want %q", gotPath, "/v1/proxy/openai/v1/responses")
|
||||
}
|
||||
if gotAuthHeader != "Bearer gateway-token" {
|
||||
t.Fatalf("Authorization = %q, want Bearer gateway-token", gotAuthHeader)
|
||||
}
|
||||
if gotRealmHeader != "saas" {
|
||||
t.Fatalf("X-Gitlab-Realm = %q, want saas", gotRealmHeader)
|
||||
}
|
||||
if gotModel != "gpt-5-codex" {
|
||||
t.Fatalf("model = %q, want gpt-5-codex", gotModel)
|
||||
}
|
||||
if got := gjson.GetBytes(resp.Payload, "choices.0.message.content").String(); got != "hello from openai gateway" {
|
||||
t.Fatalf("expected openai gateway response, got %q payload=%s", got, string(resp.Payload))
|
||||
}
|
||||
}
|
||||
|
||||
func TestGitLabExecutorExecuteUsesRequestedModelToSelectOpenAIGateway(t *testing.T) {
|
||||
var gotAuthHeader, gotRealmHeader, gotBetaHeader, gotUserAgent string
|
||||
var gotPath string
|
||||
var gotModel string
|
||||
srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
gotPath = r.URL.Path
|
||||
gotAuthHeader = r.Header.Get("Authorization")
|
||||
gotRealmHeader = r.Header.Get("X-Gitlab-Realm")
|
||||
gotBetaHeader = r.Header.Get("anthropic-beta")
|
||||
gotUserAgent = r.Header.Get("User-Agent")
|
||||
gotModel = gjson.GetBytes(readBody(t, r), "model").String()
|
||||
w.Header().Set("Content-Type", "text/event-stream")
|
||||
_, _ = w.Write([]byte("data: {\"type\":\"response.created\",\"response\":{\"id\":\"resp_1\",\"created_at\":1710000000,\"model\":\"duo-chat-gpt-5-codex\"}}\n\n"))
|
||||
_, _ = w.Write([]byte("data: {\"type\":\"response.output_text.delta\",\"delta\":\"hello from explicit openai model\"}\n\n"))
|
||||
_, _ = w.Write([]byte("data: {\"type\":\"response.completed\",\"response\":{\"id\":\"resp_1\",\"created_at\":1710000000,\"model\":\"duo-chat-gpt-5-codex\",\"output\":[{\"type\":\"message\",\"id\":\"msg_1\",\"role\":\"assistant\",\"content\":[{\"type\":\"output_text\",\"text\":\"hello from explicit openai model\"}]}],\"usage\":{\"input_tokens\":11,\"output_tokens\":4,\"total_tokens\":15}}}\n\n"))
|
||||
}))
|
||||
defer srv.Close()
|
||||
|
||||
exec := NewGitLabExecutor(&config.Config{})
|
||||
auth := &cliproxyauth.Auth{
|
||||
Provider: "gitlab",
|
||||
Metadata: map[string]any{
|
||||
"duo_gateway_base_url": srv.URL,
|
||||
"duo_gateway_token": "gateway-token",
|
||||
"duo_gateway_headers": map[string]string{"X-Gitlab-Realm": "saas"},
|
||||
"model_provider": "anthropic",
|
||||
"model_name": "claude-sonnet-4-5",
|
||||
},
|
||||
}
|
||||
req := cliproxyexecutor.Request{
|
||||
Model: "duo-chat-gpt-5-codex",
|
||||
Payload: []byte(`{"model":"duo-chat-gpt-5-codex","messages":[{"role":"user","content":"hello"}]}`),
|
||||
}
|
||||
|
||||
resp, err := exec.Execute(context.Background(), auth, req, cliproxyexecutor.Options{
|
||||
SourceFormat: sdktranslator.FromString("openai"),
|
||||
})
|
||||
if err != nil {
|
||||
t.Fatalf("Execute() error = %v", err)
|
||||
}
|
||||
if gotPath != "/v1/proxy/openai/v1/responses" {
|
||||
t.Fatalf("Path = %q, want %q", gotPath, "/v1/proxy/openai/v1/responses")
|
||||
}
|
||||
if gotAuthHeader != "Bearer gateway-token" {
|
||||
t.Fatalf("Authorization = %q, want Bearer gateway-token", gotAuthHeader)
|
||||
}
|
||||
if gotRealmHeader != "saas" {
|
||||
t.Fatalf("X-Gitlab-Realm = %q, want saas", gotRealmHeader)
|
||||
}
|
||||
if gotBetaHeader != gitLabContext1MBeta {
|
||||
t.Fatalf("anthropic-beta = %q, want %q", gotBetaHeader, gitLabContext1MBeta)
|
||||
}
|
||||
if gotUserAgent != gitLabNativeUserAgent {
|
||||
t.Fatalf("User-Agent = %q, want %q", gotUserAgent, gitLabNativeUserAgent)
|
||||
}
|
||||
if gotModel != "duo-chat-gpt-5-codex" {
|
||||
t.Fatalf("model = %q, want duo-chat-gpt-5-codex", gotModel)
|
||||
}
|
||||
if got := gjson.GetBytes(resp.Payload, "choices.0.message.content").String(); got != "hello from explicit openai model" {
|
||||
t.Fatalf("expected explicit openai model response, got %q payload=%s", got, string(resp.Payload))
|
||||
}
|
||||
}
|
||||
|
||||
func TestGitLabExecutorRefreshUpdatesMetadata(t *testing.T) {
|
||||
srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
switch r.URL.Path {
|
||||
case "/oauth/token":
|
||||
_ = json.NewEncoder(w).Encode(map[string]any{
|
||||
"access_token": "oauth-refreshed",
|
||||
"refresh_token": "oauth-refresh",
|
||||
"token_type": "Bearer",
|
||||
"scope": "api read_user",
|
||||
"created_at": 1710000000,
|
||||
"expires_in": 3600,
|
||||
})
|
||||
case "/api/v4/code_suggestions/direct_access":
|
||||
_ = json.NewEncoder(w).Encode(map[string]any{
|
||||
"base_url": "https://cloud.gitlab.example.com",
|
||||
"token": "gateway-token",
|
||||
"expires_at": 1710003600,
|
||||
"headers": map[string]string{"X-Gitlab-Realm": "saas"},
|
||||
"model_details": map[string]any{
|
||||
"model_provider": "anthropic",
|
||||
"model_name": "claude-sonnet-4-5",
|
||||
},
|
||||
})
|
||||
default:
|
||||
t.Fatalf("unexpected path %q", r.URL.Path)
|
||||
}
|
||||
}))
|
||||
defer srv.Close()
|
||||
|
||||
exec := NewGitLabExecutor(&config.Config{})
|
||||
auth := &cliproxyauth.Auth{
|
||||
ID: "gitlab-auth.json",
|
||||
Provider: "gitlab",
|
||||
Metadata: map[string]any{
|
||||
"base_url": srv.URL,
|
||||
"access_token": "oauth-access",
|
||||
"refresh_token": "oauth-refresh",
|
||||
"oauth_client_id": "client-id",
|
||||
"auth_method": "oauth",
|
||||
"oauth_expires_at": "2000-01-01T00:00:00Z",
|
||||
},
|
||||
}
|
||||
|
||||
updated, err := exec.Refresh(context.Background(), auth)
|
||||
if err != nil {
|
||||
t.Fatalf("Refresh() error = %v", err)
|
||||
}
|
||||
if got := updated.Metadata["access_token"]; got != "oauth-refreshed" {
|
||||
t.Fatalf("expected refreshed access token, got %#v", got)
|
||||
}
|
||||
if got := updated.Metadata["model_name"]; got != "claude-sonnet-4-5" {
|
||||
t.Fatalf("expected refreshed model metadata, got %#v", got)
|
||||
}
|
||||
}
|
||||
|
||||
func TestGitLabExecutorExecuteStreamUsesCodeSuggestionsSSE(t *testing.T) {
|
||||
var gotAccept, gotStreamingHeader, gotEncoding string
|
||||
var gotStreamFlag bool
|
||||
srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
if r.URL.Path != gitLabCodeSuggestionsEndpoint {
|
||||
t.Fatalf("unexpected path %q", r.URL.Path)
|
||||
}
|
||||
gotAccept = r.Header.Get("Accept")
|
||||
gotStreamingHeader = r.Header.Get(gitLabSSEStreamingHeader)
|
||||
gotEncoding = r.Header.Get("Accept-Encoding")
|
||||
gotStreamFlag = gjson.GetBytes(readBody(t, r), "stream").Bool()
|
||||
|
||||
w.Header().Set("Content-Type", "text/event-stream")
|
||||
_, _ = w.Write([]byte("event: stream_start\n"))
|
||||
_, _ = w.Write([]byte("data: {\"model\":{\"name\":\"claude-sonnet-4-5\"}}\n\n"))
|
||||
_, _ = w.Write([]byte("event: content_chunk\n"))
|
||||
_, _ = w.Write([]byte("data: {\"content\":\"hello\"}\n\n"))
|
||||
_, _ = w.Write([]byte("event: content_chunk\n"))
|
||||
_, _ = w.Write([]byte("data: {\"content\":\" world\"}\n\n"))
|
||||
_, _ = w.Write([]byte("event: stream_end\n"))
|
||||
_, _ = w.Write([]byte("data: {}\n\n"))
|
||||
}))
|
||||
defer srv.Close()
|
||||
|
||||
exec := NewGitLabExecutor(&config.Config{})
|
||||
auth := &cliproxyauth.Auth{
|
||||
Provider: "gitlab",
|
||||
Metadata: map[string]any{
|
||||
"base_url": srv.URL,
|
||||
"access_token": "oauth-access",
|
||||
"model_name": "claude-sonnet-4-5",
|
||||
},
|
||||
}
|
||||
req := cliproxyexecutor.Request{
|
||||
Model: "gitlab-duo",
|
||||
Payload: []byte(`{"model":"gitlab-duo","stream":true,"messages":[{"role":"user","content":"hello"}]}`),
|
||||
}
|
||||
|
||||
result, err := exec.ExecuteStream(context.Background(), auth, req, cliproxyexecutor.Options{
|
||||
SourceFormat: sdktranslator.FromString("openai"),
|
||||
})
|
||||
if err != nil {
|
||||
t.Fatalf("ExecuteStream() error = %v", err)
|
||||
}
|
||||
|
||||
lines := collectStreamLines(t, result)
|
||||
if gotAccept != "text/event-stream" {
|
||||
t.Fatalf("Accept = %q, want text/event-stream", gotAccept)
|
||||
}
|
||||
if gotStreamingHeader != "true" {
|
||||
t.Fatalf("%s = %q, want true", gitLabSSEStreamingHeader, gotStreamingHeader)
|
||||
}
|
||||
if gotEncoding != "identity" {
|
||||
t.Fatalf("Accept-Encoding = %q, want identity", gotEncoding)
|
||||
}
|
||||
if !gotStreamFlag {
|
||||
t.Fatalf("expected upstream request to set stream=true")
|
||||
}
|
||||
if len(lines) < 4 {
|
||||
t.Fatalf("expected translated stream chunks, got %d", len(lines))
|
||||
}
|
||||
if !strings.Contains(strings.Join(lines, "\n"), `"content":"hello"`) {
|
||||
t.Fatalf("expected hello delta in stream, got %q", strings.Join(lines, "\n"))
|
||||
}
|
||||
if !strings.Contains(strings.Join(lines, "\n"), `"content":" world"`) {
|
||||
t.Fatalf("expected world delta in stream, got %q", strings.Join(lines, "\n"))
|
||||
}
|
||||
last := lines[len(lines)-1]
|
||||
if last != "data: [DONE]" && !strings.Contains(last, `"finish_reason":"stop"`) {
|
||||
t.Fatalf("expected stream terminator, got %q", last)
|
||||
}
|
||||
}
|
||||
|
||||
func TestGitLabExecutorExecuteStreamFallsBackToSyntheticChat(t *testing.T) {
|
||||
chatCalls := 0
|
||||
streamCalls := 0
|
||||
srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
switch r.URL.Path {
|
||||
case gitLabCodeSuggestionsEndpoint:
|
||||
streamCalls++
|
||||
http.Error(w, "feature unavailable", http.StatusForbidden)
|
||||
case gitLabChatEndpoint:
|
||||
chatCalls++
|
||||
_, _ = w.Write([]byte(`"chat fallback response"`))
|
||||
default:
|
||||
t.Fatalf("unexpected path %q", r.URL.Path)
|
||||
}
|
||||
}))
|
||||
defer srv.Close()
|
||||
|
||||
exec := NewGitLabExecutor(&config.Config{})
|
||||
auth := &cliproxyauth.Auth{
|
||||
Provider: "gitlab",
|
||||
Metadata: map[string]any{
|
||||
"base_url": srv.URL,
|
||||
"access_token": "oauth-access",
|
||||
"model_name": "claude-sonnet-4-5",
|
||||
},
|
||||
}
|
||||
req := cliproxyexecutor.Request{
|
||||
Model: "gitlab-duo",
|
||||
Payload: []byte(`{"model":"gitlab-duo","stream":true,"messages":[{"role":"user","content":"hello"}]}`),
|
||||
}
|
||||
|
||||
result, err := exec.ExecuteStream(context.Background(), auth, req, cliproxyexecutor.Options{
|
||||
SourceFormat: sdktranslator.FromString("openai"),
|
||||
})
|
||||
if err != nil {
|
||||
t.Fatalf("ExecuteStream() error = %v", err)
|
||||
}
|
||||
|
||||
lines := collectStreamLines(t, result)
|
||||
if streamCalls != 1 {
|
||||
t.Fatalf("expected streaming endpoint once, got %d", streamCalls)
|
||||
}
|
||||
if chatCalls != 1 {
|
||||
t.Fatalf("expected chat fallback once, got %d", chatCalls)
|
||||
}
|
||||
if !strings.Contains(strings.Join(lines, "\n"), `"content":"chat fallback response"`) {
|
||||
t.Fatalf("expected fallback content in stream, got %q", strings.Join(lines, "\n"))
|
||||
}
|
||||
}
|
||||
|
||||
func TestGitLabExecutorExecuteStreamUsesAnthropicGateway(t *testing.T) {
|
||||
var gotPath, gotBetaHeader, gotUserAgent string
|
||||
srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
gotPath = r.URL.Path
|
||||
gotBetaHeader = r.Header.Get("Anthropic-Beta")
|
||||
gotUserAgent = r.Header.Get("User-Agent")
|
||||
w.Header().Set("Content-Type", "text/event-stream")
|
||||
_, _ = w.Write([]byte("event: message_start\n"))
|
||||
_, _ = w.Write([]byte("data: {\"type\":\"message_start\",\"message\":{\"id\":\"msg_1\",\"type\":\"message\",\"role\":\"assistant\",\"model\":\"claude-sonnet-4-5\",\"content\":[],\"stop_reason\":null,\"stop_sequence\":null,\"usage\":{\"input_tokens\":0,\"output_tokens\":0}}}\n\n"))
|
||||
_, _ = w.Write([]byte("event: content_block_start\n"))
|
||||
_, _ = w.Write([]byte("data: {\"type\":\"content_block_start\",\"index\":0,\"content_block\":{\"type\":\"text\",\"text\":\"\"}}\n\n"))
|
||||
_, _ = w.Write([]byte("event: content_block_delta\n"))
|
||||
_, _ = w.Write([]byte("data: {\"type\":\"content_block_delta\",\"index\":0,\"delta\":{\"type\":\"text_delta\",\"text\":\"hello from gateway\"}}\n\n"))
|
||||
_, _ = w.Write([]byte("event: message_delta\n"))
|
||||
_, _ = w.Write([]byte("data: {\"type\":\"message_delta\",\"delta\":{\"stop_reason\":\"end_turn\",\"stop_sequence\":null},\"usage\":{\"input_tokens\":10,\"output_tokens\":3}}\n\n"))
|
||||
_, _ = w.Write([]byte("event: message_stop\n"))
|
||||
_, _ = w.Write([]byte("data: {\"type\":\"message_stop\"}\n\n"))
|
||||
}))
|
||||
defer srv.Close()
|
||||
|
||||
exec := NewGitLabExecutor(&config.Config{})
|
||||
auth := &cliproxyauth.Auth{
|
||||
Provider: "gitlab",
|
||||
Metadata: map[string]any{
|
||||
"duo_gateway_base_url": srv.URL,
|
||||
"duo_gateway_token": "gateway-token",
|
||||
"duo_gateway_headers": map[string]string{"X-Gitlab-Realm": "saas"},
|
||||
"model_provider": "anthropic",
|
||||
"model_name": "claude-sonnet-4-5",
|
||||
},
|
||||
}
|
||||
req := cliproxyexecutor.Request{
|
||||
Model: "gitlab-duo",
|
||||
Payload: []byte(`{"model":"gitlab-duo","messages":[{"role":"user","content":[{"type":"text","text":"hello"}]}],"max_tokens":64}`),
|
||||
}
|
||||
|
||||
result, err := exec.ExecuteStream(context.Background(), auth, req, cliproxyexecutor.Options{
|
||||
SourceFormat: sdktranslator.FromString("claude"),
|
||||
})
|
||||
if err != nil {
|
||||
t.Fatalf("ExecuteStream() error = %v", err)
|
||||
}
|
||||
|
||||
lines := collectStreamLines(t, result)
|
||||
if gotPath != "/v1/proxy/anthropic/v1/messages" {
|
||||
t.Fatalf("Path = %q, want %q", gotPath, "/v1/proxy/anthropic/v1/messages")
|
||||
}
|
||||
if !strings.Contains(gotBetaHeader, gitLabContext1MBeta) {
|
||||
t.Fatalf("Anthropic-Beta = %q, want to contain %q", gotBetaHeader, gitLabContext1MBeta)
|
||||
}
|
||||
if gotUserAgent != gitLabNativeUserAgent {
|
||||
t.Fatalf("User-Agent = %q, want %q", gotUserAgent, gitLabNativeUserAgent)
|
||||
}
|
||||
if !strings.Contains(strings.Join(lines, "\n"), "hello from gateway") {
|
||||
t.Fatalf("expected anthropic gateway stream, got %q", strings.Join(lines, "\n"))
|
||||
}
|
||||
}
|
||||
|
||||
func collectStreamLines(t *testing.T, result *cliproxyexecutor.StreamResult) []string {
|
||||
t.Helper()
|
||||
lines := make([]string, 0, 8)
|
||||
for chunk := range result.Chunks {
|
||||
if chunk.Err != nil {
|
||||
t.Fatalf("unexpected stream error: %v", chunk.Err)
|
||||
}
|
||||
lines = append(lines, string(chunk.Payload))
|
||||
}
|
||||
return lines
|
||||
}
|
||||
|
||||
func readBody(t *testing.T, r *http.Request) []byte {
|
||||
t.Helper()
|
||||
defer func() { _ = r.Body.Close() }()
|
||||
body, err := io.ReadAll(r.Body)
|
||||
if err != nil {
|
||||
t.Fatalf("ReadAll() error = %v", err)
|
||||
}
|
||||
return body
|
||||
}
|
||||
@@ -29,6 +29,7 @@ func startCodexCacheCleanup() {
|
||||
go func() {
|
||||
ticker := time.NewTicker(codexCacheCleanupInterval)
|
||||
defer ticker.Stop()
|
||||
|
||||
for range ticker.C {
|
||||
purgeExpiredCodexCache()
|
||||
}
|
||||
@@ -38,8 +39,10 @@ func startCodexCacheCleanup() {
|
||||
// purgeExpiredCodexCache removes entries that have expired.
|
||||
func purgeExpiredCodexCache() {
|
||||
now := time.Now()
|
||||
|
||||
codexCacheMu.Lock()
|
||||
defer codexCacheMu.Unlock()
|
||||
|
||||
for key, cache := range codexCacheMap {
|
||||
if cache.Expire.Before(now) {
|
||||
delete(codexCacheMap, key)
|
||||
@@ -66,3 +69,10 @@ func SetCodexCache(key string, cache CodexCache) {
|
||||
codexCacheMap[key] = cache
|
||||
codexCacheMu.Unlock()
|
||||
}
|
||||
|
||||
// deleteCodexCache deletes a cache entry.
|
||||
func deleteCodexCache(key string) {
|
||||
codexCacheMu.Lock()
|
||||
delete(codexCacheMap, key)
|
||||
codexCacheMu.Unlock()
|
||||
}
|
||||
|
||||
@@ -4,6 +4,7 @@ import (
|
||||
"context"
|
||||
"net/http"
|
||||
"strings"
|
||||
"sync"
|
||||
"time"
|
||||
|
||||
"github.com/router-for-me/CLIProxyAPI/v6/internal/config"
|
||||
@@ -12,11 +13,19 @@ import (
|
||||
log "github.com/sirupsen/logrus"
|
||||
)
|
||||
|
||||
// httpClientCache caches HTTP clients by proxy URL to enable connection reuse
|
||||
var (
|
||||
httpClientCache = make(map[string]*http.Client)
|
||||
httpClientCacheMutex sync.RWMutex
|
||||
)
|
||||
|
||||
// NewProxyAwareHTTPClient creates an HTTP client with proper proxy configuration priority:
|
||||
// 1. Use auth.ProxyURL if configured (highest priority)
|
||||
// 2. Use cfg.ProxyURL if auth proxy is not configured
|
||||
// 3. Use RoundTripper from context if neither are configured
|
||||
//
|
||||
// This function caches HTTP clients by proxy URL to enable TCP/TLS connection reuse.
|
||||
//
|
||||
// Parameters:
|
||||
// - ctx: The context containing optional RoundTripper
|
||||
// - cfg: The application configuration
|
||||
@@ -26,11 +35,6 @@ import (
|
||||
// Returns:
|
||||
// - *http.Client: An HTTP client with configured proxy or transport
|
||||
func NewProxyAwareHTTPClient(ctx context.Context, cfg *config.Config, auth *cliproxyauth.Auth, timeout time.Duration) *http.Client {
|
||||
httpClient := &http.Client{}
|
||||
if timeout > 0 {
|
||||
httpClient.Timeout = timeout
|
||||
}
|
||||
|
||||
// Priority 1: Use auth.ProxyURL if configured
|
||||
var proxyURL string
|
||||
if auth != nil {
|
||||
@@ -42,11 +46,34 @@ func NewProxyAwareHTTPClient(ctx context.Context, cfg *config.Config, auth *clip
|
||||
proxyURL = strings.TrimSpace(cfg.ProxyURL)
|
||||
}
|
||||
|
||||
// If we have a proxy URL configured, try cache first to reuse TCP/TLS connections.
|
||||
if proxyURL != "" {
|
||||
httpClientCacheMutex.RLock()
|
||||
if cachedClient, ok := httpClientCache[proxyURL]; ok {
|
||||
httpClientCacheMutex.RUnlock()
|
||||
if timeout > 0 {
|
||||
return &http.Client{Transport: cachedClient.Transport, Timeout: timeout}
|
||||
}
|
||||
return cachedClient
|
||||
}
|
||||
httpClientCacheMutex.RUnlock()
|
||||
}
|
||||
|
||||
// Create new client
|
||||
httpClient := &http.Client{}
|
||||
if timeout > 0 {
|
||||
httpClient.Timeout = timeout
|
||||
}
|
||||
|
||||
// If we have a proxy URL configured, set up the transport
|
||||
if proxyURL != "" {
|
||||
transport := buildProxyTransport(proxyURL)
|
||||
if transport != nil {
|
||||
httpClient.Transport = transport
|
||||
// Cache the client
|
||||
httpClientCacheMutex.Lock()
|
||||
httpClientCache[proxyURL] = httpClient
|
||||
httpClientCacheMutex.Unlock()
|
||||
return httpClient
|
||||
}
|
||||
// If proxy setup failed, log and fall through to context RoundTripper
|
||||
|
||||
@@ -3,21 +3,65 @@ package helps
|
||||
import (
|
||||
"fmt"
|
||||
"strings"
|
||||
"sync"
|
||||
|
||||
"github.com/tidwall/gjson"
|
||||
"github.com/tiktoken-go/tokenizer"
|
||||
)
|
||||
|
||||
// tokenizerCache stores tokenizer instances to avoid repeated creation.
|
||||
var tokenizerCache sync.Map
|
||||
|
||||
type adjustedTokenizer struct {
|
||||
tokenizer.Codec
|
||||
adjustmentFactor float64
|
||||
}
|
||||
|
||||
func (tw *adjustedTokenizer) Count(text string) (int, error) {
|
||||
count, err := tw.Codec.Count(text)
|
||||
if err != nil {
|
||||
return 0, err
|
||||
}
|
||||
if tw.adjustmentFactor > 0 && tw.adjustmentFactor != 1.0 {
|
||||
return int(float64(count) * tw.adjustmentFactor), nil
|
||||
}
|
||||
return count, nil
|
||||
}
|
||||
|
||||
// TokenizerForModel returns a tokenizer codec suitable for an OpenAI-style model id.
|
||||
// For Claude-like models, it applies an adjustment factor since tiktoken may underestimate token counts.
|
||||
func TokenizerForModel(model string) (tokenizer.Codec, error) {
|
||||
sanitized := strings.ToLower(strings.TrimSpace(model))
|
||||
switch {
|
||||
case sanitized == "":
|
||||
if cached, ok := tokenizerCache.Load(sanitized); ok {
|
||||
return cached.(tokenizer.Codec), nil
|
||||
}
|
||||
|
||||
enc, err := tokenizerForModel(sanitized)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
actual, _ := tokenizerCache.LoadOrStore(sanitized, enc)
|
||||
return actual.(tokenizer.Codec), nil
|
||||
}
|
||||
|
||||
func tokenizerForModel(sanitized string) (tokenizer.Codec, error) {
|
||||
if sanitized == "" {
|
||||
return tokenizer.Get(tokenizer.Cl100kBase)
|
||||
}
|
||||
|
||||
// Claude models use cl100k_base with an adjustment factor because tiktoken may underestimate.
|
||||
if strings.Contains(sanitized, "claude") || strings.HasPrefix(sanitized, "kiro-") || strings.HasPrefix(sanitized, "amazonq-") {
|
||||
enc, err := tokenizer.Get(tokenizer.Cl100kBase)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return &adjustedTokenizer{Codec: enc, adjustmentFactor: 1.1}, nil
|
||||
}
|
||||
|
||||
switch {
|
||||
case strings.HasPrefix(sanitized, "gpt-5"):
|
||||
return tokenizer.ForModel(tokenizer.GPT5)
|
||||
case strings.HasPrefix(sanitized, "gpt-5.1"):
|
||||
return tokenizer.ForModel(tokenizer.GPT5)
|
||||
case strings.HasPrefix(sanitized, "gpt-4.1"):
|
||||
return tokenizer.ForModel(tokenizer.GPT41)
|
||||
case strings.HasPrefix(sanitized, "gpt-4o"):
|
||||
@@ -69,6 +113,34 @@ func CountOpenAIChatTokens(enc tokenizer.Codec, payload []byte) (int64, error) {
|
||||
return int64(count), nil
|
||||
}
|
||||
|
||||
// CountClaudeChatTokens approximates prompt tokens for Claude API chat payloads.
|
||||
func CountClaudeChatTokens(enc tokenizer.Codec, payload []byte) (int64, error) {
|
||||
if enc == nil {
|
||||
return 0, fmt.Errorf("encoder is nil")
|
||||
}
|
||||
if len(payload) == 0 {
|
||||
return 0, nil
|
||||
}
|
||||
|
||||
root := gjson.ParseBytes(payload)
|
||||
segments := make([]string, 0, 32)
|
||||
imageTokens := 0
|
||||
|
||||
collectClaudeContent(root.Get("system"), &segments, &imageTokens)
|
||||
collectClaudeMessages(root.Get("messages"), &segments, &imageTokens)
|
||||
collectClaudeTools(root.Get("tools"), &segments)
|
||||
|
||||
joined := strings.TrimSpace(strings.Join(segments, "\n"))
|
||||
if joined == "" {
|
||||
return int64(imageTokens), nil
|
||||
}
|
||||
count, err := enc.Count(joined)
|
||||
if err != nil {
|
||||
return 0, err
|
||||
}
|
||||
return int64(count + imageTokens), nil
|
||||
}
|
||||
|
||||
// BuildOpenAIUsageJSON returns a minimal usage structure understood by downstream translators.
|
||||
func BuildOpenAIUsageJSON(count int64) []byte {
|
||||
return []byte(fmt.Sprintf(`{"usage":{"prompt_tokens":%d,"completion_tokens":0,"total_tokens":%d}}`, count, count))
|
||||
@@ -129,6 +201,10 @@ func collectOpenAIContent(content gjson.Result, segments *[]string) {
|
||||
}
|
||||
}
|
||||
|
||||
func CollectOpenAIContent(content gjson.Result, segments *[]string) {
|
||||
collectOpenAIContent(content, segments)
|
||||
}
|
||||
|
||||
func collectOpenAIToolCalls(calls gjson.Result, segments *[]string) {
|
||||
if !calls.Exists() || !calls.IsArray() {
|
||||
return
|
||||
@@ -226,6 +302,98 @@ func appendToolPayload(tool gjson.Result, segments *[]string) {
|
||||
}
|
||||
}
|
||||
|
||||
func collectClaudeMessages(messages gjson.Result, segments *[]string, imageTokens *int) {
|
||||
if !messages.Exists() || !messages.IsArray() {
|
||||
return
|
||||
}
|
||||
messages.ForEach(func(_, message gjson.Result) bool {
|
||||
addIfNotEmpty(segments, message.Get("role").String())
|
||||
collectClaudeContent(message.Get("content"), segments, imageTokens)
|
||||
return true
|
||||
})
|
||||
}
|
||||
|
||||
func collectClaudeContent(content gjson.Result, segments *[]string, imageTokens *int) {
|
||||
if !content.Exists() {
|
||||
return
|
||||
}
|
||||
if content.Type == gjson.String {
|
||||
addIfNotEmpty(segments, content.String())
|
||||
return
|
||||
}
|
||||
if content.IsArray() {
|
||||
content.ForEach(func(_, part gjson.Result) bool {
|
||||
partType := part.Get("type").String()
|
||||
switch partType {
|
||||
case "text":
|
||||
addIfNotEmpty(segments, part.Get("text").String())
|
||||
case "image":
|
||||
source := part.Get("source")
|
||||
width := source.Get("width").Float()
|
||||
height := source.Get("height").Float()
|
||||
if imageTokens != nil {
|
||||
*imageTokens += estimateImageTokens(width, height)
|
||||
}
|
||||
case "tool_use":
|
||||
addIfNotEmpty(segments, part.Get("id").String())
|
||||
addIfNotEmpty(segments, part.Get("name").String())
|
||||
if input := part.Get("input"); input.Exists() {
|
||||
addIfNotEmpty(segments, input.Raw)
|
||||
}
|
||||
case "tool_result":
|
||||
addIfNotEmpty(segments, part.Get("tool_use_id").String())
|
||||
collectClaudeContent(part.Get("content"), segments, imageTokens)
|
||||
case "thinking":
|
||||
addIfNotEmpty(segments, part.Get("thinking").String())
|
||||
default:
|
||||
if part.Type == gjson.String {
|
||||
addIfNotEmpty(segments, part.String())
|
||||
} else if part.Type == gjson.JSON {
|
||||
addIfNotEmpty(segments, part.Raw)
|
||||
}
|
||||
}
|
||||
return true
|
||||
})
|
||||
return
|
||||
}
|
||||
if content.Type == gjson.JSON {
|
||||
addIfNotEmpty(segments, content.Raw)
|
||||
}
|
||||
}
|
||||
|
||||
func collectClaudeTools(tools gjson.Result, segments *[]string) {
|
||||
if !tools.Exists() || !tools.IsArray() {
|
||||
return
|
||||
}
|
||||
tools.ForEach(func(_, tool gjson.Result) bool {
|
||||
addIfNotEmpty(segments, tool.Get("name").String())
|
||||
addIfNotEmpty(segments, tool.Get("description").String())
|
||||
if inputSchema := tool.Get("input_schema"); inputSchema.Exists() {
|
||||
addIfNotEmpty(segments, inputSchema.Raw)
|
||||
}
|
||||
return true
|
||||
})
|
||||
}
|
||||
|
||||
// estimateImageTokens calculates estimated tokens for an image based on dimensions.
|
||||
// Based on Claude's image token calculation: tokens ≈ (width * height) / 750
|
||||
// Minimum 85 tokens, maximum 1590 tokens (for 1568x1568 images).
|
||||
func estimateImageTokens(width, height float64) int {
|
||||
if width <= 0 || height <= 0 {
|
||||
// No valid dimensions, use default estimate (medium-sized image).
|
||||
return 1000
|
||||
}
|
||||
|
||||
tokens := int(width * height / 750)
|
||||
if tokens < 85 {
|
||||
return 85
|
||||
}
|
||||
if tokens > 1590 {
|
||||
return 1590
|
||||
}
|
||||
return tokens
|
||||
}
|
||||
|
||||
func addIfNotEmpty(segments *[]string, value string) {
|
||||
if segments == nil {
|
||||
return
|
||||
|
||||
@@ -247,15 +247,34 @@ func ParseOpenAIStreamUsage(line []byte) (usage.Detail, bool) {
|
||||
if !usageNode.Exists() {
|
||||
return usage.Detail{}, false
|
||||
}
|
||||
|
||||
inputNode := usageNode.Get("prompt_tokens")
|
||||
if !inputNode.Exists() {
|
||||
inputNode = usageNode.Get("input_tokens")
|
||||
}
|
||||
outputNode := usageNode.Get("completion_tokens")
|
||||
if !outputNode.Exists() {
|
||||
outputNode = usageNode.Get("output_tokens")
|
||||
}
|
||||
detail := usage.Detail{
|
||||
InputTokens: usageNode.Get("prompt_tokens").Int(),
|
||||
OutputTokens: usageNode.Get("completion_tokens").Int(),
|
||||
InputTokens: inputNode.Int(),
|
||||
OutputTokens: outputNode.Int(),
|
||||
TotalTokens: usageNode.Get("total_tokens").Int(),
|
||||
}
|
||||
if cached := usageNode.Get("prompt_tokens_details.cached_tokens"); cached.Exists() {
|
||||
|
||||
cached := usageNode.Get("prompt_tokens_details.cached_tokens")
|
||||
if !cached.Exists() {
|
||||
cached = usageNode.Get("input_tokens_details.cached_tokens")
|
||||
}
|
||||
if cached.Exists() {
|
||||
detail.CachedTokens = cached.Int()
|
||||
}
|
||||
if reasoning := usageNode.Get("completion_tokens_details.reasoning_tokens"); reasoning.Exists() {
|
||||
|
||||
reasoning := usageNode.Get("completion_tokens_details.reasoning_tokens")
|
||||
if !reasoning.Exists() {
|
||||
reasoning = usageNode.Get("output_tokens_details.reasoning_tokens")
|
||||
}
|
||||
if reasoning.Exists() {
|
||||
detail.ReasoningTokens = reasoning.Int()
|
||||
}
|
||||
return detail, true
|
||||
|
||||
460
internal/runtime/executor/kilo_executor.go
Normal file
460
internal/runtime/executor/kilo_executor.go
Normal file
@@ -0,0 +1,460 @@
|
||||
package executor
|
||||
|
||||
import (
|
||||
"bufio"
|
||||
"bytes"
|
||||
"context"
|
||||
"errors"
|
||||
"fmt"
|
||||
"io"
|
||||
"net/http"
|
||||
"strings"
|
||||
"time"
|
||||
|
||||
"github.com/router-for-me/CLIProxyAPI/v6/internal/config"
|
||||
"github.com/router-for-me/CLIProxyAPI/v6/internal/registry"
|
||||
"github.com/router-for-me/CLIProxyAPI/v6/internal/thinking"
|
||||
"github.com/router-for-me/CLIProxyAPI/v6/internal/util"
|
||||
cliproxyauth "github.com/router-for-me/CLIProxyAPI/v6/sdk/cliproxy/auth"
|
||||
cliproxyexecutor "github.com/router-for-me/CLIProxyAPI/v6/sdk/cliproxy/executor"
|
||||
sdktranslator "github.com/router-for-me/CLIProxyAPI/v6/sdk/translator"
|
||||
log "github.com/sirupsen/logrus"
|
||||
"github.com/tidwall/gjson"
|
||||
)
|
||||
|
||||
// KiloExecutor handles requests to Kilo API.
|
||||
type KiloExecutor struct {
|
||||
cfg *config.Config
|
||||
}
|
||||
|
||||
// NewKiloExecutor creates a new Kilo executor instance.
|
||||
func NewKiloExecutor(cfg *config.Config) *KiloExecutor {
|
||||
return &KiloExecutor{cfg: cfg}
|
||||
}
|
||||
|
||||
// Identifier returns the unique identifier for this executor.
|
||||
func (e *KiloExecutor) Identifier() string { return "kilo" }
|
||||
|
||||
// PrepareRequest prepares the HTTP request before execution.
|
||||
func (e *KiloExecutor) PrepareRequest(req *http.Request, auth *cliproxyauth.Auth) error {
|
||||
if req == nil {
|
||||
return nil
|
||||
}
|
||||
accessToken, _ := kiloCredentials(auth)
|
||||
if strings.TrimSpace(accessToken) == "" {
|
||||
return fmt.Errorf("kilo: missing access token")
|
||||
}
|
||||
|
||||
req.Header.Set("Authorization", "Bearer "+accessToken)
|
||||
var attrs map[string]string
|
||||
if auth != nil {
|
||||
attrs = auth.Attributes
|
||||
}
|
||||
util.ApplyCustomHeadersFromAttrs(req, attrs)
|
||||
return nil
|
||||
}
|
||||
|
||||
// HttpRequest executes a raw HTTP request.
|
||||
func (e *KiloExecutor) HttpRequest(ctx context.Context, auth *cliproxyauth.Auth, req *http.Request) (*http.Response, error) {
|
||||
if req == nil {
|
||||
return nil, fmt.Errorf("kilo executor: request is nil")
|
||||
}
|
||||
if ctx == nil {
|
||||
ctx = req.Context()
|
||||
}
|
||||
httpReq := req.WithContext(ctx)
|
||||
if err := e.PrepareRequest(httpReq, auth); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
httpClient := newProxyAwareHTTPClient(ctx, e.cfg, auth, 0)
|
||||
return httpClient.Do(httpReq)
|
||||
}
|
||||
|
||||
// Execute performs a non-streaming request.
|
||||
func (e *KiloExecutor) Execute(ctx context.Context, auth *cliproxyauth.Auth, req cliproxyexecutor.Request, opts cliproxyexecutor.Options) (resp cliproxyexecutor.Response, err error) {
|
||||
baseModel := thinking.ParseSuffix(req.Model).ModelName
|
||||
|
||||
reporter := newUsageReporter(ctx, e.Identifier(), baseModel, auth)
|
||||
defer reporter.trackFailure(ctx, &err)
|
||||
|
||||
accessToken, orgID := kiloCredentials(auth)
|
||||
if accessToken == "" {
|
||||
return resp, fmt.Errorf("kilo: missing access token")
|
||||
}
|
||||
|
||||
from := opts.SourceFormat
|
||||
to := sdktranslator.FromString("openai")
|
||||
endpoint := "/api/openrouter/chat/completions"
|
||||
|
||||
originalPayloadSource := req.Payload
|
||||
if len(opts.OriginalRequest) > 0 {
|
||||
originalPayloadSource = opts.OriginalRequest
|
||||
}
|
||||
originalPayload := originalPayloadSource
|
||||
originalTranslated := sdktranslator.TranslateRequest(from, to, baseModel, originalPayload, opts.Stream)
|
||||
translated := sdktranslator.TranslateRequest(from, to, baseModel, req.Payload, opts.Stream)
|
||||
requestedModel := payloadRequestedModel(opts, req.Model)
|
||||
translated = applyPayloadConfigWithRoot(e.cfg, baseModel, to.String(), "", translated, originalTranslated, requestedModel)
|
||||
|
||||
translated, err = thinking.ApplyThinking(translated, req.Model, from.String(), to.String(), e.Identifier())
|
||||
if err != nil {
|
||||
return resp, err
|
||||
}
|
||||
|
||||
url := "https://api.kilo.ai" + endpoint
|
||||
httpReq, err := http.NewRequestWithContext(ctx, http.MethodPost, url, bytes.NewReader(translated))
|
||||
if err != nil {
|
||||
return resp, err
|
||||
}
|
||||
httpReq.Header.Set("Content-Type", "application/json")
|
||||
httpReq.Header.Set("Authorization", "Bearer "+accessToken)
|
||||
if orgID != "" {
|
||||
httpReq.Header.Set("X-Kilocode-OrganizationID", orgID)
|
||||
}
|
||||
httpReq.Header.Set("User-Agent", "cli-proxy-kilo")
|
||||
var attrs map[string]string
|
||||
if auth != nil {
|
||||
attrs = auth.Attributes
|
||||
}
|
||||
util.ApplyCustomHeadersFromAttrs(httpReq, attrs)
|
||||
|
||||
var authID, authLabel, authType, authValue string
|
||||
if auth != nil {
|
||||
authID = auth.ID
|
||||
authLabel = auth.Label
|
||||
authType, authValue = auth.AccountInfo()
|
||||
}
|
||||
recordAPIRequest(ctx, e.cfg, upstreamRequestLog{
|
||||
URL: url,
|
||||
Method: http.MethodPost,
|
||||
Headers: httpReq.Header.Clone(),
|
||||
Body: translated,
|
||||
Provider: e.Identifier(),
|
||||
AuthID: authID,
|
||||
AuthLabel: authLabel,
|
||||
AuthType: authType,
|
||||
AuthValue: authValue,
|
||||
})
|
||||
|
||||
httpClient := newProxyAwareHTTPClient(ctx, e.cfg, auth, 0)
|
||||
httpResp, err := httpClient.Do(httpReq)
|
||||
if err != nil {
|
||||
recordAPIResponseError(ctx, e.cfg, err)
|
||||
return resp, err
|
||||
}
|
||||
defer httpResp.Body.Close()
|
||||
|
||||
recordAPIResponseMetadata(ctx, e.cfg, httpResp.StatusCode, httpResp.Header.Clone())
|
||||
if httpResp.StatusCode < 200 || httpResp.StatusCode >= 300 {
|
||||
b, _ := io.ReadAll(httpResp.Body)
|
||||
appendAPIResponseChunk(ctx, e.cfg, b)
|
||||
err = statusErr{code: httpResp.StatusCode, msg: string(b)}
|
||||
return resp, err
|
||||
}
|
||||
|
||||
body, err := io.ReadAll(httpResp.Body)
|
||||
if err != nil {
|
||||
recordAPIResponseError(ctx, e.cfg, err)
|
||||
return resp, err
|
||||
}
|
||||
appendAPIResponseChunk(ctx, e.cfg, body)
|
||||
reporter.publish(ctx, parseOpenAIUsage(body))
|
||||
reporter.ensurePublished(ctx)
|
||||
|
||||
var param any
|
||||
out := sdktranslator.TranslateNonStream(ctx, to, from, req.Model, opts.OriginalRequest, translated, body, ¶m)
|
||||
resp = cliproxyexecutor.Response{Payload: []byte(out)}
|
||||
return resp, nil
|
||||
}
|
||||
|
||||
// ExecuteStream performs a streaming request.
|
||||
func (e *KiloExecutor) ExecuteStream(ctx context.Context, auth *cliproxyauth.Auth, req cliproxyexecutor.Request, opts cliproxyexecutor.Options) (_ *cliproxyexecutor.StreamResult, err error) {
|
||||
baseModel := thinking.ParseSuffix(req.Model).ModelName
|
||||
|
||||
reporter := newUsageReporter(ctx, e.Identifier(), baseModel, auth)
|
||||
defer reporter.trackFailure(ctx, &err)
|
||||
|
||||
accessToken, orgID := kiloCredentials(auth)
|
||||
if accessToken == "" {
|
||||
return nil, fmt.Errorf("kilo: missing access token")
|
||||
}
|
||||
|
||||
from := opts.SourceFormat
|
||||
to := sdktranslator.FromString("openai")
|
||||
endpoint := "/api/openrouter/chat/completions"
|
||||
|
||||
originalPayloadSource := req.Payload
|
||||
if len(opts.OriginalRequest) > 0 {
|
||||
originalPayloadSource = opts.OriginalRequest
|
||||
}
|
||||
originalPayload := originalPayloadSource
|
||||
originalTranslated := sdktranslator.TranslateRequest(from, to, baseModel, originalPayload, true)
|
||||
translated := sdktranslator.TranslateRequest(from, to, baseModel, req.Payload, true)
|
||||
requestedModel := payloadRequestedModel(opts, req.Model)
|
||||
translated = applyPayloadConfigWithRoot(e.cfg, baseModel, to.String(), "", translated, originalTranslated, requestedModel)
|
||||
|
||||
translated, err = thinking.ApplyThinking(translated, req.Model, from.String(), to.String(), e.Identifier())
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
url := "https://api.kilo.ai" + endpoint
|
||||
httpReq, err := http.NewRequestWithContext(ctx, http.MethodPost, url, bytes.NewReader(translated))
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
httpReq.Header.Set("Content-Type", "application/json")
|
||||
httpReq.Header.Set("Authorization", "Bearer "+accessToken)
|
||||
if orgID != "" {
|
||||
httpReq.Header.Set("X-Kilocode-OrganizationID", orgID)
|
||||
}
|
||||
httpReq.Header.Set("User-Agent", "cli-proxy-kilo")
|
||||
httpReq.Header.Set("Accept", "text/event-stream")
|
||||
httpReq.Header.Set("Cache-Control", "no-cache")
|
||||
|
||||
var attrs map[string]string
|
||||
if auth != nil {
|
||||
attrs = auth.Attributes
|
||||
}
|
||||
util.ApplyCustomHeadersFromAttrs(httpReq, attrs)
|
||||
|
||||
var authID, authLabel, authType, authValue string
|
||||
if auth != nil {
|
||||
authID = auth.ID
|
||||
authLabel = auth.Label
|
||||
authType, authValue = auth.AccountInfo()
|
||||
}
|
||||
recordAPIRequest(ctx, e.cfg, upstreamRequestLog{
|
||||
URL: url,
|
||||
Method: http.MethodPost,
|
||||
Headers: httpReq.Header.Clone(),
|
||||
Body: translated,
|
||||
Provider: e.Identifier(),
|
||||
AuthID: authID,
|
||||
AuthLabel: authLabel,
|
||||
AuthType: authType,
|
||||
AuthValue: authValue,
|
||||
})
|
||||
|
||||
httpClient := newProxyAwareHTTPClient(ctx, e.cfg, auth, 0)
|
||||
httpResp, err := httpClient.Do(httpReq)
|
||||
if err != nil {
|
||||
recordAPIResponseError(ctx, e.cfg, err)
|
||||
return nil, err
|
||||
}
|
||||
|
||||
recordAPIResponseMetadata(ctx, e.cfg, httpResp.StatusCode, httpResp.Header.Clone())
|
||||
if httpResp.StatusCode < 200 || httpResp.StatusCode >= 300 {
|
||||
b, _ := io.ReadAll(httpResp.Body)
|
||||
appendAPIResponseChunk(ctx, e.cfg, b)
|
||||
httpResp.Body.Close()
|
||||
err = statusErr{code: httpResp.StatusCode, msg: string(b)}
|
||||
return nil, err
|
||||
}
|
||||
|
||||
out := make(chan cliproxyexecutor.StreamChunk)
|
||||
go func() {
|
||||
defer close(out)
|
||||
defer httpResp.Body.Close()
|
||||
|
||||
scanner := bufio.NewScanner(httpResp.Body)
|
||||
scanner.Buffer(nil, 52_428_800)
|
||||
var param any
|
||||
for scanner.Scan() {
|
||||
line := scanner.Bytes()
|
||||
appendAPIResponseChunk(ctx, e.cfg, line)
|
||||
if detail, ok := parseOpenAIStreamUsage(line); ok {
|
||||
reporter.publish(ctx, detail)
|
||||
}
|
||||
if len(line) == 0 {
|
||||
continue
|
||||
}
|
||||
if !bytes.HasPrefix(line, []byte("data:")) {
|
||||
continue
|
||||
}
|
||||
chunks := sdktranslator.TranslateStream(ctx, to, from, req.Model, opts.OriginalRequest, translated, bytes.Clone(line), ¶m)
|
||||
for i := range chunks {
|
||||
out <- cliproxyexecutor.StreamChunk{Payload: []byte(chunks[i])}
|
||||
}
|
||||
}
|
||||
if errScan := scanner.Err(); errScan != nil {
|
||||
recordAPIResponseError(ctx, e.cfg, errScan)
|
||||
reporter.publishFailure(ctx)
|
||||
out <- cliproxyexecutor.StreamChunk{Err: errScan}
|
||||
}
|
||||
reporter.ensurePublished(ctx)
|
||||
}()
|
||||
|
||||
return &cliproxyexecutor.StreamResult{
|
||||
Headers: httpResp.Header.Clone(),
|
||||
Chunks: out,
|
||||
}, nil
|
||||
}
|
||||
|
||||
// Refresh validates the Kilo token.
|
||||
func (e *KiloExecutor) Refresh(ctx context.Context, auth *cliproxyauth.Auth) (*cliproxyauth.Auth, error) {
|
||||
if auth == nil {
|
||||
return nil, fmt.Errorf("missing auth")
|
||||
}
|
||||
return auth, nil
|
||||
}
|
||||
|
||||
// CountTokens returns the token count for the given request.
|
||||
func (e *KiloExecutor) CountTokens(ctx context.Context, auth *cliproxyauth.Auth, req cliproxyexecutor.Request, opts cliproxyexecutor.Options) (cliproxyexecutor.Response, error) {
|
||||
return cliproxyexecutor.Response{}, fmt.Errorf("kilo: count tokens not supported")
|
||||
}
|
||||
|
||||
// kiloCredentials extracts access token and other info from auth.
|
||||
func kiloCredentials(auth *cliproxyauth.Auth) (accessToken, orgID string) {
|
||||
if auth == nil {
|
||||
return "", ""
|
||||
}
|
||||
|
||||
// Prefer kilocode specific keys, then fall back to generic keys.
|
||||
// Check metadata first, then attributes.
|
||||
if auth.Metadata != nil {
|
||||
if token, ok := auth.Metadata["kilocodeToken"].(string); ok && token != "" {
|
||||
accessToken = token
|
||||
} else if token, ok := auth.Metadata["access_token"].(string); ok && token != "" {
|
||||
accessToken = token
|
||||
}
|
||||
|
||||
if org, ok := auth.Metadata["kilocodeOrganizationId"].(string); ok && org != "" {
|
||||
orgID = org
|
||||
} else if org, ok := auth.Metadata["organization_id"].(string); ok && org != "" {
|
||||
orgID = org
|
||||
}
|
||||
}
|
||||
|
||||
if accessToken == "" && auth.Attributes != nil {
|
||||
if token := auth.Attributes["kilocodeToken"]; token != "" {
|
||||
accessToken = token
|
||||
} else if token := auth.Attributes["access_token"]; token != "" {
|
||||
accessToken = token
|
||||
}
|
||||
}
|
||||
|
||||
if orgID == "" && auth.Attributes != nil {
|
||||
if org := auth.Attributes["kilocodeOrganizationId"]; org != "" {
|
||||
orgID = org
|
||||
} else if org := auth.Attributes["organization_id"]; org != "" {
|
||||
orgID = org
|
||||
}
|
||||
}
|
||||
|
||||
return accessToken, orgID
|
||||
}
|
||||
|
||||
// FetchKiloModels fetches models from Kilo API.
|
||||
func FetchKiloModels(ctx context.Context, auth *cliproxyauth.Auth, cfg *config.Config) []*registry.ModelInfo {
|
||||
accessToken, orgID := kiloCredentials(auth)
|
||||
if accessToken == "" {
|
||||
log.Infof("kilo: no access token found, skipping dynamic model fetch (using static kilo/auto)")
|
||||
return registry.GetKiloModels()
|
||||
}
|
||||
|
||||
log.Debugf("kilo: fetching dynamic models (orgID: %s)", orgID)
|
||||
|
||||
httpClient := newProxyAwareHTTPClient(ctx, cfg, auth, 0)
|
||||
req, err := http.NewRequestWithContext(ctx, http.MethodGet, "https://api.kilo.ai/api/openrouter/models", nil)
|
||||
if err != nil {
|
||||
log.Warnf("kilo: failed to create model fetch request: %v", err)
|
||||
return registry.GetKiloModels()
|
||||
}
|
||||
|
||||
req.Header.Set("Authorization", "Bearer "+accessToken)
|
||||
if orgID != "" {
|
||||
req.Header.Set("X-Kilocode-OrganizationID", orgID)
|
||||
}
|
||||
req.Header.Set("User-Agent", "cli-proxy-kilo")
|
||||
|
||||
resp, err := httpClient.Do(req)
|
||||
if err != nil {
|
||||
if errors.Is(err, context.Canceled) || errors.Is(err, context.DeadlineExceeded) {
|
||||
log.Warnf("kilo: fetch models canceled: %v", err)
|
||||
} else {
|
||||
log.Warnf("kilo: using static models (API fetch failed: %v)", err)
|
||||
}
|
||||
return registry.GetKiloModels()
|
||||
}
|
||||
defer resp.Body.Close()
|
||||
|
||||
body, err := io.ReadAll(resp.Body)
|
||||
if err != nil {
|
||||
log.Warnf("kilo: failed to read models response: %v", err)
|
||||
return registry.GetKiloModels()
|
||||
}
|
||||
|
||||
if resp.StatusCode != http.StatusOK {
|
||||
log.Warnf("kilo: fetch models failed: status %d, body: %s", resp.StatusCode, string(body))
|
||||
return registry.GetKiloModels()
|
||||
}
|
||||
|
||||
result := gjson.GetBytes(body, "data")
|
||||
if !result.Exists() {
|
||||
// Try root if data field is missing
|
||||
result = gjson.ParseBytes(body)
|
||||
if !result.IsArray() {
|
||||
log.Debugf("kilo: response body: %s", string(body))
|
||||
log.Warn("kilo: invalid API response format (expected array or data field with array)")
|
||||
return registry.GetKiloModels()
|
||||
}
|
||||
}
|
||||
|
||||
var dynamicModels []*registry.ModelInfo
|
||||
now := time.Now().Unix()
|
||||
count := 0
|
||||
totalCount := 0
|
||||
|
||||
result.ForEach(func(key, value gjson.Result) bool {
|
||||
totalCount++
|
||||
id := value.Get("id").String()
|
||||
pIdxResult := value.Get("preferredIndex")
|
||||
preferredIndex := pIdxResult.Int()
|
||||
|
||||
// Filter models where preferredIndex > 0 (Kilo-curated models)
|
||||
if preferredIndex <= 0 {
|
||||
return true
|
||||
}
|
||||
|
||||
// Check if it's free. We look for :free suffix, is_free flag, or zero pricing.
|
||||
isFree := strings.HasSuffix(id, ":free") || id == "giga-potato" || value.Get("is_free").Bool()
|
||||
if !isFree {
|
||||
// Check pricing as fallback
|
||||
promptPricing := value.Get("pricing.prompt").String()
|
||||
if promptPricing == "0" || promptPricing == "0.0" {
|
||||
isFree = true
|
||||
}
|
||||
}
|
||||
|
||||
if !isFree {
|
||||
log.Debugf("kilo: skipping curated paid model: %s", id)
|
||||
return true
|
||||
}
|
||||
|
||||
log.Debugf("kilo: found curated model: %s (preferredIndex: %d)", id, preferredIndex)
|
||||
|
||||
dynamicModels = append(dynamicModels, ®istry.ModelInfo{
|
||||
ID: id,
|
||||
DisplayName: value.Get("name").String(),
|
||||
ContextLength: int(value.Get("context_length").Int()),
|
||||
OwnedBy: "kilo",
|
||||
Type: "kilo",
|
||||
Object: "model",
|
||||
Created: now,
|
||||
})
|
||||
count++
|
||||
return true
|
||||
})
|
||||
|
||||
log.Infof("kilo: fetched %d models from API, %d curated free (preferredIndex > 0)", totalCount, count)
|
||||
if count == 0 && totalCount > 0 {
|
||||
log.Warn("kilo: no curated free models found (check API response fields)")
|
||||
}
|
||||
|
||||
staticModels := registry.GetKiloModels()
|
||||
// Always include kilo/auto (first static model)
|
||||
allModels := append(staticModels[:1], dynamicModels...)
|
||||
|
||||
return allModels
|
||||
}
|
||||
4706
internal/runtime/executor/kiro_executor.go
Normal file
4706
internal/runtime/executor/kiro_executor.go
Normal file
File diff suppressed because it is too large
Load Diff
423
internal/runtime/executor/kiro_executor_test.go
Normal file
423
internal/runtime/executor/kiro_executor_test.go
Normal file
@@ -0,0 +1,423 @@
|
||||
package executor
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"testing"
|
||||
|
||||
kiroauth "github.com/router-for-me/CLIProxyAPI/v6/internal/auth/kiro"
|
||||
cliproxyauth "github.com/router-for-me/CLIProxyAPI/v6/sdk/cliproxy/auth"
|
||||
)
|
||||
|
||||
func TestBuildKiroEndpointConfigs(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
region string
|
||||
expectedURL string
|
||||
expectedOrigin string
|
||||
expectedName string
|
||||
}{
|
||||
{
|
||||
name: "Empty region - defaults to us-east-1",
|
||||
region: "",
|
||||
expectedURL: "https://q.us-east-1.amazonaws.com/generateAssistantResponse",
|
||||
expectedOrigin: "AI_EDITOR",
|
||||
expectedName: "AmazonQ",
|
||||
},
|
||||
{
|
||||
name: "us-east-1",
|
||||
region: "us-east-1",
|
||||
expectedURL: "https://q.us-east-1.amazonaws.com/generateAssistantResponse",
|
||||
expectedOrigin: "AI_EDITOR",
|
||||
expectedName: "AmazonQ",
|
||||
},
|
||||
{
|
||||
name: "ap-southeast-1",
|
||||
region: "ap-southeast-1",
|
||||
expectedURL: "https://q.ap-southeast-1.amazonaws.com/generateAssistantResponse",
|
||||
expectedOrigin: "AI_EDITOR",
|
||||
expectedName: "AmazonQ",
|
||||
},
|
||||
{
|
||||
name: "eu-west-1",
|
||||
region: "eu-west-1",
|
||||
expectedURL: "https://q.eu-west-1.amazonaws.com/generateAssistantResponse",
|
||||
expectedOrigin: "AI_EDITOR",
|
||||
expectedName: "AmazonQ",
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
configs := buildKiroEndpointConfigs(tt.region)
|
||||
|
||||
if len(configs) != 2 {
|
||||
t.Fatalf("expected 2 endpoint configs, got %d", len(configs))
|
||||
}
|
||||
|
||||
// Check primary endpoint (AmazonQ)
|
||||
primary := configs[0]
|
||||
if primary.URL != tt.expectedURL {
|
||||
t.Errorf("primary URL = %q, want %q", primary.URL, tt.expectedURL)
|
||||
}
|
||||
if primary.Origin != tt.expectedOrigin {
|
||||
t.Errorf("primary Origin = %q, want %q", primary.Origin, tt.expectedOrigin)
|
||||
}
|
||||
if primary.Name != tt.expectedName {
|
||||
t.Errorf("primary Name = %q, want %q", primary.Name, tt.expectedName)
|
||||
}
|
||||
if primary.AmzTarget != "" {
|
||||
t.Errorf("primary AmzTarget should be empty, got %q", primary.AmzTarget)
|
||||
}
|
||||
|
||||
// Check fallback endpoint (CodeWhisperer)
|
||||
fallback := configs[1]
|
||||
if fallback.Name != "CodeWhisperer" {
|
||||
t.Errorf("fallback Name = %q, want %q", fallback.Name, "CodeWhisperer")
|
||||
}
|
||||
// CodeWhisperer fallback uses the same region as Q endpoint
|
||||
expectedRegion := tt.region
|
||||
if expectedRegion == "" {
|
||||
expectedRegion = kiroDefaultRegion
|
||||
}
|
||||
expectedFallbackURL := fmt.Sprintf("https://codewhisperer.%s.amazonaws.com/generateAssistantResponse", expectedRegion)
|
||||
if fallback.URL != expectedFallbackURL {
|
||||
t.Errorf("fallback URL = %q, want %q", fallback.URL, expectedFallbackURL)
|
||||
}
|
||||
if fallback.AmzTarget == "" {
|
||||
t.Error("fallback AmzTarget should NOT be empty")
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestGetKiroEndpointConfigs_NilAuth(t *testing.T) {
|
||||
configs := getKiroEndpointConfigs(nil)
|
||||
|
||||
if len(configs) != 2 {
|
||||
t.Fatalf("expected 2 endpoint configs, got %d", len(configs))
|
||||
}
|
||||
|
||||
// Should return default us-east-1 configs
|
||||
if configs[0].Name != "AmazonQ" {
|
||||
t.Errorf("first config Name = %q, want %q", configs[0].Name, "AmazonQ")
|
||||
}
|
||||
expectedURL := "https://q.us-east-1.amazonaws.com/generateAssistantResponse"
|
||||
if configs[0].URL != expectedURL {
|
||||
t.Errorf("first config URL = %q, want %q", configs[0].URL, expectedURL)
|
||||
}
|
||||
}
|
||||
|
||||
func TestGetKiroEndpointConfigs_WithRegionFromProfileArn(t *testing.T) {
|
||||
auth := &cliproxyauth.Auth{
|
||||
Metadata: map[string]any{
|
||||
"profile_arn": "arn:aws:codewhisperer:ap-southeast-1:123456789012:profile/ABC",
|
||||
},
|
||||
}
|
||||
|
||||
configs := getKiroEndpointConfigs(auth)
|
||||
|
||||
if len(configs) != 2 {
|
||||
t.Fatalf("expected 2 endpoint configs, got %d", len(configs))
|
||||
}
|
||||
|
||||
expectedURL := "https://q.ap-southeast-1.amazonaws.com/generateAssistantResponse"
|
||||
if configs[0].URL != expectedURL {
|
||||
t.Errorf("primary URL = %q, want %q", configs[0].URL, expectedURL)
|
||||
}
|
||||
}
|
||||
|
||||
func TestGetKiroEndpointConfigs_WithApiRegionOverride(t *testing.T) {
|
||||
auth := &cliproxyauth.Auth{
|
||||
Metadata: map[string]any{
|
||||
"api_region": "eu-central-1",
|
||||
"profile_arn": "arn:aws:codewhisperer:us-east-1:123456789012:profile/ABC",
|
||||
},
|
||||
}
|
||||
|
||||
configs := getKiroEndpointConfigs(auth)
|
||||
|
||||
// api_region should take precedence over profile_arn
|
||||
expectedURL := "https://q.eu-central-1.amazonaws.com/generateAssistantResponse"
|
||||
if configs[0].URL != expectedURL {
|
||||
t.Errorf("primary URL = %q, want %q", configs[0].URL, expectedURL)
|
||||
}
|
||||
}
|
||||
|
||||
func TestGetKiroEndpointConfigs_PreferredEndpoint(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
preference string
|
||||
expectedFirstName string
|
||||
}{
|
||||
{
|
||||
name: "Prefer codewhisperer",
|
||||
preference: "codewhisperer",
|
||||
expectedFirstName: "CodeWhisperer",
|
||||
},
|
||||
{
|
||||
name: "Prefer ide (alias for codewhisperer)",
|
||||
preference: "ide",
|
||||
expectedFirstName: "CodeWhisperer",
|
||||
},
|
||||
{
|
||||
name: "Prefer amazonq",
|
||||
preference: "amazonq",
|
||||
expectedFirstName: "AmazonQ",
|
||||
},
|
||||
{
|
||||
name: "Prefer q (alias for amazonq)",
|
||||
preference: "q",
|
||||
expectedFirstName: "AmazonQ",
|
||||
},
|
||||
{
|
||||
name: "Prefer cli (alias for amazonq)",
|
||||
preference: "cli",
|
||||
expectedFirstName: "AmazonQ",
|
||||
},
|
||||
{
|
||||
name: "Unknown preference - no reordering",
|
||||
preference: "unknown",
|
||||
expectedFirstName: "AmazonQ",
|
||||
},
|
||||
{
|
||||
name: "Empty preference - no reordering",
|
||||
preference: "",
|
||||
expectedFirstName: "AmazonQ",
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
auth := &cliproxyauth.Auth{
|
||||
Metadata: map[string]any{
|
||||
"preferred_endpoint": tt.preference,
|
||||
},
|
||||
}
|
||||
|
||||
configs := getKiroEndpointConfigs(auth)
|
||||
|
||||
if configs[0].Name != tt.expectedFirstName {
|
||||
t.Errorf("first endpoint Name = %q, want %q", configs[0].Name, tt.expectedFirstName)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestGetKiroEndpointConfigs_PreferredEndpointFromAttributes(t *testing.T) {
|
||||
// Test that preferred_endpoint can also come from Attributes
|
||||
auth := &cliproxyauth.Auth{
|
||||
Metadata: map[string]any{},
|
||||
Attributes: map[string]string{"preferred_endpoint": "codewhisperer"},
|
||||
}
|
||||
|
||||
configs := getKiroEndpointConfigs(auth)
|
||||
|
||||
if configs[0].Name != "CodeWhisperer" {
|
||||
t.Errorf("first endpoint Name = %q, want %q", configs[0].Name, "CodeWhisperer")
|
||||
}
|
||||
}
|
||||
|
||||
func TestGetKiroEndpointConfigs_MetadataTakesPrecedenceOverAttributes(t *testing.T) {
|
||||
auth := &cliproxyauth.Auth{
|
||||
Metadata: map[string]any{"preferred_endpoint": "amazonq"},
|
||||
Attributes: map[string]string{"preferred_endpoint": "codewhisperer"},
|
||||
}
|
||||
|
||||
configs := getKiroEndpointConfigs(auth)
|
||||
|
||||
// Metadata should take precedence
|
||||
if configs[0].Name != "AmazonQ" {
|
||||
t.Errorf("first endpoint Name = %q, want %q", configs[0].Name, "AmazonQ")
|
||||
}
|
||||
}
|
||||
|
||||
func TestGetAuthValue(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
auth *cliproxyauth.Auth
|
||||
key string
|
||||
expected string
|
||||
}{
|
||||
{
|
||||
name: "From metadata",
|
||||
auth: &cliproxyauth.Auth{
|
||||
Metadata: map[string]any{"test_key": "metadata_value"},
|
||||
},
|
||||
key: "test_key",
|
||||
expected: "metadata_value",
|
||||
},
|
||||
{
|
||||
name: "From attributes (fallback)",
|
||||
auth: &cliproxyauth.Auth{
|
||||
Attributes: map[string]string{"test_key": "attribute_value"},
|
||||
},
|
||||
key: "test_key",
|
||||
expected: "attribute_value",
|
||||
},
|
||||
{
|
||||
name: "Metadata takes precedence",
|
||||
auth: &cliproxyauth.Auth{
|
||||
Metadata: map[string]any{"test_key": "metadata_value"},
|
||||
Attributes: map[string]string{"test_key": "attribute_value"},
|
||||
},
|
||||
key: "test_key",
|
||||
expected: "metadata_value",
|
||||
},
|
||||
{
|
||||
name: "Key not found",
|
||||
auth: &cliproxyauth.Auth{
|
||||
Metadata: map[string]any{"other_key": "value"},
|
||||
Attributes: map[string]string{"another_key": "value"},
|
||||
},
|
||||
key: "test_key",
|
||||
expected: "",
|
||||
},
|
||||
{
|
||||
name: "Nil metadata",
|
||||
auth: &cliproxyauth.Auth{
|
||||
Attributes: map[string]string{"test_key": "attribute_value"},
|
||||
},
|
||||
key: "test_key",
|
||||
expected: "attribute_value",
|
||||
},
|
||||
{
|
||||
name: "Both nil",
|
||||
auth: &cliproxyauth.Auth{},
|
||||
key: "test_key",
|
||||
expected: "",
|
||||
},
|
||||
{
|
||||
name: "Value is trimmed and lowercased",
|
||||
auth: &cliproxyauth.Auth{
|
||||
Metadata: map[string]any{"test_key": " UPPER_VALUE "},
|
||||
},
|
||||
key: "test_key",
|
||||
expected: "upper_value",
|
||||
},
|
||||
{
|
||||
name: "Empty string value in metadata - falls back to attributes",
|
||||
auth: &cliproxyauth.Auth{
|
||||
Metadata: map[string]any{"test_key": ""},
|
||||
Attributes: map[string]string{"test_key": "attribute_value"},
|
||||
},
|
||||
key: "test_key",
|
||||
expected: "attribute_value",
|
||||
},
|
||||
{
|
||||
name: "Non-string value in metadata - falls back to attributes",
|
||||
auth: &cliproxyauth.Auth{
|
||||
Metadata: map[string]any{"test_key": 123},
|
||||
Attributes: map[string]string{"test_key": "attribute_value"},
|
||||
},
|
||||
key: "test_key",
|
||||
expected: "attribute_value",
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
result := getAuthValue(tt.auth, tt.key)
|
||||
if result != tt.expected {
|
||||
t.Errorf("getAuthValue() = %q, want %q", result, tt.expected)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestGetAccountKey(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
auth *cliproxyauth.Auth
|
||||
checkFn func(t *testing.T, result string)
|
||||
}{
|
||||
{
|
||||
name: "From client_id",
|
||||
auth: &cliproxyauth.Auth{
|
||||
Metadata: map[string]any{
|
||||
"client_id": "test-client-id-123",
|
||||
"refresh_token": "test-refresh-token-456",
|
||||
},
|
||||
},
|
||||
checkFn: func(t *testing.T, result string) {
|
||||
expected := kiroauth.GetAccountKey("test-client-id-123", "test-refresh-token-456")
|
||||
if result != expected {
|
||||
t.Errorf("expected %s, got %s", expected, result)
|
||||
}
|
||||
},
|
||||
},
|
||||
{
|
||||
name: "From refresh_token only",
|
||||
auth: &cliproxyauth.Auth{
|
||||
Metadata: map[string]any{
|
||||
"refresh_token": "test-refresh-token-789",
|
||||
},
|
||||
},
|
||||
checkFn: func(t *testing.T, result string) {
|
||||
expected := kiroauth.GetAccountKey("", "test-refresh-token-789")
|
||||
if result != expected {
|
||||
t.Errorf("expected %s, got %s", expected, result)
|
||||
}
|
||||
},
|
||||
},
|
||||
{
|
||||
name: "Nil auth",
|
||||
auth: nil,
|
||||
checkFn: func(t *testing.T, result string) {
|
||||
if len(result) != 16 {
|
||||
t.Errorf("expected 16 char key, got %d chars", len(result))
|
||||
}
|
||||
},
|
||||
},
|
||||
{
|
||||
name: "Nil metadata",
|
||||
auth: &cliproxyauth.Auth{},
|
||||
checkFn: func(t *testing.T, result string) {
|
||||
if len(result) != 16 {
|
||||
t.Errorf("expected 16 char key, got %d chars", len(result))
|
||||
}
|
||||
},
|
||||
},
|
||||
{
|
||||
name: "Empty metadata",
|
||||
auth: &cliproxyauth.Auth{
|
||||
Metadata: map[string]any{},
|
||||
},
|
||||
checkFn: func(t *testing.T, result string) {
|
||||
if len(result) != 16 {
|
||||
t.Errorf("expected 16 char key, got %d chars", len(result))
|
||||
}
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
result := getAccountKey(tt.auth)
|
||||
tt.checkFn(t, result)
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestEndpointAliases(t *testing.T) {
|
||||
// Verify all expected aliases are defined
|
||||
expectedAliases := map[string]string{
|
||||
"codewhisperer": "codewhisperer",
|
||||
"ide": "codewhisperer",
|
||||
"amazonq": "amazonq",
|
||||
"q": "amazonq",
|
||||
"cli": "amazonq",
|
||||
}
|
||||
|
||||
for alias, target := range expectedAliases {
|
||||
if actual, ok := endpointAliases[alias]; !ok {
|
||||
t.Errorf("missing alias %q", alias)
|
||||
} else if actual != target {
|
||||
t.Errorf("alias %q = %q, want %q", alias, actual, target)
|
||||
}
|
||||
}
|
||||
|
||||
// Verify no unexpected aliases
|
||||
if len(endpointAliases) != len(expectedAliases) {
|
||||
t.Errorf("unexpected number of aliases: got %d, want %d", len(endpointAliases), len(expectedAliases))
|
||||
}
|
||||
}
|
||||
Reference in New Issue
Block a user