mirror of
https://github.com/router-for-me/CLIProxyAPIPlus.git
synced 2026-03-09 15:25:17 +00:00
Compare commits
151 Commits
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
bc61bf36b2 | ||
|
|
7726a44ca2 | ||
|
|
dc55fb0ce3 | ||
|
|
e9dd44e623 | ||
|
|
cc8c4ffb5f | ||
|
|
1510bfcb6f | ||
|
|
bcd2208b51 | ||
|
|
09b19f5c4e | ||
|
|
7b01ca0e2e | ||
|
|
fe6fc628ed | ||
|
|
8192eeabc8 | ||
|
|
c3f1cdd7e5 | ||
|
|
c6bd91b86b | ||
|
|
349ddcaa89 | ||
|
|
938a799263 | ||
|
|
e17d4f8d98 | ||
|
|
c8cae1f74d | ||
|
|
0040d78496 | ||
|
|
896de027cc | ||
|
|
fc329ebf37 | ||
|
|
91841a5519 | ||
|
|
eaab1d6824 | ||
|
|
0cfe310df6 | ||
|
|
918b6955e4 | ||
|
|
532fbf00d4 | ||
|
|
45b6fffd7f | ||
|
|
5a3eb08739 | ||
|
|
0dff329162 | ||
|
|
49c1740b47 | ||
|
|
3fbee51e9f | ||
|
|
a3dc56d2a0 | ||
|
|
63643c44a1 | ||
|
|
1d93608dbe | ||
|
|
d125b7de92 | ||
|
|
d5654ee316 | ||
|
|
3b34521ad9 | ||
|
|
7197fb350b | ||
|
|
6e349bfcc7 | ||
|
|
234056072d | ||
|
|
76330f4bff | ||
|
|
d468eec6ec | ||
|
|
9bc6cc5b41 | ||
|
|
d109be159c | ||
|
|
eddf31e55b | ||
|
|
7e9d0db6aa | ||
|
|
2f1874ede5 | ||
|
|
6b83585b53 | ||
|
|
78ef04fcf1 | ||
|
|
b7e4f00c5f | ||
|
|
c20507c15e | ||
|
|
f7d0019df7 | ||
|
|
52364af5bf | ||
|
|
f410dd0440 | ||
|
|
eb5582c17c | ||
|
|
1c6cb2bec3 | ||
|
|
80b5e79e75 | ||
|
|
d182e893b6 | ||
|
|
2e8d49a641 | ||
|
|
6abd7d27d9 | ||
|
|
8fa12af403 | ||
|
|
77586ed7d3 | ||
|
|
394497fb2f | ||
|
|
fc7b6ef086 | ||
|
|
98edcad39d | ||
|
|
1187aa8222 | ||
|
|
a35d66443b | ||
|
|
40ad4a42ea | ||
|
|
dc9b4dd017 | ||
|
|
68cb81a258 | ||
|
|
16693053f5 | ||
|
|
4e3bad3907 | ||
|
|
c874f19f2a | ||
|
|
f5f26f0cbe | ||
|
|
e7e3ca1efb | ||
|
|
4b00312fef | ||
|
|
c5fd3db01e | ||
|
|
e35ffaa925 | ||
|
|
f870a9d2a7 | ||
|
|
165e03f3a7 | ||
|
|
86bdb7808c | ||
|
|
b4e034be1c | ||
|
|
84fcebf538 | ||
|
|
74d9a1ffed | ||
|
|
a5a25dec57 | ||
|
|
c71905e5e8 | ||
|
|
bc78d668ac | ||
|
|
e93eebc2e9 | ||
|
|
5bd0896ad7 | ||
|
|
09ecfbcaed | ||
|
|
f0bd14b64f | ||
|
|
14f044ce4f | ||
|
|
88872baffc | ||
|
|
dbecf5330e | ||
|
|
1c0e102637 | ||
|
|
6b6b343922 | ||
|
|
f7d82fda3f | ||
|
|
706590c62a | ||
|
|
25c6b479c7 | ||
|
|
7cf9ff0345 | ||
|
|
209d74062a | ||
|
|
d86b13c9cb | ||
|
|
075e3ab69e | ||
|
|
49ef22ab78 | ||
|
|
ae4638712e | ||
|
|
c1c9483752 | ||
|
|
6c65fdf54b | ||
|
|
4874253d1e | ||
|
|
b72250349f | ||
|
|
116573311f | ||
|
|
4af712544d | ||
|
|
3f9c9591bd | ||
|
|
1548c567ab | ||
|
|
5b23fc570c | ||
|
|
04e1c7a05a | ||
|
|
9181e72204 | ||
|
|
b854ee4680 | ||
|
|
533a6bd15c | ||
|
|
45546c1cf7 | ||
|
|
e2169e3987 | ||
|
|
e85305c815 | ||
|
|
8d4554bf17 | ||
|
|
f628e4dcbb | ||
|
|
7accae4b6a | ||
|
|
3354fae391 | ||
|
|
4939865f6d | ||
|
|
3da7f7482e | ||
|
|
9072b029b2 | ||
|
|
c296cfb8c0 | ||
|
|
2707377fcb | ||
|
|
259f586ff7 | ||
|
|
d885b81f23 | ||
|
|
fe6bffd080 | ||
|
|
1a81e8a98a | ||
|
|
0b889c6028 | ||
|
|
f6bb0011f9 | ||
|
|
fcdd91895e | ||
|
|
8dc4fc4ff5 | ||
|
|
9e9a860bda | ||
|
|
6cd32028c3 | ||
|
|
ebd58ef33a | ||
|
|
92791194e5 | ||
|
|
1f7c58f7ce | ||
|
|
a275db3fdb | ||
|
|
233be6272a | ||
|
|
47cb52385e | ||
|
|
ba168ec003 | ||
|
|
a12e22c66f | ||
|
|
4c50a7281a | ||
|
|
80d3fa384e | ||
|
|
b45ede0b71 | ||
|
|
a406ca2d5a |
3
.github/workflows/docker-image.yml
vendored
3
.github/workflows/docker-image.yml
vendored
@@ -1,13 +1,14 @@
|
||||
name: docker-image
|
||||
|
||||
on:
|
||||
workflow_dispatch:
|
||||
push:
|
||||
tags:
|
||||
- v*
|
||||
|
||||
env:
|
||||
APP_NAME: CLIProxyAPI
|
||||
DOCKERHUB_REPO: eceasy/cli-proxy-api-plus
|
||||
DOCKERHUB_REPO: ${{ secrets.DOCKERHUB_USERNAME }}/cli-proxy-api-plus
|
||||
|
||||
jobs:
|
||||
docker_amd64:
|
||||
|
||||
BIN
assets/aicodemirror.png
Normal file
BIN
assets/aicodemirror.png
Normal file
Binary file not shown.
|
After Width: | Height: | Size: 45 KiB |
Binary file not shown.
|
Before Width: | Height: | Size: 51 KiB |
@@ -77,6 +77,7 @@ func main() {
|
||||
var noBrowser bool
|
||||
var oauthCallbackPort int
|
||||
var antigravityLogin bool
|
||||
var kimiLogin bool
|
||||
var kiroLogin bool
|
||||
var kiroGoogleLogin bool
|
||||
var kiroAWSLogin bool
|
||||
@@ -102,6 +103,7 @@ func main() {
|
||||
flag.BoolVar(&useIncognito, "incognito", false, "Open browser in incognito/private mode for OAuth (useful for multiple accounts)")
|
||||
flag.BoolVar(&noIncognito, "no-incognito", false, "Force disable incognito mode (uses existing browser session)")
|
||||
flag.BoolVar(&antigravityLogin, "antigravity-login", false, "Login to Antigravity using OAuth")
|
||||
flag.BoolVar(&kimiLogin, "kimi-login", false, "Login to Kimi using OAuth")
|
||||
flag.BoolVar(&kiroLogin, "kiro-login", false, "Login to Kiro using Google OAuth")
|
||||
flag.BoolVar(&kiroGoogleLogin, "kiro-google-login", false, "Login to Kiro using Google OAuth (same as --kiro-login)")
|
||||
flag.BoolVar(&kiroAWSLogin, "kiro-aws-login", false, "Login to Kiro using AWS Builder ID (device code flow)")
|
||||
@@ -473,7 +475,7 @@ func main() {
|
||||
}
|
||||
|
||||
// Register built-in access providers before constructing services.
|
||||
configaccess.Register()
|
||||
configaccess.Register(&cfg.SDKConfig)
|
||||
|
||||
// Handle different command modes based on the provided flags.
|
||||
|
||||
@@ -501,6 +503,8 @@ func main() {
|
||||
cmd.DoIFlowLogin(cfg, options)
|
||||
} else if iflowCookie {
|
||||
cmd.DoIFlowCookieAuth(cfg, options)
|
||||
} else if kimiLogin {
|
||||
cmd.DoKimiLogin(cfg, options)
|
||||
} else if kiroLogin {
|
||||
// For Kiro auth, default to incognito mode for multi-account support
|
||||
// Users can explicitly override with --no-incognito
|
||||
|
||||
@@ -40,6 +40,11 @@ api-keys:
|
||||
# Enable debug logging
|
||||
debug: false
|
||||
|
||||
# Enable pprof HTTP debug server (host:port). Keep it bound to localhost for safety.
|
||||
pprof:
|
||||
enable: false
|
||||
addr: "127.0.0.1:8316"
|
||||
|
||||
# When true, disable high-overhead HTTP middleware features to reduce per-request memory usage under high concurrency.
|
||||
commercial-mode: false
|
||||
|
||||
@@ -231,10 +236,10 @@ nonstream-keepalive-interval: 0
|
||||
|
||||
# Global OAuth model name aliases (per channel)
|
||||
# These aliases rename model IDs for both model listing and request routing.
|
||||
# Supported channels: gemini-cli, vertex, aistudio, antigravity, claude, codex, qwen, iflow, kiro, github-copilot.
|
||||
# Supported channels: gemini-cli, vertex, aistudio, antigravity, claude, codex, qwen, iflow, kiro, github-copilot, kimi.
|
||||
# NOTE: Aliases do not apply to gemini-api-key, codex-api-key, claude-api-key, openai-compatibility, vertex-api-key, or ampcode.
|
||||
# You can repeat the same name with different aliases to expose multiple client model names.
|
||||
#oauth-model-alias:
|
||||
# oauth-model-alias:
|
||||
# antigravity:
|
||||
# - name: "rev19-uic3-1p"
|
||||
# alias: "gemini-2.5-computer-use-preview-10-2025"
|
||||
@@ -260,9 +265,6 @@ nonstream-keepalive-interval: 0
|
||||
# aistudio:
|
||||
# - name: "gemini-2.5-pro"
|
||||
# alias: "g2.5p"
|
||||
# antigravity:
|
||||
# - name: "gemini-3-pro-preview"
|
||||
# alias: "g3p"
|
||||
# claude:
|
||||
# - name: "claude-sonnet-4-5-20250929"
|
||||
# alias: "cs4.5"
|
||||
@@ -275,6 +277,9 @@ nonstream-keepalive-interval: 0
|
||||
# iflow:
|
||||
# - name: "glm-4.7"
|
||||
# alias: "glm-god"
|
||||
# kimi:
|
||||
# - name: "kimi-k2.5"
|
||||
# alias: "k2.5"
|
||||
# kiro:
|
||||
# - name: "kiro-claude-opus-4-5"
|
||||
# alias: "op45"
|
||||
@@ -304,6 +309,8 @@ nonstream-keepalive-interval: 0
|
||||
# - "vision-model"
|
||||
# iflow:
|
||||
# - "tstars2.0"
|
||||
# kimi:
|
||||
# - "kimi-k2-thinking"
|
||||
# kiro:
|
||||
# - "kiro-claude-haiku-4-5"
|
||||
# github-copilot:
|
||||
|
||||
@@ -7,80 +7,71 @@ The `github.com/router-for-me/CLIProxyAPI/v6/sdk/access` package centralizes inb
|
||||
```go
|
||||
import (
|
||||
sdkaccess "github.com/router-for-me/CLIProxyAPI/v6/sdk/access"
|
||||
"github.com/router-for-me/CLIProxyAPI/v6/internal/config"
|
||||
)
|
||||
```
|
||||
|
||||
Add the module with `go get github.com/router-for-me/CLIProxyAPI/v6/sdk/access`.
|
||||
|
||||
## Provider Registry
|
||||
|
||||
Providers are registered globally and then attached to a `Manager` as a snapshot:
|
||||
|
||||
- `RegisterProvider(type, provider)` installs a pre-initialized provider instance.
|
||||
- Registration order is preserved the first time each `type` is seen.
|
||||
- `RegisteredProviders()` returns the providers in that order.
|
||||
|
||||
## Manager Lifecycle
|
||||
|
||||
```go
|
||||
manager := sdkaccess.NewManager()
|
||||
providers, err := sdkaccess.BuildProviders(cfg)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
manager.SetProviders(providers)
|
||||
manager.SetProviders(sdkaccess.RegisteredProviders())
|
||||
```
|
||||
|
||||
* `NewManager` constructs an empty manager.
|
||||
* `SetProviders` replaces the provider slice using a defensive copy.
|
||||
* `Providers` retrieves a snapshot that can be iterated safely from other goroutines.
|
||||
* `BuildProviders` translates `config.Config` access declarations into runnable providers. When the config omits explicit providers but defines inline API keys, the helper auto-installs the built-in `config-api-key` provider.
|
||||
|
||||
If the manager itself is `nil` or no providers are configured, the call returns `nil, nil`, allowing callers to treat access control as disabled.
|
||||
|
||||
## Authenticating Requests
|
||||
|
||||
```go
|
||||
result, err := manager.Authenticate(ctx, req)
|
||||
result, authErr := manager.Authenticate(ctx, req)
|
||||
switch {
|
||||
case err == nil:
|
||||
case authErr == nil:
|
||||
// Authentication succeeded; result describes the provider and principal.
|
||||
case errors.Is(err, sdkaccess.ErrNoCredentials):
|
||||
case sdkaccess.IsAuthErrorCode(authErr, sdkaccess.AuthErrorCodeNoCredentials):
|
||||
// No recognizable credentials were supplied.
|
||||
case errors.Is(err, sdkaccess.ErrInvalidCredential):
|
||||
case sdkaccess.IsAuthErrorCode(authErr, sdkaccess.AuthErrorCodeInvalidCredential):
|
||||
// Supplied credentials were present but rejected.
|
||||
default:
|
||||
// Transport-level failure was returned by a provider.
|
||||
// Internal/transport failure was returned by a provider.
|
||||
}
|
||||
```
|
||||
|
||||
`Manager.Authenticate` walks the configured providers in order. It returns on the first success, skips providers that surface `ErrNotHandled`, and tracks whether any provider reported `ErrNoCredentials` or `ErrInvalidCredential` for downstream error reporting.
|
||||
|
||||
If the manager itself is `nil` or no providers are registered, the call returns `nil, nil`, allowing callers to treat access control as disabled without branching on errors.
|
||||
`Manager.Authenticate` walks the configured providers in order. It returns on the first success, skips providers that return `AuthErrorCodeNotHandled`, and aggregates `AuthErrorCodeNoCredentials` / `AuthErrorCodeInvalidCredential` for a final result.
|
||||
|
||||
Each `Result` includes the provider identifier, the resolved principal, and optional metadata (for example, which header carried the credential).
|
||||
|
||||
## Configuration Layout
|
||||
## Built-in `config-api-key` Provider
|
||||
|
||||
The manager expects access providers under the `auth.providers` key inside `config.yaml`:
|
||||
The proxy includes one built-in access provider:
|
||||
|
||||
- `config-api-key`: Validates API keys declared under top-level `api-keys`.
|
||||
- Credential sources: `Authorization: Bearer`, `X-Goog-Api-Key`, `X-Api-Key`, `?key=`, `?auth_token=`
|
||||
- Metadata: `Result.Metadata["source"]` is set to the matched source label.
|
||||
|
||||
In the CLI server and `sdk/cliproxy`, this provider is registered automatically based on the loaded configuration.
|
||||
|
||||
```yaml
|
||||
auth:
|
||||
providers:
|
||||
- name: inline-api
|
||||
type: config-api-key
|
||||
api-keys:
|
||||
- sk-test-123
|
||||
- sk-prod-456
|
||||
api-keys:
|
||||
- sk-test-123
|
||||
- sk-prod-456
|
||||
```
|
||||
|
||||
Fields map directly to `config.AccessProvider`: `name` labels the provider, `type` selects the registered factory, `sdk` can name an external module, `api-keys` seeds inline credentials, and `config` passes provider-specific options.
|
||||
## Loading Providers from External Go Modules
|
||||
|
||||
### Loading providers from external SDK modules
|
||||
|
||||
To consume a provider shipped in another Go module, point the `sdk` field at the module path and import it for its registration side effect:
|
||||
|
||||
```yaml
|
||||
auth:
|
||||
providers:
|
||||
- name: partner-auth
|
||||
type: partner-token
|
||||
sdk: github.com/acme/xplatform/sdk/access/providers/partner
|
||||
config:
|
||||
region: us-west-2
|
||||
audience: cli-proxy
|
||||
```
|
||||
To consume a provider shipped in another Go module, import it for its registration side effect:
|
||||
|
||||
```go
|
||||
import (
|
||||
@@ -89,19 +80,11 @@ import (
|
||||
)
|
||||
```
|
||||
|
||||
The blank identifier import ensures `init` runs so `sdkaccess.RegisterProvider` executes before `BuildProviders` is called.
|
||||
|
||||
## Built-in Providers
|
||||
|
||||
The SDK ships with one provider out of the box:
|
||||
|
||||
- `config-api-key`: Validates API keys declared inline or under top-level `api-keys`. It accepts the key from `Authorization: Bearer`, `X-Goog-Api-Key`, `X-Api-Key`, or the `?key=` query string and reports `ErrInvalidCredential` when no match is found.
|
||||
|
||||
Additional providers can be delivered by third-party packages. When a provider package is imported, it registers itself with `sdkaccess.RegisterProvider`.
|
||||
The blank identifier import ensures `init` runs so `sdkaccess.RegisterProvider` executes before you call `RegisteredProviders()` (or before `cliproxy.NewBuilder().Build()`).
|
||||
|
||||
### Metadata and auditing
|
||||
|
||||
`Result.Metadata` carries provider-specific context. The built-in `config-api-key` provider, for example, stores the credential source (`authorization`, `x-goog-api-key`, `x-api-key`, or `query-key`). Populate this map in custom providers to enrich logs and downstream auditing.
|
||||
`Result.Metadata` carries provider-specific context. The built-in `config-api-key` provider, for example, stores the credential source (`authorization`, `x-goog-api-key`, `x-api-key`, `query-key`, `query-auth-token`). Populate this map in custom providers to enrich logs and downstream auditing.
|
||||
|
||||
## Writing Custom Providers
|
||||
|
||||
@@ -110,13 +93,13 @@ type customProvider struct{}
|
||||
|
||||
func (p *customProvider) Identifier() string { return "my-provider" }
|
||||
|
||||
func (p *customProvider) Authenticate(ctx context.Context, r *http.Request) (*sdkaccess.Result, error) {
|
||||
func (p *customProvider) Authenticate(ctx context.Context, r *http.Request) (*sdkaccess.Result, *sdkaccess.AuthError) {
|
||||
token := r.Header.Get("X-Custom")
|
||||
if token == "" {
|
||||
return nil, sdkaccess.ErrNoCredentials
|
||||
return nil, sdkaccess.NewNotHandledError()
|
||||
}
|
||||
if token != "expected" {
|
||||
return nil, sdkaccess.ErrInvalidCredential
|
||||
return nil, sdkaccess.NewInvalidCredentialError()
|
||||
}
|
||||
return &sdkaccess.Result{
|
||||
Provider: p.Identifier(),
|
||||
@@ -126,51 +109,46 @@ func (p *customProvider) Authenticate(ctx context.Context, r *http.Request) (*sd
|
||||
}
|
||||
|
||||
func init() {
|
||||
sdkaccess.RegisterProvider("custom", func(cfg *config.AccessProvider, root *config.Config) (sdkaccess.Provider, error) {
|
||||
return &customProvider{}, nil
|
||||
})
|
||||
sdkaccess.RegisterProvider("custom", &customProvider{})
|
||||
}
|
||||
```
|
||||
|
||||
A provider must implement `Identifier()` and `Authenticate()`. To expose it to configuration, call `RegisterProvider` inside `init`. Provider factories receive the specific `AccessProvider` block plus the full root configuration for contextual needs.
|
||||
A provider must implement `Identifier()` and `Authenticate()`. To make it available to the access manager, call `RegisterProvider` inside `init` with an initialized provider instance.
|
||||
|
||||
## Error Semantics
|
||||
|
||||
- `ErrNoCredentials`: no credentials were present or recognized by any provider.
|
||||
- `ErrInvalidCredential`: at least one provider processed the credentials but rejected them.
|
||||
- `ErrNotHandled`: instructs the manager to fall through to the next provider without affecting aggregate error reporting.
|
||||
- `NewNoCredentialsError()` (`AuthErrorCodeNoCredentials`): no credentials were present or recognized. (HTTP 401)
|
||||
- `NewInvalidCredentialError()` (`AuthErrorCodeInvalidCredential`): credentials were present but rejected. (HTTP 401)
|
||||
- `NewNotHandledError()` (`AuthErrorCodeNotHandled`): fall through to the next provider.
|
||||
- `NewInternalAuthError(message, cause)` (`AuthErrorCodeInternal`): transport/system failure. (HTTP 500)
|
||||
|
||||
Return custom errors to surface transport failures; they propagate immediately to the caller instead of being masked.
|
||||
Errors propagate immediately to the caller unless they are classified as `not_handled` / `no_credentials` / `invalid_credential` and can be aggregated by the manager.
|
||||
|
||||
## Integration with cliproxy Service
|
||||
|
||||
`sdk/cliproxy` wires `@sdk/access` automatically when you build a CLI service via `cliproxy.NewBuilder`. Supplying a preconfigured manager allows you to extend or override the default providers:
|
||||
`sdk/cliproxy` wires `@sdk/access` automatically when you build a CLI service via `cliproxy.NewBuilder`. Supplying a manager lets you reuse the same instance in your host process:
|
||||
|
||||
```go
|
||||
coreCfg, _ := config.LoadConfig("config.yaml")
|
||||
providers, _ := sdkaccess.BuildProviders(coreCfg)
|
||||
manager := sdkaccess.NewManager()
|
||||
manager.SetProviders(providers)
|
||||
accessManager := sdkaccess.NewManager()
|
||||
|
||||
svc, _ := cliproxy.NewBuilder().
|
||||
WithConfig(coreCfg).
|
||||
WithAccessManager(manager).
|
||||
WithConfigPath("config.yaml").
|
||||
WithRequestAccessManager(accessManager).
|
||||
Build()
|
||||
```
|
||||
|
||||
The service reuses the manager for every inbound request, ensuring consistent authentication across embedded deployments and the canonical CLI binary.
|
||||
Register any custom providers (typically via blank imports) before calling `Build()` so they are present in the global registry snapshot.
|
||||
|
||||
### Hot reloading providers
|
||||
### Hot reloading
|
||||
|
||||
When configuration changes, rebuild providers and swap them into the manager:
|
||||
When configuration changes, refresh any config-backed providers and then reset the manager's provider chain:
|
||||
|
||||
```go
|
||||
providers, err := sdkaccess.BuildProviders(newCfg)
|
||||
if err != nil {
|
||||
log.Errorf("reload auth providers failed: %v", err)
|
||||
return
|
||||
}
|
||||
accessManager.SetProviders(providers)
|
||||
// configaccess is github.com/router-for-me/CLIProxyAPI/v6/internal/access/config_access
|
||||
configaccess.Register(&newCfg.SDKConfig)
|
||||
accessManager.SetProviders(sdkaccess.RegisteredProviders())
|
||||
```
|
||||
|
||||
This mirrors the behaviour in `cliproxy.Service.refreshAccessProviders` and `api.Server.applyAccessConfig`, enabling runtime updates without restarting the process.
|
||||
This mirrors the behaviour in `internal/access.ApplyAccessProviders`, enabling runtime updates without restarting the process.
|
||||
|
||||
@@ -7,80 +7,71 @@
|
||||
```go
|
||||
import (
|
||||
sdkaccess "github.com/router-for-me/CLIProxyAPI/v6/sdk/access"
|
||||
"github.com/router-for-me/CLIProxyAPI/v6/internal/config"
|
||||
)
|
||||
```
|
||||
|
||||
通过 `go get github.com/router-for-me/CLIProxyAPI/v6/sdk/access` 添加依赖。
|
||||
|
||||
## Provider Registry
|
||||
|
||||
访问提供者是全局注册,然后以快照形式挂到 `Manager` 上:
|
||||
|
||||
- `RegisterProvider(type, provider)` 注册一个已经初始化好的 provider 实例。
|
||||
- 每个 `type` 第一次出现时会记录其注册顺序。
|
||||
- `RegisteredProviders()` 会按该顺序返回 provider 列表。
|
||||
|
||||
## 管理器生命周期
|
||||
|
||||
```go
|
||||
manager := sdkaccess.NewManager()
|
||||
providers, err := sdkaccess.BuildProviders(cfg)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
manager.SetProviders(providers)
|
||||
manager.SetProviders(sdkaccess.RegisteredProviders())
|
||||
```
|
||||
|
||||
- `NewManager` 创建空管理器。
|
||||
- `SetProviders` 替换提供者切片并做防御性拷贝。
|
||||
- `Providers` 返回适合并发读取的快照。
|
||||
- `BuildProviders` 将 `config.Config` 中的访问配置转换成可运行的提供者。当配置没有显式声明但包含顶层 `api-keys` 时,会自动挂载内建的 `config-api-key` 提供者。
|
||||
|
||||
如果管理器本身为 `nil` 或未配置任何 provider,调用会返回 `nil, nil`,可视为关闭访问控制。
|
||||
|
||||
## 认证请求
|
||||
|
||||
```go
|
||||
result, err := manager.Authenticate(ctx, req)
|
||||
result, authErr := manager.Authenticate(ctx, req)
|
||||
switch {
|
||||
case err == nil:
|
||||
case authErr == nil:
|
||||
// Authentication succeeded; result carries provider and principal.
|
||||
case errors.Is(err, sdkaccess.ErrNoCredentials):
|
||||
case sdkaccess.IsAuthErrorCode(authErr, sdkaccess.AuthErrorCodeNoCredentials):
|
||||
// No recognizable credentials were supplied.
|
||||
case errors.Is(err, sdkaccess.ErrInvalidCredential):
|
||||
case sdkaccess.IsAuthErrorCode(authErr, sdkaccess.AuthErrorCodeInvalidCredential):
|
||||
// Credentials were present but rejected.
|
||||
default:
|
||||
// Provider surfaced a transport-level failure.
|
||||
}
|
||||
```
|
||||
|
||||
`Manager.Authenticate` 按配置顺序遍历提供者。遇到成功立即返回,`ErrNotHandled` 会继续尝试下一个;若发现 `ErrNoCredentials` 或 `ErrInvalidCredential`,会在遍历结束后汇总给调用方。
|
||||
|
||||
若管理器本身为 `nil` 或尚未注册提供者,调用会返回 `nil, nil`,让调用方无需针对错误做额外分支即可关闭访问控制。
|
||||
`Manager.Authenticate` 会按顺序遍历 provider:遇到成功立即返回,`AuthErrorCodeNotHandled` 会继续尝试下一个;`AuthErrorCodeNoCredentials` / `AuthErrorCodeInvalidCredential` 会在遍历结束后汇总给调用方。
|
||||
|
||||
`Result` 提供认证提供者标识、解析出的主体以及可选元数据(例如凭证来源)。
|
||||
|
||||
## 配置结构
|
||||
## 内建 `config-api-key` Provider
|
||||
|
||||
在 `config.yaml` 的 `auth.providers` 下定义访问提供者:
|
||||
代理内置一个访问提供者:
|
||||
|
||||
- `config-api-key`:校验 `config.yaml` 顶层的 `api-keys`。
|
||||
- 凭证来源:`Authorization: Bearer`、`X-Goog-Api-Key`、`X-Api-Key`、`?key=`、`?auth_token=`
|
||||
- 元数据:`Result.Metadata["source"]` 会写入匹配到的来源标识
|
||||
|
||||
在 CLI 服务端与 `sdk/cliproxy` 中,该 provider 会根据加载到的配置自动注册。
|
||||
|
||||
```yaml
|
||||
auth:
|
||||
providers:
|
||||
- name: inline-api
|
||||
type: config-api-key
|
||||
api-keys:
|
||||
- sk-test-123
|
||||
- sk-prod-456
|
||||
api-keys:
|
||||
- sk-test-123
|
||||
- sk-prod-456
|
||||
```
|
||||
|
||||
条目映射到 `config.AccessProvider`:`name` 指定实例名,`type` 选择注册的工厂,`sdk` 可引用第三方模块,`api-keys` 提供内联凭证,`config` 用于传递特定选项。
|
||||
## 引入外部 Go 模块提供者
|
||||
|
||||
### 引入外部 SDK 提供者
|
||||
|
||||
若要消费其它 Go 模块输出的访问提供者,可在配置里填写 `sdk` 字段并在代码中引入该包,利用其 `init` 注册过程:
|
||||
|
||||
```yaml
|
||||
auth:
|
||||
providers:
|
||||
- name: partner-auth
|
||||
type: partner-token
|
||||
sdk: github.com/acme/xplatform/sdk/access/providers/partner
|
||||
config:
|
||||
region: us-west-2
|
||||
audience: cli-proxy
|
||||
```
|
||||
若要消费其它 Go 模块输出的访问提供者,直接用空白标识符导入以触发其 `init` 注册即可:
|
||||
|
||||
```go
|
||||
import (
|
||||
@@ -89,19 +80,11 @@ import (
|
||||
)
|
||||
```
|
||||
|
||||
通过空白标识符导入即可确保 `init` 调用,先于 `BuildProviders` 完成 `sdkaccess.RegisterProvider`。
|
||||
|
||||
## 内建提供者
|
||||
|
||||
当前 SDK 默认内置:
|
||||
|
||||
- `config-api-key`:校验配置中的 API Key。它从 `Authorization: Bearer`、`X-Goog-Api-Key`、`X-Api-Key` 以及查询参数 `?key=` 提取凭证,不匹配时抛出 `ErrInvalidCredential`。
|
||||
|
||||
导入第三方包即可通过 `sdkaccess.RegisterProvider` 注册更多类型。
|
||||
空白导入可确保 `init` 先执行,从而在你调用 `RegisteredProviders()`(或 `cliproxy.NewBuilder().Build()`)之前完成 `sdkaccess.RegisterProvider`。
|
||||
|
||||
### 元数据与审计
|
||||
|
||||
`Result.Metadata` 用于携带提供者特定的上下文信息。内建的 `config-api-key` 会记录凭证来源(`authorization`、`x-goog-api-key`、`x-api-key` 或 `query-key`)。自定义提供者同样可以填充该 Map,以便丰富日志与审计场景。
|
||||
`Result.Metadata` 用于携带提供者特定的上下文信息。内建的 `config-api-key` 会记录凭证来源(`authorization`、`x-goog-api-key`、`x-api-key`、`query-key`、`query-auth-token`)。自定义提供者同样可以填充该 Map,以便丰富日志与审计场景。
|
||||
|
||||
## 编写自定义提供者
|
||||
|
||||
@@ -110,13 +93,13 @@ type customProvider struct{}
|
||||
|
||||
func (p *customProvider) Identifier() string { return "my-provider" }
|
||||
|
||||
func (p *customProvider) Authenticate(ctx context.Context, r *http.Request) (*sdkaccess.Result, error) {
|
||||
func (p *customProvider) Authenticate(ctx context.Context, r *http.Request) (*sdkaccess.Result, *sdkaccess.AuthError) {
|
||||
token := r.Header.Get("X-Custom")
|
||||
if token == "" {
|
||||
return nil, sdkaccess.ErrNoCredentials
|
||||
return nil, sdkaccess.NewNotHandledError()
|
||||
}
|
||||
if token != "expected" {
|
||||
return nil, sdkaccess.ErrInvalidCredential
|
||||
return nil, sdkaccess.NewInvalidCredentialError()
|
||||
}
|
||||
return &sdkaccess.Result{
|
||||
Provider: p.Identifier(),
|
||||
@@ -126,51 +109,46 @@ func (p *customProvider) Authenticate(ctx context.Context, r *http.Request) (*sd
|
||||
}
|
||||
|
||||
func init() {
|
||||
sdkaccess.RegisterProvider("custom", func(cfg *config.AccessProvider, root *config.Config) (sdkaccess.Provider, error) {
|
||||
return &customProvider{}, nil
|
||||
})
|
||||
sdkaccess.RegisterProvider("custom", &customProvider{})
|
||||
}
|
||||
```
|
||||
|
||||
自定义提供者需要实现 `Identifier()` 与 `Authenticate()`。在 `init` 中调用 `RegisterProvider` 暴露给配置层,工厂函数既能读取当前条目,也能访问完整根配置。
|
||||
自定义提供者需要实现 `Identifier()` 与 `Authenticate()`。在 `init` 中用已初始化实例调用 `RegisterProvider` 注册到全局 registry。
|
||||
|
||||
## 错误语义
|
||||
|
||||
- `ErrNoCredentials`:任何提供者都未识别到凭证。
|
||||
- `ErrInvalidCredential`:至少一个提供者处理了凭证但判定无效。
|
||||
- `ErrNotHandled`:告诉管理器跳到下一个提供者,不影响最终错误统计。
|
||||
- `NewNoCredentialsError()`(`AuthErrorCodeNoCredentials`):未提供或未识别到凭证。(HTTP 401)
|
||||
- `NewInvalidCredentialError()`(`AuthErrorCodeInvalidCredential`):凭证存在但校验失败。(HTTP 401)
|
||||
- `NewNotHandledError()`(`AuthErrorCodeNotHandled`):告诉管理器跳到下一个 provider。
|
||||
- `NewInternalAuthError(message, cause)`(`AuthErrorCodeInternal`):网络/系统错误。(HTTP 500)
|
||||
|
||||
自定义错误(例如网络异常)会马上冒泡返回。
|
||||
除可汇总的 `not_handled` / `no_credentials` / `invalid_credential` 外,其它错误会立即冒泡返回。
|
||||
|
||||
## 与 cliproxy 集成
|
||||
|
||||
使用 `sdk/cliproxy` 构建服务时会自动接入 `@sdk/access`。如果需要扩展内置行为,可传入自定义管理器:
|
||||
使用 `sdk/cliproxy` 构建服务时会自动接入 `@sdk/access`。如果希望在宿主进程里复用同一个 `Manager` 实例,可传入自定义管理器:
|
||||
|
||||
```go
|
||||
coreCfg, _ := config.LoadConfig("config.yaml")
|
||||
providers, _ := sdkaccess.BuildProviders(coreCfg)
|
||||
manager := sdkaccess.NewManager()
|
||||
manager.SetProviders(providers)
|
||||
accessManager := sdkaccess.NewManager()
|
||||
|
||||
svc, _ := cliproxy.NewBuilder().
|
||||
WithConfig(coreCfg).
|
||||
WithAccessManager(manager).
|
||||
WithConfigPath("config.yaml").
|
||||
WithRequestAccessManager(accessManager).
|
||||
Build()
|
||||
```
|
||||
|
||||
服务会复用该管理器处理每一个入站请求,实现与 CLI 二进制一致的访问控制体验。
|
||||
请在调用 `Build()` 之前完成自定义 provider 的注册(通常通过空白导入触发 `init`),以确保它们被包含在全局 registry 的快照中。
|
||||
|
||||
### 动态热更新提供者
|
||||
|
||||
当配置发生变化时,可以重新构建提供者并替换当前列表:
|
||||
当配置发生变化时,刷新依赖配置的 provider,然后重置 manager 的 provider 链:
|
||||
|
||||
```go
|
||||
providers, err := sdkaccess.BuildProviders(newCfg)
|
||||
if err != nil {
|
||||
log.Errorf("reload auth providers failed: %v", err)
|
||||
return
|
||||
}
|
||||
accessManager.SetProviders(providers)
|
||||
// configaccess is github.com/router-for-me/CLIProxyAPI/v6/internal/access/config_access
|
||||
configaccess.Register(&newCfg.SDKConfig)
|
||||
accessManager.SetProviders(sdkaccess.RegisteredProviders())
|
||||
```
|
||||
|
||||
这一流程与 `cliproxy.Service.refreshAccessProviders` 和 `api.Server.applyAccessConfig` 保持一致,避免为更新访问策略而重启进程。
|
||||
这一流程与 `internal/access.ApplyAccessProviders` 保持一致,避免为更新访问策略而重启进程。
|
||||
|
||||
@@ -4,19 +4,28 @@ import (
|
||||
"context"
|
||||
"net/http"
|
||||
"strings"
|
||||
"sync"
|
||||
|
||||
sdkaccess "github.com/router-for-me/CLIProxyAPI/v6/sdk/access"
|
||||
sdkconfig "github.com/router-for-me/CLIProxyAPI/v6/sdk/config"
|
||||
)
|
||||
|
||||
var registerOnce sync.Once
|
||||
|
||||
// Register ensures the config-access provider is available to the access manager.
|
||||
func Register() {
|
||||
registerOnce.Do(func() {
|
||||
sdkaccess.RegisterProvider(sdkconfig.AccessProviderTypeConfigAPIKey, newProvider)
|
||||
})
|
||||
func Register(cfg *sdkconfig.SDKConfig) {
|
||||
if cfg == nil {
|
||||
sdkaccess.UnregisterProvider(sdkaccess.AccessProviderTypeConfigAPIKey)
|
||||
return
|
||||
}
|
||||
|
||||
keys := normalizeKeys(cfg.APIKeys)
|
||||
if len(keys) == 0 {
|
||||
sdkaccess.UnregisterProvider(sdkaccess.AccessProviderTypeConfigAPIKey)
|
||||
return
|
||||
}
|
||||
|
||||
sdkaccess.RegisterProvider(
|
||||
sdkaccess.AccessProviderTypeConfigAPIKey,
|
||||
newProvider(sdkaccess.DefaultAccessProviderName, keys),
|
||||
)
|
||||
}
|
||||
|
||||
type provider struct {
|
||||
@@ -24,34 +33,31 @@ type provider struct {
|
||||
keys map[string]struct{}
|
||||
}
|
||||
|
||||
func newProvider(cfg *sdkconfig.AccessProvider, _ *sdkconfig.SDKConfig) (sdkaccess.Provider, error) {
|
||||
name := cfg.Name
|
||||
if name == "" {
|
||||
name = sdkconfig.DefaultAccessProviderName
|
||||
func newProvider(name string, keys []string) *provider {
|
||||
providerName := strings.TrimSpace(name)
|
||||
if providerName == "" {
|
||||
providerName = sdkaccess.DefaultAccessProviderName
|
||||
}
|
||||
keys := make(map[string]struct{}, len(cfg.APIKeys))
|
||||
for _, key := range cfg.APIKeys {
|
||||
if key == "" {
|
||||
continue
|
||||
}
|
||||
keys[key] = struct{}{}
|
||||
keySet := make(map[string]struct{}, len(keys))
|
||||
for _, key := range keys {
|
||||
keySet[key] = struct{}{}
|
||||
}
|
||||
return &provider{name: name, keys: keys}, nil
|
||||
return &provider{name: providerName, keys: keySet}
|
||||
}
|
||||
|
||||
func (p *provider) Identifier() string {
|
||||
if p == nil || p.name == "" {
|
||||
return sdkconfig.DefaultAccessProviderName
|
||||
return sdkaccess.DefaultAccessProviderName
|
||||
}
|
||||
return p.name
|
||||
}
|
||||
|
||||
func (p *provider) Authenticate(_ context.Context, r *http.Request) (*sdkaccess.Result, error) {
|
||||
func (p *provider) Authenticate(_ context.Context, r *http.Request) (*sdkaccess.Result, *sdkaccess.AuthError) {
|
||||
if p == nil {
|
||||
return nil, sdkaccess.ErrNotHandled
|
||||
return nil, sdkaccess.NewNotHandledError()
|
||||
}
|
||||
if len(p.keys) == 0 {
|
||||
return nil, sdkaccess.ErrNotHandled
|
||||
return nil, sdkaccess.NewNotHandledError()
|
||||
}
|
||||
authHeader := r.Header.Get("Authorization")
|
||||
authHeaderGoogle := r.Header.Get("X-Goog-Api-Key")
|
||||
@@ -63,7 +69,7 @@ func (p *provider) Authenticate(_ context.Context, r *http.Request) (*sdkaccess.
|
||||
queryAuthToken = r.URL.Query().Get("auth_token")
|
||||
}
|
||||
if authHeader == "" && authHeaderGoogle == "" && authHeaderAnthropic == "" && queryKey == "" && queryAuthToken == "" {
|
||||
return nil, sdkaccess.ErrNoCredentials
|
||||
return nil, sdkaccess.NewNoCredentialsError()
|
||||
}
|
||||
|
||||
apiKey := extractBearerToken(authHeader)
|
||||
@@ -94,7 +100,7 @@ func (p *provider) Authenticate(_ context.Context, r *http.Request) (*sdkaccess.
|
||||
}
|
||||
}
|
||||
|
||||
return nil, sdkaccess.ErrInvalidCredential
|
||||
return nil, sdkaccess.NewInvalidCredentialError()
|
||||
}
|
||||
|
||||
func extractBearerToken(header string) string {
|
||||
@@ -110,3 +116,26 @@ func extractBearerToken(header string) string {
|
||||
}
|
||||
return strings.TrimSpace(parts[1])
|
||||
}
|
||||
|
||||
func normalizeKeys(keys []string) []string {
|
||||
if len(keys) == 0 {
|
||||
return nil
|
||||
}
|
||||
normalized := make([]string, 0, len(keys))
|
||||
seen := make(map[string]struct{}, len(keys))
|
||||
for _, key := range keys {
|
||||
trimmedKey := strings.TrimSpace(key)
|
||||
if trimmedKey == "" {
|
||||
continue
|
||||
}
|
||||
if _, exists := seen[trimmedKey]; exists {
|
||||
continue
|
||||
}
|
||||
seen[trimmedKey] = struct{}{}
|
||||
normalized = append(normalized, trimmedKey)
|
||||
}
|
||||
if len(normalized) == 0 {
|
||||
return nil
|
||||
}
|
||||
return normalized
|
||||
}
|
||||
|
||||
@@ -6,9 +6,9 @@ import (
|
||||
"sort"
|
||||
"strings"
|
||||
|
||||
configaccess "github.com/router-for-me/CLIProxyAPI/v6/internal/access/config_access"
|
||||
"github.com/router-for-me/CLIProxyAPI/v6/internal/config"
|
||||
sdkaccess "github.com/router-for-me/CLIProxyAPI/v6/sdk/access"
|
||||
sdkConfig "github.com/router-for-me/CLIProxyAPI/v6/sdk/config"
|
||||
log "github.com/sirupsen/logrus"
|
||||
)
|
||||
|
||||
@@ -17,26 +17,26 @@ import (
|
||||
// ordered provider slice along with the identifiers of providers that were added, updated, or
|
||||
// removed compared to the previous configuration.
|
||||
func ReconcileProviders(oldCfg, newCfg *config.Config, existing []sdkaccess.Provider) (result []sdkaccess.Provider, added, updated, removed []string, err error) {
|
||||
_ = oldCfg
|
||||
if newCfg == nil {
|
||||
return nil, nil, nil, nil, nil
|
||||
}
|
||||
|
||||
result = sdkaccess.RegisteredProviders()
|
||||
|
||||
existingMap := make(map[string]sdkaccess.Provider, len(existing))
|
||||
for _, provider := range existing {
|
||||
if provider == nil {
|
||||
providerID := identifierFromProvider(provider)
|
||||
if providerID == "" {
|
||||
continue
|
||||
}
|
||||
existingMap[provider.Identifier()] = provider
|
||||
existingMap[providerID] = provider
|
||||
}
|
||||
|
||||
oldCfgMap := accessProviderMap(oldCfg)
|
||||
newEntries := collectProviderEntries(newCfg)
|
||||
|
||||
result = make([]sdkaccess.Provider, 0, len(newEntries))
|
||||
finalIDs := make(map[string]struct{}, len(newEntries))
|
||||
finalIDs := make(map[string]struct{}, len(result))
|
||||
|
||||
isInlineProvider := func(id string) bool {
|
||||
return strings.EqualFold(id, sdkConfig.DefaultAccessProviderName)
|
||||
return strings.EqualFold(id, sdkaccess.DefaultAccessProviderName)
|
||||
}
|
||||
appendChange := func(list *[]string, id string) {
|
||||
if isInlineProvider(id) {
|
||||
@@ -45,85 +45,28 @@ func ReconcileProviders(oldCfg, newCfg *config.Config, existing []sdkaccess.Prov
|
||||
*list = append(*list, id)
|
||||
}
|
||||
|
||||
for _, providerCfg := range newEntries {
|
||||
key := providerIdentifier(providerCfg)
|
||||
if key == "" {
|
||||
for _, provider := range result {
|
||||
providerID := identifierFromProvider(provider)
|
||||
if providerID == "" {
|
||||
continue
|
||||
}
|
||||
finalIDs[providerID] = struct{}{}
|
||||
|
||||
forceRebuild := strings.EqualFold(strings.TrimSpace(providerCfg.Type), sdkConfig.AccessProviderTypeConfigAPIKey)
|
||||
if oldCfgProvider, ok := oldCfgMap[key]; ok {
|
||||
isAliased := oldCfgProvider == providerCfg
|
||||
if !forceRebuild && !isAliased && providerConfigEqual(oldCfgProvider, providerCfg) {
|
||||
if existingProvider, okExisting := existingMap[key]; okExisting {
|
||||
result = append(result, existingProvider)
|
||||
finalIDs[key] = struct{}{}
|
||||
continue
|
||||
}
|
||||
}
|
||||
existingProvider, exists := existingMap[providerID]
|
||||
if !exists {
|
||||
appendChange(&added, providerID)
|
||||
continue
|
||||
}
|
||||
|
||||
provider, buildErr := sdkaccess.BuildProvider(providerCfg, &newCfg.SDKConfig)
|
||||
if buildErr != nil {
|
||||
return nil, nil, nil, nil, buildErr
|
||||
}
|
||||
if _, ok := oldCfgMap[key]; ok {
|
||||
if _, existed := existingMap[key]; existed {
|
||||
appendChange(&updated, key)
|
||||
} else {
|
||||
appendChange(&added, key)
|
||||
}
|
||||
} else {
|
||||
appendChange(&added, key)
|
||||
}
|
||||
result = append(result, provider)
|
||||
finalIDs[key] = struct{}{}
|
||||
}
|
||||
|
||||
if len(result) == 0 {
|
||||
if inline := sdkConfig.MakeInlineAPIKeyProvider(newCfg.APIKeys); inline != nil {
|
||||
key := providerIdentifier(inline)
|
||||
if key != "" {
|
||||
if oldCfgProvider, ok := oldCfgMap[key]; ok {
|
||||
if providerConfigEqual(oldCfgProvider, inline) {
|
||||
if existingProvider, okExisting := existingMap[key]; okExisting {
|
||||
result = append(result, existingProvider)
|
||||
finalIDs[key] = struct{}{}
|
||||
goto inlineDone
|
||||
}
|
||||
}
|
||||
}
|
||||
provider, buildErr := sdkaccess.BuildProvider(inline, &newCfg.SDKConfig)
|
||||
if buildErr != nil {
|
||||
return nil, nil, nil, nil, buildErr
|
||||
}
|
||||
if _, existed := existingMap[key]; existed {
|
||||
appendChange(&updated, key)
|
||||
} else if _, hadOld := oldCfgMap[key]; hadOld {
|
||||
appendChange(&updated, key)
|
||||
} else {
|
||||
appendChange(&added, key)
|
||||
}
|
||||
result = append(result, provider)
|
||||
finalIDs[key] = struct{}{}
|
||||
}
|
||||
}
|
||||
inlineDone:
|
||||
}
|
||||
|
||||
removedSet := make(map[string]struct{})
|
||||
for id := range existingMap {
|
||||
if _, ok := finalIDs[id]; !ok {
|
||||
if isInlineProvider(id) {
|
||||
continue
|
||||
}
|
||||
removedSet[id] = struct{}{}
|
||||
if !providerInstanceEqual(existingProvider, provider) {
|
||||
appendChange(&updated, providerID)
|
||||
}
|
||||
}
|
||||
|
||||
removed = make([]string, 0, len(removedSet))
|
||||
for id := range removedSet {
|
||||
removed = append(removed, id)
|
||||
for providerID := range existingMap {
|
||||
if _, exists := finalIDs[providerID]; exists {
|
||||
continue
|
||||
}
|
||||
appendChange(&removed, providerID)
|
||||
}
|
||||
|
||||
sort.Strings(added)
|
||||
@@ -142,6 +85,7 @@ func ApplyAccessProviders(manager *sdkaccess.Manager, oldCfg, newCfg *config.Con
|
||||
}
|
||||
|
||||
existing := manager.Providers()
|
||||
configaccess.Register(&newCfg.SDKConfig)
|
||||
providers, added, updated, removed, err := ReconcileProviders(oldCfg, newCfg, existing)
|
||||
if err != nil {
|
||||
log.Errorf("failed to reconcile request auth providers: %v", err)
|
||||
@@ -160,111 +104,24 @@ func ApplyAccessProviders(manager *sdkaccess.Manager, oldCfg, newCfg *config.Con
|
||||
return false, nil
|
||||
}
|
||||
|
||||
func accessProviderMap(cfg *config.Config) map[string]*sdkConfig.AccessProvider {
|
||||
result := make(map[string]*sdkConfig.AccessProvider)
|
||||
if cfg == nil {
|
||||
return result
|
||||
}
|
||||
for i := range cfg.Access.Providers {
|
||||
providerCfg := &cfg.Access.Providers[i]
|
||||
if providerCfg.Type == "" {
|
||||
continue
|
||||
}
|
||||
key := providerIdentifier(providerCfg)
|
||||
if key == "" {
|
||||
continue
|
||||
}
|
||||
result[key] = providerCfg
|
||||
}
|
||||
if len(result) == 0 && len(cfg.APIKeys) > 0 {
|
||||
if provider := sdkConfig.MakeInlineAPIKeyProvider(cfg.APIKeys); provider != nil {
|
||||
if key := providerIdentifier(provider); key != "" {
|
||||
result[key] = provider
|
||||
}
|
||||
}
|
||||
}
|
||||
return result
|
||||
}
|
||||
|
||||
func collectProviderEntries(cfg *config.Config) []*sdkConfig.AccessProvider {
|
||||
entries := make([]*sdkConfig.AccessProvider, 0, len(cfg.Access.Providers))
|
||||
for i := range cfg.Access.Providers {
|
||||
providerCfg := &cfg.Access.Providers[i]
|
||||
if providerCfg.Type == "" {
|
||||
continue
|
||||
}
|
||||
if key := providerIdentifier(providerCfg); key != "" {
|
||||
entries = append(entries, providerCfg)
|
||||
}
|
||||
}
|
||||
if len(entries) == 0 && len(cfg.APIKeys) > 0 {
|
||||
if inline := sdkConfig.MakeInlineAPIKeyProvider(cfg.APIKeys); inline != nil {
|
||||
entries = append(entries, inline)
|
||||
}
|
||||
}
|
||||
return entries
|
||||
}
|
||||
|
||||
func providerIdentifier(provider *sdkConfig.AccessProvider) string {
|
||||
func identifierFromProvider(provider sdkaccess.Provider) string {
|
||||
if provider == nil {
|
||||
return ""
|
||||
}
|
||||
if name := strings.TrimSpace(provider.Name); name != "" {
|
||||
return name
|
||||
}
|
||||
typ := strings.TrimSpace(provider.Type)
|
||||
if typ == "" {
|
||||
return ""
|
||||
}
|
||||
if strings.EqualFold(typ, sdkConfig.AccessProviderTypeConfigAPIKey) {
|
||||
return sdkConfig.DefaultAccessProviderName
|
||||
}
|
||||
return typ
|
||||
return strings.TrimSpace(provider.Identifier())
|
||||
}
|
||||
|
||||
func providerConfigEqual(a, b *sdkConfig.AccessProvider) bool {
|
||||
func providerInstanceEqual(a, b sdkaccess.Provider) bool {
|
||||
if a == nil || b == nil {
|
||||
return a == nil && b == nil
|
||||
}
|
||||
if !strings.EqualFold(strings.TrimSpace(a.Type), strings.TrimSpace(b.Type)) {
|
||||
if reflect.TypeOf(a) != reflect.TypeOf(b) {
|
||||
return false
|
||||
}
|
||||
if strings.TrimSpace(a.SDK) != strings.TrimSpace(b.SDK) {
|
||||
return false
|
||||
valueA := reflect.ValueOf(a)
|
||||
valueB := reflect.ValueOf(b)
|
||||
if valueA.Kind() == reflect.Pointer && valueB.Kind() == reflect.Pointer {
|
||||
return valueA.Pointer() == valueB.Pointer()
|
||||
}
|
||||
if !stringSetEqual(a.APIKeys, b.APIKeys) {
|
||||
return false
|
||||
}
|
||||
if len(a.Config) != len(b.Config) {
|
||||
return false
|
||||
}
|
||||
if len(a.Config) > 0 && !reflect.DeepEqual(a.Config, b.Config) {
|
||||
return false
|
||||
}
|
||||
return true
|
||||
}
|
||||
|
||||
func stringSetEqual(a, b []string) bool {
|
||||
if len(a) != len(b) {
|
||||
return false
|
||||
}
|
||||
if len(a) == 0 {
|
||||
return true
|
||||
}
|
||||
seen := make(map[string]int, len(a))
|
||||
for _, val := range a {
|
||||
seen[val]++
|
||||
}
|
||||
for _, val := range b {
|
||||
count := seen[val]
|
||||
if count == 0 {
|
||||
return false
|
||||
}
|
||||
if count == 1 {
|
||||
delete(seen, val)
|
||||
} else {
|
||||
seen[val] = count - 1
|
||||
}
|
||||
}
|
||||
return len(seen) == 0
|
||||
return reflect.DeepEqual(a, b)
|
||||
}
|
||||
|
||||
@@ -13,12 +13,13 @@ import (
|
||||
|
||||
"github.com/fxamacker/cbor/v2"
|
||||
"github.com/gin-gonic/gin"
|
||||
"github.com/router-for-me/CLIProxyAPI/v6/internal/runtime/geminicli"
|
||||
coreauth "github.com/router-for-me/CLIProxyAPI/v6/sdk/cliproxy/auth"
|
||||
log "github.com/sirupsen/logrus"
|
||||
"golang.org/x/net/proxy"
|
||||
"golang.org/x/oauth2"
|
||||
"golang.org/x/oauth2/google"
|
||||
|
||||
"github.com/router-for-me/CLIProxyAPI/v6/internal/runtime/geminicli"
|
||||
coreauth "github.com/router-for-me/CLIProxyAPI/v6/sdk/cliproxy/auth"
|
||||
)
|
||||
|
||||
const defaultAPICallTimeout = 60 * time.Second
|
||||
@@ -55,6 +56,7 @@ type apiCallResponse struct {
|
||||
StatusCode int `json:"status_code"`
|
||||
Header map[string][]string `json:"header"`
|
||||
Body string `json:"body"`
|
||||
Quota *QuotaSnapshots `json:"quota,omitempty"`
|
||||
}
|
||||
|
||||
// APICall makes a generic HTTP request on behalf of the management API caller.
|
||||
@@ -97,6 +99,8 @@ type apiCallResponse struct {
|
||||
// - status_code: Upstream HTTP status code.
|
||||
// - header: Upstream response headers.
|
||||
// - body: Upstream response body as string.
|
||||
// - quota (optional): For GitHub Copilot enterprise accounts, contains quota_snapshots
|
||||
// with details for chat, completions, and premium_interactions.
|
||||
//
|
||||
// Example:
|
||||
//
|
||||
@@ -236,6 +240,13 @@ func (h *Handler) APICall(c *gin.Context) {
|
||||
Body: string(respBody),
|
||||
}
|
||||
|
||||
// If this is a GitHub Copilot token endpoint response, try to enrich with quota information
|
||||
if resp.StatusCode == http.StatusOK &&
|
||||
strings.Contains(urlStr, "copilot_internal") &&
|
||||
strings.Contains(urlStr, "/token") {
|
||||
response = h.enrichCopilotTokenResponse(c.Request.Context(), response, auth, urlStr)
|
||||
}
|
||||
|
||||
// Return response in the same format as the request
|
||||
if isCBOR {
|
||||
cborData, errMarshal := cbor.Marshal(response)
|
||||
@@ -735,3 +746,344 @@ func buildProxyTransport(proxyStr string) *http.Transport {
|
||||
log.Debugf("unsupported proxy scheme: %s", proxyURL.Scheme)
|
||||
return nil
|
||||
}
|
||||
|
||||
// QuotaDetail represents quota information for a specific resource type
|
||||
type QuotaDetail struct {
|
||||
Entitlement float64 `json:"entitlement"`
|
||||
OverageCount float64 `json:"overage_count"`
|
||||
OveragePermitted bool `json:"overage_permitted"`
|
||||
PercentRemaining float64 `json:"percent_remaining"`
|
||||
QuotaID string `json:"quota_id"`
|
||||
QuotaRemaining float64 `json:"quota_remaining"`
|
||||
Remaining float64 `json:"remaining"`
|
||||
Unlimited bool `json:"unlimited"`
|
||||
}
|
||||
|
||||
// QuotaSnapshots contains quota details for different resource types
|
||||
type QuotaSnapshots struct {
|
||||
Chat QuotaDetail `json:"chat"`
|
||||
Completions QuotaDetail `json:"completions"`
|
||||
PremiumInteractions QuotaDetail `json:"premium_interactions"`
|
||||
}
|
||||
|
||||
// CopilotUsageResponse represents the GitHub Copilot usage information
|
||||
type CopilotUsageResponse struct {
|
||||
AccessTypeSKU string `json:"access_type_sku"`
|
||||
AnalyticsTrackingID string `json:"analytics_tracking_id"`
|
||||
AssignedDate string `json:"assigned_date"`
|
||||
CanSignupForLimited bool `json:"can_signup_for_limited"`
|
||||
ChatEnabled bool `json:"chat_enabled"`
|
||||
CopilotPlan string `json:"copilot_plan"`
|
||||
OrganizationLoginList []interface{} `json:"organization_login_list"`
|
||||
OrganizationList []interface{} `json:"organization_list"`
|
||||
QuotaResetDate string `json:"quota_reset_date"`
|
||||
QuotaSnapshots QuotaSnapshots `json:"quota_snapshots"`
|
||||
}
|
||||
|
||||
type copilotQuotaRequest struct {
|
||||
AuthIndexSnake *string `json:"auth_index"`
|
||||
AuthIndexCamel *string `json:"authIndex"`
|
||||
AuthIndexPascal *string `json:"AuthIndex"`
|
||||
}
|
||||
|
||||
// GetCopilotQuota fetches GitHub Copilot quota information from the /copilot_internal/user endpoint.
|
||||
//
|
||||
// Endpoint:
|
||||
//
|
||||
// GET /v0/management/copilot-quota
|
||||
//
|
||||
// Query Parameters (optional):
|
||||
// - auth_index: The credential "auth_index" from GET /v0/management/auth-files.
|
||||
// If omitted, uses the first available GitHub Copilot credential.
|
||||
//
|
||||
// Response:
|
||||
//
|
||||
// Returns the CopilotUsageResponse with quota_snapshots containing detailed quota information
|
||||
// for chat, completions, and premium_interactions.
|
||||
//
|
||||
// Example:
|
||||
//
|
||||
// curl -sS -X GET "http://127.0.0.1:8317/v0/management/copilot-quota?auth_index=<AUTH_INDEX>" \
|
||||
// -H "Authorization: Bearer <MANAGEMENT_KEY>"
|
||||
func (h *Handler) GetCopilotQuota(c *gin.Context) {
|
||||
authIndex := strings.TrimSpace(c.Query("auth_index"))
|
||||
if authIndex == "" {
|
||||
authIndex = strings.TrimSpace(c.Query("authIndex"))
|
||||
}
|
||||
if authIndex == "" {
|
||||
authIndex = strings.TrimSpace(c.Query("AuthIndex"))
|
||||
}
|
||||
|
||||
auth := h.findCopilotAuth(authIndex)
|
||||
if auth == nil {
|
||||
c.JSON(http.StatusBadRequest, gin.H{"error": "no github copilot credential found"})
|
||||
return
|
||||
}
|
||||
|
||||
token, tokenErr := h.resolveTokenForAuth(c.Request.Context(), auth)
|
||||
if tokenErr != nil {
|
||||
c.JSON(http.StatusBadRequest, gin.H{"error": "failed to refresh copilot token"})
|
||||
return
|
||||
}
|
||||
if token == "" {
|
||||
c.JSON(http.StatusBadRequest, gin.H{"error": "copilot token not found"})
|
||||
return
|
||||
}
|
||||
|
||||
apiURL := "https://api.github.com/copilot_internal/user"
|
||||
req, errNewRequest := http.NewRequestWithContext(c.Request.Context(), http.MethodGet, apiURL, nil)
|
||||
if errNewRequest != nil {
|
||||
c.JSON(http.StatusInternalServerError, gin.H{"error": "failed to build request"})
|
||||
return
|
||||
}
|
||||
|
||||
req.Header.Set("Authorization", "Bearer "+token)
|
||||
req.Header.Set("User-Agent", "CLIProxyAPIPlus")
|
||||
req.Header.Set("Accept", "application/json")
|
||||
|
||||
httpClient := &http.Client{
|
||||
Timeout: defaultAPICallTimeout,
|
||||
Transport: h.apiCallTransport(auth),
|
||||
}
|
||||
|
||||
resp, errDo := httpClient.Do(req)
|
||||
if errDo != nil {
|
||||
log.WithError(errDo).Debug("copilot quota request failed")
|
||||
c.JSON(http.StatusBadGateway, gin.H{"error": "request failed"})
|
||||
return
|
||||
}
|
||||
defer func() {
|
||||
if errClose := resp.Body.Close(); errClose != nil {
|
||||
log.Errorf("response body close error: %v", errClose)
|
||||
}
|
||||
}()
|
||||
|
||||
respBody, errReadAll := io.ReadAll(resp.Body)
|
||||
if errReadAll != nil {
|
||||
c.JSON(http.StatusBadGateway, gin.H{"error": "failed to read response"})
|
||||
return
|
||||
}
|
||||
|
||||
if resp.StatusCode != http.StatusOK {
|
||||
c.JSON(http.StatusBadGateway, gin.H{
|
||||
"error": "github api request failed",
|
||||
"status_code": resp.StatusCode,
|
||||
"body": string(respBody),
|
||||
})
|
||||
return
|
||||
}
|
||||
|
||||
var usage CopilotUsageResponse
|
||||
if errUnmarshal := json.Unmarshal(respBody, &usage); errUnmarshal != nil {
|
||||
c.JSON(http.StatusInternalServerError, gin.H{"error": "failed to parse response"})
|
||||
return
|
||||
}
|
||||
|
||||
c.JSON(http.StatusOK, usage)
|
||||
}
|
||||
|
||||
// findCopilotAuth locates a GitHub Copilot credential by auth_index or returns the first available one
|
||||
func (h *Handler) findCopilotAuth(authIndex string) *coreauth.Auth {
|
||||
if h == nil || h.authManager == nil {
|
||||
return nil
|
||||
}
|
||||
|
||||
auths := h.authManager.List()
|
||||
var firstCopilot *coreauth.Auth
|
||||
|
||||
for _, auth := range auths {
|
||||
if auth == nil {
|
||||
continue
|
||||
}
|
||||
|
||||
provider := strings.ToLower(strings.TrimSpace(auth.Provider))
|
||||
if provider != "copilot" && provider != "github" && provider != "github-copilot" {
|
||||
continue
|
||||
}
|
||||
|
||||
if firstCopilot == nil {
|
||||
firstCopilot = auth
|
||||
}
|
||||
|
||||
if authIndex != "" {
|
||||
auth.EnsureIndex()
|
||||
if auth.Index == authIndex {
|
||||
return auth
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
return firstCopilot
|
||||
}
|
||||
|
||||
// enrichCopilotTokenResponse fetches quota information and adds it to the Copilot token response body
|
||||
func (h *Handler) enrichCopilotTokenResponse(ctx context.Context, response apiCallResponse, auth *coreauth.Auth, originalURL string) apiCallResponse {
|
||||
if auth == nil || response.Body == "" {
|
||||
return response
|
||||
}
|
||||
|
||||
// Parse the token response to check if it's enterprise (null limited_user_quotas)
|
||||
var tokenResp map[string]interface{}
|
||||
if err := json.Unmarshal([]byte(response.Body), &tokenResp); err != nil {
|
||||
log.WithError(err).Debug("enrichCopilotTokenResponse: failed to parse copilot token response")
|
||||
return response
|
||||
}
|
||||
|
||||
// Get the GitHub token to call the copilot_internal/user endpoint
|
||||
token, tokenErr := h.resolveTokenForAuth(ctx, auth)
|
||||
if tokenErr != nil {
|
||||
log.WithError(tokenErr).Debug("enrichCopilotTokenResponse: failed to resolve token")
|
||||
return response
|
||||
}
|
||||
if token == "" {
|
||||
return response
|
||||
}
|
||||
|
||||
// Fetch quota information from /copilot_internal/user
|
||||
// Derive the base URL from the original token request to support proxies and test servers
|
||||
parsedURL, errParse := url.Parse(originalURL)
|
||||
if errParse != nil {
|
||||
log.WithError(errParse).Debug("enrichCopilotTokenResponse: failed to parse URL")
|
||||
return response
|
||||
}
|
||||
quotaURL := fmt.Sprintf("%s://%s/copilot_internal/user", parsedURL.Scheme, parsedURL.Host)
|
||||
|
||||
req, errNewRequest := http.NewRequestWithContext(ctx, http.MethodGet, quotaURL, nil)
|
||||
if errNewRequest != nil {
|
||||
log.WithError(errNewRequest).Debug("enrichCopilotTokenResponse: failed to build request")
|
||||
return response
|
||||
}
|
||||
|
||||
req.Header.Set("Authorization", "Bearer "+token)
|
||||
req.Header.Set("User-Agent", "CLIProxyAPIPlus")
|
||||
req.Header.Set("Accept", "application/json")
|
||||
|
||||
httpClient := &http.Client{
|
||||
Timeout: defaultAPICallTimeout,
|
||||
Transport: h.apiCallTransport(auth),
|
||||
}
|
||||
|
||||
quotaResp, errDo := httpClient.Do(req)
|
||||
if errDo != nil {
|
||||
log.WithError(errDo).Debug("enrichCopilotTokenResponse: quota fetch HTTP request failed")
|
||||
return response
|
||||
}
|
||||
|
||||
defer func() {
|
||||
if errClose := quotaResp.Body.Close(); errClose != nil {
|
||||
log.Errorf("quota response body close error: %v", errClose)
|
||||
}
|
||||
}()
|
||||
|
||||
if quotaResp.StatusCode != http.StatusOK {
|
||||
return response
|
||||
}
|
||||
|
||||
quotaBody, errReadAll := io.ReadAll(quotaResp.Body)
|
||||
if errReadAll != nil {
|
||||
log.WithError(errReadAll).Debug("enrichCopilotTokenResponse: failed to read response")
|
||||
return response
|
||||
}
|
||||
|
||||
// Parse the quota response
|
||||
var quotaData CopilotUsageResponse
|
||||
if err := json.Unmarshal(quotaBody, "aData); err != nil {
|
||||
log.WithError(err).Debug("enrichCopilotTokenResponse: failed to parse response")
|
||||
return response
|
||||
}
|
||||
|
||||
// Check if this is an enterprise account by looking for quota_snapshots in the response
|
||||
// Enterprise accounts have quota_snapshots, non-enterprise have limited_user_quotas
|
||||
var quotaRaw map[string]interface{}
|
||||
if err := json.Unmarshal(quotaBody, "aRaw); err == nil {
|
||||
if _, hasQuotaSnapshots := quotaRaw["quota_snapshots"]; hasQuotaSnapshots {
|
||||
// Enterprise account - has quota_snapshots
|
||||
tokenResp["quota_snapshots"] = quotaData.QuotaSnapshots
|
||||
tokenResp["access_type_sku"] = quotaData.AccessTypeSKU
|
||||
tokenResp["copilot_plan"] = quotaData.CopilotPlan
|
||||
|
||||
// Add quota reset date for enterprise (quota_reset_date_utc)
|
||||
if quotaResetDateUTC, ok := quotaRaw["quota_reset_date_utc"]; ok {
|
||||
tokenResp["quota_reset_date"] = quotaResetDateUTC
|
||||
} else if quotaData.QuotaResetDate != "" {
|
||||
tokenResp["quota_reset_date"] = quotaData.QuotaResetDate
|
||||
}
|
||||
} else {
|
||||
// Non-enterprise account - build quota from limited_user_quotas and monthly_quotas
|
||||
var quotaSnapshots QuotaSnapshots
|
||||
|
||||
// Get monthly quotas (total entitlement) and limited_user_quotas (remaining)
|
||||
monthlyQuotas, hasMonthly := quotaRaw["monthly_quotas"].(map[string]interface{})
|
||||
limitedQuotas, hasLimited := quotaRaw["limited_user_quotas"].(map[string]interface{})
|
||||
|
||||
// Process chat quota
|
||||
if hasMonthly && hasLimited {
|
||||
if chatTotal, ok := monthlyQuotas["chat"].(float64); ok {
|
||||
chatRemaining := chatTotal // default to full if no limited quota
|
||||
if chatLimited, ok := limitedQuotas["chat"].(float64); ok {
|
||||
chatRemaining = chatLimited
|
||||
}
|
||||
percentRemaining := 0.0
|
||||
if chatTotal > 0 {
|
||||
percentRemaining = (chatRemaining / chatTotal) * 100.0
|
||||
}
|
||||
quotaSnapshots.Chat = QuotaDetail{
|
||||
Entitlement: chatTotal,
|
||||
Remaining: chatRemaining,
|
||||
QuotaRemaining: chatRemaining,
|
||||
PercentRemaining: percentRemaining,
|
||||
QuotaID: "chat",
|
||||
Unlimited: false,
|
||||
}
|
||||
}
|
||||
|
||||
// Process completions quota
|
||||
if completionsTotal, ok := monthlyQuotas["completions"].(float64); ok {
|
||||
completionsRemaining := completionsTotal // default to full if no limited quota
|
||||
if completionsLimited, ok := limitedQuotas["completions"].(float64); ok {
|
||||
completionsRemaining = completionsLimited
|
||||
}
|
||||
percentRemaining := 0.0
|
||||
if completionsTotal > 0 {
|
||||
percentRemaining = (completionsRemaining / completionsTotal) * 100.0
|
||||
}
|
||||
quotaSnapshots.Completions = QuotaDetail{
|
||||
Entitlement: completionsTotal,
|
||||
Remaining: completionsRemaining,
|
||||
QuotaRemaining: completionsRemaining,
|
||||
PercentRemaining: percentRemaining,
|
||||
QuotaID: "completions",
|
||||
Unlimited: false,
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// Premium interactions don't exist for non-enterprise, leave as zero values
|
||||
quotaSnapshots.PremiumInteractions = QuotaDetail{
|
||||
QuotaID: "premium_interactions",
|
||||
Unlimited: false,
|
||||
}
|
||||
|
||||
// Add quota_snapshots to the token response
|
||||
tokenResp["quota_snapshots"] = quotaSnapshots
|
||||
tokenResp["access_type_sku"] = quotaData.AccessTypeSKU
|
||||
tokenResp["copilot_plan"] = quotaData.CopilotPlan
|
||||
|
||||
// Add quota reset date for non-enterprise (limited_user_reset_date)
|
||||
if limitedResetDate, ok := quotaRaw["limited_user_reset_date"]; ok {
|
||||
tokenResp["quota_reset_date"] = limitedResetDate
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// Re-serialize the enriched response
|
||||
enrichedBody, errMarshal := json.Marshal(tokenResp)
|
||||
if errMarshal != nil {
|
||||
log.WithError(errMarshal).Debug("failed to marshal enriched response")
|
||||
return response
|
||||
}
|
||||
|
||||
response.Body = string(enrichedBody)
|
||||
|
||||
return response
|
||||
}
|
||||
|
||||
@@ -29,6 +29,7 @@ import (
|
||||
"github.com/router-for-me/CLIProxyAPI/v6/internal/auth/copilot"
|
||||
geminiAuth "github.com/router-for-me/CLIProxyAPI/v6/internal/auth/gemini"
|
||||
iflowauth "github.com/router-for-me/CLIProxyAPI/v6/internal/auth/iflow"
|
||||
"github.com/router-for-me/CLIProxyAPI/v6/internal/auth/kimi"
|
||||
kiroauth "github.com/router-for-me/CLIProxyAPI/v6/internal/auth/kiro"
|
||||
"github.com/router-for-me/CLIProxyAPI/v6/internal/auth/qwen"
|
||||
"github.com/router-for-me/CLIProxyAPI/v6/internal/interfaces"
|
||||
@@ -1613,6 +1614,82 @@ func (h *Handler) RequestQwenToken(c *gin.Context) {
|
||||
c.JSON(200, gin.H{"status": "ok", "url": authURL, "state": state})
|
||||
}
|
||||
|
||||
func (h *Handler) RequestKimiToken(c *gin.Context) {
|
||||
ctx := context.Background()
|
||||
|
||||
fmt.Println("Initializing Kimi authentication...")
|
||||
|
||||
state := fmt.Sprintf("kmi-%d", time.Now().UnixNano())
|
||||
// Initialize Kimi auth service
|
||||
kimiAuth := kimi.NewKimiAuth(h.cfg)
|
||||
|
||||
// Generate authorization URL
|
||||
deviceFlow, errStartDeviceFlow := kimiAuth.StartDeviceFlow(ctx)
|
||||
if errStartDeviceFlow != nil {
|
||||
log.Errorf("Failed to generate authorization URL: %v", errStartDeviceFlow)
|
||||
c.JSON(http.StatusInternalServerError, gin.H{"error": "failed to generate authorization url"})
|
||||
return
|
||||
}
|
||||
authURL := deviceFlow.VerificationURIComplete
|
||||
if authURL == "" {
|
||||
authURL = deviceFlow.VerificationURI
|
||||
}
|
||||
|
||||
RegisterOAuthSession(state, "kimi")
|
||||
|
||||
go func() {
|
||||
fmt.Println("Waiting for authentication...")
|
||||
authBundle, errWaitForAuthorization := kimiAuth.WaitForAuthorization(ctx, deviceFlow)
|
||||
if errWaitForAuthorization != nil {
|
||||
SetOAuthSessionError(state, "Authentication failed")
|
||||
fmt.Printf("Authentication failed: %v\n", errWaitForAuthorization)
|
||||
return
|
||||
}
|
||||
|
||||
// Create token storage
|
||||
tokenStorage := kimiAuth.CreateTokenStorage(authBundle)
|
||||
|
||||
metadata := map[string]any{
|
||||
"type": "kimi",
|
||||
"access_token": authBundle.TokenData.AccessToken,
|
||||
"refresh_token": authBundle.TokenData.RefreshToken,
|
||||
"token_type": authBundle.TokenData.TokenType,
|
||||
"scope": authBundle.TokenData.Scope,
|
||||
"timestamp": time.Now().UnixMilli(),
|
||||
}
|
||||
if authBundle.TokenData.ExpiresAt > 0 {
|
||||
expired := time.Unix(authBundle.TokenData.ExpiresAt, 0).UTC().Format(time.RFC3339)
|
||||
metadata["expired"] = expired
|
||||
}
|
||||
if strings.TrimSpace(authBundle.DeviceID) != "" {
|
||||
metadata["device_id"] = strings.TrimSpace(authBundle.DeviceID)
|
||||
}
|
||||
|
||||
fileName := fmt.Sprintf("kimi-%d.json", time.Now().UnixMilli())
|
||||
record := &coreauth.Auth{
|
||||
ID: fileName,
|
||||
Provider: "kimi",
|
||||
FileName: fileName,
|
||||
Label: "Kimi User",
|
||||
Storage: tokenStorage,
|
||||
Metadata: metadata,
|
||||
}
|
||||
savedPath, errSave := h.saveTokenRecord(ctx, record)
|
||||
if errSave != nil {
|
||||
log.Errorf("Failed to save authentication tokens: %v", errSave)
|
||||
SetOAuthSessionError(state, "Failed to save authentication tokens")
|
||||
return
|
||||
}
|
||||
|
||||
fmt.Printf("Authentication successful! Token saved to %s\n", savedPath)
|
||||
fmt.Println("You can now use Kimi services through this CLI")
|
||||
CompleteOAuthSession(state)
|
||||
CompleteOAuthSessionsByProvider("kimi")
|
||||
}()
|
||||
|
||||
c.JSON(200, gin.H{"status": "ok", "url": authURL, "state": state})
|
||||
}
|
||||
|
||||
func (h *Handler) RequestIFlowToken(c *gin.Context) {
|
||||
ctx := context.Background()
|
||||
|
||||
|
||||
@@ -109,14 +109,13 @@ func (h *Handler) GetAPIKeys(c *gin.Context) { c.JSON(200, gin.H{"api-keys": h.c
|
||||
func (h *Handler) PutAPIKeys(c *gin.Context) {
|
||||
h.putStringList(c, func(v []string) {
|
||||
h.cfg.APIKeys = append([]string(nil), v...)
|
||||
h.cfg.Access.Providers = nil
|
||||
}, nil)
|
||||
}
|
||||
func (h *Handler) PatchAPIKeys(c *gin.Context) {
|
||||
h.patchStringList(c, &h.cfg.APIKeys, func() { h.cfg.Access.Providers = nil })
|
||||
h.patchStringList(c, &h.cfg.APIKeys, func() {})
|
||||
}
|
||||
func (h *Handler) DeleteAPIKeys(c *gin.Context) {
|
||||
h.deleteFromStringList(c, &h.cfg.APIKeys, func() { h.cfg.Access.Providers = nil })
|
||||
h.deleteFromStringList(c, &h.cfg.APIKeys, func() {})
|
||||
}
|
||||
|
||||
// gemini-api-key: []GeminiKey
|
||||
|
||||
@@ -122,7 +122,7 @@ func (rw *ResponseRewriter) Flush() {
|
||||
}
|
||||
|
||||
// modelFieldPaths lists all JSON paths where model name may appear
|
||||
var modelFieldPaths = []string{"model", "modelVersion", "response.modelVersion", "message.model"}
|
||||
var modelFieldPaths = []string{"message.model", "model", "modelVersion", "response.model", "response.modelVersion"}
|
||||
|
||||
// rewriteModelInResponse replaces all occurrences of the mapped model with the original model in JSON
|
||||
// It also suppresses "thinking" blocks if "tool_use" is present to ensure Amp client compatibility
|
||||
|
||||
110
internal/api/modules/amp/response_rewriter_test.go
Normal file
110
internal/api/modules/amp/response_rewriter_test.go
Normal file
@@ -0,0 +1,110 @@
|
||||
package amp
|
||||
|
||||
import (
|
||||
"testing"
|
||||
)
|
||||
|
||||
func TestRewriteModelInResponse_TopLevel(t *testing.T) {
|
||||
rw := &ResponseRewriter{originalModel: "gpt-5.2-codex"}
|
||||
|
||||
input := []byte(`{"id":"resp_1","model":"gpt-5.3-codex","output":[]}`)
|
||||
result := rw.rewriteModelInResponse(input)
|
||||
|
||||
expected := `{"id":"resp_1","model":"gpt-5.2-codex","output":[]}`
|
||||
if string(result) != expected {
|
||||
t.Errorf("expected %s, got %s", expected, string(result))
|
||||
}
|
||||
}
|
||||
|
||||
func TestRewriteModelInResponse_ResponseModel(t *testing.T) {
|
||||
rw := &ResponseRewriter{originalModel: "gpt-5.2-codex"}
|
||||
|
||||
input := []byte(`{"type":"response.completed","response":{"id":"resp_1","model":"gpt-5.3-codex","status":"completed"}}`)
|
||||
result := rw.rewriteModelInResponse(input)
|
||||
|
||||
expected := `{"type":"response.completed","response":{"id":"resp_1","model":"gpt-5.2-codex","status":"completed"}}`
|
||||
if string(result) != expected {
|
||||
t.Errorf("expected %s, got %s", expected, string(result))
|
||||
}
|
||||
}
|
||||
|
||||
func TestRewriteModelInResponse_ResponseCreated(t *testing.T) {
|
||||
rw := &ResponseRewriter{originalModel: "gpt-5.2-codex"}
|
||||
|
||||
input := []byte(`{"type":"response.created","response":{"id":"resp_1","model":"gpt-5.3-codex","status":"in_progress"}}`)
|
||||
result := rw.rewriteModelInResponse(input)
|
||||
|
||||
expected := `{"type":"response.created","response":{"id":"resp_1","model":"gpt-5.2-codex","status":"in_progress"}}`
|
||||
if string(result) != expected {
|
||||
t.Errorf("expected %s, got %s", expected, string(result))
|
||||
}
|
||||
}
|
||||
|
||||
func TestRewriteModelInResponse_NoModelField(t *testing.T) {
|
||||
rw := &ResponseRewriter{originalModel: "gpt-5.2-codex"}
|
||||
|
||||
input := []byte(`{"type":"response.output_item.added","item":{"id":"item_1","type":"message"}}`)
|
||||
result := rw.rewriteModelInResponse(input)
|
||||
|
||||
if string(result) != string(input) {
|
||||
t.Errorf("expected no modification, got %s", string(result))
|
||||
}
|
||||
}
|
||||
|
||||
func TestRewriteModelInResponse_EmptyOriginalModel(t *testing.T) {
|
||||
rw := &ResponseRewriter{originalModel: ""}
|
||||
|
||||
input := []byte(`{"model":"gpt-5.3-codex"}`)
|
||||
result := rw.rewriteModelInResponse(input)
|
||||
|
||||
if string(result) != string(input) {
|
||||
t.Errorf("expected no modification when originalModel is empty, got %s", string(result))
|
||||
}
|
||||
}
|
||||
|
||||
func TestRewriteStreamChunk_SSEWithResponseModel(t *testing.T) {
|
||||
rw := &ResponseRewriter{originalModel: "gpt-5.2-codex"}
|
||||
|
||||
chunk := []byte("data: {\"type\":\"response.completed\",\"response\":{\"id\":\"resp_1\",\"model\":\"gpt-5.3-codex\",\"status\":\"completed\"}}\n\n")
|
||||
result := rw.rewriteStreamChunk(chunk)
|
||||
|
||||
expected := "data: {\"type\":\"response.completed\",\"response\":{\"id\":\"resp_1\",\"model\":\"gpt-5.2-codex\",\"status\":\"completed\"}}\n\n"
|
||||
if string(result) != expected {
|
||||
t.Errorf("expected %s, got %s", expected, string(result))
|
||||
}
|
||||
}
|
||||
|
||||
func TestRewriteStreamChunk_MultipleEvents(t *testing.T) {
|
||||
rw := &ResponseRewriter{originalModel: "gpt-5.2-codex"}
|
||||
|
||||
chunk := []byte("data: {\"type\":\"response.created\",\"response\":{\"model\":\"gpt-5.3-codex\"}}\n\ndata: {\"type\":\"response.output_item.added\",\"item\":{\"id\":\"item_1\"}}\n\n")
|
||||
result := rw.rewriteStreamChunk(chunk)
|
||||
|
||||
if string(result) == string(chunk) {
|
||||
t.Error("expected response.model to be rewritten in SSE stream")
|
||||
}
|
||||
if !contains(result, []byte(`"model":"gpt-5.2-codex"`)) {
|
||||
t.Errorf("expected rewritten model in output, got %s", string(result))
|
||||
}
|
||||
}
|
||||
|
||||
func TestRewriteStreamChunk_MessageModel(t *testing.T) {
|
||||
rw := &ResponseRewriter{originalModel: "claude-opus-4.5"}
|
||||
|
||||
chunk := []byte("data: {\"message\":{\"model\":\"claude-sonnet-4\",\"role\":\"assistant\"}}\n\n")
|
||||
result := rw.rewriteStreamChunk(chunk)
|
||||
|
||||
expected := "data: {\"message\":{\"model\":\"claude-opus-4.5\",\"role\":\"assistant\"}}\n\n"
|
||||
if string(result) != expected {
|
||||
t.Errorf("expected %s, got %s", expected, string(result))
|
||||
}
|
||||
}
|
||||
|
||||
func contains(data, substr []byte) bool {
|
||||
for i := 0; i <= len(data)-len(substr); i++ {
|
||||
if string(data[i:i+len(substr)]) == string(substr) {
|
||||
return true
|
||||
}
|
||||
}
|
||||
return false
|
||||
}
|
||||
@@ -649,6 +649,7 @@ func (s *Server) registerManagementRoutes() {
|
||||
mgmt.GET("/gemini-cli-auth-url", s.mgmt.RequestGeminiCLIToken)
|
||||
mgmt.GET("/antigravity-auth-url", s.mgmt.RequestAntigravityToken)
|
||||
mgmt.GET("/qwen-auth-url", s.mgmt.RequestQwenToken)
|
||||
mgmt.GET("/kimi-auth-url", s.mgmt.RequestKimiToken)
|
||||
mgmt.GET("/iflow-auth-url", s.mgmt.RequestIFlowToken)
|
||||
mgmt.POST("/iflow-auth-url", s.mgmt.RequestIFlowCookieToken)
|
||||
mgmt.GET("/kiro-auth-url", s.mgmt.RequestKiroToken)
|
||||
@@ -682,14 +683,17 @@ func (s *Server) serveManagementControlPanel(c *gin.Context) {
|
||||
|
||||
if _, err := os.Stat(filePath); err != nil {
|
||||
if os.IsNotExist(err) {
|
||||
go managementasset.EnsureLatestManagementHTML(context.Background(), managementasset.StaticDir(s.configFilePath), cfg.ProxyURL, cfg.RemoteManagement.PanelGitHubRepository)
|
||||
c.AbortWithStatus(http.StatusNotFound)
|
||||
// Synchronously ensure management.html is available with a detached context.
|
||||
// Control panel bootstrap should not be canceled by client disconnects.
|
||||
if !managementasset.EnsureLatestManagementHTML(context.Background(), managementasset.StaticDir(s.configFilePath), cfg.ProxyURL, cfg.RemoteManagement.PanelGitHubRepository) {
|
||||
c.AbortWithStatus(http.StatusNotFound)
|
||||
return
|
||||
}
|
||||
} else {
|
||||
log.WithError(err).Error("failed to stat management control panel asset")
|
||||
c.AbortWithStatus(http.StatusInternalServerError)
|
||||
return
|
||||
}
|
||||
|
||||
log.WithError(err).Error("failed to stat management control panel asset")
|
||||
c.AbortWithStatus(http.StatusInternalServerError)
|
||||
return
|
||||
}
|
||||
|
||||
c.File(filePath)
|
||||
@@ -979,10 +983,6 @@ func (s *Server) UpdateClients(cfg *config.Config) {
|
||||
|
||||
s.handlers.UpdateClients(&cfg.SDKConfig)
|
||||
|
||||
if !cfg.RemoteManagement.DisableControlPanel {
|
||||
staticDir := managementasset.StaticDir(s.configFilePath)
|
||||
go managementasset.EnsureLatestManagementHTML(context.Background(), staticDir, cfg.ProxyURL, cfg.RemoteManagement.PanelGitHubRepository)
|
||||
}
|
||||
if s.mgmt != nil {
|
||||
s.mgmt.SetConfig(cfg)
|
||||
s.mgmt.SetAuthManager(s.handlers.AuthManager)
|
||||
@@ -1061,14 +1061,10 @@ func AuthMiddleware(manager *sdkaccess.Manager) gin.HandlerFunc {
|
||||
return
|
||||
}
|
||||
|
||||
switch {
|
||||
case errors.Is(err, sdkaccess.ErrNoCredentials):
|
||||
c.AbortWithStatusJSON(http.StatusUnauthorized, gin.H{"error": "Missing API key"})
|
||||
case errors.Is(err, sdkaccess.ErrInvalidCredential):
|
||||
c.AbortWithStatusJSON(http.StatusUnauthorized, gin.H{"error": "Invalid API key"})
|
||||
default:
|
||||
statusCode := err.HTTPStatusCode()
|
||||
if statusCode >= http.StatusInternalServerError {
|
||||
log.Errorf("authentication middleware error: %v", err)
|
||||
c.AbortWithStatusJSON(http.StatusInternalServerError, gin.H{"error": "Authentication service error"})
|
||||
}
|
||||
c.AbortWithStatusJSON(statusCode, gin.H{"error": err.Message})
|
||||
}
|
||||
}
|
||||
|
||||
396
internal/auth/kimi/kimi.go
Normal file
396
internal/auth/kimi/kimi.go
Normal file
@@ -0,0 +1,396 @@
|
||||
// Package kimi provides authentication and token management for Kimi (Moonshot AI) API.
|
||||
// It handles the RFC 8628 OAuth2 Device Authorization Grant flow for secure authentication.
|
||||
package kimi
|
||||
|
||||
import (
|
||||
"context"
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"io"
|
||||
"net/http"
|
||||
"net/url"
|
||||
"os"
|
||||
"runtime"
|
||||
"strings"
|
||||
"time"
|
||||
|
||||
"github.com/google/uuid"
|
||||
"github.com/router-for-me/CLIProxyAPI/v6/internal/config"
|
||||
"github.com/router-for-me/CLIProxyAPI/v6/internal/util"
|
||||
log "github.com/sirupsen/logrus"
|
||||
)
|
||||
|
||||
const (
|
||||
// kimiClientID is Kimi Code's OAuth client ID.
|
||||
kimiClientID = "17e5f671-d194-4dfb-9706-5516cb48c098"
|
||||
// kimiOAuthHost is the OAuth server endpoint.
|
||||
kimiOAuthHost = "https://auth.kimi.com"
|
||||
// kimiDeviceCodeURL is the endpoint for requesting device codes.
|
||||
kimiDeviceCodeURL = kimiOAuthHost + "/api/oauth/device_authorization"
|
||||
// kimiTokenURL is the endpoint for exchanging device codes for tokens.
|
||||
kimiTokenURL = kimiOAuthHost + "/api/oauth/token"
|
||||
// KimiAPIBaseURL is the base URL for Kimi API requests.
|
||||
KimiAPIBaseURL = "https://api.kimi.com/coding"
|
||||
// defaultPollInterval is the default interval for polling token endpoint.
|
||||
defaultPollInterval = 5 * time.Second
|
||||
// maxPollDuration is the maximum time to wait for user authorization.
|
||||
maxPollDuration = 15 * time.Minute
|
||||
// refreshThresholdSeconds is when to refresh token before expiry (5 minutes).
|
||||
refreshThresholdSeconds = 300
|
||||
)
|
||||
|
||||
// KimiAuth handles Kimi authentication flow.
|
||||
type KimiAuth struct {
|
||||
deviceClient *DeviceFlowClient
|
||||
cfg *config.Config
|
||||
}
|
||||
|
||||
// NewKimiAuth creates a new KimiAuth service instance.
|
||||
func NewKimiAuth(cfg *config.Config) *KimiAuth {
|
||||
return &KimiAuth{
|
||||
deviceClient: NewDeviceFlowClient(cfg),
|
||||
cfg: cfg,
|
||||
}
|
||||
}
|
||||
|
||||
// StartDeviceFlow initiates the device flow authentication.
|
||||
func (k *KimiAuth) StartDeviceFlow(ctx context.Context) (*DeviceCodeResponse, error) {
|
||||
return k.deviceClient.RequestDeviceCode(ctx)
|
||||
}
|
||||
|
||||
// WaitForAuthorization polls for user authorization and returns the auth bundle.
|
||||
func (k *KimiAuth) WaitForAuthorization(ctx context.Context, deviceCode *DeviceCodeResponse) (*KimiAuthBundle, error) {
|
||||
tokenData, err := k.deviceClient.PollForToken(ctx, deviceCode)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
return &KimiAuthBundle{
|
||||
TokenData: tokenData,
|
||||
DeviceID: k.deviceClient.deviceID,
|
||||
}, nil
|
||||
}
|
||||
|
||||
// CreateTokenStorage creates a new KimiTokenStorage from auth bundle.
|
||||
func (k *KimiAuth) CreateTokenStorage(bundle *KimiAuthBundle) *KimiTokenStorage {
|
||||
expired := ""
|
||||
if bundle.TokenData.ExpiresAt > 0 {
|
||||
expired = time.Unix(bundle.TokenData.ExpiresAt, 0).UTC().Format(time.RFC3339)
|
||||
}
|
||||
return &KimiTokenStorage{
|
||||
AccessToken: bundle.TokenData.AccessToken,
|
||||
RefreshToken: bundle.TokenData.RefreshToken,
|
||||
TokenType: bundle.TokenData.TokenType,
|
||||
Scope: bundle.TokenData.Scope,
|
||||
DeviceID: strings.TrimSpace(bundle.DeviceID),
|
||||
Expired: expired,
|
||||
Type: "kimi",
|
||||
}
|
||||
}
|
||||
|
||||
// DeviceFlowClient handles the OAuth2 device flow for Kimi.
|
||||
type DeviceFlowClient struct {
|
||||
httpClient *http.Client
|
||||
cfg *config.Config
|
||||
deviceID string
|
||||
}
|
||||
|
||||
// NewDeviceFlowClient creates a new device flow client.
|
||||
func NewDeviceFlowClient(cfg *config.Config) *DeviceFlowClient {
|
||||
return NewDeviceFlowClientWithDeviceID(cfg, "")
|
||||
}
|
||||
|
||||
// NewDeviceFlowClientWithDeviceID creates a new device flow client with the specified device ID.
|
||||
func NewDeviceFlowClientWithDeviceID(cfg *config.Config, deviceID string) *DeviceFlowClient {
|
||||
client := &http.Client{Timeout: 30 * time.Second}
|
||||
if cfg != nil {
|
||||
client = util.SetProxy(&cfg.SDKConfig, client)
|
||||
}
|
||||
resolvedDeviceID := strings.TrimSpace(deviceID)
|
||||
if resolvedDeviceID == "" {
|
||||
resolvedDeviceID = getOrCreateDeviceID()
|
||||
}
|
||||
return &DeviceFlowClient{
|
||||
httpClient: client,
|
||||
cfg: cfg,
|
||||
deviceID: resolvedDeviceID,
|
||||
}
|
||||
}
|
||||
|
||||
// getOrCreateDeviceID returns an in-memory device ID for the current authentication flow.
|
||||
func getOrCreateDeviceID() string {
|
||||
return uuid.New().String()
|
||||
}
|
||||
|
||||
// getDeviceModel returns a device model string.
|
||||
func getDeviceModel() string {
|
||||
osName := runtime.GOOS
|
||||
arch := runtime.GOARCH
|
||||
|
||||
switch osName {
|
||||
case "darwin":
|
||||
return fmt.Sprintf("macOS %s", arch)
|
||||
case "windows":
|
||||
return fmt.Sprintf("Windows %s", arch)
|
||||
case "linux":
|
||||
return fmt.Sprintf("Linux %s", arch)
|
||||
default:
|
||||
return fmt.Sprintf("%s %s", osName, arch)
|
||||
}
|
||||
}
|
||||
|
||||
// getHostname returns the machine hostname.
|
||||
func getHostname() string {
|
||||
hostname, err := os.Hostname()
|
||||
if err != nil {
|
||||
return "unknown"
|
||||
}
|
||||
return hostname
|
||||
}
|
||||
|
||||
// commonHeaders returns headers required for Kimi API requests.
|
||||
func (c *DeviceFlowClient) commonHeaders() map[string]string {
|
||||
return map[string]string{
|
||||
"X-Msh-Platform": "cli-proxy-api",
|
||||
"X-Msh-Version": "1.0.0",
|
||||
"X-Msh-Device-Name": getHostname(),
|
||||
"X-Msh-Device-Model": getDeviceModel(),
|
||||
"X-Msh-Device-Id": c.deviceID,
|
||||
}
|
||||
}
|
||||
|
||||
// RequestDeviceCode initiates the device flow by requesting a device code from Kimi.
|
||||
func (c *DeviceFlowClient) RequestDeviceCode(ctx context.Context) (*DeviceCodeResponse, error) {
|
||||
data := url.Values{}
|
||||
data.Set("client_id", kimiClientID)
|
||||
|
||||
req, err := http.NewRequestWithContext(ctx, http.MethodPost, kimiDeviceCodeURL, strings.NewReader(data.Encode()))
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("kimi: failed to create device code request: %w", err)
|
||||
}
|
||||
req.Header.Set("Content-Type", "application/x-www-form-urlencoded")
|
||||
req.Header.Set("Accept", "application/json")
|
||||
for k, v := range c.commonHeaders() {
|
||||
req.Header.Set(k, v)
|
||||
}
|
||||
|
||||
resp, err := c.httpClient.Do(req)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("kimi: device code request failed: %w", err)
|
||||
}
|
||||
defer func() {
|
||||
if errClose := resp.Body.Close(); errClose != nil {
|
||||
log.Errorf("kimi device code: close body error: %v", errClose)
|
||||
}
|
||||
}()
|
||||
|
||||
bodyBytes, err := io.ReadAll(resp.Body)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("kimi: failed to read device code response: %w", err)
|
||||
}
|
||||
|
||||
if resp.StatusCode != http.StatusOK {
|
||||
return nil, fmt.Errorf("kimi: device code request failed with status %d: %s", resp.StatusCode, string(bodyBytes))
|
||||
}
|
||||
|
||||
var deviceCode DeviceCodeResponse
|
||||
if err = json.Unmarshal(bodyBytes, &deviceCode); err != nil {
|
||||
return nil, fmt.Errorf("kimi: failed to parse device code response: %w", err)
|
||||
}
|
||||
|
||||
return &deviceCode, nil
|
||||
}
|
||||
|
||||
// PollForToken polls the token endpoint until the user authorizes or the device code expires.
|
||||
func (c *DeviceFlowClient) PollForToken(ctx context.Context, deviceCode *DeviceCodeResponse) (*KimiTokenData, error) {
|
||||
if deviceCode == nil {
|
||||
return nil, fmt.Errorf("kimi: device code is nil")
|
||||
}
|
||||
|
||||
interval := time.Duration(deviceCode.Interval) * time.Second
|
||||
if interval < defaultPollInterval {
|
||||
interval = defaultPollInterval
|
||||
}
|
||||
|
||||
deadline := time.Now().Add(maxPollDuration)
|
||||
if deviceCode.ExpiresIn > 0 {
|
||||
codeDeadline := time.Now().Add(time.Duration(deviceCode.ExpiresIn) * time.Second)
|
||||
if codeDeadline.Before(deadline) {
|
||||
deadline = codeDeadline
|
||||
}
|
||||
}
|
||||
|
||||
ticker := time.NewTicker(interval)
|
||||
defer ticker.Stop()
|
||||
|
||||
for {
|
||||
select {
|
||||
case <-ctx.Done():
|
||||
return nil, fmt.Errorf("kimi: context cancelled: %w", ctx.Err())
|
||||
case <-ticker.C:
|
||||
if time.Now().After(deadline) {
|
||||
return nil, fmt.Errorf("kimi: device code expired")
|
||||
}
|
||||
|
||||
token, pollErr, shouldContinue := c.exchangeDeviceCode(ctx, deviceCode.DeviceCode)
|
||||
if token != nil {
|
||||
return token, nil
|
||||
}
|
||||
if !shouldContinue {
|
||||
return nil, pollErr
|
||||
}
|
||||
// Continue polling
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// exchangeDeviceCode attempts to exchange the device code for an access token.
|
||||
// Returns (token, error, shouldContinue).
|
||||
func (c *DeviceFlowClient) exchangeDeviceCode(ctx context.Context, deviceCode string) (*KimiTokenData, error, bool) {
|
||||
data := url.Values{}
|
||||
data.Set("client_id", kimiClientID)
|
||||
data.Set("device_code", deviceCode)
|
||||
data.Set("grant_type", "urn:ietf:params:oauth:grant-type:device_code")
|
||||
|
||||
req, err := http.NewRequestWithContext(ctx, http.MethodPost, kimiTokenURL, strings.NewReader(data.Encode()))
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("kimi: failed to create token request: %w", err), false
|
||||
}
|
||||
req.Header.Set("Content-Type", "application/x-www-form-urlencoded")
|
||||
req.Header.Set("Accept", "application/json")
|
||||
for k, v := range c.commonHeaders() {
|
||||
req.Header.Set(k, v)
|
||||
}
|
||||
|
||||
resp, err := c.httpClient.Do(req)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("kimi: token request failed: %w", err), false
|
||||
}
|
||||
defer func() {
|
||||
if errClose := resp.Body.Close(); errClose != nil {
|
||||
log.Errorf("kimi token exchange: close body error: %v", errClose)
|
||||
}
|
||||
}()
|
||||
|
||||
bodyBytes, err := io.ReadAll(resp.Body)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("kimi: failed to read token response: %w", err), false
|
||||
}
|
||||
|
||||
// Parse response - Kimi returns 200 for both success and pending states
|
||||
var oauthResp struct {
|
||||
Error string `json:"error"`
|
||||
ErrorDescription string `json:"error_description"`
|
||||
AccessToken string `json:"access_token"`
|
||||
RefreshToken string `json:"refresh_token"`
|
||||
TokenType string `json:"token_type"`
|
||||
ExpiresIn float64 `json:"expires_in"`
|
||||
Scope string `json:"scope"`
|
||||
}
|
||||
|
||||
if err = json.Unmarshal(bodyBytes, &oauthResp); err != nil {
|
||||
return nil, fmt.Errorf("kimi: failed to parse token response: %w", err), false
|
||||
}
|
||||
|
||||
if oauthResp.Error != "" {
|
||||
switch oauthResp.Error {
|
||||
case "authorization_pending":
|
||||
return nil, nil, true // Continue polling
|
||||
case "slow_down":
|
||||
return nil, nil, true // Continue polling (with increased interval handled by caller)
|
||||
case "expired_token":
|
||||
return nil, fmt.Errorf("kimi: device code expired"), false
|
||||
case "access_denied":
|
||||
return nil, fmt.Errorf("kimi: access denied by user"), false
|
||||
default:
|
||||
return nil, fmt.Errorf("kimi: OAuth error: %s - %s", oauthResp.Error, oauthResp.ErrorDescription), false
|
||||
}
|
||||
}
|
||||
|
||||
if oauthResp.AccessToken == "" {
|
||||
return nil, fmt.Errorf("kimi: empty access token in response"), false
|
||||
}
|
||||
|
||||
var expiresAt int64
|
||||
if oauthResp.ExpiresIn > 0 {
|
||||
expiresAt = time.Now().Unix() + int64(oauthResp.ExpiresIn)
|
||||
}
|
||||
|
||||
return &KimiTokenData{
|
||||
AccessToken: oauthResp.AccessToken,
|
||||
RefreshToken: oauthResp.RefreshToken,
|
||||
TokenType: oauthResp.TokenType,
|
||||
ExpiresAt: expiresAt,
|
||||
Scope: oauthResp.Scope,
|
||||
}, nil, false
|
||||
}
|
||||
|
||||
// RefreshToken exchanges a refresh token for a new access token.
|
||||
func (c *DeviceFlowClient) RefreshToken(ctx context.Context, refreshToken string) (*KimiTokenData, error) {
|
||||
data := url.Values{}
|
||||
data.Set("client_id", kimiClientID)
|
||||
data.Set("grant_type", "refresh_token")
|
||||
data.Set("refresh_token", refreshToken)
|
||||
|
||||
req, err := http.NewRequestWithContext(ctx, http.MethodPost, kimiTokenURL, strings.NewReader(data.Encode()))
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("kimi: failed to create refresh request: %w", err)
|
||||
}
|
||||
req.Header.Set("Content-Type", "application/x-www-form-urlencoded")
|
||||
req.Header.Set("Accept", "application/json")
|
||||
for k, v := range c.commonHeaders() {
|
||||
req.Header.Set(k, v)
|
||||
}
|
||||
|
||||
resp, err := c.httpClient.Do(req)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("kimi: refresh request failed: %w", err)
|
||||
}
|
||||
defer func() {
|
||||
if errClose := resp.Body.Close(); errClose != nil {
|
||||
log.Errorf("kimi refresh token: close body error: %v", errClose)
|
||||
}
|
||||
}()
|
||||
|
||||
bodyBytes, err := io.ReadAll(resp.Body)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("kimi: failed to read refresh response: %w", err)
|
||||
}
|
||||
|
||||
if resp.StatusCode == http.StatusUnauthorized || resp.StatusCode == http.StatusForbidden {
|
||||
return nil, fmt.Errorf("kimi: refresh token rejected (status %d)", resp.StatusCode)
|
||||
}
|
||||
|
||||
if resp.StatusCode != http.StatusOK {
|
||||
return nil, fmt.Errorf("kimi: refresh failed with status %d: %s", resp.StatusCode, string(bodyBytes))
|
||||
}
|
||||
|
||||
var tokenResp struct {
|
||||
AccessToken string `json:"access_token"`
|
||||
RefreshToken string `json:"refresh_token"`
|
||||
TokenType string `json:"token_type"`
|
||||
ExpiresIn float64 `json:"expires_in"`
|
||||
Scope string `json:"scope"`
|
||||
}
|
||||
|
||||
if err = json.Unmarshal(bodyBytes, &tokenResp); err != nil {
|
||||
return nil, fmt.Errorf("kimi: failed to parse refresh response: %w", err)
|
||||
}
|
||||
|
||||
if tokenResp.AccessToken == "" {
|
||||
return nil, fmt.Errorf("kimi: empty access token in refresh response")
|
||||
}
|
||||
|
||||
var expiresAt int64
|
||||
if tokenResp.ExpiresIn > 0 {
|
||||
expiresAt = time.Now().Unix() + int64(tokenResp.ExpiresIn)
|
||||
}
|
||||
|
||||
return &KimiTokenData{
|
||||
AccessToken: tokenResp.AccessToken,
|
||||
RefreshToken: tokenResp.RefreshToken,
|
||||
TokenType: tokenResp.TokenType,
|
||||
ExpiresAt: expiresAt,
|
||||
Scope: tokenResp.Scope,
|
||||
}, nil
|
||||
}
|
||||
116
internal/auth/kimi/token.go
Normal file
116
internal/auth/kimi/token.go
Normal file
@@ -0,0 +1,116 @@
|
||||
// Package kimi provides authentication and token management functionality
|
||||
// for Kimi (Moonshot AI) services. It handles OAuth2 device flow token storage,
|
||||
// serialization, and retrieval for maintaining authenticated sessions with the Kimi API.
|
||||
package kimi
|
||||
|
||||
import (
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"os"
|
||||
"path/filepath"
|
||||
"time"
|
||||
|
||||
"github.com/router-for-me/CLIProxyAPI/v6/internal/misc"
|
||||
)
|
||||
|
||||
// KimiTokenStorage stores OAuth2 token information for Kimi API authentication.
|
||||
type KimiTokenStorage struct {
|
||||
// AccessToken is the OAuth2 access token used for authenticating API requests.
|
||||
AccessToken string `json:"access_token"`
|
||||
// RefreshToken is the OAuth2 refresh token used to obtain new access tokens.
|
||||
RefreshToken string `json:"refresh_token"`
|
||||
// TokenType is the type of token, typically "Bearer".
|
||||
TokenType string `json:"token_type"`
|
||||
// Scope is the OAuth2 scope granted to the token.
|
||||
Scope string `json:"scope,omitempty"`
|
||||
// DeviceID is the OAuth device flow identifier used for Kimi requests.
|
||||
DeviceID string `json:"device_id,omitempty"`
|
||||
// Expired is the RFC3339 timestamp when the access token expires.
|
||||
Expired string `json:"expired,omitempty"`
|
||||
// Type indicates the authentication provider type, always "kimi" for this storage.
|
||||
Type string `json:"type"`
|
||||
}
|
||||
|
||||
// KimiTokenData holds the raw OAuth token response from Kimi.
|
||||
type KimiTokenData struct {
|
||||
// AccessToken is the OAuth2 access token.
|
||||
AccessToken string `json:"access_token"`
|
||||
// RefreshToken is the OAuth2 refresh token.
|
||||
RefreshToken string `json:"refresh_token"`
|
||||
// TokenType is the type of token, typically "Bearer".
|
||||
TokenType string `json:"token_type"`
|
||||
// ExpiresAt is the Unix timestamp when the token expires.
|
||||
ExpiresAt int64 `json:"expires_at"`
|
||||
// Scope is the OAuth2 scope granted to the token.
|
||||
Scope string `json:"scope"`
|
||||
}
|
||||
|
||||
// KimiAuthBundle bundles authentication data for storage.
|
||||
type KimiAuthBundle struct {
|
||||
// TokenData contains the OAuth token information.
|
||||
TokenData *KimiTokenData
|
||||
// DeviceID is the device identifier used during OAuth device flow.
|
||||
DeviceID string
|
||||
}
|
||||
|
||||
// DeviceCodeResponse represents Kimi's device code response.
|
||||
type DeviceCodeResponse struct {
|
||||
// DeviceCode is the device verification code.
|
||||
DeviceCode string `json:"device_code"`
|
||||
// UserCode is the code the user must enter at the verification URI.
|
||||
UserCode string `json:"user_code"`
|
||||
// VerificationURI is the URL where the user should enter the code.
|
||||
VerificationURI string `json:"verification_uri,omitempty"`
|
||||
// VerificationURIComplete is the URL with the code pre-filled.
|
||||
VerificationURIComplete string `json:"verification_uri_complete"`
|
||||
// ExpiresIn is the number of seconds until the device code expires.
|
||||
ExpiresIn int `json:"expires_in"`
|
||||
// Interval is the minimum number of seconds to wait between polling requests.
|
||||
Interval int `json:"interval"`
|
||||
}
|
||||
|
||||
// SaveTokenToFile serializes the Kimi token storage to a JSON file.
|
||||
func (ts *KimiTokenStorage) SaveTokenToFile(authFilePath string) error {
|
||||
misc.LogSavingCredentials(authFilePath)
|
||||
ts.Type = "kimi"
|
||||
|
||||
if err := os.MkdirAll(filepath.Dir(authFilePath), 0700); err != nil {
|
||||
return fmt.Errorf("failed to create directory: %v", err)
|
||||
}
|
||||
|
||||
f, err := os.Create(authFilePath)
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to create token file: %w", err)
|
||||
}
|
||||
defer func() {
|
||||
_ = f.Close()
|
||||
}()
|
||||
|
||||
encoder := json.NewEncoder(f)
|
||||
encoder.SetIndent("", " ")
|
||||
if err = encoder.Encode(ts); err != nil {
|
||||
return fmt.Errorf("failed to write token to file: %w", err)
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
// IsExpired checks if the token has expired.
|
||||
func (ts *KimiTokenStorage) IsExpired() bool {
|
||||
if ts.Expired == "" {
|
||||
return false // No expiry set, assume valid
|
||||
}
|
||||
t, err := time.Parse(time.RFC3339, ts.Expired)
|
||||
if err != nil {
|
||||
return true // Has expiry string but can't parse
|
||||
}
|
||||
// Consider expired if within refresh threshold
|
||||
return time.Now().Add(time.Duration(refreshThresholdSeconds) * time.Second).After(t)
|
||||
}
|
||||
|
||||
// NeedsRefresh checks if the token should be refreshed.
|
||||
func (ts *KimiTokenStorage) NeedsRefresh() bool {
|
||||
if ts.RefreshToken == "" {
|
||||
return false // Can't refresh without refresh token
|
||||
}
|
||||
return ts.IsExpired()
|
||||
}
|
||||
@@ -92,7 +92,7 @@ const KiroIDETokenFile = ".aws/sso/cache/kiro-auth-token.json"
|
||||
|
||||
// Default retry configuration for file reading
|
||||
const (
|
||||
defaultTokenReadMaxAttempts = 10 // Maximum retry attempts
|
||||
defaultTokenReadMaxAttempts = 10 // Maximum retry attempts
|
||||
defaultTokenReadBaseDelay = 50 * time.Millisecond // Base delay between retries
|
||||
)
|
||||
|
||||
@@ -301,7 +301,7 @@ func ListKiroTokenFiles() ([]string, error) {
|
||||
}
|
||||
|
||||
cacheDir := filepath.Join(homeDir, ".aws", "sso", "cache")
|
||||
|
||||
|
||||
// Check if directory exists
|
||||
if _, err := os.Stat(cacheDir); os.IsNotExist(err) {
|
||||
return nil, nil // No token files
|
||||
@@ -488,14 +488,16 @@ func ExtractIDCIdentifier(startURL string) string {
|
||||
|
||||
// GenerateTokenFileName generates a unique filename for token storage.
|
||||
// Priority: email > startUrl identifier (for IDC) > authMethod only
|
||||
// Format: kiro-{authMethod}-{identifier}.json
|
||||
// Email is unique, so no sequence suffix needed. Sequence is only added
|
||||
// when email is unavailable to prevent filename collisions.
|
||||
// Format: kiro-{authMethod}-{identifier}[-{seq}].json
|
||||
func GenerateTokenFileName(tokenData *KiroTokenData) string {
|
||||
authMethod := tokenData.AuthMethod
|
||||
if authMethod == "" {
|
||||
authMethod = "unknown"
|
||||
}
|
||||
|
||||
// Priority 1: Use email if available
|
||||
// Priority 1: Use email if available (no sequence needed, email is unique)
|
||||
if tokenData.Email != "" {
|
||||
// Sanitize email for filename (replace @ and . with -)
|
||||
sanitizedEmail := tokenData.Email
|
||||
@@ -504,14 +506,17 @@ func GenerateTokenFileName(tokenData *KiroTokenData) string {
|
||||
return fmt.Sprintf("kiro-%s-%s.json", authMethod, sanitizedEmail)
|
||||
}
|
||||
|
||||
// Priority 2: For IDC, use startUrl identifier
|
||||
// Generate sequence only when email is unavailable
|
||||
seq := time.Now().UnixNano() % 100000
|
||||
|
||||
// Priority 2: For IDC, use startUrl identifier with sequence
|
||||
if authMethod == "idc" && tokenData.StartURL != "" {
|
||||
identifier := ExtractIDCIdentifier(tokenData.StartURL)
|
||||
if identifier != "" {
|
||||
return fmt.Sprintf("kiro-%s-%s.json", authMethod, identifier)
|
||||
return fmt.Sprintf("kiro-%s-%s-%05d.json", authMethod, identifier, seq)
|
||||
}
|
||||
}
|
||||
|
||||
// Priority 3: Fallback to authMethod only
|
||||
return fmt.Sprintf("kiro-%s.json", authMethod)
|
||||
// Priority 3: Fallback to authMethod only with sequence
|
||||
return fmt.Sprintf("kiro-%s-%05d.json", authMethod, seq)
|
||||
}
|
||||
|
||||
@@ -238,7 +238,7 @@ func (k *KiroAuth) ListAvailableModels(ctx context.Context, tokenData *KiroToken
|
||||
Description string `json:"description"`
|
||||
RateMultiplier float64 `json:"rateMultiplier"`
|
||||
RateUnit string `json:"rateUnit"`
|
||||
TokenLimits struct {
|
||||
TokenLimits *struct {
|
||||
MaxInputTokens int `json:"maxInputTokens"`
|
||||
} `json:"tokenLimits"`
|
||||
} `json:"models"`
|
||||
@@ -250,13 +250,17 @@ func (k *KiroAuth) ListAvailableModels(ctx context.Context, tokenData *KiroToken
|
||||
|
||||
models := make([]*KiroModel, 0, len(result.Models))
|
||||
for _, m := range result.Models {
|
||||
maxInputTokens := 0
|
||||
if m.TokenLimits != nil {
|
||||
maxInputTokens = m.TokenLimits.MaxInputTokens
|
||||
}
|
||||
models = append(models, &KiroModel{
|
||||
ModelID: m.ModelID,
|
||||
ModelName: m.ModelName,
|
||||
Description: m.Description,
|
||||
RateMultiplier: m.RateMultiplier,
|
||||
RateUnit: m.RateUnit,
|
||||
MaxInputTokens: m.TokenLimits.MaxInputTokens,
|
||||
MaxInputTokens: maxInputTokens,
|
||||
})
|
||||
}
|
||||
|
||||
|
||||
@@ -19,6 +19,7 @@ func newAuthManager() *sdkAuth.Manager {
|
||||
sdkAuth.NewQwenAuthenticator(),
|
||||
sdkAuth.NewIFlowAuthenticator(),
|
||||
sdkAuth.NewAntigravityAuthenticator(),
|
||||
sdkAuth.NewKimiAuthenticator(),
|
||||
sdkAuth.NewKiroAuthenticator(),
|
||||
sdkAuth.NewGitHubCopilotAuthenticator(),
|
||||
)
|
||||
|
||||
44
internal/cmd/kimi_login.go
Normal file
44
internal/cmd/kimi_login.go
Normal file
@@ -0,0 +1,44 @@
|
||||
package cmd
|
||||
|
||||
import (
|
||||
"context"
|
||||
"fmt"
|
||||
|
||||
"github.com/router-for-me/CLIProxyAPI/v6/internal/config"
|
||||
sdkAuth "github.com/router-for-me/CLIProxyAPI/v6/sdk/auth"
|
||||
log "github.com/sirupsen/logrus"
|
||||
)
|
||||
|
||||
// DoKimiLogin triggers the OAuth device flow for Kimi (Moonshot AI) and saves tokens.
|
||||
// It initiates the device flow authentication, displays the verification URL for the user,
|
||||
// and waits for authorization before saving the tokens.
|
||||
//
|
||||
// Parameters:
|
||||
// - cfg: The application configuration containing proxy and auth directory settings
|
||||
// - options: Login options including browser behavior settings
|
||||
func DoKimiLogin(cfg *config.Config, options *LoginOptions) {
|
||||
if options == nil {
|
||||
options = &LoginOptions{}
|
||||
}
|
||||
|
||||
manager := newAuthManager()
|
||||
authOpts := &sdkAuth.LoginOptions{
|
||||
NoBrowser: options.NoBrowser,
|
||||
Metadata: map[string]string{},
|
||||
Prompt: options.Prompt,
|
||||
}
|
||||
|
||||
record, savedPath, err := manager.Login(context.Background(), "kimi", cfg, authOpts)
|
||||
if err != nil {
|
||||
log.Errorf("Kimi authentication failed: %v", err)
|
||||
return
|
||||
}
|
||||
|
||||
if savedPath != "" {
|
||||
fmt.Printf("Authentication saved to %s\n", savedPath)
|
||||
}
|
||||
if record != nil && record.Label != "" {
|
||||
fmt.Printf("Authenticated as %s\n", record.Label)
|
||||
}
|
||||
fmt.Println("Kimi authentication successful!")
|
||||
}
|
||||
@@ -18,7 +18,10 @@ import (
|
||||
"gopkg.in/yaml.v3"
|
||||
)
|
||||
|
||||
const DefaultPanelGitHubRepository = "https://github.com/router-for-me/Cli-Proxy-API-Management-Center"
|
||||
const (
|
||||
DefaultPanelGitHubRepository = "https://github.com/router-for-me/Cli-Proxy-API-Management-Center"
|
||||
DefaultPprofAddr = "127.0.0.1:8316"
|
||||
)
|
||||
|
||||
// Config represents the application's configuration, loaded from a YAML file.
|
||||
type Config struct {
|
||||
@@ -41,6 +44,9 @@ type Config struct {
|
||||
// Debug enables or disables debug-level logging and other debug features.
|
||||
Debug bool `yaml:"debug" json:"debug"`
|
||||
|
||||
// Pprof config controls the optional pprof HTTP debug server.
|
||||
Pprof PprofConfig `yaml:"pprof" json:"pprof"`
|
||||
|
||||
// CommercialMode disables high-overhead HTTP middleware features to minimize per-request memory usage.
|
||||
CommercialMode bool `yaml:"commercial-mode" json:"commercial-mode"`
|
||||
|
||||
@@ -134,6 +140,14 @@ type TLSConfig struct {
|
||||
Key string `yaml:"key" json:"key"`
|
||||
}
|
||||
|
||||
// PprofConfig holds pprof HTTP server settings.
|
||||
type PprofConfig struct {
|
||||
// Enable toggles the pprof HTTP debug server.
|
||||
Enable bool `yaml:"enable" json:"enable"`
|
||||
// Addr is the host:port address for the pprof HTTP server.
|
||||
Addr string `yaml:"addr" json:"addr"`
|
||||
}
|
||||
|
||||
// RemoteManagement holds management API configuration under 'remote-management'.
|
||||
type RemoteManagement struct {
|
||||
// AllowRemote toggles remote (non-localhost) access to management API.
|
||||
@@ -521,14 +535,15 @@ func LoadConfig(configFile string) (*Config, error) {
|
||||
// If optional is true and the file is missing, it returns an empty Config.
|
||||
// If optional is true and the file is empty or invalid, it returns an empty Config.
|
||||
func LoadConfigOptional(configFile string, optional bool) (*Config, error) {
|
||||
// Perform oauth-model-alias migration before loading config.
|
||||
// This migrates oauth-model-mappings to oauth-model-alias if needed.
|
||||
if migrated, err := MigrateOAuthModelAlias(configFile); err != nil {
|
||||
// Log warning but don't fail - config loading should still work
|
||||
fmt.Printf("Warning: oauth-model-alias migration failed: %v\n", err)
|
||||
} else if migrated {
|
||||
fmt.Println("Migrated oauth-model-mappings to oauth-model-alias")
|
||||
}
|
||||
// NOTE: Startup oauth-model-alias migration is intentionally disabled.
|
||||
// Reason: avoid mutating config.yaml during server startup.
|
||||
// Re-enable the block below if automatic startup migration is needed again.
|
||||
// if migrated, err := MigrateOAuthModelAlias(configFile); err != nil {
|
||||
// // Log warning but don't fail - config loading should still work
|
||||
// fmt.Printf("Warning: oauth-model-alias migration failed: %v\n", err)
|
||||
// } else if migrated {
|
||||
// fmt.Println("Migrated oauth-model-mappings to oauth-model-alias")
|
||||
// }
|
||||
|
||||
// Read the entire configuration file into memory.
|
||||
data, err := os.ReadFile(configFile)
|
||||
@@ -556,6 +571,8 @@ func LoadConfigOptional(configFile string, optional bool) (*Config, error) {
|
||||
cfg.ErrorLogsMaxFiles = 10
|
||||
cfg.UsageStatisticsEnabled = false
|
||||
cfg.DisableCooling = false
|
||||
cfg.Pprof.Enable = false
|
||||
cfg.Pprof.Addr = DefaultPprofAddr
|
||||
cfg.AmpCode.RestrictManagementToLocalhost = false // Default to false: API key auth is sufficient
|
||||
cfg.RemoteManagement.PanelGitHubRepository = DefaultPanelGitHubRepository
|
||||
cfg.IncognitoBrowser = false // Default to normal browser (AWS uses incognito by force)
|
||||
@@ -567,18 +584,21 @@ func LoadConfigOptional(configFile string, optional bool) (*Config, error) {
|
||||
return nil, fmt.Errorf("failed to parse config file: %w", err)
|
||||
}
|
||||
|
||||
var legacy legacyConfigData
|
||||
if errLegacy := yaml.Unmarshal(data, &legacy); errLegacy == nil {
|
||||
if cfg.migrateLegacyGeminiKeys(legacy.LegacyGeminiKeys) {
|
||||
cfg.legacyMigrationPending = true
|
||||
}
|
||||
if cfg.migrateLegacyOpenAICompatibilityKeys(legacy.OpenAICompat) {
|
||||
cfg.legacyMigrationPending = true
|
||||
}
|
||||
if cfg.migrateLegacyAmpConfig(&legacy) {
|
||||
cfg.legacyMigrationPending = true
|
||||
}
|
||||
}
|
||||
// NOTE: Startup legacy key migration is intentionally disabled.
|
||||
// Reason: avoid mutating config.yaml during server startup.
|
||||
// Re-enable the block below if automatic startup migration is needed again.
|
||||
// var legacy legacyConfigData
|
||||
// if errLegacy := yaml.Unmarshal(data, &legacy); errLegacy == nil {
|
||||
// if cfg.migrateLegacyGeminiKeys(legacy.LegacyGeminiKeys) {
|
||||
// cfg.legacyMigrationPending = true
|
||||
// }
|
||||
// if cfg.migrateLegacyOpenAICompatibilityKeys(legacy.OpenAICompat) {
|
||||
// cfg.legacyMigrationPending = true
|
||||
// }
|
||||
// if cfg.migrateLegacyAmpConfig(&legacy) {
|
||||
// cfg.legacyMigrationPending = true
|
||||
// }
|
||||
// }
|
||||
|
||||
// Hash remote management key if plaintext is detected (nested)
|
||||
// We consider a value to be already hashed if it looks like a bcrypt hash ($2a$, $2b$, or $2y$ prefix).
|
||||
@@ -599,6 +619,11 @@ func LoadConfigOptional(configFile string, optional bool) (*Config, error) {
|
||||
cfg.RemoteManagement.PanelGitHubRepository = DefaultPanelGitHubRepository
|
||||
}
|
||||
|
||||
cfg.Pprof.Addr = strings.TrimSpace(cfg.Pprof.Addr)
|
||||
if cfg.Pprof.Addr == "" {
|
||||
cfg.Pprof.Addr = DefaultPprofAddr
|
||||
}
|
||||
|
||||
if cfg.LogsMaxTotalSizeMB < 0 {
|
||||
cfg.LogsMaxTotalSizeMB = 0
|
||||
}
|
||||
@@ -607,9 +632,6 @@ func LoadConfigOptional(configFile string, optional bool) (*Config, error) {
|
||||
cfg.ErrorLogsMaxFiles = 10
|
||||
}
|
||||
|
||||
// Sync request authentication providers with inline API keys for backwards compatibility.
|
||||
syncInlineAccessProvider(&cfg)
|
||||
|
||||
// Sanitize Gemini API key configuration and migrate legacy entries.
|
||||
cfg.SanitizeGeminiKeys()
|
||||
|
||||
@@ -637,17 +659,20 @@ func LoadConfigOptional(configFile string, optional bool) (*Config, error) {
|
||||
// Validate raw payload rules and drop invalid entries.
|
||||
cfg.SanitizePayloadRules()
|
||||
|
||||
if cfg.legacyMigrationPending {
|
||||
fmt.Println("Detected legacy configuration keys, attempting to persist the normalized config...")
|
||||
if !optional && configFile != "" {
|
||||
if err := SaveConfigPreserveComments(configFile, &cfg); err != nil {
|
||||
return nil, fmt.Errorf("failed to persist migrated legacy config: %w", err)
|
||||
}
|
||||
fmt.Println("Legacy configuration normalized and persisted.")
|
||||
} else {
|
||||
fmt.Println("Legacy configuration normalized in memory; persistence skipped.")
|
||||
}
|
||||
}
|
||||
// NOTE: Legacy migration persistence is intentionally disabled together with
|
||||
// startup legacy migration to keep startup read-only for config.yaml.
|
||||
// Re-enable the block below if automatic startup migration is needed again.
|
||||
// if cfg.legacyMigrationPending {
|
||||
// fmt.Println("Detected legacy configuration keys, attempting to persist the normalized config...")
|
||||
// if !optional && configFile != "" {
|
||||
// if err := SaveConfigPreserveComments(configFile, &cfg); err != nil {
|
||||
// return nil, fmt.Errorf("failed to persist migrated legacy config: %w", err)
|
||||
// }
|
||||
// fmt.Println("Legacy configuration normalized and persisted.")
|
||||
// } else {
|
||||
// fmt.Println("Legacy configuration normalized in memory; persistence skipped.")
|
||||
// }
|
||||
// }
|
||||
|
||||
// Return the populated configuration struct.
|
||||
return &cfg, nil
|
||||
@@ -711,8 +736,32 @@ func payloadRawString(value any) ([]byte, bool) {
|
||||
// SanitizeOAuthModelAlias normalizes and deduplicates global OAuth model name aliases.
|
||||
// It trims whitespace, normalizes channel keys to lower-case, drops empty entries,
|
||||
// allows multiple aliases per upstream name, and ensures aliases are unique within each channel.
|
||||
// It also injects default aliases for channels that have built-in defaults (e.g., kiro)
|
||||
// when no user-configured aliases exist for those channels.
|
||||
func (cfg *Config) SanitizeOAuthModelAlias() {
|
||||
if cfg == nil || len(cfg.OAuthModelAlias) == 0 {
|
||||
if cfg == nil {
|
||||
return
|
||||
}
|
||||
|
||||
// Inject default Kiro aliases if no user-configured kiro aliases exist
|
||||
if cfg.OAuthModelAlias == nil {
|
||||
cfg.OAuthModelAlias = make(map[string][]OAuthModelAlias)
|
||||
}
|
||||
if _, hasKiro := cfg.OAuthModelAlias["kiro"]; !hasKiro {
|
||||
// Check case-insensitive too
|
||||
found := false
|
||||
for k := range cfg.OAuthModelAlias {
|
||||
if strings.EqualFold(strings.TrimSpace(k), "kiro") {
|
||||
found = true
|
||||
break
|
||||
}
|
||||
}
|
||||
if !found {
|
||||
cfg.OAuthModelAlias["kiro"] = defaultKiroAliases()
|
||||
}
|
||||
}
|
||||
|
||||
if len(cfg.OAuthModelAlias) == 0 {
|
||||
return
|
||||
}
|
||||
out := make(map[string][]OAuthModelAlias, len(cfg.OAuthModelAlias))
|
||||
@@ -860,18 +909,6 @@ func normalizeModelPrefix(prefix string) string {
|
||||
return trimmed
|
||||
}
|
||||
|
||||
func syncInlineAccessProvider(cfg *Config) {
|
||||
if cfg == nil {
|
||||
return
|
||||
}
|
||||
if len(cfg.APIKeys) == 0 {
|
||||
if provider := cfg.ConfigAPIKeyProvider(); provider != nil && len(provider.APIKeys) > 0 {
|
||||
cfg.APIKeys = append([]string(nil), provider.APIKeys...)
|
||||
}
|
||||
}
|
||||
cfg.Access.Providers = nil
|
||||
}
|
||||
|
||||
// looksLikeBcrypt returns true if the provided string appears to be a bcrypt hash.
|
||||
func looksLikeBcrypt(s string) bool {
|
||||
return len(s) > 4 && (s[:4] == "$2a$" || s[:4] == "$2b$" || s[:4] == "$2y$")
|
||||
@@ -959,7 +996,7 @@ func hashSecret(secret string) (string, error) {
|
||||
// SaveConfigPreserveComments writes the config back to YAML while preserving existing comments
|
||||
// and key ordering by loading the original file into a yaml.Node tree and updating values in-place.
|
||||
func SaveConfigPreserveComments(configFile string, cfg *Config) error {
|
||||
persistCfg := sanitizeConfigForPersist(cfg)
|
||||
persistCfg := cfg
|
||||
// Load original YAML as a node tree to preserve comments and ordering.
|
||||
data, err := os.ReadFile(configFile)
|
||||
if err != nil {
|
||||
@@ -1027,16 +1064,6 @@ func SaveConfigPreserveComments(configFile string, cfg *Config) error {
|
||||
return err
|
||||
}
|
||||
|
||||
func sanitizeConfigForPersist(cfg *Config) *Config {
|
||||
if cfg == nil {
|
||||
return nil
|
||||
}
|
||||
clone := *cfg
|
||||
clone.SDKConfig = cfg.SDKConfig
|
||||
clone.SDKConfig.Access = AccessConfig{}
|
||||
return &clone
|
||||
}
|
||||
|
||||
// SaveConfigPreserveCommentsUpdateNestedScalar updates a nested scalar key path like ["a","b"]
|
||||
// while preserving comments and positions.
|
||||
func SaveConfigPreserveCommentsUpdateNestedScalar(configFile string, path []string, value string) error {
|
||||
@@ -1133,8 +1160,13 @@ func getOrCreateMapValue(mapNode *yaml.Node, key string) *yaml.Node {
|
||||
|
||||
// mergeMappingPreserve merges keys from src into dst mapping node while preserving
|
||||
// key order and comments of existing keys in dst. New keys are only added if their
|
||||
// value is non-zero to avoid polluting the config with defaults.
|
||||
func mergeMappingPreserve(dst, src *yaml.Node) {
|
||||
// value is non-zero and not a known default to avoid polluting the config with defaults.
|
||||
func mergeMappingPreserve(dst, src *yaml.Node, path ...[]string) {
|
||||
var currentPath []string
|
||||
if len(path) > 0 {
|
||||
currentPath = path[0]
|
||||
}
|
||||
|
||||
if dst == nil || src == nil {
|
||||
return
|
||||
}
|
||||
@@ -1148,16 +1180,19 @@ func mergeMappingPreserve(dst, src *yaml.Node) {
|
||||
sk := src.Content[i]
|
||||
sv := src.Content[i+1]
|
||||
idx := findMapKeyIndex(dst, sk.Value)
|
||||
childPath := appendPath(currentPath, sk.Value)
|
||||
if idx >= 0 {
|
||||
// Merge into existing value node (always update, even to zero values)
|
||||
dv := dst.Content[idx+1]
|
||||
mergeNodePreserve(dv, sv)
|
||||
mergeNodePreserve(dv, sv, childPath)
|
||||
} else {
|
||||
// New key: only add if value is non-zero to avoid polluting config with defaults
|
||||
if isZeroValueNode(sv) {
|
||||
// New key: only add if value is non-zero and not a known default
|
||||
candidate := deepCopyNode(sv)
|
||||
pruneKnownDefaultsInNewNode(childPath, candidate)
|
||||
if isKnownDefaultValue(childPath, candidate) {
|
||||
continue
|
||||
}
|
||||
dst.Content = append(dst.Content, deepCopyNode(sk), deepCopyNode(sv))
|
||||
dst.Content = append(dst.Content, deepCopyNode(sk), candidate)
|
||||
}
|
||||
}
|
||||
}
|
||||
@@ -1165,7 +1200,12 @@ func mergeMappingPreserve(dst, src *yaml.Node) {
|
||||
// mergeNodePreserve merges src into dst for scalars, mappings and sequences while
|
||||
// reusing destination nodes to keep comments and anchors. For sequences, it updates
|
||||
// in-place by index.
|
||||
func mergeNodePreserve(dst, src *yaml.Node) {
|
||||
func mergeNodePreserve(dst, src *yaml.Node, path ...[]string) {
|
||||
var currentPath []string
|
||||
if len(path) > 0 {
|
||||
currentPath = path[0]
|
||||
}
|
||||
|
||||
if dst == nil || src == nil {
|
||||
return
|
||||
}
|
||||
@@ -1174,7 +1214,7 @@ func mergeNodePreserve(dst, src *yaml.Node) {
|
||||
if dst.Kind != yaml.MappingNode {
|
||||
copyNodeShallow(dst, src)
|
||||
}
|
||||
mergeMappingPreserve(dst, src)
|
||||
mergeMappingPreserve(dst, src, currentPath)
|
||||
case yaml.SequenceNode:
|
||||
// Preserve explicit null style if dst was null and src is empty sequence
|
||||
if dst.Kind == yaml.ScalarNode && dst.Tag == "!!null" && len(src.Content) == 0 {
|
||||
@@ -1197,7 +1237,7 @@ func mergeNodePreserve(dst, src *yaml.Node) {
|
||||
dst.Content[i] = deepCopyNode(src.Content[i])
|
||||
continue
|
||||
}
|
||||
mergeNodePreserve(dst.Content[i], src.Content[i])
|
||||
mergeNodePreserve(dst.Content[i], src.Content[i], currentPath)
|
||||
if dst.Content[i] != nil && src.Content[i] != nil &&
|
||||
dst.Content[i].Kind == yaml.MappingNode && src.Content[i].Kind == yaml.MappingNode {
|
||||
pruneMissingMapKeys(dst.Content[i], src.Content[i])
|
||||
@@ -1239,6 +1279,94 @@ func findMapKeyIndex(mapNode *yaml.Node, key string) int {
|
||||
return -1
|
||||
}
|
||||
|
||||
// appendPath appends a key to the path, returning a new slice to avoid modifying the original.
|
||||
func appendPath(path []string, key string) []string {
|
||||
if len(path) == 0 {
|
||||
return []string{key}
|
||||
}
|
||||
newPath := make([]string, len(path)+1)
|
||||
copy(newPath, path)
|
||||
newPath[len(path)] = key
|
||||
return newPath
|
||||
}
|
||||
|
||||
// isKnownDefaultValue returns true if the given node at the specified path
|
||||
// represents a known default value that should not be written to the config file.
|
||||
// This prevents non-zero defaults from polluting the config.
|
||||
func isKnownDefaultValue(path []string, node *yaml.Node) bool {
|
||||
// First check if it's a zero value
|
||||
if isZeroValueNode(node) {
|
||||
return true
|
||||
}
|
||||
|
||||
// Match known non-zero defaults by exact dotted path.
|
||||
if len(path) == 0 {
|
||||
return false
|
||||
}
|
||||
|
||||
fullPath := strings.Join(path, ".")
|
||||
|
||||
// Check string defaults
|
||||
if node.Kind == yaml.ScalarNode && node.Tag == "!!str" {
|
||||
switch fullPath {
|
||||
case "pprof.addr":
|
||||
return node.Value == DefaultPprofAddr
|
||||
case "remote-management.panel-github-repository":
|
||||
return node.Value == DefaultPanelGitHubRepository
|
||||
case "routing.strategy":
|
||||
return node.Value == "round-robin"
|
||||
}
|
||||
}
|
||||
|
||||
// Check integer defaults
|
||||
if node.Kind == yaml.ScalarNode && node.Tag == "!!int" {
|
||||
switch fullPath {
|
||||
case "error-logs-max-files":
|
||||
return node.Value == "10"
|
||||
}
|
||||
}
|
||||
|
||||
return false
|
||||
}
|
||||
|
||||
// pruneKnownDefaultsInNewNode removes default-valued descendants from a new node
|
||||
// before it is appended into the destination YAML tree.
|
||||
func pruneKnownDefaultsInNewNode(path []string, node *yaml.Node) {
|
||||
if node == nil {
|
||||
return
|
||||
}
|
||||
|
||||
switch node.Kind {
|
||||
case yaml.MappingNode:
|
||||
filtered := make([]*yaml.Node, 0, len(node.Content))
|
||||
for i := 0; i+1 < len(node.Content); i += 2 {
|
||||
keyNode := node.Content[i]
|
||||
valueNode := node.Content[i+1]
|
||||
if keyNode == nil || valueNode == nil {
|
||||
continue
|
||||
}
|
||||
|
||||
childPath := appendPath(path, keyNode.Value)
|
||||
if isKnownDefaultValue(childPath, valueNode) {
|
||||
continue
|
||||
}
|
||||
|
||||
pruneKnownDefaultsInNewNode(childPath, valueNode)
|
||||
if (valueNode.Kind == yaml.MappingNode || valueNode.Kind == yaml.SequenceNode) &&
|
||||
len(valueNode.Content) == 0 {
|
||||
continue
|
||||
}
|
||||
|
||||
filtered = append(filtered, keyNode, valueNode)
|
||||
}
|
||||
node.Content = filtered
|
||||
case yaml.SequenceNode:
|
||||
for _, child := range node.Content {
|
||||
pruneKnownDefaultsInNewNode(path, child)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// isZeroValueNode returns true if the YAML node represents a zero/default value
|
||||
// that should not be written as a new key to preserve config cleanliness.
|
||||
// For mappings and sequences, recursively checks if all children are zero values.
|
||||
|
||||
@@ -17,6 +17,29 @@ var antigravityModelConversionTable = map[string]string{
|
||||
"gemini-claude-sonnet-4-5": "claude-sonnet-4-5",
|
||||
"gemini-claude-sonnet-4-5-thinking": "claude-sonnet-4-5-thinking",
|
||||
"gemini-claude-opus-4-5-thinking": "claude-opus-4-5-thinking",
|
||||
"gemini-claude-opus-4-6-thinking": "claude-opus-4-6-thinking",
|
||||
}
|
||||
|
||||
// defaultKiroAliases returns the default oauth-model-alias configuration
|
||||
// for the kiro channel. Maps kiro-prefixed model names to standard Claude model
|
||||
// names so that clients like Claude Code can use standard names directly.
|
||||
func defaultKiroAliases() []OAuthModelAlias {
|
||||
return []OAuthModelAlias{
|
||||
// Sonnet 4.5
|
||||
{Name: "kiro-claude-sonnet-4-5", Alias: "claude-sonnet-4-5-20250929", Fork: true},
|
||||
{Name: "kiro-claude-sonnet-4-5", Alias: "claude-sonnet-4-5", Fork: true},
|
||||
// Sonnet 4
|
||||
{Name: "kiro-claude-sonnet-4", Alias: "claude-sonnet-4-20250514", Fork: true},
|
||||
{Name: "kiro-claude-sonnet-4", Alias: "claude-sonnet-4", Fork: true},
|
||||
// Opus 4.6
|
||||
{Name: "kiro-claude-opus-4-6", Alias: "claude-opus-4-6", Fork: true},
|
||||
// Opus 4.5
|
||||
{Name: "kiro-claude-opus-4-5", Alias: "claude-opus-4-5-20251101", Fork: true},
|
||||
{Name: "kiro-claude-opus-4-5", Alias: "claude-opus-4-5", Fork: true},
|
||||
// Haiku 4.5
|
||||
{Name: "kiro-claude-haiku-4-5", Alias: "claude-haiku-4-5-20251001", Fork: true},
|
||||
{Name: "kiro-claude-haiku-4-5", Alias: "claude-haiku-4-5", Fork: true},
|
||||
}
|
||||
}
|
||||
|
||||
// defaultAntigravityAliases returns the default oauth-model-alias configuration
|
||||
@@ -30,6 +53,7 @@ func defaultAntigravityAliases() []OAuthModelAlias {
|
||||
{Name: "claude-sonnet-4-5", Alias: "gemini-claude-sonnet-4-5"},
|
||||
{Name: "claude-sonnet-4-5-thinking", Alias: "gemini-claude-sonnet-4-5-thinking"},
|
||||
{Name: "claude-opus-4-5-thinking", Alias: "gemini-claude-opus-4-5-thinking"},
|
||||
{Name: "claude-opus-4-6-thinking", Alias: "gemini-claude-opus-4-6-thinking"},
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
@@ -131,6 +131,9 @@ func TestMigrateOAuthModelAlias_ConvertsAntigravityModels(t *testing.T) {
|
||||
if !strings.Contains(content, "claude-opus-4-5-thinking") {
|
||||
t.Fatal("expected missing default alias claude-opus-4-5-thinking to be added")
|
||||
}
|
||||
if !strings.Contains(content, "claude-opus-4-6-thinking") {
|
||||
t.Fatal("expected missing default alias claude-opus-4-6-thinking to be added")
|
||||
}
|
||||
}
|
||||
|
||||
func TestMigrateOAuthModelAlias_AddsDefaultIfNeitherExists(t *testing.T) {
|
||||
|
||||
@@ -54,3 +54,88 @@ func TestSanitizeOAuthModelAlias_AllowsMultipleAliasesForSameName(t *testing.T)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func TestSanitizeOAuthModelAlias_InjectsDefaultKiroAliases(t *testing.T) {
|
||||
// When no kiro aliases are configured, defaults should be injected
|
||||
cfg := &Config{
|
||||
OAuthModelAlias: map[string][]OAuthModelAlias{
|
||||
"codex": {
|
||||
{Name: "gpt-5", Alias: "g5"},
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
cfg.SanitizeOAuthModelAlias()
|
||||
|
||||
kiroAliases := cfg.OAuthModelAlias["kiro"]
|
||||
if len(kiroAliases) == 0 {
|
||||
t.Fatal("expected default kiro aliases to be injected")
|
||||
}
|
||||
|
||||
// Check that standard Claude model names are present
|
||||
aliasSet := make(map[string]bool)
|
||||
for _, a := range kiroAliases {
|
||||
aliasSet[a.Alias] = true
|
||||
}
|
||||
expectedAliases := []string{
|
||||
"claude-sonnet-4-5-20250929",
|
||||
"claude-sonnet-4-5",
|
||||
"claude-sonnet-4-20250514",
|
||||
"claude-sonnet-4",
|
||||
"claude-opus-4-6",
|
||||
"claude-opus-4-5-20251101",
|
||||
"claude-opus-4-5",
|
||||
"claude-haiku-4-5-20251001",
|
||||
"claude-haiku-4-5",
|
||||
}
|
||||
for _, expected := range expectedAliases {
|
||||
if !aliasSet[expected] {
|
||||
t.Fatalf("expected default kiro alias %q to be present", expected)
|
||||
}
|
||||
}
|
||||
|
||||
// All should have fork=true
|
||||
for _, a := range kiroAliases {
|
||||
if !a.Fork {
|
||||
t.Fatalf("expected all default kiro aliases to have fork=true, got fork=false for %q", a.Alias)
|
||||
}
|
||||
}
|
||||
|
||||
// Codex aliases should still be preserved
|
||||
if len(cfg.OAuthModelAlias["codex"]) != 1 {
|
||||
t.Fatal("expected codex aliases to be preserved")
|
||||
}
|
||||
}
|
||||
|
||||
func TestSanitizeOAuthModelAlias_DoesNotOverrideUserKiroAliases(t *testing.T) {
|
||||
// When user has configured kiro aliases, defaults should NOT be injected
|
||||
cfg := &Config{
|
||||
OAuthModelAlias: map[string][]OAuthModelAlias{
|
||||
"kiro": {
|
||||
{Name: "kiro-claude-sonnet-4", Alias: "my-custom-sonnet", Fork: true},
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
cfg.SanitizeOAuthModelAlias()
|
||||
|
||||
kiroAliases := cfg.OAuthModelAlias["kiro"]
|
||||
if len(kiroAliases) != 1 {
|
||||
t.Fatalf("expected 1 user-configured kiro alias, got %d", len(kiroAliases))
|
||||
}
|
||||
if kiroAliases[0].Alias != "my-custom-sonnet" {
|
||||
t.Fatalf("expected user alias to be preserved, got %q", kiroAliases[0].Alias)
|
||||
}
|
||||
}
|
||||
|
||||
func TestSanitizeOAuthModelAlias_InjectsDefaultKiroWhenEmpty(t *testing.T) {
|
||||
// When OAuthModelAlias is nil, kiro defaults should still be injected
|
||||
cfg := &Config{}
|
||||
|
||||
cfg.SanitizeOAuthModelAlias()
|
||||
|
||||
kiroAliases := cfg.OAuthModelAlias["kiro"]
|
||||
if len(kiroAliases) == 0 {
|
||||
t.Fatal("expected default kiro aliases to be injected when OAuthModelAlias is nil")
|
||||
}
|
||||
}
|
||||
|
||||
@@ -20,9 +20,6 @@ type SDKConfig struct {
|
||||
// APIKeys is a list of keys for authenticating clients to this proxy server.
|
||||
APIKeys []string `yaml:"api-keys" json:"api-keys"`
|
||||
|
||||
// Access holds request authentication provider configuration.
|
||||
Access AccessConfig `yaml:"auth,omitempty" json:"auth,omitempty"`
|
||||
|
||||
// Streaming configures server-side streaming behavior (keep-alives and safe bootstrap retries).
|
||||
Streaming StreamingConfig `yaml:"streaming" json:"streaming"`
|
||||
|
||||
@@ -42,65 +39,3 @@ type StreamingConfig struct {
|
||||
// <= 0 disables bootstrap retries. Default is 0.
|
||||
BootstrapRetries int `yaml:"bootstrap-retries,omitempty" json:"bootstrap-retries,omitempty"`
|
||||
}
|
||||
|
||||
// AccessConfig groups request authentication providers.
|
||||
type AccessConfig struct {
|
||||
// Providers lists configured authentication providers.
|
||||
Providers []AccessProvider `yaml:"providers,omitempty" json:"providers,omitempty"`
|
||||
}
|
||||
|
||||
// AccessProvider describes a request authentication provider entry.
|
||||
type AccessProvider struct {
|
||||
// Name is the instance identifier for the provider.
|
||||
Name string `yaml:"name" json:"name"`
|
||||
|
||||
// Type selects the provider implementation registered via the SDK.
|
||||
Type string `yaml:"type" json:"type"`
|
||||
|
||||
// SDK optionally names a third-party SDK module providing this provider.
|
||||
SDK string `yaml:"sdk,omitempty" json:"sdk,omitempty"`
|
||||
|
||||
// APIKeys lists inline keys for providers that require them.
|
||||
APIKeys []string `yaml:"api-keys,omitempty" json:"api-keys,omitempty"`
|
||||
|
||||
// Config passes provider-specific options to the implementation.
|
||||
Config map[string]any `yaml:"config,omitempty" json:"config,omitempty"`
|
||||
}
|
||||
|
||||
const (
|
||||
// AccessProviderTypeConfigAPIKey is the built-in provider validating inline API keys.
|
||||
AccessProviderTypeConfigAPIKey = "config-api-key"
|
||||
|
||||
// DefaultAccessProviderName is applied when no provider name is supplied.
|
||||
DefaultAccessProviderName = "config-inline"
|
||||
)
|
||||
|
||||
// ConfigAPIKeyProvider returns the first inline API key provider if present.
|
||||
func (c *SDKConfig) ConfigAPIKeyProvider() *AccessProvider {
|
||||
if c == nil {
|
||||
return nil
|
||||
}
|
||||
for i := range c.Access.Providers {
|
||||
if c.Access.Providers[i].Type == AccessProviderTypeConfigAPIKey {
|
||||
if c.Access.Providers[i].Name == "" {
|
||||
c.Access.Providers[i].Name = DefaultAccessProviderName
|
||||
}
|
||||
return &c.Access.Providers[i]
|
||||
}
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
// MakeInlineAPIKeyProvider constructs an inline API key provider configuration.
|
||||
// It returns nil when no keys are supplied.
|
||||
func MakeInlineAPIKeyProvider(keys []string) *AccessProvider {
|
||||
if len(keys) == 0 {
|
||||
return nil
|
||||
}
|
||||
provider := &AccessProvider{
|
||||
Name: DefaultAccessProviderName,
|
||||
Type: AccessProviderTypeConfigAPIKey,
|
||||
APIKeys: append([]string(nil), keys...),
|
||||
}
|
||||
return provider
|
||||
}
|
||||
|
||||
@@ -132,7 +132,10 @@ func ResolveLogDirectory(cfg *config.Config) string {
|
||||
return logDir
|
||||
}
|
||||
if !isDirWritable(logDir) {
|
||||
authDir := strings.TrimSpace(cfg.AuthDir)
|
||||
authDir, err := util.ResolveAuthDir(cfg.AuthDir)
|
||||
if err != nil {
|
||||
log.Warnf("Failed to resolve auth-dir %q for log directory: %v", cfg.AuthDir, err)
|
||||
}
|
||||
if authDir != "" {
|
||||
logDir = filepath.Join(authDir, "logs")
|
||||
}
|
||||
|
||||
@@ -21,6 +21,7 @@ import (
|
||||
"github.com/router-for-me/CLIProxyAPI/v6/internal/util"
|
||||
sdkconfig "github.com/router-for-me/CLIProxyAPI/v6/sdk/config"
|
||||
log "github.com/sirupsen/logrus"
|
||||
"golang.org/x/sync/singleflight"
|
||||
)
|
||||
|
||||
const (
|
||||
@@ -28,6 +29,7 @@ const (
|
||||
defaultManagementFallbackURL = "https://cpamc.router-for.me/"
|
||||
managementAssetName = "management.html"
|
||||
httpUserAgent = "CLIProxyAPI-management-updater"
|
||||
managementSyncMinInterval = 30 * time.Second
|
||||
updateCheckInterval = 3 * time.Hour
|
||||
)
|
||||
|
||||
@@ -37,11 +39,10 @@ const ManagementFileName = managementAssetName
|
||||
var (
|
||||
lastUpdateCheckMu sync.Mutex
|
||||
lastUpdateCheckTime time.Time
|
||||
|
||||
currentConfigPtr atomic.Pointer[config.Config]
|
||||
disableControlPanel atomic.Bool
|
||||
schedulerOnce sync.Once
|
||||
schedulerConfigPath atomic.Value
|
||||
sfGroup singleflight.Group
|
||||
)
|
||||
|
||||
// SetCurrentConfig stores the latest configuration snapshot for management asset decisions.
|
||||
@@ -50,16 +51,7 @@ func SetCurrentConfig(cfg *config.Config) {
|
||||
currentConfigPtr.Store(nil)
|
||||
return
|
||||
}
|
||||
|
||||
prevDisabled := disableControlPanel.Load()
|
||||
currentConfigPtr.Store(cfg)
|
||||
disableControlPanel.Store(cfg.RemoteManagement.DisableControlPanel)
|
||||
|
||||
if prevDisabled && !cfg.RemoteManagement.DisableControlPanel {
|
||||
lastUpdateCheckMu.Lock()
|
||||
lastUpdateCheckTime = time.Time{}
|
||||
lastUpdateCheckMu.Unlock()
|
||||
}
|
||||
}
|
||||
|
||||
// StartAutoUpdater launches a background goroutine that periodically ensures the management asset is up to date.
|
||||
@@ -92,7 +84,7 @@ func runAutoUpdater(ctx context.Context) {
|
||||
log.Debug("management asset auto-updater skipped: config not yet available")
|
||||
return
|
||||
}
|
||||
if disableControlPanel.Load() {
|
||||
if cfg.RemoteManagement.DisableControlPanel {
|
||||
log.Debug("management asset auto-updater skipped: control panel disabled")
|
||||
return
|
||||
}
|
||||
@@ -181,103 +173,106 @@ func FilePath(configFilePath string) string {
|
||||
}
|
||||
|
||||
// EnsureLatestManagementHTML checks the latest management.html asset and updates the local copy when needed.
|
||||
// The function is designed to run in a background goroutine and will never panic.
|
||||
// It enforces a 3-hour rate limit to avoid frequent checks on config/auth file changes.
|
||||
func EnsureLatestManagementHTML(ctx context.Context, staticDir string, proxyURL string, panelRepository string) {
|
||||
// It coalesces concurrent sync attempts and returns whether the asset exists after the sync attempt.
|
||||
func EnsureLatestManagementHTML(ctx context.Context, staticDir string, proxyURL string, panelRepository string) bool {
|
||||
if ctx == nil {
|
||||
ctx = context.Background()
|
||||
}
|
||||
|
||||
if disableControlPanel.Load() {
|
||||
log.Debug("management asset sync skipped: control panel disabled by configuration")
|
||||
return
|
||||
}
|
||||
|
||||
staticDir = strings.TrimSpace(staticDir)
|
||||
if staticDir == "" {
|
||||
log.Debug("management asset sync skipped: empty static directory")
|
||||
return
|
||||
return false
|
||||
}
|
||||
|
||||
localPath := filepath.Join(staticDir, managementAssetName)
|
||||
localFileMissing := false
|
||||
if _, errStat := os.Stat(localPath); errStat != nil {
|
||||
if errors.Is(errStat, os.ErrNotExist) {
|
||||
localFileMissing = true
|
||||
} else {
|
||||
log.WithError(errStat).Debug("failed to stat local management asset")
|
||||
}
|
||||
}
|
||||
|
||||
// Rate limiting: check only once every 3 hours
|
||||
lastUpdateCheckMu.Lock()
|
||||
now := time.Now()
|
||||
timeSinceLastCheck := now.Sub(lastUpdateCheckTime)
|
||||
if timeSinceLastCheck < updateCheckInterval {
|
||||
_, _, _ = sfGroup.Do(localPath, func() (interface{}, error) {
|
||||
lastUpdateCheckMu.Lock()
|
||||
now := time.Now()
|
||||
timeSinceLastAttempt := now.Sub(lastUpdateCheckTime)
|
||||
if !lastUpdateCheckTime.IsZero() && timeSinceLastAttempt < managementSyncMinInterval {
|
||||
lastUpdateCheckMu.Unlock()
|
||||
log.Debugf(
|
||||
"management asset sync skipped by throttle: last attempt %v ago (interval %v)",
|
||||
timeSinceLastAttempt.Round(time.Second),
|
||||
managementSyncMinInterval,
|
||||
)
|
||||
return nil, nil
|
||||
}
|
||||
lastUpdateCheckTime = now
|
||||
lastUpdateCheckMu.Unlock()
|
||||
log.Debugf("management asset update check skipped: last check was %v ago (interval: %v)", timeSinceLastCheck.Round(time.Second), updateCheckInterval)
|
||||
return
|
||||
}
|
||||
lastUpdateCheckTime = now
|
||||
lastUpdateCheckMu.Unlock()
|
||||
|
||||
if errMkdirAll := os.MkdirAll(staticDir, 0o755); errMkdirAll != nil {
|
||||
log.WithError(errMkdirAll).Warn("failed to prepare static directory for management asset")
|
||||
return
|
||||
}
|
||||
|
||||
releaseURL := resolveReleaseURL(panelRepository)
|
||||
client := newHTTPClient(proxyURL)
|
||||
|
||||
localHash, err := fileSHA256(localPath)
|
||||
if err != nil {
|
||||
if !errors.Is(err, os.ErrNotExist) {
|
||||
log.WithError(err).Debug("failed to read local management asset hash")
|
||||
}
|
||||
localHash = ""
|
||||
}
|
||||
|
||||
asset, remoteHash, err := fetchLatestAsset(ctx, client, releaseURL)
|
||||
if err != nil {
|
||||
if localFileMissing {
|
||||
log.WithError(err).Warn("failed to fetch latest management release information, trying fallback page")
|
||||
if ensureFallbackManagementHTML(ctx, client, localPath) {
|
||||
return
|
||||
localFileMissing := false
|
||||
if _, errStat := os.Stat(localPath); errStat != nil {
|
||||
if errors.Is(errStat, os.ErrNotExist) {
|
||||
localFileMissing = true
|
||||
} else {
|
||||
log.WithError(errStat).Debug("failed to stat local management asset")
|
||||
}
|
||||
return
|
||||
}
|
||||
log.WithError(err).Warn("failed to fetch latest management release information")
|
||||
return
|
||||
}
|
||||
|
||||
if remoteHash != "" && localHash != "" && strings.EqualFold(remoteHash, localHash) {
|
||||
log.Debug("management asset is already up to date")
|
||||
return
|
||||
}
|
||||
if errMkdirAll := os.MkdirAll(staticDir, 0o755); errMkdirAll != nil {
|
||||
log.WithError(errMkdirAll).Warn("failed to prepare static directory for management asset")
|
||||
return nil, nil
|
||||
}
|
||||
|
||||
data, downloadedHash, err := downloadAsset(ctx, client, asset.BrowserDownloadURL)
|
||||
if err != nil {
|
||||
if localFileMissing {
|
||||
log.WithError(err).Warn("failed to download management asset, trying fallback page")
|
||||
if ensureFallbackManagementHTML(ctx, client, localPath) {
|
||||
return
|
||||
releaseURL := resolveReleaseURL(panelRepository)
|
||||
client := newHTTPClient(proxyURL)
|
||||
|
||||
localHash, err := fileSHA256(localPath)
|
||||
if err != nil {
|
||||
if !errors.Is(err, os.ErrNotExist) {
|
||||
log.WithError(err).Debug("failed to read local management asset hash")
|
||||
}
|
||||
return
|
||||
localHash = ""
|
||||
}
|
||||
log.WithError(err).Warn("failed to download management asset")
|
||||
return
|
||||
}
|
||||
|
||||
if remoteHash != "" && !strings.EqualFold(remoteHash, downloadedHash) {
|
||||
log.Warnf("remote digest mismatch for management asset: expected %s got %s", remoteHash, downloadedHash)
|
||||
}
|
||||
asset, remoteHash, err := fetchLatestAsset(ctx, client, releaseURL)
|
||||
if err != nil {
|
||||
if localFileMissing {
|
||||
log.WithError(err).Warn("failed to fetch latest management release information, trying fallback page")
|
||||
if ensureFallbackManagementHTML(ctx, client, localPath) {
|
||||
return nil, nil
|
||||
}
|
||||
return nil, nil
|
||||
}
|
||||
log.WithError(err).Warn("failed to fetch latest management release information")
|
||||
return nil, nil
|
||||
}
|
||||
|
||||
if err = atomicWriteFile(localPath, data); err != nil {
|
||||
log.WithError(err).Warn("failed to update management asset on disk")
|
||||
return
|
||||
}
|
||||
if remoteHash != "" && localHash != "" && strings.EqualFold(remoteHash, localHash) {
|
||||
log.Debug("management asset is already up to date")
|
||||
return nil, nil
|
||||
}
|
||||
|
||||
log.Infof("management asset updated successfully (hash=%s)", downloadedHash)
|
||||
data, downloadedHash, err := downloadAsset(ctx, client, asset.BrowserDownloadURL)
|
||||
if err != nil {
|
||||
if localFileMissing {
|
||||
log.WithError(err).Warn("failed to download management asset, trying fallback page")
|
||||
if ensureFallbackManagementHTML(ctx, client, localPath) {
|
||||
return nil, nil
|
||||
}
|
||||
return nil, nil
|
||||
}
|
||||
log.WithError(err).Warn("failed to download management asset")
|
||||
return nil, nil
|
||||
}
|
||||
|
||||
if remoteHash != "" && !strings.EqualFold(remoteHash, downloadedHash) {
|
||||
log.Warnf("remote digest mismatch for management asset: expected %s got %s", remoteHash, downloadedHash)
|
||||
}
|
||||
|
||||
if err = atomicWriteFile(localPath, data); err != nil {
|
||||
log.WithError(err).Warn("failed to update management asset on disk")
|
||||
return nil, nil
|
||||
}
|
||||
|
||||
log.Infof("management asset updated successfully (hash=%s)", downloadedHash)
|
||||
return nil, nil
|
||||
})
|
||||
|
||||
_, err := os.Stat(localPath)
|
||||
return err == nil
|
||||
}
|
||||
|
||||
func ensureFallbackManagementHTML(ctx context.Context, client *http.Client, localPath string) bool {
|
||||
|
||||
@@ -19,7 +19,10 @@ import (
|
||||
// - codex
|
||||
// - qwen
|
||||
// - iflow
|
||||
// - kiro
|
||||
// - github-copilot
|
||||
// - kiro
|
||||
// - amazonq
|
||||
// - antigravity (returns static overrides only)
|
||||
func GetStaticModelDefinitionsByChannel(channel string) []*ModelInfo {
|
||||
key := strings.ToLower(strings.TrimSpace(channel))
|
||||
@@ -42,6 +45,10 @@ func GetStaticModelDefinitionsByChannel(channel string) []*ModelInfo {
|
||||
return GetIFlowModels()
|
||||
case "github-copilot":
|
||||
return GetGitHubCopilotModels()
|
||||
case "kiro":
|
||||
return GetKiroModels()
|
||||
case "amazonq":
|
||||
return GetAmazonQModels()
|
||||
case "antigravity":
|
||||
cfg := GetAntigravityModelConfig()
|
||||
if len(cfg) == 0 {
|
||||
@@ -86,6 +93,9 @@ func LookupStaticModelInfo(modelID string) *ModelInfo {
|
||||
GetOpenAIModels(),
|
||||
GetQwenModels(),
|
||||
GetIFlowModels(),
|
||||
GetGitHubCopilotModels(),
|
||||
GetKiroModels(),
|
||||
GetAmazonQModels(),
|
||||
}
|
||||
for _, models := range allModels {
|
||||
for _, m := range models {
|
||||
@@ -267,6 +277,18 @@ func GetGitHubCopilotModels() []*ModelInfo {
|
||||
MaxCompletionTokens: 64000,
|
||||
SupportedEndpoints: []string{"/chat/completions"},
|
||||
},
|
||||
{
|
||||
ID: "claude-opus-4.6",
|
||||
Object: "model",
|
||||
Created: now,
|
||||
OwnedBy: "github-copilot",
|
||||
Type: "github-copilot",
|
||||
DisplayName: "Claude Opus 4.6",
|
||||
Description: "Anthropic Claude Opus 4.6 via GitHub Copilot",
|
||||
ContextLength: 200000,
|
||||
MaxCompletionTokens: 64000,
|
||||
SupportedEndpoints: []string{"/chat/completions"},
|
||||
},
|
||||
{
|
||||
ID: "claude-sonnet-4",
|
||||
Object: "model",
|
||||
@@ -366,6 +388,18 @@ func GetKiroModels() []*ModelInfo {
|
||||
MaxCompletionTokens: 64000,
|
||||
Thinking: &ThinkingSupport{Min: 1024, Max: 32000, ZeroAllowed: true, DynamicAllowed: true},
|
||||
},
|
||||
{
|
||||
ID: "kiro-claude-opus-4-6",
|
||||
Object: "model",
|
||||
Created: 1736899200, // 2025-01-15
|
||||
OwnedBy: "aws",
|
||||
Type: "kiro",
|
||||
DisplayName: "Kiro Claude Opus 4.6",
|
||||
Description: "Claude Opus 4.6 via Kiro (2.2x credit)",
|
||||
ContextLength: 200000,
|
||||
MaxCompletionTokens: 64000,
|
||||
Thinking: &ThinkingSupport{Min: 1024, Max: 32000, ZeroAllowed: true, DynamicAllowed: true},
|
||||
},
|
||||
{
|
||||
ID: "kiro-claude-opus-4-5",
|
||||
Object: "model",
|
||||
@@ -415,6 +449,18 @@ func GetKiroModels() []*ModelInfo {
|
||||
Thinking: &ThinkingSupport{Min: 1024, Max: 32000, ZeroAllowed: true, DynamicAllowed: true},
|
||||
},
|
||||
// --- Agentic Variants (Optimized for coding agents with chunked writes) ---
|
||||
{
|
||||
ID: "kiro-claude-opus-4-6-agentic",
|
||||
Object: "model",
|
||||
Created: 1736899200, // 2025-01-15
|
||||
OwnedBy: "aws",
|
||||
Type: "kiro",
|
||||
DisplayName: "Kiro Claude Opus 4.6 (Agentic)",
|
||||
Description: "Claude Opus 4.6 optimized for coding agents (chunked writes)",
|
||||
ContextLength: 200000,
|
||||
MaxCompletionTokens: 64000,
|
||||
Thinking: &ThinkingSupport{Min: 1024, Max: 32000, ZeroAllowed: true, DynamicAllowed: true},
|
||||
},
|
||||
{
|
||||
ID: "kiro-claude-opus-4-5-agentic",
|
||||
Object: "model",
|
||||
|
||||
@@ -15,7 +15,7 @@ func GetClaudeModels() []*ModelInfo {
|
||||
DisplayName: "Claude 4.5 Haiku",
|
||||
ContextLength: 200000,
|
||||
MaxCompletionTokens: 64000,
|
||||
// Thinking: not supported for Haiku models
|
||||
Thinking: &ThinkingSupport{Min: 1024, Max: 128000, ZeroAllowed: true, DynamicAllowed: false},
|
||||
},
|
||||
{
|
||||
ID: "claude-sonnet-4-5-20250929",
|
||||
@@ -28,6 +28,18 @@ func GetClaudeModels() []*ModelInfo {
|
||||
MaxCompletionTokens: 64000,
|
||||
Thinking: &ThinkingSupport{Min: 1024, Max: 128000, ZeroAllowed: true, DynamicAllowed: false},
|
||||
},
|
||||
{
|
||||
ID: "claude-opus-4-6",
|
||||
Object: "model",
|
||||
Created: 1770318000, // 2026-02-05
|
||||
OwnedBy: "anthropic",
|
||||
Type: "claude",
|
||||
DisplayName: "Claude 4.6 Opus",
|
||||
Description: "Premium model combining maximum intelligence with practical performance",
|
||||
ContextLength: 1000000,
|
||||
MaxCompletionTokens: 128000,
|
||||
Thinking: &ThinkingSupport{Min: 1024, Max: 128000, ZeroAllowed: true, DynamicAllowed: false},
|
||||
},
|
||||
{
|
||||
ID: "claude-opus-4-5-20251101",
|
||||
Object: "model",
|
||||
@@ -716,6 +728,20 @@ func GetOpenAIModels() []*ModelInfo {
|
||||
SupportedParameters: []string{"tools"},
|
||||
Thinking: &ThinkingSupport{Levels: []string{"low", "medium", "high", "xhigh"}},
|
||||
},
|
||||
{
|
||||
ID: "gpt-5.3-codex",
|
||||
Object: "model",
|
||||
Created: 1770307200,
|
||||
OwnedBy: "openai",
|
||||
Type: "openai",
|
||||
Version: "gpt-5.3",
|
||||
DisplayName: "GPT 5.3 Codex",
|
||||
Description: "Stable version of GPT 5.3 Codex, The best model for coding and agentic tasks across domains.",
|
||||
ContextLength: 400000,
|
||||
MaxCompletionTokens: 128000,
|
||||
SupportedParameters: []string{"tools"},
|
||||
Thinking: &ThinkingSupport{Levels: []string{"low", "medium", "high", "xhigh"}},
|
||||
},
|
||||
}
|
||||
}
|
||||
|
||||
@@ -803,6 +829,7 @@ func GetIFlowModels() []*ModelInfo {
|
||||
{ID: "minimax-m2", DisplayName: "MiniMax-M2", Description: "MiniMax M2", Created: 1758672000, Thinking: iFlowThinkingSupport},
|
||||
{ID: "minimax-m2.1", DisplayName: "MiniMax-M2.1", Description: "MiniMax M2.1", Created: 1766448000, Thinking: iFlowThinkingSupport},
|
||||
{ID: "iflow-rome-30ba3b", DisplayName: "iFlow-ROME", Description: "iFlow Rome 30BA3B model", Created: 1736899200},
|
||||
{ID: "kimi-k2.5", DisplayName: "Kimi-K2.5", Description: "Moonshot Kimi K2.5", Created: 1769443200, Thinking: iFlowThinkingSupport},
|
||||
}
|
||||
models := make([]*ModelInfo, 0, len(entries))
|
||||
for _, entry := range entries {
|
||||
@@ -839,8 +866,50 @@ func GetAntigravityModelConfig() map[string]*AntigravityModelConfig {
|
||||
"gemini-3-flash": {Thinking: &ThinkingSupport{Min: 128, Max: 32768, ZeroAllowed: false, DynamicAllowed: true, Levels: []string{"minimal", "low", "medium", "high"}}},
|
||||
"claude-sonnet-4-5-thinking": {Thinking: &ThinkingSupport{Min: 1024, Max: 128000, ZeroAllowed: true, DynamicAllowed: true}, MaxCompletionTokens: 64000},
|
||||
"claude-opus-4-5-thinking": {Thinking: &ThinkingSupport{Min: 1024, Max: 128000, ZeroAllowed: true, DynamicAllowed: true}, MaxCompletionTokens: 64000},
|
||||
"claude-opus-4-6-thinking": {Thinking: &ThinkingSupport{Min: 1024, Max: 128000, ZeroAllowed: true, DynamicAllowed: true}, MaxCompletionTokens: 64000},
|
||||
"claude-sonnet-4-5": {MaxCompletionTokens: 64000},
|
||||
"gpt-oss-120b-medium": {},
|
||||
"tab_flash_lite_preview": {},
|
||||
}
|
||||
}
|
||||
|
||||
// GetKimiModels returns the standard Kimi (Moonshot AI) model definitions
|
||||
func GetKimiModels() []*ModelInfo {
|
||||
return []*ModelInfo{
|
||||
{
|
||||
ID: "kimi-k2",
|
||||
Object: "model",
|
||||
Created: 1752192000, // 2025-07-11
|
||||
OwnedBy: "moonshot",
|
||||
Type: "kimi",
|
||||
DisplayName: "Kimi K2",
|
||||
Description: "Kimi K2 - Moonshot AI's flagship coding model",
|
||||
ContextLength: 131072,
|
||||
MaxCompletionTokens: 32768,
|
||||
},
|
||||
{
|
||||
ID: "kimi-k2-thinking",
|
||||
Object: "model",
|
||||
Created: 1762387200, // 2025-11-06
|
||||
OwnedBy: "moonshot",
|
||||
Type: "kimi",
|
||||
DisplayName: "Kimi K2 Thinking",
|
||||
Description: "Kimi K2 Thinking - Extended reasoning model",
|
||||
ContextLength: 131072,
|
||||
MaxCompletionTokens: 32768,
|
||||
Thinking: &ThinkingSupport{Min: 1024, Max: 32000, ZeroAllowed: true, DynamicAllowed: true},
|
||||
},
|
||||
{
|
||||
ID: "kimi-k2.5",
|
||||
Object: "model",
|
||||
Created: 1769472000, // 2026-01-26
|
||||
OwnedBy: "moonshot",
|
||||
Type: "kimi",
|
||||
DisplayName: "Kimi K2.5",
|
||||
Description: "Kimi K2.5 - Latest Moonshot AI coding model with improved capabilities",
|
||||
ContextLength: 131072,
|
||||
MaxCompletionTokens: 32768,
|
||||
Thinking: &ThinkingSupport{Min: 1024, Max: 32000, ZeroAllowed: true, DynamicAllowed: true},
|
||||
},
|
||||
}
|
||||
}
|
||||
|
||||
@@ -141,7 +141,7 @@ func (e *AIStudioExecutor) Execute(ctx context.Context, auth *cliproxyauth.Auth,
|
||||
URL: endpoint,
|
||||
Method: http.MethodPost,
|
||||
Headers: wsReq.Headers.Clone(),
|
||||
Body: bytes.Clone(body.payload),
|
||||
Body: body.payload,
|
||||
Provider: e.Identifier(),
|
||||
AuthID: authID,
|
||||
AuthLabel: authLabel,
|
||||
@@ -156,14 +156,14 @@ func (e *AIStudioExecutor) Execute(ctx context.Context, auth *cliproxyauth.Auth,
|
||||
}
|
||||
recordAPIResponseMetadata(ctx, e.cfg, wsResp.Status, wsResp.Headers.Clone())
|
||||
if len(wsResp.Body) > 0 {
|
||||
appendAPIResponseChunk(ctx, e.cfg, bytes.Clone(wsResp.Body))
|
||||
appendAPIResponseChunk(ctx, e.cfg, wsResp.Body)
|
||||
}
|
||||
if wsResp.Status < 200 || wsResp.Status >= 300 {
|
||||
return resp, statusErr{code: wsResp.Status, msg: string(wsResp.Body)}
|
||||
}
|
||||
reporter.publish(ctx, parseGeminiUsage(wsResp.Body))
|
||||
var param any
|
||||
out := sdktranslator.TranslateNonStream(ctx, body.toFormat, opts.SourceFormat, req.Model, bytes.Clone(opts.OriginalRequest), bytes.Clone(translatedReq), bytes.Clone(wsResp.Body), ¶m)
|
||||
out := sdktranslator.TranslateNonStream(ctx, body.toFormat, opts.SourceFormat, req.Model, opts.OriginalRequest, translatedReq, wsResp.Body, ¶m)
|
||||
resp = cliproxyexecutor.Response{Payload: ensureColonSpacedJSON([]byte(out))}
|
||||
return resp, nil
|
||||
}
|
||||
@@ -199,7 +199,7 @@ func (e *AIStudioExecutor) ExecuteStream(ctx context.Context, auth *cliproxyauth
|
||||
URL: endpoint,
|
||||
Method: http.MethodPost,
|
||||
Headers: wsReq.Headers.Clone(),
|
||||
Body: bytes.Clone(body.payload),
|
||||
Body: body.payload,
|
||||
Provider: e.Identifier(),
|
||||
AuthID: authID,
|
||||
AuthLabel: authLabel,
|
||||
@@ -225,7 +225,7 @@ func (e *AIStudioExecutor) ExecuteStream(ctx context.Context, auth *cliproxyauth
|
||||
}
|
||||
var body bytes.Buffer
|
||||
if len(firstEvent.Payload) > 0 {
|
||||
appendAPIResponseChunk(ctx, e.cfg, bytes.Clone(firstEvent.Payload))
|
||||
appendAPIResponseChunk(ctx, e.cfg, firstEvent.Payload)
|
||||
body.Write(firstEvent.Payload)
|
||||
}
|
||||
if firstEvent.Type == wsrelay.MessageTypeStreamEnd {
|
||||
@@ -244,7 +244,7 @@ func (e *AIStudioExecutor) ExecuteStream(ctx context.Context, auth *cliproxyauth
|
||||
metadataLogged = true
|
||||
}
|
||||
if len(event.Payload) > 0 {
|
||||
appendAPIResponseChunk(ctx, e.cfg, bytes.Clone(event.Payload))
|
||||
appendAPIResponseChunk(ctx, e.cfg, event.Payload)
|
||||
body.Write(event.Payload)
|
||||
}
|
||||
if event.Type == wsrelay.MessageTypeStreamEnd {
|
||||
@@ -274,12 +274,12 @@ func (e *AIStudioExecutor) ExecuteStream(ctx context.Context, auth *cliproxyauth
|
||||
}
|
||||
case wsrelay.MessageTypeStreamChunk:
|
||||
if len(event.Payload) > 0 {
|
||||
appendAPIResponseChunk(ctx, e.cfg, bytes.Clone(event.Payload))
|
||||
appendAPIResponseChunk(ctx, e.cfg, event.Payload)
|
||||
filtered := FilterSSEUsageMetadata(event.Payload)
|
||||
if detail, ok := parseGeminiStreamUsage(filtered); ok {
|
||||
reporter.publish(ctx, detail)
|
||||
}
|
||||
lines := sdktranslator.TranslateStream(ctx, body.toFormat, opts.SourceFormat, req.Model, bytes.Clone(opts.OriginalRequest), translatedReq, bytes.Clone(filtered), ¶m)
|
||||
lines := sdktranslator.TranslateStream(ctx, body.toFormat, opts.SourceFormat, req.Model, opts.OriginalRequest, translatedReq, filtered, ¶m)
|
||||
for i := range lines {
|
||||
out <- cliproxyexecutor.StreamChunk{Payload: ensureColonSpacedJSON([]byte(lines[i]))}
|
||||
}
|
||||
@@ -293,9 +293,9 @@ func (e *AIStudioExecutor) ExecuteStream(ctx context.Context, auth *cliproxyauth
|
||||
metadataLogged = true
|
||||
}
|
||||
if len(event.Payload) > 0 {
|
||||
appendAPIResponseChunk(ctx, e.cfg, bytes.Clone(event.Payload))
|
||||
appendAPIResponseChunk(ctx, e.cfg, event.Payload)
|
||||
}
|
||||
lines := sdktranslator.TranslateStream(ctx, body.toFormat, opts.SourceFormat, req.Model, bytes.Clone(opts.OriginalRequest), translatedReq, bytes.Clone(event.Payload), ¶m)
|
||||
lines := sdktranslator.TranslateStream(ctx, body.toFormat, opts.SourceFormat, req.Model, opts.OriginalRequest, translatedReq, event.Payload, ¶m)
|
||||
for i := range lines {
|
||||
out <- cliproxyexecutor.StreamChunk{Payload: ensureColonSpacedJSON([]byte(lines[i]))}
|
||||
}
|
||||
@@ -350,7 +350,7 @@ func (e *AIStudioExecutor) CountTokens(ctx context.Context, auth *cliproxyauth.A
|
||||
URL: endpoint,
|
||||
Method: http.MethodPost,
|
||||
Headers: wsReq.Headers.Clone(),
|
||||
Body: bytes.Clone(body.payload),
|
||||
Body: body.payload,
|
||||
Provider: e.Identifier(),
|
||||
AuthID: authID,
|
||||
AuthLabel: authLabel,
|
||||
@@ -364,7 +364,7 @@ func (e *AIStudioExecutor) CountTokens(ctx context.Context, auth *cliproxyauth.A
|
||||
}
|
||||
recordAPIResponseMetadata(ctx, e.cfg, resp.Status, resp.Headers.Clone())
|
||||
if len(resp.Body) > 0 {
|
||||
appendAPIResponseChunk(ctx, e.cfg, bytes.Clone(resp.Body))
|
||||
appendAPIResponseChunk(ctx, e.cfg, resp.Body)
|
||||
}
|
||||
if resp.Status < 200 || resp.Status >= 300 {
|
||||
return cliproxyexecutor.Response{}, statusErr{code: resp.Status, msg: string(resp.Body)}
|
||||
@@ -373,7 +373,7 @@ func (e *AIStudioExecutor) CountTokens(ctx context.Context, auth *cliproxyauth.A
|
||||
if totalTokens <= 0 {
|
||||
return cliproxyexecutor.Response{}, fmt.Errorf("wsrelay: totalTokens missing in response")
|
||||
}
|
||||
translated := sdktranslator.TranslateTokenCount(ctx, body.toFormat, opts.SourceFormat, totalTokens, bytes.Clone(resp.Body))
|
||||
translated := sdktranslator.TranslateTokenCount(ctx, body.toFormat, opts.SourceFormat, totalTokens, resp.Body)
|
||||
return cliproxyexecutor.Response{Payload: []byte(translated)}, nil
|
||||
}
|
||||
|
||||
@@ -393,12 +393,13 @@ func (e *AIStudioExecutor) translateRequest(req cliproxyexecutor.Request, opts c
|
||||
|
||||
from := opts.SourceFormat
|
||||
to := sdktranslator.FromString("gemini")
|
||||
originalPayload := bytes.Clone(req.Payload)
|
||||
originalPayloadSource := req.Payload
|
||||
if len(opts.OriginalRequest) > 0 {
|
||||
originalPayload = bytes.Clone(opts.OriginalRequest)
|
||||
originalPayloadSource = opts.OriginalRequest
|
||||
}
|
||||
originalPayload := originalPayloadSource
|
||||
originalTranslated := sdktranslator.TranslateRequest(from, to, baseModel, originalPayload, stream)
|
||||
payload := sdktranslator.TranslateRequest(from, to, baseModel, bytes.Clone(req.Payload), stream)
|
||||
payload := sdktranslator.TranslateRequest(from, to, baseModel, req.Payload, stream)
|
||||
payload, err := thinking.ApplyThinking(payload, req.Model, from.String(), to.String(), e.Identifier())
|
||||
if err != nil {
|
||||
return nil, translatedPayload{}, err
|
||||
|
||||
@@ -133,12 +133,13 @@ func (e *AntigravityExecutor) Execute(ctx context.Context, auth *cliproxyauth.Au
|
||||
from := opts.SourceFormat
|
||||
to := sdktranslator.FromString("antigravity")
|
||||
|
||||
originalPayload := bytes.Clone(req.Payload)
|
||||
originalPayloadSource := req.Payload
|
||||
if len(opts.OriginalRequest) > 0 {
|
||||
originalPayload = bytes.Clone(opts.OriginalRequest)
|
||||
originalPayloadSource = opts.OriginalRequest
|
||||
}
|
||||
originalPayload := originalPayloadSource
|
||||
originalTranslated := sdktranslator.TranslateRequest(from, to, baseModel, originalPayload, false)
|
||||
translated := sdktranslator.TranslateRequest(from, to, baseModel, bytes.Clone(req.Payload), false)
|
||||
translated := sdktranslator.TranslateRequest(from, to, baseModel, req.Payload, false)
|
||||
|
||||
translated, err = thinking.ApplyThinking(translated, req.Model, from.String(), to.String(), e.Identifier())
|
||||
if err != nil {
|
||||
@@ -230,7 +231,7 @@ attemptLoop:
|
||||
|
||||
reporter.publish(ctx, parseAntigravityUsage(bodyBytes))
|
||||
var param any
|
||||
converted := sdktranslator.TranslateNonStream(ctx, to, from, req.Model, bytes.Clone(opts.OriginalRequest), translated, bodyBytes, ¶m)
|
||||
converted := sdktranslator.TranslateNonStream(ctx, to, from, req.Model, opts.OriginalRequest, translated, bodyBytes, ¶m)
|
||||
resp = cliproxyexecutor.Response{Payload: []byte(converted)}
|
||||
reporter.ensurePublished(ctx)
|
||||
return resp, nil
|
||||
@@ -274,12 +275,13 @@ func (e *AntigravityExecutor) executeClaudeNonStream(ctx context.Context, auth *
|
||||
from := opts.SourceFormat
|
||||
to := sdktranslator.FromString("antigravity")
|
||||
|
||||
originalPayload := bytes.Clone(req.Payload)
|
||||
originalPayloadSource := req.Payload
|
||||
if len(opts.OriginalRequest) > 0 {
|
||||
originalPayload = bytes.Clone(opts.OriginalRequest)
|
||||
originalPayloadSource = opts.OriginalRequest
|
||||
}
|
||||
originalPayload := originalPayloadSource
|
||||
originalTranslated := sdktranslator.TranslateRequest(from, to, baseModel, originalPayload, true)
|
||||
translated := sdktranslator.TranslateRequest(from, to, baseModel, bytes.Clone(req.Payload), true)
|
||||
translated := sdktranslator.TranslateRequest(from, to, baseModel, req.Payload, true)
|
||||
|
||||
translated, err = thinking.ApplyThinking(translated, req.Model, from.String(), to.String(), e.Identifier())
|
||||
if err != nil {
|
||||
@@ -433,7 +435,7 @@ attemptLoop:
|
||||
|
||||
reporter.publish(ctx, parseAntigravityUsage(resp.Payload))
|
||||
var param any
|
||||
converted := sdktranslator.TranslateNonStream(ctx, to, from, req.Model, bytes.Clone(opts.OriginalRequest), translated, resp.Payload, ¶m)
|
||||
converted := sdktranslator.TranslateNonStream(ctx, to, from, req.Model, opts.OriginalRequest, translated, resp.Payload, ¶m)
|
||||
resp = cliproxyexecutor.Response{Payload: []byte(converted)}
|
||||
reporter.ensurePublished(ctx)
|
||||
|
||||
@@ -665,12 +667,13 @@ func (e *AntigravityExecutor) ExecuteStream(ctx context.Context, auth *cliproxya
|
||||
from := opts.SourceFormat
|
||||
to := sdktranslator.FromString("antigravity")
|
||||
|
||||
originalPayload := bytes.Clone(req.Payload)
|
||||
originalPayloadSource := req.Payload
|
||||
if len(opts.OriginalRequest) > 0 {
|
||||
originalPayload = bytes.Clone(opts.OriginalRequest)
|
||||
originalPayloadSource = opts.OriginalRequest
|
||||
}
|
||||
originalPayload := originalPayloadSource
|
||||
originalTranslated := sdktranslator.TranslateRequest(from, to, baseModel, originalPayload, true)
|
||||
translated := sdktranslator.TranslateRequest(from, to, baseModel, bytes.Clone(req.Payload), true)
|
||||
translated := sdktranslator.TranslateRequest(from, to, baseModel, req.Payload, true)
|
||||
|
||||
translated, err = thinking.ApplyThinking(translated, req.Model, from.String(), to.String(), e.Identifier())
|
||||
if err != nil {
|
||||
@@ -800,12 +803,12 @@ attemptLoop:
|
||||
reporter.publish(ctx, detail)
|
||||
}
|
||||
|
||||
chunks := sdktranslator.TranslateStream(ctx, to, from, req.Model, bytes.Clone(opts.OriginalRequest), translated, bytes.Clone(payload), ¶m)
|
||||
chunks := sdktranslator.TranslateStream(ctx, to, from, req.Model, opts.OriginalRequest, translated, bytes.Clone(payload), ¶m)
|
||||
for i := range chunks {
|
||||
out <- cliproxyexecutor.StreamChunk{Payload: []byte(chunks[i])}
|
||||
}
|
||||
}
|
||||
tail := sdktranslator.TranslateStream(ctx, to, from, req.Model, bytes.Clone(opts.OriginalRequest), translated, []byte("[DONE]"), ¶m)
|
||||
tail := sdktranslator.TranslateStream(ctx, to, from, req.Model, opts.OriginalRequest, translated, []byte("[DONE]"), ¶m)
|
||||
for i := range tail {
|
||||
out <- cliproxyexecutor.StreamChunk{Payload: []byte(tail[i])}
|
||||
}
|
||||
@@ -872,7 +875,7 @@ func (e *AntigravityExecutor) CountTokens(ctx context.Context, auth *cliproxyaut
|
||||
respCtx := context.WithValue(ctx, "alt", opts.Alt)
|
||||
|
||||
// Prepare payload once (doesn't depend on baseURL)
|
||||
payload := sdktranslator.TranslateRequest(from, to, baseModel, bytes.Clone(req.Payload), false)
|
||||
payload := sdktranslator.TranslateRequest(from, to, baseModel, req.Payload, false)
|
||||
|
||||
payload, err := thinking.ApplyThinking(payload, req.Model, from.String(), to.String(), e.Identifier())
|
||||
if err != nil {
|
||||
@@ -1280,51 +1283,40 @@ func (e *AntigravityExecutor) buildRequest(ctx context.Context, auth *cliproxyau
|
||||
payload = geminiToAntigravity(modelName, payload, projectID)
|
||||
payload, _ = sjson.SetBytes(payload, "model", modelName)
|
||||
|
||||
if strings.Contains(modelName, "claude") || strings.Contains(modelName, "gemini-3-pro-high") {
|
||||
strJSON := string(payload)
|
||||
paths := make([]string, 0)
|
||||
util.Walk(gjson.ParseBytes(payload), "", "parametersJsonSchema", &paths)
|
||||
for _, p := range paths {
|
||||
strJSON, _ = util.RenameKey(strJSON, p, p[:len(p)-len("parametersJsonSchema")]+"parameters")
|
||||
}
|
||||
|
||||
// Use the centralized schema cleaner to handle unsupported keywords,
|
||||
// const->enum conversion, and flattening of types/anyOf.
|
||||
strJSON = util.CleanJSONSchemaForAntigravity(strJSON)
|
||||
payload = []byte(strJSON)
|
||||
} else {
|
||||
strJSON := string(payload)
|
||||
paths := make([]string, 0)
|
||||
util.Walk(gjson.Parse(strJSON), "", "parametersJsonSchema", &paths)
|
||||
for _, p := range paths {
|
||||
strJSON, _ = util.RenameKey(strJSON, p, p[:len(p)-len("parametersJsonSchema")]+"parameters")
|
||||
}
|
||||
// Clean tool schemas for Gemini to remove unsupported JSON Schema keywords
|
||||
// without adding empty-schema placeholders.
|
||||
strJSON = util.CleanJSONSchemaForGemini(strJSON)
|
||||
payload = []byte(strJSON)
|
||||
useAntigravitySchema := strings.Contains(modelName, "claude") || strings.Contains(modelName, "gemini-3-pro-high")
|
||||
payloadStr := string(payload)
|
||||
paths := make([]string, 0)
|
||||
util.Walk(gjson.Parse(payloadStr), "", "parametersJsonSchema", &paths)
|
||||
for _, p := range paths {
|
||||
payloadStr, _ = util.RenameKey(payloadStr, p, p[:len(p)-len("parametersJsonSchema")]+"parameters")
|
||||
}
|
||||
|
||||
if strings.Contains(modelName, "claude") || strings.Contains(modelName, "gemini-3-pro-high") {
|
||||
systemInstructionPartsResult := gjson.GetBytes(payload, "request.systemInstruction.parts")
|
||||
payload, _ = sjson.SetBytes(payload, "request.systemInstruction.role", "user")
|
||||
payload, _ = sjson.SetBytes(payload, "request.systemInstruction.parts.0.text", systemInstruction)
|
||||
payload, _ = sjson.SetBytes(payload, "request.systemInstruction.parts.1.text", fmt.Sprintf("Please ignore following [ignore]%s[/ignore]", systemInstruction))
|
||||
if useAntigravitySchema {
|
||||
payloadStr = util.CleanJSONSchemaForAntigravity(payloadStr)
|
||||
} else {
|
||||
payloadStr = util.CleanJSONSchemaForGemini(payloadStr)
|
||||
}
|
||||
|
||||
if useAntigravitySchema {
|
||||
systemInstructionPartsResult := gjson.Get(payloadStr, "request.systemInstruction.parts")
|
||||
payloadStr, _ = sjson.Set(payloadStr, "request.systemInstruction.role", "user")
|
||||
payloadStr, _ = sjson.Set(payloadStr, "request.systemInstruction.parts.0.text", systemInstruction)
|
||||
payloadStr, _ = sjson.Set(payloadStr, "request.systemInstruction.parts.1.text", fmt.Sprintf("Please ignore following [ignore]%s[/ignore]", systemInstruction))
|
||||
|
||||
if systemInstructionPartsResult.Exists() && systemInstructionPartsResult.IsArray() {
|
||||
for _, partResult := range systemInstructionPartsResult.Array() {
|
||||
payload, _ = sjson.SetRawBytes(payload, "request.systemInstruction.parts.-1", []byte(partResult.Raw))
|
||||
payloadStr, _ = sjson.SetRaw(payloadStr, "request.systemInstruction.parts.-1", partResult.Raw)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
if strings.Contains(modelName, "claude") {
|
||||
payload, _ = sjson.SetBytes(payload, "request.toolConfig.functionCallingConfig.mode", "VALIDATED")
|
||||
payloadStr, _ = sjson.Set(payloadStr, "request.toolConfig.functionCallingConfig.mode", "VALIDATED")
|
||||
} else {
|
||||
payload, _ = sjson.DeleteBytes(payload, "request.generationConfig.maxOutputTokens")
|
||||
payloadStr, _ = sjson.Delete(payloadStr, "request.generationConfig.maxOutputTokens")
|
||||
}
|
||||
|
||||
httpReq, errReq := http.NewRequestWithContext(ctx, http.MethodPost, requestURL.String(), bytes.NewReader(payload))
|
||||
httpReq, errReq := http.NewRequestWithContext(ctx, http.MethodPost, requestURL.String(), strings.NewReader(payloadStr))
|
||||
if errReq != nil {
|
||||
return nil, errReq
|
||||
}
|
||||
@@ -1346,11 +1338,15 @@ func (e *AntigravityExecutor) buildRequest(ctx context.Context, auth *cliproxyau
|
||||
authLabel = auth.Label
|
||||
authType, authValue = auth.AccountInfo()
|
||||
}
|
||||
var payloadLog []byte
|
||||
if e.cfg != nil && e.cfg.RequestLog {
|
||||
payloadLog = []byte(payloadStr)
|
||||
}
|
||||
recordAPIRequest(ctx, e.cfg, upstreamRequestLog{
|
||||
URL: requestURL.String(),
|
||||
Method: http.MethodPost,
|
||||
Headers: httpReq.Header.Clone(),
|
||||
Body: payload,
|
||||
Body: payloadLog,
|
||||
Provider: e.Identifier(),
|
||||
AuthID: authID,
|
||||
AuthLabel: authLabel,
|
||||
|
||||
@@ -100,12 +100,13 @@ func (e *ClaudeExecutor) Execute(ctx context.Context, auth *cliproxyauth.Auth, r
|
||||
to := sdktranslator.FromString("claude")
|
||||
// Use streaming translation to preserve function calling, except for claude.
|
||||
stream := from != to
|
||||
originalPayload := bytes.Clone(req.Payload)
|
||||
originalPayloadSource := req.Payload
|
||||
if len(opts.OriginalRequest) > 0 {
|
||||
originalPayload = bytes.Clone(opts.OriginalRequest)
|
||||
originalPayloadSource = opts.OriginalRequest
|
||||
}
|
||||
originalPayload := originalPayloadSource
|
||||
originalTranslated := sdktranslator.TranslateRequest(from, to, baseModel, originalPayload, stream)
|
||||
body := sdktranslator.TranslateRequest(from, to, baseModel, bytes.Clone(req.Payload), stream)
|
||||
body := sdktranslator.TranslateRequest(from, to, baseModel, req.Payload, stream)
|
||||
body, _ = sjson.SetBytes(body, "model", baseModel)
|
||||
|
||||
body, err = thinking.ApplyThinking(body, req.Model, from.String(), to.String(), e.Identifier())
|
||||
@@ -216,7 +217,7 @@ func (e *ClaudeExecutor) Execute(ctx context.Context, auth *cliproxyauth.Auth, r
|
||||
to,
|
||||
from,
|
||||
req.Model,
|
||||
bytes.Clone(opts.OriginalRequest),
|
||||
opts.OriginalRequest,
|
||||
bodyForTranslation,
|
||||
data,
|
||||
¶m,
|
||||
@@ -240,12 +241,13 @@ func (e *ClaudeExecutor) ExecuteStream(ctx context.Context, auth *cliproxyauth.A
|
||||
defer reporter.trackFailure(ctx, &err)
|
||||
from := opts.SourceFormat
|
||||
to := sdktranslator.FromString("claude")
|
||||
originalPayload := bytes.Clone(req.Payload)
|
||||
originalPayloadSource := req.Payload
|
||||
if len(opts.OriginalRequest) > 0 {
|
||||
originalPayload = bytes.Clone(opts.OriginalRequest)
|
||||
originalPayloadSource = opts.OriginalRequest
|
||||
}
|
||||
originalPayload := originalPayloadSource
|
||||
originalTranslated := sdktranslator.TranslateRequest(from, to, baseModel, originalPayload, true)
|
||||
body := sdktranslator.TranslateRequest(from, to, baseModel, bytes.Clone(req.Payload), true)
|
||||
body := sdktranslator.TranslateRequest(from, to, baseModel, req.Payload, true)
|
||||
body, _ = sjson.SetBytes(body, "model", baseModel)
|
||||
|
||||
body, err = thinking.ApplyThinking(body, req.Model, from.String(), to.String(), e.Identifier())
|
||||
@@ -381,7 +383,7 @@ func (e *ClaudeExecutor) ExecuteStream(ctx context.Context, auth *cliproxyauth.A
|
||||
to,
|
||||
from,
|
||||
req.Model,
|
||||
bytes.Clone(opts.OriginalRequest),
|
||||
opts.OriginalRequest,
|
||||
bodyForTranslation,
|
||||
bytes.Clone(line),
|
||||
¶m,
|
||||
@@ -411,7 +413,7 @@ func (e *ClaudeExecutor) CountTokens(ctx context.Context, auth *cliproxyauth.Aut
|
||||
to := sdktranslator.FromString("claude")
|
||||
// Use streaming translation to preserve function calling, except for claude.
|
||||
stream := from != to
|
||||
body := sdktranslator.TranslateRequest(from, to, baseModel, bytes.Clone(req.Payload), stream)
|
||||
body := sdktranslator.TranslateRequest(from, to, baseModel, req.Payload, stream)
|
||||
body, _ = sjson.SetBytes(body, "model", baseModel)
|
||||
|
||||
if !strings.HasPrefix(baseModel, "claude-3-5-haiku") {
|
||||
|
||||
@@ -27,6 +27,11 @@ import (
|
||||
"github.com/google/uuid"
|
||||
)
|
||||
|
||||
const (
|
||||
codexClientVersion = "0.98.0"
|
||||
codexUserAgent = "codex_cli_rs/0.98.0 (Mac OS 26.0.1; arm64) Apple_Terminal/464"
|
||||
)
|
||||
|
||||
var dataTag = []byte("data:")
|
||||
|
||||
// CodexExecutor is a stateless executor for Codex (OpenAI Responses API entrypoint).
|
||||
@@ -88,12 +93,13 @@ func (e *CodexExecutor) Execute(ctx context.Context, auth *cliproxyauth.Auth, re
|
||||
|
||||
from := opts.SourceFormat
|
||||
to := sdktranslator.FromString("codex")
|
||||
originalPayload := bytes.Clone(req.Payload)
|
||||
originalPayloadSource := req.Payload
|
||||
if len(opts.OriginalRequest) > 0 {
|
||||
originalPayload = bytes.Clone(opts.OriginalRequest)
|
||||
originalPayloadSource = opts.OriginalRequest
|
||||
}
|
||||
originalPayload := originalPayloadSource
|
||||
originalTranslated := sdktranslator.TranslateRequest(from, to, baseModel, originalPayload, false)
|
||||
body := sdktranslator.TranslateRequest(from, to, baseModel, bytes.Clone(req.Payload), false)
|
||||
body := sdktranslator.TranslateRequest(from, to, baseModel, req.Payload, false)
|
||||
|
||||
body, err = thinking.ApplyThinking(body, req.Model, from.String(), to.String(), e.Identifier())
|
||||
if err != nil {
|
||||
@@ -176,7 +182,7 @@ func (e *CodexExecutor) Execute(ctx context.Context, auth *cliproxyauth.Auth, re
|
||||
}
|
||||
|
||||
var param any
|
||||
out := sdktranslator.TranslateNonStream(ctx, to, from, req.Model, bytes.Clone(originalPayload), body, line, ¶m)
|
||||
out := sdktranslator.TranslateNonStream(ctx, to, from, req.Model, originalPayload, body, line, ¶m)
|
||||
resp = cliproxyexecutor.Response{Payload: []byte(out)}
|
||||
return resp, nil
|
||||
}
|
||||
@@ -197,12 +203,13 @@ func (e *CodexExecutor) executeCompact(ctx context.Context, auth *cliproxyauth.A
|
||||
|
||||
from := opts.SourceFormat
|
||||
to := sdktranslator.FromString("openai-response")
|
||||
originalPayload := bytes.Clone(req.Payload)
|
||||
originalPayloadSource := req.Payload
|
||||
if len(opts.OriginalRequest) > 0 {
|
||||
originalPayload = bytes.Clone(opts.OriginalRequest)
|
||||
originalPayloadSource = opts.OriginalRequest
|
||||
}
|
||||
originalPayload := originalPayloadSource
|
||||
originalTranslated := sdktranslator.TranslateRequest(from, to, baseModel, originalPayload, false)
|
||||
body := sdktranslator.TranslateRequest(from, to, baseModel, bytes.Clone(req.Payload), false)
|
||||
body := sdktranslator.TranslateRequest(from, to, baseModel, req.Payload, false)
|
||||
|
||||
body, err = thinking.ApplyThinking(body, req.Model, from.String(), to.String(), e.Identifier())
|
||||
if err != nil {
|
||||
@@ -265,7 +272,7 @@ func (e *CodexExecutor) executeCompact(ctx context.Context, auth *cliproxyauth.A
|
||||
reporter.publish(ctx, parseOpenAIUsage(data))
|
||||
reporter.ensurePublished(ctx)
|
||||
var param any
|
||||
out := sdktranslator.TranslateNonStream(ctx, to, from, req.Model, bytes.Clone(originalPayload), body, data, ¶m)
|
||||
out := sdktranslator.TranslateNonStream(ctx, to, from, req.Model, originalPayload, body, data, ¶m)
|
||||
resp = cliproxyexecutor.Response{Payload: []byte(out)}
|
||||
return resp, nil
|
||||
}
|
||||
@@ -286,12 +293,13 @@ func (e *CodexExecutor) ExecuteStream(ctx context.Context, auth *cliproxyauth.Au
|
||||
|
||||
from := opts.SourceFormat
|
||||
to := sdktranslator.FromString("codex")
|
||||
originalPayload := bytes.Clone(req.Payload)
|
||||
originalPayloadSource := req.Payload
|
||||
if len(opts.OriginalRequest) > 0 {
|
||||
originalPayload = bytes.Clone(opts.OriginalRequest)
|
||||
originalPayloadSource = opts.OriginalRequest
|
||||
}
|
||||
originalPayload := originalPayloadSource
|
||||
originalTranslated := sdktranslator.TranslateRequest(from, to, baseModel, originalPayload, true)
|
||||
body := sdktranslator.TranslateRequest(from, to, baseModel, bytes.Clone(req.Payload), true)
|
||||
body := sdktranslator.TranslateRequest(from, to, baseModel, req.Payload, true)
|
||||
|
||||
body, err = thinking.ApplyThinking(body, req.Model, from.String(), to.String(), e.Identifier())
|
||||
if err != nil {
|
||||
@@ -378,7 +386,7 @@ func (e *CodexExecutor) ExecuteStream(ctx context.Context, auth *cliproxyauth.Au
|
||||
}
|
||||
}
|
||||
|
||||
chunks := sdktranslator.TranslateStream(ctx, to, from, req.Model, bytes.Clone(originalPayload), body, bytes.Clone(line), ¶m)
|
||||
chunks := sdktranslator.TranslateStream(ctx, to, from, req.Model, originalPayload, body, bytes.Clone(line), ¶m)
|
||||
for i := range chunks {
|
||||
out <- cliproxyexecutor.StreamChunk{Payload: []byte(chunks[i])}
|
||||
}
|
||||
@@ -397,7 +405,7 @@ func (e *CodexExecutor) CountTokens(ctx context.Context, auth *cliproxyauth.Auth
|
||||
|
||||
from := opts.SourceFormat
|
||||
to := sdktranslator.FromString("codex")
|
||||
body := sdktranslator.TranslateRequest(from, to, baseModel, bytes.Clone(req.Payload), false)
|
||||
body := sdktranslator.TranslateRequest(from, to, baseModel, req.Payload, false)
|
||||
|
||||
body, err := thinking.ApplyThinking(body, req.Model, from.String(), to.String(), e.Identifier())
|
||||
if err != nil {
|
||||
@@ -634,10 +642,10 @@ func applyCodexHeaders(r *http.Request, auth *cliproxyauth.Auth, token string, s
|
||||
ginHeaders = ginCtx.Request.Header
|
||||
}
|
||||
|
||||
misc.EnsureHeader(r.Header, ginHeaders, "Version", "0.21.0")
|
||||
misc.EnsureHeader(r.Header, ginHeaders, "Version", codexClientVersion)
|
||||
misc.EnsureHeader(r.Header, ginHeaders, "Openai-Beta", "responses=experimental")
|
||||
misc.EnsureHeader(r.Header, ginHeaders, "Session_id", uuid.NewString())
|
||||
misc.EnsureHeader(r.Header, ginHeaders, "User-Agent", "codex_cli_rs/0.50.0 (Mac OS 26.0.1; arm64) Apple_Terminal/464")
|
||||
misc.EnsureHeader(r.Header, ginHeaders, "User-Agent", codexUserAgent)
|
||||
|
||||
if stream {
|
||||
r.Header.Set("Accept", "text/event-stream")
|
||||
|
||||
@@ -119,12 +119,13 @@ func (e *GeminiCLIExecutor) Execute(ctx context.Context, auth *cliproxyauth.Auth
|
||||
from := opts.SourceFormat
|
||||
to := sdktranslator.FromString("gemini-cli")
|
||||
|
||||
originalPayload := bytes.Clone(req.Payload)
|
||||
originalPayloadSource := req.Payload
|
||||
if len(opts.OriginalRequest) > 0 {
|
||||
originalPayload = bytes.Clone(opts.OriginalRequest)
|
||||
originalPayloadSource = opts.OriginalRequest
|
||||
}
|
||||
originalPayload := originalPayloadSource
|
||||
originalTranslated := sdktranslator.TranslateRequest(from, to, baseModel, originalPayload, false)
|
||||
basePayload := sdktranslator.TranslateRequest(from, to, baseModel, bytes.Clone(req.Payload), false)
|
||||
basePayload := sdktranslator.TranslateRequest(from, to, baseModel, req.Payload, false)
|
||||
|
||||
basePayload, err = thinking.ApplyThinking(basePayload, req.Model, from.String(), to.String(), e.Identifier())
|
||||
if err != nil {
|
||||
@@ -223,7 +224,7 @@ func (e *GeminiCLIExecutor) Execute(ctx context.Context, auth *cliproxyauth.Auth
|
||||
if httpResp.StatusCode >= 200 && httpResp.StatusCode < 300 {
|
||||
reporter.publish(ctx, parseGeminiCLIUsage(data))
|
||||
var param any
|
||||
out := sdktranslator.TranslateNonStream(respCtx, to, from, attemptModel, bytes.Clone(opts.OriginalRequest), payload, data, ¶m)
|
||||
out := sdktranslator.TranslateNonStream(respCtx, to, from, attemptModel, opts.OriginalRequest, payload, data, ¶m)
|
||||
resp = cliproxyexecutor.Response{Payload: []byte(out)}
|
||||
return resp, nil
|
||||
}
|
||||
@@ -272,12 +273,13 @@ func (e *GeminiCLIExecutor) ExecuteStream(ctx context.Context, auth *cliproxyaut
|
||||
from := opts.SourceFormat
|
||||
to := sdktranslator.FromString("gemini-cli")
|
||||
|
||||
originalPayload := bytes.Clone(req.Payload)
|
||||
originalPayloadSource := req.Payload
|
||||
if len(opts.OriginalRequest) > 0 {
|
||||
originalPayload = bytes.Clone(opts.OriginalRequest)
|
||||
originalPayloadSource = opts.OriginalRequest
|
||||
}
|
||||
originalPayload := originalPayloadSource
|
||||
originalTranslated := sdktranslator.TranslateRequest(from, to, baseModel, originalPayload, true)
|
||||
basePayload := sdktranslator.TranslateRequest(from, to, baseModel, bytes.Clone(req.Payload), true)
|
||||
basePayload := sdktranslator.TranslateRequest(from, to, baseModel, req.Payload, true)
|
||||
|
||||
basePayload, err = thinking.ApplyThinking(basePayload, req.Model, from.String(), to.String(), e.Identifier())
|
||||
if err != nil {
|
||||
@@ -399,14 +401,14 @@ func (e *GeminiCLIExecutor) ExecuteStream(ctx context.Context, auth *cliproxyaut
|
||||
reporter.publish(ctx, detail)
|
||||
}
|
||||
if bytes.HasPrefix(line, dataTag) {
|
||||
segments := sdktranslator.TranslateStream(respCtx, to, from, attemptModel, bytes.Clone(opts.OriginalRequest), reqBody, bytes.Clone(line), ¶m)
|
||||
segments := sdktranslator.TranslateStream(respCtx, to, from, attemptModel, opts.OriginalRequest, reqBody, bytes.Clone(line), ¶m)
|
||||
for i := range segments {
|
||||
out <- cliproxyexecutor.StreamChunk{Payload: []byte(segments[i])}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
segments := sdktranslator.TranslateStream(respCtx, to, from, attemptModel, bytes.Clone(opts.OriginalRequest), reqBody, bytes.Clone([]byte("[DONE]")), ¶m)
|
||||
segments := sdktranslator.TranslateStream(respCtx, to, from, attemptModel, opts.OriginalRequest, reqBody, []byte("[DONE]"), ¶m)
|
||||
for i := range segments {
|
||||
out <- cliproxyexecutor.StreamChunk{Payload: []byte(segments[i])}
|
||||
}
|
||||
@@ -428,12 +430,12 @@ func (e *GeminiCLIExecutor) ExecuteStream(ctx context.Context, auth *cliproxyaut
|
||||
appendAPIResponseChunk(ctx, e.cfg, data)
|
||||
reporter.publish(ctx, parseGeminiCLIUsage(data))
|
||||
var param any
|
||||
segments := sdktranslator.TranslateStream(respCtx, to, from, attemptModel, bytes.Clone(opts.OriginalRequest), reqBody, data, ¶m)
|
||||
segments := sdktranslator.TranslateStream(respCtx, to, from, attemptModel, opts.OriginalRequest, reqBody, data, ¶m)
|
||||
for i := range segments {
|
||||
out <- cliproxyexecutor.StreamChunk{Payload: []byte(segments[i])}
|
||||
}
|
||||
|
||||
segments = sdktranslator.TranslateStream(respCtx, to, from, attemptModel, bytes.Clone(opts.OriginalRequest), reqBody, bytes.Clone([]byte("[DONE]")), ¶m)
|
||||
segments = sdktranslator.TranslateStream(respCtx, to, from, attemptModel, opts.OriginalRequest, reqBody, []byte("[DONE]"), ¶m)
|
||||
for i := range segments {
|
||||
out <- cliproxyexecutor.StreamChunk{Payload: []byte(segments[i])}
|
||||
}
|
||||
@@ -485,7 +487,7 @@ func (e *GeminiCLIExecutor) CountTokens(ctx context.Context, auth *cliproxyauth.
|
||||
// The loop variable attemptModel is only used as the concrete model id sent to the upstream
|
||||
// Gemini CLI endpoint when iterating fallback variants.
|
||||
for range models {
|
||||
payload := sdktranslator.TranslateRequest(from, to, baseModel, bytes.Clone(req.Payload), false)
|
||||
payload := sdktranslator.TranslateRequest(from, to, baseModel, req.Payload, false)
|
||||
|
||||
payload, err = thinking.ApplyThinking(payload, req.Model, from.String(), to.String(), e.Identifier())
|
||||
if err != nil {
|
||||
|
||||
@@ -116,12 +116,13 @@ func (e *GeminiExecutor) Execute(ctx context.Context, auth *cliproxyauth.Auth, r
|
||||
// Official Gemini API via API key or OAuth bearer
|
||||
from := opts.SourceFormat
|
||||
to := sdktranslator.FromString("gemini")
|
||||
originalPayload := bytes.Clone(req.Payload)
|
||||
originalPayloadSource := req.Payload
|
||||
if len(opts.OriginalRequest) > 0 {
|
||||
originalPayload = bytes.Clone(opts.OriginalRequest)
|
||||
originalPayloadSource = opts.OriginalRequest
|
||||
}
|
||||
originalPayload := originalPayloadSource
|
||||
originalTranslated := sdktranslator.TranslateRequest(from, to, baseModel, originalPayload, false)
|
||||
body := sdktranslator.TranslateRequest(from, to, baseModel, bytes.Clone(req.Payload), false)
|
||||
body := sdktranslator.TranslateRequest(from, to, baseModel, req.Payload, false)
|
||||
|
||||
body, err = thinking.ApplyThinking(body, req.Model, from.String(), to.String(), e.Identifier())
|
||||
if err != nil {
|
||||
@@ -203,7 +204,7 @@ func (e *GeminiExecutor) Execute(ctx context.Context, auth *cliproxyauth.Auth, r
|
||||
appendAPIResponseChunk(ctx, e.cfg, data)
|
||||
reporter.publish(ctx, parseGeminiUsage(data))
|
||||
var param any
|
||||
out := sdktranslator.TranslateNonStream(ctx, to, from, req.Model, bytes.Clone(opts.OriginalRequest), body, data, ¶m)
|
||||
out := sdktranslator.TranslateNonStream(ctx, to, from, req.Model, opts.OriginalRequest, body, data, ¶m)
|
||||
resp = cliproxyexecutor.Response{Payload: []byte(out)}
|
||||
return resp, nil
|
||||
}
|
||||
@@ -222,12 +223,13 @@ func (e *GeminiExecutor) ExecuteStream(ctx context.Context, auth *cliproxyauth.A
|
||||
|
||||
from := opts.SourceFormat
|
||||
to := sdktranslator.FromString("gemini")
|
||||
originalPayload := bytes.Clone(req.Payload)
|
||||
originalPayloadSource := req.Payload
|
||||
if len(opts.OriginalRequest) > 0 {
|
||||
originalPayload = bytes.Clone(opts.OriginalRequest)
|
||||
originalPayloadSource = opts.OriginalRequest
|
||||
}
|
||||
originalPayload := originalPayloadSource
|
||||
originalTranslated := sdktranslator.TranslateRequest(from, to, baseModel, originalPayload, true)
|
||||
body := sdktranslator.TranslateRequest(from, to, baseModel, bytes.Clone(req.Payload), true)
|
||||
body := sdktranslator.TranslateRequest(from, to, baseModel, req.Payload, true)
|
||||
|
||||
body, err = thinking.ApplyThinking(body, req.Model, from.String(), to.String(), e.Identifier())
|
||||
if err != nil {
|
||||
@@ -318,12 +320,12 @@ func (e *GeminiExecutor) ExecuteStream(ctx context.Context, auth *cliproxyauth.A
|
||||
if detail, ok := parseGeminiStreamUsage(payload); ok {
|
||||
reporter.publish(ctx, detail)
|
||||
}
|
||||
lines := sdktranslator.TranslateStream(ctx, to, from, req.Model, bytes.Clone(opts.OriginalRequest), body, bytes.Clone(payload), ¶m)
|
||||
lines := sdktranslator.TranslateStream(ctx, to, from, req.Model, opts.OriginalRequest, body, bytes.Clone(payload), ¶m)
|
||||
for i := range lines {
|
||||
out <- cliproxyexecutor.StreamChunk{Payload: []byte(lines[i])}
|
||||
}
|
||||
}
|
||||
lines := sdktranslator.TranslateStream(ctx, to, from, req.Model, bytes.Clone(opts.OriginalRequest), body, bytes.Clone([]byte("[DONE]")), ¶m)
|
||||
lines := sdktranslator.TranslateStream(ctx, to, from, req.Model, opts.OriginalRequest, body, []byte("[DONE]"), ¶m)
|
||||
for i := range lines {
|
||||
out <- cliproxyexecutor.StreamChunk{Payload: []byte(lines[i])}
|
||||
}
|
||||
@@ -344,7 +346,7 @@ func (e *GeminiExecutor) CountTokens(ctx context.Context, auth *cliproxyauth.Aut
|
||||
|
||||
from := opts.SourceFormat
|
||||
to := sdktranslator.FromString("gemini")
|
||||
translatedReq := sdktranslator.TranslateRequest(from, to, baseModel, bytes.Clone(req.Payload), false)
|
||||
translatedReq := sdktranslator.TranslateRequest(from, to, baseModel, req.Payload, false)
|
||||
|
||||
translatedReq, err := thinking.ApplyThinking(translatedReq, req.Model, from.String(), to.String(), e.Identifier())
|
||||
if err != nil {
|
||||
|
||||
@@ -318,12 +318,13 @@ func (e *GeminiVertexExecutor) executeWithServiceAccount(ctx context.Context, au
|
||||
from := opts.SourceFormat
|
||||
to := sdktranslator.FromString("gemini")
|
||||
|
||||
originalPayload := bytes.Clone(req.Payload)
|
||||
originalPayloadSource := req.Payload
|
||||
if len(opts.OriginalRequest) > 0 {
|
||||
originalPayload = bytes.Clone(opts.OriginalRequest)
|
||||
originalPayloadSource = opts.OriginalRequest
|
||||
}
|
||||
originalPayload := originalPayloadSource
|
||||
originalTranslated := sdktranslator.TranslateRequest(from, to, baseModel, originalPayload, false)
|
||||
body = sdktranslator.TranslateRequest(from, to, baseModel, bytes.Clone(req.Payload), false)
|
||||
body = sdktranslator.TranslateRequest(from, to, baseModel, req.Payload, false)
|
||||
|
||||
body, err = thinking.ApplyThinking(body, req.Model, from.String(), to.String(), e.Identifier())
|
||||
if err != nil {
|
||||
@@ -417,7 +418,7 @@ func (e *GeminiVertexExecutor) executeWithServiceAccount(ctx context.Context, au
|
||||
from := opts.SourceFormat
|
||||
to := sdktranslator.FromString("gemini")
|
||||
var param any
|
||||
out := sdktranslator.TranslateNonStream(ctx, to, from, req.Model, bytes.Clone(opts.OriginalRequest), body, data, ¶m)
|
||||
out := sdktranslator.TranslateNonStream(ctx, to, from, req.Model, opts.OriginalRequest, body, data, ¶m)
|
||||
resp = cliproxyexecutor.Response{Payload: []byte(out)}
|
||||
return resp, nil
|
||||
}
|
||||
@@ -432,12 +433,13 @@ func (e *GeminiVertexExecutor) executeWithAPIKey(ctx context.Context, auth *clip
|
||||
from := opts.SourceFormat
|
||||
to := sdktranslator.FromString("gemini")
|
||||
|
||||
originalPayload := bytes.Clone(req.Payload)
|
||||
originalPayloadSource := req.Payload
|
||||
if len(opts.OriginalRequest) > 0 {
|
||||
originalPayload = bytes.Clone(opts.OriginalRequest)
|
||||
originalPayloadSource = opts.OriginalRequest
|
||||
}
|
||||
originalPayload := originalPayloadSource
|
||||
originalTranslated := sdktranslator.TranslateRequest(from, to, baseModel, originalPayload, false)
|
||||
body := sdktranslator.TranslateRequest(from, to, baseModel, bytes.Clone(req.Payload), false)
|
||||
body := sdktranslator.TranslateRequest(from, to, baseModel, req.Payload, false)
|
||||
|
||||
body, err = thinking.ApplyThinking(body, req.Model, from.String(), to.String(), e.Identifier())
|
||||
if err != nil {
|
||||
@@ -521,7 +523,7 @@ func (e *GeminiVertexExecutor) executeWithAPIKey(ctx context.Context, auth *clip
|
||||
appendAPIResponseChunk(ctx, e.cfg, data)
|
||||
reporter.publish(ctx, parseGeminiUsage(data))
|
||||
var param any
|
||||
out := sdktranslator.TranslateNonStream(ctx, to, from, req.Model, bytes.Clone(opts.OriginalRequest), body, data, ¶m)
|
||||
out := sdktranslator.TranslateNonStream(ctx, to, from, req.Model, opts.OriginalRequest, body, data, ¶m)
|
||||
resp = cliproxyexecutor.Response{Payload: []byte(out)}
|
||||
return resp, nil
|
||||
}
|
||||
@@ -536,12 +538,13 @@ func (e *GeminiVertexExecutor) executeStreamWithServiceAccount(ctx context.Conte
|
||||
from := opts.SourceFormat
|
||||
to := sdktranslator.FromString("gemini")
|
||||
|
||||
originalPayload := bytes.Clone(req.Payload)
|
||||
originalPayloadSource := req.Payload
|
||||
if len(opts.OriginalRequest) > 0 {
|
||||
originalPayload = bytes.Clone(opts.OriginalRequest)
|
||||
originalPayloadSource = opts.OriginalRequest
|
||||
}
|
||||
originalPayload := originalPayloadSource
|
||||
originalTranslated := sdktranslator.TranslateRequest(from, to, baseModel, originalPayload, true)
|
||||
body := sdktranslator.TranslateRequest(from, to, baseModel, bytes.Clone(req.Payload), true)
|
||||
body := sdktranslator.TranslateRequest(from, to, baseModel, req.Payload, true)
|
||||
|
||||
body, err = thinking.ApplyThinking(body, req.Model, from.String(), to.String(), e.Identifier())
|
||||
if err != nil {
|
||||
@@ -632,12 +635,12 @@ func (e *GeminiVertexExecutor) executeStreamWithServiceAccount(ctx context.Conte
|
||||
if detail, ok := parseGeminiStreamUsage(line); ok {
|
||||
reporter.publish(ctx, detail)
|
||||
}
|
||||
lines := sdktranslator.TranslateStream(ctx, to, from, req.Model, bytes.Clone(opts.OriginalRequest), body, bytes.Clone(line), ¶m)
|
||||
lines := sdktranslator.TranslateStream(ctx, to, from, req.Model, opts.OriginalRequest, body, bytes.Clone(line), ¶m)
|
||||
for i := range lines {
|
||||
out <- cliproxyexecutor.StreamChunk{Payload: []byte(lines[i])}
|
||||
}
|
||||
}
|
||||
lines := sdktranslator.TranslateStream(ctx, to, from, req.Model, bytes.Clone(opts.OriginalRequest), body, []byte("[DONE]"), ¶m)
|
||||
lines := sdktranslator.TranslateStream(ctx, to, from, req.Model, opts.OriginalRequest, body, []byte("[DONE]"), ¶m)
|
||||
for i := range lines {
|
||||
out <- cliproxyexecutor.StreamChunk{Payload: []byte(lines[i])}
|
||||
}
|
||||
@@ -660,12 +663,13 @@ func (e *GeminiVertexExecutor) executeStreamWithAPIKey(ctx context.Context, auth
|
||||
from := opts.SourceFormat
|
||||
to := sdktranslator.FromString("gemini")
|
||||
|
||||
originalPayload := bytes.Clone(req.Payload)
|
||||
originalPayloadSource := req.Payload
|
||||
if len(opts.OriginalRequest) > 0 {
|
||||
originalPayload = bytes.Clone(opts.OriginalRequest)
|
||||
originalPayloadSource = opts.OriginalRequest
|
||||
}
|
||||
originalPayload := originalPayloadSource
|
||||
originalTranslated := sdktranslator.TranslateRequest(from, to, baseModel, originalPayload, true)
|
||||
body := sdktranslator.TranslateRequest(from, to, baseModel, bytes.Clone(req.Payload), true)
|
||||
body := sdktranslator.TranslateRequest(from, to, baseModel, req.Payload, true)
|
||||
|
||||
body, err = thinking.ApplyThinking(body, req.Model, from.String(), to.String(), e.Identifier())
|
||||
if err != nil {
|
||||
@@ -756,12 +760,12 @@ func (e *GeminiVertexExecutor) executeStreamWithAPIKey(ctx context.Context, auth
|
||||
if detail, ok := parseGeminiStreamUsage(line); ok {
|
||||
reporter.publish(ctx, detail)
|
||||
}
|
||||
lines := sdktranslator.TranslateStream(ctx, to, from, req.Model, bytes.Clone(opts.OriginalRequest), body, bytes.Clone(line), ¶m)
|
||||
lines := sdktranslator.TranslateStream(ctx, to, from, req.Model, opts.OriginalRequest, body, bytes.Clone(line), ¶m)
|
||||
for i := range lines {
|
||||
out <- cliproxyexecutor.StreamChunk{Payload: []byte(lines[i])}
|
||||
}
|
||||
}
|
||||
lines := sdktranslator.TranslateStream(ctx, to, from, req.Model, bytes.Clone(opts.OriginalRequest), body, []byte("[DONE]"), ¶m)
|
||||
lines := sdktranslator.TranslateStream(ctx, to, from, req.Model, opts.OriginalRequest, body, []byte("[DONE]"), ¶m)
|
||||
for i := range lines {
|
||||
out <- cliproxyexecutor.StreamChunk{Payload: []byte(lines[i])}
|
||||
}
|
||||
@@ -781,7 +785,7 @@ func (e *GeminiVertexExecutor) countTokensWithServiceAccount(ctx context.Context
|
||||
from := opts.SourceFormat
|
||||
to := sdktranslator.FromString("gemini")
|
||||
|
||||
translatedReq := sdktranslator.TranslateRequest(from, to, baseModel, bytes.Clone(req.Payload), false)
|
||||
translatedReq := sdktranslator.TranslateRequest(from, to, baseModel, req.Payload, false)
|
||||
|
||||
translatedReq, err := thinking.ApplyThinking(translatedReq, req.Model, from.String(), to.String(), e.Identifier())
|
||||
if err != nil {
|
||||
@@ -865,7 +869,7 @@ func (e *GeminiVertexExecutor) countTokensWithAPIKey(ctx context.Context, auth *
|
||||
from := opts.SourceFormat
|
||||
to := sdktranslator.FromString("gemini")
|
||||
|
||||
translatedReq := sdktranslator.TranslateRequest(from, to, baseModel, bytes.Clone(req.Payload), false)
|
||||
translatedReq := sdktranslator.TranslateRequest(from, to, baseModel, req.Payload, false)
|
||||
|
||||
translatedReq, err := thinking.ApplyThinking(translatedReq, req.Model, from.String(), to.String(), e.Identifier())
|
||||
if err != nil {
|
||||
|
||||
@@ -7,12 +7,14 @@ import (
|
||||
"fmt"
|
||||
"io"
|
||||
"net/http"
|
||||
"strings"
|
||||
"sync"
|
||||
"time"
|
||||
|
||||
"github.com/google/uuid"
|
||||
copilotauth "github.com/router-for-me/CLIProxyAPI/v6/internal/auth/copilot"
|
||||
"github.com/router-for-me/CLIProxyAPI/v6/internal/config"
|
||||
"github.com/router-for-me/CLIProxyAPI/v6/internal/thinking"
|
||||
cliproxyauth "github.com/router-for-me/CLIProxyAPI/v6/sdk/cliproxy/auth"
|
||||
cliproxyexecutor "github.com/router-for-me/CLIProxyAPI/v6/sdk/cliproxy/executor"
|
||||
sdktranslator "github.com/router-for-me/CLIProxyAPI/v6/sdk/translator"
|
||||
@@ -33,11 +35,11 @@ const (
|
||||
maxScannerBufferSize = 20_971_520
|
||||
|
||||
// Copilot API header values.
|
||||
copilotUserAgent = "GithubCopilot/1.0"
|
||||
copilotEditorVersion = "vscode/1.100.0"
|
||||
copilotPluginVersion = "copilot/1.300.0"
|
||||
copilotUserAgent = "GitHubCopilotChat/0.35.0"
|
||||
copilotEditorVersion = "vscode/1.107.0"
|
||||
copilotPluginVersion = "copilot-chat/0.35.0"
|
||||
copilotIntegrationID = "vscode-chat"
|
||||
copilotOpenAIIntent = "conversation-panel"
|
||||
copilotOpenAIIntent = "conversation-edits"
|
||||
)
|
||||
|
||||
// GitHubCopilotExecutor handles requests to the GitHub Copilot API.
|
||||
@@ -77,7 +79,7 @@ func (e *GitHubCopilotExecutor) PrepareRequest(req *http.Request, auth *cliproxy
|
||||
if errToken != nil {
|
||||
return errToken
|
||||
}
|
||||
e.applyHeaders(req, apiToken)
|
||||
e.applyHeaders(req, apiToken, nil)
|
||||
return nil
|
||||
}
|
||||
|
||||
@@ -120,6 +122,7 @@ func (e *GitHubCopilotExecutor) Execute(ctx context.Context, auth *cliproxyauth.
|
||||
originalTranslated := sdktranslator.TranslateRequest(from, to, req.Model, originalPayload, false)
|
||||
body := sdktranslator.TranslateRequest(from, to, req.Model, bytes.Clone(req.Payload), false)
|
||||
body = e.normalizeModel(req.Model, body)
|
||||
body = flattenAssistantContent(body)
|
||||
requestedModel := payloadRequestedModel(opts, req.Model)
|
||||
body = applyPayloadConfigWithRoot(e.cfg, req.Model, to.String(), "", body, originalTranslated, requestedModel)
|
||||
body, _ = sjson.SetBytes(body, "stream", false)
|
||||
@@ -133,7 +136,7 @@ func (e *GitHubCopilotExecutor) Execute(ctx context.Context, auth *cliproxyauth.
|
||||
if err != nil {
|
||||
return resp, err
|
||||
}
|
||||
e.applyHeaders(httpReq, apiToken)
|
||||
e.applyHeaders(httpReq, apiToken, body)
|
||||
|
||||
// Add Copilot-Vision-Request header if the request contains vision content
|
||||
if detectVisionContent(body) {
|
||||
@@ -225,6 +228,7 @@ func (e *GitHubCopilotExecutor) ExecuteStream(ctx context.Context, auth *cliprox
|
||||
originalTranslated := sdktranslator.TranslateRequest(from, to, req.Model, originalPayload, false)
|
||||
body := sdktranslator.TranslateRequest(from, to, req.Model, bytes.Clone(req.Payload), true)
|
||||
body = e.normalizeModel(req.Model, body)
|
||||
body = flattenAssistantContent(body)
|
||||
requestedModel := payloadRequestedModel(opts, req.Model)
|
||||
body = applyPayloadConfigWithRoot(e.cfg, req.Model, to.String(), "", body, originalTranslated, requestedModel)
|
||||
body, _ = sjson.SetBytes(body, "stream", true)
|
||||
@@ -242,7 +246,7 @@ func (e *GitHubCopilotExecutor) ExecuteStream(ctx context.Context, auth *cliprox
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
e.applyHeaders(httpReq, apiToken)
|
||||
e.applyHeaders(httpReq, apiToken, body)
|
||||
|
||||
// Add Copilot-Vision-Request header if the request contains vision content
|
||||
if detectVisionContent(body) {
|
||||
@@ -414,7 +418,7 @@ func (e *GitHubCopilotExecutor) ensureAPIToken(ctx context.Context, auth *clipro
|
||||
}
|
||||
|
||||
// applyHeaders sets the required headers for GitHub Copilot API requests.
|
||||
func (e *GitHubCopilotExecutor) applyHeaders(r *http.Request, apiToken string) {
|
||||
func (e *GitHubCopilotExecutor) applyHeaders(r *http.Request, apiToken string, body []byte) {
|
||||
r.Header.Set("Content-Type", "application/json")
|
||||
r.Header.Set("Authorization", "Bearer "+apiToken)
|
||||
r.Header.Set("Accept", "application/json")
|
||||
@@ -424,6 +428,20 @@ func (e *GitHubCopilotExecutor) applyHeaders(r *http.Request, apiToken string) {
|
||||
r.Header.Set("Openai-Intent", copilotOpenAIIntent)
|
||||
r.Header.Set("Copilot-Integration-Id", copilotIntegrationID)
|
||||
r.Header.Set("X-Request-Id", uuid.NewString())
|
||||
|
||||
initiator := "user"
|
||||
if len(body) > 0 {
|
||||
if messages := gjson.GetBytes(body, "messages"); messages.Exists() && messages.IsArray() {
|
||||
arr := messages.Array()
|
||||
if len(arr) > 0 {
|
||||
lastRole := arr[len(arr)-1].Get("role").String()
|
||||
if lastRole != "" && lastRole != "user" {
|
||||
initiator = "agent"
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
r.Header.Set("X-Initiator", initiator)
|
||||
}
|
||||
|
||||
// detectVisionContent checks if the request body contains vision/image content.
|
||||
@@ -454,9 +472,14 @@ func detectVisionContent(body []byte) bool {
|
||||
return false
|
||||
}
|
||||
|
||||
// normalizeModel is a no-op as GitHub Copilot accepts model names directly.
|
||||
// Model mapping should be done at the registry level if needed.
|
||||
func (e *GitHubCopilotExecutor) normalizeModel(_ string, body []byte) []byte {
|
||||
// normalizeModel strips the suffix (e.g. "(medium)") from the model name
|
||||
// before sending to GitHub Copilot, as the upstream API does not accept
|
||||
// suffixed model identifiers.
|
||||
func (e *GitHubCopilotExecutor) normalizeModel(model string, body []byte) []byte {
|
||||
baseModel := thinking.ParseSuffix(model).ModelName
|
||||
if baseModel != model {
|
||||
body, _ = sjson.SetBytes(body, "model", baseModel)
|
||||
}
|
||||
return body
|
||||
}
|
||||
|
||||
@@ -464,6 +487,38 @@ func useGitHubCopilotResponsesEndpoint(sourceFormat sdktranslator.Format) bool {
|
||||
return sourceFormat.String() == "openai-response"
|
||||
}
|
||||
|
||||
// flattenAssistantContent converts assistant message content from array format
|
||||
// to a joined string. GitHub Copilot requires assistant content as a string;
|
||||
// sending it as an array causes Claude models to re-answer all previous prompts.
|
||||
func flattenAssistantContent(body []byte) []byte {
|
||||
messages := gjson.GetBytes(body, "messages")
|
||||
if !messages.Exists() || !messages.IsArray() {
|
||||
return body
|
||||
}
|
||||
result := body
|
||||
for i, msg := range messages.Array() {
|
||||
if msg.Get("role").String() != "assistant" {
|
||||
continue
|
||||
}
|
||||
content := msg.Get("content")
|
||||
if !content.Exists() || !content.IsArray() {
|
||||
continue
|
||||
}
|
||||
var textParts []string
|
||||
for _, part := range content.Array() {
|
||||
if part.Get("type").String() == "text" {
|
||||
if t := part.Get("text").String(); t != "" {
|
||||
textParts = append(textParts, t)
|
||||
}
|
||||
}
|
||||
}
|
||||
joined := strings.Join(textParts, "")
|
||||
path := fmt.Sprintf("messages.%d.content", i)
|
||||
result, _ = sjson.SetBytes(result, path, joined)
|
||||
}
|
||||
return result
|
||||
}
|
||||
|
||||
// isHTTPSuccess checks if the status code indicates success (2xx).
|
||||
func isHTTPSuccess(statusCode int) bool {
|
||||
return statusCode >= 200 && statusCode < 300
|
||||
|
||||
54
internal/runtime/executor/github_copilot_executor_test.go
Normal file
54
internal/runtime/executor/github_copilot_executor_test.go
Normal file
@@ -0,0 +1,54 @@
|
||||
package executor
|
||||
|
||||
import (
|
||||
"testing"
|
||||
|
||||
"github.com/tidwall/gjson"
|
||||
)
|
||||
|
||||
func TestGitHubCopilotNormalizeModel_StripsSuffix(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
tests := []struct {
|
||||
name string
|
||||
model string
|
||||
wantModel string
|
||||
}{
|
||||
{
|
||||
name: "suffix stripped",
|
||||
model: "claude-opus-4.6(medium)",
|
||||
wantModel: "claude-opus-4.6",
|
||||
},
|
||||
{
|
||||
name: "no suffix unchanged",
|
||||
model: "claude-opus-4.6",
|
||||
wantModel: "claude-opus-4.6",
|
||||
},
|
||||
{
|
||||
name: "different suffix stripped",
|
||||
model: "gpt-4o(high)",
|
||||
wantModel: "gpt-4o",
|
||||
},
|
||||
{
|
||||
name: "numeric suffix stripped",
|
||||
model: "gemini-2.5-pro(8192)",
|
||||
wantModel: "gemini-2.5-pro",
|
||||
},
|
||||
}
|
||||
|
||||
e := &GitHubCopilotExecutor{}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
body := []byte(`{"model":"` + tt.model + `","messages":[]}`)
|
||||
got := e.normalizeModel(tt.model, body)
|
||||
|
||||
gotModel := gjson.GetBytes(got, "model").String()
|
||||
if gotModel != tt.wantModel {
|
||||
t.Fatalf("normalizeModel() model = %q, want %q", gotModel, tt.wantModel)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
@@ -4,12 +4,16 @@ import (
|
||||
"bufio"
|
||||
"bytes"
|
||||
"context"
|
||||
"crypto/hmac"
|
||||
"crypto/sha256"
|
||||
"encoding/hex"
|
||||
"fmt"
|
||||
"io"
|
||||
"net/http"
|
||||
"strings"
|
||||
"time"
|
||||
|
||||
"github.com/google/uuid"
|
||||
iflowauth "github.com/router-for-me/CLIProxyAPI/v6/internal/auth/iflow"
|
||||
"github.com/router-for-me/CLIProxyAPI/v6/internal/config"
|
||||
"github.com/router-for-me/CLIProxyAPI/v6/internal/thinking"
|
||||
@@ -87,12 +91,13 @@ func (e *IFlowExecutor) Execute(ctx context.Context, auth *cliproxyauth.Auth, re
|
||||
|
||||
from := opts.SourceFormat
|
||||
to := sdktranslator.FromString("openai")
|
||||
originalPayload := bytes.Clone(req.Payload)
|
||||
originalPayloadSource := req.Payload
|
||||
if len(opts.OriginalRequest) > 0 {
|
||||
originalPayload = bytes.Clone(opts.OriginalRequest)
|
||||
originalPayloadSource = opts.OriginalRequest
|
||||
}
|
||||
originalPayload := originalPayloadSource
|
||||
originalTranslated := sdktranslator.TranslateRequest(from, to, baseModel, originalPayload, false)
|
||||
body := sdktranslator.TranslateRequest(from, to, baseModel, bytes.Clone(req.Payload), false)
|
||||
body := sdktranslator.TranslateRequest(from, to, baseModel, req.Payload, false)
|
||||
body, _ = sjson.SetBytes(body, "model", baseModel)
|
||||
|
||||
body, err = thinking.ApplyThinking(body, req.Model, from.String(), "iflow", e.Identifier())
|
||||
@@ -163,7 +168,7 @@ func (e *IFlowExecutor) Execute(ctx context.Context, auth *cliproxyauth.Auth, re
|
||||
var param any
|
||||
// Note: TranslateNonStream uses req.Model (original with suffix) to preserve
|
||||
// the original model name in the response for client compatibility.
|
||||
out := sdktranslator.TranslateNonStream(ctx, to, from, req.Model, bytes.Clone(opts.OriginalRequest), body, data, ¶m)
|
||||
out := sdktranslator.TranslateNonStream(ctx, to, from, req.Model, opts.OriginalRequest, body, data, ¶m)
|
||||
resp = cliproxyexecutor.Response{Payload: []byte(out)}
|
||||
return resp, nil
|
||||
}
|
||||
@@ -189,12 +194,13 @@ func (e *IFlowExecutor) ExecuteStream(ctx context.Context, auth *cliproxyauth.Au
|
||||
|
||||
from := opts.SourceFormat
|
||||
to := sdktranslator.FromString("openai")
|
||||
originalPayload := bytes.Clone(req.Payload)
|
||||
originalPayloadSource := req.Payload
|
||||
if len(opts.OriginalRequest) > 0 {
|
||||
originalPayload = bytes.Clone(opts.OriginalRequest)
|
||||
originalPayloadSource = opts.OriginalRequest
|
||||
}
|
||||
originalPayload := originalPayloadSource
|
||||
originalTranslated := sdktranslator.TranslateRequest(from, to, baseModel, originalPayload, true)
|
||||
body := sdktranslator.TranslateRequest(from, to, baseModel, bytes.Clone(req.Payload), true)
|
||||
body := sdktranslator.TranslateRequest(from, to, baseModel, req.Payload, true)
|
||||
body, _ = sjson.SetBytes(body, "model", baseModel)
|
||||
|
||||
body, err = thinking.ApplyThinking(body, req.Model, from.String(), "iflow", e.Identifier())
|
||||
@@ -274,7 +280,7 @@ func (e *IFlowExecutor) ExecuteStream(ctx context.Context, auth *cliproxyauth.Au
|
||||
if detail, ok := parseOpenAIStreamUsage(line); ok {
|
||||
reporter.publish(ctx, detail)
|
||||
}
|
||||
chunks := sdktranslator.TranslateStream(ctx, to, from, req.Model, bytes.Clone(opts.OriginalRequest), body, bytes.Clone(line), ¶m)
|
||||
chunks := sdktranslator.TranslateStream(ctx, to, from, req.Model, opts.OriginalRequest, body, bytes.Clone(line), ¶m)
|
||||
for i := range chunks {
|
||||
out <- cliproxyexecutor.StreamChunk{Payload: []byte(chunks[i])}
|
||||
}
|
||||
@@ -296,7 +302,7 @@ func (e *IFlowExecutor) CountTokens(ctx context.Context, auth *cliproxyauth.Auth
|
||||
|
||||
from := opts.SourceFormat
|
||||
to := sdktranslator.FromString("openai")
|
||||
body := sdktranslator.TranslateRequest(from, to, baseModel, bytes.Clone(req.Payload), false)
|
||||
body := sdktranslator.TranslateRequest(from, to, baseModel, req.Payload, false)
|
||||
|
||||
enc, err := tokenizerForModel(baseModel)
|
||||
if err != nil {
|
||||
@@ -451,6 +457,20 @@ func applyIFlowHeaders(r *http.Request, apiKey string, stream bool) {
|
||||
r.Header.Set("Content-Type", "application/json")
|
||||
r.Header.Set("Authorization", "Bearer "+apiKey)
|
||||
r.Header.Set("User-Agent", iflowUserAgent)
|
||||
|
||||
// Generate session-id
|
||||
sessionID := "session-" + generateUUID()
|
||||
r.Header.Set("session-id", sessionID)
|
||||
|
||||
// Generate timestamp and signature
|
||||
timestamp := time.Now().UnixMilli()
|
||||
r.Header.Set("x-iflow-timestamp", fmt.Sprintf("%d", timestamp))
|
||||
|
||||
signature := createIFlowSignature(iflowUserAgent, sessionID, timestamp, apiKey)
|
||||
if signature != "" {
|
||||
r.Header.Set("x-iflow-signature", signature)
|
||||
}
|
||||
|
||||
if stream {
|
||||
r.Header.Set("Accept", "text/event-stream")
|
||||
} else {
|
||||
@@ -458,6 +478,23 @@ func applyIFlowHeaders(r *http.Request, apiKey string, stream bool) {
|
||||
}
|
||||
}
|
||||
|
||||
// createIFlowSignature generates HMAC-SHA256 signature for iFlow API requests.
|
||||
// The signature payload format is: userAgent:sessionId:timestamp
|
||||
func createIFlowSignature(userAgent, sessionID string, timestamp int64, apiKey string) string {
|
||||
if apiKey == "" {
|
||||
return ""
|
||||
}
|
||||
payload := fmt.Sprintf("%s:%s:%d", userAgent, sessionID, timestamp)
|
||||
h := hmac.New(sha256.New, []byte(apiKey))
|
||||
h.Write([]byte(payload))
|
||||
return hex.EncodeToString(h.Sum(nil))
|
||||
}
|
||||
|
||||
// generateUUID generates a random UUID v4 string.
|
||||
func generateUUID() string {
|
||||
return uuid.New().String()
|
||||
}
|
||||
|
||||
func iflowCreds(a *cliproxyauth.Auth) (apiKey, baseURL string) {
|
||||
if a == nil {
|
||||
return "", ""
|
||||
|
||||
618
internal/runtime/executor/kimi_executor.go
Normal file
618
internal/runtime/executor/kimi_executor.go
Normal file
@@ -0,0 +1,618 @@
|
||||
package executor
|
||||
|
||||
import (
|
||||
"bufio"
|
||||
"bytes"
|
||||
"context"
|
||||
"fmt"
|
||||
"io"
|
||||
"net/http"
|
||||
"os"
|
||||
"path/filepath"
|
||||
"runtime"
|
||||
"strings"
|
||||
"time"
|
||||
|
||||
kimiauth "github.com/router-for-me/CLIProxyAPI/v6/internal/auth/kimi"
|
||||
"github.com/router-for-me/CLIProxyAPI/v6/internal/config"
|
||||
"github.com/router-for-me/CLIProxyAPI/v6/internal/thinking"
|
||||
cliproxyauth "github.com/router-for-me/CLIProxyAPI/v6/sdk/cliproxy/auth"
|
||||
cliproxyexecutor "github.com/router-for-me/CLIProxyAPI/v6/sdk/cliproxy/executor"
|
||||
sdktranslator "github.com/router-for-me/CLIProxyAPI/v6/sdk/translator"
|
||||
log "github.com/sirupsen/logrus"
|
||||
"github.com/tidwall/gjson"
|
||||
"github.com/tidwall/sjson"
|
||||
)
|
||||
|
||||
// KimiExecutor is a stateless executor for Kimi API using OpenAI-compatible chat completions.
|
||||
type KimiExecutor struct {
|
||||
ClaudeExecutor
|
||||
cfg *config.Config
|
||||
}
|
||||
|
||||
// NewKimiExecutor creates a new Kimi executor.
|
||||
func NewKimiExecutor(cfg *config.Config) *KimiExecutor { return &KimiExecutor{cfg: cfg} }
|
||||
|
||||
// Identifier returns the executor identifier.
|
||||
func (e *KimiExecutor) Identifier() string { return "kimi" }
|
||||
|
||||
// PrepareRequest injects Kimi credentials into the outgoing HTTP request.
|
||||
func (e *KimiExecutor) PrepareRequest(req *http.Request, auth *cliproxyauth.Auth) error {
|
||||
if req == nil {
|
||||
return nil
|
||||
}
|
||||
token := kimiCreds(auth)
|
||||
if strings.TrimSpace(token) != "" {
|
||||
req.Header.Set("Authorization", "Bearer "+token)
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
// HttpRequest injects Kimi credentials into the request and executes it.
|
||||
func (e *KimiExecutor) HttpRequest(ctx context.Context, auth *cliproxyauth.Auth, req *http.Request) (*http.Response, error) {
|
||||
if req == nil {
|
||||
return nil, fmt.Errorf("kimi executor: request is nil")
|
||||
}
|
||||
if ctx == nil {
|
||||
ctx = req.Context()
|
||||
}
|
||||
httpReq := req.WithContext(ctx)
|
||||
if err := e.PrepareRequest(httpReq, auth); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
httpClient := newProxyAwareHTTPClient(ctx, e.cfg, auth, 0)
|
||||
return httpClient.Do(httpReq)
|
||||
}
|
||||
|
||||
// Execute performs a non-streaming chat completion request to Kimi.
|
||||
func (e *KimiExecutor) Execute(ctx context.Context, auth *cliproxyauth.Auth, req cliproxyexecutor.Request, opts cliproxyexecutor.Options) (resp cliproxyexecutor.Response, err error) {
|
||||
from := opts.SourceFormat
|
||||
if from.String() == "claude" {
|
||||
auth.Attributes["base_url"] = kimiauth.KimiAPIBaseURL
|
||||
return e.ClaudeExecutor.Execute(ctx, auth, req, opts)
|
||||
}
|
||||
|
||||
baseModel := thinking.ParseSuffix(req.Model).ModelName
|
||||
|
||||
token := kimiCreds(auth)
|
||||
|
||||
reporter := newUsageReporter(ctx, e.Identifier(), baseModel, auth)
|
||||
defer reporter.trackFailure(ctx, &err)
|
||||
|
||||
to := sdktranslator.FromString("openai")
|
||||
originalPayloadSource := req.Payload
|
||||
if len(opts.OriginalRequest) > 0 {
|
||||
originalPayloadSource = opts.OriginalRequest
|
||||
}
|
||||
originalPayload := bytes.Clone(originalPayloadSource)
|
||||
originalTranslated := sdktranslator.TranslateRequest(from, to, baseModel, originalPayload, false)
|
||||
body := sdktranslator.TranslateRequest(from, to, baseModel, bytes.Clone(req.Payload), false)
|
||||
|
||||
// Strip kimi- prefix for upstream API
|
||||
upstreamModel := stripKimiPrefix(baseModel)
|
||||
body, err = sjson.SetBytes(body, "model", upstreamModel)
|
||||
if err != nil {
|
||||
return resp, fmt.Errorf("kimi executor: failed to set model in payload: %w", err)
|
||||
}
|
||||
|
||||
body, err = thinking.ApplyThinking(body, req.Model, from.String(), "kimi", e.Identifier())
|
||||
if err != nil {
|
||||
return resp, err
|
||||
}
|
||||
|
||||
requestedModel := payloadRequestedModel(opts, req.Model)
|
||||
body = applyPayloadConfigWithRoot(e.cfg, baseModel, to.String(), "", body, originalTranslated, requestedModel)
|
||||
body, err = normalizeKimiToolMessageLinks(body)
|
||||
if err != nil {
|
||||
return resp, err
|
||||
}
|
||||
|
||||
url := kimiauth.KimiAPIBaseURL + "/v1/chat/completions"
|
||||
httpReq, err := http.NewRequestWithContext(ctx, http.MethodPost, url, bytes.NewReader(body))
|
||||
if err != nil {
|
||||
return resp, err
|
||||
}
|
||||
applyKimiHeadersWithAuth(httpReq, token, false, auth)
|
||||
var authID, authLabel, authType, authValue string
|
||||
if auth != nil {
|
||||
authID = auth.ID
|
||||
authLabel = auth.Label
|
||||
authType, authValue = auth.AccountInfo()
|
||||
}
|
||||
recordAPIRequest(ctx, e.cfg, upstreamRequestLog{
|
||||
URL: url,
|
||||
Method: http.MethodPost,
|
||||
Headers: httpReq.Header.Clone(),
|
||||
Body: body,
|
||||
Provider: e.Identifier(),
|
||||
AuthID: authID,
|
||||
AuthLabel: authLabel,
|
||||
AuthType: authType,
|
||||
AuthValue: authValue,
|
||||
})
|
||||
|
||||
httpClient := newProxyAwareHTTPClient(ctx, e.cfg, auth, 0)
|
||||
httpResp, err := httpClient.Do(httpReq)
|
||||
if err != nil {
|
||||
recordAPIResponseError(ctx, e.cfg, err)
|
||||
return resp, err
|
||||
}
|
||||
defer func() {
|
||||
if errClose := httpResp.Body.Close(); errClose != nil {
|
||||
log.Errorf("kimi executor: close response body error: %v", errClose)
|
||||
}
|
||||
}()
|
||||
recordAPIResponseMetadata(ctx, e.cfg, httpResp.StatusCode, httpResp.Header.Clone())
|
||||
if httpResp.StatusCode < 200 || httpResp.StatusCode >= 300 {
|
||||
b, _ := io.ReadAll(httpResp.Body)
|
||||
appendAPIResponseChunk(ctx, e.cfg, b)
|
||||
logWithRequestID(ctx).Debugf("request error, error status: %d, error message: %s", httpResp.StatusCode, summarizeErrorBody(httpResp.Header.Get("Content-Type"), b))
|
||||
err = statusErr{code: httpResp.StatusCode, msg: string(b)}
|
||||
return resp, err
|
||||
}
|
||||
data, err := io.ReadAll(httpResp.Body)
|
||||
if err != nil {
|
||||
recordAPIResponseError(ctx, e.cfg, err)
|
||||
return resp, err
|
||||
}
|
||||
appendAPIResponseChunk(ctx, e.cfg, data)
|
||||
reporter.publish(ctx, parseOpenAIUsage(data))
|
||||
var param any
|
||||
// Note: TranslateNonStream uses req.Model (original with suffix) to preserve
|
||||
// the original model name in the response for client compatibility.
|
||||
out := sdktranslator.TranslateNonStream(ctx, to, from, req.Model, opts.OriginalRequest, body, data, ¶m)
|
||||
resp = cliproxyexecutor.Response{Payload: []byte(out)}
|
||||
return resp, nil
|
||||
}
|
||||
|
||||
// ExecuteStream performs a streaming chat completion request to Kimi.
|
||||
func (e *KimiExecutor) ExecuteStream(ctx context.Context, auth *cliproxyauth.Auth, req cliproxyexecutor.Request, opts cliproxyexecutor.Options) (stream <-chan cliproxyexecutor.StreamChunk, err error) {
|
||||
from := opts.SourceFormat
|
||||
if from.String() == "claude" {
|
||||
auth.Attributes["base_url"] = kimiauth.KimiAPIBaseURL
|
||||
return e.ClaudeExecutor.ExecuteStream(ctx, auth, req, opts)
|
||||
}
|
||||
|
||||
baseModel := thinking.ParseSuffix(req.Model).ModelName
|
||||
token := kimiCreds(auth)
|
||||
|
||||
reporter := newUsageReporter(ctx, e.Identifier(), baseModel, auth)
|
||||
defer reporter.trackFailure(ctx, &err)
|
||||
|
||||
to := sdktranslator.FromString("openai")
|
||||
originalPayloadSource := req.Payload
|
||||
if len(opts.OriginalRequest) > 0 {
|
||||
originalPayloadSource = opts.OriginalRequest
|
||||
}
|
||||
originalPayload := bytes.Clone(originalPayloadSource)
|
||||
originalTranslated := sdktranslator.TranslateRequest(from, to, baseModel, originalPayload, true)
|
||||
body := sdktranslator.TranslateRequest(from, to, baseModel, bytes.Clone(req.Payload), true)
|
||||
|
||||
// Strip kimi- prefix for upstream API
|
||||
upstreamModel := stripKimiPrefix(baseModel)
|
||||
body, err = sjson.SetBytes(body, "model", upstreamModel)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("kimi executor: failed to set model in payload: %w", err)
|
||||
}
|
||||
|
||||
body, err = thinking.ApplyThinking(body, req.Model, from.String(), "kimi", e.Identifier())
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
body, err = sjson.SetBytes(body, "stream_options.include_usage", true)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("kimi executor: failed to set stream_options in payload: %w", err)
|
||||
}
|
||||
requestedModel := payloadRequestedModel(opts, req.Model)
|
||||
body = applyPayloadConfigWithRoot(e.cfg, baseModel, to.String(), "", body, originalTranslated, requestedModel)
|
||||
body, err = normalizeKimiToolMessageLinks(body)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
url := kimiauth.KimiAPIBaseURL + "/v1/chat/completions"
|
||||
httpReq, err := http.NewRequestWithContext(ctx, http.MethodPost, url, bytes.NewReader(body))
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
applyKimiHeadersWithAuth(httpReq, token, true, auth)
|
||||
var authID, authLabel, authType, authValue string
|
||||
if auth != nil {
|
||||
authID = auth.ID
|
||||
authLabel = auth.Label
|
||||
authType, authValue = auth.AccountInfo()
|
||||
}
|
||||
recordAPIRequest(ctx, e.cfg, upstreamRequestLog{
|
||||
URL: url,
|
||||
Method: http.MethodPost,
|
||||
Headers: httpReq.Header.Clone(),
|
||||
Body: body,
|
||||
Provider: e.Identifier(),
|
||||
AuthID: authID,
|
||||
AuthLabel: authLabel,
|
||||
AuthType: authType,
|
||||
AuthValue: authValue,
|
||||
})
|
||||
|
||||
httpClient := newProxyAwareHTTPClient(ctx, e.cfg, auth, 0)
|
||||
httpResp, err := httpClient.Do(httpReq)
|
||||
if err != nil {
|
||||
recordAPIResponseError(ctx, e.cfg, err)
|
||||
return nil, err
|
||||
}
|
||||
recordAPIResponseMetadata(ctx, e.cfg, httpResp.StatusCode, httpResp.Header.Clone())
|
||||
if httpResp.StatusCode < 200 || httpResp.StatusCode >= 300 {
|
||||
b, _ := io.ReadAll(httpResp.Body)
|
||||
appendAPIResponseChunk(ctx, e.cfg, b)
|
||||
logWithRequestID(ctx).Debugf("request error, error status: %d, error message: %s", httpResp.StatusCode, summarizeErrorBody(httpResp.Header.Get("Content-Type"), b))
|
||||
if errClose := httpResp.Body.Close(); errClose != nil {
|
||||
log.Errorf("kimi executor: close response body error: %v", errClose)
|
||||
}
|
||||
err = statusErr{code: httpResp.StatusCode, msg: string(b)}
|
||||
return nil, err
|
||||
}
|
||||
out := make(chan cliproxyexecutor.StreamChunk)
|
||||
stream = out
|
||||
go func() {
|
||||
defer close(out)
|
||||
defer func() {
|
||||
if errClose := httpResp.Body.Close(); errClose != nil {
|
||||
log.Errorf("kimi executor: close response body error: %v", errClose)
|
||||
}
|
||||
}()
|
||||
scanner := bufio.NewScanner(httpResp.Body)
|
||||
scanner.Buffer(nil, 1_048_576) // 1MB
|
||||
var param any
|
||||
for scanner.Scan() {
|
||||
line := scanner.Bytes()
|
||||
appendAPIResponseChunk(ctx, e.cfg, line)
|
||||
if detail, ok := parseOpenAIStreamUsage(line); ok {
|
||||
reporter.publish(ctx, detail)
|
||||
}
|
||||
chunks := sdktranslator.TranslateStream(ctx, to, from, req.Model, opts.OriginalRequest, body, bytes.Clone(line), ¶m)
|
||||
for i := range chunks {
|
||||
out <- cliproxyexecutor.StreamChunk{Payload: []byte(chunks[i])}
|
||||
}
|
||||
}
|
||||
doneChunks := sdktranslator.TranslateStream(ctx, to, from, req.Model, opts.OriginalRequest, body, []byte("[DONE]"), ¶m)
|
||||
for i := range doneChunks {
|
||||
out <- cliproxyexecutor.StreamChunk{Payload: []byte(doneChunks[i])}
|
||||
}
|
||||
if errScan := scanner.Err(); errScan != nil {
|
||||
recordAPIResponseError(ctx, e.cfg, errScan)
|
||||
reporter.publishFailure(ctx)
|
||||
out <- cliproxyexecutor.StreamChunk{Err: errScan}
|
||||
}
|
||||
}()
|
||||
return stream, nil
|
||||
}
|
||||
|
||||
// CountTokens estimates token count for Kimi requests.
|
||||
func (e *KimiExecutor) CountTokens(ctx context.Context, auth *cliproxyauth.Auth, req cliproxyexecutor.Request, opts cliproxyexecutor.Options) (cliproxyexecutor.Response, error) {
|
||||
auth.Attributes["base_url"] = kimiauth.KimiAPIBaseURL
|
||||
return e.ClaudeExecutor.CountTokens(ctx, auth, req, opts)
|
||||
}
|
||||
|
||||
func normalizeKimiToolMessageLinks(body []byte) ([]byte, error) {
|
||||
if len(body) == 0 || !gjson.ValidBytes(body) {
|
||||
return body, nil
|
||||
}
|
||||
|
||||
messages := gjson.GetBytes(body, "messages")
|
||||
if !messages.Exists() || !messages.IsArray() {
|
||||
return body, nil
|
||||
}
|
||||
|
||||
out := body
|
||||
pending := make([]string, 0)
|
||||
patched := 0
|
||||
patchedReasoning := 0
|
||||
ambiguous := 0
|
||||
latestReasoning := ""
|
||||
hasLatestReasoning := false
|
||||
|
||||
removePending := func(id string) {
|
||||
for idx := range pending {
|
||||
if pending[idx] != id {
|
||||
continue
|
||||
}
|
||||
pending = append(pending[:idx], pending[idx+1:]...)
|
||||
return
|
||||
}
|
||||
}
|
||||
|
||||
msgs := messages.Array()
|
||||
for msgIdx := range msgs {
|
||||
msg := msgs[msgIdx]
|
||||
role := strings.TrimSpace(msg.Get("role").String())
|
||||
switch role {
|
||||
case "assistant":
|
||||
reasoning := msg.Get("reasoning_content")
|
||||
if reasoning.Exists() {
|
||||
reasoningText := reasoning.String()
|
||||
if strings.TrimSpace(reasoningText) != "" {
|
||||
latestReasoning = reasoningText
|
||||
hasLatestReasoning = true
|
||||
}
|
||||
}
|
||||
|
||||
toolCalls := msg.Get("tool_calls")
|
||||
if !toolCalls.Exists() || !toolCalls.IsArray() || len(toolCalls.Array()) == 0 {
|
||||
continue
|
||||
}
|
||||
|
||||
if !reasoning.Exists() || strings.TrimSpace(reasoning.String()) == "" {
|
||||
reasoningText := fallbackAssistantReasoning(msg, hasLatestReasoning, latestReasoning)
|
||||
path := fmt.Sprintf("messages.%d.reasoning_content", msgIdx)
|
||||
next, err := sjson.SetBytes(out, path, reasoningText)
|
||||
if err != nil {
|
||||
return body, fmt.Errorf("kimi executor: failed to set assistant reasoning_content: %w", err)
|
||||
}
|
||||
out = next
|
||||
patchedReasoning++
|
||||
}
|
||||
|
||||
for _, tc := range toolCalls.Array() {
|
||||
id := strings.TrimSpace(tc.Get("id").String())
|
||||
if id == "" {
|
||||
continue
|
||||
}
|
||||
pending = append(pending, id)
|
||||
}
|
||||
case "tool":
|
||||
toolCallID := strings.TrimSpace(msg.Get("tool_call_id").String())
|
||||
if toolCallID == "" {
|
||||
toolCallID = strings.TrimSpace(msg.Get("call_id").String())
|
||||
if toolCallID != "" {
|
||||
path := fmt.Sprintf("messages.%d.tool_call_id", msgIdx)
|
||||
next, err := sjson.SetBytes(out, path, toolCallID)
|
||||
if err != nil {
|
||||
return body, fmt.Errorf("kimi executor: failed to set tool_call_id from call_id: %w", err)
|
||||
}
|
||||
out = next
|
||||
patched++
|
||||
}
|
||||
}
|
||||
if toolCallID == "" {
|
||||
if len(pending) == 1 {
|
||||
toolCallID = pending[0]
|
||||
path := fmt.Sprintf("messages.%d.tool_call_id", msgIdx)
|
||||
next, err := sjson.SetBytes(out, path, toolCallID)
|
||||
if err != nil {
|
||||
return body, fmt.Errorf("kimi executor: failed to infer tool_call_id: %w", err)
|
||||
}
|
||||
out = next
|
||||
patched++
|
||||
} else if len(pending) > 1 {
|
||||
ambiguous++
|
||||
}
|
||||
}
|
||||
if toolCallID != "" {
|
||||
removePending(toolCallID)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
if patched > 0 || patchedReasoning > 0 {
|
||||
log.WithFields(log.Fields{
|
||||
"patched_tool_messages": patched,
|
||||
"patched_reasoning_messages": patchedReasoning,
|
||||
}).Debug("kimi executor: normalized tool message fields")
|
||||
}
|
||||
if ambiguous > 0 {
|
||||
log.WithFields(log.Fields{
|
||||
"ambiguous_tool_messages": ambiguous,
|
||||
"pending_tool_calls": len(pending),
|
||||
}).Warn("kimi executor: tool messages missing tool_call_id with ambiguous candidates")
|
||||
}
|
||||
|
||||
return out, nil
|
||||
}
|
||||
|
||||
func fallbackAssistantReasoning(msg gjson.Result, hasLatest bool, latest string) string {
|
||||
if hasLatest && strings.TrimSpace(latest) != "" {
|
||||
return latest
|
||||
}
|
||||
|
||||
content := msg.Get("content")
|
||||
if content.Type == gjson.String {
|
||||
if text := strings.TrimSpace(content.String()); text != "" {
|
||||
return text
|
||||
}
|
||||
}
|
||||
if content.IsArray() {
|
||||
parts := make([]string, 0, len(content.Array()))
|
||||
for _, item := range content.Array() {
|
||||
text := strings.TrimSpace(item.Get("text").String())
|
||||
if text == "" {
|
||||
continue
|
||||
}
|
||||
parts = append(parts, text)
|
||||
}
|
||||
if len(parts) > 0 {
|
||||
return strings.Join(parts, "\n")
|
||||
}
|
||||
}
|
||||
|
||||
return "[reasoning unavailable]"
|
||||
}
|
||||
|
||||
// Refresh refreshes the Kimi token using the refresh token.
|
||||
func (e *KimiExecutor) Refresh(ctx context.Context, auth *cliproxyauth.Auth) (*cliproxyauth.Auth, error) {
|
||||
log.Debugf("kimi executor: refresh called")
|
||||
if auth == nil {
|
||||
return nil, fmt.Errorf("kimi executor: auth is nil")
|
||||
}
|
||||
// Expect refresh_token in metadata for OAuth-based accounts
|
||||
var refreshToken string
|
||||
if auth.Metadata != nil {
|
||||
if v, ok := auth.Metadata["refresh_token"].(string); ok && strings.TrimSpace(v) != "" {
|
||||
refreshToken = v
|
||||
}
|
||||
}
|
||||
if strings.TrimSpace(refreshToken) == "" {
|
||||
// Nothing to refresh
|
||||
return auth, nil
|
||||
}
|
||||
|
||||
client := kimiauth.NewDeviceFlowClientWithDeviceID(e.cfg, resolveKimiDeviceID(auth))
|
||||
td, err := client.RefreshToken(ctx, refreshToken)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
if auth.Metadata == nil {
|
||||
auth.Metadata = make(map[string]any)
|
||||
}
|
||||
auth.Metadata["access_token"] = td.AccessToken
|
||||
if td.RefreshToken != "" {
|
||||
auth.Metadata["refresh_token"] = td.RefreshToken
|
||||
}
|
||||
if td.ExpiresAt > 0 {
|
||||
exp := time.Unix(td.ExpiresAt, 0).UTC().Format(time.RFC3339)
|
||||
auth.Metadata["expired"] = exp
|
||||
}
|
||||
auth.Metadata["type"] = "kimi"
|
||||
now := time.Now().Format(time.RFC3339)
|
||||
auth.Metadata["last_refresh"] = now
|
||||
return auth, nil
|
||||
}
|
||||
|
||||
// applyKimiHeaders sets required headers for Kimi API requests.
|
||||
// Headers match kimi-cli client for compatibility.
|
||||
func applyKimiHeaders(r *http.Request, token string, stream bool) {
|
||||
r.Header.Set("Content-Type", "application/json")
|
||||
r.Header.Set("Authorization", "Bearer "+token)
|
||||
// Match kimi-cli headers exactly
|
||||
r.Header.Set("User-Agent", "KimiCLI/1.10.6")
|
||||
r.Header.Set("X-Msh-Platform", "kimi_cli")
|
||||
r.Header.Set("X-Msh-Version", "1.10.6")
|
||||
r.Header.Set("X-Msh-Device-Name", getKimiHostname())
|
||||
r.Header.Set("X-Msh-Device-Model", getKimiDeviceModel())
|
||||
r.Header.Set("X-Msh-Device-Id", getKimiDeviceID())
|
||||
if stream {
|
||||
r.Header.Set("Accept", "text/event-stream")
|
||||
return
|
||||
}
|
||||
r.Header.Set("Accept", "application/json")
|
||||
}
|
||||
|
||||
func resolveKimiDeviceIDFromAuth(auth *cliproxyauth.Auth) string {
|
||||
if auth == nil || auth.Metadata == nil {
|
||||
return ""
|
||||
}
|
||||
|
||||
deviceIDRaw, ok := auth.Metadata["device_id"]
|
||||
if !ok {
|
||||
return ""
|
||||
}
|
||||
|
||||
deviceID, ok := deviceIDRaw.(string)
|
||||
if !ok {
|
||||
return ""
|
||||
}
|
||||
|
||||
return strings.TrimSpace(deviceID)
|
||||
}
|
||||
|
||||
func resolveKimiDeviceIDFromStorage(auth *cliproxyauth.Auth) string {
|
||||
if auth == nil {
|
||||
return ""
|
||||
}
|
||||
|
||||
storage, ok := auth.Storage.(*kimiauth.KimiTokenStorage)
|
||||
if !ok || storage == nil {
|
||||
return ""
|
||||
}
|
||||
|
||||
return strings.TrimSpace(storage.DeviceID)
|
||||
}
|
||||
|
||||
func resolveKimiDeviceID(auth *cliproxyauth.Auth) string {
|
||||
deviceID := resolveKimiDeviceIDFromAuth(auth)
|
||||
if deviceID != "" {
|
||||
return deviceID
|
||||
}
|
||||
return resolveKimiDeviceIDFromStorage(auth)
|
||||
}
|
||||
|
||||
func applyKimiHeadersWithAuth(r *http.Request, token string, stream bool, auth *cliproxyauth.Auth) {
|
||||
applyKimiHeaders(r, token, stream)
|
||||
|
||||
if deviceID := resolveKimiDeviceID(auth); deviceID != "" {
|
||||
r.Header.Set("X-Msh-Device-Id", deviceID)
|
||||
}
|
||||
}
|
||||
|
||||
// getKimiHostname returns the machine hostname.
|
||||
func getKimiHostname() string {
|
||||
hostname, err := os.Hostname()
|
||||
if err != nil {
|
||||
return "unknown"
|
||||
}
|
||||
return hostname
|
||||
}
|
||||
|
||||
// getKimiDeviceModel returns a device model string matching kimi-cli format.
|
||||
func getKimiDeviceModel() string {
|
||||
return fmt.Sprintf("%s %s", runtime.GOOS, runtime.GOARCH)
|
||||
}
|
||||
|
||||
// getKimiDeviceID returns a stable device ID, matching kimi-cli storage location.
|
||||
func getKimiDeviceID() string {
|
||||
homeDir, err := os.UserHomeDir()
|
||||
if err != nil {
|
||||
return "cli-proxy-api-device"
|
||||
}
|
||||
// Check kimi-cli's device_id location first (platform-specific)
|
||||
var kimiShareDir string
|
||||
switch runtime.GOOS {
|
||||
case "darwin":
|
||||
kimiShareDir = filepath.Join(homeDir, "Library", "Application Support", "kimi")
|
||||
case "windows":
|
||||
appData := os.Getenv("APPDATA")
|
||||
if appData == "" {
|
||||
appData = filepath.Join(homeDir, "AppData", "Roaming")
|
||||
}
|
||||
kimiShareDir = filepath.Join(appData, "kimi")
|
||||
default: // linux and other unix-like
|
||||
kimiShareDir = filepath.Join(homeDir, ".local", "share", "kimi")
|
||||
}
|
||||
deviceIDPath := filepath.Join(kimiShareDir, "device_id")
|
||||
if data, err := os.ReadFile(deviceIDPath); err == nil {
|
||||
return strings.TrimSpace(string(data))
|
||||
}
|
||||
return "cli-proxy-api-device"
|
||||
}
|
||||
|
||||
// kimiCreds extracts the access token from auth.
|
||||
func kimiCreds(a *cliproxyauth.Auth) (token string) {
|
||||
if a == nil {
|
||||
return ""
|
||||
}
|
||||
// Check metadata first (OAuth flow stores tokens here)
|
||||
if a.Metadata != nil {
|
||||
if v, ok := a.Metadata["access_token"].(string); ok && strings.TrimSpace(v) != "" {
|
||||
return v
|
||||
}
|
||||
}
|
||||
// Fallback to attributes (API key style)
|
||||
if a.Attributes != nil {
|
||||
if v := a.Attributes["access_token"]; v != "" {
|
||||
return v
|
||||
}
|
||||
if v := a.Attributes["api_key"]; v != "" {
|
||||
return v
|
||||
}
|
||||
}
|
||||
return ""
|
||||
}
|
||||
|
||||
// stripKimiPrefix removes the "kimi-" prefix from model names for the upstream API.
|
||||
func stripKimiPrefix(model string) string {
|
||||
model = strings.TrimSpace(model)
|
||||
if strings.HasPrefix(strings.ToLower(model), "kimi-") {
|
||||
return model[5:]
|
||||
}
|
||||
return model
|
||||
}
|
||||
205
internal/runtime/executor/kimi_executor_test.go
Normal file
205
internal/runtime/executor/kimi_executor_test.go
Normal file
@@ -0,0 +1,205 @@
|
||||
package executor
|
||||
|
||||
import (
|
||||
"testing"
|
||||
|
||||
"github.com/tidwall/gjson"
|
||||
)
|
||||
|
||||
func TestNormalizeKimiToolMessageLinks_UsesCallIDFallback(t *testing.T) {
|
||||
body := []byte(`{
|
||||
"messages":[
|
||||
{"role":"assistant","tool_calls":[{"id":"list_directory:1","type":"function","function":{"name":"list_directory","arguments":"{}"}}]},
|
||||
{"role":"tool","call_id":"list_directory:1","content":"[]"}
|
||||
]
|
||||
}`)
|
||||
|
||||
out, err := normalizeKimiToolMessageLinks(body)
|
||||
if err != nil {
|
||||
t.Fatalf("normalizeKimiToolMessageLinks() error = %v", err)
|
||||
}
|
||||
|
||||
got := gjson.GetBytes(out, "messages.1.tool_call_id").String()
|
||||
if got != "list_directory:1" {
|
||||
t.Fatalf("messages.1.tool_call_id = %q, want %q", got, "list_directory:1")
|
||||
}
|
||||
}
|
||||
|
||||
func TestNormalizeKimiToolMessageLinks_InferSinglePendingID(t *testing.T) {
|
||||
body := []byte(`{
|
||||
"messages":[
|
||||
{"role":"assistant","tool_calls":[{"id":"call_123","type":"function","function":{"name":"read_file","arguments":"{}"}}]},
|
||||
{"role":"tool","content":"file-content"}
|
||||
]
|
||||
}`)
|
||||
|
||||
out, err := normalizeKimiToolMessageLinks(body)
|
||||
if err != nil {
|
||||
t.Fatalf("normalizeKimiToolMessageLinks() error = %v", err)
|
||||
}
|
||||
|
||||
got := gjson.GetBytes(out, "messages.1.tool_call_id").String()
|
||||
if got != "call_123" {
|
||||
t.Fatalf("messages.1.tool_call_id = %q, want %q", got, "call_123")
|
||||
}
|
||||
}
|
||||
|
||||
func TestNormalizeKimiToolMessageLinks_AmbiguousMissingIDIsNotInferred(t *testing.T) {
|
||||
body := []byte(`{
|
||||
"messages":[
|
||||
{"role":"assistant","tool_calls":[
|
||||
{"id":"call_1","type":"function","function":{"name":"list_directory","arguments":"{}"}},
|
||||
{"id":"call_2","type":"function","function":{"name":"read_file","arguments":"{}"}}
|
||||
]},
|
||||
{"role":"tool","content":"result-without-id"}
|
||||
]
|
||||
}`)
|
||||
|
||||
out, err := normalizeKimiToolMessageLinks(body)
|
||||
if err != nil {
|
||||
t.Fatalf("normalizeKimiToolMessageLinks() error = %v", err)
|
||||
}
|
||||
|
||||
if gjson.GetBytes(out, "messages.1.tool_call_id").Exists() {
|
||||
t.Fatalf("messages.1.tool_call_id should be absent for ambiguous case, got %q", gjson.GetBytes(out, "messages.1.tool_call_id").String())
|
||||
}
|
||||
}
|
||||
|
||||
func TestNormalizeKimiToolMessageLinks_PreservesExistingToolCallID(t *testing.T) {
|
||||
body := []byte(`{
|
||||
"messages":[
|
||||
{"role":"assistant","tool_calls":[{"id":"call_1","type":"function","function":{"name":"list_directory","arguments":"{}"}}]},
|
||||
{"role":"tool","tool_call_id":"call_1","call_id":"different-id","content":"result"}
|
||||
]
|
||||
}`)
|
||||
|
||||
out, err := normalizeKimiToolMessageLinks(body)
|
||||
if err != nil {
|
||||
t.Fatalf("normalizeKimiToolMessageLinks() error = %v", err)
|
||||
}
|
||||
|
||||
got := gjson.GetBytes(out, "messages.1.tool_call_id").String()
|
||||
if got != "call_1" {
|
||||
t.Fatalf("messages.1.tool_call_id = %q, want %q", got, "call_1")
|
||||
}
|
||||
}
|
||||
|
||||
func TestNormalizeKimiToolMessageLinks_InheritsPreviousReasoningForAssistantToolCalls(t *testing.T) {
|
||||
body := []byte(`{
|
||||
"messages":[
|
||||
{"role":"assistant","content":"plan","reasoning_content":"previous reasoning"},
|
||||
{"role":"assistant","tool_calls":[{"id":"call_1","type":"function","function":{"name":"list_directory","arguments":"{}"}}]}
|
||||
]
|
||||
}`)
|
||||
|
||||
out, err := normalizeKimiToolMessageLinks(body)
|
||||
if err != nil {
|
||||
t.Fatalf("normalizeKimiToolMessageLinks() error = %v", err)
|
||||
}
|
||||
|
||||
got := gjson.GetBytes(out, "messages.1.reasoning_content").String()
|
||||
if got != "previous reasoning" {
|
||||
t.Fatalf("messages.1.reasoning_content = %q, want %q", got, "previous reasoning")
|
||||
}
|
||||
}
|
||||
|
||||
func TestNormalizeKimiToolMessageLinks_InsertsFallbackReasoningWhenMissing(t *testing.T) {
|
||||
body := []byte(`{
|
||||
"messages":[
|
||||
{"role":"assistant","tool_calls":[{"id":"call_1","type":"function","function":{"name":"list_directory","arguments":"{}"}}]}
|
||||
]
|
||||
}`)
|
||||
|
||||
out, err := normalizeKimiToolMessageLinks(body)
|
||||
if err != nil {
|
||||
t.Fatalf("normalizeKimiToolMessageLinks() error = %v", err)
|
||||
}
|
||||
|
||||
reasoning := gjson.GetBytes(out, "messages.0.reasoning_content")
|
||||
if !reasoning.Exists() {
|
||||
t.Fatalf("messages.0.reasoning_content should exist")
|
||||
}
|
||||
if reasoning.String() != "[reasoning unavailable]" {
|
||||
t.Fatalf("messages.0.reasoning_content = %q, want %q", reasoning.String(), "[reasoning unavailable]")
|
||||
}
|
||||
}
|
||||
|
||||
func TestNormalizeKimiToolMessageLinks_UsesContentAsReasoningFallback(t *testing.T) {
|
||||
body := []byte(`{
|
||||
"messages":[
|
||||
{"role":"assistant","content":[{"type":"text","text":"first line"},{"type":"text","text":"second line"}],"tool_calls":[{"id":"call_1","type":"function","function":{"name":"list_directory","arguments":"{}"}}]}
|
||||
]
|
||||
}`)
|
||||
|
||||
out, err := normalizeKimiToolMessageLinks(body)
|
||||
if err != nil {
|
||||
t.Fatalf("normalizeKimiToolMessageLinks() error = %v", err)
|
||||
}
|
||||
|
||||
got := gjson.GetBytes(out, "messages.0.reasoning_content").String()
|
||||
if got != "first line\nsecond line" {
|
||||
t.Fatalf("messages.0.reasoning_content = %q, want %q", got, "first line\nsecond line")
|
||||
}
|
||||
}
|
||||
|
||||
func TestNormalizeKimiToolMessageLinks_ReplacesEmptyReasoningContent(t *testing.T) {
|
||||
body := []byte(`{
|
||||
"messages":[
|
||||
{"role":"assistant","content":"assistant summary","tool_calls":[{"id":"call_1","type":"function","function":{"name":"list_directory","arguments":"{}"}}],"reasoning_content":""}
|
||||
]
|
||||
}`)
|
||||
|
||||
out, err := normalizeKimiToolMessageLinks(body)
|
||||
if err != nil {
|
||||
t.Fatalf("normalizeKimiToolMessageLinks() error = %v", err)
|
||||
}
|
||||
|
||||
got := gjson.GetBytes(out, "messages.0.reasoning_content").String()
|
||||
if got != "assistant summary" {
|
||||
t.Fatalf("messages.0.reasoning_content = %q, want %q", got, "assistant summary")
|
||||
}
|
||||
}
|
||||
|
||||
func TestNormalizeKimiToolMessageLinks_PreservesExistingAssistantReasoning(t *testing.T) {
|
||||
body := []byte(`{
|
||||
"messages":[
|
||||
{"role":"assistant","tool_calls":[{"id":"call_1","type":"function","function":{"name":"list_directory","arguments":"{}"}}],"reasoning_content":"keep me"}
|
||||
]
|
||||
}`)
|
||||
|
||||
out, err := normalizeKimiToolMessageLinks(body)
|
||||
if err != nil {
|
||||
t.Fatalf("normalizeKimiToolMessageLinks() error = %v", err)
|
||||
}
|
||||
|
||||
got := gjson.GetBytes(out, "messages.0.reasoning_content").String()
|
||||
if got != "keep me" {
|
||||
t.Fatalf("messages.0.reasoning_content = %q, want %q", got, "keep me")
|
||||
}
|
||||
}
|
||||
|
||||
func TestNormalizeKimiToolMessageLinks_RepairsIDsAndReasoningTogether(t *testing.T) {
|
||||
body := []byte(`{
|
||||
"messages":[
|
||||
{"role":"assistant","tool_calls":[{"id":"call_1","type":"function","function":{"name":"list_directory","arguments":"{}"}}],"reasoning_content":"r1"},
|
||||
{"role":"tool","call_id":"call_1","content":"[]"},
|
||||
{"role":"assistant","tool_calls":[{"id":"call_2","type":"function","function":{"name":"read_file","arguments":"{}"}}]},
|
||||
{"role":"tool","call_id":"call_2","content":"file"}
|
||||
]
|
||||
}`)
|
||||
|
||||
out, err := normalizeKimiToolMessageLinks(body)
|
||||
if err != nil {
|
||||
t.Fatalf("normalizeKimiToolMessageLinks() error = %v", err)
|
||||
}
|
||||
|
||||
if got := gjson.GetBytes(out, "messages.1.tool_call_id").String(); got != "call_1" {
|
||||
t.Fatalf("messages.1.tool_call_id = %q, want %q", got, "call_1")
|
||||
}
|
||||
if got := gjson.GetBytes(out, "messages.3.tool_call_id").String(); got != "call_2" {
|
||||
t.Fatalf("messages.3.tool_call_id = %q, want %q", got, "call_2")
|
||||
}
|
||||
if got := gjson.GetBytes(out, "messages.2.reasoning_content").String(); got != "r1" {
|
||||
t.Fatalf("messages.2.reasoning_content = %q, want %q", got, "r1")
|
||||
}
|
||||
}
|
||||
@@ -519,8 +519,12 @@ func buildKiroPayloadForFormat(body []byte, modelID, profileArn, origin string,
|
||||
case "openai":
|
||||
log.Debugf("kiro: using OpenAI payload builder for source format: %s", sourceFormat.String())
|
||||
return kiroopenai.BuildKiroPayloadFromOpenAI(body, modelID, profileArn, origin, isAgentic, isChatOnly, headers, nil)
|
||||
case "kiro":
|
||||
// Body is already in Kiro format — pass through directly (used by callKiroRawAndBuffer)
|
||||
log.Debugf("kiro: body already in Kiro format, passing through directly")
|
||||
return body, false
|
||||
default:
|
||||
// Default to Claude format (also handles "claude", "kiro", etc.)
|
||||
// Default to Claude format
|
||||
log.Debugf("kiro: using Claude payload builder for source format: %s", sourceFormat.String())
|
||||
return kiroclaude.BuildKiroPayload(body, modelID, profileArn, origin, isAgentic, isChatOnly, headers, nil)
|
||||
}
|
||||
@@ -636,6 +640,13 @@ func (e *KiroExecutor) Execute(ctx context.Context, auth *cliproxyauth.Auth, req
|
||||
rateLimiter.WaitForToken(tokenKey)
|
||||
log.Debugf("kiro: rate limiter cleared for token %s", tokenKey)
|
||||
|
||||
// Check for pure web_search request
|
||||
// Route to MCP endpoint instead of normal Kiro API
|
||||
if kiroclaude.HasWebSearchTool(req.Payload) {
|
||||
log.Infof("kiro: detected pure web_search request (non-stream), routing to MCP endpoint")
|
||||
return e.handleWebSearch(ctx, auth, req, opts, accessToken, profileArn)
|
||||
}
|
||||
|
||||
reporter := newUsageReporter(ctx, e.Identifier(), req.Model, auth)
|
||||
defer reporter.trackFailure(ctx, &err)
|
||||
|
||||
@@ -1057,6 +1068,13 @@ func (e *KiroExecutor) ExecuteStream(ctx context.Context, auth *cliproxyauth.Aut
|
||||
rateLimiter.WaitForToken(tokenKey)
|
||||
log.Debugf("kiro: stream rate limiter cleared for token %s", tokenKey)
|
||||
|
||||
// Check for pure web_search request
|
||||
// Route to MCP endpoint instead of normal Kiro API
|
||||
if kiroclaude.HasWebSearchTool(req.Payload) {
|
||||
log.Infof("kiro: detected pure web_search request, routing to MCP endpoint")
|
||||
return e.handleWebSearchStream(ctx, auth, req, opts, accessToken, profileArn)
|
||||
}
|
||||
|
||||
reporter := newUsageReporter(ctx, e.Identifier(), req.Model, auth)
|
||||
defer reporter.trackFailure(ctx, &err)
|
||||
|
||||
@@ -1681,6 +1699,7 @@ func (e *KiroExecutor) mapModelToKiro(model string) string {
|
||||
modelMap := map[string]string{
|
||||
// Amazon Q format (amazonq- prefix) - same API as Kiro
|
||||
"amazonq-auto": "auto",
|
||||
"amazonq-claude-opus-4-6": "claude-opus-4.6",
|
||||
"amazonq-claude-opus-4-5": "claude-opus-4.5",
|
||||
"amazonq-claude-sonnet-4-5": "claude-sonnet-4.5",
|
||||
"amazonq-claude-sonnet-4-5-20250929": "claude-sonnet-4.5",
|
||||
@@ -1688,6 +1707,7 @@ func (e *KiroExecutor) mapModelToKiro(model string) string {
|
||||
"amazonq-claude-sonnet-4-20250514": "claude-sonnet-4",
|
||||
"amazonq-claude-haiku-4-5": "claude-haiku-4.5",
|
||||
// Kiro format (kiro- prefix) - valid model names that should be preserved
|
||||
"kiro-claude-opus-4-6": "claude-opus-4.6",
|
||||
"kiro-claude-opus-4-5": "claude-opus-4.5",
|
||||
"kiro-claude-sonnet-4-5": "claude-sonnet-4.5",
|
||||
"kiro-claude-sonnet-4-5-20250929": "claude-sonnet-4.5",
|
||||
@@ -1696,6 +1716,8 @@ func (e *KiroExecutor) mapModelToKiro(model string) string {
|
||||
"kiro-claude-haiku-4-5": "claude-haiku-4.5",
|
||||
"kiro-auto": "auto",
|
||||
// Native format (no prefix) - used by Kiro IDE directly
|
||||
"claude-opus-4-6": "claude-opus-4.6",
|
||||
"claude-opus-4.6": "claude-opus-4.6",
|
||||
"claude-opus-4-5": "claude-opus-4.5",
|
||||
"claude-opus-4.5": "claude-opus-4.5",
|
||||
"claude-haiku-4-5": "claude-haiku-4.5",
|
||||
@@ -1707,10 +1729,12 @@ func (e *KiroExecutor) mapModelToKiro(model string) string {
|
||||
"claude-sonnet-4-20250514": "claude-sonnet-4",
|
||||
"auto": "auto",
|
||||
// Agentic variants (same backend model IDs, but with special system prompt)
|
||||
"claude-opus-4.6-agentic": "claude-opus-4.6",
|
||||
"claude-opus-4.5-agentic": "claude-opus-4.5",
|
||||
"claude-sonnet-4.5-agentic": "claude-sonnet-4.5",
|
||||
"claude-sonnet-4-agentic": "claude-sonnet-4",
|
||||
"claude-haiku-4.5-agentic": "claude-haiku-4.5",
|
||||
"kiro-claude-opus-4-6-agentic": "claude-opus-4.6",
|
||||
"kiro-claude-opus-4-5-agentic": "claude-opus-4.5",
|
||||
"kiro-claude-sonnet-4-5-agentic": "claude-sonnet-4.5",
|
||||
"kiro-claude-sonnet-4-agentic": "claude-sonnet-4",
|
||||
@@ -2096,6 +2120,22 @@ func (e *KiroExecutor) parseEventStream(body io.Reader) (string, []kiroclaude.Ki
|
||||
}
|
||||
}
|
||||
|
||||
case "contextUsageEvent":
|
||||
// Handle context usage events from Kiro API
|
||||
// Format: {"contextUsageEvent": {"contextUsagePercentage": 0.53}}
|
||||
if ctxUsage, ok := event["contextUsageEvent"].(map[string]interface{}); ok {
|
||||
if ctxPct, ok := ctxUsage["contextUsagePercentage"].(float64); ok {
|
||||
upstreamContextPercentage = ctxPct
|
||||
log.Debugf("kiro: parseEventStream received contextUsageEvent: %.2f%%", ctxPct*100)
|
||||
}
|
||||
} else {
|
||||
// Try direct field (fallback)
|
||||
if ctxPct, ok := event["contextUsagePercentage"].(float64); ok {
|
||||
upstreamContextPercentage = ctxPct
|
||||
log.Debugf("kiro: parseEventStream received contextUsagePercentage (direct): %.2f%%", ctxPct*100)
|
||||
}
|
||||
}
|
||||
|
||||
case "error", "exception", "internalServerException", "invalidStateEvent":
|
||||
// Handle error events from Kiro API stream
|
||||
errMsg := ""
|
||||
@@ -2442,8 +2482,9 @@ func (e *KiroExecutor) extractEventTypeFromBytes(headers []byte) string {
|
||||
func (e *KiroExecutor) streamToChannel(ctx context.Context, body io.Reader, out chan<- cliproxyexecutor.StreamChunk, targetFormat sdktranslator.Format, model string, originalReq, claudeBody []byte, reporter *usageReporter, thinkingEnabled bool) {
|
||||
reader := bufio.NewReaderSize(body, 20*1024*1024) // 20MB buffer to match other providers
|
||||
var totalUsage usage.Detail
|
||||
var hasToolUses bool // Track if any tool uses were emitted
|
||||
var upstreamStopReason string // Track stop_reason from upstream events
|
||||
var hasToolUses bool // Track if any tool uses were emitted
|
||||
var hasTruncatedTools bool // Track if any tool uses were truncated
|
||||
var upstreamStopReason string // Track stop_reason from upstream events
|
||||
|
||||
// Tool use state tracking for input buffering and deduplication
|
||||
processedIDs := make(map[string]bool)
|
||||
@@ -2698,6 +2739,22 @@ func (e *KiroExecutor) streamToChannel(ctx context.Context, body io.Reader, out
|
||||
}
|
||||
}
|
||||
|
||||
case "contextUsageEvent":
|
||||
// Handle context usage events from Kiro API
|
||||
// Format: {"contextUsageEvent": {"contextUsagePercentage": 0.53}}
|
||||
if ctxUsage, ok := event["contextUsageEvent"].(map[string]interface{}); ok {
|
||||
if ctxPct, ok := ctxUsage["contextUsagePercentage"].(float64); ok {
|
||||
upstreamContextPercentage = ctxPct
|
||||
log.Debugf("kiro: streamToChannel received contextUsageEvent: %.2f%%", ctxPct*100)
|
||||
}
|
||||
} else {
|
||||
// Try direct field (fallback)
|
||||
if ctxPct, ok := event["contextUsagePercentage"].(float64); ok {
|
||||
upstreamContextPercentage = ctxPct
|
||||
log.Debugf("kiro: streamToChannel received contextUsagePercentage (direct): %.2f%%", ctxPct*100)
|
||||
}
|
||||
}
|
||||
|
||||
case "error", "exception", "internalServerException":
|
||||
// Handle error events from Kiro API stream
|
||||
errMsg := ""
|
||||
@@ -3221,40 +3278,16 @@ func (e *KiroExecutor) streamToChannel(ctx context.Context, body io.Reader, out
|
||||
_ = signature // Signature can be used for verification if needed
|
||||
|
||||
case "toolUseEvent":
|
||||
// Debug: log raw toolUseEvent payload for large tool inputs
|
||||
if log.IsLevelEnabled(log.DebugLevel) {
|
||||
payloadStr := string(payload)
|
||||
if len(payloadStr) > 500 {
|
||||
payloadStr = payloadStr[:500] + "...[truncated]"
|
||||
}
|
||||
log.Debugf("kiro: raw toolUseEvent payload (%d bytes): %s", len(payload), payloadStr)
|
||||
}
|
||||
// Handle dedicated tool use events with input buffering
|
||||
completedToolUses, newState := kiroclaude.ProcessToolUseEvent(event, currentToolUse, processedIDs)
|
||||
currentToolUse = newState
|
||||
|
||||
// Emit completed tool uses
|
||||
for _, tu := range completedToolUses {
|
||||
// Check for truncated write marker - emit as a Bash tool that echoes the error
|
||||
// This way Claude Code will execute it, see the error, and the agent can retry
|
||||
if tu.Name == "__truncated_write__" {
|
||||
filePath := ""
|
||||
if fp, ok := tu.Input["file_path"].(string); ok && fp != "" {
|
||||
filePath = fp
|
||||
}
|
||||
|
||||
// Create a Bash tool that echoes the error message
|
||||
// This will be executed by Claude Code and the agent will see the result
|
||||
var errorMsg string
|
||||
if filePath != "" {
|
||||
errorMsg = fmt.Sprintf("echo '[WRITE TOOL ERROR] The file content for \"%s\" is too large to be transmitted by the upstream API. You MUST retry by writing the file in smaller chunks: First use Write to create the file with the first 700 lines, then use multiple Edit operations to append the remaining content in chunks of ~700 lines each.'", filePath)
|
||||
} else {
|
||||
errorMsg = "echo '[WRITE TOOL ERROR] The file content is too large to be transmitted by the upstream API. The Write tool input was truncated. You MUST retry by writing the file in smaller chunks: First use Write to create the file with the first 700 lines, then use multiple Edit operations to append the remaining content in chunks of ~700 lines each.'"
|
||||
}
|
||||
|
||||
log.Warnf("kiro: converting truncated write to Bash echo for file: %s", filePath)
|
||||
|
||||
hasToolUses = true
|
||||
// Check if this tool was truncated - emit with SOFT_LIMIT_REACHED marker
|
||||
if tu.IsTruncated {
|
||||
hasTruncatedTools = true
|
||||
log.Infof("kiro: streamToChannel emitting truncated tool with SOFT_LIMIT_REACHED: %s (ID: %s)", tu.Name, tu.ToolUseID)
|
||||
|
||||
// Close text block if open
|
||||
if isTextBlockOpen && contentBlockIndex >= 0 {
|
||||
@@ -3270,8 +3303,8 @@ func (e *KiroExecutor) streamToChannel(ctx context.Context, body io.Reader, out
|
||||
|
||||
contentBlockIndex++
|
||||
|
||||
// Emit as Bash tool_use
|
||||
blockStart := kiroclaude.BuildClaudeContentBlockStartEvent(contentBlockIndex, "tool_use", tu.ToolUseID, "Bash")
|
||||
// Emit tool_use with SOFT_LIMIT_REACHED marker input
|
||||
blockStart := kiroclaude.BuildClaudeContentBlockStartEvent(contentBlockIndex, "tool_use", tu.ToolUseID, tu.Name)
|
||||
sseData := sdktranslator.TranslateStream(ctx, sdktranslator.FromString("kiro"), targetFormat, model, originalReq, claudeBody, blockStart, &translatorParam)
|
||||
for _, chunk := range sseData {
|
||||
if chunk != "" {
|
||||
@@ -3279,16 +3312,14 @@ func (e *KiroExecutor) streamToChannel(ctx context.Context, body io.Reader, out
|
||||
}
|
||||
}
|
||||
|
||||
// Emit the Bash command as input
|
||||
bashInput := map[string]interface{}{
|
||||
"command": errorMsg,
|
||||
// Build SOFT_LIMIT_REACHED marker input
|
||||
markerInput := map[string]interface{}{
|
||||
"_status": "SOFT_LIMIT_REACHED",
|
||||
"_message": "Tool output was truncated. Split content into smaller chunks (max 300 lines). Due to potential model hallucination, you MUST re-fetch the current working directory and generate the correct file_path.",
|
||||
}
|
||||
inputJSON, err := json.Marshal(bashInput)
|
||||
if err != nil {
|
||||
log.Errorf("kiro: failed to marshal bash input for truncated write error: %v", err)
|
||||
continue
|
||||
}
|
||||
inputDelta := kiroclaude.BuildClaudeInputJsonDeltaEvent(string(inputJSON), contentBlockIndex)
|
||||
|
||||
markerJSON, _ := json.Marshal(markerInput)
|
||||
inputDelta := kiroclaude.BuildClaudeInputJsonDeltaEvent(string(markerJSON), contentBlockIndex)
|
||||
sseData = sdktranslator.TranslateStream(ctx, sdktranslator.FromString("kiro"), targetFormat, model, originalReq, claudeBody, inputDelta, &translatorParam)
|
||||
for _, chunk := range sseData {
|
||||
if chunk != "" {
|
||||
@@ -3296,6 +3327,7 @@ func (e *KiroExecutor) streamToChannel(ctx context.Context, body io.Reader, out
|
||||
}
|
||||
}
|
||||
|
||||
// Close tool_use block
|
||||
blockStop := kiroclaude.BuildClaudeContentBlockStopEvent(contentBlockIndex)
|
||||
sseData = sdktranslator.TranslateStream(ctx, sdktranslator.FromString("kiro"), targetFormat, model, originalReq, claudeBody, blockStop, &translatorParam)
|
||||
for _, chunk := range sseData {
|
||||
@@ -3304,7 +3336,8 @@ func (e *KiroExecutor) streamToChannel(ctx context.Context, body io.Reader, out
|
||||
}
|
||||
}
|
||||
|
||||
continue // Skip the normal tool_use emission
|
||||
hasToolUses = true // Keep this so stop_reason = tool_use
|
||||
continue
|
||||
}
|
||||
|
||||
hasToolUses = true
|
||||
@@ -3605,7 +3638,12 @@ func (e *KiroExecutor) streamToChannel(ctx context.Context, body io.Reader, out
|
||||
}
|
||||
|
||||
// Determine stop reason: prefer upstream, then detect tool_use, default to end_turn
|
||||
// SOFT_LIMIT_REACHED: Keep stop_reason = "tool_use" so Claude continues the loop
|
||||
stopReason := upstreamStopReason
|
||||
if hasTruncatedTools {
|
||||
// Log that we're using SOFT_LIMIT_REACHED approach
|
||||
log.Infof("kiro: streamToChannel using SOFT_LIMIT_REACHED - keeping stop_reason=tool_use for truncated tools")
|
||||
}
|
||||
if stopReason == "" {
|
||||
if hasToolUses {
|
||||
stopReason = "tool_use"
|
||||
@@ -4076,6 +4114,539 @@ func (e *KiroExecutor) isTokenExpired(accessToken string) bool {
|
||||
return isExpired
|
||||
}
|
||||
|
||||
// NOTE: Message merging functions moved to internal/translator/kiro/common/message_merge.go
|
||||
// NOTE: Tool calling support functions moved to internal/translator/kiro/claude/kiro_claude_tools.go
|
||||
// The executor now uses kiroclaude.* and kirocommon.* functions instead
|
||||
const maxWebSearchIterations = 5
|
||||
|
||||
// handleWebSearchStream handles web_search requests:
|
||||
// Step 1: tools/list (sync) → fetch/cache tool description
|
||||
// Step 2+: MCP search → InjectToolResultsClaude → callKiroAndBuffer loop
|
||||
// Note: We skip the "model decides to search" step because Claude Code already
|
||||
// decided to use web_search. The Kiro tool description restricts non-coding
|
||||
// topics, so asking the model again would cause it to refuse valid searches.
|
||||
func (e *KiroExecutor) handleWebSearchStream(
|
||||
ctx context.Context,
|
||||
auth *cliproxyauth.Auth,
|
||||
req cliproxyexecutor.Request,
|
||||
opts cliproxyexecutor.Options,
|
||||
accessToken, profileArn string,
|
||||
) (<-chan cliproxyexecutor.StreamChunk, error) {
|
||||
// Extract search query from Claude Code's web_search tool_use
|
||||
query := kiroclaude.ExtractSearchQuery(req.Payload)
|
||||
if query == "" {
|
||||
log.Warnf("kiro/websearch: failed to extract search query, falling back to normal flow")
|
||||
return e.callKiroDirectStream(ctx, auth, req, opts, accessToken, profileArn)
|
||||
}
|
||||
|
||||
// Build MCP endpoint based on region
|
||||
region := kiroDefaultRegion
|
||||
if auth != nil && auth.Metadata != nil {
|
||||
if r, ok := auth.Metadata["api_region"].(string); ok && r != "" {
|
||||
region = r
|
||||
}
|
||||
}
|
||||
mcpEndpoint := fmt.Sprintf("https://q.%s.amazonaws.com/mcp", region)
|
||||
|
||||
// ── Step 1: tools/list (SYNC) — cache tool description ──
|
||||
{
|
||||
tokenKey := getTokenKey(auth)
|
||||
fp := getGlobalFingerprintManager().GetFingerprint(tokenKey)
|
||||
var authAttrs map[string]string
|
||||
if auth != nil {
|
||||
authAttrs = auth.Attributes
|
||||
}
|
||||
kiroclaude.FetchToolDescription(mcpEndpoint, accessToken, newKiroHTTPClientWithPooling(ctx, e.cfg, auth, 30*time.Second), fp, authAttrs)
|
||||
}
|
||||
|
||||
// Create output channel
|
||||
out := make(chan cliproxyexecutor.StreamChunk)
|
||||
|
||||
go func() {
|
||||
defer close(out)
|
||||
|
||||
// Send message_start event to client
|
||||
messageStartEvent := kiroclaude.SseEvent{
|
||||
Event: "message_start",
|
||||
Data: map[string]interface{}{
|
||||
"type": "message_start",
|
||||
"message": map[string]interface{}{
|
||||
"id": kiroclaude.GenerateMessageID(),
|
||||
"type": "message",
|
||||
"role": "assistant",
|
||||
"model": req.Model,
|
||||
"content": []interface{}{},
|
||||
"stop_reason": nil,
|
||||
"stop_sequence": nil,
|
||||
"usage": map[string]interface{}{
|
||||
"input_tokens": len(req.Payload) / 4,
|
||||
"output_tokens": 0,
|
||||
"cache_creation_input_tokens": 0,
|
||||
"cache_read_input_tokens": 0,
|
||||
},
|
||||
},
|
||||
},
|
||||
}
|
||||
select {
|
||||
case <-ctx.Done():
|
||||
return
|
||||
case out <- cliproxyexecutor.StreamChunk{Payload: []byte(messageStartEvent.ToSSEString())}:
|
||||
}
|
||||
|
||||
// ── Step 2+: MCP search → InjectToolResultsClaude → callKiroAndBuffer loop ──
|
||||
contentBlockIndex := 0
|
||||
currentQuery := query
|
||||
|
||||
// Replace web_search tool description with a minimal one that allows re-search.
|
||||
// The original tools/list description from Kiro restricts non-coding topics,
|
||||
// but we've already decided to search. We keep the tool so the model can
|
||||
// request additional searches when results are insufficient.
|
||||
simplifiedPayload, simplifyErr := kiroclaude.ReplaceWebSearchToolDescription(bytes.Clone(req.Payload))
|
||||
if simplifyErr != nil {
|
||||
log.Warnf("kiro/websearch: failed to simplify web_search tool: %v, using original payload", simplifyErr)
|
||||
simplifiedPayload = bytes.Clone(req.Payload)
|
||||
}
|
||||
|
||||
currentClaudePayload := simplifiedPayload
|
||||
totalSearches := 0
|
||||
|
||||
// Generate toolUseId for the first iteration (Claude Code already decided to search)
|
||||
currentToolUseId := fmt.Sprintf("srvtoolu_%s", kiroclaude.GenerateToolUseID())
|
||||
|
||||
for iteration := 0; iteration < maxWebSearchIterations; iteration++ {
|
||||
log.Infof("kiro/websearch: search iteration %d/%d — query: %s",
|
||||
iteration+1, maxWebSearchIterations, currentQuery)
|
||||
|
||||
// MCP search
|
||||
_, mcpRequest := kiroclaude.CreateMcpRequest(currentQuery)
|
||||
tokenKey := getTokenKey(auth)
|
||||
fp := getGlobalFingerprintManager().GetFingerprint(tokenKey)
|
||||
var authAttrs map[string]string
|
||||
if auth != nil {
|
||||
authAttrs = auth.Attributes
|
||||
}
|
||||
handler := kiroclaude.NewWebSearchHandler(mcpEndpoint, accessToken, newKiroHTTPClientWithPooling(ctx, e.cfg, auth, 30*time.Second), fp, authAttrs)
|
||||
mcpResponse, mcpErr := handler.CallMcpAPI(mcpRequest)
|
||||
|
||||
var searchResults *kiroclaude.WebSearchResults
|
||||
if mcpErr != nil {
|
||||
log.Warnf("kiro/websearch: MCP API call failed: %v, continuing with empty results", mcpErr)
|
||||
} else {
|
||||
searchResults = kiroclaude.ParseSearchResults(mcpResponse)
|
||||
}
|
||||
|
||||
resultCount := 0
|
||||
if searchResults != nil {
|
||||
resultCount = len(searchResults.Results)
|
||||
}
|
||||
totalSearches++
|
||||
log.Infof("kiro/websearch: iteration %d — got %d search results", iteration+1, resultCount)
|
||||
|
||||
// Send search indicator events to client
|
||||
searchEvents := kiroclaude.GenerateSearchIndicatorEvents(currentQuery, currentToolUseId, searchResults, contentBlockIndex)
|
||||
for _, event := range searchEvents {
|
||||
select {
|
||||
case <-ctx.Done():
|
||||
return
|
||||
case out <- cliproxyexecutor.StreamChunk{Payload: []byte(event.ToSSEString())}:
|
||||
}
|
||||
}
|
||||
contentBlockIndex += 2
|
||||
|
||||
// Inject tool_use + tool_result into Claude payload, then call GAR
|
||||
var err error
|
||||
currentClaudePayload, err = kiroclaude.InjectToolResultsClaude(currentClaudePayload, currentToolUseId, currentQuery, searchResults)
|
||||
if err != nil {
|
||||
log.Warnf("kiro/websearch: failed to inject tool results: %v", err)
|
||||
e.sendFallbackText(ctx, out, contentBlockIndex, currentQuery, searchResults)
|
||||
break
|
||||
}
|
||||
|
||||
// Call GAR with modified Claude payload (full translation pipeline)
|
||||
modifiedReq := req
|
||||
modifiedReq.Payload = currentClaudePayload
|
||||
kiroChunks, kiroErr := e.callKiroAndBuffer(ctx, auth, modifiedReq, opts, accessToken, profileArn)
|
||||
if kiroErr != nil {
|
||||
log.Warnf("kiro/websearch: Kiro API failed at iteration %d: %v", iteration+1, kiroErr)
|
||||
e.sendFallbackText(ctx, out, contentBlockIndex, currentQuery, searchResults)
|
||||
break
|
||||
}
|
||||
|
||||
// Analyze response
|
||||
analysis := kiroclaude.AnalyzeBufferedStream(kiroChunks)
|
||||
log.Infof("kiro/websearch: iteration %d — stop_reason: %s, has_tool_use: %v, query: %s, toolUseId: %s",
|
||||
iteration+1, analysis.StopReason, analysis.HasWebSearchToolUse, analysis.WebSearchQuery, analysis.WebSearchToolUseId)
|
||||
|
||||
if analysis.HasWebSearchToolUse && analysis.WebSearchQuery != "" && iteration+1 < maxWebSearchIterations {
|
||||
// Model wants another search
|
||||
filteredChunks := kiroclaude.FilterChunksForClient(kiroChunks, analysis.WebSearchToolUseIndex, contentBlockIndex)
|
||||
for _, chunk := range filteredChunks {
|
||||
select {
|
||||
case <-ctx.Done():
|
||||
return
|
||||
case out <- cliproxyexecutor.StreamChunk{Payload: chunk}:
|
||||
}
|
||||
}
|
||||
|
||||
currentQuery = analysis.WebSearchQuery
|
||||
currentToolUseId = analysis.WebSearchToolUseId
|
||||
continue
|
||||
}
|
||||
|
||||
// Model returned final response — stream to client
|
||||
for _, chunk := range kiroChunks {
|
||||
if contentBlockIndex > 0 && len(chunk) > 0 {
|
||||
adjusted, shouldForward := kiroclaude.AdjustSSEChunk(chunk, contentBlockIndex)
|
||||
if !shouldForward {
|
||||
continue
|
||||
}
|
||||
select {
|
||||
case <-ctx.Done():
|
||||
return
|
||||
case out <- cliproxyexecutor.StreamChunk{Payload: adjusted}:
|
||||
}
|
||||
} else {
|
||||
select {
|
||||
case <-ctx.Done():
|
||||
return
|
||||
case out <- cliproxyexecutor.StreamChunk{Payload: chunk}:
|
||||
}
|
||||
}
|
||||
}
|
||||
log.Infof("kiro/websearch: completed after %d search iteration(s), total searches: %d", iteration+1, totalSearches)
|
||||
return
|
||||
}
|
||||
|
||||
log.Warnf("kiro/websearch: reached max iterations (%d), stopping search loop", maxWebSearchIterations)
|
||||
}()
|
||||
|
||||
return out, nil
|
||||
}
|
||||
|
||||
// callKiroAndBuffer calls the Kiro API and buffers all response chunks.
|
||||
// Returns the buffered chunks for analysis before forwarding to client.
|
||||
func (e *KiroExecutor) callKiroAndBuffer(
|
||||
ctx context.Context,
|
||||
auth *cliproxyauth.Auth,
|
||||
req cliproxyexecutor.Request,
|
||||
opts cliproxyexecutor.Options,
|
||||
accessToken, profileArn string,
|
||||
) ([][]byte, error) {
|
||||
from := opts.SourceFormat
|
||||
to := sdktranslator.FromString("kiro")
|
||||
body := sdktranslator.TranslateRequest(from, to, req.Model, bytes.Clone(req.Payload), true)
|
||||
log.Debugf("kiro/websearch GAR request: %d bytes", len(body))
|
||||
|
||||
kiroModelID := e.mapModelToKiro(req.Model)
|
||||
isAgentic, isChatOnly := determineAgenticMode(req.Model)
|
||||
effectiveProfileArn := getEffectiveProfileArnWithWarning(auth, profileArn)
|
||||
|
||||
tokenKey := ""
|
||||
if auth != nil {
|
||||
tokenKey = auth.ID
|
||||
}
|
||||
|
||||
kiroStream, err := e.executeStreamWithRetry(
|
||||
ctx, auth, req, opts, accessToken, effectiveProfileArn,
|
||||
nil, body, from, nil, "", kiroModelID, isAgentic, isChatOnly, tokenKey,
|
||||
)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
// Buffer all chunks
|
||||
var chunks [][]byte
|
||||
for chunk := range kiroStream {
|
||||
if chunk.Err != nil {
|
||||
return chunks, chunk.Err
|
||||
}
|
||||
if len(chunk.Payload) > 0 {
|
||||
chunks = append(chunks, bytes.Clone(chunk.Payload))
|
||||
}
|
||||
}
|
||||
|
||||
log.Debugf("kiro/websearch GAR response: %d chunks buffered", len(chunks))
|
||||
|
||||
return chunks, nil
|
||||
}
|
||||
|
||||
// callKiroRawAndBuffer calls the Kiro API with a pre-built Kiro payload (no translation).
|
||||
// Used in the web search loop where the payload is modified directly in Kiro format.
|
||||
func (e *KiroExecutor) callKiroRawAndBuffer(
|
||||
ctx context.Context,
|
||||
auth *cliproxyauth.Auth,
|
||||
req cliproxyexecutor.Request,
|
||||
opts cliproxyexecutor.Options,
|
||||
accessToken, profileArn string,
|
||||
kiroBody []byte,
|
||||
) ([][]byte, error) {
|
||||
kiroModelID := e.mapModelToKiro(req.Model)
|
||||
isAgentic, isChatOnly := determineAgenticMode(req.Model)
|
||||
effectiveProfileArn := getEffectiveProfileArnWithWarning(auth, profileArn)
|
||||
|
||||
tokenKey := ""
|
||||
if auth != nil {
|
||||
tokenKey = auth.ID
|
||||
}
|
||||
log.Debugf("kiro/websearch GAR raw request: %d bytes", len(kiroBody))
|
||||
|
||||
kiroFormat := sdktranslator.FromString("kiro")
|
||||
kiroStream, err := e.executeStreamWithRetry(
|
||||
ctx, auth, req, opts, accessToken, effectiveProfileArn,
|
||||
nil, kiroBody, kiroFormat, nil, "", kiroModelID, isAgentic, isChatOnly, tokenKey,
|
||||
)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
// Buffer all chunks
|
||||
var chunks [][]byte
|
||||
for chunk := range kiroStream {
|
||||
if chunk.Err != nil {
|
||||
return chunks, chunk.Err
|
||||
}
|
||||
if len(chunk.Payload) > 0 {
|
||||
chunks = append(chunks, bytes.Clone(chunk.Payload))
|
||||
}
|
||||
}
|
||||
|
||||
log.Debugf("kiro/websearch GAR raw response: %d chunks buffered", len(chunks))
|
||||
|
||||
return chunks, nil
|
||||
}
|
||||
|
||||
// callKiroDirectStream creates a direct streaming channel to Kiro API without search.
|
||||
func (e *KiroExecutor) callKiroDirectStream(
|
||||
ctx context.Context,
|
||||
auth *cliproxyauth.Auth,
|
||||
req cliproxyexecutor.Request,
|
||||
opts cliproxyexecutor.Options,
|
||||
accessToken, profileArn string,
|
||||
) (<-chan cliproxyexecutor.StreamChunk, error) {
|
||||
from := opts.SourceFormat
|
||||
to := sdktranslator.FromString("kiro")
|
||||
body := sdktranslator.TranslateRequest(from, to, req.Model, bytes.Clone(req.Payload), true)
|
||||
|
||||
kiroModelID := e.mapModelToKiro(req.Model)
|
||||
isAgentic, isChatOnly := determineAgenticMode(req.Model)
|
||||
effectiveProfileArn := getEffectiveProfileArnWithWarning(auth, profileArn)
|
||||
|
||||
tokenKey := ""
|
||||
if auth != nil {
|
||||
tokenKey = auth.ID
|
||||
}
|
||||
|
||||
return e.executeStreamWithRetry(
|
||||
ctx, auth, req, opts, accessToken, effectiveProfileArn,
|
||||
nil, body, from, nil, "", kiroModelID, isAgentic, isChatOnly, tokenKey,
|
||||
)
|
||||
}
|
||||
|
||||
// sendFallbackText sends a simple text response when the Kiro API fails during the search loop.
|
||||
func (e *KiroExecutor) sendFallbackText(
|
||||
ctx context.Context,
|
||||
out chan<- cliproxyexecutor.StreamChunk,
|
||||
contentBlockIndex int,
|
||||
query string,
|
||||
searchResults *kiroclaude.WebSearchResults,
|
||||
) {
|
||||
// Generate a simple text summary from search results
|
||||
summary := kiroclaude.FormatSearchContextPrompt(query, searchResults)
|
||||
|
||||
events := []kiroclaude.SseEvent{
|
||||
{
|
||||
Event: "content_block_start",
|
||||
Data: map[string]interface{}{
|
||||
"type": "content_block_start",
|
||||
"index": contentBlockIndex,
|
||||
"content_block": map[string]interface{}{
|
||||
"type": "text",
|
||||
"text": "",
|
||||
},
|
||||
},
|
||||
},
|
||||
{
|
||||
Event: "content_block_delta",
|
||||
Data: map[string]interface{}{
|
||||
"type": "content_block_delta",
|
||||
"index": contentBlockIndex,
|
||||
"delta": map[string]interface{}{
|
||||
"type": "text_delta",
|
||||
"text": summary,
|
||||
},
|
||||
},
|
||||
},
|
||||
{
|
||||
Event: "content_block_stop",
|
||||
Data: map[string]interface{}{
|
||||
"type": "content_block_stop",
|
||||
"index": contentBlockIndex,
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
for _, event := range events {
|
||||
select {
|
||||
case <-ctx.Done():
|
||||
return
|
||||
case out <- cliproxyexecutor.StreamChunk{Payload: []byte(event.ToSSEString())}:
|
||||
}
|
||||
}
|
||||
|
||||
// Send message_delta with end_turn and message_stop
|
||||
msgDelta := kiroclaude.SseEvent{
|
||||
Event: "message_delta",
|
||||
Data: map[string]interface{}{
|
||||
"type": "message_delta",
|
||||
"delta": map[string]interface{}{
|
||||
"stop_reason": "end_turn",
|
||||
"stop_sequence": nil,
|
||||
},
|
||||
"usage": map[string]interface{}{
|
||||
"output_tokens": len(summary) / 4,
|
||||
},
|
||||
},
|
||||
}
|
||||
select {
|
||||
case <-ctx.Done():
|
||||
return
|
||||
case out <- cliproxyexecutor.StreamChunk{Payload: []byte(msgDelta.ToSSEString())}:
|
||||
}
|
||||
|
||||
msgStop := kiroclaude.SseEvent{
|
||||
Event: "message_stop",
|
||||
Data: map[string]interface{}{
|
||||
"type": "message_stop",
|
||||
},
|
||||
}
|
||||
select {
|
||||
case <-ctx.Done():
|
||||
return
|
||||
case out <- cliproxyexecutor.StreamChunk{Payload: []byte(msgStop.ToSSEString())}:
|
||||
}
|
||||
|
||||
}
|
||||
|
||||
// handleWebSearch handles web_search requests for non-streaming Execute path.
|
||||
// Performs MCP search synchronously, injects results into the request payload,
|
||||
// then calls the normal non-streaming Kiro API path which returns a proper
|
||||
// Claude JSON response (not SSE chunks).
|
||||
func (e *KiroExecutor) handleWebSearch(
|
||||
ctx context.Context,
|
||||
auth *cliproxyauth.Auth,
|
||||
req cliproxyexecutor.Request,
|
||||
opts cliproxyexecutor.Options,
|
||||
accessToken, profileArn string,
|
||||
) (cliproxyexecutor.Response, error) {
|
||||
// Extract search query from Claude Code's web_search tool_use
|
||||
query := kiroclaude.ExtractSearchQuery(req.Payload)
|
||||
if query == "" {
|
||||
log.Warnf("kiro/websearch: non-stream: failed to extract search query, falling back to normal Execute")
|
||||
// Fall through to normal non-streaming path
|
||||
return e.executeNonStreamFallback(ctx, auth, req, opts, accessToken, profileArn)
|
||||
}
|
||||
|
||||
// Build MCP endpoint based on region
|
||||
region := kiroDefaultRegion
|
||||
if auth != nil && auth.Metadata != nil {
|
||||
if r, ok := auth.Metadata["api_region"].(string); ok && r != "" {
|
||||
region = r
|
||||
}
|
||||
}
|
||||
mcpEndpoint := fmt.Sprintf("https://q.%s.amazonaws.com/mcp", region)
|
||||
|
||||
// Step 1: Fetch/cache tool description (sync)
|
||||
{
|
||||
tokenKey := getTokenKey(auth)
|
||||
fp := getGlobalFingerprintManager().GetFingerprint(tokenKey)
|
||||
var authAttrs map[string]string
|
||||
if auth != nil {
|
||||
authAttrs = auth.Attributes
|
||||
}
|
||||
kiroclaude.FetchToolDescription(mcpEndpoint, accessToken, newKiroHTTPClientWithPooling(ctx, e.cfg, auth, 30*time.Second), fp, authAttrs)
|
||||
}
|
||||
|
||||
// Step 2: Perform MCP search
|
||||
_, mcpRequest := kiroclaude.CreateMcpRequest(query)
|
||||
tokenKey := getTokenKey(auth)
|
||||
fp := getGlobalFingerprintManager().GetFingerprint(tokenKey)
|
||||
var authAttrs map[string]string
|
||||
if auth != nil {
|
||||
authAttrs = auth.Attributes
|
||||
}
|
||||
handler := kiroclaude.NewWebSearchHandler(mcpEndpoint, accessToken, newKiroHTTPClientWithPooling(ctx, e.cfg, auth, 30*time.Second), fp, authAttrs)
|
||||
mcpResponse, mcpErr := handler.CallMcpAPI(mcpRequest)
|
||||
|
||||
var searchResults *kiroclaude.WebSearchResults
|
||||
if mcpErr != nil {
|
||||
log.Warnf("kiro/websearch: non-stream: MCP API call failed: %v, continuing with empty results", mcpErr)
|
||||
} else {
|
||||
searchResults = kiroclaude.ParseSearchResults(mcpResponse)
|
||||
}
|
||||
|
||||
resultCount := 0
|
||||
if searchResults != nil {
|
||||
resultCount = len(searchResults.Results)
|
||||
}
|
||||
log.Infof("kiro/websearch: non-stream: got %d search results for query: %s", resultCount, query)
|
||||
|
||||
// Step 3: Inject search tool_use + tool_result into Claude payload
|
||||
currentToolUseId := fmt.Sprintf("srvtoolu_%s", kiroclaude.GenerateToolUseID())
|
||||
modifiedPayload, err := kiroclaude.InjectToolResultsClaude(bytes.Clone(req.Payload), currentToolUseId, query, searchResults)
|
||||
if err != nil {
|
||||
log.Warnf("kiro/websearch: non-stream: failed to inject tool results: %v, falling back", err)
|
||||
return e.executeNonStreamFallback(ctx, auth, req, opts, accessToken, profileArn)
|
||||
}
|
||||
|
||||
// Step 4: Call Kiro API via the normal non-streaming path (executeWithRetry)
|
||||
// This path uses parseEventStream → BuildClaudeResponse → TranslateNonStream
|
||||
// to produce a proper Claude JSON response
|
||||
modifiedReq := req
|
||||
modifiedReq.Payload = modifiedPayload
|
||||
|
||||
resp, err := e.executeNonStreamFallback(ctx, auth, modifiedReq, opts, accessToken, profileArn)
|
||||
if err != nil {
|
||||
return resp, err
|
||||
}
|
||||
|
||||
// Step 5: Inject server_tool_use + web_search_tool_result into response
|
||||
// so Claude Code can display "Did X searches in Ys"
|
||||
indicators := []kiroclaude.SearchIndicator{
|
||||
{
|
||||
ToolUseID: currentToolUseId,
|
||||
Query: query,
|
||||
Results: searchResults,
|
||||
},
|
||||
}
|
||||
injectedPayload, injErr := kiroclaude.InjectSearchIndicatorsInResponse(resp.Payload, indicators)
|
||||
if injErr != nil {
|
||||
log.Warnf("kiro/websearch: non-stream: failed to inject search indicators: %v", injErr)
|
||||
} else {
|
||||
resp.Payload = injectedPayload
|
||||
}
|
||||
|
||||
return resp, nil
|
||||
}
|
||||
|
||||
// executeNonStreamFallback runs the standard non-streaming Execute path for a request.
|
||||
// Used by handleWebSearch after injecting search results, or as a fallback.
|
||||
func (e *KiroExecutor) executeNonStreamFallback(
|
||||
ctx context.Context,
|
||||
auth *cliproxyauth.Auth,
|
||||
req cliproxyexecutor.Request,
|
||||
opts cliproxyexecutor.Options,
|
||||
accessToken, profileArn string,
|
||||
) (cliproxyexecutor.Response, error) {
|
||||
from := opts.SourceFormat
|
||||
to := sdktranslator.FromString("kiro")
|
||||
body := sdktranslator.TranslateRequest(from, to, req.Model, bytes.Clone(req.Payload), true)
|
||||
|
||||
kiroModelID := e.mapModelToKiro(req.Model)
|
||||
isAgentic, isChatOnly := determineAgenticMode(req.Model)
|
||||
effectiveProfileArn := getEffectiveProfileArnWithWarning(auth, profileArn)
|
||||
tokenKey := getTokenKey(auth)
|
||||
|
||||
reporter := newUsageReporter(ctx, e.Identifier(), req.Model, auth)
|
||||
var err error
|
||||
defer reporter.trackFailure(ctx, &err)
|
||||
|
||||
resp, err := e.executeWithRetry(ctx, auth, req, opts, accessToken, effectiveProfileArn, nil, body, from, to, reporter, "", kiroModelID, isAgentic, isChatOnly, tokenKey)
|
||||
return resp, err
|
||||
}
|
||||
|
||||
@@ -80,7 +80,7 @@ func recordAPIRequest(ctx context.Context, cfg *config.Config, info upstreamRequ
|
||||
writeHeaders(builder, info.Headers)
|
||||
builder.WriteString("\nBody:\n")
|
||||
if len(info.Body) > 0 {
|
||||
builder.WriteString(string(bytes.Clone(info.Body)))
|
||||
builder.WriteString(string(info.Body))
|
||||
} else {
|
||||
builder.WriteString("<empty>")
|
||||
}
|
||||
@@ -152,7 +152,7 @@ func appendAPIResponseChunk(ctx context.Context, cfg *config.Config, chunk []byt
|
||||
if cfg == nil || !cfg.RequestLog {
|
||||
return
|
||||
}
|
||||
data := bytes.TrimSpace(bytes.Clone(chunk))
|
||||
data := bytes.TrimSpace(chunk)
|
||||
if len(data) == 0 {
|
||||
return
|
||||
}
|
||||
|
||||
@@ -88,12 +88,13 @@ func (e *OpenAICompatExecutor) Execute(ctx context.Context, auth *cliproxyauth.A
|
||||
to = sdktranslator.FromString("openai-response")
|
||||
endpoint = "/responses/compact"
|
||||
}
|
||||
originalPayload := bytes.Clone(req.Payload)
|
||||
originalPayloadSource := req.Payload
|
||||
if len(opts.OriginalRequest) > 0 {
|
||||
originalPayload = bytes.Clone(opts.OriginalRequest)
|
||||
originalPayloadSource = opts.OriginalRequest
|
||||
}
|
||||
originalPayload := originalPayloadSource
|
||||
originalTranslated := sdktranslator.TranslateRequest(from, to, baseModel, originalPayload, opts.Stream)
|
||||
translated := sdktranslator.TranslateRequest(from, to, baseModel, bytes.Clone(req.Payload), opts.Stream)
|
||||
translated := sdktranslator.TranslateRequest(from, to, baseModel, req.Payload, opts.Stream)
|
||||
requestedModel := payloadRequestedModel(opts, req.Model)
|
||||
translated = applyPayloadConfigWithRoot(e.cfg, baseModel, to.String(), "", translated, originalTranslated, requestedModel)
|
||||
if opts.Alt == "responses/compact" {
|
||||
@@ -170,7 +171,7 @@ func (e *OpenAICompatExecutor) Execute(ctx context.Context, auth *cliproxyauth.A
|
||||
reporter.ensurePublished(ctx)
|
||||
// Translate response back to source format when needed
|
||||
var param any
|
||||
out := sdktranslator.TranslateNonStream(ctx, to, from, req.Model, bytes.Clone(opts.OriginalRequest), translated, body, ¶m)
|
||||
out := sdktranslator.TranslateNonStream(ctx, to, from, req.Model, opts.OriginalRequest, translated, body, ¶m)
|
||||
resp = cliproxyexecutor.Response{Payload: []byte(out)}
|
||||
return resp, nil
|
||||
}
|
||||
@@ -189,12 +190,13 @@ func (e *OpenAICompatExecutor) ExecuteStream(ctx context.Context, auth *cliproxy
|
||||
|
||||
from := opts.SourceFormat
|
||||
to := sdktranslator.FromString("openai")
|
||||
originalPayload := bytes.Clone(req.Payload)
|
||||
originalPayloadSource := req.Payload
|
||||
if len(opts.OriginalRequest) > 0 {
|
||||
originalPayload = bytes.Clone(opts.OriginalRequest)
|
||||
originalPayloadSource = opts.OriginalRequest
|
||||
}
|
||||
originalPayload := originalPayloadSource
|
||||
originalTranslated := sdktranslator.TranslateRequest(from, to, baseModel, originalPayload, true)
|
||||
translated := sdktranslator.TranslateRequest(from, to, baseModel, bytes.Clone(req.Payload), true)
|
||||
translated := sdktranslator.TranslateRequest(from, to, baseModel, req.Payload, true)
|
||||
requestedModel := payloadRequestedModel(opts, req.Model)
|
||||
translated = applyPayloadConfigWithRoot(e.cfg, baseModel, to.String(), "", translated, originalTranslated, requestedModel)
|
||||
|
||||
@@ -283,7 +285,7 @@ func (e *OpenAICompatExecutor) ExecuteStream(ctx context.Context, auth *cliproxy
|
||||
|
||||
// OpenAI-compatible streams are SSE: lines typically prefixed with "data: ".
|
||||
// Pass through translator; it yields one or more chunks for the target schema.
|
||||
chunks := sdktranslator.TranslateStream(ctx, to, from, req.Model, bytes.Clone(opts.OriginalRequest), translated, bytes.Clone(line), ¶m)
|
||||
chunks := sdktranslator.TranslateStream(ctx, to, from, req.Model, opts.OriginalRequest, translated, bytes.Clone(line), ¶m)
|
||||
for i := range chunks {
|
||||
out <- cliproxyexecutor.StreamChunk{Payload: []byte(chunks[i])}
|
||||
}
|
||||
@@ -304,7 +306,7 @@ func (e *OpenAICompatExecutor) CountTokens(ctx context.Context, auth *cliproxyau
|
||||
|
||||
from := opts.SourceFormat
|
||||
to := sdktranslator.FromString("openai")
|
||||
translated := sdktranslator.TranslateRequest(from, to, baseModel, bytes.Clone(req.Payload), false)
|
||||
translated := sdktranslator.TranslateRequest(from, to, baseModel, req.Payload, false)
|
||||
|
||||
modelForCounting := baseModel
|
||||
|
||||
|
||||
@@ -81,12 +81,13 @@ func (e *QwenExecutor) Execute(ctx context.Context, auth *cliproxyauth.Auth, req
|
||||
|
||||
from := opts.SourceFormat
|
||||
to := sdktranslator.FromString("openai")
|
||||
originalPayload := bytes.Clone(req.Payload)
|
||||
originalPayloadSource := req.Payload
|
||||
if len(opts.OriginalRequest) > 0 {
|
||||
originalPayload = bytes.Clone(opts.OriginalRequest)
|
||||
originalPayloadSource = opts.OriginalRequest
|
||||
}
|
||||
originalPayload := originalPayloadSource
|
||||
originalTranslated := sdktranslator.TranslateRequest(from, to, baseModel, originalPayload, false)
|
||||
body := sdktranslator.TranslateRequest(from, to, baseModel, bytes.Clone(req.Payload), false)
|
||||
body := sdktranslator.TranslateRequest(from, to, baseModel, req.Payload, false)
|
||||
body, _ = sjson.SetBytes(body, "model", baseModel)
|
||||
|
||||
body, err = thinking.ApplyThinking(body, req.Model, from.String(), to.String(), e.Identifier())
|
||||
@@ -150,7 +151,7 @@ func (e *QwenExecutor) Execute(ctx context.Context, auth *cliproxyauth.Auth, req
|
||||
var param any
|
||||
// Note: TranslateNonStream uses req.Model (original with suffix) to preserve
|
||||
// the original model name in the response for client compatibility.
|
||||
out := sdktranslator.TranslateNonStream(ctx, to, from, req.Model, bytes.Clone(opts.OriginalRequest), body, data, ¶m)
|
||||
out := sdktranslator.TranslateNonStream(ctx, to, from, req.Model, opts.OriginalRequest, body, data, ¶m)
|
||||
resp = cliproxyexecutor.Response{Payload: []byte(out)}
|
||||
return resp, nil
|
||||
}
|
||||
@@ -171,12 +172,13 @@ func (e *QwenExecutor) ExecuteStream(ctx context.Context, auth *cliproxyauth.Aut
|
||||
|
||||
from := opts.SourceFormat
|
||||
to := sdktranslator.FromString("openai")
|
||||
originalPayload := bytes.Clone(req.Payload)
|
||||
originalPayloadSource := req.Payload
|
||||
if len(opts.OriginalRequest) > 0 {
|
||||
originalPayload = bytes.Clone(opts.OriginalRequest)
|
||||
originalPayloadSource = opts.OriginalRequest
|
||||
}
|
||||
originalPayload := originalPayloadSource
|
||||
originalTranslated := sdktranslator.TranslateRequest(from, to, baseModel, originalPayload, true)
|
||||
body := sdktranslator.TranslateRequest(from, to, baseModel, bytes.Clone(req.Payload), true)
|
||||
body := sdktranslator.TranslateRequest(from, to, baseModel, req.Payload, true)
|
||||
body, _ = sjson.SetBytes(body, "model", baseModel)
|
||||
|
||||
body, err = thinking.ApplyThinking(body, req.Model, from.String(), to.String(), e.Identifier())
|
||||
@@ -253,12 +255,12 @@ func (e *QwenExecutor) ExecuteStream(ctx context.Context, auth *cliproxyauth.Aut
|
||||
if detail, ok := parseOpenAIStreamUsage(line); ok {
|
||||
reporter.publish(ctx, detail)
|
||||
}
|
||||
chunks := sdktranslator.TranslateStream(ctx, to, from, req.Model, bytes.Clone(opts.OriginalRequest), body, bytes.Clone(line), ¶m)
|
||||
chunks := sdktranslator.TranslateStream(ctx, to, from, req.Model, opts.OriginalRequest, body, bytes.Clone(line), ¶m)
|
||||
for i := range chunks {
|
||||
out <- cliproxyexecutor.StreamChunk{Payload: []byte(chunks[i])}
|
||||
}
|
||||
}
|
||||
doneChunks := sdktranslator.TranslateStream(ctx, to, from, req.Model, bytes.Clone(opts.OriginalRequest), body, bytes.Clone([]byte("[DONE]")), ¶m)
|
||||
doneChunks := sdktranslator.TranslateStream(ctx, to, from, req.Model, opts.OriginalRequest, body, []byte("[DONE]"), ¶m)
|
||||
for i := range doneChunks {
|
||||
out <- cliproxyexecutor.StreamChunk{Payload: []byte(doneChunks[i])}
|
||||
}
|
||||
@@ -276,7 +278,7 @@ func (e *QwenExecutor) CountTokens(ctx context.Context, auth *cliproxyauth.Auth,
|
||||
|
||||
from := opts.SourceFormat
|
||||
to := sdktranslator.FromString("openai")
|
||||
body := sdktranslator.TranslateRequest(from, to, baseModel, bytes.Clone(req.Payload), false)
|
||||
body := sdktranslator.TranslateRequest(from, to, baseModel, req.Payload, false)
|
||||
|
||||
modelName := gjson.GetBytes(body, "model").String()
|
||||
if strings.TrimSpace(modelName) == "" {
|
||||
|
||||
@@ -7,5 +7,6 @@ import (
|
||||
_ "github.com/router-for-me/CLIProxyAPI/v6/internal/thinking/provider/gemini"
|
||||
_ "github.com/router-for-me/CLIProxyAPI/v6/internal/thinking/provider/geminicli"
|
||||
_ "github.com/router-for-me/CLIProxyAPI/v6/internal/thinking/provider/iflow"
|
||||
_ "github.com/router-for-me/CLIProxyAPI/v6/internal/thinking/provider/kimi"
|
||||
_ "github.com/router-for-me/CLIProxyAPI/v6/internal/thinking/provider/openai"
|
||||
)
|
||||
|
||||
@@ -21,6 +21,9 @@ import (
|
||||
cliproxyauth "github.com/router-for-me/CLIProxyAPI/v6/sdk/cliproxy/auth"
|
||||
)
|
||||
|
||||
// gcInterval defines minimum time between garbage collection runs.
|
||||
const gcInterval = 5 * time.Minute
|
||||
|
||||
// GitTokenStore persists token records and auth metadata using git as the backing storage.
|
||||
type GitTokenStore struct {
|
||||
mu sync.Mutex
|
||||
@@ -31,6 +34,7 @@ type GitTokenStore struct {
|
||||
remote string
|
||||
username string
|
||||
password string
|
||||
lastGC time.Time
|
||||
}
|
||||
|
||||
// NewGitTokenStore creates a token store that saves credentials to disk through the
|
||||
@@ -613,6 +617,7 @@ func (s *GitTokenStore) commitAndPushLocked(message string, relPaths ...string)
|
||||
} else if errRewrite := s.rewriteHeadAsSingleCommit(repo, headRef.Name(), commitHash, message, signature); errRewrite != nil {
|
||||
return errRewrite
|
||||
}
|
||||
s.maybeRunGC(repo)
|
||||
if err = repo.Push(&git.PushOptions{Auth: s.gitAuth(), Force: true}); err != nil {
|
||||
if errors.Is(err, git.NoErrAlreadyUpToDate) {
|
||||
return nil
|
||||
@@ -652,6 +657,23 @@ func (s *GitTokenStore) rewriteHeadAsSingleCommit(repo *git.Repository, branch p
|
||||
return nil
|
||||
}
|
||||
|
||||
func (s *GitTokenStore) maybeRunGC(repo *git.Repository) {
|
||||
now := time.Now()
|
||||
if now.Sub(s.lastGC) < gcInterval {
|
||||
return
|
||||
}
|
||||
s.lastGC = now
|
||||
|
||||
pruneOpts := git.PruneOptions{
|
||||
OnlyObjectsOlderThan: now,
|
||||
Handler: repo.DeleteObject,
|
||||
}
|
||||
if err := repo.Prune(pruneOpts); err != nil && !errors.Is(err, git.ErrLooseObjectsNotSupported) {
|
||||
return
|
||||
}
|
||||
_ = repo.RepackObjects(&git.RepackConfig{})
|
||||
}
|
||||
|
||||
// PersistConfig commits and pushes configuration changes to git.
|
||||
func (s *GitTokenStore) PersistConfig(_ context.Context) error {
|
||||
if err := s.EnsureRepository(); err != nil {
|
||||
|
||||
@@ -18,6 +18,7 @@ var providerAppliers = map[string]ProviderApplier{
|
||||
"codex": nil,
|
||||
"iflow": nil,
|
||||
"antigravity": nil,
|
||||
"kimi": nil,
|
||||
}
|
||||
|
||||
// GetProviderApplier returns the ProviderApplier for the given provider name.
|
||||
@@ -326,6 +327,9 @@ func extractThinkingConfig(body []byte, provider string) ThinkingConfig {
|
||||
return config
|
||||
}
|
||||
return extractOpenAIConfig(body)
|
||||
case "kimi":
|
||||
// Kimi uses OpenAI-compatible reasoning_effort format
|
||||
return extractOpenAIConfig(body)
|
||||
default:
|
||||
return ThinkingConfig{}
|
||||
}
|
||||
@@ -388,7 +392,12 @@ func extractGeminiConfig(body []byte, provider string) ThinkingConfig {
|
||||
}
|
||||
|
||||
// Check thinkingLevel first (Gemini 3 format takes precedence)
|
||||
if level := gjson.GetBytes(body, prefix+".thinkingLevel"); level.Exists() {
|
||||
level := gjson.GetBytes(body, prefix+".thinkingLevel")
|
||||
if !level.Exists() {
|
||||
// Google official Gemini Python SDK sends snake_case field names
|
||||
level = gjson.GetBytes(body, prefix+".thinking_level")
|
||||
}
|
||||
if level.Exists() {
|
||||
value := level.String()
|
||||
switch value {
|
||||
case "none":
|
||||
@@ -401,7 +410,12 @@ func extractGeminiConfig(body []byte, provider string) ThinkingConfig {
|
||||
}
|
||||
|
||||
// Check thinkingBudget (Gemini 2.5 format)
|
||||
if budget := gjson.GetBytes(body, prefix+".thinkingBudget"); budget.Exists() {
|
||||
budget := gjson.GetBytes(body, prefix+".thinkingBudget")
|
||||
if !budget.Exists() {
|
||||
// Google official Gemini Python SDK sends snake_case field names
|
||||
budget = gjson.GetBytes(body, prefix+".thinking_budget")
|
||||
}
|
||||
if budget.Exists() {
|
||||
value := int(budget.Int())
|
||||
switch value {
|
||||
case 0:
|
||||
|
||||
@@ -94,8 +94,10 @@ func (a *Applier) applyCompatible(body []byte, config thinking.ThinkingConfig, m
|
||||
}
|
||||
|
||||
func (a *Applier) applyLevelFormat(body []byte, config thinking.ThinkingConfig) ([]byte, error) {
|
||||
// Remove conflicting field to avoid both thinkingLevel and thinkingBudget in output
|
||||
// Remove conflicting fields to avoid both thinkingLevel and thinkingBudget in output
|
||||
result, _ := sjson.DeleteBytes(body, "request.generationConfig.thinkingConfig.thinkingBudget")
|
||||
result, _ = sjson.DeleteBytes(result, "request.generationConfig.thinkingConfig.thinking_budget")
|
||||
result, _ = sjson.DeleteBytes(result, "request.generationConfig.thinkingConfig.thinking_level")
|
||||
// Normalize includeThoughts field name to avoid oneof conflicts in upstream JSON parsing.
|
||||
result, _ = sjson.DeleteBytes(result, "request.generationConfig.thinkingConfig.include_thoughts")
|
||||
|
||||
@@ -114,28 +116,30 @@ func (a *Applier) applyLevelFormat(body []byte, config thinking.ThinkingConfig)
|
||||
|
||||
level := string(config.Level)
|
||||
result, _ = sjson.SetBytes(result, "request.generationConfig.thinkingConfig.thinkingLevel", level)
|
||||
result, _ = sjson.SetBytes(result, "request.generationConfig.thinkingConfig.includeThoughts", true)
|
||||
|
||||
// Respect user's explicit includeThoughts setting from original body; default to true if not set
|
||||
// Support both camelCase and snake_case variants
|
||||
includeThoughts := true
|
||||
if inc := gjson.GetBytes(body, "request.generationConfig.thinkingConfig.includeThoughts"); inc.Exists() {
|
||||
includeThoughts = inc.Bool()
|
||||
} else if inc := gjson.GetBytes(body, "request.generationConfig.thinkingConfig.include_thoughts"); inc.Exists() {
|
||||
includeThoughts = inc.Bool()
|
||||
}
|
||||
result, _ = sjson.SetBytes(result, "request.generationConfig.thinkingConfig.includeThoughts", includeThoughts)
|
||||
return result, nil
|
||||
}
|
||||
|
||||
func (a *Applier) applyBudgetFormat(body []byte, config thinking.ThinkingConfig, modelInfo *registry.ModelInfo, isClaude bool) ([]byte, error) {
|
||||
// Remove conflicting field to avoid both thinkingLevel and thinkingBudget in output
|
||||
// Remove conflicting fields to avoid both thinkingLevel and thinkingBudget in output
|
||||
result, _ := sjson.DeleteBytes(body, "request.generationConfig.thinkingConfig.thinkingLevel")
|
||||
result, _ = sjson.DeleteBytes(result, "request.generationConfig.thinkingConfig.thinking_level")
|
||||
result, _ = sjson.DeleteBytes(result, "request.generationConfig.thinkingConfig.thinking_budget")
|
||||
// Normalize includeThoughts field name to avoid oneof conflicts in upstream JSON parsing.
|
||||
result, _ = sjson.DeleteBytes(result, "request.generationConfig.thinkingConfig.include_thoughts")
|
||||
|
||||
budget := config.Budget
|
||||
includeThoughts := false
|
||||
switch config.Mode {
|
||||
case thinking.ModeNone:
|
||||
includeThoughts = false
|
||||
case thinking.ModeAuto:
|
||||
includeThoughts = true
|
||||
default:
|
||||
includeThoughts = budget > 0
|
||||
}
|
||||
|
||||
// Apply Claude-specific constraints
|
||||
// Apply Claude-specific constraints first to get the final budget value
|
||||
if isClaude && modelInfo != nil {
|
||||
budget, result = a.normalizeClaudeBudget(budget, result, modelInfo)
|
||||
// Check if budget was removed entirely
|
||||
@@ -144,6 +148,37 @@ func (a *Applier) applyBudgetFormat(body []byte, config thinking.ThinkingConfig,
|
||||
}
|
||||
}
|
||||
|
||||
// For ModeNone, always set includeThoughts to false regardless of user setting.
|
||||
// This ensures that when user requests budget=0 (disable thinking output),
|
||||
// the includeThoughts is correctly set to false even if budget is clamped to min.
|
||||
if config.Mode == thinking.ModeNone {
|
||||
result, _ = sjson.SetBytes(result, "request.generationConfig.thinkingConfig.thinkingBudget", budget)
|
||||
result, _ = sjson.SetBytes(result, "request.generationConfig.thinkingConfig.includeThoughts", false)
|
||||
return result, nil
|
||||
}
|
||||
|
||||
// Determine includeThoughts: respect user's explicit setting from original body if provided
|
||||
// Support both camelCase and snake_case variants
|
||||
var includeThoughts bool
|
||||
var userSetIncludeThoughts bool
|
||||
if inc := gjson.GetBytes(body, "request.generationConfig.thinkingConfig.includeThoughts"); inc.Exists() {
|
||||
includeThoughts = inc.Bool()
|
||||
userSetIncludeThoughts = true
|
||||
} else if inc := gjson.GetBytes(body, "request.generationConfig.thinkingConfig.include_thoughts"); inc.Exists() {
|
||||
includeThoughts = inc.Bool()
|
||||
userSetIncludeThoughts = true
|
||||
}
|
||||
|
||||
if !userSetIncludeThoughts {
|
||||
// No explicit setting, use default logic based on mode
|
||||
switch config.Mode {
|
||||
case thinking.ModeAuto:
|
||||
includeThoughts = true
|
||||
default:
|
||||
includeThoughts = budget > 0
|
||||
}
|
||||
}
|
||||
|
||||
result, _ = sjson.SetBytes(result, "request.generationConfig.thinkingConfig.thinkingBudget", budget)
|
||||
result, _ = sjson.SetBytes(result, "request.generationConfig.thinkingConfig.includeThoughts", includeThoughts)
|
||||
return result, nil
|
||||
|
||||
@@ -118,8 +118,10 @@ func (a *Applier) applyLevelFormat(body []byte, config thinking.ThinkingConfig)
|
||||
// - ModeNone + Budget>0: forced to think but hide output (includeThoughts=false)
|
||||
// ValidateConfig sets config.Level to the lowest level when ModeNone + Budget > 0.
|
||||
|
||||
// Remove conflicting field to avoid both thinkingLevel and thinkingBudget in output
|
||||
// Remove conflicting fields to avoid both thinkingLevel and thinkingBudget in output
|
||||
result, _ := sjson.DeleteBytes(body, "generationConfig.thinkingConfig.thinkingBudget")
|
||||
result, _ = sjson.DeleteBytes(result, "generationConfig.thinkingConfig.thinking_budget")
|
||||
result, _ = sjson.DeleteBytes(result, "generationConfig.thinkingConfig.thinking_level")
|
||||
// Normalize includeThoughts field name to avoid oneof conflicts in upstream JSON parsing.
|
||||
result, _ = sjson.DeleteBytes(result, "generationConfig.thinkingConfig.include_thoughts")
|
||||
|
||||
@@ -138,29 +140,58 @@ func (a *Applier) applyLevelFormat(body []byte, config thinking.ThinkingConfig)
|
||||
|
||||
level := string(config.Level)
|
||||
result, _ = sjson.SetBytes(result, "generationConfig.thinkingConfig.thinkingLevel", level)
|
||||
result, _ = sjson.SetBytes(result, "generationConfig.thinkingConfig.includeThoughts", true)
|
||||
|
||||
// Respect user's explicit includeThoughts setting from original body; default to true if not set
|
||||
// Support both camelCase and snake_case variants
|
||||
includeThoughts := true
|
||||
if inc := gjson.GetBytes(body, "generationConfig.thinkingConfig.includeThoughts"); inc.Exists() {
|
||||
includeThoughts = inc.Bool()
|
||||
} else if inc := gjson.GetBytes(body, "generationConfig.thinkingConfig.include_thoughts"); inc.Exists() {
|
||||
includeThoughts = inc.Bool()
|
||||
}
|
||||
result, _ = sjson.SetBytes(result, "generationConfig.thinkingConfig.includeThoughts", includeThoughts)
|
||||
return result, nil
|
||||
}
|
||||
|
||||
func (a *Applier) applyBudgetFormat(body []byte, config thinking.ThinkingConfig) ([]byte, error) {
|
||||
// Remove conflicting field to avoid both thinkingLevel and thinkingBudget in output
|
||||
// Remove conflicting fields to avoid both thinkingLevel and thinkingBudget in output
|
||||
result, _ := sjson.DeleteBytes(body, "generationConfig.thinkingConfig.thinkingLevel")
|
||||
result, _ = sjson.DeleteBytes(result, "generationConfig.thinkingConfig.thinking_level")
|
||||
result, _ = sjson.DeleteBytes(result, "generationConfig.thinkingConfig.thinking_budget")
|
||||
// Normalize includeThoughts field name to avoid oneof conflicts in upstream JSON parsing.
|
||||
result, _ = sjson.DeleteBytes(result, "generationConfig.thinkingConfig.include_thoughts")
|
||||
|
||||
budget := config.Budget
|
||||
// ModeNone semantics:
|
||||
// - ModeNone + Budget=0: completely disable thinking
|
||||
// - ModeNone + Budget>0: forced to think but hide output (includeThoughts=false)
|
||||
// When ZeroAllowed=false, ValidateConfig clamps Budget to Min while preserving ModeNone.
|
||||
includeThoughts := false
|
||||
switch config.Mode {
|
||||
case thinking.ModeNone:
|
||||
includeThoughts = false
|
||||
case thinking.ModeAuto:
|
||||
includeThoughts = true
|
||||
default:
|
||||
includeThoughts = budget > 0
|
||||
|
||||
// For ModeNone, always set includeThoughts to false regardless of user setting.
|
||||
// This ensures that when user requests budget=0 (disable thinking output),
|
||||
// the includeThoughts is correctly set to false even if budget is clamped to min.
|
||||
if config.Mode == thinking.ModeNone {
|
||||
result, _ = sjson.SetBytes(result, "generationConfig.thinkingConfig.thinkingBudget", budget)
|
||||
result, _ = sjson.SetBytes(result, "generationConfig.thinkingConfig.includeThoughts", false)
|
||||
return result, nil
|
||||
}
|
||||
|
||||
// Determine includeThoughts: respect user's explicit setting from original body if provided
|
||||
// Support both camelCase and snake_case variants
|
||||
var includeThoughts bool
|
||||
var userSetIncludeThoughts bool
|
||||
if inc := gjson.GetBytes(body, "generationConfig.thinkingConfig.includeThoughts"); inc.Exists() {
|
||||
includeThoughts = inc.Bool()
|
||||
userSetIncludeThoughts = true
|
||||
} else if inc := gjson.GetBytes(body, "generationConfig.thinkingConfig.include_thoughts"); inc.Exists() {
|
||||
includeThoughts = inc.Bool()
|
||||
userSetIncludeThoughts = true
|
||||
}
|
||||
|
||||
if !userSetIncludeThoughts {
|
||||
// No explicit setting, use default logic based on mode
|
||||
switch config.Mode {
|
||||
case thinking.ModeAuto:
|
||||
includeThoughts = true
|
||||
default:
|
||||
includeThoughts = budget > 0
|
||||
}
|
||||
}
|
||||
|
||||
result, _ = sjson.SetBytes(result, "generationConfig.thinkingConfig.thinkingBudget", budget)
|
||||
|
||||
@@ -79,8 +79,10 @@ func (a *Applier) applyCompatible(body []byte, config thinking.ThinkingConfig) (
|
||||
}
|
||||
|
||||
func (a *Applier) applyLevelFormat(body []byte, config thinking.ThinkingConfig) ([]byte, error) {
|
||||
// Remove conflicting field to avoid both thinkingLevel and thinkingBudget in output
|
||||
// Remove conflicting fields to avoid both thinkingLevel and thinkingBudget in output
|
||||
result, _ := sjson.DeleteBytes(body, "request.generationConfig.thinkingConfig.thinkingBudget")
|
||||
result, _ = sjson.DeleteBytes(result, "request.generationConfig.thinkingConfig.thinking_budget")
|
||||
result, _ = sjson.DeleteBytes(result, "request.generationConfig.thinkingConfig.thinking_level")
|
||||
// Normalize includeThoughts field name to avoid oneof conflicts in upstream JSON parsing.
|
||||
result, _ = sjson.DeleteBytes(result, "request.generationConfig.thinkingConfig.include_thoughts")
|
||||
|
||||
@@ -99,25 +101,58 @@ func (a *Applier) applyLevelFormat(body []byte, config thinking.ThinkingConfig)
|
||||
|
||||
level := string(config.Level)
|
||||
result, _ = sjson.SetBytes(result, "request.generationConfig.thinkingConfig.thinkingLevel", level)
|
||||
result, _ = sjson.SetBytes(result, "request.generationConfig.thinkingConfig.includeThoughts", true)
|
||||
|
||||
// Respect user's explicit includeThoughts setting from original body; default to true if not set
|
||||
// Support both camelCase and snake_case variants
|
||||
includeThoughts := true
|
||||
if inc := gjson.GetBytes(body, "request.generationConfig.thinkingConfig.includeThoughts"); inc.Exists() {
|
||||
includeThoughts = inc.Bool()
|
||||
} else if inc := gjson.GetBytes(body, "request.generationConfig.thinkingConfig.include_thoughts"); inc.Exists() {
|
||||
includeThoughts = inc.Bool()
|
||||
}
|
||||
result, _ = sjson.SetBytes(result, "request.generationConfig.thinkingConfig.includeThoughts", includeThoughts)
|
||||
return result, nil
|
||||
}
|
||||
|
||||
func (a *Applier) applyBudgetFormat(body []byte, config thinking.ThinkingConfig) ([]byte, error) {
|
||||
// Remove conflicting field to avoid both thinkingLevel and thinkingBudget in output
|
||||
// Remove conflicting fields to avoid both thinkingLevel and thinkingBudget in output
|
||||
result, _ := sjson.DeleteBytes(body, "request.generationConfig.thinkingConfig.thinkingLevel")
|
||||
result, _ = sjson.DeleteBytes(result, "request.generationConfig.thinkingConfig.thinking_level")
|
||||
result, _ = sjson.DeleteBytes(result, "request.generationConfig.thinkingConfig.thinking_budget")
|
||||
// Normalize includeThoughts field name to avoid oneof conflicts in upstream JSON parsing.
|
||||
result, _ = sjson.DeleteBytes(result, "request.generationConfig.thinkingConfig.include_thoughts")
|
||||
|
||||
budget := config.Budget
|
||||
includeThoughts := false
|
||||
switch config.Mode {
|
||||
case thinking.ModeNone:
|
||||
includeThoughts = false
|
||||
case thinking.ModeAuto:
|
||||
includeThoughts = true
|
||||
default:
|
||||
includeThoughts = budget > 0
|
||||
|
||||
// For ModeNone, always set includeThoughts to false regardless of user setting.
|
||||
// This ensures that when user requests budget=0 (disable thinking output),
|
||||
// the includeThoughts is correctly set to false even if budget is clamped to min.
|
||||
if config.Mode == thinking.ModeNone {
|
||||
result, _ = sjson.SetBytes(result, "request.generationConfig.thinkingConfig.thinkingBudget", budget)
|
||||
result, _ = sjson.SetBytes(result, "request.generationConfig.thinkingConfig.includeThoughts", false)
|
||||
return result, nil
|
||||
}
|
||||
|
||||
// Determine includeThoughts: respect user's explicit setting from original body if provided
|
||||
// Support both camelCase and snake_case variants
|
||||
var includeThoughts bool
|
||||
var userSetIncludeThoughts bool
|
||||
if inc := gjson.GetBytes(body, "request.generationConfig.thinkingConfig.includeThoughts"); inc.Exists() {
|
||||
includeThoughts = inc.Bool()
|
||||
userSetIncludeThoughts = true
|
||||
} else if inc := gjson.GetBytes(body, "request.generationConfig.thinkingConfig.include_thoughts"); inc.Exists() {
|
||||
includeThoughts = inc.Bool()
|
||||
userSetIncludeThoughts = true
|
||||
}
|
||||
|
||||
if !userSetIncludeThoughts {
|
||||
// No explicit setting, use default logic based on mode
|
||||
switch config.Mode {
|
||||
case thinking.ModeAuto:
|
||||
includeThoughts = true
|
||||
default:
|
||||
includeThoughts = budget > 0
|
||||
}
|
||||
}
|
||||
|
||||
result, _ = sjson.SetBytes(result, "request.generationConfig.thinkingConfig.thinkingBudget", budget)
|
||||
|
||||
126
internal/thinking/provider/kimi/apply.go
Normal file
126
internal/thinking/provider/kimi/apply.go
Normal file
@@ -0,0 +1,126 @@
|
||||
// Package kimi implements thinking configuration for Kimi (Moonshot AI) models.
|
||||
//
|
||||
// Kimi models use the OpenAI-compatible reasoning_effort format with discrete levels
|
||||
// (low/medium/high). The provider strips any existing thinking config and applies
|
||||
// the unified ThinkingConfig in OpenAI format.
|
||||
package kimi
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
|
||||
"github.com/router-for-me/CLIProxyAPI/v6/internal/registry"
|
||||
"github.com/router-for-me/CLIProxyAPI/v6/internal/thinking"
|
||||
"github.com/tidwall/gjson"
|
||||
"github.com/tidwall/sjson"
|
||||
)
|
||||
|
||||
// Applier implements thinking.ProviderApplier for Kimi models.
|
||||
//
|
||||
// Kimi-specific behavior:
|
||||
// - Output format: reasoning_effort (string: low/medium/high)
|
||||
// - Uses OpenAI-compatible format
|
||||
// - Supports budget-to-level conversion
|
||||
type Applier struct{}
|
||||
|
||||
var _ thinking.ProviderApplier = (*Applier)(nil)
|
||||
|
||||
// NewApplier creates a new Kimi thinking applier.
|
||||
func NewApplier() *Applier {
|
||||
return &Applier{}
|
||||
}
|
||||
|
||||
func init() {
|
||||
thinking.RegisterProvider("kimi", NewApplier())
|
||||
}
|
||||
|
||||
// Apply applies thinking configuration to Kimi request body.
|
||||
//
|
||||
// Expected output format:
|
||||
//
|
||||
// {
|
||||
// "reasoning_effort": "high"
|
||||
// }
|
||||
func (a *Applier) Apply(body []byte, config thinking.ThinkingConfig, modelInfo *registry.ModelInfo) ([]byte, error) {
|
||||
if thinking.IsUserDefinedModel(modelInfo) {
|
||||
return applyCompatibleKimi(body, config)
|
||||
}
|
||||
if modelInfo.Thinking == nil {
|
||||
return body, nil
|
||||
}
|
||||
|
||||
if len(body) == 0 || !gjson.ValidBytes(body) {
|
||||
body = []byte(`{}`)
|
||||
}
|
||||
|
||||
var effort string
|
||||
switch config.Mode {
|
||||
case thinking.ModeLevel:
|
||||
if config.Level == "" {
|
||||
return body, nil
|
||||
}
|
||||
effort = string(config.Level)
|
||||
case thinking.ModeNone:
|
||||
// Kimi uses "none" to disable thinking
|
||||
effort = string(thinking.LevelNone)
|
||||
case thinking.ModeBudget:
|
||||
// Convert budget to level using threshold mapping
|
||||
level, ok := thinking.ConvertBudgetToLevel(config.Budget)
|
||||
if !ok {
|
||||
return body, nil
|
||||
}
|
||||
effort = level
|
||||
case thinking.ModeAuto:
|
||||
// Auto mode maps to "auto" effort
|
||||
effort = string(thinking.LevelAuto)
|
||||
default:
|
||||
return body, nil
|
||||
}
|
||||
|
||||
if effort == "" {
|
||||
return body, nil
|
||||
}
|
||||
|
||||
result, err := sjson.SetBytes(body, "reasoning_effort", effort)
|
||||
if err != nil {
|
||||
return body, fmt.Errorf("kimi thinking: failed to set reasoning_effort: %w", err)
|
||||
}
|
||||
return result, nil
|
||||
}
|
||||
|
||||
// applyCompatibleKimi applies thinking config for user-defined Kimi models.
|
||||
func applyCompatibleKimi(body []byte, config thinking.ThinkingConfig) ([]byte, error) {
|
||||
if len(body) == 0 || !gjson.ValidBytes(body) {
|
||||
body = []byte(`{}`)
|
||||
}
|
||||
|
||||
var effort string
|
||||
switch config.Mode {
|
||||
case thinking.ModeLevel:
|
||||
if config.Level == "" {
|
||||
return body, nil
|
||||
}
|
||||
effort = string(config.Level)
|
||||
case thinking.ModeNone:
|
||||
effort = string(thinking.LevelNone)
|
||||
if config.Level != "" {
|
||||
effort = string(config.Level)
|
||||
}
|
||||
case thinking.ModeAuto:
|
||||
effort = string(thinking.LevelAuto)
|
||||
case thinking.ModeBudget:
|
||||
// Convert budget to level
|
||||
level, ok := thinking.ConvertBudgetToLevel(config.Budget)
|
||||
if !ok {
|
||||
return body, nil
|
||||
}
|
||||
effort = level
|
||||
default:
|
||||
return body, nil
|
||||
}
|
||||
|
||||
result, err := sjson.SetBytes(body, "reasoning_effort", effort)
|
||||
if err != nil {
|
||||
return body, fmt.Errorf("kimi thinking: failed to set reasoning_effort: %w", err)
|
||||
}
|
||||
return result, nil
|
||||
}
|
||||
@@ -6,7 +6,6 @@
|
||||
package claude
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"strings"
|
||||
|
||||
"github.com/router-for-me/CLIProxyAPI/v6/internal/cache"
|
||||
@@ -37,7 +36,7 @@ import (
|
||||
// - []byte: The transformed request data in Gemini CLI API format
|
||||
func ConvertClaudeRequestToAntigravity(modelName string, inputRawJSON []byte, _ bool) []byte {
|
||||
enableThoughtTranslate := true
|
||||
rawJSON := bytes.Clone(inputRawJSON)
|
||||
rawJSON := inputRawJSON
|
||||
|
||||
// system instruction
|
||||
systemInstructionJSON := ""
|
||||
@@ -115,7 +114,7 @@ func ConvertClaudeRequestToAntigravity(modelName string, inputRawJSON []byte, _
|
||||
if signatureResult.Exists() && signatureResult.String() != "" {
|
||||
arrayClientSignatures := strings.SplitN(signatureResult.String(), "#", 2)
|
||||
if len(arrayClientSignatures) == 2 {
|
||||
if modelName == arrayClientSignatures[0] {
|
||||
if cache.GetModelGroup(modelName) == arrayClientSignatures[0] {
|
||||
clientSignature = arrayClientSignatures[1]
|
||||
}
|
||||
}
|
||||
@@ -345,7 +344,8 @@ func ConvertClaudeRequestToAntigravity(modelName string, inputRawJSON []byte, _
|
||||
// Inject interleaved thinking hint when both tools and thinking are active
|
||||
hasTools := toolDeclCount > 0
|
||||
thinkingResult := gjson.GetBytes(rawJSON, "thinking")
|
||||
hasThinking := thinkingResult.Exists() && thinkingResult.IsObject() && thinkingResult.Get("type").String() == "enabled"
|
||||
thinkingType := thinkingResult.Get("type").String()
|
||||
hasThinking := thinkingResult.Exists() && thinkingResult.IsObject() && (thinkingType == "enabled" || thinkingType == "adaptive")
|
||||
isClaudeThinking := util.IsClaudeThinkingModel(modelName)
|
||||
|
||||
if hasTools && hasThinking && isClaudeThinking {
|
||||
@@ -378,12 +378,18 @@ func ConvertClaudeRequestToAntigravity(modelName string, inputRawJSON []byte, _
|
||||
|
||||
// Map Anthropic thinking -> Gemini thinkingBudget/include_thoughts when type==enabled
|
||||
if t := gjson.GetBytes(rawJSON, "thinking"); enableThoughtTranslate && t.Exists() && t.IsObject() {
|
||||
if t.Get("type").String() == "enabled" {
|
||||
switch t.Get("type").String() {
|
||||
case "enabled":
|
||||
if b := t.Get("budget_tokens"); b.Exists() && b.Type == gjson.Number {
|
||||
budget := int(b.Int())
|
||||
out, _ = sjson.Set(out, "request.generationConfig.thinkingConfig.thinkingBudget", budget)
|
||||
out, _ = sjson.Set(out, "request.generationConfig.thinkingConfig.includeThoughts", true)
|
||||
}
|
||||
case "adaptive":
|
||||
// Keep adaptive as a high level sentinel; ApplyThinking resolves it
|
||||
// to model-specific max capability.
|
||||
out, _ = sjson.Set(out, "request.generationConfig.thinkingConfig.thinkingLevel", "high")
|
||||
out, _ = sjson.Set(out, "request.generationConfig.thinkingConfig.includeThoughts", true)
|
||||
}
|
||||
}
|
||||
if v := gjson.GetBytes(rawJSON, "temperature"); v.Exists() && v.Type == gjson.Number {
|
||||
|
||||
@@ -6,7 +6,6 @@
|
||||
package gemini
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"fmt"
|
||||
"strings"
|
||||
|
||||
@@ -34,7 +33,7 @@ import (
|
||||
// Returns:
|
||||
// - []byte: The transformed request data in Gemini API format
|
||||
func ConvertGeminiRequestToAntigravity(modelName string, inputRawJSON []byte, _ bool) []byte {
|
||||
rawJSON := bytes.Clone(inputRawJSON)
|
||||
rawJSON := inputRawJSON
|
||||
template := ""
|
||||
template = `{"project":"","request":{},"model":""}`
|
||||
template, _ = sjson.SetRaw(template, "request", string(rawJSON))
|
||||
|
||||
@@ -3,7 +3,6 @@
|
||||
package chat_completions
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"fmt"
|
||||
"strings"
|
||||
|
||||
@@ -28,7 +27,7 @@ const geminiCLIFunctionThoughtSignature = "skip_thought_signature_validator"
|
||||
// Returns:
|
||||
// - []byte: The transformed request data in Gemini CLI API format
|
||||
func ConvertOpenAIRequestToAntigravity(modelName string, inputRawJSON []byte, _ bool) []byte {
|
||||
rawJSON := bytes.Clone(inputRawJSON)
|
||||
rawJSON := inputRawJSON
|
||||
// Base envelope (no default thinkingConfig)
|
||||
out := []byte(`{"project":"","request":{"contents":[]},"model":"gemini-2.5-pro"}`)
|
||||
|
||||
|
||||
@@ -1,14 +1,12 @@
|
||||
package responses
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
|
||||
. "github.com/router-for-me/CLIProxyAPI/v6/internal/translator/antigravity/gemini"
|
||||
. "github.com/router-for-me/CLIProxyAPI/v6/internal/translator/gemini/openai/responses"
|
||||
)
|
||||
|
||||
func ConvertOpenAIResponsesRequestToAntigravity(modelName string, inputRawJSON []byte, stream bool) []byte {
|
||||
rawJSON := bytes.Clone(inputRawJSON)
|
||||
rawJSON := inputRawJSON
|
||||
rawJSON = ConvertOpenAIResponsesRequestToGemini(modelName, rawJSON, stream)
|
||||
return ConvertGeminiRequestToAntigravity(modelName, rawJSON, stream)
|
||||
}
|
||||
|
||||
@@ -6,8 +6,6 @@
|
||||
package geminiCLI
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
|
||||
. "github.com/router-for-me/CLIProxyAPI/v6/internal/translator/claude/gemini"
|
||||
"github.com/tidwall/gjson"
|
||||
"github.com/tidwall/sjson"
|
||||
@@ -30,7 +28,7 @@ import (
|
||||
// Returns:
|
||||
// - []byte: The transformed request data in Claude Code API format
|
||||
func ConvertGeminiCLIRequestToClaude(modelName string, inputRawJSON []byte, stream bool) []byte {
|
||||
rawJSON := bytes.Clone(inputRawJSON)
|
||||
rawJSON := inputRawJSON
|
||||
|
||||
modelResult := gjson.GetBytes(rawJSON, "model")
|
||||
// Extract the inner request object and promote it to the top level
|
||||
|
||||
@@ -6,7 +6,6 @@
|
||||
package gemini
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"crypto/rand"
|
||||
"crypto/sha256"
|
||||
"encoding/hex"
|
||||
@@ -46,7 +45,7 @@ var (
|
||||
// Returns:
|
||||
// - []byte: The transformed request data in Claude Code API format
|
||||
func ConvertGeminiRequestToClaude(modelName string, inputRawJSON []byte, stream bool) []byte {
|
||||
rawJSON := bytes.Clone(inputRawJSON)
|
||||
rawJSON := inputRawJSON
|
||||
|
||||
if account == "" {
|
||||
u, _ := uuid.NewRandom()
|
||||
@@ -116,7 +115,11 @@ func ConvertGeminiRequestToClaude(modelName string, inputRawJSON []byte, stream
|
||||
// Include thoughts configuration for reasoning process visibility
|
||||
// Translator only does format conversion, ApplyThinking handles model capability validation.
|
||||
if thinkingConfig := genConfig.Get("thinkingConfig"); thinkingConfig.Exists() && thinkingConfig.IsObject() {
|
||||
if thinkingLevel := thinkingConfig.Get("thinkingLevel"); thinkingLevel.Exists() {
|
||||
thinkingLevel := thinkingConfig.Get("thinkingLevel")
|
||||
if !thinkingLevel.Exists() {
|
||||
thinkingLevel = thinkingConfig.Get("thinking_level")
|
||||
}
|
||||
if thinkingLevel.Exists() {
|
||||
level := strings.ToLower(strings.TrimSpace(thinkingLevel.String()))
|
||||
switch level {
|
||||
case "":
|
||||
@@ -132,23 +135,29 @@ func ConvertGeminiRequestToClaude(modelName string, inputRawJSON []byte, stream
|
||||
out, _ = sjson.Set(out, "thinking.budget_tokens", budget)
|
||||
}
|
||||
}
|
||||
} else if thinkingBudget := thinkingConfig.Get("thinkingBudget"); thinkingBudget.Exists() {
|
||||
budget := int(thinkingBudget.Int())
|
||||
switch budget {
|
||||
case 0:
|
||||
out, _ = sjson.Set(out, "thinking.type", "disabled")
|
||||
out, _ = sjson.Delete(out, "thinking.budget_tokens")
|
||||
case -1:
|
||||
out, _ = sjson.Set(out, "thinking.type", "enabled")
|
||||
out, _ = sjson.Delete(out, "thinking.budget_tokens")
|
||||
default:
|
||||
out, _ = sjson.Set(out, "thinking.type", "enabled")
|
||||
out, _ = sjson.Set(out, "thinking.budget_tokens", budget)
|
||||
} else {
|
||||
thinkingBudget := thinkingConfig.Get("thinkingBudget")
|
||||
if !thinkingBudget.Exists() {
|
||||
thinkingBudget = thinkingConfig.Get("thinking_budget")
|
||||
}
|
||||
if thinkingBudget.Exists() {
|
||||
budget := int(thinkingBudget.Int())
|
||||
switch budget {
|
||||
case 0:
|
||||
out, _ = sjson.Set(out, "thinking.type", "disabled")
|
||||
out, _ = sjson.Delete(out, "thinking.budget_tokens")
|
||||
case -1:
|
||||
out, _ = sjson.Set(out, "thinking.type", "enabled")
|
||||
out, _ = sjson.Delete(out, "thinking.budget_tokens")
|
||||
default:
|
||||
out, _ = sjson.Set(out, "thinking.type", "enabled")
|
||||
out, _ = sjson.Set(out, "thinking.budget_tokens", budget)
|
||||
}
|
||||
} else if includeThoughts := thinkingConfig.Get("includeThoughts"); includeThoughts.Exists() && includeThoughts.Type == gjson.True {
|
||||
out, _ = sjson.Set(out, "thinking.type", "enabled")
|
||||
} else if includeThoughts := thinkingConfig.Get("include_thoughts"); includeThoughts.Exists() && includeThoughts.Type == gjson.True {
|
||||
out, _ = sjson.Set(out, "thinking.type", "enabled")
|
||||
}
|
||||
} else if includeThoughts := thinkingConfig.Get("includeThoughts"); includeThoughts.Exists() && includeThoughts.Type == gjson.True {
|
||||
out, _ = sjson.Set(out, "thinking.type", "enabled")
|
||||
} else if includeThoughts := thinkingConfig.Get("include_thoughts"); includeThoughts.Exists() && includeThoughts.Type == gjson.True {
|
||||
out, _ = sjson.Set(out, "thinking.type", "enabled")
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
@@ -6,7 +6,6 @@
|
||||
package chat_completions
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"crypto/rand"
|
||||
"crypto/sha256"
|
||||
"encoding/hex"
|
||||
@@ -44,7 +43,7 @@ var (
|
||||
// Returns:
|
||||
// - []byte: The transformed request data in Claude Code API format
|
||||
func ConvertOpenAIRequestToClaude(modelName string, inputRawJSON []byte, stream bool) []byte {
|
||||
rawJSON := bytes.Clone(inputRawJSON)
|
||||
rawJSON := inputRawJSON
|
||||
|
||||
if account == "" {
|
||||
u, _ := uuid.NewRandom()
|
||||
|
||||
@@ -1,7 +1,6 @@
|
||||
package responses
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"crypto/rand"
|
||||
"crypto/sha256"
|
||||
"encoding/hex"
|
||||
@@ -32,7 +31,7 @@ var (
|
||||
// - max_output_tokens -> max_tokens
|
||||
// - stream passthrough via parameter
|
||||
func ConvertOpenAIResponsesRequestToClaude(modelName string, inputRawJSON []byte, stream bool) []byte {
|
||||
rawJSON := bytes.Clone(inputRawJSON)
|
||||
rawJSON := inputRawJSON
|
||||
|
||||
if account == "" {
|
||||
u, _ := uuid.NewRandom()
|
||||
|
||||
@@ -6,7 +6,6 @@
|
||||
package claude
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"fmt"
|
||||
"strconv"
|
||||
"strings"
|
||||
@@ -35,7 +34,7 @@ import (
|
||||
// Returns:
|
||||
// - []byte: The transformed request data in internal client format
|
||||
func ConvertClaudeRequestToCodex(modelName string, inputRawJSON []byte, _ bool) []byte {
|
||||
rawJSON := bytes.Clone(inputRawJSON)
|
||||
rawJSON := inputRawJSON
|
||||
|
||||
template := `{"model":"","instructions":"","input":[]}`
|
||||
|
||||
@@ -223,6 +222,10 @@ func ConvertClaudeRequestToCodex(modelName string, inputRawJSON []byte, _ bool)
|
||||
reasoningEffort = effort
|
||||
}
|
||||
}
|
||||
case "adaptive":
|
||||
// Claude adaptive means "enable with max capacity"; keep it as highest level
|
||||
// and let ApplyThinking normalize per target model capability.
|
||||
reasoningEffort = string(thinking.LevelXHigh)
|
||||
case "disabled":
|
||||
if effort, ok := thinking.ConvertBudgetToLevel(0); ok && effort != "" {
|
||||
reasoningEffort = effort
|
||||
|
||||
@@ -113,10 +113,10 @@ func ConvertCodexResponseToClaude(_ context.Context, _ string, originalRequestRa
|
||||
template = `{"type":"message_delta","delta":{"stop_reason":"tool_use","stop_sequence":null},"usage":{"input_tokens":0,"output_tokens":0}}`
|
||||
p := (*param).(*ConvertCodexResponseToClaudeParams).HasToolCall
|
||||
stopReason := rootResult.Get("response.stop_reason").String()
|
||||
if stopReason != "" {
|
||||
template, _ = sjson.Set(template, "delta.stop_reason", stopReason)
|
||||
} else if p {
|
||||
if p {
|
||||
template, _ = sjson.Set(template, "delta.stop_reason", "tool_use")
|
||||
} else if stopReason == "max_tokens" || stopReason == "stop" {
|
||||
template, _ = sjson.Set(template, "delta.stop_reason", stopReason)
|
||||
} else {
|
||||
template, _ = sjson.Set(template, "delta.stop_reason", "end_turn")
|
||||
}
|
||||
|
||||
@@ -6,8 +6,6 @@
|
||||
package geminiCLI
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
|
||||
. "github.com/router-for-me/CLIProxyAPI/v6/internal/translator/codex/gemini"
|
||||
"github.com/tidwall/gjson"
|
||||
"github.com/tidwall/sjson"
|
||||
@@ -30,7 +28,7 @@ import (
|
||||
// Returns:
|
||||
// - []byte: The transformed request data in Codex API format
|
||||
func ConvertGeminiCLIRequestToCodex(modelName string, inputRawJSON []byte, stream bool) []byte {
|
||||
rawJSON := bytes.Clone(inputRawJSON)
|
||||
rawJSON := inputRawJSON
|
||||
|
||||
rawJSON = []byte(gjson.GetBytes(rawJSON, "request").Raw)
|
||||
rawJSON, _ = sjson.SetBytes(rawJSON, "model", modelName)
|
||||
|
||||
@@ -6,7 +6,6 @@
|
||||
package gemini
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"crypto/rand"
|
||||
"fmt"
|
||||
"math/big"
|
||||
@@ -37,7 +36,7 @@ import (
|
||||
// Returns:
|
||||
// - []byte: The transformed request data in Codex API format
|
||||
func ConvertGeminiRequestToCodex(modelName string, inputRawJSON []byte, _ bool) []byte {
|
||||
rawJSON := bytes.Clone(inputRawJSON)
|
||||
rawJSON := inputRawJSON
|
||||
// Base template
|
||||
out := `{"model":"","instructions":"","input":[]}`
|
||||
|
||||
@@ -243,19 +242,30 @@ func ConvertGeminiRequestToCodex(modelName string, inputRawJSON []byte, _ bool)
|
||||
out, _ = sjson.Set(out, "parallel_tool_calls", true)
|
||||
|
||||
// Convert Gemini thinkingConfig to Codex reasoning.effort.
|
||||
// Note: Google official Python SDK sends snake_case fields (thinking_level/thinking_budget).
|
||||
effortSet := false
|
||||
if genConfig := root.Get("generationConfig"); genConfig.Exists() {
|
||||
if thinkingConfig := genConfig.Get("thinkingConfig"); thinkingConfig.Exists() && thinkingConfig.IsObject() {
|
||||
if thinkingLevel := thinkingConfig.Get("thinkingLevel"); thinkingLevel.Exists() {
|
||||
thinkingLevel := thinkingConfig.Get("thinkingLevel")
|
||||
if !thinkingLevel.Exists() {
|
||||
thinkingLevel = thinkingConfig.Get("thinking_level")
|
||||
}
|
||||
if thinkingLevel.Exists() {
|
||||
effort := strings.ToLower(strings.TrimSpace(thinkingLevel.String()))
|
||||
if effort != "" {
|
||||
out, _ = sjson.Set(out, "reasoning.effort", effort)
|
||||
effortSet = true
|
||||
}
|
||||
} else if thinkingBudget := thinkingConfig.Get("thinkingBudget"); thinkingBudget.Exists() {
|
||||
if effort, ok := thinking.ConvertBudgetToLevel(int(thinkingBudget.Int())); ok {
|
||||
out, _ = sjson.Set(out, "reasoning.effort", effort)
|
||||
effortSet = true
|
||||
} else {
|
||||
thinkingBudget := thinkingConfig.Get("thinkingBudget")
|
||||
if !thinkingBudget.Exists() {
|
||||
thinkingBudget = thinkingConfig.Get("thinking_budget")
|
||||
}
|
||||
if thinkingBudget.Exists() {
|
||||
if effort, ok := thinking.ConvertBudgetToLevel(int(thinkingBudget.Int())); ok {
|
||||
out, _ = sjson.Set(out, "reasoning.effort", effort)
|
||||
effortSet = true
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
@@ -7,8 +7,6 @@
|
||||
package chat_completions
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
|
||||
"strconv"
|
||||
"strings"
|
||||
|
||||
@@ -29,7 +27,7 @@ import (
|
||||
// Returns:
|
||||
// - []byte: The transformed request data in OpenAI Responses API format
|
||||
func ConvertOpenAIRequestToCodex(modelName string, inputRawJSON []byte, stream bool) []byte {
|
||||
rawJSON := bytes.Clone(inputRawJSON)
|
||||
rawJSON := inputRawJSON
|
||||
// Start with empty JSON object
|
||||
out := `{"instructions":""}`
|
||||
|
||||
|
||||
@@ -1,7 +1,6 @@
|
||||
package responses
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"fmt"
|
||||
|
||||
"github.com/tidwall/gjson"
|
||||
@@ -9,7 +8,13 @@ import (
|
||||
)
|
||||
|
||||
func ConvertOpenAIResponsesRequestToCodex(modelName string, inputRawJSON []byte, _ bool) []byte {
|
||||
rawJSON := bytes.Clone(inputRawJSON)
|
||||
rawJSON := inputRawJSON
|
||||
|
||||
inputResult := gjson.GetBytes(rawJSON, "input")
|
||||
if inputResult.Type == gjson.String {
|
||||
input, _ := sjson.Set(`[{"type":"message","role":"user","content":[{"type":"input_text","text":""}]}]`, "0.content.0.text", inputResult.String())
|
||||
rawJSON, _ = sjson.SetRawBytes(rawJSON, "input", []byte(input))
|
||||
}
|
||||
|
||||
rawJSON, _ = sjson.SetBytes(rawJSON, "stream", true)
|
||||
rawJSON, _ = sjson.SetBytes(rawJSON, "store", false)
|
||||
|
||||
@@ -35,7 +35,7 @@ const geminiCLIClaudeThoughtSignature = "skip_thought_signature_validator"
|
||||
// Returns:
|
||||
// - []byte: The transformed request data in Gemini CLI API format
|
||||
func ConvertClaudeRequestToCLI(modelName string, inputRawJSON []byte, _ bool) []byte {
|
||||
rawJSON := bytes.Clone(inputRawJSON)
|
||||
rawJSON := inputRawJSON
|
||||
rawJSON = bytes.Replace(rawJSON, []byte(`"url":{"type":"string","format":"uri",`), []byte(`"url":{"type":"string",`), -1)
|
||||
|
||||
// Build output Gemini CLI request JSON
|
||||
@@ -116,6 +116,19 @@ func ConvertClaudeRequestToCLI(modelName string, inputRawJSON []byte, _ bool) []
|
||||
part, _ = sjson.Set(part, "functionResponse.name", funcName)
|
||||
part, _ = sjson.Set(part, "functionResponse.response.result", responseData)
|
||||
contentJSON, _ = sjson.SetRaw(contentJSON, "parts.-1", part)
|
||||
|
||||
case "image":
|
||||
source := contentResult.Get("source")
|
||||
if source.Get("type").String() == "base64" {
|
||||
mimeType := source.Get("media_type").String()
|
||||
data := source.Get("data").String()
|
||||
if mimeType != "" && data != "" {
|
||||
part := `{"inlineData":{"mime_type":"","data":""}}`
|
||||
part, _ = sjson.Set(part, "inlineData.mime_type", mimeType)
|
||||
part, _ = sjson.Set(part, "inlineData.data", data)
|
||||
contentJSON, _ = sjson.SetRaw(contentJSON, "parts.-1", part)
|
||||
}
|
||||
}
|
||||
}
|
||||
return true
|
||||
})
|
||||
@@ -160,12 +173,18 @@ func ConvertClaudeRequestToCLI(modelName string, inputRawJSON []byte, _ bool) []
|
||||
|
||||
// Map Anthropic thinking -> Gemini thinkingBudget/include_thoughts when type==enabled
|
||||
if t := gjson.GetBytes(rawJSON, "thinking"); t.Exists() && t.IsObject() {
|
||||
if t.Get("type").String() == "enabled" {
|
||||
switch t.Get("type").String() {
|
||||
case "enabled":
|
||||
if b := t.Get("budget_tokens"); b.Exists() && b.Type == gjson.Number {
|
||||
budget := int(b.Int())
|
||||
out, _ = sjson.Set(out, "request.generationConfig.thinkingConfig.thinkingBudget", budget)
|
||||
out, _ = sjson.Set(out, "request.generationConfig.thinkingConfig.includeThoughts", true)
|
||||
}
|
||||
case "adaptive":
|
||||
// Keep adaptive as a high level sentinel; ApplyThinking resolves it
|
||||
// to model-specific max capability.
|
||||
out, _ = sjson.Set(out, "request.generationConfig.thinkingConfig.thinkingLevel", "high")
|
||||
out, _ = sjson.Set(out, "request.generationConfig.thinkingConfig.includeThoughts", true)
|
||||
}
|
||||
}
|
||||
if v := gjson.GetBytes(rawJSON, "temperature"); v.Exists() && v.Type == gjson.Number {
|
||||
|
||||
@@ -6,7 +6,6 @@
|
||||
package gemini
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"fmt"
|
||||
|
||||
"github.com/router-for-me/CLIProxyAPI/v6/internal/translator/gemini/common"
|
||||
@@ -33,7 +32,7 @@ import (
|
||||
// Returns:
|
||||
// - []byte: The transformed request data in Gemini API format
|
||||
func ConvertGeminiRequestToGeminiCLI(_ string, inputRawJSON []byte, _ bool) []byte {
|
||||
rawJSON := bytes.Clone(inputRawJSON)
|
||||
rawJSON := inputRawJSON
|
||||
template := ""
|
||||
template = `{"project":"","request":{},"model":""}`
|
||||
template, _ = sjson.SetRaw(template, "request", string(rawJSON))
|
||||
|
||||
@@ -3,7 +3,6 @@
|
||||
package chat_completions
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"fmt"
|
||||
"strings"
|
||||
|
||||
@@ -28,7 +27,7 @@ const geminiCLIFunctionThoughtSignature = "skip_thought_signature_validator"
|
||||
// Returns:
|
||||
// - []byte: The transformed request data in Gemini CLI API format
|
||||
func ConvertOpenAIRequestToGeminiCLI(modelName string, inputRawJSON []byte, _ bool) []byte {
|
||||
rawJSON := bytes.Clone(inputRawJSON)
|
||||
rawJSON := inputRawJSON
|
||||
// Base envelope (no default thinkingConfig)
|
||||
out := []byte(`{"project":"","request":{"contents":[]},"model":"gemini-2.5-pro"}`)
|
||||
|
||||
|
||||
@@ -14,6 +14,7 @@ import (
|
||||
"time"
|
||||
|
||||
. "github.com/router-for-me/CLIProxyAPI/v6/internal/translator/gemini/openai/chat-completions"
|
||||
log "github.com/sirupsen/logrus"
|
||||
"github.com/tidwall/gjson"
|
||||
"github.com/tidwall/sjson"
|
||||
)
|
||||
@@ -77,14 +78,20 @@ func ConvertCliResponseToOpenAI(_ context.Context, _ string, originalRequestRawJ
|
||||
template, _ = sjson.Set(template, "id", responseIDResult.String())
|
||||
}
|
||||
|
||||
// Extract and set the finish reason.
|
||||
if finishReasonResult := gjson.GetBytes(rawJSON, "response.candidates.0.finishReason"); finishReasonResult.Exists() {
|
||||
template, _ = sjson.Set(template, "choices.0.finish_reason", strings.ToLower(finishReasonResult.String()))
|
||||
template, _ = sjson.Set(template, "choices.0.native_finish_reason", strings.ToLower(finishReasonResult.String()))
|
||||
finishReason := ""
|
||||
if stopReasonResult := gjson.GetBytes(rawJSON, "response.stop_reason"); stopReasonResult.Exists() {
|
||||
finishReason = stopReasonResult.String()
|
||||
}
|
||||
if finishReason == "" {
|
||||
if finishReasonResult := gjson.GetBytes(rawJSON, "response.candidates.0.finishReason"); finishReasonResult.Exists() {
|
||||
finishReason = finishReasonResult.String()
|
||||
}
|
||||
}
|
||||
finishReason = strings.ToLower(finishReason)
|
||||
|
||||
// Extract and set usage metadata (token counts).
|
||||
if usageResult := gjson.GetBytes(rawJSON, "response.usageMetadata"); usageResult.Exists() {
|
||||
cachedTokenCount := usageResult.Get("cachedContentTokenCount").Int()
|
||||
if candidatesTokenCountResult := usageResult.Get("candidatesTokenCount"); candidatesTokenCountResult.Exists() {
|
||||
template, _ = sjson.Set(template, "usage.completion_tokens", candidatesTokenCountResult.Int())
|
||||
}
|
||||
@@ -97,6 +104,14 @@ func ConvertCliResponseToOpenAI(_ context.Context, _ string, originalRequestRawJ
|
||||
if thoughtsTokenCount > 0 {
|
||||
template, _ = sjson.Set(template, "usage.completion_tokens_details.reasoning_tokens", thoughtsTokenCount)
|
||||
}
|
||||
// Include cached token count if present (indicates prompt caching is working)
|
||||
if cachedTokenCount > 0 {
|
||||
var err error
|
||||
template, err = sjson.Set(template, "usage.prompt_tokens_details.cached_tokens", cachedTokenCount)
|
||||
if err != nil {
|
||||
log.Warnf("gemini-cli openai response: failed to set cached_tokens: %v", err)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// Process the main content part of the response.
|
||||
@@ -187,6 +202,12 @@ func ConvertCliResponseToOpenAI(_ context.Context, _ string, originalRequestRawJ
|
||||
if hasFunctionCall {
|
||||
template, _ = sjson.Set(template, "choices.0.finish_reason", "tool_calls")
|
||||
template, _ = sjson.Set(template, "choices.0.native_finish_reason", "tool_calls")
|
||||
} else if finishReason != "" && (*param).(*convertCliResponseToOpenAIChatParams).FunctionIndex == 0 {
|
||||
// Only pass through specific finish reasons
|
||||
if finishReason == "max_tokens" || finishReason == "stop" {
|
||||
template, _ = sjson.Set(template, "choices.0.finish_reason", finishReason)
|
||||
template, _ = sjson.Set(template, "choices.0.native_finish_reason", finishReason)
|
||||
}
|
||||
}
|
||||
|
||||
return []string{template}
|
||||
|
||||
@@ -1,14 +1,12 @@
|
||||
package responses
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
|
||||
. "github.com/router-for-me/CLIProxyAPI/v6/internal/translator/gemini-cli/gemini"
|
||||
. "github.com/router-for-me/CLIProxyAPI/v6/internal/translator/gemini/openai/responses"
|
||||
)
|
||||
|
||||
func ConvertOpenAIResponsesRequestToGeminiCLI(modelName string, inputRawJSON []byte, stream bool) []byte {
|
||||
rawJSON := bytes.Clone(inputRawJSON)
|
||||
rawJSON := inputRawJSON
|
||||
rawJSON = ConvertOpenAIResponsesRequestToGemini(modelName, rawJSON, stream)
|
||||
return ConvertGeminiRequestToGeminiCLI(modelName, rawJSON, stream)
|
||||
}
|
||||
|
||||
@@ -28,7 +28,7 @@ const geminiClaudeThoughtSignature = "skip_thought_signature_validator"
|
||||
// Returns:
|
||||
// - []byte: The transformed request in Gemini CLI format.
|
||||
func ConvertClaudeRequestToGemini(modelName string, inputRawJSON []byte, _ bool) []byte {
|
||||
rawJSON := bytes.Clone(inputRawJSON)
|
||||
rawJSON := inputRawJSON
|
||||
rawJSON = bytes.Replace(rawJSON, []byte(`"url":{"type":"string","format":"uri",`), []byte(`"url":{"type":"string",`), -1)
|
||||
|
||||
// Build output Gemini CLI request JSON
|
||||
@@ -154,12 +154,18 @@ func ConvertClaudeRequestToGemini(modelName string, inputRawJSON []byte, _ bool)
|
||||
// Map Anthropic thinking -> Gemini thinkingBudget/include_thoughts when enabled
|
||||
// Translator only does format conversion, ApplyThinking handles model capability validation.
|
||||
if t := gjson.GetBytes(rawJSON, "thinking"); t.Exists() && t.IsObject() {
|
||||
if t.Get("type").String() == "enabled" {
|
||||
switch t.Get("type").String() {
|
||||
case "enabled":
|
||||
if b := t.Get("budget_tokens"); b.Exists() && b.Type == gjson.Number {
|
||||
budget := int(b.Int())
|
||||
out, _ = sjson.Set(out, "generationConfig.thinkingConfig.thinkingBudget", budget)
|
||||
out, _ = sjson.Set(out, "generationConfig.thinkingConfig.includeThoughts", true)
|
||||
}
|
||||
case "adaptive":
|
||||
// Keep adaptive as a high level sentinel; ApplyThinking resolves it
|
||||
// to model-specific max capability.
|
||||
out, _ = sjson.Set(out, "generationConfig.thinkingConfig.thinkingLevel", "high")
|
||||
out, _ = sjson.Set(out, "generationConfig.thinkingConfig.includeThoughts", true)
|
||||
}
|
||||
}
|
||||
if v := gjson.GetBytes(rawJSON, "temperature"); v.Exists() && v.Type == gjson.Number {
|
||||
|
||||
@@ -6,7 +6,6 @@
|
||||
package geminiCLI
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"fmt"
|
||||
|
||||
"github.com/router-for-me/CLIProxyAPI/v6/internal/translator/gemini/common"
|
||||
@@ -19,7 +18,7 @@ import (
|
||||
// It extracts the model name, system instruction, message contents, and tool declarations
|
||||
// from the raw JSON request and returns them in the format expected by the internal client.
|
||||
func ConvertGeminiCLIRequestToGemini(_ string, inputRawJSON []byte, _ bool) []byte {
|
||||
rawJSON := bytes.Clone(inputRawJSON)
|
||||
rawJSON := inputRawJSON
|
||||
modelResult := gjson.GetBytes(rawJSON, "model")
|
||||
rawJSON = []byte(gjson.GetBytes(rawJSON, "request").Raw)
|
||||
rawJSON, _ = sjson.SetBytes(rawJSON, "model", modelResult.String())
|
||||
|
||||
@@ -4,7 +4,6 @@
|
||||
package gemini
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"fmt"
|
||||
|
||||
"github.com/router-for-me/CLIProxyAPI/v6/internal/translator/gemini/common"
|
||||
@@ -19,7 +18,7 @@ import (
|
||||
//
|
||||
// It keeps the payload otherwise unchanged.
|
||||
func ConvertGeminiRequestToGemini(_ string, inputRawJSON []byte, _ bool) []byte {
|
||||
rawJSON := bytes.Clone(inputRawJSON)
|
||||
rawJSON := inputRawJSON
|
||||
// Fast path: if no contents field, only attach safety settings
|
||||
contents := gjson.GetBytes(rawJSON, "contents")
|
||||
if !contents.Exists() {
|
||||
|
||||
@@ -3,7 +3,6 @@
|
||||
package chat_completions
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"fmt"
|
||||
"strings"
|
||||
|
||||
@@ -28,7 +27,7 @@ const geminiFunctionThoughtSignature = "skip_thought_signature_validator"
|
||||
// Returns:
|
||||
// - []byte: The transformed request data in Gemini API format
|
||||
func ConvertOpenAIRequestToGemini(modelName string, inputRawJSON []byte, _ bool) []byte {
|
||||
rawJSON := bytes.Clone(inputRawJSON)
|
||||
rawJSON := inputRawJSON
|
||||
// Base envelope (no default thinkingConfig)
|
||||
out := []byte(`{"contents":[]}`)
|
||||
|
||||
|
||||
@@ -129,11 +129,16 @@ func ConvertGeminiResponseToOpenAI(_ context.Context, _ string, originalRequestR
|
||||
candidateIndex := int(candidate.Get("index").Int())
|
||||
template, _ = sjson.Set(template, "choices.0.index", candidateIndex)
|
||||
|
||||
// Extract and set the finish reason.
|
||||
if finishReasonResult := candidate.Get("finishReason"); finishReasonResult.Exists() {
|
||||
template, _ = sjson.Set(template, "choices.0.finish_reason", strings.ToLower(finishReasonResult.String()))
|
||||
template, _ = sjson.Set(template, "choices.0.native_finish_reason", strings.ToLower(finishReasonResult.String()))
|
||||
finishReason := ""
|
||||
if stopReasonResult := gjson.GetBytes(rawJSON, "stop_reason"); stopReasonResult.Exists() {
|
||||
finishReason = stopReasonResult.String()
|
||||
}
|
||||
if finishReason == "" {
|
||||
if finishReasonResult := gjson.GetBytes(rawJSON, "candidates.0.finishReason"); finishReasonResult.Exists() {
|
||||
finishReason = finishReasonResult.String()
|
||||
}
|
||||
}
|
||||
finishReason = strings.ToLower(finishReason)
|
||||
|
||||
partsResult := candidate.Get("content.parts")
|
||||
hasFunctionCall := false
|
||||
@@ -225,6 +230,12 @@ func ConvertGeminiResponseToOpenAI(_ context.Context, _ string, originalRequestR
|
||||
if hasFunctionCall {
|
||||
template, _ = sjson.Set(template, "choices.0.finish_reason", "tool_calls")
|
||||
template, _ = sjson.Set(template, "choices.0.native_finish_reason", "tool_calls")
|
||||
} else if finishReason != "" {
|
||||
// Only pass through specific finish reasons
|
||||
if finishReason == "max_tokens" || finishReason == "stop" {
|
||||
template, _ = sjson.Set(template, "choices.0.finish_reason", finishReason)
|
||||
template, _ = sjson.Set(template, "choices.0.native_finish_reason", finishReason)
|
||||
}
|
||||
}
|
||||
|
||||
responseStrings = append(responseStrings, template)
|
||||
|
||||
@@ -1,7 +1,6 @@
|
||||
package responses
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"strings"
|
||||
|
||||
"github.com/router-for-me/CLIProxyAPI/v6/internal/translator/gemini/common"
|
||||
@@ -12,7 +11,7 @@ import (
|
||||
const geminiResponsesThoughtSignature = "skip_thought_signature_validator"
|
||||
|
||||
func ConvertOpenAIResponsesRequestToGemini(modelName string, inputRawJSON []byte, stream bool) []byte {
|
||||
rawJSON := bytes.Clone(inputRawJSON)
|
||||
rawJSON := inputRawJSON
|
||||
|
||||
// Note: modelName and stream parameters are part of the fixed method signature
|
||||
_ = modelName // Unused but required by interface
|
||||
@@ -118,19 +117,29 @@ func ConvertOpenAIResponsesRequestToGemini(modelName string, inputRawJSON []byte
|
||||
switch itemType {
|
||||
case "message":
|
||||
if strings.EqualFold(itemRole, "system") {
|
||||
if contentArray := item.Get("content"); contentArray.Exists() && contentArray.IsArray() {
|
||||
var builder strings.Builder
|
||||
contentArray.ForEach(func(_, contentItem gjson.Result) bool {
|
||||
text := contentItem.Get("text").String()
|
||||
if builder.Len() > 0 && text != "" {
|
||||
builder.WriteByte('\n')
|
||||
}
|
||||
builder.WriteString(text)
|
||||
return true
|
||||
})
|
||||
if !gjson.Get(out, "system_instruction").Exists() {
|
||||
systemInstr := `{"parts":[{"text":""}]}`
|
||||
systemInstr, _ = sjson.Set(systemInstr, "parts.0.text", builder.String())
|
||||
if contentArray := item.Get("content"); contentArray.Exists() {
|
||||
systemInstr := ""
|
||||
if systemInstructionResult := gjson.Get(out, "system_instruction"); systemInstructionResult.Exists() {
|
||||
systemInstr = systemInstructionResult.Raw
|
||||
} else {
|
||||
systemInstr = `{"parts":[]}`
|
||||
}
|
||||
|
||||
if contentArray.IsArray() {
|
||||
contentArray.ForEach(func(_, contentItem gjson.Result) bool {
|
||||
part := `{"text":""}`
|
||||
text := contentItem.Get("text").String()
|
||||
part, _ = sjson.Set(part, "text", text)
|
||||
systemInstr, _ = sjson.SetRaw(systemInstr, "parts.-1", part)
|
||||
return true
|
||||
})
|
||||
} else if contentArray.Type == gjson.String {
|
||||
part := `{"text":""}`
|
||||
part, _ = sjson.Set(part, "text", contentArray.String())
|
||||
systemInstr, _ = sjson.SetRaw(systemInstr, "parts.-1", part)
|
||||
}
|
||||
|
||||
if systemInstr != `{"parts":[]}` {
|
||||
out, _ = sjson.SetRaw(out, "system_instruction", systemInstr)
|
||||
}
|
||||
}
|
||||
@@ -237,8 +246,22 @@ func ConvertOpenAIResponsesRequestToGemini(modelName string, inputRawJSON []byte
|
||||
})
|
||||
|
||||
flush()
|
||||
}
|
||||
} else if contentArray.Type == gjson.String {
|
||||
effRole := "user"
|
||||
if itemRole != "" {
|
||||
switch strings.ToLower(itemRole) {
|
||||
case "assistant", "model":
|
||||
effRole = "model"
|
||||
default:
|
||||
effRole = strings.ToLower(itemRole)
|
||||
}
|
||||
}
|
||||
|
||||
one := `{"role":"","parts":[{"text":""}]}`
|
||||
one, _ = sjson.Set(one, "role", effRole)
|
||||
one, _ = sjson.Set(one, "parts.0.text", contentArray.String())
|
||||
out, _ = sjson.SetRaw(out, "contents.-1", one)
|
||||
}
|
||||
case "function_call":
|
||||
// Handle function calls - convert to model message with functionCall
|
||||
name := item.Get("name").String()
|
||||
|
||||
@@ -17,6 +17,9 @@ import (
|
||||
"github.com/tidwall/gjson"
|
||||
)
|
||||
|
||||
// remoteWebSearchDescription is a minimal fallback for when dynamic fetch from MCP tools/list hasn't completed yet.
|
||||
const remoteWebSearchDescription = "WebSearch looks up information outside the model's training data. Supports multiple queries to gather comprehensive information."
|
||||
|
||||
// Kiro API request structs - field order determines JSON key order
|
||||
|
||||
// KiroPayload is the top-level request structure for Kiro API
|
||||
@@ -115,9 +118,11 @@ type KiroAssistantResponseMessage struct {
|
||||
|
||||
// KiroToolUse represents a tool invocation by the assistant
|
||||
type KiroToolUse struct {
|
||||
ToolUseID string `json:"toolUseId"`
|
||||
Name string `json:"name"`
|
||||
Input map[string]interface{} `json:"input"`
|
||||
ToolUseID string `json:"toolUseId"`
|
||||
Name string `json:"name"`
|
||||
Input map[string]interface{} `json:"input"`
|
||||
IsTruncated bool `json:"-"` // Internal flag, not serialized
|
||||
TruncationInfo *TruncationInfo `json:"-"` // Truncation details, not serialized
|
||||
}
|
||||
|
||||
// ConvertClaudeRequestToKiro converts a Claude API request to Kiro format.
|
||||
@@ -217,35 +222,16 @@ func BuildKiroPayload(claudeBody []byte, modelID, profileArn, origin string, isA
|
||||
}
|
||||
|
||||
// Convert Claude tools to Kiro format
|
||||
kiroTools, hasWebSearch := convertClaudeToolsToKiro(tools)
|
||||
|
||||
// If web_search was requested but filtered, inject alternative hint
|
||||
if hasWebSearch {
|
||||
webSearchHint := `[CRITICAL WEB ACCESS INSTRUCTION]
|
||||
You have the Fetch/read_url_content tool available. When the user asks about current events, weather, news, or any information that requires web access:
|
||||
- DO NOT say you cannot search the web
|
||||
- DO NOT refuse to help with web-related queries
|
||||
- IMMEDIATELY use the Fetch tool to access relevant URLs
|
||||
- Use well-known official websites, documentation sites, or API endpoints
|
||||
- Construct appropriate URLs based on the query context
|
||||
|
||||
IMPORTANT: Always attempt to fetch information FIRST before declining. You CAN access the web via Fetch.`
|
||||
if systemPrompt != "" {
|
||||
systemPrompt += "\n" + webSearchHint
|
||||
} else {
|
||||
systemPrompt = webSearchHint
|
||||
}
|
||||
log.Infof("kiro: injected web_search alternative hint (tool was filtered)")
|
||||
}
|
||||
kiroTools := convertClaudeToolsToKiro(tools)
|
||||
|
||||
// Thinking mode implementation:
|
||||
// Kiro API supports official thinking/reasoning mode via <thinking_mode> tag.
|
||||
// When set to "enabled", Kiro returns reasoning content as official reasoningContentEvent
|
||||
// rather than inline <thinking> tags in assistantResponseEvent.
|
||||
// We use a high max_thinking_length to allow extensive reasoning.
|
||||
// We cap max_thinking_length to reserve space for tool outputs and prevent truncation.
|
||||
if thinkingEnabled {
|
||||
thinkingHint := `<thinking_mode>enabled</thinking_mode>
|
||||
<max_thinking_length>200000</max_thinking_length>`
|
||||
<max_thinking_length>16000</max_thinking_length>`
|
||||
if systemPrompt != "" {
|
||||
systemPrompt = thinkingHint + "\n\n" + systemPrompt
|
||||
} else {
|
||||
@@ -525,27 +511,15 @@ func ensureKiroInputSchema(parameters interface{}) interface{} {
|
||||
}
|
||||
}
|
||||
|
||||
// convertClaudeToolsToKiro converts Claude tools to Kiro format.
|
||||
// Returns the converted tools and a boolean indicating if web_search was filtered.
|
||||
func convertClaudeToolsToKiro(tools gjson.Result) ([]KiroToolWrapper, bool) {
|
||||
// convertClaudeToolsToKiro converts Claude tools to Kiro format
|
||||
func convertClaudeToolsToKiro(tools gjson.Result) []KiroToolWrapper {
|
||||
var kiroTools []KiroToolWrapper
|
||||
hasWebSearch := false
|
||||
if !tools.IsArray() {
|
||||
return kiroTools, hasWebSearch
|
||||
return kiroTools
|
||||
}
|
||||
|
||||
for _, tool := range tools.Array() {
|
||||
name := tool.Get("name").String()
|
||||
|
||||
// Filter out web_search/websearch tools (Kiro API doesn't support them)
|
||||
// This matches the behavior in AIClient-2-API/claude-kiro.js
|
||||
nameLower := strings.ToLower(name)
|
||||
if nameLower == "web_search" || nameLower == "websearch" {
|
||||
log.Debugf("kiro: skipping unsupported tool: %s", name)
|
||||
hasWebSearch = true
|
||||
continue
|
||||
}
|
||||
|
||||
description := tool.Get("description").String()
|
||||
inputSchemaResult := tool.Get("input_schema")
|
||||
var inputSchema interface{}
|
||||
@@ -567,6 +541,18 @@ func convertClaudeToolsToKiro(tools gjson.Result) ([]KiroToolWrapper, bool) {
|
||||
log.Debugf("kiro: tool '%s' has empty description, using default: %s", name, description)
|
||||
}
|
||||
|
||||
// Rename web_search → remote_web_search for Kiro API compatibility
|
||||
if name == "web_search" {
|
||||
name = "remote_web_search"
|
||||
// Prefer dynamically fetched description, fall back to hardcoded constant
|
||||
if cached := GetWebSearchDescription(); cached != "" {
|
||||
description = cached
|
||||
} else {
|
||||
description = remoteWebSearchDescription
|
||||
}
|
||||
log.Debugf("kiro: renamed tool web_search → remote_web_search")
|
||||
}
|
||||
|
||||
// Truncate long descriptions (individual tool limit)
|
||||
if len(description) > kirocommon.KiroMaxToolDescLen {
|
||||
truncLen := kirocommon.KiroMaxToolDescLen - 30
|
||||
@@ -589,7 +575,7 @@ func convertClaudeToolsToKiro(tools gjson.Result) ([]KiroToolWrapper, bool) {
|
||||
// This prevents 500 errors when Claude Code sends too many tools
|
||||
kiroTools = compressToolsIfNeeded(kiroTools)
|
||||
|
||||
return kiroTools, hasWebSearch
|
||||
return kiroTools
|
||||
}
|
||||
|
||||
// processMessages processes Claude messages and builds Kiro history
|
||||
@@ -606,18 +592,22 @@ func processMessages(messages gjson.Result, modelID, origin string) ([]KiroHisto
|
||||
|
||||
if role == "user" {
|
||||
userMsg, toolResults := BuildUserMessageStruct(msg, modelID, origin)
|
||||
// CRITICAL: Kiro API requires content to be non-empty for ALL user messages
|
||||
// This includes both history messages and the current message.
|
||||
// When user message contains only tool_result (no text), content will be empty.
|
||||
// This commonly happens in compaction requests from OpenCode.
|
||||
if strings.TrimSpace(userMsg.Content) == "" {
|
||||
if len(toolResults) > 0 {
|
||||
userMsg.Content = kirocommon.DefaultUserContentWithToolResults
|
||||
} else {
|
||||
userMsg.Content = kirocommon.DefaultUserContent
|
||||
}
|
||||
log.Debugf("kiro: user content was empty, using default: %s", userMsg.Content)
|
||||
}
|
||||
if isLastMessage {
|
||||
currentUserMsg = &userMsg
|
||||
currentToolResults = toolResults
|
||||
} else {
|
||||
// CRITICAL: Kiro API requires content to be non-empty for history messages too
|
||||
if strings.TrimSpace(userMsg.Content) == "" {
|
||||
if len(toolResults) > 0 {
|
||||
userMsg.Content = "Tool results provided."
|
||||
} else {
|
||||
userMsg.Content = "Continue"
|
||||
}
|
||||
}
|
||||
// For history messages, embed tool results in context
|
||||
if len(toolResults) > 0 {
|
||||
userMsg.UserInputMessageContext = &KiroUserInputMessageContext{
|
||||
@@ -648,6 +638,57 @@ func processMessages(messages gjson.Result, modelID, origin string) ([]KiroHisto
|
||||
}
|
||||
}
|
||||
|
||||
// POST-PROCESSING: Remove orphaned tool_results that have no matching tool_use
|
||||
// in any assistant message. This happens when Claude Code compaction truncates
|
||||
// the conversation and removes the assistant message containing the tool_use,
|
||||
// but keeps the user message with the corresponding tool_result.
|
||||
// Without this fix, Kiro API returns "Improperly formed request".
|
||||
validToolUseIDs := make(map[string]bool)
|
||||
for _, h := range history {
|
||||
if h.AssistantResponseMessage != nil {
|
||||
for _, tu := range h.AssistantResponseMessage.ToolUses {
|
||||
validToolUseIDs[tu.ToolUseID] = true
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// Filter orphaned tool results from history user messages
|
||||
for i, h := range history {
|
||||
if h.UserInputMessage != nil && h.UserInputMessage.UserInputMessageContext != nil {
|
||||
ctx := h.UserInputMessage.UserInputMessageContext
|
||||
if len(ctx.ToolResults) > 0 {
|
||||
filtered := make([]KiroToolResult, 0, len(ctx.ToolResults))
|
||||
for _, tr := range ctx.ToolResults {
|
||||
if validToolUseIDs[tr.ToolUseID] {
|
||||
filtered = append(filtered, tr)
|
||||
} else {
|
||||
log.Debugf("kiro: dropping orphaned tool_result in history[%d]: toolUseId=%s (no matching tool_use)", i, tr.ToolUseID)
|
||||
}
|
||||
}
|
||||
ctx.ToolResults = filtered
|
||||
if len(ctx.ToolResults) == 0 && len(ctx.Tools) == 0 {
|
||||
h.UserInputMessage.UserInputMessageContext = nil
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// Filter orphaned tool results from current message
|
||||
if len(currentToolResults) > 0 {
|
||||
filtered := make([]KiroToolResult, 0, len(currentToolResults))
|
||||
for _, tr := range currentToolResults {
|
||||
if validToolUseIDs[tr.ToolUseID] {
|
||||
filtered = append(filtered, tr)
|
||||
} else {
|
||||
log.Debugf("kiro: dropping orphaned tool_result in currentMessage: toolUseId=%s (no matching tool_use)", tr.ToolUseID)
|
||||
}
|
||||
}
|
||||
if len(filtered) != len(currentToolResults) {
|
||||
log.Infof("kiro: dropped %d orphaned tool_result(s) from currentMessage (compaction artifact)", len(currentToolResults)-len(filtered))
|
||||
}
|
||||
currentToolResults = filtered
|
||||
}
|
||||
|
||||
return history, currentUserMsg, currentToolResults
|
||||
}
|
||||
|
||||
@@ -771,7 +812,35 @@ func BuildUserMessageStruct(msg gjson.Result, modelID, origin string) (KiroUserI
|
||||
resultContent := part.Get("content")
|
||||
|
||||
var textContents []KiroTextContent
|
||||
if resultContent.IsArray() {
|
||||
|
||||
// Check if this tool_result contains error from our SOFT_LIMIT_REACHED tool_use
|
||||
// The client will return an error when trying to execute a tool with marker input
|
||||
resultStr := resultContent.String()
|
||||
isSoftLimitError := strings.Contains(resultStr, "SOFT_LIMIT_REACHED") ||
|
||||
strings.Contains(resultStr, "_status") ||
|
||||
strings.Contains(resultStr, "truncated") ||
|
||||
strings.Contains(resultStr, "missing required") ||
|
||||
strings.Contains(resultStr, "invalid input") ||
|
||||
strings.Contains(resultStr, "Error writing file")
|
||||
|
||||
if isError && isSoftLimitError {
|
||||
// Replace error content with SOFT_LIMIT_REACHED guidance
|
||||
log.Infof("kiro: detected SOFT_LIMIT_REACHED in tool_result for %s, replacing with guidance", toolUseID)
|
||||
softLimitMsg := `SOFT_LIMIT_REACHED
|
||||
|
||||
Your previous tool call was incomplete due to API output size limits.
|
||||
The content was PARTIALLY transmitted but NOT executed.
|
||||
|
||||
REQUIRED ACTION:
|
||||
1. Split your content into smaller chunks (max 300 lines per call)
|
||||
2. For file writes: Create file with first chunk, then use append for remaining
|
||||
3. Do NOT regenerate content you already attempted - continue from where you stopped
|
||||
|
||||
STATUS: This is NOT an error. Continue with smaller chunks.`
|
||||
textContents = append(textContents, KiroTextContent{Text: softLimitMsg})
|
||||
// Mark as SUCCESS so Claude doesn't treat it as a failure
|
||||
isError = false
|
||||
} else if resultContent.IsArray() {
|
||||
for _, item := range resultContent.Array() {
|
||||
if item.Get("type").String() == "text" {
|
||||
textContents = append(textContents, KiroTextContent{Text: item.Get("text").String()})
|
||||
@@ -842,6 +911,11 @@ func BuildAssistantMessageStruct(msg gjson.Result) KiroAssistantResponseMessage
|
||||
})
|
||||
}
|
||||
|
||||
// Rename web_search → remote_web_search to match convertClaudeToolsToKiro
|
||||
if toolName == "web_search" {
|
||||
toolName = "remote_web_search"
|
||||
}
|
||||
|
||||
toolUses = append(toolUses, KiroToolUse{
|
||||
ToolUseID: toolUseID,
|
||||
Name: toolName,
|
||||
@@ -853,8 +927,21 @@ func BuildAssistantMessageStruct(msg gjson.Result) KiroAssistantResponseMessage
|
||||
contentBuilder.WriteString(content.String())
|
||||
}
|
||||
|
||||
// CRITICAL FIX: Kiro API requires non-empty content for assistant messages
|
||||
// This can happen with compaction requests where assistant messages have only tool_use
|
||||
// (no text content). Without this fix, Kiro API returns "Improperly formed request" error.
|
||||
finalContent := contentBuilder.String()
|
||||
if strings.TrimSpace(finalContent) == "" {
|
||||
if len(toolUses) > 0 {
|
||||
finalContent = kirocommon.DefaultAssistantContentWithTools
|
||||
} else {
|
||||
finalContent = kirocommon.DefaultAssistantContent
|
||||
}
|
||||
log.Debugf("kiro: assistant content was empty, using default: %s", finalContent)
|
||||
}
|
||||
|
||||
return KiroAssistantResponseMessage{
|
||||
Content: contentBuilder.String(),
|
||||
Content: finalContent,
|
||||
ToolUses: toolUses,
|
||||
}
|
||||
}
|
||||
|
||||
@@ -55,14 +55,39 @@ func BuildClaudeResponse(content string, toolUses []KiroToolUse, model string, u
|
||||
}
|
||||
}
|
||||
|
||||
// Add tool_use blocks
|
||||
// Add tool_use blocks - emit truncated tools with SOFT_LIMIT_REACHED marker
|
||||
hasTruncatedTools := false
|
||||
for _, toolUse := range toolUses {
|
||||
contentBlocks = append(contentBlocks, map[string]interface{}{
|
||||
"type": "tool_use",
|
||||
"id": toolUse.ToolUseID,
|
||||
"name": toolUse.Name,
|
||||
"input": toolUse.Input,
|
||||
})
|
||||
if toolUse.IsTruncated && toolUse.TruncationInfo != nil {
|
||||
// Emit tool_use with SOFT_LIMIT_REACHED marker input
|
||||
hasTruncatedTools = true
|
||||
log.Infof("kiro: buildClaudeResponse emitting truncated tool with SOFT_LIMIT_REACHED: %s (ID: %s)", toolUse.Name, toolUse.ToolUseID)
|
||||
|
||||
markerInput := map[string]interface{}{
|
||||
"_status": "SOFT_LIMIT_REACHED",
|
||||
"_message": "Tool output was truncated. Split content into smaller chunks (max 300 lines). Due to potential model hallucination, you MUST re-fetch the current working directory and generate the correct file_path.",
|
||||
}
|
||||
|
||||
contentBlocks = append(contentBlocks, map[string]interface{}{
|
||||
"type": "tool_use",
|
||||
"id": toolUse.ToolUseID,
|
||||
"name": toolUse.Name,
|
||||
"input": markerInput,
|
||||
})
|
||||
} else {
|
||||
// Normal tool use
|
||||
contentBlocks = append(contentBlocks, map[string]interface{}{
|
||||
"type": "tool_use",
|
||||
"id": toolUse.ToolUseID,
|
||||
"name": toolUse.Name,
|
||||
"input": toolUse.Input,
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
// Log if we used SOFT_LIMIT_REACHED
|
||||
if hasTruncatedTools {
|
||||
log.Infof("kiro: buildClaudeResponse using SOFT_LIMIT_REACHED - keeping stop_reason=tool_use")
|
||||
}
|
||||
|
||||
// Ensure at least one content block (Claude API requires non-empty content)
|
||||
@@ -74,6 +99,7 @@ func BuildClaudeResponse(content string, toolUses []KiroToolUse, model string, u
|
||||
}
|
||||
|
||||
// Use upstream stopReason; apply fallback logic if not provided
|
||||
// SOFT_LIMIT_REACHED: Keep stop_reason = "tool_use" so Claude continues the loop
|
||||
if stopReason == "" {
|
||||
stopReason = "end_turn"
|
||||
if len(toolUses) > 0 {
|
||||
@@ -201,4 +227,4 @@ func ExtractThinkingFromContent(content string) []map[string]interface{} {
|
||||
}
|
||||
|
||||
return blocks
|
||||
}
|
||||
}
|
||||
|
||||
@@ -14,10 +14,11 @@ import (
|
||||
|
||||
// ToolUseState tracks the state of an in-progress tool use during streaming.
|
||||
type ToolUseState struct {
|
||||
ToolUseID string
|
||||
Name string
|
||||
InputBuffer strings.Builder
|
||||
IsComplete bool
|
||||
ToolUseID string
|
||||
Name string
|
||||
InputBuffer strings.Builder
|
||||
IsComplete bool
|
||||
TruncationInfo *TruncationInfo // Truncation detection result (set when complete)
|
||||
}
|
||||
|
||||
// Pre-compiled regex patterns for performance
|
||||
@@ -395,17 +396,6 @@ func ProcessToolUseEvent(event map[string]interface{}, currentToolUse *ToolUseSt
|
||||
isStop = stop
|
||||
}
|
||||
|
||||
// Debug: log when stop event arrives
|
||||
if isStop {
|
||||
log.Debugf("kiro: toolUseEvent stop=true received for tool %s (ID: %s), currentToolUse buffer len: %d",
|
||||
toolName, toolUseID, func() int {
|
||||
if currentToolUse != nil {
|
||||
return currentToolUse.InputBuffer.Len()
|
||||
}
|
||||
return -1
|
||||
}())
|
||||
}
|
||||
|
||||
// Get input - can be string (fragment) or object (complete)
|
||||
var inputFragment string
|
||||
var inputMap map[string]interface{}
|
||||
@@ -477,98 +467,39 @@ func ProcessToolUseEvent(event map[string]interface{}, currentToolUse *ToolUseSt
|
||||
if isStop && currentToolUse != nil {
|
||||
fullInput := currentToolUse.InputBuffer.String()
|
||||
|
||||
// Check for Write tool with empty or missing input - this happens when Kiro API
|
||||
// completely skips sending input for large file writes
|
||||
if currentToolUse.Name == "Write" && len(strings.TrimSpace(fullInput)) == 0 {
|
||||
log.Warnf("kiro: Write tool received no input from upstream API. The file content may be too large to transmit.")
|
||||
// Return nil to skip this tool use - it will be handled as a truncation error
|
||||
// The caller should emit a text block explaining the error instead
|
||||
if processedIDs != nil {
|
||||
processedIDs[currentToolUse.ToolUseID] = true
|
||||
}
|
||||
log.Infof("kiro: skipping Write tool use %s due to empty input (content too large)", currentToolUse.ToolUseID)
|
||||
// Return a special marker tool use that indicates truncation
|
||||
toolUse := KiroToolUse{
|
||||
ToolUseID: currentToolUse.ToolUseID,
|
||||
Name: "__truncated_write__", // Special marker name
|
||||
Input: map[string]interface{}{
|
||||
"error": "Write tool input was not transmitted by upstream API. The file content is too large.",
|
||||
},
|
||||
}
|
||||
toolUses = append(toolUses, toolUse)
|
||||
return toolUses, nil
|
||||
}
|
||||
|
||||
// Repair and parse the accumulated JSON
|
||||
repairedJSON := RepairJSON(fullInput)
|
||||
var finalInput map[string]interface{}
|
||||
if err := json.Unmarshal([]byte(repairedJSON), &finalInput); err != nil {
|
||||
log.Warnf("kiro: failed to parse accumulated tool input: %v, raw: %s", err, fullInput)
|
||||
finalInput = make(map[string]interface{})
|
||||
|
||||
// Check if this is a Write tool with truncated input (missing content field)
|
||||
// This happens when the Kiro API truncates large tool inputs
|
||||
if currentToolUse.Name == "Write" && strings.Contains(fullInput, "file_path") && !strings.Contains(fullInput, "content") {
|
||||
log.Warnf("kiro: Write tool input was truncated by upstream API (content field missing). The file content may be too large.")
|
||||
// Extract file_path if possible for error context
|
||||
filePath := ""
|
||||
if idx := strings.Index(fullInput, "file_path"); idx >= 0 {
|
||||
// Try to extract the file path value
|
||||
rest := fullInput[idx:]
|
||||
if colonIdx := strings.Index(rest, ":"); colonIdx >= 0 {
|
||||
rest = strings.TrimSpace(rest[colonIdx+1:])
|
||||
if len(rest) > 0 && rest[0] == '"' {
|
||||
rest = rest[1:]
|
||||
if endQuote := strings.Index(rest, "\""); endQuote >= 0 {
|
||||
filePath = rest[:endQuote]
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
if processedIDs != nil {
|
||||
processedIDs[currentToolUse.ToolUseID] = true
|
||||
}
|
||||
// Return a special marker tool use that indicates truncation
|
||||
toolUse := KiroToolUse{
|
||||
ToolUseID: currentToolUse.ToolUseID,
|
||||
Name: "__truncated_write__", // Special marker name
|
||||
Input: map[string]interface{}{
|
||||
"error": "Write tool content was truncated by upstream API. The file content is too large.",
|
||||
"file_path": filePath,
|
||||
},
|
||||
}
|
||||
toolUses = append(toolUses, toolUse)
|
||||
return toolUses, nil
|
||||
}
|
||||
}
|
||||
|
||||
// Additional check: Write tool parsed successfully but missing content field
|
||||
if currentToolUse.Name == "Write" {
|
||||
if _, hasContent := finalInput["content"]; !hasContent {
|
||||
if filePath, hasPath := finalInput["file_path"]; hasPath {
|
||||
log.Warnf("kiro: Write tool input missing 'content' field, likely truncated by upstream API")
|
||||
if processedIDs != nil {
|
||||
processedIDs[currentToolUse.ToolUseID] = true
|
||||
}
|
||||
// Return a special marker tool use that indicates truncation
|
||||
toolUse := KiroToolUse{
|
||||
ToolUseID: currentToolUse.ToolUseID,
|
||||
Name: "__truncated_write__", // Special marker name
|
||||
Input: map[string]interface{}{
|
||||
"error": "Write tool content field was missing. The file content is too large.",
|
||||
"file_path": filePath,
|
||||
},
|
||||
}
|
||||
toolUses = append(toolUses, toolUse)
|
||||
return toolUses, nil
|
||||
}
|
||||
// Detect truncation for all tools
|
||||
truncInfo := DetectTruncation(currentToolUse.Name, currentToolUse.ToolUseID, fullInput, finalInput)
|
||||
if truncInfo.IsTruncated {
|
||||
log.Warnf("kiro: TRUNCATION DETECTED for tool %s (ID: %s): type=%s, raw_size=%d bytes",
|
||||
currentToolUse.Name, currentToolUse.ToolUseID, truncInfo.TruncationType, len(fullInput))
|
||||
log.Warnf("kiro: truncation details: %s", truncInfo.ErrorMessage)
|
||||
if len(truncInfo.ParsedFields) > 0 {
|
||||
log.Infof("kiro: partial fields received: %v", truncInfo.ParsedFields)
|
||||
}
|
||||
// Store truncation info in the state for upstream handling
|
||||
currentToolUse.TruncationInfo = &truncInfo
|
||||
} else {
|
||||
log.Infof("kiro: tool use %s input length: %d bytes (no truncation)", currentToolUse.Name, len(fullInput))
|
||||
}
|
||||
|
||||
// Create the tool use with truncation info if applicable
|
||||
toolUse := KiroToolUse{
|
||||
ToolUseID: currentToolUse.ToolUseID,
|
||||
Name: currentToolUse.Name,
|
||||
Input: finalInput,
|
||||
ToolUseID: currentToolUse.ToolUseID,
|
||||
Name: currentToolUse.Name,
|
||||
Input: finalInput,
|
||||
IsTruncated: truncInfo.IsTruncated,
|
||||
TruncationInfo: nil, // Will be set below if truncated
|
||||
}
|
||||
if truncInfo.IsTruncated {
|
||||
toolUse.TruncationInfo = &truncInfo
|
||||
}
|
||||
toolUses = append(toolUses, toolUse)
|
||||
|
||||
@@ -576,7 +507,7 @@ func ProcessToolUseEvent(event map[string]interface{}, currentToolUse *ToolUseSt
|
||||
processedIDs[currentToolUse.ToolUseID] = true
|
||||
}
|
||||
|
||||
log.Infof("kiro: completed tool use: %s (ID: %s)", currentToolUse.Name, currentToolUse.ToolUseID)
|
||||
log.Infof("kiro: completed tool use: %s (ID: %s, truncated: %v)", currentToolUse.Name, currentToolUse.ToolUseID, truncInfo.IsTruncated)
|
||||
return toolUses, nil
|
||||
}
|
||||
|
||||
@@ -610,4 +541,3 @@ func DeduplicateToolUses(toolUses []KiroToolUse) []KiroToolUse {
|
||||
|
||||
return unique
|
||||
}
|
||||
|
||||
|
||||
1169
internal/translator/kiro/claude/kiro_websearch.go
Normal file
1169
internal/translator/kiro/claude/kiro_websearch.go
Normal file
File diff suppressed because it is too large
Load Diff
270
internal/translator/kiro/claude/kiro_websearch_handler.go
Normal file
270
internal/translator/kiro/claude/kiro_websearch_handler.go
Normal file
@@ -0,0 +1,270 @@
|
||||
// Package claude provides web search handler for Kiro translator.
|
||||
// This file implements the MCP API call and response handling.
|
||||
package claude
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"io"
|
||||
"net/http"
|
||||
"sync"
|
||||
"sync/atomic"
|
||||
"time"
|
||||
|
||||
"github.com/google/uuid"
|
||||
kiroauth "github.com/router-for-me/CLIProxyAPI/v6/internal/auth/kiro"
|
||||
"github.com/router-for-me/CLIProxyAPI/v6/internal/util"
|
||||
log "github.com/sirupsen/logrus"
|
||||
)
|
||||
|
||||
// Cached web_search tool description fetched from MCP tools/list.
|
||||
// Uses atomic.Pointer[sync.Once] for lock-free reads with retry-on-failure:
|
||||
// - sync.Once prevents race conditions and deduplicates concurrent calls
|
||||
// - On failure, a fresh sync.Once is swapped in to allow retry on next call
|
||||
// - On success, sync.Once stays "done" forever — zero overhead for subsequent calls
|
||||
var (
|
||||
cachedToolDescription atomic.Value // stores string
|
||||
toolDescOnce atomic.Pointer[sync.Once]
|
||||
fallbackFpOnce sync.Once
|
||||
fallbackFp *kiroauth.Fingerprint
|
||||
)
|
||||
|
||||
func init() {
|
||||
toolDescOnce.Store(&sync.Once{})
|
||||
}
|
||||
|
||||
// FetchToolDescription calls MCP tools/list to get the web_search tool description
|
||||
// and caches it. Safe to call concurrently — only one goroutine fetches at a time.
|
||||
// If the fetch fails, subsequent calls will retry. On success, no further fetches occur.
|
||||
// The httpClient parameter allows reusing a shared pooled HTTP client.
|
||||
func FetchToolDescription(mcpEndpoint, authToken string, httpClient *http.Client, fp *kiroauth.Fingerprint, authAttrs map[string]string) {
|
||||
toolDescOnce.Load().Do(func() {
|
||||
handler := NewWebSearchHandler(mcpEndpoint, authToken, httpClient, fp, authAttrs)
|
||||
reqBody := []byte(`{"id":"tools_list","jsonrpc":"2.0","method":"tools/list"}`)
|
||||
log.Debugf("kiro/websearch MCP tools/list request: %d bytes", len(reqBody))
|
||||
|
||||
req, err := http.NewRequest("POST", mcpEndpoint, bytes.NewReader(reqBody))
|
||||
if err != nil {
|
||||
log.Warnf("kiro/websearch: failed to create tools/list request: %v", err)
|
||||
toolDescOnce.Store(&sync.Once{}) // allow retry
|
||||
return
|
||||
}
|
||||
|
||||
// Reuse same headers as CallMcpAPI
|
||||
handler.setMcpHeaders(req)
|
||||
|
||||
resp, err := handler.HTTPClient.Do(req)
|
||||
if err != nil {
|
||||
log.Warnf("kiro/websearch: tools/list request failed: %v", err)
|
||||
toolDescOnce.Store(&sync.Once{}) // allow retry
|
||||
return
|
||||
}
|
||||
defer resp.Body.Close()
|
||||
|
||||
body, err := io.ReadAll(resp.Body)
|
||||
if err != nil || resp.StatusCode != http.StatusOK {
|
||||
log.Warnf("kiro/websearch: tools/list returned status %d", resp.StatusCode)
|
||||
toolDescOnce.Store(&sync.Once{}) // allow retry
|
||||
return
|
||||
}
|
||||
log.Debugf("kiro/websearch MCP tools/list response: [%d] %d bytes", resp.StatusCode, len(body))
|
||||
|
||||
// Parse: {"result":{"tools":[{"name":"web_search","description":"..."}]}}
|
||||
var result struct {
|
||||
Result *struct {
|
||||
Tools []struct {
|
||||
Name string `json:"name"`
|
||||
Description string `json:"description"`
|
||||
} `json:"tools"`
|
||||
} `json:"result"`
|
||||
}
|
||||
if err := json.Unmarshal(body, &result); err != nil || result.Result == nil {
|
||||
log.Warnf("kiro/websearch: failed to parse tools/list response")
|
||||
toolDescOnce.Store(&sync.Once{}) // allow retry
|
||||
return
|
||||
}
|
||||
|
||||
for _, tool := range result.Result.Tools {
|
||||
if tool.Name == "web_search" && tool.Description != "" {
|
||||
cachedToolDescription.Store(tool.Description)
|
||||
log.Infof("kiro/websearch: cached web_search description from tools/list (%d bytes)", len(tool.Description))
|
||||
return // success — sync.Once stays "done", no more fetches
|
||||
}
|
||||
}
|
||||
|
||||
// web_search tool not found in response
|
||||
toolDescOnce.Store(&sync.Once{}) // allow retry
|
||||
})
|
||||
}
|
||||
|
||||
// GetWebSearchDescription returns the cached web_search tool description,
|
||||
// or empty string if not yet fetched. Lock-free via atomic.Value.
|
||||
func GetWebSearchDescription() string {
|
||||
if v := cachedToolDescription.Load(); v != nil {
|
||||
return v.(string)
|
||||
}
|
||||
return ""
|
||||
}
|
||||
|
||||
// WebSearchHandler handles web search requests via Kiro MCP API
|
||||
type WebSearchHandler struct {
|
||||
McpEndpoint string
|
||||
HTTPClient *http.Client
|
||||
AuthToken string
|
||||
Fingerprint *kiroauth.Fingerprint // optional, for dynamic headers
|
||||
AuthAttrs map[string]string // optional, for custom headers from auth.Attributes
|
||||
}
|
||||
|
||||
// NewWebSearchHandler creates a new WebSearchHandler.
|
||||
// If httpClient is nil, a default client with 30s timeout is used.
|
||||
// If fingerprint is nil, a random one-off fingerprint is generated.
|
||||
// Pass a shared pooled client (e.g. from getKiroPooledHTTPClient) for connection reuse.
|
||||
func NewWebSearchHandler(mcpEndpoint, authToken string, httpClient *http.Client, fp *kiroauth.Fingerprint, authAttrs map[string]string) *WebSearchHandler {
|
||||
if httpClient == nil {
|
||||
httpClient = &http.Client{
|
||||
Timeout: 30 * time.Second,
|
||||
}
|
||||
}
|
||||
if fp == nil {
|
||||
// Use a shared fallback fingerprint for callers without token context
|
||||
fallbackFpOnce.Do(func() {
|
||||
mgr := kiroauth.NewFingerprintManager()
|
||||
fallbackFp = mgr.GetFingerprint("mcp-fallback")
|
||||
})
|
||||
fp = fallbackFp
|
||||
}
|
||||
return &WebSearchHandler{
|
||||
McpEndpoint: mcpEndpoint,
|
||||
HTTPClient: httpClient,
|
||||
AuthToken: authToken,
|
||||
Fingerprint: fp,
|
||||
AuthAttrs: authAttrs,
|
||||
}
|
||||
}
|
||||
|
||||
// setMcpHeaders sets standard MCP API headers on the request,
|
||||
// aligned with the GAR request pattern in kiro_executor.go.
|
||||
func (h *WebSearchHandler) setMcpHeaders(req *http.Request) {
|
||||
fp := h.Fingerprint
|
||||
|
||||
// 1. Content-Type & Accept (aligned with GAR)
|
||||
req.Header.Set("Content-Type", "application/json")
|
||||
req.Header.Set("Accept", "*/*")
|
||||
|
||||
// 2. Kiro-specific headers (aligned with GAR)
|
||||
req.Header.Set("x-amzn-kiro-agent-mode", "vibe")
|
||||
req.Header.Set("x-amzn-codewhisperer-optout", "true")
|
||||
|
||||
// 3. Dynamic fingerprint headers
|
||||
req.Header.Set("User-Agent", fp.BuildUserAgent())
|
||||
req.Header.Set("X-Amz-User-Agent", fp.BuildAmzUserAgent())
|
||||
|
||||
// 4. AWS SDK identifiers (casing aligned with GAR)
|
||||
req.Header.Set("Amz-Sdk-Request", "attempt=1; max=3")
|
||||
req.Header.Set("Amz-Sdk-Invocation-Id", uuid.New().String())
|
||||
|
||||
// 5. Authentication
|
||||
req.Header.Set("Authorization", "Bearer "+h.AuthToken)
|
||||
|
||||
// 6. Custom headers from auth attributes
|
||||
util.ApplyCustomHeadersFromAttrs(req, h.AuthAttrs)
|
||||
}
|
||||
|
||||
// mcpMaxRetries is the maximum number of retries for MCP API calls.
|
||||
const mcpMaxRetries = 2
|
||||
|
||||
// CallMcpAPI calls the Kiro MCP API with the given request.
|
||||
// Includes retry logic with exponential backoff for retryable errors,
|
||||
// aligned with the GAR request retry pattern.
|
||||
func (h *WebSearchHandler) CallMcpAPI(request *McpRequest) (*McpResponse, error) {
|
||||
requestBody, err := json.Marshal(request)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to marshal MCP request: %w", err)
|
||||
}
|
||||
log.Debugf("kiro/websearch MCP request → %s (%d bytes)", h.McpEndpoint, len(requestBody))
|
||||
|
||||
var lastErr error
|
||||
for attempt := 0; attempt <= mcpMaxRetries; attempt++ {
|
||||
if attempt > 0 {
|
||||
backoff := time.Duration(1<<attempt) * time.Second
|
||||
if backoff > 10*time.Second {
|
||||
backoff = 10 * time.Second
|
||||
}
|
||||
log.Warnf("kiro/websearch: MCP retry %d/%d after %v (last error: %v)", attempt, mcpMaxRetries, backoff, lastErr)
|
||||
time.Sleep(backoff)
|
||||
}
|
||||
|
||||
req, err := http.NewRequest("POST", h.McpEndpoint, bytes.NewReader(requestBody))
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to create HTTP request: %w", err)
|
||||
}
|
||||
|
||||
h.setMcpHeaders(req)
|
||||
|
||||
resp, err := h.HTTPClient.Do(req)
|
||||
if err != nil {
|
||||
lastErr = fmt.Errorf("MCP API request failed: %w", err)
|
||||
continue // network error → retry
|
||||
}
|
||||
|
||||
body, err := io.ReadAll(resp.Body)
|
||||
resp.Body.Close()
|
||||
if err != nil {
|
||||
lastErr = fmt.Errorf("failed to read MCP response: %w", err)
|
||||
continue // read error → retry
|
||||
}
|
||||
log.Debugf("kiro/websearch MCP response ← [%d] (%d bytes)", resp.StatusCode, len(body))
|
||||
|
||||
// Retryable HTTP status codes (aligned with GAR: 502, 503, 504)
|
||||
if resp.StatusCode >= 502 && resp.StatusCode <= 504 {
|
||||
lastErr = fmt.Errorf("MCP API returned retryable status %d: %s", resp.StatusCode, string(body))
|
||||
continue
|
||||
}
|
||||
|
||||
if resp.StatusCode != http.StatusOK {
|
||||
return nil, fmt.Errorf("MCP API returned status %d: %s", resp.StatusCode, string(body))
|
||||
}
|
||||
|
||||
var mcpResponse McpResponse
|
||||
if err := json.Unmarshal(body, &mcpResponse); err != nil {
|
||||
return nil, fmt.Errorf("failed to parse MCP response: %w", err)
|
||||
}
|
||||
|
||||
if mcpResponse.Error != nil {
|
||||
code := -1
|
||||
if mcpResponse.Error.Code != nil {
|
||||
code = *mcpResponse.Error.Code
|
||||
}
|
||||
msg := "Unknown error"
|
||||
if mcpResponse.Error.Message != nil {
|
||||
msg = *mcpResponse.Error.Message
|
||||
}
|
||||
return nil, fmt.Errorf("MCP error %d: %s", code, msg)
|
||||
}
|
||||
|
||||
return &mcpResponse, nil
|
||||
}
|
||||
|
||||
return nil, lastErr
|
||||
}
|
||||
|
||||
// ParseSearchResults extracts WebSearchResults from MCP response
|
||||
func ParseSearchResults(response *McpResponse) *WebSearchResults {
|
||||
if response == nil || response.Result == nil || len(response.Result.Content) == 0 {
|
||||
return nil
|
||||
}
|
||||
|
||||
content := response.Result.Content[0]
|
||||
if content.ContentType != "text" {
|
||||
return nil
|
||||
}
|
||||
|
||||
var results WebSearchResults
|
||||
if err := json.Unmarshal([]byte(content.Text), &results); err != nil {
|
||||
log.Warnf("kiro/websearch: failed to parse search results: %v", err)
|
||||
return nil
|
||||
}
|
||||
|
||||
return &results
|
||||
}
|
||||
517
internal/translator/kiro/claude/truncation_detector.go
Normal file
517
internal/translator/kiro/claude/truncation_detector.go
Normal file
@@ -0,0 +1,517 @@
|
||||
// Package claude provides truncation detection for Kiro tool call responses.
|
||||
// When Kiro API reaches its output token limit, tool call JSON may be truncated,
|
||||
// resulting in incomplete or unparseable input parameters.
|
||||
package claude
|
||||
|
||||
import (
|
||||
"encoding/json"
|
||||
"strings"
|
||||
|
||||
log "github.com/sirupsen/logrus"
|
||||
)
|
||||
|
||||
// TruncationInfo contains details about detected truncation in a tool use event.
|
||||
type TruncationInfo struct {
|
||||
IsTruncated bool // Whether truncation was detected
|
||||
TruncationType string // Type of truncation detected
|
||||
ToolName string // Name of the truncated tool
|
||||
ToolUseID string // ID of the truncated tool use
|
||||
RawInput string // The raw (possibly truncated) input string
|
||||
ParsedFields map[string]string // Fields that were successfully parsed before truncation
|
||||
ErrorMessage string // Human-readable error message
|
||||
}
|
||||
|
||||
// TruncationType constants for different truncation scenarios
|
||||
const (
|
||||
TruncationTypeNone = "" // No truncation detected
|
||||
TruncationTypeEmptyInput = "empty_input" // No input data received at all
|
||||
TruncationTypeInvalidJSON = "invalid_json" // JSON is syntactically invalid (truncated mid-value)
|
||||
TruncationTypeMissingFields = "missing_fields" // JSON parsed but critical fields are missing
|
||||
TruncationTypeIncompleteString = "incomplete_string" // String value was cut off mid-content
|
||||
)
|
||||
|
||||
// KnownWriteTools lists tool names that typically write content and have a "content" field.
|
||||
// These tools are checked for content field truncation specifically.
|
||||
var KnownWriteTools = map[string]bool{
|
||||
"Write": true,
|
||||
"write_to_file": true,
|
||||
"fsWrite": true,
|
||||
"create_file": true,
|
||||
"edit_file": true,
|
||||
"apply_diff": true,
|
||||
"str_replace_editor": true,
|
||||
"insert": true,
|
||||
}
|
||||
|
||||
// KnownCommandTools lists tool names that execute commands.
|
||||
var KnownCommandTools = map[string]bool{
|
||||
"Bash": true,
|
||||
"execute": true,
|
||||
"run_command": true,
|
||||
"shell": true,
|
||||
"terminal": true,
|
||||
"execute_python": true,
|
||||
}
|
||||
|
||||
// RequiredFieldsByTool maps tool names to their required fields.
|
||||
// If any of these fields are missing, the tool input is considered truncated.
|
||||
var RequiredFieldsByTool = map[string][]string{
|
||||
"Write": {"file_path", "content"},
|
||||
"write_to_file": {"path", "content"},
|
||||
"fsWrite": {"path", "content"},
|
||||
"create_file": {"path", "content"},
|
||||
"edit_file": {"path"},
|
||||
"apply_diff": {"path", "diff"},
|
||||
"str_replace_editor": {"path", "old_str", "new_str"},
|
||||
"Bash": {"command"},
|
||||
"execute": {"command"},
|
||||
"run_command": {"command"},
|
||||
}
|
||||
|
||||
// DetectTruncation checks if the tool use input appears to be truncated.
|
||||
// It returns detailed information about the truncation status and type.
|
||||
func DetectTruncation(toolName, toolUseID, rawInput string, parsedInput map[string]interface{}) TruncationInfo {
|
||||
info := TruncationInfo{
|
||||
ToolName: toolName,
|
||||
ToolUseID: toolUseID,
|
||||
RawInput: rawInput,
|
||||
ParsedFields: make(map[string]string),
|
||||
}
|
||||
|
||||
// Scenario 1: Empty input buffer - no data received at all
|
||||
if strings.TrimSpace(rawInput) == "" {
|
||||
info.IsTruncated = true
|
||||
info.TruncationType = TruncationTypeEmptyInput
|
||||
info.ErrorMessage = "Tool input was completely empty - API response may have been truncated before tool parameters were transmitted"
|
||||
log.Warnf("kiro: truncation detected [%s] for tool %s (ID: %s): empty input buffer",
|
||||
info.TruncationType, toolName, toolUseID)
|
||||
return info
|
||||
}
|
||||
|
||||
// Scenario 2: JSON parse failure - syntactically invalid JSON
|
||||
if parsedInput == nil || len(parsedInput) == 0 {
|
||||
// Check if the raw input looks like truncated JSON
|
||||
if looksLikeTruncatedJSON(rawInput) {
|
||||
info.IsTruncated = true
|
||||
info.TruncationType = TruncationTypeInvalidJSON
|
||||
info.ParsedFields = extractPartialFields(rawInput)
|
||||
info.ErrorMessage = buildTruncationErrorMessage(toolName, info.TruncationType, info.ParsedFields, rawInput)
|
||||
log.Warnf("kiro: truncation detected [%s] for tool %s (ID: %s): JSON parse failed, raw length=%d bytes",
|
||||
info.TruncationType, toolName, toolUseID, len(rawInput))
|
||||
return info
|
||||
}
|
||||
}
|
||||
|
||||
// Scenario 3: JSON parsed but critical fields are missing
|
||||
if parsedInput != nil {
|
||||
requiredFields, hasRequirements := RequiredFieldsByTool[toolName]
|
||||
if hasRequirements {
|
||||
missingFields := findMissingRequiredFields(parsedInput, requiredFields)
|
||||
if len(missingFields) > 0 {
|
||||
info.IsTruncated = true
|
||||
info.TruncationType = TruncationTypeMissingFields
|
||||
info.ParsedFields = extractParsedFieldNames(parsedInput)
|
||||
info.ErrorMessage = buildMissingFieldsErrorMessage(toolName, missingFields, info.ParsedFields)
|
||||
log.Warnf("kiro: truncation detected [%s] for tool %s (ID: %s): missing required fields: %v",
|
||||
info.TruncationType, toolName, toolUseID, missingFields)
|
||||
return info
|
||||
}
|
||||
}
|
||||
|
||||
// Scenario 4: Check for incomplete string values (very short content for write tools)
|
||||
if isWriteTool(toolName) {
|
||||
if contentTruncation := detectContentTruncation(parsedInput, rawInput); contentTruncation != "" {
|
||||
info.IsTruncated = true
|
||||
info.TruncationType = TruncationTypeIncompleteString
|
||||
info.ParsedFields = extractParsedFieldNames(parsedInput)
|
||||
info.ErrorMessage = contentTruncation
|
||||
log.Warnf("kiro: truncation detected [%s] for tool %s (ID: %s): %s",
|
||||
info.TruncationType, toolName, toolUseID, contentTruncation)
|
||||
return info
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// No truncation detected
|
||||
info.IsTruncated = false
|
||||
info.TruncationType = TruncationTypeNone
|
||||
return info
|
||||
}
|
||||
|
||||
// looksLikeTruncatedJSON checks if the raw string appears to be truncated JSON.
|
||||
func looksLikeTruncatedJSON(raw string) bool {
|
||||
trimmed := strings.TrimSpace(raw)
|
||||
if trimmed == "" {
|
||||
return false
|
||||
}
|
||||
|
||||
// Must start with { to be considered JSON
|
||||
if !strings.HasPrefix(trimmed, "{") {
|
||||
return false
|
||||
}
|
||||
|
||||
// Count brackets to detect imbalance
|
||||
openBraces := strings.Count(trimmed, "{")
|
||||
closeBraces := strings.Count(trimmed, "}")
|
||||
openBrackets := strings.Count(trimmed, "[")
|
||||
closeBrackets := strings.Count(trimmed, "]")
|
||||
|
||||
// Bracket imbalance suggests truncation
|
||||
if openBraces > closeBraces || openBrackets > closeBrackets {
|
||||
return true
|
||||
}
|
||||
|
||||
// Check for obvious truncation patterns
|
||||
// - Ends with a quote but no closing brace
|
||||
// - Ends with a colon (mid key-value)
|
||||
// - Ends with a comma (mid object/array)
|
||||
lastChar := trimmed[len(trimmed)-1]
|
||||
if lastChar != '}' && lastChar != ']' {
|
||||
// Check if it's not a complete simple value
|
||||
if lastChar == '"' || lastChar == ':' || lastChar == ',' {
|
||||
return true
|
||||
}
|
||||
}
|
||||
|
||||
// Check for unclosed strings (odd number of unescaped quotes)
|
||||
inString := false
|
||||
escaped := false
|
||||
for i := 0; i < len(trimmed); i++ {
|
||||
c := trimmed[i]
|
||||
if escaped {
|
||||
escaped = false
|
||||
continue
|
||||
}
|
||||
if c == '\\' {
|
||||
escaped = true
|
||||
continue
|
||||
}
|
||||
if c == '"' {
|
||||
inString = !inString
|
||||
}
|
||||
}
|
||||
if inString {
|
||||
return true // Unclosed string
|
||||
}
|
||||
|
||||
return false
|
||||
}
|
||||
|
||||
// extractPartialFields attempts to extract any field names from malformed JSON.
|
||||
// This helps provide context about what was received before truncation.
|
||||
func extractPartialFields(raw string) map[string]string {
|
||||
fields := make(map[string]string)
|
||||
|
||||
// Simple pattern matching for "key": "value" or "key": value patterns
|
||||
// This works even with truncated JSON
|
||||
trimmed := strings.TrimSpace(raw)
|
||||
if !strings.HasPrefix(trimmed, "{") {
|
||||
return fields
|
||||
}
|
||||
|
||||
// Remove opening brace
|
||||
content := strings.TrimPrefix(trimmed, "{")
|
||||
|
||||
// Split by comma (rough parsing)
|
||||
parts := strings.Split(content, ",")
|
||||
for _, part := range parts {
|
||||
part = strings.TrimSpace(part)
|
||||
if colonIdx := strings.Index(part, ":"); colonIdx > 0 {
|
||||
key := strings.TrimSpace(part[:colonIdx])
|
||||
key = strings.Trim(key, `"`)
|
||||
value := strings.TrimSpace(part[colonIdx+1:])
|
||||
|
||||
// Truncate long values for display
|
||||
if len(value) > 50 {
|
||||
value = value[:50] + "..."
|
||||
}
|
||||
fields[key] = value
|
||||
}
|
||||
}
|
||||
|
||||
return fields
|
||||
}
|
||||
|
||||
// extractParsedFieldNames returns the field names from a successfully parsed map.
|
||||
func extractParsedFieldNames(parsed map[string]interface{}) map[string]string {
|
||||
fields := make(map[string]string)
|
||||
for key, val := range parsed {
|
||||
switch v := val.(type) {
|
||||
case string:
|
||||
if len(v) > 50 {
|
||||
fields[key] = v[:50] + "..."
|
||||
} else {
|
||||
fields[key] = v
|
||||
}
|
||||
case nil:
|
||||
fields[key] = "<null>"
|
||||
default:
|
||||
// For complex types, just indicate presence
|
||||
fields[key] = "<present>"
|
||||
}
|
||||
}
|
||||
return fields
|
||||
}
|
||||
|
||||
// findMissingRequiredFields checks which required fields are missing from the parsed input.
|
||||
func findMissingRequiredFields(parsed map[string]interface{}, required []string) []string {
|
||||
var missing []string
|
||||
for _, field := range required {
|
||||
if _, exists := parsed[field]; !exists {
|
||||
missing = append(missing, field)
|
||||
}
|
||||
}
|
||||
return missing
|
||||
}
|
||||
|
||||
// isWriteTool checks if the tool is a known write/file operation tool.
|
||||
func isWriteTool(toolName string) bool {
|
||||
return KnownWriteTools[toolName]
|
||||
}
|
||||
|
||||
// detectContentTruncation checks if the content field appears truncated for write tools.
|
||||
func detectContentTruncation(parsed map[string]interface{}, rawInput string) string {
|
||||
// Check for content field
|
||||
content, hasContent := parsed["content"]
|
||||
if !hasContent {
|
||||
return ""
|
||||
}
|
||||
|
||||
contentStr, isString := content.(string)
|
||||
if !isString {
|
||||
return ""
|
||||
}
|
||||
|
||||
// Heuristic: if raw input is very large but content is suspiciously short,
|
||||
// it might indicate truncation during JSON repair
|
||||
if len(rawInput) > 1000 && len(contentStr) < 100 {
|
||||
return "content field appears suspiciously short compared to raw input size"
|
||||
}
|
||||
|
||||
// Check for code blocks that appear to be cut off
|
||||
if strings.Contains(contentStr, "```") {
|
||||
openFences := strings.Count(contentStr, "```")
|
||||
if openFences%2 != 0 {
|
||||
return "content contains unclosed code fence (```) suggesting truncation"
|
||||
}
|
||||
}
|
||||
|
||||
return ""
|
||||
}
|
||||
|
||||
// buildTruncationErrorMessage creates a human-readable error message for truncation.
|
||||
func buildTruncationErrorMessage(toolName, truncationType string, parsedFields map[string]string, rawInput string) string {
|
||||
var sb strings.Builder
|
||||
sb.WriteString("Tool input was truncated by the API. ")
|
||||
|
||||
switch truncationType {
|
||||
case TruncationTypeEmptyInput:
|
||||
sb.WriteString("No input data was received.")
|
||||
case TruncationTypeInvalidJSON:
|
||||
sb.WriteString("JSON was cut off mid-transmission. ")
|
||||
if len(parsedFields) > 0 {
|
||||
sb.WriteString("Partial fields received: ")
|
||||
first := true
|
||||
for k := range parsedFields {
|
||||
if !first {
|
||||
sb.WriteString(", ")
|
||||
}
|
||||
sb.WriteString(k)
|
||||
first = false
|
||||
}
|
||||
}
|
||||
case TruncationTypeMissingFields:
|
||||
sb.WriteString("Required fields are missing from the input.")
|
||||
case TruncationTypeIncompleteString:
|
||||
sb.WriteString("Content appears to be shortened or incomplete.")
|
||||
}
|
||||
|
||||
sb.WriteString(" Received ")
|
||||
sb.WriteString(string(rune(len(rawInput))))
|
||||
sb.WriteString(" bytes. Please retry with smaller content chunks.")
|
||||
|
||||
return sb.String()
|
||||
}
|
||||
|
||||
// buildMissingFieldsErrorMessage creates an error message for missing required fields.
|
||||
func buildMissingFieldsErrorMessage(toolName string, missingFields []string, parsedFields map[string]string) string {
|
||||
var sb strings.Builder
|
||||
sb.WriteString("Tool '")
|
||||
sb.WriteString(toolName)
|
||||
sb.WriteString("' is missing required fields: ")
|
||||
sb.WriteString(strings.Join(missingFields, ", "))
|
||||
sb.WriteString(". Fields received: ")
|
||||
|
||||
first := true
|
||||
for k := range parsedFields {
|
||||
if !first {
|
||||
sb.WriteString(", ")
|
||||
}
|
||||
sb.WriteString(k)
|
||||
first = false
|
||||
}
|
||||
|
||||
sb.WriteString(". This usually indicates the API response was truncated.")
|
||||
return sb.String()
|
||||
}
|
||||
|
||||
// IsTruncated is a convenience function to check if a tool use appears truncated.
|
||||
func IsTruncated(toolName, rawInput string, parsedInput map[string]interface{}) bool {
|
||||
info := DetectTruncation(toolName, "", rawInput, parsedInput)
|
||||
return info.IsTruncated
|
||||
}
|
||||
|
||||
// GetTruncationSummary returns a short summary string for logging.
|
||||
func GetTruncationSummary(info TruncationInfo) string {
|
||||
if !info.IsTruncated {
|
||||
return ""
|
||||
}
|
||||
|
||||
result, _ := json.Marshal(map[string]interface{}{
|
||||
"tool": info.ToolName,
|
||||
"type": info.TruncationType,
|
||||
"parsed_fields": info.ParsedFields,
|
||||
"raw_input_size": len(info.RawInput),
|
||||
})
|
||||
return string(result)
|
||||
}
|
||||
|
||||
// SoftFailureMessage contains the message structure for a truncation soft failure.
|
||||
// This is returned to Claude as a tool_result to guide retry behavior.
|
||||
type SoftFailureMessage struct {
|
||||
Status string // "incomplete" - not an error, just incomplete
|
||||
Reason string // Why the tool call was incomplete
|
||||
Guidance []string // Step-by-step retry instructions
|
||||
Context string // Any context about what was received
|
||||
MaxLineHint int // Suggested maximum lines per chunk
|
||||
}
|
||||
|
||||
// BuildSoftFailureMessage creates a structured message for Claude when truncation is detected.
|
||||
// This follows the "soft failure" pattern:
|
||||
// - For Claude: Clear explanation of what happened and how to fix
|
||||
// - For User: Hidden or minimized (appears as normal processing)
|
||||
//
|
||||
// Key principle: "Conclusion First"
|
||||
// 1. First state what happened (incomplete)
|
||||
// 2. Then explain how to fix (chunked approach)
|
||||
// 3. Provide specific guidance (line limits)
|
||||
func BuildSoftFailureMessage(info TruncationInfo) SoftFailureMessage {
|
||||
msg := SoftFailureMessage{
|
||||
Status: "incomplete",
|
||||
MaxLineHint: 300, // Conservative default
|
||||
}
|
||||
|
||||
// Build reason based on truncation type
|
||||
switch info.TruncationType {
|
||||
case TruncationTypeEmptyInput:
|
||||
msg.Reason = "Your tool call was too large and the input was completely lost during transmission."
|
||||
msg.MaxLineHint = 200
|
||||
case TruncationTypeInvalidJSON:
|
||||
msg.Reason = "Your tool call was truncated mid-transmission, resulting in incomplete JSON."
|
||||
msg.MaxLineHint = 250
|
||||
case TruncationTypeMissingFields:
|
||||
msg.Reason = "Your tool call was partially received but critical fields were cut off."
|
||||
msg.MaxLineHint = 300
|
||||
case TruncationTypeIncompleteString:
|
||||
msg.Reason = "Your tool call content was truncated - the full content did not arrive."
|
||||
msg.MaxLineHint = 350
|
||||
default:
|
||||
msg.Reason = "Your tool call was truncated by the API due to output size limits."
|
||||
}
|
||||
|
||||
// Build context from parsed fields
|
||||
if len(info.ParsedFields) > 0 {
|
||||
var parts []string
|
||||
for k, v := range info.ParsedFields {
|
||||
if len(v) > 30 {
|
||||
v = v[:30] + "..."
|
||||
}
|
||||
parts = append(parts, k+"="+v)
|
||||
}
|
||||
msg.Context = "Received partial data: " + strings.Join(parts, ", ")
|
||||
}
|
||||
|
||||
// Build retry guidance - CRITICAL: Conclusion first approach
|
||||
msg.Guidance = []string{
|
||||
"CONCLUSION: Split your output into smaller chunks and retry.",
|
||||
"",
|
||||
"REQUIRED APPROACH:",
|
||||
"1. For file writes: Write in chunks of ~" + formatInt(msg.MaxLineHint) + " lines maximum",
|
||||
"2. For new files: First create with initial chunk, then append remaining sections",
|
||||
"3. For edits: Make surgical, targeted changes - avoid rewriting entire files",
|
||||
"",
|
||||
"EXAMPLE (writing a 600-line file):",
|
||||
" - Step 1: Write lines 1-300 (create file)",
|
||||
" - Step 2: Append lines 301-600 (extend file)",
|
||||
"",
|
||||
"DO NOT attempt to write the full content again in a single call.",
|
||||
"The API has a hard output limit that cannot be bypassed.",
|
||||
}
|
||||
|
||||
return msg
|
||||
}
|
||||
|
||||
// formatInt converts an integer to string (helper to avoid strconv import)
|
||||
func formatInt(n int) string {
|
||||
if n == 0 {
|
||||
return "0"
|
||||
}
|
||||
result := ""
|
||||
for n > 0 {
|
||||
result = string(rune('0'+n%10)) + result
|
||||
n /= 10
|
||||
}
|
||||
return result
|
||||
}
|
||||
|
||||
// BuildSoftFailureToolResult creates a tool_result content for Claude.
|
||||
// This is what Claude will see when a tool call is truncated.
|
||||
// Returns a string that should be used as the tool_result content.
|
||||
func BuildSoftFailureToolResult(info TruncationInfo) string {
|
||||
msg := BuildSoftFailureMessage(info)
|
||||
|
||||
var sb strings.Builder
|
||||
sb.WriteString("TOOL_CALL_INCOMPLETE\n")
|
||||
sb.WriteString("status: ")
|
||||
sb.WriteString(msg.Status)
|
||||
sb.WriteString("\n")
|
||||
sb.WriteString("reason: ")
|
||||
sb.WriteString(msg.Reason)
|
||||
sb.WriteString("\n")
|
||||
|
||||
if msg.Context != "" {
|
||||
sb.WriteString("context: ")
|
||||
sb.WriteString(msg.Context)
|
||||
sb.WriteString("\n")
|
||||
}
|
||||
|
||||
sb.WriteString("\n")
|
||||
for _, line := range msg.Guidance {
|
||||
if line != "" {
|
||||
sb.WriteString(line)
|
||||
sb.WriteString("\n")
|
||||
}
|
||||
}
|
||||
|
||||
return sb.String()
|
||||
}
|
||||
|
||||
// CreateTruncationToolResult creates a KiroToolUse that represents a soft failure.
|
||||
// Instead of returning the truncated tool_use, we return a tool with a special
|
||||
// error result that guides Claude to retry with smaller chunks.
|
||||
//
|
||||
// This is the key mechanism for "soft failure":
|
||||
// - stop_reason remains "tool_use" so Claude continues
|
||||
// - The tool_result content explains the issue and how to fix it
|
||||
// - Claude will read this and adjust its approach
|
||||
func CreateTruncationToolResult(info TruncationInfo) KiroToolUse {
|
||||
// We create a pseudo tool_use that represents the failed attempt
|
||||
// The executor will convert this to a tool_result with the guidance message
|
||||
return KiroToolUse{
|
||||
ToolUseID: info.ToolUseID,
|
||||
Name: info.ToolName,
|
||||
Input: nil, // No input since it was truncated
|
||||
IsTruncated: true,
|
||||
TruncationInfo: &info,
|
||||
}
|
||||
}
|
||||
@@ -29,6 +29,26 @@ const (
|
||||
// InlineCodeMarker is the markdown inline code marker (backtick).
|
||||
InlineCodeMarker = "`"
|
||||
|
||||
// DefaultAssistantContentWithTools is the fallback content for assistant messages
|
||||
// that have tool_use but no text content. Kiro API requires non-empty content.
|
||||
// IMPORTANT: Use a minimal neutral string that the model won't mimic in responses.
|
||||
// Previously "I'll help you with that." which caused the model to parrot it back.
|
||||
DefaultAssistantContentWithTools = "."
|
||||
|
||||
// DefaultAssistantContent is the fallback content for assistant messages
|
||||
// that have no content at all. Kiro API requires non-empty content.
|
||||
// IMPORTANT: Use a minimal neutral string that the model won't mimic in responses.
|
||||
// Previously "I understand." which could leak into model behavior.
|
||||
DefaultAssistantContent = "."
|
||||
|
||||
// DefaultUserContentWithToolResults is the fallback content for user messages
|
||||
// that have only tool_result (no text). Kiro API requires non-empty content.
|
||||
DefaultUserContentWithToolResults = "Tool results provided."
|
||||
|
||||
// DefaultUserContent is the fallback content for user messages
|
||||
// that have no content at all. Kiro API requires non-empty content.
|
||||
DefaultUserContent = "Continue"
|
||||
|
||||
// KiroAgenticSystemPrompt is injected only for -agentic models to prevent timeouts on large writes.
|
||||
// AWS Kiro API has a 2-3 minute timeout for large file write operations.
|
||||
KiroAgenticSystemPrompt = `
|
||||
|
||||
@@ -576,9 +576,23 @@ func processOpenAIMessages(messages gjson.Result, modelID, origin string) ([]Kir
|
||||
}
|
||||
}
|
||||
|
||||
// Truncate history if too long to prevent Kiro API errors
|
||||
history = truncateHistoryIfNeeded(history)
|
||||
|
||||
return history, currentUserMsg, currentToolResults
|
||||
}
|
||||
|
||||
const kiroMaxHistoryMessages = 50
|
||||
|
||||
func truncateHistoryIfNeeded(history []KiroHistoryMessage) []KiroHistoryMessage {
|
||||
if len(history) <= kiroMaxHistoryMessages {
|
||||
return history
|
||||
}
|
||||
|
||||
log.Debugf("kiro-openai: truncating history from %d to %d messages", len(history), kiroMaxHistoryMessages)
|
||||
return history[len(history)-kiroMaxHistoryMessages:]
|
||||
}
|
||||
|
||||
// buildUserMessageFromOpenAI builds a user message from OpenAI format and extracts tool results
|
||||
func buildUserMessageFromOpenAI(msg gjson.Result, modelID, origin string) (KiroUserInputMessage, []KiroToolResult) {
|
||||
content := msg.Get("content")
|
||||
@@ -645,13 +659,36 @@ func buildAssistantMessageFromOpenAI(msg gjson.Result) KiroAssistantResponseMess
|
||||
contentBuilder.WriteString(content.String())
|
||||
} else if content.IsArray() {
|
||||
for _, part := range content.Array() {
|
||||
if part.Get("type").String() == "text" {
|
||||
partType := part.Get("type").String()
|
||||
switch partType {
|
||||
case "text":
|
||||
contentBuilder.WriteString(part.Get("text").String())
|
||||
case "tool_use":
|
||||
// Handle tool_use in content array (Anthropic/OpenCode format)
|
||||
// This is different from OpenAI's tool_calls format
|
||||
toolUseID := part.Get("id").String()
|
||||
toolName := part.Get("name").String()
|
||||
inputData := part.Get("input")
|
||||
|
||||
inputMap := make(map[string]interface{})
|
||||
if inputData.Exists() && inputData.IsObject() {
|
||||
inputData.ForEach(func(key, value gjson.Result) bool {
|
||||
inputMap[key.String()] = value.Value()
|
||||
return true
|
||||
})
|
||||
}
|
||||
|
||||
toolUses = append(toolUses, KiroToolUse{
|
||||
ToolUseID: toolUseID,
|
||||
Name: toolName,
|
||||
Input: inputMap,
|
||||
})
|
||||
log.Debugf("kiro-openai: extracted tool_use from content array: %s", toolName)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// Handle tool_calls
|
||||
// Handle tool_calls (OpenAI format)
|
||||
toolCalls := msg.Get("tool_calls")
|
||||
if toolCalls.IsArray() {
|
||||
for _, tc := range toolCalls.Array() {
|
||||
@@ -677,8 +714,20 @@ func buildAssistantMessageFromOpenAI(msg gjson.Result) KiroAssistantResponseMess
|
||||
}
|
||||
}
|
||||
|
||||
// CRITICAL FIX: Kiro API requires non-empty content for assistant messages
|
||||
// This can happen with compaction requests or error recovery scenarios
|
||||
finalContent := contentBuilder.String()
|
||||
if strings.TrimSpace(finalContent) == "" {
|
||||
if len(toolUses) > 0 {
|
||||
finalContent = kirocommon.DefaultAssistantContentWithTools
|
||||
} else {
|
||||
finalContent = kirocommon.DefaultAssistantContent
|
||||
}
|
||||
log.Debugf("kiro-openai: assistant content was empty, using default: %s", finalContent)
|
||||
}
|
||||
|
||||
return KiroAssistantResponseMessage{
|
||||
Content: contentBuilder.String(),
|
||||
Content: finalContent,
|
||||
ToolUses: toolUses,
|
||||
}
|
||||
}
|
||||
|
||||
@@ -6,7 +6,6 @@
|
||||
package claude
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"strings"
|
||||
|
||||
"github.com/router-for-me/CLIProxyAPI/v6/internal/thinking"
|
||||
@@ -18,7 +17,7 @@ import (
|
||||
// It extracts the model name, system instruction, message contents, and tool declarations
|
||||
// from the raw JSON request and returns them in the format expected by the OpenAI API.
|
||||
func ConvertClaudeRequestToOpenAI(modelName string, inputRawJSON []byte, stream bool) []byte {
|
||||
rawJSON := bytes.Clone(inputRawJSON)
|
||||
rawJSON := inputRawJSON
|
||||
// Base OpenAI Chat Completions API template
|
||||
out := `{"model":"","messages":[]}`
|
||||
|
||||
@@ -76,6 +75,10 @@ func ConvertClaudeRequestToOpenAI(modelName string, inputRawJSON []byte, stream
|
||||
out, _ = sjson.Set(out, "reasoning_effort", effort)
|
||||
}
|
||||
}
|
||||
case "adaptive":
|
||||
// Claude adaptive means "enable with max capacity"; keep it as highest level
|
||||
// and let ApplyThinking normalize per target model capability.
|
||||
out, _ = sjson.Set(out, "reasoning_effort", string(thinking.LevelXHigh))
|
||||
case "disabled":
|
||||
if effort, ok := thinking.ConvertBudgetToLevel(0); ok && effort != "" {
|
||||
out, _ = sjson.Set(out, "reasoning_effort", effort)
|
||||
|
||||
@@ -6,8 +6,6 @@
|
||||
package geminiCLI
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
|
||||
. "github.com/router-for-me/CLIProxyAPI/v6/internal/translator/openai/gemini"
|
||||
"github.com/tidwall/gjson"
|
||||
"github.com/tidwall/sjson"
|
||||
@@ -17,7 +15,7 @@ import (
|
||||
// It extracts the model name, generation config, message contents, and tool declarations
|
||||
// from the raw JSON request and returns them in the format expected by the OpenAI API.
|
||||
func ConvertGeminiCLIRequestToOpenAI(modelName string, inputRawJSON []byte, stream bool) []byte {
|
||||
rawJSON := bytes.Clone(inputRawJSON)
|
||||
rawJSON := inputRawJSON
|
||||
rawJSON = []byte(gjson.GetBytes(rawJSON, "request").Raw)
|
||||
rawJSON, _ = sjson.SetBytes(rawJSON, "model", modelName)
|
||||
if gjson.GetBytes(rawJSON, "systemInstruction").Exists() {
|
||||
|
||||
@@ -6,7 +6,6 @@
|
||||
package gemini
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"crypto/rand"
|
||||
"fmt"
|
||||
"math/big"
|
||||
@@ -21,7 +20,7 @@ import (
|
||||
// It extracts the model name, generation config, message contents, and tool declarations
|
||||
// from the raw JSON request and returns them in the format expected by the OpenAI API.
|
||||
func ConvertGeminiRequestToOpenAI(modelName string, inputRawJSON []byte, stream bool) []byte {
|
||||
rawJSON := bytes.Clone(inputRawJSON)
|
||||
rawJSON := inputRawJSON
|
||||
// Base OpenAI Chat Completions API template
|
||||
out := `{"model":"","messages":[]}`
|
||||
|
||||
@@ -83,16 +82,27 @@ func ConvertGeminiRequestToOpenAI(modelName string, inputRawJSON []byte, stream
|
||||
}
|
||||
|
||||
// Map Gemini thinkingConfig to OpenAI reasoning_effort.
|
||||
// Always perform conversion to support allowCompat models that may not be in registry
|
||||
// Always perform conversion to support allowCompat models that may not be in registry.
|
||||
// Note: Google official Python SDK sends snake_case fields (thinking_level/thinking_budget).
|
||||
if thinkingConfig := genConfig.Get("thinkingConfig"); thinkingConfig.Exists() && thinkingConfig.IsObject() {
|
||||
if thinkingLevel := thinkingConfig.Get("thinkingLevel"); thinkingLevel.Exists() {
|
||||
thinkingLevel := thinkingConfig.Get("thinkingLevel")
|
||||
if !thinkingLevel.Exists() {
|
||||
thinkingLevel = thinkingConfig.Get("thinking_level")
|
||||
}
|
||||
if thinkingLevel.Exists() {
|
||||
effort := strings.ToLower(strings.TrimSpace(thinkingLevel.String()))
|
||||
if effort != "" {
|
||||
out, _ = sjson.Set(out, "reasoning_effort", effort)
|
||||
}
|
||||
} else if thinkingBudget := thinkingConfig.Get("thinkingBudget"); thinkingBudget.Exists() {
|
||||
if effort, ok := thinking.ConvertBudgetToLevel(int(thinkingBudget.Int())); ok {
|
||||
out, _ = sjson.Set(out, "reasoning_effort", effort)
|
||||
} else {
|
||||
thinkingBudget := thinkingConfig.Get("thinkingBudget")
|
||||
if !thinkingBudget.Exists() {
|
||||
thinkingBudget = thinkingConfig.Get("thinking_budget")
|
||||
}
|
||||
if thinkingBudget.Exists() {
|
||||
if effort, ok := thinking.ConvertBudgetToLevel(int(thinkingBudget.Int())); ok {
|
||||
out, _ = sjson.Set(out, "reasoning_effort", effort)
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
@@ -3,7 +3,6 @@
|
||||
package chat_completions
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"github.com/tidwall/sjson"
|
||||
)
|
||||
|
||||
@@ -25,7 +24,7 @@ func ConvertOpenAIRequestToOpenAI(modelName string, inputRawJSON []byte, _ bool)
|
||||
// If there's an error, return the original JSON or handle the error appropriately.
|
||||
// For now, we'll return the original, but in a real scenario, logging or a more robust error
|
||||
// handling mechanism would be needed.
|
||||
return bytes.Clone(inputRawJSON)
|
||||
return inputRawJSON
|
||||
}
|
||||
return updatedJSON
|
||||
}
|
||||
|
||||
@@ -1,7 +1,6 @@
|
||||
package responses
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"strings"
|
||||
|
||||
"github.com/tidwall/gjson"
|
||||
@@ -28,7 +27,7 @@ import (
|
||||
// Returns:
|
||||
// - []byte: The transformed request data in OpenAI chat completions format
|
||||
func ConvertOpenAIResponsesRequestToOpenAIChatCompletions(modelName string, inputRawJSON []byte, stream bool) []byte {
|
||||
rawJSON := bytes.Clone(inputRawJSON)
|
||||
rawJSON := inputRawJSON
|
||||
// Base OpenAI chat completions template with default values
|
||||
out := `{"model":"","messages":[],"stream":false}`
|
||||
|
||||
@@ -68,7 +67,10 @@ func ConvertOpenAIResponsesRequestToOpenAIChatCompletions(modelName string, inpu
|
||||
case "message", "":
|
||||
// Handle regular message conversion
|
||||
role := item.Get("role").String()
|
||||
message := `{"role":"","content":""}`
|
||||
if role == "developer" {
|
||||
role = "user"
|
||||
}
|
||||
message := `{"role":"","content":[]}`
|
||||
message, _ = sjson.Set(message, "role", role)
|
||||
|
||||
if content := item.Get("content"); content.Exists() && content.IsArray() {
|
||||
@@ -82,20 +84,16 @@ func ConvertOpenAIResponsesRequestToOpenAIChatCompletions(modelName string, inpu
|
||||
}
|
||||
|
||||
switch contentType {
|
||||
case "input_text":
|
||||
case "input_text", "output_text":
|
||||
text := contentItem.Get("text").String()
|
||||
if messageContent != "" {
|
||||
messageContent += "\n" + text
|
||||
} else {
|
||||
messageContent = text
|
||||
}
|
||||
case "output_text":
|
||||
text := contentItem.Get("text").String()
|
||||
if messageContent != "" {
|
||||
messageContent += "\n" + text
|
||||
} else {
|
||||
messageContent = text
|
||||
}
|
||||
contentPart := `{"type":"text","text":""}`
|
||||
contentPart, _ = sjson.Set(contentPart, "text", text)
|
||||
message, _ = sjson.SetRaw(message, "content.-1", contentPart)
|
||||
case "input_image":
|
||||
imageURL := contentItem.Get("image_url").String()
|
||||
contentPart := `{"type":"image_url","image_url":{"url":""}}`
|
||||
contentPart, _ = sjson.Set(contentPart, "image_url.url", imageURL)
|
||||
message, _ = sjson.SetRaw(message, "content.-1", contentPart)
|
||||
}
|
||||
return true
|
||||
})
|
||||
@@ -167,7 +165,8 @@ func ConvertOpenAIResponsesRequestToOpenAIChatCompletions(modelName string, inpu
|
||||
// Only function tools need structural conversion because Chat Completions nests details under "function".
|
||||
toolType := tool.Get("type").String()
|
||||
if toolType != "" && toolType != "function" && tool.IsObject() {
|
||||
chatCompletionsTools = append(chatCompletionsTools, tool.Value())
|
||||
// Almost all providers lack built-in tools, so we just ignore them.
|
||||
// chatCompletionsTools = append(chatCompletionsTools, tool.Value())
|
||||
return true
|
||||
}
|
||||
|
||||
|
||||
@@ -11,6 +11,7 @@ func TestIsClaudeThinkingModel(t *testing.T) {
|
||||
// Claude thinking models - should return true
|
||||
{"claude-sonnet-4-5-thinking", "claude-sonnet-4-5-thinking", true},
|
||||
{"claude-opus-4-5-thinking", "claude-opus-4-5-thinking", true},
|
||||
{"claude-opus-4-6-thinking", "claude-opus-4-6-thinking", true},
|
||||
{"Claude-Sonnet-Thinking uppercase", "Claude-Sonnet-4-5-Thinking", true},
|
||||
{"claude thinking mixed case", "Claude-THINKING-Model", true},
|
||||
|
||||
|
||||
@@ -61,14 +61,20 @@ func cleanJSONSchema(jsonStr string, addPlaceholder bool) string {
|
||||
|
||||
// removeKeywords removes all occurrences of specified keywords from the JSON schema.
|
||||
func removeKeywords(jsonStr string, keywords []string) string {
|
||||
deletePaths := make([]string, 0)
|
||||
pathsByField := findPathsByFields(jsonStr, keywords)
|
||||
for _, key := range keywords {
|
||||
for _, p := range findPaths(jsonStr, key) {
|
||||
for _, p := range pathsByField[key] {
|
||||
if isPropertyDefinition(trimSuffix(p, "."+key)) {
|
||||
continue
|
||||
}
|
||||
jsonStr, _ = sjson.Delete(jsonStr, p)
|
||||
deletePaths = append(deletePaths, p)
|
||||
}
|
||||
}
|
||||
sortByDepth(deletePaths)
|
||||
for _, p := range deletePaths {
|
||||
jsonStr, _ = sjson.Delete(jsonStr, p)
|
||||
}
|
||||
return jsonStr
|
||||
}
|
||||
|
||||
@@ -235,8 +241,9 @@ var unsupportedConstraints = []string{
|
||||
}
|
||||
|
||||
func moveConstraintsToDescription(jsonStr string) string {
|
||||
pathsByField := findPathsByFields(jsonStr, unsupportedConstraints)
|
||||
for _, key := range unsupportedConstraints {
|
||||
for _, p := range findPaths(jsonStr, key) {
|
||||
for _, p := range pathsByField[key] {
|
||||
val := gjson.Get(jsonStr, p)
|
||||
if !val.Exists() || val.IsObject() || val.IsArray() {
|
||||
continue
|
||||
@@ -424,14 +431,21 @@ func removeUnsupportedKeywords(jsonStr string) string {
|
||||
"$schema", "$defs", "definitions", "const", "$ref", "additionalProperties",
|
||||
"propertyNames", // Gemini doesn't support property name validation
|
||||
)
|
||||
|
||||
deletePaths := make([]string, 0)
|
||||
pathsByField := findPathsByFields(jsonStr, keywords)
|
||||
for _, key := range keywords {
|
||||
for _, p := range findPaths(jsonStr, key) {
|
||||
for _, p := range pathsByField[key] {
|
||||
if isPropertyDefinition(trimSuffix(p, "."+key)) {
|
||||
continue
|
||||
}
|
||||
jsonStr, _ = sjson.Delete(jsonStr, p)
|
||||
deletePaths = append(deletePaths, p)
|
||||
}
|
||||
}
|
||||
sortByDepth(deletePaths)
|
||||
for _, p := range deletePaths {
|
||||
jsonStr, _ = sjson.Delete(jsonStr, p)
|
||||
}
|
||||
// Remove x-* extension fields (e.g., x-google-enum-descriptions) that are not supported by Gemini API
|
||||
jsonStr = removeExtensionFields(jsonStr)
|
||||
return jsonStr
|
||||
@@ -581,6 +595,42 @@ func findPaths(jsonStr, field string) []string {
|
||||
return paths
|
||||
}
|
||||
|
||||
func findPathsByFields(jsonStr string, fields []string) map[string][]string {
|
||||
set := make(map[string]struct{}, len(fields))
|
||||
for _, field := range fields {
|
||||
set[field] = struct{}{}
|
||||
}
|
||||
paths := make(map[string][]string, len(set))
|
||||
walkForFields(gjson.Parse(jsonStr), "", set, paths)
|
||||
return paths
|
||||
}
|
||||
|
||||
func walkForFields(value gjson.Result, path string, fields map[string]struct{}, paths map[string][]string) {
|
||||
switch value.Type {
|
||||
case gjson.JSON:
|
||||
value.ForEach(func(key, val gjson.Result) bool {
|
||||
keyStr := key.String()
|
||||
safeKey := escapeGJSONPathKey(keyStr)
|
||||
|
||||
var childPath string
|
||||
if path == "" {
|
||||
childPath = safeKey
|
||||
} else {
|
||||
childPath = path + "." + safeKey
|
||||
}
|
||||
|
||||
if _, ok := fields[keyStr]; ok {
|
||||
paths[keyStr] = append(paths[keyStr], childPath)
|
||||
}
|
||||
|
||||
walkForFields(val, childPath, fields, paths)
|
||||
return true
|
||||
})
|
||||
case gjson.String, gjson.Number, gjson.True, gjson.False, gjson.Null:
|
||||
// Terminal types - no further traversal needed
|
||||
}
|
||||
}
|
||||
|
||||
func sortByDepth(paths []string) {
|
||||
sort.Slice(paths, func(i, j int) bool { return len(paths[i]) > len(paths[j]) })
|
||||
}
|
||||
@@ -667,6 +717,9 @@ func orDefault(val, def string) string {
|
||||
}
|
||||
|
||||
func escapeGJSONPathKey(key string) string {
|
||||
if strings.IndexAny(key, ".*?") == -1 {
|
||||
return key
|
||||
}
|
||||
return gjsonPathKeyReplacer.Replace(key)
|
||||
}
|
||||
|
||||
|
||||
@@ -6,7 +6,6 @@ package util
|
||||
import (
|
||||
"bytes"
|
||||
"fmt"
|
||||
"strings"
|
||||
|
||||
"github.com/tidwall/gjson"
|
||||
"github.com/tidwall/sjson"
|
||||
@@ -33,15 +32,15 @@ func Walk(value gjson.Result, path, field string, paths *[]string) {
|
||||
// . -> \.
|
||||
// * -> \*
|
||||
// ? -> \?
|
||||
var keyReplacer = strings.NewReplacer(".", "\\.", "*", "\\*", "?", "\\?")
|
||||
safeKey := keyReplacer.Replace(key.String())
|
||||
keyStr := key.String()
|
||||
safeKey := escapeGJSONPathKey(keyStr)
|
||||
|
||||
if path == "" {
|
||||
childPath = safeKey
|
||||
} else {
|
||||
childPath = path + "." + safeKey
|
||||
}
|
||||
if key.String() == field {
|
||||
if keyStr == field {
|
||||
*paths = append(*paths, childPath)
|
||||
}
|
||||
Walk(val, childPath, field, paths)
|
||||
@@ -87,15 +86,6 @@ func RenameKey(jsonStr, oldKeyPath, newKeyPath string) (string, error) {
|
||||
return finalJson, nil
|
||||
}
|
||||
|
||||
func DeleteKey(jsonStr, keyName string) string {
|
||||
paths := make([]string, 0)
|
||||
Walk(gjson.Parse(jsonStr), "", keyName, &paths)
|
||||
for _, p := range paths {
|
||||
jsonStr, _ = sjson.Delete(jsonStr, p)
|
||||
}
|
||||
return jsonStr
|
||||
}
|
||||
|
||||
// FixJSON converts non-standard JSON that uses single quotes for strings into
|
||||
// RFC 8259-compliant JSON by converting those single-quoted strings to
|
||||
// double-quoted strings with proper escaping.
|
||||
|
||||
@@ -6,6 +6,7 @@ import (
|
||||
"context"
|
||||
"crypto/sha256"
|
||||
"encoding/hex"
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"io/fs"
|
||||
"os"
|
||||
@@ -15,6 +16,7 @@ import (
|
||||
|
||||
"github.com/router-for-me/CLIProxyAPI/v6/internal/config"
|
||||
"github.com/router-for-me/CLIProxyAPI/v6/internal/util"
|
||||
"github.com/router-for-me/CLIProxyAPI/v6/internal/watcher/diff"
|
||||
coreauth "github.com/router-for-me/CLIProxyAPI/v6/sdk/cliproxy/auth"
|
||||
log "github.com/sirupsen/logrus"
|
||||
)
|
||||
@@ -72,6 +74,7 @@ func (w *Watcher) reloadClients(rescanAuth bool, affectedOAuthProviders []string
|
||||
w.clientsMutex.Lock()
|
||||
|
||||
w.lastAuthHashes = make(map[string]string)
|
||||
w.lastAuthContents = make(map[string]*coreauth.Auth)
|
||||
if resolvedAuthDir, errResolveAuthDir := util.ResolveAuthDir(cfg.AuthDir); errResolveAuthDir != nil {
|
||||
log.Errorf("failed to resolve auth directory for hash cache: %v", errResolveAuthDir)
|
||||
} else if resolvedAuthDir != "" {
|
||||
@@ -84,6 +87,11 @@ func (w *Watcher) reloadClients(rescanAuth bool, affectedOAuthProviders []string
|
||||
sum := sha256.Sum256(data)
|
||||
normalizedPath := w.normalizeAuthPath(path)
|
||||
w.lastAuthHashes[normalizedPath] = hex.EncodeToString(sum[:])
|
||||
// Parse and cache auth content for future diff comparisons
|
||||
var auth coreauth.Auth
|
||||
if errParse := json.Unmarshal(data, &auth); errParse == nil {
|
||||
w.lastAuthContents[normalizedPath] = &auth
|
||||
}
|
||||
}
|
||||
}
|
||||
return nil
|
||||
@@ -127,6 +135,13 @@ func (w *Watcher) addOrUpdateClient(path string) {
|
||||
curHash := hex.EncodeToString(sum[:])
|
||||
normalized := w.normalizeAuthPath(path)
|
||||
|
||||
// Parse new auth content for diff comparison
|
||||
var newAuth coreauth.Auth
|
||||
if errParse := json.Unmarshal(data, &newAuth); errParse != nil {
|
||||
log.Errorf("failed to parse auth file %s: %v", filepath.Base(path), errParse)
|
||||
return
|
||||
}
|
||||
|
||||
w.clientsMutex.Lock()
|
||||
|
||||
cfg := w.config
|
||||
@@ -141,7 +156,26 @@ func (w *Watcher) addOrUpdateClient(path string) {
|
||||
return
|
||||
}
|
||||
|
||||
// Get old auth for diff comparison
|
||||
var oldAuth *coreauth.Auth
|
||||
if w.lastAuthContents != nil {
|
||||
oldAuth = w.lastAuthContents[normalized]
|
||||
}
|
||||
|
||||
// Compute and log field changes
|
||||
if changes := diff.BuildAuthChangeDetails(oldAuth, &newAuth); len(changes) > 0 {
|
||||
log.Debugf("auth field changes for %s:", filepath.Base(path))
|
||||
for _, c := range changes {
|
||||
log.Debugf(" %s", c)
|
||||
}
|
||||
}
|
||||
|
||||
// Update caches
|
||||
w.lastAuthHashes[normalized] = curHash
|
||||
if w.lastAuthContents == nil {
|
||||
w.lastAuthContents = make(map[string]*coreauth.Auth)
|
||||
}
|
||||
w.lastAuthContents[normalized] = &newAuth
|
||||
|
||||
w.clientsMutex.Unlock() // Unlock before the callback
|
||||
|
||||
@@ -160,6 +194,7 @@ func (w *Watcher) removeClient(path string) {
|
||||
|
||||
cfg := w.config
|
||||
delete(w.lastAuthHashes, normalized)
|
||||
delete(w.lastAuthContents, normalized)
|
||||
|
||||
w.clientsMutex.Unlock() // Release the lock before the callback
|
||||
|
||||
|
||||
44
internal/watcher/diff/auth_diff.go
Normal file
44
internal/watcher/diff/auth_diff.go
Normal file
@@ -0,0 +1,44 @@
|
||||
// auth_diff.go computes human-readable diffs for auth file field changes.
|
||||
package diff
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"strings"
|
||||
|
||||
coreauth "github.com/router-for-me/CLIProxyAPI/v6/sdk/cliproxy/auth"
|
||||
)
|
||||
|
||||
// BuildAuthChangeDetails computes a redacted, human-readable list of auth field changes.
|
||||
// Only prefix, proxy_url, and disabled fields are tracked; sensitive data is never printed.
|
||||
func BuildAuthChangeDetails(oldAuth, newAuth *coreauth.Auth) []string {
|
||||
changes := make([]string, 0, 3)
|
||||
|
||||
// Handle nil cases by using empty Auth as default
|
||||
if oldAuth == nil {
|
||||
oldAuth = &coreauth.Auth{}
|
||||
}
|
||||
if newAuth == nil {
|
||||
return changes
|
||||
}
|
||||
|
||||
// Compare prefix
|
||||
oldPrefix := strings.TrimSpace(oldAuth.Prefix)
|
||||
newPrefix := strings.TrimSpace(newAuth.Prefix)
|
||||
if oldPrefix != newPrefix {
|
||||
changes = append(changes, fmt.Sprintf("prefix: %s -> %s", oldPrefix, newPrefix))
|
||||
}
|
||||
|
||||
// Compare proxy_url (redacted)
|
||||
oldProxy := strings.TrimSpace(oldAuth.ProxyURL)
|
||||
newProxy := strings.TrimSpace(newAuth.ProxyURL)
|
||||
if oldProxy != newProxy {
|
||||
changes = append(changes, fmt.Sprintf("proxy_url: %s -> %s", formatProxyURL(oldProxy), formatProxyURL(newProxy)))
|
||||
}
|
||||
|
||||
// Compare disabled
|
||||
if oldAuth.Disabled != newAuth.Disabled {
|
||||
changes = append(changes, fmt.Sprintf("disabled: %t -> %t", oldAuth.Disabled, newAuth.Disabled))
|
||||
}
|
||||
|
||||
return changes
|
||||
}
|
||||
@@ -27,6 +27,12 @@ func BuildConfigChangeDetails(oldCfg, newCfg *config.Config) []string {
|
||||
if oldCfg.Debug != newCfg.Debug {
|
||||
changes = append(changes, fmt.Sprintf("debug: %t -> %t", oldCfg.Debug, newCfg.Debug))
|
||||
}
|
||||
if oldCfg.Pprof.Enable != newCfg.Pprof.Enable {
|
||||
changes = append(changes, fmt.Sprintf("pprof.enable: %t -> %t", oldCfg.Pprof.Enable, newCfg.Pprof.Enable))
|
||||
}
|
||||
if strings.TrimSpace(oldCfg.Pprof.Addr) != strings.TrimSpace(newCfg.Pprof.Addr) {
|
||||
changes = append(changes, fmt.Sprintf("pprof.addr: %s -> %s", strings.TrimSpace(oldCfg.Pprof.Addr), strings.TrimSpace(newCfg.Pprof.Addr)))
|
||||
}
|
||||
if oldCfg.LoggingToFile != newCfg.LoggingToFile {
|
||||
changes = append(changes, fmt.Sprintf("logging-to-file: %t -> %t", oldCfg.LoggingToFile, newCfg.LoggingToFile))
|
||||
}
|
||||
|
||||
@@ -38,6 +38,7 @@ type Watcher struct {
|
||||
reloadCallback func(*config.Config)
|
||||
watcher *fsnotify.Watcher
|
||||
lastAuthHashes map[string]string
|
||||
lastAuthContents map[string]*coreauth.Auth
|
||||
lastRemoveTimes map[string]time.Time
|
||||
lastConfigHash string
|
||||
authQueue chan<- AuthUpdate
|
||||
|
||||
@@ -1,12 +1,90 @@
|
||||
package access
|
||||
|
||||
import "errors"
|
||||
|
||||
var (
|
||||
// ErrNoCredentials indicates no recognizable credentials were supplied.
|
||||
ErrNoCredentials = errors.New("access: no credentials provided")
|
||||
// ErrInvalidCredential signals that supplied credentials were rejected by a provider.
|
||||
ErrInvalidCredential = errors.New("access: invalid credential")
|
||||
// ErrNotHandled tells the manager to continue trying other providers.
|
||||
ErrNotHandled = errors.New("access: not handled")
|
||||
import (
|
||||
"fmt"
|
||||
"net/http"
|
||||
"strings"
|
||||
)
|
||||
|
||||
// AuthErrorCode classifies authentication failures.
|
||||
type AuthErrorCode string
|
||||
|
||||
const (
|
||||
AuthErrorCodeNoCredentials AuthErrorCode = "no_credentials"
|
||||
AuthErrorCodeInvalidCredential AuthErrorCode = "invalid_credential"
|
||||
AuthErrorCodeNotHandled AuthErrorCode = "not_handled"
|
||||
AuthErrorCodeInternal AuthErrorCode = "internal_error"
|
||||
)
|
||||
|
||||
// AuthError carries authentication failure details and HTTP status.
|
||||
type AuthError struct {
|
||||
Code AuthErrorCode
|
||||
Message string
|
||||
StatusCode int
|
||||
Cause error
|
||||
}
|
||||
|
||||
func (e *AuthError) Error() string {
|
||||
if e == nil {
|
||||
return ""
|
||||
}
|
||||
message := strings.TrimSpace(e.Message)
|
||||
if message == "" {
|
||||
message = "authentication error"
|
||||
}
|
||||
if e.Cause != nil {
|
||||
return fmt.Sprintf("%s: %v", message, e.Cause)
|
||||
}
|
||||
return message
|
||||
}
|
||||
|
||||
func (e *AuthError) Unwrap() error {
|
||||
if e == nil {
|
||||
return nil
|
||||
}
|
||||
return e.Cause
|
||||
}
|
||||
|
||||
// HTTPStatusCode returns a safe fallback for missing status codes.
|
||||
func (e *AuthError) HTTPStatusCode() int {
|
||||
if e == nil || e.StatusCode <= 0 {
|
||||
return http.StatusInternalServerError
|
||||
}
|
||||
return e.StatusCode
|
||||
}
|
||||
|
||||
func newAuthError(code AuthErrorCode, message string, statusCode int, cause error) *AuthError {
|
||||
return &AuthError{
|
||||
Code: code,
|
||||
Message: message,
|
||||
StatusCode: statusCode,
|
||||
Cause: cause,
|
||||
}
|
||||
}
|
||||
|
||||
func NewNoCredentialsError() *AuthError {
|
||||
return newAuthError(AuthErrorCodeNoCredentials, "Missing API key", http.StatusUnauthorized, nil)
|
||||
}
|
||||
|
||||
func NewInvalidCredentialError() *AuthError {
|
||||
return newAuthError(AuthErrorCodeInvalidCredential, "Invalid API key", http.StatusUnauthorized, nil)
|
||||
}
|
||||
|
||||
func NewNotHandledError() *AuthError {
|
||||
return newAuthError(AuthErrorCodeNotHandled, "authentication provider did not handle request", 0, nil)
|
||||
}
|
||||
|
||||
func NewInternalAuthError(message string, cause error) *AuthError {
|
||||
normalizedMessage := strings.TrimSpace(message)
|
||||
if normalizedMessage == "" {
|
||||
normalizedMessage = "Authentication service error"
|
||||
}
|
||||
return newAuthError(AuthErrorCodeInternal, normalizedMessage, http.StatusInternalServerError, cause)
|
||||
}
|
||||
|
||||
func IsAuthErrorCode(authErr *AuthError, code AuthErrorCode) bool {
|
||||
if authErr == nil {
|
||||
return false
|
||||
}
|
||||
return authErr.Code == code
|
||||
}
|
||||
|
||||
@@ -2,7 +2,6 @@ package access
|
||||
|
||||
import (
|
||||
"context"
|
||||
"errors"
|
||||
"net/http"
|
||||
"sync"
|
||||
)
|
||||
@@ -43,7 +42,7 @@ func (m *Manager) Providers() []Provider {
|
||||
}
|
||||
|
||||
// Authenticate evaluates providers until one succeeds.
|
||||
func (m *Manager) Authenticate(ctx context.Context, r *http.Request) (*Result, error) {
|
||||
func (m *Manager) Authenticate(ctx context.Context, r *http.Request) (*Result, *AuthError) {
|
||||
if m == nil {
|
||||
return nil, nil
|
||||
}
|
||||
@@ -61,29 +60,29 @@ func (m *Manager) Authenticate(ctx context.Context, r *http.Request) (*Result, e
|
||||
if provider == nil {
|
||||
continue
|
||||
}
|
||||
res, err := provider.Authenticate(ctx, r)
|
||||
if err == nil {
|
||||
res, authErr := provider.Authenticate(ctx, r)
|
||||
if authErr == nil {
|
||||
return res, nil
|
||||
}
|
||||
if errors.Is(err, ErrNotHandled) {
|
||||
if IsAuthErrorCode(authErr, AuthErrorCodeNotHandled) {
|
||||
continue
|
||||
}
|
||||
if errors.Is(err, ErrNoCredentials) {
|
||||
if IsAuthErrorCode(authErr, AuthErrorCodeNoCredentials) {
|
||||
missing = true
|
||||
continue
|
||||
}
|
||||
if errors.Is(err, ErrInvalidCredential) {
|
||||
if IsAuthErrorCode(authErr, AuthErrorCodeInvalidCredential) {
|
||||
invalid = true
|
||||
continue
|
||||
}
|
||||
return nil, err
|
||||
return nil, authErr
|
||||
}
|
||||
|
||||
if invalid {
|
||||
return nil, ErrInvalidCredential
|
||||
return nil, NewInvalidCredentialError()
|
||||
}
|
||||
if missing {
|
||||
return nil, ErrNoCredentials
|
||||
return nil, NewNoCredentialsError()
|
||||
}
|
||||
return nil, ErrNoCredentials
|
||||
return nil, NewNoCredentialsError()
|
||||
}
|
||||
|
||||
Some files were not shown because too many files have changed in this diff Show More
Reference in New Issue
Block a user