mirror of
https://github.com/router-for-me/CLIProxyAPIPlus.git
synced 2026-03-12 08:43:58 +00:00
Compare commits
86 Commits
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
d560c20c26 | ||
|
|
5abeca1f9e | ||
|
|
294eac3a88 | ||
|
|
a31104020c | ||
|
|
65bec4d734 | ||
|
|
edb2993838 | ||
|
|
c0d8e0dec7 | ||
|
|
795da13d5d | ||
|
|
55789df275 | ||
|
|
9e652a3540 | ||
|
|
46a6782065 | ||
|
|
c359f61859 | ||
|
|
908c8eab5b | ||
|
|
f5f2c69233 | ||
|
|
63d4de5eea | ||
|
|
af15083496 | ||
|
|
c4722e42b1 | ||
|
|
f9a991365f | ||
|
|
6df16bedba | ||
|
|
632a2fd2f2 | ||
|
|
5626637fbd | ||
|
|
2db89211a9 | ||
|
|
587371eb14 | ||
|
|
75818b1e25 | ||
|
|
cbe56955a9 | ||
|
|
8ea6ac913d | ||
|
|
ae1e8a5191 | ||
|
|
b3ccc55f09 | ||
|
|
1ce56d7413 | ||
|
|
41a78be3a2 | ||
|
|
1ff5de9a31 | ||
|
|
46a6853046 | ||
|
|
4b2d40bd67 | ||
|
|
726f1a590c | ||
|
|
575881cb59 | ||
|
|
d02df0141b | ||
|
|
e4bc9da913 | ||
|
|
8c6be49625 | ||
|
|
c727e4251f | ||
|
|
99266be998 | ||
|
|
d0f3fd96f8 | ||
|
|
f361b2716d | ||
|
|
086d8d0d0b | ||
|
|
627dee1dac | ||
|
|
55c3197fb8 | ||
|
|
5a2cf0d53c | ||
|
|
2573358173 | ||
|
|
09cd3cff91 | ||
|
|
ab0bf1b517 | ||
|
|
58e09f8e5f | ||
|
|
2334a2b174 | ||
|
|
bc61bf36b2 | ||
|
|
7726a44ca2 | ||
|
|
dc55fb0ce3 | ||
|
|
a146c6c0aa | ||
|
|
4c133d3ea9 | ||
|
|
544238772a | ||
|
|
f3ccd85ba1 | ||
|
|
dc279de443 | ||
|
|
bf1634bda0 | ||
|
|
166d2d24d9 | ||
|
|
4cbcc835d1 | ||
|
|
b93026d83a | ||
|
|
5ed2133ff9 | ||
|
|
e9dd44e623 | ||
|
|
cc8c4ffb5f | ||
|
|
1510bfcb6f | ||
|
|
bcd2208b51 | ||
|
|
09b19f5c4e | ||
|
|
7b01ca0e2e | ||
|
|
9c65e17a21 | ||
|
|
fe6fc628ed | ||
|
|
8192eeabc8 | ||
|
|
c3f1cdd7e5 | ||
|
|
c6bd91b86b | ||
|
|
349ddcaa89 | ||
|
|
bb9fe52f1e | ||
|
|
afe4c1bfb7 | ||
|
|
865af9f19e | ||
|
|
2b97cb98b5 | ||
|
|
938a799263 | ||
|
|
15bc99f6ea | ||
|
|
3ec7991e5f | ||
|
|
40e85a6759 | ||
|
|
cc116ce67d | ||
|
|
40efc2ba43 |
2
.github/workflows/release.yaml
vendored
2
.github/workflows/release.yaml
vendored
@@ -19,7 +19,7 @@ jobs:
|
||||
- run: git fetch --force --tags
|
||||
- uses: actions/setup-go@v4
|
||||
with:
|
||||
go-version: '>=1.24.0'
|
||||
go-version: '>=1.26.0'
|
||||
cache: true
|
||||
- name: Generate Build Metadata
|
||||
run: |
|
||||
|
||||
2
.gitignore
vendored
2
.gitignore
vendored
@@ -6,7 +6,7 @@ cliproxy
|
||||
# Configuration
|
||||
config.yaml
|
||||
.env
|
||||
|
||||
.mcp.json
|
||||
# Generated content
|
||||
bin/*
|
||||
logs/*
|
||||
|
||||
@@ -1,4 +1,4 @@
|
||||
FROM golang:1.24-alpine AS builder
|
||||
FROM golang:1.26-alpine AS builder
|
||||
|
||||
WORKDIR /app
|
||||
|
||||
|
||||
2
go.mod
2
go.mod
@@ -1,6 +1,6 @@
|
||||
module github.com/router-for-me/CLIProxyAPI/v6
|
||||
|
||||
go 1.24.0
|
||||
go 1.26.0
|
||||
|
||||
require (
|
||||
github.com/andybalholm/brotli v1.0.6
|
||||
|
||||
@@ -1,6 +1,7 @@
|
||||
package management
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"context"
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
@@ -189,9 +190,21 @@ func (h *Handler) APICall(c *gin.Context) {
|
||||
reqHeaders[key] = strings.ReplaceAll(value, "$TOKEN$", token)
|
||||
}
|
||||
|
||||
// When caller indicates CBOR in request headers, convert JSON string payload to CBOR bytes.
|
||||
useCBORPayload := headerContainsValue(reqHeaders, "Content-Type", "application/cbor")
|
||||
|
||||
var requestBody io.Reader
|
||||
if body.Data != "" {
|
||||
requestBody = strings.NewReader(body.Data)
|
||||
if useCBORPayload {
|
||||
cborPayload, errEncode := encodeJSONStringToCBOR(body.Data)
|
||||
if errEncode != nil {
|
||||
c.JSON(http.StatusBadRequest, gin.H{"error": "invalid json data for cbor content-type"})
|
||||
return
|
||||
}
|
||||
requestBody = bytes.NewReader(cborPayload)
|
||||
} else {
|
||||
requestBody = strings.NewReader(body.Data)
|
||||
}
|
||||
}
|
||||
|
||||
req, errNewRequest := http.NewRequestWithContext(c.Request.Context(), method, urlStr, requestBody)
|
||||
@@ -234,10 +247,18 @@ func (h *Handler) APICall(c *gin.Context) {
|
||||
return
|
||||
}
|
||||
|
||||
// For CBOR upstream responses, decode into plain text or JSON string before returning.
|
||||
responseBodyText := string(respBody)
|
||||
if headerContainsValue(reqHeaders, "Accept", "application/cbor") || strings.Contains(strings.ToLower(resp.Header.Get("Content-Type")), "application/cbor") {
|
||||
if decodedBody, errDecode := decodeCBORBodyToTextOrJSON(respBody); errDecode == nil {
|
||||
responseBodyText = decodedBody
|
||||
}
|
||||
}
|
||||
|
||||
response := apiCallResponse{
|
||||
StatusCode: resp.StatusCode,
|
||||
Header: resp.Header,
|
||||
Body: string(respBody),
|
||||
Body: responseBodyText,
|
||||
}
|
||||
|
||||
// If this is a GitHub Copilot token endpoint response, try to enrich with quota information
|
||||
@@ -747,6 +768,83 @@ func buildProxyTransport(proxyStr string) *http.Transport {
|
||||
return nil
|
||||
}
|
||||
|
||||
// headerContainsValue checks whether a header map contains a target value (case-insensitive key and value).
|
||||
func headerContainsValue(headers map[string]string, targetKey, targetValue string) bool {
|
||||
if len(headers) == 0 {
|
||||
return false
|
||||
}
|
||||
for key, value := range headers {
|
||||
if !strings.EqualFold(strings.TrimSpace(key), strings.TrimSpace(targetKey)) {
|
||||
continue
|
||||
}
|
||||
if strings.Contains(strings.ToLower(value), strings.ToLower(strings.TrimSpace(targetValue))) {
|
||||
return true
|
||||
}
|
||||
}
|
||||
return false
|
||||
}
|
||||
|
||||
// encodeJSONStringToCBOR converts a JSON string payload into CBOR bytes.
|
||||
func encodeJSONStringToCBOR(jsonString string) ([]byte, error) {
|
||||
var payload any
|
||||
if errUnmarshal := json.Unmarshal([]byte(jsonString), &payload); errUnmarshal != nil {
|
||||
return nil, errUnmarshal
|
||||
}
|
||||
return cbor.Marshal(payload)
|
||||
}
|
||||
|
||||
// decodeCBORBodyToTextOrJSON decodes CBOR bytes to plain text (for string payloads) or JSON string.
|
||||
func decodeCBORBodyToTextOrJSON(raw []byte) (string, error) {
|
||||
if len(raw) == 0 {
|
||||
return "", nil
|
||||
}
|
||||
|
||||
var payload any
|
||||
if errUnmarshal := cbor.Unmarshal(raw, &payload); errUnmarshal != nil {
|
||||
return "", errUnmarshal
|
||||
}
|
||||
|
||||
jsonCompatible := cborValueToJSONCompatible(payload)
|
||||
switch typed := jsonCompatible.(type) {
|
||||
case string:
|
||||
return typed, nil
|
||||
case []byte:
|
||||
return string(typed), nil
|
||||
default:
|
||||
jsonBytes, errMarshal := json.Marshal(jsonCompatible)
|
||||
if errMarshal != nil {
|
||||
return "", errMarshal
|
||||
}
|
||||
return string(jsonBytes), nil
|
||||
}
|
||||
}
|
||||
|
||||
// cborValueToJSONCompatible recursively converts CBOR-decoded values into JSON-marshalable values.
|
||||
func cborValueToJSONCompatible(value any) any {
|
||||
switch typed := value.(type) {
|
||||
case map[any]any:
|
||||
out := make(map[string]any, len(typed))
|
||||
for key, item := range typed {
|
||||
out[fmt.Sprint(key)] = cborValueToJSONCompatible(item)
|
||||
}
|
||||
return out
|
||||
case map[string]any:
|
||||
out := make(map[string]any, len(typed))
|
||||
for key, item := range typed {
|
||||
out[key] = cborValueToJSONCompatible(item)
|
||||
}
|
||||
return out
|
||||
case []any:
|
||||
out := make([]any, len(typed))
|
||||
for i, item := range typed {
|
||||
out[i] = cborValueToJSONCompatible(item)
|
||||
}
|
||||
return out
|
||||
default:
|
||||
return typed
|
||||
}
|
||||
}
|
||||
|
||||
// QuotaDetail represents quota information for a specific resource type
|
||||
type QuotaDetail struct {
|
||||
Entitlement float64 `json:"entitlement"`
|
||||
|
||||
@@ -1193,6 +1193,30 @@ func (h *Handler) RequestGeminiCLIToken(c *gin.Context) {
|
||||
}
|
||||
ts.ProjectID = strings.Join(projects, ",")
|
||||
ts.Checked = true
|
||||
} else if strings.EqualFold(requestedProjectID, "GOOGLE_ONE") {
|
||||
ts.Auto = false
|
||||
if errSetup := performGeminiCLISetup(ctx, gemClient, &ts, ""); errSetup != nil {
|
||||
log.Errorf("Google One auto-discovery failed: %v", errSetup)
|
||||
SetOAuthSessionError(state, "Google One auto-discovery failed")
|
||||
return
|
||||
}
|
||||
if strings.TrimSpace(ts.ProjectID) == "" {
|
||||
log.Error("Google One auto-discovery returned empty project ID")
|
||||
SetOAuthSessionError(state, "Google One auto-discovery returned empty project ID")
|
||||
return
|
||||
}
|
||||
isChecked, errCheck := checkCloudAPIIsEnabled(ctx, gemClient, ts.ProjectID)
|
||||
if errCheck != nil {
|
||||
log.Errorf("Failed to verify Cloud AI API status: %v", errCheck)
|
||||
SetOAuthSessionError(state, "Failed to verify Cloud AI API status")
|
||||
return
|
||||
}
|
||||
ts.Checked = isChecked
|
||||
if !isChecked {
|
||||
log.Error("Cloud AI API is not enabled for the auto-discovered project")
|
||||
SetOAuthSessionError(state, "Cloud AI API not enabled")
|
||||
return
|
||||
}
|
||||
} else {
|
||||
if errEnsure := ensureGeminiProjectAndOnboard(ctx, gemClient, &ts, requestedProjectID); errEnsure != nil {
|
||||
log.Errorf("Failed to complete Gemini CLI onboarding: %v", errEnsure)
|
||||
@@ -2124,7 +2148,48 @@ func performGeminiCLISetup(ctx context.Context, httpClient *http.Client, storage
|
||||
}
|
||||
}
|
||||
if projectID == "" {
|
||||
return &projectSelectionRequiredError{}
|
||||
// Auto-discovery: try onboardUser without specifying a project
|
||||
// to let Google auto-provision one (matches Gemini CLI headless behavior
|
||||
// and Antigravity's FetchProjectID pattern).
|
||||
autoOnboardReq := map[string]any{
|
||||
"tierId": tierID,
|
||||
"metadata": metadata,
|
||||
}
|
||||
|
||||
autoCtx, autoCancel := context.WithTimeout(ctx, 30*time.Second)
|
||||
defer autoCancel()
|
||||
for attempt := 1; ; attempt++ {
|
||||
var onboardResp map[string]any
|
||||
if errOnboard := callGeminiCLI(autoCtx, httpClient, "onboardUser", autoOnboardReq, &onboardResp); errOnboard != nil {
|
||||
return fmt.Errorf("auto-discovery onboardUser: %w", errOnboard)
|
||||
}
|
||||
|
||||
if done, okDone := onboardResp["done"].(bool); okDone && done {
|
||||
if resp, okResp := onboardResp["response"].(map[string]any); okResp {
|
||||
switch v := resp["cloudaicompanionProject"].(type) {
|
||||
case string:
|
||||
projectID = strings.TrimSpace(v)
|
||||
case map[string]any:
|
||||
if id, okID := v["id"].(string); okID {
|
||||
projectID = strings.TrimSpace(id)
|
||||
}
|
||||
}
|
||||
}
|
||||
break
|
||||
}
|
||||
|
||||
log.Debugf("Auto-discovery: onboarding in progress, attempt %d...", attempt)
|
||||
select {
|
||||
case <-autoCtx.Done():
|
||||
return &projectSelectionRequiredError{}
|
||||
case <-time.After(2 * time.Second):
|
||||
}
|
||||
}
|
||||
|
||||
if projectID == "" {
|
||||
return &projectSelectionRequiredError{}
|
||||
}
|
||||
log.Infof("Auto-discovered project ID via onboarding: %s", projectID)
|
||||
}
|
||||
|
||||
onboardReqBody := map[string]any{
|
||||
|
||||
@@ -28,8 +28,7 @@ func (h *Handler) GetConfig(c *gin.Context) {
|
||||
c.JSON(200, gin.H{})
|
||||
return
|
||||
}
|
||||
cfgCopy := *h.cfg
|
||||
c.JSON(200, &cfgCopy)
|
||||
c.JSON(200, new(*h.cfg))
|
||||
}
|
||||
|
||||
type releaseInfo struct {
|
||||
|
||||
@@ -796,10 +796,10 @@ func (h *Handler) DeleteOAuthModelAlias(c *gin.Context) {
|
||||
c.JSON(404, gin.H{"error": "channel not found"})
|
||||
return
|
||||
}
|
||||
delete(h.cfg.OAuthModelAlias, channel)
|
||||
if len(h.cfg.OAuthModelAlias) == 0 {
|
||||
h.cfg.OAuthModelAlias = nil
|
||||
}
|
||||
// Set to nil instead of deleting the key so that the "explicitly disabled"
|
||||
// marker survives config reload and prevents SanitizeOAuthModelAlias from
|
||||
// re-injecting default aliases (fixes #222).
|
||||
h.cfg.OAuthModelAlias[channel] = nil
|
||||
h.persist(c)
|
||||
}
|
||||
|
||||
|
||||
@@ -127,8 +127,7 @@ func (m *AmpModule) Register(ctx modules.Context) error {
|
||||
m.modelMapper = NewModelMapper(settings.ModelMappings)
|
||||
|
||||
// Store initial config for partial reload comparison
|
||||
settingsCopy := settings
|
||||
m.lastConfig = &settingsCopy
|
||||
m.lastConfig = new(settings)
|
||||
|
||||
// Initialize localhost restriction setting (hot-reloadable)
|
||||
m.setRestrictToLocalhost(settings.RestrictManagementToLocalhost)
|
||||
|
||||
@@ -40,8 +40,7 @@ func DoClaudeLogin(cfg *config.Config, options *LoginOptions) {
|
||||
|
||||
_, savedPath, err := manager.Login(context.Background(), "claude", cfg, authOpts)
|
||||
if err != nil {
|
||||
var authErr *claude.AuthenticationError
|
||||
if errors.As(err, &authErr) {
|
||||
if authErr, ok := errors.AsType[*claude.AuthenticationError](err); ok {
|
||||
log.Error(claude.GetUserFriendlyMessage(authErr))
|
||||
if authErr.Type == claude.ErrPortInUse.Type {
|
||||
os.Exit(claude.ErrPortInUse.Code)
|
||||
|
||||
@@ -32,8 +32,7 @@ func DoIFlowLogin(cfg *config.Config, options *LoginOptions) {
|
||||
|
||||
_, savedPath, err := manager.Login(context.Background(), "iflow", cfg, authOpts)
|
||||
if err != nil {
|
||||
var emailErr *sdkAuth.EmailRequiredError
|
||||
if errors.As(err, &emailErr) {
|
||||
if emailErr, ok := errors.AsType[*sdkAuth.EmailRequiredError](err); ok {
|
||||
log.Error(emailErr.Error())
|
||||
return
|
||||
}
|
||||
|
||||
@@ -100,49 +100,74 @@ func DoLogin(cfg *config.Config, projectID string, options *LoginOptions) {
|
||||
|
||||
log.Info("Authentication successful.")
|
||||
|
||||
projects, errProjects := fetchGCPProjects(ctx, httpClient)
|
||||
if errProjects != nil {
|
||||
log.Errorf("Failed to get project list: %v", errProjects)
|
||||
return
|
||||
var activatedProjects []string
|
||||
|
||||
useGoogleOne := false
|
||||
if trimmedProjectID == "" && promptFn != nil {
|
||||
fmt.Println("\nSelect login mode:")
|
||||
fmt.Println(" 1. Code Assist (GCP project, manual selection)")
|
||||
fmt.Println(" 2. Google One (personal account, auto-discover project)")
|
||||
choice, errPrompt := promptFn("Enter choice [1/2] (default: 1): ")
|
||||
if errPrompt == nil && strings.TrimSpace(choice) == "2" {
|
||||
useGoogleOne = true
|
||||
}
|
||||
}
|
||||
|
||||
selectedProjectID := promptForProjectSelection(projects, trimmedProjectID, promptFn)
|
||||
projectSelections, errSelection := resolveProjectSelections(selectedProjectID, projects)
|
||||
if errSelection != nil {
|
||||
log.Errorf("Invalid project selection: %v", errSelection)
|
||||
return
|
||||
}
|
||||
if len(projectSelections) == 0 {
|
||||
log.Error("No project selected; aborting login.")
|
||||
return
|
||||
}
|
||||
|
||||
activatedProjects := make([]string, 0, len(projectSelections))
|
||||
seenProjects := make(map[string]bool)
|
||||
for _, candidateID := range projectSelections {
|
||||
log.Infof("Activating project %s", candidateID)
|
||||
if errSetup := performGeminiCLISetup(ctx, httpClient, storage, candidateID); errSetup != nil {
|
||||
var projectErr *projectSelectionRequiredError
|
||||
if errors.As(errSetup, &projectErr) {
|
||||
log.Error("Failed to start user onboarding: A project ID is required.")
|
||||
showProjectSelectionHelp(storage.Email, projects)
|
||||
return
|
||||
}
|
||||
log.Errorf("Failed to complete user setup: %v", errSetup)
|
||||
if useGoogleOne {
|
||||
log.Info("Google One mode: auto-discovering project...")
|
||||
if errSetup := performGeminiCLISetup(ctx, httpClient, storage, ""); errSetup != nil {
|
||||
log.Errorf("Google One auto-discovery failed: %v", errSetup)
|
||||
return
|
||||
}
|
||||
finalID := strings.TrimSpace(storage.ProjectID)
|
||||
if finalID == "" {
|
||||
finalID = candidateID
|
||||
autoProject := strings.TrimSpace(storage.ProjectID)
|
||||
if autoProject == "" {
|
||||
log.Error("Google One auto-discovery returned empty project ID")
|
||||
return
|
||||
}
|
||||
log.Infof("Auto-discovered project: %s", autoProject)
|
||||
activatedProjects = []string{autoProject}
|
||||
} else {
|
||||
projects, errProjects := fetchGCPProjects(ctx, httpClient)
|
||||
if errProjects != nil {
|
||||
log.Errorf("Failed to get project list: %v", errProjects)
|
||||
return
|
||||
}
|
||||
|
||||
// Skip duplicates
|
||||
if seenProjects[finalID] {
|
||||
log.Infof("Project %s already activated, skipping", finalID)
|
||||
continue
|
||||
selectedProjectID := promptForProjectSelection(projects, trimmedProjectID, promptFn)
|
||||
projectSelections, errSelection := resolveProjectSelections(selectedProjectID, projects)
|
||||
if errSelection != nil {
|
||||
log.Errorf("Invalid project selection: %v", errSelection)
|
||||
return
|
||||
}
|
||||
if len(projectSelections) == 0 {
|
||||
log.Error("No project selected; aborting login.")
|
||||
return
|
||||
}
|
||||
|
||||
seenProjects := make(map[string]bool)
|
||||
for _, candidateID := range projectSelections {
|
||||
log.Infof("Activating project %s", candidateID)
|
||||
if errSetup := performGeminiCLISetup(ctx, httpClient, storage, candidateID); errSetup != nil {
|
||||
if _, ok := errors.AsType[*projectSelectionRequiredError](errSetup); ok {
|
||||
log.Error("Failed to start user onboarding: A project ID is required.")
|
||||
showProjectSelectionHelp(storage.Email, projects)
|
||||
return
|
||||
}
|
||||
log.Errorf("Failed to complete user setup: %v", errSetup)
|
||||
return
|
||||
}
|
||||
finalID := strings.TrimSpace(storage.ProjectID)
|
||||
if finalID == "" {
|
||||
finalID = candidateID
|
||||
}
|
||||
|
||||
if seenProjects[finalID] {
|
||||
log.Infof("Project %s already activated, skipping", finalID)
|
||||
continue
|
||||
}
|
||||
seenProjects[finalID] = true
|
||||
activatedProjects = append(activatedProjects, finalID)
|
||||
}
|
||||
seenProjects[finalID] = true
|
||||
activatedProjects = append(activatedProjects, finalID)
|
||||
}
|
||||
|
||||
storage.Auto = false
|
||||
@@ -235,7 +260,48 @@ func performGeminiCLISetup(ctx context.Context, httpClient *http.Client, storage
|
||||
}
|
||||
}
|
||||
if projectID == "" {
|
||||
return &projectSelectionRequiredError{}
|
||||
// Auto-discovery: try onboardUser without specifying a project
|
||||
// to let Google auto-provision one (matches Gemini CLI headless behavior
|
||||
// and Antigravity's FetchProjectID pattern).
|
||||
autoOnboardReq := map[string]any{
|
||||
"tierId": tierID,
|
||||
"metadata": metadata,
|
||||
}
|
||||
|
||||
autoCtx, autoCancel := context.WithTimeout(ctx, 30*time.Second)
|
||||
defer autoCancel()
|
||||
for attempt := 1; ; attempt++ {
|
||||
var onboardResp map[string]any
|
||||
if errOnboard := callGeminiCLI(autoCtx, httpClient, "onboardUser", autoOnboardReq, &onboardResp); errOnboard != nil {
|
||||
return fmt.Errorf("auto-discovery onboardUser: %w", errOnboard)
|
||||
}
|
||||
|
||||
if done, okDone := onboardResp["done"].(bool); okDone && done {
|
||||
if resp, okResp := onboardResp["response"].(map[string]any); okResp {
|
||||
switch v := resp["cloudaicompanionProject"].(type) {
|
||||
case string:
|
||||
projectID = strings.TrimSpace(v)
|
||||
case map[string]any:
|
||||
if id, okID := v["id"].(string); okID {
|
||||
projectID = strings.TrimSpace(id)
|
||||
}
|
||||
}
|
||||
}
|
||||
break
|
||||
}
|
||||
|
||||
log.Debugf("Auto-discovery: onboarding in progress, attempt %d...", attempt)
|
||||
select {
|
||||
case <-autoCtx.Done():
|
||||
return &projectSelectionRequiredError{}
|
||||
case <-time.After(2 * time.Second):
|
||||
}
|
||||
}
|
||||
|
||||
if projectID == "" {
|
||||
return &projectSelectionRequiredError{}
|
||||
}
|
||||
log.Infof("Auto-discovered project ID via onboarding: %s", projectID)
|
||||
}
|
||||
|
||||
onboardReqBody := map[string]any{
|
||||
@@ -617,7 +683,7 @@ func updateAuthRecord(record *cliproxyauth.Auth, storage *gemini.GeminiTokenStor
|
||||
return
|
||||
}
|
||||
|
||||
finalName := gemini.CredentialFileName(storage.Email, storage.ProjectID, false)
|
||||
finalName := gemini.CredentialFileName(storage.Email, storage.ProjectID, true)
|
||||
|
||||
if record.Metadata == nil {
|
||||
record.Metadata = make(map[string]any)
|
||||
|
||||
@@ -54,8 +54,7 @@ func DoCodexLogin(cfg *config.Config, options *LoginOptions) {
|
||||
|
||||
_, savedPath, err := manager.Login(context.Background(), "codex", cfg, authOpts)
|
||||
if err != nil {
|
||||
var authErr *codex.AuthenticationError
|
||||
if errors.As(err, &authErr) {
|
||||
if authErr, ok := errors.AsType[*codex.AuthenticationError](err); ok {
|
||||
log.Error(codex.GetUserFriendlyMessage(authErr))
|
||||
if authErr.Type == codex.ErrPortInUse.Type {
|
||||
os.Exit(codex.ErrPortInUse.Code)
|
||||
|
||||
@@ -44,8 +44,7 @@ func DoQwenLogin(cfg *config.Config, options *LoginOptions) {
|
||||
|
||||
_, savedPath, err := manager.Login(context.Background(), "qwen", cfg, authOpts)
|
||||
if err != nil {
|
||||
var emailErr *sdkAuth.EmailRequiredError
|
||||
if errors.As(err, &emailErr) {
|
||||
if emailErr, ok := errors.AsType[*sdkAuth.EmailRequiredError](err); ok {
|
||||
log.Error(emailErr.Error())
|
||||
return
|
||||
}
|
||||
|
||||
@@ -736,14 +736,44 @@ func payloadRawString(value any) ([]byte, bool) {
|
||||
// SanitizeOAuthModelAlias normalizes and deduplicates global OAuth model name aliases.
|
||||
// It trims whitespace, normalizes channel keys to lower-case, drops empty entries,
|
||||
// allows multiple aliases per upstream name, and ensures aliases are unique within each channel.
|
||||
// It also injects default aliases for channels that have built-in defaults (e.g., kiro)
|
||||
// when no user-configured aliases exist for those channels.
|
||||
func (cfg *Config) SanitizeOAuthModelAlias() {
|
||||
if cfg == nil || len(cfg.OAuthModelAlias) == 0 {
|
||||
if cfg == nil {
|
||||
return
|
||||
}
|
||||
|
||||
// Inject default Kiro aliases if no user-configured kiro aliases exist
|
||||
if cfg.OAuthModelAlias == nil {
|
||||
cfg.OAuthModelAlias = make(map[string][]OAuthModelAlias)
|
||||
}
|
||||
if _, hasKiro := cfg.OAuthModelAlias["kiro"]; !hasKiro {
|
||||
// Check case-insensitive too
|
||||
found := false
|
||||
for k := range cfg.OAuthModelAlias {
|
||||
if strings.EqualFold(strings.TrimSpace(k), "kiro") {
|
||||
found = true
|
||||
break
|
||||
}
|
||||
}
|
||||
if !found {
|
||||
cfg.OAuthModelAlias["kiro"] = defaultKiroAliases()
|
||||
}
|
||||
}
|
||||
|
||||
if len(cfg.OAuthModelAlias) == 0 {
|
||||
return
|
||||
}
|
||||
out := make(map[string][]OAuthModelAlias, len(cfg.OAuthModelAlias))
|
||||
for rawChannel, aliases := range cfg.OAuthModelAlias {
|
||||
channel := strings.ToLower(strings.TrimSpace(rawChannel))
|
||||
if channel == "" || len(aliases) == 0 {
|
||||
if channel == "" {
|
||||
continue
|
||||
}
|
||||
// Preserve channels that were explicitly set to empty/nil – they act
|
||||
// as "disabled" markers so default injection won't re-add them (#222).
|
||||
if len(aliases) == 0 {
|
||||
out[channel] = nil
|
||||
continue
|
||||
}
|
||||
seenAlias := make(map[string]struct{}, len(aliases))
|
||||
|
||||
@@ -20,6 +20,28 @@ var antigravityModelConversionTable = map[string]string{
|
||||
"gemini-claude-opus-4-6-thinking": "claude-opus-4-6-thinking",
|
||||
}
|
||||
|
||||
// defaultKiroAliases returns the default oauth-model-alias configuration
|
||||
// for the kiro channel. Maps kiro-prefixed model names to standard Claude model
|
||||
// names so that clients like Claude Code can use standard names directly.
|
||||
func defaultKiroAliases() []OAuthModelAlias {
|
||||
return []OAuthModelAlias{
|
||||
// Sonnet 4.5
|
||||
{Name: "kiro-claude-sonnet-4-5", Alias: "claude-sonnet-4-5-20250929", Fork: true},
|
||||
{Name: "kiro-claude-sonnet-4-5", Alias: "claude-sonnet-4-5", Fork: true},
|
||||
// Sonnet 4
|
||||
{Name: "kiro-claude-sonnet-4", Alias: "claude-sonnet-4-20250514", Fork: true},
|
||||
{Name: "kiro-claude-sonnet-4", Alias: "claude-sonnet-4", Fork: true},
|
||||
// Opus 4.6
|
||||
{Name: "kiro-claude-opus-4-6", Alias: "claude-opus-4-6", Fork: true},
|
||||
// Opus 4.5
|
||||
{Name: "kiro-claude-opus-4-5", Alias: "claude-opus-4-5-20251101", Fork: true},
|
||||
{Name: "kiro-claude-opus-4-5", Alias: "claude-opus-4-5", Fork: true},
|
||||
// Haiku 4.5
|
||||
{Name: "kiro-claude-haiku-4-5", Alias: "claude-haiku-4-5-20251001", Fork: true},
|
||||
{Name: "kiro-claude-haiku-4-5", Alias: "claude-haiku-4-5", Fork: true},
|
||||
}
|
||||
}
|
||||
|
||||
// defaultAntigravityAliases returns the default oauth-model-alias configuration
|
||||
// for the antigravity channel when neither field exists.
|
||||
func defaultAntigravityAliases() []OAuthModelAlias {
|
||||
|
||||
@@ -54,3 +54,132 @@ func TestSanitizeOAuthModelAlias_AllowsMultipleAliasesForSameName(t *testing.T)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func TestSanitizeOAuthModelAlias_InjectsDefaultKiroAliases(t *testing.T) {
|
||||
// When no kiro aliases are configured, defaults should be injected
|
||||
cfg := &Config{
|
||||
OAuthModelAlias: map[string][]OAuthModelAlias{
|
||||
"codex": {
|
||||
{Name: "gpt-5", Alias: "g5"},
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
cfg.SanitizeOAuthModelAlias()
|
||||
|
||||
kiroAliases := cfg.OAuthModelAlias["kiro"]
|
||||
if len(kiroAliases) == 0 {
|
||||
t.Fatal("expected default kiro aliases to be injected")
|
||||
}
|
||||
|
||||
// Check that standard Claude model names are present
|
||||
aliasSet := make(map[string]bool)
|
||||
for _, a := range kiroAliases {
|
||||
aliasSet[a.Alias] = true
|
||||
}
|
||||
expectedAliases := []string{
|
||||
"claude-sonnet-4-5-20250929",
|
||||
"claude-sonnet-4-5",
|
||||
"claude-sonnet-4-20250514",
|
||||
"claude-sonnet-4",
|
||||
"claude-opus-4-6",
|
||||
"claude-opus-4-5-20251101",
|
||||
"claude-opus-4-5",
|
||||
"claude-haiku-4-5-20251001",
|
||||
"claude-haiku-4-5",
|
||||
}
|
||||
for _, expected := range expectedAliases {
|
||||
if !aliasSet[expected] {
|
||||
t.Fatalf("expected default kiro alias %q to be present", expected)
|
||||
}
|
||||
}
|
||||
|
||||
// All should have fork=true
|
||||
for _, a := range kiroAliases {
|
||||
if !a.Fork {
|
||||
t.Fatalf("expected all default kiro aliases to have fork=true, got fork=false for %q", a.Alias)
|
||||
}
|
||||
}
|
||||
|
||||
// Codex aliases should still be preserved
|
||||
if len(cfg.OAuthModelAlias["codex"]) != 1 {
|
||||
t.Fatal("expected codex aliases to be preserved")
|
||||
}
|
||||
}
|
||||
|
||||
func TestSanitizeOAuthModelAlias_DoesNotOverrideUserKiroAliases(t *testing.T) {
|
||||
// When user has configured kiro aliases, defaults should NOT be injected
|
||||
cfg := &Config{
|
||||
OAuthModelAlias: map[string][]OAuthModelAlias{
|
||||
"kiro": {
|
||||
{Name: "kiro-claude-sonnet-4", Alias: "my-custom-sonnet", Fork: true},
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
cfg.SanitizeOAuthModelAlias()
|
||||
|
||||
kiroAliases := cfg.OAuthModelAlias["kiro"]
|
||||
if len(kiroAliases) != 1 {
|
||||
t.Fatalf("expected 1 user-configured kiro alias, got %d", len(kiroAliases))
|
||||
}
|
||||
if kiroAliases[0].Alias != "my-custom-sonnet" {
|
||||
t.Fatalf("expected user alias to be preserved, got %q", kiroAliases[0].Alias)
|
||||
}
|
||||
}
|
||||
|
||||
func TestSanitizeOAuthModelAlias_DoesNotReinjectAfterExplicitDeletion(t *testing.T) {
|
||||
// When user explicitly deletes kiro aliases (key exists with nil value),
|
||||
// defaults should NOT be re-injected on subsequent sanitize calls (#222).
|
||||
cfg := &Config{
|
||||
OAuthModelAlias: map[string][]OAuthModelAlias{
|
||||
"kiro": nil, // explicitly deleted
|
||||
"codex": {{Name: "gpt-5", Alias: "g5"}},
|
||||
},
|
||||
}
|
||||
|
||||
cfg.SanitizeOAuthModelAlias()
|
||||
|
||||
kiroAliases := cfg.OAuthModelAlias["kiro"]
|
||||
if len(kiroAliases) != 0 {
|
||||
t.Fatalf("expected kiro aliases to remain empty after explicit deletion, got %d aliases", len(kiroAliases))
|
||||
}
|
||||
// The key itself must still be present to prevent re-injection on next reload
|
||||
if _, exists := cfg.OAuthModelAlias["kiro"]; !exists {
|
||||
t.Fatal("expected kiro key to be preserved as nil marker after sanitization")
|
||||
}
|
||||
// Other channels should be unaffected
|
||||
if len(cfg.OAuthModelAlias["codex"]) != 1 {
|
||||
t.Fatal("expected codex aliases to be preserved")
|
||||
}
|
||||
}
|
||||
|
||||
func TestSanitizeOAuthModelAlias_DoesNotReinjectAfterExplicitDeletionEmpty(t *testing.T) {
|
||||
// Same as above but with empty slice instead of nil (PUT with empty body).
|
||||
cfg := &Config{
|
||||
OAuthModelAlias: map[string][]OAuthModelAlias{
|
||||
"kiro": {}, // explicitly set to empty
|
||||
},
|
||||
}
|
||||
|
||||
cfg.SanitizeOAuthModelAlias()
|
||||
|
||||
if len(cfg.OAuthModelAlias["kiro"]) != 0 {
|
||||
t.Fatalf("expected kiro aliases to remain empty, got %d aliases", len(cfg.OAuthModelAlias["kiro"]))
|
||||
}
|
||||
if _, exists := cfg.OAuthModelAlias["kiro"]; !exists {
|
||||
t.Fatal("expected kiro key to be preserved")
|
||||
}
|
||||
}
|
||||
|
||||
func TestSanitizeOAuthModelAlias_InjectsDefaultKiroWhenEmpty(t *testing.T) {
|
||||
// When OAuthModelAlias is nil, kiro defaults should still be injected
|
||||
cfg := &Config{}
|
||||
|
||||
cfg.SanitizeOAuthModelAlias()
|
||||
|
||||
kiroAliases := cfg.OAuthModelAlias["kiro"]
|
||||
if len(kiroAliases) == 0 {
|
||||
t.Fatal("expected default kiro aliases to be injected when OAuthModelAlias is nil")
|
||||
}
|
||||
}
|
||||
|
||||
@@ -144,6 +144,7 @@ func GetGitHubCopilotModels() []*ModelInfo {
|
||||
ContextLength: 200000,
|
||||
MaxCompletionTokens: 32768,
|
||||
SupportedEndpoints: []string{"/chat/completions", "/responses"},
|
||||
Thinking: &ThinkingSupport{Levels: []string{"low", "medium", "high"}},
|
||||
},
|
||||
{
|
||||
ID: "gpt-5-mini",
|
||||
@@ -156,6 +157,7 @@ func GetGitHubCopilotModels() []*ModelInfo {
|
||||
ContextLength: 128000,
|
||||
MaxCompletionTokens: 16384,
|
||||
SupportedEndpoints: []string{"/chat/completions", "/responses"},
|
||||
Thinking: &ThinkingSupport{Levels: []string{"low", "medium", "high"}},
|
||||
},
|
||||
{
|
||||
ID: "gpt-5-codex",
|
||||
@@ -168,6 +170,7 @@ func GetGitHubCopilotModels() []*ModelInfo {
|
||||
ContextLength: 200000,
|
||||
MaxCompletionTokens: 32768,
|
||||
SupportedEndpoints: []string{"/responses"},
|
||||
Thinking: &ThinkingSupport{Levels: []string{"low", "medium", "high"}},
|
||||
},
|
||||
{
|
||||
ID: "gpt-5.1",
|
||||
@@ -180,6 +183,7 @@ func GetGitHubCopilotModels() []*ModelInfo {
|
||||
ContextLength: 200000,
|
||||
MaxCompletionTokens: 32768,
|
||||
SupportedEndpoints: []string{"/chat/completions", "/responses"},
|
||||
Thinking: &ThinkingSupport{Levels: []string{"none", "low", "medium", "high"}},
|
||||
},
|
||||
{
|
||||
ID: "gpt-5.1-codex",
|
||||
@@ -192,6 +196,7 @@ func GetGitHubCopilotModels() []*ModelInfo {
|
||||
ContextLength: 200000,
|
||||
MaxCompletionTokens: 32768,
|
||||
SupportedEndpoints: []string{"/responses"},
|
||||
Thinking: &ThinkingSupport{Levels: []string{"none", "low", "medium", "high"}},
|
||||
},
|
||||
{
|
||||
ID: "gpt-5.1-codex-mini",
|
||||
@@ -204,6 +209,7 @@ func GetGitHubCopilotModels() []*ModelInfo {
|
||||
ContextLength: 128000,
|
||||
MaxCompletionTokens: 16384,
|
||||
SupportedEndpoints: []string{"/responses"},
|
||||
Thinking: &ThinkingSupport{Levels: []string{"none", "low", "medium", "high"}},
|
||||
},
|
||||
{
|
||||
ID: "gpt-5.1-codex-max",
|
||||
@@ -216,6 +222,7 @@ func GetGitHubCopilotModels() []*ModelInfo {
|
||||
ContextLength: 200000,
|
||||
MaxCompletionTokens: 32768,
|
||||
SupportedEndpoints: []string{"/responses"},
|
||||
Thinking: &ThinkingSupport{Levels: []string{"none", "low", "medium", "high", "xhigh"}},
|
||||
},
|
||||
{
|
||||
ID: "gpt-5.2",
|
||||
@@ -228,6 +235,7 @@ func GetGitHubCopilotModels() []*ModelInfo {
|
||||
ContextLength: 200000,
|
||||
MaxCompletionTokens: 32768,
|
||||
SupportedEndpoints: []string{"/chat/completions", "/responses"},
|
||||
Thinking: &ThinkingSupport{Levels: []string{"none", "low", "medium", "high", "xhigh"}},
|
||||
},
|
||||
{
|
||||
ID: "gpt-5.2-codex",
|
||||
@@ -240,6 +248,7 @@ func GetGitHubCopilotModels() []*ModelInfo {
|
||||
ContextLength: 200000,
|
||||
MaxCompletionTokens: 32768,
|
||||
SupportedEndpoints: []string{"/responses"},
|
||||
Thinking: &ThinkingSupport{Levels: []string{"none", "low", "medium", "high", "xhigh"}},
|
||||
},
|
||||
{
|
||||
ID: "claude-haiku-4.5",
|
||||
@@ -448,6 +457,87 @@ func GetKiroModels() []*ModelInfo {
|
||||
MaxCompletionTokens: 64000,
|
||||
Thinking: &ThinkingSupport{Min: 1024, Max: 32000, ZeroAllowed: true, DynamicAllowed: true},
|
||||
},
|
||||
// --- 第三方模型 (通过 Kiro 接入) ---
|
||||
{
|
||||
ID: "kiro-deepseek-3-2",
|
||||
Object: "model",
|
||||
Created: 1732752000,
|
||||
OwnedBy: "aws",
|
||||
Type: "kiro",
|
||||
DisplayName: "Kiro DeepSeek 3.2",
|
||||
Description: "DeepSeek 3.2 via Kiro",
|
||||
ContextLength: 128000,
|
||||
MaxCompletionTokens: 32768,
|
||||
Thinking: &ThinkingSupport{Min: 1024, Max: 32000, ZeroAllowed: true, DynamicAllowed: true},
|
||||
},
|
||||
{
|
||||
ID: "kiro-minimax-m2-1",
|
||||
Object: "model",
|
||||
Created: 1732752000,
|
||||
OwnedBy: "aws",
|
||||
Type: "kiro",
|
||||
DisplayName: "Kiro MiniMax M2.1",
|
||||
Description: "MiniMax M2.1 via Kiro",
|
||||
ContextLength: 200000,
|
||||
MaxCompletionTokens: 64000,
|
||||
Thinking: &ThinkingSupport{Min: 1024, Max: 32000, ZeroAllowed: true, DynamicAllowed: true},
|
||||
},
|
||||
{
|
||||
ID: "kiro-qwen3-coder-next",
|
||||
Object: "model",
|
||||
Created: 1732752000,
|
||||
OwnedBy: "aws",
|
||||
Type: "kiro",
|
||||
DisplayName: "Kiro Qwen3 Coder Next",
|
||||
Description: "Qwen3 Coder Next via Kiro",
|
||||
ContextLength: 128000,
|
||||
MaxCompletionTokens: 32768,
|
||||
Thinking: &ThinkingSupport{Min: 1024, Max: 32000, ZeroAllowed: true, DynamicAllowed: true},
|
||||
},
|
||||
{
|
||||
ID: "kiro-gpt-4o",
|
||||
Object: "model",
|
||||
Created: 1732752000,
|
||||
OwnedBy: "aws",
|
||||
Type: "kiro",
|
||||
DisplayName: "Kiro GPT-4o",
|
||||
Description: "OpenAI GPT-4o via Kiro",
|
||||
ContextLength: 128000,
|
||||
MaxCompletionTokens: 16384,
|
||||
},
|
||||
{
|
||||
ID: "kiro-gpt-4",
|
||||
Object: "model",
|
||||
Created: 1732752000,
|
||||
OwnedBy: "aws",
|
||||
Type: "kiro",
|
||||
DisplayName: "Kiro GPT-4",
|
||||
Description: "OpenAI GPT-4 via Kiro",
|
||||
ContextLength: 128000,
|
||||
MaxCompletionTokens: 8192,
|
||||
},
|
||||
{
|
||||
ID: "kiro-gpt-4-turbo",
|
||||
Object: "model",
|
||||
Created: 1732752000,
|
||||
OwnedBy: "aws",
|
||||
Type: "kiro",
|
||||
DisplayName: "Kiro GPT-4 Turbo",
|
||||
Description: "OpenAI GPT-4 Turbo via Kiro",
|
||||
ContextLength: 128000,
|
||||
MaxCompletionTokens: 16384,
|
||||
},
|
||||
{
|
||||
ID: "kiro-gpt-3-5-turbo",
|
||||
Object: "model",
|
||||
Created: 1732752000,
|
||||
OwnedBy: "aws",
|
||||
Type: "kiro",
|
||||
DisplayName: "Kiro GPT-3.5 Turbo",
|
||||
Description: "OpenAI GPT-3.5 Turbo via Kiro",
|
||||
ContextLength: 16384,
|
||||
MaxCompletionTokens: 4096,
|
||||
},
|
||||
// --- Agentic Variants (Optimized for coding agents with chunked writes) ---
|
||||
{
|
||||
ID: "kiro-claude-opus-4-6-agentic",
|
||||
|
||||
@@ -742,6 +742,20 @@ func GetOpenAIModels() []*ModelInfo {
|
||||
SupportedParameters: []string{"tools"},
|
||||
Thinking: &ThinkingSupport{Levels: []string{"low", "medium", "high", "xhigh"}},
|
||||
},
|
||||
{
|
||||
ID: "gpt-5.3-codex-spark",
|
||||
Object: "model",
|
||||
Created: 1770912000,
|
||||
OwnedBy: "openai",
|
||||
Type: "openai",
|
||||
Version: "gpt-5.3",
|
||||
DisplayName: "GPT 5.3 Codex Spark",
|
||||
Description: "Ultra-fast coding model.",
|
||||
ContextLength: 128000,
|
||||
MaxCompletionTokens: 128000,
|
||||
SupportedParameters: []string{"tools"},
|
||||
Thinking: &ThinkingSupport{Levels: []string{"low", "medium", "high", "xhigh"}},
|
||||
},
|
||||
}
|
||||
}
|
||||
|
||||
@@ -814,6 +828,7 @@ func GetIFlowModels() []*ModelInfo {
|
||||
{ID: "kimi-k2-0905", DisplayName: "Kimi-K2-Instruct-0905", Description: "Moonshot Kimi K2 instruct 0905", Created: 1757030400},
|
||||
{ID: "glm-4.6", DisplayName: "GLM-4.6", Description: "Zhipu GLM 4.6 general model", Created: 1759190400, Thinking: iFlowThinkingSupport},
|
||||
{ID: "glm-4.7", DisplayName: "GLM-4.7", Description: "Zhipu GLM 4.7 general model", Created: 1766448000, Thinking: iFlowThinkingSupport},
|
||||
{ID: "glm-5", DisplayName: "GLM-5", Description: "Zhipu GLM 5 general model", Created: 1770768000, Thinking: iFlowThinkingSupport},
|
||||
{ID: "kimi-k2", DisplayName: "Kimi-K2", Description: "Moonshot Kimi K2 general model", Created: 1752192000},
|
||||
{ID: "kimi-k2-thinking", DisplayName: "Kimi-K2-Thinking", Description: "Moonshot Kimi K2 thinking model", Created: 1762387200},
|
||||
{ID: "deepseek-v3.2-chat", DisplayName: "DeepSeek-V3.2", Description: "DeepSeek V3.2 Chat", Created: 1764576000},
|
||||
@@ -828,6 +843,7 @@ func GetIFlowModels() []*ModelInfo {
|
||||
{ID: "qwen3-235b", DisplayName: "Qwen3-235B-A22B", Description: "Qwen3 235B A22B", Created: 1753401600},
|
||||
{ID: "minimax-m2", DisplayName: "MiniMax-M2", Description: "MiniMax M2", Created: 1758672000, Thinking: iFlowThinkingSupport},
|
||||
{ID: "minimax-m2.1", DisplayName: "MiniMax-M2.1", Description: "MiniMax M2.1", Created: 1766448000, Thinking: iFlowThinkingSupport},
|
||||
{ID: "minimax-m2.5", DisplayName: "MiniMax-M2.5", Description: "MiniMax M2.5", Created: 1770825600, Thinking: iFlowThinkingSupport},
|
||||
{ID: "iflow-rome-30ba3b", DisplayName: "iFlow-ROME", Description: "iFlow Rome 30BA3B model", Created: 1736899200},
|
||||
{ID: "kimi-k2.5", DisplayName: "Kimi-K2.5", Description: "Moonshot Kimi K2.5", Created: 1769443200, Thinking: iFlowThinkingSupport},
|
||||
}
|
||||
@@ -866,7 +882,7 @@ func GetAntigravityModelConfig() map[string]*AntigravityModelConfig {
|
||||
"gemini-3-flash": {Thinking: &ThinkingSupport{Min: 128, Max: 32768, ZeroAllowed: false, DynamicAllowed: true, Levels: []string{"minimal", "low", "medium", "high"}}},
|
||||
"claude-sonnet-4-5-thinking": {Thinking: &ThinkingSupport{Min: 1024, Max: 128000, ZeroAllowed: true, DynamicAllowed: true}, MaxCompletionTokens: 64000},
|
||||
"claude-opus-4-5-thinking": {Thinking: &ThinkingSupport{Min: 1024, Max: 128000, ZeroAllowed: true, DynamicAllowed: true}, MaxCompletionTokens: 64000},
|
||||
"claude-opus-4-6-thinking": {Thinking: &ThinkingSupport{Min: 1024, Max: 128000, ZeroAllowed: true, DynamicAllowed: true}, MaxCompletionTokens: 128000},
|
||||
"claude-opus-4-6-thinking": {Thinking: &ThinkingSupport{Min: 1024, Max: 128000, ZeroAllowed: true, DynamicAllowed: true}, MaxCompletionTokens: 64000},
|
||||
"claude-sonnet-4-5": {MaxCompletionTokens: 64000},
|
||||
"gpt-oss-120b-medium": {},
|
||||
"tab_flash_lite_preview": {},
|
||||
|
||||
@@ -601,8 +601,7 @@ func (r *ModelRegistry) SetModelQuotaExceeded(clientID, modelID string) {
|
||||
defer r.mutex.Unlock()
|
||||
|
||||
if registration, exists := r.models[modelID]; exists {
|
||||
now := time.Now()
|
||||
registration.QuotaExceededClients[clientID] = &now
|
||||
registration.QuotaExceededClients[clientID] = new(time.Now())
|
||||
log.Debugf("Marked model %s as quota exceeded for client %s", modelID, clientID)
|
||||
}
|
||||
}
|
||||
|
||||
@@ -1007,7 +1007,12 @@ func (e *AntigravityExecutor) CountTokens(ctx context.Context, auth *cliproxyaut
|
||||
func FetchAntigravityModels(ctx context.Context, auth *cliproxyauth.Auth, cfg *config.Config) []*registry.ModelInfo {
|
||||
exec := &AntigravityExecutor{cfg: cfg}
|
||||
token, updatedAuth, errToken := exec.ensureAccessToken(ctx, auth)
|
||||
if errToken != nil || token == "" {
|
||||
if errToken != nil {
|
||||
log.Warnf("antigravity executor: fetch models failed for %s: token error: %v", auth.ID, errToken)
|
||||
return nil
|
||||
}
|
||||
if token == "" {
|
||||
log.Warnf("antigravity executor: fetch models failed for %s: got empty token", auth.ID)
|
||||
return nil
|
||||
}
|
||||
if updatedAuth != nil {
|
||||
@@ -1021,6 +1026,7 @@ func FetchAntigravityModels(ctx context.Context, auth *cliproxyauth.Auth, cfg *c
|
||||
modelsURL := baseURL + antigravityModelsPath
|
||||
httpReq, errReq := http.NewRequestWithContext(ctx, http.MethodPost, modelsURL, bytes.NewReader([]byte(`{}`)))
|
||||
if errReq != nil {
|
||||
log.Warnf("antigravity executor: fetch models failed for %s: create request error: %v", auth.ID, errReq)
|
||||
return nil
|
||||
}
|
||||
httpReq.Header.Set("Content-Type", "application/json")
|
||||
@@ -1033,12 +1039,14 @@ func FetchAntigravityModels(ctx context.Context, auth *cliproxyauth.Auth, cfg *c
|
||||
httpResp, errDo := httpClient.Do(httpReq)
|
||||
if errDo != nil {
|
||||
if errors.Is(errDo, context.Canceled) || errors.Is(errDo, context.DeadlineExceeded) {
|
||||
log.Warnf("antigravity executor: fetch models failed for %s: context canceled: %v", auth.ID, errDo)
|
||||
return nil
|
||||
}
|
||||
if idx+1 < len(baseURLs) {
|
||||
log.Debugf("antigravity executor: models request error on base url %s, retrying with fallback base url: %s", baseURL, baseURLs[idx+1])
|
||||
continue
|
||||
}
|
||||
log.Warnf("antigravity executor: fetch models failed for %s: request error: %v", auth.ID, errDo)
|
||||
return nil
|
||||
}
|
||||
|
||||
@@ -1051,6 +1059,7 @@ func FetchAntigravityModels(ctx context.Context, auth *cliproxyauth.Auth, cfg *c
|
||||
log.Debugf("antigravity executor: models read error on base url %s, retrying with fallback base url: %s", baseURL, baseURLs[idx+1])
|
||||
continue
|
||||
}
|
||||
log.Warnf("antigravity executor: fetch models failed for %s: read body error: %v", auth.ID, errRead)
|
||||
return nil
|
||||
}
|
||||
if httpResp.StatusCode < http.StatusOK || httpResp.StatusCode >= http.StatusMultipleChoices {
|
||||
@@ -1058,11 +1067,13 @@ func FetchAntigravityModels(ctx context.Context, auth *cliproxyauth.Auth, cfg *c
|
||||
log.Debugf("antigravity executor: models request rate limited on base url %s, retrying with fallback base url: %s", baseURL, baseURLs[idx+1])
|
||||
continue
|
||||
}
|
||||
log.Warnf("antigravity executor: fetch models failed for %s: unexpected status %d, body: %s", auth.ID, httpResp.StatusCode, string(bodyBytes))
|
||||
return nil
|
||||
}
|
||||
|
||||
result := gjson.GetBytes(bodyBytes, "models")
|
||||
if !result.Exists() {
|
||||
log.Warnf("antigravity executor: fetch models failed for %s: no models field in response, body: %s", auth.ID, string(bodyBytes))
|
||||
return nil
|
||||
}
|
||||
|
||||
|
||||
@@ -0,0 +1,159 @@
|
||||
package executor
|
||||
|
||||
import (
|
||||
"context"
|
||||
"encoding/json"
|
||||
"io"
|
||||
"testing"
|
||||
|
||||
cliproxyauth "github.com/router-for-me/CLIProxyAPI/v6/sdk/cliproxy/auth"
|
||||
)
|
||||
|
||||
func TestAntigravityBuildRequest_SanitizesGeminiToolSchema(t *testing.T) {
|
||||
body := buildRequestBodyFromPayload(t, "gemini-2.5-pro")
|
||||
|
||||
decl := extractFirstFunctionDeclaration(t, body)
|
||||
if _, ok := decl["parametersJsonSchema"]; ok {
|
||||
t.Fatalf("parametersJsonSchema should be renamed to parameters")
|
||||
}
|
||||
|
||||
params, ok := decl["parameters"].(map[string]any)
|
||||
if !ok {
|
||||
t.Fatalf("parameters missing or invalid type")
|
||||
}
|
||||
assertSchemaSanitizedAndPropertyPreserved(t, params)
|
||||
}
|
||||
|
||||
func TestAntigravityBuildRequest_SanitizesAntigravityToolSchema(t *testing.T) {
|
||||
body := buildRequestBodyFromPayload(t, "claude-opus-4-6")
|
||||
|
||||
decl := extractFirstFunctionDeclaration(t, body)
|
||||
params, ok := decl["parameters"].(map[string]any)
|
||||
if !ok {
|
||||
t.Fatalf("parameters missing or invalid type")
|
||||
}
|
||||
assertSchemaSanitizedAndPropertyPreserved(t, params)
|
||||
}
|
||||
|
||||
func buildRequestBodyFromPayload(t *testing.T, modelName string) map[string]any {
|
||||
t.Helper()
|
||||
|
||||
executor := &AntigravityExecutor{}
|
||||
auth := &cliproxyauth.Auth{}
|
||||
payload := []byte(`{
|
||||
"request": {
|
||||
"tools": [
|
||||
{
|
||||
"function_declarations": [
|
||||
{
|
||||
"name": "tool_1",
|
||||
"parametersJsonSchema": {
|
||||
"$schema": "http://json-schema.org/draft-07/schema#",
|
||||
"$id": "root-schema",
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"$id": {"type": "string"},
|
||||
"arg": {
|
||||
"type": "object",
|
||||
"prefill": "hello",
|
||||
"properties": {
|
||||
"mode": {
|
||||
"type": "string",
|
||||
"enum": ["a", "b"],
|
||||
"enumTitles": ["A", "B"]
|
||||
}
|
||||
}
|
||||
}
|
||||
},
|
||||
"patternProperties": {
|
||||
"^x-": {"type": "string"}
|
||||
}
|
||||
}
|
||||
}
|
||||
]
|
||||
}
|
||||
]
|
||||
}
|
||||
}`)
|
||||
|
||||
req, err := executor.buildRequest(context.Background(), auth, "token", modelName, payload, false, "", "https://example.com")
|
||||
if err != nil {
|
||||
t.Fatalf("buildRequest error: %v", err)
|
||||
}
|
||||
|
||||
raw, err := io.ReadAll(req.Body)
|
||||
if err != nil {
|
||||
t.Fatalf("read request body error: %v", err)
|
||||
}
|
||||
|
||||
var body map[string]any
|
||||
if err := json.Unmarshal(raw, &body); err != nil {
|
||||
t.Fatalf("unmarshal request body error: %v, body=%s", err, string(raw))
|
||||
}
|
||||
return body
|
||||
}
|
||||
|
||||
func extractFirstFunctionDeclaration(t *testing.T, body map[string]any) map[string]any {
|
||||
t.Helper()
|
||||
|
||||
request, ok := body["request"].(map[string]any)
|
||||
if !ok {
|
||||
t.Fatalf("request missing or invalid type")
|
||||
}
|
||||
tools, ok := request["tools"].([]any)
|
||||
if !ok || len(tools) == 0 {
|
||||
t.Fatalf("tools missing or empty")
|
||||
}
|
||||
tool, ok := tools[0].(map[string]any)
|
||||
if !ok {
|
||||
t.Fatalf("first tool invalid type")
|
||||
}
|
||||
decls, ok := tool["function_declarations"].([]any)
|
||||
if !ok || len(decls) == 0 {
|
||||
t.Fatalf("function_declarations missing or empty")
|
||||
}
|
||||
decl, ok := decls[0].(map[string]any)
|
||||
if !ok {
|
||||
t.Fatalf("first function declaration invalid type")
|
||||
}
|
||||
return decl
|
||||
}
|
||||
|
||||
func assertSchemaSanitizedAndPropertyPreserved(t *testing.T, params map[string]any) {
|
||||
t.Helper()
|
||||
|
||||
if _, ok := params["$id"]; ok {
|
||||
t.Fatalf("root $id should be removed from schema")
|
||||
}
|
||||
if _, ok := params["patternProperties"]; ok {
|
||||
t.Fatalf("patternProperties should be removed from schema")
|
||||
}
|
||||
|
||||
props, ok := params["properties"].(map[string]any)
|
||||
if !ok {
|
||||
t.Fatalf("properties missing or invalid type")
|
||||
}
|
||||
if _, ok := props["$id"]; !ok {
|
||||
t.Fatalf("property named $id should be preserved")
|
||||
}
|
||||
|
||||
arg, ok := props["arg"].(map[string]any)
|
||||
if !ok {
|
||||
t.Fatalf("arg property missing or invalid type")
|
||||
}
|
||||
if _, ok := arg["prefill"]; ok {
|
||||
t.Fatalf("prefill should be removed from nested schema")
|
||||
}
|
||||
|
||||
argProps, ok := arg["properties"].(map[string]any)
|
||||
if !ok {
|
||||
t.Fatalf("arg.properties missing or invalid type")
|
||||
}
|
||||
mode, ok := argProps["mode"].(map[string]any)
|
||||
if !ok {
|
||||
t.Fatalf("mode property missing or invalid type")
|
||||
}
|
||||
if _, ok := mode["enumTitles"]; ok {
|
||||
t.Fatalf("enumTitles should be removed from nested schema")
|
||||
}
|
||||
}
|
||||
@@ -28,8 +28,8 @@ import (
|
||||
)
|
||||
|
||||
const (
|
||||
codexClientVersion = "0.98.0"
|
||||
codexUserAgent = "codex_cli_rs/0.98.0 (Mac OS 26.0.1; arm64) Apple_Terminal/464"
|
||||
codexClientVersion = "0.101.0"
|
||||
codexUserAgent = "codex_cli_rs/0.101.0 (Mac OS 26.0.1; arm64) Apple_Terminal/464"
|
||||
)
|
||||
|
||||
var dataTag = []byte("data:")
|
||||
|
||||
@@ -899,8 +899,7 @@ func parseRetryDelay(errorBody []byte) (*time.Duration, error) {
|
||||
if matches := re.FindStringSubmatch(message); len(matches) > 1 {
|
||||
seconds, err := strconv.Atoi(matches[1])
|
||||
if err == nil {
|
||||
duration := time.Duration(seconds) * time.Second
|
||||
return &duration, nil
|
||||
return new(time.Duration(seconds) * time.Second), nil
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
@@ -14,6 +14,7 @@ import (
|
||||
"github.com/google/uuid"
|
||||
copilotauth "github.com/router-for-me/CLIProxyAPI/v6/internal/auth/copilot"
|
||||
"github.com/router-for-me/CLIProxyAPI/v6/internal/config"
|
||||
"github.com/router-for-me/CLIProxyAPI/v6/internal/thinking"
|
||||
cliproxyauth "github.com/router-for-me/CLIProxyAPI/v6/sdk/cliproxy/auth"
|
||||
cliproxyexecutor "github.com/router-for-me/CLIProxyAPI/v6/sdk/cliproxy/executor"
|
||||
sdktranslator "github.com/router-for-me/CLIProxyAPI/v6/sdk/translator"
|
||||
@@ -109,7 +110,7 @@ func (e *GitHubCopilotExecutor) Execute(ctx context.Context, auth *cliproxyauth.
|
||||
defer reporter.trackFailure(ctx, &err)
|
||||
|
||||
from := opts.SourceFormat
|
||||
useResponses := useGitHubCopilotResponsesEndpoint(from)
|
||||
useResponses := useGitHubCopilotResponsesEndpoint(from, req.Model)
|
||||
to := sdktranslator.FromString("openai")
|
||||
if useResponses {
|
||||
to = sdktranslator.FromString("openai-response")
|
||||
@@ -122,6 +123,22 @@ func (e *GitHubCopilotExecutor) Execute(ctx context.Context, auth *cliproxyauth.
|
||||
body := sdktranslator.TranslateRequest(from, to, req.Model, bytes.Clone(req.Payload), false)
|
||||
body = e.normalizeModel(req.Model, body)
|
||||
body = flattenAssistantContent(body)
|
||||
|
||||
thinkingProvider := "openai"
|
||||
if useResponses {
|
||||
thinkingProvider = "codex"
|
||||
}
|
||||
body, err = thinking.ApplyThinking(body, req.Model, from.String(), thinkingProvider, e.Identifier())
|
||||
if err != nil {
|
||||
return resp, err
|
||||
}
|
||||
|
||||
if useResponses {
|
||||
body = normalizeGitHubCopilotResponsesInput(body)
|
||||
body = normalizeGitHubCopilotResponsesTools(body)
|
||||
} else {
|
||||
body = normalizeGitHubCopilotChatTools(body)
|
||||
}
|
||||
requestedModel := payloadRequestedModel(opts, req.Model)
|
||||
body = applyPayloadConfigWithRoot(e.cfg, req.Model, to.String(), "", body, originalTranslated, requestedModel)
|
||||
body, _ = sjson.SetBytes(body, "stream", false)
|
||||
@@ -198,7 +215,12 @@ func (e *GitHubCopilotExecutor) Execute(ctx context.Context, auth *cliproxyauth.
|
||||
}
|
||||
|
||||
var param any
|
||||
converted := sdktranslator.TranslateNonStream(ctx, to, from, req.Model, bytes.Clone(opts.OriginalRequest), body, data, ¶m)
|
||||
converted := ""
|
||||
if useResponses && from.String() == "claude" {
|
||||
converted = translateGitHubCopilotResponsesNonStreamToClaude(data)
|
||||
} else {
|
||||
converted = sdktranslator.TranslateNonStream(ctx, to, from, req.Model, bytes.Clone(opts.OriginalRequest), body, data, ¶m)
|
||||
}
|
||||
resp = cliproxyexecutor.Response{Payload: []byte(converted)}
|
||||
reporter.ensurePublished(ctx)
|
||||
return resp, nil
|
||||
@@ -215,7 +237,7 @@ func (e *GitHubCopilotExecutor) ExecuteStream(ctx context.Context, auth *cliprox
|
||||
defer reporter.trackFailure(ctx, &err)
|
||||
|
||||
from := opts.SourceFormat
|
||||
useResponses := useGitHubCopilotResponsesEndpoint(from)
|
||||
useResponses := useGitHubCopilotResponsesEndpoint(from, req.Model)
|
||||
to := sdktranslator.FromString("openai")
|
||||
if useResponses {
|
||||
to = sdktranslator.FromString("openai-response")
|
||||
@@ -228,6 +250,22 @@ func (e *GitHubCopilotExecutor) ExecuteStream(ctx context.Context, auth *cliprox
|
||||
body := sdktranslator.TranslateRequest(from, to, req.Model, bytes.Clone(req.Payload), true)
|
||||
body = e.normalizeModel(req.Model, body)
|
||||
body = flattenAssistantContent(body)
|
||||
|
||||
thinkingProvider := "openai"
|
||||
if useResponses {
|
||||
thinkingProvider = "codex"
|
||||
}
|
||||
body, err = thinking.ApplyThinking(body, req.Model, from.String(), thinkingProvider, e.Identifier())
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
if useResponses {
|
||||
body = normalizeGitHubCopilotResponsesInput(body)
|
||||
body = normalizeGitHubCopilotResponsesTools(body)
|
||||
} else {
|
||||
body = normalizeGitHubCopilotChatTools(body)
|
||||
}
|
||||
requestedModel := payloadRequestedModel(opts, req.Model)
|
||||
body = applyPayloadConfigWithRoot(e.cfg, req.Model, to.String(), "", body, originalTranslated, requestedModel)
|
||||
body, _ = sjson.SetBytes(body, "stream", true)
|
||||
@@ -328,7 +366,12 @@ func (e *GitHubCopilotExecutor) ExecuteStream(ctx context.Context, auth *cliprox
|
||||
}
|
||||
}
|
||||
|
||||
chunks := sdktranslator.TranslateStream(ctx, to, from, req.Model, bytes.Clone(opts.OriginalRequest), body, bytes.Clone(line), ¶m)
|
||||
var chunks []string
|
||||
if useResponses && from.String() == "claude" {
|
||||
chunks = translateGitHubCopilotResponsesStreamToClaude(bytes.Clone(line), ¶m)
|
||||
} else {
|
||||
chunks = sdktranslator.TranslateStream(ctx, to, from, req.Model, bytes.Clone(opts.OriginalRequest), body, bytes.Clone(line), ¶m)
|
||||
}
|
||||
for i := range chunks {
|
||||
out <- cliproxyexecutor.StreamChunk{Payload: []byte(chunks[i])}
|
||||
}
|
||||
@@ -471,14 +514,23 @@ func detectVisionContent(body []byte) bool {
|
||||
return false
|
||||
}
|
||||
|
||||
// normalizeModel is a no-op as GitHub Copilot accepts model names directly.
|
||||
// Model mapping should be done at the registry level if needed.
|
||||
func (e *GitHubCopilotExecutor) normalizeModel(_ string, body []byte) []byte {
|
||||
// normalizeModel strips the suffix (e.g. "(medium)") from the model name
|
||||
// before sending to GitHub Copilot, as the upstream API does not accept
|
||||
// suffixed model identifiers.
|
||||
func (e *GitHubCopilotExecutor) normalizeModel(model string, body []byte) []byte {
|
||||
baseModel := thinking.ParseSuffix(model).ModelName
|
||||
if baseModel != model {
|
||||
body, _ = sjson.SetBytes(body, "model", baseModel)
|
||||
}
|
||||
return body
|
||||
}
|
||||
|
||||
func useGitHubCopilotResponsesEndpoint(sourceFormat sdktranslator.Format) bool {
|
||||
return sourceFormat.String() == "openai-response"
|
||||
func useGitHubCopilotResponsesEndpoint(sourceFormat sdktranslator.Format, model string) bool {
|
||||
if sourceFormat.String() == "openai-response" {
|
||||
return true
|
||||
}
|
||||
baseModel := strings.ToLower(thinking.ParseSuffix(model).ModelName)
|
||||
return strings.Contains(baseModel, "codex")
|
||||
}
|
||||
|
||||
// flattenAssistantContent converts assistant message content from array format
|
||||
@@ -513,6 +565,411 @@ func flattenAssistantContent(body []byte) []byte {
|
||||
return result
|
||||
}
|
||||
|
||||
func normalizeGitHubCopilotChatTools(body []byte) []byte {
|
||||
tools := gjson.GetBytes(body, "tools")
|
||||
if tools.Exists() {
|
||||
filtered := "[]"
|
||||
if tools.IsArray() {
|
||||
for _, tool := range tools.Array() {
|
||||
if tool.Get("type").String() != "function" {
|
||||
continue
|
||||
}
|
||||
filtered, _ = sjson.SetRaw(filtered, "-1", tool.Raw)
|
||||
}
|
||||
}
|
||||
body, _ = sjson.SetRawBytes(body, "tools", []byte(filtered))
|
||||
}
|
||||
|
||||
toolChoice := gjson.GetBytes(body, "tool_choice")
|
||||
if !toolChoice.Exists() {
|
||||
return body
|
||||
}
|
||||
if toolChoice.Type == gjson.String {
|
||||
switch toolChoice.String() {
|
||||
case "auto", "none", "required":
|
||||
return body
|
||||
}
|
||||
}
|
||||
body, _ = sjson.SetBytes(body, "tool_choice", "auto")
|
||||
return body
|
||||
}
|
||||
|
||||
func normalizeGitHubCopilotResponsesInput(body []byte) []byte {
|
||||
input := gjson.GetBytes(body, "input")
|
||||
if input.Exists() {
|
||||
if input.Type == gjson.String {
|
||||
return body
|
||||
}
|
||||
inputString := input.Raw
|
||||
if input.Type != gjson.JSON {
|
||||
inputString = input.String()
|
||||
}
|
||||
body, _ = sjson.SetBytes(body, "input", inputString)
|
||||
return body
|
||||
}
|
||||
|
||||
var parts []string
|
||||
if system := gjson.GetBytes(body, "system"); system.Exists() {
|
||||
if text := strings.TrimSpace(collectTextFromNode(system)); text != "" {
|
||||
parts = append(parts, text)
|
||||
}
|
||||
}
|
||||
if messages := gjson.GetBytes(body, "messages"); messages.Exists() && messages.IsArray() {
|
||||
for _, msg := range messages.Array() {
|
||||
if text := strings.TrimSpace(collectTextFromNode(msg.Get("content"))); text != "" {
|
||||
parts = append(parts, text)
|
||||
}
|
||||
}
|
||||
}
|
||||
body, _ = sjson.SetBytes(body, "input", strings.Join(parts, "\n"))
|
||||
return body
|
||||
}
|
||||
|
||||
func normalizeGitHubCopilotResponsesTools(body []byte) []byte {
|
||||
tools := gjson.GetBytes(body, "tools")
|
||||
if tools.Exists() {
|
||||
filtered := "[]"
|
||||
if tools.IsArray() {
|
||||
for _, tool := range tools.Array() {
|
||||
toolType := tool.Get("type").String()
|
||||
// Accept OpenAI format (type="function") and Claude format
|
||||
// (no type field, but has top-level name + input_schema).
|
||||
if toolType != "" && toolType != "function" {
|
||||
continue
|
||||
}
|
||||
name := tool.Get("name").String()
|
||||
if name == "" {
|
||||
name = tool.Get("function.name").String()
|
||||
}
|
||||
if name == "" {
|
||||
continue
|
||||
}
|
||||
normalized := `{"type":"function","name":""}`
|
||||
normalized, _ = sjson.Set(normalized, "name", name)
|
||||
if desc := tool.Get("description").String(); desc != "" {
|
||||
normalized, _ = sjson.Set(normalized, "description", desc)
|
||||
} else if desc = tool.Get("function.description").String(); desc != "" {
|
||||
normalized, _ = sjson.Set(normalized, "description", desc)
|
||||
}
|
||||
if params := tool.Get("parameters"); params.Exists() {
|
||||
normalized, _ = sjson.SetRaw(normalized, "parameters", params.Raw)
|
||||
} else if params = tool.Get("function.parameters"); params.Exists() {
|
||||
normalized, _ = sjson.SetRaw(normalized, "parameters", params.Raw)
|
||||
} else if params = tool.Get("input_schema"); params.Exists() {
|
||||
normalized, _ = sjson.SetRaw(normalized, "parameters", params.Raw)
|
||||
}
|
||||
filtered, _ = sjson.SetRaw(filtered, "-1", normalized)
|
||||
}
|
||||
}
|
||||
body, _ = sjson.SetRawBytes(body, "tools", []byte(filtered))
|
||||
}
|
||||
|
||||
toolChoice := gjson.GetBytes(body, "tool_choice")
|
||||
if !toolChoice.Exists() {
|
||||
return body
|
||||
}
|
||||
if toolChoice.Type == gjson.String {
|
||||
switch toolChoice.String() {
|
||||
case "auto", "none", "required":
|
||||
return body
|
||||
default:
|
||||
body, _ = sjson.SetBytes(body, "tool_choice", "auto")
|
||||
return body
|
||||
}
|
||||
}
|
||||
if toolChoice.Type == gjson.JSON {
|
||||
choiceType := toolChoice.Get("type").String()
|
||||
if choiceType == "function" {
|
||||
name := toolChoice.Get("name").String()
|
||||
if name == "" {
|
||||
name = toolChoice.Get("function.name").String()
|
||||
}
|
||||
if name != "" {
|
||||
normalized := `{"type":"function","name":""}`
|
||||
normalized, _ = sjson.Set(normalized, "name", name)
|
||||
body, _ = sjson.SetRawBytes(body, "tool_choice", []byte(normalized))
|
||||
return body
|
||||
}
|
||||
}
|
||||
}
|
||||
body, _ = sjson.SetBytes(body, "tool_choice", "auto")
|
||||
return body
|
||||
}
|
||||
|
||||
func collectTextFromNode(node gjson.Result) string {
|
||||
if !node.Exists() {
|
||||
return ""
|
||||
}
|
||||
if node.Type == gjson.String {
|
||||
return node.String()
|
||||
}
|
||||
if node.IsArray() {
|
||||
var parts []string
|
||||
for _, item := range node.Array() {
|
||||
if item.Type == gjson.String {
|
||||
if text := item.String(); text != "" {
|
||||
parts = append(parts, text)
|
||||
}
|
||||
continue
|
||||
}
|
||||
if text := item.Get("text").String(); text != "" {
|
||||
parts = append(parts, text)
|
||||
continue
|
||||
}
|
||||
if nested := collectTextFromNode(item.Get("content")); nested != "" {
|
||||
parts = append(parts, nested)
|
||||
}
|
||||
}
|
||||
return strings.Join(parts, "\n")
|
||||
}
|
||||
if node.Type == gjson.JSON {
|
||||
if text := node.Get("text").String(); text != "" {
|
||||
return text
|
||||
}
|
||||
if nested := collectTextFromNode(node.Get("content")); nested != "" {
|
||||
return nested
|
||||
}
|
||||
return node.Raw
|
||||
}
|
||||
return node.String()
|
||||
}
|
||||
|
||||
type githubCopilotResponsesStreamToolState struct {
|
||||
Index int
|
||||
ID string
|
||||
Name string
|
||||
}
|
||||
|
||||
type githubCopilotResponsesStreamState struct {
|
||||
MessageStarted bool
|
||||
MessageStopSent bool
|
||||
TextBlockStarted bool
|
||||
TextBlockIndex int
|
||||
NextContentIndex int
|
||||
HasToolUse bool
|
||||
OutputIndexToTool map[int]*githubCopilotResponsesStreamToolState
|
||||
ItemIDToTool map[string]*githubCopilotResponsesStreamToolState
|
||||
}
|
||||
|
||||
func translateGitHubCopilotResponsesNonStreamToClaude(data []byte) string {
|
||||
root := gjson.ParseBytes(data)
|
||||
out := `{"id":"","type":"message","role":"assistant","model":"","content":[],"stop_reason":null,"stop_sequence":null,"usage":{"input_tokens":0,"output_tokens":0}}`
|
||||
out, _ = sjson.Set(out, "id", root.Get("id").String())
|
||||
out, _ = sjson.Set(out, "model", root.Get("model").String())
|
||||
|
||||
hasToolUse := false
|
||||
if output := root.Get("output"); output.Exists() && output.IsArray() {
|
||||
for _, item := range output.Array() {
|
||||
switch item.Get("type").String() {
|
||||
case "message":
|
||||
if content := item.Get("content"); content.Exists() && content.IsArray() {
|
||||
for _, part := range content.Array() {
|
||||
if part.Get("type").String() != "output_text" {
|
||||
continue
|
||||
}
|
||||
text := part.Get("text").String()
|
||||
if text == "" {
|
||||
continue
|
||||
}
|
||||
block := `{"type":"text","text":""}`
|
||||
block, _ = sjson.Set(block, "text", text)
|
||||
out, _ = sjson.SetRaw(out, "content.-1", block)
|
||||
}
|
||||
}
|
||||
case "function_call":
|
||||
hasToolUse = true
|
||||
toolUse := `{"type":"tool_use","id":"","name":"","input":{}}`
|
||||
toolID := item.Get("call_id").String()
|
||||
if toolID == "" {
|
||||
toolID = item.Get("id").String()
|
||||
}
|
||||
toolUse, _ = sjson.Set(toolUse, "id", toolID)
|
||||
toolUse, _ = sjson.Set(toolUse, "name", item.Get("name").String())
|
||||
if args := item.Get("arguments").String(); args != "" && gjson.Valid(args) {
|
||||
argObj := gjson.Parse(args)
|
||||
if argObj.IsObject() {
|
||||
toolUse, _ = sjson.SetRaw(toolUse, "input", argObj.Raw)
|
||||
}
|
||||
}
|
||||
out, _ = sjson.SetRaw(out, "content.-1", toolUse)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
inputTokens := root.Get("usage.input_tokens").Int()
|
||||
outputTokens := root.Get("usage.output_tokens").Int()
|
||||
out, _ = sjson.Set(out, "usage.input_tokens", inputTokens)
|
||||
out, _ = sjson.Set(out, "usage.output_tokens", outputTokens)
|
||||
if hasToolUse {
|
||||
out, _ = sjson.Set(out, "stop_reason", "tool_use")
|
||||
} else {
|
||||
out, _ = sjson.Set(out, "stop_reason", "end_turn")
|
||||
}
|
||||
return out
|
||||
}
|
||||
|
||||
func translateGitHubCopilotResponsesStreamToClaude(line []byte, param *any) []string {
|
||||
if *param == nil {
|
||||
*param = &githubCopilotResponsesStreamState{
|
||||
TextBlockIndex: -1,
|
||||
OutputIndexToTool: make(map[int]*githubCopilotResponsesStreamToolState),
|
||||
ItemIDToTool: make(map[string]*githubCopilotResponsesStreamToolState),
|
||||
}
|
||||
}
|
||||
state := (*param).(*githubCopilotResponsesStreamState)
|
||||
|
||||
if !bytes.HasPrefix(line, dataTag) {
|
||||
return nil
|
||||
}
|
||||
payload := bytes.TrimSpace(line[5:])
|
||||
if bytes.Equal(payload, []byte("[DONE]")) {
|
||||
return nil
|
||||
}
|
||||
if !gjson.ValidBytes(payload) {
|
||||
return nil
|
||||
}
|
||||
|
||||
event := gjson.GetBytes(payload, "type").String()
|
||||
results := make([]string, 0, 4)
|
||||
ensureMessageStart := func() {
|
||||
if state.MessageStarted {
|
||||
return
|
||||
}
|
||||
messageStart := `{"type":"message_start","message":{"id":"","type":"message","role":"assistant","model":"","content":[],"stop_reason":null,"stop_sequence":null,"usage":{"input_tokens":0,"output_tokens":0}}}`
|
||||
messageStart, _ = sjson.Set(messageStart, "message.id", gjson.GetBytes(payload, "response.id").String())
|
||||
messageStart, _ = sjson.Set(messageStart, "message.model", gjson.GetBytes(payload, "response.model").String())
|
||||
results = append(results, "event: message_start\ndata: "+messageStart+"\n\n")
|
||||
state.MessageStarted = true
|
||||
}
|
||||
startTextBlockIfNeeded := func() {
|
||||
if state.TextBlockStarted {
|
||||
return
|
||||
}
|
||||
if state.TextBlockIndex < 0 {
|
||||
state.TextBlockIndex = state.NextContentIndex
|
||||
state.NextContentIndex++
|
||||
}
|
||||
contentBlockStart := `{"type":"content_block_start","index":0,"content_block":{"type":"text","text":""}}`
|
||||
contentBlockStart, _ = sjson.Set(contentBlockStart, "index", state.TextBlockIndex)
|
||||
results = append(results, "event: content_block_start\ndata: "+contentBlockStart+"\n\n")
|
||||
state.TextBlockStarted = true
|
||||
}
|
||||
stopTextBlockIfNeeded := func() {
|
||||
if !state.TextBlockStarted {
|
||||
return
|
||||
}
|
||||
contentBlockStop := `{"type":"content_block_stop","index":0}`
|
||||
contentBlockStop, _ = sjson.Set(contentBlockStop, "index", state.TextBlockIndex)
|
||||
results = append(results, "event: content_block_stop\ndata: "+contentBlockStop+"\n\n")
|
||||
state.TextBlockStarted = false
|
||||
state.TextBlockIndex = -1
|
||||
}
|
||||
resolveTool := func(itemID string, outputIndex int) *githubCopilotResponsesStreamToolState {
|
||||
if itemID != "" {
|
||||
if tool, ok := state.ItemIDToTool[itemID]; ok {
|
||||
return tool
|
||||
}
|
||||
}
|
||||
if tool, ok := state.OutputIndexToTool[outputIndex]; ok {
|
||||
if itemID != "" {
|
||||
state.ItemIDToTool[itemID] = tool
|
||||
}
|
||||
return tool
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
switch event {
|
||||
case "response.created":
|
||||
ensureMessageStart()
|
||||
case "response.output_text.delta":
|
||||
ensureMessageStart()
|
||||
startTextBlockIfNeeded()
|
||||
delta := gjson.GetBytes(payload, "delta").String()
|
||||
if delta != "" {
|
||||
contentDelta := `{"type":"content_block_delta","index":0,"delta":{"type":"text_delta","text":""}}`
|
||||
contentDelta, _ = sjson.Set(contentDelta, "index", state.TextBlockIndex)
|
||||
contentDelta, _ = sjson.Set(contentDelta, "delta.text", delta)
|
||||
results = append(results, "event: content_block_delta\ndata: "+contentDelta+"\n\n")
|
||||
}
|
||||
case "response.output_item.added":
|
||||
if gjson.GetBytes(payload, "item.type").String() != "function_call" {
|
||||
break
|
||||
}
|
||||
ensureMessageStart()
|
||||
stopTextBlockIfNeeded()
|
||||
state.HasToolUse = true
|
||||
tool := &githubCopilotResponsesStreamToolState{
|
||||
Index: state.NextContentIndex,
|
||||
ID: gjson.GetBytes(payload, "item.call_id").String(),
|
||||
Name: gjson.GetBytes(payload, "item.name").String(),
|
||||
}
|
||||
if tool.ID == "" {
|
||||
tool.ID = gjson.GetBytes(payload, "item.id").String()
|
||||
}
|
||||
state.NextContentIndex++
|
||||
outputIndex := int(gjson.GetBytes(payload, "output_index").Int())
|
||||
state.OutputIndexToTool[outputIndex] = tool
|
||||
if itemID := gjson.GetBytes(payload, "item.id").String(); itemID != "" {
|
||||
state.ItemIDToTool[itemID] = tool
|
||||
}
|
||||
contentBlockStart := `{"type":"content_block_start","index":0,"content_block":{"type":"tool_use","id":"","name":"","input":{}}}`
|
||||
contentBlockStart, _ = sjson.Set(contentBlockStart, "index", tool.Index)
|
||||
contentBlockStart, _ = sjson.Set(contentBlockStart, "content_block.id", tool.ID)
|
||||
contentBlockStart, _ = sjson.Set(contentBlockStart, "content_block.name", tool.Name)
|
||||
results = append(results, "event: content_block_start\ndata: "+contentBlockStart+"\n\n")
|
||||
case "response.output_item.delta":
|
||||
item := gjson.GetBytes(payload, "item")
|
||||
if item.Get("type").String() != "function_call" {
|
||||
break
|
||||
}
|
||||
tool := resolveTool(item.Get("id").String(), int(gjson.GetBytes(payload, "output_index").Int()))
|
||||
if tool == nil {
|
||||
break
|
||||
}
|
||||
partial := gjson.GetBytes(payload, "delta").String()
|
||||
if partial == "" {
|
||||
partial = item.Get("arguments").String()
|
||||
}
|
||||
if partial == "" {
|
||||
break
|
||||
}
|
||||
inputDelta := `{"type":"content_block_delta","index":0,"delta":{"type":"input_json_delta","partial_json":""}}`
|
||||
inputDelta, _ = sjson.Set(inputDelta, "index", tool.Index)
|
||||
inputDelta, _ = sjson.Set(inputDelta, "delta.partial_json", partial)
|
||||
results = append(results, "event: content_block_delta\ndata: "+inputDelta+"\n\n")
|
||||
case "response.output_item.done":
|
||||
if gjson.GetBytes(payload, "item.type").String() != "function_call" {
|
||||
break
|
||||
}
|
||||
tool := resolveTool(gjson.GetBytes(payload, "item.id").String(), int(gjson.GetBytes(payload, "output_index").Int()))
|
||||
if tool == nil {
|
||||
break
|
||||
}
|
||||
contentBlockStop := `{"type":"content_block_stop","index":0}`
|
||||
contentBlockStop, _ = sjson.Set(contentBlockStop, "index", tool.Index)
|
||||
results = append(results, "event: content_block_stop\ndata: "+contentBlockStop+"\n\n")
|
||||
case "response.completed":
|
||||
ensureMessageStart()
|
||||
stopTextBlockIfNeeded()
|
||||
if !state.MessageStopSent {
|
||||
stopReason := "end_turn"
|
||||
if state.HasToolUse {
|
||||
stopReason = "tool_use"
|
||||
}
|
||||
messageDelta := `{"type":"message_delta","delta":{"stop_reason":"","stop_sequence":null},"usage":{"input_tokens":0,"output_tokens":0}}`
|
||||
messageDelta, _ = sjson.Set(messageDelta, "delta.stop_reason", stopReason)
|
||||
messageDelta, _ = sjson.Set(messageDelta, "usage.input_tokens", gjson.GetBytes(payload, "response.usage.input_tokens").Int())
|
||||
messageDelta, _ = sjson.Set(messageDelta, "usage.output_tokens", gjson.GetBytes(payload, "response.usage.output_tokens").Int())
|
||||
results = append(results, "event: message_delta\ndata: "+messageDelta+"\n\n")
|
||||
results = append(results, "event: message_stop\ndata: {\"type\":\"message_stop\"}\n\n")
|
||||
state.MessageStopSent = true
|
||||
}
|
||||
}
|
||||
|
||||
return results
|
||||
}
|
||||
|
||||
// isHTTPSuccess checks if the status code indicates success (2xx).
|
||||
func isHTTPSuccess(statusCode int) bool {
|
||||
return statusCode >= 200 && statusCode < 300
|
||||
|
||||
242
internal/runtime/executor/github_copilot_executor_test.go
Normal file
242
internal/runtime/executor/github_copilot_executor_test.go
Normal file
@@ -0,0 +1,242 @@
|
||||
package executor
|
||||
|
||||
import (
|
||||
"strings"
|
||||
"testing"
|
||||
|
||||
sdktranslator "github.com/router-for-me/CLIProxyAPI/v6/sdk/translator"
|
||||
"github.com/tidwall/gjson"
|
||||
)
|
||||
|
||||
func TestGitHubCopilotNormalizeModel_StripsSuffix(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
tests := []struct {
|
||||
name string
|
||||
model string
|
||||
wantModel string
|
||||
}{
|
||||
{
|
||||
name: "suffix stripped",
|
||||
model: "claude-opus-4.6(medium)",
|
||||
wantModel: "claude-opus-4.6",
|
||||
},
|
||||
{
|
||||
name: "no suffix unchanged",
|
||||
model: "claude-opus-4.6",
|
||||
wantModel: "claude-opus-4.6",
|
||||
},
|
||||
{
|
||||
name: "different suffix stripped",
|
||||
model: "gpt-4o(high)",
|
||||
wantModel: "gpt-4o",
|
||||
},
|
||||
{
|
||||
name: "numeric suffix stripped",
|
||||
model: "gemini-2.5-pro(8192)",
|
||||
wantModel: "gemini-2.5-pro",
|
||||
},
|
||||
}
|
||||
|
||||
e := &GitHubCopilotExecutor{}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
body := []byte(`{"model":"` + tt.model + `","messages":[]}`)
|
||||
got := e.normalizeModel(tt.model, body)
|
||||
|
||||
gotModel := gjson.GetBytes(got, "model").String()
|
||||
if gotModel != tt.wantModel {
|
||||
t.Fatalf("normalizeModel() model = %q, want %q", gotModel, tt.wantModel)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestUseGitHubCopilotResponsesEndpoint_OpenAIResponseSource(t *testing.T) {
|
||||
t.Parallel()
|
||||
if !useGitHubCopilotResponsesEndpoint(sdktranslator.FromString("openai-response"), "claude-3-5-sonnet") {
|
||||
t.Fatal("expected openai-response source to use /responses")
|
||||
}
|
||||
}
|
||||
|
||||
func TestUseGitHubCopilotResponsesEndpoint_CodexModel(t *testing.T) {
|
||||
t.Parallel()
|
||||
if !useGitHubCopilotResponsesEndpoint(sdktranslator.FromString("openai"), "gpt-5-codex") {
|
||||
t.Fatal("expected codex model to use /responses")
|
||||
}
|
||||
}
|
||||
|
||||
func TestUseGitHubCopilotResponsesEndpoint_DefaultChat(t *testing.T) {
|
||||
t.Parallel()
|
||||
if useGitHubCopilotResponsesEndpoint(sdktranslator.FromString("openai"), "claude-3-5-sonnet") {
|
||||
t.Fatal("expected default openai source with non-codex model to use /chat/completions")
|
||||
}
|
||||
}
|
||||
|
||||
func TestNormalizeGitHubCopilotChatTools_KeepFunctionOnly(t *testing.T) {
|
||||
t.Parallel()
|
||||
body := []byte(`{"tools":[{"type":"function","function":{"name":"ok"}},{"type":"code_interpreter"}],"tool_choice":"auto"}`)
|
||||
got := normalizeGitHubCopilotChatTools(body)
|
||||
tools := gjson.GetBytes(got, "tools").Array()
|
||||
if len(tools) != 1 {
|
||||
t.Fatalf("tools len = %d, want 1", len(tools))
|
||||
}
|
||||
if tools[0].Get("type").String() != "function" {
|
||||
t.Fatalf("tool type = %q, want function", tools[0].Get("type").String())
|
||||
}
|
||||
}
|
||||
|
||||
func TestNormalizeGitHubCopilotChatTools_InvalidToolChoiceDowngradeToAuto(t *testing.T) {
|
||||
t.Parallel()
|
||||
body := []byte(`{"tools":[],"tool_choice":{"type":"function","function":{"name":"x"}}}`)
|
||||
got := normalizeGitHubCopilotChatTools(body)
|
||||
if gjson.GetBytes(got, "tool_choice").String() != "auto" {
|
||||
t.Fatalf("tool_choice = %s, want auto", gjson.GetBytes(got, "tool_choice").Raw)
|
||||
}
|
||||
}
|
||||
|
||||
func TestNormalizeGitHubCopilotResponsesInput_MissingInputExtractedFromSystemAndMessages(t *testing.T) {
|
||||
t.Parallel()
|
||||
body := []byte(`{"system":"sys text","messages":[{"role":"user","content":"user text"},{"role":"assistant","content":[{"type":"text","text":"assistant text"}]}]}`)
|
||||
got := normalizeGitHubCopilotResponsesInput(body)
|
||||
in := gjson.GetBytes(got, "input")
|
||||
if in.Type != gjson.String {
|
||||
t.Fatalf("input type = %v, want string", in.Type)
|
||||
}
|
||||
if !strings.Contains(in.String(), "sys text") || !strings.Contains(in.String(), "user text") || !strings.Contains(in.String(), "assistant text") {
|
||||
t.Fatalf("input = %q, want merged text", in.String())
|
||||
}
|
||||
}
|
||||
|
||||
func TestNormalizeGitHubCopilotResponsesInput_NonStringInputStringified(t *testing.T) {
|
||||
t.Parallel()
|
||||
body := []byte(`{"input":{"foo":"bar"}}`)
|
||||
got := normalizeGitHubCopilotResponsesInput(body)
|
||||
in := gjson.GetBytes(got, "input")
|
||||
if in.Type != gjson.String {
|
||||
t.Fatalf("input type = %v, want string", in.Type)
|
||||
}
|
||||
if !strings.Contains(in.String(), "foo") {
|
||||
t.Fatalf("input = %q, want stringified object", in.String())
|
||||
}
|
||||
}
|
||||
|
||||
func TestNormalizeGitHubCopilotResponsesTools_FlattenFunctionTools(t *testing.T) {
|
||||
t.Parallel()
|
||||
body := []byte(`{"tools":[{"type":"function","function":{"name":"sum","description":"d","parameters":{"type":"object"}}},{"type":"web_search"}]}`)
|
||||
got := normalizeGitHubCopilotResponsesTools(body)
|
||||
tools := gjson.GetBytes(got, "tools").Array()
|
||||
if len(tools) != 1 {
|
||||
t.Fatalf("tools len = %d, want 1", len(tools))
|
||||
}
|
||||
if tools[0].Get("name").String() != "sum" {
|
||||
t.Fatalf("tools[0].name = %q, want sum", tools[0].Get("name").String())
|
||||
}
|
||||
if !tools[0].Get("parameters").Exists() {
|
||||
t.Fatal("expected parameters to be preserved")
|
||||
}
|
||||
}
|
||||
|
||||
func TestNormalizeGitHubCopilotResponsesTools_ClaudeFormatTools(t *testing.T) {
|
||||
t.Parallel()
|
||||
body := []byte(`{"tools":[{"name":"Bash","description":"Run commands","input_schema":{"type":"object","properties":{"command":{"type":"string"}},"required":["command"]}},{"name":"Read","description":"Read files","input_schema":{"type":"object","properties":{"path":{"type":"string"}}}}]}`)
|
||||
got := normalizeGitHubCopilotResponsesTools(body)
|
||||
tools := gjson.GetBytes(got, "tools").Array()
|
||||
if len(tools) != 2 {
|
||||
t.Fatalf("tools len = %d, want 2", len(tools))
|
||||
}
|
||||
if tools[0].Get("type").String() != "function" {
|
||||
t.Fatalf("tools[0].type = %q, want function", tools[0].Get("type").String())
|
||||
}
|
||||
if tools[0].Get("name").String() != "Bash" {
|
||||
t.Fatalf("tools[0].name = %q, want Bash", tools[0].Get("name").String())
|
||||
}
|
||||
if tools[0].Get("description").String() != "Run commands" {
|
||||
t.Fatalf("tools[0].description = %q, want 'Run commands'", tools[0].Get("description").String())
|
||||
}
|
||||
if !tools[0].Get("parameters").Exists() {
|
||||
t.Fatal("expected parameters to be set from input_schema")
|
||||
}
|
||||
if tools[0].Get("parameters.properties.command").Exists() != true {
|
||||
t.Fatal("expected parameters.properties.command to exist")
|
||||
}
|
||||
if tools[1].Get("name").String() != "Read" {
|
||||
t.Fatalf("tools[1].name = %q, want Read", tools[1].Get("name").String())
|
||||
}
|
||||
}
|
||||
|
||||
func TestNormalizeGitHubCopilotResponsesTools_FlattenToolChoiceFunctionObject(t *testing.T) {
|
||||
t.Parallel()
|
||||
body := []byte(`{"tool_choice":{"type":"function","function":{"name":"sum"}}}`)
|
||||
got := normalizeGitHubCopilotResponsesTools(body)
|
||||
if gjson.GetBytes(got, "tool_choice.type").String() != "function" {
|
||||
t.Fatalf("tool_choice.type = %q, want function", gjson.GetBytes(got, "tool_choice.type").String())
|
||||
}
|
||||
if gjson.GetBytes(got, "tool_choice.name").String() != "sum" {
|
||||
t.Fatalf("tool_choice.name = %q, want sum", gjson.GetBytes(got, "tool_choice.name").String())
|
||||
}
|
||||
}
|
||||
|
||||
func TestNormalizeGitHubCopilotResponsesTools_InvalidToolChoiceDowngradeToAuto(t *testing.T) {
|
||||
t.Parallel()
|
||||
body := []byte(`{"tool_choice":{"type":"function"}}`)
|
||||
got := normalizeGitHubCopilotResponsesTools(body)
|
||||
if gjson.GetBytes(got, "tool_choice").String() != "auto" {
|
||||
t.Fatalf("tool_choice = %s, want auto", gjson.GetBytes(got, "tool_choice").Raw)
|
||||
}
|
||||
}
|
||||
|
||||
func TestTranslateGitHubCopilotResponsesNonStreamToClaude_TextMapping(t *testing.T) {
|
||||
t.Parallel()
|
||||
resp := []byte(`{"id":"resp_1","model":"gpt-5-codex","output":[{"type":"message","content":[{"type":"output_text","text":"hello"}]}],"usage":{"input_tokens":3,"output_tokens":5}}`)
|
||||
out := translateGitHubCopilotResponsesNonStreamToClaude(resp)
|
||||
if gjson.Get(out, "type").String() != "message" {
|
||||
t.Fatalf("type = %q, want message", gjson.Get(out, "type").String())
|
||||
}
|
||||
if gjson.Get(out, "content.0.type").String() != "text" {
|
||||
t.Fatalf("content.0.type = %q, want text", gjson.Get(out, "content.0.type").String())
|
||||
}
|
||||
if gjson.Get(out, "content.0.text").String() != "hello" {
|
||||
t.Fatalf("content.0.text = %q, want hello", gjson.Get(out, "content.0.text").String())
|
||||
}
|
||||
}
|
||||
|
||||
func TestTranslateGitHubCopilotResponsesNonStreamToClaude_ToolUseMapping(t *testing.T) {
|
||||
t.Parallel()
|
||||
resp := []byte(`{"id":"resp_2","model":"gpt-5-codex","output":[{"type":"function_call","id":"fc_1","call_id":"call_1","name":"sum","arguments":"{\"a\":1}"}],"usage":{"input_tokens":1,"output_tokens":2}}`)
|
||||
out := translateGitHubCopilotResponsesNonStreamToClaude(resp)
|
||||
if gjson.Get(out, "content.0.type").String() != "tool_use" {
|
||||
t.Fatalf("content.0.type = %q, want tool_use", gjson.Get(out, "content.0.type").String())
|
||||
}
|
||||
if gjson.Get(out, "content.0.name").String() != "sum" {
|
||||
t.Fatalf("content.0.name = %q, want sum", gjson.Get(out, "content.0.name").String())
|
||||
}
|
||||
if gjson.Get(out, "stop_reason").String() != "tool_use" {
|
||||
t.Fatalf("stop_reason = %q, want tool_use", gjson.Get(out, "stop_reason").String())
|
||||
}
|
||||
}
|
||||
|
||||
func TestTranslateGitHubCopilotResponsesStreamToClaude_TextLifecycle(t *testing.T) {
|
||||
t.Parallel()
|
||||
var param any
|
||||
|
||||
created := translateGitHubCopilotResponsesStreamToClaude([]byte(`data: {"type":"response.created","response":{"id":"resp_1","model":"gpt-5-codex"}}`), ¶m)
|
||||
if len(created) == 0 || !strings.Contains(created[0], "message_start") {
|
||||
t.Fatalf("created events = %#v, want message_start", created)
|
||||
}
|
||||
|
||||
delta := translateGitHubCopilotResponsesStreamToClaude([]byte(`data: {"type":"response.output_text.delta","delta":"he"}`), ¶m)
|
||||
joinedDelta := strings.Join(delta, "")
|
||||
if !strings.Contains(joinedDelta, "content_block_start") || !strings.Contains(joinedDelta, "text_delta") {
|
||||
t.Fatalf("delta events = %#v, want content_block_start + text_delta", delta)
|
||||
}
|
||||
|
||||
completed := translateGitHubCopilotResponsesStreamToClaude([]byte(`data: {"type":"response.completed","response":{"usage":{"input_tokens":7,"output_tokens":9}}}`), ¶m)
|
||||
joinedCompleted := strings.Join(completed, "")
|
||||
if !strings.Contains(joinedCompleted, "message_delta") || !strings.Contains(joinedCompleted, "message_stop") {
|
||||
t.Fatalf("completed events = %#v, want message_delta + message_stop", completed)
|
||||
}
|
||||
}
|
||||
@@ -16,6 +16,7 @@ import (
|
||||
"path/filepath"
|
||||
"strings"
|
||||
"sync"
|
||||
"sync/atomic"
|
||||
"syscall"
|
||||
"time"
|
||||
|
||||
@@ -385,6 +386,35 @@ func buildKiroEndpointConfigs(region string) []kiroEndpointConfig {
|
||||
}
|
||||
}
|
||||
|
||||
// resolveKiroAPIRegion determines the AWS region for Kiro API calls.
|
||||
// Region priority:
|
||||
// 1. auth.Metadata["api_region"] - explicit API region override
|
||||
// 2. ProfileARN region - extracted from arn:aws:service:REGION:account:resource
|
||||
// 3. kiroDefaultRegion (us-east-1) - fallback
|
||||
// Note: OIDC "region" is NOT used - it's for token refresh, not API calls
|
||||
func resolveKiroAPIRegion(auth *cliproxyauth.Auth) string {
|
||||
if auth == nil || auth.Metadata == nil {
|
||||
return kiroDefaultRegion
|
||||
}
|
||||
// Priority 1: Explicit api_region override
|
||||
if r, ok := auth.Metadata["api_region"].(string); ok && r != "" {
|
||||
log.Debugf("kiro: using region %s (source: api_region)", r)
|
||||
return r
|
||||
}
|
||||
// Priority 2: Extract from ProfileARN
|
||||
if profileArn, ok := auth.Metadata["profile_arn"].(string); ok && profileArn != "" {
|
||||
if arnRegion := extractRegionFromProfileARN(profileArn); arnRegion != "" {
|
||||
log.Debugf("kiro: using region %s (source: profile_arn)", arnRegion)
|
||||
return arnRegion
|
||||
}
|
||||
}
|
||||
// Note: OIDC "region" field is NOT used for API endpoint
|
||||
// Kiro API only exists in us-east-1, while OIDC region can vary (e.g., ap-northeast-2)
|
||||
// Using OIDC region for API calls causes DNS failures
|
||||
log.Debugf("kiro: using region %s (source: default)", kiroDefaultRegion)
|
||||
return kiroDefaultRegion
|
||||
}
|
||||
|
||||
// kiroEndpointConfigs is kept for backward compatibility with default us-east-1 region.
|
||||
// Prefer using buildKiroEndpointConfigs(region) for dynamic region support.
|
||||
var kiroEndpointConfigs = buildKiroEndpointConfigs(kiroDefaultRegion)
|
||||
@@ -403,30 +433,8 @@ func getKiroEndpointConfigs(auth *cliproxyauth.Auth) []kiroEndpointConfig {
|
||||
return kiroEndpointConfigs
|
||||
}
|
||||
|
||||
// Determine API region with priority: api_region > profile_arn > region > default
|
||||
region := kiroDefaultRegion
|
||||
regionSource := "default"
|
||||
|
||||
if auth.Metadata != nil {
|
||||
// Priority 1: Explicit api_region override
|
||||
if r, ok := auth.Metadata["api_region"].(string); ok && r != "" {
|
||||
region = r
|
||||
regionSource = "api_region"
|
||||
} else {
|
||||
// Priority 2: Extract from ProfileARN
|
||||
if profileArn, ok := auth.Metadata["profile_arn"].(string); ok && profileArn != "" {
|
||||
if arnRegion := extractRegionFromProfileARN(profileArn); arnRegion != "" {
|
||||
region = arnRegion
|
||||
regionSource = "profile_arn"
|
||||
}
|
||||
}
|
||||
// Note: OIDC "region" field is NOT used for API endpoint
|
||||
// Kiro API only exists in us-east-1, while OIDC region can vary (e.g., ap-northeast-2)
|
||||
// Using OIDC region for API calls causes DNS failures
|
||||
}
|
||||
}
|
||||
|
||||
log.Debugf("kiro: using region %s (source: %s)", region, regionSource)
|
||||
// Determine API region using shared resolution logic
|
||||
region := resolveKiroAPIRegion(auth)
|
||||
|
||||
// Build endpoint configs for the specified region
|
||||
endpointConfigs := buildKiroEndpointConfigs(region)
|
||||
@@ -519,8 +527,12 @@ func buildKiroPayloadForFormat(body []byte, modelID, profileArn, origin string,
|
||||
case "openai":
|
||||
log.Debugf("kiro: using OpenAI payload builder for source format: %s", sourceFormat.String())
|
||||
return kiroopenai.BuildKiroPayloadFromOpenAI(body, modelID, profileArn, origin, isAgentic, isChatOnly, headers, nil)
|
||||
case "kiro":
|
||||
// Body is already in Kiro format — pass through directly
|
||||
log.Debugf("kiro: body already in Kiro format, passing through directly")
|
||||
return body, false
|
||||
default:
|
||||
// Default to Claude format (also handles "claude", "kiro", etc.)
|
||||
// Default to Claude format
|
||||
log.Debugf("kiro: using Claude payload builder for source format: %s", sourceFormat.String())
|
||||
return kiroclaude.BuildKiroPayload(body, modelID, profileArn, origin, isAgentic, isChatOnly, headers, nil)
|
||||
}
|
||||
@@ -636,10 +648,7 @@ func (e *KiroExecutor) Execute(ctx context.Context, auth *cliproxyauth.Auth, req
|
||||
rateLimiter.WaitForToken(tokenKey)
|
||||
log.Debugf("kiro: rate limiter cleared for token %s", tokenKey)
|
||||
|
||||
reporter := newUsageReporter(ctx, e.Identifier(), req.Model, auth)
|
||||
defer reporter.trackFailure(ctx, &err)
|
||||
|
||||
// Check if token is expired before making request
|
||||
// Check if token is expired before making request (covers both normal and web_search paths)
|
||||
if e.isTokenExpired(accessToken) {
|
||||
log.Infof("kiro: access token expired, attempting recovery")
|
||||
|
||||
@@ -668,6 +677,16 @@ func (e *KiroExecutor) Execute(ctx context.Context, auth *cliproxyauth.Auth, req
|
||||
}
|
||||
}
|
||||
|
||||
// Check for pure web_search request
|
||||
// Route to MCP endpoint instead of normal Kiro API
|
||||
if kiroclaude.HasWebSearchTool(req.Payload) {
|
||||
log.Infof("kiro: detected pure web_search request (non-stream), routing to MCP endpoint")
|
||||
return e.handleWebSearch(ctx, auth, req, opts, accessToken, profileArn)
|
||||
}
|
||||
|
||||
reporter := newUsageReporter(ctx, e.Identifier(), req.Model, auth)
|
||||
defer reporter.trackFailure(ctx, &err)
|
||||
|
||||
from := opts.SourceFormat
|
||||
to := sdktranslator.FromString("kiro")
|
||||
body := sdktranslator.TranslateRequest(from, to, req.Model, bytes.Clone(req.Payload), true)
|
||||
@@ -1014,8 +1033,9 @@ func (e *KiroExecutor) executeWithRetry(ctx context.Context, auth *cliproxyauth.
|
||||
|
||||
// Build response in Claude format for Kiro translator
|
||||
// stopReason is extracted from upstream response by parseEventStream
|
||||
kiroResponse := kiroclaude.BuildClaudeResponse(content, toolUses, req.Model, usageInfo, stopReason)
|
||||
out := sdktranslator.TranslateNonStream(ctx, to, from, req.Model, bytes.Clone(opts.OriginalRequest), body, kiroResponse, nil)
|
||||
requestedModel := payloadRequestedModel(opts, req.Model)
|
||||
kiroResponse := kiroclaude.BuildClaudeResponse(content, toolUses, requestedModel, usageInfo, stopReason)
|
||||
out := sdktranslator.TranslateNonStream(ctx, to, from, requestedModel, bytes.Clone(opts.OriginalRequest), body, kiroResponse, nil)
|
||||
resp = cliproxyexecutor.Response{Payload: []byte(out)}
|
||||
return resp, nil
|
||||
}
|
||||
@@ -1057,10 +1077,7 @@ func (e *KiroExecutor) ExecuteStream(ctx context.Context, auth *cliproxyauth.Aut
|
||||
rateLimiter.WaitForToken(tokenKey)
|
||||
log.Debugf("kiro: stream rate limiter cleared for token %s", tokenKey)
|
||||
|
||||
reporter := newUsageReporter(ctx, e.Identifier(), req.Model, auth)
|
||||
defer reporter.trackFailure(ctx, &err)
|
||||
|
||||
// Check if token is expired before making request
|
||||
// Check if token is expired before making request (covers both normal and web_search paths)
|
||||
if e.isTokenExpired(accessToken) {
|
||||
log.Infof("kiro: access token expired, attempting recovery before stream request")
|
||||
|
||||
@@ -1089,6 +1106,16 @@ func (e *KiroExecutor) ExecuteStream(ctx context.Context, auth *cliproxyauth.Aut
|
||||
}
|
||||
}
|
||||
|
||||
// Check for pure web_search request
|
||||
// Route to MCP endpoint instead of normal Kiro API
|
||||
if kiroclaude.HasWebSearchTool(req.Payload) {
|
||||
log.Infof("kiro: detected pure web_search request, routing to MCP endpoint")
|
||||
return e.handleWebSearchStream(ctx, auth, req, opts, accessToken, profileArn)
|
||||
}
|
||||
|
||||
reporter := newUsageReporter(ctx, e.Identifier(), req.Model, auth)
|
||||
defer reporter.trackFailure(ctx, &err)
|
||||
|
||||
from := opts.SourceFormat
|
||||
to := sdktranslator.FromString("kiro")
|
||||
body := sdktranslator.TranslateRequest(from, to, req.Model, bytes.Clone(req.Payload), true)
|
||||
@@ -1405,7 +1432,7 @@ func (e *KiroExecutor) executeStreamWithRetry(ctx context.Context, auth *cliprox
|
||||
// So we always enable thinking parsing for Kiro responses
|
||||
log.Debugf("kiro: stream thinkingEnabled = %v (always true for Kiro)", thinkingEnabled)
|
||||
|
||||
e.streamToChannel(ctx, resp.Body, out, from, req.Model, opts.OriginalRequest, body, reporter, thinkingEnabled)
|
||||
e.streamToChannel(ctx, resp.Body, out, from, payloadRequestedModel(opts, req.Model), opts.OriginalRequest, body, reporter, thinkingEnabled)
|
||||
}(httpResp, thinkingEnabled)
|
||||
|
||||
return out, nil
|
||||
@@ -4096,6 +4123,659 @@ func (e *KiroExecutor) isTokenExpired(accessToken string) bool {
|
||||
return isExpired
|
||||
}
|
||||
|
||||
// NOTE: Message merging functions moved to internal/translator/kiro/common/message_merge.go
|
||||
// NOTE: Tool calling support functions moved to internal/translator/kiro/claude/kiro_claude_tools.go
|
||||
// The executor now uses kiroclaude.* and kirocommon.* functions instead
|
||||
// ══════════════════════════════════════════════════════════════════════════════
|
||||
// Web Search Handler (MCP API)
|
||||
// ══════════════════════════════════════════════════════════════════════════════
|
||||
|
||||
// fetchToolDescription caching:
|
||||
// Uses a mutex + fetched flag to ensure only one goroutine fetches at a time,
|
||||
// with automatic retry on failure:
|
||||
// - On failure, fetched stays false so subsequent calls will retry
|
||||
// - On success, fetched is set to true — subsequent calls skip immediately (mutex-free fast path)
|
||||
// The cached description is stored in the translator package via kiroclaude.SetWebSearchDescription(),
|
||||
// enabling the translator's convertClaudeToolsToKiro to read it when building Kiro requests.
|
||||
var (
|
||||
toolDescMu sync.Mutex
|
||||
toolDescFetched atomic.Bool
|
||||
)
|
||||
|
||||
// fetchToolDescription calls MCP tools/list to get the web_search tool description
|
||||
// and caches it. Safe to call concurrently — only one goroutine fetches at a time.
|
||||
// If the fetch fails, subsequent calls will retry. On success, no further fetches occur.
|
||||
// The httpClient parameter allows reusing a shared pooled HTTP client.
|
||||
func fetchToolDescription(ctx context.Context, mcpEndpoint, authToken string, httpClient *http.Client, auth *cliproxyauth.Auth, authAttrs map[string]string) {
|
||||
// Fast path: already fetched successfully, no lock needed
|
||||
if toolDescFetched.Load() {
|
||||
return
|
||||
}
|
||||
|
||||
toolDescMu.Lock()
|
||||
defer toolDescMu.Unlock()
|
||||
|
||||
// Double-check after acquiring lock
|
||||
if toolDescFetched.Load() {
|
||||
return
|
||||
}
|
||||
|
||||
handler := newWebSearchHandler(ctx, mcpEndpoint, authToken, httpClient, auth, authAttrs)
|
||||
reqBody := []byte(`{"id":"tools_list","jsonrpc":"2.0","method":"tools/list"}`)
|
||||
log.Debugf("kiro/websearch MCP tools/list request: %d bytes", len(reqBody))
|
||||
|
||||
req, err := http.NewRequestWithContext(ctx, "POST", mcpEndpoint, bytes.NewReader(reqBody))
|
||||
if err != nil {
|
||||
log.Warnf("kiro/websearch: failed to create tools/list request: %v", err)
|
||||
return
|
||||
}
|
||||
|
||||
// Reuse same headers as callMcpAPI
|
||||
handler.setMcpHeaders(req)
|
||||
|
||||
resp, err := handler.httpClient.Do(req)
|
||||
if err != nil {
|
||||
log.Warnf("kiro/websearch: tools/list request failed: %v", err)
|
||||
return
|
||||
}
|
||||
defer resp.Body.Close()
|
||||
|
||||
body, err := io.ReadAll(resp.Body)
|
||||
if err != nil || resp.StatusCode != http.StatusOK {
|
||||
log.Warnf("kiro/websearch: tools/list returned status %d", resp.StatusCode)
|
||||
return
|
||||
}
|
||||
log.Debugf("kiro/websearch MCP tools/list response: [%d] %d bytes", resp.StatusCode, len(body))
|
||||
|
||||
// Parse: {"result":{"tools":[{"name":"web_search","description":"..."}]}}
|
||||
var result struct {
|
||||
Result *struct {
|
||||
Tools []struct {
|
||||
Name string `json:"name"`
|
||||
Description string `json:"description"`
|
||||
} `json:"tools"`
|
||||
} `json:"result"`
|
||||
}
|
||||
if err := json.Unmarshal(body, &result); err != nil || result.Result == nil {
|
||||
log.Warnf("kiro/websearch: failed to parse tools/list response")
|
||||
return
|
||||
}
|
||||
|
||||
for _, tool := range result.Result.Tools {
|
||||
if tool.Name == "web_search" && tool.Description != "" {
|
||||
kiroclaude.SetWebSearchDescription(tool.Description)
|
||||
toolDescFetched.Store(true) // success — no more fetches
|
||||
log.Infof("kiro/websearch: cached web_search description from tools/list (%d bytes)", len(tool.Description))
|
||||
return
|
||||
}
|
||||
}
|
||||
|
||||
// web_search tool not found in response
|
||||
log.Warnf("kiro/websearch: web_search tool not found in tools/list response")
|
||||
}
|
||||
|
||||
// webSearchHandler handles web search requests via Kiro MCP API
|
||||
type webSearchHandler struct {
|
||||
ctx context.Context
|
||||
mcpEndpoint string
|
||||
httpClient *http.Client
|
||||
authToken string
|
||||
auth *cliproxyauth.Auth // for applyDynamicFingerprint
|
||||
authAttrs map[string]string // optional, for custom headers from auth.Attributes
|
||||
}
|
||||
|
||||
// newWebSearchHandler creates a new webSearchHandler.
|
||||
// If httpClient is nil, a default client with 30s timeout is used.
|
||||
// Pass a shared pooled client (e.g. from getKiroPooledHTTPClient) for connection reuse.
|
||||
func newWebSearchHandler(ctx context.Context, mcpEndpoint, authToken string, httpClient *http.Client, auth *cliproxyauth.Auth, authAttrs map[string]string) *webSearchHandler {
|
||||
if httpClient == nil {
|
||||
httpClient = &http.Client{
|
||||
Timeout: 30 * time.Second,
|
||||
}
|
||||
}
|
||||
return &webSearchHandler{
|
||||
ctx: ctx,
|
||||
mcpEndpoint: mcpEndpoint,
|
||||
httpClient: httpClient,
|
||||
authToken: authToken,
|
||||
auth: auth,
|
||||
authAttrs: authAttrs,
|
||||
}
|
||||
}
|
||||
|
||||
// setMcpHeaders sets standard MCP API headers on the request,
|
||||
// aligned with the GAR request pattern.
|
||||
func (h *webSearchHandler) setMcpHeaders(req *http.Request) {
|
||||
// 1. Content-Type & Accept (aligned with GAR)
|
||||
req.Header.Set("Content-Type", "application/json")
|
||||
req.Header.Set("Accept", "*/*")
|
||||
|
||||
// 2. Kiro-specific headers (aligned with GAR)
|
||||
req.Header.Set("x-amzn-kiro-agent-mode", "vibe")
|
||||
req.Header.Set("x-amzn-codewhisperer-optout", "true")
|
||||
|
||||
// 3. User-Agent: Reuse applyDynamicFingerprint for consistency
|
||||
applyDynamicFingerprint(req, h.auth)
|
||||
|
||||
// 4. AWS SDK identifiers
|
||||
req.Header.Set("Amz-Sdk-Request", "attempt=1; max=3")
|
||||
req.Header.Set("Amz-Sdk-Invocation-Id", uuid.New().String())
|
||||
|
||||
// 5. Authentication
|
||||
req.Header.Set("Authorization", "Bearer "+h.authToken)
|
||||
|
||||
// 6. Custom headers from auth attributes
|
||||
util.ApplyCustomHeadersFromAttrs(req, h.authAttrs)
|
||||
}
|
||||
|
||||
// mcpMaxRetries is the maximum number of retries for MCP API calls.
|
||||
const mcpMaxRetries = 2
|
||||
|
||||
// callMcpAPI calls the Kiro MCP API with the given request.
|
||||
// Includes retry logic with exponential backoff for retryable errors.
|
||||
func (h *webSearchHandler) callMcpAPI(request *kiroclaude.McpRequest) (*kiroclaude.McpResponse, error) {
|
||||
requestBody, err := json.Marshal(request)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to marshal MCP request: %w", err)
|
||||
}
|
||||
log.Debugf("kiro/websearch MCP request → %s (%d bytes)", h.mcpEndpoint, len(requestBody))
|
||||
|
||||
var lastErr error
|
||||
for attempt := 0; attempt <= mcpMaxRetries; attempt++ {
|
||||
if attempt > 0 {
|
||||
backoff := time.Duration(1<<attempt) * time.Second
|
||||
if backoff > 10*time.Second {
|
||||
backoff = 10 * time.Second
|
||||
}
|
||||
log.Warnf("kiro/websearch: MCP retry %d/%d after %v (last error: %v)", attempt, mcpMaxRetries, backoff, lastErr)
|
||||
select {
|
||||
case <-h.ctx.Done():
|
||||
return nil, h.ctx.Err()
|
||||
case <-time.After(backoff):
|
||||
}
|
||||
}
|
||||
|
||||
req, err := http.NewRequestWithContext(h.ctx, "POST", h.mcpEndpoint, bytes.NewReader(requestBody))
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to create HTTP request: %w", err)
|
||||
}
|
||||
|
||||
h.setMcpHeaders(req)
|
||||
|
||||
resp, err := h.httpClient.Do(req)
|
||||
if err != nil {
|
||||
lastErr = fmt.Errorf("MCP API request failed: %w", err)
|
||||
continue // network error → retry
|
||||
}
|
||||
|
||||
body, err := io.ReadAll(resp.Body)
|
||||
resp.Body.Close()
|
||||
if err != nil {
|
||||
lastErr = fmt.Errorf("failed to read MCP response: %w", err)
|
||||
continue // read error → retry
|
||||
}
|
||||
log.Debugf("kiro/websearch MCP response ← [%d] (%d bytes)", resp.StatusCode, len(body))
|
||||
|
||||
// Retryable HTTP status codes (aligned with GAR: 502, 503, 504)
|
||||
if resp.StatusCode >= 502 && resp.StatusCode <= 504 {
|
||||
lastErr = fmt.Errorf("MCP API returned retryable status %d: %s", resp.StatusCode, string(body))
|
||||
continue
|
||||
}
|
||||
|
||||
if resp.StatusCode != http.StatusOK {
|
||||
return nil, fmt.Errorf("MCP API returned status %d: %s", resp.StatusCode, string(body))
|
||||
}
|
||||
|
||||
var mcpResponse kiroclaude.McpResponse
|
||||
if err := json.Unmarshal(body, &mcpResponse); err != nil {
|
||||
return nil, fmt.Errorf("failed to parse MCP response: %w", err)
|
||||
}
|
||||
|
||||
if mcpResponse.Error != nil {
|
||||
code := -1
|
||||
if mcpResponse.Error.Code != nil {
|
||||
code = *mcpResponse.Error.Code
|
||||
}
|
||||
msg := "Unknown error"
|
||||
if mcpResponse.Error.Message != nil {
|
||||
msg = *mcpResponse.Error.Message
|
||||
}
|
||||
return nil, fmt.Errorf("MCP error %d: %s", code, msg)
|
||||
}
|
||||
|
||||
return &mcpResponse, nil
|
||||
}
|
||||
|
||||
return nil, lastErr
|
||||
}
|
||||
|
||||
// webSearchAuthAttrs extracts auth attributes for MCP calls.
|
||||
// Used by handleWebSearch and handleWebSearchStream to pass custom headers.
|
||||
func webSearchAuthAttrs(auth *cliproxyauth.Auth) map[string]string {
|
||||
if auth != nil {
|
||||
return auth.Attributes
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
const maxWebSearchIterations = 5
|
||||
|
||||
// handleWebSearchStream handles web_search requests:
|
||||
// Step 1: tools/list (sync) → fetch/cache tool description
|
||||
// Step 2+: MCP search → InjectToolResultsClaude → callKiroAndBuffer loop
|
||||
// Note: We skip the "model decides to search" step because Claude Code already
|
||||
// decided to use web_search. The Kiro tool description restricts non-coding
|
||||
// topics, so asking the model again would cause it to refuse valid searches.
|
||||
func (e *KiroExecutor) handleWebSearchStream(
|
||||
ctx context.Context,
|
||||
auth *cliproxyauth.Auth,
|
||||
req cliproxyexecutor.Request,
|
||||
opts cliproxyexecutor.Options,
|
||||
accessToken, profileArn string,
|
||||
) (<-chan cliproxyexecutor.StreamChunk, error) {
|
||||
// Extract search query from Claude Code's web_search tool_use
|
||||
query := kiroclaude.ExtractSearchQuery(req.Payload)
|
||||
if query == "" {
|
||||
log.Warnf("kiro/websearch: failed to extract search query, falling back to normal flow")
|
||||
return e.callKiroDirectStream(ctx, auth, req, opts, accessToken, profileArn)
|
||||
}
|
||||
|
||||
// Build MCP endpoint using shared region resolution (supports api_region + ProfileARN fallback)
|
||||
region := resolveKiroAPIRegion(auth)
|
||||
mcpEndpoint := kiroclaude.BuildMcpEndpoint(region)
|
||||
|
||||
// ── Step 1: tools/list (SYNC) — cache tool description ──
|
||||
{
|
||||
authAttrs := webSearchAuthAttrs(auth)
|
||||
fetchToolDescription(ctx, mcpEndpoint, accessToken, newKiroHTTPClientWithPooling(ctx, e.cfg, auth, 30*time.Second), auth, authAttrs)
|
||||
}
|
||||
|
||||
// Create output channel
|
||||
out := make(chan cliproxyexecutor.StreamChunk)
|
||||
|
||||
// Usage reporting: track web search requests like normal streaming requests
|
||||
reporter := newUsageReporter(ctx, e.Identifier(), req.Model, auth)
|
||||
|
||||
go func() {
|
||||
var wsErr error
|
||||
defer reporter.trackFailure(ctx, &wsErr)
|
||||
defer close(out)
|
||||
|
||||
// Estimate input tokens using tokenizer (matching streamToChannel pattern)
|
||||
var totalUsage usage.Detail
|
||||
if enc, tokErr := getTokenizer(req.Model); tokErr == nil {
|
||||
if inp, e := countClaudeChatTokens(enc, req.Payload); e == nil && inp > 0 {
|
||||
totalUsage.InputTokens = inp
|
||||
} else {
|
||||
totalUsage.InputTokens = int64(len(req.Payload) / 4)
|
||||
}
|
||||
} else {
|
||||
totalUsage.InputTokens = int64(len(req.Payload) / 4)
|
||||
}
|
||||
if totalUsage.InputTokens == 0 && len(req.Payload) > 0 {
|
||||
totalUsage.InputTokens = 1
|
||||
}
|
||||
var accumulatedOutputLen int
|
||||
defer func() {
|
||||
if wsErr != nil {
|
||||
return // let trackFailure handle failure reporting
|
||||
}
|
||||
totalUsage.OutputTokens = int64(accumulatedOutputLen / 4)
|
||||
if accumulatedOutputLen > 0 && totalUsage.OutputTokens == 0 {
|
||||
totalUsage.OutputTokens = 1
|
||||
}
|
||||
reporter.publish(ctx, totalUsage)
|
||||
}()
|
||||
|
||||
// Send message_start event to client (aligned with streamToChannel pattern)
|
||||
// Use payloadRequestedModel to return user's original model alias
|
||||
msgStart := kiroclaude.BuildClaudeMessageStartEvent(
|
||||
payloadRequestedModel(opts, req.Model),
|
||||
totalUsage.InputTokens,
|
||||
)
|
||||
select {
|
||||
case <-ctx.Done():
|
||||
return
|
||||
case out <- cliproxyexecutor.StreamChunk{Payload: append(msgStart, '\n', '\n')}:
|
||||
}
|
||||
|
||||
// ── Step 2+: MCP search → InjectToolResultsClaude → callKiroAndBuffer loop ──
|
||||
contentBlockIndex := 0
|
||||
currentQuery := query
|
||||
|
||||
// Replace web_search tool description with a minimal one that allows re-search.
|
||||
// The original tools/list description from Kiro restricts non-coding topics,
|
||||
// but we've already decided to search. We keep the tool so the model can
|
||||
// request additional searches when results are insufficient.
|
||||
simplifiedPayload, simplifyErr := kiroclaude.ReplaceWebSearchToolDescription(bytes.Clone(req.Payload))
|
||||
if simplifyErr != nil {
|
||||
log.Warnf("kiro/websearch: failed to simplify web_search tool: %v, using original payload", simplifyErr)
|
||||
simplifiedPayload = bytes.Clone(req.Payload)
|
||||
}
|
||||
|
||||
currentClaudePayload := simplifiedPayload
|
||||
totalSearches := 0
|
||||
|
||||
// Generate toolUseId for the first iteration (Claude Code already decided to search)
|
||||
currentToolUseId := fmt.Sprintf("srvtoolu_%s", kiroclaude.GenerateToolUseID())
|
||||
|
||||
for iteration := 0; iteration < maxWebSearchIterations; iteration++ {
|
||||
log.Infof("kiro/websearch: search iteration %d/%d",
|
||||
iteration+1, maxWebSearchIterations)
|
||||
|
||||
// MCP search
|
||||
_, mcpRequest := kiroclaude.CreateMcpRequest(currentQuery)
|
||||
|
||||
authAttrs := webSearchAuthAttrs(auth)
|
||||
handler := newWebSearchHandler(ctx, mcpEndpoint, accessToken, newKiroHTTPClientWithPooling(ctx, e.cfg, auth, 30*time.Second), auth, authAttrs)
|
||||
mcpResponse, mcpErr := handler.callMcpAPI(mcpRequest)
|
||||
|
||||
var searchResults *kiroclaude.WebSearchResults
|
||||
if mcpErr != nil {
|
||||
log.Warnf("kiro/websearch: MCP API call failed: %v, continuing with empty results", mcpErr)
|
||||
} else {
|
||||
searchResults = kiroclaude.ParseSearchResults(mcpResponse)
|
||||
}
|
||||
|
||||
resultCount := 0
|
||||
if searchResults != nil {
|
||||
resultCount = len(searchResults.Results)
|
||||
}
|
||||
totalSearches++
|
||||
log.Infof("kiro/websearch: iteration %d — got %d search results", iteration+1, resultCount)
|
||||
|
||||
// Send search indicator events to client
|
||||
searchEvents := kiroclaude.GenerateSearchIndicatorEvents(currentQuery, currentToolUseId, searchResults, contentBlockIndex)
|
||||
for _, event := range searchEvents {
|
||||
select {
|
||||
case <-ctx.Done():
|
||||
return
|
||||
case out <- cliproxyexecutor.StreamChunk{Payload: event}:
|
||||
}
|
||||
}
|
||||
contentBlockIndex += 2
|
||||
|
||||
// Inject tool_use + tool_result into Claude payload, then call GAR
|
||||
var err error
|
||||
currentClaudePayload, err = kiroclaude.InjectToolResultsClaude(currentClaudePayload, currentToolUseId, currentQuery, searchResults)
|
||||
if err != nil {
|
||||
log.Warnf("kiro/websearch: failed to inject tool results: %v", err)
|
||||
wsErr = fmt.Errorf("failed to inject tool results: %w", err)
|
||||
e.sendFallbackText(ctx, out, contentBlockIndex, currentQuery, searchResults)
|
||||
return
|
||||
}
|
||||
|
||||
// Call GAR with modified Claude payload (full translation pipeline)
|
||||
modifiedReq := req
|
||||
modifiedReq.Payload = currentClaudePayload
|
||||
kiroChunks, kiroErr := e.callKiroAndBuffer(ctx, auth, modifiedReq, opts, accessToken, profileArn)
|
||||
if kiroErr != nil {
|
||||
log.Warnf("kiro/websearch: Kiro API failed at iteration %d: %v", iteration+1, kiroErr)
|
||||
wsErr = fmt.Errorf("Kiro API failed at iteration %d: %w", iteration+1, kiroErr)
|
||||
e.sendFallbackText(ctx, out, contentBlockIndex, currentQuery, searchResults)
|
||||
return
|
||||
}
|
||||
|
||||
// Analyze response
|
||||
analysis := kiroclaude.AnalyzeBufferedStream(kiroChunks)
|
||||
log.Infof("kiro/websearch: iteration %d — stop_reason: %s, has_tool_use: %v",
|
||||
iteration+1, analysis.StopReason, analysis.HasWebSearchToolUse)
|
||||
|
||||
if analysis.HasWebSearchToolUse && analysis.WebSearchQuery != "" && iteration+1 < maxWebSearchIterations {
|
||||
// Model wants another search
|
||||
filteredChunks := kiroclaude.FilterChunksForClient(kiroChunks, analysis.WebSearchToolUseIndex, contentBlockIndex)
|
||||
for _, chunk := range filteredChunks {
|
||||
select {
|
||||
case <-ctx.Done():
|
||||
return
|
||||
case out <- cliproxyexecutor.StreamChunk{Payload: chunk}:
|
||||
}
|
||||
}
|
||||
|
||||
currentQuery = analysis.WebSearchQuery
|
||||
currentToolUseId = analysis.WebSearchToolUseId
|
||||
continue
|
||||
}
|
||||
|
||||
// Model returned final response — stream to client
|
||||
for _, chunk := range kiroChunks {
|
||||
if contentBlockIndex > 0 && len(chunk) > 0 {
|
||||
adjusted, shouldForward := kiroclaude.AdjustSSEChunk(chunk, contentBlockIndex)
|
||||
if !shouldForward {
|
||||
continue
|
||||
}
|
||||
accumulatedOutputLen += len(adjusted)
|
||||
select {
|
||||
case <-ctx.Done():
|
||||
return
|
||||
case out <- cliproxyexecutor.StreamChunk{Payload: adjusted}:
|
||||
}
|
||||
} else {
|
||||
accumulatedOutputLen += len(chunk)
|
||||
select {
|
||||
case <-ctx.Done():
|
||||
return
|
||||
case out <- cliproxyexecutor.StreamChunk{Payload: chunk}:
|
||||
}
|
||||
}
|
||||
}
|
||||
log.Infof("kiro/websearch: completed after %d search iteration(s), total searches: %d", iteration+1, totalSearches)
|
||||
return
|
||||
}
|
||||
|
||||
log.Warnf("kiro/websearch: reached max iterations (%d), stopping search loop", maxWebSearchIterations)
|
||||
}()
|
||||
|
||||
return out, nil
|
||||
}
|
||||
|
||||
// handleWebSearch handles web_search requests for non-streaming Execute path.
|
||||
// Performs MCP search synchronously, injects results into the request payload,
|
||||
// then calls the normal non-streaming Kiro API path which returns a proper
|
||||
// Claude JSON response (not SSE chunks).
|
||||
func (e *KiroExecutor) handleWebSearch(
|
||||
ctx context.Context,
|
||||
auth *cliproxyauth.Auth,
|
||||
req cliproxyexecutor.Request,
|
||||
opts cliproxyexecutor.Options,
|
||||
accessToken, profileArn string,
|
||||
) (cliproxyexecutor.Response, error) {
|
||||
// Extract search query from Claude Code's web_search tool_use
|
||||
query := kiroclaude.ExtractSearchQuery(req.Payload)
|
||||
if query == "" {
|
||||
log.Warnf("kiro/websearch: non-stream: failed to extract search query, falling back to normal Execute")
|
||||
// Fall through to normal non-streaming path
|
||||
return e.executeNonStreamFallback(ctx, auth, req, opts, accessToken, profileArn)
|
||||
}
|
||||
|
||||
// Build MCP endpoint using shared region resolution (supports api_region + ProfileARN fallback)
|
||||
region := resolveKiroAPIRegion(auth)
|
||||
mcpEndpoint := kiroclaude.BuildMcpEndpoint(region)
|
||||
|
||||
// Step 1: Fetch/cache tool description (sync)
|
||||
{
|
||||
authAttrs := webSearchAuthAttrs(auth)
|
||||
fetchToolDescription(ctx, mcpEndpoint, accessToken, newKiroHTTPClientWithPooling(ctx, e.cfg, auth, 30*time.Second), auth, authAttrs)
|
||||
}
|
||||
|
||||
// Step 2: Perform MCP search
|
||||
_, mcpRequest := kiroclaude.CreateMcpRequest(query)
|
||||
|
||||
authAttrs := webSearchAuthAttrs(auth)
|
||||
handler := newWebSearchHandler(ctx, mcpEndpoint, accessToken, newKiroHTTPClientWithPooling(ctx, e.cfg, auth, 30*time.Second), auth, authAttrs)
|
||||
mcpResponse, mcpErr := handler.callMcpAPI(mcpRequest)
|
||||
|
||||
var searchResults *kiroclaude.WebSearchResults
|
||||
if mcpErr != nil {
|
||||
log.Warnf("kiro/websearch: non-stream: MCP API call failed: %v, continuing with empty results", mcpErr)
|
||||
} else {
|
||||
searchResults = kiroclaude.ParseSearchResults(mcpResponse)
|
||||
}
|
||||
|
||||
resultCount := 0
|
||||
if searchResults != nil {
|
||||
resultCount = len(searchResults.Results)
|
||||
}
|
||||
log.Infof("kiro/websearch: non-stream: got %d search results", resultCount)
|
||||
|
||||
// Step 3: Replace restrictive web_search tool description (align with streaming path)
|
||||
simplifiedPayload, simplifyErr := kiroclaude.ReplaceWebSearchToolDescription(bytes.Clone(req.Payload))
|
||||
if simplifyErr != nil {
|
||||
log.Warnf("kiro/websearch: non-stream: failed to simplify web_search tool: %v, using original payload", simplifyErr)
|
||||
simplifiedPayload = bytes.Clone(req.Payload)
|
||||
}
|
||||
|
||||
// Step 4: Inject search tool_use + tool_result into Claude payload
|
||||
currentToolUseId := fmt.Sprintf("srvtoolu_%s", kiroclaude.GenerateToolUseID())
|
||||
modifiedPayload, err := kiroclaude.InjectToolResultsClaude(simplifiedPayload, currentToolUseId, query, searchResults)
|
||||
if err != nil {
|
||||
log.Warnf("kiro/websearch: non-stream: failed to inject tool results: %v, falling back", err)
|
||||
return e.executeNonStreamFallback(ctx, auth, req, opts, accessToken, profileArn)
|
||||
}
|
||||
|
||||
// Step 5: Call Kiro API via the normal non-streaming path (executeWithRetry)
|
||||
// This path uses parseEventStream → BuildClaudeResponse → TranslateNonStream
|
||||
// to produce a proper Claude JSON response
|
||||
modifiedReq := req
|
||||
modifiedReq.Payload = modifiedPayload
|
||||
|
||||
resp, err := e.executeNonStreamFallback(ctx, auth, modifiedReq, opts, accessToken, profileArn)
|
||||
if err != nil {
|
||||
return resp, err
|
||||
}
|
||||
|
||||
// Step 6: Inject server_tool_use + web_search_tool_result into response
|
||||
// so Claude Code can display "Did X searches in Ys"
|
||||
indicators := []kiroclaude.SearchIndicator{
|
||||
{
|
||||
ToolUseID: currentToolUseId,
|
||||
Query: query,
|
||||
Results: searchResults,
|
||||
},
|
||||
}
|
||||
injectedPayload, injErr := kiroclaude.InjectSearchIndicatorsInResponse(resp.Payload, indicators)
|
||||
if injErr != nil {
|
||||
log.Warnf("kiro/websearch: non-stream: failed to inject search indicators: %v", injErr)
|
||||
} else {
|
||||
resp.Payload = injectedPayload
|
||||
}
|
||||
|
||||
return resp, nil
|
||||
}
|
||||
|
||||
// callKiroAndBuffer calls the Kiro API and buffers all response chunks.
|
||||
// Returns the buffered chunks for analysis before forwarding to client.
|
||||
// Usage reporting is NOT done here — the caller (handleWebSearchStream) manages its own reporter.
|
||||
func (e *KiroExecutor) callKiroAndBuffer(
|
||||
ctx context.Context,
|
||||
auth *cliproxyauth.Auth,
|
||||
req cliproxyexecutor.Request,
|
||||
opts cliproxyexecutor.Options,
|
||||
accessToken, profileArn string,
|
||||
) ([][]byte, error) {
|
||||
from := opts.SourceFormat
|
||||
to := sdktranslator.FromString("kiro")
|
||||
body := sdktranslator.TranslateRequest(from, to, req.Model, bytes.Clone(req.Payload), true)
|
||||
log.Debugf("kiro/websearch GAR request: %d bytes", len(body))
|
||||
|
||||
kiroModelID := e.mapModelToKiro(req.Model)
|
||||
isAgentic, isChatOnly := determineAgenticMode(req.Model)
|
||||
effectiveProfileArn := getEffectiveProfileArnWithWarning(auth, profileArn)
|
||||
|
||||
tokenKey := getTokenKey(auth)
|
||||
|
||||
kiroStream, err := e.executeStreamWithRetry(
|
||||
ctx, auth, req, opts, accessToken, effectiveProfileArn,
|
||||
nil, body, from, nil, "", kiroModelID, isAgentic, isChatOnly, tokenKey,
|
||||
)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
// Buffer all chunks
|
||||
var chunks [][]byte
|
||||
for chunk := range kiroStream {
|
||||
if chunk.Err != nil {
|
||||
return chunks, chunk.Err
|
||||
}
|
||||
if len(chunk.Payload) > 0 {
|
||||
chunks = append(chunks, bytes.Clone(chunk.Payload))
|
||||
}
|
||||
}
|
||||
|
||||
log.Debugf("kiro/websearch GAR response: %d chunks buffered", len(chunks))
|
||||
|
||||
return chunks, nil
|
||||
}
|
||||
|
||||
// callKiroDirectStream creates a direct streaming channel to Kiro API without search.
|
||||
func (e *KiroExecutor) callKiroDirectStream(
|
||||
ctx context.Context,
|
||||
auth *cliproxyauth.Auth,
|
||||
req cliproxyexecutor.Request,
|
||||
opts cliproxyexecutor.Options,
|
||||
accessToken, profileArn string,
|
||||
) (<-chan cliproxyexecutor.StreamChunk, error) {
|
||||
from := opts.SourceFormat
|
||||
to := sdktranslator.FromString("kiro")
|
||||
body := sdktranslator.TranslateRequest(from, to, req.Model, bytes.Clone(req.Payload), true)
|
||||
|
||||
kiroModelID := e.mapModelToKiro(req.Model)
|
||||
isAgentic, isChatOnly := determineAgenticMode(req.Model)
|
||||
effectiveProfileArn := getEffectiveProfileArnWithWarning(auth, profileArn)
|
||||
|
||||
tokenKey := getTokenKey(auth)
|
||||
|
||||
reporter := newUsageReporter(ctx, e.Identifier(), req.Model, auth)
|
||||
var streamErr error
|
||||
defer reporter.trackFailure(ctx, &streamErr)
|
||||
|
||||
stream, streamErr := e.executeStreamWithRetry(
|
||||
ctx, auth, req, opts, accessToken, effectiveProfileArn,
|
||||
nil, body, from, reporter, "", kiroModelID, isAgentic, isChatOnly, tokenKey,
|
||||
)
|
||||
return stream, streamErr
|
||||
}
|
||||
|
||||
// sendFallbackText sends a simple text response when the Kiro API fails during the search loop.
|
||||
// Delegates SSE event construction to kiroclaude.BuildFallbackTextEvents() for alignment
|
||||
// with how streamToChannel() uses BuildClaude*Event() functions.
|
||||
func (e *KiroExecutor) sendFallbackText(
|
||||
ctx context.Context,
|
||||
out chan<- cliproxyexecutor.StreamChunk,
|
||||
contentBlockIndex int,
|
||||
query string,
|
||||
searchResults *kiroclaude.WebSearchResults,
|
||||
) {
|
||||
events := kiroclaude.BuildFallbackTextEvents(contentBlockIndex, query, searchResults)
|
||||
for _, event := range events {
|
||||
select {
|
||||
case <-ctx.Done():
|
||||
return
|
||||
case out <- cliproxyexecutor.StreamChunk{Payload: append(event, '\n', '\n')}:
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// executeNonStreamFallback runs the standard non-streaming Execute path for a request.
|
||||
// Used by handleWebSearch after injecting search results, or as a fallback.
|
||||
func (e *KiroExecutor) executeNonStreamFallback(
|
||||
ctx context.Context,
|
||||
auth *cliproxyauth.Auth,
|
||||
req cliproxyexecutor.Request,
|
||||
opts cliproxyexecutor.Options,
|
||||
accessToken, profileArn string,
|
||||
) (cliproxyexecutor.Response, error) {
|
||||
from := opts.SourceFormat
|
||||
to := sdktranslator.FromString("kiro")
|
||||
body := sdktranslator.TranslateRequest(from, to, req.Model, bytes.Clone(req.Payload), true)
|
||||
|
||||
kiroModelID := e.mapModelToKiro(req.Model)
|
||||
isAgentic, isChatOnly := determineAgenticMode(req.Model)
|
||||
effectiveProfileArn := getEffectiveProfileArnWithWarning(auth, profileArn)
|
||||
tokenKey := getTokenKey(auth)
|
||||
|
||||
reporter := newUsageReporter(ctx, e.Identifier(), req.Model, auth)
|
||||
var err error
|
||||
defer reporter.trackFailure(ctx, &err)
|
||||
|
||||
resp, err := e.executeWithRetry(ctx, auth, req, opts, accessToken, effectiveProfileArn, nil, body, from, to, reporter, "", kiroModelID, isAgentic, isChatOnly, tokenKey)
|
||||
return resp, err
|
||||
}
|
||||
|
||||
@@ -344,7 +344,8 @@ func ConvertClaudeRequestToAntigravity(modelName string, inputRawJSON []byte, _
|
||||
// Inject interleaved thinking hint when both tools and thinking are active
|
||||
hasTools := toolDeclCount > 0
|
||||
thinkingResult := gjson.GetBytes(rawJSON, "thinking")
|
||||
hasThinking := thinkingResult.Exists() && thinkingResult.IsObject() && thinkingResult.Get("type").String() == "enabled"
|
||||
thinkingType := thinkingResult.Get("type").String()
|
||||
hasThinking := thinkingResult.Exists() && thinkingResult.IsObject() && (thinkingType == "enabled" || thinkingType == "adaptive")
|
||||
isClaudeThinking := util.IsClaudeThinkingModel(modelName)
|
||||
|
||||
if hasTools && hasThinking && isClaudeThinking {
|
||||
@@ -377,12 +378,18 @@ func ConvertClaudeRequestToAntigravity(modelName string, inputRawJSON []byte, _
|
||||
|
||||
// Map Anthropic thinking -> Gemini thinkingBudget/include_thoughts when type==enabled
|
||||
if t := gjson.GetBytes(rawJSON, "thinking"); enableThoughtTranslate && t.Exists() && t.IsObject() {
|
||||
if t.Get("type").String() == "enabled" {
|
||||
switch t.Get("type").String() {
|
||||
case "enabled":
|
||||
if b := t.Get("budget_tokens"); b.Exists() && b.Type == gjson.Number {
|
||||
budget := int(b.Int())
|
||||
out, _ = sjson.Set(out, "request.generationConfig.thinkingConfig.thinkingBudget", budget)
|
||||
out, _ = sjson.Set(out, "request.generationConfig.thinkingConfig.includeThoughts", true)
|
||||
}
|
||||
case "adaptive":
|
||||
// Keep adaptive as a high level sentinel; ApplyThinking resolves it
|
||||
// to model-specific max capability.
|
||||
out, _ = sjson.Set(out, "request.generationConfig.thinkingConfig.thinkingLevel", "high")
|
||||
out, _ = sjson.Set(out, "request.generationConfig.thinkingConfig.includeThoughts", true)
|
||||
}
|
||||
}
|
||||
if v := gjson.GetBytes(rawJSON, "temperature"); v.Exists() && v.Type == gjson.Number {
|
||||
|
||||
@@ -222,6 +222,10 @@ func ConvertClaudeRequestToCodex(modelName string, inputRawJSON []byte, _ bool)
|
||||
reasoningEffort = effort
|
||||
}
|
||||
}
|
||||
case "adaptive":
|
||||
// Claude adaptive means "enable with max capacity"; keep it as highest level
|
||||
// and let ApplyThinking normalize per target model capability.
|
||||
reasoningEffort = string(thinking.LevelXHigh)
|
||||
case "disabled":
|
||||
if effort, ok := thinking.ConvertBudgetToLevel(0); ok && effort != "" {
|
||||
reasoningEffort = effort
|
||||
|
||||
@@ -90,6 +90,9 @@ func ConvertCodexResponseToOpenAI(_ context.Context, modelName string, originalR
|
||||
if inputTokensResult := usageResult.Get("input_tokens"); inputTokensResult.Exists() {
|
||||
template, _ = sjson.Set(template, "usage.prompt_tokens", inputTokensResult.Int())
|
||||
}
|
||||
if cachedTokensResult := usageResult.Get("input_tokens_details.cached_tokens"); cachedTokensResult.Exists() {
|
||||
template, _ = sjson.Set(template, "usage.prompt_tokens_details.cached_tokens", cachedTokensResult.Int())
|
||||
}
|
||||
if reasoningTokensResult := usageResult.Get("output_tokens_details.reasoning_tokens"); reasoningTokensResult.Exists() {
|
||||
template, _ = sjson.Set(template, "usage.completion_tokens_details.reasoning_tokens", reasoningTokensResult.Int())
|
||||
}
|
||||
@@ -205,6 +208,9 @@ func ConvertCodexResponseToOpenAINonStream(_ context.Context, _ string, original
|
||||
if inputTokensResult := usageResult.Get("input_tokens"); inputTokensResult.Exists() {
|
||||
template, _ = sjson.Set(template, "usage.prompt_tokens", inputTokensResult.Int())
|
||||
}
|
||||
if cachedTokensResult := usageResult.Get("input_tokens_details.cached_tokens"); cachedTokensResult.Exists() {
|
||||
template, _ = sjson.Set(template, "usage.prompt_tokens_details.cached_tokens", cachedTokensResult.Int())
|
||||
}
|
||||
if reasoningTokensResult := usageResult.Get("output_tokens_details.reasoning_tokens"); reasoningTokensResult.Exists() {
|
||||
template, _ = sjson.Set(template, "usage.completion_tokens_details.reasoning_tokens", reasoningTokensResult.Int())
|
||||
}
|
||||
|
||||
@@ -27,6 +27,9 @@ func ConvertOpenAIResponsesRequestToCodex(modelName string, inputRawJSON []byte,
|
||||
rawJSON, _ = sjson.DeleteBytes(rawJSON, "top_p")
|
||||
rawJSON, _ = sjson.DeleteBytes(rawJSON, "service_tier")
|
||||
|
||||
// Delete the user field as it is not supported by the Codex upstream.
|
||||
rawJSON, _ = sjson.DeleteBytes(rawJSON, "user")
|
||||
|
||||
// Convert role "system" to "developer" in input array to comply with Codex API requirements.
|
||||
rawJSON = convertSystemRoleToDeveloper(rawJSON)
|
||||
|
||||
|
||||
@@ -263,3 +263,20 @@ func TestConvertSystemRoleToDeveloper_AssistantRole(t *testing.T) {
|
||||
t.Errorf("Expected third role 'assistant', got '%s'", thirdRole.String())
|
||||
}
|
||||
}
|
||||
|
||||
func TestUserFieldDeletion(t *testing.T) {
|
||||
inputJSON := []byte(`{
|
||||
"model": "gpt-5.2",
|
||||
"user": "test-user",
|
||||
"input": [{"role": "user", "content": "Hello"}]
|
||||
}`)
|
||||
|
||||
output := ConvertOpenAIResponsesRequestToCodex("gpt-5.2", inputJSON, false)
|
||||
outputStr := string(output)
|
||||
|
||||
// Verify user field is deleted
|
||||
userField := gjson.Get(outputStr, "user")
|
||||
if userField.Exists() {
|
||||
t.Errorf("user field should be deleted, but it was found with value: %s", userField.Raw)
|
||||
}
|
||||
}
|
||||
|
||||
@@ -173,12 +173,18 @@ func ConvertClaudeRequestToCLI(modelName string, inputRawJSON []byte, _ bool) []
|
||||
|
||||
// Map Anthropic thinking -> Gemini thinkingBudget/include_thoughts when type==enabled
|
||||
if t := gjson.GetBytes(rawJSON, "thinking"); t.Exists() && t.IsObject() {
|
||||
if t.Get("type").String() == "enabled" {
|
||||
switch t.Get("type").String() {
|
||||
case "enabled":
|
||||
if b := t.Get("budget_tokens"); b.Exists() && b.Type == gjson.Number {
|
||||
budget := int(b.Int())
|
||||
out, _ = sjson.Set(out, "request.generationConfig.thinkingConfig.thinkingBudget", budget)
|
||||
out, _ = sjson.Set(out, "request.generationConfig.thinkingConfig.includeThoughts", true)
|
||||
}
|
||||
case "adaptive":
|
||||
// Keep adaptive as a high level sentinel; ApplyThinking resolves it
|
||||
// to model-specific max capability.
|
||||
out, _ = sjson.Set(out, "request.generationConfig.thinkingConfig.thinkingLevel", "high")
|
||||
out, _ = sjson.Set(out, "request.generationConfig.thinkingConfig.includeThoughts", true)
|
||||
}
|
||||
}
|
||||
if v := gjson.GetBytes(rawJSON, "temperature"); v.Exists() && v.Type == gjson.Number {
|
||||
|
||||
@@ -154,12 +154,18 @@ func ConvertClaudeRequestToGemini(modelName string, inputRawJSON []byte, _ bool)
|
||||
// Map Anthropic thinking -> Gemini thinkingBudget/include_thoughts when enabled
|
||||
// Translator only does format conversion, ApplyThinking handles model capability validation.
|
||||
if t := gjson.GetBytes(rawJSON, "thinking"); t.Exists() && t.IsObject() {
|
||||
if t.Get("type").String() == "enabled" {
|
||||
switch t.Get("type").String() {
|
||||
case "enabled":
|
||||
if b := t.Get("budget_tokens"); b.Exists() && b.Type == gjson.Number {
|
||||
budget := int(b.Int())
|
||||
out, _ = sjson.Set(out, "generationConfig.thinkingConfig.thinkingBudget", budget)
|
||||
out, _ = sjson.Set(out, "generationConfig.thinkingConfig.includeThoughts", true)
|
||||
}
|
||||
case "adaptive":
|
||||
// Keep adaptive as a high level sentinel; ApplyThinking resolves it
|
||||
// to model-specific max capability.
|
||||
out, _ = sjson.Set(out, "generationConfig.thinkingConfig.thinkingLevel", "high")
|
||||
out, _ = sjson.Set(out, "generationConfig.thinkingConfig.includeThoughts", true)
|
||||
}
|
||||
}
|
||||
if v := gjson.GetBytes(rawJSON, "temperature"); v.Exists() && v.Type == gjson.Number {
|
||||
|
||||
@@ -117,19 +117,29 @@ func ConvertOpenAIResponsesRequestToGemini(modelName string, inputRawJSON []byte
|
||||
switch itemType {
|
||||
case "message":
|
||||
if strings.EqualFold(itemRole, "system") {
|
||||
if contentArray := item.Get("content"); contentArray.Exists() && contentArray.IsArray() {
|
||||
var builder strings.Builder
|
||||
contentArray.ForEach(func(_, contentItem gjson.Result) bool {
|
||||
text := contentItem.Get("text").String()
|
||||
if builder.Len() > 0 && text != "" {
|
||||
builder.WriteByte('\n')
|
||||
}
|
||||
builder.WriteString(text)
|
||||
return true
|
||||
})
|
||||
if !gjson.Get(out, "system_instruction").Exists() {
|
||||
systemInstr := `{"parts":[{"text":""}]}`
|
||||
systemInstr, _ = sjson.Set(systemInstr, "parts.0.text", builder.String())
|
||||
if contentArray := item.Get("content"); contentArray.Exists() {
|
||||
systemInstr := ""
|
||||
if systemInstructionResult := gjson.Get(out, "system_instruction"); systemInstructionResult.Exists() {
|
||||
systemInstr = systemInstructionResult.Raw
|
||||
} else {
|
||||
systemInstr = `{"parts":[]}`
|
||||
}
|
||||
|
||||
if contentArray.IsArray() {
|
||||
contentArray.ForEach(func(_, contentItem gjson.Result) bool {
|
||||
part := `{"text":""}`
|
||||
text := contentItem.Get("text").String()
|
||||
part, _ = sjson.Set(part, "text", text)
|
||||
systemInstr, _ = sjson.SetRaw(systemInstr, "parts.-1", part)
|
||||
return true
|
||||
})
|
||||
} else if contentArray.Type == gjson.String {
|
||||
part := `{"text":""}`
|
||||
part, _ = sjson.Set(part, "text", contentArray.String())
|
||||
systemInstr, _ = sjson.SetRaw(systemInstr, "parts.-1", part)
|
||||
}
|
||||
|
||||
if systemInstr != `{"parts":[]}` {
|
||||
out, _ = sjson.SetRaw(out, "system_instruction", systemInstr)
|
||||
}
|
||||
}
|
||||
@@ -236,8 +246,22 @@ func ConvertOpenAIResponsesRequestToGemini(modelName string, inputRawJSON []byte
|
||||
})
|
||||
|
||||
flush()
|
||||
}
|
||||
} else if contentArray.Type == gjson.String {
|
||||
effRole := "user"
|
||||
if itemRole != "" {
|
||||
switch strings.ToLower(itemRole) {
|
||||
case "assistant", "model":
|
||||
effRole = "model"
|
||||
default:
|
||||
effRole = strings.ToLower(itemRole)
|
||||
}
|
||||
}
|
||||
|
||||
one := `{"role":"","parts":[{"text":""}]}`
|
||||
one, _ = sjson.Set(one, "role", effRole)
|
||||
one, _ = sjson.Set(one, "parts.0.text", contentArray.String())
|
||||
out, _ = sjson.SetRaw(out, "contents.-1", one)
|
||||
}
|
||||
case "function_call":
|
||||
// Handle function calls - convert to model message with functionCall
|
||||
name := item.Get("name").String()
|
||||
|
||||
@@ -17,6 +17,9 @@ import (
|
||||
"github.com/tidwall/gjson"
|
||||
)
|
||||
|
||||
// remoteWebSearchDescription is a minimal fallback for when dynamic fetch from MCP tools/list hasn't completed yet.
|
||||
const remoteWebSearchDescription = "WebSearch looks up information outside the model's training data. Supports multiple queries to gather comprehensive information."
|
||||
|
||||
// Kiro API request structs - field order determines JSON key order
|
||||
|
||||
// KiroPayload is the top-level request structure for Kiro API
|
||||
@@ -219,26 +222,7 @@ func BuildKiroPayload(claudeBody []byte, modelID, profileArn, origin string, isA
|
||||
}
|
||||
|
||||
// Convert Claude tools to Kiro format
|
||||
kiroTools, hasWebSearch := convertClaudeToolsToKiro(tools)
|
||||
|
||||
// If web_search was requested but filtered, inject alternative hint
|
||||
if hasWebSearch {
|
||||
webSearchHint := `[CRITICAL WEB ACCESS INSTRUCTION]
|
||||
You have the Fetch/read_url_content tool available. When the user asks about current events, weather, news, or any information that requires web access:
|
||||
- DO NOT say you cannot search the web
|
||||
- DO NOT refuse to help with web-related queries
|
||||
- IMMEDIATELY use the Fetch tool to access relevant URLs
|
||||
- Use well-known official websites, documentation sites, or API endpoints
|
||||
- Construct appropriate URLs based on the query context
|
||||
|
||||
IMPORTANT: Always attempt to fetch information FIRST before declining. You CAN access the web via Fetch.`
|
||||
if systemPrompt != "" {
|
||||
systemPrompt += "\n" + webSearchHint
|
||||
} else {
|
||||
systemPrompt = webSearchHint
|
||||
}
|
||||
log.Infof("kiro: injected web_search alternative hint (tool was filtered)")
|
||||
}
|
||||
kiroTools := convertClaudeToolsToKiro(tools)
|
||||
|
||||
// Thinking mode implementation:
|
||||
// Kiro API supports official thinking/reasoning mode via <thinking_mode> tag.
|
||||
@@ -527,27 +511,15 @@ func ensureKiroInputSchema(parameters interface{}) interface{} {
|
||||
}
|
||||
}
|
||||
|
||||
// convertClaudeToolsToKiro converts Claude tools to Kiro format.
|
||||
// Returns the converted tools and a boolean indicating if web_search was filtered.
|
||||
func convertClaudeToolsToKiro(tools gjson.Result) ([]KiroToolWrapper, bool) {
|
||||
// convertClaudeToolsToKiro converts Claude tools to Kiro format
|
||||
func convertClaudeToolsToKiro(tools gjson.Result) []KiroToolWrapper {
|
||||
var kiroTools []KiroToolWrapper
|
||||
hasWebSearch := false
|
||||
if !tools.IsArray() {
|
||||
return kiroTools, hasWebSearch
|
||||
return kiroTools
|
||||
}
|
||||
|
||||
for _, tool := range tools.Array() {
|
||||
name := tool.Get("name").String()
|
||||
|
||||
// Filter out web_search/websearch tools (Kiro API doesn't support them)
|
||||
// This matches the behavior in AIClient-2-API/claude-kiro.js
|
||||
nameLower := strings.ToLower(name)
|
||||
if nameLower == "web_search" || nameLower == "websearch" {
|
||||
log.Debugf("kiro: skipping unsupported tool: %s", name)
|
||||
hasWebSearch = true
|
||||
continue
|
||||
}
|
||||
|
||||
description := tool.Get("description").String()
|
||||
inputSchemaResult := tool.Get("input_schema")
|
||||
var inputSchema interface{}
|
||||
@@ -569,6 +541,18 @@ func convertClaudeToolsToKiro(tools gjson.Result) ([]KiroToolWrapper, bool) {
|
||||
log.Debugf("kiro: tool '%s' has empty description, using default: %s", name, description)
|
||||
}
|
||||
|
||||
// Rename web_search → remote_web_search for Kiro API compatibility
|
||||
if name == "web_search" {
|
||||
name = "remote_web_search"
|
||||
// Prefer dynamically fetched description, fall back to hardcoded constant
|
||||
if cached := GetWebSearchDescription(); cached != "" {
|
||||
description = cached
|
||||
} else {
|
||||
description = remoteWebSearchDescription
|
||||
}
|
||||
log.Debugf("kiro: renamed tool web_search → remote_web_search")
|
||||
}
|
||||
|
||||
// Truncate long descriptions (individual tool limit)
|
||||
if len(description) > kirocommon.KiroMaxToolDescLen {
|
||||
truncLen := kirocommon.KiroMaxToolDescLen - 30
|
||||
@@ -591,7 +575,7 @@ func convertClaudeToolsToKiro(tools gjson.Result) ([]KiroToolWrapper, bool) {
|
||||
// This prevents 500 errors when Claude Code sends too many tools
|
||||
kiroTools = compressToolsIfNeeded(kiroTools)
|
||||
|
||||
return kiroTools, hasWebSearch
|
||||
return kiroTools
|
||||
}
|
||||
|
||||
// processMessages processes Claude messages and builds Kiro history
|
||||
@@ -602,6 +586,17 @@ func processMessages(messages gjson.Result, modelID, origin string) ([]KiroHisto
|
||||
|
||||
// Merge adjacent messages with the same role
|
||||
messagesArray := kirocommon.MergeAdjacentMessages(messages.Array())
|
||||
|
||||
// FIX: Kiro API requires history to start with a user message.
|
||||
// Some clients (e.g., OpenClaw) send conversations starting with an assistant message,
|
||||
// which is valid for the Claude API but causes "Improperly formed request" on Kiro.
|
||||
// Prepend a placeholder user message so the history alternation is correct.
|
||||
if len(messagesArray) > 0 && messagesArray[0].Get("role").String() == "assistant" {
|
||||
placeholder := `{"role":"user","content":"."}`
|
||||
messagesArray = append([]gjson.Result{gjson.Parse(placeholder)}, messagesArray...)
|
||||
log.Infof("kiro: messages started with assistant role, prepended placeholder user message for Kiro API compatibility")
|
||||
}
|
||||
|
||||
for i, msg := range messagesArray {
|
||||
role := msg.Get("role").String()
|
||||
isLastMessage := i == len(messagesArray)-1
|
||||
@@ -654,6 +649,57 @@ func processMessages(messages gjson.Result, modelID, origin string) ([]KiroHisto
|
||||
}
|
||||
}
|
||||
|
||||
// POST-PROCESSING: Remove orphaned tool_results that have no matching tool_use
|
||||
// in any assistant message. This happens when Claude Code compaction truncates
|
||||
// the conversation and removes the assistant message containing the tool_use,
|
||||
// but keeps the user message with the corresponding tool_result.
|
||||
// Without this fix, Kiro API returns "Improperly formed request".
|
||||
validToolUseIDs := make(map[string]bool)
|
||||
for _, h := range history {
|
||||
if h.AssistantResponseMessage != nil {
|
||||
for _, tu := range h.AssistantResponseMessage.ToolUses {
|
||||
validToolUseIDs[tu.ToolUseID] = true
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// Filter orphaned tool results from history user messages
|
||||
for i, h := range history {
|
||||
if h.UserInputMessage != nil && h.UserInputMessage.UserInputMessageContext != nil {
|
||||
ctx := h.UserInputMessage.UserInputMessageContext
|
||||
if len(ctx.ToolResults) > 0 {
|
||||
filtered := make([]KiroToolResult, 0, len(ctx.ToolResults))
|
||||
for _, tr := range ctx.ToolResults {
|
||||
if validToolUseIDs[tr.ToolUseID] {
|
||||
filtered = append(filtered, tr)
|
||||
} else {
|
||||
log.Debugf("kiro: dropping orphaned tool_result in history[%d]: toolUseId=%s (no matching tool_use)", i, tr.ToolUseID)
|
||||
}
|
||||
}
|
||||
ctx.ToolResults = filtered
|
||||
if len(ctx.ToolResults) == 0 && len(ctx.Tools) == 0 {
|
||||
h.UserInputMessage.UserInputMessageContext = nil
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// Filter orphaned tool results from current message
|
||||
if len(currentToolResults) > 0 {
|
||||
filtered := make([]KiroToolResult, 0, len(currentToolResults))
|
||||
for _, tr := range currentToolResults {
|
||||
if validToolUseIDs[tr.ToolUseID] {
|
||||
filtered = append(filtered, tr)
|
||||
} else {
|
||||
log.Debugf("kiro: dropping orphaned tool_result in currentMessage: toolUseId=%s (no matching tool_use)", tr.ToolUseID)
|
||||
}
|
||||
}
|
||||
if len(filtered) != len(currentToolResults) {
|
||||
log.Infof("kiro: dropped %d orphaned tool_result(s) from currentMessage (compaction artifact)", len(currentToolResults)-len(filtered))
|
||||
}
|
||||
currentToolResults = filtered
|
||||
}
|
||||
|
||||
return history, currentUserMsg, currentToolResults
|
||||
}
|
||||
|
||||
@@ -876,6 +922,11 @@ func BuildAssistantMessageStruct(msg gjson.Result) KiroAssistantResponseMessage
|
||||
})
|
||||
}
|
||||
|
||||
// Rename web_search → remote_web_search to match convertClaudeToolsToKiro
|
||||
if toolName == "web_search" {
|
||||
toolName = "remote_web_search"
|
||||
}
|
||||
|
||||
toolUses = append(toolUses, KiroToolUse{
|
||||
ToolUseID: toolUseID,
|
||||
Name: toolName,
|
||||
|
||||
@@ -183,4 +183,124 @@ func PendingTagSuffix(buffer, tag string) int {
|
||||
}
|
||||
}
|
||||
return 0
|
||||
}
|
||||
}
|
||||
|
||||
// GenerateSearchIndicatorEvents generates ONLY the search indicator SSE events
|
||||
// (server_tool_use + web_search_tool_result) without text summary or message termination.
|
||||
// These events trigger Claude Code's search indicator UI.
|
||||
// The caller is responsible for sending message_start before and message_delta/stop after.
|
||||
func GenerateSearchIndicatorEvents(
|
||||
query string,
|
||||
toolUseID string,
|
||||
searchResults *WebSearchResults,
|
||||
startIndex int,
|
||||
) [][]byte {
|
||||
events := make([][]byte, 0, 5)
|
||||
|
||||
// 1. content_block_start (server_tool_use)
|
||||
event1 := map[string]interface{}{
|
||||
"type": "content_block_start",
|
||||
"index": startIndex,
|
||||
"content_block": map[string]interface{}{
|
||||
"id": toolUseID,
|
||||
"type": "server_tool_use",
|
||||
"name": "web_search",
|
||||
"input": map[string]interface{}{},
|
||||
},
|
||||
}
|
||||
data1, _ := json.Marshal(event1)
|
||||
events = append(events, []byte("event: content_block_start\ndata: "+string(data1)+"\n\n"))
|
||||
|
||||
// 2. content_block_delta (input_json_delta)
|
||||
inputJSON, _ := json.Marshal(map[string]string{"query": query})
|
||||
event2 := map[string]interface{}{
|
||||
"type": "content_block_delta",
|
||||
"index": startIndex,
|
||||
"delta": map[string]interface{}{
|
||||
"type": "input_json_delta",
|
||||
"partial_json": string(inputJSON),
|
||||
},
|
||||
}
|
||||
data2, _ := json.Marshal(event2)
|
||||
events = append(events, []byte("event: content_block_delta\ndata: "+string(data2)+"\n\n"))
|
||||
|
||||
// 3. content_block_stop (server_tool_use)
|
||||
event3 := map[string]interface{}{
|
||||
"type": "content_block_stop",
|
||||
"index": startIndex,
|
||||
}
|
||||
data3, _ := json.Marshal(event3)
|
||||
events = append(events, []byte("event: content_block_stop\ndata: "+string(data3)+"\n\n"))
|
||||
|
||||
// 4. content_block_start (web_search_tool_result)
|
||||
searchContent := make([]map[string]interface{}, 0)
|
||||
if searchResults != nil {
|
||||
for _, r := range searchResults.Results {
|
||||
snippet := ""
|
||||
if r.Snippet != nil {
|
||||
snippet = *r.Snippet
|
||||
}
|
||||
searchContent = append(searchContent, map[string]interface{}{
|
||||
"type": "web_search_result",
|
||||
"title": r.Title,
|
||||
"url": r.URL,
|
||||
"encrypted_content": snippet,
|
||||
"page_age": nil,
|
||||
})
|
||||
}
|
||||
}
|
||||
event4 := map[string]interface{}{
|
||||
"type": "content_block_start",
|
||||
"index": startIndex + 1,
|
||||
"content_block": map[string]interface{}{
|
||||
"type": "web_search_tool_result",
|
||||
"tool_use_id": toolUseID,
|
||||
"content": searchContent,
|
||||
},
|
||||
}
|
||||
data4, _ := json.Marshal(event4)
|
||||
events = append(events, []byte("event: content_block_start\ndata: "+string(data4)+"\n\n"))
|
||||
|
||||
// 5. content_block_stop (web_search_tool_result)
|
||||
event5 := map[string]interface{}{
|
||||
"type": "content_block_stop",
|
||||
"index": startIndex + 1,
|
||||
}
|
||||
data5, _ := json.Marshal(event5)
|
||||
events = append(events, []byte("event: content_block_stop\ndata: "+string(data5)+"\n\n"))
|
||||
|
||||
return events
|
||||
}
|
||||
|
||||
// BuildFallbackTextEvents generates SSE events for a fallback text response
|
||||
// when the Kiro API fails during the search loop. Uses BuildClaude*Event()
|
||||
// functions to align with streamToChannel patterns.
|
||||
// Returns raw SSE byte slices ready to be sent to the client channel.
|
||||
func BuildFallbackTextEvents(contentBlockIndex int, query string, results *WebSearchResults) [][]byte {
|
||||
summary := FormatSearchContextPrompt(query, results)
|
||||
outputTokens := len(summary) / 4
|
||||
if len(summary) > 0 && outputTokens == 0 {
|
||||
outputTokens = 1
|
||||
}
|
||||
|
||||
var events [][]byte
|
||||
|
||||
// content_block_start (text)
|
||||
events = append(events, BuildClaudeContentBlockStartEvent(contentBlockIndex, "text", "", ""))
|
||||
|
||||
// content_block_delta (text_delta)
|
||||
events = append(events, BuildClaudeStreamEvent(summary, contentBlockIndex))
|
||||
|
||||
// content_block_stop
|
||||
events = append(events, BuildClaudeContentBlockStopEvent(contentBlockIndex))
|
||||
|
||||
// message_delta with end_turn
|
||||
events = append(events, BuildClaudeMessageDeltaEvent("end_turn", usage.Detail{
|
||||
OutputTokens: int64(outputTokens),
|
||||
}))
|
||||
|
||||
// message_stop
|
||||
events = append(events, BuildClaudeMessageStopOnlyEvent())
|
||||
|
||||
return events
|
||||
}
|
||||
|
||||
350
internal/translator/kiro/claude/kiro_claude_stream_parser.go
Normal file
350
internal/translator/kiro/claude/kiro_claude_stream_parser.go
Normal file
@@ -0,0 +1,350 @@
|
||||
package claude
|
||||
|
||||
import (
|
||||
"encoding/json"
|
||||
"strings"
|
||||
|
||||
log "github.com/sirupsen/logrus"
|
||||
)
|
||||
|
||||
// sseEvent represents a Server-Sent Event
|
||||
type sseEvent struct {
|
||||
Event string
|
||||
Data interface{}
|
||||
}
|
||||
|
||||
// ToSSEString converts the event to SSE wire format
|
||||
func (e *sseEvent) ToSSEString() string {
|
||||
dataBytes, _ := json.Marshal(e.Data)
|
||||
return "event: " + e.Event + "\ndata: " + string(dataBytes) + "\n\n"
|
||||
}
|
||||
|
||||
// AdjustStreamIndices adjusts content block indices in SSE event data by adding an offset.
|
||||
// It also suppresses duplicate message_start events (returns shouldForward=false).
|
||||
// This is used to combine search indicator events (indices 0,1) with Kiro model response events.
|
||||
//
|
||||
// The data parameter is a single SSE "data:" line payload (JSON).
|
||||
// Returns: adjusted data, shouldForward (false = skip this event).
|
||||
func AdjustStreamIndices(data []byte, offset int) ([]byte, bool) {
|
||||
if len(data) == 0 {
|
||||
return data, true
|
||||
}
|
||||
|
||||
// Quick check: parse the JSON
|
||||
var event map[string]interface{}
|
||||
if err := json.Unmarshal(data, &event); err != nil {
|
||||
// Not valid JSON, pass through
|
||||
return data, true
|
||||
}
|
||||
|
||||
eventType, _ := event["type"].(string)
|
||||
|
||||
// Suppress duplicate message_start events
|
||||
if eventType == "message_start" {
|
||||
return data, false
|
||||
}
|
||||
|
||||
// Adjust index for content_block events
|
||||
switch eventType {
|
||||
case "content_block_start", "content_block_delta", "content_block_stop":
|
||||
if idx, ok := event["index"].(float64); ok {
|
||||
event["index"] = int(idx) + offset
|
||||
adjusted, err := json.Marshal(event)
|
||||
if err != nil {
|
||||
return data, true
|
||||
}
|
||||
return adjusted, true
|
||||
}
|
||||
}
|
||||
|
||||
// Pass through all other events unchanged (message_delta, message_stop, ping, etc.)
|
||||
return data, true
|
||||
}
|
||||
|
||||
// AdjustSSEChunk processes a raw SSE chunk (potentially containing multiple "event:/data:" pairs)
|
||||
// and adjusts content block indices. Suppresses duplicate message_start events.
|
||||
// Returns the adjusted chunk and whether it should be forwarded.
|
||||
func AdjustSSEChunk(chunk []byte, offset int) ([]byte, bool) {
|
||||
chunkStr := string(chunk)
|
||||
|
||||
// Fast path: if no "data:" prefix, pass through
|
||||
if !strings.Contains(chunkStr, "data: ") {
|
||||
return chunk, true
|
||||
}
|
||||
|
||||
var result strings.Builder
|
||||
hasContent := false
|
||||
|
||||
lines := strings.Split(chunkStr, "\n")
|
||||
for i := 0; i < len(lines); i++ {
|
||||
line := lines[i]
|
||||
|
||||
if strings.HasPrefix(line, "data: ") {
|
||||
dataPayload := strings.TrimPrefix(line, "data: ")
|
||||
dataPayload = strings.TrimSpace(dataPayload)
|
||||
|
||||
if dataPayload == "[DONE]" {
|
||||
result.WriteString(line + "\n")
|
||||
hasContent = true
|
||||
continue
|
||||
}
|
||||
|
||||
adjusted, shouldForward := AdjustStreamIndices([]byte(dataPayload), offset)
|
||||
if !shouldForward {
|
||||
// Skip this event and its preceding "event:" line
|
||||
// Also skip the trailing empty line
|
||||
continue
|
||||
}
|
||||
|
||||
result.WriteString("data: " + string(adjusted) + "\n")
|
||||
hasContent = true
|
||||
} else if strings.HasPrefix(line, "event: ") {
|
||||
// Check if the next data line will be suppressed
|
||||
if i+1 < len(lines) && strings.HasPrefix(lines[i+1], "data: ") {
|
||||
dataPayload := strings.TrimPrefix(lines[i+1], "data: ")
|
||||
dataPayload = strings.TrimSpace(dataPayload)
|
||||
|
||||
var event map[string]interface{}
|
||||
if err := json.Unmarshal([]byte(dataPayload), &event); err == nil {
|
||||
if eventType, ok := event["type"].(string); ok && eventType == "message_start" {
|
||||
// Skip both the event: and data: lines
|
||||
i++ // skip the data: line too
|
||||
continue
|
||||
}
|
||||
}
|
||||
}
|
||||
result.WriteString(line + "\n")
|
||||
hasContent = true
|
||||
} else {
|
||||
result.WriteString(line + "\n")
|
||||
if strings.TrimSpace(line) != "" {
|
||||
hasContent = true
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
if !hasContent {
|
||||
return nil, false
|
||||
}
|
||||
|
||||
return []byte(result.String()), true
|
||||
}
|
||||
|
||||
// BufferedStreamResult contains the analysis of buffered SSE chunks from a Kiro API response.
|
||||
type BufferedStreamResult struct {
|
||||
// StopReason is the detected stop_reason from the stream (e.g., "end_turn", "tool_use")
|
||||
StopReason string
|
||||
// WebSearchQuery is the extracted query if the model requested another web_search
|
||||
WebSearchQuery string
|
||||
// WebSearchToolUseId is the tool_use ID from the model's response (needed for toolResults)
|
||||
WebSearchToolUseId string
|
||||
// HasWebSearchToolUse indicates whether the model requested web_search
|
||||
HasWebSearchToolUse bool
|
||||
// WebSearchToolUseIndex is the content_block index of the web_search tool_use
|
||||
WebSearchToolUseIndex int
|
||||
}
|
||||
|
||||
// AnalyzeBufferedStream scans buffered SSE chunks to detect stop_reason and web_search tool_use.
|
||||
// This is used in the search loop to determine if the model wants another search round.
|
||||
func AnalyzeBufferedStream(chunks [][]byte) BufferedStreamResult {
|
||||
result := BufferedStreamResult{WebSearchToolUseIndex: -1}
|
||||
|
||||
// Track tool use state across chunks
|
||||
var currentToolName string
|
||||
var currentToolIndex int = -1
|
||||
var toolInputBuilder strings.Builder
|
||||
|
||||
for _, chunk := range chunks {
|
||||
chunkStr := string(chunk)
|
||||
lines := strings.Split(chunkStr, "\n")
|
||||
for _, line := range lines {
|
||||
if !strings.HasPrefix(line, "data: ") {
|
||||
continue
|
||||
}
|
||||
dataPayload := strings.TrimPrefix(line, "data: ")
|
||||
dataPayload = strings.TrimSpace(dataPayload)
|
||||
if dataPayload == "[DONE]" || dataPayload == "" {
|
||||
continue
|
||||
}
|
||||
|
||||
var event map[string]interface{}
|
||||
if err := json.Unmarshal([]byte(dataPayload), &event); err != nil {
|
||||
continue
|
||||
}
|
||||
|
||||
eventType, _ := event["type"].(string)
|
||||
|
||||
switch eventType {
|
||||
case "message_delta":
|
||||
// Extract stop_reason from message_delta
|
||||
if delta, ok := event["delta"].(map[string]interface{}); ok {
|
||||
if sr, ok := delta["stop_reason"].(string); ok && sr != "" {
|
||||
result.StopReason = sr
|
||||
}
|
||||
}
|
||||
|
||||
case "content_block_start":
|
||||
// Detect tool_use content blocks
|
||||
if cb, ok := event["content_block"].(map[string]interface{}); ok {
|
||||
if cbType, ok := cb["type"].(string); ok && cbType == "tool_use" {
|
||||
if name, ok := cb["name"].(string); ok {
|
||||
currentToolName = strings.ToLower(name)
|
||||
if idx, ok := event["index"].(float64); ok {
|
||||
currentToolIndex = int(idx)
|
||||
}
|
||||
// Capture tool use ID for toolResults handshake
|
||||
if id, ok := cb["id"].(string); ok {
|
||||
result.WebSearchToolUseId = id
|
||||
}
|
||||
toolInputBuilder.Reset()
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
case "content_block_delta":
|
||||
// Accumulate tool input JSON
|
||||
if currentToolName != "" {
|
||||
if delta, ok := event["delta"].(map[string]interface{}); ok {
|
||||
if deltaType, ok := delta["type"].(string); ok && deltaType == "input_json_delta" {
|
||||
if partial, ok := delta["partial_json"].(string); ok {
|
||||
toolInputBuilder.WriteString(partial)
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
case "content_block_stop":
|
||||
// Finalize tool use detection
|
||||
if currentToolName == "web_search" || currentToolName == "websearch" || currentToolName == "remote_web_search" {
|
||||
result.HasWebSearchToolUse = true
|
||||
result.WebSearchToolUseIndex = currentToolIndex
|
||||
// Extract query from accumulated input JSON
|
||||
inputJSON := toolInputBuilder.String()
|
||||
var input map[string]string
|
||||
if err := json.Unmarshal([]byte(inputJSON), &input); err == nil {
|
||||
if q, ok := input["query"]; ok {
|
||||
result.WebSearchQuery = q
|
||||
}
|
||||
}
|
||||
log.Debugf("kiro/websearch: detected web_search tool_use")
|
||||
}
|
||||
currentToolName = ""
|
||||
currentToolIndex = -1
|
||||
toolInputBuilder.Reset()
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
return result
|
||||
}
|
||||
|
||||
// FilterChunksForClient processes buffered SSE chunks and removes web_search tool_use
|
||||
// content blocks. This prevents the client from seeing "Tool use" prompts for web_search
|
||||
// when the proxy is handling the search loop internally.
|
||||
// Also suppresses message_start and message_delta/message_stop events since those
|
||||
// are managed by the outer handleWebSearchStream.
|
||||
func FilterChunksForClient(chunks [][]byte, wsToolIndex int, indexOffset int) [][]byte {
|
||||
var filtered [][]byte
|
||||
|
||||
for _, chunk := range chunks {
|
||||
chunkStr := string(chunk)
|
||||
lines := strings.Split(chunkStr, "\n")
|
||||
|
||||
var resultBuilder strings.Builder
|
||||
hasContent := false
|
||||
|
||||
for i := 0; i < len(lines); i++ {
|
||||
line := lines[i]
|
||||
|
||||
if strings.HasPrefix(line, "data: ") {
|
||||
dataPayload := strings.TrimPrefix(line, "data: ")
|
||||
dataPayload = strings.TrimSpace(dataPayload)
|
||||
|
||||
if dataPayload == "[DONE]" {
|
||||
// Skip [DONE] — the outer loop manages stream termination
|
||||
continue
|
||||
}
|
||||
|
||||
var event map[string]interface{}
|
||||
if err := json.Unmarshal([]byte(dataPayload), &event); err != nil {
|
||||
resultBuilder.WriteString(line + "\n")
|
||||
hasContent = true
|
||||
continue
|
||||
}
|
||||
|
||||
eventType, _ := event["type"].(string)
|
||||
|
||||
// Skip message_start (outer loop sends its own)
|
||||
if eventType == "message_start" {
|
||||
continue
|
||||
}
|
||||
|
||||
// Skip message_delta and message_stop (outer loop manages these)
|
||||
if eventType == "message_delta" || eventType == "message_stop" {
|
||||
continue
|
||||
}
|
||||
|
||||
// Check if this event belongs to the web_search tool_use block
|
||||
if wsToolIndex >= 0 {
|
||||
if idx, ok := event["index"].(float64); ok && int(idx) == wsToolIndex {
|
||||
// Skip events for the web_search tool_use block
|
||||
continue
|
||||
}
|
||||
}
|
||||
|
||||
// Apply index offset for remaining events
|
||||
if indexOffset > 0 {
|
||||
switch eventType {
|
||||
case "content_block_start", "content_block_delta", "content_block_stop":
|
||||
if idx, ok := event["index"].(float64); ok {
|
||||
event["index"] = int(idx) + indexOffset
|
||||
adjusted, err := json.Marshal(event)
|
||||
if err == nil {
|
||||
resultBuilder.WriteString("data: " + string(adjusted) + "\n")
|
||||
hasContent = true
|
||||
continue
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
resultBuilder.WriteString(line + "\n")
|
||||
hasContent = true
|
||||
} else if strings.HasPrefix(line, "event: ") {
|
||||
// Check if the next data line will be suppressed
|
||||
if i+1 < len(lines) && strings.HasPrefix(lines[i+1], "data: ") {
|
||||
nextData := strings.TrimPrefix(lines[i+1], "data: ")
|
||||
nextData = strings.TrimSpace(nextData)
|
||||
|
||||
var nextEvent map[string]interface{}
|
||||
if err := json.Unmarshal([]byte(nextData), &nextEvent); err == nil {
|
||||
nextType, _ := nextEvent["type"].(string)
|
||||
if nextType == "message_start" || nextType == "message_delta" || nextType == "message_stop" {
|
||||
i++ // skip the data line
|
||||
continue
|
||||
}
|
||||
if wsToolIndex >= 0 {
|
||||
if idx, ok := nextEvent["index"].(float64); ok && int(idx) == wsToolIndex {
|
||||
i++ // skip the data line
|
||||
continue
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
resultBuilder.WriteString(line + "\n")
|
||||
hasContent = true
|
||||
} else {
|
||||
resultBuilder.WriteString(line + "\n")
|
||||
if strings.TrimSpace(line) != "" {
|
||||
hasContent = true
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
if hasContent {
|
||||
filtered = append(filtered, []byte(resultBuilder.String()))
|
||||
}
|
||||
}
|
||||
|
||||
return filtered
|
||||
}
|
||||
495
internal/translator/kiro/claude/kiro_websearch.go
Normal file
495
internal/translator/kiro/claude/kiro_websearch.go
Normal file
@@ -0,0 +1,495 @@
|
||||
// Package claude provides web search functionality for Kiro translator.
|
||||
// This file implements detection, MCP request/response types, and pure data
|
||||
// transformation utilities for web search. SSE event generation, stream analysis,
|
||||
// and HTTP I/O logic reside in the executor package (kiro_executor.go).
|
||||
package claude
|
||||
|
||||
import (
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"strings"
|
||||
"sync/atomic"
|
||||
"time"
|
||||
|
||||
"github.com/google/uuid"
|
||||
log "github.com/sirupsen/logrus"
|
||||
"github.com/tidwall/gjson"
|
||||
"github.com/tidwall/sjson"
|
||||
)
|
||||
|
||||
// cachedToolDescription stores the dynamically-fetched web_search tool description.
|
||||
// Written by the executor via SetWebSearchDescription, read by the translator
|
||||
// when building the remote_web_search tool for Kiro API requests.
|
||||
var cachedToolDescription atomic.Value // stores string
|
||||
|
||||
// GetWebSearchDescription returns the cached web_search tool description,
|
||||
// or empty string if not yet fetched. Lock-free via atomic.Value.
|
||||
func GetWebSearchDescription() string {
|
||||
if v := cachedToolDescription.Load(); v != nil {
|
||||
return v.(string)
|
||||
}
|
||||
return ""
|
||||
}
|
||||
|
||||
// SetWebSearchDescription stores the dynamically-fetched web_search tool description.
|
||||
// Called by the executor after fetching from MCP tools/list.
|
||||
func SetWebSearchDescription(desc string) {
|
||||
cachedToolDescription.Store(desc)
|
||||
}
|
||||
|
||||
// McpRequest represents a JSON-RPC 2.0 request to Kiro MCP API
|
||||
type McpRequest struct {
|
||||
ID string `json:"id"`
|
||||
JSONRPC string `json:"jsonrpc"`
|
||||
Method string `json:"method"`
|
||||
Params McpParams `json:"params"`
|
||||
}
|
||||
|
||||
// McpParams represents MCP request parameters
|
||||
type McpParams struct {
|
||||
Name string `json:"name"`
|
||||
Arguments McpArguments `json:"arguments"`
|
||||
}
|
||||
|
||||
// McpArgumentsMeta represents the _meta field in MCP arguments
|
||||
type McpArgumentsMeta struct {
|
||||
IsValid bool `json:"_isValid"`
|
||||
ActivePath []string `json:"_activePath"`
|
||||
CompletedPaths [][]string `json:"_completedPaths"`
|
||||
}
|
||||
|
||||
// McpArguments represents MCP request arguments
|
||||
type McpArguments struct {
|
||||
Query string `json:"query"`
|
||||
Meta *McpArgumentsMeta `json:"_meta,omitempty"`
|
||||
}
|
||||
|
||||
// McpResponse represents a JSON-RPC 2.0 response from Kiro MCP API
|
||||
type McpResponse struct {
|
||||
Error *McpError `json:"error,omitempty"`
|
||||
ID string `json:"id"`
|
||||
JSONRPC string `json:"jsonrpc"`
|
||||
Result *McpResult `json:"result,omitempty"`
|
||||
}
|
||||
|
||||
// McpError represents an MCP error
|
||||
type McpError struct {
|
||||
Code *int `json:"code,omitempty"`
|
||||
Message *string `json:"message,omitempty"`
|
||||
}
|
||||
|
||||
// McpResult represents MCP result
|
||||
type McpResult struct {
|
||||
Content []McpContent `json:"content"`
|
||||
IsError bool `json:"isError"`
|
||||
}
|
||||
|
||||
// McpContent represents MCP content item
|
||||
type McpContent struct {
|
||||
ContentType string `json:"type"`
|
||||
Text string `json:"text"`
|
||||
}
|
||||
|
||||
// WebSearchResults represents parsed search results
|
||||
type WebSearchResults struct {
|
||||
Results []WebSearchResult `json:"results"`
|
||||
TotalResults *int `json:"totalResults,omitempty"`
|
||||
Query *string `json:"query,omitempty"`
|
||||
Error *string `json:"error,omitempty"`
|
||||
}
|
||||
|
||||
// WebSearchResult represents a single search result
|
||||
type WebSearchResult struct {
|
||||
Title string `json:"title"`
|
||||
URL string `json:"url"`
|
||||
Snippet *string `json:"snippet,omitempty"`
|
||||
PublishedDate *int64 `json:"publishedDate,omitempty"`
|
||||
ID *string `json:"id,omitempty"`
|
||||
Domain *string `json:"domain,omitempty"`
|
||||
MaxVerbatimWordLimit *int `json:"maxVerbatimWordLimit,omitempty"`
|
||||
PublicDomain *bool `json:"publicDomain,omitempty"`
|
||||
}
|
||||
|
||||
// isWebSearchTool checks if a tool name or type indicates a web_search tool.
|
||||
func isWebSearchTool(name, toolType string) bool {
|
||||
return name == "web_search" ||
|
||||
strings.HasPrefix(toolType, "web_search") ||
|
||||
toolType == "web_search_20250305"
|
||||
}
|
||||
|
||||
// HasWebSearchTool checks if the request contains ONLY a web_search tool.
|
||||
// Returns true only if tools array has exactly one tool named "web_search".
|
||||
// Only intercept pure web_search requests (single-tool array).
|
||||
func HasWebSearchTool(body []byte) bool {
|
||||
tools := gjson.GetBytes(body, "tools")
|
||||
if !tools.IsArray() {
|
||||
return false
|
||||
}
|
||||
|
||||
toolsArray := tools.Array()
|
||||
if len(toolsArray) != 1 {
|
||||
return false
|
||||
}
|
||||
|
||||
// Check if the single tool is web_search
|
||||
tool := toolsArray[0]
|
||||
|
||||
// Check both name and type fields for web_search detection
|
||||
name := strings.ToLower(tool.Get("name").String())
|
||||
toolType := strings.ToLower(tool.Get("type").String())
|
||||
|
||||
return isWebSearchTool(name, toolType)
|
||||
}
|
||||
|
||||
// ExtractSearchQuery extracts the search query from the request.
|
||||
// Reads messages[0].content and removes "Perform a web search for the query: " prefix.
|
||||
func ExtractSearchQuery(body []byte) string {
|
||||
messages := gjson.GetBytes(body, "messages")
|
||||
if !messages.IsArray() || len(messages.Array()) == 0 {
|
||||
return ""
|
||||
}
|
||||
|
||||
firstMsg := messages.Array()[0]
|
||||
content := firstMsg.Get("content")
|
||||
|
||||
var text string
|
||||
if content.IsArray() {
|
||||
// Array format: [{"type": "text", "text": "..."}]
|
||||
for _, block := range content.Array() {
|
||||
if block.Get("type").String() == "text" {
|
||||
text = block.Get("text").String()
|
||||
break
|
||||
}
|
||||
}
|
||||
} else {
|
||||
// String format
|
||||
text = content.String()
|
||||
}
|
||||
|
||||
// Remove prefix "Perform a web search for the query: "
|
||||
const prefix = "Perform a web search for the query: "
|
||||
if strings.HasPrefix(text, prefix) {
|
||||
text = text[len(prefix):]
|
||||
}
|
||||
|
||||
return strings.TrimSpace(text)
|
||||
}
|
||||
|
||||
// generateRandomID8 generates an 8-character random lowercase alphanumeric string
|
||||
func generateRandomID8() string {
|
||||
u := uuid.New()
|
||||
return strings.ToLower(strings.ReplaceAll(u.String(), "-", "")[:8])
|
||||
}
|
||||
|
||||
// CreateMcpRequest creates an MCP request for web search.
|
||||
// Returns (toolUseID, McpRequest)
|
||||
// ID format: web_search_tooluse_{22 random}_{timestamp_millis}_{8 random}
|
||||
func CreateMcpRequest(query string) (string, *McpRequest) {
|
||||
random22 := GenerateToolUseID()
|
||||
timestamp := time.Now().UnixMilli()
|
||||
random8 := generateRandomID8()
|
||||
|
||||
requestID := fmt.Sprintf("web_search_tooluse_%s_%d_%s", random22, timestamp, random8)
|
||||
|
||||
// tool_use_id format: srvtoolu_{32 hex chars}
|
||||
toolUseID := "srvtoolu_" + strings.ReplaceAll(uuid.New().String(), "-", "")[:32]
|
||||
|
||||
request := &McpRequest{
|
||||
ID: requestID,
|
||||
JSONRPC: "2.0",
|
||||
Method: "tools/call",
|
||||
Params: McpParams{
|
||||
Name: "web_search",
|
||||
Arguments: McpArguments{
|
||||
Query: query,
|
||||
Meta: &McpArgumentsMeta{
|
||||
IsValid: true,
|
||||
ActivePath: []string{"query"},
|
||||
CompletedPaths: [][]string{{"query"}},
|
||||
},
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
return toolUseID, request
|
||||
}
|
||||
|
||||
// GenerateToolUseID generates a Kiro-style tool use ID (base62-like UUID)
|
||||
func GenerateToolUseID() string {
|
||||
return strings.ReplaceAll(uuid.New().String(), "-", "")[:22]
|
||||
}
|
||||
|
||||
// ReplaceWebSearchToolDescription replaces the web_search tool description with
|
||||
// a minimal version that allows re-search without the restrictive "do not search
|
||||
// non-coding topics" instruction from the original Kiro tools/list response.
|
||||
// This keeps the tool available so the model can request additional searches.
|
||||
func ReplaceWebSearchToolDescription(body []byte) ([]byte, error) {
|
||||
tools := gjson.GetBytes(body, "tools")
|
||||
if !tools.IsArray() {
|
||||
return body, nil
|
||||
}
|
||||
|
||||
var updated []json.RawMessage
|
||||
for _, tool := range tools.Array() {
|
||||
name := strings.ToLower(tool.Get("name").String())
|
||||
toolType := strings.ToLower(tool.Get("type").String())
|
||||
|
||||
if isWebSearchTool(name, toolType) {
|
||||
// Replace with a minimal web_search tool definition
|
||||
minimalTool := map[string]interface{}{
|
||||
"name": "web_search",
|
||||
"description": "Search the web for information. Use this when the previous search results are insufficient or when you need additional information on a different aspect of the query. Provide a refined or different search query.",
|
||||
"input_schema": map[string]interface{}{
|
||||
"type": "object",
|
||||
"properties": map[string]interface{}{
|
||||
"query": map[string]interface{}{
|
||||
"type": "string",
|
||||
"description": "The search query to execute",
|
||||
},
|
||||
},
|
||||
"required": []string{"query"},
|
||||
"additionalProperties": false,
|
||||
},
|
||||
}
|
||||
minimalJSON, err := json.Marshal(minimalTool)
|
||||
if err != nil {
|
||||
return body, fmt.Errorf("failed to marshal minimal tool: %w", err)
|
||||
}
|
||||
updated = append(updated, json.RawMessage(minimalJSON))
|
||||
} else {
|
||||
updated = append(updated, json.RawMessage(tool.Raw))
|
||||
}
|
||||
}
|
||||
|
||||
updatedJSON, err := json.Marshal(updated)
|
||||
if err != nil {
|
||||
return body, fmt.Errorf("failed to marshal updated tools: %w", err)
|
||||
}
|
||||
result, err := sjson.SetRawBytes(body, "tools", updatedJSON)
|
||||
if err != nil {
|
||||
return body, fmt.Errorf("failed to set updated tools: %w", err)
|
||||
}
|
||||
|
||||
return result, nil
|
||||
}
|
||||
|
||||
// FormatSearchContextPrompt formats search results as a structured text block
|
||||
// for injection into the system prompt.
|
||||
func FormatSearchContextPrompt(query string, results *WebSearchResults) string {
|
||||
var sb strings.Builder
|
||||
sb.WriteString(fmt.Sprintf("[Web Search Results for \"%s\"]\n", query))
|
||||
|
||||
if results != nil && len(results.Results) > 0 {
|
||||
for i, r := range results.Results {
|
||||
sb.WriteString(fmt.Sprintf("%d. %s - %s\n", i+1, r.Title, r.URL))
|
||||
if r.Snippet != nil && *r.Snippet != "" {
|
||||
snippet := *r.Snippet
|
||||
if len(snippet) > 500 {
|
||||
snippet = snippet[:500] + "..."
|
||||
}
|
||||
sb.WriteString(fmt.Sprintf(" %s\n", snippet))
|
||||
}
|
||||
}
|
||||
} else {
|
||||
sb.WriteString("No results found.\n")
|
||||
}
|
||||
|
||||
sb.WriteString("[End Web Search Results]")
|
||||
return sb.String()
|
||||
}
|
||||
|
||||
// FormatToolResultText formats search results as JSON text for the toolResults content field.
|
||||
// This matches the format observed in Kiro IDE HAR captures.
|
||||
func FormatToolResultText(results *WebSearchResults) string {
|
||||
if results == nil || len(results.Results) == 0 {
|
||||
return "No search results found."
|
||||
}
|
||||
|
||||
text := fmt.Sprintf("Found %d search result(s):\n\n", len(results.Results))
|
||||
resultJSON, err := json.MarshalIndent(results.Results, "", " ")
|
||||
if err != nil {
|
||||
return text + "Error formatting results."
|
||||
}
|
||||
return text + string(resultJSON)
|
||||
}
|
||||
|
||||
// InjectToolResultsClaude modifies a Claude-format JSON payload to append
|
||||
// tool_use (assistant) and tool_result (user) messages to the messages array.
|
||||
// BuildKiroPayload correctly translates:
|
||||
// - assistant tool_use → KiroAssistantResponseMessage.toolUses
|
||||
// - user tool_result → KiroUserInputMessageContext.toolResults
|
||||
//
|
||||
// This produces the exact same GAR request format as the Kiro IDE (HAR captures).
|
||||
// IMPORTANT: The web_search tool must remain in the "tools" array for this to work.
|
||||
// Use ReplaceWebSearchToolDescription to keep the tool available with a minimal description.
|
||||
func InjectToolResultsClaude(claudePayload []byte, toolUseId, query string, results *WebSearchResults) ([]byte, error) {
|
||||
var payload map[string]interface{}
|
||||
if err := json.Unmarshal(claudePayload, &payload); err != nil {
|
||||
return claudePayload, fmt.Errorf("failed to parse claude payload: %w", err)
|
||||
}
|
||||
|
||||
messages, _ := payload["messages"].([]interface{})
|
||||
|
||||
// 1. Append assistant message with tool_use (matches HAR: assistantResponseMessage.toolUses)
|
||||
assistantMsg := map[string]interface{}{
|
||||
"role": "assistant",
|
||||
"content": []interface{}{
|
||||
map[string]interface{}{
|
||||
"type": "tool_use",
|
||||
"id": toolUseId,
|
||||
"name": "web_search",
|
||||
"input": map[string]interface{}{"query": query},
|
||||
},
|
||||
},
|
||||
}
|
||||
messages = append(messages, assistantMsg)
|
||||
|
||||
// 2. Append user message with tool_result + search behavior instructions.
|
||||
// NOTE: We embed search instructions HERE (not in system prompt) because
|
||||
// BuildKiroPayload clears the system prompt when len(history) > 0,
|
||||
// which is always true after injecting assistant + user messages.
|
||||
now := time.Now()
|
||||
searchGuidance := fmt.Sprintf(`<search_guidance>
|
||||
Current date: %s (%s)
|
||||
|
||||
IMPORTANT: Evaluate the search results above carefully. If the results are:
|
||||
- Mostly spam, SEO junk, or unrelated websites
|
||||
- Missing actual information about the query topic
|
||||
- Outdated or not matching the requested time frame
|
||||
|
||||
Then you MUST use the web_search tool again with a refined query. Try:
|
||||
- Rephrasing in English for better coverage
|
||||
- Using more specific keywords
|
||||
- Adding date context
|
||||
|
||||
Do NOT apologize for bad results without first attempting a re-search.
|
||||
</search_guidance>`, now.Format("January 2, 2006"), now.Format("Monday"))
|
||||
|
||||
userMsg := map[string]interface{}{
|
||||
"role": "user",
|
||||
"content": []interface{}{
|
||||
map[string]interface{}{
|
||||
"type": "tool_result",
|
||||
"tool_use_id": toolUseId,
|
||||
"content": FormatToolResultText(results),
|
||||
},
|
||||
map[string]interface{}{
|
||||
"type": "text",
|
||||
"text": searchGuidance,
|
||||
},
|
||||
},
|
||||
}
|
||||
messages = append(messages, userMsg)
|
||||
|
||||
payload["messages"] = messages
|
||||
|
||||
result, err := json.Marshal(payload)
|
||||
if err != nil {
|
||||
return claudePayload, fmt.Errorf("failed to marshal updated payload: %w", err)
|
||||
}
|
||||
|
||||
log.Infof("kiro/websearch: injected tool_use+tool_result (toolUseId=%s, messages=%d)",
|
||||
toolUseId, len(messages))
|
||||
|
||||
return result, nil
|
||||
}
|
||||
|
||||
// InjectSearchIndicatorsInResponse prepends server_tool_use + web_search_tool_result
|
||||
// content blocks into a non-streaming Claude JSON response. Claude Code counts
|
||||
// server_tool_use blocks to display "Did X searches in Ys".
|
||||
//
|
||||
// Input response: {"content": [{"type":"text","text":"..."}], ...}
|
||||
// Output response: {"content": [{"type":"server_tool_use",...}, {"type":"web_search_tool_result",...}, {"type":"text","text":"..."}], ...}
|
||||
func InjectSearchIndicatorsInResponse(responsePayload []byte, searches []SearchIndicator) ([]byte, error) {
|
||||
if len(searches) == 0 {
|
||||
return responsePayload, nil
|
||||
}
|
||||
|
||||
var resp map[string]interface{}
|
||||
if err := json.Unmarshal(responsePayload, &resp); err != nil {
|
||||
return responsePayload, fmt.Errorf("failed to parse response: %w", err)
|
||||
}
|
||||
|
||||
existingContent, _ := resp["content"].([]interface{})
|
||||
|
||||
// Build new content: search indicators first, then existing content
|
||||
newContent := make([]interface{}, 0, len(searches)*2+len(existingContent))
|
||||
|
||||
for _, s := range searches {
|
||||
// server_tool_use block
|
||||
newContent = append(newContent, map[string]interface{}{
|
||||
"type": "server_tool_use",
|
||||
"id": s.ToolUseID,
|
||||
"name": "web_search",
|
||||
"input": map[string]interface{}{"query": s.Query},
|
||||
})
|
||||
|
||||
// web_search_tool_result block
|
||||
searchContent := make([]map[string]interface{}, 0)
|
||||
if s.Results != nil {
|
||||
for _, r := range s.Results.Results {
|
||||
snippet := ""
|
||||
if r.Snippet != nil {
|
||||
snippet = *r.Snippet
|
||||
}
|
||||
searchContent = append(searchContent, map[string]interface{}{
|
||||
"type": "web_search_result",
|
||||
"title": r.Title,
|
||||
"url": r.URL,
|
||||
"encrypted_content": snippet,
|
||||
"page_age": nil,
|
||||
})
|
||||
}
|
||||
}
|
||||
newContent = append(newContent, map[string]interface{}{
|
||||
"type": "web_search_tool_result",
|
||||
"tool_use_id": s.ToolUseID,
|
||||
"content": searchContent,
|
||||
})
|
||||
}
|
||||
|
||||
// Append existing content blocks
|
||||
newContent = append(newContent, existingContent...)
|
||||
resp["content"] = newContent
|
||||
|
||||
result, err := json.Marshal(resp)
|
||||
if err != nil {
|
||||
return responsePayload, fmt.Errorf("failed to marshal response: %w", err)
|
||||
}
|
||||
|
||||
log.Infof("kiro/websearch: injected %d search indicator(s) into non-stream response", len(searches))
|
||||
return result, nil
|
||||
}
|
||||
|
||||
// SearchIndicator holds the data for one search operation to inject into a response.
|
||||
type SearchIndicator struct {
|
||||
ToolUseID string
|
||||
Query string
|
||||
Results *WebSearchResults
|
||||
}
|
||||
|
||||
// BuildMcpEndpoint constructs the MCP endpoint URL for the given AWS region.
|
||||
// Centralizes the URL pattern used by both handleWebSearch and handleWebSearchStream.
|
||||
func BuildMcpEndpoint(region string) string {
|
||||
return fmt.Sprintf("https://q.%s.amazonaws.com/mcp", region)
|
||||
}
|
||||
|
||||
// ParseSearchResults extracts WebSearchResults from MCP response
|
||||
func ParseSearchResults(response *McpResponse) *WebSearchResults {
|
||||
if response == nil || response.Result == nil || len(response.Result.Content) == 0 {
|
||||
return nil
|
||||
}
|
||||
|
||||
content := response.Result.Content[0]
|
||||
if content.ContentType != "text" {
|
||||
return nil
|
||||
}
|
||||
|
||||
var results WebSearchResults
|
||||
if err := json.Unmarshal([]byte(content.Text), &results); err != nil {
|
||||
log.Warnf("kiro/websearch: failed to parse search results: %v", err)
|
||||
return nil
|
||||
}
|
||||
|
||||
return &results
|
||||
}
|
||||
@@ -36,8 +36,14 @@ func MergeAdjacentMessages(messages []gjson.Result) []gjson.Result {
|
||||
if currentRole == lastRole {
|
||||
// Merge content from current message into last message
|
||||
mergedContent := mergeMessageContent(lastMsg, msg)
|
||||
// Create a new merged message JSON
|
||||
mergedMsg := createMergedMessage(lastRole, mergedContent)
|
||||
var mergedToolCalls []interface{}
|
||||
if currentRole == "assistant" {
|
||||
// Preserve assistant tool_calls when adjacent assistant messages are merged.
|
||||
mergedToolCalls = mergeToolCalls(lastMsg.Get("tool_calls"), msg.Get("tool_calls"))
|
||||
}
|
||||
|
||||
// Create a new merged message JSON.
|
||||
mergedMsg := createMergedMessage(lastRole, mergedContent, mergedToolCalls)
|
||||
merged[len(merged)-1] = gjson.Parse(mergedMsg)
|
||||
} else {
|
||||
merged = append(merged, msg)
|
||||
@@ -121,12 +127,34 @@ func blockToMap(block gjson.Result) map[string]interface{} {
|
||||
return result
|
||||
}
|
||||
|
||||
// createMergedMessage creates a JSON string for a merged message
|
||||
func createMergedMessage(role string, content string) string {
|
||||
// createMergedMessage creates a JSON string for a merged message.
|
||||
// toolCalls is optional and only emitted for assistant role.
|
||||
func createMergedMessage(role string, content string, toolCalls []interface{}) string {
|
||||
msg := map[string]interface{}{
|
||||
"role": role,
|
||||
"content": json.RawMessage(content),
|
||||
}
|
||||
if role == "assistant" && len(toolCalls) > 0 {
|
||||
msg["tool_calls"] = toolCalls
|
||||
}
|
||||
result, _ := json.Marshal(msg)
|
||||
return string(result)
|
||||
}
|
||||
}
|
||||
|
||||
// mergeToolCalls combines tool_calls from two assistant messages while preserving order.
|
||||
func mergeToolCalls(tc1, tc2 gjson.Result) []interface{} {
|
||||
var merged []interface{}
|
||||
|
||||
if tc1.IsArray() {
|
||||
for _, tc := range tc1.Array() {
|
||||
merged = append(merged, tc.Value())
|
||||
}
|
||||
}
|
||||
if tc2.IsArray() {
|
||||
for _, tc := range tc2.Array() {
|
||||
merged = append(merged, tc.Value())
|
||||
}
|
||||
}
|
||||
|
||||
return merged
|
||||
}
|
||||
|
||||
106
internal/translator/kiro/common/message_merge_test.go
Normal file
106
internal/translator/kiro/common/message_merge_test.go
Normal file
@@ -0,0 +1,106 @@
|
||||
package common
|
||||
|
||||
import (
|
||||
"strings"
|
||||
"testing"
|
||||
|
||||
"github.com/tidwall/gjson"
|
||||
)
|
||||
|
||||
func parseMessages(t *testing.T, raw string) []gjson.Result {
|
||||
t.Helper()
|
||||
parsed := gjson.Parse(raw)
|
||||
if !parsed.IsArray() {
|
||||
t.Fatalf("expected JSON array, got: %s", raw)
|
||||
}
|
||||
return parsed.Array()
|
||||
}
|
||||
|
||||
func TestMergeAdjacentMessages_AssistantMergePreservesToolCalls(t *testing.T) {
|
||||
messages := parseMessages(t, `[
|
||||
{"role":"assistant","content":"part1"},
|
||||
{
|
||||
"role":"assistant",
|
||||
"content":"part2",
|
||||
"tool_calls":[
|
||||
{
|
||||
"id":"call_1",
|
||||
"type":"function",
|
||||
"function":{"name":"Read","arguments":"{}"}
|
||||
}
|
||||
]
|
||||
},
|
||||
{"role":"tool","tool_call_id":"call_1","content":"ok"}
|
||||
]`)
|
||||
|
||||
merged := MergeAdjacentMessages(messages)
|
||||
if len(merged) != 2 {
|
||||
t.Fatalf("expected 2 messages after merge, got %d", len(merged))
|
||||
}
|
||||
|
||||
assistant := merged[0]
|
||||
if assistant.Get("role").String() != "assistant" {
|
||||
t.Fatalf("expected first message role assistant, got %q", assistant.Get("role").String())
|
||||
}
|
||||
|
||||
toolCalls := assistant.Get("tool_calls")
|
||||
if !toolCalls.IsArray() || len(toolCalls.Array()) != 1 {
|
||||
t.Fatalf("expected assistant.tool_calls length 1, got: %s", toolCalls.Raw)
|
||||
}
|
||||
if toolCalls.Array()[0].Get("id").String() != "call_1" {
|
||||
t.Fatalf("expected tool call id call_1, got %q", toolCalls.Array()[0].Get("id").String())
|
||||
}
|
||||
|
||||
contentRaw := assistant.Get("content").Raw
|
||||
if !strings.Contains(contentRaw, "part1") || !strings.Contains(contentRaw, "part2") {
|
||||
t.Fatalf("expected merged content to contain both parts, got: %s", contentRaw)
|
||||
}
|
||||
|
||||
if merged[1].Get("role").String() != "tool" {
|
||||
t.Fatalf("expected second message role tool, got %q", merged[1].Get("role").String())
|
||||
}
|
||||
}
|
||||
|
||||
func TestMergeAdjacentMessages_AssistantMergeCombinesMultipleToolCalls(t *testing.T) {
|
||||
messages := parseMessages(t, `[
|
||||
{
|
||||
"role":"assistant",
|
||||
"content":"first",
|
||||
"tool_calls":[
|
||||
{"id":"call_1","type":"function","function":{"name":"Read","arguments":"{}"}}
|
||||
]
|
||||
},
|
||||
{
|
||||
"role":"assistant",
|
||||
"content":"second",
|
||||
"tool_calls":[
|
||||
{"id":"call_2","type":"function","function":{"name":"Write","arguments":"{}"}}
|
||||
]
|
||||
}
|
||||
]`)
|
||||
|
||||
merged := MergeAdjacentMessages(messages)
|
||||
if len(merged) != 1 {
|
||||
t.Fatalf("expected 1 message after merge, got %d", len(merged))
|
||||
}
|
||||
|
||||
toolCalls := merged[0].Get("tool_calls").Array()
|
||||
if len(toolCalls) != 2 {
|
||||
t.Fatalf("expected 2 merged tool calls, got %d", len(toolCalls))
|
||||
}
|
||||
if toolCalls[0].Get("id").String() != "call_1" || toolCalls[1].Get("id").String() != "call_2" {
|
||||
t.Fatalf("unexpected merged tool call ids: %q, %q", toolCalls[0].Get("id").String(), toolCalls[1].Get("id").String())
|
||||
}
|
||||
}
|
||||
|
||||
func TestMergeAdjacentMessages_ToolMessagesRemainUnmerged(t *testing.T) {
|
||||
messages := parseMessages(t, `[
|
||||
{"role":"tool","tool_call_id":"call_1","content":"r1"},
|
||||
{"role":"tool","tool_call_id":"call_2","content":"r2"}
|
||||
]`)
|
||||
|
||||
merged := MergeAdjacentMessages(messages)
|
||||
if len(merged) != 2 {
|
||||
t.Fatalf("expected tool messages to remain separate, got %d", len(merged))
|
||||
}
|
||||
}
|
||||
@@ -75,6 +75,10 @@ func ConvertClaudeRequestToOpenAI(modelName string, inputRawJSON []byte, stream
|
||||
out, _ = sjson.Set(out, "reasoning_effort", effort)
|
||||
}
|
||||
}
|
||||
case "adaptive":
|
||||
// Claude adaptive means "enable with max capacity"; keep it as highest level
|
||||
// and let ApplyThinking normalize per target model capability.
|
||||
out, _ = sjson.Set(out, "reasoning_effort", string(thinking.LevelXHigh))
|
||||
case "disabled":
|
||||
if effort, ok := thinking.ConvertBudgetToLevel(0); ok && effort != "" {
|
||||
out, _ = sjson.Set(out, "reasoning_effort", effort)
|
||||
|
||||
@@ -428,8 +428,9 @@ func flattenTypeArrays(jsonStr string) string {
|
||||
|
||||
func removeUnsupportedKeywords(jsonStr string) string {
|
||||
keywords := append(unsupportedConstraints,
|
||||
"$schema", "$defs", "definitions", "const", "$ref", "additionalProperties",
|
||||
"propertyNames", // Gemini doesn't support property name validation
|
||||
"$schema", "$defs", "definitions", "const", "$ref", "$id", "additionalProperties",
|
||||
"propertyNames", "patternProperties", // Gemini doesn't support these schema keywords
|
||||
"enumTitles", "prefill", // Claude/OpenCode schema metadata fields unsupported by Gemini
|
||||
)
|
||||
|
||||
deletePaths := make([]string, 0)
|
||||
|
||||
@@ -870,6 +870,57 @@ func TestCleanJSONSchemaForAntigravity_BooleanEnumToString(t *testing.T) {
|
||||
}
|
||||
}
|
||||
|
||||
func TestCleanJSONSchemaForGemini_RemovesGeminiUnsupportedMetadataFields(t *testing.T) {
|
||||
input := `{
|
||||
"$schema": "http://json-schema.org/draft-07/schema#",
|
||||
"$id": "root-schema",
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"payload": {
|
||||
"type": "object",
|
||||
"prefill": "hello",
|
||||
"properties": {
|
||||
"mode": {
|
||||
"type": "string",
|
||||
"enum": ["a", "b"],
|
||||
"enumTitles": ["A", "B"]
|
||||
}
|
||||
},
|
||||
"patternProperties": {
|
||||
"^x-": {"type": "string"}
|
||||
}
|
||||
},
|
||||
"$id": {
|
||||
"type": "string",
|
||||
"description": "property name should not be removed"
|
||||
}
|
||||
}
|
||||
}`
|
||||
|
||||
expected := `{
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"payload": {
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"mode": {
|
||||
"type": "string",
|
||||
"enum": ["a", "b"],
|
||||
"description": "Allowed: a, b"
|
||||
}
|
||||
}
|
||||
},
|
||||
"$id": {
|
||||
"type": "string",
|
||||
"description": "property name should not be removed"
|
||||
}
|
||||
}
|
||||
}`
|
||||
|
||||
result := CleanJSONSchemaForGemini(input)
|
||||
compareJSON(t, expected, result)
|
||||
}
|
||||
|
||||
func TestRemoveExtensionFields(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
|
||||
@@ -5,6 +5,7 @@ import (
|
||||
"fmt"
|
||||
"os"
|
||||
"path/filepath"
|
||||
"strconv"
|
||||
"strings"
|
||||
"time"
|
||||
|
||||
@@ -92,6 +93,9 @@ func (s *FileSynthesizer) Synthesize(ctx *SynthesisContext) ([]*coreauth.Auth, e
|
||||
status = coreauth.StatusDisabled
|
||||
}
|
||||
|
||||
// Read per-account excluded models from the OAuth JSON file
|
||||
perAccountExcluded := extractExcludedModelsFromMetadata(metadata)
|
||||
|
||||
a := &coreauth.Auth{
|
||||
ID: id,
|
||||
Provider: provider,
|
||||
@@ -108,11 +112,23 @@ func (s *FileSynthesizer) Synthesize(ctx *SynthesisContext) ([]*coreauth.Auth, e
|
||||
CreatedAt: now,
|
||||
UpdatedAt: now,
|
||||
}
|
||||
ApplyAuthExcludedModelsMeta(a, cfg, nil, "oauth")
|
||||
// Read priority from auth file
|
||||
if rawPriority, ok := metadata["priority"]; ok {
|
||||
switch v := rawPriority.(type) {
|
||||
case float64:
|
||||
a.Attributes["priority"] = strconv.Itoa(int(v))
|
||||
case string:
|
||||
priority := strings.TrimSpace(v)
|
||||
if _, errAtoi := strconv.Atoi(priority); errAtoi == nil {
|
||||
a.Attributes["priority"] = priority
|
||||
}
|
||||
}
|
||||
}
|
||||
ApplyAuthExcludedModelsMeta(a, cfg, perAccountExcluded, "oauth")
|
||||
if provider == "gemini-cli" {
|
||||
if virtuals := SynthesizeGeminiVirtualAuths(a, metadata, now); len(virtuals) > 0 {
|
||||
for _, v := range virtuals {
|
||||
ApplyAuthExcludedModelsMeta(v, cfg, nil, "oauth")
|
||||
ApplyAuthExcludedModelsMeta(v, cfg, perAccountExcluded, "oauth")
|
||||
}
|
||||
out = append(out, a)
|
||||
out = append(out, virtuals...)
|
||||
@@ -167,6 +183,10 @@ func SynthesizeGeminiVirtualAuths(primary *coreauth.Auth, metadata map[string]an
|
||||
if authPath != "" {
|
||||
attrs["path"] = authPath
|
||||
}
|
||||
// Propagate priority from primary auth to virtual auths
|
||||
if priorityVal, hasPriority := primary.Attributes["priority"]; hasPriority && priorityVal != "" {
|
||||
attrs["priority"] = priorityVal
|
||||
}
|
||||
metadataCopy := map[string]any{
|
||||
"email": email,
|
||||
"project_id": projectID,
|
||||
@@ -239,3 +259,40 @@ func buildGeminiVirtualID(baseID, projectID string) string {
|
||||
replacer := strings.NewReplacer("/", "_", "\\", "_", " ", "_")
|
||||
return fmt.Sprintf("%s::%s", baseID, replacer.Replace(project))
|
||||
}
|
||||
|
||||
// extractExcludedModelsFromMetadata reads per-account excluded models from the OAuth JSON metadata.
|
||||
// Supports both "excluded_models" and "excluded-models" keys, and accepts both []string and []interface{}.
|
||||
func extractExcludedModelsFromMetadata(metadata map[string]any) []string {
|
||||
if metadata == nil {
|
||||
return nil
|
||||
}
|
||||
// Try both key formats
|
||||
raw, ok := metadata["excluded_models"]
|
||||
if !ok {
|
||||
raw, ok = metadata["excluded-models"]
|
||||
}
|
||||
if !ok || raw == nil {
|
||||
return nil
|
||||
}
|
||||
var stringSlice []string
|
||||
switch v := raw.(type) {
|
||||
case []string:
|
||||
stringSlice = v
|
||||
case []interface{}:
|
||||
stringSlice = make([]string, 0, len(v))
|
||||
for _, item := range v {
|
||||
if s, ok := item.(string); ok {
|
||||
stringSlice = append(stringSlice, s)
|
||||
}
|
||||
}
|
||||
default:
|
||||
return nil
|
||||
}
|
||||
result := make([]string, 0, len(stringSlice))
|
||||
for _, s := range stringSlice {
|
||||
if trimmed := strings.TrimSpace(s); trimmed != "" {
|
||||
result = append(result, trimmed)
|
||||
}
|
||||
}
|
||||
return result
|
||||
}
|
||||
|
||||
@@ -297,6 +297,117 @@ func TestFileSynthesizer_Synthesize_PrefixValidation(t *testing.T) {
|
||||
}
|
||||
}
|
||||
|
||||
func TestFileSynthesizer_Synthesize_PriorityParsing(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
priority any
|
||||
want string
|
||||
hasValue bool
|
||||
}{
|
||||
{
|
||||
name: "string with spaces",
|
||||
priority: " 10 ",
|
||||
want: "10",
|
||||
hasValue: true,
|
||||
},
|
||||
{
|
||||
name: "number",
|
||||
priority: 8,
|
||||
want: "8",
|
||||
hasValue: true,
|
||||
},
|
||||
{
|
||||
name: "invalid string",
|
||||
priority: "1x",
|
||||
hasValue: false,
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
tempDir := t.TempDir()
|
||||
authData := map[string]any{
|
||||
"type": "claude",
|
||||
"priority": tt.priority,
|
||||
}
|
||||
data, _ := json.Marshal(authData)
|
||||
errWriteFile := os.WriteFile(filepath.Join(tempDir, "auth.json"), data, 0644)
|
||||
if errWriteFile != nil {
|
||||
t.Fatalf("failed to write auth file: %v", errWriteFile)
|
||||
}
|
||||
|
||||
synth := NewFileSynthesizer()
|
||||
ctx := &SynthesisContext{
|
||||
Config: &config.Config{},
|
||||
AuthDir: tempDir,
|
||||
Now: time.Now(),
|
||||
IDGenerator: NewStableIDGenerator(),
|
||||
}
|
||||
|
||||
auths, errSynthesize := synth.Synthesize(ctx)
|
||||
if errSynthesize != nil {
|
||||
t.Fatalf("unexpected error: %v", errSynthesize)
|
||||
}
|
||||
if len(auths) != 1 {
|
||||
t.Fatalf("expected 1 auth, got %d", len(auths))
|
||||
}
|
||||
|
||||
value, ok := auths[0].Attributes["priority"]
|
||||
if tt.hasValue {
|
||||
if !ok {
|
||||
t.Fatal("expected priority attribute to be set")
|
||||
}
|
||||
if value != tt.want {
|
||||
t.Fatalf("expected priority %q, got %q", tt.want, value)
|
||||
}
|
||||
return
|
||||
}
|
||||
if ok {
|
||||
t.Fatalf("expected priority attribute to be absent, got %q", value)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestFileSynthesizer_Synthesize_OAuthExcludedModelsMerged(t *testing.T) {
|
||||
tempDir := t.TempDir()
|
||||
authData := map[string]any{
|
||||
"type": "claude",
|
||||
"excluded_models": []string{"custom-model", "MODEL-B"},
|
||||
}
|
||||
data, _ := json.Marshal(authData)
|
||||
errWriteFile := os.WriteFile(filepath.Join(tempDir, "auth.json"), data, 0644)
|
||||
if errWriteFile != nil {
|
||||
t.Fatalf("failed to write auth file: %v", errWriteFile)
|
||||
}
|
||||
|
||||
synth := NewFileSynthesizer()
|
||||
ctx := &SynthesisContext{
|
||||
Config: &config.Config{
|
||||
OAuthExcludedModels: map[string][]string{
|
||||
"claude": {"shared", "model-b"},
|
||||
},
|
||||
},
|
||||
AuthDir: tempDir,
|
||||
Now: time.Now(),
|
||||
IDGenerator: NewStableIDGenerator(),
|
||||
}
|
||||
|
||||
auths, errSynthesize := synth.Synthesize(ctx)
|
||||
if errSynthesize != nil {
|
||||
t.Fatalf("unexpected error: %v", errSynthesize)
|
||||
}
|
||||
if len(auths) != 1 {
|
||||
t.Fatalf("expected 1 auth, got %d", len(auths))
|
||||
}
|
||||
|
||||
got := auths[0].Attributes["excluded_models"]
|
||||
want := "custom-model,model-b,shared"
|
||||
if got != want {
|
||||
t.Fatalf("expected excluded_models %q, got %q", want, got)
|
||||
}
|
||||
}
|
||||
|
||||
func TestSynthesizeGeminiVirtualAuths_NilInputs(t *testing.T) {
|
||||
now := time.Now()
|
||||
|
||||
@@ -533,6 +644,7 @@ func TestFileSynthesizer_Synthesize_MultiProjectGemini(t *testing.T) {
|
||||
"type": "gemini",
|
||||
"email": "multi@example.com",
|
||||
"project_id": "project-a, project-b, project-c",
|
||||
"priority": " 10 ",
|
||||
}
|
||||
data, _ := json.Marshal(authData)
|
||||
err := os.WriteFile(filepath.Join(tempDir, "gemini-multi.json"), data, 0644)
|
||||
@@ -565,6 +677,9 @@ func TestFileSynthesizer_Synthesize_MultiProjectGemini(t *testing.T) {
|
||||
if primary.Status != coreauth.StatusDisabled {
|
||||
t.Errorf("expected primary status disabled, got %s", primary.Status)
|
||||
}
|
||||
if gotPriority := primary.Attributes["priority"]; gotPriority != "10" {
|
||||
t.Errorf("expected primary priority 10, got %q", gotPriority)
|
||||
}
|
||||
|
||||
// Remaining auths should be virtuals
|
||||
for i := 1; i < 4; i++ {
|
||||
@@ -575,6 +690,9 @@ func TestFileSynthesizer_Synthesize_MultiProjectGemini(t *testing.T) {
|
||||
if v.Attributes["gemini_virtual_parent"] != primary.ID {
|
||||
t.Errorf("expected virtual %d parent to be %s, got %s", i, primary.ID, v.Attributes["gemini_virtual_parent"])
|
||||
}
|
||||
if gotPriority := v.Attributes["priority"]; gotPriority != "10" {
|
||||
t.Errorf("expected virtual %d priority 10, got %q", i, gotPriority)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
@@ -53,6 +53,8 @@ func (g *StableIDGenerator) Next(kind string, parts ...string) (string, string)
|
||||
|
||||
// ApplyAuthExcludedModelsMeta applies excluded models metadata to an auth entry.
|
||||
// It computes a hash of excluded models and sets the auth_kind attribute.
|
||||
// For OAuth entries, perKey (from the JSON file's excluded-models field) is merged
|
||||
// with the global oauth-excluded-models config for the provider.
|
||||
func ApplyAuthExcludedModelsMeta(auth *coreauth.Auth, cfg *config.Config, perKey []string, authKind string) {
|
||||
if auth == nil || cfg == nil {
|
||||
return
|
||||
@@ -72,9 +74,13 @@ func ApplyAuthExcludedModelsMeta(auth *coreauth.Auth, cfg *config.Config, perKey
|
||||
}
|
||||
if authKindKey == "apikey" {
|
||||
add(perKey)
|
||||
} else if cfg.OAuthExcludedModels != nil {
|
||||
providerKey := strings.ToLower(strings.TrimSpace(auth.Provider))
|
||||
add(cfg.OAuthExcludedModels[providerKey])
|
||||
} else {
|
||||
// For OAuth: merge per-account excluded models with global provider-level exclusions
|
||||
add(perKey)
|
||||
if cfg.OAuthExcludedModels != nil {
|
||||
providerKey := strings.ToLower(strings.TrimSpace(auth.Provider))
|
||||
add(cfg.OAuthExcludedModels[providerKey])
|
||||
}
|
||||
}
|
||||
combined := make([]string, 0, len(seen))
|
||||
for k := range seen {
|
||||
@@ -88,6 +94,10 @@ func ApplyAuthExcludedModelsMeta(auth *coreauth.Auth, cfg *config.Config, perKey
|
||||
if hash != "" {
|
||||
auth.Attributes["excluded_models_hash"] = hash
|
||||
}
|
||||
// Store the combined excluded models list so that routing can read it at runtime
|
||||
if len(combined) > 0 {
|
||||
auth.Attributes["excluded_models"] = strings.Join(combined, ",")
|
||||
}
|
||||
if authKind != "" {
|
||||
auth.Attributes["auth_kind"] = authKind
|
||||
}
|
||||
|
||||
@@ -6,6 +6,7 @@ import (
|
||||
"testing"
|
||||
|
||||
"github.com/router-for-me/CLIProxyAPI/v6/internal/config"
|
||||
"github.com/router-for-me/CLIProxyAPI/v6/internal/watcher/diff"
|
||||
coreauth "github.com/router-for-me/CLIProxyAPI/v6/sdk/cliproxy/auth"
|
||||
)
|
||||
|
||||
@@ -200,6 +201,30 @@ func TestApplyAuthExcludedModelsMeta(t *testing.T) {
|
||||
}
|
||||
}
|
||||
|
||||
func TestApplyAuthExcludedModelsMeta_OAuthMergeWritesCombinedModels(t *testing.T) {
|
||||
auth := &coreauth.Auth{
|
||||
Provider: "claude",
|
||||
Attributes: make(map[string]string),
|
||||
}
|
||||
cfg := &config.Config{
|
||||
OAuthExcludedModels: map[string][]string{
|
||||
"claude": {"global-a", "shared"},
|
||||
},
|
||||
}
|
||||
|
||||
ApplyAuthExcludedModelsMeta(auth, cfg, []string{"per", "SHARED"}, "oauth")
|
||||
|
||||
const wantCombined = "global-a,per,shared"
|
||||
if gotCombined := auth.Attributes["excluded_models"]; gotCombined != wantCombined {
|
||||
t.Fatalf("expected excluded_models=%q, got %q", wantCombined, gotCombined)
|
||||
}
|
||||
|
||||
expectedHash := diff.ComputeExcludedModelsHash([]string{"global-a", "per", "shared"})
|
||||
if gotHash := auth.Attributes["excluded_models_hash"]; gotHash != expectedHash {
|
||||
t.Fatalf("expected excluded_models_hash=%q, got %q", expectedHash, gotHash)
|
||||
}
|
||||
}
|
||||
|
||||
func TestAddConfigHeadersToAttrs(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
|
||||
@@ -185,8 +185,7 @@ func (h *GeminiCLIAPIHandler) handleInternalGenerateContent(c *gin.Context, rawJ
|
||||
func (h *GeminiCLIAPIHandler) forwardCLIStream(c *gin.Context, flusher http.Flusher, alt string, cancel func(error), data <-chan []byte, errs <-chan *interfaces.ErrorMessage) {
|
||||
var keepAliveInterval *time.Duration
|
||||
if alt != "" {
|
||||
disabled := time.Duration(0)
|
||||
keepAliveInterval = &disabled
|
||||
keepAliveInterval = new(time.Duration(0))
|
||||
}
|
||||
|
||||
h.ForwardStream(c, flusher, cancel, data, errs, handlers.StreamForwardOptions{
|
||||
|
||||
@@ -300,8 +300,7 @@ func (h *GeminiAPIHandler) handleGenerateContent(c *gin.Context, modelName strin
|
||||
func (h *GeminiAPIHandler) forwardGeminiStream(c *gin.Context, flusher http.Flusher, alt string, cancel func(error), data <-chan []byte, errs <-chan *interfaces.ErrorMessage) {
|
||||
var keepAliveInterval *time.Duration
|
||||
if alt != "" {
|
||||
disabled := time.Duration(0)
|
||||
keepAliveInterval = &disabled
|
||||
keepAliveInterval = new(time.Duration(0))
|
||||
}
|
||||
|
||||
h.ForwardStream(c, flusher, cancel, data, errs, handlers.StreamForwardOptions{
|
||||
|
||||
@@ -28,8 +28,7 @@ func (AntigravityAuthenticator) Provider() string { return "antigravity" }
|
||||
|
||||
// RefreshLead instructs the manager to refresh five minutes before expiry.
|
||||
func (AntigravityAuthenticator) RefreshLead() *time.Duration {
|
||||
lead := 5 * time.Minute
|
||||
return &lead
|
||||
return new(5 * time.Minute)
|
||||
}
|
||||
|
||||
// Login launches a local OAuth flow to obtain antigravity tokens and persists them.
|
||||
|
||||
@@ -32,8 +32,7 @@ func (a *ClaudeAuthenticator) Provider() string {
|
||||
}
|
||||
|
||||
func (a *ClaudeAuthenticator) RefreshLead() *time.Duration {
|
||||
d := 4 * time.Hour
|
||||
return &d
|
||||
return new(4 * time.Hour)
|
||||
}
|
||||
|
||||
func (a *ClaudeAuthenticator) Login(ctx context.Context, cfg *config.Config, opts *LoginOptions) (*coreauth.Auth, error) {
|
||||
|
||||
@@ -34,8 +34,7 @@ func (a *CodexAuthenticator) Provider() string {
|
||||
}
|
||||
|
||||
func (a *CodexAuthenticator) RefreshLead() *time.Duration {
|
||||
d := 5 * 24 * time.Hour
|
||||
return &d
|
||||
return new(5 * 24 * time.Hour)
|
||||
}
|
||||
|
||||
func (a *CodexAuthenticator) Login(ctx context.Context, cfg *config.Config, opts *LoginOptions) (*coreauth.Auth, error) {
|
||||
|
||||
@@ -4,8 +4,10 @@ import (
|
||||
"context"
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"io"
|
||||
"io/fs"
|
||||
"net/http"
|
||||
"net/url"
|
||||
"os"
|
||||
"path/filepath"
|
||||
"strings"
|
||||
@@ -186,15 +188,21 @@ func (s *FileTokenStore) readAuthFile(path, baseDir string) (*cliproxyauth.Auth,
|
||||
if provider == "" {
|
||||
provider = "unknown"
|
||||
}
|
||||
if provider == "antigravity" {
|
||||
if provider == "antigravity" || provider == "gemini" {
|
||||
projectID := ""
|
||||
if pid, ok := metadata["project_id"].(string); ok {
|
||||
projectID = strings.TrimSpace(pid)
|
||||
}
|
||||
if projectID == "" {
|
||||
accessToken := ""
|
||||
if token, ok := metadata["access_token"].(string); ok {
|
||||
accessToken = strings.TrimSpace(token)
|
||||
accessToken := extractAccessToken(metadata)
|
||||
// For gemini type, the stored access_token is likely expired (~1h lifetime).
|
||||
// Refresh it using the long-lived refresh_token before querying.
|
||||
if provider == "gemini" {
|
||||
if tokenMap, ok := metadata["token"].(map[string]any); ok {
|
||||
if refreshed, errRefresh := refreshGeminiAccessToken(tokenMap, http.DefaultClient); errRefresh == nil {
|
||||
accessToken = refreshed
|
||||
}
|
||||
}
|
||||
}
|
||||
if accessToken != "" {
|
||||
fetchedProjectID, errFetch := FetchAntigravityProjectID(context.Background(), accessToken, http.DefaultClient)
|
||||
@@ -313,6 +321,67 @@ func (s *FileTokenStore) baseDirSnapshot() string {
|
||||
return s.baseDir
|
||||
}
|
||||
|
||||
func extractAccessToken(metadata map[string]any) string {
|
||||
if at, ok := metadata["access_token"].(string); ok {
|
||||
if v := strings.TrimSpace(at); v != "" {
|
||||
return v
|
||||
}
|
||||
}
|
||||
if tokenMap, ok := metadata["token"].(map[string]any); ok {
|
||||
if at, ok := tokenMap["access_token"].(string); ok {
|
||||
if v := strings.TrimSpace(at); v != "" {
|
||||
return v
|
||||
}
|
||||
}
|
||||
}
|
||||
return ""
|
||||
}
|
||||
|
||||
func refreshGeminiAccessToken(tokenMap map[string]any, httpClient *http.Client) (string, error) {
|
||||
refreshToken, _ := tokenMap["refresh_token"].(string)
|
||||
clientID, _ := tokenMap["client_id"].(string)
|
||||
clientSecret, _ := tokenMap["client_secret"].(string)
|
||||
tokenURI, _ := tokenMap["token_uri"].(string)
|
||||
|
||||
if refreshToken == "" || clientID == "" || clientSecret == "" {
|
||||
return "", fmt.Errorf("missing refresh credentials")
|
||||
}
|
||||
if tokenURI == "" {
|
||||
tokenURI = "https://oauth2.googleapis.com/token"
|
||||
}
|
||||
|
||||
data := url.Values{
|
||||
"grant_type": {"refresh_token"},
|
||||
"refresh_token": {refreshToken},
|
||||
"client_id": {clientID},
|
||||
"client_secret": {clientSecret},
|
||||
}
|
||||
|
||||
resp, err := httpClient.PostForm(tokenURI, data)
|
||||
if err != nil {
|
||||
return "", fmt.Errorf("refresh request: %w", err)
|
||||
}
|
||||
defer func() { _ = resp.Body.Close() }()
|
||||
|
||||
body, _ := io.ReadAll(resp.Body)
|
||||
if resp.StatusCode != http.StatusOK {
|
||||
return "", fmt.Errorf("refresh failed: status %d", resp.StatusCode)
|
||||
}
|
||||
|
||||
var result map[string]any
|
||||
if errUnmarshal := json.Unmarshal(body, &result); errUnmarshal != nil {
|
||||
return "", fmt.Errorf("decode refresh response: %w", errUnmarshal)
|
||||
}
|
||||
|
||||
newAccessToken, _ := result["access_token"].(string)
|
||||
if newAccessToken == "" {
|
||||
return "", fmt.Errorf("no access_token in refresh response")
|
||||
}
|
||||
|
||||
tokenMap["access_token"] = newAccessToken
|
||||
return newAccessToken, nil
|
||||
}
|
||||
|
||||
// jsonEqual compares two JSON blobs by parsing them into Go objects and deep comparing.
|
||||
func jsonEqual(a, b []byte) bool {
|
||||
var objA any
|
||||
|
||||
80
sdk/auth/filestore_test.go
Normal file
80
sdk/auth/filestore_test.go
Normal file
@@ -0,0 +1,80 @@
|
||||
package auth
|
||||
|
||||
import "testing"
|
||||
|
||||
func TestExtractAccessToken(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
tests := []struct {
|
||||
name string
|
||||
metadata map[string]any
|
||||
expected string
|
||||
}{
|
||||
{
|
||||
"antigravity top-level access_token",
|
||||
map[string]any{"access_token": "tok-abc"},
|
||||
"tok-abc",
|
||||
},
|
||||
{
|
||||
"gemini nested token.access_token",
|
||||
map[string]any{
|
||||
"token": map[string]any{"access_token": "tok-nested"},
|
||||
},
|
||||
"tok-nested",
|
||||
},
|
||||
{
|
||||
"top-level takes precedence over nested",
|
||||
map[string]any{
|
||||
"access_token": "tok-top",
|
||||
"token": map[string]any{"access_token": "tok-nested"},
|
||||
},
|
||||
"tok-top",
|
||||
},
|
||||
{
|
||||
"empty metadata",
|
||||
map[string]any{},
|
||||
"",
|
||||
},
|
||||
{
|
||||
"whitespace-only access_token",
|
||||
map[string]any{"access_token": " "},
|
||||
"",
|
||||
},
|
||||
{
|
||||
"wrong type access_token",
|
||||
map[string]any{"access_token": 12345},
|
||||
"",
|
||||
},
|
||||
{
|
||||
"token is not a map",
|
||||
map[string]any{"token": "not-a-map"},
|
||||
"",
|
||||
},
|
||||
{
|
||||
"nested whitespace-only",
|
||||
map[string]any{
|
||||
"token": map[string]any{"access_token": " "},
|
||||
},
|
||||
"",
|
||||
},
|
||||
{
|
||||
"fallback to nested when top-level empty",
|
||||
map[string]any{
|
||||
"access_token": "",
|
||||
"token": map[string]any{"access_token": "tok-fallback"},
|
||||
},
|
||||
"tok-fallback",
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
tt := tt
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
t.Parallel()
|
||||
got := extractAccessToken(tt.metadata)
|
||||
if got != tt.expected {
|
||||
t.Errorf("extractAccessToken() = %q, want %q", got, tt.expected)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
@@ -26,8 +26,7 @@ func (a *IFlowAuthenticator) Provider() string { return "iflow" }
|
||||
|
||||
// RefreshLead indicates how soon before expiry a refresh should be attempted.
|
||||
func (a *IFlowAuthenticator) RefreshLead() *time.Duration {
|
||||
d := 24 * time.Hour
|
||||
return &d
|
||||
return new(24 * time.Hour)
|
||||
}
|
||||
|
||||
// Login performs the OAuth code flow using a local callback server.
|
||||
|
||||
@@ -27,8 +27,7 @@ func (a *QwenAuthenticator) Provider() string {
|
||||
}
|
||||
|
||||
func (a *QwenAuthenticator) RefreshLead() *time.Duration {
|
||||
d := 3 * time.Hour
|
||||
return &d
|
||||
return new(3 * time.Hour)
|
||||
}
|
||||
|
||||
func (a *QwenAuthenticator) Login(ctx context.Context, cfg *config.Config, opts *LoginOptions) (*coreauth.Auth, error) {
|
||||
|
||||
@@ -599,8 +599,7 @@ func (m *Manager) executeMixedOnce(ctx context.Context, providers []string, req
|
||||
return cliproxyexecutor.Response{}, errCtx
|
||||
}
|
||||
result.Error = &Error{Message: errExec.Error()}
|
||||
var se cliproxyexecutor.StatusError
|
||||
if errors.As(errExec, &se) && se != nil {
|
||||
if se, ok := errors.AsType[cliproxyexecutor.StatusError](errExec); ok && se != nil {
|
||||
result.Error.HTTPStatus = se.StatusCode()
|
||||
}
|
||||
if ra := retryAfterFromError(errExec); ra != nil {
|
||||
@@ -655,8 +654,7 @@ func (m *Manager) executeCountMixedOnce(ctx context.Context, providers []string,
|
||||
return cliproxyexecutor.Response{}, errCtx
|
||||
}
|
||||
result.Error = &Error{Message: errExec.Error()}
|
||||
var se cliproxyexecutor.StatusError
|
||||
if errors.As(errExec, &se) && se != nil {
|
||||
if se, ok := errors.AsType[cliproxyexecutor.StatusError](errExec); ok && se != nil {
|
||||
result.Error.HTTPStatus = se.StatusCode()
|
||||
}
|
||||
if ra := retryAfterFromError(errExec); ra != nil {
|
||||
@@ -710,8 +708,7 @@ func (m *Manager) executeStreamMixedOnce(ctx context.Context, providers []string
|
||||
return nil, errCtx
|
||||
}
|
||||
rerr := &Error{Message: errStream.Error()}
|
||||
var se cliproxyexecutor.StatusError
|
||||
if errors.As(errStream, &se) && se != nil {
|
||||
if se, ok := errors.AsType[cliproxyexecutor.StatusError](errStream); ok && se != nil {
|
||||
rerr.HTTPStatus = se.StatusCode()
|
||||
}
|
||||
result := Result{AuthID: auth.ID, Provider: provider, Model: routeModel, Success: false, Error: rerr}
|
||||
@@ -732,8 +729,7 @@ func (m *Manager) executeStreamMixedOnce(ctx context.Context, providers []string
|
||||
if chunk.Err != nil && !failed {
|
||||
failed = true
|
||||
rerr := &Error{Message: chunk.Err.Error()}
|
||||
var se cliproxyexecutor.StatusError
|
||||
if errors.As(chunk.Err, &se) && se != nil {
|
||||
if se, ok := errors.AsType[cliproxyexecutor.StatusError](chunk.Err); ok && se != nil {
|
||||
rerr.HTTPStatus = se.StatusCode()
|
||||
}
|
||||
m.MarkResult(streamCtx, Result{AuthID: streamAuth.ID, Provider: streamProvider, Model: routeModel, Success: false, Error: rerr})
|
||||
@@ -1431,8 +1427,7 @@ func retryAfterFromError(err error) *time.Duration {
|
||||
if retryAfter == nil {
|
||||
return nil
|
||||
}
|
||||
val := *retryAfter
|
||||
return &val
|
||||
return new(*retryAfter)
|
||||
}
|
||||
|
||||
func statusCodeFromResult(err *Error) int {
|
||||
|
||||
@@ -79,6 +79,24 @@ func TestResolveOAuthUpstreamModel_SuffixPreservation(t *testing.T) {
|
||||
input: "gemini-2.5-pro(none)",
|
||||
want: "gemini-2.5-pro-exp-03-25(none)",
|
||||
},
|
||||
{
|
||||
name: "github-copilot suffix preserved",
|
||||
aliases: map[string][]internalconfig.OAuthModelAlias{
|
||||
"github-copilot": {{Name: "claude-opus-4.6", Alias: "opus"}},
|
||||
},
|
||||
channel: "github-copilot",
|
||||
input: "opus(medium)",
|
||||
want: "claude-opus-4.6(medium)",
|
||||
},
|
||||
{
|
||||
name: "github-copilot no suffix",
|
||||
aliases: map[string][]internalconfig.OAuthModelAlias{
|
||||
"github-copilot": {{Name: "claude-opus-4.6", Alias: "opus"}},
|
||||
},
|
||||
channel: "github-copilot",
|
||||
input: "opus",
|
||||
want: "claude-opus-4.6",
|
||||
},
|
||||
{
|
||||
name: "kimi suffix preserved",
|
||||
aliases: map[string][]internalconfig.OAuthModelAlias{
|
||||
@@ -174,6 +192,8 @@ func createAuthForChannel(channel string) *Auth {
|
||||
return &Auth{Provider: "kimi"}
|
||||
case "kiro":
|
||||
return &Auth{Provider: "kiro"}
|
||||
case "github-copilot":
|
||||
return &Auth{Provider: "github-copilot"}
|
||||
default:
|
||||
return &Auth{Provider: channel}
|
||||
}
|
||||
@@ -187,6 +207,22 @@ func TestOAuthModelAliasChannel_Kimi(t *testing.T) {
|
||||
}
|
||||
}
|
||||
|
||||
func TestOAuthModelAliasChannel_GitHubCopilot(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
if got := OAuthModelAliasChannel("github-copilot", ""); got != "github-copilot" {
|
||||
t.Fatalf("OAuthModelAliasChannel() = %q, want %q", got, "github-copilot")
|
||||
}
|
||||
}
|
||||
|
||||
func TestOAuthModelAliasChannel_Kiro(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
if got := OAuthModelAliasChannel("kiro", ""); got != "kiro" {
|
||||
t.Fatalf("OAuthModelAliasChannel() = %q, want %q", got, "kiro")
|
||||
}
|
||||
}
|
||||
|
||||
func TestApplyOAuthModelAlias_SuffixPreservation(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
|
||||
@@ -767,6 +767,13 @@ func (s *Service) registerModelsForAuth(a *coreauth.Auth) {
|
||||
provider = "openai-compatibility"
|
||||
}
|
||||
excluded := s.oauthExcludedModels(provider, authKind)
|
||||
// The synthesizer pre-merges per-account and global exclusions into the "excluded_models" attribute.
|
||||
// If this attribute is present, it represents the complete list of exclusions and overrides the global config.
|
||||
if a.Attributes != nil {
|
||||
if val, ok := a.Attributes["excluded_models"]; ok && strings.TrimSpace(val) != "" {
|
||||
excluded = strings.Split(val, ",")
|
||||
}
|
||||
}
|
||||
var models []*ModelInfo
|
||||
switch provider {
|
||||
case "gemini":
|
||||
|
||||
65
sdk/cliproxy/service_excluded_models_test.go
Normal file
65
sdk/cliproxy/service_excluded_models_test.go
Normal file
@@ -0,0 +1,65 @@
|
||||
package cliproxy
|
||||
|
||||
import (
|
||||
"strings"
|
||||
"testing"
|
||||
|
||||
coreauth "github.com/router-for-me/CLIProxyAPI/v6/sdk/cliproxy/auth"
|
||||
"github.com/router-for-me/CLIProxyAPI/v6/sdk/config"
|
||||
)
|
||||
|
||||
func TestRegisterModelsForAuth_UsesPreMergedExcludedModelsAttribute(t *testing.T) {
|
||||
service := &Service{
|
||||
cfg: &config.Config{
|
||||
OAuthExcludedModels: map[string][]string{
|
||||
"gemini-cli": {"gemini-2.5-pro"},
|
||||
},
|
||||
},
|
||||
}
|
||||
auth := &coreauth.Auth{
|
||||
ID: "auth-gemini-cli",
|
||||
Provider: "gemini-cli",
|
||||
Status: coreauth.StatusActive,
|
||||
Attributes: map[string]string{
|
||||
"auth_kind": "oauth",
|
||||
"excluded_models": "gemini-2.5-flash",
|
||||
},
|
||||
}
|
||||
|
||||
registry := GlobalModelRegistry()
|
||||
registry.UnregisterClient(auth.ID)
|
||||
t.Cleanup(func() {
|
||||
registry.UnregisterClient(auth.ID)
|
||||
})
|
||||
|
||||
service.registerModelsForAuth(auth)
|
||||
|
||||
models := registry.GetAvailableModelsByProvider("gemini-cli")
|
||||
if len(models) == 0 {
|
||||
t.Fatal("expected gemini-cli models to be registered")
|
||||
}
|
||||
|
||||
for _, model := range models {
|
||||
if model == nil {
|
||||
continue
|
||||
}
|
||||
modelID := strings.TrimSpace(model.ID)
|
||||
if strings.EqualFold(modelID, "gemini-2.5-flash") {
|
||||
t.Fatalf("expected model %q to be excluded by auth attribute", modelID)
|
||||
}
|
||||
}
|
||||
|
||||
seenGlobalExcluded := false
|
||||
for _, model := range models {
|
||||
if model == nil {
|
||||
continue
|
||||
}
|
||||
if strings.EqualFold(strings.TrimSpace(model.ID), "gemini-2.5-pro") {
|
||||
seenGlobalExcluded = true
|
||||
break
|
||||
}
|
||||
}
|
||||
if !seenGlobalExcluded {
|
||||
t.Fatal("expected global excluded model to be present when attribute override is set")
|
||||
}
|
||||
}
|
||||
@@ -1316,6 +1316,122 @@ func TestThinkingE2EMatrix_Suffix(t *testing.T) {
|
||||
includeThoughts: "true",
|
||||
expectErr: false,
|
||||
},
|
||||
|
||||
// GitHub Copilot tests: gpt-5, gpt-5.1, gpt-5.2 (Levels=low/medium/high, some with none/xhigh)
|
||||
// Testing /chat/completions endpoint (openai format) - with suffix
|
||||
|
||||
// Case 112: OpenAI to gpt-5, level high → high
|
||||
{
|
||||
name: "112",
|
||||
from: "openai",
|
||||
to: "github-copilot",
|
||||
model: "gpt-5(high)",
|
||||
inputJSON: `{"model":"gpt-5(high)","messages":[{"role":"user","content":"hi"}]}`,
|
||||
expectField: "reasoning_effort",
|
||||
expectValue: "high",
|
||||
expectErr: false,
|
||||
},
|
||||
// Case 113: OpenAI to gpt-5, level none → clamped to low (ZeroAllowed=false)
|
||||
{
|
||||
name: "113",
|
||||
from: "openai",
|
||||
to: "github-copilot",
|
||||
model: "gpt-5(none)",
|
||||
inputJSON: `{"model":"gpt-5(none)","messages":[{"role":"user","content":"hi"}]}`,
|
||||
expectField: "reasoning_effort",
|
||||
expectValue: "low",
|
||||
expectErr: false,
|
||||
},
|
||||
// Case 114: OpenAI to gpt-5.1, level none → none (ZeroAllowed=true)
|
||||
{
|
||||
name: "114",
|
||||
from: "openai",
|
||||
to: "github-copilot",
|
||||
model: "gpt-5.1(none)",
|
||||
inputJSON: `{"model":"gpt-5.1(none)","messages":[{"role":"user","content":"hi"}]}`,
|
||||
expectField: "reasoning_effort",
|
||||
expectValue: "none",
|
||||
expectErr: false,
|
||||
},
|
||||
// Case 115: OpenAI to gpt-5.2, level xhigh → xhigh
|
||||
{
|
||||
name: "115",
|
||||
from: "openai",
|
||||
to: "github-copilot",
|
||||
model: "gpt-5.2(xhigh)",
|
||||
inputJSON: `{"model":"gpt-5.2(xhigh)","messages":[{"role":"user","content":"hi"}]}`,
|
||||
expectField: "reasoning_effort",
|
||||
expectValue: "xhigh",
|
||||
expectErr: false,
|
||||
},
|
||||
// Case 116: OpenAI to gpt-5, level xhigh (out of range) → error
|
||||
{
|
||||
name: "116",
|
||||
from: "openai",
|
||||
to: "github-copilot",
|
||||
model: "gpt-5(xhigh)",
|
||||
inputJSON: `{"model":"gpt-5(xhigh)","messages":[{"role":"user","content":"hi"}]}`,
|
||||
expectField: "",
|
||||
expectErr: true,
|
||||
},
|
||||
// Case 117: Claude to gpt-5.1, budget 0 → none (ZeroAllowed=true)
|
||||
{
|
||||
name: "117",
|
||||
from: "claude",
|
||||
to: "github-copilot",
|
||||
model: "gpt-5.1(0)",
|
||||
inputJSON: `{"model":"gpt-5.1(0)","messages":[{"role":"user","content":"hi"}]}`,
|
||||
expectField: "reasoning_effort",
|
||||
expectValue: "none",
|
||||
expectErr: false,
|
||||
},
|
||||
|
||||
// GitHub Copilot tests: /responses endpoint (codex format) - with suffix
|
||||
|
||||
// Case 118: OpenAI-Response to gpt-5-codex, level high → high
|
||||
{
|
||||
name: "118",
|
||||
from: "openai-response",
|
||||
to: "github-copilot",
|
||||
model: "gpt-5-codex(high)",
|
||||
inputJSON: `{"model":"gpt-5-codex(high)","input":[{"role":"user","content":"hi"}]}`,
|
||||
expectField: "reasoning.effort",
|
||||
expectValue: "high",
|
||||
expectErr: false,
|
||||
},
|
||||
// Case 119: OpenAI-Response to gpt-5.2-codex, level xhigh → xhigh
|
||||
{
|
||||
name: "119",
|
||||
from: "openai-response",
|
||||
to: "github-copilot",
|
||||
model: "gpt-5.2-codex(xhigh)",
|
||||
inputJSON: `{"model":"gpt-5.2-codex(xhigh)","input":[{"role":"user","content":"hi"}]}`,
|
||||
expectField: "reasoning.effort",
|
||||
expectValue: "xhigh",
|
||||
expectErr: false,
|
||||
},
|
||||
// Case 120: OpenAI-Response to gpt-5.2-codex, level none → none
|
||||
{
|
||||
name: "120",
|
||||
from: "openai-response",
|
||||
to: "github-copilot",
|
||||
model: "gpt-5.2-codex(none)",
|
||||
inputJSON: `{"model":"gpt-5.2-codex(none)","input":[{"role":"user","content":"hi"}]}`,
|
||||
expectField: "reasoning.effort",
|
||||
expectValue: "none",
|
||||
expectErr: false,
|
||||
},
|
||||
// Case 121: OpenAI-Response to gpt-5-codex, level none → clamped to low (ZeroAllowed=false)
|
||||
{
|
||||
name: "121",
|
||||
from: "openai-response",
|
||||
to: "github-copilot",
|
||||
model: "gpt-5-codex(none)",
|
||||
inputJSON: `{"model":"gpt-5-codex(none)","input":[{"role":"user","content":"hi"}]}`,
|
||||
expectField: "reasoning.effort",
|
||||
expectValue: "low",
|
||||
expectErr: false,
|
||||
},
|
||||
}
|
||||
|
||||
runThinkingTests(t, cases)
|
||||
@@ -2585,6 +2701,251 @@ func TestThinkingE2EMatrix_Body(t *testing.T) {
|
||||
includeThoughts: "true",
|
||||
expectErr: false,
|
||||
},
|
||||
|
||||
// GitHub Copilot tests: gpt-5, gpt-5.1, gpt-5.2 (Levels=low/medium/high, some with none/xhigh)
|
||||
// Testing /chat/completions endpoint (openai format) - with body params
|
||||
|
||||
// Case 112: OpenAI to gpt-5, reasoning_effort=high → high
|
||||
{
|
||||
name: "112",
|
||||
from: "openai",
|
||||
to: "github-copilot",
|
||||
model: "gpt-5",
|
||||
inputJSON: `{"model":"gpt-5","messages":[{"role":"user","content":"hi"}],"reasoning_effort":"high"}`,
|
||||
expectField: "reasoning_effort",
|
||||
expectValue: "high",
|
||||
expectErr: false,
|
||||
},
|
||||
// Case 113: OpenAI to gpt-5, reasoning_effort=none → clamped to low (ZeroAllowed=false)
|
||||
{
|
||||
name: "113",
|
||||
from: "openai",
|
||||
to: "github-copilot",
|
||||
model: "gpt-5",
|
||||
inputJSON: `{"model":"gpt-5","messages":[{"role":"user","content":"hi"}],"reasoning_effort":"none"}`,
|
||||
expectField: "reasoning_effort",
|
||||
expectValue: "low",
|
||||
expectErr: false,
|
||||
},
|
||||
// Case 114: OpenAI to gpt-5.1, reasoning_effort=none → none (ZeroAllowed=true)
|
||||
{
|
||||
name: "114",
|
||||
from: "openai",
|
||||
to: "github-copilot",
|
||||
model: "gpt-5.1",
|
||||
inputJSON: `{"model":"gpt-5.1","messages":[{"role":"user","content":"hi"}],"reasoning_effort":"none"}`,
|
||||
expectField: "reasoning_effort",
|
||||
expectValue: "none",
|
||||
expectErr: false,
|
||||
},
|
||||
// Case 115: OpenAI to gpt-5.2, reasoning_effort=xhigh → xhigh
|
||||
{
|
||||
name: "115",
|
||||
from: "openai",
|
||||
to: "github-copilot",
|
||||
model: "gpt-5.2",
|
||||
inputJSON: `{"model":"gpt-5.2","messages":[{"role":"user","content":"hi"}],"reasoning_effort":"xhigh"}`,
|
||||
expectField: "reasoning_effort",
|
||||
expectValue: "xhigh",
|
||||
expectErr: false,
|
||||
},
|
||||
// Case 116: OpenAI to gpt-5, reasoning_effort=xhigh (out of range) → error
|
||||
{
|
||||
name: "116",
|
||||
from: "openai",
|
||||
to: "github-copilot",
|
||||
model: "gpt-5",
|
||||
inputJSON: `{"model":"gpt-5","messages":[{"role":"user","content":"hi"}],"reasoning_effort":"xhigh"}`,
|
||||
expectField: "",
|
||||
expectErr: true,
|
||||
},
|
||||
// Case 117: Claude to gpt-5.1, thinking.budget_tokens=0 → none (ZeroAllowed=true)
|
||||
{
|
||||
name: "117",
|
||||
from: "claude",
|
||||
to: "github-copilot",
|
||||
model: "gpt-5.1",
|
||||
inputJSON: `{"model":"gpt-5.1","messages":[{"role":"user","content":"hi"}],"thinking":{"type":"enabled","budget_tokens":0}}`,
|
||||
expectField: "reasoning_effort",
|
||||
expectValue: "none",
|
||||
expectErr: false,
|
||||
},
|
||||
|
||||
// GitHub Copilot tests: /responses endpoint (codex format) - with body params
|
||||
|
||||
// Case 118: OpenAI-Response to gpt-5-codex, reasoning.effort=high → high
|
||||
{
|
||||
name: "118",
|
||||
from: "openai-response",
|
||||
to: "github-copilot",
|
||||
model: "gpt-5-codex",
|
||||
inputJSON: `{"model":"gpt-5-codex","input":[{"role":"user","content":"hi"}],"reasoning":{"effort":"high"}}`,
|
||||
expectField: "reasoning.effort",
|
||||
expectValue: "high",
|
||||
expectErr: false,
|
||||
},
|
||||
// Case 119: OpenAI-Response to gpt-5.2-codex, reasoning.effort=xhigh → xhigh
|
||||
{
|
||||
name: "119",
|
||||
from: "openai-response",
|
||||
to: "github-copilot",
|
||||
model: "gpt-5.2-codex",
|
||||
inputJSON: `{"model":"gpt-5.2-codex","input":[{"role":"user","content":"hi"}],"reasoning":{"effort":"xhigh"}}`,
|
||||
expectField: "reasoning.effort",
|
||||
expectValue: "xhigh",
|
||||
expectErr: false,
|
||||
},
|
||||
// Case 120: OpenAI-Response to gpt-5.2-codex, reasoning.effort=none → none
|
||||
{
|
||||
name: "120",
|
||||
from: "openai-response",
|
||||
to: "github-copilot",
|
||||
model: "gpt-5.2-codex",
|
||||
inputJSON: `{"model":"gpt-5.2-codex","input":[{"role":"user","content":"hi"}],"reasoning":{"effort":"none"}}`,
|
||||
expectField: "reasoning.effort",
|
||||
expectValue: "none",
|
||||
expectErr: false,
|
||||
},
|
||||
// Case 121: OpenAI-Response to gpt-5-codex, reasoning.effort=none → clamped to low (ZeroAllowed=false)
|
||||
{
|
||||
name: "121",
|
||||
from: "openai-response",
|
||||
to: "github-copilot",
|
||||
model: "gpt-5-codex",
|
||||
inputJSON: `{"model":"gpt-5-codex","input":[{"role":"user","content":"hi"}],"reasoning":{"effort":"none"}}`,
|
||||
expectField: "reasoning.effort",
|
||||
expectValue: "low",
|
||||
expectErr: false,
|
||||
},
|
||||
}
|
||||
|
||||
runThinkingTests(t, cases)
|
||||
}
|
||||
|
||||
// TestThinkingE2EClaudeAdaptive_Body tests Claude thinking.type=adaptive extended body-only cases.
|
||||
// These cases validate that adaptive means "thinking enabled without explicit budget", and
|
||||
// cross-protocol conversion should resolve to target-model maximum thinking capability.
|
||||
func TestThinkingE2EClaudeAdaptive_Body(t *testing.T) {
|
||||
reg := registry.GetGlobalRegistry()
|
||||
uid := fmt.Sprintf("thinking-e2e-claude-adaptive-%d", time.Now().UnixNano())
|
||||
|
||||
reg.RegisterClient(uid, "test", getTestModels())
|
||||
defer reg.UnregisterClient(uid)
|
||||
|
||||
cases := []thinkingTestCase{
|
||||
// A1: Claude adaptive to OpenAI level model -> highest supported level
|
||||
{
|
||||
name: "A1",
|
||||
from: "claude",
|
||||
to: "openai",
|
||||
model: "level-model",
|
||||
inputJSON: `{"model":"level-model","messages":[{"role":"user","content":"hi"}],"thinking":{"type":"adaptive"}}`,
|
||||
expectField: "reasoning_effort",
|
||||
expectValue: "high",
|
||||
expectErr: false,
|
||||
},
|
||||
// A2: Claude adaptive to Gemini level subset model -> highest supported level
|
||||
{
|
||||
name: "A2",
|
||||
from: "claude",
|
||||
to: "gemini",
|
||||
model: "level-subset-model",
|
||||
inputJSON: `{"model":"level-subset-model","messages":[{"role":"user","content":"hi"}],"thinking":{"type":"adaptive"}}`,
|
||||
expectField: "generationConfig.thinkingConfig.thinkingLevel",
|
||||
expectValue: "high",
|
||||
includeThoughts: "true",
|
||||
expectErr: false,
|
||||
},
|
||||
// A3: Claude adaptive to Gemini budget model -> max budget
|
||||
{
|
||||
name: "A3",
|
||||
from: "claude",
|
||||
to: "gemini",
|
||||
model: "gemini-budget-model",
|
||||
inputJSON: `{"model":"gemini-budget-model","messages":[{"role":"user","content":"hi"}],"thinking":{"type":"adaptive"}}`,
|
||||
expectField: "generationConfig.thinkingConfig.thinkingBudget",
|
||||
expectValue: "20000",
|
||||
includeThoughts: "true",
|
||||
expectErr: false,
|
||||
},
|
||||
// A4: Claude adaptive to Gemini mixed model -> highest supported level
|
||||
{
|
||||
name: "A4",
|
||||
from: "claude",
|
||||
to: "gemini",
|
||||
model: "gemini-mixed-model",
|
||||
inputJSON: `{"model":"gemini-mixed-model","messages":[{"role":"user","content":"hi"}],"thinking":{"type":"adaptive"}}`,
|
||||
expectField: "generationConfig.thinkingConfig.thinkingLevel",
|
||||
expectValue: "high",
|
||||
includeThoughts: "true",
|
||||
expectErr: false,
|
||||
},
|
||||
// A5: Claude adaptive passthrough for same protocol
|
||||
{
|
||||
name: "A5",
|
||||
from: "claude",
|
||||
to: "claude",
|
||||
model: "claude-budget-model",
|
||||
inputJSON: `{"model":"claude-budget-model","messages":[{"role":"user","content":"hi"}],"thinking":{"type":"adaptive"}}`,
|
||||
expectField: "thinking.type",
|
||||
expectValue: "adaptive",
|
||||
expectErr: false,
|
||||
},
|
||||
// A6: Claude adaptive to Antigravity budget model -> max budget
|
||||
{
|
||||
name: "A6",
|
||||
from: "claude",
|
||||
to: "antigravity",
|
||||
model: "antigravity-budget-model",
|
||||
inputJSON: `{"model":"antigravity-budget-model","messages":[{"role":"user","content":"hi"}],"thinking":{"type":"adaptive"}}`,
|
||||
expectField: "request.generationConfig.thinkingConfig.thinkingBudget",
|
||||
expectValue: "20000",
|
||||
includeThoughts: "true",
|
||||
expectErr: false,
|
||||
},
|
||||
// A7: Claude adaptive to iFlow GLM -> enabled boolean
|
||||
{
|
||||
name: "A7",
|
||||
from: "claude",
|
||||
to: "iflow",
|
||||
model: "glm-test",
|
||||
inputJSON: `{"model":"glm-test","messages":[{"role":"user","content":"hi"}],"thinking":{"type":"adaptive"}}`,
|
||||
expectField: "chat_template_kwargs.enable_thinking",
|
||||
expectValue: "true",
|
||||
expectErr: false,
|
||||
},
|
||||
// A8: Claude adaptive to iFlow MiniMax -> enabled boolean
|
||||
{
|
||||
name: "A8",
|
||||
from: "claude",
|
||||
to: "iflow",
|
||||
model: "minimax-test",
|
||||
inputJSON: `{"model":"minimax-test","messages":[{"role":"user","content":"hi"}],"thinking":{"type":"adaptive"}}`,
|
||||
expectField: "reasoning_split",
|
||||
expectValue: "true",
|
||||
expectErr: false,
|
||||
},
|
||||
// A9: Claude adaptive to Codex level model -> highest supported level
|
||||
{
|
||||
name: "A9",
|
||||
from: "claude",
|
||||
to: "codex",
|
||||
model: "level-model",
|
||||
inputJSON: `{"model":"level-model","messages":[{"role":"user","content":"hi"}],"thinking":{"type":"adaptive"}}`,
|
||||
expectField: "reasoning.effort",
|
||||
expectValue: "high",
|
||||
expectErr: false,
|
||||
},
|
||||
// A10: Claude adaptive on non-thinking model should still be stripped
|
||||
{
|
||||
name: "A10",
|
||||
from: "claude",
|
||||
to: "openai",
|
||||
model: "no-thinking-model",
|
||||
inputJSON: `{"model":"no-thinking-model","messages":[{"role":"user","content":"hi"}],"thinking":{"type":"adaptive"}}`,
|
||||
expectField: "",
|
||||
expectErr: false,
|
||||
},
|
||||
}
|
||||
|
||||
runThinkingTests(t, cases)
|
||||
@@ -2684,6 +3045,51 @@ func getTestModels() []*registry.ModelInfo {
|
||||
DisplayName: "MiniMax Test Model",
|
||||
Thinking: ®istry.ThinkingSupport{Levels: []string{"none", "auto", "minimal", "low", "medium", "high", "xhigh"}},
|
||||
},
|
||||
{
|
||||
ID: "gpt-5",
|
||||
Object: "model",
|
||||
Created: 1700000000,
|
||||
OwnedBy: "github-copilot",
|
||||
Type: "github-copilot",
|
||||
DisplayName: "GPT-5",
|
||||
Thinking: ®istry.ThinkingSupport{Levels: []string{"low", "medium", "high"}, ZeroAllowed: false, DynamicAllowed: false},
|
||||
},
|
||||
{
|
||||
ID: "gpt-5.1",
|
||||
Object: "model",
|
||||
Created: 1700000000,
|
||||
OwnedBy: "github-copilot",
|
||||
Type: "github-copilot",
|
||||
DisplayName: "GPT-5.1",
|
||||
Thinking: ®istry.ThinkingSupport{Levels: []string{"none", "low", "medium", "high"}, ZeroAllowed: true, DynamicAllowed: false},
|
||||
},
|
||||
{
|
||||
ID: "gpt-5.2",
|
||||
Object: "model",
|
||||
Created: 1700000000,
|
||||
OwnedBy: "github-copilot",
|
||||
Type: "github-copilot",
|
||||
DisplayName: "GPT-5.2",
|
||||
Thinking: ®istry.ThinkingSupport{Levels: []string{"none", "low", "medium", "high", "xhigh"}, ZeroAllowed: true, DynamicAllowed: false},
|
||||
},
|
||||
{
|
||||
ID: "gpt-5-codex",
|
||||
Object: "model",
|
||||
Created: 1700000000,
|
||||
OwnedBy: "github-copilot",
|
||||
Type: "github-copilot",
|
||||
DisplayName: "GPT-5 Codex",
|
||||
Thinking: ®istry.ThinkingSupport{Levels: []string{"low", "medium", "high"}, ZeroAllowed: false, DynamicAllowed: false},
|
||||
},
|
||||
{
|
||||
ID: "gpt-5.2-codex",
|
||||
Object: "model",
|
||||
Created: 1700000000,
|
||||
OwnedBy: "github-copilot",
|
||||
Type: "github-copilot",
|
||||
DisplayName: "GPT-5.2 Codex",
|
||||
Thinking: ®istry.ThinkingSupport{Levels: []string{"none", "low", "medium", "high", "xhigh"}, ZeroAllowed: true, DynamicAllowed: false},
|
||||
},
|
||||
}
|
||||
}
|
||||
|
||||
@@ -2702,6 +3108,15 @@ func runThinkingTests(t *testing.T, cases []thinkingTestCase) {
|
||||
translateTo = "openai"
|
||||
applyTo = "iflow"
|
||||
}
|
||||
if tc.to == "github-copilot" {
|
||||
if tc.from == "openai-response" {
|
||||
translateTo = "codex"
|
||||
applyTo = "codex"
|
||||
} else {
|
||||
translateTo = "openai"
|
||||
applyTo = "openai"
|
||||
}
|
||||
}
|
||||
|
||||
body := sdktranslator.TranslateRequest(
|
||||
sdktranslator.FromString(tc.from),
|
||||
|
||||
Reference in New Issue
Block a user