mirror of
https://github.com/router-for-me/CLIProxyAPIPlus.git
synced 2026-04-24 07:02:33 +00:00
Compare commits
66 Commits
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
cc8c4ffb5f | ||
|
|
1510bfcb6f | ||
|
|
c6bd91b86b | ||
|
|
349ddcaa89 | ||
|
|
938a799263 | ||
|
|
e17d4f8d98 | ||
|
|
c8cae1f74d | ||
|
|
0040d78496 | ||
|
|
896de027cc | ||
|
|
fc329ebf37 | ||
|
|
91841a5519 | ||
|
|
eaab1d6824 | ||
|
|
0cfe310df6 | ||
|
|
918b6955e4 | ||
|
|
532fbf00d4 | ||
|
|
45b6fffd7f | ||
|
|
5a3eb08739 | ||
|
|
0dff329162 | ||
|
|
49c1740b47 | ||
|
|
3fbee51e9f | ||
|
|
a3dc56d2a0 | ||
|
|
63643c44a1 | ||
|
|
1d93608dbe | ||
|
|
d125b7de92 | ||
|
|
d5654ee316 | ||
|
|
3b34521ad9 | ||
|
|
7197fb350b | ||
|
|
6e349bfcc7 | ||
|
|
234056072d | ||
|
|
76330f4bff | ||
|
|
d468eec6ec | ||
|
|
9bc6cc5b41 | ||
|
|
d109be159c | ||
|
|
eddf31e55b | ||
|
|
7e9d0db6aa | ||
|
|
2f1874ede5 | ||
|
|
6b83585b53 | ||
|
|
78ef04fcf1 | ||
|
|
b7e4f00c5f | ||
|
|
c20507c15e | ||
|
|
f7d0019df7 | ||
|
|
52364af5bf | ||
|
|
f410dd0440 | ||
|
|
eb5582c17c | ||
|
|
1c6cb2bec3 | ||
|
|
80b5e79e75 | ||
|
|
d182e893b6 | ||
|
|
2e8d49a641 | ||
|
|
6abd7d27d9 | ||
|
|
8fa12af403 | ||
|
|
77586ed7d3 | ||
|
|
394497fb2f | ||
|
|
fc7b6ef086 | ||
|
|
98edcad39d | ||
|
|
1187aa8222 | ||
|
|
a35d66443b | ||
|
|
40ad4a42ea | ||
|
|
dc9b4dd017 | ||
|
|
68cb81a258 | ||
|
|
16693053f5 | ||
|
|
4e3bad3907 | ||
|
|
c874f19f2a | ||
|
|
f5f26f0cbe | ||
|
|
233be6272a | ||
|
|
47cb52385e | ||
|
|
a406ca2d5a |
Binary file not shown.
|
Before Width: | Height: | Size: 51 KiB |
@@ -77,6 +77,7 @@ func main() {
|
|||||||
var noBrowser bool
|
var noBrowser bool
|
||||||
var oauthCallbackPort int
|
var oauthCallbackPort int
|
||||||
var antigravityLogin bool
|
var antigravityLogin bool
|
||||||
|
var kimiLogin bool
|
||||||
var kiroLogin bool
|
var kiroLogin bool
|
||||||
var kiroGoogleLogin bool
|
var kiroGoogleLogin bool
|
||||||
var kiroAWSLogin bool
|
var kiroAWSLogin bool
|
||||||
@@ -102,6 +103,7 @@ func main() {
|
|||||||
flag.BoolVar(&useIncognito, "incognito", false, "Open browser in incognito/private mode for OAuth (useful for multiple accounts)")
|
flag.BoolVar(&useIncognito, "incognito", false, "Open browser in incognito/private mode for OAuth (useful for multiple accounts)")
|
||||||
flag.BoolVar(&noIncognito, "no-incognito", false, "Force disable incognito mode (uses existing browser session)")
|
flag.BoolVar(&noIncognito, "no-incognito", false, "Force disable incognito mode (uses existing browser session)")
|
||||||
flag.BoolVar(&antigravityLogin, "antigravity-login", false, "Login to Antigravity using OAuth")
|
flag.BoolVar(&antigravityLogin, "antigravity-login", false, "Login to Antigravity using OAuth")
|
||||||
|
flag.BoolVar(&kimiLogin, "kimi-login", false, "Login to Kimi using OAuth")
|
||||||
flag.BoolVar(&kiroLogin, "kiro-login", false, "Login to Kiro using Google OAuth")
|
flag.BoolVar(&kiroLogin, "kiro-login", false, "Login to Kiro using Google OAuth")
|
||||||
flag.BoolVar(&kiroGoogleLogin, "kiro-google-login", false, "Login to Kiro using Google OAuth (same as --kiro-login)")
|
flag.BoolVar(&kiroGoogleLogin, "kiro-google-login", false, "Login to Kiro using Google OAuth (same as --kiro-login)")
|
||||||
flag.BoolVar(&kiroAWSLogin, "kiro-aws-login", false, "Login to Kiro using AWS Builder ID (device code flow)")
|
flag.BoolVar(&kiroAWSLogin, "kiro-aws-login", false, "Login to Kiro using AWS Builder ID (device code flow)")
|
||||||
@@ -473,7 +475,7 @@ func main() {
|
|||||||
}
|
}
|
||||||
|
|
||||||
// Register built-in access providers before constructing services.
|
// Register built-in access providers before constructing services.
|
||||||
configaccess.Register()
|
configaccess.Register(&cfg.SDKConfig)
|
||||||
|
|
||||||
// Handle different command modes based on the provided flags.
|
// Handle different command modes based on the provided flags.
|
||||||
|
|
||||||
@@ -501,6 +503,8 @@ func main() {
|
|||||||
cmd.DoIFlowLogin(cfg, options)
|
cmd.DoIFlowLogin(cfg, options)
|
||||||
} else if iflowCookie {
|
} else if iflowCookie {
|
||||||
cmd.DoIFlowCookieAuth(cfg, options)
|
cmd.DoIFlowCookieAuth(cfg, options)
|
||||||
|
} else if kimiLogin {
|
||||||
|
cmd.DoKimiLogin(cfg, options)
|
||||||
} else if kiroLogin {
|
} else if kiroLogin {
|
||||||
// For Kiro auth, default to incognito mode for multi-account support
|
// For Kiro auth, default to incognito mode for multi-account support
|
||||||
// Users can explicitly override with --no-incognito
|
// Users can explicitly override with --no-incognito
|
||||||
|
|||||||
@@ -236,10 +236,10 @@ nonstream-keepalive-interval: 0
|
|||||||
|
|
||||||
# Global OAuth model name aliases (per channel)
|
# Global OAuth model name aliases (per channel)
|
||||||
# These aliases rename model IDs for both model listing and request routing.
|
# These aliases rename model IDs for both model listing and request routing.
|
||||||
# Supported channels: gemini-cli, vertex, aistudio, antigravity, claude, codex, qwen, iflow, kiro, github-copilot.
|
# Supported channels: gemini-cli, vertex, aistudio, antigravity, claude, codex, qwen, iflow, kiro, github-copilot, kimi.
|
||||||
# NOTE: Aliases do not apply to gemini-api-key, codex-api-key, claude-api-key, openai-compatibility, vertex-api-key, or ampcode.
|
# NOTE: Aliases do not apply to gemini-api-key, codex-api-key, claude-api-key, openai-compatibility, vertex-api-key, or ampcode.
|
||||||
# You can repeat the same name with different aliases to expose multiple client model names.
|
# You can repeat the same name with different aliases to expose multiple client model names.
|
||||||
#oauth-model-alias:
|
# oauth-model-alias:
|
||||||
# antigravity:
|
# antigravity:
|
||||||
# - name: "rev19-uic3-1p"
|
# - name: "rev19-uic3-1p"
|
||||||
# alias: "gemini-2.5-computer-use-preview-10-2025"
|
# alias: "gemini-2.5-computer-use-preview-10-2025"
|
||||||
@@ -265,9 +265,6 @@ nonstream-keepalive-interval: 0
|
|||||||
# aistudio:
|
# aistudio:
|
||||||
# - name: "gemini-2.5-pro"
|
# - name: "gemini-2.5-pro"
|
||||||
# alias: "g2.5p"
|
# alias: "g2.5p"
|
||||||
# antigravity:
|
|
||||||
# - name: "gemini-3-pro-preview"
|
|
||||||
# alias: "g3p"
|
|
||||||
# claude:
|
# claude:
|
||||||
# - name: "claude-sonnet-4-5-20250929"
|
# - name: "claude-sonnet-4-5-20250929"
|
||||||
# alias: "cs4.5"
|
# alias: "cs4.5"
|
||||||
@@ -280,6 +277,9 @@ nonstream-keepalive-interval: 0
|
|||||||
# iflow:
|
# iflow:
|
||||||
# - name: "glm-4.7"
|
# - name: "glm-4.7"
|
||||||
# alias: "glm-god"
|
# alias: "glm-god"
|
||||||
|
# kimi:
|
||||||
|
# - name: "kimi-k2.5"
|
||||||
|
# alias: "k2.5"
|
||||||
# kiro:
|
# kiro:
|
||||||
# - name: "kiro-claude-opus-4-5"
|
# - name: "kiro-claude-opus-4-5"
|
||||||
# alias: "op45"
|
# alias: "op45"
|
||||||
@@ -309,6 +309,8 @@ nonstream-keepalive-interval: 0
|
|||||||
# - "vision-model"
|
# - "vision-model"
|
||||||
# iflow:
|
# iflow:
|
||||||
# - "tstars2.0"
|
# - "tstars2.0"
|
||||||
|
# kimi:
|
||||||
|
# - "kimi-k2-thinking"
|
||||||
# kiro:
|
# kiro:
|
||||||
# - "kiro-claude-haiku-4-5"
|
# - "kiro-claude-haiku-4-5"
|
||||||
# github-copilot:
|
# github-copilot:
|
||||||
|
|||||||
@@ -7,80 +7,71 @@ The `github.com/router-for-me/CLIProxyAPI/v6/sdk/access` package centralizes inb
|
|||||||
```go
|
```go
|
||||||
import (
|
import (
|
||||||
sdkaccess "github.com/router-for-me/CLIProxyAPI/v6/sdk/access"
|
sdkaccess "github.com/router-for-me/CLIProxyAPI/v6/sdk/access"
|
||||||
"github.com/router-for-me/CLIProxyAPI/v6/internal/config"
|
|
||||||
)
|
)
|
||||||
```
|
```
|
||||||
|
|
||||||
Add the module with `go get github.com/router-for-me/CLIProxyAPI/v6/sdk/access`.
|
Add the module with `go get github.com/router-for-me/CLIProxyAPI/v6/sdk/access`.
|
||||||
|
|
||||||
|
## Provider Registry
|
||||||
|
|
||||||
|
Providers are registered globally and then attached to a `Manager` as a snapshot:
|
||||||
|
|
||||||
|
- `RegisterProvider(type, provider)` installs a pre-initialized provider instance.
|
||||||
|
- Registration order is preserved the first time each `type` is seen.
|
||||||
|
- `RegisteredProviders()` returns the providers in that order.
|
||||||
|
|
||||||
## Manager Lifecycle
|
## Manager Lifecycle
|
||||||
|
|
||||||
```go
|
```go
|
||||||
manager := sdkaccess.NewManager()
|
manager := sdkaccess.NewManager()
|
||||||
providers, err := sdkaccess.BuildProviders(cfg)
|
manager.SetProviders(sdkaccess.RegisteredProviders())
|
||||||
if err != nil {
|
|
||||||
return err
|
|
||||||
}
|
|
||||||
manager.SetProviders(providers)
|
|
||||||
```
|
```
|
||||||
|
|
||||||
* `NewManager` constructs an empty manager.
|
* `NewManager` constructs an empty manager.
|
||||||
* `SetProviders` replaces the provider slice using a defensive copy.
|
* `SetProviders` replaces the provider slice using a defensive copy.
|
||||||
* `Providers` retrieves a snapshot that can be iterated safely from other goroutines.
|
* `Providers` retrieves a snapshot that can be iterated safely from other goroutines.
|
||||||
* `BuildProviders` translates `config.Config` access declarations into runnable providers. When the config omits explicit providers but defines inline API keys, the helper auto-installs the built-in `config-api-key` provider.
|
|
||||||
|
If the manager itself is `nil` or no providers are configured, the call returns `nil, nil`, allowing callers to treat access control as disabled.
|
||||||
|
|
||||||
## Authenticating Requests
|
## Authenticating Requests
|
||||||
|
|
||||||
```go
|
```go
|
||||||
result, err := manager.Authenticate(ctx, req)
|
result, authErr := manager.Authenticate(ctx, req)
|
||||||
switch {
|
switch {
|
||||||
case err == nil:
|
case authErr == nil:
|
||||||
// Authentication succeeded; result describes the provider and principal.
|
// Authentication succeeded; result describes the provider and principal.
|
||||||
case errors.Is(err, sdkaccess.ErrNoCredentials):
|
case sdkaccess.IsAuthErrorCode(authErr, sdkaccess.AuthErrorCodeNoCredentials):
|
||||||
// No recognizable credentials were supplied.
|
// No recognizable credentials were supplied.
|
||||||
case errors.Is(err, sdkaccess.ErrInvalidCredential):
|
case sdkaccess.IsAuthErrorCode(authErr, sdkaccess.AuthErrorCodeInvalidCredential):
|
||||||
// Supplied credentials were present but rejected.
|
// Supplied credentials were present but rejected.
|
||||||
default:
|
default:
|
||||||
// Transport-level failure was returned by a provider.
|
// Internal/transport failure was returned by a provider.
|
||||||
}
|
}
|
||||||
```
|
```
|
||||||
|
|
||||||
`Manager.Authenticate` walks the configured providers in order. It returns on the first success, skips providers that surface `ErrNotHandled`, and tracks whether any provider reported `ErrNoCredentials` or `ErrInvalidCredential` for downstream error reporting.
|
`Manager.Authenticate` walks the configured providers in order. It returns on the first success, skips providers that return `AuthErrorCodeNotHandled`, and aggregates `AuthErrorCodeNoCredentials` / `AuthErrorCodeInvalidCredential` for a final result.
|
||||||
|
|
||||||
If the manager itself is `nil` or no providers are registered, the call returns `nil, nil`, allowing callers to treat access control as disabled without branching on errors.
|
|
||||||
|
|
||||||
Each `Result` includes the provider identifier, the resolved principal, and optional metadata (for example, which header carried the credential).
|
Each `Result` includes the provider identifier, the resolved principal, and optional metadata (for example, which header carried the credential).
|
||||||
|
|
||||||
## Configuration Layout
|
## Built-in `config-api-key` Provider
|
||||||
|
|
||||||
The manager expects access providers under the `auth.providers` key inside `config.yaml`:
|
The proxy includes one built-in access provider:
|
||||||
|
|
||||||
|
- `config-api-key`: Validates API keys declared under top-level `api-keys`.
|
||||||
|
- Credential sources: `Authorization: Bearer`, `X-Goog-Api-Key`, `X-Api-Key`, `?key=`, `?auth_token=`
|
||||||
|
- Metadata: `Result.Metadata["source"]` is set to the matched source label.
|
||||||
|
|
||||||
|
In the CLI server and `sdk/cliproxy`, this provider is registered automatically based on the loaded configuration.
|
||||||
|
|
||||||
```yaml
|
```yaml
|
||||||
auth:
|
api-keys:
|
||||||
providers:
|
- sk-test-123
|
||||||
- name: inline-api
|
- sk-prod-456
|
||||||
type: config-api-key
|
|
||||||
api-keys:
|
|
||||||
- sk-test-123
|
|
||||||
- sk-prod-456
|
|
||||||
```
|
```
|
||||||
|
|
||||||
Fields map directly to `config.AccessProvider`: `name` labels the provider, `type` selects the registered factory, `sdk` can name an external module, `api-keys` seeds inline credentials, and `config` passes provider-specific options.
|
## Loading Providers from External Go Modules
|
||||||
|
|
||||||
### Loading providers from external SDK modules
|
To consume a provider shipped in another Go module, import it for its registration side effect:
|
||||||
|
|
||||||
To consume a provider shipped in another Go module, point the `sdk` field at the module path and import it for its registration side effect:
|
|
||||||
|
|
||||||
```yaml
|
|
||||||
auth:
|
|
||||||
providers:
|
|
||||||
- name: partner-auth
|
|
||||||
type: partner-token
|
|
||||||
sdk: github.com/acme/xplatform/sdk/access/providers/partner
|
|
||||||
config:
|
|
||||||
region: us-west-2
|
|
||||||
audience: cli-proxy
|
|
||||||
```
|
|
||||||
|
|
||||||
```go
|
```go
|
||||||
import (
|
import (
|
||||||
@@ -89,19 +80,11 @@ import (
|
|||||||
)
|
)
|
||||||
```
|
```
|
||||||
|
|
||||||
The blank identifier import ensures `init` runs so `sdkaccess.RegisterProvider` executes before `BuildProviders` is called.
|
The blank identifier import ensures `init` runs so `sdkaccess.RegisterProvider` executes before you call `RegisteredProviders()` (or before `cliproxy.NewBuilder().Build()`).
|
||||||
|
|
||||||
## Built-in Providers
|
|
||||||
|
|
||||||
The SDK ships with one provider out of the box:
|
|
||||||
|
|
||||||
- `config-api-key`: Validates API keys declared inline or under top-level `api-keys`. It accepts the key from `Authorization: Bearer`, `X-Goog-Api-Key`, `X-Api-Key`, or the `?key=` query string and reports `ErrInvalidCredential` when no match is found.
|
|
||||||
|
|
||||||
Additional providers can be delivered by third-party packages. When a provider package is imported, it registers itself with `sdkaccess.RegisterProvider`.
|
|
||||||
|
|
||||||
### Metadata and auditing
|
### Metadata and auditing
|
||||||
|
|
||||||
`Result.Metadata` carries provider-specific context. The built-in `config-api-key` provider, for example, stores the credential source (`authorization`, `x-goog-api-key`, `x-api-key`, or `query-key`). Populate this map in custom providers to enrich logs and downstream auditing.
|
`Result.Metadata` carries provider-specific context. The built-in `config-api-key` provider, for example, stores the credential source (`authorization`, `x-goog-api-key`, `x-api-key`, `query-key`, `query-auth-token`). Populate this map in custom providers to enrich logs and downstream auditing.
|
||||||
|
|
||||||
## Writing Custom Providers
|
## Writing Custom Providers
|
||||||
|
|
||||||
@@ -110,13 +93,13 @@ type customProvider struct{}
|
|||||||
|
|
||||||
func (p *customProvider) Identifier() string { return "my-provider" }
|
func (p *customProvider) Identifier() string { return "my-provider" }
|
||||||
|
|
||||||
func (p *customProvider) Authenticate(ctx context.Context, r *http.Request) (*sdkaccess.Result, error) {
|
func (p *customProvider) Authenticate(ctx context.Context, r *http.Request) (*sdkaccess.Result, *sdkaccess.AuthError) {
|
||||||
token := r.Header.Get("X-Custom")
|
token := r.Header.Get("X-Custom")
|
||||||
if token == "" {
|
if token == "" {
|
||||||
return nil, sdkaccess.ErrNoCredentials
|
return nil, sdkaccess.NewNotHandledError()
|
||||||
}
|
}
|
||||||
if token != "expected" {
|
if token != "expected" {
|
||||||
return nil, sdkaccess.ErrInvalidCredential
|
return nil, sdkaccess.NewInvalidCredentialError()
|
||||||
}
|
}
|
||||||
return &sdkaccess.Result{
|
return &sdkaccess.Result{
|
||||||
Provider: p.Identifier(),
|
Provider: p.Identifier(),
|
||||||
@@ -126,51 +109,46 @@ func (p *customProvider) Authenticate(ctx context.Context, r *http.Request) (*sd
|
|||||||
}
|
}
|
||||||
|
|
||||||
func init() {
|
func init() {
|
||||||
sdkaccess.RegisterProvider("custom", func(cfg *config.AccessProvider, root *config.Config) (sdkaccess.Provider, error) {
|
sdkaccess.RegisterProvider("custom", &customProvider{})
|
||||||
return &customProvider{}, nil
|
|
||||||
})
|
|
||||||
}
|
}
|
||||||
```
|
```
|
||||||
|
|
||||||
A provider must implement `Identifier()` and `Authenticate()`. To expose it to configuration, call `RegisterProvider` inside `init`. Provider factories receive the specific `AccessProvider` block plus the full root configuration for contextual needs.
|
A provider must implement `Identifier()` and `Authenticate()`. To make it available to the access manager, call `RegisterProvider` inside `init` with an initialized provider instance.
|
||||||
|
|
||||||
## Error Semantics
|
## Error Semantics
|
||||||
|
|
||||||
- `ErrNoCredentials`: no credentials were present or recognized by any provider.
|
- `NewNoCredentialsError()` (`AuthErrorCodeNoCredentials`): no credentials were present or recognized. (HTTP 401)
|
||||||
- `ErrInvalidCredential`: at least one provider processed the credentials but rejected them.
|
- `NewInvalidCredentialError()` (`AuthErrorCodeInvalidCredential`): credentials were present but rejected. (HTTP 401)
|
||||||
- `ErrNotHandled`: instructs the manager to fall through to the next provider without affecting aggregate error reporting.
|
- `NewNotHandledError()` (`AuthErrorCodeNotHandled`): fall through to the next provider.
|
||||||
|
- `NewInternalAuthError(message, cause)` (`AuthErrorCodeInternal`): transport/system failure. (HTTP 500)
|
||||||
|
|
||||||
Return custom errors to surface transport failures; they propagate immediately to the caller instead of being masked.
|
Errors propagate immediately to the caller unless they are classified as `not_handled` / `no_credentials` / `invalid_credential` and can be aggregated by the manager.
|
||||||
|
|
||||||
## Integration with cliproxy Service
|
## Integration with cliproxy Service
|
||||||
|
|
||||||
`sdk/cliproxy` wires `@sdk/access` automatically when you build a CLI service via `cliproxy.NewBuilder`. Supplying a preconfigured manager allows you to extend or override the default providers:
|
`sdk/cliproxy` wires `@sdk/access` automatically when you build a CLI service via `cliproxy.NewBuilder`. Supplying a manager lets you reuse the same instance in your host process:
|
||||||
|
|
||||||
```go
|
```go
|
||||||
coreCfg, _ := config.LoadConfig("config.yaml")
|
coreCfg, _ := config.LoadConfig("config.yaml")
|
||||||
providers, _ := sdkaccess.BuildProviders(coreCfg)
|
accessManager := sdkaccess.NewManager()
|
||||||
manager := sdkaccess.NewManager()
|
|
||||||
manager.SetProviders(providers)
|
|
||||||
|
|
||||||
svc, _ := cliproxy.NewBuilder().
|
svc, _ := cliproxy.NewBuilder().
|
||||||
WithConfig(coreCfg).
|
WithConfig(coreCfg).
|
||||||
WithAccessManager(manager).
|
WithConfigPath("config.yaml").
|
||||||
|
WithRequestAccessManager(accessManager).
|
||||||
Build()
|
Build()
|
||||||
```
|
```
|
||||||
|
|
||||||
The service reuses the manager for every inbound request, ensuring consistent authentication across embedded deployments and the canonical CLI binary.
|
Register any custom providers (typically via blank imports) before calling `Build()` so they are present in the global registry snapshot.
|
||||||
|
|
||||||
### Hot reloading providers
|
### Hot reloading
|
||||||
|
|
||||||
When configuration changes, rebuild providers and swap them into the manager:
|
When configuration changes, refresh any config-backed providers and then reset the manager's provider chain:
|
||||||
|
|
||||||
```go
|
```go
|
||||||
providers, err := sdkaccess.BuildProviders(newCfg)
|
// configaccess is github.com/router-for-me/CLIProxyAPI/v6/internal/access/config_access
|
||||||
if err != nil {
|
configaccess.Register(&newCfg.SDKConfig)
|
||||||
log.Errorf("reload auth providers failed: %v", err)
|
accessManager.SetProviders(sdkaccess.RegisteredProviders())
|
||||||
return
|
|
||||||
}
|
|
||||||
accessManager.SetProviders(providers)
|
|
||||||
```
|
```
|
||||||
|
|
||||||
This mirrors the behaviour in `cliproxy.Service.refreshAccessProviders` and `api.Server.applyAccessConfig`, enabling runtime updates without restarting the process.
|
This mirrors the behaviour in `internal/access.ApplyAccessProviders`, enabling runtime updates without restarting the process.
|
||||||
|
|||||||
@@ -7,80 +7,71 @@
|
|||||||
```go
|
```go
|
||||||
import (
|
import (
|
||||||
sdkaccess "github.com/router-for-me/CLIProxyAPI/v6/sdk/access"
|
sdkaccess "github.com/router-for-me/CLIProxyAPI/v6/sdk/access"
|
||||||
"github.com/router-for-me/CLIProxyAPI/v6/internal/config"
|
|
||||||
)
|
)
|
||||||
```
|
```
|
||||||
|
|
||||||
通过 `go get github.com/router-for-me/CLIProxyAPI/v6/sdk/access` 添加依赖。
|
通过 `go get github.com/router-for-me/CLIProxyAPI/v6/sdk/access` 添加依赖。
|
||||||
|
|
||||||
|
## Provider Registry
|
||||||
|
|
||||||
|
访问提供者是全局注册,然后以快照形式挂到 `Manager` 上:
|
||||||
|
|
||||||
|
- `RegisterProvider(type, provider)` 注册一个已经初始化好的 provider 实例。
|
||||||
|
- 每个 `type` 第一次出现时会记录其注册顺序。
|
||||||
|
- `RegisteredProviders()` 会按该顺序返回 provider 列表。
|
||||||
|
|
||||||
## 管理器生命周期
|
## 管理器生命周期
|
||||||
|
|
||||||
```go
|
```go
|
||||||
manager := sdkaccess.NewManager()
|
manager := sdkaccess.NewManager()
|
||||||
providers, err := sdkaccess.BuildProviders(cfg)
|
manager.SetProviders(sdkaccess.RegisteredProviders())
|
||||||
if err != nil {
|
|
||||||
return err
|
|
||||||
}
|
|
||||||
manager.SetProviders(providers)
|
|
||||||
```
|
```
|
||||||
|
|
||||||
- `NewManager` 创建空管理器。
|
- `NewManager` 创建空管理器。
|
||||||
- `SetProviders` 替换提供者切片并做防御性拷贝。
|
- `SetProviders` 替换提供者切片并做防御性拷贝。
|
||||||
- `Providers` 返回适合并发读取的快照。
|
- `Providers` 返回适合并发读取的快照。
|
||||||
- `BuildProviders` 将 `config.Config` 中的访问配置转换成可运行的提供者。当配置没有显式声明但包含顶层 `api-keys` 时,会自动挂载内建的 `config-api-key` 提供者。
|
|
||||||
|
如果管理器本身为 `nil` 或未配置任何 provider,调用会返回 `nil, nil`,可视为关闭访问控制。
|
||||||
|
|
||||||
## 认证请求
|
## 认证请求
|
||||||
|
|
||||||
```go
|
```go
|
||||||
result, err := manager.Authenticate(ctx, req)
|
result, authErr := manager.Authenticate(ctx, req)
|
||||||
switch {
|
switch {
|
||||||
case err == nil:
|
case authErr == nil:
|
||||||
// Authentication succeeded; result carries provider and principal.
|
// Authentication succeeded; result carries provider and principal.
|
||||||
case errors.Is(err, sdkaccess.ErrNoCredentials):
|
case sdkaccess.IsAuthErrorCode(authErr, sdkaccess.AuthErrorCodeNoCredentials):
|
||||||
// No recognizable credentials were supplied.
|
// No recognizable credentials were supplied.
|
||||||
case errors.Is(err, sdkaccess.ErrInvalidCredential):
|
case sdkaccess.IsAuthErrorCode(authErr, sdkaccess.AuthErrorCodeInvalidCredential):
|
||||||
// Credentials were present but rejected.
|
// Credentials were present but rejected.
|
||||||
default:
|
default:
|
||||||
// Provider surfaced a transport-level failure.
|
// Provider surfaced a transport-level failure.
|
||||||
}
|
}
|
||||||
```
|
```
|
||||||
|
|
||||||
`Manager.Authenticate` 按配置顺序遍历提供者。遇到成功立即返回,`ErrNotHandled` 会继续尝试下一个;若发现 `ErrNoCredentials` 或 `ErrInvalidCredential`,会在遍历结束后汇总给调用方。
|
`Manager.Authenticate` 会按顺序遍历 provider:遇到成功立即返回,`AuthErrorCodeNotHandled` 会继续尝试下一个;`AuthErrorCodeNoCredentials` / `AuthErrorCodeInvalidCredential` 会在遍历结束后汇总给调用方。
|
||||||
|
|
||||||
若管理器本身为 `nil` 或尚未注册提供者,调用会返回 `nil, nil`,让调用方无需针对错误做额外分支即可关闭访问控制。
|
|
||||||
|
|
||||||
`Result` 提供认证提供者标识、解析出的主体以及可选元数据(例如凭证来源)。
|
`Result` 提供认证提供者标识、解析出的主体以及可选元数据(例如凭证来源)。
|
||||||
|
|
||||||
## 配置结构
|
## 内建 `config-api-key` Provider
|
||||||
|
|
||||||
在 `config.yaml` 的 `auth.providers` 下定义访问提供者:
|
代理内置一个访问提供者:
|
||||||
|
|
||||||
|
- `config-api-key`:校验 `config.yaml` 顶层的 `api-keys`。
|
||||||
|
- 凭证来源:`Authorization: Bearer`、`X-Goog-Api-Key`、`X-Api-Key`、`?key=`、`?auth_token=`
|
||||||
|
- 元数据:`Result.Metadata["source"]` 会写入匹配到的来源标识
|
||||||
|
|
||||||
|
在 CLI 服务端与 `sdk/cliproxy` 中,该 provider 会根据加载到的配置自动注册。
|
||||||
|
|
||||||
```yaml
|
```yaml
|
||||||
auth:
|
api-keys:
|
||||||
providers:
|
- sk-test-123
|
||||||
- name: inline-api
|
- sk-prod-456
|
||||||
type: config-api-key
|
|
||||||
api-keys:
|
|
||||||
- sk-test-123
|
|
||||||
- sk-prod-456
|
|
||||||
```
|
```
|
||||||
|
|
||||||
条目映射到 `config.AccessProvider`:`name` 指定实例名,`type` 选择注册的工厂,`sdk` 可引用第三方模块,`api-keys` 提供内联凭证,`config` 用于传递特定选项。
|
## 引入外部 Go 模块提供者
|
||||||
|
|
||||||
### 引入外部 SDK 提供者
|
若要消费其它 Go 模块输出的访问提供者,直接用空白标识符导入以触发其 `init` 注册即可:
|
||||||
|
|
||||||
若要消费其它 Go 模块输出的访问提供者,可在配置里填写 `sdk` 字段并在代码中引入该包,利用其 `init` 注册过程:
|
|
||||||
|
|
||||||
```yaml
|
|
||||||
auth:
|
|
||||||
providers:
|
|
||||||
- name: partner-auth
|
|
||||||
type: partner-token
|
|
||||||
sdk: github.com/acme/xplatform/sdk/access/providers/partner
|
|
||||||
config:
|
|
||||||
region: us-west-2
|
|
||||||
audience: cli-proxy
|
|
||||||
```
|
|
||||||
|
|
||||||
```go
|
```go
|
||||||
import (
|
import (
|
||||||
@@ -89,19 +80,11 @@ import (
|
|||||||
)
|
)
|
||||||
```
|
```
|
||||||
|
|
||||||
通过空白标识符导入即可确保 `init` 调用,先于 `BuildProviders` 完成 `sdkaccess.RegisterProvider`。
|
空白导入可确保 `init` 先执行,从而在你调用 `RegisteredProviders()`(或 `cliproxy.NewBuilder().Build()`)之前完成 `sdkaccess.RegisterProvider`。
|
||||||
|
|
||||||
## 内建提供者
|
|
||||||
|
|
||||||
当前 SDK 默认内置:
|
|
||||||
|
|
||||||
- `config-api-key`:校验配置中的 API Key。它从 `Authorization: Bearer`、`X-Goog-Api-Key`、`X-Api-Key` 以及查询参数 `?key=` 提取凭证,不匹配时抛出 `ErrInvalidCredential`。
|
|
||||||
|
|
||||||
导入第三方包即可通过 `sdkaccess.RegisterProvider` 注册更多类型。
|
|
||||||
|
|
||||||
### 元数据与审计
|
### 元数据与审计
|
||||||
|
|
||||||
`Result.Metadata` 用于携带提供者特定的上下文信息。内建的 `config-api-key` 会记录凭证来源(`authorization`、`x-goog-api-key`、`x-api-key` 或 `query-key`)。自定义提供者同样可以填充该 Map,以便丰富日志与审计场景。
|
`Result.Metadata` 用于携带提供者特定的上下文信息。内建的 `config-api-key` 会记录凭证来源(`authorization`、`x-goog-api-key`、`x-api-key`、`query-key`、`query-auth-token`)。自定义提供者同样可以填充该 Map,以便丰富日志与审计场景。
|
||||||
|
|
||||||
## 编写自定义提供者
|
## 编写自定义提供者
|
||||||
|
|
||||||
@@ -110,13 +93,13 @@ type customProvider struct{}
|
|||||||
|
|
||||||
func (p *customProvider) Identifier() string { return "my-provider" }
|
func (p *customProvider) Identifier() string { return "my-provider" }
|
||||||
|
|
||||||
func (p *customProvider) Authenticate(ctx context.Context, r *http.Request) (*sdkaccess.Result, error) {
|
func (p *customProvider) Authenticate(ctx context.Context, r *http.Request) (*sdkaccess.Result, *sdkaccess.AuthError) {
|
||||||
token := r.Header.Get("X-Custom")
|
token := r.Header.Get("X-Custom")
|
||||||
if token == "" {
|
if token == "" {
|
||||||
return nil, sdkaccess.ErrNoCredentials
|
return nil, sdkaccess.NewNotHandledError()
|
||||||
}
|
}
|
||||||
if token != "expected" {
|
if token != "expected" {
|
||||||
return nil, sdkaccess.ErrInvalidCredential
|
return nil, sdkaccess.NewInvalidCredentialError()
|
||||||
}
|
}
|
||||||
return &sdkaccess.Result{
|
return &sdkaccess.Result{
|
||||||
Provider: p.Identifier(),
|
Provider: p.Identifier(),
|
||||||
@@ -126,51 +109,46 @@ func (p *customProvider) Authenticate(ctx context.Context, r *http.Request) (*sd
|
|||||||
}
|
}
|
||||||
|
|
||||||
func init() {
|
func init() {
|
||||||
sdkaccess.RegisterProvider("custom", func(cfg *config.AccessProvider, root *config.Config) (sdkaccess.Provider, error) {
|
sdkaccess.RegisterProvider("custom", &customProvider{})
|
||||||
return &customProvider{}, nil
|
|
||||||
})
|
|
||||||
}
|
}
|
||||||
```
|
```
|
||||||
|
|
||||||
自定义提供者需要实现 `Identifier()` 与 `Authenticate()`。在 `init` 中调用 `RegisterProvider` 暴露给配置层,工厂函数既能读取当前条目,也能访问完整根配置。
|
自定义提供者需要实现 `Identifier()` 与 `Authenticate()`。在 `init` 中用已初始化实例调用 `RegisterProvider` 注册到全局 registry。
|
||||||
|
|
||||||
## 错误语义
|
## 错误语义
|
||||||
|
|
||||||
- `ErrNoCredentials`:任何提供者都未识别到凭证。
|
- `NewNoCredentialsError()`(`AuthErrorCodeNoCredentials`):未提供或未识别到凭证。(HTTP 401)
|
||||||
- `ErrInvalidCredential`:至少一个提供者处理了凭证但判定无效。
|
- `NewInvalidCredentialError()`(`AuthErrorCodeInvalidCredential`):凭证存在但校验失败。(HTTP 401)
|
||||||
- `ErrNotHandled`:告诉管理器跳到下一个提供者,不影响最终错误统计。
|
- `NewNotHandledError()`(`AuthErrorCodeNotHandled`):告诉管理器跳到下一个 provider。
|
||||||
|
- `NewInternalAuthError(message, cause)`(`AuthErrorCodeInternal`):网络/系统错误。(HTTP 500)
|
||||||
|
|
||||||
自定义错误(例如网络异常)会马上冒泡返回。
|
除可汇总的 `not_handled` / `no_credentials` / `invalid_credential` 外,其它错误会立即冒泡返回。
|
||||||
|
|
||||||
## 与 cliproxy 集成
|
## 与 cliproxy 集成
|
||||||
|
|
||||||
使用 `sdk/cliproxy` 构建服务时会自动接入 `@sdk/access`。如果需要扩展内置行为,可传入自定义管理器:
|
使用 `sdk/cliproxy` 构建服务时会自动接入 `@sdk/access`。如果希望在宿主进程里复用同一个 `Manager` 实例,可传入自定义管理器:
|
||||||
|
|
||||||
```go
|
```go
|
||||||
coreCfg, _ := config.LoadConfig("config.yaml")
|
coreCfg, _ := config.LoadConfig("config.yaml")
|
||||||
providers, _ := sdkaccess.BuildProviders(coreCfg)
|
accessManager := sdkaccess.NewManager()
|
||||||
manager := sdkaccess.NewManager()
|
|
||||||
manager.SetProviders(providers)
|
|
||||||
|
|
||||||
svc, _ := cliproxy.NewBuilder().
|
svc, _ := cliproxy.NewBuilder().
|
||||||
WithConfig(coreCfg).
|
WithConfig(coreCfg).
|
||||||
WithAccessManager(manager).
|
WithConfigPath("config.yaml").
|
||||||
|
WithRequestAccessManager(accessManager).
|
||||||
Build()
|
Build()
|
||||||
```
|
```
|
||||||
|
|
||||||
服务会复用该管理器处理每一个入站请求,实现与 CLI 二进制一致的访问控制体验。
|
请在调用 `Build()` 之前完成自定义 provider 的注册(通常通过空白导入触发 `init`),以确保它们被包含在全局 registry 的快照中。
|
||||||
|
|
||||||
### 动态热更新提供者
|
### 动态热更新提供者
|
||||||
|
|
||||||
当配置发生变化时,可以重新构建提供者并替换当前列表:
|
当配置发生变化时,刷新依赖配置的 provider,然后重置 manager 的 provider 链:
|
||||||
|
|
||||||
```go
|
```go
|
||||||
providers, err := sdkaccess.BuildProviders(newCfg)
|
// configaccess is github.com/router-for-me/CLIProxyAPI/v6/internal/access/config_access
|
||||||
if err != nil {
|
configaccess.Register(&newCfg.SDKConfig)
|
||||||
log.Errorf("reload auth providers failed: %v", err)
|
accessManager.SetProviders(sdkaccess.RegisteredProviders())
|
||||||
return
|
|
||||||
}
|
|
||||||
accessManager.SetProviders(providers)
|
|
||||||
```
|
```
|
||||||
|
|
||||||
这一流程与 `cliproxy.Service.refreshAccessProviders` 和 `api.Server.applyAccessConfig` 保持一致,避免为更新访问策略而重启进程。
|
这一流程与 `internal/access.ApplyAccessProviders` 保持一致,避免为更新访问策略而重启进程。
|
||||||
|
|||||||
@@ -4,19 +4,28 @@ import (
|
|||||||
"context"
|
"context"
|
||||||
"net/http"
|
"net/http"
|
||||||
"strings"
|
"strings"
|
||||||
"sync"
|
|
||||||
|
|
||||||
sdkaccess "github.com/router-for-me/CLIProxyAPI/v6/sdk/access"
|
sdkaccess "github.com/router-for-me/CLIProxyAPI/v6/sdk/access"
|
||||||
sdkconfig "github.com/router-for-me/CLIProxyAPI/v6/sdk/config"
|
sdkconfig "github.com/router-for-me/CLIProxyAPI/v6/sdk/config"
|
||||||
)
|
)
|
||||||
|
|
||||||
var registerOnce sync.Once
|
|
||||||
|
|
||||||
// Register ensures the config-access provider is available to the access manager.
|
// Register ensures the config-access provider is available to the access manager.
|
||||||
func Register() {
|
func Register(cfg *sdkconfig.SDKConfig) {
|
||||||
registerOnce.Do(func() {
|
if cfg == nil {
|
||||||
sdkaccess.RegisterProvider(sdkconfig.AccessProviderTypeConfigAPIKey, newProvider)
|
sdkaccess.UnregisterProvider(sdkaccess.AccessProviderTypeConfigAPIKey)
|
||||||
})
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
keys := normalizeKeys(cfg.APIKeys)
|
||||||
|
if len(keys) == 0 {
|
||||||
|
sdkaccess.UnregisterProvider(sdkaccess.AccessProviderTypeConfigAPIKey)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
sdkaccess.RegisterProvider(
|
||||||
|
sdkaccess.AccessProviderTypeConfigAPIKey,
|
||||||
|
newProvider(sdkaccess.DefaultAccessProviderName, keys),
|
||||||
|
)
|
||||||
}
|
}
|
||||||
|
|
||||||
type provider struct {
|
type provider struct {
|
||||||
@@ -24,34 +33,31 @@ type provider struct {
|
|||||||
keys map[string]struct{}
|
keys map[string]struct{}
|
||||||
}
|
}
|
||||||
|
|
||||||
func newProvider(cfg *sdkconfig.AccessProvider, _ *sdkconfig.SDKConfig) (sdkaccess.Provider, error) {
|
func newProvider(name string, keys []string) *provider {
|
||||||
name := cfg.Name
|
providerName := strings.TrimSpace(name)
|
||||||
if name == "" {
|
if providerName == "" {
|
||||||
name = sdkconfig.DefaultAccessProviderName
|
providerName = sdkaccess.DefaultAccessProviderName
|
||||||
}
|
}
|
||||||
keys := make(map[string]struct{}, len(cfg.APIKeys))
|
keySet := make(map[string]struct{}, len(keys))
|
||||||
for _, key := range cfg.APIKeys {
|
for _, key := range keys {
|
||||||
if key == "" {
|
keySet[key] = struct{}{}
|
||||||
continue
|
|
||||||
}
|
|
||||||
keys[key] = struct{}{}
|
|
||||||
}
|
}
|
||||||
return &provider{name: name, keys: keys}, nil
|
return &provider{name: providerName, keys: keySet}
|
||||||
}
|
}
|
||||||
|
|
||||||
func (p *provider) Identifier() string {
|
func (p *provider) Identifier() string {
|
||||||
if p == nil || p.name == "" {
|
if p == nil || p.name == "" {
|
||||||
return sdkconfig.DefaultAccessProviderName
|
return sdkaccess.DefaultAccessProviderName
|
||||||
}
|
}
|
||||||
return p.name
|
return p.name
|
||||||
}
|
}
|
||||||
|
|
||||||
func (p *provider) Authenticate(_ context.Context, r *http.Request) (*sdkaccess.Result, error) {
|
func (p *provider) Authenticate(_ context.Context, r *http.Request) (*sdkaccess.Result, *sdkaccess.AuthError) {
|
||||||
if p == nil {
|
if p == nil {
|
||||||
return nil, sdkaccess.ErrNotHandled
|
return nil, sdkaccess.NewNotHandledError()
|
||||||
}
|
}
|
||||||
if len(p.keys) == 0 {
|
if len(p.keys) == 0 {
|
||||||
return nil, sdkaccess.ErrNotHandled
|
return nil, sdkaccess.NewNotHandledError()
|
||||||
}
|
}
|
||||||
authHeader := r.Header.Get("Authorization")
|
authHeader := r.Header.Get("Authorization")
|
||||||
authHeaderGoogle := r.Header.Get("X-Goog-Api-Key")
|
authHeaderGoogle := r.Header.Get("X-Goog-Api-Key")
|
||||||
@@ -63,7 +69,7 @@ func (p *provider) Authenticate(_ context.Context, r *http.Request) (*sdkaccess.
|
|||||||
queryAuthToken = r.URL.Query().Get("auth_token")
|
queryAuthToken = r.URL.Query().Get("auth_token")
|
||||||
}
|
}
|
||||||
if authHeader == "" && authHeaderGoogle == "" && authHeaderAnthropic == "" && queryKey == "" && queryAuthToken == "" {
|
if authHeader == "" && authHeaderGoogle == "" && authHeaderAnthropic == "" && queryKey == "" && queryAuthToken == "" {
|
||||||
return nil, sdkaccess.ErrNoCredentials
|
return nil, sdkaccess.NewNoCredentialsError()
|
||||||
}
|
}
|
||||||
|
|
||||||
apiKey := extractBearerToken(authHeader)
|
apiKey := extractBearerToken(authHeader)
|
||||||
@@ -94,7 +100,7 @@ func (p *provider) Authenticate(_ context.Context, r *http.Request) (*sdkaccess.
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
return nil, sdkaccess.ErrInvalidCredential
|
return nil, sdkaccess.NewInvalidCredentialError()
|
||||||
}
|
}
|
||||||
|
|
||||||
func extractBearerToken(header string) string {
|
func extractBearerToken(header string) string {
|
||||||
@@ -110,3 +116,26 @@ func extractBearerToken(header string) string {
|
|||||||
}
|
}
|
||||||
return strings.TrimSpace(parts[1])
|
return strings.TrimSpace(parts[1])
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func normalizeKeys(keys []string) []string {
|
||||||
|
if len(keys) == 0 {
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
normalized := make([]string, 0, len(keys))
|
||||||
|
seen := make(map[string]struct{}, len(keys))
|
||||||
|
for _, key := range keys {
|
||||||
|
trimmedKey := strings.TrimSpace(key)
|
||||||
|
if trimmedKey == "" {
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
if _, exists := seen[trimmedKey]; exists {
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
seen[trimmedKey] = struct{}{}
|
||||||
|
normalized = append(normalized, trimmedKey)
|
||||||
|
}
|
||||||
|
if len(normalized) == 0 {
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
return normalized
|
||||||
|
}
|
||||||
|
|||||||
@@ -6,9 +6,9 @@ import (
|
|||||||
"sort"
|
"sort"
|
||||||
"strings"
|
"strings"
|
||||||
|
|
||||||
|
configaccess "github.com/router-for-me/CLIProxyAPI/v6/internal/access/config_access"
|
||||||
"github.com/router-for-me/CLIProxyAPI/v6/internal/config"
|
"github.com/router-for-me/CLIProxyAPI/v6/internal/config"
|
||||||
sdkaccess "github.com/router-for-me/CLIProxyAPI/v6/sdk/access"
|
sdkaccess "github.com/router-for-me/CLIProxyAPI/v6/sdk/access"
|
||||||
sdkConfig "github.com/router-for-me/CLIProxyAPI/v6/sdk/config"
|
|
||||||
log "github.com/sirupsen/logrus"
|
log "github.com/sirupsen/logrus"
|
||||||
)
|
)
|
||||||
|
|
||||||
@@ -17,26 +17,26 @@ import (
|
|||||||
// ordered provider slice along with the identifiers of providers that were added, updated, or
|
// ordered provider slice along with the identifiers of providers that were added, updated, or
|
||||||
// removed compared to the previous configuration.
|
// removed compared to the previous configuration.
|
||||||
func ReconcileProviders(oldCfg, newCfg *config.Config, existing []sdkaccess.Provider) (result []sdkaccess.Provider, added, updated, removed []string, err error) {
|
func ReconcileProviders(oldCfg, newCfg *config.Config, existing []sdkaccess.Provider) (result []sdkaccess.Provider, added, updated, removed []string, err error) {
|
||||||
|
_ = oldCfg
|
||||||
if newCfg == nil {
|
if newCfg == nil {
|
||||||
return nil, nil, nil, nil, nil
|
return nil, nil, nil, nil, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
|
result = sdkaccess.RegisteredProviders()
|
||||||
|
|
||||||
existingMap := make(map[string]sdkaccess.Provider, len(existing))
|
existingMap := make(map[string]sdkaccess.Provider, len(existing))
|
||||||
for _, provider := range existing {
|
for _, provider := range existing {
|
||||||
if provider == nil {
|
providerID := identifierFromProvider(provider)
|
||||||
|
if providerID == "" {
|
||||||
continue
|
continue
|
||||||
}
|
}
|
||||||
existingMap[provider.Identifier()] = provider
|
existingMap[providerID] = provider
|
||||||
}
|
}
|
||||||
|
|
||||||
oldCfgMap := accessProviderMap(oldCfg)
|
finalIDs := make(map[string]struct{}, len(result))
|
||||||
newEntries := collectProviderEntries(newCfg)
|
|
||||||
|
|
||||||
result = make([]sdkaccess.Provider, 0, len(newEntries))
|
|
||||||
finalIDs := make(map[string]struct{}, len(newEntries))
|
|
||||||
|
|
||||||
isInlineProvider := func(id string) bool {
|
isInlineProvider := func(id string) bool {
|
||||||
return strings.EqualFold(id, sdkConfig.DefaultAccessProviderName)
|
return strings.EqualFold(id, sdkaccess.DefaultAccessProviderName)
|
||||||
}
|
}
|
||||||
appendChange := func(list *[]string, id string) {
|
appendChange := func(list *[]string, id string) {
|
||||||
if isInlineProvider(id) {
|
if isInlineProvider(id) {
|
||||||
@@ -45,85 +45,28 @@ func ReconcileProviders(oldCfg, newCfg *config.Config, existing []sdkaccess.Prov
|
|||||||
*list = append(*list, id)
|
*list = append(*list, id)
|
||||||
}
|
}
|
||||||
|
|
||||||
for _, providerCfg := range newEntries {
|
for _, provider := range result {
|
||||||
key := providerIdentifier(providerCfg)
|
providerID := identifierFromProvider(provider)
|
||||||
if key == "" {
|
if providerID == "" {
|
||||||
continue
|
continue
|
||||||
}
|
}
|
||||||
|
finalIDs[providerID] = struct{}{}
|
||||||
|
|
||||||
forceRebuild := strings.EqualFold(strings.TrimSpace(providerCfg.Type), sdkConfig.AccessProviderTypeConfigAPIKey)
|
existingProvider, exists := existingMap[providerID]
|
||||||
if oldCfgProvider, ok := oldCfgMap[key]; ok {
|
if !exists {
|
||||||
isAliased := oldCfgProvider == providerCfg
|
appendChange(&added, providerID)
|
||||||
if !forceRebuild && !isAliased && providerConfigEqual(oldCfgProvider, providerCfg) {
|
continue
|
||||||
if existingProvider, okExisting := existingMap[key]; okExisting {
|
|
||||||
result = append(result, existingProvider)
|
|
||||||
finalIDs[key] = struct{}{}
|
|
||||||
continue
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
}
|
||||||
|
if !providerInstanceEqual(existingProvider, provider) {
|
||||||
provider, buildErr := sdkaccess.BuildProvider(providerCfg, &newCfg.SDKConfig)
|
appendChange(&updated, providerID)
|
||||||
if buildErr != nil {
|
|
||||||
return nil, nil, nil, nil, buildErr
|
|
||||||
}
|
|
||||||
if _, ok := oldCfgMap[key]; ok {
|
|
||||||
if _, existed := existingMap[key]; existed {
|
|
||||||
appendChange(&updated, key)
|
|
||||||
} else {
|
|
||||||
appendChange(&added, key)
|
|
||||||
}
|
|
||||||
} else {
|
|
||||||
appendChange(&added, key)
|
|
||||||
}
|
|
||||||
result = append(result, provider)
|
|
||||||
finalIDs[key] = struct{}{}
|
|
||||||
}
|
|
||||||
|
|
||||||
if len(result) == 0 {
|
|
||||||
if inline := sdkConfig.MakeInlineAPIKeyProvider(newCfg.APIKeys); inline != nil {
|
|
||||||
key := providerIdentifier(inline)
|
|
||||||
if key != "" {
|
|
||||||
if oldCfgProvider, ok := oldCfgMap[key]; ok {
|
|
||||||
if providerConfigEqual(oldCfgProvider, inline) {
|
|
||||||
if existingProvider, okExisting := existingMap[key]; okExisting {
|
|
||||||
result = append(result, existingProvider)
|
|
||||||
finalIDs[key] = struct{}{}
|
|
||||||
goto inlineDone
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
provider, buildErr := sdkaccess.BuildProvider(inline, &newCfg.SDKConfig)
|
|
||||||
if buildErr != nil {
|
|
||||||
return nil, nil, nil, nil, buildErr
|
|
||||||
}
|
|
||||||
if _, existed := existingMap[key]; existed {
|
|
||||||
appendChange(&updated, key)
|
|
||||||
} else if _, hadOld := oldCfgMap[key]; hadOld {
|
|
||||||
appendChange(&updated, key)
|
|
||||||
} else {
|
|
||||||
appendChange(&added, key)
|
|
||||||
}
|
|
||||||
result = append(result, provider)
|
|
||||||
finalIDs[key] = struct{}{}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
inlineDone:
|
|
||||||
}
|
|
||||||
|
|
||||||
removedSet := make(map[string]struct{})
|
|
||||||
for id := range existingMap {
|
|
||||||
if _, ok := finalIDs[id]; !ok {
|
|
||||||
if isInlineProvider(id) {
|
|
||||||
continue
|
|
||||||
}
|
|
||||||
removedSet[id] = struct{}{}
|
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
removed = make([]string, 0, len(removedSet))
|
for providerID := range existingMap {
|
||||||
for id := range removedSet {
|
if _, exists := finalIDs[providerID]; exists {
|
||||||
removed = append(removed, id)
|
continue
|
||||||
|
}
|
||||||
|
appendChange(&removed, providerID)
|
||||||
}
|
}
|
||||||
|
|
||||||
sort.Strings(added)
|
sort.Strings(added)
|
||||||
@@ -142,6 +85,7 @@ func ApplyAccessProviders(manager *sdkaccess.Manager, oldCfg, newCfg *config.Con
|
|||||||
}
|
}
|
||||||
|
|
||||||
existing := manager.Providers()
|
existing := manager.Providers()
|
||||||
|
configaccess.Register(&newCfg.SDKConfig)
|
||||||
providers, added, updated, removed, err := ReconcileProviders(oldCfg, newCfg, existing)
|
providers, added, updated, removed, err := ReconcileProviders(oldCfg, newCfg, existing)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
log.Errorf("failed to reconcile request auth providers: %v", err)
|
log.Errorf("failed to reconcile request auth providers: %v", err)
|
||||||
@@ -160,111 +104,24 @@ func ApplyAccessProviders(manager *sdkaccess.Manager, oldCfg, newCfg *config.Con
|
|||||||
return false, nil
|
return false, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func accessProviderMap(cfg *config.Config) map[string]*sdkConfig.AccessProvider {
|
func identifierFromProvider(provider sdkaccess.Provider) string {
|
||||||
result := make(map[string]*sdkConfig.AccessProvider)
|
|
||||||
if cfg == nil {
|
|
||||||
return result
|
|
||||||
}
|
|
||||||
for i := range cfg.Access.Providers {
|
|
||||||
providerCfg := &cfg.Access.Providers[i]
|
|
||||||
if providerCfg.Type == "" {
|
|
||||||
continue
|
|
||||||
}
|
|
||||||
key := providerIdentifier(providerCfg)
|
|
||||||
if key == "" {
|
|
||||||
continue
|
|
||||||
}
|
|
||||||
result[key] = providerCfg
|
|
||||||
}
|
|
||||||
if len(result) == 0 && len(cfg.APIKeys) > 0 {
|
|
||||||
if provider := sdkConfig.MakeInlineAPIKeyProvider(cfg.APIKeys); provider != nil {
|
|
||||||
if key := providerIdentifier(provider); key != "" {
|
|
||||||
result[key] = provider
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
return result
|
|
||||||
}
|
|
||||||
|
|
||||||
func collectProviderEntries(cfg *config.Config) []*sdkConfig.AccessProvider {
|
|
||||||
entries := make([]*sdkConfig.AccessProvider, 0, len(cfg.Access.Providers))
|
|
||||||
for i := range cfg.Access.Providers {
|
|
||||||
providerCfg := &cfg.Access.Providers[i]
|
|
||||||
if providerCfg.Type == "" {
|
|
||||||
continue
|
|
||||||
}
|
|
||||||
if key := providerIdentifier(providerCfg); key != "" {
|
|
||||||
entries = append(entries, providerCfg)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
if len(entries) == 0 && len(cfg.APIKeys) > 0 {
|
|
||||||
if inline := sdkConfig.MakeInlineAPIKeyProvider(cfg.APIKeys); inline != nil {
|
|
||||||
entries = append(entries, inline)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
return entries
|
|
||||||
}
|
|
||||||
|
|
||||||
func providerIdentifier(provider *sdkConfig.AccessProvider) string {
|
|
||||||
if provider == nil {
|
if provider == nil {
|
||||||
return ""
|
return ""
|
||||||
}
|
}
|
||||||
if name := strings.TrimSpace(provider.Name); name != "" {
|
return strings.TrimSpace(provider.Identifier())
|
||||||
return name
|
|
||||||
}
|
|
||||||
typ := strings.TrimSpace(provider.Type)
|
|
||||||
if typ == "" {
|
|
||||||
return ""
|
|
||||||
}
|
|
||||||
if strings.EqualFold(typ, sdkConfig.AccessProviderTypeConfigAPIKey) {
|
|
||||||
return sdkConfig.DefaultAccessProviderName
|
|
||||||
}
|
|
||||||
return typ
|
|
||||||
}
|
}
|
||||||
|
|
||||||
func providerConfigEqual(a, b *sdkConfig.AccessProvider) bool {
|
func providerInstanceEqual(a, b sdkaccess.Provider) bool {
|
||||||
if a == nil || b == nil {
|
if a == nil || b == nil {
|
||||||
return a == nil && b == nil
|
return a == nil && b == nil
|
||||||
}
|
}
|
||||||
if !strings.EqualFold(strings.TrimSpace(a.Type), strings.TrimSpace(b.Type)) {
|
if reflect.TypeOf(a) != reflect.TypeOf(b) {
|
||||||
return false
|
return false
|
||||||
}
|
}
|
||||||
if strings.TrimSpace(a.SDK) != strings.TrimSpace(b.SDK) {
|
valueA := reflect.ValueOf(a)
|
||||||
return false
|
valueB := reflect.ValueOf(b)
|
||||||
|
if valueA.Kind() == reflect.Pointer && valueB.Kind() == reflect.Pointer {
|
||||||
|
return valueA.Pointer() == valueB.Pointer()
|
||||||
}
|
}
|
||||||
if !stringSetEqual(a.APIKeys, b.APIKeys) {
|
return reflect.DeepEqual(a, b)
|
||||||
return false
|
|
||||||
}
|
|
||||||
if len(a.Config) != len(b.Config) {
|
|
||||||
return false
|
|
||||||
}
|
|
||||||
if len(a.Config) > 0 && !reflect.DeepEqual(a.Config, b.Config) {
|
|
||||||
return false
|
|
||||||
}
|
|
||||||
return true
|
|
||||||
}
|
|
||||||
|
|
||||||
func stringSetEqual(a, b []string) bool {
|
|
||||||
if len(a) != len(b) {
|
|
||||||
return false
|
|
||||||
}
|
|
||||||
if len(a) == 0 {
|
|
||||||
return true
|
|
||||||
}
|
|
||||||
seen := make(map[string]int, len(a))
|
|
||||||
for _, val := range a {
|
|
||||||
seen[val]++
|
|
||||||
}
|
|
||||||
for _, val := range b {
|
|
||||||
count := seen[val]
|
|
||||||
if count == 0 {
|
|
||||||
return false
|
|
||||||
}
|
|
||||||
if count == 1 {
|
|
||||||
delete(seen, val)
|
|
||||||
} else {
|
|
||||||
seen[val] = count - 1
|
|
||||||
}
|
|
||||||
}
|
|
||||||
return len(seen) == 0
|
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -29,6 +29,7 @@ import (
|
|||||||
"github.com/router-for-me/CLIProxyAPI/v6/internal/auth/copilot"
|
"github.com/router-for-me/CLIProxyAPI/v6/internal/auth/copilot"
|
||||||
geminiAuth "github.com/router-for-me/CLIProxyAPI/v6/internal/auth/gemini"
|
geminiAuth "github.com/router-for-me/CLIProxyAPI/v6/internal/auth/gemini"
|
||||||
iflowauth "github.com/router-for-me/CLIProxyAPI/v6/internal/auth/iflow"
|
iflowauth "github.com/router-for-me/CLIProxyAPI/v6/internal/auth/iflow"
|
||||||
|
"github.com/router-for-me/CLIProxyAPI/v6/internal/auth/kimi"
|
||||||
kiroauth "github.com/router-for-me/CLIProxyAPI/v6/internal/auth/kiro"
|
kiroauth "github.com/router-for-me/CLIProxyAPI/v6/internal/auth/kiro"
|
||||||
"github.com/router-for-me/CLIProxyAPI/v6/internal/auth/qwen"
|
"github.com/router-for-me/CLIProxyAPI/v6/internal/auth/qwen"
|
||||||
"github.com/router-for-me/CLIProxyAPI/v6/internal/interfaces"
|
"github.com/router-for-me/CLIProxyAPI/v6/internal/interfaces"
|
||||||
@@ -1613,6 +1614,82 @@ func (h *Handler) RequestQwenToken(c *gin.Context) {
|
|||||||
c.JSON(200, gin.H{"status": "ok", "url": authURL, "state": state})
|
c.JSON(200, gin.H{"status": "ok", "url": authURL, "state": state})
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func (h *Handler) RequestKimiToken(c *gin.Context) {
|
||||||
|
ctx := context.Background()
|
||||||
|
|
||||||
|
fmt.Println("Initializing Kimi authentication...")
|
||||||
|
|
||||||
|
state := fmt.Sprintf("kmi-%d", time.Now().UnixNano())
|
||||||
|
// Initialize Kimi auth service
|
||||||
|
kimiAuth := kimi.NewKimiAuth(h.cfg)
|
||||||
|
|
||||||
|
// Generate authorization URL
|
||||||
|
deviceFlow, errStartDeviceFlow := kimiAuth.StartDeviceFlow(ctx)
|
||||||
|
if errStartDeviceFlow != nil {
|
||||||
|
log.Errorf("Failed to generate authorization URL: %v", errStartDeviceFlow)
|
||||||
|
c.JSON(http.StatusInternalServerError, gin.H{"error": "failed to generate authorization url"})
|
||||||
|
return
|
||||||
|
}
|
||||||
|
authURL := deviceFlow.VerificationURIComplete
|
||||||
|
if authURL == "" {
|
||||||
|
authURL = deviceFlow.VerificationURI
|
||||||
|
}
|
||||||
|
|
||||||
|
RegisterOAuthSession(state, "kimi")
|
||||||
|
|
||||||
|
go func() {
|
||||||
|
fmt.Println("Waiting for authentication...")
|
||||||
|
authBundle, errWaitForAuthorization := kimiAuth.WaitForAuthorization(ctx, deviceFlow)
|
||||||
|
if errWaitForAuthorization != nil {
|
||||||
|
SetOAuthSessionError(state, "Authentication failed")
|
||||||
|
fmt.Printf("Authentication failed: %v\n", errWaitForAuthorization)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
// Create token storage
|
||||||
|
tokenStorage := kimiAuth.CreateTokenStorage(authBundle)
|
||||||
|
|
||||||
|
metadata := map[string]any{
|
||||||
|
"type": "kimi",
|
||||||
|
"access_token": authBundle.TokenData.AccessToken,
|
||||||
|
"refresh_token": authBundle.TokenData.RefreshToken,
|
||||||
|
"token_type": authBundle.TokenData.TokenType,
|
||||||
|
"scope": authBundle.TokenData.Scope,
|
||||||
|
"timestamp": time.Now().UnixMilli(),
|
||||||
|
}
|
||||||
|
if authBundle.TokenData.ExpiresAt > 0 {
|
||||||
|
expired := time.Unix(authBundle.TokenData.ExpiresAt, 0).UTC().Format(time.RFC3339)
|
||||||
|
metadata["expired"] = expired
|
||||||
|
}
|
||||||
|
if strings.TrimSpace(authBundle.DeviceID) != "" {
|
||||||
|
metadata["device_id"] = strings.TrimSpace(authBundle.DeviceID)
|
||||||
|
}
|
||||||
|
|
||||||
|
fileName := fmt.Sprintf("kimi-%d.json", time.Now().UnixMilli())
|
||||||
|
record := &coreauth.Auth{
|
||||||
|
ID: fileName,
|
||||||
|
Provider: "kimi",
|
||||||
|
FileName: fileName,
|
||||||
|
Label: "Kimi User",
|
||||||
|
Storage: tokenStorage,
|
||||||
|
Metadata: metadata,
|
||||||
|
}
|
||||||
|
savedPath, errSave := h.saveTokenRecord(ctx, record)
|
||||||
|
if errSave != nil {
|
||||||
|
log.Errorf("Failed to save authentication tokens: %v", errSave)
|
||||||
|
SetOAuthSessionError(state, "Failed to save authentication tokens")
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
fmt.Printf("Authentication successful! Token saved to %s\n", savedPath)
|
||||||
|
fmt.Println("You can now use Kimi services through this CLI")
|
||||||
|
CompleteOAuthSession(state)
|
||||||
|
CompleteOAuthSessionsByProvider("kimi")
|
||||||
|
}()
|
||||||
|
|
||||||
|
c.JSON(200, gin.H{"status": "ok", "url": authURL, "state": state})
|
||||||
|
}
|
||||||
|
|
||||||
func (h *Handler) RequestIFlowToken(c *gin.Context) {
|
func (h *Handler) RequestIFlowToken(c *gin.Context) {
|
||||||
ctx := context.Background()
|
ctx := context.Background()
|
||||||
|
|
||||||
|
|||||||
@@ -109,14 +109,13 @@ func (h *Handler) GetAPIKeys(c *gin.Context) { c.JSON(200, gin.H{"api-keys": h.c
|
|||||||
func (h *Handler) PutAPIKeys(c *gin.Context) {
|
func (h *Handler) PutAPIKeys(c *gin.Context) {
|
||||||
h.putStringList(c, func(v []string) {
|
h.putStringList(c, func(v []string) {
|
||||||
h.cfg.APIKeys = append([]string(nil), v...)
|
h.cfg.APIKeys = append([]string(nil), v...)
|
||||||
h.cfg.Access.Providers = nil
|
|
||||||
}, nil)
|
}, nil)
|
||||||
}
|
}
|
||||||
func (h *Handler) PatchAPIKeys(c *gin.Context) {
|
func (h *Handler) PatchAPIKeys(c *gin.Context) {
|
||||||
h.patchStringList(c, &h.cfg.APIKeys, func() { h.cfg.Access.Providers = nil })
|
h.patchStringList(c, &h.cfg.APIKeys, func() {})
|
||||||
}
|
}
|
||||||
func (h *Handler) DeleteAPIKeys(c *gin.Context) {
|
func (h *Handler) DeleteAPIKeys(c *gin.Context) {
|
||||||
h.deleteFromStringList(c, &h.cfg.APIKeys, func() { h.cfg.Access.Providers = nil })
|
h.deleteFromStringList(c, &h.cfg.APIKeys, func() {})
|
||||||
}
|
}
|
||||||
|
|
||||||
// gemini-api-key: []GeminiKey
|
// gemini-api-key: []GeminiKey
|
||||||
|
|||||||
@@ -122,7 +122,7 @@ func (rw *ResponseRewriter) Flush() {
|
|||||||
}
|
}
|
||||||
|
|
||||||
// modelFieldPaths lists all JSON paths where model name may appear
|
// modelFieldPaths lists all JSON paths where model name may appear
|
||||||
var modelFieldPaths = []string{"model", "modelVersion", "response.modelVersion", "message.model"}
|
var modelFieldPaths = []string{"message.model", "model", "modelVersion", "response.model", "response.modelVersion"}
|
||||||
|
|
||||||
// rewriteModelInResponse replaces all occurrences of the mapped model with the original model in JSON
|
// rewriteModelInResponse replaces all occurrences of the mapped model with the original model in JSON
|
||||||
// It also suppresses "thinking" blocks if "tool_use" is present to ensure Amp client compatibility
|
// It also suppresses "thinking" blocks if "tool_use" is present to ensure Amp client compatibility
|
||||||
|
|||||||
110
internal/api/modules/amp/response_rewriter_test.go
Normal file
110
internal/api/modules/amp/response_rewriter_test.go
Normal file
@@ -0,0 +1,110 @@
|
|||||||
|
package amp
|
||||||
|
|
||||||
|
import (
|
||||||
|
"testing"
|
||||||
|
)
|
||||||
|
|
||||||
|
func TestRewriteModelInResponse_TopLevel(t *testing.T) {
|
||||||
|
rw := &ResponseRewriter{originalModel: "gpt-5.2-codex"}
|
||||||
|
|
||||||
|
input := []byte(`{"id":"resp_1","model":"gpt-5.3-codex","output":[]}`)
|
||||||
|
result := rw.rewriteModelInResponse(input)
|
||||||
|
|
||||||
|
expected := `{"id":"resp_1","model":"gpt-5.2-codex","output":[]}`
|
||||||
|
if string(result) != expected {
|
||||||
|
t.Errorf("expected %s, got %s", expected, string(result))
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestRewriteModelInResponse_ResponseModel(t *testing.T) {
|
||||||
|
rw := &ResponseRewriter{originalModel: "gpt-5.2-codex"}
|
||||||
|
|
||||||
|
input := []byte(`{"type":"response.completed","response":{"id":"resp_1","model":"gpt-5.3-codex","status":"completed"}}`)
|
||||||
|
result := rw.rewriteModelInResponse(input)
|
||||||
|
|
||||||
|
expected := `{"type":"response.completed","response":{"id":"resp_1","model":"gpt-5.2-codex","status":"completed"}}`
|
||||||
|
if string(result) != expected {
|
||||||
|
t.Errorf("expected %s, got %s", expected, string(result))
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestRewriteModelInResponse_ResponseCreated(t *testing.T) {
|
||||||
|
rw := &ResponseRewriter{originalModel: "gpt-5.2-codex"}
|
||||||
|
|
||||||
|
input := []byte(`{"type":"response.created","response":{"id":"resp_1","model":"gpt-5.3-codex","status":"in_progress"}}`)
|
||||||
|
result := rw.rewriteModelInResponse(input)
|
||||||
|
|
||||||
|
expected := `{"type":"response.created","response":{"id":"resp_1","model":"gpt-5.2-codex","status":"in_progress"}}`
|
||||||
|
if string(result) != expected {
|
||||||
|
t.Errorf("expected %s, got %s", expected, string(result))
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestRewriteModelInResponse_NoModelField(t *testing.T) {
|
||||||
|
rw := &ResponseRewriter{originalModel: "gpt-5.2-codex"}
|
||||||
|
|
||||||
|
input := []byte(`{"type":"response.output_item.added","item":{"id":"item_1","type":"message"}}`)
|
||||||
|
result := rw.rewriteModelInResponse(input)
|
||||||
|
|
||||||
|
if string(result) != string(input) {
|
||||||
|
t.Errorf("expected no modification, got %s", string(result))
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestRewriteModelInResponse_EmptyOriginalModel(t *testing.T) {
|
||||||
|
rw := &ResponseRewriter{originalModel: ""}
|
||||||
|
|
||||||
|
input := []byte(`{"model":"gpt-5.3-codex"}`)
|
||||||
|
result := rw.rewriteModelInResponse(input)
|
||||||
|
|
||||||
|
if string(result) != string(input) {
|
||||||
|
t.Errorf("expected no modification when originalModel is empty, got %s", string(result))
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestRewriteStreamChunk_SSEWithResponseModel(t *testing.T) {
|
||||||
|
rw := &ResponseRewriter{originalModel: "gpt-5.2-codex"}
|
||||||
|
|
||||||
|
chunk := []byte("data: {\"type\":\"response.completed\",\"response\":{\"id\":\"resp_1\",\"model\":\"gpt-5.3-codex\",\"status\":\"completed\"}}\n\n")
|
||||||
|
result := rw.rewriteStreamChunk(chunk)
|
||||||
|
|
||||||
|
expected := "data: {\"type\":\"response.completed\",\"response\":{\"id\":\"resp_1\",\"model\":\"gpt-5.2-codex\",\"status\":\"completed\"}}\n\n"
|
||||||
|
if string(result) != expected {
|
||||||
|
t.Errorf("expected %s, got %s", expected, string(result))
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestRewriteStreamChunk_MultipleEvents(t *testing.T) {
|
||||||
|
rw := &ResponseRewriter{originalModel: "gpt-5.2-codex"}
|
||||||
|
|
||||||
|
chunk := []byte("data: {\"type\":\"response.created\",\"response\":{\"model\":\"gpt-5.3-codex\"}}\n\ndata: {\"type\":\"response.output_item.added\",\"item\":{\"id\":\"item_1\"}}\n\n")
|
||||||
|
result := rw.rewriteStreamChunk(chunk)
|
||||||
|
|
||||||
|
if string(result) == string(chunk) {
|
||||||
|
t.Error("expected response.model to be rewritten in SSE stream")
|
||||||
|
}
|
||||||
|
if !contains(result, []byte(`"model":"gpt-5.2-codex"`)) {
|
||||||
|
t.Errorf("expected rewritten model in output, got %s", string(result))
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestRewriteStreamChunk_MessageModel(t *testing.T) {
|
||||||
|
rw := &ResponseRewriter{originalModel: "claude-opus-4.5"}
|
||||||
|
|
||||||
|
chunk := []byte("data: {\"message\":{\"model\":\"claude-sonnet-4\",\"role\":\"assistant\"}}\n\n")
|
||||||
|
result := rw.rewriteStreamChunk(chunk)
|
||||||
|
|
||||||
|
expected := "data: {\"message\":{\"model\":\"claude-opus-4.5\",\"role\":\"assistant\"}}\n\n"
|
||||||
|
if string(result) != expected {
|
||||||
|
t.Errorf("expected %s, got %s", expected, string(result))
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func contains(data, substr []byte) bool {
|
||||||
|
for i := 0; i <= len(data)-len(substr); i++ {
|
||||||
|
if string(data[i:i+len(substr)]) == string(substr) {
|
||||||
|
return true
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return false
|
||||||
|
}
|
||||||
@@ -649,6 +649,7 @@ func (s *Server) registerManagementRoutes() {
|
|||||||
mgmt.GET("/gemini-cli-auth-url", s.mgmt.RequestGeminiCLIToken)
|
mgmt.GET("/gemini-cli-auth-url", s.mgmt.RequestGeminiCLIToken)
|
||||||
mgmt.GET("/antigravity-auth-url", s.mgmt.RequestAntigravityToken)
|
mgmt.GET("/antigravity-auth-url", s.mgmt.RequestAntigravityToken)
|
||||||
mgmt.GET("/qwen-auth-url", s.mgmt.RequestQwenToken)
|
mgmt.GET("/qwen-auth-url", s.mgmt.RequestQwenToken)
|
||||||
|
mgmt.GET("/kimi-auth-url", s.mgmt.RequestKimiToken)
|
||||||
mgmt.GET("/iflow-auth-url", s.mgmt.RequestIFlowToken)
|
mgmt.GET("/iflow-auth-url", s.mgmt.RequestIFlowToken)
|
||||||
mgmt.POST("/iflow-auth-url", s.mgmt.RequestIFlowCookieToken)
|
mgmt.POST("/iflow-auth-url", s.mgmt.RequestIFlowCookieToken)
|
||||||
mgmt.GET("/kiro-auth-url", s.mgmt.RequestKiroToken)
|
mgmt.GET("/kiro-auth-url", s.mgmt.RequestKiroToken)
|
||||||
@@ -682,14 +683,17 @@ func (s *Server) serveManagementControlPanel(c *gin.Context) {
|
|||||||
|
|
||||||
if _, err := os.Stat(filePath); err != nil {
|
if _, err := os.Stat(filePath); err != nil {
|
||||||
if os.IsNotExist(err) {
|
if os.IsNotExist(err) {
|
||||||
go managementasset.EnsureLatestManagementHTML(context.Background(), managementasset.StaticDir(s.configFilePath), cfg.ProxyURL, cfg.RemoteManagement.PanelGitHubRepository)
|
// Synchronously ensure management.html is available with a detached context.
|
||||||
c.AbortWithStatus(http.StatusNotFound)
|
// Control panel bootstrap should not be canceled by client disconnects.
|
||||||
|
if !managementasset.EnsureLatestManagementHTML(context.Background(), managementasset.StaticDir(s.configFilePath), cfg.ProxyURL, cfg.RemoteManagement.PanelGitHubRepository) {
|
||||||
|
c.AbortWithStatus(http.StatusNotFound)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
} else {
|
||||||
|
log.WithError(err).Error("failed to stat management control panel asset")
|
||||||
|
c.AbortWithStatus(http.StatusInternalServerError)
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
log.WithError(err).Error("failed to stat management control panel asset")
|
|
||||||
c.AbortWithStatus(http.StatusInternalServerError)
|
|
||||||
return
|
|
||||||
}
|
}
|
||||||
|
|
||||||
c.File(filePath)
|
c.File(filePath)
|
||||||
@@ -979,10 +983,6 @@ func (s *Server) UpdateClients(cfg *config.Config) {
|
|||||||
|
|
||||||
s.handlers.UpdateClients(&cfg.SDKConfig)
|
s.handlers.UpdateClients(&cfg.SDKConfig)
|
||||||
|
|
||||||
if !cfg.RemoteManagement.DisableControlPanel {
|
|
||||||
staticDir := managementasset.StaticDir(s.configFilePath)
|
|
||||||
go managementasset.EnsureLatestManagementHTML(context.Background(), staticDir, cfg.ProxyURL, cfg.RemoteManagement.PanelGitHubRepository)
|
|
||||||
}
|
|
||||||
if s.mgmt != nil {
|
if s.mgmt != nil {
|
||||||
s.mgmt.SetConfig(cfg)
|
s.mgmt.SetConfig(cfg)
|
||||||
s.mgmt.SetAuthManager(s.handlers.AuthManager)
|
s.mgmt.SetAuthManager(s.handlers.AuthManager)
|
||||||
@@ -1061,14 +1061,10 @@ func AuthMiddleware(manager *sdkaccess.Manager) gin.HandlerFunc {
|
|||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
switch {
|
statusCode := err.HTTPStatusCode()
|
||||||
case errors.Is(err, sdkaccess.ErrNoCredentials):
|
if statusCode >= http.StatusInternalServerError {
|
||||||
c.AbortWithStatusJSON(http.StatusUnauthorized, gin.H{"error": "Missing API key"})
|
|
||||||
case errors.Is(err, sdkaccess.ErrInvalidCredential):
|
|
||||||
c.AbortWithStatusJSON(http.StatusUnauthorized, gin.H{"error": "Invalid API key"})
|
|
||||||
default:
|
|
||||||
log.Errorf("authentication middleware error: %v", err)
|
log.Errorf("authentication middleware error: %v", err)
|
||||||
c.AbortWithStatusJSON(http.StatusInternalServerError, gin.H{"error": "Authentication service error"})
|
|
||||||
}
|
}
|
||||||
|
c.AbortWithStatusJSON(statusCode, gin.H{"error": err.Message})
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|||||||
396
internal/auth/kimi/kimi.go
Normal file
396
internal/auth/kimi/kimi.go
Normal file
@@ -0,0 +1,396 @@
|
|||||||
|
// Package kimi provides authentication and token management for Kimi (Moonshot AI) API.
|
||||||
|
// It handles the RFC 8628 OAuth2 Device Authorization Grant flow for secure authentication.
|
||||||
|
package kimi
|
||||||
|
|
||||||
|
import (
|
||||||
|
"context"
|
||||||
|
"encoding/json"
|
||||||
|
"fmt"
|
||||||
|
"io"
|
||||||
|
"net/http"
|
||||||
|
"net/url"
|
||||||
|
"os"
|
||||||
|
"runtime"
|
||||||
|
"strings"
|
||||||
|
"time"
|
||||||
|
|
||||||
|
"github.com/google/uuid"
|
||||||
|
"github.com/router-for-me/CLIProxyAPI/v6/internal/config"
|
||||||
|
"github.com/router-for-me/CLIProxyAPI/v6/internal/util"
|
||||||
|
log "github.com/sirupsen/logrus"
|
||||||
|
)
|
||||||
|
|
||||||
|
const (
|
||||||
|
// kimiClientID is Kimi Code's OAuth client ID.
|
||||||
|
kimiClientID = "17e5f671-d194-4dfb-9706-5516cb48c098"
|
||||||
|
// kimiOAuthHost is the OAuth server endpoint.
|
||||||
|
kimiOAuthHost = "https://auth.kimi.com"
|
||||||
|
// kimiDeviceCodeURL is the endpoint for requesting device codes.
|
||||||
|
kimiDeviceCodeURL = kimiOAuthHost + "/api/oauth/device_authorization"
|
||||||
|
// kimiTokenURL is the endpoint for exchanging device codes for tokens.
|
||||||
|
kimiTokenURL = kimiOAuthHost + "/api/oauth/token"
|
||||||
|
// KimiAPIBaseURL is the base URL for Kimi API requests.
|
||||||
|
KimiAPIBaseURL = "https://api.kimi.com/coding"
|
||||||
|
// defaultPollInterval is the default interval for polling token endpoint.
|
||||||
|
defaultPollInterval = 5 * time.Second
|
||||||
|
// maxPollDuration is the maximum time to wait for user authorization.
|
||||||
|
maxPollDuration = 15 * time.Minute
|
||||||
|
// refreshThresholdSeconds is when to refresh token before expiry (5 minutes).
|
||||||
|
refreshThresholdSeconds = 300
|
||||||
|
)
|
||||||
|
|
||||||
|
// KimiAuth handles Kimi authentication flow.
|
||||||
|
type KimiAuth struct {
|
||||||
|
deviceClient *DeviceFlowClient
|
||||||
|
cfg *config.Config
|
||||||
|
}
|
||||||
|
|
||||||
|
// NewKimiAuth creates a new KimiAuth service instance.
|
||||||
|
func NewKimiAuth(cfg *config.Config) *KimiAuth {
|
||||||
|
return &KimiAuth{
|
||||||
|
deviceClient: NewDeviceFlowClient(cfg),
|
||||||
|
cfg: cfg,
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// StartDeviceFlow initiates the device flow authentication.
|
||||||
|
func (k *KimiAuth) StartDeviceFlow(ctx context.Context) (*DeviceCodeResponse, error) {
|
||||||
|
return k.deviceClient.RequestDeviceCode(ctx)
|
||||||
|
}
|
||||||
|
|
||||||
|
// WaitForAuthorization polls for user authorization and returns the auth bundle.
|
||||||
|
func (k *KimiAuth) WaitForAuthorization(ctx context.Context, deviceCode *DeviceCodeResponse) (*KimiAuthBundle, error) {
|
||||||
|
tokenData, err := k.deviceClient.PollForToken(ctx, deviceCode)
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
|
||||||
|
return &KimiAuthBundle{
|
||||||
|
TokenData: tokenData,
|
||||||
|
DeviceID: k.deviceClient.deviceID,
|
||||||
|
}, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// CreateTokenStorage creates a new KimiTokenStorage from auth bundle.
|
||||||
|
func (k *KimiAuth) CreateTokenStorage(bundle *KimiAuthBundle) *KimiTokenStorage {
|
||||||
|
expired := ""
|
||||||
|
if bundle.TokenData.ExpiresAt > 0 {
|
||||||
|
expired = time.Unix(bundle.TokenData.ExpiresAt, 0).UTC().Format(time.RFC3339)
|
||||||
|
}
|
||||||
|
return &KimiTokenStorage{
|
||||||
|
AccessToken: bundle.TokenData.AccessToken,
|
||||||
|
RefreshToken: bundle.TokenData.RefreshToken,
|
||||||
|
TokenType: bundle.TokenData.TokenType,
|
||||||
|
Scope: bundle.TokenData.Scope,
|
||||||
|
DeviceID: strings.TrimSpace(bundle.DeviceID),
|
||||||
|
Expired: expired,
|
||||||
|
Type: "kimi",
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// DeviceFlowClient handles the OAuth2 device flow for Kimi.
|
||||||
|
type DeviceFlowClient struct {
|
||||||
|
httpClient *http.Client
|
||||||
|
cfg *config.Config
|
||||||
|
deviceID string
|
||||||
|
}
|
||||||
|
|
||||||
|
// NewDeviceFlowClient creates a new device flow client.
|
||||||
|
func NewDeviceFlowClient(cfg *config.Config) *DeviceFlowClient {
|
||||||
|
return NewDeviceFlowClientWithDeviceID(cfg, "")
|
||||||
|
}
|
||||||
|
|
||||||
|
// NewDeviceFlowClientWithDeviceID creates a new device flow client with the specified device ID.
|
||||||
|
func NewDeviceFlowClientWithDeviceID(cfg *config.Config, deviceID string) *DeviceFlowClient {
|
||||||
|
client := &http.Client{Timeout: 30 * time.Second}
|
||||||
|
if cfg != nil {
|
||||||
|
client = util.SetProxy(&cfg.SDKConfig, client)
|
||||||
|
}
|
||||||
|
resolvedDeviceID := strings.TrimSpace(deviceID)
|
||||||
|
if resolvedDeviceID == "" {
|
||||||
|
resolvedDeviceID = getOrCreateDeviceID()
|
||||||
|
}
|
||||||
|
return &DeviceFlowClient{
|
||||||
|
httpClient: client,
|
||||||
|
cfg: cfg,
|
||||||
|
deviceID: resolvedDeviceID,
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// getOrCreateDeviceID returns an in-memory device ID for the current authentication flow.
|
||||||
|
func getOrCreateDeviceID() string {
|
||||||
|
return uuid.New().String()
|
||||||
|
}
|
||||||
|
|
||||||
|
// getDeviceModel returns a device model string.
|
||||||
|
func getDeviceModel() string {
|
||||||
|
osName := runtime.GOOS
|
||||||
|
arch := runtime.GOARCH
|
||||||
|
|
||||||
|
switch osName {
|
||||||
|
case "darwin":
|
||||||
|
return fmt.Sprintf("macOS %s", arch)
|
||||||
|
case "windows":
|
||||||
|
return fmt.Sprintf("Windows %s", arch)
|
||||||
|
case "linux":
|
||||||
|
return fmt.Sprintf("Linux %s", arch)
|
||||||
|
default:
|
||||||
|
return fmt.Sprintf("%s %s", osName, arch)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// getHostname returns the machine hostname.
|
||||||
|
func getHostname() string {
|
||||||
|
hostname, err := os.Hostname()
|
||||||
|
if err != nil {
|
||||||
|
return "unknown"
|
||||||
|
}
|
||||||
|
return hostname
|
||||||
|
}
|
||||||
|
|
||||||
|
// commonHeaders returns headers required for Kimi API requests.
|
||||||
|
func (c *DeviceFlowClient) commonHeaders() map[string]string {
|
||||||
|
return map[string]string{
|
||||||
|
"X-Msh-Platform": "cli-proxy-api",
|
||||||
|
"X-Msh-Version": "1.0.0",
|
||||||
|
"X-Msh-Device-Name": getHostname(),
|
||||||
|
"X-Msh-Device-Model": getDeviceModel(),
|
||||||
|
"X-Msh-Device-Id": c.deviceID,
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// RequestDeviceCode initiates the device flow by requesting a device code from Kimi.
|
||||||
|
func (c *DeviceFlowClient) RequestDeviceCode(ctx context.Context) (*DeviceCodeResponse, error) {
|
||||||
|
data := url.Values{}
|
||||||
|
data.Set("client_id", kimiClientID)
|
||||||
|
|
||||||
|
req, err := http.NewRequestWithContext(ctx, http.MethodPost, kimiDeviceCodeURL, strings.NewReader(data.Encode()))
|
||||||
|
if err != nil {
|
||||||
|
return nil, fmt.Errorf("kimi: failed to create device code request: %w", err)
|
||||||
|
}
|
||||||
|
req.Header.Set("Content-Type", "application/x-www-form-urlencoded")
|
||||||
|
req.Header.Set("Accept", "application/json")
|
||||||
|
for k, v := range c.commonHeaders() {
|
||||||
|
req.Header.Set(k, v)
|
||||||
|
}
|
||||||
|
|
||||||
|
resp, err := c.httpClient.Do(req)
|
||||||
|
if err != nil {
|
||||||
|
return nil, fmt.Errorf("kimi: device code request failed: %w", err)
|
||||||
|
}
|
||||||
|
defer func() {
|
||||||
|
if errClose := resp.Body.Close(); errClose != nil {
|
||||||
|
log.Errorf("kimi device code: close body error: %v", errClose)
|
||||||
|
}
|
||||||
|
}()
|
||||||
|
|
||||||
|
bodyBytes, err := io.ReadAll(resp.Body)
|
||||||
|
if err != nil {
|
||||||
|
return nil, fmt.Errorf("kimi: failed to read device code response: %w", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
if resp.StatusCode != http.StatusOK {
|
||||||
|
return nil, fmt.Errorf("kimi: device code request failed with status %d: %s", resp.StatusCode, string(bodyBytes))
|
||||||
|
}
|
||||||
|
|
||||||
|
var deviceCode DeviceCodeResponse
|
||||||
|
if err = json.Unmarshal(bodyBytes, &deviceCode); err != nil {
|
||||||
|
return nil, fmt.Errorf("kimi: failed to parse device code response: %w", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
return &deviceCode, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// PollForToken polls the token endpoint until the user authorizes or the device code expires.
|
||||||
|
func (c *DeviceFlowClient) PollForToken(ctx context.Context, deviceCode *DeviceCodeResponse) (*KimiTokenData, error) {
|
||||||
|
if deviceCode == nil {
|
||||||
|
return nil, fmt.Errorf("kimi: device code is nil")
|
||||||
|
}
|
||||||
|
|
||||||
|
interval := time.Duration(deviceCode.Interval) * time.Second
|
||||||
|
if interval < defaultPollInterval {
|
||||||
|
interval = defaultPollInterval
|
||||||
|
}
|
||||||
|
|
||||||
|
deadline := time.Now().Add(maxPollDuration)
|
||||||
|
if deviceCode.ExpiresIn > 0 {
|
||||||
|
codeDeadline := time.Now().Add(time.Duration(deviceCode.ExpiresIn) * time.Second)
|
||||||
|
if codeDeadline.Before(deadline) {
|
||||||
|
deadline = codeDeadline
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
ticker := time.NewTicker(interval)
|
||||||
|
defer ticker.Stop()
|
||||||
|
|
||||||
|
for {
|
||||||
|
select {
|
||||||
|
case <-ctx.Done():
|
||||||
|
return nil, fmt.Errorf("kimi: context cancelled: %w", ctx.Err())
|
||||||
|
case <-ticker.C:
|
||||||
|
if time.Now().After(deadline) {
|
||||||
|
return nil, fmt.Errorf("kimi: device code expired")
|
||||||
|
}
|
||||||
|
|
||||||
|
token, pollErr, shouldContinue := c.exchangeDeviceCode(ctx, deviceCode.DeviceCode)
|
||||||
|
if token != nil {
|
||||||
|
return token, nil
|
||||||
|
}
|
||||||
|
if !shouldContinue {
|
||||||
|
return nil, pollErr
|
||||||
|
}
|
||||||
|
// Continue polling
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// exchangeDeviceCode attempts to exchange the device code for an access token.
|
||||||
|
// Returns (token, error, shouldContinue).
|
||||||
|
func (c *DeviceFlowClient) exchangeDeviceCode(ctx context.Context, deviceCode string) (*KimiTokenData, error, bool) {
|
||||||
|
data := url.Values{}
|
||||||
|
data.Set("client_id", kimiClientID)
|
||||||
|
data.Set("device_code", deviceCode)
|
||||||
|
data.Set("grant_type", "urn:ietf:params:oauth:grant-type:device_code")
|
||||||
|
|
||||||
|
req, err := http.NewRequestWithContext(ctx, http.MethodPost, kimiTokenURL, strings.NewReader(data.Encode()))
|
||||||
|
if err != nil {
|
||||||
|
return nil, fmt.Errorf("kimi: failed to create token request: %w", err), false
|
||||||
|
}
|
||||||
|
req.Header.Set("Content-Type", "application/x-www-form-urlencoded")
|
||||||
|
req.Header.Set("Accept", "application/json")
|
||||||
|
for k, v := range c.commonHeaders() {
|
||||||
|
req.Header.Set(k, v)
|
||||||
|
}
|
||||||
|
|
||||||
|
resp, err := c.httpClient.Do(req)
|
||||||
|
if err != nil {
|
||||||
|
return nil, fmt.Errorf("kimi: token request failed: %w", err), false
|
||||||
|
}
|
||||||
|
defer func() {
|
||||||
|
if errClose := resp.Body.Close(); errClose != nil {
|
||||||
|
log.Errorf("kimi token exchange: close body error: %v", errClose)
|
||||||
|
}
|
||||||
|
}()
|
||||||
|
|
||||||
|
bodyBytes, err := io.ReadAll(resp.Body)
|
||||||
|
if err != nil {
|
||||||
|
return nil, fmt.Errorf("kimi: failed to read token response: %w", err), false
|
||||||
|
}
|
||||||
|
|
||||||
|
// Parse response - Kimi returns 200 for both success and pending states
|
||||||
|
var oauthResp struct {
|
||||||
|
Error string `json:"error"`
|
||||||
|
ErrorDescription string `json:"error_description"`
|
||||||
|
AccessToken string `json:"access_token"`
|
||||||
|
RefreshToken string `json:"refresh_token"`
|
||||||
|
TokenType string `json:"token_type"`
|
||||||
|
ExpiresIn float64 `json:"expires_in"`
|
||||||
|
Scope string `json:"scope"`
|
||||||
|
}
|
||||||
|
|
||||||
|
if err = json.Unmarshal(bodyBytes, &oauthResp); err != nil {
|
||||||
|
return nil, fmt.Errorf("kimi: failed to parse token response: %w", err), false
|
||||||
|
}
|
||||||
|
|
||||||
|
if oauthResp.Error != "" {
|
||||||
|
switch oauthResp.Error {
|
||||||
|
case "authorization_pending":
|
||||||
|
return nil, nil, true // Continue polling
|
||||||
|
case "slow_down":
|
||||||
|
return nil, nil, true // Continue polling (with increased interval handled by caller)
|
||||||
|
case "expired_token":
|
||||||
|
return nil, fmt.Errorf("kimi: device code expired"), false
|
||||||
|
case "access_denied":
|
||||||
|
return nil, fmt.Errorf("kimi: access denied by user"), false
|
||||||
|
default:
|
||||||
|
return nil, fmt.Errorf("kimi: OAuth error: %s - %s", oauthResp.Error, oauthResp.ErrorDescription), false
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
if oauthResp.AccessToken == "" {
|
||||||
|
return nil, fmt.Errorf("kimi: empty access token in response"), false
|
||||||
|
}
|
||||||
|
|
||||||
|
var expiresAt int64
|
||||||
|
if oauthResp.ExpiresIn > 0 {
|
||||||
|
expiresAt = time.Now().Unix() + int64(oauthResp.ExpiresIn)
|
||||||
|
}
|
||||||
|
|
||||||
|
return &KimiTokenData{
|
||||||
|
AccessToken: oauthResp.AccessToken,
|
||||||
|
RefreshToken: oauthResp.RefreshToken,
|
||||||
|
TokenType: oauthResp.TokenType,
|
||||||
|
ExpiresAt: expiresAt,
|
||||||
|
Scope: oauthResp.Scope,
|
||||||
|
}, nil, false
|
||||||
|
}
|
||||||
|
|
||||||
|
// RefreshToken exchanges a refresh token for a new access token.
|
||||||
|
func (c *DeviceFlowClient) RefreshToken(ctx context.Context, refreshToken string) (*KimiTokenData, error) {
|
||||||
|
data := url.Values{}
|
||||||
|
data.Set("client_id", kimiClientID)
|
||||||
|
data.Set("grant_type", "refresh_token")
|
||||||
|
data.Set("refresh_token", refreshToken)
|
||||||
|
|
||||||
|
req, err := http.NewRequestWithContext(ctx, http.MethodPost, kimiTokenURL, strings.NewReader(data.Encode()))
|
||||||
|
if err != nil {
|
||||||
|
return nil, fmt.Errorf("kimi: failed to create refresh request: %w", err)
|
||||||
|
}
|
||||||
|
req.Header.Set("Content-Type", "application/x-www-form-urlencoded")
|
||||||
|
req.Header.Set("Accept", "application/json")
|
||||||
|
for k, v := range c.commonHeaders() {
|
||||||
|
req.Header.Set(k, v)
|
||||||
|
}
|
||||||
|
|
||||||
|
resp, err := c.httpClient.Do(req)
|
||||||
|
if err != nil {
|
||||||
|
return nil, fmt.Errorf("kimi: refresh request failed: %w", err)
|
||||||
|
}
|
||||||
|
defer func() {
|
||||||
|
if errClose := resp.Body.Close(); errClose != nil {
|
||||||
|
log.Errorf("kimi refresh token: close body error: %v", errClose)
|
||||||
|
}
|
||||||
|
}()
|
||||||
|
|
||||||
|
bodyBytes, err := io.ReadAll(resp.Body)
|
||||||
|
if err != nil {
|
||||||
|
return nil, fmt.Errorf("kimi: failed to read refresh response: %w", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
if resp.StatusCode == http.StatusUnauthorized || resp.StatusCode == http.StatusForbidden {
|
||||||
|
return nil, fmt.Errorf("kimi: refresh token rejected (status %d)", resp.StatusCode)
|
||||||
|
}
|
||||||
|
|
||||||
|
if resp.StatusCode != http.StatusOK {
|
||||||
|
return nil, fmt.Errorf("kimi: refresh failed with status %d: %s", resp.StatusCode, string(bodyBytes))
|
||||||
|
}
|
||||||
|
|
||||||
|
var tokenResp struct {
|
||||||
|
AccessToken string `json:"access_token"`
|
||||||
|
RefreshToken string `json:"refresh_token"`
|
||||||
|
TokenType string `json:"token_type"`
|
||||||
|
ExpiresIn float64 `json:"expires_in"`
|
||||||
|
Scope string `json:"scope"`
|
||||||
|
}
|
||||||
|
|
||||||
|
if err = json.Unmarshal(bodyBytes, &tokenResp); err != nil {
|
||||||
|
return nil, fmt.Errorf("kimi: failed to parse refresh response: %w", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
if tokenResp.AccessToken == "" {
|
||||||
|
return nil, fmt.Errorf("kimi: empty access token in refresh response")
|
||||||
|
}
|
||||||
|
|
||||||
|
var expiresAt int64
|
||||||
|
if tokenResp.ExpiresIn > 0 {
|
||||||
|
expiresAt = time.Now().Unix() + int64(tokenResp.ExpiresIn)
|
||||||
|
}
|
||||||
|
|
||||||
|
return &KimiTokenData{
|
||||||
|
AccessToken: tokenResp.AccessToken,
|
||||||
|
RefreshToken: tokenResp.RefreshToken,
|
||||||
|
TokenType: tokenResp.TokenType,
|
||||||
|
ExpiresAt: expiresAt,
|
||||||
|
Scope: tokenResp.Scope,
|
||||||
|
}, nil
|
||||||
|
}
|
||||||
116
internal/auth/kimi/token.go
Normal file
116
internal/auth/kimi/token.go
Normal file
@@ -0,0 +1,116 @@
|
|||||||
|
// Package kimi provides authentication and token management functionality
|
||||||
|
// for Kimi (Moonshot AI) services. It handles OAuth2 device flow token storage,
|
||||||
|
// serialization, and retrieval for maintaining authenticated sessions with the Kimi API.
|
||||||
|
package kimi
|
||||||
|
|
||||||
|
import (
|
||||||
|
"encoding/json"
|
||||||
|
"fmt"
|
||||||
|
"os"
|
||||||
|
"path/filepath"
|
||||||
|
"time"
|
||||||
|
|
||||||
|
"github.com/router-for-me/CLIProxyAPI/v6/internal/misc"
|
||||||
|
)
|
||||||
|
|
||||||
|
// KimiTokenStorage stores OAuth2 token information for Kimi API authentication.
|
||||||
|
type KimiTokenStorage struct {
|
||||||
|
// AccessToken is the OAuth2 access token used for authenticating API requests.
|
||||||
|
AccessToken string `json:"access_token"`
|
||||||
|
// RefreshToken is the OAuth2 refresh token used to obtain new access tokens.
|
||||||
|
RefreshToken string `json:"refresh_token"`
|
||||||
|
// TokenType is the type of token, typically "Bearer".
|
||||||
|
TokenType string `json:"token_type"`
|
||||||
|
// Scope is the OAuth2 scope granted to the token.
|
||||||
|
Scope string `json:"scope,omitempty"`
|
||||||
|
// DeviceID is the OAuth device flow identifier used for Kimi requests.
|
||||||
|
DeviceID string `json:"device_id,omitempty"`
|
||||||
|
// Expired is the RFC3339 timestamp when the access token expires.
|
||||||
|
Expired string `json:"expired,omitempty"`
|
||||||
|
// Type indicates the authentication provider type, always "kimi" for this storage.
|
||||||
|
Type string `json:"type"`
|
||||||
|
}
|
||||||
|
|
||||||
|
// KimiTokenData holds the raw OAuth token response from Kimi.
|
||||||
|
type KimiTokenData struct {
|
||||||
|
// AccessToken is the OAuth2 access token.
|
||||||
|
AccessToken string `json:"access_token"`
|
||||||
|
// RefreshToken is the OAuth2 refresh token.
|
||||||
|
RefreshToken string `json:"refresh_token"`
|
||||||
|
// TokenType is the type of token, typically "Bearer".
|
||||||
|
TokenType string `json:"token_type"`
|
||||||
|
// ExpiresAt is the Unix timestamp when the token expires.
|
||||||
|
ExpiresAt int64 `json:"expires_at"`
|
||||||
|
// Scope is the OAuth2 scope granted to the token.
|
||||||
|
Scope string `json:"scope"`
|
||||||
|
}
|
||||||
|
|
||||||
|
// KimiAuthBundle bundles authentication data for storage.
|
||||||
|
type KimiAuthBundle struct {
|
||||||
|
// TokenData contains the OAuth token information.
|
||||||
|
TokenData *KimiTokenData
|
||||||
|
// DeviceID is the device identifier used during OAuth device flow.
|
||||||
|
DeviceID string
|
||||||
|
}
|
||||||
|
|
||||||
|
// DeviceCodeResponse represents Kimi's device code response.
|
||||||
|
type DeviceCodeResponse struct {
|
||||||
|
// DeviceCode is the device verification code.
|
||||||
|
DeviceCode string `json:"device_code"`
|
||||||
|
// UserCode is the code the user must enter at the verification URI.
|
||||||
|
UserCode string `json:"user_code"`
|
||||||
|
// VerificationURI is the URL where the user should enter the code.
|
||||||
|
VerificationURI string `json:"verification_uri,omitempty"`
|
||||||
|
// VerificationURIComplete is the URL with the code pre-filled.
|
||||||
|
VerificationURIComplete string `json:"verification_uri_complete"`
|
||||||
|
// ExpiresIn is the number of seconds until the device code expires.
|
||||||
|
ExpiresIn int `json:"expires_in"`
|
||||||
|
// Interval is the minimum number of seconds to wait between polling requests.
|
||||||
|
Interval int `json:"interval"`
|
||||||
|
}
|
||||||
|
|
||||||
|
// SaveTokenToFile serializes the Kimi token storage to a JSON file.
|
||||||
|
func (ts *KimiTokenStorage) SaveTokenToFile(authFilePath string) error {
|
||||||
|
misc.LogSavingCredentials(authFilePath)
|
||||||
|
ts.Type = "kimi"
|
||||||
|
|
||||||
|
if err := os.MkdirAll(filepath.Dir(authFilePath), 0700); err != nil {
|
||||||
|
return fmt.Errorf("failed to create directory: %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
f, err := os.Create(authFilePath)
|
||||||
|
if err != nil {
|
||||||
|
return fmt.Errorf("failed to create token file: %w", err)
|
||||||
|
}
|
||||||
|
defer func() {
|
||||||
|
_ = f.Close()
|
||||||
|
}()
|
||||||
|
|
||||||
|
encoder := json.NewEncoder(f)
|
||||||
|
encoder.SetIndent("", " ")
|
||||||
|
if err = encoder.Encode(ts); err != nil {
|
||||||
|
return fmt.Errorf("failed to write token to file: %w", err)
|
||||||
|
}
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// IsExpired checks if the token has expired.
|
||||||
|
func (ts *KimiTokenStorage) IsExpired() bool {
|
||||||
|
if ts.Expired == "" {
|
||||||
|
return false // No expiry set, assume valid
|
||||||
|
}
|
||||||
|
t, err := time.Parse(time.RFC3339, ts.Expired)
|
||||||
|
if err != nil {
|
||||||
|
return true // Has expiry string but can't parse
|
||||||
|
}
|
||||||
|
// Consider expired if within refresh threshold
|
||||||
|
return time.Now().Add(time.Duration(refreshThresholdSeconds) * time.Second).After(t)
|
||||||
|
}
|
||||||
|
|
||||||
|
// NeedsRefresh checks if the token should be refreshed.
|
||||||
|
func (ts *KimiTokenStorage) NeedsRefresh() bool {
|
||||||
|
if ts.RefreshToken == "" {
|
||||||
|
return false // Can't refresh without refresh token
|
||||||
|
}
|
||||||
|
return ts.IsExpired()
|
||||||
|
}
|
||||||
@@ -238,7 +238,7 @@ func (k *KiroAuth) ListAvailableModels(ctx context.Context, tokenData *KiroToken
|
|||||||
Description string `json:"description"`
|
Description string `json:"description"`
|
||||||
RateMultiplier float64 `json:"rateMultiplier"`
|
RateMultiplier float64 `json:"rateMultiplier"`
|
||||||
RateUnit string `json:"rateUnit"`
|
RateUnit string `json:"rateUnit"`
|
||||||
TokenLimits struct {
|
TokenLimits *struct {
|
||||||
MaxInputTokens int `json:"maxInputTokens"`
|
MaxInputTokens int `json:"maxInputTokens"`
|
||||||
} `json:"tokenLimits"`
|
} `json:"tokenLimits"`
|
||||||
} `json:"models"`
|
} `json:"models"`
|
||||||
@@ -250,13 +250,17 @@ func (k *KiroAuth) ListAvailableModels(ctx context.Context, tokenData *KiroToken
|
|||||||
|
|
||||||
models := make([]*KiroModel, 0, len(result.Models))
|
models := make([]*KiroModel, 0, len(result.Models))
|
||||||
for _, m := range result.Models {
|
for _, m := range result.Models {
|
||||||
|
maxInputTokens := 0
|
||||||
|
if m.TokenLimits != nil {
|
||||||
|
maxInputTokens = m.TokenLimits.MaxInputTokens
|
||||||
|
}
|
||||||
models = append(models, &KiroModel{
|
models = append(models, &KiroModel{
|
||||||
ModelID: m.ModelID,
|
ModelID: m.ModelID,
|
||||||
ModelName: m.ModelName,
|
ModelName: m.ModelName,
|
||||||
Description: m.Description,
|
Description: m.Description,
|
||||||
RateMultiplier: m.RateMultiplier,
|
RateMultiplier: m.RateMultiplier,
|
||||||
RateUnit: m.RateUnit,
|
RateUnit: m.RateUnit,
|
||||||
MaxInputTokens: m.TokenLimits.MaxInputTokens,
|
MaxInputTokens: maxInputTokens,
|
||||||
})
|
})
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|||||||
@@ -19,6 +19,7 @@ func newAuthManager() *sdkAuth.Manager {
|
|||||||
sdkAuth.NewQwenAuthenticator(),
|
sdkAuth.NewQwenAuthenticator(),
|
||||||
sdkAuth.NewIFlowAuthenticator(),
|
sdkAuth.NewIFlowAuthenticator(),
|
||||||
sdkAuth.NewAntigravityAuthenticator(),
|
sdkAuth.NewAntigravityAuthenticator(),
|
||||||
|
sdkAuth.NewKimiAuthenticator(),
|
||||||
sdkAuth.NewKiroAuthenticator(),
|
sdkAuth.NewKiroAuthenticator(),
|
||||||
sdkAuth.NewGitHubCopilotAuthenticator(),
|
sdkAuth.NewGitHubCopilotAuthenticator(),
|
||||||
)
|
)
|
||||||
|
|||||||
44
internal/cmd/kimi_login.go
Normal file
44
internal/cmd/kimi_login.go
Normal file
@@ -0,0 +1,44 @@
|
|||||||
|
package cmd
|
||||||
|
|
||||||
|
import (
|
||||||
|
"context"
|
||||||
|
"fmt"
|
||||||
|
|
||||||
|
"github.com/router-for-me/CLIProxyAPI/v6/internal/config"
|
||||||
|
sdkAuth "github.com/router-for-me/CLIProxyAPI/v6/sdk/auth"
|
||||||
|
log "github.com/sirupsen/logrus"
|
||||||
|
)
|
||||||
|
|
||||||
|
// DoKimiLogin triggers the OAuth device flow for Kimi (Moonshot AI) and saves tokens.
|
||||||
|
// It initiates the device flow authentication, displays the verification URL for the user,
|
||||||
|
// and waits for authorization before saving the tokens.
|
||||||
|
//
|
||||||
|
// Parameters:
|
||||||
|
// - cfg: The application configuration containing proxy and auth directory settings
|
||||||
|
// - options: Login options including browser behavior settings
|
||||||
|
func DoKimiLogin(cfg *config.Config, options *LoginOptions) {
|
||||||
|
if options == nil {
|
||||||
|
options = &LoginOptions{}
|
||||||
|
}
|
||||||
|
|
||||||
|
manager := newAuthManager()
|
||||||
|
authOpts := &sdkAuth.LoginOptions{
|
||||||
|
NoBrowser: options.NoBrowser,
|
||||||
|
Metadata: map[string]string{},
|
||||||
|
Prompt: options.Prompt,
|
||||||
|
}
|
||||||
|
|
||||||
|
record, savedPath, err := manager.Login(context.Background(), "kimi", cfg, authOpts)
|
||||||
|
if err != nil {
|
||||||
|
log.Errorf("Kimi authentication failed: %v", err)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
if savedPath != "" {
|
||||||
|
fmt.Printf("Authentication saved to %s\n", savedPath)
|
||||||
|
}
|
||||||
|
if record != nil && record.Label != "" {
|
||||||
|
fmt.Printf("Authenticated as %s\n", record.Label)
|
||||||
|
}
|
||||||
|
fmt.Println("Kimi authentication successful!")
|
||||||
|
}
|
||||||
@@ -535,14 +535,15 @@ func LoadConfig(configFile string) (*Config, error) {
|
|||||||
// If optional is true and the file is missing, it returns an empty Config.
|
// If optional is true and the file is missing, it returns an empty Config.
|
||||||
// If optional is true and the file is empty or invalid, it returns an empty Config.
|
// If optional is true and the file is empty or invalid, it returns an empty Config.
|
||||||
func LoadConfigOptional(configFile string, optional bool) (*Config, error) {
|
func LoadConfigOptional(configFile string, optional bool) (*Config, error) {
|
||||||
// Perform oauth-model-alias migration before loading config.
|
// NOTE: Startup oauth-model-alias migration is intentionally disabled.
|
||||||
// This migrates oauth-model-mappings to oauth-model-alias if needed.
|
// Reason: avoid mutating config.yaml during server startup.
|
||||||
if migrated, err := MigrateOAuthModelAlias(configFile); err != nil {
|
// Re-enable the block below if automatic startup migration is needed again.
|
||||||
// Log warning but don't fail - config loading should still work
|
// if migrated, err := MigrateOAuthModelAlias(configFile); err != nil {
|
||||||
fmt.Printf("Warning: oauth-model-alias migration failed: %v\n", err)
|
// // Log warning but don't fail - config loading should still work
|
||||||
} else if migrated {
|
// fmt.Printf("Warning: oauth-model-alias migration failed: %v\n", err)
|
||||||
fmt.Println("Migrated oauth-model-mappings to oauth-model-alias")
|
// } else if migrated {
|
||||||
}
|
// fmt.Println("Migrated oauth-model-mappings to oauth-model-alias")
|
||||||
|
// }
|
||||||
|
|
||||||
// Read the entire configuration file into memory.
|
// Read the entire configuration file into memory.
|
||||||
data, err := os.ReadFile(configFile)
|
data, err := os.ReadFile(configFile)
|
||||||
@@ -583,18 +584,21 @@ func LoadConfigOptional(configFile string, optional bool) (*Config, error) {
|
|||||||
return nil, fmt.Errorf("failed to parse config file: %w", err)
|
return nil, fmt.Errorf("failed to parse config file: %w", err)
|
||||||
}
|
}
|
||||||
|
|
||||||
var legacy legacyConfigData
|
// NOTE: Startup legacy key migration is intentionally disabled.
|
||||||
if errLegacy := yaml.Unmarshal(data, &legacy); errLegacy == nil {
|
// Reason: avoid mutating config.yaml during server startup.
|
||||||
if cfg.migrateLegacyGeminiKeys(legacy.LegacyGeminiKeys) {
|
// Re-enable the block below if automatic startup migration is needed again.
|
||||||
cfg.legacyMigrationPending = true
|
// var legacy legacyConfigData
|
||||||
}
|
// if errLegacy := yaml.Unmarshal(data, &legacy); errLegacy == nil {
|
||||||
if cfg.migrateLegacyOpenAICompatibilityKeys(legacy.OpenAICompat) {
|
// if cfg.migrateLegacyGeminiKeys(legacy.LegacyGeminiKeys) {
|
||||||
cfg.legacyMigrationPending = true
|
// cfg.legacyMigrationPending = true
|
||||||
}
|
// }
|
||||||
if cfg.migrateLegacyAmpConfig(&legacy) {
|
// if cfg.migrateLegacyOpenAICompatibilityKeys(legacy.OpenAICompat) {
|
||||||
cfg.legacyMigrationPending = true
|
// cfg.legacyMigrationPending = true
|
||||||
}
|
// }
|
||||||
}
|
// if cfg.migrateLegacyAmpConfig(&legacy) {
|
||||||
|
// cfg.legacyMigrationPending = true
|
||||||
|
// }
|
||||||
|
// }
|
||||||
|
|
||||||
// Hash remote management key if plaintext is detected (nested)
|
// Hash remote management key if plaintext is detected (nested)
|
||||||
// We consider a value to be already hashed if it looks like a bcrypt hash ($2a$, $2b$, or $2y$ prefix).
|
// We consider a value to be already hashed if it looks like a bcrypt hash ($2a$, $2b$, or $2y$ prefix).
|
||||||
@@ -628,9 +632,6 @@ func LoadConfigOptional(configFile string, optional bool) (*Config, error) {
|
|||||||
cfg.ErrorLogsMaxFiles = 10
|
cfg.ErrorLogsMaxFiles = 10
|
||||||
}
|
}
|
||||||
|
|
||||||
// Sync request authentication providers with inline API keys for backwards compatibility.
|
|
||||||
syncInlineAccessProvider(&cfg)
|
|
||||||
|
|
||||||
// Sanitize Gemini API key configuration and migrate legacy entries.
|
// Sanitize Gemini API key configuration and migrate legacy entries.
|
||||||
cfg.SanitizeGeminiKeys()
|
cfg.SanitizeGeminiKeys()
|
||||||
|
|
||||||
@@ -658,17 +659,20 @@ func LoadConfigOptional(configFile string, optional bool) (*Config, error) {
|
|||||||
// Validate raw payload rules and drop invalid entries.
|
// Validate raw payload rules and drop invalid entries.
|
||||||
cfg.SanitizePayloadRules()
|
cfg.SanitizePayloadRules()
|
||||||
|
|
||||||
if cfg.legacyMigrationPending {
|
// NOTE: Legacy migration persistence is intentionally disabled together with
|
||||||
fmt.Println("Detected legacy configuration keys, attempting to persist the normalized config...")
|
// startup legacy migration to keep startup read-only for config.yaml.
|
||||||
if !optional && configFile != "" {
|
// Re-enable the block below if automatic startup migration is needed again.
|
||||||
if err := SaveConfigPreserveComments(configFile, &cfg); err != nil {
|
// if cfg.legacyMigrationPending {
|
||||||
return nil, fmt.Errorf("failed to persist migrated legacy config: %w", err)
|
// fmt.Println("Detected legacy configuration keys, attempting to persist the normalized config...")
|
||||||
}
|
// if !optional && configFile != "" {
|
||||||
fmt.Println("Legacy configuration normalized and persisted.")
|
// if err := SaveConfigPreserveComments(configFile, &cfg); err != nil {
|
||||||
} else {
|
// return nil, fmt.Errorf("failed to persist migrated legacy config: %w", err)
|
||||||
fmt.Println("Legacy configuration normalized in memory; persistence skipped.")
|
// }
|
||||||
}
|
// fmt.Println("Legacy configuration normalized and persisted.")
|
||||||
}
|
// } else {
|
||||||
|
// fmt.Println("Legacy configuration normalized in memory; persistence skipped.")
|
||||||
|
// }
|
||||||
|
// }
|
||||||
|
|
||||||
// Return the populated configuration struct.
|
// Return the populated configuration struct.
|
||||||
return &cfg, nil
|
return &cfg, nil
|
||||||
@@ -881,18 +885,6 @@ func normalizeModelPrefix(prefix string) string {
|
|||||||
return trimmed
|
return trimmed
|
||||||
}
|
}
|
||||||
|
|
||||||
func syncInlineAccessProvider(cfg *Config) {
|
|
||||||
if cfg == nil {
|
|
||||||
return
|
|
||||||
}
|
|
||||||
if len(cfg.APIKeys) == 0 {
|
|
||||||
if provider := cfg.ConfigAPIKeyProvider(); provider != nil && len(provider.APIKeys) > 0 {
|
|
||||||
cfg.APIKeys = append([]string(nil), provider.APIKeys...)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
cfg.Access.Providers = nil
|
|
||||||
}
|
|
||||||
|
|
||||||
// looksLikeBcrypt returns true if the provided string appears to be a bcrypt hash.
|
// looksLikeBcrypt returns true if the provided string appears to be a bcrypt hash.
|
||||||
func looksLikeBcrypt(s string) bool {
|
func looksLikeBcrypt(s string) bool {
|
||||||
return len(s) > 4 && (s[:4] == "$2a$" || s[:4] == "$2b$" || s[:4] == "$2y$")
|
return len(s) > 4 && (s[:4] == "$2a$" || s[:4] == "$2b$" || s[:4] == "$2y$")
|
||||||
@@ -980,7 +972,7 @@ func hashSecret(secret string) (string, error) {
|
|||||||
// SaveConfigPreserveComments writes the config back to YAML while preserving existing comments
|
// SaveConfigPreserveComments writes the config back to YAML while preserving existing comments
|
||||||
// and key ordering by loading the original file into a yaml.Node tree and updating values in-place.
|
// and key ordering by loading the original file into a yaml.Node tree and updating values in-place.
|
||||||
func SaveConfigPreserveComments(configFile string, cfg *Config) error {
|
func SaveConfigPreserveComments(configFile string, cfg *Config) error {
|
||||||
persistCfg := sanitizeConfigForPersist(cfg)
|
persistCfg := cfg
|
||||||
// Load original YAML as a node tree to preserve comments and ordering.
|
// Load original YAML as a node tree to preserve comments and ordering.
|
||||||
data, err := os.ReadFile(configFile)
|
data, err := os.ReadFile(configFile)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
@@ -1048,16 +1040,6 @@ func SaveConfigPreserveComments(configFile string, cfg *Config) error {
|
|||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
|
|
||||||
func sanitizeConfigForPersist(cfg *Config) *Config {
|
|
||||||
if cfg == nil {
|
|
||||||
return nil
|
|
||||||
}
|
|
||||||
clone := *cfg
|
|
||||||
clone.SDKConfig = cfg.SDKConfig
|
|
||||||
clone.SDKConfig.Access = AccessConfig{}
|
|
||||||
return &clone
|
|
||||||
}
|
|
||||||
|
|
||||||
// SaveConfigPreserveCommentsUpdateNestedScalar updates a nested scalar key path like ["a","b"]
|
// SaveConfigPreserveCommentsUpdateNestedScalar updates a nested scalar key path like ["a","b"]
|
||||||
// while preserving comments and positions.
|
// while preserving comments and positions.
|
||||||
func SaveConfigPreserveCommentsUpdateNestedScalar(configFile string, path []string, value string) error {
|
func SaveConfigPreserveCommentsUpdateNestedScalar(configFile string, path []string, value string) error {
|
||||||
@@ -1154,8 +1136,13 @@ func getOrCreateMapValue(mapNode *yaml.Node, key string) *yaml.Node {
|
|||||||
|
|
||||||
// mergeMappingPreserve merges keys from src into dst mapping node while preserving
|
// mergeMappingPreserve merges keys from src into dst mapping node while preserving
|
||||||
// key order and comments of existing keys in dst. New keys are only added if their
|
// key order and comments of existing keys in dst. New keys are only added if their
|
||||||
// value is non-zero to avoid polluting the config with defaults.
|
// value is non-zero and not a known default to avoid polluting the config with defaults.
|
||||||
func mergeMappingPreserve(dst, src *yaml.Node) {
|
func mergeMappingPreserve(dst, src *yaml.Node, path ...[]string) {
|
||||||
|
var currentPath []string
|
||||||
|
if len(path) > 0 {
|
||||||
|
currentPath = path[0]
|
||||||
|
}
|
||||||
|
|
||||||
if dst == nil || src == nil {
|
if dst == nil || src == nil {
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
@@ -1169,16 +1156,19 @@ func mergeMappingPreserve(dst, src *yaml.Node) {
|
|||||||
sk := src.Content[i]
|
sk := src.Content[i]
|
||||||
sv := src.Content[i+1]
|
sv := src.Content[i+1]
|
||||||
idx := findMapKeyIndex(dst, sk.Value)
|
idx := findMapKeyIndex(dst, sk.Value)
|
||||||
|
childPath := appendPath(currentPath, sk.Value)
|
||||||
if idx >= 0 {
|
if idx >= 0 {
|
||||||
// Merge into existing value node (always update, even to zero values)
|
// Merge into existing value node (always update, even to zero values)
|
||||||
dv := dst.Content[idx+1]
|
dv := dst.Content[idx+1]
|
||||||
mergeNodePreserve(dv, sv)
|
mergeNodePreserve(dv, sv, childPath)
|
||||||
} else {
|
} else {
|
||||||
// New key: only add if value is non-zero to avoid polluting config with defaults
|
// New key: only add if value is non-zero and not a known default
|
||||||
if isZeroValueNode(sv) {
|
candidate := deepCopyNode(sv)
|
||||||
|
pruneKnownDefaultsInNewNode(childPath, candidate)
|
||||||
|
if isKnownDefaultValue(childPath, candidate) {
|
||||||
continue
|
continue
|
||||||
}
|
}
|
||||||
dst.Content = append(dst.Content, deepCopyNode(sk), deepCopyNode(sv))
|
dst.Content = append(dst.Content, deepCopyNode(sk), candidate)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
@@ -1186,7 +1176,12 @@ func mergeMappingPreserve(dst, src *yaml.Node) {
|
|||||||
// mergeNodePreserve merges src into dst for scalars, mappings and sequences while
|
// mergeNodePreserve merges src into dst for scalars, mappings and sequences while
|
||||||
// reusing destination nodes to keep comments and anchors. For sequences, it updates
|
// reusing destination nodes to keep comments and anchors. For sequences, it updates
|
||||||
// in-place by index.
|
// in-place by index.
|
||||||
func mergeNodePreserve(dst, src *yaml.Node) {
|
func mergeNodePreserve(dst, src *yaml.Node, path ...[]string) {
|
||||||
|
var currentPath []string
|
||||||
|
if len(path) > 0 {
|
||||||
|
currentPath = path[0]
|
||||||
|
}
|
||||||
|
|
||||||
if dst == nil || src == nil {
|
if dst == nil || src == nil {
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
@@ -1195,7 +1190,7 @@ func mergeNodePreserve(dst, src *yaml.Node) {
|
|||||||
if dst.Kind != yaml.MappingNode {
|
if dst.Kind != yaml.MappingNode {
|
||||||
copyNodeShallow(dst, src)
|
copyNodeShallow(dst, src)
|
||||||
}
|
}
|
||||||
mergeMappingPreserve(dst, src)
|
mergeMappingPreserve(dst, src, currentPath)
|
||||||
case yaml.SequenceNode:
|
case yaml.SequenceNode:
|
||||||
// Preserve explicit null style if dst was null and src is empty sequence
|
// Preserve explicit null style if dst was null and src is empty sequence
|
||||||
if dst.Kind == yaml.ScalarNode && dst.Tag == "!!null" && len(src.Content) == 0 {
|
if dst.Kind == yaml.ScalarNode && dst.Tag == "!!null" && len(src.Content) == 0 {
|
||||||
@@ -1218,7 +1213,7 @@ func mergeNodePreserve(dst, src *yaml.Node) {
|
|||||||
dst.Content[i] = deepCopyNode(src.Content[i])
|
dst.Content[i] = deepCopyNode(src.Content[i])
|
||||||
continue
|
continue
|
||||||
}
|
}
|
||||||
mergeNodePreserve(dst.Content[i], src.Content[i])
|
mergeNodePreserve(dst.Content[i], src.Content[i], currentPath)
|
||||||
if dst.Content[i] != nil && src.Content[i] != nil &&
|
if dst.Content[i] != nil && src.Content[i] != nil &&
|
||||||
dst.Content[i].Kind == yaml.MappingNode && src.Content[i].Kind == yaml.MappingNode {
|
dst.Content[i].Kind == yaml.MappingNode && src.Content[i].Kind == yaml.MappingNode {
|
||||||
pruneMissingMapKeys(dst.Content[i], src.Content[i])
|
pruneMissingMapKeys(dst.Content[i], src.Content[i])
|
||||||
@@ -1260,6 +1255,94 @@ func findMapKeyIndex(mapNode *yaml.Node, key string) int {
|
|||||||
return -1
|
return -1
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// appendPath appends a key to the path, returning a new slice to avoid modifying the original.
|
||||||
|
func appendPath(path []string, key string) []string {
|
||||||
|
if len(path) == 0 {
|
||||||
|
return []string{key}
|
||||||
|
}
|
||||||
|
newPath := make([]string, len(path)+1)
|
||||||
|
copy(newPath, path)
|
||||||
|
newPath[len(path)] = key
|
||||||
|
return newPath
|
||||||
|
}
|
||||||
|
|
||||||
|
// isKnownDefaultValue returns true if the given node at the specified path
|
||||||
|
// represents a known default value that should not be written to the config file.
|
||||||
|
// This prevents non-zero defaults from polluting the config.
|
||||||
|
func isKnownDefaultValue(path []string, node *yaml.Node) bool {
|
||||||
|
// First check if it's a zero value
|
||||||
|
if isZeroValueNode(node) {
|
||||||
|
return true
|
||||||
|
}
|
||||||
|
|
||||||
|
// Match known non-zero defaults by exact dotted path.
|
||||||
|
if len(path) == 0 {
|
||||||
|
return false
|
||||||
|
}
|
||||||
|
|
||||||
|
fullPath := strings.Join(path, ".")
|
||||||
|
|
||||||
|
// Check string defaults
|
||||||
|
if node.Kind == yaml.ScalarNode && node.Tag == "!!str" {
|
||||||
|
switch fullPath {
|
||||||
|
case "pprof.addr":
|
||||||
|
return node.Value == DefaultPprofAddr
|
||||||
|
case "remote-management.panel-github-repository":
|
||||||
|
return node.Value == DefaultPanelGitHubRepository
|
||||||
|
case "routing.strategy":
|
||||||
|
return node.Value == "round-robin"
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// Check integer defaults
|
||||||
|
if node.Kind == yaml.ScalarNode && node.Tag == "!!int" {
|
||||||
|
switch fullPath {
|
||||||
|
case "error-logs-max-files":
|
||||||
|
return node.Value == "10"
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
return false
|
||||||
|
}
|
||||||
|
|
||||||
|
// pruneKnownDefaultsInNewNode removes default-valued descendants from a new node
|
||||||
|
// before it is appended into the destination YAML tree.
|
||||||
|
func pruneKnownDefaultsInNewNode(path []string, node *yaml.Node) {
|
||||||
|
if node == nil {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
switch node.Kind {
|
||||||
|
case yaml.MappingNode:
|
||||||
|
filtered := make([]*yaml.Node, 0, len(node.Content))
|
||||||
|
for i := 0; i+1 < len(node.Content); i += 2 {
|
||||||
|
keyNode := node.Content[i]
|
||||||
|
valueNode := node.Content[i+1]
|
||||||
|
if keyNode == nil || valueNode == nil {
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
|
||||||
|
childPath := appendPath(path, keyNode.Value)
|
||||||
|
if isKnownDefaultValue(childPath, valueNode) {
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
|
||||||
|
pruneKnownDefaultsInNewNode(childPath, valueNode)
|
||||||
|
if (valueNode.Kind == yaml.MappingNode || valueNode.Kind == yaml.SequenceNode) &&
|
||||||
|
len(valueNode.Content) == 0 {
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
|
||||||
|
filtered = append(filtered, keyNode, valueNode)
|
||||||
|
}
|
||||||
|
node.Content = filtered
|
||||||
|
case yaml.SequenceNode:
|
||||||
|
for _, child := range node.Content {
|
||||||
|
pruneKnownDefaultsInNewNode(path, child)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
// isZeroValueNode returns true if the YAML node represents a zero/default value
|
// isZeroValueNode returns true if the YAML node represents a zero/default value
|
||||||
// that should not be written as a new key to preserve config cleanliness.
|
// that should not be written as a new key to preserve config cleanliness.
|
||||||
// For mappings and sequences, recursively checks if all children are zero values.
|
// For mappings and sequences, recursively checks if all children are zero values.
|
||||||
|
|||||||
@@ -20,9 +20,6 @@ type SDKConfig struct {
|
|||||||
// APIKeys is a list of keys for authenticating clients to this proxy server.
|
// APIKeys is a list of keys for authenticating clients to this proxy server.
|
||||||
APIKeys []string `yaml:"api-keys" json:"api-keys"`
|
APIKeys []string `yaml:"api-keys" json:"api-keys"`
|
||||||
|
|
||||||
// Access holds request authentication provider configuration.
|
|
||||||
Access AccessConfig `yaml:"auth,omitempty" json:"auth,omitempty"`
|
|
||||||
|
|
||||||
// Streaming configures server-side streaming behavior (keep-alives and safe bootstrap retries).
|
// Streaming configures server-side streaming behavior (keep-alives and safe bootstrap retries).
|
||||||
Streaming StreamingConfig `yaml:"streaming" json:"streaming"`
|
Streaming StreamingConfig `yaml:"streaming" json:"streaming"`
|
||||||
|
|
||||||
@@ -42,65 +39,3 @@ type StreamingConfig struct {
|
|||||||
// <= 0 disables bootstrap retries. Default is 0.
|
// <= 0 disables bootstrap retries. Default is 0.
|
||||||
BootstrapRetries int `yaml:"bootstrap-retries,omitempty" json:"bootstrap-retries,omitempty"`
|
BootstrapRetries int `yaml:"bootstrap-retries,omitempty" json:"bootstrap-retries,omitempty"`
|
||||||
}
|
}
|
||||||
|
|
||||||
// AccessConfig groups request authentication providers.
|
|
||||||
type AccessConfig struct {
|
|
||||||
// Providers lists configured authentication providers.
|
|
||||||
Providers []AccessProvider `yaml:"providers,omitempty" json:"providers,omitempty"`
|
|
||||||
}
|
|
||||||
|
|
||||||
// AccessProvider describes a request authentication provider entry.
|
|
||||||
type AccessProvider struct {
|
|
||||||
// Name is the instance identifier for the provider.
|
|
||||||
Name string `yaml:"name" json:"name"`
|
|
||||||
|
|
||||||
// Type selects the provider implementation registered via the SDK.
|
|
||||||
Type string `yaml:"type" json:"type"`
|
|
||||||
|
|
||||||
// SDK optionally names a third-party SDK module providing this provider.
|
|
||||||
SDK string `yaml:"sdk,omitempty" json:"sdk,omitempty"`
|
|
||||||
|
|
||||||
// APIKeys lists inline keys for providers that require them.
|
|
||||||
APIKeys []string `yaml:"api-keys,omitempty" json:"api-keys,omitempty"`
|
|
||||||
|
|
||||||
// Config passes provider-specific options to the implementation.
|
|
||||||
Config map[string]any `yaml:"config,omitempty" json:"config,omitempty"`
|
|
||||||
}
|
|
||||||
|
|
||||||
const (
|
|
||||||
// AccessProviderTypeConfigAPIKey is the built-in provider validating inline API keys.
|
|
||||||
AccessProviderTypeConfigAPIKey = "config-api-key"
|
|
||||||
|
|
||||||
// DefaultAccessProviderName is applied when no provider name is supplied.
|
|
||||||
DefaultAccessProviderName = "config-inline"
|
|
||||||
)
|
|
||||||
|
|
||||||
// ConfigAPIKeyProvider returns the first inline API key provider if present.
|
|
||||||
func (c *SDKConfig) ConfigAPIKeyProvider() *AccessProvider {
|
|
||||||
if c == nil {
|
|
||||||
return nil
|
|
||||||
}
|
|
||||||
for i := range c.Access.Providers {
|
|
||||||
if c.Access.Providers[i].Type == AccessProviderTypeConfigAPIKey {
|
|
||||||
if c.Access.Providers[i].Name == "" {
|
|
||||||
c.Access.Providers[i].Name = DefaultAccessProviderName
|
|
||||||
}
|
|
||||||
return &c.Access.Providers[i]
|
|
||||||
}
|
|
||||||
}
|
|
||||||
return nil
|
|
||||||
}
|
|
||||||
|
|
||||||
// MakeInlineAPIKeyProvider constructs an inline API key provider configuration.
|
|
||||||
// It returns nil when no keys are supplied.
|
|
||||||
func MakeInlineAPIKeyProvider(keys []string) *AccessProvider {
|
|
||||||
if len(keys) == 0 {
|
|
||||||
return nil
|
|
||||||
}
|
|
||||||
provider := &AccessProvider{
|
|
||||||
Name: DefaultAccessProviderName,
|
|
||||||
Type: AccessProviderTypeConfigAPIKey,
|
|
||||||
APIKeys: append([]string(nil), keys...),
|
|
||||||
}
|
|
||||||
return provider
|
|
||||||
}
|
|
||||||
|
|||||||
@@ -21,6 +21,7 @@ import (
|
|||||||
"github.com/router-for-me/CLIProxyAPI/v6/internal/util"
|
"github.com/router-for-me/CLIProxyAPI/v6/internal/util"
|
||||||
sdkconfig "github.com/router-for-me/CLIProxyAPI/v6/sdk/config"
|
sdkconfig "github.com/router-for-me/CLIProxyAPI/v6/sdk/config"
|
||||||
log "github.com/sirupsen/logrus"
|
log "github.com/sirupsen/logrus"
|
||||||
|
"golang.org/x/sync/singleflight"
|
||||||
)
|
)
|
||||||
|
|
||||||
const (
|
const (
|
||||||
@@ -28,6 +29,7 @@ const (
|
|||||||
defaultManagementFallbackURL = "https://cpamc.router-for.me/"
|
defaultManagementFallbackURL = "https://cpamc.router-for.me/"
|
||||||
managementAssetName = "management.html"
|
managementAssetName = "management.html"
|
||||||
httpUserAgent = "CLIProxyAPI-management-updater"
|
httpUserAgent = "CLIProxyAPI-management-updater"
|
||||||
|
managementSyncMinInterval = 30 * time.Second
|
||||||
updateCheckInterval = 3 * time.Hour
|
updateCheckInterval = 3 * time.Hour
|
||||||
)
|
)
|
||||||
|
|
||||||
@@ -37,11 +39,10 @@ const ManagementFileName = managementAssetName
|
|||||||
var (
|
var (
|
||||||
lastUpdateCheckMu sync.Mutex
|
lastUpdateCheckMu sync.Mutex
|
||||||
lastUpdateCheckTime time.Time
|
lastUpdateCheckTime time.Time
|
||||||
|
|
||||||
currentConfigPtr atomic.Pointer[config.Config]
|
currentConfigPtr atomic.Pointer[config.Config]
|
||||||
disableControlPanel atomic.Bool
|
|
||||||
schedulerOnce sync.Once
|
schedulerOnce sync.Once
|
||||||
schedulerConfigPath atomic.Value
|
schedulerConfigPath atomic.Value
|
||||||
|
sfGroup singleflight.Group
|
||||||
)
|
)
|
||||||
|
|
||||||
// SetCurrentConfig stores the latest configuration snapshot for management asset decisions.
|
// SetCurrentConfig stores the latest configuration snapshot for management asset decisions.
|
||||||
@@ -50,16 +51,7 @@ func SetCurrentConfig(cfg *config.Config) {
|
|||||||
currentConfigPtr.Store(nil)
|
currentConfigPtr.Store(nil)
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
prevDisabled := disableControlPanel.Load()
|
|
||||||
currentConfigPtr.Store(cfg)
|
currentConfigPtr.Store(cfg)
|
||||||
disableControlPanel.Store(cfg.RemoteManagement.DisableControlPanel)
|
|
||||||
|
|
||||||
if prevDisabled && !cfg.RemoteManagement.DisableControlPanel {
|
|
||||||
lastUpdateCheckMu.Lock()
|
|
||||||
lastUpdateCheckTime = time.Time{}
|
|
||||||
lastUpdateCheckMu.Unlock()
|
|
||||||
}
|
|
||||||
}
|
}
|
||||||
|
|
||||||
// StartAutoUpdater launches a background goroutine that periodically ensures the management asset is up to date.
|
// StartAutoUpdater launches a background goroutine that periodically ensures the management asset is up to date.
|
||||||
@@ -92,7 +84,7 @@ func runAutoUpdater(ctx context.Context) {
|
|||||||
log.Debug("management asset auto-updater skipped: config not yet available")
|
log.Debug("management asset auto-updater skipped: config not yet available")
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
if disableControlPanel.Load() {
|
if cfg.RemoteManagement.DisableControlPanel {
|
||||||
log.Debug("management asset auto-updater skipped: control panel disabled")
|
log.Debug("management asset auto-updater skipped: control panel disabled")
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
@@ -181,103 +173,106 @@ func FilePath(configFilePath string) string {
|
|||||||
}
|
}
|
||||||
|
|
||||||
// EnsureLatestManagementHTML checks the latest management.html asset and updates the local copy when needed.
|
// EnsureLatestManagementHTML checks the latest management.html asset and updates the local copy when needed.
|
||||||
// The function is designed to run in a background goroutine and will never panic.
|
// It coalesces concurrent sync attempts and returns whether the asset exists after the sync attempt.
|
||||||
// It enforces a 3-hour rate limit to avoid frequent checks on config/auth file changes.
|
func EnsureLatestManagementHTML(ctx context.Context, staticDir string, proxyURL string, panelRepository string) bool {
|
||||||
func EnsureLatestManagementHTML(ctx context.Context, staticDir string, proxyURL string, panelRepository string) {
|
|
||||||
if ctx == nil {
|
if ctx == nil {
|
||||||
ctx = context.Background()
|
ctx = context.Background()
|
||||||
}
|
}
|
||||||
|
|
||||||
if disableControlPanel.Load() {
|
|
||||||
log.Debug("management asset sync skipped: control panel disabled by configuration")
|
|
||||||
return
|
|
||||||
}
|
|
||||||
|
|
||||||
staticDir = strings.TrimSpace(staticDir)
|
staticDir = strings.TrimSpace(staticDir)
|
||||||
if staticDir == "" {
|
if staticDir == "" {
|
||||||
log.Debug("management asset sync skipped: empty static directory")
|
log.Debug("management asset sync skipped: empty static directory")
|
||||||
return
|
return false
|
||||||
}
|
}
|
||||||
|
|
||||||
localPath := filepath.Join(staticDir, managementAssetName)
|
localPath := filepath.Join(staticDir, managementAssetName)
|
||||||
localFileMissing := false
|
|
||||||
if _, errStat := os.Stat(localPath); errStat != nil {
|
|
||||||
if errors.Is(errStat, os.ErrNotExist) {
|
|
||||||
localFileMissing = true
|
|
||||||
} else {
|
|
||||||
log.WithError(errStat).Debug("failed to stat local management asset")
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
// Rate limiting: check only once every 3 hours
|
_, _, _ = sfGroup.Do(localPath, func() (interface{}, error) {
|
||||||
lastUpdateCheckMu.Lock()
|
lastUpdateCheckMu.Lock()
|
||||||
now := time.Now()
|
now := time.Now()
|
||||||
timeSinceLastCheck := now.Sub(lastUpdateCheckTime)
|
timeSinceLastAttempt := now.Sub(lastUpdateCheckTime)
|
||||||
if timeSinceLastCheck < updateCheckInterval {
|
if !lastUpdateCheckTime.IsZero() && timeSinceLastAttempt < managementSyncMinInterval {
|
||||||
|
lastUpdateCheckMu.Unlock()
|
||||||
|
log.Debugf(
|
||||||
|
"management asset sync skipped by throttle: last attempt %v ago (interval %v)",
|
||||||
|
timeSinceLastAttempt.Round(time.Second),
|
||||||
|
managementSyncMinInterval,
|
||||||
|
)
|
||||||
|
return nil, nil
|
||||||
|
}
|
||||||
|
lastUpdateCheckTime = now
|
||||||
lastUpdateCheckMu.Unlock()
|
lastUpdateCheckMu.Unlock()
|
||||||
log.Debugf("management asset update check skipped: last check was %v ago (interval: %v)", timeSinceLastCheck.Round(time.Second), updateCheckInterval)
|
|
||||||
return
|
|
||||||
}
|
|
||||||
lastUpdateCheckTime = now
|
|
||||||
lastUpdateCheckMu.Unlock()
|
|
||||||
|
|
||||||
if errMkdirAll := os.MkdirAll(staticDir, 0o755); errMkdirAll != nil {
|
localFileMissing := false
|
||||||
log.WithError(errMkdirAll).Warn("failed to prepare static directory for management asset")
|
if _, errStat := os.Stat(localPath); errStat != nil {
|
||||||
return
|
if errors.Is(errStat, os.ErrNotExist) {
|
||||||
}
|
localFileMissing = true
|
||||||
|
} else {
|
||||||
releaseURL := resolveReleaseURL(panelRepository)
|
log.WithError(errStat).Debug("failed to stat local management asset")
|
||||||
client := newHTTPClient(proxyURL)
|
|
||||||
|
|
||||||
localHash, err := fileSHA256(localPath)
|
|
||||||
if err != nil {
|
|
||||||
if !errors.Is(err, os.ErrNotExist) {
|
|
||||||
log.WithError(err).Debug("failed to read local management asset hash")
|
|
||||||
}
|
|
||||||
localHash = ""
|
|
||||||
}
|
|
||||||
|
|
||||||
asset, remoteHash, err := fetchLatestAsset(ctx, client, releaseURL)
|
|
||||||
if err != nil {
|
|
||||||
if localFileMissing {
|
|
||||||
log.WithError(err).Warn("failed to fetch latest management release information, trying fallback page")
|
|
||||||
if ensureFallbackManagementHTML(ctx, client, localPath) {
|
|
||||||
return
|
|
||||||
}
|
}
|
||||||
return
|
|
||||||
}
|
}
|
||||||
log.WithError(err).Warn("failed to fetch latest management release information")
|
|
||||||
return
|
|
||||||
}
|
|
||||||
|
|
||||||
if remoteHash != "" && localHash != "" && strings.EqualFold(remoteHash, localHash) {
|
if errMkdirAll := os.MkdirAll(staticDir, 0o755); errMkdirAll != nil {
|
||||||
log.Debug("management asset is already up to date")
|
log.WithError(errMkdirAll).Warn("failed to prepare static directory for management asset")
|
||||||
return
|
return nil, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
data, downloadedHash, err := downloadAsset(ctx, client, asset.BrowserDownloadURL)
|
releaseURL := resolveReleaseURL(panelRepository)
|
||||||
if err != nil {
|
client := newHTTPClient(proxyURL)
|
||||||
if localFileMissing {
|
|
||||||
log.WithError(err).Warn("failed to download management asset, trying fallback page")
|
localHash, err := fileSHA256(localPath)
|
||||||
if ensureFallbackManagementHTML(ctx, client, localPath) {
|
if err != nil {
|
||||||
return
|
if !errors.Is(err, os.ErrNotExist) {
|
||||||
|
log.WithError(err).Debug("failed to read local management asset hash")
|
||||||
}
|
}
|
||||||
return
|
localHash = ""
|
||||||
}
|
}
|
||||||
log.WithError(err).Warn("failed to download management asset")
|
|
||||||
return
|
|
||||||
}
|
|
||||||
|
|
||||||
if remoteHash != "" && !strings.EqualFold(remoteHash, downloadedHash) {
|
asset, remoteHash, err := fetchLatestAsset(ctx, client, releaseURL)
|
||||||
log.Warnf("remote digest mismatch for management asset: expected %s got %s", remoteHash, downloadedHash)
|
if err != nil {
|
||||||
}
|
if localFileMissing {
|
||||||
|
log.WithError(err).Warn("failed to fetch latest management release information, trying fallback page")
|
||||||
|
if ensureFallbackManagementHTML(ctx, client, localPath) {
|
||||||
|
return nil, nil
|
||||||
|
}
|
||||||
|
return nil, nil
|
||||||
|
}
|
||||||
|
log.WithError(err).Warn("failed to fetch latest management release information")
|
||||||
|
return nil, nil
|
||||||
|
}
|
||||||
|
|
||||||
if err = atomicWriteFile(localPath, data); err != nil {
|
if remoteHash != "" && localHash != "" && strings.EqualFold(remoteHash, localHash) {
|
||||||
log.WithError(err).Warn("failed to update management asset on disk")
|
log.Debug("management asset is already up to date")
|
||||||
return
|
return nil, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
log.Infof("management asset updated successfully (hash=%s)", downloadedHash)
|
data, downloadedHash, err := downloadAsset(ctx, client, asset.BrowserDownloadURL)
|
||||||
|
if err != nil {
|
||||||
|
if localFileMissing {
|
||||||
|
log.WithError(err).Warn("failed to download management asset, trying fallback page")
|
||||||
|
if ensureFallbackManagementHTML(ctx, client, localPath) {
|
||||||
|
return nil, nil
|
||||||
|
}
|
||||||
|
return nil, nil
|
||||||
|
}
|
||||||
|
log.WithError(err).Warn("failed to download management asset")
|
||||||
|
return nil, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
if remoteHash != "" && !strings.EqualFold(remoteHash, downloadedHash) {
|
||||||
|
log.Warnf("remote digest mismatch for management asset: expected %s got %s", remoteHash, downloadedHash)
|
||||||
|
}
|
||||||
|
|
||||||
|
if err = atomicWriteFile(localPath, data); err != nil {
|
||||||
|
log.WithError(err).Warn("failed to update management asset on disk")
|
||||||
|
return nil, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
log.Infof("management asset updated successfully (hash=%s)", downloadedHash)
|
||||||
|
return nil, nil
|
||||||
|
})
|
||||||
|
|
||||||
|
_, err := os.Stat(localPath)
|
||||||
|
return err == nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func ensureFallbackManagementHTML(ctx context.Context, client *http.Client, localPath string) bool {
|
func ensureFallbackManagementHTML(ctx context.Context, client *http.Client, localPath string) bool {
|
||||||
|
|||||||
@@ -277,6 +277,18 @@ func GetGitHubCopilotModels() []*ModelInfo {
|
|||||||
MaxCompletionTokens: 64000,
|
MaxCompletionTokens: 64000,
|
||||||
SupportedEndpoints: []string{"/chat/completions"},
|
SupportedEndpoints: []string{"/chat/completions"},
|
||||||
},
|
},
|
||||||
|
{
|
||||||
|
ID: "claude-opus-4.6",
|
||||||
|
Object: "model",
|
||||||
|
Created: now,
|
||||||
|
OwnedBy: "github-copilot",
|
||||||
|
Type: "github-copilot",
|
||||||
|
DisplayName: "Claude Opus 4.6",
|
||||||
|
Description: "Anthropic Claude Opus 4.6 via GitHub Copilot",
|
||||||
|
ContextLength: 200000,
|
||||||
|
MaxCompletionTokens: 64000,
|
||||||
|
SupportedEndpoints: []string{"/chat/completions"},
|
||||||
|
},
|
||||||
{
|
{
|
||||||
ID: "claude-sonnet-4",
|
ID: "claude-sonnet-4",
|
||||||
Object: "model",
|
Object: "model",
|
||||||
|
|||||||
@@ -866,9 +866,50 @@ func GetAntigravityModelConfig() map[string]*AntigravityModelConfig {
|
|||||||
"gemini-3-flash": {Thinking: &ThinkingSupport{Min: 128, Max: 32768, ZeroAllowed: false, DynamicAllowed: true, Levels: []string{"minimal", "low", "medium", "high"}}},
|
"gemini-3-flash": {Thinking: &ThinkingSupport{Min: 128, Max: 32768, ZeroAllowed: false, DynamicAllowed: true, Levels: []string{"minimal", "low", "medium", "high"}}},
|
||||||
"claude-sonnet-4-5-thinking": {Thinking: &ThinkingSupport{Min: 1024, Max: 128000, ZeroAllowed: true, DynamicAllowed: true}, MaxCompletionTokens: 64000},
|
"claude-sonnet-4-5-thinking": {Thinking: &ThinkingSupport{Min: 1024, Max: 128000, ZeroAllowed: true, DynamicAllowed: true}, MaxCompletionTokens: 64000},
|
||||||
"claude-opus-4-5-thinking": {Thinking: &ThinkingSupport{Min: 1024, Max: 128000, ZeroAllowed: true, DynamicAllowed: true}, MaxCompletionTokens: 64000},
|
"claude-opus-4-5-thinking": {Thinking: &ThinkingSupport{Min: 1024, Max: 128000, ZeroAllowed: true, DynamicAllowed: true}, MaxCompletionTokens: 64000},
|
||||||
"claude-opus-4-6-thinking": {Thinking: &ThinkingSupport{Min: 1024, Max: 128000, ZeroAllowed: true, DynamicAllowed: true}, MaxCompletionTokens: 128000},
|
"claude-opus-4-6-thinking": {Thinking: &ThinkingSupport{Min: 1024, Max: 128000, ZeroAllowed: true, DynamicAllowed: true}, MaxCompletionTokens: 64000},
|
||||||
"claude-sonnet-4-5": {MaxCompletionTokens: 64000},
|
"claude-sonnet-4-5": {MaxCompletionTokens: 64000},
|
||||||
"gpt-oss-120b-medium": {},
|
"gpt-oss-120b-medium": {},
|
||||||
"tab_flash_lite_preview": {},
|
"tab_flash_lite_preview": {},
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// GetKimiModels returns the standard Kimi (Moonshot AI) model definitions
|
||||||
|
func GetKimiModels() []*ModelInfo {
|
||||||
|
return []*ModelInfo{
|
||||||
|
{
|
||||||
|
ID: "kimi-k2",
|
||||||
|
Object: "model",
|
||||||
|
Created: 1752192000, // 2025-07-11
|
||||||
|
OwnedBy: "moonshot",
|
||||||
|
Type: "kimi",
|
||||||
|
DisplayName: "Kimi K2",
|
||||||
|
Description: "Kimi K2 - Moonshot AI's flagship coding model",
|
||||||
|
ContextLength: 131072,
|
||||||
|
MaxCompletionTokens: 32768,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
ID: "kimi-k2-thinking",
|
||||||
|
Object: "model",
|
||||||
|
Created: 1762387200, // 2025-11-06
|
||||||
|
OwnedBy: "moonshot",
|
||||||
|
Type: "kimi",
|
||||||
|
DisplayName: "Kimi K2 Thinking",
|
||||||
|
Description: "Kimi K2 Thinking - Extended reasoning model",
|
||||||
|
ContextLength: 131072,
|
||||||
|
MaxCompletionTokens: 32768,
|
||||||
|
Thinking: &ThinkingSupport{Min: 1024, Max: 32000, ZeroAllowed: true, DynamicAllowed: true},
|
||||||
|
},
|
||||||
|
{
|
||||||
|
ID: "kimi-k2.5",
|
||||||
|
Object: "model",
|
||||||
|
Created: 1769472000, // 2026-01-26
|
||||||
|
OwnedBy: "moonshot",
|
||||||
|
Type: "kimi",
|
||||||
|
DisplayName: "Kimi K2.5",
|
||||||
|
Description: "Kimi K2.5 - Latest Moonshot AI coding model with improved capabilities",
|
||||||
|
ContextLength: 131072,
|
||||||
|
MaxCompletionTokens: 32768,
|
||||||
|
Thinking: &ThinkingSupport{Min: 1024, Max: 32000, ZeroAllowed: true, DynamicAllowed: true},
|
||||||
|
},
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|||||||
@@ -7,6 +7,7 @@ import (
|
|||||||
"fmt"
|
"fmt"
|
||||||
"io"
|
"io"
|
||||||
"net/http"
|
"net/http"
|
||||||
|
"strings"
|
||||||
"sync"
|
"sync"
|
||||||
"time"
|
"time"
|
||||||
|
|
||||||
@@ -33,11 +34,11 @@ const (
|
|||||||
maxScannerBufferSize = 20_971_520
|
maxScannerBufferSize = 20_971_520
|
||||||
|
|
||||||
// Copilot API header values.
|
// Copilot API header values.
|
||||||
copilotUserAgent = "GithubCopilot/1.0"
|
copilotUserAgent = "GitHubCopilotChat/0.35.0"
|
||||||
copilotEditorVersion = "vscode/1.100.0"
|
copilotEditorVersion = "vscode/1.107.0"
|
||||||
copilotPluginVersion = "copilot/1.300.0"
|
copilotPluginVersion = "copilot-chat/0.35.0"
|
||||||
copilotIntegrationID = "vscode-chat"
|
copilotIntegrationID = "vscode-chat"
|
||||||
copilotOpenAIIntent = "conversation-panel"
|
copilotOpenAIIntent = "conversation-edits"
|
||||||
)
|
)
|
||||||
|
|
||||||
// GitHubCopilotExecutor handles requests to the GitHub Copilot API.
|
// GitHubCopilotExecutor handles requests to the GitHub Copilot API.
|
||||||
@@ -77,7 +78,7 @@ func (e *GitHubCopilotExecutor) PrepareRequest(req *http.Request, auth *cliproxy
|
|||||||
if errToken != nil {
|
if errToken != nil {
|
||||||
return errToken
|
return errToken
|
||||||
}
|
}
|
||||||
e.applyHeaders(req, apiToken)
|
e.applyHeaders(req, apiToken, nil)
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -120,6 +121,7 @@ func (e *GitHubCopilotExecutor) Execute(ctx context.Context, auth *cliproxyauth.
|
|||||||
originalTranslated := sdktranslator.TranslateRequest(from, to, req.Model, originalPayload, false)
|
originalTranslated := sdktranslator.TranslateRequest(from, to, req.Model, originalPayload, false)
|
||||||
body := sdktranslator.TranslateRequest(from, to, req.Model, bytes.Clone(req.Payload), false)
|
body := sdktranslator.TranslateRequest(from, to, req.Model, bytes.Clone(req.Payload), false)
|
||||||
body = e.normalizeModel(req.Model, body)
|
body = e.normalizeModel(req.Model, body)
|
||||||
|
body = flattenAssistantContent(body)
|
||||||
requestedModel := payloadRequestedModel(opts, req.Model)
|
requestedModel := payloadRequestedModel(opts, req.Model)
|
||||||
body = applyPayloadConfigWithRoot(e.cfg, req.Model, to.String(), "", body, originalTranslated, requestedModel)
|
body = applyPayloadConfigWithRoot(e.cfg, req.Model, to.String(), "", body, originalTranslated, requestedModel)
|
||||||
body, _ = sjson.SetBytes(body, "stream", false)
|
body, _ = sjson.SetBytes(body, "stream", false)
|
||||||
@@ -133,7 +135,7 @@ func (e *GitHubCopilotExecutor) Execute(ctx context.Context, auth *cliproxyauth.
|
|||||||
if err != nil {
|
if err != nil {
|
||||||
return resp, err
|
return resp, err
|
||||||
}
|
}
|
||||||
e.applyHeaders(httpReq, apiToken)
|
e.applyHeaders(httpReq, apiToken, body)
|
||||||
|
|
||||||
// Add Copilot-Vision-Request header if the request contains vision content
|
// Add Copilot-Vision-Request header if the request contains vision content
|
||||||
if detectVisionContent(body) {
|
if detectVisionContent(body) {
|
||||||
@@ -225,6 +227,7 @@ func (e *GitHubCopilotExecutor) ExecuteStream(ctx context.Context, auth *cliprox
|
|||||||
originalTranslated := sdktranslator.TranslateRequest(from, to, req.Model, originalPayload, false)
|
originalTranslated := sdktranslator.TranslateRequest(from, to, req.Model, originalPayload, false)
|
||||||
body := sdktranslator.TranslateRequest(from, to, req.Model, bytes.Clone(req.Payload), true)
|
body := sdktranslator.TranslateRequest(from, to, req.Model, bytes.Clone(req.Payload), true)
|
||||||
body = e.normalizeModel(req.Model, body)
|
body = e.normalizeModel(req.Model, body)
|
||||||
|
body = flattenAssistantContent(body)
|
||||||
requestedModel := payloadRequestedModel(opts, req.Model)
|
requestedModel := payloadRequestedModel(opts, req.Model)
|
||||||
body = applyPayloadConfigWithRoot(e.cfg, req.Model, to.String(), "", body, originalTranslated, requestedModel)
|
body = applyPayloadConfigWithRoot(e.cfg, req.Model, to.String(), "", body, originalTranslated, requestedModel)
|
||||||
body, _ = sjson.SetBytes(body, "stream", true)
|
body, _ = sjson.SetBytes(body, "stream", true)
|
||||||
@@ -242,7 +245,7 @@ func (e *GitHubCopilotExecutor) ExecuteStream(ctx context.Context, auth *cliprox
|
|||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
e.applyHeaders(httpReq, apiToken)
|
e.applyHeaders(httpReq, apiToken, body)
|
||||||
|
|
||||||
// Add Copilot-Vision-Request header if the request contains vision content
|
// Add Copilot-Vision-Request header if the request contains vision content
|
||||||
if detectVisionContent(body) {
|
if detectVisionContent(body) {
|
||||||
@@ -414,7 +417,7 @@ func (e *GitHubCopilotExecutor) ensureAPIToken(ctx context.Context, auth *clipro
|
|||||||
}
|
}
|
||||||
|
|
||||||
// applyHeaders sets the required headers for GitHub Copilot API requests.
|
// applyHeaders sets the required headers for GitHub Copilot API requests.
|
||||||
func (e *GitHubCopilotExecutor) applyHeaders(r *http.Request, apiToken string) {
|
func (e *GitHubCopilotExecutor) applyHeaders(r *http.Request, apiToken string, body []byte) {
|
||||||
r.Header.Set("Content-Type", "application/json")
|
r.Header.Set("Content-Type", "application/json")
|
||||||
r.Header.Set("Authorization", "Bearer "+apiToken)
|
r.Header.Set("Authorization", "Bearer "+apiToken)
|
||||||
r.Header.Set("Accept", "application/json")
|
r.Header.Set("Accept", "application/json")
|
||||||
@@ -424,6 +427,20 @@ func (e *GitHubCopilotExecutor) applyHeaders(r *http.Request, apiToken string) {
|
|||||||
r.Header.Set("Openai-Intent", copilotOpenAIIntent)
|
r.Header.Set("Openai-Intent", copilotOpenAIIntent)
|
||||||
r.Header.Set("Copilot-Integration-Id", copilotIntegrationID)
|
r.Header.Set("Copilot-Integration-Id", copilotIntegrationID)
|
||||||
r.Header.Set("X-Request-Id", uuid.NewString())
|
r.Header.Set("X-Request-Id", uuid.NewString())
|
||||||
|
|
||||||
|
initiator := "user"
|
||||||
|
if len(body) > 0 {
|
||||||
|
if messages := gjson.GetBytes(body, "messages"); messages.Exists() && messages.IsArray() {
|
||||||
|
arr := messages.Array()
|
||||||
|
if len(arr) > 0 {
|
||||||
|
lastRole := arr[len(arr)-1].Get("role").String()
|
||||||
|
if lastRole != "" && lastRole != "user" {
|
||||||
|
initiator = "agent"
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
r.Header.Set("X-Initiator", initiator)
|
||||||
}
|
}
|
||||||
|
|
||||||
// detectVisionContent checks if the request body contains vision/image content.
|
// detectVisionContent checks if the request body contains vision/image content.
|
||||||
@@ -464,6 +481,38 @@ func useGitHubCopilotResponsesEndpoint(sourceFormat sdktranslator.Format) bool {
|
|||||||
return sourceFormat.String() == "openai-response"
|
return sourceFormat.String() == "openai-response"
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// flattenAssistantContent converts assistant message content from array format
|
||||||
|
// to a joined string. GitHub Copilot requires assistant content as a string;
|
||||||
|
// sending it as an array causes Claude models to re-answer all previous prompts.
|
||||||
|
func flattenAssistantContent(body []byte) []byte {
|
||||||
|
messages := gjson.GetBytes(body, "messages")
|
||||||
|
if !messages.Exists() || !messages.IsArray() {
|
||||||
|
return body
|
||||||
|
}
|
||||||
|
result := body
|
||||||
|
for i, msg := range messages.Array() {
|
||||||
|
if msg.Get("role").String() != "assistant" {
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
content := msg.Get("content")
|
||||||
|
if !content.Exists() || !content.IsArray() {
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
var textParts []string
|
||||||
|
for _, part := range content.Array() {
|
||||||
|
if part.Get("type").String() == "text" {
|
||||||
|
if t := part.Get("text").String(); t != "" {
|
||||||
|
textParts = append(textParts, t)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
joined := strings.Join(textParts, "")
|
||||||
|
path := fmt.Sprintf("messages.%d.content", i)
|
||||||
|
result, _ = sjson.SetBytes(result, path, joined)
|
||||||
|
}
|
||||||
|
return result
|
||||||
|
}
|
||||||
|
|
||||||
// isHTTPSuccess checks if the status code indicates success (2xx).
|
// isHTTPSuccess checks if the status code indicates success (2xx).
|
||||||
func isHTTPSuccess(statusCode int) bool {
|
func isHTTPSuccess(statusCode int) bool {
|
||||||
return statusCode >= 200 && statusCode < 300
|
return statusCode >= 200 && statusCode < 300
|
||||||
|
|||||||
@@ -4,12 +4,16 @@ import (
|
|||||||
"bufio"
|
"bufio"
|
||||||
"bytes"
|
"bytes"
|
||||||
"context"
|
"context"
|
||||||
|
"crypto/hmac"
|
||||||
|
"crypto/sha256"
|
||||||
|
"encoding/hex"
|
||||||
"fmt"
|
"fmt"
|
||||||
"io"
|
"io"
|
||||||
"net/http"
|
"net/http"
|
||||||
"strings"
|
"strings"
|
||||||
"time"
|
"time"
|
||||||
|
|
||||||
|
"github.com/google/uuid"
|
||||||
iflowauth "github.com/router-for-me/CLIProxyAPI/v6/internal/auth/iflow"
|
iflowauth "github.com/router-for-me/CLIProxyAPI/v6/internal/auth/iflow"
|
||||||
"github.com/router-for-me/CLIProxyAPI/v6/internal/config"
|
"github.com/router-for-me/CLIProxyAPI/v6/internal/config"
|
||||||
"github.com/router-for-me/CLIProxyAPI/v6/internal/thinking"
|
"github.com/router-for-me/CLIProxyAPI/v6/internal/thinking"
|
||||||
@@ -453,6 +457,20 @@ func applyIFlowHeaders(r *http.Request, apiKey string, stream bool) {
|
|||||||
r.Header.Set("Content-Type", "application/json")
|
r.Header.Set("Content-Type", "application/json")
|
||||||
r.Header.Set("Authorization", "Bearer "+apiKey)
|
r.Header.Set("Authorization", "Bearer "+apiKey)
|
||||||
r.Header.Set("User-Agent", iflowUserAgent)
|
r.Header.Set("User-Agent", iflowUserAgent)
|
||||||
|
|
||||||
|
// Generate session-id
|
||||||
|
sessionID := "session-" + generateUUID()
|
||||||
|
r.Header.Set("session-id", sessionID)
|
||||||
|
|
||||||
|
// Generate timestamp and signature
|
||||||
|
timestamp := time.Now().UnixMilli()
|
||||||
|
r.Header.Set("x-iflow-timestamp", fmt.Sprintf("%d", timestamp))
|
||||||
|
|
||||||
|
signature := createIFlowSignature(iflowUserAgent, sessionID, timestamp, apiKey)
|
||||||
|
if signature != "" {
|
||||||
|
r.Header.Set("x-iflow-signature", signature)
|
||||||
|
}
|
||||||
|
|
||||||
if stream {
|
if stream {
|
||||||
r.Header.Set("Accept", "text/event-stream")
|
r.Header.Set("Accept", "text/event-stream")
|
||||||
} else {
|
} else {
|
||||||
@@ -460,6 +478,23 @@ func applyIFlowHeaders(r *http.Request, apiKey string, stream bool) {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// createIFlowSignature generates HMAC-SHA256 signature for iFlow API requests.
|
||||||
|
// The signature payload format is: userAgent:sessionId:timestamp
|
||||||
|
func createIFlowSignature(userAgent, sessionID string, timestamp int64, apiKey string) string {
|
||||||
|
if apiKey == "" {
|
||||||
|
return ""
|
||||||
|
}
|
||||||
|
payload := fmt.Sprintf("%s:%s:%d", userAgent, sessionID, timestamp)
|
||||||
|
h := hmac.New(sha256.New, []byte(apiKey))
|
||||||
|
h.Write([]byte(payload))
|
||||||
|
return hex.EncodeToString(h.Sum(nil))
|
||||||
|
}
|
||||||
|
|
||||||
|
// generateUUID generates a random UUID v4 string.
|
||||||
|
func generateUUID() string {
|
||||||
|
return uuid.New().String()
|
||||||
|
}
|
||||||
|
|
||||||
func iflowCreds(a *cliproxyauth.Auth) (apiKey, baseURL string) {
|
func iflowCreds(a *cliproxyauth.Auth) (apiKey, baseURL string) {
|
||||||
if a == nil {
|
if a == nil {
|
||||||
return "", ""
|
return "", ""
|
||||||
|
|||||||
618
internal/runtime/executor/kimi_executor.go
Normal file
618
internal/runtime/executor/kimi_executor.go
Normal file
@@ -0,0 +1,618 @@
|
|||||||
|
package executor
|
||||||
|
|
||||||
|
import (
|
||||||
|
"bufio"
|
||||||
|
"bytes"
|
||||||
|
"context"
|
||||||
|
"fmt"
|
||||||
|
"io"
|
||||||
|
"net/http"
|
||||||
|
"os"
|
||||||
|
"path/filepath"
|
||||||
|
"runtime"
|
||||||
|
"strings"
|
||||||
|
"time"
|
||||||
|
|
||||||
|
kimiauth "github.com/router-for-me/CLIProxyAPI/v6/internal/auth/kimi"
|
||||||
|
"github.com/router-for-me/CLIProxyAPI/v6/internal/config"
|
||||||
|
"github.com/router-for-me/CLIProxyAPI/v6/internal/thinking"
|
||||||
|
cliproxyauth "github.com/router-for-me/CLIProxyAPI/v6/sdk/cliproxy/auth"
|
||||||
|
cliproxyexecutor "github.com/router-for-me/CLIProxyAPI/v6/sdk/cliproxy/executor"
|
||||||
|
sdktranslator "github.com/router-for-me/CLIProxyAPI/v6/sdk/translator"
|
||||||
|
log "github.com/sirupsen/logrus"
|
||||||
|
"github.com/tidwall/gjson"
|
||||||
|
"github.com/tidwall/sjson"
|
||||||
|
)
|
||||||
|
|
||||||
|
// KimiExecutor is a stateless executor for Kimi API using OpenAI-compatible chat completions.
|
||||||
|
type KimiExecutor struct {
|
||||||
|
ClaudeExecutor
|
||||||
|
cfg *config.Config
|
||||||
|
}
|
||||||
|
|
||||||
|
// NewKimiExecutor creates a new Kimi executor.
|
||||||
|
func NewKimiExecutor(cfg *config.Config) *KimiExecutor { return &KimiExecutor{cfg: cfg} }
|
||||||
|
|
||||||
|
// Identifier returns the executor identifier.
|
||||||
|
func (e *KimiExecutor) Identifier() string { return "kimi" }
|
||||||
|
|
||||||
|
// PrepareRequest injects Kimi credentials into the outgoing HTTP request.
|
||||||
|
func (e *KimiExecutor) PrepareRequest(req *http.Request, auth *cliproxyauth.Auth) error {
|
||||||
|
if req == nil {
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
token := kimiCreds(auth)
|
||||||
|
if strings.TrimSpace(token) != "" {
|
||||||
|
req.Header.Set("Authorization", "Bearer "+token)
|
||||||
|
}
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// HttpRequest injects Kimi credentials into the request and executes it.
|
||||||
|
func (e *KimiExecutor) HttpRequest(ctx context.Context, auth *cliproxyauth.Auth, req *http.Request) (*http.Response, error) {
|
||||||
|
if req == nil {
|
||||||
|
return nil, fmt.Errorf("kimi executor: request is nil")
|
||||||
|
}
|
||||||
|
if ctx == nil {
|
||||||
|
ctx = req.Context()
|
||||||
|
}
|
||||||
|
httpReq := req.WithContext(ctx)
|
||||||
|
if err := e.PrepareRequest(httpReq, auth); err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
httpClient := newProxyAwareHTTPClient(ctx, e.cfg, auth, 0)
|
||||||
|
return httpClient.Do(httpReq)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Execute performs a non-streaming chat completion request to Kimi.
|
||||||
|
func (e *KimiExecutor) Execute(ctx context.Context, auth *cliproxyauth.Auth, req cliproxyexecutor.Request, opts cliproxyexecutor.Options) (resp cliproxyexecutor.Response, err error) {
|
||||||
|
from := opts.SourceFormat
|
||||||
|
if from.String() == "claude" {
|
||||||
|
auth.Attributes["base_url"] = kimiauth.KimiAPIBaseURL
|
||||||
|
return e.ClaudeExecutor.Execute(ctx, auth, req, opts)
|
||||||
|
}
|
||||||
|
|
||||||
|
baseModel := thinking.ParseSuffix(req.Model).ModelName
|
||||||
|
|
||||||
|
token := kimiCreds(auth)
|
||||||
|
|
||||||
|
reporter := newUsageReporter(ctx, e.Identifier(), baseModel, auth)
|
||||||
|
defer reporter.trackFailure(ctx, &err)
|
||||||
|
|
||||||
|
to := sdktranslator.FromString("openai")
|
||||||
|
originalPayloadSource := req.Payload
|
||||||
|
if len(opts.OriginalRequest) > 0 {
|
||||||
|
originalPayloadSource = opts.OriginalRequest
|
||||||
|
}
|
||||||
|
originalPayload := bytes.Clone(originalPayloadSource)
|
||||||
|
originalTranslated := sdktranslator.TranslateRequest(from, to, baseModel, originalPayload, false)
|
||||||
|
body := sdktranslator.TranslateRequest(from, to, baseModel, bytes.Clone(req.Payload), false)
|
||||||
|
|
||||||
|
// Strip kimi- prefix for upstream API
|
||||||
|
upstreamModel := stripKimiPrefix(baseModel)
|
||||||
|
body, err = sjson.SetBytes(body, "model", upstreamModel)
|
||||||
|
if err != nil {
|
||||||
|
return resp, fmt.Errorf("kimi executor: failed to set model in payload: %w", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
body, err = thinking.ApplyThinking(body, req.Model, from.String(), "kimi", e.Identifier())
|
||||||
|
if err != nil {
|
||||||
|
return resp, err
|
||||||
|
}
|
||||||
|
|
||||||
|
requestedModel := payloadRequestedModel(opts, req.Model)
|
||||||
|
body = applyPayloadConfigWithRoot(e.cfg, baseModel, to.String(), "", body, originalTranslated, requestedModel)
|
||||||
|
body, err = normalizeKimiToolMessageLinks(body)
|
||||||
|
if err != nil {
|
||||||
|
return resp, err
|
||||||
|
}
|
||||||
|
|
||||||
|
url := kimiauth.KimiAPIBaseURL + "/v1/chat/completions"
|
||||||
|
httpReq, err := http.NewRequestWithContext(ctx, http.MethodPost, url, bytes.NewReader(body))
|
||||||
|
if err != nil {
|
||||||
|
return resp, err
|
||||||
|
}
|
||||||
|
applyKimiHeadersWithAuth(httpReq, token, false, auth)
|
||||||
|
var authID, authLabel, authType, authValue string
|
||||||
|
if auth != nil {
|
||||||
|
authID = auth.ID
|
||||||
|
authLabel = auth.Label
|
||||||
|
authType, authValue = auth.AccountInfo()
|
||||||
|
}
|
||||||
|
recordAPIRequest(ctx, e.cfg, upstreamRequestLog{
|
||||||
|
URL: url,
|
||||||
|
Method: http.MethodPost,
|
||||||
|
Headers: httpReq.Header.Clone(),
|
||||||
|
Body: body,
|
||||||
|
Provider: e.Identifier(),
|
||||||
|
AuthID: authID,
|
||||||
|
AuthLabel: authLabel,
|
||||||
|
AuthType: authType,
|
||||||
|
AuthValue: authValue,
|
||||||
|
})
|
||||||
|
|
||||||
|
httpClient := newProxyAwareHTTPClient(ctx, e.cfg, auth, 0)
|
||||||
|
httpResp, err := httpClient.Do(httpReq)
|
||||||
|
if err != nil {
|
||||||
|
recordAPIResponseError(ctx, e.cfg, err)
|
||||||
|
return resp, err
|
||||||
|
}
|
||||||
|
defer func() {
|
||||||
|
if errClose := httpResp.Body.Close(); errClose != nil {
|
||||||
|
log.Errorf("kimi executor: close response body error: %v", errClose)
|
||||||
|
}
|
||||||
|
}()
|
||||||
|
recordAPIResponseMetadata(ctx, e.cfg, httpResp.StatusCode, httpResp.Header.Clone())
|
||||||
|
if httpResp.StatusCode < 200 || httpResp.StatusCode >= 300 {
|
||||||
|
b, _ := io.ReadAll(httpResp.Body)
|
||||||
|
appendAPIResponseChunk(ctx, e.cfg, b)
|
||||||
|
logWithRequestID(ctx).Debugf("request error, error status: %d, error message: %s", httpResp.StatusCode, summarizeErrorBody(httpResp.Header.Get("Content-Type"), b))
|
||||||
|
err = statusErr{code: httpResp.StatusCode, msg: string(b)}
|
||||||
|
return resp, err
|
||||||
|
}
|
||||||
|
data, err := io.ReadAll(httpResp.Body)
|
||||||
|
if err != nil {
|
||||||
|
recordAPIResponseError(ctx, e.cfg, err)
|
||||||
|
return resp, err
|
||||||
|
}
|
||||||
|
appendAPIResponseChunk(ctx, e.cfg, data)
|
||||||
|
reporter.publish(ctx, parseOpenAIUsage(data))
|
||||||
|
var param any
|
||||||
|
// Note: TranslateNonStream uses req.Model (original with suffix) to preserve
|
||||||
|
// the original model name in the response for client compatibility.
|
||||||
|
out := sdktranslator.TranslateNonStream(ctx, to, from, req.Model, opts.OriginalRequest, body, data, ¶m)
|
||||||
|
resp = cliproxyexecutor.Response{Payload: []byte(out)}
|
||||||
|
return resp, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// ExecuteStream performs a streaming chat completion request to Kimi.
|
||||||
|
func (e *KimiExecutor) ExecuteStream(ctx context.Context, auth *cliproxyauth.Auth, req cliproxyexecutor.Request, opts cliproxyexecutor.Options) (stream <-chan cliproxyexecutor.StreamChunk, err error) {
|
||||||
|
from := opts.SourceFormat
|
||||||
|
if from.String() == "claude" {
|
||||||
|
auth.Attributes["base_url"] = kimiauth.KimiAPIBaseURL
|
||||||
|
return e.ClaudeExecutor.ExecuteStream(ctx, auth, req, opts)
|
||||||
|
}
|
||||||
|
|
||||||
|
baseModel := thinking.ParseSuffix(req.Model).ModelName
|
||||||
|
token := kimiCreds(auth)
|
||||||
|
|
||||||
|
reporter := newUsageReporter(ctx, e.Identifier(), baseModel, auth)
|
||||||
|
defer reporter.trackFailure(ctx, &err)
|
||||||
|
|
||||||
|
to := sdktranslator.FromString("openai")
|
||||||
|
originalPayloadSource := req.Payload
|
||||||
|
if len(opts.OriginalRequest) > 0 {
|
||||||
|
originalPayloadSource = opts.OriginalRequest
|
||||||
|
}
|
||||||
|
originalPayload := bytes.Clone(originalPayloadSource)
|
||||||
|
originalTranslated := sdktranslator.TranslateRequest(from, to, baseModel, originalPayload, true)
|
||||||
|
body := sdktranslator.TranslateRequest(from, to, baseModel, bytes.Clone(req.Payload), true)
|
||||||
|
|
||||||
|
// Strip kimi- prefix for upstream API
|
||||||
|
upstreamModel := stripKimiPrefix(baseModel)
|
||||||
|
body, err = sjson.SetBytes(body, "model", upstreamModel)
|
||||||
|
if err != nil {
|
||||||
|
return nil, fmt.Errorf("kimi executor: failed to set model in payload: %w", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
body, err = thinking.ApplyThinking(body, req.Model, from.String(), "kimi", e.Identifier())
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
|
||||||
|
body, err = sjson.SetBytes(body, "stream_options.include_usage", true)
|
||||||
|
if err != nil {
|
||||||
|
return nil, fmt.Errorf("kimi executor: failed to set stream_options in payload: %w", err)
|
||||||
|
}
|
||||||
|
requestedModel := payloadRequestedModel(opts, req.Model)
|
||||||
|
body = applyPayloadConfigWithRoot(e.cfg, baseModel, to.String(), "", body, originalTranslated, requestedModel)
|
||||||
|
body, err = normalizeKimiToolMessageLinks(body)
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
|
||||||
|
url := kimiauth.KimiAPIBaseURL + "/v1/chat/completions"
|
||||||
|
httpReq, err := http.NewRequestWithContext(ctx, http.MethodPost, url, bytes.NewReader(body))
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
applyKimiHeadersWithAuth(httpReq, token, true, auth)
|
||||||
|
var authID, authLabel, authType, authValue string
|
||||||
|
if auth != nil {
|
||||||
|
authID = auth.ID
|
||||||
|
authLabel = auth.Label
|
||||||
|
authType, authValue = auth.AccountInfo()
|
||||||
|
}
|
||||||
|
recordAPIRequest(ctx, e.cfg, upstreamRequestLog{
|
||||||
|
URL: url,
|
||||||
|
Method: http.MethodPost,
|
||||||
|
Headers: httpReq.Header.Clone(),
|
||||||
|
Body: body,
|
||||||
|
Provider: e.Identifier(),
|
||||||
|
AuthID: authID,
|
||||||
|
AuthLabel: authLabel,
|
||||||
|
AuthType: authType,
|
||||||
|
AuthValue: authValue,
|
||||||
|
})
|
||||||
|
|
||||||
|
httpClient := newProxyAwareHTTPClient(ctx, e.cfg, auth, 0)
|
||||||
|
httpResp, err := httpClient.Do(httpReq)
|
||||||
|
if err != nil {
|
||||||
|
recordAPIResponseError(ctx, e.cfg, err)
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
recordAPIResponseMetadata(ctx, e.cfg, httpResp.StatusCode, httpResp.Header.Clone())
|
||||||
|
if httpResp.StatusCode < 200 || httpResp.StatusCode >= 300 {
|
||||||
|
b, _ := io.ReadAll(httpResp.Body)
|
||||||
|
appendAPIResponseChunk(ctx, e.cfg, b)
|
||||||
|
logWithRequestID(ctx).Debugf("request error, error status: %d, error message: %s", httpResp.StatusCode, summarizeErrorBody(httpResp.Header.Get("Content-Type"), b))
|
||||||
|
if errClose := httpResp.Body.Close(); errClose != nil {
|
||||||
|
log.Errorf("kimi executor: close response body error: %v", errClose)
|
||||||
|
}
|
||||||
|
err = statusErr{code: httpResp.StatusCode, msg: string(b)}
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
out := make(chan cliproxyexecutor.StreamChunk)
|
||||||
|
stream = out
|
||||||
|
go func() {
|
||||||
|
defer close(out)
|
||||||
|
defer func() {
|
||||||
|
if errClose := httpResp.Body.Close(); errClose != nil {
|
||||||
|
log.Errorf("kimi executor: close response body error: %v", errClose)
|
||||||
|
}
|
||||||
|
}()
|
||||||
|
scanner := bufio.NewScanner(httpResp.Body)
|
||||||
|
scanner.Buffer(nil, 1_048_576) // 1MB
|
||||||
|
var param any
|
||||||
|
for scanner.Scan() {
|
||||||
|
line := scanner.Bytes()
|
||||||
|
appendAPIResponseChunk(ctx, e.cfg, line)
|
||||||
|
if detail, ok := parseOpenAIStreamUsage(line); ok {
|
||||||
|
reporter.publish(ctx, detail)
|
||||||
|
}
|
||||||
|
chunks := sdktranslator.TranslateStream(ctx, to, from, req.Model, opts.OriginalRequest, body, bytes.Clone(line), ¶m)
|
||||||
|
for i := range chunks {
|
||||||
|
out <- cliproxyexecutor.StreamChunk{Payload: []byte(chunks[i])}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
doneChunks := sdktranslator.TranslateStream(ctx, to, from, req.Model, opts.OriginalRequest, body, []byte("[DONE]"), ¶m)
|
||||||
|
for i := range doneChunks {
|
||||||
|
out <- cliproxyexecutor.StreamChunk{Payload: []byte(doneChunks[i])}
|
||||||
|
}
|
||||||
|
if errScan := scanner.Err(); errScan != nil {
|
||||||
|
recordAPIResponseError(ctx, e.cfg, errScan)
|
||||||
|
reporter.publishFailure(ctx)
|
||||||
|
out <- cliproxyexecutor.StreamChunk{Err: errScan}
|
||||||
|
}
|
||||||
|
}()
|
||||||
|
return stream, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// CountTokens estimates token count for Kimi requests.
|
||||||
|
func (e *KimiExecutor) CountTokens(ctx context.Context, auth *cliproxyauth.Auth, req cliproxyexecutor.Request, opts cliproxyexecutor.Options) (cliproxyexecutor.Response, error) {
|
||||||
|
auth.Attributes["base_url"] = kimiauth.KimiAPIBaseURL
|
||||||
|
return e.ClaudeExecutor.CountTokens(ctx, auth, req, opts)
|
||||||
|
}
|
||||||
|
|
||||||
|
func normalizeKimiToolMessageLinks(body []byte) ([]byte, error) {
|
||||||
|
if len(body) == 0 || !gjson.ValidBytes(body) {
|
||||||
|
return body, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
messages := gjson.GetBytes(body, "messages")
|
||||||
|
if !messages.Exists() || !messages.IsArray() {
|
||||||
|
return body, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
out := body
|
||||||
|
pending := make([]string, 0)
|
||||||
|
patched := 0
|
||||||
|
patchedReasoning := 0
|
||||||
|
ambiguous := 0
|
||||||
|
latestReasoning := ""
|
||||||
|
hasLatestReasoning := false
|
||||||
|
|
||||||
|
removePending := func(id string) {
|
||||||
|
for idx := range pending {
|
||||||
|
if pending[idx] != id {
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
pending = append(pending[:idx], pending[idx+1:]...)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
msgs := messages.Array()
|
||||||
|
for msgIdx := range msgs {
|
||||||
|
msg := msgs[msgIdx]
|
||||||
|
role := strings.TrimSpace(msg.Get("role").String())
|
||||||
|
switch role {
|
||||||
|
case "assistant":
|
||||||
|
reasoning := msg.Get("reasoning_content")
|
||||||
|
if reasoning.Exists() {
|
||||||
|
reasoningText := reasoning.String()
|
||||||
|
if strings.TrimSpace(reasoningText) != "" {
|
||||||
|
latestReasoning = reasoningText
|
||||||
|
hasLatestReasoning = true
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
toolCalls := msg.Get("tool_calls")
|
||||||
|
if !toolCalls.Exists() || !toolCalls.IsArray() || len(toolCalls.Array()) == 0 {
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
|
||||||
|
if !reasoning.Exists() || strings.TrimSpace(reasoning.String()) == "" {
|
||||||
|
reasoningText := fallbackAssistantReasoning(msg, hasLatestReasoning, latestReasoning)
|
||||||
|
path := fmt.Sprintf("messages.%d.reasoning_content", msgIdx)
|
||||||
|
next, err := sjson.SetBytes(out, path, reasoningText)
|
||||||
|
if err != nil {
|
||||||
|
return body, fmt.Errorf("kimi executor: failed to set assistant reasoning_content: %w", err)
|
||||||
|
}
|
||||||
|
out = next
|
||||||
|
patchedReasoning++
|
||||||
|
}
|
||||||
|
|
||||||
|
for _, tc := range toolCalls.Array() {
|
||||||
|
id := strings.TrimSpace(tc.Get("id").String())
|
||||||
|
if id == "" {
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
pending = append(pending, id)
|
||||||
|
}
|
||||||
|
case "tool":
|
||||||
|
toolCallID := strings.TrimSpace(msg.Get("tool_call_id").String())
|
||||||
|
if toolCallID == "" {
|
||||||
|
toolCallID = strings.TrimSpace(msg.Get("call_id").String())
|
||||||
|
if toolCallID != "" {
|
||||||
|
path := fmt.Sprintf("messages.%d.tool_call_id", msgIdx)
|
||||||
|
next, err := sjson.SetBytes(out, path, toolCallID)
|
||||||
|
if err != nil {
|
||||||
|
return body, fmt.Errorf("kimi executor: failed to set tool_call_id from call_id: %w", err)
|
||||||
|
}
|
||||||
|
out = next
|
||||||
|
patched++
|
||||||
|
}
|
||||||
|
}
|
||||||
|
if toolCallID == "" {
|
||||||
|
if len(pending) == 1 {
|
||||||
|
toolCallID = pending[0]
|
||||||
|
path := fmt.Sprintf("messages.%d.tool_call_id", msgIdx)
|
||||||
|
next, err := sjson.SetBytes(out, path, toolCallID)
|
||||||
|
if err != nil {
|
||||||
|
return body, fmt.Errorf("kimi executor: failed to infer tool_call_id: %w", err)
|
||||||
|
}
|
||||||
|
out = next
|
||||||
|
patched++
|
||||||
|
} else if len(pending) > 1 {
|
||||||
|
ambiguous++
|
||||||
|
}
|
||||||
|
}
|
||||||
|
if toolCallID != "" {
|
||||||
|
removePending(toolCallID)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
if patched > 0 || patchedReasoning > 0 {
|
||||||
|
log.WithFields(log.Fields{
|
||||||
|
"patched_tool_messages": patched,
|
||||||
|
"patched_reasoning_messages": patchedReasoning,
|
||||||
|
}).Debug("kimi executor: normalized tool message fields")
|
||||||
|
}
|
||||||
|
if ambiguous > 0 {
|
||||||
|
log.WithFields(log.Fields{
|
||||||
|
"ambiguous_tool_messages": ambiguous,
|
||||||
|
"pending_tool_calls": len(pending),
|
||||||
|
}).Warn("kimi executor: tool messages missing tool_call_id with ambiguous candidates")
|
||||||
|
}
|
||||||
|
|
||||||
|
return out, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func fallbackAssistantReasoning(msg gjson.Result, hasLatest bool, latest string) string {
|
||||||
|
if hasLatest && strings.TrimSpace(latest) != "" {
|
||||||
|
return latest
|
||||||
|
}
|
||||||
|
|
||||||
|
content := msg.Get("content")
|
||||||
|
if content.Type == gjson.String {
|
||||||
|
if text := strings.TrimSpace(content.String()); text != "" {
|
||||||
|
return text
|
||||||
|
}
|
||||||
|
}
|
||||||
|
if content.IsArray() {
|
||||||
|
parts := make([]string, 0, len(content.Array()))
|
||||||
|
for _, item := range content.Array() {
|
||||||
|
text := strings.TrimSpace(item.Get("text").String())
|
||||||
|
if text == "" {
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
parts = append(parts, text)
|
||||||
|
}
|
||||||
|
if len(parts) > 0 {
|
||||||
|
return strings.Join(parts, "\n")
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
return "[reasoning unavailable]"
|
||||||
|
}
|
||||||
|
|
||||||
|
// Refresh refreshes the Kimi token using the refresh token.
|
||||||
|
func (e *KimiExecutor) Refresh(ctx context.Context, auth *cliproxyauth.Auth) (*cliproxyauth.Auth, error) {
|
||||||
|
log.Debugf("kimi executor: refresh called")
|
||||||
|
if auth == nil {
|
||||||
|
return nil, fmt.Errorf("kimi executor: auth is nil")
|
||||||
|
}
|
||||||
|
// Expect refresh_token in metadata for OAuth-based accounts
|
||||||
|
var refreshToken string
|
||||||
|
if auth.Metadata != nil {
|
||||||
|
if v, ok := auth.Metadata["refresh_token"].(string); ok && strings.TrimSpace(v) != "" {
|
||||||
|
refreshToken = v
|
||||||
|
}
|
||||||
|
}
|
||||||
|
if strings.TrimSpace(refreshToken) == "" {
|
||||||
|
// Nothing to refresh
|
||||||
|
return auth, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
client := kimiauth.NewDeviceFlowClientWithDeviceID(e.cfg, resolveKimiDeviceID(auth))
|
||||||
|
td, err := client.RefreshToken(ctx, refreshToken)
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
if auth.Metadata == nil {
|
||||||
|
auth.Metadata = make(map[string]any)
|
||||||
|
}
|
||||||
|
auth.Metadata["access_token"] = td.AccessToken
|
||||||
|
if td.RefreshToken != "" {
|
||||||
|
auth.Metadata["refresh_token"] = td.RefreshToken
|
||||||
|
}
|
||||||
|
if td.ExpiresAt > 0 {
|
||||||
|
exp := time.Unix(td.ExpiresAt, 0).UTC().Format(time.RFC3339)
|
||||||
|
auth.Metadata["expired"] = exp
|
||||||
|
}
|
||||||
|
auth.Metadata["type"] = "kimi"
|
||||||
|
now := time.Now().Format(time.RFC3339)
|
||||||
|
auth.Metadata["last_refresh"] = now
|
||||||
|
return auth, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// applyKimiHeaders sets required headers for Kimi API requests.
|
||||||
|
// Headers match kimi-cli client for compatibility.
|
||||||
|
func applyKimiHeaders(r *http.Request, token string, stream bool) {
|
||||||
|
r.Header.Set("Content-Type", "application/json")
|
||||||
|
r.Header.Set("Authorization", "Bearer "+token)
|
||||||
|
// Match kimi-cli headers exactly
|
||||||
|
r.Header.Set("User-Agent", "KimiCLI/1.10.6")
|
||||||
|
r.Header.Set("X-Msh-Platform", "kimi_cli")
|
||||||
|
r.Header.Set("X-Msh-Version", "1.10.6")
|
||||||
|
r.Header.Set("X-Msh-Device-Name", getKimiHostname())
|
||||||
|
r.Header.Set("X-Msh-Device-Model", getKimiDeviceModel())
|
||||||
|
r.Header.Set("X-Msh-Device-Id", getKimiDeviceID())
|
||||||
|
if stream {
|
||||||
|
r.Header.Set("Accept", "text/event-stream")
|
||||||
|
return
|
||||||
|
}
|
||||||
|
r.Header.Set("Accept", "application/json")
|
||||||
|
}
|
||||||
|
|
||||||
|
func resolveKimiDeviceIDFromAuth(auth *cliproxyauth.Auth) string {
|
||||||
|
if auth == nil || auth.Metadata == nil {
|
||||||
|
return ""
|
||||||
|
}
|
||||||
|
|
||||||
|
deviceIDRaw, ok := auth.Metadata["device_id"]
|
||||||
|
if !ok {
|
||||||
|
return ""
|
||||||
|
}
|
||||||
|
|
||||||
|
deviceID, ok := deviceIDRaw.(string)
|
||||||
|
if !ok {
|
||||||
|
return ""
|
||||||
|
}
|
||||||
|
|
||||||
|
return strings.TrimSpace(deviceID)
|
||||||
|
}
|
||||||
|
|
||||||
|
func resolveKimiDeviceIDFromStorage(auth *cliproxyauth.Auth) string {
|
||||||
|
if auth == nil {
|
||||||
|
return ""
|
||||||
|
}
|
||||||
|
|
||||||
|
storage, ok := auth.Storage.(*kimiauth.KimiTokenStorage)
|
||||||
|
if !ok || storage == nil {
|
||||||
|
return ""
|
||||||
|
}
|
||||||
|
|
||||||
|
return strings.TrimSpace(storage.DeviceID)
|
||||||
|
}
|
||||||
|
|
||||||
|
func resolveKimiDeviceID(auth *cliproxyauth.Auth) string {
|
||||||
|
deviceID := resolveKimiDeviceIDFromAuth(auth)
|
||||||
|
if deviceID != "" {
|
||||||
|
return deviceID
|
||||||
|
}
|
||||||
|
return resolveKimiDeviceIDFromStorage(auth)
|
||||||
|
}
|
||||||
|
|
||||||
|
func applyKimiHeadersWithAuth(r *http.Request, token string, stream bool, auth *cliproxyauth.Auth) {
|
||||||
|
applyKimiHeaders(r, token, stream)
|
||||||
|
|
||||||
|
if deviceID := resolveKimiDeviceID(auth); deviceID != "" {
|
||||||
|
r.Header.Set("X-Msh-Device-Id", deviceID)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// getKimiHostname returns the machine hostname.
|
||||||
|
func getKimiHostname() string {
|
||||||
|
hostname, err := os.Hostname()
|
||||||
|
if err != nil {
|
||||||
|
return "unknown"
|
||||||
|
}
|
||||||
|
return hostname
|
||||||
|
}
|
||||||
|
|
||||||
|
// getKimiDeviceModel returns a device model string matching kimi-cli format.
|
||||||
|
func getKimiDeviceModel() string {
|
||||||
|
return fmt.Sprintf("%s %s", runtime.GOOS, runtime.GOARCH)
|
||||||
|
}
|
||||||
|
|
||||||
|
// getKimiDeviceID returns a stable device ID, matching kimi-cli storage location.
|
||||||
|
func getKimiDeviceID() string {
|
||||||
|
homeDir, err := os.UserHomeDir()
|
||||||
|
if err != nil {
|
||||||
|
return "cli-proxy-api-device"
|
||||||
|
}
|
||||||
|
// Check kimi-cli's device_id location first (platform-specific)
|
||||||
|
var kimiShareDir string
|
||||||
|
switch runtime.GOOS {
|
||||||
|
case "darwin":
|
||||||
|
kimiShareDir = filepath.Join(homeDir, "Library", "Application Support", "kimi")
|
||||||
|
case "windows":
|
||||||
|
appData := os.Getenv("APPDATA")
|
||||||
|
if appData == "" {
|
||||||
|
appData = filepath.Join(homeDir, "AppData", "Roaming")
|
||||||
|
}
|
||||||
|
kimiShareDir = filepath.Join(appData, "kimi")
|
||||||
|
default: // linux and other unix-like
|
||||||
|
kimiShareDir = filepath.Join(homeDir, ".local", "share", "kimi")
|
||||||
|
}
|
||||||
|
deviceIDPath := filepath.Join(kimiShareDir, "device_id")
|
||||||
|
if data, err := os.ReadFile(deviceIDPath); err == nil {
|
||||||
|
return strings.TrimSpace(string(data))
|
||||||
|
}
|
||||||
|
return "cli-proxy-api-device"
|
||||||
|
}
|
||||||
|
|
||||||
|
// kimiCreds extracts the access token from auth.
|
||||||
|
func kimiCreds(a *cliproxyauth.Auth) (token string) {
|
||||||
|
if a == nil {
|
||||||
|
return ""
|
||||||
|
}
|
||||||
|
// Check metadata first (OAuth flow stores tokens here)
|
||||||
|
if a.Metadata != nil {
|
||||||
|
if v, ok := a.Metadata["access_token"].(string); ok && strings.TrimSpace(v) != "" {
|
||||||
|
return v
|
||||||
|
}
|
||||||
|
}
|
||||||
|
// Fallback to attributes (API key style)
|
||||||
|
if a.Attributes != nil {
|
||||||
|
if v := a.Attributes["access_token"]; v != "" {
|
||||||
|
return v
|
||||||
|
}
|
||||||
|
if v := a.Attributes["api_key"]; v != "" {
|
||||||
|
return v
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return ""
|
||||||
|
}
|
||||||
|
|
||||||
|
// stripKimiPrefix removes the "kimi-" prefix from model names for the upstream API.
|
||||||
|
func stripKimiPrefix(model string) string {
|
||||||
|
model = strings.TrimSpace(model)
|
||||||
|
if strings.HasPrefix(strings.ToLower(model), "kimi-") {
|
||||||
|
return model[5:]
|
||||||
|
}
|
||||||
|
return model
|
||||||
|
}
|
||||||
205
internal/runtime/executor/kimi_executor_test.go
Normal file
205
internal/runtime/executor/kimi_executor_test.go
Normal file
@@ -0,0 +1,205 @@
|
|||||||
|
package executor
|
||||||
|
|
||||||
|
import (
|
||||||
|
"testing"
|
||||||
|
|
||||||
|
"github.com/tidwall/gjson"
|
||||||
|
)
|
||||||
|
|
||||||
|
func TestNormalizeKimiToolMessageLinks_UsesCallIDFallback(t *testing.T) {
|
||||||
|
body := []byte(`{
|
||||||
|
"messages":[
|
||||||
|
{"role":"assistant","tool_calls":[{"id":"list_directory:1","type":"function","function":{"name":"list_directory","arguments":"{}"}}]},
|
||||||
|
{"role":"tool","call_id":"list_directory:1","content":"[]"}
|
||||||
|
]
|
||||||
|
}`)
|
||||||
|
|
||||||
|
out, err := normalizeKimiToolMessageLinks(body)
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("normalizeKimiToolMessageLinks() error = %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
got := gjson.GetBytes(out, "messages.1.tool_call_id").String()
|
||||||
|
if got != "list_directory:1" {
|
||||||
|
t.Fatalf("messages.1.tool_call_id = %q, want %q", got, "list_directory:1")
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestNormalizeKimiToolMessageLinks_InferSinglePendingID(t *testing.T) {
|
||||||
|
body := []byte(`{
|
||||||
|
"messages":[
|
||||||
|
{"role":"assistant","tool_calls":[{"id":"call_123","type":"function","function":{"name":"read_file","arguments":"{}"}}]},
|
||||||
|
{"role":"tool","content":"file-content"}
|
||||||
|
]
|
||||||
|
}`)
|
||||||
|
|
||||||
|
out, err := normalizeKimiToolMessageLinks(body)
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("normalizeKimiToolMessageLinks() error = %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
got := gjson.GetBytes(out, "messages.1.tool_call_id").String()
|
||||||
|
if got != "call_123" {
|
||||||
|
t.Fatalf("messages.1.tool_call_id = %q, want %q", got, "call_123")
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestNormalizeKimiToolMessageLinks_AmbiguousMissingIDIsNotInferred(t *testing.T) {
|
||||||
|
body := []byte(`{
|
||||||
|
"messages":[
|
||||||
|
{"role":"assistant","tool_calls":[
|
||||||
|
{"id":"call_1","type":"function","function":{"name":"list_directory","arguments":"{}"}},
|
||||||
|
{"id":"call_2","type":"function","function":{"name":"read_file","arguments":"{}"}}
|
||||||
|
]},
|
||||||
|
{"role":"tool","content":"result-without-id"}
|
||||||
|
]
|
||||||
|
}`)
|
||||||
|
|
||||||
|
out, err := normalizeKimiToolMessageLinks(body)
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("normalizeKimiToolMessageLinks() error = %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
if gjson.GetBytes(out, "messages.1.tool_call_id").Exists() {
|
||||||
|
t.Fatalf("messages.1.tool_call_id should be absent for ambiguous case, got %q", gjson.GetBytes(out, "messages.1.tool_call_id").String())
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestNormalizeKimiToolMessageLinks_PreservesExistingToolCallID(t *testing.T) {
|
||||||
|
body := []byte(`{
|
||||||
|
"messages":[
|
||||||
|
{"role":"assistant","tool_calls":[{"id":"call_1","type":"function","function":{"name":"list_directory","arguments":"{}"}}]},
|
||||||
|
{"role":"tool","tool_call_id":"call_1","call_id":"different-id","content":"result"}
|
||||||
|
]
|
||||||
|
}`)
|
||||||
|
|
||||||
|
out, err := normalizeKimiToolMessageLinks(body)
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("normalizeKimiToolMessageLinks() error = %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
got := gjson.GetBytes(out, "messages.1.tool_call_id").String()
|
||||||
|
if got != "call_1" {
|
||||||
|
t.Fatalf("messages.1.tool_call_id = %q, want %q", got, "call_1")
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestNormalizeKimiToolMessageLinks_InheritsPreviousReasoningForAssistantToolCalls(t *testing.T) {
|
||||||
|
body := []byte(`{
|
||||||
|
"messages":[
|
||||||
|
{"role":"assistant","content":"plan","reasoning_content":"previous reasoning"},
|
||||||
|
{"role":"assistant","tool_calls":[{"id":"call_1","type":"function","function":{"name":"list_directory","arguments":"{}"}}]}
|
||||||
|
]
|
||||||
|
}`)
|
||||||
|
|
||||||
|
out, err := normalizeKimiToolMessageLinks(body)
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("normalizeKimiToolMessageLinks() error = %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
got := gjson.GetBytes(out, "messages.1.reasoning_content").String()
|
||||||
|
if got != "previous reasoning" {
|
||||||
|
t.Fatalf("messages.1.reasoning_content = %q, want %q", got, "previous reasoning")
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestNormalizeKimiToolMessageLinks_InsertsFallbackReasoningWhenMissing(t *testing.T) {
|
||||||
|
body := []byte(`{
|
||||||
|
"messages":[
|
||||||
|
{"role":"assistant","tool_calls":[{"id":"call_1","type":"function","function":{"name":"list_directory","arguments":"{}"}}]}
|
||||||
|
]
|
||||||
|
}`)
|
||||||
|
|
||||||
|
out, err := normalizeKimiToolMessageLinks(body)
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("normalizeKimiToolMessageLinks() error = %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
reasoning := gjson.GetBytes(out, "messages.0.reasoning_content")
|
||||||
|
if !reasoning.Exists() {
|
||||||
|
t.Fatalf("messages.0.reasoning_content should exist")
|
||||||
|
}
|
||||||
|
if reasoning.String() != "[reasoning unavailable]" {
|
||||||
|
t.Fatalf("messages.0.reasoning_content = %q, want %q", reasoning.String(), "[reasoning unavailable]")
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestNormalizeKimiToolMessageLinks_UsesContentAsReasoningFallback(t *testing.T) {
|
||||||
|
body := []byte(`{
|
||||||
|
"messages":[
|
||||||
|
{"role":"assistant","content":[{"type":"text","text":"first line"},{"type":"text","text":"second line"}],"tool_calls":[{"id":"call_1","type":"function","function":{"name":"list_directory","arguments":"{}"}}]}
|
||||||
|
]
|
||||||
|
}`)
|
||||||
|
|
||||||
|
out, err := normalizeKimiToolMessageLinks(body)
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("normalizeKimiToolMessageLinks() error = %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
got := gjson.GetBytes(out, "messages.0.reasoning_content").String()
|
||||||
|
if got != "first line\nsecond line" {
|
||||||
|
t.Fatalf("messages.0.reasoning_content = %q, want %q", got, "first line\nsecond line")
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestNormalizeKimiToolMessageLinks_ReplacesEmptyReasoningContent(t *testing.T) {
|
||||||
|
body := []byte(`{
|
||||||
|
"messages":[
|
||||||
|
{"role":"assistant","content":"assistant summary","tool_calls":[{"id":"call_1","type":"function","function":{"name":"list_directory","arguments":"{}"}}],"reasoning_content":""}
|
||||||
|
]
|
||||||
|
}`)
|
||||||
|
|
||||||
|
out, err := normalizeKimiToolMessageLinks(body)
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("normalizeKimiToolMessageLinks() error = %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
got := gjson.GetBytes(out, "messages.0.reasoning_content").String()
|
||||||
|
if got != "assistant summary" {
|
||||||
|
t.Fatalf("messages.0.reasoning_content = %q, want %q", got, "assistant summary")
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestNormalizeKimiToolMessageLinks_PreservesExistingAssistantReasoning(t *testing.T) {
|
||||||
|
body := []byte(`{
|
||||||
|
"messages":[
|
||||||
|
{"role":"assistant","tool_calls":[{"id":"call_1","type":"function","function":{"name":"list_directory","arguments":"{}"}}],"reasoning_content":"keep me"}
|
||||||
|
]
|
||||||
|
}`)
|
||||||
|
|
||||||
|
out, err := normalizeKimiToolMessageLinks(body)
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("normalizeKimiToolMessageLinks() error = %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
got := gjson.GetBytes(out, "messages.0.reasoning_content").String()
|
||||||
|
if got != "keep me" {
|
||||||
|
t.Fatalf("messages.0.reasoning_content = %q, want %q", got, "keep me")
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestNormalizeKimiToolMessageLinks_RepairsIDsAndReasoningTogether(t *testing.T) {
|
||||||
|
body := []byte(`{
|
||||||
|
"messages":[
|
||||||
|
{"role":"assistant","tool_calls":[{"id":"call_1","type":"function","function":{"name":"list_directory","arguments":"{}"}}],"reasoning_content":"r1"},
|
||||||
|
{"role":"tool","call_id":"call_1","content":"[]"},
|
||||||
|
{"role":"assistant","tool_calls":[{"id":"call_2","type":"function","function":{"name":"read_file","arguments":"{}"}}]},
|
||||||
|
{"role":"tool","call_id":"call_2","content":"file"}
|
||||||
|
]
|
||||||
|
}`)
|
||||||
|
|
||||||
|
out, err := normalizeKimiToolMessageLinks(body)
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("normalizeKimiToolMessageLinks() error = %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
if got := gjson.GetBytes(out, "messages.1.tool_call_id").String(); got != "call_1" {
|
||||||
|
t.Fatalf("messages.1.tool_call_id = %q, want %q", got, "call_1")
|
||||||
|
}
|
||||||
|
if got := gjson.GetBytes(out, "messages.3.tool_call_id").String(); got != "call_2" {
|
||||||
|
t.Fatalf("messages.3.tool_call_id = %q, want %q", got, "call_2")
|
||||||
|
}
|
||||||
|
if got := gjson.GetBytes(out, "messages.2.reasoning_content").String(); got != "r1" {
|
||||||
|
t.Fatalf("messages.2.reasoning_content = %q, want %q", got, "r1")
|
||||||
|
}
|
||||||
|
}
|
||||||
@@ -2102,6 +2102,22 @@ func (e *KiroExecutor) parseEventStream(body io.Reader) (string, []kiroclaude.Ki
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
case "contextUsageEvent":
|
||||||
|
// Handle context usage events from Kiro API
|
||||||
|
// Format: {"contextUsageEvent": {"contextUsagePercentage": 0.53}}
|
||||||
|
if ctxUsage, ok := event["contextUsageEvent"].(map[string]interface{}); ok {
|
||||||
|
if ctxPct, ok := ctxUsage["contextUsagePercentage"].(float64); ok {
|
||||||
|
upstreamContextPercentage = ctxPct
|
||||||
|
log.Debugf("kiro: parseEventStream received contextUsageEvent: %.2f%%", ctxPct*100)
|
||||||
|
}
|
||||||
|
} else {
|
||||||
|
// Try direct field (fallback)
|
||||||
|
if ctxPct, ok := event["contextUsagePercentage"].(float64); ok {
|
||||||
|
upstreamContextPercentage = ctxPct
|
||||||
|
log.Debugf("kiro: parseEventStream received contextUsagePercentage (direct): %.2f%%", ctxPct*100)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
case "error", "exception", "internalServerException", "invalidStateEvent":
|
case "error", "exception", "internalServerException", "invalidStateEvent":
|
||||||
// Handle error events from Kiro API stream
|
// Handle error events from Kiro API stream
|
||||||
errMsg := ""
|
errMsg := ""
|
||||||
@@ -2705,6 +2721,22 @@ func (e *KiroExecutor) streamToChannel(ctx context.Context, body io.Reader, out
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
case "contextUsageEvent":
|
||||||
|
// Handle context usage events from Kiro API
|
||||||
|
// Format: {"contextUsageEvent": {"contextUsagePercentage": 0.53}}
|
||||||
|
if ctxUsage, ok := event["contextUsageEvent"].(map[string]interface{}); ok {
|
||||||
|
if ctxPct, ok := ctxUsage["contextUsagePercentage"].(float64); ok {
|
||||||
|
upstreamContextPercentage = ctxPct
|
||||||
|
log.Debugf("kiro: streamToChannel received contextUsageEvent: %.2f%%", ctxPct*100)
|
||||||
|
}
|
||||||
|
} else {
|
||||||
|
// Try direct field (fallback)
|
||||||
|
if ctxPct, ok := event["contextUsagePercentage"].(float64); ok {
|
||||||
|
upstreamContextPercentage = ctxPct
|
||||||
|
log.Debugf("kiro: streamToChannel received contextUsagePercentage (direct): %.2f%%", ctxPct*100)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
case "error", "exception", "internalServerException":
|
case "error", "exception", "internalServerException":
|
||||||
// Handle error events from Kiro API stream
|
// Handle error events from Kiro API stream
|
||||||
errMsg := ""
|
errMsg := ""
|
||||||
|
|||||||
@@ -7,5 +7,6 @@ import (
|
|||||||
_ "github.com/router-for-me/CLIProxyAPI/v6/internal/thinking/provider/gemini"
|
_ "github.com/router-for-me/CLIProxyAPI/v6/internal/thinking/provider/gemini"
|
||||||
_ "github.com/router-for-me/CLIProxyAPI/v6/internal/thinking/provider/geminicli"
|
_ "github.com/router-for-me/CLIProxyAPI/v6/internal/thinking/provider/geminicli"
|
||||||
_ "github.com/router-for-me/CLIProxyAPI/v6/internal/thinking/provider/iflow"
|
_ "github.com/router-for-me/CLIProxyAPI/v6/internal/thinking/provider/iflow"
|
||||||
|
_ "github.com/router-for-me/CLIProxyAPI/v6/internal/thinking/provider/kimi"
|
||||||
_ "github.com/router-for-me/CLIProxyAPI/v6/internal/thinking/provider/openai"
|
_ "github.com/router-for-me/CLIProxyAPI/v6/internal/thinking/provider/openai"
|
||||||
)
|
)
|
||||||
|
|||||||
@@ -21,6 +21,9 @@ import (
|
|||||||
cliproxyauth "github.com/router-for-me/CLIProxyAPI/v6/sdk/cliproxy/auth"
|
cliproxyauth "github.com/router-for-me/CLIProxyAPI/v6/sdk/cliproxy/auth"
|
||||||
)
|
)
|
||||||
|
|
||||||
|
// gcInterval defines minimum time between garbage collection runs.
|
||||||
|
const gcInterval = 5 * time.Minute
|
||||||
|
|
||||||
// GitTokenStore persists token records and auth metadata using git as the backing storage.
|
// GitTokenStore persists token records and auth metadata using git as the backing storage.
|
||||||
type GitTokenStore struct {
|
type GitTokenStore struct {
|
||||||
mu sync.Mutex
|
mu sync.Mutex
|
||||||
@@ -31,6 +34,7 @@ type GitTokenStore struct {
|
|||||||
remote string
|
remote string
|
||||||
username string
|
username string
|
||||||
password string
|
password string
|
||||||
|
lastGC time.Time
|
||||||
}
|
}
|
||||||
|
|
||||||
// NewGitTokenStore creates a token store that saves credentials to disk through the
|
// NewGitTokenStore creates a token store that saves credentials to disk through the
|
||||||
@@ -613,6 +617,7 @@ func (s *GitTokenStore) commitAndPushLocked(message string, relPaths ...string)
|
|||||||
} else if errRewrite := s.rewriteHeadAsSingleCommit(repo, headRef.Name(), commitHash, message, signature); errRewrite != nil {
|
} else if errRewrite := s.rewriteHeadAsSingleCommit(repo, headRef.Name(), commitHash, message, signature); errRewrite != nil {
|
||||||
return errRewrite
|
return errRewrite
|
||||||
}
|
}
|
||||||
|
s.maybeRunGC(repo)
|
||||||
if err = repo.Push(&git.PushOptions{Auth: s.gitAuth(), Force: true}); err != nil {
|
if err = repo.Push(&git.PushOptions{Auth: s.gitAuth(), Force: true}); err != nil {
|
||||||
if errors.Is(err, git.NoErrAlreadyUpToDate) {
|
if errors.Is(err, git.NoErrAlreadyUpToDate) {
|
||||||
return nil
|
return nil
|
||||||
@@ -652,6 +657,23 @@ func (s *GitTokenStore) rewriteHeadAsSingleCommit(repo *git.Repository, branch p
|
|||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func (s *GitTokenStore) maybeRunGC(repo *git.Repository) {
|
||||||
|
now := time.Now()
|
||||||
|
if now.Sub(s.lastGC) < gcInterval {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
s.lastGC = now
|
||||||
|
|
||||||
|
pruneOpts := git.PruneOptions{
|
||||||
|
OnlyObjectsOlderThan: now,
|
||||||
|
Handler: repo.DeleteObject,
|
||||||
|
}
|
||||||
|
if err := repo.Prune(pruneOpts); err != nil && !errors.Is(err, git.ErrLooseObjectsNotSupported) {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
_ = repo.RepackObjects(&git.RepackConfig{})
|
||||||
|
}
|
||||||
|
|
||||||
// PersistConfig commits and pushes configuration changes to git.
|
// PersistConfig commits and pushes configuration changes to git.
|
||||||
func (s *GitTokenStore) PersistConfig(_ context.Context) error {
|
func (s *GitTokenStore) PersistConfig(_ context.Context) error {
|
||||||
if err := s.EnsureRepository(); err != nil {
|
if err := s.EnsureRepository(); err != nil {
|
||||||
|
|||||||
@@ -18,6 +18,7 @@ var providerAppliers = map[string]ProviderApplier{
|
|||||||
"codex": nil,
|
"codex": nil,
|
||||||
"iflow": nil,
|
"iflow": nil,
|
||||||
"antigravity": nil,
|
"antigravity": nil,
|
||||||
|
"kimi": nil,
|
||||||
}
|
}
|
||||||
|
|
||||||
// GetProviderApplier returns the ProviderApplier for the given provider name.
|
// GetProviderApplier returns the ProviderApplier for the given provider name.
|
||||||
@@ -326,6 +327,9 @@ func extractThinkingConfig(body []byte, provider string) ThinkingConfig {
|
|||||||
return config
|
return config
|
||||||
}
|
}
|
||||||
return extractOpenAIConfig(body)
|
return extractOpenAIConfig(body)
|
||||||
|
case "kimi":
|
||||||
|
// Kimi uses OpenAI-compatible reasoning_effort format
|
||||||
|
return extractOpenAIConfig(body)
|
||||||
default:
|
default:
|
||||||
return ThinkingConfig{}
|
return ThinkingConfig{}
|
||||||
}
|
}
|
||||||
|
|||||||
126
internal/thinking/provider/kimi/apply.go
Normal file
126
internal/thinking/provider/kimi/apply.go
Normal file
@@ -0,0 +1,126 @@
|
|||||||
|
// Package kimi implements thinking configuration for Kimi (Moonshot AI) models.
|
||||||
|
//
|
||||||
|
// Kimi models use the OpenAI-compatible reasoning_effort format with discrete levels
|
||||||
|
// (low/medium/high). The provider strips any existing thinking config and applies
|
||||||
|
// the unified ThinkingConfig in OpenAI format.
|
||||||
|
package kimi
|
||||||
|
|
||||||
|
import (
|
||||||
|
"fmt"
|
||||||
|
|
||||||
|
"github.com/router-for-me/CLIProxyAPI/v6/internal/registry"
|
||||||
|
"github.com/router-for-me/CLIProxyAPI/v6/internal/thinking"
|
||||||
|
"github.com/tidwall/gjson"
|
||||||
|
"github.com/tidwall/sjson"
|
||||||
|
)
|
||||||
|
|
||||||
|
// Applier implements thinking.ProviderApplier for Kimi models.
|
||||||
|
//
|
||||||
|
// Kimi-specific behavior:
|
||||||
|
// - Output format: reasoning_effort (string: low/medium/high)
|
||||||
|
// - Uses OpenAI-compatible format
|
||||||
|
// - Supports budget-to-level conversion
|
||||||
|
type Applier struct{}
|
||||||
|
|
||||||
|
var _ thinking.ProviderApplier = (*Applier)(nil)
|
||||||
|
|
||||||
|
// NewApplier creates a new Kimi thinking applier.
|
||||||
|
func NewApplier() *Applier {
|
||||||
|
return &Applier{}
|
||||||
|
}
|
||||||
|
|
||||||
|
func init() {
|
||||||
|
thinking.RegisterProvider("kimi", NewApplier())
|
||||||
|
}
|
||||||
|
|
||||||
|
// Apply applies thinking configuration to Kimi request body.
|
||||||
|
//
|
||||||
|
// Expected output format:
|
||||||
|
//
|
||||||
|
// {
|
||||||
|
// "reasoning_effort": "high"
|
||||||
|
// }
|
||||||
|
func (a *Applier) Apply(body []byte, config thinking.ThinkingConfig, modelInfo *registry.ModelInfo) ([]byte, error) {
|
||||||
|
if thinking.IsUserDefinedModel(modelInfo) {
|
||||||
|
return applyCompatibleKimi(body, config)
|
||||||
|
}
|
||||||
|
if modelInfo.Thinking == nil {
|
||||||
|
return body, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
if len(body) == 0 || !gjson.ValidBytes(body) {
|
||||||
|
body = []byte(`{}`)
|
||||||
|
}
|
||||||
|
|
||||||
|
var effort string
|
||||||
|
switch config.Mode {
|
||||||
|
case thinking.ModeLevel:
|
||||||
|
if config.Level == "" {
|
||||||
|
return body, nil
|
||||||
|
}
|
||||||
|
effort = string(config.Level)
|
||||||
|
case thinking.ModeNone:
|
||||||
|
// Kimi uses "none" to disable thinking
|
||||||
|
effort = string(thinking.LevelNone)
|
||||||
|
case thinking.ModeBudget:
|
||||||
|
// Convert budget to level using threshold mapping
|
||||||
|
level, ok := thinking.ConvertBudgetToLevel(config.Budget)
|
||||||
|
if !ok {
|
||||||
|
return body, nil
|
||||||
|
}
|
||||||
|
effort = level
|
||||||
|
case thinking.ModeAuto:
|
||||||
|
// Auto mode maps to "auto" effort
|
||||||
|
effort = string(thinking.LevelAuto)
|
||||||
|
default:
|
||||||
|
return body, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
if effort == "" {
|
||||||
|
return body, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
result, err := sjson.SetBytes(body, "reasoning_effort", effort)
|
||||||
|
if err != nil {
|
||||||
|
return body, fmt.Errorf("kimi thinking: failed to set reasoning_effort: %w", err)
|
||||||
|
}
|
||||||
|
return result, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// applyCompatibleKimi applies thinking config for user-defined Kimi models.
|
||||||
|
func applyCompatibleKimi(body []byte, config thinking.ThinkingConfig) ([]byte, error) {
|
||||||
|
if len(body) == 0 || !gjson.ValidBytes(body) {
|
||||||
|
body = []byte(`{}`)
|
||||||
|
}
|
||||||
|
|
||||||
|
var effort string
|
||||||
|
switch config.Mode {
|
||||||
|
case thinking.ModeLevel:
|
||||||
|
if config.Level == "" {
|
||||||
|
return body, nil
|
||||||
|
}
|
||||||
|
effort = string(config.Level)
|
||||||
|
case thinking.ModeNone:
|
||||||
|
effort = string(thinking.LevelNone)
|
||||||
|
if config.Level != "" {
|
||||||
|
effort = string(config.Level)
|
||||||
|
}
|
||||||
|
case thinking.ModeAuto:
|
||||||
|
effort = string(thinking.LevelAuto)
|
||||||
|
case thinking.ModeBudget:
|
||||||
|
// Convert budget to level
|
||||||
|
level, ok := thinking.ConvertBudgetToLevel(config.Budget)
|
||||||
|
if !ok {
|
||||||
|
return body, nil
|
||||||
|
}
|
||||||
|
effort = level
|
||||||
|
default:
|
||||||
|
return body, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
result, err := sjson.SetBytes(body, "reasoning_effort", effort)
|
||||||
|
if err != nil {
|
||||||
|
return body, fmt.Errorf("kimi thinking: failed to set reasoning_effort: %w", err)
|
||||||
|
}
|
||||||
|
return result, nil
|
||||||
|
}
|
||||||
@@ -344,7 +344,8 @@ func ConvertClaudeRequestToAntigravity(modelName string, inputRawJSON []byte, _
|
|||||||
// Inject interleaved thinking hint when both tools and thinking are active
|
// Inject interleaved thinking hint when both tools and thinking are active
|
||||||
hasTools := toolDeclCount > 0
|
hasTools := toolDeclCount > 0
|
||||||
thinkingResult := gjson.GetBytes(rawJSON, "thinking")
|
thinkingResult := gjson.GetBytes(rawJSON, "thinking")
|
||||||
hasThinking := thinkingResult.Exists() && thinkingResult.IsObject() && thinkingResult.Get("type").String() == "enabled"
|
thinkingType := thinkingResult.Get("type").String()
|
||||||
|
hasThinking := thinkingResult.Exists() && thinkingResult.IsObject() && (thinkingType == "enabled" || thinkingType == "adaptive")
|
||||||
isClaudeThinking := util.IsClaudeThinkingModel(modelName)
|
isClaudeThinking := util.IsClaudeThinkingModel(modelName)
|
||||||
|
|
||||||
if hasTools && hasThinking && isClaudeThinking {
|
if hasTools && hasThinking && isClaudeThinking {
|
||||||
@@ -377,12 +378,18 @@ func ConvertClaudeRequestToAntigravity(modelName string, inputRawJSON []byte, _
|
|||||||
|
|
||||||
// Map Anthropic thinking -> Gemini thinkingBudget/include_thoughts when type==enabled
|
// Map Anthropic thinking -> Gemini thinkingBudget/include_thoughts when type==enabled
|
||||||
if t := gjson.GetBytes(rawJSON, "thinking"); enableThoughtTranslate && t.Exists() && t.IsObject() {
|
if t := gjson.GetBytes(rawJSON, "thinking"); enableThoughtTranslate && t.Exists() && t.IsObject() {
|
||||||
if t.Get("type").String() == "enabled" {
|
switch t.Get("type").String() {
|
||||||
|
case "enabled":
|
||||||
if b := t.Get("budget_tokens"); b.Exists() && b.Type == gjson.Number {
|
if b := t.Get("budget_tokens"); b.Exists() && b.Type == gjson.Number {
|
||||||
budget := int(b.Int())
|
budget := int(b.Int())
|
||||||
out, _ = sjson.Set(out, "request.generationConfig.thinkingConfig.thinkingBudget", budget)
|
out, _ = sjson.Set(out, "request.generationConfig.thinkingConfig.thinkingBudget", budget)
|
||||||
out, _ = sjson.Set(out, "request.generationConfig.thinkingConfig.includeThoughts", true)
|
out, _ = sjson.Set(out, "request.generationConfig.thinkingConfig.includeThoughts", true)
|
||||||
}
|
}
|
||||||
|
case "adaptive":
|
||||||
|
// Keep adaptive as a high level sentinel; ApplyThinking resolves it
|
||||||
|
// to model-specific max capability.
|
||||||
|
out, _ = sjson.Set(out, "request.generationConfig.thinkingConfig.thinkingLevel", "high")
|
||||||
|
out, _ = sjson.Set(out, "request.generationConfig.thinkingConfig.includeThoughts", true)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
if v := gjson.GetBytes(rawJSON, "temperature"); v.Exists() && v.Type == gjson.Number {
|
if v := gjson.GetBytes(rawJSON, "temperature"); v.Exists() && v.Type == gjson.Number {
|
||||||
|
|||||||
@@ -222,6 +222,10 @@ func ConvertClaudeRequestToCodex(modelName string, inputRawJSON []byte, _ bool)
|
|||||||
reasoningEffort = effort
|
reasoningEffort = effort
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
case "adaptive":
|
||||||
|
// Claude adaptive means "enable with max capacity"; keep it as highest level
|
||||||
|
// and let ApplyThinking normalize per target model capability.
|
||||||
|
reasoningEffort = string(thinking.LevelXHigh)
|
||||||
case "disabled":
|
case "disabled":
|
||||||
if effort, ok := thinking.ConvertBudgetToLevel(0); ok && effort != "" {
|
if effort, ok := thinking.ConvertBudgetToLevel(0); ok && effort != "" {
|
||||||
reasoningEffort = effort
|
reasoningEffort = effort
|
||||||
|
|||||||
@@ -113,10 +113,10 @@ func ConvertCodexResponseToClaude(_ context.Context, _ string, originalRequestRa
|
|||||||
template = `{"type":"message_delta","delta":{"stop_reason":"tool_use","stop_sequence":null},"usage":{"input_tokens":0,"output_tokens":0}}`
|
template = `{"type":"message_delta","delta":{"stop_reason":"tool_use","stop_sequence":null},"usage":{"input_tokens":0,"output_tokens":0}}`
|
||||||
p := (*param).(*ConvertCodexResponseToClaudeParams).HasToolCall
|
p := (*param).(*ConvertCodexResponseToClaudeParams).HasToolCall
|
||||||
stopReason := rootResult.Get("response.stop_reason").String()
|
stopReason := rootResult.Get("response.stop_reason").String()
|
||||||
if stopReason != "" {
|
if p {
|
||||||
template, _ = sjson.Set(template, "delta.stop_reason", stopReason)
|
|
||||||
} else if p {
|
|
||||||
template, _ = sjson.Set(template, "delta.stop_reason", "tool_use")
|
template, _ = sjson.Set(template, "delta.stop_reason", "tool_use")
|
||||||
|
} else if stopReason == "max_tokens" || stopReason == "stop" {
|
||||||
|
template, _ = sjson.Set(template, "delta.stop_reason", stopReason)
|
||||||
} else {
|
} else {
|
||||||
template, _ = sjson.Set(template, "delta.stop_reason", "end_turn")
|
template, _ = sjson.Set(template, "delta.stop_reason", "end_turn")
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -173,12 +173,18 @@ func ConvertClaudeRequestToCLI(modelName string, inputRawJSON []byte, _ bool) []
|
|||||||
|
|
||||||
// Map Anthropic thinking -> Gemini thinkingBudget/include_thoughts when type==enabled
|
// Map Anthropic thinking -> Gemini thinkingBudget/include_thoughts when type==enabled
|
||||||
if t := gjson.GetBytes(rawJSON, "thinking"); t.Exists() && t.IsObject() {
|
if t := gjson.GetBytes(rawJSON, "thinking"); t.Exists() && t.IsObject() {
|
||||||
if t.Get("type").String() == "enabled" {
|
switch t.Get("type").String() {
|
||||||
|
case "enabled":
|
||||||
if b := t.Get("budget_tokens"); b.Exists() && b.Type == gjson.Number {
|
if b := t.Get("budget_tokens"); b.Exists() && b.Type == gjson.Number {
|
||||||
budget := int(b.Int())
|
budget := int(b.Int())
|
||||||
out, _ = sjson.Set(out, "request.generationConfig.thinkingConfig.thinkingBudget", budget)
|
out, _ = sjson.Set(out, "request.generationConfig.thinkingConfig.thinkingBudget", budget)
|
||||||
out, _ = sjson.Set(out, "request.generationConfig.thinkingConfig.includeThoughts", true)
|
out, _ = sjson.Set(out, "request.generationConfig.thinkingConfig.includeThoughts", true)
|
||||||
}
|
}
|
||||||
|
case "adaptive":
|
||||||
|
// Keep adaptive as a high level sentinel; ApplyThinking resolves it
|
||||||
|
// to model-specific max capability.
|
||||||
|
out, _ = sjson.Set(out, "request.generationConfig.thinkingConfig.thinkingLevel", "high")
|
||||||
|
out, _ = sjson.Set(out, "request.generationConfig.thinkingConfig.includeThoughts", true)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
if v := gjson.GetBytes(rawJSON, "temperature"); v.Exists() && v.Type == gjson.Number {
|
if v := gjson.GetBytes(rawJSON, "temperature"); v.Exists() && v.Type == gjson.Number {
|
||||||
|
|||||||
@@ -14,6 +14,7 @@ import (
|
|||||||
"time"
|
"time"
|
||||||
|
|
||||||
. "github.com/router-for-me/CLIProxyAPI/v6/internal/translator/gemini/openai/chat-completions"
|
. "github.com/router-for-me/CLIProxyAPI/v6/internal/translator/gemini/openai/chat-completions"
|
||||||
|
log "github.com/sirupsen/logrus"
|
||||||
"github.com/tidwall/gjson"
|
"github.com/tidwall/gjson"
|
||||||
"github.com/tidwall/sjson"
|
"github.com/tidwall/sjson"
|
||||||
)
|
)
|
||||||
@@ -77,14 +78,20 @@ func ConvertCliResponseToOpenAI(_ context.Context, _ string, originalRequestRawJ
|
|||||||
template, _ = sjson.Set(template, "id", responseIDResult.String())
|
template, _ = sjson.Set(template, "id", responseIDResult.String())
|
||||||
}
|
}
|
||||||
|
|
||||||
// Extract and set the finish reason.
|
finishReason := ""
|
||||||
if finishReasonResult := gjson.GetBytes(rawJSON, "response.candidates.0.finishReason"); finishReasonResult.Exists() {
|
if stopReasonResult := gjson.GetBytes(rawJSON, "response.stop_reason"); stopReasonResult.Exists() {
|
||||||
template, _ = sjson.Set(template, "choices.0.finish_reason", strings.ToLower(finishReasonResult.String()))
|
finishReason = stopReasonResult.String()
|
||||||
template, _ = sjson.Set(template, "choices.0.native_finish_reason", strings.ToLower(finishReasonResult.String()))
|
|
||||||
}
|
}
|
||||||
|
if finishReason == "" {
|
||||||
|
if finishReasonResult := gjson.GetBytes(rawJSON, "response.candidates.0.finishReason"); finishReasonResult.Exists() {
|
||||||
|
finishReason = finishReasonResult.String()
|
||||||
|
}
|
||||||
|
}
|
||||||
|
finishReason = strings.ToLower(finishReason)
|
||||||
|
|
||||||
// Extract and set usage metadata (token counts).
|
// Extract and set usage metadata (token counts).
|
||||||
if usageResult := gjson.GetBytes(rawJSON, "response.usageMetadata"); usageResult.Exists() {
|
if usageResult := gjson.GetBytes(rawJSON, "response.usageMetadata"); usageResult.Exists() {
|
||||||
|
cachedTokenCount := usageResult.Get("cachedContentTokenCount").Int()
|
||||||
if candidatesTokenCountResult := usageResult.Get("candidatesTokenCount"); candidatesTokenCountResult.Exists() {
|
if candidatesTokenCountResult := usageResult.Get("candidatesTokenCount"); candidatesTokenCountResult.Exists() {
|
||||||
template, _ = sjson.Set(template, "usage.completion_tokens", candidatesTokenCountResult.Int())
|
template, _ = sjson.Set(template, "usage.completion_tokens", candidatesTokenCountResult.Int())
|
||||||
}
|
}
|
||||||
@@ -97,6 +104,14 @@ func ConvertCliResponseToOpenAI(_ context.Context, _ string, originalRequestRawJ
|
|||||||
if thoughtsTokenCount > 0 {
|
if thoughtsTokenCount > 0 {
|
||||||
template, _ = sjson.Set(template, "usage.completion_tokens_details.reasoning_tokens", thoughtsTokenCount)
|
template, _ = sjson.Set(template, "usage.completion_tokens_details.reasoning_tokens", thoughtsTokenCount)
|
||||||
}
|
}
|
||||||
|
// Include cached token count if present (indicates prompt caching is working)
|
||||||
|
if cachedTokenCount > 0 {
|
||||||
|
var err error
|
||||||
|
template, err = sjson.Set(template, "usage.prompt_tokens_details.cached_tokens", cachedTokenCount)
|
||||||
|
if err != nil {
|
||||||
|
log.Warnf("gemini-cli openai response: failed to set cached_tokens: %v", err)
|
||||||
|
}
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
// Process the main content part of the response.
|
// Process the main content part of the response.
|
||||||
@@ -187,6 +202,12 @@ func ConvertCliResponseToOpenAI(_ context.Context, _ string, originalRequestRawJ
|
|||||||
if hasFunctionCall {
|
if hasFunctionCall {
|
||||||
template, _ = sjson.Set(template, "choices.0.finish_reason", "tool_calls")
|
template, _ = sjson.Set(template, "choices.0.finish_reason", "tool_calls")
|
||||||
template, _ = sjson.Set(template, "choices.0.native_finish_reason", "tool_calls")
|
template, _ = sjson.Set(template, "choices.0.native_finish_reason", "tool_calls")
|
||||||
|
} else if finishReason != "" && (*param).(*convertCliResponseToOpenAIChatParams).FunctionIndex == 0 {
|
||||||
|
// Only pass through specific finish reasons
|
||||||
|
if finishReason == "max_tokens" || finishReason == "stop" {
|
||||||
|
template, _ = sjson.Set(template, "choices.0.finish_reason", finishReason)
|
||||||
|
template, _ = sjson.Set(template, "choices.0.native_finish_reason", finishReason)
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
return []string{template}
|
return []string{template}
|
||||||
|
|||||||
@@ -154,12 +154,18 @@ func ConvertClaudeRequestToGemini(modelName string, inputRawJSON []byte, _ bool)
|
|||||||
// Map Anthropic thinking -> Gemini thinkingBudget/include_thoughts when enabled
|
// Map Anthropic thinking -> Gemini thinkingBudget/include_thoughts when enabled
|
||||||
// Translator only does format conversion, ApplyThinking handles model capability validation.
|
// Translator only does format conversion, ApplyThinking handles model capability validation.
|
||||||
if t := gjson.GetBytes(rawJSON, "thinking"); t.Exists() && t.IsObject() {
|
if t := gjson.GetBytes(rawJSON, "thinking"); t.Exists() && t.IsObject() {
|
||||||
if t.Get("type").String() == "enabled" {
|
switch t.Get("type").String() {
|
||||||
|
case "enabled":
|
||||||
if b := t.Get("budget_tokens"); b.Exists() && b.Type == gjson.Number {
|
if b := t.Get("budget_tokens"); b.Exists() && b.Type == gjson.Number {
|
||||||
budget := int(b.Int())
|
budget := int(b.Int())
|
||||||
out, _ = sjson.Set(out, "generationConfig.thinkingConfig.thinkingBudget", budget)
|
out, _ = sjson.Set(out, "generationConfig.thinkingConfig.thinkingBudget", budget)
|
||||||
out, _ = sjson.Set(out, "generationConfig.thinkingConfig.includeThoughts", true)
|
out, _ = sjson.Set(out, "generationConfig.thinkingConfig.includeThoughts", true)
|
||||||
}
|
}
|
||||||
|
case "adaptive":
|
||||||
|
// Keep adaptive as a high level sentinel; ApplyThinking resolves it
|
||||||
|
// to model-specific max capability.
|
||||||
|
out, _ = sjson.Set(out, "generationConfig.thinkingConfig.thinkingLevel", "high")
|
||||||
|
out, _ = sjson.Set(out, "generationConfig.thinkingConfig.includeThoughts", true)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
if v := gjson.GetBytes(rawJSON, "temperature"); v.Exists() && v.Type == gjson.Number {
|
if v := gjson.GetBytes(rawJSON, "temperature"); v.Exists() && v.Type == gjson.Number {
|
||||||
|
|||||||
@@ -129,11 +129,16 @@ func ConvertGeminiResponseToOpenAI(_ context.Context, _ string, originalRequestR
|
|||||||
candidateIndex := int(candidate.Get("index").Int())
|
candidateIndex := int(candidate.Get("index").Int())
|
||||||
template, _ = sjson.Set(template, "choices.0.index", candidateIndex)
|
template, _ = sjson.Set(template, "choices.0.index", candidateIndex)
|
||||||
|
|
||||||
// Extract and set the finish reason.
|
finishReason := ""
|
||||||
if finishReasonResult := candidate.Get("finishReason"); finishReasonResult.Exists() {
|
if stopReasonResult := gjson.GetBytes(rawJSON, "stop_reason"); stopReasonResult.Exists() {
|
||||||
template, _ = sjson.Set(template, "choices.0.finish_reason", strings.ToLower(finishReasonResult.String()))
|
finishReason = stopReasonResult.String()
|
||||||
template, _ = sjson.Set(template, "choices.0.native_finish_reason", strings.ToLower(finishReasonResult.String()))
|
|
||||||
}
|
}
|
||||||
|
if finishReason == "" {
|
||||||
|
if finishReasonResult := gjson.GetBytes(rawJSON, "candidates.0.finishReason"); finishReasonResult.Exists() {
|
||||||
|
finishReason = finishReasonResult.String()
|
||||||
|
}
|
||||||
|
}
|
||||||
|
finishReason = strings.ToLower(finishReason)
|
||||||
|
|
||||||
partsResult := candidate.Get("content.parts")
|
partsResult := candidate.Get("content.parts")
|
||||||
hasFunctionCall := false
|
hasFunctionCall := false
|
||||||
@@ -225,6 +230,12 @@ func ConvertGeminiResponseToOpenAI(_ context.Context, _ string, originalRequestR
|
|||||||
if hasFunctionCall {
|
if hasFunctionCall {
|
||||||
template, _ = sjson.Set(template, "choices.0.finish_reason", "tool_calls")
|
template, _ = sjson.Set(template, "choices.0.finish_reason", "tool_calls")
|
||||||
template, _ = sjson.Set(template, "choices.0.native_finish_reason", "tool_calls")
|
template, _ = sjson.Set(template, "choices.0.native_finish_reason", "tool_calls")
|
||||||
|
} else if finishReason != "" {
|
||||||
|
// Only pass through specific finish reasons
|
||||||
|
if finishReason == "max_tokens" || finishReason == "stop" {
|
||||||
|
template, _ = sjson.Set(template, "choices.0.finish_reason", finishReason)
|
||||||
|
template, _ = sjson.Set(template, "choices.0.native_finish_reason", finishReason)
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
responseStrings = append(responseStrings, template)
|
responseStrings = append(responseStrings, template)
|
||||||
|
|||||||
@@ -117,19 +117,29 @@ func ConvertOpenAIResponsesRequestToGemini(modelName string, inputRawJSON []byte
|
|||||||
switch itemType {
|
switch itemType {
|
||||||
case "message":
|
case "message":
|
||||||
if strings.EqualFold(itemRole, "system") {
|
if strings.EqualFold(itemRole, "system") {
|
||||||
if contentArray := item.Get("content"); contentArray.Exists() && contentArray.IsArray() {
|
if contentArray := item.Get("content"); contentArray.Exists() {
|
||||||
var builder strings.Builder
|
systemInstr := ""
|
||||||
contentArray.ForEach(func(_, contentItem gjson.Result) bool {
|
if systemInstructionResult := gjson.Get(out, "system_instruction"); systemInstructionResult.Exists() {
|
||||||
text := contentItem.Get("text").String()
|
systemInstr = systemInstructionResult.Raw
|
||||||
if builder.Len() > 0 && text != "" {
|
} else {
|
||||||
builder.WriteByte('\n')
|
systemInstr = `{"parts":[]}`
|
||||||
}
|
}
|
||||||
builder.WriteString(text)
|
|
||||||
return true
|
if contentArray.IsArray() {
|
||||||
})
|
contentArray.ForEach(func(_, contentItem gjson.Result) bool {
|
||||||
if !gjson.Get(out, "system_instruction").Exists() {
|
part := `{"text":""}`
|
||||||
systemInstr := `{"parts":[{"text":""}]}`
|
text := contentItem.Get("text").String()
|
||||||
systemInstr, _ = sjson.Set(systemInstr, "parts.0.text", builder.String())
|
part, _ = sjson.Set(part, "text", text)
|
||||||
|
systemInstr, _ = sjson.SetRaw(systemInstr, "parts.-1", part)
|
||||||
|
return true
|
||||||
|
})
|
||||||
|
} else if contentArray.Type == gjson.String {
|
||||||
|
part := `{"text":""}`
|
||||||
|
part, _ = sjson.Set(part, "text", contentArray.String())
|
||||||
|
systemInstr, _ = sjson.SetRaw(systemInstr, "parts.-1", part)
|
||||||
|
}
|
||||||
|
|
||||||
|
if systemInstr != `{"parts":[]}` {
|
||||||
out, _ = sjson.SetRaw(out, "system_instruction", systemInstr)
|
out, _ = sjson.SetRaw(out, "system_instruction", systemInstr)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
@@ -236,8 +246,22 @@ func ConvertOpenAIResponsesRequestToGemini(modelName string, inputRawJSON []byte
|
|||||||
})
|
})
|
||||||
|
|
||||||
flush()
|
flush()
|
||||||
}
|
} else if contentArray.Type == gjson.String {
|
||||||
|
effRole := "user"
|
||||||
|
if itemRole != "" {
|
||||||
|
switch strings.ToLower(itemRole) {
|
||||||
|
case "assistant", "model":
|
||||||
|
effRole = "model"
|
||||||
|
default:
|
||||||
|
effRole = strings.ToLower(itemRole)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
one := `{"role":"","parts":[{"text":""}]}`
|
||||||
|
one, _ = sjson.Set(one, "role", effRole)
|
||||||
|
one, _ = sjson.Set(one, "parts.0.text", contentArray.String())
|
||||||
|
out, _ = sjson.SetRaw(out, "contents.-1", one)
|
||||||
|
}
|
||||||
case "function_call":
|
case "function_call":
|
||||||
// Handle function calls - convert to model message with functionCall
|
// Handle function calls - convert to model message with functionCall
|
||||||
name := item.Get("name").String()
|
name := item.Get("name").String()
|
||||||
|
|||||||
@@ -608,18 +608,22 @@ func processMessages(messages gjson.Result, modelID, origin string) ([]KiroHisto
|
|||||||
|
|
||||||
if role == "user" {
|
if role == "user" {
|
||||||
userMsg, toolResults := BuildUserMessageStruct(msg, modelID, origin)
|
userMsg, toolResults := BuildUserMessageStruct(msg, modelID, origin)
|
||||||
|
// CRITICAL: Kiro API requires content to be non-empty for ALL user messages
|
||||||
|
// This includes both history messages and the current message.
|
||||||
|
// When user message contains only tool_result (no text), content will be empty.
|
||||||
|
// This commonly happens in compaction requests from OpenCode.
|
||||||
|
if strings.TrimSpace(userMsg.Content) == "" {
|
||||||
|
if len(toolResults) > 0 {
|
||||||
|
userMsg.Content = kirocommon.DefaultUserContentWithToolResults
|
||||||
|
} else {
|
||||||
|
userMsg.Content = kirocommon.DefaultUserContent
|
||||||
|
}
|
||||||
|
log.Debugf("kiro: user content was empty, using default: %s", userMsg.Content)
|
||||||
|
}
|
||||||
if isLastMessage {
|
if isLastMessage {
|
||||||
currentUserMsg = &userMsg
|
currentUserMsg = &userMsg
|
||||||
currentToolResults = toolResults
|
currentToolResults = toolResults
|
||||||
} else {
|
} else {
|
||||||
// CRITICAL: Kiro API requires content to be non-empty for history messages too
|
|
||||||
if strings.TrimSpace(userMsg.Content) == "" {
|
|
||||||
if len(toolResults) > 0 {
|
|
||||||
userMsg.Content = "Tool results provided."
|
|
||||||
} else {
|
|
||||||
userMsg.Content = "Continue"
|
|
||||||
}
|
|
||||||
}
|
|
||||||
// For history messages, embed tool results in context
|
// For history messages, embed tool results in context
|
||||||
if len(toolResults) > 0 {
|
if len(toolResults) > 0 {
|
||||||
userMsg.UserInputMessageContext = &KiroUserInputMessageContext{
|
userMsg.UserInputMessageContext = &KiroUserInputMessageContext{
|
||||||
|
|||||||
@@ -31,11 +31,23 @@ const (
|
|||||||
|
|
||||||
// DefaultAssistantContentWithTools is the fallback content for assistant messages
|
// DefaultAssistantContentWithTools is the fallback content for assistant messages
|
||||||
// that have tool_use but no text content. Kiro API requires non-empty content.
|
// that have tool_use but no text content. Kiro API requires non-empty content.
|
||||||
DefaultAssistantContentWithTools = "I'll help you with that."
|
// IMPORTANT: Use a minimal neutral string that the model won't mimic in responses.
|
||||||
|
// Previously "I'll help you with that." which caused the model to parrot it back.
|
||||||
|
DefaultAssistantContentWithTools = "."
|
||||||
|
|
||||||
// DefaultAssistantContent is the fallback content for assistant messages
|
// DefaultAssistantContent is the fallback content for assistant messages
|
||||||
// that have no content at all. Kiro API requires non-empty content.
|
// that have no content at all. Kiro API requires non-empty content.
|
||||||
DefaultAssistantContent = "I understand."
|
// IMPORTANT: Use a minimal neutral string that the model won't mimic in responses.
|
||||||
|
// Previously "I understand." which could leak into model behavior.
|
||||||
|
DefaultAssistantContent = "."
|
||||||
|
|
||||||
|
// DefaultUserContentWithToolResults is the fallback content for user messages
|
||||||
|
// that have only tool_result (no text). Kiro API requires non-empty content.
|
||||||
|
DefaultUserContentWithToolResults = "Tool results provided."
|
||||||
|
|
||||||
|
// DefaultUserContent is the fallback content for user messages
|
||||||
|
// that have no content at all. Kiro API requires non-empty content.
|
||||||
|
DefaultUserContent = "Continue"
|
||||||
|
|
||||||
// KiroAgenticSystemPrompt is injected only for -agentic models to prevent timeouts on large writes.
|
// KiroAgenticSystemPrompt is injected only for -agentic models to prevent timeouts on large writes.
|
||||||
// AWS Kiro API has a 2-3 minute timeout for large file write operations.
|
// AWS Kiro API has a 2-3 minute timeout for large file write operations.
|
||||||
|
|||||||
@@ -75,6 +75,10 @@ func ConvertClaudeRequestToOpenAI(modelName string, inputRawJSON []byte, stream
|
|||||||
out, _ = sjson.Set(out, "reasoning_effort", effort)
|
out, _ = sjson.Set(out, "reasoning_effort", effort)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
case "adaptive":
|
||||||
|
// Claude adaptive means "enable with max capacity"; keep it as highest level
|
||||||
|
// and let ApplyThinking normalize per target model capability.
|
||||||
|
out, _ = sjson.Set(out, "reasoning_effort", string(thinking.LevelXHigh))
|
||||||
case "disabled":
|
case "disabled":
|
||||||
if effort, ok := thinking.ConvertBudgetToLevel(0); ok && effort != "" {
|
if effort, ok := thinking.ConvertBudgetToLevel(0); ok && effort != "" {
|
||||||
out, _ = sjson.Set(out, "reasoning_effort", effort)
|
out, _ = sjson.Set(out, "reasoning_effort", effort)
|
||||||
|
|||||||
@@ -70,7 +70,7 @@ func ConvertOpenAIResponsesRequestToOpenAIChatCompletions(modelName string, inpu
|
|||||||
if role == "developer" {
|
if role == "developer" {
|
||||||
role = "user"
|
role = "user"
|
||||||
}
|
}
|
||||||
message := `{"role":"","content":""}`
|
message := `{"role":"","content":[]}`
|
||||||
message, _ = sjson.Set(message, "role", role)
|
message, _ = sjson.Set(message, "role", role)
|
||||||
|
|
||||||
if content := item.Get("content"); content.Exists() && content.IsArray() {
|
if content := item.Get("content"); content.Exists() && content.IsArray() {
|
||||||
@@ -84,20 +84,16 @@ func ConvertOpenAIResponsesRequestToOpenAIChatCompletions(modelName string, inpu
|
|||||||
}
|
}
|
||||||
|
|
||||||
switch contentType {
|
switch contentType {
|
||||||
case "input_text":
|
case "input_text", "output_text":
|
||||||
text := contentItem.Get("text").String()
|
text := contentItem.Get("text").String()
|
||||||
if messageContent != "" {
|
contentPart := `{"type":"text","text":""}`
|
||||||
messageContent += "\n" + text
|
contentPart, _ = sjson.Set(contentPart, "text", text)
|
||||||
} else {
|
message, _ = sjson.SetRaw(message, "content.-1", contentPart)
|
||||||
messageContent = text
|
case "input_image":
|
||||||
}
|
imageURL := contentItem.Get("image_url").String()
|
||||||
case "output_text":
|
contentPart := `{"type":"image_url","image_url":{"url":""}}`
|
||||||
text := contentItem.Get("text").String()
|
contentPart, _ = sjson.Set(contentPart, "image_url.url", imageURL)
|
||||||
if messageContent != "" {
|
message, _ = sjson.SetRaw(message, "content.-1", contentPart)
|
||||||
messageContent += "\n" + text
|
|
||||||
} else {
|
|
||||||
messageContent = text
|
|
||||||
}
|
|
||||||
}
|
}
|
||||||
return true
|
return true
|
||||||
})
|
})
|
||||||
|
|||||||
@@ -1,12 +1,90 @@
|
|||||||
package access
|
package access
|
||||||
|
|
||||||
import "errors"
|
import (
|
||||||
|
"fmt"
|
||||||
var (
|
"net/http"
|
||||||
// ErrNoCredentials indicates no recognizable credentials were supplied.
|
"strings"
|
||||||
ErrNoCredentials = errors.New("access: no credentials provided")
|
|
||||||
// ErrInvalidCredential signals that supplied credentials were rejected by a provider.
|
|
||||||
ErrInvalidCredential = errors.New("access: invalid credential")
|
|
||||||
// ErrNotHandled tells the manager to continue trying other providers.
|
|
||||||
ErrNotHandled = errors.New("access: not handled")
|
|
||||||
)
|
)
|
||||||
|
|
||||||
|
// AuthErrorCode classifies authentication failures.
|
||||||
|
type AuthErrorCode string
|
||||||
|
|
||||||
|
const (
|
||||||
|
AuthErrorCodeNoCredentials AuthErrorCode = "no_credentials"
|
||||||
|
AuthErrorCodeInvalidCredential AuthErrorCode = "invalid_credential"
|
||||||
|
AuthErrorCodeNotHandled AuthErrorCode = "not_handled"
|
||||||
|
AuthErrorCodeInternal AuthErrorCode = "internal_error"
|
||||||
|
)
|
||||||
|
|
||||||
|
// AuthError carries authentication failure details and HTTP status.
|
||||||
|
type AuthError struct {
|
||||||
|
Code AuthErrorCode
|
||||||
|
Message string
|
||||||
|
StatusCode int
|
||||||
|
Cause error
|
||||||
|
}
|
||||||
|
|
||||||
|
func (e *AuthError) Error() string {
|
||||||
|
if e == nil {
|
||||||
|
return ""
|
||||||
|
}
|
||||||
|
message := strings.TrimSpace(e.Message)
|
||||||
|
if message == "" {
|
||||||
|
message = "authentication error"
|
||||||
|
}
|
||||||
|
if e.Cause != nil {
|
||||||
|
return fmt.Sprintf("%s: %v", message, e.Cause)
|
||||||
|
}
|
||||||
|
return message
|
||||||
|
}
|
||||||
|
|
||||||
|
func (e *AuthError) Unwrap() error {
|
||||||
|
if e == nil {
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
return e.Cause
|
||||||
|
}
|
||||||
|
|
||||||
|
// HTTPStatusCode returns a safe fallback for missing status codes.
|
||||||
|
func (e *AuthError) HTTPStatusCode() int {
|
||||||
|
if e == nil || e.StatusCode <= 0 {
|
||||||
|
return http.StatusInternalServerError
|
||||||
|
}
|
||||||
|
return e.StatusCode
|
||||||
|
}
|
||||||
|
|
||||||
|
func newAuthError(code AuthErrorCode, message string, statusCode int, cause error) *AuthError {
|
||||||
|
return &AuthError{
|
||||||
|
Code: code,
|
||||||
|
Message: message,
|
||||||
|
StatusCode: statusCode,
|
||||||
|
Cause: cause,
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func NewNoCredentialsError() *AuthError {
|
||||||
|
return newAuthError(AuthErrorCodeNoCredentials, "Missing API key", http.StatusUnauthorized, nil)
|
||||||
|
}
|
||||||
|
|
||||||
|
func NewInvalidCredentialError() *AuthError {
|
||||||
|
return newAuthError(AuthErrorCodeInvalidCredential, "Invalid API key", http.StatusUnauthorized, nil)
|
||||||
|
}
|
||||||
|
|
||||||
|
func NewNotHandledError() *AuthError {
|
||||||
|
return newAuthError(AuthErrorCodeNotHandled, "authentication provider did not handle request", 0, nil)
|
||||||
|
}
|
||||||
|
|
||||||
|
func NewInternalAuthError(message string, cause error) *AuthError {
|
||||||
|
normalizedMessage := strings.TrimSpace(message)
|
||||||
|
if normalizedMessage == "" {
|
||||||
|
normalizedMessage = "Authentication service error"
|
||||||
|
}
|
||||||
|
return newAuthError(AuthErrorCodeInternal, normalizedMessage, http.StatusInternalServerError, cause)
|
||||||
|
}
|
||||||
|
|
||||||
|
func IsAuthErrorCode(authErr *AuthError, code AuthErrorCode) bool {
|
||||||
|
if authErr == nil {
|
||||||
|
return false
|
||||||
|
}
|
||||||
|
return authErr.Code == code
|
||||||
|
}
|
||||||
|
|||||||
@@ -2,7 +2,6 @@ package access
|
|||||||
|
|
||||||
import (
|
import (
|
||||||
"context"
|
"context"
|
||||||
"errors"
|
|
||||||
"net/http"
|
"net/http"
|
||||||
"sync"
|
"sync"
|
||||||
)
|
)
|
||||||
@@ -43,7 +42,7 @@ func (m *Manager) Providers() []Provider {
|
|||||||
}
|
}
|
||||||
|
|
||||||
// Authenticate evaluates providers until one succeeds.
|
// Authenticate evaluates providers until one succeeds.
|
||||||
func (m *Manager) Authenticate(ctx context.Context, r *http.Request) (*Result, error) {
|
func (m *Manager) Authenticate(ctx context.Context, r *http.Request) (*Result, *AuthError) {
|
||||||
if m == nil {
|
if m == nil {
|
||||||
return nil, nil
|
return nil, nil
|
||||||
}
|
}
|
||||||
@@ -61,29 +60,29 @@ func (m *Manager) Authenticate(ctx context.Context, r *http.Request) (*Result, e
|
|||||||
if provider == nil {
|
if provider == nil {
|
||||||
continue
|
continue
|
||||||
}
|
}
|
||||||
res, err := provider.Authenticate(ctx, r)
|
res, authErr := provider.Authenticate(ctx, r)
|
||||||
if err == nil {
|
if authErr == nil {
|
||||||
return res, nil
|
return res, nil
|
||||||
}
|
}
|
||||||
if errors.Is(err, ErrNotHandled) {
|
if IsAuthErrorCode(authErr, AuthErrorCodeNotHandled) {
|
||||||
continue
|
continue
|
||||||
}
|
}
|
||||||
if errors.Is(err, ErrNoCredentials) {
|
if IsAuthErrorCode(authErr, AuthErrorCodeNoCredentials) {
|
||||||
missing = true
|
missing = true
|
||||||
continue
|
continue
|
||||||
}
|
}
|
||||||
if errors.Is(err, ErrInvalidCredential) {
|
if IsAuthErrorCode(authErr, AuthErrorCodeInvalidCredential) {
|
||||||
invalid = true
|
invalid = true
|
||||||
continue
|
continue
|
||||||
}
|
}
|
||||||
return nil, err
|
return nil, authErr
|
||||||
}
|
}
|
||||||
|
|
||||||
if invalid {
|
if invalid {
|
||||||
return nil, ErrInvalidCredential
|
return nil, NewInvalidCredentialError()
|
||||||
}
|
}
|
||||||
if missing {
|
if missing {
|
||||||
return nil, ErrNoCredentials
|
return nil, NewNoCredentialsError()
|
||||||
}
|
}
|
||||||
return nil, ErrNoCredentials
|
return nil, NewNoCredentialsError()
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -2,17 +2,15 @@ package access
|
|||||||
|
|
||||||
import (
|
import (
|
||||||
"context"
|
"context"
|
||||||
"fmt"
|
|
||||||
"net/http"
|
"net/http"
|
||||||
|
"strings"
|
||||||
"sync"
|
"sync"
|
||||||
|
|
||||||
"github.com/router-for-me/CLIProxyAPI/v6/sdk/config"
|
|
||||||
)
|
)
|
||||||
|
|
||||||
// Provider validates credentials for incoming requests.
|
// Provider validates credentials for incoming requests.
|
||||||
type Provider interface {
|
type Provider interface {
|
||||||
Identifier() string
|
Identifier() string
|
||||||
Authenticate(ctx context.Context, r *http.Request) (*Result, error)
|
Authenticate(ctx context.Context, r *http.Request) (*Result, *AuthError)
|
||||||
}
|
}
|
||||||
|
|
||||||
// Result conveys authentication outcome.
|
// Result conveys authentication outcome.
|
||||||
@@ -22,66 +20,64 @@ type Result struct {
|
|||||||
Metadata map[string]string
|
Metadata map[string]string
|
||||||
}
|
}
|
||||||
|
|
||||||
// ProviderFactory builds a provider from configuration data.
|
|
||||||
type ProviderFactory func(cfg *config.AccessProvider, root *config.SDKConfig) (Provider, error)
|
|
||||||
|
|
||||||
var (
|
var (
|
||||||
registryMu sync.RWMutex
|
registryMu sync.RWMutex
|
||||||
registry = make(map[string]ProviderFactory)
|
registry = make(map[string]Provider)
|
||||||
|
order []string
|
||||||
)
|
)
|
||||||
|
|
||||||
// RegisterProvider registers a provider factory for a given type identifier.
|
// RegisterProvider registers a pre-built provider instance for a given type identifier.
|
||||||
func RegisterProvider(typ string, factory ProviderFactory) {
|
func RegisterProvider(typ string, provider Provider) {
|
||||||
if typ == "" || factory == nil {
|
normalizedType := strings.TrimSpace(typ)
|
||||||
|
if normalizedType == "" || provider == nil {
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
registryMu.Lock()
|
registryMu.Lock()
|
||||||
registry[typ] = factory
|
if _, exists := registry[normalizedType]; !exists {
|
||||||
|
order = append(order, normalizedType)
|
||||||
|
}
|
||||||
|
registry[normalizedType] = provider
|
||||||
registryMu.Unlock()
|
registryMu.Unlock()
|
||||||
}
|
}
|
||||||
|
|
||||||
func BuildProvider(cfg *config.AccessProvider, root *config.SDKConfig) (Provider, error) {
|
// UnregisterProvider removes a provider by type identifier.
|
||||||
if cfg == nil {
|
func UnregisterProvider(typ string) {
|
||||||
return nil, fmt.Errorf("access: nil provider config")
|
normalizedType := strings.TrimSpace(typ)
|
||||||
|
if normalizedType == "" {
|
||||||
|
return
|
||||||
}
|
}
|
||||||
registryMu.RLock()
|
registryMu.Lock()
|
||||||
factory, ok := registry[cfg.Type]
|
if _, exists := registry[normalizedType]; !exists {
|
||||||
registryMu.RUnlock()
|
registryMu.Unlock()
|
||||||
if !ok {
|
return
|
||||||
return nil, fmt.Errorf("access: provider type %q is not registered", cfg.Type)
|
|
||||||
}
|
}
|
||||||
provider, err := factory(cfg, root)
|
delete(registry, normalizedType)
|
||||||
if err != nil {
|
for index := range order {
|
||||||
return nil, fmt.Errorf("access: failed to build provider %q: %w", cfg.Name, err)
|
if order[index] != normalizedType {
|
||||||
}
|
|
||||||
return provider, nil
|
|
||||||
}
|
|
||||||
|
|
||||||
// BuildProviders constructs providers declared in configuration.
|
|
||||||
func BuildProviders(root *config.SDKConfig) ([]Provider, error) {
|
|
||||||
if root == nil {
|
|
||||||
return nil, nil
|
|
||||||
}
|
|
||||||
providers := make([]Provider, 0, len(root.Access.Providers))
|
|
||||||
for i := range root.Access.Providers {
|
|
||||||
providerCfg := &root.Access.Providers[i]
|
|
||||||
if providerCfg.Type == "" {
|
|
||||||
continue
|
continue
|
||||||
}
|
}
|
||||||
provider, err := BuildProvider(providerCfg, root)
|
order = append(order[:index], order[index+1:]...)
|
||||||
if err != nil {
|
break
|
||||||
return nil, err
|
}
|
||||||
|
registryMu.Unlock()
|
||||||
|
}
|
||||||
|
|
||||||
|
// RegisteredProviders returns the global provider instances in registration order.
|
||||||
|
func RegisteredProviders() []Provider {
|
||||||
|
registryMu.RLock()
|
||||||
|
if len(order) == 0 {
|
||||||
|
registryMu.RUnlock()
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
providers := make([]Provider, 0, len(order))
|
||||||
|
for _, providerType := range order {
|
||||||
|
provider, exists := registry[providerType]
|
||||||
|
if !exists || provider == nil {
|
||||||
|
continue
|
||||||
}
|
}
|
||||||
providers = append(providers, provider)
|
providers = append(providers, provider)
|
||||||
}
|
}
|
||||||
if len(providers) == 0 {
|
registryMu.RUnlock()
|
||||||
if inline := config.MakeInlineAPIKeyProvider(root.APIKeys); inline != nil {
|
return providers
|
||||||
provider, err := BuildProvider(inline, root)
|
|
||||||
if err != nil {
|
|
||||||
return nil, err
|
|
||||||
}
|
|
||||||
providers = append(providers, provider)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
return providers, nil
|
|
||||||
}
|
}
|
||||||
|
|||||||
47
sdk/access/types.go
Normal file
47
sdk/access/types.go
Normal file
@@ -0,0 +1,47 @@
|
|||||||
|
package access
|
||||||
|
|
||||||
|
// AccessConfig groups request authentication providers.
|
||||||
|
type AccessConfig struct {
|
||||||
|
// Providers lists configured authentication providers.
|
||||||
|
Providers []AccessProvider `yaml:"providers,omitempty" json:"providers,omitempty"`
|
||||||
|
}
|
||||||
|
|
||||||
|
// AccessProvider describes a request authentication provider entry.
|
||||||
|
type AccessProvider struct {
|
||||||
|
// Name is the instance identifier for the provider.
|
||||||
|
Name string `yaml:"name" json:"name"`
|
||||||
|
|
||||||
|
// Type selects the provider implementation registered via the SDK.
|
||||||
|
Type string `yaml:"type" json:"type"`
|
||||||
|
|
||||||
|
// SDK optionally names a third-party SDK module providing this provider.
|
||||||
|
SDK string `yaml:"sdk,omitempty" json:"sdk,omitempty"`
|
||||||
|
|
||||||
|
// APIKeys lists inline keys for providers that require them.
|
||||||
|
APIKeys []string `yaml:"api-keys,omitempty" json:"api-keys,omitempty"`
|
||||||
|
|
||||||
|
// Config passes provider-specific options to the implementation.
|
||||||
|
Config map[string]any `yaml:"config,omitempty" json:"config,omitempty"`
|
||||||
|
}
|
||||||
|
|
||||||
|
const (
|
||||||
|
// AccessProviderTypeConfigAPIKey is the built-in provider validating inline API keys.
|
||||||
|
AccessProviderTypeConfigAPIKey = "config-api-key"
|
||||||
|
|
||||||
|
// DefaultAccessProviderName is applied when no provider name is supplied.
|
||||||
|
DefaultAccessProviderName = "config-inline"
|
||||||
|
)
|
||||||
|
|
||||||
|
// MakeInlineAPIKeyProvider constructs an inline API key provider configuration.
|
||||||
|
// It returns nil when no keys are supplied.
|
||||||
|
func MakeInlineAPIKeyProvider(keys []string) *AccessProvider {
|
||||||
|
if len(keys) == 0 {
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
provider := &AccessProvider{
|
||||||
|
Name: DefaultAccessProviderName,
|
||||||
|
Type: AccessProviderTypeConfigAPIKey,
|
||||||
|
APIKeys: append([]string(nil), keys...),
|
||||||
|
}
|
||||||
|
return provider
|
||||||
|
}
|
||||||
@@ -18,6 +18,7 @@ type ManagementTokenRequester interface {
|
|||||||
RequestCodexToken(*gin.Context)
|
RequestCodexToken(*gin.Context)
|
||||||
RequestAntigravityToken(*gin.Context)
|
RequestAntigravityToken(*gin.Context)
|
||||||
RequestQwenToken(*gin.Context)
|
RequestQwenToken(*gin.Context)
|
||||||
|
RequestKimiToken(*gin.Context)
|
||||||
RequestIFlowToken(*gin.Context)
|
RequestIFlowToken(*gin.Context)
|
||||||
RequestIFlowCookieToken(*gin.Context)
|
RequestIFlowCookieToken(*gin.Context)
|
||||||
GetAuthStatus(c *gin.Context)
|
GetAuthStatus(c *gin.Context)
|
||||||
@@ -55,6 +56,10 @@ func (m *managementTokenRequester) RequestQwenToken(c *gin.Context) {
|
|||||||
m.handler.RequestQwenToken(c)
|
m.handler.RequestQwenToken(c)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func (m *managementTokenRequester) RequestKimiToken(c *gin.Context) {
|
||||||
|
m.handler.RequestKimiToken(c)
|
||||||
|
}
|
||||||
|
|
||||||
func (m *managementTokenRequester) RequestIFlowToken(c *gin.Context) {
|
func (m *managementTokenRequester) RequestIFlowToken(c *gin.Context) {
|
||||||
m.handler.RequestIFlowToken(c)
|
m.handler.RequestIFlowToken(c)
|
||||||
}
|
}
|
||||||
|
|||||||
123
sdk/auth/kimi.go
Normal file
123
sdk/auth/kimi.go
Normal file
@@ -0,0 +1,123 @@
|
|||||||
|
package auth
|
||||||
|
|
||||||
|
import (
|
||||||
|
"context"
|
||||||
|
"fmt"
|
||||||
|
"strings"
|
||||||
|
"time"
|
||||||
|
|
||||||
|
"github.com/router-for-me/CLIProxyAPI/v6/internal/auth/kimi"
|
||||||
|
"github.com/router-for-me/CLIProxyAPI/v6/internal/browser"
|
||||||
|
"github.com/router-for-me/CLIProxyAPI/v6/internal/config"
|
||||||
|
coreauth "github.com/router-for-me/CLIProxyAPI/v6/sdk/cliproxy/auth"
|
||||||
|
log "github.com/sirupsen/logrus"
|
||||||
|
)
|
||||||
|
|
||||||
|
// kimiRefreshLead is the duration before token expiry when refresh should occur.
|
||||||
|
var kimiRefreshLead = 5 * time.Minute
|
||||||
|
|
||||||
|
// KimiAuthenticator implements the OAuth device flow login for Kimi (Moonshot AI).
|
||||||
|
type KimiAuthenticator struct{}
|
||||||
|
|
||||||
|
// NewKimiAuthenticator constructs a new Kimi authenticator.
|
||||||
|
func NewKimiAuthenticator() Authenticator {
|
||||||
|
return &KimiAuthenticator{}
|
||||||
|
}
|
||||||
|
|
||||||
|
// Provider returns the provider key for kimi.
|
||||||
|
func (KimiAuthenticator) Provider() string {
|
||||||
|
return "kimi"
|
||||||
|
}
|
||||||
|
|
||||||
|
// RefreshLead returns the duration before token expiry when refresh should occur.
|
||||||
|
// Kimi tokens expire and need to be refreshed before expiry.
|
||||||
|
func (KimiAuthenticator) RefreshLead() *time.Duration {
|
||||||
|
return &kimiRefreshLead
|
||||||
|
}
|
||||||
|
|
||||||
|
// Login initiates the Kimi device flow authentication.
|
||||||
|
func (a KimiAuthenticator) Login(ctx context.Context, cfg *config.Config, opts *LoginOptions) (*coreauth.Auth, error) {
|
||||||
|
if cfg == nil {
|
||||||
|
return nil, fmt.Errorf("cliproxy auth: configuration is required")
|
||||||
|
}
|
||||||
|
if opts == nil {
|
||||||
|
opts = &LoginOptions{}
|
||||||
|
}
|
||||||
|
|
||||||
|
authSvc := kimi.NewKimiAuth(cfg)
|
||||||
|
|
||||||
|
// Start the device flow
|
||||||
|
fmt.Println("Starting Kimi authentication...")
|
||||||
|
deviceCode, err := authSvc.StartDeviceFlow(ctx)
|
||||||
|
if err != nil {
|
||||||
|
return nil, fmt.Errorf("kimi: failed to start device flow: %w", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Display the verification URL
|
||||||
|
verificationURL := deviceCode.VerificationURIComplete
|
||||||
|
if verificationURL == "" {
|
||||||
|
verificationURL = deviceCode.VerificationURI
|
||||||
|
}
|
||||||
|
|
||||||
|
fmt.Printf("\nTo authenticate, please visit:\n%s\n\n", verificationURL)
|
||||||
|
if deviceCode.UserCode != "" {
|
||||||
|
fmt.Printf("User code: %s\n\n", deviceCode.UserCode)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Try to open the browser automatically
|
||||||
|
if !opts.NoBrowser {
|
||||||
|
if browser.IsAvailable() {
|
||||||
|
if errOpen := browser.OpenURL(verificationURL); errOpen != nil {
|
||||||
|
log.Warnf("Failed to open browser automatically: %v", errOpen)
|
||||||
|
} else {
|
||||||
|
fmt.Println("Browser opened automatically.")
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
fmt.Println("Waiting for authorization...")
|
||||||
|
if deviceCode.ExpiresIn > 0 {
|
||||||
|
fmt.Printf("(This will timeout in %d seconds if not authorized)\n", deviceCode.ExpiresIn)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Wait for user authorization
|
||||||
|
authBundle, err := authSvc.WaitForAuthorization(ctx, deviceCode)
|
||||||
|
if err != nil {
|
||||||
|
return nil, fmt.Errorf("kimi: %w", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Create the token storage
|
||||||
|
tokenStorage := authSvc.CreateTokenStorage(authBundle)
|
||||||
|
|
||||||
|
// Build metadata with token information
|
||||||
|
metadata := map[string]any{
|
||||||
|
"type": "kimi",
|
||||||
|
"access_token": authBundle.TokenData.AccessToken,
|
||||||
|
"refresh_token": authBundle.TokenData.RefreshToken,
|
||||||
|
"token_type": authBundle.TokenData.TokenType,
|
||||||
|
"scope": authBundle.TokenData.Scope,
|
||||||
|
"timestamp": time.Now().UnixMilli(),
|
||||||
|
}
|
||||||
|
|
||||||
|
if authBundle.TokenData.ExpiresAt > 0 {
|
||||||
|
exp := time.Unix(authBundle.TokenData.ExpiresAt, 0).UTC().Format(time.RFC3339)
|
||||||
|
metadata["expired"] = exp
|
||||||
|
}
|
||||||
|
if strings.TrimSpace(authBundle.DeviceID) != "" {
|
||||||
|
metadata["device_id"] = strings.TrimSpace(authBundle.DeviceID)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Generate a unique filename
|
||||||
|
fileName := fmt.Sprintf("kimi-%d.json", time.Now().UnixMilli())
|
||||||
|
|
||||||
|
fmt.Println("\nKimi authentication successful!")
|
||||||
|
|
||||||
|
return &coreauth.Auth{
|
||||||
|
ID: fileName,
|
||||||
|
Provider: a.Provider(),
|
||||||
|
FileName: fileName,
|
||||||
|
Label: "Kimi User",
|
||||||
|
Storage: tokenStorage,
|
||||||
|
Metadata: metadata,
|
||||||
|
}, nil
|
||||||
|
}
|
||||||
@@ -14,6 +14,7 @@ func init() {
|
|||||||
registerRefreshLead("gemini", func() Authenticator { return NewGeminiAuthenticator() })
|
registerRefreshLead("gemini", func() Authenticator { return NewGeminiAuthenticator() })
|
||||||
registerRefreshLead("gemini-cli", func() Authenticator { return NewGeminiAuthenticator() })
|
registerRefreshLead("gemini-cli", func() Authenticator { return NewGeminiAuthenticator() })
|
||||||
registerRefreshLead("antigravity", func() Authenticator { return NewAntigravityAuthenticator() })
|
registerRefreshLead("antigravity", func() Authenticator { return NewAntigravityAuthenticator() })
|
||||||
|
registerRefreshLead("kimi", func() Authenticator { return NewKimiAuthenticator() })
|
||||||
registerRefreshLead("kiro", func() Authenticator { return NewKiroAuthenticator() })
|
registerRefreshLead("kiro", func() Authenticator { return NewKiroAuthenticator() })
|
||||||
registerRefreshLead("github-copilot", func() Authenticator { return NewGitHubCopilotAuthenticator() })
|
registerRefreshLead("github-copilot", func() Authenticator { return NewGitHubCopilotAuthenticator() })
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -607,6 +607,9 @@ func (m *Manager) executeMixedOnce(ctx context.Context, providers []string, req
|
|||||||
result.RetryAfter = ra
|
result.RetryAfter = ra
|
||||||
}
|
}
|
||||||
m.MarkResult(execCtx, result)
|
m.MarkResult(execCtx, result)
|
||||||
|
if isRequestInvalidError(errExec) {
|
||||||
|
return cliproxyexecutor.Response{}, errExec
|
||||||
|
}
|
||||||
lastErr = errExec
|
lastErr = errExec
|
||||||
continue
|
continue
|
||||||
}
|
}
|
||||||
@@ -660,6 +663,9 @@ func (m *Manager) executeCountMixedOnce(ctx context.Context, providers []string,
|
|||||||
result.RetryAfter = ra
|
result.RetryAfter = ra
|
||||||
}
|
}
|
||||||
m.MarkResult(execCtx, result)
|
m.MarkResult(execCtx, result)
|
||||||
|
if isRequestInvalidError(errExec) {
|
||||||
|
return cliproxyexecutor.Response{}, errExec
|
||||||
|
}
|
||||||
lastErr = errExec
|
lastErr = errExec
|
||||||
continue
|
continue
|
||||||
}
|
}
|
||||||
@@ -711,6 +717,9 @@ func (m *Manager) executeStreamMixedOnce(ctx context.Context, providers []string
|
|||||||
result := Result{AuthID: auth.ID, Provider: provider, Model: routeModel, Success: false, Error: rerr}
|
result := Result{AuthID: auth.ID, Provider: provider, Model: routeModel, Success: false, Error: rerr}
|
||||||
result.RetryAfter = retryAfterFromError(errStream)
|
result.RetryAfter = retryAfterFromError(errStream)
|
||||||
m.MarkResult(execCtx, result)
|
m.MarkResult(execCtx, result)
|
||||||
|
if isRequestInvalidError(errStream) {
|
||||||
|
return nil, errStream
|
||||||
|
}
|
||||||
lastErr = errStream
|
lastErr = errStream
|
||||||
continue
|
continue
|
||||||
}
|
}
|
||||||
@@ -1110,6 +1119,9 @@ func (m *Manager) shouldRetryAfterError(err error, attempt int, providers []stri
|
|||||||
if status := statusCodeFromError(err); status == http.StatusOK {
|
if status := statusCodeFromError(err); status == http.StatusOK {
|
||||||
return 0, false
|
return 0, false
|
||||||
}
|
}
|
||||||
|
if isRequestInvalidError(err) {
|
||||||
|
return 0, false
|
||||||
|
}
|
||||||
wait, found := m.closestCooldownWait(providers, model, attempt)
|
wait, found := m.closestCooldownWait(providers, model, attempt)
|
||||||
if !found || wait > maxWait {
|
if !found || wait > maxWait {
|
||||||
return 0, false
|
return 0, false
|
||||||
@@ -1299,7 +1311,7 @@ func updateAggregatedAvailability(auth *Auth, now time.Time) {
|
|||||||
stateUnavailable = true
|
stateUnavailable = true
|
||||||
} else if state.Unavailable {
|
} else if state.Unavailable {
|
||||||
if state.NextRetryAfter.IsZero() {
|
if state.NextRetryAfter.IsZero() {
|
||||||
stateUnavailable = true
|
stateUnavailable = false
|
||||||
} else if state.NextRetryAfter.After(now) {
|
} else if state.NextRetryAfter.After(now) {
|
||||||
stateUnavailable = true
|
stateUnavailable = true
|
||||||
if earliestRetry.IsZero() || state.NextRetryAfter.Before(earliestRetry) {
|
if earliestRetry.IsZero() || state.NextRetryAfter.Before(earliestRetry) {
|
||||||
@@ -1430,6 +1442,21 @@ func statusCodeFromResult(err *Error) int {
|
|||||||
return err.StatusCode()
|
return err.StatusCode()
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// isRequestInvalidError returns true if the error represents a client request
|
||||||
|
// error that should not be retried. Specifically, it checks for 400 Bad Request
|
||||||
|
// with "invalid_request_error" in the message, indicating the request itself is
|
||||||
|
// malformed and switching to a different auth will not help.
|
||||||
|
func isRequestInvalidError(err error) bool {
|
||||||
|
if err == nil {
|
||||||
|
return false
|
||||||
|
}
|
||||||
|
status := statusCodeFromError(err)
|
||||||
|
if status != http.StatusBadRequest {
|
||||||
|
return false
|
||||||
|
}
|
||||||
|
return strings.Contains(err.Error(), "invalid_request_error")
|
||||||
|
}
|
||||||
|
|
||||||
func applyAuthFailureState(auth *Auth, resultErr *Error, retryAfter *time.Duration, now time.Time) {
|
func applyAuthFailureState(auth *Auth, resultErr *Error, retryAfter *time.Duration, now time.Time) {
|
||||||
if auth == nil {
|
if auth == nil {
|
||||||
return
|
return
|
||||||
|
|||||||
61
sdk/cliproxy/auth/conductor_availability_test.go
Normal file
61
sdk/cliproxy/auth/conductor_availability_test.go
Normal file
@@ -0,0 +1,61 @@
|
|||||||
|
package auth
|
||||||
|
|
||||||
|
import (
|
||||||
|
"testing"
|
||||||
|
"time"
|
||||||
|
)
|
||||||
|
|
||||||
|
func TestUpdateAggregatedAvailability_UnavailableWithoutNextRetryDoesNotBlockAuth(t *testing.T) {
|
||||||
|
t.Parallel()
|
||||||
|
|
||||||
|
now := time.Now()
|
||||||
|
model := "test-model"
|
||||||
|
auth := &Auth{
|
||||||
|
ID: "a",
|
||||||
|
ModelStates: map[string]*ModelState{
|
||||||
|
model: {
|
||||||
|
Status: StatusError,
|
||||||
|
Unavailable: true,
|
||||||
|
},
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
updateAggregatedAvailability(auth, now)
|
||||||
|
|
||||||
|
if auth.Unavailable {
|
||||||
|
t.Fatalf("auth.Unavailable = true, want false")
|
||||||
|
}
|
||||||
|
if !auth.NextRetryAfter.IsZero() {
|
||||||
|
t.Fatalf("auth.NextRetryAfter = %v, want zero", auth.NextRetryAfter)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestUpdateAggregatedAvailability_FutureNextRetryBlocksAuth(t *testing.T) {
|
||||||
|
t.Parallel()
|
||||||
|
|
||||||
|
now := time.Now()
|
||||||
|
model := "test-model"
|
||||||
|
next := now.Add(5 * time.Minute)
|
||||||
|
auth := &Auth{
|
||||||
|
ID: "a",
|
||||||
|
ModelStates: map[string]*ModelState{
|
||||||
|
model: {
|
||||||
|
Status: StatusError,
|
||||||
|
Unavailable: true,
|
||||||
|
NextRetryAfter: next,
|
||||||
|
},
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
updateAggregatedAvailability(auth, now)
|
||||||
|
|
||||||
|
if !auth.Unavailable {
|
||||||
|
t.Fatalf("auth.Unavailable = false, want true")
|
||||||
|
}
|
||||||
|
if auth.NextRetryAfter.IsZero() {
|
||||||
|
t.Fatalf("auth.NextRetryAfter = zero, want %v", next)
|
||||||
|
}
|
||||||
|
if auth.NextRetryAfter.Sub(next) > time.Second || next.Sub(auth.NextRetryAfter) > time.Second {
|
||||||
|
t.Fatalf("auth.NextRetryAfter = %v, want %v", auth.NextRetryAfter, next)
|
||||||
|
}
|
||||||
|
}
|
||||||
@@ -221,7 +221,7 @@ func modelAliasChannel(auth *Auth) string {
|
|||||||
// and auth kind. Returns empty string if the provider/authKind combination doesn't support
|
// and auth kind. Returns empty string if the provider/authKind combination doesn't support
|
||||||
// OAuth model alias (e.g., API key authentication).
|
// OAuth model alias (e.g., API key authentication).
|
||||||
//
|
//
|
||||||
// Supported channels: gemini-cli, vertex, aistudio, antigravity, claude, codex, qwen, iflow, kiro, github-copilot.
|
// Supported channels: gemini-cli, vertex, aistudio, antigravity, claude, codex, qwen, iflow, kiro, github-copilot, kimi.
|
||||||
func OAuthModelAliasChannel(provider, authKind string) string {
|
func OAuthModelAliasChannel(provider, authKind string) string {
|
||||||
provider = strings.ToLower(strings.TrimSpace(provider))
|
provider = strings.ToLower(strings.TrimSpace(provider))
|
||||||
authKind = strings.ToLower(strings.TrimSpace(authKind))
|
authKind = strings.ToLower(strings.TrimSpace(authKind))
|
||||||
@@ -245,7 +245,7 @@ func OAuthModelAliasChannel(provider, authKind string) string {
|
|||||||
return ""
|
return ""
|
||||||
}
|
}
|
||||||
return "codex"
|
return "codex"
|
||||||
case "gemini-cli", "aistudio", "antigravity", "qwen", "iflow", "kiro", "github-copilot":
|
case "gemini-cli", "aistudio", "antigravity", "qwen", "iflow", "kiro", "github-copilot", "kimi":
|
||||||
return provider
|
return provider
|
||||||
default:
|
default:
|
||||||
return ""
|
return ""
|
||||||
|
|||||||
@@ -79,6 +79,15 @@ func TestResolveOAuthUpstreamModel_SuffixPreservation(t *testing.T) {
|
|||||||
input: "gemini-2.5-pro(none)",
|
input: "gemini-2.5-pro(none)",
|
||||||
want: "gemini-2.5-pro-exp-03-25(none)",
|
want: "gemini-2.5-pro-exp-03-25(none)",
|
||||||
},
|
},
|
||||||
|
{
|
||||||
|
name: "kimi suffix preserved",
|
||||||
|
aliases: map[string][]internalconfig.OAuthModelAlias{
|
||||||
|
"kimi": {{Name: "kimi-k2.5", Alias: "k2.5"}},
|
||||||
|
},
|
||||||
|
channel: "kimi",
|
||||||
|
input: "k2.5(high)",
|
||||||
|
want: "kimi-k2.5(high)",
|
||||||
|
},
|
||||||
{
|
{
|
||||||
name: "case insensitive alias lookup with suffix",
|
name: "case insensitive alias lookup with suffix",
|
||||||
aliases: map[string][]internalconfig.OAuthModelAlias{
|
aliases: map[string][]internalconfig.OAuthModelAlias{
|
||||||
@@ -161,6 +170,8 @@ func createAuthForChannel(channel string) *Auth {
|
|||||||
return &Auth{Provider: "qwen"}
|
return &Auth{Provider: "qwen"}
|
||||||
case "iflow":
|
case "iflow":
|
||||||
return &Auth{Provider: "iflow"}
|
return &Auth{Provider: "iflow"}
|
||||||
|
case "kimi":
|
||||||
|
return &Auth{Provider: "kimi"}
|
||||||
case "kiro":
|
case "kiro":
|
||||||
return &Auth{Provider: "kiro"}
|
return &Auth{Provider: "kiro"}
|
||||||
default:
|
default:
|
||||||
@@ -168,6 +179,14 @@ func createAuthForChannel(channel string) *Auth {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func TestOAuthModelAliasChannel_Kimi(t *testing.T) {
|
||||||
|
t.Parallel()
|
||||||
|
|
||||||
|
if got := OAuthModelAliasChannel("kimi", "oauth"); got != "kimi" {
|
||||||
|
t.Fatalf("OAuthModelAliasChannel() = %q, want %q", got, "kimi")
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
func TestApplyOAuthModelAlias_SuffixPreservation(t *testing.T) {
|
func TestApplyOAuthModelAlias_SuffixPreservation(t *testing.T) {
|
||||||
t.Parallel()
|
t.Parallel()
|
||||||
|
|
||||||
|
|||||||
@@ -12,6 +12,7 @@ import (
|
|||||||
"sync"
|
"sync"
|
||||||
"time"
|
"time"
|
||||||
|
|
||||||
|
"github.com/router-for-me/CLIProxyAPI/v6/internal/thinking"
|
||||||
cliproxyexecutor "github.com/router-for-me/CLIProxyAPI/v6/sdk/cliproxy/executor"
|
cliproxyexecutor "github.com/router-for-me/CLIProxyAPI/v6/sdk/cliproxy/executor"
|
||||||
)
|
)
|
||||||
|
|
||||||
@@ -19,6 +20,7 @@ import (
|
|||||||
type RoundRobinSelector struct {
|
type RoundRobinSelector struct {
|
||||||
mu sync.Mutex
|
mu sync.Mutex
|
||||||
cursors map[string]int
|
cursors map[string]int
|
||||||
|
maxKeys int
|
||||||
}
|
}
|
||||||
|
|
||||||
// FillFirstSelector selects the first available credential (deterministic ordering).
|
// FillFirstSelector selects the first available credential (deterministic ordering).
|
||||||
@@ -119,6 +121,19 @@ func authPriority(auth *Auth) int {
|
|||||||
return parsed
|
return parsed
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func canonicalModelKey(model string) string {
|
||||||
|
model = strings.TrimSpace(model)
|
||||||
|
if model == "" {
|
||||||
|
return ""
|
||||||
|
}
|
||||||
|
parsed := thinking.ParseSuffix(model)
|
||||||
|
modelName := strings.TrimSpace(parsed.ModelName)
|
||||||
|
if modelName == "" {
|
||||||
|
return model
|
||||||
|
}
|
||||||
|
return modelName
|
||||||
|
}
|
||||||
|
|
||||||
func collectAvailableByPriority(auths []*Auth, model string, now time.Time) (available map[int][]*Auth, cooldownCount int, earliest time.Time) {
|
func collectAvailableByPriority(auths []*Auth, model string, now time.Time) (available map[int][]*Auth, cooldownCount int, earliest time.Time) {
|
||||||
available = make(map[int][]*Auth)
|
available = make(map[int][]*Auth)
|
||||||
for i := 0; i < len(auths); i++ {
|
for i := 0; i < len(auths); i++ {
|
||||||
@@ -185,11 +200,18 @@ func (s *RoundRobinSelector) Pick(ctx context.Context, provider, model string, o
|
|||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
key := provider + ":" + model
|
key := provider + ":" + canonicalModelKey(model)
|
||||||
s.mu.Lock()
|
s.mu.Lock()
|
||||||
if s.cursors == nil {
|
if s.cursors == nil {
|
||||||
s.cursors = make(map[string]int)
|
s.cursors = make(map[string]int)
|
||||||
}
|
}
|
||||||
|
limit := s.maxKeys
|
||||||
|
if limit <= 0 {
|
||||||
|
limit = 4096
|
||||||
|
}
|
||||||
|
if _, ok := s.cursors[key]; !ok && len(s.cursors) >= limit {
|
||||||
|
s.cursors = make(map[string]int)
|
||||||
|
}
|
||||||
index := s.cursors[key]
|
index := s.cursors[key]
|
||||||
|
|
||||||
if index >= 2_147_483_640 {
|
if index >= 2_147_483_640 {
|
||||||
@@ -223,7 +245,14 @@ func isAuthBlockedForModel(auth *Auth, model string, now time.Time) (bool, block
|
|||||||
}
|
}
|
||||||
if model != "" {
|
if model != "" {
|
||||||
if len(auth.ModelStates) > 0 {
|
if len(auth.ModelStates) > 0 {
|
||||||
if state, ok := auth.ModelStates[model]; ok && state != nil {
|
state, ok := auth.ModelStates[model]
|
||||||
|
if (!ok || state == nil) && model != "" {
|
||||||
|
baseModel := canonicalModelKey(model)
|
||||||
|
if baseModel != "" && baseModel != model {
|
||||||
|
state, ok = auth.ModelStates[baseModel]
|
||||||
|
}
|
||||||
|
}
|
||||||
|
if ok && state != nil {
|
||||||
if state.Status == StatusDisabled {
|
if state.Status == StatusDisabled {
|
||||||
return true, blockReasonDisabled, time.Time{}
|
return true, blockReasonDisabled, time.Time{}
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -2,7 +2,9 @@ package auth
|
|||||||
|
|
||||||
import (
|
import (
|
||||||
"context"
|
"context"
|
||||||
|
"encoding/json"
|
||||||
"errors"
|
"errors"
|
||||||
|
"net/http"
|
||||||
"sync"
|
"sync"
|
||||||
"testing"
|
"testing"
|
||||||
"time"
|
"time"
|
||||||
@@ -175,3 +177,228 @@ func TestRoundRobinSelectorPick_Concurrent(t *testing.T) {
|
|||||||
default:
|
default:
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func TestSelectorPick_AllCooldownReturnsModelCooldownError(t *testing.T) {
|
||||||
|
t.Parallel()
|
||||||
|
|
||||||
|
model := "test-model"
|
||||||
|
now := time.Now()
|
||||||
|
next := now.Add(60 * time.Second)
|
||||||
|
auths := []*Auth{
|
||||||
|
{
|
||||||
|
ID: "a",
|
||||||
|
ModelStates: map[string]*ModelState{
|
||||||
|
model: {
|
||||||
|
Status: StatusActive,
|
||||||
|
Unavailable: true,
|
||||||
|
NextRetryAfter: next,
|
||||||
|
Quota: QuotaState{
|
||||||
|
Exceeded: true,
|
||||||
|
NextRecoverAt: next,
|
||||||
|
},
|
||||||
|
},
|
||||||
|
},
|
||||||
|
},
|
||||||
|
{
|
||||||
|
ID: "b",
|
||||||
|
ModelStates: map[string]*ModelState{
|
||||||
|
model: {
|
||||||
|
Status: StatusActive,
|
||||||
|
Unavailable: true,
|
||||||
|
NextRetryAfter: next,
|
||||||
|
Quota: QuotaState{
|
||||||
|
Exceeded: true,
|
||||||
|
NextRecoverAt: next,
|
||||||
|
},
|
||||||
|
},
|
||||||
|
},
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
t.Run("mixed provider redacts provider field", func(t *testing.T) {
|
||||||
|
t.Parallel()
|
||||||
|
|
||||||
|
selector := &FillFirstSelector{}
|
||||||
|
_, err := selector.Pick(context.Background(), "mixed", model, cliproxyexecutor.Options{}, auths)
|
||||||
|
if err == nil {
|
||||||
|
t.Fatalf("Pick() error = nil")
|
||||||
|
}
|
||||||
|
|
||||||
|
var mce *modelCooldownError
|
||||||
|
if !errors.As(err, &mce) {
|
||||||
|
t.Fatalf("Pick() error = %T, want *modelCooldownError", err)
|
||||||
|
}
|
||||||
|
if mce.StatusCode() != http.StatusTooManyRequests {
|
||||||
|
t.Fatalf("StatusCode() = %d, want %d", mce.StatusCode(), http.StatusTooManyRequests)
|
||||||
|
}
|
||||||
|
|
||||||
|
headers := mce.Headers()
|
||||||
|
if got := headers.Get("Retry-After"); got == "" {
|
||||||
|
t.Fatalf("Headers().Get(Retry-After) = empty")
|
||||||
|
}
|
||||||
|
|
||||||
|
var payload map[string]any
|
||||||
|
if err := json.Unmarshal([]byte(mce.Error()), &payload); err != nil {
|
||||||
|
t.Fatalf("json.Unmarshal(Error()) error = %v", err)
|
||||||
|
}
|
||||||
|
rawErr, ok := payload["error"].(map[string]any)
|
||||||
|
if !ok {
|
||||||
|
t.Fatalf("Error() payload missing error object: %v", payload)
|
||||||
|
}
|
||||||
|
if got, _ := rawErr["code"].(string); got != "model_cooldown" {
|
||||||
|
t.Fatalf("Error().error.code = %q, want %q", got, "model_cooldown")
|
||||||
|
}
|
||||||
|
if _, ok := rawErr["provider"]; ok {
|
||||||
|
t.Fatalf("Error().error.provider exists for mixed provider: %v", rawErr["provider"])
|
||||||
|
}
|
||||||
|
})
|
||||||
|
|
||||||
|
t.Run("non-mixed provider includes provider field", func(t *testing.T) {
|
||||||
|
t.Parallel()
|
||||||
|
|
||||||
|
selector := &FillFirstSelector{}
|
||||||
|
_, err := selector.Pick(context.Background(), "gemini", model, cliproxyexecutor.Options{}, auths)
|
||||||
|
if err == nil {
|
||||||
|
t.Fatalf("Pick() error = nil")
|
||||||
|
}
|
||||||
|
|
||||||
|
var mce *modelCooldownError
|
||||||
|
if !errors.As(err, &mce) {
|
||||||
|
t.Fatalf("Pick() error = %T, want *modelCooldownError", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
var payload map[string]any
|
||||||
|
if err := json.Unmarshal([]byte(mce.Error()), &payload); err != nil {
|
||||||
|
t.Fatalf("json.Unmarshal(Error()) error = %v", err)
|
||||||
|
}
|
||||||
|
rawErr, ok := payload["error"].(map[string]any)
|
||||||
|
if !ok {
|
||||||
|
t.Fatalf("Error() payload missing error object: %v", payload)
|
||||||
|
}
|
||||||
|
if got, _ := rawErr["provider"].(string); got != "gemini" {
|
||||||
|
t.Fatalf("Error().error.provider = %q, want %q", got, "gemini")
|
||||||
|
}
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestIsAuthBlockedForModel_UnavailableWithoutNextRetryIsNotBlocked(t *testing.T) {
|
||||||
|
t.Parallel()
|
||||||
|
|
||||||
|
now := time.Now()
|
||||||
|
model := "test-model"
|
||||||
|
auth := &Auth{
|
||||||
|
ID: "a",
|
||||||
|
ModelStates: map[string]*ModelState{
|
||||||
|
model: {
|
||||||
|
Status: StatusActive,
|
||||||
|
Unavailable: true,
|
||||||
|
Quota: QuotaState{
|
||||||
|
Exceeded: true,
|
||||||
|
},
|
||||||
|
},
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
blocked, reason, next := isAuthBlockedForModel(auth, model, now)
|
||||||
|
if blocked {
|
||||||
|
t.Fatalf("blocked = true, want false")
|
||||||
|
}
|
||||||
|
if reason != blockReasonNone {
|
||||||
|
t.Fatalf("reason = %v, want %v", reason, blockReasonNone)
|
||||||
|
}
|
||||||
|
if !next.IsZero() {
|
||||||
|
t.Fatalf("next = %v, want zero", next)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestFillFirstSelectorPick_ThinkingSuffixFallsBackToBaseModelState(t *testing.T) {
|
||||||
|
t.Parallel()
|
||||||
|
|
||||||
|
selector := &FillFirstSelector{}
|
||||||
|
now := time.Now()
|
||||||
|
|
||||||
|
baseModel := "test-model"
|
||||||
|
requestedModel := "test-model(high)"
|
||||||
|
|
||||||
|
high := &Auth{
|
||||||
|
ID: "high",
|
||||||
|
Attributes: map[string]string{"priority": "10"},
|
||||||
|
ModelStates: map[string]*ModelState{
|
||||||
|
baseModel: {
|
||||||
|
Status: StatusActive,
|
||||||
|
Unavailable: true,
|
||||||
|
NextRetryAfter: now.Add(30 * time.Minute),
|
||||||
|
Quota: QuotaState{
|
||||||
|
Exceeded: true,
|
||||||
|
},
|
||||||
|
},
|
||||||
|
},
|
||||||
|
}
|
||||||
|
low := &Auth{
|
||||||
|
ID: "low",
|
||||||
|
Attributes: map[string]string{"priority": "0"},
|
||||||
|
}
|
||||||
|
|
||||||
|
got, err := selector.Pick(context.Background(), "mixed", requestedModel, cliproxyexecutor.Options{}, []*Auth{high, low})
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("Pick() error = %v", err)
|
||||||
|
}
|
||||||
|
if got == nil {
|
||||||
|
t.Fatalf("Pick() auth = nil")
|
||||||
|
}
|
||||||
|
if got.ID != "low" {
|
||||||
|
t.Fatalf("Pick() auth.ID = %q, want %q", got.ID, "low")
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestRoundRobinSelectorPick_ThinkingSuffixSharesCursor(t *testing.T) {
|
||||||
|
t.Parallel()
|
||||||
|
|
||||||
|
selector := &RoundRobinSelector{}
|
||||||
|
auths := []*Auth{
|
||||||
|
{ID: "b"},
|
||||||
|
{ID: "a"},
|
||||||
|
}
|
||||||
|
|
||||||
|
first, err := selector.Pick(context.Background(), "gemini", "test-model(high)", cliproxyexecutor.Options{}, auths)
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("Pick() first error = %v", err)
|
||||||
|
}
|
||||||
|
second, err := selector.Pick(context.Background(), "gemini", "test-model(low)", cliproxyexecutor.Options{}, auths)
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("Pick() second error = %v", err)
|
||||||
|
}
|
||||||
|
if first == nil || second == nil {
|
||||||
|
t.Fatalf("Pick() returned nil auth")
|
||||||
|
}
|
||||||
|
if first.ID != "a" {
|
||||||
|
t.Fatalf("Pick() first auth.ID = %q, want %q", first.ID, "a")
|
||||||
|
}
|
||||||
|
if second.ID != "b" {
|
||||||
|
t.Fatalf("Pick() second auth.ID = %q, want %q", second.ID, "b")
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestRoundRobinSelectorPick_CursorKeyCap(t *testing.T) {
|
||||||
|
t.Parallel()
|
||||||
|
|
||||||
|
selector := &RoundRobinSelector{maxKeys: 2}
|
||||||
|
auths := []*Auth{{ID: "a"}}
|
||||||
|
|
||||||
|
_, _ = selector.Pick(context.Background(), "gemini", "m1", cliproxyexecutor.Options{}, auths)
|
||||||
|
_, _ = selector.Pick(context.Background(), "gemini", "m2", cliproxyexecutor.Options{}, auths)
|
||||||
|
_, _ = selector.Pick(context.Background(), "gemini", "m3", cliproxyexecutor.Options{}, auths)
|
||||||
|
|
||||||
|
selector.mu.Lock()
|
||||||
|
defer selector.mu.Unlock()
|
||||||
|
|
||||||
|
if selector.cursors == nil {
|
||||||
|
t.Fatalf("selector.cursors = nil")
|
||||||
|
}
|
||||||
|
if len(selector.cursors) != 1 {
|
||||||
|
t.Fatalf("len(selector.cursors) = %d, want %d", len(selector.cursors), 1)
|
||||||
|
}
|
||||||
|
if _, ok := selector.cursors["gemini:m3"]; !ok {
|
||||||
|
t.Fatalf("selector.cursors missing key %q", "gemini:m3")
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|||||||
@@ -7,6 +7,7 @@ import (
|
|||||||
"fmt"
|
"fmt"
|
||||||
"strings"
|
"strings"
|
||||||
|
|
||||||
|
configaccess "github.com/router-for-me/CLIProxyAPI/v6/internal/access/config_access"
|
||||||
"github.com/router-for-me/CLIProxyAPI/v6/internal/api"
|
"github.com/router-for-me/CLIProxyAPI/v6/internal/api"
|
||||||
sdkaccess "github.com/router-for-me/CLIProxyAPI/v6/sdk/access"
|
sdkaccess "github.com/router-for-me/CLIProxyAPI/v6/sdk/access"
|
||||||
sdkAuth "github.com/router-for-me/CLIProxyAPI/v6/sdk/auth"
|
sdkAuth "github.com/router-for-me/CLIProxyAPI/v6/sdk/auth"
|
||||||
@@ -186,11 +187,8 @@ func (b *Builder) Build() (*Service, error) {
|
|||||||
accessManager = sdkaccess.NewManager()
|
accessManager = sdkaccess.NewManager()
|
||||||
}
|
}
|
||||||
|
|
||||||
providers, err := sdkaccess.BuildProviders(&b.cfg.SDKConfig)
|
configaccess.Register(&b.cfg.SDKConfig)
|
||||||
if err != nil {
|
accessManager.SetProviders(sdkaccess.RegisteredProviders())
|
||||||
return nil, err
|
|
||||||
}
|
|
||||||
accessManager.SetProviders(providers)
|
|
||||||
|
|
||||||
coreManager := b.coreManager
|
coreManager := b.coreManager
|
||||||
if coreManager == nil {
|
if coreManager == nil {
|
||||||
|
|||||||
@@ -409,6 +409,8 @@ func (s *Service) ensureExecutorsForAuth(a *coreauth.Auth) {
|
|||||||
s.coreManager.RegisterExecutor(executor.NewQwenExecutor(s.cfg))
|
s.coreManager.RegisterExecutor(executor.NewQwenExecutor(s.cfg))
|
||||||
case "iflow":
|
case "iflow":
|
||||||
s.coreManager.RegisterExecutor(executor.NewIFlowExecutor(s.cfg))
|
s.coreManager.RegisterExecutor(executor.NewIFlowExecutor(s.cfg))
|
||||||
|
case "kimi":
|
||||||
|
s.coreManager.RegisterExecutor(executor.NewKimiExecutor(s.cfg))
|
||||||
case "kiro":
|
case "kiro":
|
||||||
s.coreManager.RegisterExecutor(executor.NewKiroExecutor(s.cfg))
|
s.coreManager.RegisterExecutor(executor.NewKiroExecutor(s.cfg))
|
||||||
case "github-copilot":
|
case "github-copilot":
|
||||||
@@ -826,6 +828,9 @@ func (s *Service) registerModelsForAuth(a *coreauth.Auth) {
|
|||||||
case "iflow":
|
case "iflow":
|
||||||
models = registry.GetIFlowModels()
|
models = registry.GetIFlowModels()
|
||||||
models = applyExcludedModels(models, excluded)
|
models = applyExcludedModels(models, excluded)
|
||||||
|
case "kimi":
|
||||||
|
models = registry.GetKimiModels()
|
||||||
|
models = applyExcludedModels(models, excluded)
|
||||||
case "github-copilot":
|
case "github-copilot":
|
||||||
models = registry.GetGitHubCopilotModels()
|
models = registry.GetGitHubCopilotModels()
|
||||||
models = applyExcludedModels(models, excluded)
|
models = applyExcludedModels(models, excluded)
|
||||||
|
|||||||
@@ -7,8 +7,6 @@ package config
|
|||||||
import internalconfig "github.com/router-for-me/CLIProxyAPI/v6/internal/config"
|
import internalconfig "github.com/router-for-me/CLIProxyAPI/v6/internal/config"
|
||||||
|
|
||||||
type SDKConfig = internalconfig.SDKConfig
|
type SDKConfig = internalconfig.SDKConfig
|
||||||
type AccessConfig = internalconfig.AccessConfig
|
|
||||||
type AccessProvider = internalconfig.AccessProvider
|
|
||||||
|
|
||||||
type Config = internalconfig.Config
|
type Config = internalconfig.Config
|
||||||
|
|
||||||
@@ -34,15 +32,9 @@ type OpenAICompatibilityModel = internalconfig.OpenAICompatibilityModel
|
|||||||
type TLS = internalconfig.TLSConfig
|
type TLS = internalconfig.TLSConfig
|
||||||
|
|
||||||
const (
|
const (
|
||||||
AccessProviderTypeConfigAPIKey = internalconfig.AccessProviderTypeConfigAPIKey
|
DefaultPanelGitHubRepository = internalconfig.DefaultPanelGitHubRepository
|
||||||
DefaultAccessProviderName = internalconfig.DefaultAccessProviderName
|
|
||||||
DefaultPanelGitHubRepository = internalconfig.DefaultPanelGitHubRepository
|
|
||||||
)
|
)
|
||||||
|
|
||||||
func MakeInlineAPIKeyProvider(keys []string) *AccessProvider {
|
|
||||||
return internalconfig.MakeInlineAPIKeyProvider(keys)
|
|
||||||
}
|
|
||||||
|
|
||||||
func LoadConfig(configFile string) (*Config, error) { return internalconfig.LoadConfig(configFile) }
|
func LoadConfig(configFile string) (*Config, error) { return internalconfig.LoadConfig(configFile) }
|
||||||
|
|
||||||
func LoadConfigOptional(configFile string, optional bool) (*Config, error) {
|
func LoadConfigOptional(configFile string, optional bool) (*Config, error) {
|
||||||
|
|||||||
@@ -1,195 +0,0 @@
|
|||||||
package test
|
|
||||||
|
|
||||||
import (
|
|
||||||
"os"
|
|
||||||
"path/filepath"
|
|
||||||
"strings"
|
|
||||||
"testing"
|
|
||||||
|
|
||||||
"github.com/router-for-me/CLIProxyAPI/v6/internal/config"
|
|
||||||
)
|
|
||||||
|
|
||||||
func TestLegacyConfigMigration(t *testing.T) {
|
|
||||||
t.Run("onlyLegacyFields", func(t *testing.T) {
|
|
||||||
path := writeConfig(t, `
|
|
||||||
port: 8080
|
|
||||||
generative-language-api-key:
|
|
||||||
- "legacy-gemini-1"
|
|
||||||
openai-compatibility:
|
|
||||||
- name: "legacy-provider"
|
|
||||||
base-url: "https://example.com"
|
|
||||||
api-keys:
|
|
||||||
- "legacy-openai-1"
|
|
||||||
amp-upstream-url: "https://amp.example.com"
|
|
||||||
amp-upstream-api-key: "amp-legacy-key"
|
|
||||||
amp-restrict-management-to-localhost: false
|
|
||||||
amp-model-mappings:
|
|
||||||
- from: "old-model"
|
|
||||||
to: "new-model"
|
|
||||||
`)
|
|
||||||
cfg, err := config.LoadConfig(path)
|
|
||||||
if err != nil {
|
|
||||||
t.Fatalf("load legacy config: %v", err)
|
|
||||||
}
|
|
||||||
if got := len(cfg.GeminiKey); got != 1 || cfg.GeminiKey[0].APIKey != "legacy-gemini-1" {
|
|
||||||
t.Fatalf("gemini migration mismatch: %+v", cfg.GeminiKey)
|
|
||||||
}
|
|
||||||
if got := len(cfg.OpenAICompatibility); got != 1 {
|
|
||||||
t.Fatalf("expected 1 openai-compat provider, got %d", got)
|
|
||||||
}
|
|
||||||
if entries := cfg.OpenAICompatibility[0].APIKeyEntries; len(entries) != 1 || entries[0].APIKey != "legacy-openai-1" {
|
|
||||||
t.Fatalf("openai-compat migration mismatch: %+v", entries)
|
|
||||||
}
|
|
||||||
if cfg.AmpCode.UpstreamURL != "https://amp.example.com" || cfg.AmpCode.UpstreamAPIKey != "amp-legacy-key" {
|
|
||||||
t.Fatalf("amp migration failed: %+v", cfg.AmpCode)
|
|
||||||
}
|
|
||||||
if cfg.AmpCode.RestrictManagementToLocalhost {
|
|
||||||
t.Fatalf("expected amp restriction to be false after migration")
|
|
||||||
}
|
|
||||||
if got := len(cfg.AmpCode.ModelMappings); got != 1 || cfg.AmpCode.ModelMappings[0].From != "old-model" {
|
|
||||||
t.Fatalf("amp mappings migration mismatch: %+v", cfg.AmpCode.ModelMappings)
|
|
||||||
}
|
|
||||||
updated := readFile(t, path)
|
|
||||||
if strings.Contains(updated, "generative-language-api-key") {
|
|
||||||
t.Fatalf("legacy gemini key still present:\n%s", updated)
|
|
||||||
}
|
|
||||||
if strings.Contains(updated, "amp-upstream-url") || strings.Contains(updated, "amp-restrict-management-to-localhost") {
|
|
||||||
t.Fatalf("legacy amp keys still present:\n%s", updated)
|
|
||||||
}
|
|
||||||
if strings.Contains(updated, "\n api-keys:") {
|
|
||||||
t.Fatalf("legacy openai compat keys still present:\n%s", updated)
|
|
||||||
}
|
|
||||||
})
|
|
||||||
|
|
||||||
t.Run("mixedLegacyAndNewFields", func(t *testing.T) {
|
|
||||||
path := writeConfig(t, `
|
|
||||||
gemini-api-key:
|
|
||||||
- api-key: "new-gemini"
|
|
||||||
generative-language-api-key:
|
|
||||||
- "new-gemini"
|
|
||||||
- "legacy-gemini-only"
|
|
||||||
openai-compatibility:
|
|
||||||
- name: "mixed-provider"
|
|
||||||
base-url: "https://mixed.example.com"
|
|
||||||
api-key-entries:
|
|
||||||
- api-key: "new-entry"
|
|
||||||
api-keys:
|
|
||||||
- "legacy-entry"
|
|
||||||
- "new-entry"
|
|
||||||
`)
|
|
||||||
cfg, err := config.LoadConfig(path)
|
|
||||||
if err != nil {
|
|
||||||
t.Fatalf("load mixed config: %v", err)
|
|
||||||
}
|
|
||||||
if got := len(cfg.GeminiKey); got != 2 {
|
|
||||||
t.Fatalf("expected 2 gemini entries, got %d: %+v", got, cfg.GeminiKey)
|
|
||||||
}
|
|
||||||
seen := make(map[string]struct{}, len(cfg.GeminiKey))
|
|
||||||
for _, entry := range cfg.GeminiKey {
|
|
||||||
if _, exists := seen[entry.APIKey]; exists {
|
|
||||||
t.Fatalf("duplicate gemini key %q after migration", entry.APIKey)
|
|
||||||
}
|
|
||||||
seen[entry.APIKey] = struct{}{}
|
|
||||||
}
|
|
||||||
provider := cfg.OpenAICompatibility[0]
|
|
||||||
if got := len(provider.APIKeyEntries); got != 2 {
|
|
||||||
t.Fatalf("expected 2 openai entries, got %d: %+v", got, provider.APIKeyEntries)
|
|
||||||
}
|
|
||||||
entrySeen := make(map[string]struct{}, len(provider.APIKeyEntries))
|
|
||||||
for _, entry := range provider.APIKeyEntries {
|
|
||||||
if _, ok := entrySeen[entry.APIKey]; ok {
|
|
||||||
t.Fatalf("duplicate openai key %q after migration", entry.APIKey)
|
|
||||||
}
|
|
||||||
entrySeen[entry.APIKey] = struct{}{}
|
|
||||||
}
|
|
||||||
})
|
|
||||||
|
|
||||||
t.Run("onlyNewFields", func(t *testing.T) {
|
|
||||||
path := writeConfig(t, `
|
|
||||||
gemini-api-key:
|
|
||||||
- api-key: "new-only"
|
|
||||||
openai-compatibility:
|
|
||||||
- name: "new-only-provider"
|
|
||||||
base-url: "https://new-only.example.com"
|
|
||||||
api-key-entries:
|
|
||||||
- api-key: "new-only-entry"
|
|
||||||
ampcode:
|
|
||||||
upstream-url: "https://amp.new"
|
|
||||||
upstream-api-key: "new-amp-key"
|
|
||||||
restrict-management-to-localhost: true
|
|
||||||
model-mappings:
|
|
||||||
- from: "a"
|
|
||||||
to: "b"
|
|
||||||
`)
|
|
||||||
cfg, err := config.LoadConfig(path)
|
|
||||||
if err != nil {
|
|
||||||
t.Fatalf("load new config: %v", err)
|
|
||||||
}
|
|
||||||
if len(cfg.GeminiKey) != 1 || cfg.GeminiKey[0].APIKey != "new-only" {
|
|
||||||
t.Fatalf("unexpected gemini entries: %+v", cfg.GeminiKey)
|
|
||||||
}
|
|
||||||
if len(cfg.OpenAICompatibility) != 1 || len(cfg.OpenAICompatibility[0].APIKeyEntries) != 1 {
|
|
||||||
t.Fatalf("unexpected openai compat entries: %+v", cfg.OpenAICompatibility)
|
|
||||||
}
|
|
||||||
if cfg.AmpCode.UpstreamURL != "https://amp.new" || cfg.AmpCode.UpstreamAPIKey != "new-amp-key" {
|
|
||||||
t.Fatalf("unexpected amp config: %+v", cfg.AmpCode)
|
|
||||||
}
|
|
||||||
})
|
|
||||||
|
|
||||||
t.Run("duplicateNamesDifferentBase", func(t *testing.T) {
|
|
||||||
path := writeConfig(t, `
|
|
||||||
openai-compatibility:
|
|
||||||
- name: "dup-provider"
|
|
||||||
base-url: "https://provider-a"
|
|
||||||
api-keys:
|
|
||||||
- "key-a"
|
|
||||||
- name: "dup-provider"
|
|
||||||
base-url: "https://provider-b"
|
|
||||||
api-keys:
|
|
||||||
- "key-b"
|
|
||||||
`)
|
|
||||||
cfg, err := config.LoadConfig(path)
|
|
||||||
if err != nil {
|
|
||||||
t.Fatalf("load duplicate config: %v", err)
|
|
||||||
}
|
|
||||||
if len(cfg.OpenAICompatibility) != 2 {
|
|
||||||
t.Fatalf("expected 2 providers, got %d", len(cfg.OpenAICompatibility))
|
|
||||||
}
|
|
||||||
for _, entry := range cfg.OpenAICompatibility {
|
|
||||||
if len(entry.APIKeyEntries) != 1 {
|
|
||||||
t.Fatalf("expected 1 key entry per provider: %+v", entry)
|
|
||||||
}
|
|
||||||
switch entry.BaseURL {
|
|
||||||
case "https://provider-a":
|
|
||||||
if entry.APIKeyEntries[0].APIKey != "key-a" {
|
|
||||||
t.Fatalf("provider-a key mismatch: %+v", entry.APIKeyEntries)
|
|
||||||
}
|
|
||||||
case "https://provider-b":
|
|
||||||
if entry.APIKeyEntries[0].APIKey != "key-b" {
|
|
||||||
t.Fatalf("provider-b key mismatch: %+v", entry.APIKeyEntries)
|
|
||||||
}
|
|
||||||
default:
|
|
||||||
t.Fatalf("unexpected provider base url: %s", entry.BaseURL)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
})
|
|
||||||
}
|
|
||||||
|
|
||||||
func writeConfig(t *testing.T, content string) string {
|
|
||||||
t.Helper()
|
|
||||||
dir := t.TempDir()
|
|
||||||
path := filepath.Join(dir, "config.yaml")
|
|
||||||
if err := os.WriteFile(path, []byte(strings.TrimSpace(content)+"\n"), 0o644); err != nil {
|
|
||||||
t.Fatalf("write temp config: %v", err)
|
|
||||||
}
|
|
||||||
return path
|
|
||||||
}
|
|
||||||
|
|
||||||
func readFile(t *testing.T, path string) string {
|
|
||||||
t.Helper()
|
|
||||||
data, err := os.ReadFile(path)
|
|
||||||
if err != nil {
|
|
||||||
t.Fatalf("read temp config: %v", err)
|
|
||||||
}
|
|
||||||
return string(data)
|
|
||||||
}
|
|
||||||
@@ -15,6 +15,7 @@ import (
|
|||||||
_ "github.com/router-for-me/CLIProxyAPI/v6/internal/thinking/provider/gemini"
|
_ "github.com/router-for-me/CLIProxyAPI/v6/internal/thinking/provider/gemini"
|
||||||
_ "github.com/router-for-me/CLIProxyAPI/v6/internal/thinking/provider/geminicli"
|
_ "github.com/router-for-me/CLIProxyAPI/v6/internal/thinking/provider/geminicli"
|
||||||
_ "github.com/router-for-me/CLIProxyAPI/v6/internal/thinking/provider/iflow"
|
_ "github.com/router-for-me/CLIProxyAPI/v6/internal/thinking/provider/iflow"
|
||||||
|
_ "github.com/router-for-me/CLIProxyAPI/v6/internal/thinking/provider/kimi"
|
||||||
_ "github.com/router-for-me/CLIProxyAPI/v6/internal/thinking/provider/openai"
|
_ "github.com/router-for-me/CLIProxyAPI/v6/internal/thinking/provider/openai"
|
||||||
|
|
||||||
"github.com/router-for-me/CLIProxyAPI/v6/internal/registry"
|
"github.com/router-for-me/CLIProxyAPI/v6/internal/registry"
|
||||||
@@ -2589,6 +2590,135 @@ func TestThinkingE2EMatrix_Body(t *testing.T) {
|
|||||||
runThinkingTests(t, cases)
|
runThinkingTests(t, cases)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// TestThinkingE2EClaudeAdaptive_Body tests Claude thinking.type=adaptive extended body-only cases.
|
||||||
|
// These cases validate that adaptive means "thinking enabled without explicit budget", and
|
||||||
|
// cross-protocol conversion should resolve to target-model maximum thinking capability.
|
||||||
|
func TestThinkingE2EClaudeAdaptive_Body(t *testing.T) {
|
||||||
|
reg := registry.GetGlobalRegistry()
|
||||||
|
uid := fmt.Sprintf("thinking-e2e-claude-adaptive-%d", time.Now().UnixNano())
|
||||||
|
|
||||||
|
reg.RegisterClient(uid, "test", getTestModels())
|
||||||
|
defer reg.UnregisterClient(uid)
|
||||||
|
|
||||||
|
cases := []thinkingTestCase{
|
||||||
|
// A1: Claude adaptive to OpenAI level model -> highest supported level
|
||||||
|
{
|
||||||
|
name: "A1",
|
||||||
|
from: "claude",
|
||||||
|
to: "openai",
|
||||||
|
model: "level-model",
|
||||||
|
inputJSON: `{"model":"level-model","messages":[{"role":"user","content":"hi"}],"thinking":{"type":"adaptive"}}`,
|
||||||
|
expectField: "reasoning_effort",
|
||||||
|
expectValue: "high",
|
||||||
|
expectErr: false,
|
||||||
|
},
|
||||||
|
// A2: Claude adaptive to Gemini level subset model -> highest supported level
|
||||||
|
{
|
||||||
|
name: "A2",
|
||||||
|
from: "claude",
|
||||||
|
to: "gemini",
|
||||||
|
model: "level-subset-model",
|
||||||
|
inputJSON: `{"model":"level-subset-model","messages":[{"role":"user","content":"hi"}],"thinking":{"type":"adaptive"}}`,
|
||||||
|
expectField: "generationConfig.thinkingConfig.thinkingLevel",
|
||||||
|
expectValue: "high",
|
||||||
|
includeThoughts: "true",
|
||||||
|
expectErr: false,
|
||||||
|
},
|
||||||
|
// A3: Claude adaptive to Gemini budget model -> max budget
|
||||||
|
{
|
||||||
|
name: "A3",
|
||||||
|
from: "claude",
|
||||||
|
to: "gemini",
|
||||||
|
model: "gemini-budget-model",
|
||||||
|
inputJSON: `{"model":"gemini-budget-model","messages":[{"role":"user","content":"hi"}],"thinking":{"type":"adaptive"}}`,
|
||||||
|
expectField: "generationConfig.thinkingConfig.thinkingBudget",
|
||||||
|
expectValue: "20000",
|
||||||
|
includeThoughts: "true",
|
||||||
|
expectErr: false,
|
||||||
|
},
|
||||||
|
// A4: Claude adaptive to Gemini mixed model -> highest supported level
|
||||||
|
{
|
||||||
|
name: "A4",
|
||||||
|
from: "claude",
|
||||||
|
to: "gemini",
|
||||||
|
model: "gemini-mixed-model",
|
||||||
|
inputJSON: `{"model":"gemini-mixed-model","messages":[{"role":"user","content":"hi"}],"thinking":{"type":"adaptive"}}`,
|
||||||
|
expectField: "generationConfig.thinkingConfig.thinkingLevel",
|
||||||
|
expectValue: "high",
|
||||||
|
includeThoughts: "true",
|
||||||
|
expectErr: false,
|
||||||
|
},
|
||||||
|
// A5: Claude adaptive passthrough for same protocol
|
||||||
|
{
|
||||||
|
name: "A5",
|
||||||
|
from: "claude",
|
||||||
|
to: "claude",
|
||||||
|
model: "claude-budget-model",
|
||||||
|
inputJSON: `{"model":"claude-budget-model","messages":[{"role":"user","content":"hi"}],"thinking":{"type":"adaptive"}}`,
|
||||||
|
expectField: "thinking.type",
|
||||||
|
expectValue: "adaptive",
|
||||||
|
expectErr: false,
|
||||||
|
},
|
||||||
|
// A6: Claude adaptive to Antigravity budget model -> max budget
|
||||||
|
{
|
||||||
|
name: "A6",
|
||||||
|
from: "claude",
|
||||||
|
to: "antigravity",
|
||||||
|
model: "antigravity-budget-model",
|
||||||
|
inputJSON: `{"model":"antigravity-budget-model","messages":[{"role":"user","content":"hi"}],"thinking":{"type":"adaptive"}}`,
|
||||||
|
expectField: "request.generationConfig.thinkingConfig.thinkingBudget",
|
||||||
|
expectValue: "20000",
|
||||||
|
includeThoughts: "true",
|
||||||
|
expectErr: false,
|
||||||
|
},
|
||||||
|
// A7: Claude adaptive to iFlow GLM -> enabled boolean
|
||||||
|
{
|
||||||
|
name: "A7",
|
||||||
|
from: "claude",
|
||||||
|
to: "iflow",
|
||||||
|
model: "glm-test",
|
||||||
|
inputJSON: `{"model":"glm-test","messages":[{"role":"user","content":"hi"}],"thinking":{"type":"adaptive"}}`,
|
||||||
|
expectField: "chat_template_kwargs.enable_thinking",
|
||||||
|
expectValue: "true",
|
||||||
|
expectErr: false,
|
||||||
|
},
|
||||||
|
// A8: Claude adaptive to iFlow MiniMax -> enabled boolean
|
||||||
|
{
|
||||||
|
name: "A8",
|
||||||
|
from: "claude",
|
||||||
|
to: "iflow",
|
||||||
|
model: "minimax-test",
|
||||||
|
inputJSON: `{"model":"minimax-test","messages":[{"role":"user","content":"hi"}],"thinking":{"type":"adaptive"}}`,
|
||||||
|
expectField: "reasoning_split",
|
||||||
|
expectValue: "true",
|
||||||
|
expectErr: false,
|
||||||
|
},
|
||||||
|
// A9: Claude adaptive to Codex level model -> highest supported level
|
||||||
|
{
|
||||||
|
name: "A9",
|
||||||
|
from: "claude",
|
||||||
|
to: "codex",
|
||||||
|
model: "level-model",
|
||||||
|
inputJSON: `{"model":"level-model","messages":[{"role":"user","content":"hi"}],"thinking":{"type":"adaptive"}}`,
|
||||||
|
expectField: "reasoning.effort",
|
||||||
|
expectValue: "high",
|
||||||
|
expectErr: false,
|
||||||
|
},
|
||||||
|
// A10: Claude adaptive on non-thinking model should still be stripped
|
||||||
|
{
|
||||||
|
name: "A10",
|
||||||
|
from: "claude",
|
||||||
|
to: "openai",
|
||||||
|
model: "no-thinking-model",
|
||||||
|
inputJSON: `{"model":"no-thinking-model","messages":[{"role":"user","content":"hi"}],"thinking":{"type":"adaptive"}}`,
|
||||||
|
expectField: "",
|
||||||
|
expectErr: false,
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
runThinkingTests(t, cases)
|
||||||
|
}
|
||||||
|
|
||||||
// getTestModels returns the shared model definitions for E2E tests.
|
// getTestModels returns the shared model definitions for E2E tests.
|
||||||
func getTestModels() []*registry.ModelInfo {
|
func getTestModels() []*registry.ModelInfo {
|
||||||
return []*registry.ModelInfo{
|
return []*registry.ModelInfo{
|
||||||
|
|||||||
Reference in New Issue
Block a user