Merge branch 'main' into plus

This commit is contained in:
Luis Pater
2026-03-01 02:44:26 +08:00
committed by GitHub
127 changed files with 32229 additions and 419 deletions

View File

@@ -1,13 +1,14 @@
name: docker-image
on:
workflow_dispatch:
push:
tags:
- v*
env:
APP_NAME: CLIProxyAPI
DOCKERHUB_REPO: eceasy/cli-proxy-api
DOCKERHUB_REPO: ${{ secrets.DOCKERHUB_USERNAME }}/cli-proxy-api-plus
jobs:
docker_amd64:

View File

@@ -23,7 +23,8 @@ jobs:
cache: true
- name: Generate Build Metadata
run: |
echo VERSION=`git describe --tags --always --dirty` >> $GITHUB_ENV
VERSION=$(git describe --tags --always --dirty)
echo "VERSION=${VERSION}" >> $GITHUB_ENV
echo COMMIT=`git rev-parse --short HEAD` >> $GITHUB_ENV
echo BUILD_DATE=`date -u +%Y-%m-%dT%H:%M:%SZ` >> $GITHUB_ENV
- uses: goreleaser/goreleaser-action@v4

7
.gitignore vendored
View File

@@ -1,17 +1,20 @@
# Binaries
cli-proxy-api
cliproxy
*.exe
# Configuration
config.yaml
.env
.mcp.json
# Generated content
bin/*
logs/*
conv/*
temp/*
refs/*
tmp/*
# Storage backends
pgstore/*
@@ -44,7 +47,9 @@ GEMINI.md
.bmad/*
_bmad/*
_bmad-output/*
.mcp/cache/
# macOS
.DS_Store
._*
*.bak

View File

@@ -1,5 +1,5 @@
builds:
- id: "cli-proxy-api"
- id: "cli-proxy-api-plus"
env:
- CGO_ENABLED=0
goos:
@@ -10,11 +10,11 @@ builds:
- amd64
- arm64
main: ./cmd/server/
binary: cli-proxy-api
binary: cli-proxy-api-plus
ldflags:
- -s -w -X 'main.Version={{.Version}}' -X 'main.Commit={{.ShortCommit}}' -X 'main.BuildDate={{.Date}}'
- -s -w -X 'main.Version={{.Version}}-plus' -X 'main.Commit={{.ShortCommit}}' -X 'main.BuildDate={{.Date}}'
archives:
- id: "cli-proxy-api"
- id: "cli-proxy-api-plus"
format: tar.gz
format_overrides:
- goos: windows

View File

@@ -12,7 +12,7 @@ ARG VERSION=dev
ARG COMMIT=none
ARG BUILD_DATE=unknown
RUN CGO_ENABLED=0 GOOS=linux go build -ldflags="-s -w -X 'main.Version=${VERSION}' -X 'main.Commit=${COMMIT}' -X 'main.BuildDate=${BUILD_DATE}'" -o ./CLIProxyAPI ./cmd/server/
RUN CGO_ENABLED=0 GOOS=linux go build -ldflags="-s -w -X 'main.Version=${VERSION}-plus' -X 'main.Commit=${COMMIT}' -X 'main.BuildDate=${BUILD_DATE}'" -o ./CLIProxyAPIPlus ./cmd/server/
FROM alpine:3.22.0
@@ -20,7 +20,7 @@ RUN apk add --no-cache tzdata
RUN mkdir /CLIProxyAPI
COPY --from=builder ./app/CLIProxyAPI /CLIProxyAPI/CLIProxyAPI
COPY --from=builder ./app/CLIProxyAPIPlus /CLIProxyAPI/CLIProxyAPIPlus
COPY config.example.yaml /CLIProxyAPI/config.example.yaml
@@ -32,4 +32,4 @@ ENV TZ=Asia/Shanghai
RUN cp /usr/share/zoneinfo/${TZ} /etc/localtime && echo "${TZ}" > /etc/timezone
CMD ["./CLIProxyAPI"]
CMD ["./CLIProxyAPIPlus"]

258
README.md
View File

@@ -1,174 +1,144 @@
# CLI Proxy API
# CLIProxyAPI Plus
English | [中文](README_CN.md)
English | [Chinese](README_CN.md)
A proxy server that provides OpenAI/Gemini/Claude/Codex compatible API interfaces for CLI.
This is the Plus version of [CLIProxyAPI](https://github.com/router-for-me/CLIProxyAPI), adding support for third-party providers on top of the mainline project.
It now also supports OpenAI Codex (GPT models) and Claude Code via OAuth.
All third-party provider support is maintained by community contributors; CLIProxyAPI does not provide technical support. Please contact the corresponding community maintainer if you need assistance.
So you can use local or multi-account CLI access with OpenAI(include Responses)/Gemini/Claude-compatible clients and SDKs.
The Plus release stays in lockstep with the mainline features.
## Sponsor
## Differences from the Mainline
[![z.ai](https://assets.router-for.me/english-4.7.png)](https://z.ai/subscribe?ic=8JVLJQFSKB)
- Added GitHub Copilot support (OAuth login), provided by [em4go](https://github.com/em4go/CLIProxyAPI/tree/feature/github-copilot-auth)
- Added Kiro (AWS CodeWhisperer) support (OAuth login), provided by [fuko2935](https://github.com/fuko2935/CLIProxyAPI/tree/feature/kiro-integration), [Ravens2121](https://github.com/Ravens2121/CLIProxyAPIPlus/)
This project is sponsored by Z.ai, supporting us with their GLM CODING PLAN.
## New Features (Plus Enhanced)
GLM CODING PLAN is a subscription service designed for AI coding, starting at just $3/month. It provides access to their flagship GLM-4.7 model across 10+ popular AI coding tools (Claude Code, Cline, Roo Code, etc.), offering developers top-tier, fast, and stable coding experiences.
- **OAuth Web Authentication**: Browser-based OAuth login for Kiro with beautiful web UI
- **Rate Limiter**: Built-in request rate limiting to prevent API abuse
- **Background Token Refresh**: Automatic token refresh 10 minutes before expiration
- **Metrics & Monitoring**: Request metrics collection for monitoring and debugging
- **Device Fingerprint**: Device fingerprint generation for enhanced security
- **Cooldown Management**: Smart cooldown mechanism for API rate limits
- **Usage Checker**: Real-time usage monitoring and quota management
- **Model Converter**: Unified model name conversion across providers
- **UTF-8 Stream Processing**: Improved streaming response handling
Get 10% OFF GLM CODING PLANhttps://z.ai/subscribe?ic=8JVLJQFSKB
## Kiro Authentication
---
### CLI Login
<table>
<tbody>
<tr>
<td width="180"><a href="https://www.packyapi.com/register?aff=cliproxyapi"><img src="./assets/packycode.png" alt="PackyCode" width="150"></a></td>
<td>Thanks to PackyCode for sponsoring this project! PackyCode is a reliable and efficient API relay service provider, offering relay services for Claude Code, Codex, Gemini, and more. PackyCode provides special discounts for our software users: register using <a href="https://www.packyapi.com/register?aff=cliproxyapi">this link</a> and enter the "cliproxyapi" promo code during recharge to get 10% off.</td>
</tr>
<tr>
<td width="180"><a href="https://www.aicodemirror.com/register?invitecode=TJNAIF"><img src="./assets/aicodemirror.png" alt="AICodeMirror" width="150"></a></td>
<td>Thanks to AICodeMirror for sponsoring this project! AICodeMirror provides official high-stability relay services for Claude Code / Codex / Gemini CLI, with enterprise-grade concurrency, fast invoicing, and 24/7 dedicated technical support. Claude Code / Codex / Gemini official channels at 38% / 2% / 9% of original price, with extra discounts on top-ups! AICodeMirror offers special benefits for CLIProxyAPI users: register via <a href="https://www.aicodemirror.com/register?invitecode=TJNAIF">this link</a> to enjoy 20% off your first top-up, and enterprise customers can get up to 25% off!</td>
</tr>
</tbody>
</table>
> **Note:** Google/GitHub login is not available for third-party applications due to AWS Cognito restrictions.
## Overview
**AWS Builder ID** (recommended):
- OpenAI/Gemini/Claude compatible API endpoints for CLI models
- OpenAI Codex support (GPT models) via OAuth login
- Claude Code support via OAuth login
- Qwen Code support via OAuth login
- iFlow support via OAuth login
- Amp CLI and IDE extensions support with provider routing
- Streaming and non-streaming responses
- Function calling/tools support
- Multimodal input support (text and images)
- Multiple accounts with round-robin load balancing (Gemini, OpenAI, Claude, Qwen and iFlow)
- Simple CLI authentication flows (Gemini, OpenAI, Claude, Qwen and iFlow)
- Generative Language API Key support
- AI Studio Build multi-account load balancing
- Gemini CLI multi-account load balancing
- Claude Code multi-account load balancing
- Qwen Code multi-account load balancing
- iFlow multi-account load balancing
- OpenAI Codex multi-account load balancing
- OpenAI-compatible upstream providers via config (e.g., OpenRouter)
- Reusable Go SDK for embedding the proxy (see `docs/sdk-usage.md`)
```bash
# Device code flow
./CLIProxyAPI --kiro-aws-login
## Getting Started
# Authorization code flow
./CLIProxyAPI --kiro-aws-authcode
```
CLIProxyAPI Guides: [https://help.router-for.me/](https://help.router-for.me/)
**Import token from Kiro IDE:**
## Management API
```bash
./CLIProxyAPI --kiro-import
```
see [MANAGEMENT_API.md](https://help.router-for.me/management/api)
To get a token from Kiro IDE:
## Amp CLI Support
1. Open Kiro IDE and login with Google (or GitHub)
2. Find the token file: `~/.kiro/kiro-auth-token.json`
3. Run: `./CLIProxyAPI --kiro-import`
CLIProxyAPI includes integrated support for [Amp CLI](https://ampcode.com) and Amp IDE extensions, enabling you to use your Google/ChatGPT/Claude OAuth subscriptions with Amp's coding tools:
**AWS IAM Identity Center (IDC):**
- Provider route aliases for Amp's API patterns (`/api/provider/{provider}/v1...`)
- Management proxy for OAuth authentication and account features
- Smart model fallback with automatic routing
- **Model mapping** to route unavailable models to alternatives (e.g., `claude-opus-4.5``claude-sonnet-4`)
- Security-first design with localhost-only management endpoints
```bash
./CLIProxyAPI --kiro-idc-login --kiro-idc-start-url https://d-xxxxxxxxxx.awsapps.com/start
**→ [Complete Amp CLI Integration Guide](https://help.router-for.me/agent-client/amp-cli.html)**
# Specify region
./CLIProxyAPI --kiro-idc-login --kiro-idc-start-url https://d-xxxxxxxxxx.awsapps.com/start --kiro-idc-region us-west-2
```
## SDK Docs
**Additional flags:**
- Usage: [docs/sdk-usage.md](docs/sdk-usage.md)
- Advanced (executors & translators): [docs/sdk-advanced.md](docs/sdk-advanced.md)
- Access: [docs/sdk-access.md](docs/sdk-access.md)
- Watcher: [docs/sdk-watcher.md](docs/sdk-watcher.md)
- Custom Provider Example: `examples/custom-provider`
| Flag | Description |
|------|-------------|
| `--no-browser` | Don't open browser automatically, print URL instead |
| `--no-incognito` | Use existing browser session (Kiro defaults to incognito). Useful for corporate SSO that requires an authenticated browser session |
| `--kiro-idc-start-url` | IDC Start URL (required with `--kiro-idc-login`) |
| `--kiro-idc-region` | IDC region (default: `us-east-1`) |
| `--kiro-idc-flow` | IDC flow type: `authcode` (default) or `device` |
### Web-based OAuth Login
Access the Kiro OAuth web interface at:
```
http://your-server:8080/v0/oauth/kiro
```
This provides a browser-based OAuth flow for Kiro (AWS CodeWhisperer) authentication with:
- AWS Builder ID login
- AWS Identity Center (IDC) login
- Token import from Kiro IDE
## Quick Deployment with Docker
### One-Command Deployment
```bash
# Create deployment directory
mkdir -p ~/cli-proxy && cd ~/cli-proxy
# Create docker-compose.yml
cat > docker-compose.yml << 'EOF'
services:
cli-proxy-api:
image: eceasy/cli-proxy-api-plus:latest
container_name: cli-proxy-api-plus
ports:
- "8317:8317"
volumes:
- ./config.yaml:/CLIProxyAPI/config.yaml
- ./auths:/root/.cli-proxy-api
- ./logs:/CLIProxyAPI/logs
restart: unless-stopped
EOF
# Download example config
curl -o config.yaml https://raw.githubusercontent.com/router-for-me/CLIProxyAPIPlus/main/config.example.yaml
# Pull and start
docker compose pull && docker compose up -d
```
### Configuration
Edit `config.yaml` before starting:
```yaml
# Basic configuration example
server:
port: 8317
# Add your provider configurations here
```
### Update to Latest Version
```bash
cd ~/cli-proxy
docker compose pull && docker compose up -d
```
## Contributing
Contributions are welcome! Please feel free to submit a Pull Request.
This project only accepts pull requests that relate to third-party provider support. Any pull requests unrelated to third-party provider support will be rejected.
1. Fork the repository
2. Create your feature branch (`git checkout -b feature/amazing-feature`)
3. Commit your changes (`git commit -m 'Add some amazing feature'`)
4. Push to the branch (`git push origin feature/amazing-feature`)
5. Open a Pull Request
## Who is with us?
Those projects are based on CLIProxyAPI:
### [vibeproxy](https://github.com/automazeio/vibeproxy)
Native macOS menu bar app to use your Claude Code & ChatGPT subscriptions with AI coding tools - no API keys needed
### [Subtitle Translator](https://github.com/VjayC/SRT-Subtitle-Translator-Validator)
Browser-based tool to translate SRT subtitles using your Gemini subscription via CLIProxyAPI with automatic validation/error correction - no API keys needed
### [CCS (Claude Code Switch)](https://github.com/kaitranntt/ccs)
CLI wrapper for instant switching between multiple Claude accounts and alternative models (Gemini, Codex, Antigravity) via CLIProxyAPI OAuth - no API keys needed
### [ProxyPal](https://github.com/heyhuynhgiabuu/proxypal)
Native macOS GUI for managing CLIProxyAPI: configure providers, model mappings, and endpoints via OAuth - no API keys needed.
### [Quotio](https://github.com/nguyenphutrong/quotio)
Native macOS menu bar app that unifies Claude, Gemini, OpenAI, Qwen, and Antigravity subscriptions with real-time quota tracking and smart auto-failover for AI coding tools like Claude Code, OpenCode, and Droid - no API keys needed.
### [CodMate](https://github.com/loocor/CodMate)
Native macOS SwiftUI app for managing CLI AI sessions (Codex, Claude Code, Gemini CLI) with unified provider management, Git review, project organization, global search, and terminal integration. Integrates CLIProxyAPI to provide OAuth authentication for Codex, Claude, Gemini, Antigravity, and Qwen Code, with built-in and third-party provider rerouting through a single proxy endpoint - no API keys needed for OAuth providers.
### [ProxyPilot](https://github.com/Finesssee/ProxyPilot)
Windows-native CLIProxyAPI fork with TUI, system tray, and multi-provider OAuth for AI coding tools - no API keys needed.
### [Claude Proxy VSCode](https://github.com/uzhao/claude-proxy-vscode)
VSCode extension for quick switching between Claude Code models, featuring integrated CLIProxyAPI as its backend with automatic background lifecycle management.
### [ZeroLimit](https://github.com/0xtbug/zero-limit)
Windows desktop app built with Tauri + React for monitoring AI coding assistant quotas via CLIProxyAPI. Track usage across Gemini, Claude, OpenAI Codex, and Antigravity accounts with real-time dashboard, system tray integration, and one-click proxy control - no API keys needed.
### [CPA-XXX Panel](https://github.com/ferretgeek/CPA-X)
A lightweight web admin panel for CLIProxyAPI with health checks, resource monitoring, real-time logs, auto-update, request statistics and pricing display. Supports one-click installation and systemd service.
### [CLIProxyAPI Tray](https://github.com/kitephp/CLIProxyAPI_Tray)
A Windows tray application implemented using PowerShell scripts, without relying on any third-party libraries. The main features include: automatic creation of shortcuts, silent running, password management, channel switching (Main / Plus), and automatic downloading and updating.
### [霖君](https://github.com/wangdabaoqq/LinJun)
霖君 is a cross-platform desktop application for managing AI programming assistants, supporting macOS, Windows, and Linux systems. Unified management of Claude Code, Gemini CLI, OpenAI Codex, Qwen Code, and other AI coding tools, with local proxy for multi-account quota tracking and one-click configuration.
### [CLIProxyAPI Dashboard](https://github.com/itsmylife44/cliproxyapi-dashboard)
A modern web-based management dashboard for CLIProxyAPI built with Next.js, React, and PostgreSQL. Features real-time log streaming, structured configuration editing, API key management, OAuth provider integration for Claude/Gemini/Codex, usage analytics, container management, and config sync with OpenCode via companion plugin - no manual YAML editing needed.
> [!NOTE]
> If you developed a project based on CLIProxyAPI, please open a PR to add it to this list.
## More choices
Those projects are ports of CLIProxyAPI or inspired by it:
### [9Router](https://github.com/decolua/9router)
A Next.js implementation inspired by CLIProxyAPI, easy to install and use, built from scratch with format translation (OpenAI/Claude/Gemini/Ollama), combo system with auto-fallback, multi-account management with exponential backoff, a Next.js web dashboard, and support for CLI tools (Cursor, Claude Code, Cline, RooCode) - no API keys needed.
### [OmniRoute](https://github.com/diegosouzapw/OmniRoute)
Never stop coding. Smart routing to FREE & low-cost AI models with automatic fallback.
OmniRoute is an AI gateway for multi-provider LLMs: an OpenAI-compatible endpoint with smart routing, load balancing, retries, and fallbacks. Add policies, rate limits, caching, and observability for reliable, cost-aware inference.
> [!NOTE]
> If you have developed a port of CLIProxyAPI or a project inspired by it, please open a PR to add it to this list.
If you need to submit any non-third-party provider changes, please open them against the [mainline](https://github.com/router-for-me/CLIProxyAPI) repository.
## License

View File

@@ -1,182 +1,145 @@
# CLI 代理 API
# CLIProxyAPI Plus
[English](README.md) | 中文
一个为 CLI 提供 OpenAI/Gemini/Claude/Codex 兼容 API 接口的代理服务器
这是 [CLIProxyAPI](https://github.com/router-for-me/CLIProxyAPI) 的 Plus 版本,在原有基础上增加了第三方供应商的支持
现已支持通过 OAuth 登录接入 OpenAI CodexGPT 系列)和 Claude Code
所有的第三方供应商支持都由第三方社区维护者提供CLIProxyAPI 不提供技术支持。如需取得支持,请与对应的社区维护者联系
您可以使用本地或多账户的CLI方式通过任何与 OpenAI包括Responses/Gemini/Claude 兼容的客户端和SDK进行访问
该 Plus 版本的主线功能与主线功能强制同步
## 赞助商
## 与主线版本版本差异
[![bigmodel.cn](https://assets.router-for.me/chinese-4.7.png)](https://www.bigmodel.cn/claude-code?ic=RRVJPB5SII)
- 新增 GitHub Copilot 支持OAuth 登录),由[em4go](https://github.com/em4go/CLIProxyAPI/tree/feature/github-copilot-auth)提供
- 新增 Kiro (AWS CodeWhisperer) 支持 (OAuth 登录), 由[fuko2935](https://github.com/fuko2935/CLIProxyAPI/tree/feature/kiro-integration)、[Ravens2121](https://github.com/Ravens2121/CLIProxyAPIPlus/)提供
本项目由 Z智谱 提供赞助, 他们通过 GLM CODING PLAN 对本项目提供技术支持。
## 新增功能 (Plus 增强版)
GLM CODING PLAN 是专为AI编码打造的订阅套餐每月最低仅需20元即可在十余款主流AI编码工具如 Claude Code、Cline、Roo Code 中畅享智谱旗舰模型GLM-4.7,为开发者提供顶尖的编码体验。
- **OAuth Web 认证**: 基于浏览器的 Kiro OAuth 登录,提供美观的 Web UI
- **请求限流器**: 内置请求限流,防止 API 滥用
- **后台令牌刷新**: 过期前 10 分钟自动刷新令牌
- **监控指标**: 请求指标收集,用于监控和调试
- **设备指纹**: 设备指纹生成,增强安全性
- **冷却管理**: 智能冷却机制,应对 API 速率限制
- **用量检查器**: 实时用量监控和配额管理
- **模型转换器**: 跨供应商的统一模型名称转换
- **UTF-8 流处理**: 改进的流式响应处理
智谱AI为本软件提供了特别优惠使用以下链接购买可以享受九折优惠https://www.bigmodel.cn/claude-code?ic=RRVJPB5SII
## Kiro 认证
---
### 命令行登录
<table>
<tbody>
<tr>
<td width="180"><a href="https://www.packyapi.com/register?aff=cliproxyapi"><img src="./assets/packycode.png" alt="PackyCode" width="150"></a></td>
<td>感谢 PackyCode 对本项目的赞助PackyCode 是一家可靠高效的 API 中转服务商,提供 Claude Code、Codex、Gemini 等多种服务的中转。PackyCode 为本软件用户提供了特别优惠:使用<a href="https://www.packyapi.com/register?aff=cliproxyapi">此链接</a>注册,并在充值时输入 "cliproxyapi" 优惠码即可享受九折优惠。</td>
</tr>
<tr>
<td width="180"><a href="https://www.aicodemirror.com/register?invitecode=TJNAIF"><img src="./assets/aicodemirror.png" alt="AICodeMirror" width="150"></a></td>
<td>感谢 AICodeMirror 赞助了本项目AICodeMirror 提供 Claude Code / Codex / Gemini CLI 官方高稳定中转服务支持企业级高并发、极速开票、7×24 专属技术支持。 Claude Code / Codex / Gemini 官方渠道低至 3.8 / 0.2 / 0.9 折充值更有折上折AICodeMirror 为 CLIProxyAPI 的用户提供了特别福利,通过<a href="https://www.aicodemirror.com/register?invitecode=TJNAIF">此链接</a>注册的用户可享受首充8折企业客户最高可享 7.5 折!</td>
</tr>
</tbody>
</table>
> **注意:** 由于 AWS Cognito 限制Google/GitHub 登录不可用于第三方应用。
**AWS Builder ID**(推荐):
## 功能特性
```bash
# 设备码流程
./CLIProxyAPI --kiro-aws-login
- 为 CLI 模型提供 OpenAI/Gemini/Claude/Codex 兼容的 API 端点
- 新增 OpenAI CodexGPT 系列支持OAuth 登录)
- 新增 Claude Code 支持OAuth 登录)
- 新增 Qwen Code 支持OAuth 登录)
- 新增 iFlow 支持OAuth 登录)
- 支持流式与非流式响应
- 函数调用/工具支持
- 多模态输入(文本、图片)
- 多账户支持与轮询负载均衡Gemini、OpenAI、Claude、Qwen 与 iFlow
- 简单的 CLI 身份验证流程Gemini、OpenAI、Claude、Qwen 与 iFlow
- 支持 Gemini AIStudio API 密钥
- 支持 AI Studio Build 多账户轮询
- 支持 Gemini CLI 多账户轮询
- 支持 Claude Code 多账户轮询
- 支持 Qwen Code 多账户轮询
- 支持 iFlow 多账户轮询
- 支持 OpenAI Codex 多账户轮询
- 通过配置接入上游 OpenAI 兼容提供商(例如 OpenRouter
- 可复用的 Go SDK`docs/sdk-usage_CN.md`
# 授权码流程
./CLIProxyAPI --kiro-aws-authcode
```
## 新手入门
**从 Kiro IDE 导入令牌:**
CLIProxyAPI 用户手册: [https://help.router-for.me/](https://help.router-for.me/cn/)
```bash
./CLIProxyAPI --kiro-import
```
## 管理 API 文档
获取令牌步骤:
请参见 [MANAGEMENT_API_CN.md](https://help.router-for.me/cn/management/api)
1. 打开 Kiro IDE使用 Google或 GitHub登录
2. 找到令牌文件:`~/.kiro/kiro-auth-token.json`
3. 运行:`./CLIProxyAPI --kiro-import`
## Amp CLI 支持
**AWS IAM Identity Center (IDC)**
CLIProxyAPI 已内置对 [Amp CLI](https://ampcode.com) 和 Amp IDE 扩展的支持,可让你使用自己的 Google/ChatGPT/Claude OAuth 订阅来配合 Amp 编码工具:
```bash
./CLIProxyAPI --kiro-idc-login --kiro-idc-start-url https://d-xxxxxxxxxx.awsapps.com/start
- 提供商路由别名,兼容 Amp 的 API 路径模式(`/api/provider/{provider}/v1...`
- 管理代理,处理 OAuth 认证和账号功能
- 智能模型回退与自动路由
- 以安全为先的设计,管理端点仅限 localhost
# 指定区域
./CLIProxyAPI --kiro-idc-login --kiro-idc-start-url https://d-xxxxxxxxxx.awsapps.com/start --kiro-idc-region us-west-2
```
**→ [Amp CLI 完整集成指南](https://help.router-for.me/cn/agent-client/amp-cli.html)**
**附加参数:**
## SDK 文档
| 参数 | 说明 |
|------|------|
| `--no-browser` | 不自动打开浏览器,打印 URL |
| `--no-incognito` | 使用已有浏览器会话Kiro 默认使用无痕模式),适用于需要已登录浏览器会话的企业 SSO 场景 |
| `--kiro-idc-start-url` | IDC Start URL`--kiro-idc-login` 必需) |
| `--kiro-idc-region` | IDC 区域(默认:`us-east-1` |
| `--kiro-idc-flow` | IDC 流程类型:`authcode`(默认)或 `device` |
- 使用文档:[docs/sdk-usage_CN.md](docs/sdk-usage_CN.md)
- 高级(执行器与翻译器):[docs/sdk-advanced_CN.md](docs/sdk-advanced_CN.md)
- 认证: [docs/sdk-access_CN.md](docs/sdk-access_CN.md)
- 凭据加载/更新: [docs/sdk-watcher_CN.md](docs/sdk-watcher_CN.md)
- 自定义 Provider 示例:`examples/custom-provider`
### 网页端 OAuth 登录
访问 Kiro OAuth 网页认证界面:
```
http://your-server:8080/v0/oauth/kiro
```
提供基于浏览器的 Kiro (AWS CodeWhisperer) OAuth 认证流程,支持:
- AWS Builder ID 登录
- AWS Identity Center (IDC) 登录
- 从 Kiro IDE 导入令牌
## Docker 快速部署
### 一键部署
```bash
# 创建部署目录
mkdir -p ~/cli-proxy && cd ~/cli-proxy
# 创建 docker-compose.yml
cat > docker-compose.yml << 'EOF'
services:
cli-proxy-api:
image: eceasy/cli-proxy-api-plus:latest
container_name: cli-proxy-api-plus
ports:
- "8317:8317"
volumes:
- ./config.yaml:/CLIProxyAPI/config.yaml
- ./auths:/root/.cli-proxy-api
- ./logs:/CLIProxyAPI/logs
restart: unless-stopped
EOF
# 下载示例配置
curl -o config.yaml https://raw.githubusercontent.com/router-for-me/CLIProxyAPIPlus/main/config.example.yaml
# 拉取并启动
docker compose pull && docker compose up -d
```
### 配置说明
启动前请编辑 `config.yaml`
```yaml
# 基本配置示例
server:
port: 8317
# 在此添加你的供应商配置
```
### 更新到最新版本
```bash
cd ~/cli-proxy
docker compose pull && docker compose up -d
```
## 贡献
欢迎贡献!请随时提交 Pull Request。
该项目仅接受第三方供应商支持的 Pull Request。任何非第三方供应商支持的 Pull Request 都将被拒绝
1. Fork 仓库
2. 创建您的功能分支(`git checkout -b feature/amazing-feature`
3. 提交您的更改(`git commit -m 'Add some amazing feature'`
4. 推送到分支(`git push origin feature/amazing-feature`
5. 打开 Pull Request
## 谁与我们在一起?
这些项目基于 CLIProxyAPI:
### [vibeproxy](https://github.com/automazeio/vibeproxy)
一个原生 macOS 菜单栏应用,让您可以使用 Claude Code & ChatGPT 订阅服务和 AI 编程工具,无需 API 密钥。
### [Subtitle Translator](https://github.com/VjayC/SRT-Subtitle-Translator-Validator)
一款基于浏览器的 SRT 字幕翻译工具,可通过 CLI 代理 API 使用您的 Gemini 订阅。内置自动验证与错误修正功能,无需 API 密钥。
### [CCS (Claude Code Switch)](https://github.com/kaitranntt/ccs)
CLI 封装器,用于通过 CLIProxyAPI OAuth 即时切换多个 Claude 账户和替代模型Gemini, Codex, Antigravity无需 API 密钥。
### [ProxyPal](https://github.com/heyhuynhgiabuu/proxypal)
基于 macOS 平台的原生 CLIProxyAPI GUI配置供应商、模型映射以及OAuth端点无需 API 密钥。
### [Quotio](https://github.com/nguyenphutrong/quotio)
原生 macOS 菜单栏应用,统一管理 Claude、Gemini、OpenAI、Qwen 和 Antigravity 订阅,提供实时配额追踪和智能自动故障转移,支持 Claude Code、OpenCode 和 Droid 等 AI 编程工具,无需 API 密钥。
### [CodMate](https://github.com/loocor/CodMate)
原生 macOS SwiftUI 应用,用于管理 CLI AI 会话Claude Code、Codex、Gemini CLI提供统一的提供商管理、Git 审查、项目组织、全局搜索和终端集成。集成 CLIProxyAPI 为 Codex、Claude、Gemini、Antigravity 和 Qwen Code 提供统一的 OAuth 认证,支持内置和第三方提供商通过单一代理端点重路由 - OAuth 提供商无需 API 密钥。
### [ProxyPilot](https://github.com/Finesssee/ProxyPilot)
原生 Windows CLIProxyAPI 分支,集成 TUI、系统托盘及多服务商 OAuth 认证,专为 AI 编程工具打造,无需 API 密钥。
### [Claude Proxy VSCode](https://github.com/uzhao/claude-proxy-vscode)
一款 VSCode 扩展,提供了在 VSCode 中快速切换 Claude Code 模型的功能,内置 CLIProxyAPI 作为其后端,支持后台自动启动和关闭。
### [ZeroLimit](https://github.com/0xtbug/zero-limit)
Windows 桌面应用,基于 Tauri + React 构建,用于通过 CLIProxyAPI 监控 AI 编程助手配额。支持跨 Gemini、Claude、OpenAI Codex 和 Antigravity 账户的使用量追踪,提供实时仪表盘、系统托盘集成和一键代理控制,无需 API 密钥。
### [CPA-XXX Panel](https://github.com/ferretgeek/CPA-X)
面向 CLIProxyAPI 的 Web 管理面板,提供健康检查、资源监控、日志查看、自动更新、请求统计与定价展示,支持一键安装与 systemd 服务。
### [CLIProxyAPI Tray](https://github.com/kitephp/CLIProxyAPI_Tray)
Windows 托盘应用,基于 PowerShell 脚本实现不依赖任何第三方库。主要功能包括自动创建快捷方式、静默运行、密码管理、通道切换Main / Plus以及自动下载与更新。
### [霖君](https://github.com/wangdabaoqq/LinJun)
霖君是一款用于管理AI编程助手的跨平台桌面应用支持macOS、Windows、Linux系统。统一管理Claude Code、Gemini CLI、OpenAI Codex、Qwen Code等AI编程工具本地代理实现多账户配额跟踪和一键配置。
### [CLIProxyAPI Dashboard](https://github.com/itsmylife44/cliproxyapi-dashboard)
一个面向 CLIProxyAPI 的现代化 Web 管理仪表盘,基于 Next.js、React 和 PostgreSQL 构建。支持实时日志流、结构化配置编辑、API Key 管理、Claude/Gemini/Codex 的 OAuth 提供方集成、使用量分析、容器管理,并可通过配套插件与 OpenCode 同步配置,无需手动编辑 YAML。
> [!NOTE]
> 如果你开发了基于 CLIProxyAPI 的项目,请提交一个 PR拉取请求将其添加到此列表中。
## 更多选择
以下项目是 CLIProxyAPI 的移植版或受其启发:
### [9Router](https://github.com/decolua/9router)
基于 Next.js 的实现,灵感来自 CLIProxyAPI易于安装使用自研格式转换OpenAI/Claude/Gemini/Ollama、组合系统与自动回退、多账户管理指数退避、Next.js Web 控制台,并支持 Cursor、Claude Code、Cline、RooCode 等 CLI 工具,无需 API 密钥。
### [OmniRoute](https://github.com/diegosouzapw/OmniRoute)
代码不止,创新不停。智能路由至免费及低成本 AI 模型,并支持自动故障转移。
OmniRoute 是一个面向多供应商大语言模型的 AI 网关:它提供兼容 OpenAI 的端点,具备智能路由、负载均衡、重试及回退机制。通过添加策略、速率限制、缓存和可观测性,确保推理过程既可靠又具备成本意识。
> [!NOTE]
> 如果你开发了 CLIProxyAPI 的移植或衍生项目,请提交 PR 将其添加到此列表中。
如果需要提交任何非第三方供应商支持的 Pull Request请提交到[主线](https://github.com/router-for-me/CLIProxyAPI)版本。
## 许可证
此项目根据 MIT 许可证授权 - 有关详细信息,请参阅 [LICENSE](LICENSE) 文件。
## 写给所有中国网友的
QQ 群188637136
Telegram 群https://t.me/CLIProxyAPI
此项目根据 MIT 许可证授权 - 有关详细信息,请参阅 [LICENSE](LICENSE) 文件。

View File

@@ -18,6 +18,7 @@ import (
"github.com/joho/godotenv"
configaccess "github.com/router-for-me/CLIProxyAPI/v6/internal/access/config_access"
"github.com/router-for-me/CLIProxyAPI/v6/internal/auth/kiro"
"github.com/router-for-me/CLIProxyAPI/v6/internal/buildinfo"
"github.com/router-for-me/CLIProxyAPI/v6/internal/cmd"
"github.com/router-for-me/CLIProxyAPI/v6/internal/config"
@@ -49,6 +50,19 @@ func init() {
buildinfo.BuildDate = BuildDate
}
// setKiroIncognitoMode sets the incognito browser mode for Kiro authentication.
// Kiro defaults to incognito mode for multi-account support.
// Users can explicitly override with --incognito or --no-incognito flags.
func setKiroIncognitoMode(cfg *config.Config, useIncognito, noIncognito bool) {
if useIncognito {
cfg.IncognitoBrowser = true
} else if noIncognito {
cfg.IncognitoBrowser = false
} else {
cfg.IncognitoBrowser = true // Kiro default
}
}
// main is the entry point of the application.
// It parses command-line flags, loads configuration, and starts the appropriate
// service based on the provided flags (login, codex-login, or server mode).
@@ -61,18 +75,31 @@ func main() {
var codexDeviceLogin bool
var claudeLogin bool
var qwenLogin bool
var kiloLogin bool
var iflowLogin bool
var iflowCookie bool
var noBrowser bool
var oauthCallbackPort int
var antigravityLogin bool
var kimiLogin bool
var kiroLogin bool
var kiroGoogleLogin bool
var kiroAWSLogin bool
var kiroAWSAuthCode bool
var kiroImport bool
var kiroIDCLogin bool
var kiroIDCStartURL string
var kiroIDCRegion string
var kiroIDCFlow string
var githubCopilotLogin bool
var projectID string
var vertexImport string
var configPath string
var password string
var tuiMode bool
var standalone bool
var noIncognito bool
var useIncognito bool
// Define command-line flags for different operation modes.
flag.BoolVar(&login, "login", false, "Login Google Account")
@@ -80,12 +107,25 @@ func main() {
flag.BoolVar(&codexDeviceLogin, "codex-device-login", false, "Login to Codex using device code flow")
flag.BoolVar(&claudeLogin, "claude-login", false, "Login to Claude using OAuth")
flag.BoolVar(&qwenLogin, "qwen-login", false, "Login to Qwen using OAuth")
flag.BoolVar(&kiloLogin, "kilo-login", false, "Login to Kilo AI using device flow")
flag.BoolVar(&iflowLogin, "iflow-login", false, "Login to iFlow using OAuth")
flag.BoolVar(&iflowCookie, "iflow-cookie", false, "Login to iFlow using Cookie")
flag.BoolVar(&noBrowser, "no-browser", false, "Don't open browser automatically for OAuth")
flag.IntVar(&oauthCallbackPort, "oauth-callback-port", 0, "Override OAuth callback port (defaults to provider-specific port)")
flag.BoolVar(&useIncognito, "incognito", false, "Open browser in incognito/private mode for OAuth (useful for multiple accounts)")
flag.BoolVar(&noIncognito, "no-incognito", false, "Force disable incognito mode (uses existing browser session)")
flag.BoolVar(&antigravityLogin, "antigravity-login", false, "Login to Antigravity using OAuth")
flag.BoolVar(&kimiLogin, "kimi-login", false, "Login to Kimi using OAuth")
flag.BoolVar(&kiroLogin, "kiro-login", false, "Login to Kiro using Google OAuth")
flag.BoolVar(&kiroGoogleLogin, "kiro-google-login", false, "Login to Kiro using Google OAuth (same as --kiro-login)")
flag.BoolVar(&kiroAWSLogin, "kiro-aws-login", false, "Login to Kiro using AWS Builder ID (device code flow)")
flag.BoolVar(&kiroAWSAuthCode, "kiro-aws-authcode", false, "Login to Kiro using AWS Builder ID (authorization code flow, better UX)")
flag.BoolVar(&kiroImport, "kiro-import", false, "Import Kiro token from Kiro IDE (~/.aws/sso/cache/kiro-auth-token.json)")
flag.BoolVar(&kiroIDCLogin, "kiro-idc-login", false, "Login to Kiro using IAM Identity Center (IDC)")
flag.StringVar(&kiroIDCStartURL, "kiro-idc-start-url", "", "IDC start URL (required with --kiro-idc-login)")
flag.StringVar(&kiroIDCRegion, "kiro-idc-region", "", "IDC region (default: us-east-1)")
flag.StringVar(&kiroIDCFlow, "kiro-idc-flow", "", "IDC flow type: authcode (default) or device")
flag.BoolVar(&githubCopilotLogin, "github-copilot-login", false, "Login to GitHub Copilot using device flow")
flag.StringVar(&projectID, "project_id", "", "Project ID (Gemini only, not required)")
flag.StringVar(&configPath, "config", DefaultConfigPath, "Configure File Path")
flag.StringVar(&vertexImport, "vertex-import", "", "Import Vertex service account key JSON file")
@@ -466,6 +506,9 @@ func main() {
} else if antigravityLogin {
// Handle Antigravity login
cmd.DoAntigravityLogin(cfg, options)
} else if githubCopilotLogin {
// Handle GitHub Copilot login
cmd.DoGitHubCopilotLogin(cfg, options)
} else if codexLogin {
// Handle Codex login
cmd.DoCodexLogin(cfg, options)
@@ -477,12 +520,48 @@ func main() {
cmd.DoClaudeLogin(cfg, options)
} else if qwenLogin {
cmd.DoQwenLogin(cfg, options)
} else if kiloLogin {
cmd.DoKiloLogin(cfg, options)
} else if iflowLogin {
cmd.DoIFlowLogin(cfg, options)
} else if iflowCookie {
cmd.DoIFlowCookieAuth(cfg, options)
} else if kimiLogin {
cmd.DoKimiLogin(cfg, options)
} else if kiroLogin {
// For Kiro auth, default to incognito mode for multi-account support
// Users can explicitly override with --no-incognito
// Note: This config mutation is safe - auth commands exit after completion
// and don't share config with StartService (which is in the else branch)
setKiroIncognitoMode(cfg, useIncognito, noIncognito)
kiro.InitFingerprintConfig(cfg)
cmd.DoKiroLogin(cfg, options)
} else if kiroGoogleLogin {
// For Kiro auth, default to incognito mode for multi-account support
// Users can explicitly override with --no-incognito
// Note: This config mutation is safe - auth commands exit after completion
setKiroIncognitoMode(cfg, useIncognito, noIncognito)
kiro.InitFingerprintConfig(cfg)
cmd.DoKiroGoogleLogin(cfg, options)
} else if kiroAWSLogin {
// For Kiro auth, default to incognito mode for multi-account support
// Users can explicitly override with --no-incognito
setKiroIncognitoMode(cfg, useIncognito, noIncognito)
kiro.InitFingerprintConfig(cfg)
cmd.DoKiroAWSLogin(cfg, options)
} else if kiroAWSAuthCode {
// For Kiro auth with authorization code flow (better UX)
setKiroIncognitoMode(cfg, useIncognito, noIncognito)
kiro.InitFingerprintConfig(cfg)
cmd.DoKiroAWSAuthCodeLogin(cfg, options)
} else if kiroImport {
kiro.InitFingerprintConfig(cfg)
cmd.DoKiroImport(cfg, options)
} else if kiroIDCLogin {
// For Kiro IDC auth, default to incognito mode for multi-account support
setKiroIncognitoMode(cfg, useIncognito, noIncognito)
kiro.InitFingerprintConfig(cfg)
cmd.DoKiroIDCLogin(cfg, options, kiroIDCStartURL, kiroIDCRegion, kiroIDCFlow)
} else {
// In cloud deploy mode without config file, just wait for shutdown signals
if isCloudDeploy && !configFileExists {
@@ -564,9 +643,15 @@ func main() {
}
}
} else {
// Start the main proxy service
managementasset.StartAutoUpdater(context.Background(), configFilePath)
cmd.StartService(cfg, configFilePath, password)
// Start the main proxy service
managementasset.StartAutoUpdater(context.Background(), configFilePath)
if cfg.AuthDir != "" {
kiro.InitializeAndStart(cfg.AuthDir, cfg)
defer kiro.StopGlobalRefreshManager()
}
cmd.StartService(cfg, configFilePath, password)
}
}
}

View File

@@ -1,6 +1,6 @@
# Server host/interface to bind to. Default is empty ("") to bind all interfaces (IPv4 + IPv6).
# Use "127.0.0.1" or "localhost" to restrict access to local machine only.
host: ""
host: ''
# Server port
port: 8317
@@ -8,8 +8,8 @@ port: 8317
# TLS settings for HTTPS. When enabled, the server listens with the provided certificate and key.
tls:
enable: false
cert: ""
key: ""
cert: ''
key: ''
# Management API settings
remote-management:
@@ -20,22 +20,22 @@ remote-management:
# Management key. If a plaintext value is provided here, it will be hashed on startup.
# All management requests (even from localhost) require this key.
# Leave empty to disable the Management API entirely (404 for all /v0/management routes).
secret-key: ""
secret-key: ''
# Disable the bundled management control panel asset download and HTTP route when true.
disable-control-panel: false
# GitHub repository for the management control panel. Accepts a repository URL or releases API URL.
panel-github-repository: "https://github.com/router-for-me/Cli-Proxy-API-Management-Center"
panel-github-repository: 'https://github.com/router-for-me/Cli-Proxy-API-Management-Center'
# Authentication directory (supports ~ for home directory)
auth-dir: "~/.cli-proxy-api"
auth-dir: '~/.cli-proxy-api'
# API keys for authentication
api-keys:
- "your-api-key-1"
- "your-api-key-2"
- "your-api-key-3"
- 'your-api-key-1'
- 'your-api-key-2'
- 'your-api-key-3'
# Enable debug logging
debug: false
@@ -43,11 +43,16 @@ debug: false
# Enable pprof HTTP debug server (host:port). Keep it bound to localhost for safety.
pprof:
enable: false
addr: "127.0.0.1:8316"
addr: '127.0.0.1:8316'
# When true, disable high-overhead HTTP middleware features to reduce per-request memory usage under high concurrency.
commercial-mode: false
# Open OAuth URLs in incognito/private browser mode.
# Useful when you want to login with a different account without logging out from your current session.
# Default: false (but Kiro auth defaults to true for multi-account support)
incognito-browser: true
# When true, write application logs to rotating files instead of stdout
logging-to-file: false
@@ -63,7 +68,7 @@ error-logs-max-files: 10
usage-statistics-enabled: false
# Proxy URL. Supports socks5/http/https protocols. Example: socks5://user:pass@192.168.1.1:1080/
proxy-url: ""
proxy-url: ''
# When true, unprefixed model requests only use credentials without a prefix (except when prefix == model name).
force-model-prefix: false
@@ -89,7 +94,7 @@ quota-exceeded:
# Routing strategy for selecting credentials when multiple match.
routing:
strategy: "round-robin" # round-robin (default), fill-first
strategy: 'round-robin' # round-robin (default), fill-first
# When true, enable authentication for the WebSocket API (/v1/ws).
ws-auth: false
@@ -173,6 +178,33 @@ nonstream-keepalive-interval: 0
# runtime-version: "v24.3.0"
# timeout: "600"
# Kiro (AWS CodeWhisperer) configuration
# Note: Kiro API currently only operates in us-east-1 region
#kiro:
# - token-file: "~/.aws/sso/cache/kiro-auth-token.json" # path to Kiro token file
# agent-task-type: "" # optional: "vibe" or empty (API default)
# start-url: "https://your-company.awsapps.com/start" # optional: IDC start URL (preset for login)
# region: "us-east-1" # optional: OIDC region for IDC login and token refresh
# - access-token: "aoaAAAAA..." # or provide tokens directly
# refresh-token: "aorAAAAA..."
# profile-arn: "arn:aws:codewhisperer:us-east-1:..."
# proxy-url: "socks5://proxy.example.com:1080" # optional: proxy override
# Kilocode (OAuth-based code assistant)
# Note: Kilocode uses OAuth device flow authentication.
# Use the CLI command: ./server --kilo-login
# This will save credentials to the auth directory (default: ~/.cli-proxy-api/)
# oauth-model-alias:
# kilo:
# - name: "minimax/minimax-m2.5:free"
# alias: "minimax-m2.5"
# - name: "z-ai/glm-5:free"
# alias: "glm-5"
# oauth-excluded-models:
# kilo:
# - "kilo-claude-opus-4-6" # exclude specific models (exact match)
# - "*:free" # wildcard matching suffix (e.g. all free models)
# OpenAI compatibility providers
# openai-compatibility:
# - name: "openrouter" # The name of the provider; it will be used in the user agent and other places.
@@ -238,10 +270,25 @@ nonstream-keepalive-interval: 0
# Global OAuth model name aliases (per channel)
# These aliases rename model IDs for both model listing and request routing.
# Supported channels: gemini-cli, vertex, aistudio, antigravity, claude, codex, qwen, iflow, kimi.
# Supported channels: gemini-cli, vertex, aistudio, antigravity, claude, codex, qwen, iflow, kiro, github-copilot, kimi.
# NOTE: Aliases do not apply to gemini-api-key, codex-api-key, claude-api-key, openai-compatibility, vertex-api-key, or ampcode.
# You can repeat the same name with different aliases to expose multiple client model names.
# oauth-model-alias:
# antigravity:
# - name: "rev19-uic3-1p"
# alias: "gemini-2.5-computer-use-preview-10-2025"
# - name: "gemini-3-pro-image"
# alias: "gemini-3-pro-image-preview"
# - name: "gemini-3-pro-high"
# alias: "gemini-3-pro-preview"
# - name: "gemini-3-flash"
# alias: "gemini-3-flash-preview"
# - name: "claude-sonnet-4-5"
# alias: "gemini-claude-sonnet-4-5"
# - name: "claude-sonnet-4-5-thinking"
# alias: "gemini-claude-sonnet-4-5-thinking"
# - name: "claude-opus-4-5-thinking"
# alias: "gemini-claude-opus-4-5-thinking"
# gemini-cli:
# - name: "gemini-2.5-pro" # original model name under this channel
# alias: "g2.5p" # client-visible alias
@@ -252,9 +299,6 @@ nonstream-keepalive-interval: 0
# aistudio:
# - name: "gemini-2.5-pro"
# alias: "g2.5p"
# antigravity:
# - name: "gemini-3-pro-high"
# alias: "gemini-3-pro-preview"
# claude:
# - name: "claude-sonnet-4-5-20250929"
# alias: "cs4.5"
@@ -270,8 +314,15 @@ nonstream-keepalive-interval: 0
# kimi:
# - name: "kimi-k2.5"
# alias: "k2.5"
# kiro:
# - name: "kiro-claude-opus-4-5"
# alias: "op45"
# github-copilot:
# - name: "gpt-5"
# alias: "copilot-gpt5"
# OAuth provider excluded models
# Supported channels: gemini-cli, vertex, aistudio, antigravity, claude, codex, qwen, iflow, kiro, github-copilot.
# oauth-excluded-models:
# gemini-cli:
# - "gemini-2.5-pro" # exclude specific models (exact match)
@@ -294,6 +345,10 @@ nonstream-keepalive-interval: 0
# - "tstars2.0"
# kimi:
# - "kimi-k2-thinking"
# kiro:
# - "kiro-claude-haiku-4-5"
# github-copilot:
# - "raptor-mini"
# Optional payload configuration
# payload:

View File

@@ -1,6 +1,6 @@
services:
cli-proxy-api:
image: ${CLI_PROXY_IMAGE:-eceasy/cli-proxy-api:latest}
image: ${CLI_PROXY_IMAGE:-eceasy/cli-proxy-api-plus:latest}
pull_policy: always
build:
context: .
@@ -9,7 +9,7 @@ services:
VERSION: ${VERSION:-dev}
COMMIT: ${COMMIT:-none}
BUILD_DATE: ${BUILD_DATE:-unknown}
container_name: cli-proxy-api
container_name: cli-proxy-api-plus
# env_file:
# - .env
environment:

5
go.mod
View File

@@ -9,6 +9,7 @@ require (
github.com/charmbracelet/bubbletea v1.3.10
github.com/charmbracelet/lipgloss v1.1.0
github.com/fsnotify/fsnotify v1.9.0
github.com/fxamacker/cbor/v2 v2.9.0
github.com/gin-gonic/gin v1.10.1
github.com/go-git/go-git/v6 v6.0.0-20251009132922-75a182125145
github.com/google/uuid v1.6.0
@@ -17,9 +18,9 @@ require (
github.com/joho/godotenv v1.5.1
github.com/klauspost/compress v1.17.4
github.com/minio/minio-go/v7 v7.0.66
github.com/pkg/browser v0.0.0-20240102092130-5ac0b6a4141c
github.com/refraction-networking/utls v1.8.2
github.com/sirupsen/logrus v1.9.3
github.com/skratchdot/open-golang v0.0.0-20200116055534-eef842397966
github.com/tidwall/gjson v1.18.0
github.com/tidwall/sjson v1.2.5
github.com/tiktoken-go/tokenizer v0.7.0
@@ -27,6 +28,7 @@ require (
golang.org/x/net v0.47.0
golang.org/x/oauth2 v0.30.0
golang.org/x/sync v0.18.0
golang.org/x/term v0.37.0
gopkg.in/natefinch/lumberjack.v2 v2.2.1
gopkg.in/yaml.v3 v3.0.1
)
@@ -90,6 +92,7 @@ require (
github.com/twitchyliquid64/golang-asm v0.15.1 // indirect
github.com/ugorji/go/codec v1.2.12 // indirect
github.com/xo/terminfo v0.0.0-20220910002029-abceb7e1c41e // indirect
github.com/x448/float16 v0.8.4 // indirect
golang.org/x/arch v0.8.0 // indirect
golang.org/x/sys v0.38.0 // indirect
golang.org/x/text v0.31.0 // indirect

9
go.sum
View File

@@ -61,6 +61,8 @@ github.com/erikgeiser/coninput v0.0.0-20211004153227-1c3628e74d0f h1:Y/CXytFA4m6
github.com/erikgeiser/coninput v0.0.0-20211004153227-1c3628e74d0f/go.mod h1:vw97MGsxSvLiUE2X8qFplwetxpGLQrlU1Q9AUEIzCaM=
github.com/fsnotify/fsnotify v1.9.0 h1:2Ml+OJNzbYCTzsxtv8vKSFD9PbJjmhYF14k/jKC7S9k=
github.com/fsnotify/fsnotify v1.9.0/go.mod h1:8jBTzvmWwFyi3Pb8djgCCO5IBqzKJ/Jwo8TRcHyHii0=
github.com/fxamacker/cbor/v2 v2.9.0 h1:NpKPmjDBgUfBms6tr6JZkTHtfFGcMKsw3eGcmD/sapM=
github.com/fxamacker/cbor/v2 v2.9.0/go.mod h1:vM4b+DJCtHn+zz7h3FFp/hDAI9WNWCsZj23V5ytsSxQ=
github.com/gabriel-vasile/mimetype v1.4.3 h1:in2uUcidCuFcDKtdcBxlR0rJ1+fsokWf+uqxgUFjbI0=
github.com/gabriel-vasile/mimetype v1.4.3/go.mod h1:d8uq/6HKRL6CGdk+aubisF/M5GcPfT7nKyLpA0lbSSk=
github.com/gin-contrib/sse v0.1.0 h1:Y/yl/+YNO8GZSjAhjMsSuLt29uWRFHdHYUb5lYOV9qE=
@@ -154,6 +156,8 @@ github.com/pelletier/go-toml/v2 v2.2.2 h1:aYUidT7k73Pcl9nb2gScu7NSrKCSHIDE89b3+6
github.com/pelletier/go-toml/v2 v2.2.2/go.mod h1:1t835xjRzz80PqgE6HHgN2JOsmgYu/h4qDAS4n929Rs=
github.com/pjbgf/sha1cd v0.5.0 h1:a+UkboSi1znleCDUNT3M5YxjOnN1fz2FhN48FlwCxs0=
github.com/pjbgf/sha1cd v0.5.0/go.mod h1:lhpGlyHLpQZoxMv8HcgXvZEhcGs0PG/vsZnEJ7H0iCM=
github.com/pkg/browser v0.0.0-20240102092130-5ac0b6a4141c h1:+mdjkGKdHQG3305AYmdv1U2eRNDiU2ErMBj1gwrq8eQ=
github.com/pkg/browser v0.0.0-20240102092130-5ac0b6a4141c/go.mod h1:7rwL4CYBLnjLxUqIJNnCWiEdr3bn6IUYi15bNlnbCCU=
github.com/pmezard/go-difflib v1.0.0 h1:4DBwDE0NGyQoBHbLQYPwSUPoCMWR5BEzIk/f1lZbAQM=
github.com/pmezard/go-difflib v1.0.0/go.mod h1:iKH77koFhYxTK1pcRnkKkqfTogsbg7gZNVY4sRDYZ/4=
github.com/refraction-networking/utls v1.8.2 h1:j4Q1gJj0xngdeH+Ox/qND11aEfhpgoEvV+S9iJ2IdQo=
@@ -168,8 +172,6 @@ github.com/sergi/go-diff v1.4.0 h1:n/SP9D5ad1fORl+llWyN+D6qoUETXNZARKjyY2/KVCw=
github.com/sergi/go-diff v1.4.0/go.mod h1:A0bzQcvG0E7Rwjx0REVgAGH58e96+X0MeOfepqsbeW4=
github.com/sirupsen/logrus v1.9.3 h1:dueUQJ1C2q9oE3F7wvmSGAaVtTmUizReu6fjN8uqzbQ=
github.com/sirupsen/logrus v1.9.3/go.mod h1:naHLuLoDiP4jHNo9R0sCBMtWGeIprob74mVsIT4qYEQ=
github.com/skratchdot/open-golang v0.0.0-20200116055534-eef842397966 h1:JIAuq3EEf9cgbU6AtGPK4CTG3Zf6CKMNqf0MHTggAUA=
github.com/skratchdot/open-golang v0.0.0-20200116055534-eef842397966/go.mod h1:sUM3LWHvSMaG192sy56D9F7CNvL7jUJVXoqM1QKLnog=
github.com/stretchr/objx v0.1.0/go.mod h1:HFkY916IF+rwdDfMAkV7OtwuqBVzrE8GR6GFx+wExME=
github.com/stretchr/objx v0.4.0/go.mod h1:YvHI0jy2hoMjB+UWwv71VJQ9isScKT/TqJzVSSt89Yw=
github.com/stretchr/objx v0.5.0/go.mod h1:Yh+to48EsGEfYuaHDzXPcE3xhTkx73EhmCGUpEOglKo=
@@ -201,6 +203,8 @@ github.com/ugorji/go/codec v1.2.12 h1:9LC83zGrHhuUA9l16C9AHXAqEV/2wBQ4nkvumAE65E
github.com/ugorji/go/codec v1.2.12/go.mod h1:UNopzCgEMSXjBc6AOMqYvWC1ktqTAfzJZUZgYf6w6lg=
github.com/xo/terminfo v0.0.0-20220910002029-abceb7e1c41e h1:JVG44RsyaB9T2KIHavMF/ppJZNG9ZpyihvCd0w101no=
github.com/xo/terminfo v0.0.0-20220910002029-abceb7e1c41e/go.mod h1:RbqR21r5mrJuqunuUZ/Dhy/avygyECGrLceyNeo4LiM=
github.com/x448/float16 v0.8.4 h1:qLwI1I70+NjRFUR3zs1JPUCgaCXSh3SW62uAKT1mSBM=
github.com/x448/float16 v0.8.4/go.mod h1:14CWIYCyZA/cWjXOioeEpHeN/83MdbZDRQHoFcYsOfg=
golang.org/x/arch v0.0.0-20210923205945-b76863e36670/go.mod h1:5om86z9Hs0C8fWVUuoMHwpExlXzs5Tkyp9hOrfG7pp8=
golang.org/x/arch v0.8.0 h1:3wRIsP3pM4yUptoR96otTUOXI367OS0+c9eeRi9doIc=
golang.org/x/arch v0.8.0/go.mod h1:FEVrYAQjsQXMVJ1nsMoVVXPZg6p2JE2mx8psSWTDQys=
@@ -216,6 +220,7 @@ golang.org/x/sync v0.18.0 h1:kr88TuHDroi+UVf+0hZnirlk8o8T+4MrK6mr60WkH/I=
golang.org/x/sync v0.18.0/go.mod h1:9KTHXmSnoGruLpwFjVSX0lNNA75CykiMECbovNTZqGI=
golang.org/x/sys v0.0.0-20210809222454-d867a43fc93e/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg=
golang.org/x/sys v0.0.0-20220715151400-c0bba94af5f8/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg=
golang.org/x/sys v0.1.0/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg=
golang.org/x/sys v0.6.0/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg=
golang.org/x/sys v0.38.0 h1:3yZWxaJjBmCWXqhN1qh02AkOnCQ1poK6oF+a7xWL6Gc=
golang.org/x/sys v0.38.0/go.mod h1:OgkHotnGiDImocRcuBABYBEXf8A9a87e/uXjp9XT3ks=

View File

@@ -1,6 +1,7 @@
package management
import (
"bytes"
"context"
"encoding/json"
"fmt"
@@ -11,13 +12,15 @@ import (
"strings"
"time"
"github.com/fxamacker/cbor/v2"
"github.com/gin-gonic/gin"
"github.com/router-for-me/CLIProxyAPI/v6/internal/runtime/geminicli"
coreauth "github.com/router-for-me/CLIProxyAPI/v6/sdk/cliproxy/auth"
log "github.com/sirupsen/logrus"
"golang.org/x/net/proxy"
"golang.org/x/oauth2"
"golang.org/x/oauth2/google"
"github.com/router-for-me/CLIProxyAPI/v6/internal/runtime/geminicli"
coreauth "github.com/router-for-me/CLIProxyAPI/v6/sdk/cliproxy/auth"
)
const defaultAPICallTimeout = 60 * time.Second
@@ -54,6 +57,7 @@ type apiCallResponse struct {
StatusCode int `json:"status_code"`
Header map[string][]string `json:"header"`
Body string `json:"body"`
Quota *QuotaSnapshots `json:"quota,omitempty"`
}
// APICall makes a generic HTTP request on behalf of the management API caller.
@@ -70,7 +74,7 @@ type apiCallResponse struct {
// - Authorization: Bearer <key>
// - X-Management-Key: <key>
//
// Request JSON:
// Request JSON (supports both application/json and application/cbor):
// - auth_index / authIndex / AuthIndex (optional):
// The credential "auth_index" from GET /v0/management/auth-files (or other endpoints returning it).
// If omitted or not found, credential-specific proxy/token substitution is skipped.
@@ -90,10 +94,14 @@ type apiCallResponse struct {
// 2. Global config proxy-url
// 3. Direct connect (environment proxies are not used)
//
// Response JSON (returned with HTTP 200 when the APICall itself succeeds):
// - status_code: Upstream HTTP status code.
// - header: Upstream response headers.
// - body: Upstream response body as string.
// Response (returned with HTTP 200 when the APICall itself succeeds):
//
// Format matches request Content-Type (application/json or application/cbor)
// - status_code: Upstream HTTP status code.
// - header: Upstream response headers.
// - body: Upstream response body as string.
// - quota (optional): For GitHub Copilot enterprise accounts, contains quota_snapshots
// with details for chat, completions, and premium_interactions.
//
// Example:
//
@@ -107,10 +115,28 @@ type apiCallResponse struct {
// -H "Content-Type: application/json" \
// -d '{"auth_index":"<AUTH_INDEX>","method":"POST","url":"https://api.example.com/v1/fetchAvailableModels","header":{"Authorization":"Bearer $TOKEN$","Content-Type":"application/json","User-Agent":"cliproxyapi"},"data":"{}"}'
func (h *Handler) APICall(c *gin.Context) {
// Detect content type
contentType := strings.ToLower(strings.TrimSpace(c.GetHeader("Content-Type")))
isCBOR := strings.Contains(contentType, "application/cbor")
var body apiCallRequest
if errBindJSON := c.ShouldBindJSON(&body); errBindJSON != nil {
c.JSON(http.StatusBadRequest, gin.H{"error": "invalid body"})
return
// Parse request body based on content type
if isCBOR {
rawBody, errRead := io.ReadAll(c.Request.Body)
if errRead != nil {
c.JSON(http.StatusBadRequest, gin.H{"error": "failed to read request body"})
return
}
if errUnmarshal := cbor.Unmarshal(rawBody, &body); errUnmarshal != nil {
c.JSON(http.StatusBadRequest, gin.H{"error": "invalid cbor body"})
return
}
} else {
if errBindJSON := c.ShouldBindJSON(&body); errBindJSON != nil {
c.JSON(http.StatusBadRequest, gin.H{"error": "invalid body"})
return
}
}
method := strings.ToUpper(strings.TrimSpace(body.Method))
@@ -164,9 +190,21 @@ func (h *Handler) APICall(c *gin.Context) {
reqHeaders[key] = strings.ReplaceAll(value, "$TOKEN$", token)
}
// When caller indicates CBOR in request headers, convert JSON string payload to CBOR bytes.
useCBORPayload := headerContainsValue(reqHeaders, "Content-Type", "application/cbor")
var requestBody io.Reader
if body.Data != "" {
requestBody = strings.NewReader(body.Data)
if useCBORPayload {
cborPayload, errEncode := encodeJSONStringToCBOR(body.Data)
if errEncode != nil {
c.JSON(http.StatusBadRequest, gin.H{"error": "invalid json data for cbor content-type"})
return
}
requestBody = bytes.NewReader(cborPayload)
} else {
requestBody = strings.NewReader(body.Data)
}
}
req, errNewRequest := http.NewRequestWithContext(c.Request.Context(), method, urlStr, requestBody)
@@ -209,11 +247,38 @@ func (h *Handler) APICall(c *gin.Context) {
return
}
c.JSON(http.StatusOK, apiCallResponse{
// For CBOR upstream responses, decode into plain text or JSON string before returning.
responseBodyText := string(respBody)
if headerContainsValue(reqHeaders, "Accept", "application/cbor") || strings.Contains(strings.ToLower(resp.Header.Get("Content-Type")), "application/cbor") {
if decodedBody, errDecode := decodeCBORBodyToTextOrJSON(respBody); errDecode == nil {
responseBodyText = decodedBody
}
}
response := apiCallResponse{
StatusCode: resp.StatusCode,
Header: resp.Header,
Body: string(respBody),
})
Body: responseBodyText,
}
// If this is a GitHub Copilot token endpoint response, try to enrich with quota information
if resp.StatusCode == http.StatusOK &&
strings.Contains(urlStr, "copilot_internal") &&
strings.Contains(urlStr, "/token") {
response = h.enrichCopilotTokenResponse(c.Request.Context(), response, auth, urlStr)
}
// Return response in the same format as the request
if isCBOR {
cborData, errMarshal := cbor.Marshal(response)
if errMarshal != nil {
c.JSON(http.StatusInternalServerError, gin.H{"error": "failed to encode cbor response"})
return
}
c.Data(http.StatusOK, "application/cbor", cborData)
} else {
c.JSON(http.StatusOK, response)
}
}
func firstNonEmptyString(values ...*string) string {
@@ -702,3 +767,421 @@ func buildProxyTransport(proxyStr string) *http.Transport {
log.Debugf("unsupported proxy scheme: %s", proxyURL.Scheme)
return nil
}
// headerContainsValue checks whether a header map contains a target value (case-insensitive key and value).
func headerContainsValue(headers map[string]string, targetKey, targetValue string) bool {
if len(headers) == 0 {
return false
}
for key, value := range headers {
if !strings.EqualFold(strings.TrimSpace(key), strings.TrimSpace(targetKey)) {
continue
}
if strings.Contains(strings.ToLower(value), strings.ToLower(strings.TrimSpace(targetValue))) {
return true
}
}
return false
}
// encodeJSONStringToCBOR converts a JSON string payload into CBOR bytes.
func encodeJSONStringToCBOR(jsonString string) ([]byte, error) {
var payload any
if errUnmarshal := json.Unmarshal([]byte(jsonString), &payload); errUnmarshal != nil {
return nil, errUnmarshal
}
return cbor.Marshal(payload)
}
// decodeCBORBodyToTextOrJSON decodes CBOR bytes to plain text (for string payloads) or JSON string.
func decodeCBORBodyToTextOrJSON(raw []byte) (string, error) {
if len(raw) == 0 {
return "", nil
}
var payload any
if errUnmarshal := cbor.Unmarshal(raw, &payload); errUnmarshal != nil {
return "", errUnmarshal
}
jsonCompatible := cborValueToJSONCompatible(payload)
switch typed := jsonCompatible.(type) {
case string:
return typed, nil
case []byte:
return string(typed), nil
default:
jsonBytes, errMarshal := json.Marshal(jsonCompatible)
if errMarshal != nil {
return "", errMarshal
}
return string(jsonBytes), nil
}
}
// cborValueToJSONCompatible recursively converts CBOR-decoded values into JSON-marshalable values.
func cborValueToJSONCompatible(value any) any {
switch typed := value.(type) {
case map[any]any:
out := make(map[string]any, len(typed))
for key, item := range typed {
out[fmt.Sprint(key)] = cborValueToJSONCompatible(item)
}
return out
case map[string]any:
out := make(map[string]any, len(typed))
for key, item := range typed {
out[key] = cborValueToJSONCompatible(item)
}
return out
case []any:
out := make([]any, len(typed))
for i, item := range typed {
out[i] = cborValueToJSONCompatible(item)
}
return out
default:
return typed
}
}
// QuotaDetail represents quota information for a specific resource type
type QuotaDetail struct {
Entitlement float64 `json:"entitlement"`
OverageCount float64 `json:"overage_count"`
OveragePermitted bool `json:"overage_permitted"`
PercentRemaining float64 `json:"percent_remaining"`
QuotaID string `json:"quota_id"`
QuotaRemaining float64 `json:"quota_remaining"`
Remaining float64 `json:"remaining"`
Unlimited bool `json:"unlimited"`
}
// QuotaSnapshots contains quota details for different resource types
type QuotaSnapshots struct {
Chat QuotaDetail `json:"chat"`
Completions QuotaDetail `json:"completions"`
PremiumInteractions QuotaDetail `json:"premium_interactions"`
}
// CopilotUsageResponse represents the GitHub Copilot usage information
type CopilotUsageResponse struct {
AccessTypeSKU string `json:"access_type_sku"`
AnalyticsTrackingID string `json:"analytics_tracking_id"`
AssignedDate string `json:"assigned_date"`
CanSignupForLimited bool `json:"can_signup_for_limited"`
ChatEnabled bool `json:"chat_enabled"`
CopilotPlan string `json:"copilot_plan"`
OrganizationLoginList []interface{} `json:"organization_login_list"`
OrganizationList []interface{} `json:"organization_list"`
QuotaResetDate string `json:"quota_reset_date"`
QuotaSnapshots QuotaSnapshots `json:"quota_snapshots"`
}
type copilotQuotaRequest struct {
AuthIndexSnake *string `json:"auth_index"`
AuthIndexCamel *string `json:"authIndex"`
AuthIndexPascal *string `json:"AuthIndex"`
}
// GetCopilotQuota fetches GitHub Copilot quota information from the /copilot_internal/user endpoint.
//
// Endpoint:
//
// GET /v0/management/copilot-quota
//
// Query Parameters (optional):
// - auth_index: The credential "auth_index" from GET /v0/management/auth-files.
// If omitted, uses the first available GitHub Copilot credential.
//
// Response:
//
// Returns the CopilotUsageResponse with quota_snapshots containing detailed quota information
// for chat, completions, and premium_interactions.
//
// Example:
//
// curl -sS -X GET "http://127.0.0.1:8317/v0/management/copilot-quota?auth_index=<AUTH_INDEX>" \
// -H "Authorization: Bearer <MANAGEMENT_KEY>"
func (h *Handler) GetCopilotQuota(c *gin.Context) {
authIndex := strings.TrimSpace(c.Query("auth_index"))
if authIndex == "" {
authIndex = strings.TrimSpace(c.Query("authIndex"))
}
if authIndex == "" {
authIndex = strings.TrimSpace(c.Query("AuthIndex"))
}
auth := h.findCopilotAuth(authIndex)
if auth == nil {
c.JSON(http.StatusBadRequest, gin.H{"error": "no github copilot credential found"})
return
}
token, tokenErr := h.resolveTokenForAuth(c.Request.Context(), auth)
if tokenErr != nil {
c.JSON(http.StatusBadRequest, gin.H{"error": "failed to refresh copilot token"})
return
}
if token == "" {
c.JSON(http.StatusBadRequest, gin.H{"error": "copilot token not found"})
return
}
apiURL := "https://api.github.com/copilot_internal/user"
req, errNewRequest := http.NewRequestWithContext(c.Request.Context(), http.MethodGet, apiURL, nil)
if errNewRequest != nil {
c.JSON(http.StatusInternalServerError, gin.H{"error": "failed to build request"})
return
}
req.Header.Set("Authorization", "Bearer "+token)
req.Header.Set("User-Agent", "CLIProxyAPIPlus")
req.Header.Set("Accept", "application/json")
httpClient := &http.Client{
Timeout: defaultAPICallTimeout,
Transport: h.apiCallTransport(auth),
}
resp, errDo := httpClient.Do(req)
if errDo != nil {
log.WithError(errDo).Debug("copilot quota request failed")
c.JSON(http.StatusBadGateway, gin.H{"error": "request failed"})
return
}
defer func() {
if errClose := resp.Body.Close(); errClose != nil {
log.Errorf("response body close error: %v", errClose)
}
}()
respBody, errReadAll := io.ReadAll(resp.Body)
if errReadAll != nil {
c.JSON(http.StatusBadGateway, gin.H{"error": "failed to read response"})
return
}
if resp.StatusCode != http.StatusOK {
c.JSON(http.StatusBadGateway, gin.H{
"error": "github api request failed",
"status_code": resp.StatusCode,
"body": string(respBody),
})
return
}
var usage CopilotUsageResponse
if errUnmarshal := json.Unmarshal(respBody, &usage); errUnmarshal != nil {
c.JSON(http.StatusInternalServerError, gin.H{"error": "failed to parse response"})
return
}
c.JSON(http.StatusOK, usage)
}
// findCopilotAuth locates a GitHub Copilot credential by auth_index or returns the first available one
func (h *Handler) findCopilotAuth(authIndex string) *coreauth.Auth {
if h == nil || h.authManager == nil {
return nil
}
auths := h.authManager.List()
var firstCopilot *coreauth.Auth
for _, auth := range auths {
if auth == nil {
continue
}
provider := strings.ToLower(strings.TrimSpace(auth.Provider))
if provider != "copilot" && provider != "github" && provider != "github-copilot" {
continue
}
if firstCopilot == nil {
firstCopilot = auth
}
if authIndex != "" {
auth.EnsureIndex()
if auth.Index == authIndex {
return auth
}
}
}
return firstCopilot
}
// enrichCopilotTokenResponse fetches quota information and adds it to the Copilot token response body
func (h *Handler) enrichCopilotTokenResponse(ctx context.Context, response apiCallResponse, auth *coreauth.Auth, originalURL string) apiCallResponse {
if auth == nil || response.Body == "" {
return response
}
// Parse the token response to check if it's enterprise (null limited_user_quotas)
var tokenResp map[string]interface{}
if err := json.Unmarshal([]byte(response.Body), &tokenResp); err != nil {
log.WithError(err).Debug("enrichCopilotTokenResponse: failed to parse copilot token response")
return response
}
// Get the GitHub token to call the copilot_internal/user endpoint
token, tokenErr := h.resolveTokenForAuth(ctx, auth)
if tokenErr != nil {
log.WithError(tokenErr).Debug("enrichCopilotTokenResponse: failed to resolve token")
return response
}
if token == "" {
return response
}
// Fetch quota information from /copilot_internal/user
// Derive the base URL from the original token request to support proxies and test servers
parsedURL, errParse := url.Parse(originalURL)
if errParse != nil {
log.WithError(errParse).Debug("enrichCopilotTokenResponse: failed to parse URL")
return response
}
quotaURL := fmt.Sprintf("%s://%s/copilot_internal/user", parsedURL.Scheme, parsedURL.Host)
req, errNewRequest := http.NewRequestWithContext(ctx, http.MethodGet, quotaURL, nil)
if errNewRequest != nil {
log.WithError(errNewRequest).Debug("enrichCopilotTokenResponse: failed to build request")
return response
}
req.Header.Set("Authorization", "Bearer "+token)
req.Header.Set("User-Agent", "CLIProxyAPIPlus")
req.Header.Set("Accept", "application/json")
httpClient := &http.Client{
Timeout: defaultAPICallTimeout,
Transport: h.apiCallTransport(auth),
}
quotaResp, errDo := httpClient.Do(req)
if errDo != nil {
log.WithError(errDo).Debug("enrichCopilotTokenResponse: quota fetch HTTP request failed")
return response
}
defer func() {
if errClose := quotaResp.Body.Close(); errClose != nil {
log.Errorf("quota response body close error: %v", errClose)
}
}()
if quotaResp.StatusCode != http.StatusOK {
return response
}
quotaBody, errReadAll := io.ReadAll(quotaResp.Body)
if errReadAll != nil {
log.WithError(errReadAll).Debug("enrichCopilotTokenResponse: failed to read response")
return response
}
// Parse the quota response
var quotaData CopilotUsageResponse
if err := json.Unmarshal(quotaBody, &quotaData); err != nil {
log.WithError(err).Debug("enrichCopilotTokenResponse: failed to parse response")
return response
}
// Check if this is an enterprise account by looking for quota_snapshots in the response
// Enterprise accounts have quota_snapshots, non-enterprise have limited_user_quotas
var quotaRaw map[string]interface{}
if err := json.Unmarshal(quotaBody, &quotaRaw); err == nil {
if _, hasQuotaSnapshots := quotaRaw["quota_snapshots"]; hasQuotaSnapshots {
// Enterprise account - has quota_snapshots
tokenResp["quota_snapshots"] = quotaData.QuotaSnapshots
tokenResp["access_type_sku"] = quotaData.AccessTypeSKU
tokenResp["copilot_plan"] = quotaData.CopilotPlan
// Add quota reset date for enterprise (quota_reset_date_utc)
if quotaResetDateUTC, ok := quotaRaw["quota_reset_date_utc"]; ok {
tokenResp["quota_reset_date"] = quotaResetDateUTC
} else if quotaData.QuotaResetDate != "" {
tokenResp["quota_reset_date"] = quotaData.QuotaResetDate
}
} else {
// Non-enterprise account - build quota from limited_user_quotas and monthly_quotas
var quotaSnapshots QuotaSnapshots
// Get monthly quotas (total entitlement) and limited_user_quotas (remaining)
monthlyQuotas, hasMonthly := quotaRaw["monthly_quotas"].(map[string]interface{})
limitedQuotas, hasLimited := quotaRaw["limited_user_quotas"].(map[string]interface{})
// Process chat quota
if hasMonthly && hasLimited {
if chatTotal, ok := monthlyQuotas["chat"].(float64); ok {
chatRemaining := chatTotal // default to full if no limited quota
if chatLimited, ok := limitedQuotas["chat"].(float64); ok {
chatRemaining = chatLimited
}
percentRemaining := 0.0
if chatTotal > 0 {
percentRemaining = (chatRemaining / chatTotal) * 100.0
}
quotaSnapshots.Chat = QuotaDetail{
Entitlement: chatTotal,
Remaining: chatRemaining,
QuotaRemaining: chatRemaining,
PercentRemaining: percentRemaining,
QuotaID: "chat",
Unlimited: false,
}
}
// Process completions quota
if completionsTotal, ok := monthlyQuotas["completions"].(float64); ok {
completionsRemaining := completionsTotal // default to full if no limited quota
if completionsLimited, ok := limitedQuotas["completions"].(float64); ok {
completionsRemaining = completionsLimited
}
percentRemaining := 0.0
if completionsTotal > 0 {
percentRemaining = (completionsRemaining / completionsTotal) * 100.0
}
quotaSnapshots.Completions = QuotaDetail{
Entitlement: completionsTotal,
Remaining: completionsRemaining,
QuotaRemaining: completionsRemaining,
PercentRemaining: percentRemaining,
QuotaID: "completions",
Unlimited: false,
}
}
}
// Premium interactions don't exist for non-enterprise, leave as zero values
quotaSnapshots.PremiumInteractions = QuotaDetail{
QuotaID: "premium_interactions",
Unlimited: false,
}
// Add quota_snapshots to the token response
tokenResp["quota_snapshots"] = quotaSnapshots
tokenResp["access_type_sku"] = quotaData.AccessTypeSKU
tokenResp["copilot_plan"] = quotaData.CopilotPlan
// Add quota reset date for non-enterprise (limited_user_reset_date)
if limitedResetDate, ok := quotaRaw["limited_user_reset_date"]; ok {
tokenResp["quota_reset_date"] = limitedResetDate
}
}
}
// Re-serialize the enriched response
enrichedBody, errMarshal := json.Marshal(tokenResp)
if errMarshal != nil {
log.WithError(errMarshal).Debug("failed to marshal enriched response")
return response
}
response.Body = string(enrichedBody)
return response
}

View File

@@ -0,0 +1,149 @@
package management
import (
"bytes"
"encoding/json"
"net/http"
"net/http/httptest"
"testing"
"github.com/fxamacker/cbor/v2"
"github.com/gin-gonic/gin"
)
func TestAPICall_CBOR_Support(t *testing.T) {
gin.SetMode(gin.TestMode)
// Create a test handler
h := &Handler{}
// Create test request data
reqData := apiCallRequest{
Method: "GET",
URL: "https://httpbin.org/get",
Header: map[string]string{
"User-Agent": "test-client",
},
}
t.Run("JSON request and response", func(t *testing.T) {
// Marshal request as JSON
jsonData, err := json.Marshal(reqData)
if err != nil {
t.Fatalf("Failed to marshal JSON: %v", err)
}
// Create HTTP request
req := httptest.NewRequest(http.MethodPost, "/v0/management/api-call", bytes.NewReader(jsonData))
req.Header.Set("Content-Type", "application/json")
// Create response recorder
w := httptest.NewRecorder()
// Create Gin context
c, _ := gin.CreateTestContext(w)
c.Request = req
// Call handler
h.APICall(c)
// Verify response
if w.Code != http.StatusOK && w.Code != http.StatusBadGateway {
t.Logf("Response status: %d", w.Code)
t.Logf("Response body: %s", w.Body.String())
}
// Check content type
contentType := w.Header().Get("Content-Type")
if w.Code == http.StatusOK && !contains(contentType, "application/json") {
t.Errorf("Expected JSON response, got: %s", contentType)
}
})
t.Run("CBOR request and response", func(t *testing.T) {
// Marshal request as CBOR
cborData, err := cbor.Marshal(reqData)
if err != nil {
t.Fatalf("Failed to marshal CBOR: %v", err)
}
// Create HTTP request
req := httptest.NewRequest(http.MethodPost, "/v0/management/api-call", bytes.NewReader(cborData))
req.Header.Set("Content-Type", "application/cbor")
// Create response recorder
w := httptest.NewRecorder()
// Create Gin context
c, _ := gin.CreateTestContext(w)
c.Request = req
// Call handler
h.APICall(c)
// Verify response
if w.Code != http.StatusOK && w.Code != http.StatusBadGateway {
t.Logf("Response status: %d", w.Code)
t.Logf("Response body: %s", w.Body.String())
}
// Check content type
contentType := w.Header().Get("Content-Type")
if w.Code == http.StatusOK && !contains(contentType, "application/cbor") {
t.Errorf("Expected CBOR response, got: %s", contentType)
}
// Try to decode CBOR response
if w.Code == http.StatusOK {
var response apiCallResponse
if err := cbor.Unmarshal(w.Body.Bytes(), &response); err != nil {
t.Errorf("Failed to unmarshal CBOR response: %v", err)
} else {
t.Logf("CBOR response decoded successfully: status_code=%d", response.StatusCode)
}
}
})
t.Run("CBOR encoding and decoding consistency", func(t *testing.T) {
// Test data
testReq := apiCallRequest{
Method: "POST",
URL: "https://example.com/api",
Header: map[string]string{
"Authorization": "Bearer $TOKEN$",
"Content-Type": "application/json",
},
Data: `{"key":"value"}`,
}
// Encode to CBOR
cborData, err := cbor.Marshal(testReq)
if err != nil {
t.Fatalf("Failed to marshal to CBOR: %v", err)
}
// Decode from CBOR
var decoded apiCallRequest
if err := cbor.Unmarshal(cborData, &decoded); err != nil {
t.Fatalf("Failed to unmarshal from CBOR: %v", err)
}
// Verify fields
if decoded.Method != testReq.Method {
t.Errorf("Method mismatch: got %s, want %s", decoded.Method, testReq.Method)
}
if decoded.URL != testReq.URL {
t.Errorf("URL mismatch: got %s, want %s", decoded.URL, testReq.URL)
}
if decoded.Data != testReq.Data {
t.Errorf("Data mismatch: got %s, want %s", decoded.Data, testReq.Data)
}
if len(decoded.Header) != len(testReq.Header) {
t.Errorf("Header count mismatch: got %d, want %d", len(decoded.Header), len(testReq.Header))
}
})
}
func contains(s, substr string) bool {
return len(s) > 0 && len(substr) > 0 && (s == substr || len(s) >= len(substr) && s[:len(substr)] == substr || bytes.Contains([]byte(s), []byte(substr)))
}

View File

@@ -3,7 +3,9 @@ package management
import (
"bytes"
"context"
"crypto/rand"
"crypto/sha256"
"encoding/base64"
"encoding/hex"
"encoding/json"
"errors"
@@ -11,6 +13,7 @@ import (
"io"
"net"
"net/http"
"net/url"
"os"
"path/filepath"
"sort"
@@ -23,9 +26,12 @@ import (
"github.com/router-for-me/CLIProxyAPI/v6/internal/auth/antigravity"
"github.com/router-for-me/CLIProxyAPI/v6/internal/auth/claude"
"github.com/router-for-me/CLIProxyAPI/v6/internal/auth/codex"
"github.com/router-for-me/CLIProxyAPI/v6/internal/auth/copilot"
geminiAuth "github.com/router-for-me/CLIProxyAPI/v6/internal/auth/gemini"
iflowauth "github.com/router-for-me/CLIProxyAPI/v6/internal/auth/iflow"
"github.com/router-for-me/CLIProxyAPI/v6/internal/auth/kilo"
"github.com/router-for-me/CLIProxyAPI/v6/internal/auth/kimi"
kiroauth "github.com/router-for-me/CLIProxyAPI/v6/internal/auth/kiro"
"github.com/router-for-me/CLIProxyAPI/v6/internal/auth/qwen"
"github.com/router-for-me/CLIProxyAPI/v6/internal/interfaces"
"github.com/router-for-me/CLIProxyAPI/v6/internal/misc"
@@ -1918,6 +1924,99 @@ func (h *Handler) RequestIFlowToken(c *gin.Context) {
c.JSON(http.StatusOK, gin.H{"status": "ok", "url": authURL, "state": state})
}
func (h *Handler) RequestGitHubToken(c *gin.Context) {
ctx := context.Background()
fmt.Println("Initializing GitHub Copilot authentication...")
state := fmt.Sprintf("gh-%d", time.Now().UnixNano())
// Initialize Copilot auth service
deviceClient := copilot.NewDeviceFlowClient(h.cfg)
// Initiate device flow
deviceCode, err := deviceClient.RequestDeviceCode(ctx)
if err != nil {
log.Errorf("Failed to initiate device flow: %v", err)
c.JSON(http.StatusInternalServerError, gin.H{"error": "failed to initiate device flow"})
return
}
authURL := deviceCode.VerificationURI
userCode := deviceCode.UserCode
RegisterOAuthSession(state, "github-copilot")
go func() {
fmt.Printf("Please visit %s and enter code: %s\n", authURL, userCode)
tokenData, errPoll := deviceClient.PollForToken(ctx, deviceCode)
if errPoll != nil {
SetOAuthSessionError(state, "Authentication failed")
fmt.Printf("Authentication failed: %v\n", errPoll)
return
}
userInfo, errUser := deviceClient.FetchUserInfo(ctx, tokenData.AccessToken)
if errUser != nil {
log.Warnf("Failed to fetch user info: %v", errUser)
}
username := userInfo.Login
if username == "" {
username = "github-user"
}
tokenStorage := &copilot.CopilotTokenStorage{
AccessToken: tokenData.AccessToken,
TokenType: tokenData.TokenType,
Scope: tokenData.Scope,
Username: username,
Email: userInfo.Email,
Name: userInfo.Name,
Type: "github-copilot",
}
fileName := fmt.Sprintf("github-copilot-%s.json", username)
label := userInfo.Email
if label == "" {
label = username
}
record := &coreauth.Auth{
ID: fileName,
Provider: "github-copilot",
Label: label,
FileName: fileName,
Storage: tokenStorage,
Metadata: map[string]any{
"email": userInfo.Email,
"username": username,
"name": userInfo.Name,
},
}
savedPath, errSave := h.saveTokenRecord(ctx, record)
if errSave != nil {
log.Errorf("Failed to save authentication tokens: %v", errSave)
SetOAuthSessionError(state, "Failed to save authentication tokens")
return
}
fmt.Printf("Authentication successful! Token saved to %s\n", savedPath)
fmt.Println("You can now use GitHub Copilot services through this CLI")
CompleteOAuthSession(state)
CompleteOAuthSessionsByProvider("github-copilot")
}()
c.JSON(200, gin.H{
"status": "ok",
"url": authURL,
"state": state,
"user_code": userCode,
"verification_uri": authURL,
})
}
func (h *Handler) RequestIFlowCookieToken(c *gin.Context) {
ctx := context.Background()
@@ -2422,6 +2521,25 @@ func (h *Handler) GetAuthStatus(c *gin.Context) {
return
}
if status != "" {
if strings.HasPrefix(status, "device_code|") {
parts := strings.SplitN(status, "|", 3)
if len(parts) == 3 {
c.JSON(http.StatusOK, gin.H{
"status": "device_code",
"verification_url": parts[1],
"user_code": parts[2],
})
return
}
}
if strings.HasPrefix(status, "auth_url|") {
authURL := strings.TrimPrefix(status, "auth_url|")
c.JSON(http.StatusOK, gin.H{
"status": "auth_url",
"url": authURL,
})
return
}
c.JSON(http.StatusOK, gin.H{"status": "error", "error": status})
return
}
@@ -2436,3 +2554,382 @@ func PopulateAuthContext(ctx context.Context, c *gin.Context) context.Context {
}
return coreauth.WithRequestInfo(ctx, info)
}
const kiroCallbackPort = 9876
func (h *Handler) RequestKiroToken(c *gin.Context) {
ctx := context.Background()
// Get the login method from query parameter (default: aws for device code flow)
method := strings.ToLower(strings.TrimSpace(c.Query("method")))
if method == "" {
method = "aws"
}
fmt.Println("Initializing Kiro authentication...")
state := fmt.Sprintf("kiro-%d", time.Now().UnixNano())
switch method {
case "aws", "builder-id":
RegisterOAuthSession(state, "kiro")
// AWS Builder ID uses device code flow (no callback needed)
go func() {
ssoClient := kiroauth.NewSSOOIDCClient(h.cfg)
// Step 1: Register client
fmt.Println("Registering client...")
regResp, errRegister := ssoClient.RegisterClient(ctx)
if errRegister != nil {
log.Errorf("Failed to register client: %v", errRegister)
SetOAuthSessionError(state, "Failed to register client")
return
}
// Step 2: Start device authorization
fmt.Println("Starting device authorization...")
authResp, errAuth := ssoClient.StartDeviceAuthorization(ctx, regResp.ClientID, regResp.ClientSecret)
if errAuth != nil {
log.Errorf("Failed to start device auth: %v", errAuth)
SetOAuthSessionError(state, "Failed to start device authorization")
return
}
// Store the verification URL for the frontend to display.
// Using "|" as separator because URLs contain ":".
SetOAuthSessionError(state, "device_code|"+authResp.VerificationURIComplete+"|"+authResp.UserCode)
// Step 3: Poll for token
fmt.Println("Waiting for authorization...")
interval := 5 * time.Second
if authResp.Interval > 0 {
interval = time.Duration(authResp.Interval) * time.Second
}
deadline := time.Now().Add(time.Duration(authResp.ExpiresIn) * time.Second)
for time.Now().Before(deadline) {
select {
case <-ctx.Done():
SetOAuthSessionError(state, "Authorization cancelled")
return
case <-time.After(interval):
tokenResp, errToken := ssoClient.CreateToken(ctx, regResp.ClientID, regResp.ClientSecret, authResp.DeviceCode)
if errToken != nil {
errStr := errToken.Error()
if strings.Contains(errStr, "authorization_pending") {
continue
}
if strings.Contains(errStr, "slow_down") {
interval += 5 * time.Second
continue
}
log.Errorf("Token creation failed: %v", errToken)
SetOAuthSessionError(state, "Token creation failed")
return
}
// Success! Save the token
expiresAt := time.Now().Add(time.Duration(tokenResp.ExpiresIn) * time.Second)
email := kiroauth.ExtractEmailFromJWT(tokenResp.AccessToken)
idPart := kiroauth.SanitizeEmailForFilename(email)
if idPart == "" {
idPart = fmt.Sprintf("%d", time.Now().UnixNano()%100000)
}
now := time.Now()
fileName := fmt.Sprintf("kiro-aws-%s.json", idPart)
record := &coreauth.Auth{
ID: fileName,
Provider: "kiro",
FileName: fileName,
Metadata: map[string]any{
"type": "kiro",
"access_token": tokenResp.AccessToken,
"refresh_token": tokenResp.RefreshToken,
"expires_at": expiresAt.Format(time.RFC3339),
"auth_method": "builder-id",
"provider": "AWS",
"client_id": regResp.ClientID,
"client_secret": regResp.ClientSecret,
"email": email,
"last_refresh": now.Format(time.RFC3339),
},
}
savedPath, errSave := h.saveTokenRecord(ctx, record)
if errSave != nil {
log.Errorf("Failed to save authentication tokens: %v", errSave)
SetOAuthSessionError(state, "Failed to save authentication tokens")
return
}
fmt.Printf("Authentication successful! Token saved to %s\n", savedPath)
if email != "" {
fmt.Printf("Authenticated as: %s\n", email)
}
CompleteOAuthSession(state)
return
}
}
SetOAuthSessionError(state, "Authorization timed out")
}()
// Return immediately with the state for polling
c.JSON(http.StatusOK, gin.H{"status": "ok", "state": state, "method": "device_code"})
case "google", "github":
RegisterOAuthSession(state, "kiro")
// Social auth uses protocol handler - for WEB UI we use a callback forwarder
provider := "Google"
if method == "github" {
provider = "Github"
}
isWebUI := isWebUIRequest(c)
if isWebUI {
targetURL, errTarget := h.managementCallbackURL("/kiro/callback")
if errTarget != nil {
log.WithError(errTarget).Error("failed to compute kiro callback target")
c.JSON(http.StatusInternalServerError, gin.H{"error": "callback server unavailable"})
return
}
if _, errStart := startCallbackForwarder(kiroCallbackPort, "kiro", targetURL); errStart != nil {
log.WithError(errStart).Error("failed to start kiro callback forwarder")
c.JSON(http.StatusInternalServerError, gin.H{"error": "failed to start callback server"})
return
}
}
go func() {
if isWebUI {
defer stopCallbackForwarder(kiroCallbackPort)
}
socialClient := kiroauth.NewSocialAuthClient(h.cfg)
// Generate PKCE codes
codeVerifier, codeChallenge, errPKCE := generateKiroPKCE()
if errPKCE != nil {
log.Errorf("Failed to generate PKCE: %v", errPKCE)
SetOAuthSessionError(state, "Failed to generate PKCE")
return
}
// Build login URL
authURL := fmt.Sprintf("%s/login?idp=%s&redirect_uri=%s&code_challenge=%s&code_challenge_method=S256&state=%s&prompt=select_account",
"https://prod.us-east-1.auth.desktop.kiro.dev",
provider,
url.QueryEscape(kiroauth.KiroRedirectURI),
codeChallenge,
state,
)
// Store auth URL for frontend.
// Using "|" as separator because URLs contain ":".
SetOAuthSessionError(state, "auth_url|"+authURL)
// Wait for callback file
waitFile := filepath.Join(h.cfg.AuthDir, fmt.Sprintf(".oauth-kiro-%s.oauth", state))
deadline := time.Now().Add(5 * time.Minute)
for {
if time.Now().After(deadline) {
log.Error("oauth flow timed out")
SetOAuthSessionError(state, "OAuth flow timed out")
return
}
if data, errRead := os.ReadFile(waitFile); errRead == nil {
var m map[string]string
_ = json.Unmarshal(data, &m)
_ = os.Remove(waitFile)
if errStr := m["error"]; errStr != "" {
log.Errorf("Authentication failed: %s", errStr)
SetOAuthSessionError(state, "Authentication failed")
return
}
if m["state"] != state {
log.Errorf("State mismatch")
SetOAuthSessionError(state, "State mismatch")
return
}
code := m["code"]
if code == "" {
log.Error("No authorization code received")
SetOAuthSessionError(state, "No authorization code received")
return
}
// Exchange code for tokens
tokenReq := &kiroauth.CreateTokenRequest{
Code: code,
CodeVerifier: codeVerifier,
RedirectURI: kiroauth.KiroRedirectURI,
}
tokenResp, errToken := socialClient.CreateToken(ctx, tokenReq)
if errToken != nil {
log.Errorf("Failed to exchange code for tokens: %v", errToken)
SetOAuthSessionError(state, "Failed to exchange code for tokens")
return
}
// Save the token
expiresIn := tokenResp.ExpiresIn
if expiresIn <= 0 {
expiresIn = 3600
}
expiresAt := time.Now().Add(time.Duration(expiresIn) * time.Second)
email := kiroauth.ExtractEmailFromJWT(tokenResp.AccessToken)
idPart := kiroauth.SanitizeEmailForFilename(email)
if idPart == "" {
idPart = fmt.Sprintf("%d", time.Now().UnixNano()%100000)
}
now := time.Now()
fileName := fmt.Sprintf("kiro-%s-%s.json", strings.ToLower(provider), idPart)
record := &coreauth.Auth{
ID: fileName,
Provider: "kiro",
FileName: fileName,
Metadata: map[string]any{
"type": "kiro",
"access_token": tokenResp.AccessToken,
"refresh_token": tokenResp.RefreshToken,
"profile_arn": tokenResp.ProfileArn,
"expires_at": expiresAt.Format(time.RFC3339),
"auth_method": "social",
"provider": provider,
"email": email,
"last_refresh": now.Format(time.RFC3339),
},
}
savedPath, errSave := h.saveTokenRecord(ctx, record)
if errSave != nil {
log.Errorf("Failed to save authentication tokens: %v", errSave)
SetOAuthSessionError(state, "Failed to save authentication tokens")
return
}
fmt.Printf("Authentication successful! Token saved to %s\n", savedPath)
if email != "" {
fmt.Printf("Authenticated as: %s\n", email)
}
CompleteOAuthSession(state)
return
}
time.Sleep(500 * time.Millisecond)
}
}()
c.JSON(http.StatusOK, gin.H{"status": "ok", "state": state, "method": "social"})
default:
c.JSON(http.StatusBadRequest, gin.H{"error": "invalid method, use 'aws', 'google', or 'github'"})
}
}
// generateKiroPKCE generates PKCE code verifier and challenge for Kiro OAuth.
func generateKiroPKCE() (verifier, challenge string, err error) {
b := make([]byte, 32)
if _, errRead := io.ReadFull(rand.Reader, b); errRead != nil {
return "", "", fmt.Errorf("failed to generate random bytes: %w", errRead)
}
verifier = base64.RawURLEncoding.EncodeToString(b)
h := sha256.Sum256([]byte(verifier))
challenge = base64.RawURLEncoding.EncodeToString(h[:])
return verifier, challenge, nil
}
func (h *Handler) RequestKiloToken(c *gin.Context) {
ctx := context.Background()
fmt.Println("Initializing Kilo authentication...")
state := fmt.Sprintf("kil-%d", time.Now().UnixNano())
kilocodeAuth := kilo.NewKiloAuth()
resp, err := kilocodeAuth.InitiateDeviceFlow(ctx)
if err != nil {
log.Errorf("Failed to initiate device flow: %v", err)
c.JSON(http.StatusInternalServerError, gin.H{"error": "failed to initiate device flow"})
return
}
RegisterOAuthSession(state, "kilo")
go func() {
fmt.Printf("Please visit %s and enter code: %s\n", resp.VerificationURL, resp.Code)
status, err := kilocodeAuth.PollForToken(ctx, resp.Code)
if err != nil {
SetOAuthSessionError(state, "Authentication failed")
fmt.Printf("Authentication failed: %v\n", err)
return
}
profile, err := kilocodeAuth.GetProfile(ctx, status.Token)
if err != nil {
log.Warnf("Failed to fetch profile: %v", err)
profile = &kilo.Profile{Email: status.UserEmail}
}
var orgID string
if len(profile.Orgs) > 0 {
orgID = profile.Orgs[0].ID
}
defaults, err := kilocodeAuth.GetDefaults(ctx, status.Token, orgID)
if err != nil {
defaults = &kilo.Defaults{}
}
ts := &kilo.KiloTokenStorage{
Token: status.Token,
OrganizationID: orgID,
Model: defaults.Model,
Email: status.UserEmail,
Type: "kilo",
}
fileName := kilo.CredentialFileName(status.UserEmail)
record := &coreauth.Auth{
ID: fileName,
Provider: "kilo",
FileName: fileName,
Storage: ts,
Metadata: map[string]any{
"email": status.UserEmail,
"organization_id": orgID,
"model": defaults.Model,
},
}
savedPath, errSave := h.saveTokenRecord(ctx, record)
if errSave != nil {
log.Errorf("Failed to save authentication tokens: %v", errSave)
SetOAuthSessionError(state, "Failed to save authentication tokens")
return
}
fmt.Printf("Authentication successful! Token saved to %s\n", savedPath)
CompleteOAuthSession(state)
CompleteOAuthSessionsByProvider("kilo")
}()
c.JSON(200, gin.H{
"status": "ok",
"url": resp.VerificationURL,
"state": state,
"user_code": resp.Code,
"verification_uri": resp.VerificationURL,
})
}

View File

@@ -19,8 +19,8 @@ import (
)
const (
latestReleaseURL = "https://api.github.com/repos/router-for-me/CLIProxyAPI/releases/latest"
latestReleaseUserAgent = "CLIProxyAPI"
latestReleaseURL = "https://api.github.com/repos/router-for-me/CLIProxyAPIPlus/releases/latest"
latestReleaseUserAgent = "CLIProxyAPIPlus"
)
func (h *Handler) GetConfig(c *gin.Context) {

View File

@@ -753,18 +753,22 @@ func (h *Handler) PatchOAuthModelAlias(c *gin.Context) {
normalizedMap := sanitizedOAuthModelAlias(map[string][]config.OAuthModelAlias{channel: body.Aliases})
normalized := normalizedMap[channel]
if len(normalized) == 0 {
// Only delete if channel exists, otherwise just create empty entry
if h.cfg.OAuthModelAlias != nil {
if _, ok := h.cfg.OAuthModelAlias[channel]; ok {
delete(h.cfg.OAuthModelAlias, channel)
if len(h.cfg.OAuthModelAlias) == 0 {
h.cfg.OAuthModelAlias = nil
}
h.persist(c)
return
}
}
// Create new channel with empty aliases
if h.cfg.OAuthModelAlias == nil {
c.JSON(404, gin.H{"error": "channel not found"})
return
}
if _, ok := h.cfg.OAuthModelAlias[channel]; !ok {
c.JSON(404, gin.H{"error": "channel not found"})
return
}
delete(h.cfg.OAuthModelAlias, channel)
if len(h.cfg.OAuthModelAlias) == 0 {
h.cfg.OAuthModelAlias = nil
h.cfg.OAuthModelAlias = make(map[string][]config.OAuthModelAlias)
}
h.cfg.OAuthModelAlias[channel] = []config.OAuthModelAlias{}
h.persist(c)
return
}
@@ -792,10 +796,10 @@ func (h *Handler) DeleteOAuthModelAlias(c *gin.Context) {
c.JSON(404, gin.H{"error": "channel not found"})
return
}
delete(h.cfg.OAuthModelAlias, channel)
if len(h.cfg.OAuthModelAlias) == 0 {
h.cfg.OAuthModelAlias = nil
}
// Set to nil instead of deleting the key so that the "explicitly disabled"
// marker survives config reload and prevents SanitizeOAuthModelAlias from
// re-injecting default aliases (fixes #222).
h.cfg.OAuthModelAlias[channel] = nil
h.persist(c)
}

View File

@@ -158,7 +158,12 @@ func (s *oauthSessionStore) IsPending(state, provider string) bool {
return false
}
if session.Status != "" {
return false
if !strings.EqualFold(session.Provider, "kiro") {
return false
}
if !strings.HasPrefix(session.Status, "device_code|") && !strings.HasPrefix(session.Status, "auth_url|") {
return false
}
}
if provider == "" {
return true
@@ -231,6 +236,10 @@ func NormalizeOAuthProvider(provider string) (string, error) {
return "antigravity", nil
case "qwen":
return "qwen", nil
case "kiro":
return "kiro", nil
case "github":
return "github", nil
default:
return "", errUnsupportedOAuthFlow
}

View File

@@ -7,6 +7,7 @@ import (
"errors"
"fmt"
"io"
"net"
"net/http"
"net/http/httputil"
"net/url"
@@ -104,7 +105,15 @@ func createReverseProxy(upstreamURL string, secretSource SecretSource) (*httputi
// Modify incoming responses to handle gzip without Content-Encoding
// This addresses the same issue as inline handler gzip handling, but at the proxy level
proxy.ModifyResponse = func(resp *http.Response) error {
// Only process successful responses
// Log upstream error responses for diagnostics (502, 503, etc.)
// These are NOT proxy connection errors - the upstream responded with an error status
if resp.StatusCode >= 500 {
log.Errorf("amp upstream responded with error [%d] for %s %s", resp.StatusCode, resp.Request.Method, resp.Request.URL.Path)
} else if resp.StatusCode >= 400 {
log.Warnf("amp upstream responded with client error [%d] for %s %s", resp.StatusCode, resp.Request.Method, resp.Request.URL.Path)
}
// Only process successful responses for gzip decompression
if resp.StatusCode < 200 || resp.StatusCode >= 300 {
return nil
}
@@ -188,13 +197,29 @@ func createReverseProxy(upstreamURL string, secretSource SecretSource) (*httputi
return nil
}
// Error handler for proxy failures
// Error handler for proxy failures with detailed error classification for diagnostics
proxy.ErrorHandler = func(rw http.ResponseWriter, req *http.Request, err error) {
// Client-side cancellations are common during polling; suppress logging in this case
// Classify the error type for better diagnostics
var errType string
if errors.Is(err, context.DeadlineExceeded) {
errType = "timeout"
} else if errors.Is(err, context.Canceled) {
errType = "canceled"
} else if netErr, ok := err.(net.Error); ok && netErr.Timeout() {
errType = "dial_timeout"
} else if _, ok := err.(net.Error); ok {
errType = "network_error"
} else {
errType = "connection_error"
}
// Don't log as error for context canceled - it's usually client closing connection
if errors.Is(err, context.Canceled) {
return
} else {
log.Errorf("amp upstream proxy error [%s] for %s %s: %v", errType, req.Method, req.URL.Path, err)
}
log.Errorf("amp upstream proxy error for %s %s: %v", req.Method, req.URL.Path, err)
rw.Header().Set("Content-Type", "application/json")
rw.WriteHeader(http.StatusBadGateway)
_, _ = rw.Write([]byte(`{"error":"amp_upstream_proxy_error","message":"Failed to reach Amp upstream"}`))

View File

@@ -29,15 +29,71 @@ func NewResponseRewriter(w gin.ResponseWriter, originalModel string) *ResponseRe
}
}
const maxBufferedResponseBytes = 2 * 1024 * 1024 // 2MB safety cap
func looksLikeSSEChunk(data []byte) bool {
// Fallback detection: some upstreams may omit/lie about Content-Type, causing SSE to be buffered.
// Heuristics are intentionally simple and cheap.
return bytes.Contains(data, []byte("data:")) ||
bytes.Contains(data, []byte("event:")) ||
bytes.Contains(data, []byte("message_start")) ||
bytes.Contains(data, []byte("message_delta")) ||
bytes.Contains(data, []byte("content_block_start")) ||
bytes.Contains(data, []byte("content_block_delta")) ||
bytes.Contains(data, []byte("content_block_stop")) ||
bytes.Contains(data, []byte("\n\n"))
}
func (rw *ResponseRewriter) enableStreaming(reason string) error {
if rw.isStreaming {
return nil
}
rw.isStreaming = true
// Flush any previously buffered data to avoid reordering or data loss.
if rw.body != nil && rw.body.Len() > 0 {
buf := rw.body.Bytes()
// Copy before Reset() to keep bytes stable.
toFlush := make([]byte, len(buf))
copy(toFlush, buf)
rw.body.Reset()
if _, err := rw.ResponseWriter.Write(rw.rewriteStreamChunk(toFlush)); err != nil {
return err
}
if flusher, ok := rw.ResponseWriter.(http.Flusher); ok {
flusher.Flush()
}
}
log.Debugf("amp response rewriter: switched to streaming (%s)", reason)
return nil
}
// Write intercepts response writes and buffers them for model name replacement
func (rw *ResponseRewriter) Write(data []byte) (int, error) {
// Detect streaming on first write
if rw.body.Len() == 0 && !rw.isStreaming {
// Detect streaming on first write (header-based)
if !rw.isStreaming && rw.body.Len() == 0 {
contentType := rw.Header().Get("Content-Type")
rw.isStreaming = strings.Contains(contentType, "text/event-stream") ||
strings.Contains(contentType, "stream")
}
if !rw.isStreaming {
// Content-based fallback: detect SSE-like chunks even if Content-Type is missing/wrong.
if looksLikeSSEChunk(data) {
if err := rw.enableStreaming("sse heuristic"); err != nil {
return 0, err
}
} else if rw.body.Len()+len(data) > maxBufferedResponseBytes {
// Safety cap: avoid unbounded buffering on large responses.
log.Warnf("amp response rewriter: buffer exceeded %d bytes, switching to streaming", maxBufferedResponseBytes)
if err := rw.enableStreaming("buffer limit"); err != nil {
return 0, err
}
}
}
if rw.isStreaming {
n, err := rw.ResponseWriter.Write(rw.rewriteStreamChunk(data))
if err == nil {

View File

@@ -24,6 +24,7 @@ import (
"github.com/router-for-me/CLIProxyAPI/v6/internal/api/middleware"
"github.com/router-for-me/CLIProxyAPI/v6/internal/api/modules"
ampmodule "github.com/router-for-me/CLIProxyAPI/v6/internal/api/modules/amp"
"github.com/router-for-me/CLIProxyAPI/v6/internal/auth/kiro"
"github.com/router-for-me/CLIProxyAPI/v6/internal/config"
"github.com/router-for-me/CLIProxyAPI/v6/internal/logging"
"github.com/router-for-me/CLIProxyAPI/v6/internal/managementasset"
@@ -301,6 +302,11 @@ func NewServer(cfg *config.Config, authManager *auth.Manager, accessManager *sdk
s.registerManagementRoutes()
}
// === CLIProxyAPIPlus 扩展: 注册 Kiro OAuth Web 路由 ===
kiroOAuthHandler := kiro.NewOAuthWebHandler(cfg)
kiroOAuthHandler.RegisterRoutes(engine)
log.Info("Kiro OAuth Web routes registered at /v0/oauth/kiro/*")
if optionState.keepAliveEnabled {
s.enableKeepAlive(optionState.keepAliveTimeout, optionState.keepAliveOnTimeout)
}
@@ -358,6 +364,12 @@ func (s *Server) setupRoutes() {
},
})
})
// Event logging endpoint - handles Claude Code telemetry requests
// Returns 200 OK to prevent 404 errors in logs
s.engine.POST("/api/event_logging/batch", func(c *gin.Context) {
c.JSON(http.StatusOK, gin.H{"status": "ok"})
})
s.engine.POST("/v1internal:method", geminiCLIHandlers.CLIHandler)
// OAuth callback endpoints (reuse main server port)
@@ -433,6 +445,20 @@ func (s *Server) setupRoutes() {
c.String(http.StatusOK, oauthCallbackSuccessHTML)
})
s.engine.GET("/kiro/callback", func(c *gin.Context) {
code := c.Query("code")
state := c.Query("state")
errStr := c.Query("error")
if errStr == "" {
errStr = c.Query("error_description")
}
if state != "" {
_, _ = managementHandlers.WriteOAuthCallbackFileForPendingSession(s.cfg.AuthDir, "kiro", state, code, errStr)
}
c.Header("Content-Type", "text/html; charset=utf-8")
c.String(http.StatusOK, oauthCallbackSuccessHTML)
})
// Management routes are registered lazily by registerManagementRoutes when a secret is configured.
}
@@ -635,9 +661,12 @@ func (s *Server) registerManagementRoutes() {
mgmt.GET("/gemini-cli-auth-url", s.mgmt.RequestGeminiCLIToken)
mgmt.GET("/antigravity-auth-url", s.mgmt.RequestAntigravityToken)
mgmt.GET("/qwen-auth-url", s.mgmt.RequestQwenToken)
mgmt.GET("/kilo-auth-url", s.mgmt.RequestKiloToken)
mgmt.GET("/kimi-auth-url", s.mgmt.RequestKimiToken)
mgmt.GET("/iflow-auth-url", s.mgmt.RequestIFlowToken)
mgmt.POST("/iflow-auth-url", s.mgmt.RequestIFlowCookieToken)
mgmt.GET("/kiro-auth-url", s.mgmt.RequestKiroToken)
mgmt.GET("/github-auth-url", s.mgmt.RequestGitHubToken)
mgmt.POST("/oauth-callback", s.mgmt.PostOAuthCallback)
mgmt.GET("/get-auth-status", s.mgmt.GetAuthStatus)
}

View File

@@ -242,6 +242,11 @@ func (s *OAuthServer) handleSuccess(w http.ResponseWriter, r *http.Request) {
platformURL = "https://console.anthropic.com/"
}
// Validate platformURL to prevent XSS - only allow http/https URLs
if !isValidURL(platformURL) {
platformURL = "https://console.anthropic.com/"
}
// Generate success page HTML with dynamic content
successHTML := s.generateSuccessHTML(setupRequired, platformURL)
@@ -251,6 +256,12 @@ func (s *OAuthServer) handleSuccess(w http.ResponseWriter, r *http.Request) {
}
}
// isValidURL checks if the URL is a valid http/https URL to prevent XSS
func isValidURL(urlStr string) bool {
urlStr = strings.TrimSpace(urlStr)
return strings.HasPrefix(urlStr, "https://") || strings.HasPrefix(urlStr, "http://")
}
// generateSuccessHTML creates the HTML content for the success page.
// It customizes the page based on whether additional setup is required
// and includes a link to the platform.

View File

@@ -239,6 +239,11 @@ func (s *OAuthServer) handleSuccess(w http.ResponseWriter, r *http.Request) {
platformURL = "https://platform.openai.com"
}
// Validate platformURL to prevent XSS - only allow http/https URLs
if !isValidURL(platformURL) {
platformURL = "https://platform.openai.com"
}
// Generate success page HTML with dynamic content
successHTML := s.generateSuccessHTML(setupRequired, platformURL)
@@ -248,6 +253,12 @@ func (s *OAuthServer) handleSuccess(w http.ResponseWriter, r *http.Request) {
}
}
// isValidURL checks if the URL is a valid http/https URL to prevent XSS
func isValidURL(urlStr string) bool {
urlStr = strings.TrimSpace(urlStr)
return strings.HasPrefix(urlStr, "https://") || strings.HasPrefix(urlStr, "http://")
}
// generateSuccessHTML creates the HTML content for the success page.
// It customizes the page based on whether additional setup is required
// and includes a link to the platform.

View File

@@ -0,0 +1,233 @@
// Package copilot provides authentication and token management for GitHub Copilot API.
// It handles the OAuth2 device flow for secure authentication with the Copilot API.
package copilot
import (
"context"
"encoding/json"
"fmt"
"io"
"net/http"
"time"
"github.com/router-for-me/CLIProxyAPI/v6/internal/config"
"github.com/router-for-me/CLIProxyAPI/v6/internal/util"
log "github.com/sirupsen/logrus"
)
const (
// copilotAPITokenURL is the endpoint for getting Copilot API tokens from GitHub token.
copilotAPITokenURL = "https://api.github.com/copilot_internal/v2/token"
// copilotAPIEndpoint is the base URL for making API requests.
copilotAPIEndpoint = "https://api.githubcopilot.com"
// Common HTTP header values for Copilot API requests.
copilotUserAgent = "GithubCopilot/1.0"
copilotEditorVersion = "vscode/1.100.0"
copilotPluginVersion = "copilot/1.300.0"
copilotIntegrationID = "vscode-chat"
copilotOpenAIIntent = "conversation-panel"
)
// CopilotAPIToken represents the Copilot API token response.
type CopilotAPIToken struct {
// Token is the JWT token for authenticating with the Copilot API.
Token string `json:"token"`
// ExpiresAt is the Unix timestamp when the token expires.
ExpiresAt int64 `json:"expires_at"`
// Endpoints contains the available API endpoints.
Endpoints struct {
API string `json:"api"`
Proxy string `json:"proxy"`
OriginTracker string `json:"origin-tracker"`
Telemetry string `json:"telemetry"`
} `json:"endpoints,omitempty"`
// ErrorDetails contains error information if the request failed.
ErrorDetails *struct {
URL string `json:"url"`
Message string `json:"message"`
DocumentationURL string `json:"documentation_url"`
} `json:"error_details,omitempty"`
}
// CopilotAuth handles GitHub Copilot authentication flow.
// It provides methods for device flow authentication and token management.
type CopilotAuth struct {
httpClient *http.Client
deviceClient *DeviceFlowClient
cfg *config.Config
}
// NewCopilotAuth creates a new CopilotAuth service instance.
// It initializes an HTTP client with proxy settings from the provided configuration.
func NewCopilotAuth(cfg *config.Config) *CopilotAuth {
return &CopilotAuth{
httpClient: util.SetProxy(&cfg.SDKConfig, &http.Client{Timeout: 30 * time.Second}),
deviceClient: NewDeviceFlowClient(cfg),
cfg: cfg,
}
}
// StartDeviceFlow initiates the device flow authentication.
// Returns the device code response containing the user code and verification URI.
func (c *CopilotAuth) StartDeviceFlow(ctx context.Context) (*DeviceCodeResponse, error) {
return c.deviceClient.RequestDeviceCode(ctx)
}
// WaitForAuthorization polls for user authorization and returns the auth bundle.
func (c *CopilotAuth) WaitForAuthorization(ctx context.Context, deviceCode *DeviceCodeResponse) (*CopilotAuthBundle, error) {
tokenData, err := c.deviceClient.PollForToken(ctx, deviceCode)
if err != nil {
return nil, err
}
// Fetch the GitHub username
userInfo, err := c.deviceClient.FetchUserInfo(ctx, tokenData.AccessToken)
if err != nil {
log.Warnf("copilot: failed to fetch user info: %v", err)
}
username := userInfo.Login
if username == "" {
username = "github-user"
}
return &CopilotAuthBundle{
TokenData: tokenData,
Username: username,
Email: userInfo.Email,
Name: userInfo.Name,
}, nil
}
// GetCopilotAPIToken exchanges a GitHub access token for a Copilot API token.
// This token is used to make authenticated requests to the Copilot API.
func (c *CopilotAuth) GetCopilotAPIToken(ctx context.Context, githubAccessToken string) (*CopilotAPIToken, error) {
if githubAccessToken == "" {
return nil, NewAuthenticationError(ErrTokenExchangeFailed, fmt.Errorf("github access token is empty"))
}
req, err := http.NewRequestWithContext(ctx, http.MethodGet, copilotAPITokenURL, nil)
if err != nil {
return nil, NewAuthenticationError(ErrTokenExchangeFailed, err)
}
req.Header.Set("Authorization", "token "+githubAccessToken)
req.Header.Set("Accept", "application/json")
req.Header.Set("User-Agent", copilotUserAgent)
req.Header.Set("Editor-Version", copilotEditorVersion)
req.Header.Set("Editor-Plugin-Version", copilotPluginVersion)
resp, err := c.httpClient.Do(req)
if err != nil {
return nil, NewAuthenticationError(ErrTokenExchangeFailed, err)
}
defer func() {
if errClose := resp.Body.Close(); errClose != nil {
log.Errorf("copilot api token: close body error: %v", errClose)
}
}()
bodyBytes, err := io.ReadAll(resp.Body)
if err != nil {
return nil, NewAuthenticationError(ErrTokenExchangeFailed, err)
}
if !isHTTPSuccess(resp.StatusCode) {
return nil, NewAuthenticationError(ErrTokenExchangeFailed,
fmt.Errorf("status %d: %s", resp.StatusCode, string(bodyBytes)))
}
var apiToken CopilotAPIToken
if err = json.Unmarshal(bodyBytes, &apiToken); err != nil {
return nil, NewAuthenticationError(ErrTokenExchangeFailed, err)
}
if apiToken.Token == "" {
return nil, NewAuthenticationError(ErrTokenExchangeFailed, fmt.Errorf("empty copilot api token"))
}
return &apiToken, nil
}
// ValidateToken checks if a GitHub access token is valid by attempting to fetch user info.
func (c *CopilotAuth) ValidateToken(ctx context.Context, accessToken string) (bool, string, error) {
if accessToken == "" {
return false, "", nil
}
userInfo, err := c.deviceClient.FetchUserInfo(ctx, accessToken)
if err != nil {
return false, "", err
}
return true, userInfo.Login, nil
}
// CreateTokenStorage creates a new CopilotTokenStorage from auth bundle.
func (c *CopilotAuth) CreateTokenStorage(bundle *CopilotAuthBundle) *CopilotTokenStorage {
return &CopilotTokenStorage{
AccessToken: bundle.TokenData.AccessToken,
TokenType: bundle.TokenData.TokenType,
Scope: bundle.TokenData.Scope,
Username: bundle.Username,
Email: bundle.Email,
Name: bundle.Name,
Type: "github-copilot",
}
}
// LoadAndValidateToken loads a token from storage and validates it.
// Returns the storage if valid, or an error if the token is invalid or expired.
func (c *CopilotAuth) LoadAndValidateToken(ctx context.Context, storage *CopilotTokenStorage) (bool, error) {
if storage == nil || storage.AccessToken == "" {
return false, fmt.Errorf("no token available")
}
// Check if we can still use the GitHub token to get a Copilot API token
apiToken, err := c.GetCopilotAPIToken(ctx, storage.AccessToken)
if err != nil {
return false, err
}
// Check if the API token is expired
if apiToken.ExpiresAt > 0 && time.Now().Unix() >= apiToken.ExpiresAt {
return false, fmt.Errorf("copilot api token expired")
}
return true, nil
}
// GetAPIEndpoint returns the Copilot API endpoint URL.
func (c *CopilotAuth) GetAPIEndpoint() string {
return copilotAPIEndpoint
}
// MakeAuthenticatedRequest creates an authenticated HTTP request to the Copilot API.
func (c *CopilotAuth) MakeAuthenticatedRequest(ctx context.Context, method, url string, body io.Reader, apiToken *CopilotAPIToken) (*http.Request, error) {
req, err := http.NewRequestWithContext(ctx, method, url, body)
if err != nil {
return nil, fmt.Errorf("failed to create request: %w", err)
}
req.Header.Set("Authorization", "Bearer "+apiToken.Token)
req.Header.Set("Content-Type", "application/json")
req.Header.Set("Accept", "application/json")
req.Header.Set("User-Agent", copilotUserAgent)
req.Header.Set("Editor-Version", copilotEditorVersion)
req.Header.Set("Editor-Plugin-Version", copilotPluginVersion)
req.Header.Set("Openai-Intent", copilotOpenAIIntent)
req.Header.Set("Copilot-Integration-Id", copilotIntegrationID)
return req, nil
}
// buildChatCompletionURL builds the URL for chat completions API.
func buildChatCompletionURL() string {
return copilotAPIEndpoint + "/chat/completions"
}
// isHTTPSuccess checks if the status code indicates success (2xx).
func isHTTPSuccess(statusCode int) bool {
return statusCode >= 200 && statusCode < 300
}

View File

@@ -0,0 +1,187 @@
package copilot
import (
"errors"
"fmt"
"net/http"
)
// OAuthError represents an OAuth-specific error.
type OAuthError struct {
// Code is the OAuth error code.
Code string `json:"error"`
// Description is a human-readable description of the error.
Description string `json:"error_description,omitempty"`
// URI is a URI identifying a human-readable web page with information about the error.
URI string `json:"error_uri,omitempty"`
// StatusCode is the HTTP status code associated with the error.
StatusCode int `json:"-"`
}
// Error returns a string representation of the OAuth error.
func (e *OAuthError) Error() string {
if e.Description != "" {
return fmt.Sprintf("OAuth error %s: %s", e.Code, e.Description)
}
return fmt.Sprintf("OAuth error: %s", e.Code)
}
// NewOAuthError creates a new OAuth error with the specified code, description, and status code.
func NewOAuthError(code, description string, statusCode int) *OAuthError {
return &OAuthError{
Code: code,
Description: description,
StatusCode: statusCode,
}
}
// AuthenticationError represents authentication-related errors.
type AuthenticationError struct {
// Type is the type of authentication error.
Type string `json:"type"`
// Message is a human-readable message describing the error.
Message string `json:"message"`
// Code is the HTTP status code associated with the error.
Code int `json:"code"`
// Cause is the underlying error that caused this authentication error.
Cause error `json:"-"`
}
// Error returns a string representation of the authentication error.
func (e *AuthenticationError) Error() string {
if e.Cause != nil {
return fmt.Sprintf("%s: %s (caused by: %v)", e.Type, e.Message, e.Cause)
}
return fmt.Sprintf("%s: %s", e.Type, e.Message)
}
// Unwrap returns the underlying cause of the error.
func (e *AuthenticationError) Unwrap() error {
return e.Cause
}
// Common authentication error types for GitHub Copilot device flow.
var (
// ErrDeviceCodeFailed represents an error when requesting the device code fails.
ErrDeviceCodeFailed = &AuthenticationError{
Type: "device_code_failed",
Message: "Failed to request device code from GitHub",
Code: http.StatusBadRequest,
}
// ErrDeviceCodeExpired represents an error when the device code has expired.
ErrDeviceCodeExpired = &AuthenticationError{
Type: "device_code_expired",
Message: "Device code has expired. Please try again.",
Code: http.StatusGone,
}
// ErrAuthorizationPending represents a pending authorization state (not an error, used for polling).
ErrAuthorizationPending = &AuthenticationError{
Type: "authorization_pending",
Message: "Authorization is pending. Waiting for user to authorize.",
Code: http.StatusAccepted,
}
// ErrSlowDown represents a request to slow down polling.
ErrSlowDown = &AuthenticationError{
Type: "slow_down",
Message: "Polling too frequently. Slowing down.",
Code: http.StatusTooManyRequests,
}
// ErrAccessDenied represents an error when the user denies authorization.
ErrAccessDenied = &AuthenticationError{
Type: "access_denied",
Message: "User denied authorization",
Code: http.StatusForbidden,
}
// ErrTokenExchangeFailed represents an error when token exchange fails.
ErrTokenExchangeFailed = &AuthenticationError{
Type: "token_exchange_failed",
Message: "Failed to exchange device code for access token",
Code: http.StatusBadRequest,
}
// ErrPollingTimeout represents an error when polling times out.
ErrPollingTimeout = &AuthenticationError{
Type: "polling_timeout",
Message: "Timeout waiting for user authorization",
Code: http.StatusRequestTimeout,
}
// ErrUserInfoFailed represents an error when fetching user info fails.
ErrUserInfoFailed = &AuthenticationError{
Type: "user_info_failed",
Message: "Failed to fetch GitHub user information",
Code: http.StatusBadRequest,
}
)
// NewAuthenticationError creates a new authentication error with a cause based on a base error.
func NewAuthenticationError(baseErr *AuthenticationError, cause error) *AuthenticationError {
return &AuthenticationError{
Type: baseErr.Type,
Message: baseErr.Message,
Code: baseErr.Code,
Cause: cause,
}
}
// IsAuthenticationError checks if an error is an authentication error.
func IsAuthenticationError(err error) bool {
var authenticationError *AuthenticationError
ok := errors.As(err, &authenticationError)
return ok
}
// IsOAuthError checks if an error is an OAuth error.
func IsOAuthError(err error) bool {
var oAuthError *OAuthError
ok := errors.As(err, &oAuthError)
return ok
}
// GetUserFriendlyMessage returns a user-friendly error message based on the error type.
func GetUserFriendlyMessage(err error) string {
var authErr *AuthenticationError
if errors.As(err, &authErr) {
switch authErr.Type {
case "device_code_failed":
return "Failed to start GitHub authentication. Please check your network connection and try again."
case "device_code_expired":
return "The authentication code has expired. Please try again."
case "authorization_pending":
return "Waiting for you to authorize the application on GitHub."
case "slow_down":
return "Please wait a moment before trying again."
case "access_denied":
return "Authentication was cancelled or denied."
case "token_exchange_failed":
return "Failed to complete authentication. Please try again."
case "polling_timeout":
return "Authentication timed out. Please try again."
case "user_info_failed":
return "Failed to get your GitHub account information. Please try again."
default:
return "Authentication failed. Please try again."
}
}
var oauthErr *OAuthError
if errors.As(err, &oauthErr) {
switch oauthErr.Code {
case "access_denied":
return "Authentication was cancelled or denied."
case "invalid_request":
return "Invalid authentication request. Please try again."
case "server_error":
return "GitHub server error. Please try again later."
default:
return fmt.Sprintf("Authentication failed: %s", oauthErr.Description)
}
}
return "An unexpected error occurred. Please try again."
}

View File

@@ -0,0 +1,271 @@
package copilot
import (
"context"
"encoding/json"
"errors"
"fmt"
"io"
"net/http"
"net/url"
"strings"
"time"
"github.com/router-for-me/CLIProxyAPI/v6/internal/config"
"github.com/router-for-me/CLIProxyAPI/v6/internal/util"
log "github.com/sirupsen/logrus"
)
const (
// copilotClientID is GitHub's Copilot CLI OAuth client ID.
copilotClientID = "Iv1.b507a08c87ecfe98"
// copilotDeviceCodeURL is the endpoint for requesting device codes.
copilotDeviceCodeURL = "https://github.com/login/device/code"
// copilotTokenURL is the endpoint for exchanging device codes for tokens.
copilotTokenURL = "https://github.com/login/oauth/access_token"
// copilotUserInfoURL is the endpoint for fetching GitHub user information.
copilotUserInfoURL = "https://api.github.com/user"
// defaultPollInterval is the default interval for polling token endpoint.
defaultPollInterval = 5 * time.Second
// maxPollDuration is the maximum time to wait for user authorization.
maxPollDuration = 15 * time.Minute
)
// DeviceFlowClient handles the OAuth2 device flow for GitHub Copilot.
type DeviceFlowClient struct {
httpClient *http.Client
cfg *config.Config
}
// NewDeviceFlowClient creates a new device flow client.
func NewDeviceFlowClient(cfg *config.Config) *DeviceFlowClient {
client := &http.Client{Timeout: 30 * time.Second}
if cfg != nil {
client = util.SetProxy(&cfg.SDKConfig, client)
}
return &DeviceFlowClient{
httpClient: client,
cfg: cfg,
}
}
// RequestDeviceCode initiates the device flow by requesting a device code from GitHub.
func (c *DeviceFlowClient) RequestDeviceCode(ctx context.Context) (*DeviceCodeResponse, error) {
data := url.Values{}
data.Set("client_id", copilotClientID)
data.Set("scope", "read:user user:email")
req, err := http.NewRequestWithContext(ctx, http.MethodPost, copilotDeviceCodeURL, strings.NewReader(data.Encode()))
if err != nil {
return nil, NewAuthenticationError(ErrDeviceCodeFailed, err)
}
req.Header.Set("Content-Type", "application/x-www-form-urlencoded")
req.Header.Set("Accept", "application/json")
resp, err := c.httpClient.Do(req)
if err != nil {
return nil, NewAuthenticationError(ErrDeviceCodeFailed, err)
}
defer func() {
if errClose := resp.Body.Close(); errClose != nil {
log.Errorf("copilot device code: close body error: %v", errClose)
}
}()
if !isHTTPSuccess(resp.StatusCode) {
bodyBytes, _ := io.ReadAll(resp.Body)
return nil, NewAuthenticationError(ErrDeviceCodeFailed, fmt.Errorf("status %d: %s", resp.StatusCode, string(bodyBytes)))
}
var deviceCode DeviceCodeResponse
if err = json.NewDecoder(resp.Body).Decode(&deviceCode); err != nil {
return nil, NewAuthenticationError(ErrDeviceCodeFailed, err)
}
return &deviceCode, nil
}
// PollForToken polls the token endpoint until the user authorizes or the device code expires.
func (c *DeviceFlowClient) PollForToken(ctx context.Context, deviceCode *DeviceCodeResponse) (*CopilotTokenData, error) {
if deviceCode == nil {
return nil, NewAuthenticationError(ErrTokenExchangeFailed, fmt.Errorf("device code is nil"))
}
interval := time.Duration(deviceCode.Interval) * time.Second
if interval < defaultPollInterval {
interval = defaultPollInterval
}
deadline := time.Now().Add(maxPollDuration)
if deviceCode.ExpiresIn > 0 {
codeDeadline := time.Now().Add(time.Duration(deviceCode.ExpiresIn) * time.Second)
if codeDeadline.Before(deadline) {
deadline = codeDeadline
}
}
ticker := time.NewTicker(interval)
defer ticker.Stop()
for {
select {
case <-ctx.Done():
return nil, NewAuthenticationError(ErrPollingTimeout, ctx.Err())
case <-ticker.C:
if time.Now().After(deadline) {
return nil, ErrPollingTimeout
}
token, err := c.exchangeDeviceCode(ctx, deviceCode.DeviceCode)
if err != nil {
var authErr *AuthenticationError
if errors.As(err, &authErr) {
switch authErr.Type {
case ErrAuthorizationPending.Type:
// Continue polling
continue
case ErrSlowDown.Type:
// Increase interval and continue
interval += 5 * time.Second
ticker.Reset(interval)
continue
case ErrDeviceCodeExpired.Type:
return nil, err
case ErrAccessDenied.Type:
return nil, err
}
}
return nil, err
}
return token, nil
}
}
}
// exchangeDeviceCode attempts to exchange the device code for an access token.
func (c *DeviceFlowClient) exchangeDeviceCode(ctx context.Context, deviceCode string) (*CopilotTokenData, error) {
data := url.Values{}
data.Set("client_id", copilotClientID)
data.Set("device_code", deviceCode)
data.Set("grant_type", "urn:ietf:params:oauth:grant-type:device_code")
req, err := http.NewRequestWithContext(ctx, http.MethodPost, copilotTokenURL, strings.NewReader(data.Encode()))
if err != nil {
return nil, NewAuthenticationError(ErrTokenExchangeFailed, err)
}
req.Header.Set("Content-Type", "application/x-www-form-urlencoded")
req.Header.Set("Accept", "application/json")
resp, err := c.httpClient.Do(req)
if err != nil {
return nil, NewAuthenticationError(ErrTokenExchangeFailed, err)
}
defer func() {
if errClose := resp.Body.Close(); errClose != nil {
log.Errorf("copilot token exchange: close body error: %v", errClose)
}
}()
bodyBytes, err := io.ReadAll(resp.Body)
if err != nil {
return nil, NewAuthenticationError(ErrTokenExchangeFailed, err)
}
// GitHub returns 200 for both success and error cases in device flow
// Check for OAuth error response first
var oauthResp struct {
Error string `json:"error"`
ErrorDescription string `json:"error_description"`
AccessToken string `json:"access_token"`
TokenType string `json:"token_type"`
Scope string `json:"scope"`
}
if err = json.Unmarshal(bodyBytes, &oauthResp); err != nil {
return nil, NewAuthenticationError(ErrTokenExchangeFailed, err)
}
if oauthResp.Error != "" {
switch oauthResp.Error {
case "authorization_pending":
return nil, ErrAuthorizationPending
case "slow_down":
return nil, ErrSlowDown
case "expired_token":
return nil, ErrDeviceCodeExpired
case "access_denied":
return nil, ErrAccessDenied
default:
return nil, NewOAuthError(oauthResp.Error, oauthResp.ErrorDescription, resp.StatusCode)
}
}
if oauthResp.AccessToken == "" {
return nil, NewAuthenticationError(ErrTokenExchangeFailed, fmt.Errorf("empty access token"))
}
return &CopilotTokenData{
AccessToken: oauthResp.AccessToken,
TokenType: oauthResp.TokenType,
Scope: oauthResp.Scope,
}, nil
}
// GitHubUserInfo holds GitHub user profile information.
type GitHubUserInfo struct {
// Login is the GitHub username.
Login string
// Email is the primary email address (may be empty if not public).
Email string
// Name is the display name.
Name string
}
// FetchUserInfo retrieves the GitHub user profile for the authenticated user.
func (c *DeviceFlowClient) FetchUserInfo(ctx context.Context, accessToken string) (GitHubUserInfo, error) {
if accessToken == "" {
return GitHubUserInfo{}, NewAuthenticationError(ErrUserInfoFailed, fmt.Errorf("access token is empty"))
}
req, err := http.NewRequestWithContext(ctx, http.MethodGet, copilotUserInfoURL, nil)
if err != nil {
return GitHubUserInfo{}, NewAuthenticationError(ErrUserInfoFailed, err)
}
req.Header.Set("Authorization", "Bearer "+accessToken)
req.Header.Set("Accept", "application/json")
req.Header.Set("User-Agent", "CLIProxyAPI")
resp, err := c.httpClient.Do(req)
if err != nil {
return GitHubUserInfo{}, NewAuthenticationError(ErrUserInfoFailed, err)
}
defer func() {
if errClose := resp.Body.Close(); errClose != nil {
log.Errorf("copilot user info: close body error: %v", errClose)
}
}()
if !isHTTPSuccess(resp.StatusCode) {
bodyBytes, _ := io.ReadAll(resp.Body)
return GitHubUserInfo{}, NewAuthenticationError(ErrUserInfoFailed, fmt.Errorf("status %d: %s", resp.StatusCode, string(bodyBytes)))
}
var raw struct {
Login string `json:"login"`
Email string `json:"email"`
Name string `json:"name"`
}
if err = json.NewDecoder(resp.Body).Decode(&raw); err != nil {
return GitHubUserInfo{}, NewAuthenticationError(ErrUserInfoFailed, err)
}
if raw.Login == "" {
return GitHubUserInfo{}, NewAuthenticationError(ErrUserInfoFailed, fmt.Errorf("empty username"))
}
return GitHubUserInfo{
Login: raw.Login,
Email: raw.Email,
Name: raw.Name,
}, nil
}

View File

@@ -0,0 +1,213 @@
package copilot
import (
"context"
"encoding/json"
"net/http"
"net/http/httptest"
"strings"
"testing"
)
// roundTripFunc lets us inject a custom transport for testing.
type roundTripFunc func(*http.Request) (*http.Response, error)
func (f roundTripFunc) RoundTrip(r *http.Request) (*http.Response, error) { return f(r) }
// newTestClient returns an *http.Client whose requests are redirected to the given test server,
// regardless of the original URL host.
func newTestClient(srv *httptest.Server) *http.Client {
return &http.Client{
Transport: roundTripFunc(func(req *http.Request) (*http.Response, error) {
req2 := req.Clone(req.Context())
req2.URL.Scheme = "http"
req2.URL.Host = strings.TrimPrefix(srv.URL, "http://")
return srv.Client().Transport.RoundTrip(req2)
}),
}
}
// TestFetchUserInfo_FullProfile verifies that FetchUserInfo returns login, email, and name.
func TestFetchUserInfo_FullProfile(t *testing.T) {
srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
if !strings.HasPrefix(r.Header.Get("Authorization"), "Bearer ") {
w.WriteHeader(http.StatusUnauthorized)
return
}
w.Header().Set("Content-Type", "application/json")
_ = json.NewEncoder(w).Encode(map[string]string{
"login": "octocat",
"email": "octocat@github.com",
"name": "The Octocat",
})
}))
defer srv.Close()
client := &DeviceFlowClient{httpClient: newTestClient(srv)}
info, err := client.FetchUserInfo(context.Background(), "test-token")
if err != nil {
t.Fatalf("unexpected error: %v", err)
}
if info.Login != "octocat" {
t.Errorf("Login: got %q, want %q", info.Login, "octocat")
}
if info.Email != "octocat@github.com" {
t.Errorf("Email: got %q, want %q", info.Email, "octocat@github.com")
}
if info.Name != "The Octocat" {
t.Errorf("Name: got %q, want %q", info.Name, "The Octocat")
}
}
// TestFetchUserInfo_EmptyEmail verifies graceful handling when email is absent (private account).
func TestFetchUserInfo_EmptyEmail(t *testing.T) {
srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
w.Header().Set("Content-Type", "application/json")
// GitHub returns null for private emails.
_, _ = w.Write([]byte(`{"login":"privateuser","email":null,"name":"Private User"}`))
}))
defer srv.Close()
client := &DeviceFlowClient{httpClient: newTestClient(srv)}
info, err := client.FetchUserInfo(context.Background(), "test-token")
if err != nil {
t.Fatalf("unexpected error: %v", err)
}
if info.Login != "privateuser" {
t.Errorf("Login: got %q, want %q", info.Login, "privateuser")
}
if info.Email != "" {
t.Errorf("Email: got %q, want empty string", info.Email)
}
if info.Name != "Private User" {
t.Errorf("Name: got %q, want %q", info.Name, "Private User")
}
}
// TestFetchUserInfo_EmptyToken verifies error is returned for empty access token.
func TestFetchUserInfo_EmptyToken(t *testing.T) {
client := &DeviceFlowClient{httpClient: http.DefaultClient}
_, err := client.FetchUserInfo(context.Background(), "")
if err == nil {
t.Fatal("expected error for empty token, got nil")
}
}
// TestFetchUserInfo_EmptyLogin verifies error is returned when API returns no login.
func TestFetchUserInfo_EmptyLogin(t *testing.T) {
srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
w.Header().Set("Content-Type", "application/json")
_, _ = w.Write([]byte(`{"email":"someone@example.com","name":"No Login"}`))
}))
defer srv.Close()
client := &DeviceFlowClient{httpClient: newTestClient(srv)}
_, err := client.FetchUserInfo(context.Background(), "test-token")
if err == nil {
t.Fatal("expected error for empty login, got nil")
}
}
// TestFetchUserInfo_HTTPError verifies error is returned on non-2xx response.
func TestFetchUserInfo_HTTPError(t *testing.T) {
srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
w.WriteHeader(http.StatusUnauthorized)
_, _ = w.Write([]byte(`{"message":"Bad credentials"}`))
}))
defer srv.Close()
client := &DeviceFlowClient{httpClient: newTestClient(srv)}
_, err := client.FetchUserInfo(context.Background(), "bad-token")
if err == nil {
t.Fatal("expected error for 401 response, got nil")
}
}
// TestCopilotTokenStorage_EmailNameFields verifies Email and Name serialise correctly.
func TestCopilotTokenStorage_EmailNameFields(t *testing.T) {
ts := &CopilotTokenStorage{
AccessToken: "ghu_abc",
TokenType: "bearer",
Scope: "read:user user:email",
Username: "octocat",
Email: "octocat@github.com",
Name: "The Octocat",
Type: "github-copilot",
}
data, err := json.Marshal(ts)
if err != nil {
t.Fatalf("marshal error: %v", err)
}
var out map[string]any
if err = json.Unmarshal(data, &out); err != nil {
t.Fatalf("unmarshal error: %v", err)
}
for _, key := range []string{"access_token", "username", "email", "name", "type"} {
if _, ok := out[key]; !ok {
t.Errorf("expected key %q in JSON output, not found", key)
}
}
if out["email"] != "octocat@github.com" {
t.Errorf("email: got %v, want %q", out["email"], "octocat@github.com")
}
if out["name"] != "The Octocat" {
t.Errorf("name: got %v, want %q", out["name"], "The Octocat")
}
}
// TestCopilotTokenStorage_OmitEmptyEmailName verifies email/name are omitted when empty (omitempty).
func TestCopilotTokenStorage_OmitEmptyEmailName(t *testing.T) {
ts := &CopilotTokenStorage{
AccessToken: "ghu_abc",
Username: "octocat",
Type: "github-copilot",
}
data, err := json.Marshal(ts)
if err != nil {
t.Fatalf("marshal error: %v", err)
}
var out map[string]any
if err = json.Unmarshal(data, &out); err != nil {
t.Fatalf("unmarshal error: %v", err)
}
if _, ok := out["email"]; ok {
t.Error("email key should be omitted when empty (omitempty), but was present")
}
if _, ok := out["name"]; ok {
t.Error("name key should be omitted when empty (omitempty), but was present")
}
}
// TestCopilotAuthBundle_EmailNameFields verifies bundle carries email and name through the pipeline.
func TestCopilotAuthBundle_EmailNameFields(t *testing.T) {
bundle := &CopilotAuthBundle{
TokenData: &CopilotTokenData{AccessToken: "ghu_abc"},
Username: "octocat",
Email: "octocat@github.com",
Name: "The Octocat",
}
if bundle.Email != "octocat@github.com" {
t.Errorf("bundle.Email: got %q, want %q", bundle.Email, "octocat@github.com")
}
if bundle.Name != "The Octocat" {
t.Errorf("bundle.Name: got %q, want %q", bundle.Name, "The Octocat")
}
}
// TestGitHubUserInfo_Struct verifies the exported GitHubUserInfo struct fields are accessible.
func TestGitHubUserInfo_Struct(t *testing.T) {
info := GitHubUserInfo{
Login: "octocat",
Email: "octocat@github.com",
Name: "The Octocat",
}
if info.Login == "" || info.Email == "" || info.Name == "" {
t.Error("GitHubUserInfo fields should not be empty")
}
}

View File

@@ -0,0 +1,101 @@
// Package copilot provides authentication and token management functionality
// for GitHub Copilot AI services. It handles OAuth2 device flow token storage,
// serialization, and retrieval for maintaining authenticated sessions with the Copilot API.
package copilot
import (
"encoding/json"
"fmt"
"os"
"path/filepath"
"github.com/router-for-me/CLIProxyAPI/v6/internal/misc"
)
// CopilotTokenStorage stores OAuth2 token information for GitHub Copilot API authentication.
// It maintains compatibility with the existing auth system while adding Copilot-specific fields
// for managing access tokens and user account information.
type CopilotTokenStorage struct {
// AccessToken is the OAuth2 access token used for authenticating API requests.
AccessToken string `json:"access_token"`
// TokenType is the type of token, typically "bearer".
TokenType string `json:"token_type"`
// Scope is the OAuth2 scope granted to the token.
Scope string `json:"scope"`
// ExpiresAt is the timestamp when the access token expires (if provided).
ExpiresAt string `json:"expires_at,omitempty"`
// Username is the GitHub username associated with this token.
Username string `json:"username"`
// Email is the GitHub email address associated with this token.
Email string `json:"email,omitempty"`
// Name is the GitHub display name associated with this token.
Name string `json:"name,omitempty"`
// Type indicates the authentication provider type, always "github-copilot" for this storage.
Type string `json:"type"`
}
// CopilotTokenData holds the raw OAuth token response from GitHub.
type CopilotTokenData struct {
// AccessToken is the OAuth2 access token.
AccessToken string `json:"access_token"`
// TokenType is the type of token, typically "bearer".
TokenType string `json:"token_type"`
// Scope is the OAuth2 scope granted to the token.
Scope string `json:"scope"`
}
// CopilotAuthBundle bundles authentication data for storage.
type CopilotAuthBundle struct {
// TokenData contains the OAuth token information.
TokenData *CopilotTokenData
// Username is the GitHub username.
Username string
// Email is the GitHub email address.
Email string
// Name is the GitHub display name.
Name string
}
// DeviceCodeResponse represents GitHub's device code response.
type DeviceCodeResponse struct {
// DeviceCode is the device verification code.
DeviceCode string `json:"device_code"`
// UserCode is the code the user must enter at the verification URI.
UserCode string `json:"user_code"`
// VerificationURI is the URL where the user should enter the code.
VerificationURI string `json:"verification_uri"`
// ExpiresIn is the number of seconds until the device code expires.
ExpiresIn int `json:"expires_in"`
// Interval is the minimum number of seconds to wait between polling requests.
Interval int `json:"interval"`
}
// SaveTokenToFile serializes the Copilot token storage to a JSON file.
// This method creates the necessary directory structure and writes the token
// data in JSON format to the specified file path for persistent storage.
//
// Parameters:
// - authFilePath: The full path where the token file should be saved
//
// Returns:
// - error: An error if the operation fails, nil otherwise
func (ts *CopilotTokenStorage) SaveTokenToFile(authFilePath string) error {
misc.LogSavingCredentials(authFilePath)
ts.Type = "github-copilot"
if err := os.MkdirAll(filepath.Dir(authFilePath), 0700); err != nil {
return fmt.Errorf("failed to create directory: %v", err)
}
f, err := os.Create(authFilePath)
if err != nil {
return fmt.Errorf("failed to create token file: %w", err)
}
defer func() {
_ = f.Close()
}()
if err = json.NewEncoder(f).Encode(ts); err != nil {
return fmt.Errorf("failed to write token to file: %w", err)
}
return nil
}

View File

@@ -9,6 +9,7 @@ import (
"io"
"net/http"
"net/url"
"os"
"strings"
"time"
@@ -28,10 +29,21 @@ const (
iFlowAPIKeyEndpoint = "https://platform.iflow.cn/api/openapi/apikey"
// Client credentials provided by iFlow for the Code Assist integration.
iFlowOAuthClientID = "10009311001"
iFlowOAuthClientSecret = "4Z3YjXycVsQvyGF1etiNlIBB4RsqSDtW"
iFlowOAuthClientID = "10009311001"
// Default client secret (can be overridden via IFLOW_CLIENT_SECRET env var)
defaultIFlowClientSecret = "4Z3YjXycVsQvyGF1etiNlIBB4RsqSDtW"
)
// getIFlowClientSecret returns the iFlow OAuth client secret.
// It first checks the IFLOW_CLIENT_SECRET environment variable,
// falling back to the default value if not set.
func getIFlowClientSecret() string {
if secret := os.Getenv("IFLOW_CLIENT_SECRET"); secret != "" {
return secret
}
return defaultIFlowClientSecret
}
// DefaultAPIBaseURL is the canonical chat completions endpoint.
const DefaultAPIBaseURL = "https://apis.iflow.cn/v1"
@@ -72,7 +84,7 @@ func (ia *IFlowAuth) ExchangeCodeForTokens(ctx context.Context, code, redirectUR
form.Set("code", code)
form.Set("redirect_uri", redirectURI)
form.Set("client_id", iFlowOAuthClientID)
form.Set("client_secret", iFlowOAuthClientSecret)
form.Set("client_secret", getIFlowClientSecret())
req, err := ia.newTokenRequest(ctx, form)
if err != nil {
@@ -88,7 +100,7 @@ func (ia *IFlowAuth) RefreshTokens(ctx context.Context, refreshToken string) (*I
form.Set("grant_type", "refresh_token")
form.Set("refresh_token", refreshToken)
form.Set("client_id", iFlowOAuthClientID)
form.Set("client_secret", iFlowOAuthClientSecret)
form.Set("client_secret", getIFlowClientSecret())
req, err := ia.newTokenRequest(ctx, form)
if err != nil {
@@ -104,7 +116,7 @@ func (ia *IFlowAuth) newTokenRequest(ctx context.Context, form url.Values) (*htt
return nil, fmt.Errorf("iflow token: create request failed: %w", err)
}
basic := base64.StdEncoding.EncodeToString([]byte(iFlowOAuthClientID + ":" + iFlowOAuthClientSecret))
basic := base64.StdEncoding.EncodeToString([]byte(iFlowOAuthClientID + ":" + getIFlowClientSecret()))
req.Header.Set("Content-Type", "application/x-www-form-urlencoded")
req.Header.Set("Accept", "application/json")
req.Header.Set("Authorization", "Basic "+basic)

View File

@@ -0,0 +1,168 @@
// Package kilo provides authentication and token management functionality
// for Kilo AI services.
package kilo
import (
"context"
"encoding/json"
"fmt"
"net/http"
"time"
)
const (
// BaseURL is the base URL for the Kilo AI API.
BaseURL = "https://api.kilo.ai/api"
)
// DeviceAuthResponse represents the response from initiating device flow.
type DeviceAuthResponse struct {
Code string `json:"code"`
VerificationURL string `json:"verificationUrl"`
ExpiresIn int `json:"expiresIn"`
}
// DeviceStatusResponse represents the response when polling for device flow status.
type DeviceStatusResponse struct {
Status string `json:"status"`
Token string `json:"token"`
UserEmail string `json:"userEmail"`
}
// Profile represents the user profile from Kilo AI.
type Profile struct {
Email string `json:"email"`
Orgs []Organization `json:"organizations"`
}
// Organization represents a Kilo AI organization.
type Organization struct {
ID string `json:"id"`
Name string `json:"name"`
}
// Defaults represents default settings for an organization or user.
type Defaults struct {
Model string `json:"model"`
}
// KiloAuth provides methods for handling the Kilo AI authentication flow.
type KiloAuth struct {
client *http.Client
}
// NewKiloAuth creates a new instance of KiloAuth.
func NewKiloAuth() *KiloAuth {
return &KiloAuth{
client: &http.Client{Timeout: 30 * time.Second},
}
}
// InitiateDeviceFlow starts the device authentication flow.
func (k *KiloAuth) InitiateDeviceFlow(ctx context.Context) (*DeviceAuthResponse, error) {
resp, err := k.client.Post(BaseURL+"/device-auth/codes", "application/json", nil)
if err != nil {
return nil, err
}
defer resp.Body.Close()
if resp.StatusCode != http.StatusCreated && resp.StatusCode != http.StatusOK {
return nil, fmt.Errorf("failed to initiate device flow: status %d", resp.StatusCode)
}
var data DeviceAuthResponse
if err := json.NewDecoder(resp.Body).Decode(&data); err != nil {
return nil, err
}
return &data, nil
}
// PollForToken polls for the device flow completion.
func (k *KiloAuth) PollForToken(ctx context.Context, code string) (*DeviceStatusResponse, error) {
ticker := time.NewTicker(5 * time.Second)
defer ticker.Stop()
for {
select {
case <-ctx.Done():
return nil, ctx.Err()
case <-ticker.C:
resp, err := k.client.Get(BaseURL + "/device-auth/codes/" + code)
if err != nil {
return nil, err
}
defer resp.Body.Close()
var data DeviceStatusResponse
if err := json.NewDecoder(resp.Body).Decode(&data); err != nil {
return nil, err
}
switch data.Status {
case "approved":
return &data, nil
case "denied", "expired":
return nil, fmt.Errorf("device flow %s", data.Status)
case "pending":
continue
default:
return nil, fmt.Errorf("unknown status: %s", data.Status)
}
}
}
}
// GetProfile fetches the user's profile.
func (k *KiloAuth) GetProfile(ctx context.Context, token string) (*Profile, error) {
req, err := http.NewRequestWithContext(ctx, "GET", BaseURL+"/profile", nil)
if err != nil {
return nil, fmt.Errorf("failed to create get profile request: %w", err)
}
req.Header.Set("Authorization", "Bearer "+token)
resp, err := k.client.Do(req)
if err != nil {
return nil, err
}
defer resp.Body.Close()
if resp.StatusCode != http.StatusOK {
return nil, fmt.Errorf("failed to get profile: status %d", resp.StatusCode)
}
var profile Profile
if err := json.NewDecoder(resp.Body).Decode(&profile); err != nil {
return nil, err
}
return &profile, nil
}
// GetDefaults fetches default settings for an organization.
func (k *KiloAuth) GetDefaults(ctx context.Context, token, orgID string) (*Defaults, error) {
url := BaseURL + "/defaults"
if orgID != "" {
url = BaseURL + "/organizations/" + orgID + "/defaults"
}
req, err := http.NewRequestWithContext(ctx, "GET", url, nil)
if err != nil {
return nil, fmt.Errorf("failed to create get defaults request: %w", err)
}
req.Header.Set("Authorization", "Bearer "+token)
resp, err := k.client.Do(req)
if err != nil {
return nil, err
}
defer resp.Body.Close()
if resp.StatusCode != http.StatusOK {
return nil, fmt.Errorf("failed to get defaults: status %d", resp.StatusCode)
}
var defaults Defaults
if err := json.NewDecoder(resp.Body).Decode(&defaults); err != nil {
return nil, err
}
return &defaults, nil
}

View File

@@ -0,0 +1,60 @@
// Package kilo provides authentication and token management functionality
// for Kilo AI services.
package kilo
import (
"encoding/json"
"fmt"
"os"
"path/filepath"
"github.com/router-for-me/CLIProxyAPI/v6/internal/misc"
log "github.com/sirupsen/logrus"
)
// KiloTokenStorage stores token information for Kilo AI authentication.
type KiloTokenStorage struct {
// Token is the Kilo access token.
Token string `json:"kilocodeToken"`
// OrganizationID is the Kilo organization ID.
OrganizationID string `json:"kilocodeOrganizationId"`
// Model is the default model to use.
Model string `json:"kilocodeModel"`
// Email is the email address of the authenticated user.
Email string `json:"email"`
// Type indicates the authentication provider type, always "kilo" for this storage.
Type string `json:"type"`
}
// SaveTokenToFile serializes the Kilo token storage to a JSON file.
func (ts *KiloTokenStorage) SaveTokenToFile(authFilePath string) error {
misc.LogSavingCredentials(authFilePath)
ts.Type = "kilo"
if err := os.MkdirAll(filepath.Dir(authFilePath), 0700); err != nil {
return fmt.Errorf("failed to create directory: %v", err)
}
f, err := os.Create(authFilePath)
if err != nil {
return fmt.Errorf("failed to create token file: %w", err)
}
defer func() {
if errClose := f.Close(); errClose != nil {
log.Errorf("failed to close file: %v", errClose)
}
}()
if err = json.NewEncoder(f).Encode(ts); err != nil {
return fmt.Errorf("failed to write token to file: %w", err)
}
return nil
}
// CredentialFileName returns the filename used to persist Kilo credentials.
func CredentialFileName(email string) string {
return fmt.Sprintf("kilo-%s.json", email)
}

681
internal/auth/kiro/aws.go Normal file
View File

@@ -0,0 +1,681 @@
// Package kiro provides authentication functionality for AWS CodeWhisperer (Kiro) API.
// It includes interfaces and implementations for token storage and authentication methods.
package kiro
import (
"encoding/base64"
"encoding/json"
"errors"
"fmt"
"net/url"
"os"
"path/filepath"
"strings"
"time"
log "github.com/sirupsen/logrus"
)
// PKCECodes holds PKCE verification codes for OAuth2 PKCE flow
type PKCECodes struct {
// CodeVerifier is the cryptographically random string used to correlate
// the authorization request to the token request
CodeVerifier string `json:"code_verifier"`
// CodeChallenge is the SHA256 hash of the code verifier, base64url-encoded
CodeChallenge string `json:"code_challenge"`
}
// KiroTokenData holds OAuth token information from AWS CodeWhisperer (Kiro)
type KiroTokenData struct {
// AccessToken is the OAuth2 access token for API access
AccessToken string `json:"accessToken"`
// RefreshToken is used to obtain new access tokens
RefreshToken string `json:"refreshToken"`
// ProfileArn is the AWS CodeWhisperer profile ARN
ProfileArn string `json:"profileArn"`
// ExpiresAt is the timestamp when the token expires
ExpiresAt string `json:"expiresAt"`
// AuthMethod indicates the authentication method used (e.g., "builder-id", "social", "idc")
AuthMethod string `json:"authMethod"`
// Provider indicates the OAuth provider (e.g., "AWS", "Google", "Enterprise")
Provider string `json:"provider"`
// ClientID is the OIDC client ID (needed for token refresh)
ClientID string `json:"clientId,omitempty"`
// ClientSecret is the OIDC client secret (needed for token refresh)
ClientSecret string `json:"clientSecret,omitempty"`
// ClientIDHash is the hash of client ID used to locate device registration file
// (Enterprise Kiro IDE stores clientId/clientSecret in ~/.aws/sso/cache/{clientIdHash}.json)
ClientIDHash string `json:"clientIdHash,omitempty"`
// Email is the user's email address (used for file naming)
Email string `json:"email,omitempty"`
// StartURL is the IDC/Identity Center start URL (only for IDC auth method)
StartURL string `json:"startUrl,omitempty"`
// Region is the OIDC region for IDC login and token refresh
Region string `json:"region,omitempty"`
}
// KiroAuthBundle aggregates authentication data after OAuth flow completion
type KiroAuthBundle struct {
// TokenData contains the OAuth tokens from the authentication flow
TokenData KiroTokenData `json:"token_data"`
// LastRefresh is the timestamp of the last token refresh
LastRefresh string `json:"last_refresh"`
}
// KiroUsageInfo represents usage information from CodeWhisperer API
type KiroUsageInfo struct {
// SubscriptionTitle is the subscription plan name (e.g., "KIRO FREE")
SubscriptionTitle string `json:"subscription_title"`
// CurrentUsage is the current credit usage
CurrentUsage float64 `json:"current_usage"`
// UsageLimit is the maximum credit limit
UsageLimit float64 `json:"usage_limit"`
// NextReset is the timestamp of the next usage reset
NextReset string `json:"next_reset"`
}
// KiroModel represents a model available through the CodeWhisperer API
type KiroModel struct {
// ModelID is the unique identifier for the model
ModelID string `json:"modelId"`
// ModelName is the human-readable name
ModelName string `json:"modelName"`
// Description is the model description
Description string `json:"description"`
// RateMultiplier is the credit multiplier for this model
RateMultiplier float64 `json:"rateMultiplier"`
// RateUnit is the unit for rate calculation (e.g., "credit")
RateUnit string `json:"rateUnit"`
// MaxInputTokens is the maximum input token limit
MaxInputTokens int `json:"maxInputTokens,omitempty"`
}
// KiroIDETokenFile is the default path to Kiro IDE's token file
const KiroIDETokenFile = ".aws/sso/cache/kiro-auth-token.json"
// Default retry configuration for file reading
const (
defaultTokenReadMaxAttempts = 10 // Maximum retry attempts
defaultTokenReadBaseDelay = 50 * time.Millisecond // Base delay between retries
)
// isTransientFileError checks if the error is a transient file access error
// that may be resolved by retrying (e.g., file locked by another process on Windows).
func isTransientFileError(err error) bool {
if err == nil {
return false
}
// Check for OS-level file access errors (Windows sharing violation, etc.)
var pathErr *os.PathError
if errors.As(err, &pathErr) {
// Windows sharing violation (ERROR_SHARING_VIOLATION = 32)
// Windows lock violation (ERROR_LOCK_VIOLATION = 33)
errStr := pathErr.Err.Error()
if strings.Contains(errStr, "being used by another process") ||
strings.Contains(errStr, "sharing violation") ||
strings.Contains(errStr, "lock violation") {
return true
}
}
// Check error message for common transient patterns
errMsg := strings.ToLower(err.Error())
transientPatterns := []string{
"being used by another process",
"sharing violation",
"lock violation",
"access is denied",
"unexpected end of json",
"unexpected eof",
}
for _, pattern := range transientPatterns {
if strings.Contains(errMsg, pattern) {
return true
}
}
return false
}
// LoadKiroIDETokenWithRetry loads token data from Kiro IDE's token file with retry logic.
// This handles transient file access errors (e.g., file locked by Kiro IDE during write).
// maxAttempts: maximum number of retry attempts (default 10 if <= 0)
// baseDelay: base delay between retries with exponential backoff (default 50ms if <= 0)
func LoadKiroIDETokenWithRetry(maxAttempts int, baseDelay time.Duration) (*KiroTokenData, error) {
if maxAttempts <= 0 {
maxAttempts = defaultTokenReadMaxAttempts
}
if baseDelay <= 0 {
baseDelay = defaultTokenReadBaseDelay
}
var lastErr error
for attempt := 0; attempt < maxAttempts; attempt++ {
token, err := LoadKiroIDEToken()
if err == nil {
return token, nil
}
lastErr = err
// Only retry for transient errors
if !isTransientFileError(err) {
return nil, err
}
// Exponential backoff: delay * 2^attempt, capped at 500ms
delay := baseDelay * time.Duration(1<<uint(attempt))
if delay > 500*time.Millisecond {
delay = 500 * time.Millisecond
}
time.Sleep(delay)
}
return nil, fmt.Errorf("failed to read token file after %d attempts: %w", maxAttempts, lastErr)
}
// LoadKiroIDEToken loads token data from Kiro IDE's token file.
// For Enterprise Kiro IDE (IDC auth), it also loads clientId and clientSecret
// from the device registration file referenced by clientIdHash.
func LoadKiroIDEToken() (*KiroTokenData, error) {
homeDir, err := os.UserHomeDir()
if err != nil {
return nil, fmt.Errorf("failed to get home directory: %w", err)
}
tokenPath := filepath.Join(homeDir, KiroIDETokenFile)
data, err := os.ReadFile(tokenPath)
if err != nil {
return nil, fmt.Errorf("failed to read Kiro IDE token file (%s): %w", tokenPath, err)
}
var token KiroTokenData
if err := json.Unmarshal(data, &token); err != nil {
return nil, fmt.Errorf("failed to parse Kiro IDE token: %w", err)
}
if token.AccessToken == "" {
return nil, fmt.Errorf("access token is empty in Kiro IDE token file")
}
// Normalize AuthMethod to lowercase (Kiro IDE uses "IdC" but we expect "idc")
token.AuthMethod = strings.ToLower(token.AuthMethod)
// For Enterprise Kiro IDE (IDC auth), load clientId and clientSecret from device registration
// The device registration file is located at ~/.aws/sso/cache/{clientIdHash}.json
if token.ClientIDHash != "" && token.ClientID == "" {
if err := loadDeviceRegistration(homeDir, token.ClientIDHash, &token); err != nil {
// Log warning but don't fail - token might still work for some operations
fmt.Printf("warning: failed to load device registration for clientIdHash %s: %v\n", token.ClientIDHash, err)
}
}
return &token, nil
}
// loadDeviceRegistration loads clientId and clientSecret from the device registration file.
// Enterprise Kiro IDE stores these in ~/.aws/sso/cache/{clientIdHash}.json
func loadDeviceRegistration(homeDir, clientIDHash string, token *KiroTokenData) error {
if clientIDHash == "" {
return fmt.Errorf("clientIdHash is empty")
}
// Sanitize clientIdHash to prevent path traversal
if strings.Contains(clientIDHash, "/") || strings.Contains(clientIDHash, "\\") || strings.Contains(clientIDHash, "..") {
return fmt.Errorf("invalid clientIdHash: contains path separator")
}
deviceRegPath := filepath.Join(homeDir, ".aws", "sso", "cache", clientIDHash+".json")
data, err := os.ReadFile(deviceRegPath)
if err != nil {
return fmt.Errorf("failed to read device registration file (%s): %w", deviceRegPath, err)
}
// Device registration file structure
var deviceReg struct {
ClientID string `json:"clientId"`
ClientSecret string `json:"clientSecret"`
ExpiresAt string `json:"expiresAt"`
}
if err := json.Unmarshal(data, &deviceReg); err != nil {
return fmt.Errorf("failed to parse device registration: %w", err)
}
if deviceReg.ClientID == "" || deviceReg.ClientSecret == "" {
return fmt.Errorf("device registration missing clientId or clientSecret")
}
token.ClientID = deviceReg.ClientID
token.ClientSecret = deviceReg.ClientSecret
return nil
}
// LoadKiroTokenFromPath loads token data from a custom path.
// This supports multiple accounts by allowing different token files.
// For Enterprise Kiro IDE (IDC auth), it also loads clientId and clientSecret
// from the device registration file referenced by clientIdHash.
func LoadKiroTokenFromPath(tokenPath string) (*KiroTokenData, error) {
homeDir, err := os.UserHomeDir()
if err != nil {
return nil, fmt.Errorf("failed to get home directory: %w", err)
}
// Expand ~ to home directory
if len(tokenPath) > 0 && tokenPath[0] == '~' {
tokenPath = filepath.Join(homeDir, tokenPath[1:])
}
data, err := os.ReadFile(tokenPath)
if err != nil {
return nil, fmt.Errorf("failed to read token file (%s): %w", tokenPath, err)
}
var token KiroTokenData
if err := json.Unmarshal(data, &token); err != nil {
return nil, fmt.Errorf("failed to parse token file: %w", err)
}
if token.AccessToken == "" {
return nil, fmt.Errorf("access token is empty in token file")
}
// Normalize AuthMethod to lowercase (Kiro IDE uses "IdC" but we expect "idc")
token.AuthMethod = strings.ToLower(token.AuthMethod)
// For Enterprise Kiro IDE (IDC auth), load clientId and clientSecret from device registration
if token.ClientIDHash != "" && token.ClientID == "" {
if err := loadDeviceRegistration(homeDir, token.ClientIDHash, &token); err != nil {
// Log warning but don't fail - token might still work for some operations
fmt.Printf("warning: failed to load device registration for clientIdHash %s: %v\n", token.ClientIDHash, err)
}
}
return &token, nil
}
// ListKiroTokenFiles lists all Kiro token files in the cache directory.
// This supports multiple accounts by finding all token files.
func ListKiroTokenFiles() ([]string, error) {
homeDir, err := os.UserHomeDir()
if err != nil {
return nil, fmt.Errorf("failed to get home directory: %w", err)
}
cacheDir := filepath.Join(homeDir, ".aws", "sso", "cache")
// Check if directory exists
if _, err := os.Stat(cacheDir); os.IsNotExist(err) {
return nil, nil // No token files
}
entries, err := os.ReadDir(cacheDir)
if err != nil {
return nil, fmt.Errorf("failed to read cache directory: %w", err)
}
var tokenFiles []string
for _, entry := range entries {
if entry.IsDir() {
continue
}
name := entry.Name()
// Look for kiro token files only (avoid matching unrelated AWS SSO cache files)
if strings.HasSuffix(name, ".json") && strings.HasPrefix(name, "kiro") {
tokenFiles = append(tokenFiles, filepath.Join(cacheDir, name))
}
}
return tokenFiles, nil
}
// LoadAllKiroTokens loads all Kiro tokens from the cache directory.
// This supports multiple accounts.
func LoadAllKiroTokens() ([]*KiroTokenData, error) {
files, err := ListKiroTokenFiles()
if err != nil {
return nil, err
}
var tokens []*KiroTokenData
for _, file := range files {
token, err := LoadKiroTokenFromPath(file)
if err != nil {
// Skip invalid token files
continue
}
tokens = append(tokens, token)
}
return tokens, nil
}
// JWTClaims represents the claims we care about from a JWT token.
// JWT tokens from Kiro/AWS contain user information in the payload.
type JWTClaims struct {
Email string `json:"email,omitempty"`
Sub string `json:"sub,omitempty"`
PreferredUser string `json:"preferred_username,omitempty"`
Name string `json:"name,omitempty"`
Iss string `json:"iss,omitempty"`
}
// ExtractEmailFromJWT extracts the user's email from a JWT access token.
// JWT tokens typically have format: header.payload.signature
// The payload is base64url-encoded JSON containing user claims.
func ExtractEmailFromJWT(accessToken string) string {
if accessToken == "" {
return ""
}
// JWT format: header.payload.signature
parts := strings.Split(accessToken, ".")
if len(parts) != 3 {
return ""
}
// Decode the payload (second part)
payload := parts[1]
// Add padding if needed (base64url requires padding)
switch len(payload) % 4 {
case 2:
payload += "=="
case 3:
payload += "="
}
decoded, err := base64.URLEncoding.DecodeString(payload)
if err != nil {
// Try RawURLEncoding (no padding)
decoded, err = base64.RawURLEncoding.DecodeString(parts[1])
if err != nil {
return ""
}
}
var claims JWTClaims
if err := json.Unmarshal(decoded, &claims); err != nil {
return ""
}
// Return email if available
if claims.Email != "" {
return claims.Email
}
// Fallback to preferred_username (some providers use this)
if claims.PreferredUser != "" && strings.Contains(claims.PreferredUser, "@") {
return claims.PreferredUser
}
// Fallback to sub if it looks like an email
if claims.Sub != "" && strings.Contains(claims.Sub, "@") {
return claims.Sub
}
return ""
}
// SanitizeEmailForFilename sanitizes an email address for use in a filename.
// Replaces special characters with underscores and prevents path traversal attacks.
// Also handles URL-encoded characters to prevent encoded path traversal attempts.
func SanitizeEmailForFilename(email string) string {
if email == "" {
return ""
}
result := email
// First, handle URL-encoded path traversal attempts (%2F, %2E, %5C, etc.)
// This prevents encoded characters from bypassing the sanitization.
// Note: We replace % last to catch any remaining encodings including double-encoding (%252F)
result = strings.ReplaceAll(result, "%2F", "_") // /
result = strings.ReplaceAll(result, "%2f", "_")
result = strings.ReplaceAll(result, "%5C", "_") // \
result = strings.ReplaceAll(result, "%5c", "_")
result = strings.ReplaceAll(result, "%2E", "_") // .
result = strings.ReplaceAll(result, "%2e", "_")
result = strings.ReplaceAll(result, "%00", "_") // null byte
result = strings.ReplaceAll(result, "%", "_") // Catch remaining % to prevent double-encoding attacks
// Replace characters that are problematic in filenames
// Keep @ and . in middle but replace other special characters
for _, char := range []string{"/", "\\", ":", "*", "?", "\"", "<", ">", "|", " ", "\x00"} {
result = strings.ReplaceAll(result, char, "_")
}
// Prevent path traversal: replace leading dots in each path component
// This handles cases like "../../../etc/passwd" → "_.._.._.._etc_passwd"
parts := strings.Split(result, "_")
for i, part := range parts {
for strings.HasPrefix(part, ".") {
part = "_" + part[1:]
}
parts[i] = part
}
result = strings.Join(parts, "_")
return result
}
// ExtractIDCIdentifier extracts a unique identifier from IDC startUrl.
// Examples:
// - "https://d-1234567890.awsapps.com/start" -> "d-1234567890"
// - "https://my-company.awsapps.com/start" -> "my-company"
// - "https://acme-corp.awsapps.com/start" -> "acme-corp"
func ExtractIDCIdentifier(startURL string) string {
if startURL == "" {
return ""
}
// Remove protocol prefix
url := strings.TrimPrefix(startURL, "https://")
url = strings.TrimPrefix(url, "http://")
// Extract subdomain (first part before the first dot)
// Format: {identifier}.awsapps.com/start
parts := strings.Split(url, ".")
if len(parts) > 0 && parts[0] != "" {
identifier := parts[0]
// Sanitize for filename safety
identifier = strings.ReplaceAll(identifier, "/", "_")
identifier = strings.ReplaceAll(identifier, "\\", "_")
identifier = strings.ReplaceAll(identifier, ":", "_")
return identifier
}
return ""
}
// GenerateTokenFileName generates a unique filename for token storage.
// Priority: email > startUrl identifier (for IDC) > authMethod only
// Email is unique, so no sequence suffix needed. Sequence is only added
// when email is unavailable to prevent filename collisions.
// Format: kiro-{authMethod}-{identifier}[-{seq}].json
func GenerateTokenFileName(tokenData *KiroTokenData) string {
authMethod := tokenData.AuthMethod
if authMethod == "" {
authMethod = "unknown"
}
// Priority 1: Use email if available (no sequence needed, email is unique)
if tokenData.Email != "" {
// Sanitize email for filename (replace @ and . with -)
sanitizedEmail := tokenData.Email
sanitizedEmail = strings.ReplaceAll(sanitizedEmail, "@", "-")
sanitizedEmail = strings.ReplaceAll(sanitizedEmail, ".", "-")
return fmt.Sprintf("kiro-%s-%s.json", authMethod, sanitizedEmail)
}
// Generate sequence only when email is unavailable
seq := time.Now().UnixNano() % 100000
// Priority 2: For IDC, use startUrl identifier with sequence
if authMethod == "idc" && tokenData.StartURL != "" {
identifier := ExtractIDCIdentifier(tokenData.StartURL)
if identifier != "" {
return fmt.Sprintf("kiro-%s-%s-%05d.json", authMethod, identifier, seq)
}
}
// Priority 3: Fallback to authMethod only with sequence
return fmt.Sprintf("kiro-%s-%05d.json", authMethod, seq)
}
// DefaultKiroRegion is the fallback region when none is specified.
const DefaultKiroRegion = "us-east-1"
// GetCodeWhispererLegacyEndpoint returns the legacy CodeWhisperer JSON-RPC endpoint.
// This endpoint supports JSON-RPC style requests with x-amz-target headers.
// The Q endpoint (q.{region}.amazonaws.com) does NOT support JSON-RPC style.
func GetCodeWhispererLegacyEndpoint(region string) string {
if region == "" {
region = DefaultKiroRegion
}
return "https://codewhisperer." + region + ".amazonaws.com"
}
// ProfileARN represents a parsed AWS CodeWhisperer profile ARN.
// ARN format: arn:partition:service:region:account-id:resource-type/resource-id
// Example: arn:aws:codewhisperer:us-east-1:123456789012:profile/ABCDEFGHIJKL
type ProfileARN struct {
// Raw is the original ARN string
Raw string
// Partition is the AWS partition (aws)
Partition string
// Service is the AWS service name (codewhisperer)
Service string
// Region is the AWS region (us-east-1, ap-southeast-1, etc.)
Region string
// AccountID is the AWS account ID
AccountID string
// ResourceType is the resource type (profile)
ResourceType string
// ResourceID is the resource identifier (e.g., ABCDEFGHIJKL)
ResourceID string
}
// ParseProfileARN parses an AWS ARN string into a ProfileARN struct.
// Returns nil if the ARN is empty, invalid, or not a codewhisperer ARN.
func ParseProfileARN(arn string) *ProfileARN {
if arn == "" {
return nil
}
// ARN format: arn:partition:service:region:account-id:resource
// Minimum 6 parts separated by ":"
parts := strings.Split(arn, ":")
if len(parts) < 6 {
log.Warnf("invalid ARN format: %s", arn)
return nil
}
// Validate ARN prefix
if parts[0] != "arn" {
return nil
}
// Validate partition
partition := parts[1]
if partition == "" {
return nil
}
// Validate service is codewhisperer
service := parts[2]
if service != "codewhisperer" {
return nil
}
// Validate region format (must contain "-")
region := parts[3]
if region == "" || !strings.Contains(region, "-") {
return nil
}
// Account ID
accountID := parts[4]
// Parse resource (format: resource-type/resource-id)
// Join remaining parts in case resource contains ":"
resource := strings.Join(parts[5:], ":")
resourceType := ""
resourceID := ""
if idx := strings.Index(resource, "/"); idx > 0 {
resourceType = resource[:idx]
resourceID = resource[idx+1:]
} else {
resourceType = resource
}
return &ProfileARN{
Raw: arn,
Partition: partition,
Service: service,
Region: region,
AccountID: accountID,
ResourceType: resourceType,
ResourceID: resourceID,
}
}
// GetKiroAPIEndpoint returns the Q API endpoint for the specified region.
// If region is empty, defaults to us-east-1.
func GetKiroAPIEndpoint(region string) string {
if region == "" {
region = DefaultKiroRegion
}
return "https://q." + region + ".amazonaws.com"
}
// GetKiroAPIEndpointFromProfileArn extracts region from profileArn and returns the endpoint.
// Returns default us-east-1 endpoint if region cannot be extracted.
func GetKiroAPIEndpointFromProfileArn(profileArn string) string {
region := ExtractRegionFromProfileArn(profileArn)
return GetKiroAPIEndpoint(region)
}
// ExtractRegionFromProfileArn extracts the AWS region from a ProfileARN string.
// Returns empty string if ARN is invalid or region cannot be extracted.
func ExtractRegionFromProfileArn(profileArn string) string {
parsed := ParseProfileARN(profileArn)
if parsed == nil {
return ""
}
return parsed.Region
}
// ExtractRegionFromMetadata extracts API region from auth metadata.
// Priority: api_region > profile_arn > DefaultKiroRegion
func ExtractRegionFromMetadata(metadata map[string]interface{}) string {
if metadata == nil {
return DefaultKiroRegion
}
// Priority 1: Explicit api_region override
if r, ok := metadata["api_region"].(string); ok && r != "" {
return r
}
// Priority 2: Extract from ProfileARN
if profileArn, ok := metadata["profile_arn"].(string); ok && profileArn != "" {
if region := ExtractRegionFromProfileArn(profileArn); region != "" {
return region
}
}
return DefaultKiroRegion
}
func buildURL(endpoint, path string, queryParams map[string]string) string {
fullURL := fmt.Sprintf("%s/%s", endpoint, path)
if len(queryParams) > 0 {
values := url.Values{}
for key, value := range queryParams {
if value == "" {
continue
}
values.Set(key, value)
}
if encoded := values.Encode(); encoded != "" {
fullURL = fullURL + "?" + encoded
}
}
return fullURL
}

View File

@@ -0,0 +1,326 @@
// Package kiro provides OAuth2 authentication functionality for AWS CodeWhisperer (Kiro) API.
// This package implements token loading, refresh, and API communication with CodeWhisperer.
package kiro
import (
"context"
"encoding/json"
"fmt"
"io"
"net/http"
"os"
"path/filepath"
"strings"
"time"
"github.com/router-for-me/CLIProxyAPI/v6/internal/config"
"github.com/router-for-me/CLIProxyAPI/v6/internal/util"
log "github.com/sirupsen/logrus"
)
const (
pathGetUsageLimits = "getUsageLimits"
pathListAvailableModels = "ListAvailableModels"
)
// KiroAuth handles AWS CodeWhisperer authentication and API communication.
// It provides methods for loading tokens, refreshing expired tokens,
// and communicating with the CodeWhisperer API.
type KiroAuth struct {
httpClient *http.Client
}
// NewKiroAuth creates a new Kiro authentication service.
// It initializes the HTTP client with proxy settings from the configuration.
//
// Parameters:
// - cfg: The application configuration containing proxy settings
//
// Returns:
// - *KiroAuth: A new Kiro authentication service instance
func NewKiroAuth(cfg *config.Config) *KiroAuth {
return &KiroAuth{
httpClient: util.SetProxy(&cfg.SDKConfig, &http.Client{Timeout: 120 * time.Second}),
}
}
// LoadTokenFromFile loads token data from a file path.
// This method reads and parses the token file, expanding ~ to the home directory.
//
// Parameters:
// - tokenFile: Path to the token file (supports ~ expansion)
//
// Returns:
// - *KiroTokenData: The parsed token data
// - error: An error if file reading or parsing fails
func (k *KiroAuth) LoadTokenFromFile(tokenFile string) (*KiroTokenData, error) {
// Expand ~ to home directory
if strings.HasPrefix(tokenFile, "~") {
home, err := os.UserHomeDir()
if err != nil {
return nil, fmt.Errorf("failed to get home directory: %w", err)
}
tokenFile = filepath.Join(home, tokenFile[1:])
}
data, err := os.ReadFile(tokenFile)
if err != nil {
return nil, fmt.Errorf("failed to read token file: %w", err)
}
var tokenData KiroTokenData
if err := json.Unmarshal(data, &tokenData); err != nil {
return nil, fmt.Errorf("failed to parse token file: %w", err)
}
return &tokenData, nil
}
// IsTokenExpired checks if the token has expired.
// This method parses the expiration timestamp and compares it with the current time.
//
// Parameters:
// - tokenData: The token data to check
//
// Returns:
// - bool: True if the token has expired, false otherwise
func (k *KiroAuth) IsTokenExpired(tokenData *KiroTokenData) bool {
if tokenData.ExpiresAt == "" {
return true
}
expiresAt, err := time.Parse(time.RFC3339, tokenData.ExpiresAt)
if err != nil {
// Try alternate format
expiresAt, err = time.Parse("2006-01-02T15:04:05.000Z", tokenData.ExpiresAt)
if err != nil {
return true
}
}
return time.Now().After(expiresAt)
}
// makeRequest sends a REST-style GET request to the CodeWhisperer API.
//
// Parameters:
// - ctx: The context for the request
// - path: The API path (e.g., "getUsageLimits")
// - tokenData: The token data containing access token, refresh token, and profile ARN
// - queryParams: Query parameters to add to the URL
//
// Returns:
// - []byte: The response body
// - error: An error if the request fails
func (k *KiroAuth) makeRequest(ctx context.Context, path string, tokenData *KiroTokenData, queryParams map[string]string) ([]byte, error) {
// Get endpoint from profileArn (defaults to us-east-1 if empty)
profileArn := queryParams["profileArn"]
endpoint := GetKiroAPIEndpointFromProfileArn(profileArn)
url := buildURL(endpoint, path, queryParams)
req, err := http.NewRequestWithContext(ctx, http.MethodGet, url, nil)
if err != nil {
return nil, fmt.Errorf("failed to create request: %w", err)
}
accountKey := GetAccountKey(tokenData.ClientID, tokenData.RefreshToken)
setRuntimeHeaders(req, tokenData.AccessToken, accountKey)
resp, err := k.httpClient.Do(req)
if err != nil {
return nil, fmt.Errorf("request failed: %w", err)
}
defer func() {
if errClose := resp.Body.Close(); errClose != nil {
log.Errorf("failed to close response body: %v", errClose)
}
}()
body, err := io.ReadAll(resp.Body)
if err != nil {
return nil, fmt.Errorf("failed to read response: %w", err)
}
if resp.StatusCode != http.StatusOK {
return nil, fmt.Errorf("API error (status %d): %s", resp.StatusCode, string(body))
}
return body, nil
}
// GetUsageLimits retrieves usage information from the CodeWhisperer API.
// This method fetches the current usage statistics and subscription information.
//
// Parameters:
// - ctx: The context for the request
// - tokenData: The token data containing access token and profile ARN
//
// Returns:
// - *KiroUsageInfo: The usage information
// - error: An error if the request fails
func (k *KiroAuth) GetUsageLimits(ctx context.Context, tokenData *KiroTokenData) (*KiroUsageInfo, error) {
queryParams := map[string]string{
"origin": "AI_EDITOR",
"profileArn": tokenData.ProfileArn,
"resourceType": "AGENTIC_REQUEST",
}
body, err := k.makeRequest(ctx, pathGetUsageLimits, tokenData, queryParams)
if err != nil {
return nil, err
}
var result struct {
SubscriptionInfo struct {
SubscriptionTitle string `json:"subscriptionTitle"`
} `json:"subscriptionInfo"`
UsageBreakdownList []struct {
CurrentUsageWithPrecision float64 `json:"currentUsageWithPrecision"`
UsageLimitWithPrecision float64 `json:"usageLimitWithPrecision"`
} `json:"usageBreakdownList"`
NextDateReset float64 `json:"nextDateReset"`
}
if err := json.Unmarshal(body, &result); err != nil {
return nil, fmt.Errorf("failed to parse usage response: %w", err)
}
usage := &KiroUsageInfo{
SubscriptionTitle: result.SubscriptionInfo.SubscriptionTitle,
NextReset: fmt.Sprintf("%v", result.NextDateReset),
}
if len(result.UsageBreakdownList) > 0 {
usage.CurrentUsage = result.UsageBreakdownList[0].CurrentUsageWithPrecision
usage.UsageLimit = result.UsageBreakdownList[0].UsageLimitWithPrecision
}
return usage, nil
}
// ListAvailableModels retrieves available models from the CodeWhisperer API.
// This method fetches the list of AI models available for the authenticated user.
//
// Parameters:
// - ctx: The context for the request
// - tokenData: The token data containing access token and profile ARN
//
// Returns:
// - []*KiroModel: The list of available models
// - error: An error if the request fails
func (k *KiroAuth) ListAvailableModels(ctx context.Context, tokenData *KiroTokenData) ([]*KiroModel, error) {
queryParams := map[string]string{
"origin": "AI_EDITOR",
"profileArn": tokenData.ProfileArn,
}
body, err := k.makeRequest(ctx, pathListAvailableModels, tokenData, queryParams)
if err != nil {
return nil, err
}
var result struct {
Models []struct {
ModelID string `json:"modelId"`
ModelName string `json:"modelName"`
Description string `json:"description"`
RateMultiplier float64 `json:"rateMultiplier"`
RateUnit string `json:"rateUnit"`
TokenLimits *struct {
MaxInputTokens int `json:"maxInputTokens"`
} `json:"tokenLimits"`
} `json:"models"`
}
if err := json.Unmarshal(body, &result); err != nil {
return nil, fmt.Errorf("failed to parse models response: %w", err)
}
models := make([]*KiroModel, 0, len(result.Models))
for _, m := range result.Models {
maxInputTokens := 0
if m.TokenLimits != nil {
maxInputTokens = m.TokenLimits.MaxInputTokens
}
models = append(models, &KiroModel{
ModelID: m.ModelID,
ModelName: m.ModelName,
Description: m.Description,
RateMultiplier: m.RateMultiplier,
RateUnit: m.RateUnit,
MaxInputTokens: maxInputTokens,
})
}
return models, nil
}
// CreateTokenStorage creates a new KiroTokenStorage from token data.
// This method converts the token data into a storage structure suitable for persistence.
//
// Parameters:
// - tokenData: The token data to convert
//
// Returns:
// - *KiroTokenStorage: A new token storage instance
func (k *KiroAuth) CreateTokenStorage(tokenData *KiroTokenData) *KiroTokenStorage {
return &KiroTokenStorage{
AccessToken: tokenData.AccessToken,
RefreshToken: tokenData.RefreshToken,
ProfileArn: tokenData.ProfileArn,
ExpiresAt: tokenData.ExpiresAt,
AuthMethod: tokenData.AuthMethod,
Provider: tokenData.Provider,
LastRefresh: time.Now().Format(time.RFC3339),
ClientID: tokenData.ClientID,
ClientSecret: tokenData.ClientSecret,
Region: tokenData.Region,
StartURL: tokenData.StartURL,
Email: tokenData.Email,
}
}
// ValidateToken checks if the token is valid by making a test API call.
// This method verifies the token by attempting to fetch usage limits.
//
// Parameters:
// - ctx: The context for the request
// - tokenData: The token data to validate
//
// Returns:
// - error: An error if the token is invalid
func (k *KiroAuth) ValidateToken(ctx context.Context, tokenData *KiroTokenData) error {
_, err := k.GetUsageLimits(ctx, tokenData)
return err
}
// UpdateTokenStorage updates an existing token storage with new token data.
// This method refreshes the token storage with newly obtained access and refresh tokens.
//
// Parameters:
// - storage: The existing token storage to update
// - tokenData: The new token data to apply
func (k *KiroAuth) UpdateTokenStorage(storage *KiroTokenStorage, tokenData *KiroTokenData) {
storage.AccessToken = tokenData.AccessToken
storage.RefreshToken = tokenData.RefreshToken
storage.ProfileArn = tokenData.ProfileArn
storage.ExpiresAt = tokenData.ExpiresAt
storage.AuthMethod = tokenData.AuthMethod
storage.Provider = tokenData.Provider
storage.LastRefresh = time.Now().Format(time.RFC3339)
if tokenData.ClientID != "" {
storage.ClientID = tokenData.ClientID
}
if tokenData.ClientSecret != "" {
storage.ClientSecret = tokenData.ClientSecret
}
if tokenData.Region != "" {
storage.Region = tokenData.Region
}
if tokenData.StartURL != "" {
storage.StartURL = tokenData.StartURL
}
if tokenData.Email != "" {
storage.Email = tokenData.Email
}
}

View File

@@ -0,0 +1,751 @@
package kiro
import (
"encoding/base64"
"encoding/json"
"strings"
"testing"
)
func TestExtractEmailFromJWT(t *testing.T) {
tests := []struct {
name string
token string
expected string
}{
{
name: "Empty token",
token: "",
expected: "",
},
{
name: "Invalid token format",
token: "not.a.valid.jwt",
expected: "",
},
{
name: "Invalid token - not base64",
token: "xxx.yyy.zzz",
expected: "",
},
{
name: "Valid JWT with email",
token: createTestJWT(map[string]any{"email": "test@example.com", "sub": "user123"}),
expected: "test@example.com",
},
{
name: "JWT without email but with preferred_username",
token: createTestJWT(map[string]any{"preferred_username": "user@domain.com", "sub": "user123"}),
expected: "user@domain.com",
},
{
name: "JWT with email-like sub",
token: createTestJWT(map[string]any{"sub": "another@test.com"}),
expected: "another@test.com",
},
{
name: "JWT without any email fields",
token: createTestJWT(map[string]any{"sub": "user123", "name": "Test User"}),
expected: "",
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
result := ExtractEmailFromJWT(tt.token)
if result != tt.expected {
t.Errorf("ExtractEmailFromJWT() = %q, want %q", result, tt.expected)
}
})
}
}
func TestSanitizeEmailForFilename(t *testing.T) {
tests := []struct {
name string
email string
expected string
}{
{
name: "Empty email",
email: "",
expected: "",
},
{
name: "Simple email",
email: "user@example.com",
expected: "user@example.com",
},
{
name: "Email with space",
email: "user name@example.com",
expected: "user_name@example.com",
},
{
name: "Email with special chars",
email: "user:name@example.com",
expected: "user_name@example.com",
},
{
name: "Email with multiple special chars",
email: "user/name:test@example.com",
expected: "user_name_test@example.com",
},
{
name: "Path traversal attempt",
email: "../../../etc/passwd",
expected: "_.__.__._etc_passwd",
},
{
name: "Path traversal with backslash",
email: `..\..\..\..\windows\system32`,
expected: "_.__.__.__._windows_system32",
},
{
name: "Null byte injection attempt",
email: "user\x00@evil.com",
expected: "user_@evil.com",
},
// URL-encoded path traversal tests
{
name: "URL-encoded slash",
email: "user%2Fpath@example.com",
expected: "user_path@example.com",
},
{
name: "URL-encoded backslash",
email: "user%5Cpath@example.com",
expected: "user_path@example.com",
},
{
name: "URL-encoded dot",
email: "%2E%2E%2Fetc%2Fpasswd",
expected: "___etc_passwd",
},
{
name: "URL-encoded null",
email: "user%00@evil.com",
expected: "user_@evil.com",
},
{
name: "Double URL-encoding attack",
email: "%252F%252E%252E",
expected: "_252F_252E_252E", // % replaced with _, remaining chars preserved (safe)
},
{
name: "Mixed case URL-encoding",
email: "%2f%2F%5c%5C",
expected: "____",
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
result := SanitizeEmailForFilename(tt.email)
if result != tt.expected {
t.Errorf("SanitizeEmailForFilename() = %q, want %q", result, tt.expected)
}
})
}
}
// createTestJWT creates a test JWT token with the given claims
func createTestJWT(claims map[string]any) string {
header := base64.RawURLEncoding.EncodeToString([]byte(`{"alg":"RS256","typ":"JWT"}`))
payloadBytes, _ := json.Marshal(claims)
payload := base64.RawURLEncoding.EncodeToString(payloadBytes)
signature := base64.RawURLEncoding.EncodeToString([]byte("fake-signature"))
return header + "." + payload + "." + signature
}
func TestExtractIDCIdentifier(t *testing.T) {
tests := []struct {
name string
startURL string
expected string
}{
{
name: "Empty URL",
startURL: "",
expected: "",
},
{
name: "Standard IDC URL with d- prefix",
startURL: "https://d-1234567890.awsapps.com/start",
expected: "d-1234567890",
},
{
name: "IDC URL with company name",
startURL: "https://my-company.awsapps.com/start",
expected: "my-company",
},
{
name: "IDC URL with simple name",
startURL: "https://acme-corp.awsapps.com/start",
expected: "acme-corp",
},
{
name: "IDC URL without https",
startURL: "http://d-9876543210.awsapps.com/start",
expected: "d-9876543210",
},
{
name: "IDC URL with subdomain only",
startURL: "https://test.awsapps.com/start",
expected: "test",
},
{
name: "Builder ID URL",
startURL: "https://view.awsapps.com/start",
expected: "view",
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
result := ExtractIDCIdentifier(tt.startURL)
if result != tt.expected {
t.Errorf("ExtractIDCIdentifier() = %q, want %q", result, tt.expected)
}
})
}
}
func TestGenerateTokenFileName(t *testing.T) {
tests := []struct {
name string
tokenData *KiroTokenData
exact string // exact match (for cases with email)
prefix string // prefix match (for cases without email, where sequence is appended)
}{
{
name: "IDC with email",
tokenData: &KiroTokenData{
AuthMethod: "idc",
Email: "user@example.com",
StartURL: "https://d-1234567890.awsapps.com/start",
},
exact: "kiro-idc-user-example-com.json",
},
{
name: "IDC without email but with startUrl",
tokenData: &KiroTokenData{
AuthMethod: "idc",
Email: "",
StartURL: "https://d-1234567890.awsapps.com/start",
},
prefix: "kiro-idc-d-1234567890-",
},
{
name: "IDC with company name in startUrl",
tokenData: &KiroTokenData{
AuthMethod: "idc",
Email: "",
StartURL: "https://my-company.awsapps.com/start",
},
prefix: "kiro-idc-my-company-",
},
{
name: "IDC without email and without startUrl",
tokenData: &KiroTokenData{
AuthMethod: "idc",
Email: "",
StartURL: "",
},
prefix: "kiro-idc-",
},
{
name: "Builder ID with email",
tokenData: &KiroTokenData{
AuthMethod: "builder-id",
Email: "user@gmail.com",
StartURL: "https://view.awsapps.com/start",
},
exact: "kiro-builder-id-user-gmail-com.json",
},
{
name: "Builder ID without email",
tokenData: &KiroTokenData{
AuthMethod: "builder-id",
Email: "",
StartURL: "https://view.awsapps.com/start",
},
prefix: "kiro-builder-id-",
},
{
name: "Social auth with email",
tokenData: &KiroTokenData{
AuthMethod: "google",
Email: "user@gmail.com",
},
exact: "kiro-google-user-gmail-com.json",
},
{
name: "Empty auth method",
tokenData: &KiroTokenData{
AuthMethod: "",
Email: "",
},
prefix: "kiro-unknown-",
},
{
name: "Email with special characters",
tokenData: &KiroTokenData{
AuthMethod: "idc",
Email: "user.name+tag@sub.example.com",
StartURL: "https://d-1234567890.awsapps.com/start",
},
exact: "kiro-idc-user-name+tag-sub-example-com.json",
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
result := GenerateTokenFileName(tt.tokenData)
if tt.exact != "" {
if result != tt.exact {
t.Errorf("GenerateTokenFileName() = %q, want %q", result, tt.exact)
}
} else if tt.prefix != "" {
if !strings.HasPrefix(result, tt.prefix) || !strings.HasSuffix(result, ".json") {
t.Errorf("GenerateTokenFileName() = %q, want prefix %q with .json suffix", result, tt.prefix)
}
}
})
}
}
func TestParseProfileARN(t *testing.T) {
tests := []struct {
name string
arn string
expected *ProfileARN
}{
{
name: "Empty ARN",
arn: "",
expected: nil,
},
{
name: "Invalid format - too few parts",
arn: "arn:aws:codewhisperer",
expected: nil,
},
{
name: "Invalid prefix - not arn",
arn: "notarn:aws:codewhisperer:us-east-1:123456789012:profile/ABC",
expected: nil,
},
{
name: "Invalid service - not codewhisperer",
arn: "arn:aws:s3:us-east-1:123456789012:bucket/mybucket",
expected: nil,
},
{
name: "Invalid region - no hyphen",
arn: "arn:aws:codewhisperer:useast1:123456789012:profile/ABC",
expected: nil,
},
{
name: "Empty partition",
arn: "arn::codewhisperer:us-east-1:123456789012:profile/ABC",
expected: nil,
},
{
name: "Empty region",
arn: "arn:aws:codewhisperer::123456789012:profile/ABC",
expected: nil,
},
{
name: "Valid ARN - us-east-1",
arn: "arn:aws:codewhisperer:us-east-1:123456789012:profile/ABCDEFGHIJKL",
expected: &ProfileARN{
Raw: "arn:aws:codewhisperer:us-east-1:123456789012:profile/ABCDEFGHIJKL",
Partition: "aws",
Service: "codewhisperer",
Region: "us-east-1",
AccountID: "123456789012",
ResourceType: "profile",
ResourceID: "ABCDEFGHIJKL",
},
},
{
name: "Valid ARN - ap-southeast-1",
arn: "arn:aws:codewhisperer:ap-southeast-1:987654321098:profile/ZYXWVUTSRQ",
expected: &ProfileARN{
Raw: "arn:aws:codewhisperer:ap-southeast-1:987654321098:profile/ZYXWVUTSRQ",
Partition: "aws",
Service: "codewhisperer",
Region: "ap-southeast-1",
AccountID: "987654321098",
ResourceType: "profile",
ResourceID: "ZYXWVUTSRQ",
},
},
{
name: "Valid ARN - eu-west-1",
arn: "arn:aws:codewhisperer:eu-west-1:111222333444:profile/PROFILE123",
expected: &ProfileARN{
Raw: "arn:aws:codewhisperer:eu-west-1:111222333444:profile/PROFILE123",
Partition: "aws",
Service: "codewhisperer",
Region: "eu-west-1",
AccountID: "111222333444",
ResourceType: "profile",
ResourceID: "PROFILE123",
},
},
{
name: "Valid ARN - aws-cn partition",
arn: "arn:aws-cn:codewhisperer:cn-north-1:123456789012:profile/CHINAID",
expected: &ProfileARN{
Raw: "arn:aws-cn:codewhisperer:cn-north-1:123456789012:profile/CHINAID",
Partition: "aws-cn",
Service: "codewhisperer",
Region: "cn-north-1",
AccountID: "123456789012",
ResourceType: "profile",
ResourceID: "CHINAID",
},
},
{
name: "Valid ARN - resource without slash",
arn: "arn:aws:codewhisperer:us-west-2:123456789012:profile",
expected: &ProfileARN{
Raw: "arn:aws:codewhisperer:us-west-2:123456789012:profile",
Partition: "aws",
Service: "codewhisperer",
Region: "us-west-2",
AccountID: "123456789012",
ResourceType: "profile",
ResourceID: "",
},
},
{
name: "Valid ARN - resource with colon",
arn: "arn:aws:codewhisperer:us-east-1:123456789012:profile/ABC:extra",
expected: &ProfileARN{
Raw: "arn:aws:codewhisperer:us-east-1:123456789012:profile/ABC:extra",
Partition: "aws",
Service: "codewhisperer",
Region: "us-east-1",
AccountID: "123456789012",
ResourceType: "profile",
ResourceID: "ABC:extra",
},
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
result := ParseProfileARN(tt.arn)
if tt.expected == nil {
if result != nil {
t.Errorf("ParseProfileARN(%q) = %+v, want nil", tt.arn, result)
}
return
}
if result == nil {
t.Errorf("ParseProfileARN(%q) = nil, want %+v", tt.arn, tt.expected)
return
}
if result.Raw != tt.expected.Raw {
t.Errorf("Raw = %q, want %q", result.Raw, tt.expected.Raw)
}
if result.Partition != tt.expected.Partition {
t.Errorf("Partition = %q, want %q", result.Partition, tt.expected.Partition)
}
if result.Service != tt.expected.Service {
t.Errorf("Service = %q, want %q", result.Service, tt.expected.Service)
}
if result.Region != tt.expected.Region {
t.Errorf("Region = %q, want %q", result.Region, tt.expected.Region)
}
if result.AccountID != tt.expected.AccountID {
t.Errorf("AccountID = %q, want %q", result.AccountID, tt.expected.AccountID)
}
if result.ResourceType != tt.expected.ResourceType {
t.Errorf("ResourceType = %q, want %q", result.ResourceType, tt.expected.ResourceType)
}
if result.ResourceID != tt.expected.ResourceID {
t.Errorf("ResourceID = %q, want %q", result.ResourceID, tt.expected.ResourceID)
}
})
}
}
func TestExtractRegionFromProfileArn(t *testing.T) {
tests := []struct {
name string
profileArn string
expected string
}{
{
name: "Empty ARN",
profileArn: "",
expected: "",
},
{
name: "Invalid ARN",
profileArn: "invalid-arn",
expected: "",
},
{
name: "Valid ARN - us-east-1",
profileArn: "arn:aws:codewhisperer:us-east-1:123456789012:profile/ABC",
expected: "us-east-1",
},
{
name: "Valid ARN - ap-southeast-1",
profileArn: "arn:aws:codewhisperer:ap-southeast-1:123456789012:profile/ABC",
expected: "ap-southeast-1",
},
{
name: "Valid ARN - eu-central-1",
profileArn: "arn:aws:codewhisperer:eu-central-1:123456789012:profile/ABC",
expected: "eu-central-1",
},
{
name: "Non-codewhisperer ARN",
profileArn: "arn:aws:s3:us-east-1:123456789012:bucket/mybucket",
expected: "",
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
result := ExtractRegionFromProfileArn(tt.profileArn)
if result != tt.expected {
t.Errorf("ExtractRegionFromProfileArn(%q) = %q, want %q", tt.profileArn, result, tt.expected)
}
})
}
}
func TestGetKiroAPIEndpoint(t *testing.T) {
tests := []struct {
name string
region string
expected string
}{
{
name: "Empty region - defaults to us-east-1",
region: "",
expected: "https://q.us-east-1.amazonaws.com",
},
{
name: "us-east-1",
region: "us-east-1",
expected: "https://q.us-east-1.amazonaws.com",
},
{
name: "us-west-2",
region: "us-west-2",
expected: "https://q.us-west-2.amazonaws.com",
},
{
name: "ap-southeast-1",
region: "ap-southeast-1",
expected: "https://q.ap-southeast-1.amazonaws.com",
},
{
name: "eu-west-1",
region: "eu-west-1",
expected: "https://q.eu-west-1.amazonaws.com",
},
{
name: "cn-north-1",
region: "cn-north-1",
expected: "https://q.cn-north-1.amazonaws.com",
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
result := GetKiroAPIEndpoint(tt.region)
if result != tt.expected {
t.Errorf("GetKiroAPIEndpoint(%q) = %q, want %q", tt.region, result, tt.expected)
}
})
}
}
func TestGetKiroAPIEndpointFromProfileArn(t *testing.T) {
tests := []struct {
name string
profileArn string
expected string
}{
{
name: "Empty ARN - defaults to us-east-1",
profileArn: "",
expected: "https://q.us-east-1.amazonaws.com",
},
{
name: "Invalid ARN - defaults to us-east-1",
profileArn: "invalid-arn",
expected: "https://q.us-east-1.amazonaws.com",
},
{
name: "Valid ARN - us-east-1",
profileArn: "arn:aws:codewhisperer:us-east-1:123456789012:profile/ABC",
expected: "https://q.us-east-1.amazonaws.com",
},
{
name: "Valid ARN - ap-southeast-1",
profileArn: "arn:aws:codewhisperer:ap-southeast-1:123456789012:profile/ABC",
expected: "https://q.ap-southeast-1.amazonaws.com",
},
{
name: "Valid ARN - eu-central-1",
profileArn: "arn:aws:codewhisperer:eu-central-1:123456789012:profile/ABC",
expected: "https://q.eu-central-1.amazonaws.com",
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
result := GetKiroAPIEndpointFromProfileArn(tt.profileArn)
if result != tt.expected {
t.Errorf("GetKiroAPIEndpointFromProfileArn(%q) = %q, want %q", tt.profileArn, result, tt.expected)
}
})
}
}
func TestGetCodeWhispererLegacyEndpoint(t *testing.T) {
tests := []struct {
name string
region string
expected string
}{
{
name: "Empty region - defaults to us-east-1",
region: "",
expected: "https://codewhisperer.us-east-1.amazonaws.com",
},
{
name: "us-east-1",
region: "us-east-1",
expected: "https://codewhisperer.us-east-1.amazonaws.com",
},
{
name: "us-west-2",
region: "us-west-2",
expected: "https://codewhisperer.us-west-2.amazonaws.com",
},
{
name: "ap-northeast-1",
region: "ap-northeast-1",
expected: "https://codewhisperer.ap-northeast-1.amazonaws.com",
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
result := GetCodeWhispererLegacyEndpoint(tt.region)
if result != tt.expected {
t.Errorf("GetCodeWhispererLegacyEndpoint(%q) = %q, want %q", tt.region, result, tt.expected)
}
})
}
}
func TestExtractRegionFromMetadata(t *testing.T) {
tests := []struct {
name string
metadata map[string]interface{}
expected string
}{
{
name: "Nil metadata - defaults to us-east-1",
metadata: nil,
expected: "us-east-1",
},
{
name: "Empty metadata - defaults to us-east-1",
metadata: map[string]interface{}{},
expected: "us-east-1",
},
{
name: "Priority 1: api_region override",
metadata: map[string]interface{}{
"api_region": "eu-west-1",
"profile_arn": "arn:aws:codewhisperer:us-east-1:123456789012:profile/ABC",
},
expected: "eu-west-1",
},
{
name: "Priority 2: profile_arn when api_region is empty",
metadata: map[string]interface{}{
"api_region": "",
"profile_arn": "arn:aws:codewhisperer:ap-southeast-1:123456789012:profile/ABC",
},
expected: "ap-southeast-1",
},
{
name: "Priority 2: profile_arn when api_region is missing",
metadata: map[string]interface{}{
"profile_arn": "arn:aws:codewhisperer:eu-central-1:123456789012:profile/ABC",
},
expected: "eu-central-1",
},
{
name: "Fallback: default when profile_arn is invalid",
metadata: map[string]interface{}{
"profile_arn": "invalid-arn",
},
expected: "us-east-1",
},
{
name: "Fallback: default when profile_arn is empty",
metadata: map[string]interface{}{
"profile_arn": "",
},
expected: "us-east-1",
},
{
name: "OIDC region is NOT used for API region",
metadata: map[string]interface{}{
"region": "ap-northeast-2", // OIDC region - should be ignored
},
expected: "us-east-1",
},
{
name: "api_region takes precedence over OIDC region",
metadata: map[string]interface{}{
"api_region": "us-west-2",
"region": "ap-northeast-2", // OIDC region - should be ignored
},
expected: "us-west-2",
},
{
name: "Non-string api_region is ignored",
metadata: map[string]interface{}{
"api_region": 123, // wrong type
"profile_arn": "arn:aws:codewhisperer:ap-south-1:123456789012:profile/ABC",
},
expected: "ap-south-1",
},
{
name: "Non-string profile_arn is ignored",
metadata: map[string]interface{}{
"profile_arn": 123, // wrong type
},
expected: "us-east-1",
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
result := ExtractRegionFromMetadata(tt.metadata)
if result != tt.expected {
t.Errorf("ExtractRegionFromMetadata(%v) = %q, want %q", tt.metadata, result, tt.expected)
}
})
}
}

View File

@@ -0,0 +1,247 @@
package kiro
import (
"context"
"log"
"strings"
"sync"
"time"
"github.com/router-for-me/CLIProxyAPI/v6/internal/config"
"golang.org/x/sync/semaphore"
)
type Token struct {
ID string
AccessToken string
RefreshToken string
ExpiresAt time.Time
LastVerified time.Time
ClientID string
ClientSecret string
AuthMethod string
Provider string
StartURL string
Region string
}
type TokenRepository interface {
FindOldestUnverified(limit int) []*Token
UpdateToken(token *Token) error
}
type RefresherOption func(*BackgroundRefresher)
func WithInterval(interval time.Duration) RefresherOption {
return func(r *BackgroundRefresher) {
r.interval = interval
}
}
func WithBatchSize(size int) RefresherOption {
return func(r *BackgroundRefresher) {
r.batchSize = size
}
}
func WithConcurrency(concurrency int) RefresherOption {
return func(r *BackgroundRefresher) {
r.concurrency = concurrency
}
}
type BackgroundRefresher struct {
interval time.Duration
batchSize int
concurrency int
tokenRepo TokenRepository
stopCh chan struct{}
wg sync.WaitGroup
oauth *KiroOAuth
ssoClient *SSOOIDCClient
callbackMu sync.RWMutex // 保护回调函数的并发访问
onTokenRefreshed func(tokenID string, tokenData *KiroTokenData) // 刷新成功回调
}
func NewBackgroundRefresher(repo TokenRepository, opts ...RefresherOption) *BackgroundRefresher {
r := &BackgroundRefresher{
interval: time.Minute,
batchSize: 50,
concurrency: 10,
tokenRepo: repo,
stopCh: make(chan struct{}),
oauth: nil, // Lazy init - will be set when config available
ssoClient: nil, // Lazy init - will be set when config available
}
for _, opt := range opts {
opt(r)
}
return r
}
// WithConfig sets the configuration for OAuth and SSO clients.
func WithConfig(cfg *config.Config) RefresherOption {
return func(r *BackgroundRefresher) {
r.oauth = NewKiroOAuth(cfg)
r.ssoClient = NewSSOOIDCClient(cfg)
}
}
// WithOnTokenRefreshed sets the callback function to be called when a token is successfully refreshed.
// The callback receives the token ID (filename) and the new token data.
// This allows external components (e.g., Watcher) to be notified of token updates.
func WithOnTokenRefreshed(callback func(tokenID string, tokenData *KiroTokenData)) RefresherOption {
return func(r *BackgroundRefresher) {
r.callbackMu.Lock()
r.onTokenRefreshed = callback
r.callbackMu.Unlock()
}
}
func (r *BackgroundRefresher) Start(ctx context.Context) {
r.wg.Add(1)
go func() {
defer r.wg.Done()
ticker := time.NewTicker(r.interval)
defer ticker.Stop()
r.refreshBatch(ctx)
for {
select {
case <-ctx.Done():
return
case <-r.stopCh:
return
case <-ticker.C:
r.refreshBatch(ctx)
}
}
}()
}
func (r *BackgroundRefresher) Stop() {
close(r.stopCh)
r.wg.Wait()
}
func (r *BackgroundRefresher) refreshBatch(ctx context.Context) {
tokens := r.tokenRepo.FindOldestUnverified(r.batchSize)
if len(tokens) == 0 {
return
}
sem := semaphore.NewWeighted(int64(r.concurrency))
var wg sync.WaitGroup
for i, token := range tokens {
if i > 0 {
select {
case <-ctx.Done():
return
case <-r.stopCh:
return
case <-time.After(100 * time.Millisecond):
}
}
if err := sem.Acquire(ctx, 1); err != nil {
return
}
wg.Add(1)
go func(t *Token) {
defer wg.Done()
defer sem.Release(1)
r.refreshSingle(ctx, t)
}(token)
}
wg.Wait()
}
func (r *BackgroundRefresher) refreshSingle(ctx context.Context, token *Token) {
// Normalize auth method to lowercase for case-insensitive matching
authMethod := strings.ToLower(token.AuthMethod)
// Create refresh function based on auth method
refreshFunc := func(ctx context.Context) (*KiroTokenData, error) {
switch authMethod {
case "idc":
return r.ssoClient.RefreshTokenWithRegion(
ctx,
token.ClientID,
token.ClientSecret,
token.RefreshToken,
token.Region,
token.StartURL,
)
case "builder-id":
return r.ssoClient.RefreshToken(
ctx,
token.ClientID,
token.ClientSecret,
token.RefreshToken,
)
default:
return r.oauth.RefreshTokenWithFingerprint(ctx, token.RefreshToken, token.ID)
}
}
// Use graceful degradation for better reliability
result := RefreshWithGracefulDegradation(
ctx,
refreshFunc,
token.AccessToken,
token.ExpiresAt,
)
if result.Error != nil {
log.Printf("failed to refresh token %s: %v", token.ID, result.Error)
return
}
newTokenData := result.TokenData
if result.UsedFallback {
log.Printf("token %s: using existing token as fallback (refresh failed but token still valid)", token.ID)
// Don't update the token file if we're using fallback
// Just update LastVerified to prevent immediate re-check
token.LastVerified = time.Now()
return
}
token.AccessToken = newTokenData.AccessToken
if newTokenData.RefreshToken != "" {
token.RefreshToken = newTokenData.RefreshToken
}
token.LastVerified = time.Now()
if newTokenData.ExpiresAt != "" {
if expTime, parseErr := time.Parse(time.RFC3339, newTokenData.ExpiresAt); parseErr == nil {
token.ExpiresAt = expTime
}
}
if err := r.tokenRepo.UpdateToken(token); err != nil {
log.Printf("failed to update token %s: %v", token.ID, err)
return
}
// 方案 A: 刷新成功后触发回调,通知 Watcher 更新内存中的 Auth 对象
r.callbackMu.RLock()
callback := r.onTokenRefreshed
r.callbackMu.RUnlock()
if callback != nil {
// 使用 defer recover 隔离回调 panic防止崩溃整个进程
func() {
defer func() {
if rec := recover(); rec != nil {
log.Printf("background refresh: callback panic for token %s: %v", token.ID, rec)
}
}()
log.Printf("background refresh: notifying token refresh callback for %s", token.ID)
callback(token.ID, newTokenData)
}()
}
}

View File

@@ -0,0 +1,153 @@
// Package kiro provides CodeWhisperer API client for fetching user info.
package kiro
import (
"context"
"encoding/json"
"fmt"
"io"
"net/http"
"time"
"github.com/router-for-me/CLIProxyAPI/v6/internal/config"
"github.com/router-for-me/CLIProxyAPI/v6/internal/util"
log "github.com/sirupsen/logrus"
)
// CodeWhispererClient handles CodeWhisperer API calls.
type CodeWhispererClient struct {
httpClient *http.Client
}
// UsageLimitsResponse represents the getUsageLimits API response.
type UsageLimitsResponse struct {
DaysUntilReset *int `json:"daysUntilReset,omitempty"`
NextDateReset *float64 `json:"nextDateReset,omitempty"`
UserInfo *UserInfo `json:"userInfo,omitempty"`
SubscriptionInfo *SubscriptionInfo `json:"subscriptionInfo,omitempty"`
UsageBreakdownList []UsageBreakdown `json:"usageBreakdownList,omitempty"`
}
// UserInfo contains user information from the API.
type UserInfo struct {
Email string `json:"email,omitempty"`
UserID string `json:"userId,omitempty"`
}
// SubscriptionInfo contains subscription details.
type SubscriptionInfo struct {
SubscriptionTitle string `json:"subscriptionTitle,omitempty"`
Type string `json:"type,omitempty"`
}
// UsageBreakdown contains usage details.
type UsageBreakdown struct {
UsageLimit *int `json:"usageLimit,omitempty"`
CurrentUsage *int `json:"currentUsage,omitempty"`
UsageLimitWithPrecision *float64 `json:"usageLimitWithPrecision,omitempty"`
CurrentUsageWithPrecision *float64 `json:"currentUsageWithPrecision,omitempty"`
NextDateReset *float64 `json:"nextDateReset,omitempty"`
DisplayName string `json:"displayName,omitempty"`
ResourceType string `json:"resourceType,omitempty"`
}
// NewCodeWhispererClient creates a new CodeWhisperer client.
func NewCodeWhispererClient(cfg *config.Config, machineID string) *CodeWhispererClient {
client := &http.Client{Timeout: 30 * time.Second}
if cfg != nil {
client = util.SetProxy(&cfg.SDKConfig, client)
}
return &CodeWhispererClient{
httpClient: client,
}
}
// GetUsageLimits fetches usage limits and user info from CodeWhisperer API.
// This is the recommended way to get user email after login.
func (c *CodeWhispererClient) GetUsageLimits(ctx context.Context, accessToken, clientID, refreshToken, profileArn string) (*UsageLimitsResponse, error) {
queryParams := map[string]string{
"origin": "AI_EDITOR",
"resourceType": "AGENTIC_REQUEST",
}
// Determine endpoint based on profileArn region
endpoint := GetKiroAPIEndpointFromProfileArn(profileArn)
if profileArn != "" {
queryParams["profileArn"] = profileArn
} else {
queryParams["isEmailRequired"] = "true"
}
url := buildURL(endpoint, pathGetUsageLimits, queryParams)
req, err := http.NewRequestWithContext(ctx, http.MethodGet, url, nil)
if err != nil {
return nil, fmt.Errorf("failed to create request: %w", err)
}
accountKey := GetAccountKey(clientID, refreshToken)
setRuntimeHeaders(req, accessToken, accountKey)
log.Debugf("codewhisperer: GET %s", url)
resp, err := c.httpClient.Do(req)
if err != nil {
return nil, fmt.Errorf("request failed: %w", err)
}
defer resp.Body.Close()
body, err := io.ReadAll(resp.Body)
if err != nil {
return nil, fmt.Errorf("failed to read response: %w", err)
}
log.Debugf("codewhisperer: status=%d, body=%s", resp.StatusCode, string(body))
if resp.StatusCode != http.StatusOK {
return nil, fmt.Errorf("API returned status %d: %s", resp.StatusCode, string(body))
}
var result UsageLimitsResponse
if err := json.Unmarshal(body, &result); err != nil {
return nil, fmt.Errorf("failed to parse response: %w", err)
}
return &result, nil
}
// FetchUserEmailFromAPI fetches user email using CodeWhisperer getUsageLimits API.
// This is more reliable than JWT parsing as it uses the official API.
func (c *CodeWhispererClient) FetchUserEmailFromAPI(ctx context.Context, accessToken, clientID, refreshToken string) string {
resp, err := c.GetUsageLimits(ctx, accessToken, clientID, refreshToken, "")
if err != nil {
log.Debugf("codewhisperer: failed to get usage limits: %v", err)
return ""
}
if resp.UserInfo != nil && resp.UserInfo.Email != "" {
log.Debugf("codewhisperer: got email from API: %s", resp.UserInfo.Email)
return resp.UserInfo.Email
}
log.Debugf("codewhisperer: no email in response")
return ""
}
// FetchUserEmailWithFallback fetches user email with multiple fallback methods.
// Priority: 1. CodeWhisperer API 2. userinfo endpoint 3. JWT parsing
func FetchUserEmailWithFallback(ctx context.Context, cfg *config.Config, accessToken, clientID, refreshToken string) string {
// Method 1: Try CodeWhisperer API (most reliable)
cwClient := NewCodeWhispererClient(cfg, "")
email := cwClient.FetchUserEmailFromAPI(ctx, accessToken, clientID, refreshToken)
if email != "" {
return email
}
// Method 2: Try SSO OIDC userinfo endpoint
ssoClient := NewSSOOIDCClient(cfg)
email = ssoClient.FetchUserEmail(ctx, accessToken)
if email != "" {
return email
}
// Method 3: Fallback to JWT parsing
return ExtractEmailFromJWT(accessToken)
}

View File

@@ -0,0 +1,112 @@
package kiro
import (
"sync"
"time"
)
const (
CooldownReason429 = "rate_limit_exceeded"
CooldownReasonSuspended = "account_suspended"
CooldownReasonQuotaExhausted = "quota_exhausted"
DefaultShortCooldown = 1 * time.Minute
MaxShortCooldown = 5 * time.Minute
LongCooldown = 24 * time.Hour
)
type CooldownManager struct {
mu sync.RWMutex
cooldowns map[string]time.Time
reasons map[string]string
}
func NewCooldownManager() *CooldownManager {
return &CooldownManager{
cooldowns: make(map[string]time.Time),
reasons: make(map[string]string),
}
}
func (cm *CooldownManager) SetCooldown(tokenKey string, duration time.Duration, reason string) {
cm.mu.Lock()
defer cm.mu.Unlock()
cm.cooldowns[tokenKey] = time.Now().Add(duration)
cm.reasons[tokenKey] = reason
}
func (cm *CooldownManager) IsInCooldown(tokenKey string) bool {
cm.mu.RLock()
defer cm.mu.RUnlock()
endTime, exists := cm.cooldowns[tokenKey]
if !exists {
return false
}
return time.Now().Before(endTime)
}
func (cm *CooldownManager) GetRemainingCooldown(tokenKey string) time.Duration {
cm.mu.RLock()
defer cm.mu.RUnlock()
endTime, exists := cm.cooldowns[tokenKey]
if !exists {
return 0
}
remaining := time.Until(endTime)
if remaining < 0 {
return 0
}
return remaining
}
func (cm *CooldownManager) GetCooldownReason(tokenKey string) string {
cm.mu.RLock()
defer cm.mu.RUnlock()
return cm.reasons[tokenKey]
}
func (cm *CooldownManager) ClearCooldown(tokenKey string) {
cm.mu.Lock()
defer cm.mu.Unlock()
delete(cm.cooldowns, tokenKey)
delete(cm.reasons, tokenKey)
}
func (cm *CooldownManager) CleanupExpired() {
cm.mu.Lock()
defer cm.mu.Unlock()
now := time.Now()
for tokenKey, endTime := range cm.cooldowns {
if now.After(endTime) {
delete(cm.cooldowns, tokenKey)
delete(cm.reasons, tokenKey)
}
}
}
func (cm *CooldownManager) StartCleanupRoutine(interval time.Duration, stopCh <-chan struct{}) {
ticker := time.NewTicker(interval)
defer ticker.Stop()
for {
select {
case <-ticker.C:
cm.CleanupExpired()
case <-stopCh:
return
}
}
}
func CalculateCooldownFor429(retryCount int) time.Duration {
duration := DefaultShortCooldown * time.Duration(1<<retryCount)
if duration > MaxShortCooldown {
return MaxShortCooldown
}
return duration
}
func CalculateCooldownUntilNextDay() time.Duration {
now := time.Now()
nextDay := time.Date(now.Year(), now.Month(), now.Day()+1, 0, 0, 0, 0, now.Location())
return time.Until(nextDay)
}

View File

@@ -0,0 +1,240 @@
package kiro
import (
"sync"
"testing"
"time"
)
func TestNewCooldownManager(t *testing.T) {
cm := NewCooldownManager()
if cm == nil {
t.Fatal("expected non-nil CooldownManager")
}
if cm.cooldowns == nil {
t.Error("expected non-nil cooldowns map")
}
if cm.reasons == nil {
t.Error("expected non-nil reasons map")
}
}
func TestSetCooldown(t *testing.T) {
cm := NewCooldownManager()
cm.SetCooldown("token1", 1*time.Minute, CooldownReason429)
if !cm.IsInCooldown("token1") {
t.Error("expected token to be in cooldown")
}
if cm.GetCooldownReason("token1") != CooldownReason429 {
t.Errorf("expected reason %s, got %s", CooldownReason429, cm.GetCooldownReason("token1"))
}
}
func TestIsInCooldown_NotSet(t *testing.T) {
cm := NewCooldownManager()
if cm.IsInCooldown("nonexistent") {
t.Error("expected non-existent token to not be in cooldown")
}
}
func TestIsInCooldown_Expired(t *testing.T) {
cm := NewCooldownManager()
cm.SetCooldown("token1", 1*time.Millisecond, CooldownReason429)
time.Sleep(10 * time.Millisecond)
if cm.IsInCooldown("token1") {
t.Error("expected expired cooldown to return false")
}
}
func TestGetRemainingCooldown(t *testing.T) {
cm := NewCooldownManager()
cm.SetCooldown("token1", 1*time.Second, CooldownReason429)
remaining := cm.GetRemainingCooldown("token1")
if remaining <= 0 || remaining > 1*time.Second {
t.Errorf("expected remaining cooldown between 0 and 1s, got %v", remaining)
}
}
func TestGetRemainingCooldown_NotSet(t *testing.T) {
cm := NewCooldownManager()
remaining := cm.GetRemainingCooldown("nonexistent")
if remaining != 0 {
t.Errorf("expected 0 remaining for non-existent, got %v", remaining)
}
}
func TestGetRemainingCooldown_Expired(t *testing.T) {
cm := NewCooldownManager()
cm.SetCooldown("token1", 1*time.Millisecond, CooldownReason429)
time.Sleep(10 * time.Millisecond)
remaining := cm.GetRemainingCooldown("token1")
if remaining != 0 {
t.Errorf("expected 0 remaining for expired, got %v", remaining)
}
}
func TestGetCooldownReason(t *testing.T) {
cm := NewCooldownManager()
cm.SetCooldown("token1", 1*time.Minute, CooldownReasonSuspended)
reason := cm.GetCooldownReason("token1")
if reason != CooldownReasonSuspended {
t.Errorf("expected reason %s, got %s", CooldownReasonSuspended, reason)
}
}
func TestGetCooldownReason_NotSet(t *testing.T) {
cm := NewCooldownManager()
reason := cm.GetCooldownReason("nonexistent")
if reason != "" {
t.Errorf("expected empty reason for non-existent, got %s", reason)
}
}
func TestClearCooldown(t *testing.T) {
cm := NewCooldownManager()
cm.SetCooldown("token1", 1*time.Minute, CooldownReason429)
cm.ClearCooldown("token1")
if cm.IsInCooldown("token1") {
t.Error("expected cooldown to be cleared")
}
if cm.GetCooldownReason("token1") != "" {
t.Error("expected reason to be cleared")
}
}
func TestClearCooldown_NonExistent(t *testing.T) {
cm := NewCooldownManager()
cm.ClearCooldown("nonexistent")
}
func TestCleanupExpired(t *testing.T) {
cm := NewCooldownManager()
cm.SetCooldown("expired1", 1*time.Millisecond, CooldownReason429)
cm.SetCooldown("expired2", 1*time.Millisecond, CooldownReason429)
cm.SetCooldown("active", 1*time.Hour, CooldownReason429)
time.Sleep(10 * time.Millisecond)
cm.CleanupExpired()
if cm.GetCooldownReason("expired1") != "" {
t.Error("expected expired1 to be cleaned up")
}
if cm.GetCooldownReason("expired2") != "" {
t.Error("expected expired2 to be cleaned up")
}
if cm.GetCooldownReason("active") != CooldownReason429 {
t.Error("expected active to remain")
}
}
func TestCalculateCooldownFor429_FirstRetry(t *testing.T) {
duration := CalculateCooldownFor429(0)
if duration != DefaultShortCooldown {
t.Errorf("expected %v for retry 0, got %v", DefaultShortCooldown, duration)
}
}
func TestCalculateCooldownFor429_Exponential(t *testing.T) {
d1 := CalculateCooldownFor429(1)
d2 := CalculateCooldownFor429(2)
if d2 <= d1 {
t.Errorf("expected d2 > d1, got d1=%v, d2=%v", d1, d2)
}
}
func TestCalculateCooldownFor429_MaxCap(t *testing.T) {
duration := CalculateCooldownFor429(10)
if duration > MaxShortCooldown {
t.Errorf("expected max %v, got %v", MaxShortCooldown, duration)
}
}
func TestCalculateCooldownUntilNextDay(t *testing.T) {
duration := CalculateCooldownUntilNextDay()
if duration <= 0 || duration > 24*time.Hour {
t.Errorf("expected duration between 0 and 24h, got %v", duration)
}
}
func TestCooldownManager_ConcurrentAccess(t *testing.T) {
cm := NewCooldownManager()
const numGoroutines = 50
const numOperations = 100
var wg sync.WaitGroup
wg.Add(numGoroutines)
for i := 0; i < numGoroutines; i++ {
go func(id int) {
defer wg.Done()
tokenKey := "token" + string(rune('a'+id%10))
for j := 0; j < numOperations; j++ {
switch j % 6 {
case 0:
cm.SetCooldown(tokenKey, time.Duration(j)*time.Millisecond, CooldownReason429)
case 1:
cm.IsInCooldown(tokenKey)
case 2:
cm.GetRemainingCooldown(tokenKey)
case 3:
cm.GetCooldownReason(tokenKey)
case 4:
cm.ClearCooldown(tokenKey)
case 5:
cm.CleanupExpired()
}
}
}(i)
}
wg.Wait()
}
func TestCooldownReasonConstants(t *testing.T) {
if CooldownReason429 != "rate_limit_exceeded" {
t.Errorf("unexpected CooldownReason429: %s", CooldownReason429)
}
if CooldownReasonSuspended != "account_suspended" {
t.Errorf("unexpected CooldownReasonSuspended: %s", CooldownReasonSuspended)
}
if CooldownReasonQuotaExhausted != "quota_exhausted" {
t.Errorf("unexpected CooldownReasonQuotaExhausted: %s", CooldownReasonQuotaExhausted)
}
}
func TestDefaultConstants(t *testing.T) {
if DefaultShortCooldown != 1*time.Minute {
t.Errorf("unexpected DefaultShortCooldown: %v", DefaultShortCooldown)
}
if MaxShortCooldown != 5*time.Minute {
t.Errorf("unexpected MaxShortCooldown: %v", MaxShortCooldown)
}
if LongCooldown != 24*time.Hour {
t.Errorf("unexpected LongCooldown: %v", LongCooldown)
}
}
func TestSetCooldown_OverwritesPrevious(t *testing.T) {
cm := NewCooldownManager()
cm.SetCooldown("token1", 1*time.Hour, CooldownReason429)
cm.SetCooldown("token1", 1*time.Minute, CooldownReasonSuspended)
reason := cm.GetCooldownReason("token1")
if reason != CooldownReasonSuspended {
t.Errorf("expected reason to be overwritten to %s, got %s", CooldownReasonSuspended, reason)
}
remaining := cm.GetRemainingCooldown("token1")
if remaining > 1*time.Minute {
t.Errorf("expected remaining <= 1 minute, got %v", remaining)
}
}

View File

@@ -0,0 +1,278 @@
package kiro
import (
"crypto/sha256"
"encoding/binary"
"encoding/hex"
"fmt"
"math/rand"
"net/http"
"runtime"
"slices"
"sync"
"time"
"github.com/google/uuid"
)
// Fingerprint holds multi-dimensional fingerprint data for runtime request disguise.
type Fingerprint struct {
OIDCSDKVersion string // 3.7xx (AWS SDK JS)
RuntimeSDKVersion string // 1.0.x (runtime API)
StreamingSDKVersion string // 1.0.x (streaming API)
OSType string // darwin/windows/linux
OSVersion string
NodeVersion string
KiroVersion string
KiroHash string // SHA256
}
// FingerprintConfig holds external fingerprint overrides.
type FingerprintConfig struct {
OIDCSDKVersion string
RuntimeSDKVersion string
StreamingSDKVersion string
OSType string
OSVersion string
NodeVersion string
KiroVersion string
KiroHash string
}
// FingerprintManager manages per-account fingerprint generation and caching.
type FingerprintManager struct {
mu sync.RWMutex
fingerprints map[string]*Fingerprint // tokenKey -> fingerprint
rng *rand.Rand
config *FingerprintConfig // External config (Optional)
}
var (
// SDK versions
oidcSDKVersions = []string{
"3.980.0", "3.975.0", "3.972.0", "3.808.0",
"3.738.0", "3.737.0", "3.736.0", "3.735.0",
}
// SDKVersions for getUsageLimits/ListAvailableModels/GetProfile (runtime API)
runtimeSDKVersions = []string{"1.0.0"}
// SDKVersions for generateAssistantResponse (streaming API)
streamingSDKVersions = []string{"1.0.27"}
// Valid OS types
osTypes = []string{"darwin", "windows", "linux"}
// OS versions
osVersions = map[string][]string{
"darwin": {"25.2.0", "25.1.0", "25.0.0", "24.5.0", "24.4.0", "24.3.0"},
"windows": {"10.0.26200", "10.0.26100", "10.0.22631", "10.0.22621", "10.0.19045"},
"linux": {"6.12.0", "6.11.0", "6.8.0", "6.6.0", "6.5.0", "6.1.0"},
}
// Node versions
nodeVersions = []string{
"22.21.1", "22.21.0", "22.20.0", "22.19.0", "22.18.0",
"20.18.0", "20.17.0", "20.16.0",
}
// Kiro IDE versions
kiroVersions = []string{
"0.10.32", "0.10.16", "0.10.10",
"0.9.47", "0.9.40", "0.9.2",
"0.8.206", "0.8.140", "0.8.135", "0.8.86",
}
// Global singleton
globalFingerprintManager *FingerprintManager
globalFingerprintManagerOnce sync.Once
)
func GlobalFingerprintManager() *FingerprintManager {
globalFingerprintManagerOnce.Do(func() {
globalFingerprintManager = NewFingerprintManager()
})
return globalFingerprintManager
}
func SetGlobalFingerprintConfig(cfg *FingerprintConfig) {
GlobalFingerprintManager().SetConfig(cfg)
}
// SetConfig applies the config and clears the fingerprint cache.
func (fm *FingerprintManager) SetConfig(cfg *FingerprintConfig) {
fm.mu.Lock()
defer fm.mu.Unlock()
fm.config = cfg
// Clear cached fingerprints so they regenerate with the new config
fm.fingerprints = make(map[string]*Fingerprint)
}
func NewFingerprintManager() *FingerprintManager {
return &FingerprintManager{
fingerprints: make(map[string]*Fingerprint),
rng: rand.New(rand.NewSource(time.Now().UnixNano())),
}
}
// GetFingerprint returns the fingerprint for tokenKey, creating one if it doesn't exist.
func (fm *FingerprintManager) GetFingerprint(tokenKey string) *Fingerprint {
fm.mu.RLock()
if fp, exists := fm.fingerprints[tokenKey]; exists {
fm.mu.RUnlock()
return fp
}
fm.mu.RUnlock()
fm.mu.Lock()
defer fm.mu.Unlock()
if fp, exists := fm.fingerprints[tokenKey]; exists {
return fp
}
fp := fm.generateFingerprint(tokenKey)
fm.fingerprints[tokenKey] = fp
return fp
}
func (fm *FingerprintManager) generateFingerprint(tokenKey string) *Fingerprint {
if fm.config != nil {
return fm.generateFromConfig(tokenKey)
}
return fm.generateRandom(tokenKey)
}
// generateFromConfig uses config values, falling back to random for empty fields.
func (fm *FingerprintManager) generateFromConfig(tokenKey string) *Fingerprint {
cfg := fm.config
// Helper: config value or random selection
configOrRandom := func(configVal string, choices []string) string {
if configVal != "" {
return configVal
}
return choices[fm.rng.Intn(len(choices))]
}
osType := cfg.OSType
if osType == "" {
osType = runtime.GOOS
if !slices.Contains(osTypes, osType) {
osType = osTypes[fm.rng.Intn(len(osTypes))]
}
}
osVersion := cfg.OSVersion
if osVersion == "" {
if versions, ok := osVersions[osType]; ok {
osVersion = versions[fm.rng.Intn(len(versions))]
}
}
kiroHash := cfg.KiroHash
if kiroHash == "" {
hash := sha256.Sum256([]byte(tokenKey))
kiroHash = hex.EncodeToString(hash[:])
}
return &Fingerprint{
OIDCSDKVersion: configOrRandom(cfg.OIDCSDKVersion, oidcSDKVersions),
RuntimeSDKVersion: configOrRandom(cfg.RuntimeSDKVersion, runtimeSDKVersions),
StreamingSDKVersion: configOrRandom(cfg.StreamingSDKVersion, streamingSDKVersions),
OSType: osType,
OSVersion: osVersion,
NodeVersion: configOrRandom(cfg.NodeVersion, nodeVersions),
KiroVersion: configOrRandom(cfg.KiroVersion, kiroVersions),
KiroHash: kiroHash,
}
}
// generateRandom generates a deterministic fingerprint seeded by accountKey hash.
func (fm *FingerprintManager) generateRandom(accountKey string) *Fingerprint {
// Use accountKey hash as seed for deterministic random selection
hash := sha256.Sum256([]byte(accountKey))
seed := int64(binary.BigEndian.Uint64(hash[:8]))
rng := rand.New(rand.NewSource(seed))
osType := runtime.GOOS
if !slices.Contains(osTypes, osType) {
osType = osTypes[rng.Intn(len(osTypes))]
}
osVersion := osVersions[osType][rng.Intn(len(osVersions[osType]))]
return &Fingerprint{
OIDCSDKVersion: oidcSDKVersions[rng.Intn(len(oidcSDKVersions))],
RuntimeSDKVersion: runtimeSDKVersions[rng.Intn(len(runtimeSDKVersions))],
StreamingSDKVersion: streamingSDKVersions[rng.Intn(len(streamingSDKVersions))],
OSType: osType,
OSVersion: osVersion,
NodeVersion: nodeVersions[rng.Intn(len(nodeVersions))],
KiroVersion: kiroVersions[rng.Intn(len(kiroVersions))],
KiroHash: hex.EncodeToString(hash[:]),
}
}
// GenerateAccountKey returns a 16-char hex key derived from SHA256(seed).
func GenerateAccountKey(seed string) string {
hash := sha256.Sum256([]byte(seed))
return hex.EncodeToString(hash[:8])
}
// GetAccountKey derives an account key from clientID > refreshToken > random UUID.
func GetAccountKey(clientID, refreshToken string) string {
// 1. Prefer ClientID
if clientID != "" {
return GenerateAccountKey(clientID)
}
// 2. Fallback to RefreshToken
if refreshToken != "" {
return GenerateAccountKey(refreshToken)
}
// 3. Random fallback
return GenerateAccountKey(uuid.New().String())
}
// BuildUserAgent format: aws-sdk-js/{SDKVersion} ua/2.1 os/{OSType}#{OSVersion} lang/js md/nodejs#{NodeVersion} api/codewhispererstreaming#{SDKVersion} m/E KiroIDE-{KiroVersion}-{KiroHash}
func (fp *Fingerprint) BuildUserAgent() string {
return fmt.Sprintf(
"aws-sdk-js/%s ua/2.1 os/%s#%s lang/js md/nodejs#%s api/codewhispererstreaming#%s m/E KiroIDE-%s-%s",
fp.StreamingSDKVersion,
fp.OSType,
fp.OSVersion,
fp.NodeVersion,
fp.StreamingSDKVersion,
fp.KiroVersion,
fp.KiroHash,
)
}
// BuildAmzUserAgent format: aws-sdk-js/{SDKVersion} KiroIDE-{KiroVersion}-{KiroHash}
func (fp *Fingerprint) BuildAmzUserAgent() string {
return fmt.Sprintf(
"aws-sdk-js/%s KiroIDE-%s-%s",
fp.StreamingSDKVersion,
fp.KiroVersion,
fp.KiroHash,
)
}
func SetOIDCHeaders(req *http.Request) {
fp := GlobalFingerprintManager().GetFingerprint("oidc-session")
req.Header.Set("Content-Type", "application/json")
req.Header.Set("x-amz-user-agent", fmt.Sprintf("aws-sdk-js/%s KiroIDE", fp.OIDCSDKVersion))
req.Header.Set("User-Agent", fmt.Sprintf(
"aws-sdk-js/%s ua/2.1 os/%s#%s lang/js md/nodejs#%s api/%s#%s m/E KiroIDE",
fp.OIDCSDKVersion, fp.OSType, fp.OSVersion, fp.NodeVersion, "sso-oidc", fp.OIDCSDKVersion))
req.Header.Set("amz-sdk-invocation-id", uuid.New().String())
req.Header.Set("amz-sdk-request", "attempt=1; max=4")
}
func setRuntimeHeaders(req *http.Request, accessToken string, accountKey string) {
fp := GlobalFingerprintManager().GetFingerprint(accountKey)
machineID := fp.KiroHash
req.Header.Set("Authorization", "Bearer "+accessToken)
req.Header.Set("x-amz-user-agent", fmt.Sprintf("aws-sdk-js/%s KiroIDE-%s-%s",
fp.RuntimeSDKVersion, fp.KiroVersion, machineID))
req.Header.Set("User-Agent", fmt.Sprintf(
"aws-sdk-js/%s ua/2.1 os/%s#%s lang/js md/nodejs#%s api/codewhispererruntime#%s m/N,E KiroIDE-%s-%s",
fp.RuntimeSDKVersion, fp.OSType, fp.OSVersion, fp.NodeVersion, fp.RuntimeSDKVersion,
fp.KiroVersion, machineID))
req.Header.Set("amz-sdk-invocation-id", uuid.New().String())
req.Header.Set("amz-sdk-request", "attempt=1; max=1")
}

View File

@@ -0,0 +1,778 @@
package kiro
import (
"net/http"
"runtime"
"strings"
"sync"
"testing"
)
func TestNewFingerprintManager(t *testing.T) {
fm := NewFingerprintManager()
if fm == nil {
t.Fatal("expected non-nil FingerprintManager")
}
if fm.fingerprints == nil {
t.Error("expected non-nil fingerprints map")
}
if fm.rng == nil {
t.Error("expected non-nil rng")
}
}
func TestGetFingerprint_NewToken(t *testing.T) {
fm := NewFingerprintManager()
fp := fm.GetFingerprint("token1")
if fp == nil {
t.Fatal("expected non-nil Fingerprint")
}
if fp.OIDCSDKVersion == "" {
t.Error("expected non-empty OIDCSDKVersion")
}
if fp.RuntimeSDKVersion == "" {
t.Error("expected non-empty RuntimeSDKVersion")
}
if fp.StreamingSDKVersion == "" {
t.Error("expected non-empty StreamingSDKVersion")
}
if fp.OSType == "" {
t.Error("expected non-empty OSType")
}
if fp.OSVersion == "" {
t.Error("expected non-empty OSVersion")
}
if fp.NodeVersion == "" {
t.Error("expected non-empty NodeVersion")
}
if fp.KiroVersion == "" {
t.Error("expected non-empty KiroVersion")
}
if fp.KiroHash == "" {
t.Error("expected non-empty KiroHash")
}
}
func TestGetFingerprint_SameTokenReturnsSameFingerprint(t *testing.T) {
fm := NewFingerprintManager()
fp1 := fm.GetFingerprint("token1")
fp2 := fm.GetFingerprint("token1")
if fp1 != fp2 {
t.Error("expected same fingerprint for same token")
}
}
func TestGetFingerprint_DifferentTokens(t *testing.T) {
fm := NewFingerprintManager()
fp1 := fm.GetFingerprint("token1")
fp2 := fm.GetFingerprint("token2")
if fp1 == fp2 {
t.Error("expected different fingerprints for different tokens")
}
}
func TestBuildUserAgent(t *testing.T) {
fm := NewFingerprintManager()
fp := fm.GetFingerprint("token1")
ua := fp.BuildUserAgent()
if ua == "" {
t.Error("expected non-empty User-Agent")
}
amzUA := fp.BuildAmzUserAgent()
if amzUA == "" {
t.Error("expected non-empty X-Amz-User-Agent")
}
}
func TestGetFingerprint_OSVersionMatchesOSType(t *testing.T) {
fm := NewFingerprintManager()
for i := 0; i < 20; i++ {
fp := fm.GetFingerprint("token" + string(rune('a'+i)))
validVersions := osVersions[fp.OSType]
found := false
for _, v := range validVersions {
if v == fp.OSVersion {
found = true
break
}
}
if !found {
t.Errorf("OS version %s not valid for OS type %s", fp.OSVersion, fp.OSType)
}
}
}
func TestGenerateFromConfig_OSTypeFromRuntimeGOOS(t *testing.T) {
fm := NewFingerprintManager()
// Set config with empty OSType to trigger runtime.GOOS fallback
fm.SetConfig(&FingerprintConfig{
OIDCSDKVersion: "3.738.0", // Set other fields to use config path
})
fp := fm.GetFingerprint("test-token")
// Expected OS type based on runtime.GOOS mapping
var expectedOS string
switch runtime.GOOS {
case "darwin":
expectedOS = "darwin"
case "windows":
expectedOS = "windows"
default:
expectedOS = "linux"
}
if fp.OSType != expectedOS {
t.Errorf("expected OSType '%s' from runtime.GOOS '%s', got '%s'",
expectedOS, runtime.GOOS, fp.OSType)
}
}
func TestFingerprintManager_ConcurrentAccess(t *testing.T) {
fm := NewFingerprintManager()
const numGoroutines = 100
const numOperations = 100
var wg sync.WaitGroup
wg.Add(numGoroutines)
for i := range numGoroutines {
go func(id int) {
defer wg.Done()
for j := range numOperations {
tokenKey := "token" + string(rune('a'+id%26))
switch j % 2 {
case 0:
fm.GetFingerprint(tokenKey)
case 1:
fp := fm.GetFingerprint(tokenKey)
_ = fp.BuildUserAgent()
_ = fp.BuildAmzUserAgent()
}
}
}(i)
}
wg.Wait()
}
func TestKiroHashStability(t *testing.T) {
fm := NewFingerprintManager()
// Same token should always return same hash
fp1 := fm.GetFingerprint("token1")
fp2 := fm.GetFingerprint("token1")
if fp1.KiroHash != fp2.KiroHash {
t.Errorf("same token should have same hash: %s vs %s", fp1.KiroHash, fp2.KiroHash)
}
// Different tokens should have different hashes
fp3 := fm.GetFingerprint("token2")
if fp1.KiroHash == fp3.KiroHash {
t.Errorf("different tokens should have different hashes")
}
}
func TestKiroHashFormat(t *testing.T) {
fm := NewFingerprintManager()
fp := fm.GetFingerprint("token1")
if len(fp.KiroHash) != 64 {
t.Errorf("expected KiroHash length 64 (SHA256 hex), got %d", len(fp.KiroHash))
}
for _, c := range fp.KiroHash {
if (c < '0' || c > '9') && (c < 'a' || c > 'f') {
t.Errorf("invalid hex character in KiroHash: %c", c)
}
}
}
func TestGlobalFingerprintManager(t *testing.T) {
fm1 := GlobalFingerprintManager()
fm2 := GlobalFingerprintManager()
if fm1 == nil {
t.Fatal("expected non-nil GlobalFingerprintManager")
}
if fm1 != fm2 {
t.Error("expected GlobalFingerprintManager to return same instance")
}
}
func TestSetOIDCHeaders(t *testing.T) {
req, _ := http.NewRequest("GET", "http://example.com", nil)
SetOIDCHeaders(req)
if req.Header.Get("Content-Type") != "application/json" {
t.Error("expected Content-Type header to be set")
}
amzUA := req.Header.Get("x-amz-user-agent")
if amzUA == "" {
t.Error("expected x-amz-user-agent header to be set")
}
if !strings.Contains(amzUA, "aws-sdk-js/") {
t.Errorf("x-amz-user-agent should contain aws-sdk-js: %s", amzUA)
}
if !strings.Contains(amzUA, "KiroIDE") {
t.Errorf("x-amz-user-agent should contain KiroIDE: %s", amzUA)
}
ua := req.Header.Get("User-Agent")
if ua == "" {
t.Error("expected User-Agent header to be set")
}
if !strings.Contains(ua, "api/sso-oidc") {
t.Errorf("User-Agent should contain api name: %s", ua)
}
if req.Header.Get("amz-sdk-invocation-id") == "" {
t.Error("expected amz-sdk-invocation-id header to be set")
}
if req.Header.Get("amz-sdk-request") != "attempt=1; max=4" {
t.Errorf("unexpected amz-sdk-request header: %s", req.Header.Get("amz-sdk-request"))
}
}
func TestBuildURL(t *testing.T) {
tests := []struct {
name string
endpoint string
path string
queryParams map[string]string
want string
wantContains []string
}{
{
name: "no query params",
endpoint: "https://api.example.com",
path: "getUsageLimits",
queryParams: nil,
want: "https://api.example.com/getUsageLimits",
},
{
name: "empty query params",
endpoint: "https://api.example.com",
path: "getUsageLimits",
queryParams: map[string]string{},
want: "https://api.example.com/getUsageLimits",
},
{
name: "single query param",
endpoint: "https://api.example.com",
path: "getUsageLimits",
queryParams: map[string]string{
"origin": "AI_EDITOR",
},
want: "https://api.example.com/getUsageLimits?origin=AI_EDITOR",
},
{
name: "multiple query params",
endpoint: "https://api.example.com",
path: "getUsageLimits",
queryParams: map[string]string{
"origin": "AI_EDITOR",
"resourceType": "AGENTIC_REQUEST",
"profileArn": "arn:aws:codewhisperer:us-east-1:123456789012:profile/ABCDEF",
},
wantContains: []string{
"https://api.example.com/getUsageLimits?",
"origin=AI_EDITOR",
"profileArn=arn%3Aaws%3Acodewhisperer%3Aus-east-1%3A123456789012%3Aprofile%2FABCDEF",
"resourceType=AGENTIC_REQUEST",
},
},
{
name: "omit empty params",
endpoint: "https://api.example.com",
path: "getUsageLimits",
queryParams: map[string]string{
"origin": "AI_EDITOR",
"profileArn": "",
},
want: "https://api.example.com/getUsageLimits?origin=AI_EDITOR",
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
got := buildURL(tt.endpoint, tt.path, tt.queryParams)
if tt.want != "" {
if got != tt.want {
t.Errorf("buildURL() = %v, want %v", got, tt.want)
}
}
if tt.wantContains != nil {
for _, substr := range tt.wantContains {
if !strings.Contains(got, substr) {
t.Errorf("buildURL() = %v, want to contain %v", got, substr)
}
}
}
})
}
}
func TestBuildUserAgentFormat(t *testing.T) {
fm := NewFingerprintManager()
fp := fm.GetFingerprint("token1")
ua := fp.BuildUserAgent()
requiredParts := []string{
"aws-sdk-js/",
"ua/2.1",
"os/",
"lang/js",
"md/nodejs#",
"api/codewhispererstreaming#",
"m/E",
"KiroIDE-",
}
for _, part := range requiredParts {
if !strings.Contains(ua, part) {
t.Errorf("User-Agent missing required part %q: %s", part, ua)
}
}
}
func TestBuildAmzUserAgentFormat(t *testing.T) {
fm := NewFingerprintManager()
fp := fm.GetFingerprint("token1")
amzUA := fp.BuildAmzUserAgent()
requiredParts := []string{
"aws-sdk-js/",
"KiroIDE-",
}
for _, part := range requiredParts {
if !strings.Contains(amzUA, part) {
t.Errorf("X-Amz-User-Agent missing required part %q: %s", part, amzUA)
}
}
// Amz-User-Agent should be shorter than User-Agent
ua := fp.BuildUserAgent()
if len(amzUA) >= len(ua) {
t.Error("X-Amz-User-Agent should be shorter than User-Agent")
}
}
func TestSetRuntimeHeaders(t *testing.T) {
req, _ := http.NewRequest("GET", "http://example.com", nil)
accessToken := "test-access-token-1234567890"
clientID := "test-client-id-12345"
accountKey := GenerateAccountKey(clientID)
fp := GlobalFingerprintManager().GetFingerprint(accountKey)
machineID := fp.KiroHash
setRuntimeHeaders(req, accessToken, accountKey)
// Check Authorization header
if req.Header.Get("Authorization") != "Bearer "+accessToken {
t.Errorf("expected Authorization header 'Bearer %s', got '%s'", accessToken, req.Header.Get("Authorization"))
}
// Check x-amz-user-agent header
amzUA := req.Header.Get("x-amz-user-agent")
if amzUA == "" {
t.Error("expected x-amz-user-agent header to be set")
}
if !strings.Contains(amzUA, "aws-sdk-js/") {
t.Errorf("x-amz-user-agent should contain aws-sdk-js: %s", amzUA)
}
if !strings.Contains(amzUA, "KiroIDE-") {
t.Errorf("x-amz-user-agent should contain KiroIDE: %s", amzUA)
}
if !strings.Contains(amzUA, machineID) {
t.Errorf("x-amz-user-agent should contain machineID: %s", amzUA)
}
// Check User-Agent header
ua := req.Header.Get("User-Agent")
if ua == "" {
t.Error("expected User-Agent header to be set")
}
if !strings.Contains(ua, "api/codewhispererruntime#") {
t.Errorf("User-Agent should contain api/codewhispererruntime: %s", ua)
}
if !strings.Contains(ua, "m/N,E") {
t.Errorf("User-Agent should contain m/N,E: %s", ua)
}
// Check amz-sdk-invocation-id (should be a UUID)
invocationID := req.Header.Get("amz-sdk-invocation-id")
if invocationID == "" {
t.Error("expected amz-sdk-invocation-id header to be set")
}
if len(invocationID) != 36 {
t.Errorf("expected amz-sdk-invocation-id to be UUID (36 chars), got %d", len(invocationID))
}
// Check amz-sdk-request
if req.Header.Get("amz-sdk-request") != "attempt=1; max=1" {
t.Errorf("unexpected amz-sdk-request header: %s", req.Header.Get("amz-sdk-request"))
}
}
func TestSDKVersionsAreValid(t *testing.T) {
// Verify all OIDC SDK versions match expected format (3.xxx.x)
for _, v := range oidcSDKVersions {
if !strings.HasPrefix(v, "3.") {
t.Errorf("OIDC SDK version should start with 3.: %s", v)
}
parts := strings.Split(v, ".")
if len(parts) != 3 {
t.Errorf("OIDC SDK version should have 3 parts: %s", v)
}
}
for _, v := range runtimeSDKVersions {
parts := strings.Split(v, ".")
if len(parts) != 3 {
t.Errorf("Runtime SDK version should have 3 parts: %s", v)
}
}
for _, v := range streamingSDKVersions {
parts := strings.Split(v, ".")
if len(parts) != 3 {
t.Errorf("Streaming SDK version should have 3 parts: %s", v)
}
}
}
func TestKiroVersionsAreValid(t *testing.T) {
// Verify all Kiro versions match expected format (0.x.xxx)
for _, v := range kiroVersions {
if !strings.HasPrefix(v, "0.") {
t.Errorf("Kiro version should start with 0.: %s", v)
}
parts := strings.Split(v, ".")
if len(parts) != 3 {
t.Errorf("Kiro version should have 3 parts: %s", v)
}
}
}
func TestNodeVersionsAreValid(t *testing.T) {
// Verify all Node versions match expected format (xx.xx.x)
for _, v := range nodeVersions {
parts := strings.Split(v, ".")
if len(parts) != 3 {
t.Errorf("Node version should have 3 parts: %s", v)
}
// Should be Node 20.x or 22.x
if !strings.HasPrefix(v, "20.") && !strings.HasPrefix(v, "22.") {
t.Errorf("Node version should be 20.x or 22.x LTS: %s", v)
}
}
}
func TestFingerprintManager_SetConfig(t *testing.T) {
fm := NewFingerprintManager()
// Without config, should generate random fingerprint
fp1 := fm.GetFingerprint("token1")
if fp1 == nil {
t.Fatal("expected non-nil fingerprint")
}
// Set config with all fields
cfg := &FingerprintConfig{
OIDCSDKVersion: "3.999.0",
RuntimeSDKVersion: "9.9.9",
StreamingSDKVersion: "8.8.8",
OSType: "darwin",
OSVersion: "99.0.0",
NodeVersion: "99.99.99",
KiroVersion: "9.9.999",
KiroHash: "customhash123",
}
fm.SetConfig(cfg)
// After setting config, should use config values
fp2 := fm.GetFingerprint("token2")
if fp2.OIDCSDKVersion != "3.999.0" {
t.Errorf("expected OIDCSDKVersion '3.999.0', got '%s'", fp2.OIDCSDKVersion)
}
if fp2.RuntimeSDKVersion != "9.9.9" {
t.Errorf("expected RuntimeSDKVersion '9.9.9', got '%s'", fp2.RuntimeSDKVersion)
}
if fp2.StreamingSDKVersion != "8.8.8" {
t.Errorf("expected StreamingSDKVersion '8.8.8', got '%s'", fp2.StreamingSDKVersion)
}
if fp2.OSType != "darwin" {
t.Errorf("expected OSType 'darwin', got '%s'", fp2.OSType)
}
if fp2.OSVersion != "99.0.0" {
t.Errorf("expected OSVersion '99.0.0', got '%s'", fp2.OSVersion)
}
if fp2.NodeVersion != "99.99.99" {
t.Errorf("expected NodeVersion '99.99.99', got '%s'", fp2.NodeVersion)
}
if fp2.KiroVersion != "9.9.999" {
t.Errorf("expected KiroVersion '9.9.999', got '%s'", fp2.KiroVersion)
}
if fp2.KiroHash != "customhash123" {
t.Errorf("expected KiroHash 'customhash123', got '%s'", fp2.KiroHash)
}
}
func TestFingerprintManager_SetConfig_PartialFields(t *testing.T) {
fm := NewFingerprintManager()
// Set config with only some fields
cfg := &FingerprintConfig{
KiroVersion: "1.2.345",
KiroHash: "myhash",
// Other fields empty - should use random
}
fm.SetConfig(cfg)
fp := fm.GetFingerprint("token1")
// Configured fields should use config values
if fp.KiroVersion != "1.2.345" {
t.Errorf("expected KiroVersion '1.2.345', got '%s'", fp.KiroVersion)
}
if fp.KiroHash != "myhash" {
t.Errorf("expected KiroHash 'myhash', got '%s'", fp.KiroHash)
}
// Empty fields should be randomly selected (non-empty)
if fp.OIDCSDKVersion == "" {
t.Error("expected non-empty OIDCSDKVersion")
}
if fp.OSType == "" {
t.Error("expected non-empty OSType")
}
if fp.NodeVersion == "" {
t.Error("expected non-empty NodeVersion")
}
}
func TestFingerprintManager_SetConfig_ClearsCache(t *testing.T) {
fm := NewFingerprintManager()
// Get fingerprint before config
fp1 := fm.GetFingerprint("token1")
originalHash := fp1.KiroHash
// Set config
cfg := &FingerprintConfig{
KiroHash: "newcustomhash",
}
fm.SetConfig(cfg)
// Same token should now return different fingerprint (cache cleared)
fp2 := fm.GetFingerprint("token1")
if fp2.KiroHash == originalHash {
t.Error("expected cache to be cleared after SetConfig")
}
if fp2.KiroHash != "newcustomhash" {
t.Errorf("expected KiroHash 'newcustomhash', got '%s'", fp2.KiroHash)
}
}
func TestGenerateAccountKey(t *testing.T) {
tests := []struct {
name string
seed string
check func(t *testing.T, result string)
}{
{
name: "Empty seed",
seed: "",
check: func(t *testing.T, result string) {
if result == "" {
t.Error("expected non-empty result for empty seed")
}
if len(result) != 16 {
t.Errorf("expected 16 char hex string, got %d chars", len(result))
}
},
},
{
name: "Simple seed",
seed: "test-client-id",
check: func(t *testing.T, result string) {
if len(result) != 16 {
t.Errorf("expected 16 char hex string, got %d chars", len(result))
}
// Verify it's valid hex
for _, c := range result {
if (c < '0' || c > '9') && (c < 'a' || c > 'f') {
t.Errorf("invalid hex character: %c", c)
}
}
},
},
{
name: "Same seed produces same result",
seed: "deterministic-seed",
check: func(t *testing.T, result string) {
result2 := GenerateAccountKey("deterministic-seed")
if result != result2 {
t.Errorf("same seed should produce same result: %s vs %s", result, result2)
}
},
},
{
name: "Different seeds produce different results",
seed: "seed-one",
check: func(t *testing.T, result string) {
result2 := GenerateAccountKey("seed-two")
if result == result2 {
t.Errorf("different seeds should produce different results: %s vs %s", result, result2)
}
},
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
result := GenerateAccountKey(tt.seed)
tt.check(t, result)
})
}
}
func TestGetAccountKey(t *testing.T) {
tests := []struct {
name string
clientID string
refreshToken string
check func(t *testing.T, result string)
}{
{
name: "Priority 1: clientID when both provided",
clientID: "client-id-123",
refreshToken: "refresh-token-456",
check: func(t *testing.T, result string) {
expected := GenerateAccountKey("client-id-123")
if result != expected {
t.Errorf("expected clientID-based key %s, got %s", expected, result)
}
},
},
{
name: "Priority 2: refreshToken when clientID is empty",
clientID: "",
refreshToken: "refresh-token-789",
check: func(t *testing.T, result string) {
expected := GenerateAccountKey("refresh-token-789")
if result != expected {
t.Errorf("expected refreshToken-based key %s, got %s", expected, result)
}
},
},
{
name: "Priority 3: random when both empty",
clientID: "",
refreshToken: "",
check: func(t *testing.T, result string) {
if len(result) != 16 {
t.Errorf("expected 16 char key, got %d chars", len(result))
}
// Should be different each time (random UUID)
result2 := GetAccountKey("", "")
if result == result2 {
t.Log("warning: random keys are the same (possible but unlikely)")
}
},
},
{
name: "clientID only",
clientID: "solo-client-id",
refreshToken: "",
check: func(t *testing.T, result string) {
expected := GenerateAccountKey("solo-client-id")
if result != expected {
t.Errorf("expected clientID-based key %s, got %s", expected, result)
}
},
},
{
name: "refreshToken only",
clientID: "",
refreshToken: "solo-refresh-token",
check: func(t *testing.T, result string) {
expected := GenerateAccountKey("solo-refresh-token")
if result != expected {
t.Errorf("expected refreshToken-based key %s, got %s", expected, result)
}
},
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
result := GetAccountKey(tt.clientID, tt.refreshToken)
tt.check(t, result)
})
}
}
func TestGetAccountKey_Deterministic(t *testing.T) {
// Verify that GetAccountKey produces deterministic results for same inputs
clientID := "test-client-id-abc"
refreshToken := "test-refresh-token-xyz"
// Call multiple times with same inputs
results := make([]string, 10)
for i := range 10 {
results[i] = GetAccountKey(clientID, refreshToken)
}
// All results should be identical
for i := 1; i < 10; i++ {
if results[i] != results[0] {
t.Errorf("GetAccountKey should be deterministic: got %s and %s", results[0], results[i])
}
}
}
func TestFingerprintDeterministic(t *testing.T) {
// Verify that fingerprints are deterministic based on accountKey
fm := NewFingerprintManager()
accountKey := GenerateAccountKey("test-client-id")
// Get fingerprint multiple times
fp1 := fm.GetFingerprint(accountKey)
fp2 := fm.GetFingerprint(accountKey)
// Should be the same pointer (cached)
if fp1 != fp2 {
t.Error("expected same fingerprint pointer for same key")
}
// Create new manager and verify same values
fm2 := NewFingerprintManager()
fp3 := fm2.GetFingerprint(accountKey)
// Values should be identical (deterministic generation)
if fp1.KiroHash != fp3.KiroHash {
t.Errorf("KiroHash should be deterministic: %s vs %s", fp1.KiroHash, fp3.KiroHash)
}
if fp1.OSType != fp3.OSType {
t.Errorf("OSType should be deterministic: %s vs %s", fp1.OSType, fp3.OSType)
}
if fp1.OSVersion != fp3.OSVersion {
t.Errorf("OSVersion should be deterministic: %s vs %s", fp1.OSVersion, fp3.OSVersion)
}
if fp1.KiroVersion != fp3.KiroVersion {
t.Errorf("KiroVersion should be deterministic: %s vs %s", fp1.KiroVersion, fp3.KiroVersion)
}
if fp1.NodeVersion != fp3.NodeVersion {
t.Errorf("NodeVersion should be deterministic: %s vs %s", fp1.NodeVersion, fp3.NodeVersion)
}
}

View File

@@ -0,0 +1,174 @@
package kiro
import (
"math/rand"
"sync"
"time"
)
// Jitter configuration constants
const (
// JitterPercent is the default percentage of jitter to apply (±30%)
JitterPercent = 0.30
// Human-like delay ranges
ShortDelayMin = 50 * time.Millisecond // Minimum for rapid consecutive operations
ShortDelayMax = 200 * time.Millisecond // Maximum for rapid consecutive operations
NormalDelayMin = 1 * time.Second // Minimum for normal thinking time
NormalDelayMax = 3 * time.Second // Maximum for normal thinking time
LongDelayMin = 5 * time.Second // Minimum for reading/resting
LongDelayMax = 10 * time.Second // Maximum for reading/resting
// Probability thresholds for human-like behavior
ShortDelayProbability = 0.20 // 20% chance of short delay (consecutive ops)
LongDelayProbability = 0.05 // 5% chance of long delay (reading/resting)
NormalDelayProbability = 0.75 // 75% chance of normal delay (thinking)
)
var (
jitterRand *rand.Rand
jitterRandOnce sync.Once
jitterMu sync.Mutex
lastRequestTime time.Time
)
// initJitterRand initializes the random number generator for jitter calculations.
// Uses a time-based seed for unpredictable but reproducible randomness.
func initJitterRand() {
jitterRandOnce.Do(func() {
jitterRand = rand.New(rand.NewSource(time.Now().UnixNano()))
})
}
// RandomDelay generates a random delay between min and max duration.
// Thread-safe implementation using mutex protection.
func RandomDelay(min, max time.Duration) time.Duration {
initJitterRand()
jitterMu.Lock()
defer jitterMu.Unlock()
if min >= max {
return min
}
rangeMs := max.Milliseconds() - min.Milliseconds()
randomMs := jitterRand.Int63n(rangeMs)
return min + time.Duration(randomMs)*time.Millisecond
}
// JitterDelay adds jitter to a base delay.
// Applies ±jitterPercent variation to the base delay.
// For example, JitterDelay(1*time.Second, 0.30) returns a value between 700ms and 1300ms.
func JitterDelay(baseDelay time.Duration, jitterPercent float64) time.Duration {
initJitterRand()
jitterMu.Lock()
defer jitterMu.Unlock()
if jitterPercent <= 0 || jitterPercent > 1 {
jitterPercent = JitterPercent
}
// Calculate jitter range: base * jitterPercent
jitterRange := float64(baseDelay) * jitterPercent
// Generate random value in range [-jitterRange, +jitterRange]
jitter := (jitterRand.Float64()*2 - 1) * jitterRange
result := time.Duration(float64(baseDelay) + jitter)
if result < 0 {
return 0
}
return result
}
// JitterDelayDefault applies the default ±30% jitter to a base delay.
func JitterDelayDefault(baseDelay time.Duration) time.Duration {
return JitterDelay(baseDelay, JitterPercent)
}
// HumanLikeDelay generates a delay that mimics human behavior patterns.
// The delay is selected based on probability distribution:
// - 20% chance: Short delay (50-200ms) - simulates consecutive rapid operations
// - 75% chance: Normal delay (1-3s) - simulates thinking/reading time
// - 5% chance: Long delay (5-10s) - simulates breaks/reading longer content
//
// Returns the delay duration (caller should call time.Sleep with this value).
func HumanLikeDelay() time.Duration {
initJitterRand()
jitterMu.Lock()
defer jitterMu.Unlock()
// Track time since last request for adaptive behavior
now := time.Now()
timeSinceLastRequest := now.Sub(lastRequestTime)
lastRequestTime = now
// If requests are very close together, use short delay
if timeSinceLastRequest < 500*time.Millisecond && timeSinceLastRequest > 0 {
rangeMs := ShortDelayMax.Milliseconds() - ShortDelayMin.Milliseconds()
randomMs := jitterRand.Int63n(rangeMs)
return ShortDelayMin + time.Duration(randomMs)*time.Millisecond
}
// Otherwise, use probability-based selection
roll := jitterRand.Float64()
var min, max time.Duration
switch {
case roll < ShortDelayProbability:
// Short delay - consecutive operations
min, max = ShortDelayMin, ShortDelayMax
case roll < ShortDelayProbability+LongDelayProbability:
// Long delay - reading/resting
min, max = LongDelayMin, LongDelayMax
default:
// Normal delay - thinking time
min, max = NormalDelayMin, NormalDelayMax
}
rangeMs := max.Milliseconds() - min.Milliseconds()
randomMs := jitterRand.Int63n(rangeMs)
return min + time.Duration(randomMs)*time.Millisecond
}
// ApplyHumanLikeDelay applies human-like delay by sleeping.
// This is a convenience function that combines HumanLikeDelay with time.Sleep.
func ApplyHumanLikeDelay() {
delay := HumanLikeDelay()
if delay > 0 {
time.Sleep(delay)
}
}
// ExponentialBackoffWithJitter calculates retry delay using exponential backoff with jitter.
// Formula: min(baseDelay * 2^attempt + jitter, maxDelay)
// This helps prevent thundering herd problem when multiple clients retry simultaneously.
func ExponentialBackoffWithJitter(attempt int, baseDelay, maxDelay time.Duration) time.Duration {
if attempt < 0 {
attempt = 0
}
// Calculate exponential backoff: baseDelay * 2^attempt
backoff := baseDelay * time.Duration(1<<uint(attempt))
if backoff > maxDelay {
backoff = maxDelay
}
// Add ±30% jitter
return JitterDelay(backoff, JitterPercent)
}
// ShouldSkipDelay determines if delay should be skipped based on context.
// Returns true for streaming responses, WebSocket connections, etc.
// This function can be extended to check additional skip conditions.
func ShouldSkipDelay(isStreaming bool) bool {
return isStreaming
}
// ResetLastRequestTime resets the last request time tracker.
// Useful for testing or when starting a new session.
func ResetLastRequestTime() {
jitterMu.Lock()
defer jitterMu.Unlock()
lastRequestTime = time.Time{}
}

View File

@@ -0,0 +1,187 @@
package kiro
import (
"math"
"sync"
"time"
)
// TokenMetrics holds performance metrics for a single token.
type TokenMetrics struct {
SuccessRate float64 // Success rate (0.0 - 1.0)
AvgLatency float64 // Average latency in milliseconds
QuotaRemaining float64 // Remaining quota (0.0 - 1.0)
LastUsed time.Time // Last usage timestamp
FailCount int // Consecutive failure count
TotalRequests int // Total request count
successCount int // Internal: successful request count
totalLatency float64 // Internal: cumulative latency
}
// TokenScorer manages token metrics and scoring.
type TokenScorer struct {
mu sync.RWMutex
metrics map[string]*TokenMetrics
// Scoring weights
successRateWeight float64
quotaWeight float64
latencyWeight float64
lastUsedWeight float64
failPenaltyMultiplier float64
}
// NewTokenScorer creates a new TokenScorer with default weights.
func NewTokenScorer() *TokenScorer {
return &TokenScorer{
metrics: make(map[string]*TokenMetrics),
successRateWeight: 0.4,
quotaWeight: 0.25,
latencyWeight: 0.2,
lastUsedWeight: 0.15,
failPenaltyMultiplier: 0.1,
}
}
// getOrCreateMetrics returns existing metrics or creates new ones.
func (s *TokenScorer) getOrCreateMetrics(tokenKey string) *TokenMetrics {
if m, ok := s.metrics[tokenKey]; ok {
return m
}
m := &TokenMetrics{
SuccessRate: 1.0,
QuotaRemaining: 1.0,
}
s.metrics[tokenKey] = m
return m
}
// RecordRequest records the result of a request for a token.
func (s *TokenScorer) RecordRequest(tokenKey string, success bool, latency time.Duration) {
s.mu.Lock()
defer s.mu.Unlock()
m := s.getOrCreateMetrics(tokenKey)
m.TotalRequests++
m.LastUsed = time.Now()
m.totalLatency += float64(latency.Milliseconds())
if success {
m.successCount++
m.FailCount = 0
} else {
m.FailCount++
}
// Update derived metrics
if m.TotalRequests > 0 {
m.SuccessRate = float64(m.successCount) / float64(m.TotalRequests)
m.AvgLatency = m.totalLatency / float64(m.TotalRequests)
}
}
// SetQuotaRemaining updates the remaining quota for a token.
func (s *TokenScorer) SetQuotaRemaining(tokenKey string, quota float64) {
s.mu.Lock()
defer s.mu.Unlock()
m := s.getOrCreateMetrics(tokenKey)
m.QuotaRemaining = quota
}
// GetMetrics returns a copy of the metrics for a token.
func (s *TokenScorer) GetMetrics(tokenKey string) *TokenMetrics {
s.mu.RLock()
defer s.mu.RUnlock()
if m, ok := s.metrics[tokenKey]; ok {
copy := *m
return &copy
}
return nil
}
// CalculateScore computes the score for a token (higher is better).
func (s *TokenScorer) CalculateScore(tokenKey string) float64 {
s.mu.RLock()
defer s.mu.RUnlock()
m, ok := s.metrics[tokenKey]
if !ok {
return 1.0 // New tokens get a high initial score
}
// Success rate component (0-1)
successScore := m.SuccessRate
// Quota component (0-1)
quotaScore := m.QuotaRemaining
// Latency component (normalized, lower is better)
// Using exponential decay: score = e^(-latency/1000)
// 1000ms latency -> ~0.37 score, 100ms -> ~0.90 score
latencyScore := math.Exp(-m.AvgLatency / 1000.0)
if m.TotalRequests == 0 {
latencyScore = 1.0
}
// Last used component (prefer tokens not recently used)
// Score increases as time since last use increases
timeSinceUse := time.Since(m.LastUsed).Seconds()
// Normalize: 60 seconds -> ~0.63 score, 0 seconds -> 0 score
lastUsedScore := 1.0 - math.Exp(-timeSinceUse/60.0)
if m.LastUsed.IsZero() {
lastUsedScore = 1.0
}
// Calculate weighted score
score := s.successRateWeight*successScore +
s.quotaWeight*quotaScore +
s.latencyWeight*latencyScore +
s.lastUsedWeight*lastUsedScore
// Apply consecutive failure penalty
if m.FailCount > 0 {
penalty := s.failPenaltyMultiplier * float64(m.FailCount)
score = score * math.Max(0, 1.0-penalty)
}
return score
}
// SelectBestToken selects the token with the highest score.
func (s *TokenScorer) SelectBestToken(tokens []string) string {
if len(tokens) == 0 {
return ""
}
if len(tokens) == 1 {
return tokens[0]
}
bestToken := tokens[0]
bestScore := s.CalculateScore(tokens[0])
for _, token := range tokens[1:] {
score := s.CalculateScore(token)
if score > bestScore {
bestScore = score
bestToken = token
}
}
return bestToken
}
// ResetMetrics clears all metrics for a token.
func (s *TokenScorer) ResetMetrics(tokenKey string) {
s.mu.Lock()
defer s.mu.Unlock()
delete(s.metrics, tokenKey)
}
// ResetAllMetrics clears all stored metrics.
func (s *TokenScorer) ResetAllMetrics() {
s.mu.Lock()
defer s.mu.Unlock()
s.metrics = make(map[string]*TokenMetrics)
}

View File

@@ -0,0 +1,301 @@
package kiro
import (
"sync"
"testing"
"time"
)
func TestNewTokenScorer(t *testing.T) {
s := NewTokenScorer()
if s == nil {
t.Fatal("expected non-nil TokenScorer")
}
if s.metrics == nil {
t.Error("expected non-nil metrics map")
}
if s.successRateWeight != 0.4 {
t.Errorf("expected successRateWeight 0.4, got %f", s.successRateWeight)
}
if s.quotaWeight != 0.25 {
t.Errorf("expected quotaWeight 0.25, got %f", s.quotaWeight)
}
}
func TestRecordRequest_Success(t *testing.T) {
s := NewTokenScorer()
s.RecordRequest("token1", true, 100*time.Millisecond)
m := s.GetMetrics("token1")
if m == nil {
t.Fatal("expected non-nil metrics")
}
if m.TotalRequests != 1 {
t.Errorf("expected TotalRequests 1, got %d", m.TotalRequests)
}
if m.SuccessRate != 1.0 {
t.Errorf("expected SuccessRate 1.0, got %f", m.SuccessRate)
}
if m.FailCount != 0 {
t.Errorf("expected FailCount 0, got %d", m.FailCount)
}
if m.AvgLatency != 100 {
t.Errorf("expected AvgLatency 100, got %f", m.AvgLatency)
}
}
func TestRecordRequest_Failure(t *testing.T) {
s := NewTokenScorer()
s.RecordRequest("token1", false, 200*time.Millisecond)
m := s.GetMetrics("token1")
if m.SuccessRate != 0.0 {
t.Errorf("expected SuccessRate 0.0, got %f", m.SuccessRate)
}
if m.FailCount != 1 {
t.Errorf("expected FailCount 1, got %d", m.FailCount)
}
}
func TestRecordRequest_MixedResults(t *testing.T) {
s := NewTokenScorer()
s.RecordRequest("token1", true, 100*time.Millisecond)
s.RecordRequest("token1", true, 100*time.Millisecond)
s.RecordRequest("token1", false, 100*time.Millisecond)
s.RecordRequest("token1", true, 100*time.Millisecond)
m := s.GetMetrics("token1")
if m.TotalRequests != 4 {
t.Errorf("expected TotalRequests 4, got %d", m.TotalRequests)
}
if m.SuccessRate != 0.75 {
t.Errorf("expected SuccessRate 0.75, got %f", m.SuccessRate)
}
if m.FailCount != 0 {
t.Errorf("expected FailCount 0 (reset on success), got %d", m.FailCount)
}
}
func TestRecordRequest_ConsecutiveFailures(t *testing.T) {
s := NewTokenScorer()
s.RecordRequest("token1", true, 100*time.Millisecond)
s.RecordRequest("token1", false, 100*time.Millisecond)
s.RecordRequest("token1", false, 100*time.Millisecond)
s.RecordRequest("token1", false, 100*time.Millisecond)
m := s.GetMetrics("token1")
if m.FailCount != 3 {
t.Errorf("expected FailCount 3, got %d", m.FailCount)
}
}
func TestSetQuotaRemaining(t *testing.T) {
s := NewTokenScorer()
s.SetQuotaRemaining("token1", 0.5)
m := s.GetMetrics("token1")
if m.QuotaRemaining != 0.5 {
t.Errorf("expected QuotaRemaining 0.5, got %f", m.QuotaRemaining)
}
}
func TestGetMetrics_NonExistent(t *testing.T) {
s := NewTokenScorer()
m := s.GetMetrics("nonexistent")
if m != nil {
t.Error("expected nil metrics for non-existent token")
}
}
func TestGetMetrics_ReturnsCopy(t *testing.T) {
s := NewTokenScorer()
s.RecordRequest("token1", true, 100*time.Millisecond)
m1 := s.GetMetrics("token1")
m1.TotalRequests = 999
m2 := s.GetMetrics("token1")
if m2.TotalRequests == 999 {
t.Error("GetMetrics should return a copy")
}
}
func TestCalculateScore_NewToken(t *testing.T) {
s := NewTokenScorer()
score := s.CalculateScore("newtoken")
if score != 1.0 {
t.Errorf("expected score 1.0 for new token, got %f", score)
}
}
func TestCalculateScore_PerfectToken(t *testing.T) {
s := NewTokenScorer()
s.RecordRequest("token1", true, 50*time.Millisecond)
s.SetQuotaRemaining("token1", 1.0)
time.Sleep(100 * time.Millisecond)
score := s.CalculateScore("token1")
if score < 0.5 || score > 1.0 {
t.Errorf("expected high score for perfect token, got %f", score)
}
}
func TestCalculateScore_FailedToken(t *testing.T) {
s := NewTokenScorer()
for i := 0; i < 5; i++ {
s.RecordRequest("token1", false, 1000*time.Millisecond)
}
s.SetQuotaRemaining("token1", 0.1)
score := s.CalculateScore("token1")
if score > 0.5 {
t.Errorf("expected low score for failed token, got %f", score)
}
}
func TestCalculateScore_FailPenalty(t *testing.T) {
s := NewTokenScorer()
s.RecordRequest("token1", true, 100*time.Millisecond)
scoreNoFail := s.CalculateScore("token1")
s.RecordRequest("token1", false, 100*time.Millisecond)
s.RecordRequest("token1", false, 100*time.Millisecond)
scoreWithFail := s.CalculateScore("token1")
if scoreWithFail >= scoreNoFail {
t.Errorf("expected lower score with consecutive failures: noFail=%f, withFail=%f", scoreNoFail, scoreWithFail)
}
}
func TestSelectBestToken_Empty(t *testing.T) {
s := NewTokenScorer()
best := s.SelectBestToken([]string{})
if best != "" {
t.Errorf("expected empty string for empty tokens, got %s", best)
}
}
func TestSelectBestToken_SingleToken(t *testing.T) {
s := NewTokenScorer()
best := s.SelectBestToken([]string{"token1"})
if best != "token1" {
t.Errorf("expected token1, got %s", best)
}
}
func TestSelectBestToken_MultipleTokens(t *testing.T) {
s := NewTokenScorer()
s.RecordRequest("bad", false, 1000*time.Millisecond)
s.RecordRequest("bad", false, 1000*time.Millisecond)
s.SetQuotaRemaining("bad", 0.1)
s.RecordRequest("good", true, 50*time.Millisecond)
s.SetQuotaRemaining("good", 0.9)
time.Sleep(50 * time.Millisecond)
best := s.SelectBestToken([]string{"bad", "good"})
if best != "good" {
t.Errorf("expected good token to be selected, got %s", best)
}
}
func TestResetMetrics(t *testing.T) {
s := NewTokenScorer()
s.RecordRequest("token1", true, 100*time.Millisecond)
s.ResetMetrics("token1")
m := s.GetMetrics("token1")
if m != nil {
t.Error("expected nil metrics after reset")
}
}
func TestResetAllMetrics(t *testing.T) {
s := NewTokenScorer()
s.RecordRequest("token1", true, 100*time.Millisecond)
s.RecordRequest("token2", true, 100*time.Millisecond)
s.RecordRequest("token3", true, 100*time.Millisecond)
s.ResetAllMetrics()
if s.GetMetrics("token1") != nil {
t.Error("expected nil metrics for token1 after reset all")
}
if s.GetMetrics("token2") != nil {
t.Error("expected nil metrics for token2 after reset all")
}
}
func TestTokenScorer_ConcurrentAccess(t *testing.T) {
s := NewTokenScorer()
const numGoroutines = 50
const numOperations = 100
var wg sync.WaitGroup
wg.Add(numGoroutines)
for i := 0; i < numGoroutines; i++ {
go func(id int) {
defer wg.Done()
tokenKey := "token" + string(rune('a'+id%10))
for j := 0; j < numOperations; j++ {
switch j % 6 {
case 0:
s.RecordRequest(tokenKey, j%2 == 0, time.Duration(j)*time.Millisecond)
case 1:
s.SetQuotaRemaining(tokenKey, float64(j%100)/100)
case 2:
s.GetMetrics(tokenKey)
case 3:
s.CalculateScore(tokenKey)
case 4:
s.SelectBestToken([]string{tokenKey, "token_x", "token_y"})
case 5:
if j%20 == 0 {
s.ResetMetrics(tokenKey)
}
}
}
}(i)
}
wg.Wait()
}
func TestAvgLatencyCalculation(t *testing.T) {
s := NewTokenScorer()
s.RecordRequest("token1", true, 100*time.Millisecond)
s.RecordRequest("token1", true, 200*time.Millisecond)
s.RecordRequest("token1", true, 300*time.Millisecond)
m := s.GetMetrics("token1")
if m.AvgLatency != 200 {
t.Errorf("expected AvgLatency 200, got %f", m.AvgLatency)
}
}
func TestLastUsedUpdated(t *testing.T) {
s := NewTokenScorer()
before := time.Now()
s.RecordRequest("token1", true, 100*time.Millisecond)
m := s.GetMetrics("token1")
if m.LastUsed.Before(before) {
t.Error("expected LastUsed to be after test start time")
}
if m.LastUsed.After(time.Now()) {
t.Error("expected LastUsed to be before or equal to now")
}
}
func TestDefaultQuotaForNewToken(t *testing.T) {
s := NewTokenScorer()
s.RecordRequest("token1", true, 100*time.Millisecond)
m := s.GetMetrics("token1")
if m.QuotaRemaining != 1.0 {
t.Errorf("expected default QuotaRemaining 1.0, got %f", m.QuotaRemaining)
}
}

319
internal/auth/kiro/oauth.go Normal file
View File

@@ -0,0 +1,319 @@
// Package kiro provides OAuth2 authentication for Kiro using native Google login.
package kiro
import (
"context"
"crypto/rand"
"crypto/sha256"
"encoding/base64"
"encoding/json"
"fmt"
"html"
"io"
"net"
"net/http"
"strings"
"time"
"github.com/router-for-me/CLIProxyAPI/v6/internal/config"
"github.com/router-for-me/CLIProxyAPI/v6/internal/util"
log "github.com/sirupsen/logrus"
)
const (
// Kiro auth endpoint
kiroAuthEndpoint = "https://prod.us-east-1.auth.desktop.kiro.dev"
// Default callback port
defaultCallbackPort = 9876
// Auth timeout
authTimeout = 10 * time.Minute
)
// KiroTokenResponse represents the response from Kiro token endpoint.
type KiroTokenResponse struct {
AccessToken string `json:"accessToken"`
RefreshToken string `json:"refreshToken"`
ProfileArn string `json:"profileArn"`
ExpiresIn int `json:"expiresIn"`
}
// KiroOAuth handles the OAuth flow for Kiro authentication.
type KiroOAuth struct {
httpClient *http.Client
cfg *config.Config
machineID string
kiroVersion string
}
// NewKiroOAuth creates a new Kiro OAuth handler.
func NewKiroOAuth(cfg *config.Config) *KiroOAuth {
client := &http.Client{Timeout: 30 * time.Second}
if cfg != nil {
client = util.SetProxy(&cfg.SDKConfig, client)
}
fp := GlobalFingerprintManager().GetFingerprint("login")
return &KiroOAuth{
httpClient: client,
cfg: cfg,
machineID: fp.KiroHash,
kiroVersion: fp.KiroVersion,
}
}
// generateCodeVerifier generates a random code verifier for PKCE.
func generateCodeVerifier() (string, error) {
b := make([]byte, 32)
if _, err := rand.Read(b); err != nil {
return "", err
}
return base64.RawURLEncoding.EncodeToString(b), nil
}
// generateCodeChallenge generates the code challenge from verifier.
func generateCodeChallenge(verifier string) string {
h := sha256.Sum256([]byte(verifier))
return base64.RawURLEncoding.EncodeToString(h[:])
}
// generateState generates a random state parameter.
func generateState() (string, error) {
b := make([]byte, 16)
if _, err := rand.Read(b); err != nil {
return "", err
}
return base64.RawURLEncoding.EncodeToString(b), nil
}
// AuthResult contains the authorization code and state from callback.
type AuthResult struct {
Code string
State string
Error string
}
// startCallbackServer starts a local HTTP server to receive the OAuth callback.
func (o *KiroOAuth) startCallbackServer(ctx context.Context, expectedState string) (string, <-chan AuthResult, error) {
// Try to find an available port - use localhost like Kiro does
listener, err := net.Listen("tcp", fmt.Sprintf("localhost:%d", defaultCallbackPort))
if err != nil {
// Try with dynamic port (RFC 8252 allows dynamic ports for native apps)
log.Warnf("kiro oauth: default port %d is busy, falling back to dynamic port", defaultCallbackPort)
listener, err = net.Listen("tcp", "localhost:0")
if err != nil {
return "", nil, fmt.Errorf("failed to start callback server: %w", err)
}
}
port := listener.Addr().(*net.TCPAddr).Port
// Use http scheme for local callback server
redirectURI := fmt.Sprintf("http://localhost:%d/oauth/callback", port)
resultChan := make(chan AuthResult, 1)
server := &http.Server{
ReadHeaderTimeout: 10 * time.Second,
}
mux := http.NewServeMux()
mux.HandleFunc("/oauth/callback", func(w http.ResponseWriter, r *http.Request) {
code := r.URL.Query().Get("code")
state := r.URL.Query().Get("state")
errParam := r.URL.Query().Get("error")
if errParam != "" {
w.Header().Set("Content-Type", "text/html")
w.WriteHeader(http.StatusBadRequest)
fmt.Fprintf(w, `<html><body><h1>Login Failed</h1><p>%s</p><p>You can close this window.</p></body></html>`, html.EscapeString(errParam))
resultChan <- AuthResult{Error: errParam}
return
}
if state != expectedState {
w.Header().Set("Content-Type", "text/html")
w.WriteHeader(http.StatusBadRequest)
fmt.Fprint(w, `<html><body><h1>Login Failed</h1><p>Invalid state parameter</p><p>You can close this window.</p></body></html>`)
resultChan <- AuthResult{Error: "state mismatch"}
return
}
w.Header().Set("Content-Type", "text/html")
fmt.Fprint(w, `<html><body><h1>Login Successful!</h1><p>You can close this window and return to the terminal.</p></body></html>`)
resultChan <- AuthResult{Code: code, State: state}
})
server.Handler = mux
go func() {
if err := server.Serve(listener); err != nil && err != http.ErrServerClosed {
log.Debugf("callback server error: %v", err)
}
}()
go func() {
select {
case <-ctx.Done():
case <-time.After(authTimeout):
case <-resultChan:
}
_ = server.Shutdown(context.Background())
}()
return redirectURI, resultChan, nil
}
// LoginWithBuilderID performs OAuth login with AWS Builder ID using device code flow.
func (o *KiroOAuth) LoginWithBuilderID(ctx context.Context) (*KiroTokenData, error) {
ssoClient := NewSSOOIDCClient(o.cfg)
return ssoClient.LoginWithBuilderID(ctx)
}
// LoginWithBuilderIDAuthCode performs OAuth login with AWS Builder ID using authorization code flow.
// This provides a better UX than device code flow as it uses automatic browser callback.
func (o *KiroOAuth) LoginWithBuilderIDAuthCode(ctx context.Context) (*KiroTokenData, error) {
ssoClient := NewSSOOIDCClient(o.cfg)
return ssoClient.LoginWithBuilderIDAuthCode(ctx)
}
// exchangeCodeForToken exchanges the authorization code for tokens.
func (o *KiroOAuth) exchangeCodeForToken(ctx context.Context, code, codeVerifier, redirectURI string) (*KiroTokenData, error) {
payload := map[string]string{
"code": code,
"code_verifier": codeVerifier,
"redirect_uri": redirectURI,
}
body, err := json.Marshal(payload)
if err != nil {
return nil, fmt.Errorf("failed to marshal request: %w", err)
}
tokenURL := kiroAuthEndpoint + "/oauth/token"
req, err := http.NewRequestWithContext(ctx, http.MethodPost, tokenURL, strings.NewReader(string(body)))
if err != nil {
return nil, fmt.Errorf("failed to create request: %w", err)
}
req.Header.Set("Content-Type", "application/json")
req.Header.Set("User-Agent", fmt.Sprintf("KiroIDE-%s-%s", o.kiroVersion, o.machineID))
req.Header.Set("Accept", "application/json, text/plain, */*")
resp, err := o.httpClient.Do(req)
if err != nil {
return nil, fmt.Errorf("token request failed: %w", err)
}
defer resp.Body.Close()
respBody, err := io.ReadAll(resp.Body)
if err != nil {
return nil, fmt.Errorf("failed to read response: %w", err)
}
if resp.StatusCode != http.StatusOK {
log.Debugf("token exchange failed (status %d): %s", resp.StatusCode, string(respBody))
return nil, fmt.Errorf("token exchange failed (status %d)", resp.StatusCode)
}
var tokenResp KiroTokenResponse
if err := json.Unmarshal(respBody, &tokenResp); err != nil {
return nil, fmt.Errorf("failed to parse token response: %w", err)
}
// Validate ExpiresIn - use default 1 hour if invalid
expiresIn := tokenResp.ExpiresIn
if expiresIn <= 0 {
expiresIn = 3600
}
expiresAt := time.Now().Add(time.Duration(expiresIn) * time.Second)
return &KiroTokenData{
AccessToken: tokenResp.AccessToken,
RefreshToken: tokenResp.RefreshToken,
ProfileArn: tokenResp.ProfileArn,
ExpiresAt: expiresAt.Format(time.RFC3339),
AuthMethod: "social",
Provider: "", // Caller should preserve original provider
Region: "us-east-1",
}, nil
}
// RefreshToken refreshes an expired access token.
// Uses KiroIDE-style User-Agent to match official Kiro IDE behavior.
func (o *KiroOAuth) RefreshToken(ctx context.Context, refreshToken string) (*KiroTokenData, error) {
return o.RefreshTokenWithFingerprint(ctx, refreshToken, "")
}
// RefreshTokenWithFingerprint refreshes an expired access token with a specific fingerprint.
// tokenKey is used to generate a consistent fingerprint for the token.
func (o *KiroOAuth) RefreshTokenWithFingerprint(ctx context.Context, refreshToken, tokenKey string) (*KiroTokenData, error) {
payload := map[string]string{
"refreshToken": refreshToken,
}
body, err := json.Marshal(payload)
if err != nil {
return nil, fmt.Errorf("failed to marshal request: %w", err)
}
refreshURL := kiroAuthEndpoint + "/refreshToken"
req, err := http.NewRequestWithContext(ctx, http.MethodPost, refreshURL, strings.NewReader(string(body)))
if err != nil {
return nil, fmt.Errorf("failed to create request: %w", err)
}
req.Header.Set("Content-Type", "application/json")
req.Header.Set("User-Agent", fmt.Sprintf("KiroIDE-%s-%s", o.kiroVersion, o.machineID))
req.Header.Set("Accept", "application/json, text/plain, */*")
resp, err := o.httpClient.Do(req)
if err != nil {
return nil, fmt.Errorf("refresh request failed: %w", err)
}
defer resp.Body.Close()
respBody, err := io.ReadAll(resp.Body)
if err != nil {
return nil, fmt.Errorf("failed to read response: %w", err)
}
if resp.StatusCode != http.StatusOK {
log.Debugf("token refresh failed (status %d): %s", resp.StatusCode, string(respBody))
return nil, fmt.Errorf("token refresh failed (status %d): %s", resp.StatusCode, string(respBody))
}
var tokenResp KiroTokenResponse
if err := json.Unmarshal(respBody, &tokenResp); err != nil {
return nil, fmt.Errorf("failed to parse token response: %w", err)
}
// Validate ExpiresIn - use default 1 hour if invalid
expiresIn := tokenResp.ExpiresIn
if expiresIn <= 0 {
expiresIn = 3600
}
expiresAt := time.Now().Add(time.Duration(expiresIn) * time.Second)
return &KiroTokenData{
AccessToken: tokenResp.AccessToken,
RefreshToken: tokenResp.RefreshToken,
ProfileArn: tokenResp.ProfileArn,
ExpiresAt: expiresAt.Format(time.RFC3339),
AuthMethod: "social",
Provider: "", // Caller should preserve original provider
Region: "us-east-1",
}, nil
}
// LoginWithGoogle performs OAuth login with Google using Kiro's social auth.
// This uses a custom protocol handler (kiro://) to receive the callback.
func (o *KiroOAuth) LoginWithGoogle(ctx context.Context) (*KiroTokenData, error) {
socialClient := NewSocialAuthClient(o.cfg)
return socialClient.LoginWithGoogle(ctx)
}
// LoginWithGitHub performs OAuth login with GitHub using Kiro's social auth.
// This uses a custom protocol handler (kiro://) to receive the callback.
func (o *KiroOAuth) LoginWithGitHub(ctx context.Context) (*KiroTokenData, error) {
socialClient := NewSocialAuthClient(o.cfg)
return socialClient.LoginWithGitHub(ctx)
}

View File

@@ -0,0 +1,975 @@
// Package kiro provides OAuth Web authentication for Kiro.
package kiro
import (
"context"
"crypto/rand"
"encoding/base64"
"encoding/json"
"fmt"
"html/template"
"net/http"
"os"
"path/filepath"
"strings"
"sync"
"time"
"github.com/gin-gonic/gin"
"github.com/router-for-me/CLIProxyAPI/v6/internal/config"
"github.com/router-for-me/CLIProxyAPI/v6/internal/util"
log "github.com/sirupsen/logrus"
)
const (
defaultSessionExpiry = 10 * time.Minute
pollIntervalSeconds = 5
)
type authSessionStatus string
const (
statusPending authSessionStatus = "pending"
statusSuccess authSessionStatus = "success"
statusFailed authSessionStatus = "failed"
)
type webAuthSession struct {
stateID string
deviceCode string
userCode string
authURL string
verificationURI string
expiresIn int
interval int
status authSessionStatus
startedAt time.Time
completedAt time.Time
expiresAt time.Time
error string
tokenData *KiroTokenData
ssoClient *SSOOIDCClient
clientID string
clientSecret string
region string
cancelFunc context.CancelFunc
authMethod string // "google", "github", "builder-id", "idc"
startURL string // Used for IDC
codeVerifier string // Used for social auth PKCE
codeChallenge string // Used for social auth PKCE
}
type OAuthWebHandler struct {
cfg *config.Config
sessions map[string]*webAuthSession
mu sync.RWMutex
onTokenObtained func(*KiroTokenData)
}
func NewOAuthWebHandler(cfg *config.Config) *OAuthWebHandler {
return &OAuthWebHandler{
cfg: cfg,
sessions: make(map[string]*webAuthSession),
}
}
func (h *OAuthWebHandler) SetTokenCallback(callback func(*KiroTokenData)) {
h.onTokenObtained = callback
}
func (h *OAuthWebHandler) RegisterRoutes(router gin.IRouter) {
oauth := router.Group("/v0/oauth/kiro")
{
oauth.GET("", h.handleSelect)
oauth.GET("/start", h.handleStart)
oauth.GET("/callback", h.handleCallback)
oauth.GET("/social/callback", h.handleSocialCallback)
oauth.GET("/status", h.handleStatus)
oauth.POST("/import", h.handleImportToken)
oauth.POST("/refresh", h.handleManualRefresh)
}
}
func generateStateID() (string, error) {
b := make([]byte, 16)
if _, err := rand.Read(b); err != nil {
return "", err
}
return base64.RawURLEncoding.EncodeToString(b), nil
}
func (h *OAuthWebHandler) handleSelect(c *gin.Context) {
h.renderSelectPage(c)
}
func (h *OAuthWebHandler) handleStart(c *gin.Context) {
method := c.Query("method")
if method == "" {
c.Redirect(http.StatusFound, "/v0/oauth/kiro")
return
}
switch method {
case "google", "github":
// Google/GitHub social login is not supported for third-party apps
// due to AWS Cognito redirect_uri restrictions
h.renderError(c, "Google/GitHub login is not available for third-party applications. Please use AWS Builder ID or import your token from Kiro IDE.")
case "builder-id":
h.startBuilderIDAuth(c)
case "idc":
h.startIDCAuth(c)
default:
h.renderError(c, fmt.Sprintf("Unknown authentication method: %s", method))
}
}
func (h *OAuthWebHandler) startSocialAuth(c *gin.Context, method string) {
stateID, err := generateStateID()
if err != nil {
h.renderError(c, "Failed to generate state parameter")
return
}
codeVerifier, codeChallenge, err := generatePKCE()
if err != nil {
h.renderError(c, "Failed to generate PKCE parameters")
return
}
socialClient := NewSocialAuthClient(h.cfg)
var provider string
if method == "google" {
provider = string(ProviderGoogle)
} else {
provider = string(ProviderGitHub)
}
redirectURI := h.getSocialCallbackURL(c)
authURL := socialClient.buildLoginURL(provider, redirectURI, codeChallenge, stateID)
ctx, cancel := context.WithTimeout(context.Background(), 10*time.Minute)
session := &webAuthSession{
stateID: stateID,
authMethod: method,
authURL: authURL,
status: statusPending,
startedAt: time.Now(),
expiresIn: 600,
codeVerifier: codeVerifier,
codeChallenge: codeChallenge,
region: "us-east-1",
cancelFunc: cancel,
}
h.mu.Lock()
h.sessions[stateID] = session
h.mu.Unlock()
go func() {
<-ctx.Done()
h.mu.Lock()
if session.status == statusPending {
session.status = statusFailed
session.error = "Authentication timed out"
}
h.mu.Unlock()
}()
c.Redirect(http.StatusFound, authURL)
}
func (h *OAuthWebHandler) getSocialCallbackURL(c *gin.Context) string {
scheme := "http"
if c.Request.TLS != nil || c.GetHeader("X-Forwarded-Proto") == "https" {
scheme = "https"
}
return fmt.Sprintf("%s://%s/v0/oauth/kiro/social/callback", scheme, c.Request.Host)
}
func (h *OAuthWebHandler) startBuilderIDAuth(c *gin.Context) {
stateID, err := generateStateID()
if err != nil {
h.renderError(c, "Failed to generate state parameter")
return
}
region := defaultIDCRegion
startURL := builderIDStartURL
ssoClient := NewSSOOIDCClient(h.cfg)
regResp, err := ssoClient.RegisterClientWithRegion(c.Request.Context(), region)
if err != nil {
log.Errorf("OAuth Web: failed to register client: %v", err)
h.renderError(c, fmt.Sprintf("Failed to register client: %v", err))
return
}
authResp, err := ssoClient.StartDeviceAuthorizationWithIDC(
c.Request.Context(),
regResp.ClientID,
regResp.ClientSecret,
startURL,
region,
)
if err != nil {
log.Errorf("OAuth Web: failed to start device authorization: %v", err)
h.renderError(c, fmt.Sprintf("Failed to start device authorization: %v", err))
return
}
ctx, cancel := context.WithTimeout(context.Background(), time.Duration(authResp.ExpiresIn)*time.Second)
session := &webAuthSession{
stateID: stateID,
deviceCode: authResp.DeviceCode,
userCode: authResp.UserCode,
authURL: authResp.VerificationURIComplete,
verificationURI: authResp.VerificationURI,
expiresIn: authResp.ExpiresIn,
interval: authResp.Interval,
status: statusPending,
startedAt: time.Now(),
ssoClient: ssoClient,
clientID: regResp.ClientID,
clientSecret: regResp.ClientSecret,
region: region,
authMethod: "builder-id",
startURL: startURL,
cancelFunc: cancel,
}
h.mu.Lock()
h.sessions[stateID] = session
h.mu.Unlock()
go h.pollForToken(ctx, session)
h.renderStartPage(c, session)
}
func (h *OAuthWebHandler) startIDCAuth(c *gin.Context) {
startURL := c.Query("startUrl")
region := c.Query("region")
if startURL == "" {
h.renderError(c, "Missing startUrl parameter for IDC authentication")
return
}
if region == "" {
region = defaultIDCRegion
}
stateID, err := generateStateID()
if err != nil {
h.renderError(c, "Failed to generate state parameter")
return
}
ssoClient := NewSSOOIDCClient(h.cfg)
regResp, err := ssoClient.RegisterClientWithRegion(c.Request.Context(), region)
if err != nil {
log.Errorf("OAuth Web: failed to register client: %v", err)
h.renderError(c, fmt.Sprintf("Failed to register client: %v", err))
return
}
authResp, err := ssoClient.StartDeviceAuthorizationWithIDC(
c.Request.Context(),
regResp.ClientID,
regResp.ClientSecret,
startURL,
region,
)
if err != nil {
log.Errorf("OAuth Web: failed to start device authorization: %v", err)
h.renderError(c, fmt.Sprintf("Failed to start device authorization: %v", err))
return
}
ctx, cancel := context.WithTimeout(context.Background(), time.Duration(authResp.ExpiresIn)*time.Second)
session := &webAuthSession{
stateID: stateID,
deviceCode: authResp.DeviceCode,
userCode: authResp.UserCode,
authURL: authResp.VerificationURIComplete,
verificationURI: authResp.VerificationURI,
expiresIn: authResp.ExpiresIn,
interval: authResp.Interval,
status: statusPending,
startedAt: time.Now(),
ssoClient: ssoClient,
clientID: regResp.ClientID,
clientSecret: regResp.ClientSecret,
region: region,
authMethod: "idc",
startURL: startURL,
cancelFunc: cancel,
}
h.mu.Lock()
h.sessions[stateID] = session
h.mu.Unlock()
go h.pollForToken(ctx, session)
h.renderStartPage(c, session)
}
func (h *OAuthWebHandler) pollForToken(ctx context.Context, session *webAuthSession) {
defer session.cancelFunc()
interval := time.Duration(session.interval) * time.Second
if interval < time.Duration(pollIntervalSeconds)*time.Second {
interval = time.Duration(pollIntervalSeconds) * time.Second
}
ticker := time.NewTicker(interval)
defer ticker.Stop()
for {
select {
case <-ctx.Done():
h.mu.Lock()
if session.status == statusPending {
session.status = statusFailed
session.error = "Authentication timed out"
}
h.mu.Unlock()
return
case <-ticker.C:
tokenResp, err := h.ssoClient(session).CreateTokenWithRegion(
ctx,
session.clientID,
session.clientSecret,
session.deviceCode,
session.region,
)
if err != nil {
errStr := err.Error()
if errStr == ErrAuthorizationPending.Error() {
continue
}
if errStr == ErrSlowDown.Error() {
interval += 5 * time.Second
ticker.Reset(interval)
continue
}
h.mu.Lock()
session.status = statusFailed
session.error = errStr
session.completedAt = time.Now()
h.mu.Unlock()
log.Errorf("OAuth Web: token polling failed: %v", err)
return
}
expiresAt := time.Now().Add(time.Duration(tokenResp.ExpiresIn) * time.Second)
// Fetch profileArn for IDC
var profileArn string
if session.authMethod == "idc" {
profileArn = session.ssoClient.FetchProfileArn(ctx, tokenResp.AccessToken, session.clientID, tokenResp.RefreshToken)
}
email := FetchUserEmailWithFallback(ctx, h.cfg, tokenResp.AccessToken, session.clientID, tokenResp.RefreshToken)
tokenData := &KiroTokenData{
AccessToken: tokenResp.AccessToken,
RefreshToken: tokenResp.RefreshToken,
ProfileArn: profileArn,
ExpiresAt: expiresAt.Format(time.RFC3339),
AuthMethod: session.authMethod,
Provider: "AWS",
ClientID: session.clientID,
ClientSecret: session.clientSecret,
Email: email,
Region: session.region,
StartURL: session.startURL,
}
h.mu.Lock()
session.status = statusSuccess
session.completedAt = time.Now()
session.expiresAt = expiresAt
session.tokenData = tokenData
h.mu.Unlock()
if h.onTokenObtained != nil {
h.onTokenObtained(tokenData)
}
// Save token to file
h.saveTokenToFile(tokenData)
log.Infof("OAuth Web: authentication successful for %s", email)
return
}
}
}
// saveTokenToFile saves the token data to the auth directory
func (h *OAuthWebHandler) saveTokenToFile(tokenData *KiroTokenData) {
// Get auth directory from config or use default
authDir := ""
if h.cfg != nil && h.cfg.AuthDir != "" {
var err error
authDir, err = util.ResolveAuthDir(h.cfg.AuthDir)
if err != nil {
log.Errorf("OAuth Web: failed to resolve auth directory: %v", err)
}
}
// Fall back to default location
if authDir == "" {
home, err := os.UserHomeDir()
if err != nil {
log.Errorf("OAuth Web: failed to get home directory: %v", err)
return
}
authDir = filepath.Join(home, ".cli-proxy-api")
}
// Create directory if not exists
if err := os.MkdirAll(authDir, 0700); err != nil {
log.Errorf("OAuth Web: failed to create auth directory: %v", err)
return
}
// Generate filename using the unified function
fileName := GenerateTokenFileName(tokenData)
authFilePath := filepath.Join(authDir, fileName)
// Convert to storage format and save
storage := &KiroTokenStorage{
Type: "kiro",
AccessToken: tokenData.AccessToken,
RefreshToken: tokenData.RefreshToken,
ProfileArn: tokenData.ProfileArn,
ExpiresAt: tokenData.ExpiresAt,
AuthMethod: tokenData.AuthMethod,
Provider: tokenData.Provider,
LastRefresh: time.Now().Format(time.RFC3339),
ClientID: tokenData.ClientID,
ClientSecret: tokenData.ClientSecret,
Region: tokenData.Region,
StartURL: tokenData.StartURL,
Email: tokenData.Email,
}
if err := storage.SaveTokenToFile(authFilePath); err != nil {
log.Errorf("OAuth Web: failed to save token to file: %v", err)
return
}
log.Infof("OAuth Web: token saved to %s", authFilePath)
}
func (h *OAuthWebHandler) ssoClient(session *webAuthSession) *SSOOIDCClient {
return session.ssoClient
}
func (h *OAuthWebHandler) handleCallback(c *gin.Context) {
stateID := c.Query("state")
errParam := c.Query("error")
if errParam != "" {
h.renderError(c, errParam)
return
}
if stateID == "" {
h.renderError(c, "Missing state parameter")
return
}
h.mu.RLock()
session, exists := h.sessions[stateID]
h.mu.RUnlock()
if !exists {
h.renderError(c, "Invalid or expired session")
return
}
if session.status == statusSuccess {
h.renderSuccess(c, session)
} else if session.status == statusFailed {
h.renderError(c, session.error)
} else {
c.Redirect(http.StatusFound, "/v0/oauth/kiro/start")
}
}
func (h *OAuthWebHandler) handleSocialCallback(c *gin.Context) {
stateID := c.Query("state")
code := c.Query("code")
errParam := c.Query("error")
if errParam != "" {
h.renderError(c, errParam)
return
}
if stateID == "" {
h.renderError(c, "Missing state parameter")
return
}
if code == "" {
h.renderError(c, "Missing authorization code")
return
}
h.mu.RLock()
session, exists := h.sessions[stateID]
h.mu.RUnlock()
if !exists {
h.renderError(c, "Invalid or expired session")
return
}
if session.authMethod != "google" && session.authMethod != "github" {
h.renderError(c, "Invalid session type for social callback")
return
}
socialClient := NewSocialAuthClient(h.cfg)
redirectURI := h.getSocialCallbackURL(c)
tokenReq := &CreateTokenRequest{
Code: code,
CodeVerifier: session.codeVerifier,
RedirectURI: redirectURI,
}
tokenResp, err := socialClient.CreateToken(c.Request.Context(), tokenReq)
if err != nil {
log.Errorf("OAuth Web: social token exchange failed: %v", err)
h.mu.Lock()
session.status = statusFailed
session.error = fmt.Sprintf("Token exchange failed: %v", err)
session.completedAt = time.Now()
h.mu.Unlock()
h.renderError(c, session.error)
return
}
expiresIn := tokenResp.ExpiresIn
if expiresIn <= 0 {
expiresIn = 3600
}
expiresAt := time.Now().Add(time.Duration(expiresIn) * time.Second)
email := ExtractEmailFromJWT(tokenResp.AccessToken)
var provider string
if session.authMethod == "google" {
provider = string(ProviderGoogle)
} else {
provider = string(ProviderGitHub)
}
tokenData := &KiroTokenData{
AccessToken: tokenResp.AccessToken,
RefreshToken: tokenResp.RefreshToken,
ProfileArn: tokenResp.ProfileArn,
ExpiresAt: expiresAt.Format(time.RFC3339),
AuthMethod: session.authMethod,
Provider: provider,
Email: email,
Region: "us-east-1",
}
h.mu.Lock()
session.status = statusSuccess
session.completedAt = time.Now()
session.expiresAt = expiresAt
session.tokenData = tokenData
h.mu.Unlock()
if session.cancelFunc != nil {
session.cancelFunc()
}
if h.onTokenObtained != nil {
h.onTokenObtained(tokenData)
}
// Save token to file
h.saveTokenToFile(tokenData)
log.Infof("OAuth Web: social authentication successful for %s via %s", email, provider)
h.renderSuccess(c, session)
}
func (h *OAuthWebHandler) handleStatus(c *gin.Context) {
stateID := c.Query("state")
if stateID == "" {
c.JSON(http.StatusBadRequest, gin.H{"error": "missing state parameter"})
return
}
h.mu.RLock()
session, exists := h.sessions[stateID]
h.mu.RUnlock()
if !exists {
c.JSON(http.StatusNotFound, gin.H{"error": "session not found"})
return
}
response := gin.H{
"status": string(session.status),
}
switch session.status {
case statusPending:
elapsed := time.Since(session.startedAt).Seconds()
remaining := float64(session.expiresIn) - elapsed
if remaining < 0 {
remaining = 0
}
response["remaining_seconds"] = int(remaining)
case statusSuccess:
response["completed_at"] = session.completedAt.Format(time.RFC3339)
response["expires_at"] = session.expiresAt.Format(time.RFC3339)
case statusFailed:
response["error"] = session.error
response["failed_at"] = session.completedAt.Format(time.RFC3339)
}
c.JSON(http.StatusOK, response)
}
func (h *OAuthWebHandler) renderStartPage(c *gin.Context, session *webAuthSession) {
tmpl, err := template.New("start").Parse(oauthWebStartPageHTML)
if err != nil {
log.Errorf("OAuth Web: failed to parse template: %v", err)
c.String(http.StatusInternalServerError, "Template error")
return
}
data := map[string]interface{}{
"AuthURL": session.authURL,
"UserCode": session.userCode,
"ExpiresIn": session.expiresIn,
"StateID": session.stateID,
}
c.Header("Content-Type", "text/html; charset=utf-8")
if err := tmpl.Execute(c.Writer, data); err != nil {
log.Errorf("OAuth Web: failed to render template: %v", err)
}
}
func (h *OAuthWebHandler) renderSelectPage(c *gin.Context) {
tmpl, err := template.New("select").Parse(oauthWebSelectPageHTML)
if err != nil {
log.Errorf("OAuth Web: failed to parse select template: %v", err)
c.String(http.StatusInternalServerError, "Template error")
return
}
c.Header("Content-Type", "text/html; charset=utf-8")
if err := tmpl.Execute(c.Writer, nil); err != nil {
log.Errorf("OAuth Web: failed to render select template: %v", err)
}
}
func (h *OAuthWebHandler) renderError(c *gin.Context, errMsg string) {
tmpl, err := template.New("error").Parse(oauthWebErrorPageHTML)
if err != nil {
log.Errorf("OAuth Web: failed to parse error template: %v", err)
c.String(http.StatusInternalServerError, "Template error")
return
}
data := map[string]interface{}{
"Error": errMsg,
}
c.Header("Content-Type", "text/html; charset=utf-8")
c.Status(http.StatusBadRequest)
if err := tmpl.Execute(c.Writer, data); err != nil {
log.Errorf("OAuth Web: failed to render error template: %v", err)
}
}
func (h *OAuthWebHandler) renderSuccess(c *gin.Context, session *webAuthSession) {
tmpl, err := template.New("success").Parse(oauthWebSuccessPageHTML)
if err != nil {
log.Errorf("OAuth Web: failed to parse success template: %v", err)
c.String(http.StatusInternalServerError, "Template error")
return
}
data := map[string]interface{}{
"ExpiresAt": session.expiresAt.Format(time.RFC3339),
}
c.Header("Content-Type", "text/html; charset=utf-8")
if err := tmpl.Execute(c.Writer, data); err != nil {
log.Errorf("OAuth Web: failed to render success template: %v", err)
}
}
func (h *OAuthWebHandler) CleanupExpiredSessions() {
h.mu.Lock()
defer h.mu.Unlock()
now := time.Now()
for id, session := range h.sessions {
if session.status != statusPending && now.Sub(session.completedAt) > 30*time.Minute {
delete(h.sessions, id)
} else if session.status == statusPending && now.Sub(session.startedAt) > defaultSessionExpiry {
session.cancelFunc()
delete(h.sessions, id)
}
}
}
func (h *OAuthWebHandler) GetSession(stateID string) (*webAuthSession, bool) {
h.mu.RLock()
defer h.mu.RUnlock()
session, exists := h.sessions[stateID]
return session, exists
}
// ImportTokenRequest represents the request body for token import
type ImportTokenRequest struct {
RefreshToken string `json:"refreshToken"`
}
// handleImportToken handles manual refresh token import from Kiro IDE
func (h *OAuthWebHandler) handleImportToken(c *gin.Context) {
var req ImportTokenRequest
if err := c.ShouldBindJSON(&req); err != nil {
c.JSON(http.StatusBadRequest, gin.H{
"success": false,
"error": "Invalid request body",
})
return
}
refreshToken := strings.TrimSpace(req.RefreshToken)
if refreshToken == "" {
c.JSON(http.StatusBadRequest, gin.H{
"success": false,
"error": "Refresh token is required",
})
return
}
// Validate token format
if !strings.HasPrefix(refreshToken, "aorAAAAAG") {
c.JSON(http.StatusBadRequest, gin.H{
"success": false,
"error": "Invalid token format. Token should start with aorAAAAAG...",
})
return
}
// Create social auth client to refresh and validate the token
socialClient := NewSocialAuthClient(h.cfg)
// Refresh the token to validate it and get access token
tokenData, err := socialClient.RefreshSocialToken(c.Request.Context(), refreshToken)
if err != nil {
log.Errorf("OAuth Web: token refresh failed during import: %v", err)
c.JSON(http.StatusBadRequest, gin.H{
"success": false,
"error": fmt.Sprintf("Token validation failed: %v", err),
})
return
}
// Set the original refresh token (the refreshed one might be empty)
if tokenData.RefreshToken == "" {
tokenData.RefreshToken = refreshToken
}
tokenData.AuthMethod = "social"
tokenData.Provider = "imported"
// Notify callback if set
if h.onTokenObtained != nil {
h.onTokenObtained(tokenData)
}
// Save token to file
h.saveTokenToFile(tokenData)
// Generate filename for response using the unified function
fileName := GenerateTokenFileName(tokenData)
log.Infof("OAuth Web: token imported successfully")
c.JSON(http.StatusOK, gin.H{
"success": true,
"message": "Token imported successfully",
"fileName": fileName,
})
}
// handleManualRefresh handles manual token refresh requests from the web UI.
// This allows users to trigger a token refresh when needed, without waiting
// for the automatic 30-second check and 20-minute-before-expiry refresh cycle.
// Uses the same refresh logic as kiro_executor.Refresh for consistency.
func (h *OAuthWebHandler) handleManualRefresh(c *gin.Context) {
authDir := ""
if h.cfg != nil && h.cfg.AuthDir != "" {
var err error
authDir, err = util.ResolveAuthDir(h.cfg.AuthDir)
if err != nil {
log.Errorf("OAuth Web: failed to resolve auth directory: %v", err)
}
}
if authDir == "" {
home, err := os.UserHomeDir()
if err != nil {
c.JSON(http.StatusInternalServerError, gin.H{
"success": false,
"error": "Failed to get home directory",
})
return
}
authDir = filepath.Join(home, ".cli-proxy-api")
}
// Find all kiro token files in the auth directory
files, err := os.ReadDir(authDir)
if err != nil {
c.JSON(http.StatusInternalServerError, gin.H{
"success": false,
"error": fmt.Sprintf("Failed to read auth directory: %v", err),
})
return
}
var refreshedCount int
var errors []string
for _, file := range files {
if file.IsDir() {
continue
}
name := file.Name()
if !strings.HasPrefix(name, "kiro-") || !strings.HasSuffix(name, ".json") {
continue
}
filePath := filepath.Join(authDir, name)
data, err := os.ReadFile(filePath)
if err != nil {
errors = append(errors, fmt.Sprintf("%s: read error - %v", name, err))
continue
}
var storage KiroTokenStorage
if err := json.Unmarshal(data, &storage); err != nil {
errors = append(errors, fmt.Sprintf("%s: parse error - %v", name, err))
continue
}
if storage.RefreshToken == "" {
errors = append(errors, fmt.Sprintf("%s: no refresh token", name))
continue
}
// Refresh token using the same logic as kiro_executor.Refresh
tokenData, err := h.refreshTokenData(c.Request.Context(), &storage)
if err != nil {
errors = append(errors, fmt.Sprintf("%s: refresh failed - %v", name, err))
continue
}
// Update storage with new token data
storage.AccessToken = tokenData.AccessToken
if tokenData.RefreshToken != "" {
storage.RefreshToken = tokenData.RefreshToken
}
storage.ExpiresAt = tokenData.ExpiresAt
storage.LastRefresh = time.Now().Format(time.RFC3339)
if tokenData.ProfileArn != "" {
storage.ProfileArn = tokenData.ProfileArn
}
// Write updated token back to file
updatedData, err := json.MarshalIndent(storage, "", " ")
if err != nil {
errors = append(errors, fmt.Sprintf("%s: marshal error - %v", name, err))
continue
}
tmpFile := filePath + ".tmp"
if err := os.WriteFile(tmpFile, updatedData, 0600); err != nil {
errors = append(errors, fmt.Sprintf("%s: write error - %v", name, err))
continue
}
if err := os.Rename(tmpFile, filePath); err != nil {
errors = append(errors, fmt.Sprintf("%s: rename error - %v", name, err))
continue
}
log.Infof("OAuth Web: manually refreshed token in %s, expires at %s", name, tokenData.ExpiresAt)
refreshedCount++
// Notify callback if set
if h.onTokenObtained != nil {
h.onTokenObtained(tokenData)
}
}
if refreshedCount == 0 && len(errors) > 0 {
c.JSON(http.StatusBadRequest, gin.H{
"success": false,
"error": fmt.Sprintf("All refresh attempts failed: %v", errors),
})
return
}
response := gin.H{
"success": true,
"message": fmt.Sprintf("Refreshed %d token(s)", refreshedCount),
"refreshedCount": refreshedCount,
}
if len(errors) > 0 {
response["warnings"] = errors
}
c.JSON(http.StatusOK, response)
}
// refreshTokenData refreshes a token using the appropriate method based on auth type.
// This mirrors the logic in kiro_executor.Refresh for consistency.
func (h *OAuthWebHandler) refreshTokenData(ctx context.Context, storage *KiroTokenStorage) (*KiroTokenData, error) {
ssoClient := NewSSOOIDCClient(h.cfg)
switch {
case storage.ClientID != "" && storage.ClientSecret != "" && storage.AuthMethod == "idc" && storage.Region != "":
// IDC refresh with region-specific endpoint
log.Debugf("OAuth Web: using SSO OIDC refresh for IDC (region=%s)", storage.Region)
return ssoClient.RefreshTokenWithRegion(ctx, storage.ClientID, storage.ClientSecret, storage.RefreshToken, storage.Region, storage.StartURL)
case storage.ClientID != "" && storage.ClientSecret != "" && storage.AuthMethod == "builder-id":
// Builder ID refresh with default endpoint
log.Debugf("OAuth Web: using SSO OIDC refresh for AWS Builder ID")
return ssoClient.RefreshToken(ctx, storage.ClientID, storage.ClientSecret, storage.RefreshToken)
default:
// Fallback to Kiro's OAuth refresh endpoint (for social auth: Google/GitHub)
log.Debugf("OAuth Web: using Kiro OAuth refresh endpoint")
oauth := NewKiroOAuth(h.cfg)
return oauth.RefreshToken(ctx, storage.RefreshToken)
}
}

View File

@@ -0,0 +1,779 @@
// Package kiro provides OAuth Web authentication templates.
package kiro
const (
oauthWebStartPageHTML = `<!DOCTYPE html>
<html>
<head>
<meta charset="UTF-8">
<meta name="viewport" content="width=device-width, initial-scale=1.0">
<title>AWS SSO Authentication</title>
<style>
* { box-sizing: border-box; }
body {
font-family: -apple-system, BlinkMacSystemFont, "Segoe UI", Roboto, sans-serif;
margin: 0;
padding: 20px;
background: linear-gradient(135deg, #667eea 0%, #764ba2 100%);
min-height: 100vh;
display: flex;
justify-content: center;
align-items: center;
}
.container {
max-width: 500px;
width: 100%;
background: #fff;
padding: 40px;
border-radius: 12px;
box-shadow: 0 10px 40px rgba(0,0,0,0.2);
}
h1 {
margin: 0 0 10px;
color: #333;
font-size: 24px;
text-align: center;
}
.subtitle {
text-align: center;
color: #666;
margin-bottom: 30px;
}
.step {
background: #f8f9fa;
padding: 20px;
border-radius: 8px;
margin-bottom: 15px;
}
.step-title {
display: flex;
align-items: center;
font-weight: 600;
color: #333;
margin-bottom: 10px;
}
.step-number {
width: 28px;
height: 28px;
background: #667eea;
color: white;
border-radius: 50%;
display: flex;
align-items: center;
justify-content: center;
font-size: 14px;
margin-right: 12px;
}
.user-code {
background: #e7f3ff;
border: 2px dashed #2196F3;
border-radius: 8px;
padding: 20px;
text-align: center;
margin-top: 10px;
}
.user-code-label {
font-size: 12px;
color: #666;
text-transform: uppercase;
letter-spacing: 1px;
margin-bottom: 8px;
}
.user-code-value {
font-size: 32px;
font-weight: bold;
font-family: monospace;
color: #2196F3;
letter-spacing: 4px;
}
.auth-btn {
display: block;
width: 100%;
padding: 15px;
background: #667eea;
color: white;
text-align: center;
text-decoration: none;
border-radius: 8px;
font-weight: 600;
font-size: 16px;
transition: all 0.3s;
border: none;
cursor: pointer;
margin-top: 20px;
}
.auth-btn:hover {
background: #5568d3;
transform: translateY(-2px);
box-shadow: 0 4px 12px rgba(102, 126, 234, 0.4);
}
.status {
margin-top: 30px;
padding: 20px;
background: #f8f9fa;
border-radius: 8px;
text-align: center;
}
.status-pending { border-left: 4px solid #ffc107; }
.status-success { border-left: 4px solid #28a745; }
.status-failed { border-left: 4px solid #dc3545; }
.spinner {
border: 3px solid #f3f3f3;
border-top: 3px solid #667eea;
border-radius: 50%;
width: 40px;
height: 40px;
animation: spin 1s linear infinite;
margin: 0 auto 15px;
}
@keyframes spin {
0% { transform: rotate(0deg); }
100% { transform: rotate(360deg); }
}
.timer {
font-size: 24px;
font-weight: bold;
color: #667eea;
margin: 10px 0;
}
.timer.warning { color: #ffc107; }
.timer.danger { color: #dc3545; }
.status-message { color: #666; line-height: 1.6; }
.success-icon, .error-icon { font-size: 48px; margin-bottom: 15px; }
.info-box {
background: #e7f3ff;
border-left: 4px solid #2196F3;
padding: 15px;
margin-top: 20px;
border-radius: 4px;
font-size: 14px;
color: #666;
}
</style>
</head>
<body>
<div class="container">
<h1>🔐 AWS SSO Authentication</h1>
<p class="subtitle">Follow the steps below to complete authentication</p>
<div class="step">
<div class="step-title">
<span class="step-number">1</span>
Click the button below to open the authorization page
</div>
<a href="{{.AuthURL}}" target="_blank" class="auth-btn" id="authBtn">
🚀 Open Authorization Page
</a>
</div>
<div class="step">
<div class="step-title">
<span class="step-number">2</span>
Enter the verification code below
</div>
<div class="user-code">
<div class="user-code-label">Verification Code</div>
<div class="user-code-value">{{.UserCode}}</div>
</div>
</div>
<div class="step">
<div class="step-title">
<span class="step-number">3</span>
Complete AWS SSO login
</div>
<p style="color: #666; font-size: 14px; margin-top: 10px;">
Use your AWS SSO account to login and authorize
</p>
</div>
<div class="status status-pending" id="statusBox">
<div class="spinner" id="spinner"></div>
<div class="timer" id="timer">{{.ExpiresIn}}s</div>
<div class="status-message" id="statusMessage">
Waiting for authorization...
</div>
</div>
<div class="info-box">
💡 <strong>Tip:</strong> The authorization page will open in a new tab. This page will automatically update once authorization is complete.
</div>
</div>
<script>
let pollInterval;
let timerInterval;
let remainingSeconds = {{.ExpiresIn}};
const stateID = "{{.StateID}}";
setTimeout(() => {
document.getElementById('authBtn').click();
}, 500);
function pollStatus() {
fetch('/v0/oauth/kiro/status?state=' + stateID)
.then(response => response.json())
.then(data => {
console.log('Status:', data);
if (data.status === 'success') {
clearInterval(pollInterval);
clearInterval(timerInterval);
showSuccess(data);
} else if (data.status === 'failed') {
clearInterval(pollInterval);
clearInterval(timerInterval);
showError(data);
} else {
remainingSeconds = data.remaining_seconds || 0;
}
})
.catch(error => {
console.error('Poll error:', error);
});
}
function updateTimer() {
const timerEl = document.getElementById('timer');
const minutes = Math.floor(remainingSeconds / 60);
const seconds = remainingSeconds % 60;
timerEl.textContent = minutes + ':' + seconds.toString().padStart(2, '0');
if (remainingSeconds < 60) {
timerEl.className = 'timer danger';
} else if (remainingSeconds < 180) {
timerEl.className = 'timer warning';
} else {
timerEl.className = 'timer';
}
remainingSeconds--;
if (remainingSeconds < 0) {
clearInterval(timerInterval);
clearInterval(pollInterval);
showError({ error: 'Authentication timed out. Please refresh and try again.' });
}
}
function showSuccess(data) {
const statusBox = document.getElementById('statusBox');
statusBox.className = 'status status-success';
statusBox.innerHTML = '<div class="success-icon">✅</div>' +
'<div class="status-message">' +
'<strong>Authentication Successful!</strong><br>' +
'Token expires: ' + new Date(data.expires_at).toLocaleString() +
'</div>';
}
function showError(data) {
const statusBox = document.getElementById('statusBox');
statusBox.className = 'status status-failed';
statusBox.innerHTML = '<div class="error-icon">❌</div>' +
'<div class="status-message">' +
'<strong>Authentication Failed</strong><br>' +
(data.error || 'Unknown error') +
'</div>' +
'<button class="auth-btn" onclick="location.reload()" style="margin-top: 15px;">' +
'🔄 Retry' +
'</button>';
}
pollInterval = setInterval(pollStatus, 3000);
timerInterval = setInterval(updateTimer, 1000);
pollStatus();
</script>
</body>
</html>`
oauthWebErrorPageHTML = `<!DOCTYPE html>
<html>
<head>
<meta charset="UTF-8">
<meta name="viewport" content="width=device-width, initial-scale=1.0">
<title>Authentication Failed</title>
<style>
body {
font-family: -apple-system, BlinkMacSystemFont, "Segoe UI", Roboto, sans-serif;
max-width: 600px;
margin: 50px auto;
padding: 20px;
background: #f5f5f5;
}
.error {
background: #fff;
padding: 30px;
border-radius: 8px;
box-shadow: 0 2px 4px rgba(0,0,0,0.1);
border-left: 4px solid #dc3545;
}
h1 { color: #dc3545; margin-top: 0; }
.error-message { color: #666; line-height: 1.6; }
.retry-btn {
display: inline-block;
margin-top: 20px;
padding: 10px 20px;
background: #007bff;
color: white;
text-decoration: none;
border-radius: 4px;
}
.retry-btn:hover { background: #0056b3; }
</style>
</head>
<body>
<div class="error">
<h1>❌ Authentication Failed</h1>
<div class="error-message">
<p><strong>Error:</strong></p>
<p>{{.Error}}</p>
</div>
<a href="/v0/oauth/kiro/start" class="retry-btn">🔄 Retry</a>
</div>
</body>
</html>`
oauthWebSuccessPageHTML = `<!DOCTYPE html>
<html>
<head>
<meta charset="UTF-8">
<meta name="viewport" content="width=device-width, initial-scale=1.0">
<title>Authentication Successful</title>
<style>
body {
font-family: -apple-system, BlinkMacSystemFont, "Segoe UI", Roboto, sans-serif;
max-width: 600px;
margin: 50px auto;
padding: 20px;
background: #f5f5f5;
}
.success {
background: #fff;
padding: 30px;
border-radius: 8px;
box-shadow: 0 2px 4px rgba(0,0,0,0.1);
border-left: 4px solid #28a745;
text-align: center;
}
h1 { color: #28a745; margin-top: 0; }
.success-message { color: #666; line-height: 1.6; }
.icon { font-size: 48px; margin-bottom: 15px; }
.expires { font-size: 14px; color: #999; margin-top: 15px; }
</style>
</head>
<body>
<div class="success">
<div class="icon">✅</div>
<h1>Authentication Successful!</h1>
<div class="success-message">
<p>You can close this window.</p>
</div>
<div class="expires">Token expires: {{.ExpiresAt}}</div>
</div>
</body>
</html>`
oauthWebSelectPageHTML = `<!DOCTYPE html>
<html>
<head>
<meta charset="UTF-8">
<meta name="viewport" content="width=device-width, initial-scale=1.0">
<title>Select Authentication Method</title>
<style>
* { box-sizing: border-box; }
body {
font-family: -apple-system, BlinkMacSystemFont, "Segoe UI", Roboto, sans-serif;
margin: 0;
padding: 20px;
background: linear-gradient(135deg, #667eea 0%, #764ba2 100%);
min-height: 100vh;
display: flex;
justify-content: center;
align-items: center;
}
.container {
max-width: 500px;
width: 100%;
background: #fff;
padding: 40px;
border-radius: 12px;
box-shadow: 0 10px 40px rgba(0,0,0,0.2);
}
h1 {
margin: 0 0 10px;
color: #333;
font-size: 24px;
text-align: center;
}
.subtitle {
text-align: center;
color: #666;
margin-bottom: 30px;
}
.auth-methods {
display: flex;
flex-direction: column;
gap: 15px;
}
.auth-btn {
display: flex;
align-items: center;
width: 100%;
padding: 15px 20px;
background: #667eea;
color: white;
text-decoration: none;
border-radius: 8px;
font-weight: 600;
font-size: 16px;
transition: all 0.3s;
border: none;
cursor: pointer;
}
.auth-btn:hover {
background: #5568d3;
transform: translateY(-2px);
box-shadow: 0 4px 12px rgba(102, 126, 234, 0.4);
}
.auth-btn .icon {
font-size: 24px;
margin-right: 15px;
width: 32px;
text-align: center;
}
.auth-btn.google { background: #4285F4; }
.auth-btn.google:hover { background: #3367D6; }
.auth-btn.github { background: #24292e; }
.auth-btn.github:hover { background: #1a1e22; }
.auth-btn.aws { background: #FF9900; }
.auth-btn.aws:hover { background: #E68A00; }
.auth-btn.idc { background: #232F3E; }
.auth-btn.idc:hover { background: #1a242f; }
.idc-form {
background: #f8f9fa;
padding: 20px;
border-radius: 8px;
margin-top: 15px;
display: none;
}
.idc-form.show {
display: block;
}
.form-group {
margin-bottom: 15px;
}
.form-group label {
display: block;
font-weight: 600;
color: #333;
margin-bottom: 8px;
font-size: 14px;
}
.form-group input {
width: 100%;
padding: 12px;
border: 2px solid #e0e0e0;
border-radius: 6px;
font-size: 14px;
transition: border-color 0.3s;
}
.form-group input:focus {
outline: none;
border-color: #667eea;
}
.form-group .hint {
font-size: 12px;
color: #999;
margin-top: 5px;
}
.submit-btn {
display: block;
width: 100%;
padding: 15px;
background: #232F3E;
color: white;
text-align: center;
text-decoration: none;
border-radius: 8px;
font-weight: 600;
font-size: 16px;
transition: all 0.3s;
border: none;
cursor: pointer;
}
.submit-btn:hover {
background: #1a242f;
transform: translateY(-2px);
box-shadow: 0 4px 12px rgba(35, 47, 62, 0.4);
}
.divider {
display: flex;
align-items: center;
margin: 20px 0;
}
.divider::before,
.divider::after {
content: "";
flex: 1;
border-bottom: 1px solid #e0e0e0;
}
.divider span {
padding: 0 15px;
color: #999;
font-size: 14px;
}
.info-box {
background: #e7f3ff;
border-left: 4px solid #2196F3;
padding: 15px;
margin-top: 20px;
border-radius: 4px;
font-size: 14px;
color: #666;
}
.warning-box {
background: #fff3cd;
border-left: 4px solid #ffc107;
padding: 15px;
margin-top: 20px;
border-radius: 4px;
font-size: 14px;
color: #856404;
}
.auth-btn.manual { background: #6c757d; }
.auth-btn.manual:hover { background: #5a6268; }
.auth-btn.refresh { background: #17a2b8; }
.auth-btn.refresh:hover { background: #138496; }
.auth-btn.refresh:disabled { background: #7fb3bd; cursor: not-allowed; }
.manual-form {
background: #f8f9fa;
padding: 20px;
border-radius: 8px;
margin-top: 15px;
display: none;
}
.manual-form.show {
display: block;
}
.form-group textarea {
width: 100%;
padding: 12px;
border: 2px solid #e0e0e0;
border-radius: 6px;
font-size: 14px;
font-family: monospace;
transition: border-color 0.3s;
resize: vertical;
min-height: 80px;
}
.form-group textarea:focus {
outline: none;
border-color: #667eea;
}
.status-message {
padding: 15px;
border-radius: 6px;
margin-top: 15px;
display: none;
}
.status-message.success {
background: #d4edda;
color: #155724;
display: block;
}
.status-message.error {
background: #f8d7da;
color: #721c24;
display: block;
}
</style>
</head>
<body>
<div class="container">
<h1>🔐 Select Authentication Method</h1>
<p class="subtitle">Choose how you want to authenticate with Kiro</p>
<div class="auth-methods">
<a href="/v0/oauth/kiro/start?method=builder-id" class="auth-btn aws">
<span class="icon">🔶</span>
AWS Builder ID (Recommended)
</a>
<button type="button" class="auth-btn idc" onclick="toggleIdcForm()">
<span class="icon">🏢</span>
AWS Identity Center (IDC)
</button>
<div class="divider"><span>or</span></div>
<button type="button" class="auth-btn manual" onclick="toggleManualForm()">
<span class="icon">📋</span>
Import RefreshToken from Kiro IDE
</button>
<button type="button" class="auth-btn refresh" onclick="manualRefresh()" id="refreshBtn">
<span class="icon">🔄</span>
Manual Refresh All Tokens
</button>
<div class="status-message" id="refreshStatus"></div>
</div>
<div class="idc-form" id="idcForm">
<form action="/v0/oauth/kiro/start" method="get">
<input type="hidden" name="method" value="idc">
<div class="form-group">
<label for="startUrl">Start URL</label>
<input type="url" id="startUrl" name="startUrl" placeholder="https://your-org.awsapps.com/start" required>
<div class="hint">Your AWS Identity Center Start URL</div>
</div>
<div class="form-group">
<label for="region">Region</label>
<input type="text" id="region" name="region" value="us-east-1" placeholder="us-east-1">
<div class="hint">AWS Region for your Identity Center</div>
</div>
<button type="submit" class="submit-btn">
🚀 Continue with IDC
</button>
</form>
</div>
<div class="manual-form" id="manualForm">
<form id="importForm" onsubmit="submitImport(event)">
<div class="form-group">
<label for="refreshToken">Refresh Token</label>
<textarea id="refreshToken" name="refreshToken" placeholder="Paste your refreshToken here (starts with aorAAAAAG...)" required></textarea>
<div class="hint">Copy from Kiro IDE: ~/.kiro/kiro-auth-token.json → refreshToken field</div>
</div>
<button type="submit" class="submit-btn" id="importBtn">
📥 Import Token
</button>
<div class="status-message" id="importStatus"></div>
</form>
</div>
<div class="warning-box">
⚠️ <strong>Note:</strong> Google and GitHub login are not available for third-party applications due to AWS Cognito restrictions. Please use AWS Builder ID or import your token from Kiro IDE.
</div>
<div class="info-box">
💡 <strong>How to get RefreshToken:</strong><br>
1. Open Kiro IDE and login with Google/GitHub<br>
2. Find the token file: <code>~/.kiro/kiro-auth-token.json</code><br>
3. Copy the <code>refreshToken</code> value and paste it above
</div>
</div>
<script>
function toggleIdcForm() {
const idcForm = document.getElementById('idcForm');
const manualForm = document.getElementById('manualForm');
manualForm.classList.remove('show');
idcForm.classList.toggle('show');
if (idcForm.classList.contains('show')) {
document.getElementById('startUrl').focus();
}
}
function toggleManualForm() {
const idcForm = document.getElementById('idcForm');
const manualForm = document.getElementById('manualForm');
idcForm.classList.remove('show');
manualForm.classList.toggle('show');
if (manualForm.classList.contains('show')) {
document.getElementById('refreshToken').focus();
}
}
async function submitImport(event) {
event.preventDefault();
const refreshToken = document.getElementById('refreshToken').value.trim();
const statusEl = document.getElementById('importStatus');
const btn = document.getElementById('importBtn');
if (!refreshToken) {
statusEl.className = 'status-message error';
statusEl.textContent = 'Please enter a refresh token';
return;
}
if (!refreshToken.startsWith('aorAAAAAG')) {
statusEl.className = 'status-message error';
statusEl.textContent = 'Invalid token format. Token should start with aorAAAAAG...';
return;
}
btn.disabled = true;
btn.textContent = '⏳ Importing...';
statusEl.className = 'status-message';
statusEl.style.display = 'none';
try {
const response = await fetch('/v0/oauth/kiro/import', {
method: 'POST',
headers: { 'Content-Type': 'application/json' },
body: JSON.stringify({ refreshToken: refreshToken })
});
const data = await response.json();
if (response.ok && data.success) {
statusEl.className = 'status-message success';
statusEl.textContent = '✅ Token imported successfully! File: ' + (data.fileName || 'kiro-token.json');
} else {
statusEl.className = 'status-message error';
statusEl.textContent = '❌ ' + (data.error || data.message || 'Import failed');
}
} catch (error) {
statusEl.className = 'status-message error';
statusEl.textContent = '❌ Network error: ' + error.message;
} finally {
btn.disabled = false;
btn.textContent = '📥 Import Token';
}
}
async function manualRefresh() {
const btn = document.getElementById('refreshBtn');
const statusEl = document.getElementById('refreshStatus');
btn.disabled = true;
btn.innerHTML = '<span class="icon">⏳</span> Refreshing...';
statusEl.className = 'status-message';
statusEl.style.display = 'none';
try {
const response = await fetch('/v0/oauth/kiro/refresh', {
method: 'POST',
headers: { 'Content-Type': 'application/json' }
});
const data = await response.json();
if (response.ok && data.success) {
statusEl.className = 'status-message success';
let msg = '✅ ' + data.message;
if (data.warnings && data.warnings.length > 0) {
msg += ' (Warnings: ' + data.warnings.join('; ') + ')';
}
statusEl.textContent = msg;
} else {
statusEl.className = 'status-message error';
statusEl.textContent = '❌ ' + (data.error || data.message || 'Refresh failed');
}
} catch (error) {
statusEl.className = 'status-message error';
statusEl.textContent = '❌ Network error: ' + error.message;
} finally {
btn.disabled = false;
btn.innerHTML = '<span class="icon">🔄</span> Manual Refresh All Tokens';
}
}
</script>
</body>
</html>`
)

View File

@@ -0,0 +1,725 @@
// Package kiro provides custom protocol handler registration for Kiro OAuth.
// This enables the CLI to intercept kiro:// URIs for social authentication (Google/GitHub).
package kiro
import (
"context"
"fmt"
"html"
"net"
"net/http"
"net/url"
"os"
"os/exec"
"path/filepath"
"runtime"
"strings"
"sync"
"time"
log "github.com/sirupsen/logrus"
)
const (
// KiroProtocol is the custom URI scheme used by Kiro
KiroProtocol = "kiro"
// KiroAuthority is the URI authority for authentication callbacks
KiroAuthority = "kiro.kiroAgent"
// KiroAuthPath is the path for successful authentication
KiroAuthPath = "/authenticate-success"
// KiroRedirectURI is the full redirect URI for social auth
KiroRedirectURI = "kiro://kiro.kiroAgent/authenticate-success"
// DefaultHandlerPort is the default port for the local callback server
DefaultHandlerPort = 19876
// HandlerTimeout is how long to wait for the OAuth callback
HandlerTimeout = 10 * time.Minute
)
// ProtocolHandler manages the custom kiro:// protocol handler for OAuth callbacks.
type ProtocolHandler struct {
port int
server *http.Server
listener net.Listener
resultChan chan *AuthCallback
stopChan chan struct{}
mu sync.Mutex
running bool
}
// AuthCallback contains the OAuth callback parameters.
type AuthCallback struct {
Code string
State string
Error string
}
// NewProtocolHandler creates a new protocol handler.
func NewProtocolHandler() *ProtocolHandler {
return &ProtocolHandler{
port: DefaultHandlerPort,
resultChan: make(chan *AuthCallback, 1),
stopChan: make(chan struct{}),
}
}
// Start starts the local callback server that receives redirects from the protocol handler.
func (h *ProtocolHandler) Start(ctx context.Context) (int, error) {
h.mu.Lock()
defer h.mu.Unlock()
if h.running {
return h.port, nil
}
// Drain any stale results from previous runs
select {
case <-h.resultChan:
default:
}
// Reset stopChan for reuse - close old channel first to unblock any waiting goroutines
if h.stopChan != nil {
select {
case <-h.stopChan:
// Already closed
default:
close(h.stopChan)
}
}
h.stopChan = make(chan struct{})
// Try ports in known range (must match handler script port range)
var listener net.Listener
var err error
portRange := []int{DefaultHandlerPort, DefaultHandlerPort + 1, DefaultHandlerPort + 2, DefaultHandlerPort + 3, DefaultHandlerPort + 4}
for _, port := range portRange {
listener, err = net.Listen("tcp", fmt.Sprintf("127.0.0.1:%d", port))
if err == nil {
break
}
log.Debugf("kiro protocol handler: port %d busy, trying next", port)
}
if listener == nil {
return 0, fmt.Errorf("failed to start callback server: all ports %d-%d are busy", DefaultHandlerPort, DefaultHandlerPort+4)
}
h.listener = listener
h.port = listener.Addr().(*net.TCPAddr).Port
mux := http.NewServeMux()
mux.HandleFunc("/oauth/callback", h.handleCallback)
h.server = &http.Server{
Handler: mux,
ReadHeaderTimeout: 10 * time.Second,
}
go func() {
if err := h.server.Serve(listener); err != nil && err != http.ErrServerClosed {
log.Debugf("kiro protocol handler server error: %v", err)
}
}()
h.running = true
log.Debugf("kiro protocol handler started on port %d", h.port)
// Auto-shutdown after context done, timeout, or explicit stop
// Capture references to prevent race with new Start() calls
currentStopChan := h.stopChan
currentServer := h.server
currentListener := h.listener
go func() {
select {
case <-ctx.Done():
case <-time.After(HandlerTimeout):
case <-currentStopChan:
return // Already stopped, exit goroutine
}
// Only stop if this is still the current server/listener instance
h.mu.Lock()
if h.server == currentServer && h.listener == currentListener {
h.mu.Unlock()
h.Stop()
} else {
h.mu.Unlock()
}
}()
return h.port, nil
}
// Stop stops the callback server.
func (h *ProtocolHandler) Stop() {
h.mu.Lock()
defer h.mu.Unlock()
if !h.running {
return
}
// Signal the auto-shutdown goroutine to exit.
// This select pattern is safe because stopChan is only modified while holding h.mu,
// and we hold the lock here. The select prevents panic from double-close.
select {
case <-h.stopChan:
// Already closed
default:
close(h.stopChan)
}
if h.server != nil {
ctx, cancel := context.WithTimeout(context.Background(), 5*time.Second)
defer cancel()
_ = h.server.Shutdown(ctx)
}
h.running = false
log.Debug("kiro protocol handler stopped")
}
// WaitForCallback waits for the OAuth callback and returns the result.
func (h *ProtocolHandler) WaitForCallback(ctx context.Context) (*AuthCallback, error) {
select {
case <-ctx.Done():
return nil, ctx.Err()
case <-time.After(HandlerTimeout):
return nil, fmt.Errorf("timeout waiting for OAuth callback")
case result := <-h.resultChan:
return result, nil
}
}
// GetPort returns the port the handler is listening on.
func (h *ProtocolHandler) GetPort() int {
return h.port
}
// handleCallback processes the OAuth callback from the protocol handler script.
func (h *ProtocolHandler) handleCallback(w http.ResponseWriter, r *http.Request) {
code := r.URL.Query().Get("code")
state := r.URL.Query().Get("state")
errParam := r.URL.Query().Get("error")
result := &AuthCallback{
Code: code,
State: state,
Error: errParam,
}
// Send result
select {
case h.resultChan <- result:
default:
// Channel full, ignore duplicate callbacks
}
// Send success response
w.Header().Set("Content-Type", "text/html; charset=utf-8")
if errParam != "" {
w.WriteHeader(http.StatusBadRequest)
fmt.Fprintf(w, `<!DOCTYPE html>
<html>
<head><title>Login Failed</title></head>
<body>
<h1>Login Failed</h1>
<p>Error: %s</p>
<p>You can close this window.</p>
</body>
</html>`, html.EscapeString(errParam))
} else {
fmt.Fprint(w, `<!DOCTYPE html>
<html>
<head><title>Login Successful</title></head>
<body>
<h1>Login Successful!</h1>
<p>You can close this window and return to the terminal.</p>
<script>window.close();</script>
</body>
</html>`)
}
}
// IsProtocolHandlerInstalled checks if the kiro:// protocol handler is installed.
func IsProtocolHandlerInstalled() bool {
switch runtime.GOOS {
case "linux":
return isLinuxHandlerInstalled()
case "windows":
return isWindowsHandlerInstalled()
case "darwin":
return isDarwinHandlerInstalled()
default:
return false
}
}
// InstallProtocolHandler installs the kiro:// protocol handler for the current platform.
func InstallProtocolHandler(handlerPort int) error {
switch runtime.GOOS {
case "linux":
return installLinuxHandler(handlerPort)
case "windows":
return installWindowsHandler(handlerPort)
case "darwin":
return installDarwinHandler(handlerPort)
default:
return fmt.Errorf("unsupported platform: %s", runtime.GOOS)
}
}
// UninstallProtocolHandler removes the kiro:// protocol handler.
func UninstallProtocolHandler() error {
switch runtime.GOOS {
case "linux":
return uninstallLinuxHandler()
case "windows":
return uninstallWindowsHandler()
case "darwin":
return uninstallDarwinHandler()
default:
return fmt.Errorf("unsupported platform: %s", runtime.GOOS)
}
}
// --- Linux Implementation ---
func getLinuxDesktopPath() string {
homeDir, _ := os.UserHomeDir()
return filepath.Join(homeDir, ".local", "share", "applications", "kiro-oauth-handler.desktop")
}
func getLinuxHandlerScriptPath() string {
homeDir, _ := os.UserHomeDir()
return filepath.Join(homeDir, ".local", "bin", "kiro-oauth-handler")
}
func isLinuxHandlerInstalled() bool {
desktopPath := getLinuxDesktopPath()
_, err := os.Stat(desktopPath)
return err == nil
}
func installLinuxHandler(handlerPort int) error {
// Create directories
homeDir, err := os.UserHomeDir()
if err != nil {
return err
}
binDir := filepath.Join(homeDir, ".local", "bin")
appDir := filepath.Join(homeDir, ".local", "share", "applications")
if err := os.MkdirAll(binDir, 0755); err != nil {
return fmt.Errorf("failed to create bin directory: %w", err)
}
if err := os.MkdirAll(appDir, 0755); err != nil {
return fmt.Errorf("failed to create applications directory: %w", err)
}
// Create handler script - tries multiple ports to handle dynamic port allocation
scriptPath := getLinuxHandlerScriptPath()
scriptContent := fmt.Sprintf(`#!/bin/bash
# Kiro OAuth Protocol Handler
# Handles kiro:// URIs - tries CLI first, then forwards to Kiro IDE
URL="$1"
# Check curl availability
if ! command -v curl &> /dev/null; then
echo "Error: curl is required for Kiro OAuth handler" >&2
exit 1
fi
# Extract code and state from URL
[[ "$URL" =~ code=([^&]+) ]] && CODE="${BASH_REMATCH[1]}"
[[ "$URL" =~ state=([^&]+) ]] && STATE="${BASH_REMATCH[1]}"
[[ "$URL" =~ error=([^&]+) ]] && ERROR="${BASH_REMATCH[1]}"
# Try CLI proxy on multiple possible ports (default + dynamic range)
CLI_OK=0
for PORT in %d %d %d %d %d; do
if [ -n "$ERROR" ]; then
curl -sf --connect-timeout 1 "http://127.0.0.1:$PORT/oauth/callback?error=$ERROR" && CLI_OK=1 && break
elif [ -n "$CODE" ] && [ -n "$STATE" ]; then
curl -sf --connect-timeout 1 "http://127.0.0.1:$PORT/oauth/callback?code=$CODE&state=$STATE" && CLI_OK=1 && break
fi
done
# If CLI not available, forward to Kiro IDE
if [ $CLI_OK -eq 0 ] && [ -x "/usr/share/kiro/kiro" ]; then
/usr/share/kiro/kiro --open-url "$URL" &
fi
`, handlerPort, handlerPort+1, handlerPort+2, handlerPort+3, handlerPort+4)
if err := os.WriteFile(scriptPath, []byte(scriptContent), 0755); err != nil {
return fmt.Errorf("failed to write handler script: %w", err)
}
// Create .desktop file
desktopPath := getLinuxDesktopPath()
desktopContent := fmt.Sprintf(`[Desktop Entry]
Name=Kiro OAuth Handler
Comment=Handle kiro:// protocol for CLI Proxy API authentication
Exec=%s %%u
Type=Application
Terminal=false
NoDisplay=true
MimeType=x-scheme-handler/kiro;
Categories=Utility;
`, scriptPath)
if err := os.WriteFile(desktopPath, []byte(desktopContent), 0644); err != nil {
return fmt.Errorf("failed to write desktop file: %w", err)
}
// Register handler with xdg-mime
cmd := exec.Command("xdg-mime", "default", "kiro-oauth-handler.desktop", "x-scheme-handler/kiro")
if err := cmd.Run(); err != nil {
log.Warnf("xdg-mime registration failed (may need manual setup): %v", err)
}
// Update desktop database
cmd = exec.Command("update-desktop-database", appDir)
_ = cmd.Run() // Ignore errors, not critical
log.Info("Kiro protocol handler installed for Linux")
return nil
}
func uninstallLinuxHandler() error {
desktopPath := getLinuxDesktopPath()
scriptPath := getLinuxHandlerScriptPath()
if err := os.Remove(desktopPath); err != nil && !os.IsNotExist(err) {
return fmt.Errorf("failed to remove desktop file: %w", err)
}
if err := os.Remove(scriptPath); err != nil && !os.IsNotExist(err) {
return fmt.Errorf("failed to remove handler script: %w", err)
}
log.Info("Kiro protocol handler uninstalled")
return nil
}
// --- Windows Implementation ---
func isWindowsHandlerInstalled() bool {
// Check registry key existence
cmd := exec.Command("reg", "query", `HKCU\Software\Classes\kiro`, "/ve")
return cmd.Run() == nil
}
func installWindowsHandler(handlerPort int) error {
homeDir, err := os.UserHomeDir()
if err != nil {
return err
}
// Create handler script (PowerShell)
scriptDir := filepath.Join(homeDir, ".cliproxyapi")
if err := os.MkdirAll(scriptDir, 0755); err != nil {
return fmt.Errorf("failed to create script directory: %w", err)
}
scriptPath := filepath.Join(scriptDir, "kiro-oauth-handler.ps1")
scriptContent := fmt.Sprintf(`# Kiro OAuth Protocol Handler for Windows
param([string]$url)
# Load required assembly for HttpUtility
Add-Type -AssemblyName System.Web
# Parse URL parameters
$uri = [System.Uri]$url
$query = [System.Web.HttpUtility]::ParseQueryString($uri.Query)
$code = $query["code"]
$state = $query["state"]
$errorParam = $query["error"]
# Try multiple ports (default + dynamic range)
$ports = @(%d, %d, %d, %d, %d)
$success = $false
foreach ($port in $ports) {
if ($success) { break }
$callbackUrl = "http://127.0.0.1:$port/oauth/callback"
try {
if ($errorParam) {
$fullUrl = $callbackUrl + "?error=" + $errorParam
Invoke-WebRequest -Uri $fullUrl -UseBasicParsing -TimeoutSec 1 -ErrorAction Stop | Out-Null
$success = $true
} elseif ($code -and $state) {
$fullUrl = $callbackUrl + "?code=" + $code + "&state=" + $state
Invoke-WebRequest -Uri $fullUrl -UseBasicParsing -TimeoutSec 1 -ErrorAction Stop | Out-Null
$success = $true
}
} catch {
# Try next port
}
}
`, handlerPort, handlerPort+1, handlerPort+2, handlerPort+3, handlerPort+4)
if err := os.WriteFile(scriptPath, []byte(scriptContent), 0644); err != nil {
return fmt.Errorf("failed to write handler script: %w", err)
}
// Create batch wrapper
batchPath := filepath.Join(scriptDir, "kiro-oauth-handler.bat")
batchContent := fmt.Sprintf("@echo off\npowershell -ExecutionPolicy Bypass -File \"%s\" %%1\n", scriptPath)
if err := os.WriteFile(batchPath, []byte(batchContent), 0644); err != nil {
return fmt.Errorf("failed to write batch wrapper: %w", err)
}
// Register in Windows registry
commands := [][]string{
{"reg", "add", `HKCU\Software\Classes\kiro`, "/ve", "/d", "URL:Kiro Protocol", "/f"},
{"reg", "add", `HKCU\Software\Classes\kiro`, "/v", "URL Protocol", "/d", "", "/f"},
{"reg", "add", `HKCU\Software\Classes\kiro\shell`, "/f"},
{"reg", "add", `HKCU\Software\Classes\kiro\shell\open`, "/f"},
{"reg", "add", `HKCU\Software\Classes\kiro\shell\open\command`, "/ve", "/d", fmt.Sprintf("\"%s\" \"%%1\"", batchPath), "/f"},
}
for _, args := range commands {
cmd := exec.Command(args[0], args[1:]...)
if err := cmd.Run(); err != nil {
return fmt.Errorf("failed to run registry command: %w", err)
}
}
log.Info("Kiro protocol handler installed for Windows")
return nil
}
func uninstallWindowsHandler() error {
// Remove registry keys
cmd := exec.Command("reg", "delete", `HKCU\Software\Classes\kiro`, "/f")
if err := cmd.Run(); err != nil {
log.Warnf("failed to remove registry key: %v", err)
}
// Remove scripts
homeDir, _ := os.UserHomeDir()
scriptDir := filepath.Join(homeDir, ".cliproxyapi")
_ = os.Remove(filepath.Join(scriptDir, "kiro-oauth-handler.ps1"))
_ = os.Remove(filepath.Join(scriptDir, "kiro-oauth-handler.bat"))
log.Info("Kiro protocol handler uninstalled")
return nil
}
// --- macOS Implementation ---
func getDarwinAppPath() string {
homeDir, _ := os.UserHomeDir()
return filepath.Join(homeDir, "Applications", "KiroOAuthHandler.app")
}
func isDarwinHandlerInstalled() bool {
appPath := getDarwinAppPath()
_, err := os.Stat(appPath)
return err == nil
}
func installDarwinHandler(handlerPort int) error {
// Create app bundle structure
appPath := getDarwinAppPath()
contentsPath := filepath.Join(appPath, "Contents")
macOSPath := filepath.Join(contentsPath, "MacOS")
if err := os.MkdirAll(macOSPath, 0755); err != nil {
return fmt.Errorf("failed to create app bundle: %w", err)
}
// Create Info.plist
plistPath := filepath.Join(contentsPath, "Info.plist")
plistContent := `<?xml version="1.0" encoding="UTF-8"?>
<!DOCTYPE plist PUBLIC "-//Apple//DTD PLIST 1.0//EN" "http://www.apple.com/DTDs/PropertyList-1.0.dtd">
<plist version="1.0">
<dict>
<key>CFBundleIdentifier</key>
<string>com.cliproxyapi.kiro-oauth-handler</string>
<key>CFBundleName</key>
<string>KiroOAuthHandler</string>
<key>CFBundleExecutable</key>
<string>kiro-oauth-handler</string>
<key>CFBundleVersion</key>
<string>1.0</string>
<key>CFBundleURLTypes</key>
<array>
<dict>
<key>CFBundleURLName</key>
<string>Kiro Protocol</string>
<key>CFBundleURLSchemes</key>
<array>
<string>kiro</string>
</array>
</dict>
</array>
<key>LSBackgroundOnly</key>
<true/>
</dict>
</plist>`
if err := os.WriteFile(plistPath, []byte(plistContent), 0644); err != nil {
return fmt.Errorf("failed to write Info.plist: %w", err)
}
// Create executable script - tries multiple ports to handle dynamic port allocation
execPath := filepath.Join(macOSPath, "kiro-oauth-handler")
execContent := fmt.Sprintf(`#!/bin/bash
# Kiro OAuth Protocol Handler for macOS
URL="$1"
# Check curl availability (should always exist on macOS)
if [ ! -x /usr/bin/curl ]; then
echo "Error: curl is required for Kiro OAuth handler" >&2
exit 1
fi
# Extract code and state from URL
[[ "$URL" =~ code=([^&]+) ]] && CODE="${BASH_REMATCH[1]}"
[[ "$URL" =~ state=([^&]+) ]] && STATE="${BASH_REMATCH[1]}"
[[ "$URL" =~ error=([^&]+) ]] && ERROR="${BASH_REMATCH[1]}"
# Try multiple ports (default + dynamic range)
for PORT in %d %d %d %d %d; do
if [ -n "$ERROR" ]; then
/usr/bin/curl -sf --connect-timeout 1 "http://127.0.0.1:$PORT/oauth/callback?error=$ERROR" && exit 0
elif [ -n "$CODE" ] && [ -n "$STATE" ]; then
/usr/bin/curl -sf --connect-timeout 1 "http://127.0.0.1:$PORT/oauth/callback?code=$CODE&state=$STATE" && exit 0
fi
done
`, handlerPort, handlerPort+1, handlerPort+2, handlerPort+3, handlerPort+4)
if err := os.WriteFile(execPath, []byte(execContent), 0755); err != nil {
return fmt.Errorf("failed to write executable: %w", err)
}
// Register the app with Launch Services
cmd := exec.Command("/System/Library/Frameworks/CoreServices.framework/Frameworks/LaunchServices.framework/Support/lsregister",
"-f", appPath)
if err := cmd.Run(); err != nil {
log.Warnf("lsregister failed (handler may still work): %v", err)
}
log.Info("Kiro protocol handler installed for macOS")
return nil
}
func uninstallDarwinHandler() error {
appPath := getDarwinAppPath()
// Unregister from Launch Services
cmd := exec.Command("/System/Library/Frameworks/CoreServices.framework/Frameworks/LaunchServices.framework/Support/lsregister",
"-u", appPath)
_ = cmd.Run()
// Remove app bundle
if err := os.RemoveAll(appPath); err != nil && !os.IsNotExist(err) {
return fmt.Errorf("failed to remove app bundle: %w", err)
}
log.Info("Kiro protocol handler uninstalled")
return nil
}
// ParseKiroURI parses a kiro:// URI and extracts the callback parameters.
func ParseKiroURI(rawURI string) (*AuthCallback, error) {
u, err := url.Parse(rawURI)
if err != nil {
return nil, fmt.Errorf("invalid URI: %w", err)
}
if u.Scheme != KiroProtocol {
return nil, fmt.Errorf("invalid scheme: expected %s, got %s", KiroProtocol, u.Scheme)
}
if u.Host != KiroAuthority {
return nil, fmt.Errorf("invalid authority: expected %s, got %s", KiroAuthority, u.Host)
}
query := u.Query()
return &AuthCallback{
Code: query.Get("code"),
State: query.Get("state"),
Error: query.Get("error"),
}, nil
}
// GetHandlerInstructions returns platform-specific instructions for manual handler setup.
func GetHandlerInstructions() string {
switch runtime.GOOS {
case "linux":
return `To manually set up the Kiro protocol handler on Linux:
1. Create ~/.local/share/applications/kiro-oauth-handler.desktop:
[Desktop Entry]
Name=Kiro OAuth Handler
Exec=~/.local/bin/kiro-oauth-handler %u
Type=Application
Terminal=false
MimeType=x-scheme-handler/kiro;
2. Create ~/.local/bin/kiro-oauth-handler (make it executable):
#!/bin/bash
URL="$1"
# ... (see generated script for full content)
3. Run: xdg-mime default kiro-oauth-handler.desktop x-scheme-handler/kiro`
case "windows":
return `To manually set up the Kiro protocol handler on Windows:
1. Open Registry Editor (regedit.exe)
2. Create key: HKEY_CURRENT_USER\Software\Classes\kiro
3. Set default value to: URL:Kiro Protocol
4. Create string value "URL Protocol" with empty data
5. Create subkey: shell\open\command
6. Set default value to: "C:\path\to\handler.bat" "%1"`
case "darwin":
return `To manually set up the Kiro protocol handler on macOS:
1. Create ~/Applications/KiroOAuthHandler.app bundle
2. Add Info.plist with CFBundleURLTypes containing "kiro" scheme
3. Create executable in Contents/MacOS/
4. Run: /System/Library/.../lsregister -f ~/Applications/KiroOAuthHandler.app`
default:
return "Protocol handler setup is not supported on this platform."
}
}
// SetupProtocolHandlerIfNeeded checks and installs the protocol handler if needed.
func SetupProtocolHandlerIfNeeded(handlerPort int) error {
if IsProtocolHandlerInstalled() {
log.Debug("Kiro protocol handler already installed")
return nil
}
fmt.Println("\n╔══════════════════════════════════════════════════════════╗")
fmt.Println("║ Kiro Protocol Handler Setup Required ║")
fmt.Println("╚══════════════════════════════════════════════════════════╝")
fmt.Println("\nTo enable Google/GitHub login, we need to install a protocol handler.")
fmt.Println("This allows your browser to redirect back to the CLI after authentication.")
fmt.Println("\nInstalling protocol handler...")
if err := InstallProtocolHandler(handlerPort); err != nil {
fmt.Printf("\n⚠ Automatic installation failed: %v\n", err)
fmt.Println("\nManual setup instructions:")
fmt.Println(strings.Repeat("-", 60))
fmt.Println(GetHandlerInstructions())
return err
}
fmt.Println("\n✓ Protocol handler installed successfully!")
return nil
}

View File

@@ -0,0 +1,316 @@
package kiro
import (
"math"
"math/rand"
"strings"
"sync"
"time"
)
const (
DefaultMinTokenInterval = 1 * time.Second
DefaultMaxTokenInterval = 2 * time.Second
DefaultDailyMaxRequests = 500
DefaultJitterPercent = 0.3
DefaultBackoffBase = 30 * time.Second
DefaultBackoffMax = 5 * time.Minute
DefaultBackoffMultiplier = 1.5
DefaultSuspendCooldown = 1 * time.Hour
)
// TokenState Token 状态
type TokenState struct {
LastRequest time.Time
RequestCount int
CooldownEnd time.Time
FailCount int
DailyRequests int
DailyResetTime time.Time
IsSuspended bool
SuspendedAt time.Time
SuspendReason string
}
// RateLimiter 频率限制器
type RateLimiter struct {
mu sync.RWMutex
states map[string]*TokenState
minTokenInterval time.Duration
maxTokenInterval time.Duration
dailyMaxRequests int
jitterPercent float64
backoffBase time.Duration
backoffMax time.Duration
backoffMultiplier float64
suspendCooldown time.Duration
rng *rand.Rand
}
// NewRateLimiter 创建默认配置的频率限制器
func NewRateLimiter() *RateLimiter {
return &RateLimiter{
states: make(map[string]*TokenState),
minTokenInterval: DefaultMinTokenInterval,
maxTokenInterval: DefaultMaxTokenInterval,
dailyMaxRequests: DefaultDailyMaxRequests,
jitterPercent: DefaultJitterPercent,
backoffBase: DefaultBackoffBase,
backoffMax: DefaultBackoffMax,
backoffMultiplier: DefaultBackoffMultiplier,
suspendCooldown: DefaultSuspendCooldown,
rng: rand.New(rand.NewSource(time.Now().UnixNano())),
}
}
// RateLimiterConfig 频率限制器配置
type RateLimiterConfig struct {
MinTokenInterval time.Duration
MaxTokenInterval time.Duration
DailyMaxRequests int
JitterPercent float64
BackoffBase time.Duration
BackoffMax time.Duration
BackoffMultiplier float64
SuspendCooldown time.Duration
}
// NewRateLimiterWithConfig 使用自定义配置创建频率限制器
func NewRateLimiterWithConfig(cfg RateLimiterConfig) *RateLimiter {
rl := NewRateLimiter()
if cfg.MinTokenInterval > 0 {
rl.minTokenInterval = cfg.MinTokenInterval
}
if cfg.MaxTokenInterval > 0 {
rl.maxTokenInterval = cfg.MaxTokenInterval
}
if cfg.DailyMaxRequests > 0 {
rl.dailyMaxRequests = cfg.DailyMaxRequests
}
if cfg.JitterPercent > 0 {
rl.jitterPercent = cfg.JitterPercent
}
if cfg.BackoffBase > 0 {
rl.backoffBase = cfg.BackoffBase
}
if cfg.BackoffMax > 0 {
rl.backoffMax = cfg.BackoffMax
}
if cfg.BackoffMultiplier > 0 {
rl.backoffMultiplier = cfg.BackoffMultiplier
}
if cfg.SuspendCooldown > 0 {
rl.suspendCooldown = cfg.SuspendCooldown
}
return rl
}
// getOrCreateState 获取或创建 Token 状态
func (rl *RateLimiter) getOrCreateState(tokenKey string) *TokenState {
state, exists := rl.states[tokenKey]
if !exists {
state = &TokenState{
DailyResetTime: time.Now().Truncate(24 * time.Hour).Add(24 * time.Hour),
}
rl.states[tokenKey] = state
}
return state
}
// resetDailyIfNeeded 如果需要则重置每日计数
func (rl *RateLimiter) resetDailyIfNeeded(state *TokenState) {
now := time.Now()
if now.After(state.DailyResetTime) {
state.DailyRequests = 0
state.DailyResetTime = now.Truncate(24 * time.Hour).Add(24 * time.Hour)
}
}
// calculateInterval 计算带抖动的随机间隔
func (rl *RateLimiter) calculateInterval() time.Duration {
baseInterval := rl.minTokenInterval + time.Duration(rl.rng.Int63n(int64(rl.maxTokenInterval-rl.minTokenInterval)))
jitter := time.Duration(float64(baseInterval) * rl.jitterPercent * (rl.rng.Float64()*2 - 1))
return baseInterval + jitter
}
// WaitForToken 等待 Token 可用(带抖动的随机间隔)
func (rl *RateLimiter) WaitForToken(tokenKey string) {
rl.mu.Lock()
state := rl.getOrCreateState(tokenKey)
rl.resetDailyIfNeeded(state)
now := time.Now()
// 检查是否在冷却期
if now.Before(state.CooldownEnd) {
waitTime := state.CooldownEnd.Sub(now)
rl.mu.Unlock()
time.Sleep(waitTime)
rl.mu.Lock()
state = rl.getOrCreateState(tokenKey)
now = time.Now()
}
// 计算距离上次请求的间隔
interval := rl.calculateInterval()
nextAllowedTime := state.LastRequest.Add(interval)
if now.Before(nextAllowedTime) {
waitTime := nextAllowedTime.Sub(now)
rl.mu.Unlock()
time.Sleep(waitTime)
rl.mu.Lock()
state = rl.getOrCreateState(tokenKey)
}
state.LastRequest = time.Now()
state.RequestCount++
state.DailyRequests++
rl.mu.Unlock()
}
// MarkTokenFailed 标记 Token 失败
func (rl *RateLimiter) MarkTokenFailed(tokenKey string) {
rl.mu.Lock()
defer rl.mu.Unlock()
state := rl.getOrCreateState(tokenKey)
state.FailCount++
state.CooldownEnd = time.Now().Add(rl.calculateBackoff(state.FailCount))
}
// MarkTokenSuccess 标记 Token 成功
func (rl *RateLimiter) MarkTokenSuccess(tokenKey string) {
rl.mu.Lock()
defer rl.mu.Unlock()
state := rl.getOrCreateState(tokenKey)
state.FailCount = 0
state.CooldownEnd = time.Time{}
}
// CheckAndMarkSuspended 检测暂停错误并标记
func (rl *RateLimiter) CheckAndMarkSuspended(tokenKey string, errorMsg string) bool {
suspendKeywords := []string{
"suspended",
"banned",
"disabled",
"account has been",
"access denied",
"rate limit exceeded",
"too many requests",
"quota exceeded",
}
lowerMsg := strings.ToLower(errorMsg)
for _, keyword := range suspendKeywords {
if strings.Contains(lowerMsg, keyword) {
rl.mu.Lock()
defer rl.mu.Unlock()
state := rl.getOrCreateState(tokenKey)
state.IsSuspended = true
state.SuspendedAt = time.Now()
state.SuspendReason = errorMsg
state.CooldownEnd = time.Now().Add(rl.suspendCooldown)
return true
}
}
return false
}
// IsTokenAvailable 检查 Token 是否可用
func (rl *RateLimiter) IsTokenAvailable(tokenKey string) bool {
rl.mu.RLock()
defer rl.mu.RUnlock()
state, exists := rl.states[tokenKey]
if !exists {
return true
}
now := time.Now()
// 检查是否被暂停
if state.IsSuspended {
if now.After(state.SuspendedAt.Add(rl.suspendCooldown)) {
return true
}
return false
}
// 检查是否在冷却期
if now.Before(state.CooldownEnd) {
return false
}
// 检查每日请求限制
rl.mu.RUnlock()
rl.mu.Lock()
rl.resetDailyIfNeeded(state)
dailyRequests := state.DailyRequests
dailyMax := rl.dailyMaxRequests
rl.mu.Unlock()
rl.mu.RLock()
if dailyRequests >= dailyMax {
return false
}
return true
}
// calculateBackoff 计算指数退避时间
func (rl *RateLimiter) calculateBackoff(failCount int) time.Duration {
if failCount <= 0 {
return 0
}
backoff := float64(rl.backoffBase) * math.Pow(rl.backoffMultiplier, float64(failCount-1))
// 添加抖动
jitter := backoff * rl.jitterPercent * (rl.rng.Float64()*2 - 1)
backoff += jitter
if time.Duration(backoff) > rl.backoffMax {
return rl.backoffMax
}
return time.Duration(backoff)
}
// GetTokenState 获取 Token 状态(只读)
func (rl *RateLimiter) GetTokenState(tokenKey string) *TokenState {
rl.mu.RLock()
defer rl.mu.RUnlock()
state, exists := rl.states[tokenKey]
if !exists {
return nil
}
// 返回副本以防止外部修改
stateCopy := *state
return &stateCopy
}
// ClearTokenState 清除 Token 状态
func (rl *RateLimiter) ClearTokenState(tokenKey string) {
rl.mu.Lock()
defer rl.mu.Unlock()
delete(rl.states, tokenKey)
}
// ResetSuspension 重置暂停状态
func (rl *RateLimiter) ResetSuspension(tokenKey string) {
rl.mu.Lock()
defer rl.mu.Unlock()
state, exists := rl.states[tokenKey]
if exists {
state.IsSuspended = false
state.SuspendedAt = time.Time{}
state.SuspendReason = ""
state.CooldownEnd = time.Time{}
state.FailCount = 0
}
}

View File

@@ -0,0 +1,46 @@
package kiro
import (
"sync"
"time"
log "github.com/sirupsen/logrus"
)
var (
globalRateLimiter *RateLimiter
globalRateLimiterOnce sync.Once
globalCooldownManager *CooldownManager
globalCooldownManagerOnce sync.Once
cooldownStopCh chan struct{}
)
// GetGlobalRateLimiter returns the singleton RateLimiter instance.
func GetGlobalRateLimiter() *RateLimiter {
globalRateLimiterOnce.Do(func() {
globalRateLimiter = NewRateLimiter()
log.Info("kiro: global RateLimiter initialized")
})
return globalRateLimiter
}
// GetGlobalCooldownManager returns the singleton CooldownManager instance.
func GetGlobalCooldownManager() *CooldownManager {
globalCooldownManagerOnce.Do(func() {
globalCooldownManager = NewCooldownManager()
cooldownStopCh = make(chan struct{})
go globalCooldownManager.StartCleanupRoutine(5*time.Minute, cooldownStopCh)
log.Info("kiro: global CooldownManager initialized with cleanup routine")
})
return globalCooldownManager
}
// ShutdownRateLimiters stops the cooldown cleanup routine.
// Should be called during application shutdown.
func ShutdownRateLimiters() {
if cooldownStopCh != nil {
close(cooldownStopCh)
log.Info("kiro: rate limiter cleanup routine stopped")
}
}

View File

@@ -0,0 +1,304 @@
package kiro
import (
"sync"
"testing"
"time"
)
func TestNewRateLimiter(t *testing.T) {
rl := NewRateLimiter()
if rl == nil {
t.Fatal("expected non-nil RateLimiter")
}
if rl.states == nil {
t.Error("expected non-nil states map")
}
if rl.minTokenInterval != DefaultMinTokenInterval {
t.Errorf("expected minTokenInterval %v, got %v", DefaultMinTokenInterval, rl.minTokenInterval)
}
if rl.maxTokenInterval != DefaultMaxTokenInterval {
t.Errorf("expected maxTokenInterval %v, got %v", DefaultMaxTokenInterval, rl.maxTokenInterval)
}
if rl.dailyMaxRequests != DefaultDailyMaxRequests {
t.Errorf("expected dailyMaxRequests %d, got %d", DefaultDailyMaxRequests, rl.dailyMaxRequests)
}
}
func TestNewRateLimiterWithConfig(t *testing.T) {
cfg := RateLimiterConfig{
MinTokenInterval: 5 * time.Second,
MaxTokenInterval: 15 * time.Second,
DailyMaxRequests: 100,
JitterPercent: 0.2,
BackoffBase: 1 * time.Minute,
BackoffMax: 30 * time.Minute,
BackoffMultiplier: 1.5,
SuspendCooldown: 12 * time.Hour,
}
rl := NewRateLimiterWithConfig(cfg)
if rl.minTokenInterval != 5*time.Second {
t.Errorf("expected minTokenInterval 5s, got %v", rl.minTokenInterval)
}
if rl.maxTokenInterval != 15*time.Second {
t.Errorf("expected maxTokenInterval 15s, got %v", rl.maxTokenInterval)
}
if rl.dailyMaxRequests != 100 {
t.Errorf("expected dailyMaxRequests 100, got %d", rl.dailyMaxRequests)
}
}
func TestNewRateLimiterWithConfig_PartialConfig(t *testing.T) {
cfg := RateLimiterConfig{
MinTokenInterval: 5 * time.Second,
}
rl := NewRateLimiterWithConfig(cfg)
if rl.minTokenInterval != 5*time.Second {
t.Errorf("expected minTokenInterval 5s, got %v", rl.minTokenInterval)
}
if rl.maxTokenInterval != DefaultMaxTokenInterval {
t.Errorf("expected default maxTokenInterval, got %v", rl.maxTokenInterval)
}
}
func TestGetTokenState_NonExistent(t *testing.T) {
rl := NewRateLimiter()
state := rl.GetTokenState("nonexistent")
if state != nil {
t.Error("expected nil state for non-existent token")
}
}
func TestIsTokenAvailable_NewToken(t *testing.T) {
rl := NewRateLimiter()
if !rl.IsTokenAvailable("newtoken") {
t.Error("expected new token to be available")
}
}
func TestMarkTokenFailed(t *testing.T) {
rl := NewRateLimiter()
rl.MarkTokenFailed("token1")
state := rl.GetTokenState("token1")
if state == nil {
t.Fatal("expected non-nil state")
}
if state.FailCount != 1 {
t.Errorf("expected FailCount 1, got %d", state.FailCount)
}
if state.CooldownEnd.IsZero() {
t.Error("expected non-zero CooldownEnd")
}
}
func TestMarkTokenSuccess(t *testing.T) {
rl := NewRateLimiter()
rl.MarkTokenFailed("token1")
rl.MarkTokenFailed("token1")
rl.MarkTokenSuccess("token1")
state := rl.GetTokenState("token1")
if state == nil {
t.Fatal("expected non-nil state")
}
if state.FailCount != 0 {
t.Errorf("expected FailCount 0, got %d", state.FailCount)
}
if !state.CooldownEnd.IsZero() {
t.Error("expected zero CooldownEnd after success")
}
}
func TestCheckAndMarkSuspended_Suspended(t *testing.T) {
rl := NewRateLimiter()
testCases := []string{
"Account has been suspended",
"You are banned from this service",
"Account disabled",
"Access denied permanently",
"Rate limit exceeded",
"Too many requests",
"Quota exceeded for today",
}
for i, msg := range testCases {
tokenKey := "token" + string(rune('a'+i))
if !rl.CheckAndMarkSuspended(tokenKey, msg) {
t.Errorf("expected suspension detected for: %s", msg)
}
state := rl.GetTokenState(tokenKey)
if !state.IsSuspended {
t.Errorf("expected IsSuspended true for: %s", msg)
}
}
}
func TestCheckAndMarkSuspended_NotSuspended(t *testing.T) {
rl := NewRateLimiter()
normalErrors := []string{
"connection timeout",
"internal server error",
"bad request",
"invalid token format",
}
for i, msg := range normalErrors {
tokenKey := "token" + string(rune('a'+i))
if rl.CheckAndMarkSuspended(tokenKey, msg) {
t.Errorf("unexpected suspension for: %s", msg)
}
}
}
func TestIsTokenAvailable_Suspended(t *testing.T) {
rl := NewRateLimiter()
rl.CheckAndMarkSuspended("token1", "Account suspended")
if rl.IsTokenAvailable("token1") {
t.Error("expected suspended token to be unavailable")
}
}
func TestClearTokenState(t *testing.T) {
rl := NewRateLimiter()
rl.MarkTokenFailed("token1")
rl.ClearTokenState("token1")
state := rl.GetTokenState("token1")
if state != nil {
t.Error("expected nil state after clear")
}
}
func TestResetSuspension(t *testing.T) {
rl := NewRateLimiter()
rl.CheckAndMarkSuspended("token1", "Account suspended")
rl.ResetSuspension("token1")
state := rl.GetTokenState("token1")
if state.IsSuspended {
t.Error("expected IsSuspended false after reset")
}
if state.FailCount != 0 {
t.Errorf("expected FailCount 0, got %d", state.FailCount)
}
}
func TestResetSuspension_NonExistent(t *testing.T) {
rl := NewRateLimiter()
rl.ResetSuspension("nonexistent")
}
func TestCalculateBackoff_ZeroFailCount(t *testing.T) {
rl := NewRateLimiter()
backoff := rl.calculateBackoff(0)
if backoff != 0 {
t.Errorf("expected 0 backoff for 0 fails, got %v", backoff)
}
}
func TestCalculateBackoff_Exponential(t *testing.T) {
cfg := RateLimiterConfig{
BackoffBase: 1 * time.Minute,
BackoffMax: 60 * time.Minute,
BackoffMultiplier: 2.0,
JitterPercent: 0.3,
}
rl := NewRateLimiterWithConfig(cfg)
backoff1 := rl.calculateBackoff(1)
if backoff1 < 40*time.Second || backoff1 > 80*time.Second {
t.Errorf("expected ~1min (with jitter) for fail 1, got %v", backoff1)
}
backoff2 := rl.calculateBackoff(2)
if backoff2 < 80*time.Second || backoff2 > 160*time.Second {
t.Errorf("expected ~2min (with jitter) for fail 2, got %v", backoff2)
}
}
func TestCalculateBackoff_MaxCap(t *testing.T) {
cfg := RateLimiterConfig{
BackoffBase: 1 * time.Minute,
BackoffMax: 10 * time.Minute,
BackoffMultiplier: 2.0,
JitterPercent: 0,
}
rl := NewRateLimiterWithConfig(cfg)
backoff := rl.calculateBackoff(10)
if backoff > 10*time.Minute {
t.Errorf("expected backoff capped at 10min, got %v", backoff)
}
}
func TestGetTokenState_ReturnsCopy(t *testing.T) {
rl := NewRateLimiter()
rl.MarkTokenFailed("token1")
state1 := rl.GetTokenState("token1")
state1.FailCount = 999
state2 := rl.GetTokenState("token1")
if state2.FailCount == 999 {
t.Error("GetTokenState should return a copy")
}
}
func TestRateLimiter_ConcurrentAccess(t *testing.T) {
rl := NewRateLimiter()
const numGoroutines = 50
const numOperations = 50
var wg sync.WaitGroup
wg.Add(numGoroutines)
for i := 0; i < numGoroutines; i++ {
go func(id int) {
defer wg.Done()
tokenKey := "token" + string(rune('a'+id%10))
for j := 0; j < numOperations; j++ {
switch j % 6 {
case 0:
rl.IsTokenAvailable(tokenKey)
case 1:
rl.MarkTokenFailed(tokenKey)
case 2:
rl.MarkTokenSuccess(tokenKey)
case 3:
rl.GetTokenState(tokenKey)
case 4:
rl.CheckAndMarkSuspended(tokenKey, "test error")
case 5:
rl.ResetSuspension(tokenKey)
}
}
}(i)
}
wg.Wait()
}
func TestCalculateInterval_WithinRange(t *testing.T) {
cfg := RateLimiterConfig{
MinTokenInterval: 10 * time.Second,
MaxTokenInterval: 30 * time.Second,
JitterPercent: 0.3,
}
rl := NewRateLimiterWithConfig(cfg)
minAllowed := 7 * time.Second
maxAllowed := 40 * time.Second
for i := 0; i < 100; i++ {
interval := rl.calculateInterval()
if interval < minAllowed || interval > maxAllowed {
t.Errorf("interval %v outside expected range [%v, %v]", interval, minAllowed, maxAllowed)
}
}
}

View File

@@ -0,0 +1,202 @@
package kiro
import (
"context"
"sync"
"time"
"github.com/router-for-me/CLIProxyAPI/v6/internal/config"
"github.com/router-for-me/CLIProxyAPI/v6/internal/util"
log "github.com/sirupsen/logrus"
)
// RefreshManager is a singleton manager for background token refreshing.
type RefreshManager struct {
mu sync.Mutex
refresher *BackgroundRefresher
ctx context.Context
cancel context.CancelFunc
started bool
onTokenRefreshed func(tokenID string, tokenData *KiroTokenData)
}
var (
globalRefreshManager *RefreshManager
managerOnce sync.Once
)
// GetRefreshManager returns the global RefreshManager singleton.
func GetRefreshManager() *RefreshManager {
managerOnce.Do(func() {
globalRefreshManager = &RefreshManager{}
})
return globalRefreshManager
}
// Initialize sets up the background refresher.
func (m *RefreshManager) Initialize(baseDir string, cfg *config.Config) error {
m.mu.Lock()
defer m.mu.Unlock()
if m.started {
log.Debug("refresh manager: already initialized")
return nil
}
if baseDir == "" {
log.Warn("refresh manager: base directory not provided, skipping initialization")
return nil
}
resolvedBaseDir, err := util.ResolveAuthDir(baseDir)
if err != nil {
log.Warnf("refresh manager: failed to resolve auth directory %s: %v", baseDir, err)
}
if resolvedBaseDir != "" {
baseDir = resolvedBaseDir
}
repo := NewFileTokenRepository(baseDir)
opts := []RefresherOption{
WithInterval(time.Minute),
WithBatchSize(50),
WithConcurrency(10),
WithConfig(cfg),
}
// Pass callback to BackgroundRefresher if already set
if m.onTokenRefreshed != nil {
opts = append(opts, WithOnTokenRefreshed(m.onTokenRefreshed))
}
m.refresher = NewBackgroundRefresher(repo, opts...)
log.Infof("refresh manager: initialized with base directory %s", baseDir)
return nil
}
// Start begins background token refreshing.
func (m *RefreshManager) Start() {
m.mu.Lock()
defer m.mu.Unlock()
if m.started {
log.Debug("refresh manager: already started")
return
}
if m.refresher == nil {
log.Warn("refresh manager: not initialized, cannot start")
return
}
m.ctx, m.cancel = context.WithCancel(context.Background())
m.refresher.Start(m.ctx)
m.started = true
log.Info("refresh manager: background refresh started")
}
// Stop halts background token refreshing.
func (m *RefreshManager) Stop() {
m.mu.Lock()
defer m.mu.Unlock()
if !m.started {
return
}
if m.cancel != nil {
m.cancel()
}
if m.refresher != nil {
m.refresher.Stop()
}
m.started = false
log.Info("refresh manager: background refresh stopped")
}
// IsRunning reports whether background refreshing is active.
func (m *RefreshManager) IsRunning() bool {
m.mu.Lock()
defer m.mu.Unlock()
return m.started
}
// UpdateBaseDir changes the token directory at runtime.
func (m *RefreshManager) UpdateBaseDir(baseDir string) {
m.mu.Lock()
defer m.mu.Unlock()
if m.refresher != nil && m.refresher.tokenRepo != nil {
if repo, ok := m.refresher.tokenRepo.(*FileTokenRepository); ok {
repo.SetBaseDir(baseDir)
log.Infof("refresh manager: updated base directory to %s", baseDir)
}
}
}
// SetOnTokenRefreshed registers a callback invoked after a successful token refresh.
// Can be called at any time; supports runtime callback updates.
func (m *RefreshManager) SetOnTokenRefreshed(callback func(tokenID string, tokenData *KiroTokenData)) {
m.mu.Lock()
defer m.mu.Unlock()
m.onTokenRefreshed = callback
// Update the refresher's callback in a thread-safe manner if already created
if m.refresher != nil {
m.refresher.callbackMu.Lock()
m.refresher.onTokenRefreshed = callback
m.refresher.callbackMu.Unlock()
}
log.Debug("refresh manager: token refresh callback registered")
}
// InitializeAndStart initializes and starts background refreshing (convenience method).
func InitializeAndStart(baseDir string, cfg *config.Config) {
// Initialize global fingerprint config
initGlobalFingerprintConfig(cfg)
manager := GetRefreshManager()
if err := manager.Initialize(baseDir, cfg); err != nil {
log.Errorf("refresh manager: initialization failed: %v", err)
return
}
manager.Start()
}
// initGlobalFingerprintConfig loads fingerprint settings from application config.
func initGlobalFingerprintConfig(cfg *config.Config) {
if cfg == nil || cfg.KiroFingerprint == nil {
return
}
fpCfg := cfg.KiroFingerprint
SetGlobalFingerprintConfig(&FingerprintConfig{
OIDCSDKVersion: fpCfg.OIDCSDKVersion,
RuntimeSDKVersion: fpCfg.RuntimeSDKVersion,
StreamingSDKVersion: fpCfg.StreamingSDKVersion,
OSType: fpCfg.OSType,
OSVersion: fpCfg.OSVersion,
NodeVersion: fpCfg.NodeVersion,
KiroVersion: fpCfg.KiroVersion,
KiroHash: fpCfg.KiroHash,
})
log.Debug("kiro: global fingerprint config loaded")
}
// InitFingerprintConfig initializes the global fingerprint config from application config.
func InitFingerprintConfig(cfg *config.Config) {
initGlobalFingerprintConfig(cfg)
}
// StopGlobalRefreshManager stops the global refresh manager.
func StopGlobalRefreshManager() {
if globalRefreshManager != nil {
globalRefreshManager.Stop()
}
}

View File

@@ -0,0 +1,159 @@
// Package kiro provides refresh utilities for Kiro token management.
package kiro
import (
"context"
"fmt"
"time"
log "github.com/sirupsen/logrus"
)
// RefreshResult contains the result of a token refresh attempt.
type RefreshResult struct {
TokenData *KiroTokenData
Error error
UsedFallback bool // True if we used the existing token as fallback
}
// RefreshWithGracefulDegradation attempts to refresh a token with graceful degradation.
// If refresh fails but the existing access token is still valid, it returns the existing token.
// This matches kiro-openai-gateway's behavior for better reliability.
//
// Parameters:
// - ctx: Context for the request
// - refreshFunc: Function to perform the actual refresh
// - existingAccessToken: Current access token (for fallback)
// - expiresAt: Expiration time of the existing token
//
// Returns:
// - RefreshResult containing the new or existing token data
func RefreshWithGracefulDegradation(
ctx context.Context,
refreshFunc func(ctx context.Context) (*KiroTokenData, error),
existingAccessToken string,
expiresAt time.Time,
) RefreshResult {
// Try to refresh the token
newTokenData, err := refreshFunc(ctx)
if err == nil {
return RefreshResult{
TokenData: newTokenData,
Error: nil,
UsedFallback: false,
}
}
// Refresh failed - check if we can use the existing token
log.Warnf("kiro: token refresh failed: %v", err)
// Check if existing token is still valid (not expired)
if existingAccessToken != "" && time.Now().Before(expiresAt) {
remainingTime := time.Until(expiresAt)
log.Warnf("kiro: using existing access token (expires in %v). Will retry refresh later.", remainingTime.Round(time.Second))
return RefreshResult{
TokenData: &KiroTokenData{
AccessToken: existingAccessToken,
ExpiresAt: expiresAt.Format(time.RFC3339),
},
Error: nil,
UsedFallback: true,
}
}
// Token is expired and refresh failed - return the error
return RefreshResult{
TokenData: nil,
Error: fmt.Errorf("token refresh failed and existing token is expired: %w", err),
UsedFallback: false,
}
}
// IsTokenExpiringSoon checks if a token is expiring within the given threshold.
// Default threshold is 5 minutes if not specified.
func IsTokenExpiringSoon(expiresAt time.Time, threshold time.Duration) bool {
if threshold == 0 {
threshold = 5 * time.Minute
}
return time.Now().Add(threshold).After(expiresAt)
}
// IsTokenExpired checks if a token has already expired.
func IsTokenExpired(expiresAt time.Time) bool {
return time.Now().After(expiresAt)
}
// ParseExpiresAt parses an expiration time string in RFC3339 format.
// Returns zero time if parsing fails.
func ParseExpiresAt(expiresAtStr string) time.Time {
if expiresAtStr == "" {
return time.Time{}
}
t, err := time.Parse(time.RFC3339, expiresAtStr)
if err != nil {
log.Debugf("kiro: failed to parse expiresAt '%s': %v", expiresAtStr, err)
return time.Time{}
}
return t
}
// RefreshConfig contains configuration for token refresh behavior.
type RefreshConfig struct {
// MaxRetries is the maximum number of refresh attempts (default: 1)
MaxRetries int
// RetryDelay is the delay between retry attempts (default: 1 second)
RetryDelay time.Duration
// RefreshThreshold is how early to refresh before expiration (default: 5 minutes)
RefreshThreshold time.Duration
// EnableGracefulDegradation allows using existing token if refresh fails (default: true)
EnableGracefulDegradation bool
}
// DefaultRefreshConfig returns the default refresh configuration.
func DefaultRefreshConfig() RefreshConfig {
return RefreshConfig{
MaxRetries: 1,
RetryDelay: time.Second,
RefreshThreshold: 5 * time.Minute,
EnableGracefulDegradation: true,
}
}
// RefreshWithRetry attempts to refresh a token with retry logic.
func RefreshWithRetry(
ctx context.Context,
refreshFunc func(ctx context.Context) (*KiroTokenData, error),
config RefreshConfig,
) (*KiroTokenData, error) {
var lastErr error
maxAttempts := config.MaxRetries + 1
if maxAttempts < 1 {
maxAttempts = 1
}
for attempt := 1; attempt <= maxAttempts; attempt++ {
tokenData, err := refreshFunc(ctx)
if err == nil {
if attempt > 1 {
log.Infof("kiro: token refresh succeeded on attempt %d", attempt)
}
return tokenData, nil
}
lastErr = err
log.Warnf("kiro: token refresh attempt %d/%d failed: %v", attempt, maxAttempts, err)
// Don't sleep after the last attempt
if attempt < maxAttempts {
select {
case <-ctx.Done():
return nil, ctx.Err()
case <-time.After(config.RetryDelay):
}
}
}
return nil, fmt.Errorf("token refresh failed after %d attempts: %w", maxAttempts, lastErr)
}

View File

@@ -0,0 +1,488 @@
// Package kiro provides social authentication (Google/GitHub) for Kiro via AuthServiceClient.
package kiro
import (
"bufio"
"context"
"crypto/rand"
"crypto/sha256"
"encoding/base64"
"encoding/json"
"fmt"
"html"
"io"
"net"
"net/http"
"net/url"
"os"
"os/exec"
"runtime"
"strings"
"time"
"github.com/router-for-me/CLIProxyAPI/v6/internal/browser"
"github.com/router-for-me/CLIProxyAPI/v6/internal/config"
"github.com/router-for-me/CLIProxyAPI/v6/internal/util"
log "github.com/sirupsen/logrus"
"golang.org/x/term"
)
const (
// Kiro AuthService endpoint
kiroAuthServiceEndpoint = "https://prod.us-east-1.auth.desktop.kiro.dev"
// OAuth timeout
socialAuthTimeout = 10 * time.Minute
// Default callback port for social auth HTTP server
socialAuthCallbackPort = 9876
)
// SocialProvider represents the social login provider.
type SocialProvider string
const (
// ProviderGoogle is Google OAuth provider
ProviderGoogle SocialProvider = "Google"
// ProviderGitHub is GitHub OAuth provider
ProviderGitHub SocialProvider = "Github"
// Note: AWS Builder ID is NOT supported by Kiro's auth service.
// It only supports: Google, Github, Cognito
// AWS Builder ID must use device code flow via SSO OIDC.
)
// CreateTokenRequest is sent to Kiro's /oauth/token endpoint.
type CreateTokenRequest struct {
Code string `json:"code"`
CodeVerifier string `json:"code_verifier"`
RedirectURI string `json:"redirect_uri"`
InvitationCode string `json:"invitation_code,omitempty"`
}
// SocialTokenResponse from Kiro's /oauth/token endpoint for social auth.
type SocialTokenResponse struct {
AccessToken string `json:"accessToken"`
RefreshToken string `json:"refreshToken"`
ProfileArn string `json:"profileArn"`
ExpiresIn int `json:"expiresIn"`
}
// RefreshTokenRequest is sent to Kiro's /refreshToken endpoint.
type RefreshTokenRequest struct {
RefreshToken string `json:"refreshToken"`
}
// WebCallbackResult contains the OAuth callback result from HTTP server.
type WebCallbackResult struct {
Code string
State string
Error string
}
// SocialAuthClient handles social authentication with Kiro.
type SocialAuthClient struct {
httpClient *http.Client
cfg *config.Config
protocolHandler *ProtocolHandler
machineID string
kiroVersion string
}
// NewSocialAuthClient creates a new social auth client.
func NewSocialAuthClient(cfg *config.Config) *SocialAuthClient {
client := &http.Client{Timeout: 30 * time.Second}
if cfg != nil {
client = util.SetProxy(&cfg.SDKConfig, client)
}
fp := GlobalFingerprintManager().GetFingerprint("login")
return &SocialAuthClient{
httpClient: client,
cfg: cfg,
protocolHandler: NewProtocolHandler(),
machineID: fp.KiroHash,
kiroVersion: fp.KiroVersion,
}
}
// startWebCallbackServer starts a local HTTP server to receive the OAuth callback.
// This is used instead of the kiro:// protocol handler to avoid redirect_mismatch errors.
func (c *SocialAuthClient) startWebCallbackServer(ctx context.Context, expectedState string) (string, <-chan WebCallbackResult, error) {
// Try to find an available port - use localhost like Kiro does
listener, err := net.Listen("tcp", fmt.Sprintf("localhost:%d", socialAuthCallbackPort))
if err != nil {
// Try with dynamic port (RFC 8252 allows dynamic ports for native apps)
log.Warnf("kiro social auth: default port %d is busy, falling back to dynamic port", socialAuthCallbackPort)
listener, err = net.Listen("tcp", "localhost:0")
if err != nil {
return "", nil, fmt.Errorf("failed to start callback server: %w", err)
}
}
port := listener.Addr().(*net.TCPAddr).Port
// Use http scheme for local callback server
redirectURI := fmt.Sprintf("http://localhost:%d/oauth/callback", port)
resultChan := make(chan WebCallbackResult, 1)
server := &http.Server{
ReadHeaderTimeout: 10 * time.Second,
}
mux := http.NewServeMux()
mux.HandleFunc("/oauth/callback", func(w http.ResponseWriter, r *http.Request) {
code := r.URL.Query().Get("code")
state := r.URL.Query().Get("state")
errParam := r.URL.Query().Get("error")
if errParam != "" {
w.Header().Set("Content-Type", "text/html; charset=utf-8")
w.WriteHeader(http.StatusBadRequest)
fmt.Fprintf(w, `<!DOCTYPE html>
<html><head><title>Login Failed</title></head>
<body><h1>Login Failed</h1><p>%s</p><p>You can close this window.</p></body></html>`, html.EscapeString(errParam))
resultChan <- WebCallbackResult{Error: errParam}
return
}
if state != expectedState {
w.Header().Set("Content-Type", "text/html; charset=utf-8")
w.WriteHeader(http.StatusBadRequest)
fmt.Fprint(w, `<!DOCTYPE html>
<html><head><title>Login Failed</title></head>
<body><h1>Login Failed</h1><p>Invalid state parameter</p><p>You can close this window.</p></body></html>`)
resultChan <- WebCallbackResult{Error: "state mismatch"}
return
}
w.Header().Set("Content-Type", "text/html; charset=utf-8")
fmt.Fprint(w, `<!DOCTYPE html>
<html><head><title>Login Successful</title></head>
<body><h1>Login Successful!</h1><p>You can close this window and return to the terminal.</p>
<script>window.close();</script></body></html>`)
resultChan <- WebCallbackResult{Code: code, State: state}
})
server.Handler = mux
go func() {
if err := server.Serve(listener); err != nil && err != http.ErrServerClosed {
log.Debugf("kiro social auth callback server error: %v", err)
}
}()
go func() {
select {
case <-ctx.Done():
case <-time.After(socialAuthTimeout):
case <-resultChan:
}
_ = server.Shutdown(context.Background())
}()
return redirectURI, resultChan, nil
}
// generatePKCE generates PKCE code verifier and challenge.
func generatePKCE() (verifier, challenge string, err error) {
// Generate 32 bytes of random data for verifier
b := make([]byte, 32)
if _, err := rand.Read(b); err != nil {
return "", "", fmt.Errorf("failed to generate random bytes: %w", err)
}
verifier = base64.RawURLEncoding.EncodeToString(b)
// Generate SHA256 hash of verifier for challenge
h := sha256.Sum256([]byte(verifier))
challenge = base64.RawURLEncoding.EncodeToString(h[:])
return verifier, challenge, nil
}
// generateState generates a random state parameter.
func generateStateParam() (string, error) {
b := make([]byte, 16)
if _, err := rand.Read(b); err != nil {
return "", err
}
return base64.RawURLEncoding.EncodeToString(b), nil
}
// buildLoginURL constructs the Kiro OAuth login URL.
// The login endpoint expects a GET request with query parameters.
// Format: /login?idp=Google&redirect_uri=...&code_challenge=...&code_challenge_method=S256&state=...&prompt=select_account
// The prompt=select_account parameter forces the account selection screen even if already logged in.
func (c *SocialAuthClient) buildLoginURL(provider, redirectURI, codeChallenge, state string) string {
return fmt.Sprintf("%s/login?idp=%s&redirect_uri=%s&code_challenge=%s&code_challenge_method=S256&state=%s&prompt=select_account",
kiroAuthServiceEndpoint,
provider,
url.QueryEscape(redirectURI),
codeChallenge,
state,
)
}
// CreateToken exchanges the authorization code for tokens.
func (c *SocialAuthClient) CreateToken(ctx context.Context, req *CreateTokenRequest) (*SocialTokenResponse, error) {
body, err := json.Marshal(req)
if err != nil {
return nil, fmt.Errorf("failed to marshal token request: %w", err)
}
tokenURL := kiroAuthServiceEndpoint + "/oauth/token"
httpReq, err := http.NewRequestWithContext(ctx, http.MethodPost, tokenURL, strings.NewReader(string(body)))
if err != nil {
return nil, fmt.Errorf("failed to create token request: %w", err)
}
httpReq.Header.Set("Content-Type", "application/json")
httpReq.Header.Set("User-Agent", fmt.Sprintf("KiroIDE-%s-%s", c.kiroVersion, c.machineID))
httpReq.Header.Set("Accept", "application/json, text/plain, */*")
resp, err := c.httpClient.Do(httpReq)
if err != nil {
return nil, fmt.Errorf("token request failed: %w", err)
}
defer resp.Body.Close()
respBody, err := io.ReadAll(resp.Body)
if err != nil {
return nil, fmt.Errorf("failed to read token response: %w", err)
}
if resp.StatusCode != http.StatusOK {
log.Debugf("token exchange failed (status %d): %s", resp.StatusCode, string(respBody))
return nil, fmt.Errorf("token exchange failed (status %d)", resp.StatusCode)
}
var tokenResp SocialTokenResponse
if err := json.Unmarshal(respBody, &tokenResp); err != nil {
return nil, fmt.Errorf("failed to parse token response: %w", err)
}
return &tokenResp, nil
}
// RefreshSocialToken refreshes an expired social auth token.
func (c *SocialAuthClient) RefreshSocialToken(ctx context.Context, refreshToken string) (*KiroTokenData, error) {
body, err := json.Marshal(&RefreshTokenRequest{RefreshToken: refreshToken})
if err != nil {
return nil, fmt.Errorf("failed to marshal refresh request: %w", err)
}
refreshURL := kiroAuthServiceEndpoint + "/refreshToken"
httpReq, err := http.NewRequestWithContext(ctx, http.MethodPost, refreshURL, strings.NewReader(string(body)))
if err != nil {
return nil, fmt.Errorf("failed to create refresh request: %w", err)
}
httpReq.Header.Set("Content-Type", "application/json")
httpReq.Header.Set("User-Agent", fmt.Sprintf("KiroIDE-%s-%s", c.kiroVersion, c.machineID))
httpReq.Header.Set("Accept", "application/json, text/plain, */*")
resp, err := c.httpClient.Do(httpReq)
if err != nil {
return nil, fmt.Errorf("refresh request failed: %w", err)
}
defer resp.Body.Close()
respBody, err := io.ReadAll(resp.Body)
if err != nil {
return nil, fmt.Errorf("failed to read refresh response: %w", err)
}
if resp.StatusCode != http.StatusOK {
log.Debugf("token refresh failed (status %d): %s", resp.StatusCode, string(respBody))
return nil, fmt.Errorf("token refresh failed (status %d)", resp.StatusCode)
}
var tokenResp SocialTokenResponse
if err := json.Unmarshal(respBody, &tokenResp); err != nil {
return nil, fmt.Errorf("failed to parse refresh response: %w", err)
}
// Validate ExpiresIn - use default 1 hour if invalid
expiresIn := tokenResp.ExpiresIn
if expiresIn <= 0 {
expiresIn = 3600 // Default 1 hour
}
expiresAt := time.Now().Add(time.Duration(expiresIn) * time.Second)
return &KiroTokenData{
AccessToken: tokenResp.AccessToken,
RefreshToken: tokenResp.RefreshToken,
ProfileArn: tokenResp.ProfileArn,
ExpiresAt: expiresAt.Format(time.RFC3339),
AuthMethod: "social",
Provider: "", // Caller should preserve original provider
Region: "us-east-1",
}, nil
}
// LoginWithSocial performs OAuth login with Google or GitHub.
// Uses local HTTP callback server instead of custom protocol handler to avoid redirect_mismatch errors.
func (c *SocialAuthClient) LoginWithSocial(ctx context.Context, provider SocialProvider) (*KiroTokenData, error) {
providerName := string(provider)
fmt.Println("\n╔══════════════════════════════════════════════════════════╗")
fmt.Printf("║ Kiro Authentication (%s) ║\n", providerName)
fmt.Println("╚══════════════════════════════════════════════════════════╝")
// Step 1: Start local HTTP callback server (instead of kiro:// protocol handler)
// This avoids redirect_mismatch errors with AWS Cognito
fmt.Println("\nSetting up authentication...")
// Step 2: Generate PKCE codes
codeVerifier, codeChallenge, err := generatePKCE()
if err != nil {
return nil, fmt.Errorf("failed to generate PKCE: %w", err)
}
// Step 3: Generate state
state, err := generateStateParam()
if err != nil {
return nil, fmt.Errorf("failed to generate state: %w", err)
}
// Step 4: Start local HTTP callback server
redirectURI, resultChan, err := c.startWebCallbackServer(ctx, state)
if err != nil {
return nil, fmt.Errorf("failed to start callback server: %w", err)
}
log.Debugf("kiro social auth: callback server started at %s", redirectURI)
// Step 5: Build the login URL using HTTP redirect URI
authURL := c.buildLoginURL(providerName, redirectURI, codeChallenge, state)
// Set incognito mode based on config (defaults to true for Kiro, can be overridden with --no-incognito)
// Incognito mode enables multi-account support by bypassing cached sessions
if c.cfg != nil {
browser.SetIncognitoMode(c.cfg.IncognitoBrowser)
if !c.cfg.IncognitoBrowser {
log.Info("kiro: using normal browser mode (--no-incognito). Note: You may not be able to select a different account.")
} else {
log.Debug("kiro: using incognito mode for multi-account support")
}
} else {
browser.SetIncognitoMode(true) // Default to incognito if no config
log.Debug("kiro: using incognito mode for multi-account support (default)")
}
// Step 6: Open browser for user authentication
fmt.Println("\n════════════════════════════════════════════════════════════")
fmt.Printf(" Opening browser for %s authentication...\n", providerName)
fmt.Println("════════════════════════════════════════════════════════════")
fmt.Printf("\n URL: %s\n\n", authURL)
if err := browser.OpenURL(authURL); err != nil {
log.Warnf("Could not open browser automatically: %v", err)
fmt.Println(" ⚠ Could not open browser automatically.")
fmt.Println(" Please open the URL above in your browser manually.")
} else {
fmt.Println(" (Browser opened automatically)")
}
fmt.Println("\n Waiting for authentication callback...")
// Step 7: Wait for callback from HTTP server
select {
case <-ctx.Done():
return nil, ctx.Err()
case <-time.After(socialAuthTimeout):
return nil, fmt.Errorf("authentication timed out")
case callback := <-resultChan:
if callback.Error != "" {
return nil, fmt.Errorf("authentication error: %s", callback.Error)
}
// State is already validated by the callback server
if callback.Code == "" {
return nil, fmt.Errorf("no authorization code received")
}
fmt.Println("\n✓ Authorization received!")
// Step 8: Exchange code for tokens
fmt.Println("Exchanging code for tokens...")
tokenReq := &CreateTokenRequest{
Code: callback.Code,
CodeVerifier: codeVerifier,
RedirectURI: redirectURI, // Use HTTP redirect URI, not kiro:// protocol
}
tokenResp, err := c.CreateToken(ctx, tokenReq)
if err != nil {
return nil, fmt.Errorf("failed to exchange code for tokens: %w", err)
}
fmt.Println("\n✓ Authentication successful!")
// Close the browser window
if err := browser.CloseBrowser(); err != nil {
log.Debugf("Failed to close browser: %v", err)
}
// Validate ExpiresIn - use default 1 hour if invalid
expiresIn := tokenResp.ExpiresIn
if expiresIn <= 0 {
expiresIn = 3600
}
expiresAt := time.Now().Add(time.Duration(expiresIn) * time.Second)
// Try to extract email from JWT access token first
email := ExtractEmailFromJWT(tokenResp.AccessToken)
// If no email in JWT, ask user for account label (only in interactive mode)
if email == "" && isInteractiveTerminal() {
fmt.Print("\n Enter account label for file naming (optional, press Enter to skip): ")
reader := bufio.NewReader(os.Stdin)
var err error
email, err = reader.ReadString('\n')
if err != nil {
log.Debugf("Failed to read account label: %v", err)
}
email = strings.TrimSpace(email)
}
return &KiroTokenData{
AccessToken: tokenResp.AccessToken,
RefreshToken: tokenResp.RefreshToken,
ProfileArn: tokenResp.ProfileArn,
ExpiresAt: expiresAt.Format(time.RFC3339),
AuthMethod: "social",
Provider: providerName,
Email: email, // JWT email or user-provided label
Region: "us-east-1",
}, nil
}
}
// LoginWithGoogle performs OAuth login with Google.
func (c *SocialAuthClient) LoginWithGoogle(ctx context.Context) (*KiroTokenData, error) {
return c.LoginWithSocial(ctx, ProviderGoogle)
}
// LoginWithGitHub performs OAuth login with GitHub.
func (c *SocialAuthClient) LoginWithGitHub(ctx context.Context) (*KiroTokenData, error) {
return c.LoginWithSocial(ctx, ProviderGitHub)
}
// forceDefaultProtocolHandler sets our protocol handler as the default for kiro:// URLs.
// This prevents the "Open with" dialog from appearing on Linux.
// On non-Linux platforms, this is a no-op as they use different mechanisms.
func forceDefaultProtocolHandler() {
if runtime.GOOS != "linux" {
return // Non-Linux platforms use different handler mechanisms
}
// Set our handler as default using xdg-mime
cmd := exec.Command("xdg-mime", "default", "kiro-oauth-handler.desktop", "x-scheme-handler/kiro")
if err := cmd.Run(); err != nil {
log.Warnf("Failed to set default protocol handler: %v. You may see a handler selection dialog.", err)
}
}
// isInteractiveTerminal checks if stdin is connected to an interactive terminal.
// Returns false in CI/automated environments or when stdin is piped.
func isInteractiveTerminal() bool {
return term.IsTerminal(int(os.Stdin.Fd()))
}

File diff suppressed because it is too large Load Diff

View File

@@ -0,0 +1,261 @@
package kiro
import (
"context"
"encoding/json"
"fmt"
"io"
"net/http"
"net/http/httptest"
"net/url"
"strings"
"testing"
)
type recordingRoundTripper struct {
lastReq *http.Request
}
func (rt *recordingRoundTripper) RoundTrip(req *http.Request) (*http.Response, error) {
rt.lastReq = req
body := `{"nextToken":null,"profiles":[{"arn":"arn:aws:codewhisperer:us-east-1:123456789012:profile/ABC","profileName":"test"}]}`
return &http.Response{
StatusCode: http.StatusOK,
Body: io.NopCloser(strings.NewReader(body)),
Header: make(http.Header),
}, nil
}
func TestTryListAvailableProfiles_UsesClientIDForAccountKey(t *testing.T) {
rt := &recordingRoundTripper{}
client := &SSOOIDCClient{
httpClient: &http.Client{Transport: rt},
}
profileArn := client.tryListAvailableProfiles(context.Background(), "access-token", "client-id-123", "refresh-token-456")
if profileArn == "" {
t.Fatal("expected profileArn, got empty result")
}
accountKey := GetAccountKey("client-id-123", "refresh-token-456")
fp := GlobalFingerprintManager().GetFingerprint(accountKey)
expected := fmt.Sprintf("aws-sdk-js/%s KiroIDE-%s-%s", fp.RuntimeSDKVersion, fp.KiroVersion, fp.KiroHash)
got := rt.lastReq.Header.Get("X-Amz-User-Agent")
if got != expected {
t.Errorf("X-Amz-User-Agent = %q, want %q", got, expected)
}
}
func TestTryListAvailableProfiles_UsesRefreshTokenWhenClientIDMissing(t *testing.T) {
rt := &recordingRoundTripper{}
client := &SSOOIDCClient{
httpClient: &http.Client{Transport: rt},
}
profileArn := client.tryListAvailableProfiles(context.Background(), "access-token", "", "refresh-token-789")
if profileArn == "" {
t.Fatal("expected profileArn, got empty result")
}
accountKey := GetAccountKey("", "refresh-token-789")
fp := GlobalFingerprintManager().GetFingerprint(accountKey)
expected := fmt.Sprintf("aws-sdk-js/%s KiroIDE-%s-%s", fp.RuntimeSDKVersion, fp.KiroVersion, fp.KiroHash)
got := rt.lastReq.Header.Get("X-Amz-User-Agent")
if got != expected {
t.Errorf("X-Amz-User-Agent = %q, want %q", got, expected)
}
}
func TestRegisterClientForAuthCodeWithIDC(t *testing.T) {
var capturedReq struct {
Method string
Path string
Headers http.Header
Body map[string]interface{}
}
mockResp := RegisterClientResponse{
ClientID: "test-client-id",
ClientSecret: "test-client-secret",
ClientIDIssuedAt: 1700000000,
ClientSecretExpiresAt: 1700086400,
}
ts := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
capturedReq.Method = r.Method
capturedReq.Path = r.URL.Path
capturedReq.Headers = r.Header.Clone()
bodyBytes, _ := io.ReadAll(r.Body)
json.Unmarshal(bodyBytes, &capturedReq.Body)
w.Header().Set("Content-Type", "application/json")
json.NewEncoder(w).Encode(mockResp)
}))
defer ts.Close()
// Extract host to build a region that resolves to our test server.
// Override getOIDCEndpoint by passing region="" and patching the endpoint.
// Since getOIDCEndpoint builds "https://oidc.{region}.amazonaws.com", we
// instead inject the test server URL directly via a custom HTTP client transport.
client := &SSOOIDCClient{
httpClient: ts.Client(),
}
// We need to route the request to our test server. Use a transport that rewrites the URL.
client.httpClient.Transport = &rewriteTransport{
base: ts.Client().Transport,
targetURL: ts.URL,
}
resp, err := client.RegisterClientForAuthCodeWithIDC(
context.Background(),
"http://127.0.0.1:19877/oauth/callback",
"https://my-idc-instance.awsapps.com/start",
"us-east-1",
)
if err != nil {
t.Fatalf("unexpected error: %v", err)
}
// Verify request method and path
if capturedReq.Method != http.MethodPost {
t.Errorf("method = %q, want POST", capturedReq.Method)
}
if capturedReq.Path != "/client/register" {
t.Errorf("path = %q, want /client/register", capturedReq.Path)
}
// Verify headers
if ct := capturedReq.Headers.Get("Content-Type"); ct != "application/json" {
t.Errorf("Content-Type = %q, want application/json", ct)
}
ua := capturedReq.Headers.Get("User-Agent")
if !strings.Contains(ua, "KiroIDE") {
t.Errorf("User-Agent %q does not contain KiroIDE", ua)
}
if !strings.Contains(ua, "sso-oidc") {
t.Errorf("User-Agent %q does not contain sso-oidc", ua)
}
xua := capturedReq.Headers.Get("X-Amz-User-Agent")
if !strings.Contains(xua, "KiroIDE") {
t.Errorf("x-amz-user-agent %q does not contain KiroIDE", xua)
}
// Verify body fields
if v, _ := capturedReq.Body["clientName"].(string); v != "Kiro IDE" {
t.Errorf("clientName = %q, want %q", v, "Kiro IDE")
}
if v, _ := capturedReq.Body["clientType"].(string); v != "public" {
t.Errorf("clientType = %q, want %q", v, "public")
}
if v, _ := capturedReq.Body["issuerUrl"].(string); v != "https://my-idc-instance.awsapps.com/start" {
t.Errorf("issuerUrl = %q, want %q", v, "https://my-idc-instance.awsapps.com/start")
}
// Verify scopes array
scopesRaw, ok := capturedReq.Body["scopes"].([]interface{})
if !ok || len(scopesRaw) != 5 {
t.Fatalf("scopes: got %v, want 5-element array", capturedReq.Body["scopes"])
}
expectedScopes := []string{
"codewhisperer:completions", "codewhisperer:analysis",
"codewhisperer:conversations", "codewhisperer:transformations",
"codewhisperer:taskassist",
}
for i, s := range expectedScopes {
if scopesRaw[i].(string) != s {
t.Errorf("scopes[%d] = %q, want %q", i, scopesRaw[i], s)
}
}
// Verify grantTypes
grantTypesRaw, ok := capturedReq.Body["grantTypes"].([]interface{})
if !ok || len(grantTypesRaw) != 2 {
t.Fatalf("grantTypes: got %v, want 2-element array", capturedReq.Body["grantTypes"])
}
if grantTypesRaw[0].(string) != "authorization_code" || grantTypesRaw[1].(string) != "refresh_token" {
t.Errorf("grantTypes = %v, want [authorization_code, refresh_token]", grantTypesRaw)
}
// Verify redirectUris
redirectRaw, ok := capturedReq.Body["redirectUris"].([]interface{})
if !ok || len(redirectRaw) != 1 {
t.Fatalf("redirectUris: got %v, want 1-element array", capturedReq.Body["redirectUris"])
}
if redirectRaw[0].(string) != "http://127.0.0.1:19877/oauth/callback" {
t.Errorf("redirectUris[0] = %q, want %q", redirectRaw[0], "http://127.0.0.1:19877/oauth/callback")
}
// Verify response parsing
if resp.ClientID != "test-client-id" {
t.Errorf("ClientID = %q, want %q", resp.ClientID, "test-client-id")
}
if resp.ClientSecret != "test-client-secret" {
t.Errorf("ClientSecret = %q, want %q", resp.ClientSecret, "test-client-secret")
}
if resp.ClientIDIssuedAt != 1700000000 {
t.Errorf("ClientIDIssuedAt = %d, want %d", resp.ClientIDIssuedAt, 1700000000)
}
if resp.ClientSecretExpiresAt != 1700086400 {
t.Errorf("ClientSecretExpiresAt = %d, want %d", resp.ClientSecretExpiresAt, 1700086400)
}
}
// rewriteTransport redirects all requests to the test server URL.
type rewriteTransport struct {
base http.RoundTripper
targetURL string
}
func (t *rewriteTransport) RoundTrip(req *http.Request) (*http.Response, error) {
target, _ := url.Parse(t.targetURL)
req.URL.Scheme = target.Scheme
req.URL.Host = target.Host
if t.base != nil {
return t.base.RoundTrip(req)
}
return http.DefaultTransport.RoundTrip(req)
}
func TestBuildAuthorizationURL(t *testing.T) {
scopes := "codewhisperer:completions,codewhisperer:analysis,codewhisperer:conversations,codewhisperer:transformations,codewhisperer:taskassist"
endpoint := "https://oidc.us-east-1.amazonaws.com"
redirectURI := "http://127.0.0.1:19877/oauth/callback"
authURL := buildAuthorizationURL(endpoint, "test-client-id", redirectURI, scopes, "random-state", "test-challenge")
// Verify colons and commas in scopes are percent-encoded
if !strings.Contains(authURL, "codewhisperer%3Acompletions") {
t.Errorf("expected colons in scopes to be percent-encoded, got: %s", authURL)
}
if !strings.Contains(authURL, "completions%2Ccodewhisperer") {
t.Errorf("expected commas in scopes to be percent-encoded, got: %s", authURL)
}
// Parse back and verify all parameters round-trip correctly
parsed, err := url.Parse(authURL)
if err != nil {
t.Fatalf("failed to parse auth URL: %v", err)
}
if !strings.HasPrefix(authURL, endpoint+"/authorize?") {
t.Errorf("expected URL to start with %s/authorize?, got: %s", endpoint, authURL)
}
q := parsed.Query()
checks := map[string]string{
"response_type": "code",
"client_id": "test-client-id",
"redirect_uri": redirectURI,
"scopes": scopes,
"state": "random-state",
"code_challenge": "test-challenge",
"code_challenge_method": "S256",
}
for key, want := range checks {
if got := q.Get(key); got != want {
t.Errorf("%s = %q, want %q", key, got, want)
}
}
}

View File

@@ -0,0 +1,89 @@
package kiro
import (
"encoding/json"
"fmt"
"os"
"path/filepath"
)
// KiroTokenStorage holds the persistent token data for Kiro authentication.
type KiroTokenStorage struct {
// Type is the provider type for management UI recognition (must be "kiro")
Type string `json:"type"`
// AccessToken is the OAuth2 access token for API access
AccessToken string `json:"access_token"`
// RefreshToken is used to obtain new access tokens
RefreshToken string `json:"refresh_token"`
// ProfileArn is the AWS CodeWhisperer profile ARN
ProfileArn string `json:"profile_arn"`
// ExpiresAt is the timestamp when the token expires
ExpiresAt string `json:"expires_at"`
// AuthMethod indicates the authentication method used
AuthMethod string `json:"auth_method"`
// Provider indicates the OAuth provider
Provider string `json:"provider"`
// LastRefresh is the timestamp of the last token refresh
LastRefresh string `json:"last_refresh"`
// ClientID is the OAuth client ID (required for token refresh)
ClientID string `json:"client_id,omitempty"`
// ClientSecret is the OAuth client secret (required for token refresh)
ClientSecret string `json:"client_secret,omitempty"`
// Region is the OIDC region for IDC login and token refresh
Region string `json:"region,omitempty"`
// StartURL is the AWS Identity Center start URL (for IDC auth)
StartURL string `json:"start_url,omitempty"`
// Email is the user's email address
Email string `json:"email,omitempty"`
}
// SaveTokenToFile persists the token storage to the specified file path.
func (s *KiroTokenStorage) SaveTokenToFile(authFilePath string) error {
dir := filepath.Dir(authFilePath)
if err := os.MkdirAll(dir, 0700); err != nil {
return fmt.Errorf("failed to create directory: %w", err)
}
data, err := json.MarshalIndent(s, "", " ")
if err != nil {
return fmt.Errorf("failed to marshal token storage: %w", err)
}
if err := os.WriteFile(authFilePath, data, 0600); err != nil {
return fmt.Errorf("failed to write token file: %w", err)
}
return nil
}
// LoadFromFile loads token storage from the specified file path.
func LoadFromFile(authFilePath string) (*KiroTokenStorage, error) {
data, err := os.ReadFile(authFilePath)
if err != nil {
return nil, fmt.Errorf("failed to read token file: %w", err)
}
var storage KiroTokenStorage
if err := json.Unmarshal(data, &storage); err != nil {
return nil, fmt.Errorf("failed to parse token file: %w", err)
}
return &storage, nil
}
// ToTokenData converts storage to KiroTokenData for API use.
func (s *KiroTokenStorage) ToTokenData() *KiroTokenData {
return &KiroTokenData{
AccessToken: s.AccessToken,
RefreshToken: s.RefreshToken,
ProfileArn: s.ProfileArn,
ExpiresAt: s.ExpiresAt,
AuthMethod: s.AuthMethod,
Provider: s.Provider,
ClientID: s.ClientID,
ClientSecret: s.ClientSecret,
Region: s.Region,
StartURL: s.StartURL,
Email: s.Email,
}
}

View File

@@ -0,0 +1,260 @@
package kiro
import (
"context"
"encoding/json"
"fmt"
"io/fs"
"os"
"path/filepath"
"sort"
"strings"
"sync"
"time"
log "github.com/sirupsen/logrus"
)
// FileTokenRepository 实现 TokenRepository 接口,基于文件系统存储
type FileTokenRepository struct {
mu sync.RWMutex
baseDir string
}
// NewFileTokenRepository 创建一个新的文件 token 存储库
func NewFileTokenRepository(baseDir string) *FileTokenRepository {
return &FileTokenRepository{
baseDir: baseDir,
}
}
// SetBaseDir 设置基础目录
func (r *FileTokenRepository) SetBaseDir(dir string) {
r.mu.Lock()
r.baseDir = strings.TrimSpace(dir)
r.mu.Unlock()
}
// FindOldestUnverified 查找需要刷新的 token按最后验证时间排序
func (r *FileTokenRepository) FindOldestUnverified(limit int) []*Token {
r.mu.RLock()
baseDir := r.baseDir
r.mu.RUnlock()
if baseDir == "" {
log.Debug("token repository: base directory not configured")
return nil
}
var tokens []*Token
err := filepath.WalkDir(baseDir, func(path string, d fs.DirEntry, walkErr error) error {
if walkErr != nil {
return nil // 忽略错误,继续遍历
}
if d.IsDir() {
return nil
}
if !strings.HasSuffix(strings.ToLower(d.Name()), ".json") {
return nil
}
// 只处理 kiro 相关的 token 文件
if !strings.HasPrefix(d.Name(), "kiro-") {
return nil
}
token, err := r.readTokenFile(path)
if err != nil {
log.Debugf("token repository: failed to read token file %s: %v", path, err)
return nil
}
if token != nil && token.RefreshToken != "" {
// 检查 token 是否需要刷新(过期前 5 分钟)
if token.ExpiresAt.IsZero() || time.Until(token.ExpiresAt) < 5*time.Minute {
tokens = append(tokens, token)
}
}
return nil
})
if err != nil {
log.Warnf("token repository: error walking directory: %v", err)
}
// 按最后验证时间排序(最旧的优先)
sort.Slice(tokens, func(i, j int) bool {
return tokens[i].LastVerified.Before(tokens[j].LastVerified)
})
// 限制返回数量
if limit > 0 && len(tokens) > limit {
tokens = tokens[:limit]
}
return tokens
}
// UpdateToken 更新 token 并持久化到文件
func (r *FileTokenRepository) UpdateToken(token *Token) error {
if token == nil {
return fmt.Errorf("token repository: token is nil")
}
r.mu.RLock()
baseDir := r.baseDir
r.mu.RUnlock()
if baseDir == "" {
return fmt.Errorf("token repository: base directory not configured")
}
// 构建文件路径
filePath := filepath.Join(baseDir, token.ID)
if !strings.HasSuffix(filePath, ".json") {
filePath += ".json"
}
// 读取现有文件内容
existingData := make(map[string]any)
if data, err := os.ReadFile(filePath); err == nil {
_ = json.Unmarshal(data, &existingData)
}
// 更新字段
existingData["access_token"] = token.AccessToken
existingData["refresh_token"] = token.RefreshToken
existingData["last_refresh"] = time.Now().Format(time.RFC3339)
if !token.ExpiresAt.IsZero() {
existingData["expires_at"] = token.ExpiresAt.Format(time.RFC3339)
}
// 保持原有的关键字段
if token.ClientID != "" {
existingData["client_id"] = token.ClientID
}
if token.ClientSecret != "" {
existingData["client_secret"] = token.ClientSecret
}
if token.AuthMethod != "" {
existingData["auth_method"] = token.AuthMethod
}
if token.Region != "" {
existingData["region"] = token.Region
}
if token.StartURL != "" {
existingData["start_url"] = token.StartURL
}
// 序列化并写入文件
raw, err := json.MarshalIndent(existingData, "", " ")
if err != nil {
return fmt.Errorf("token repository: marshal failed: %w", err)
}
// 原子写入:先写入临时文件,再重命名
tmpPath := filePath + ".tmp"
if err := os.WriteFile(tmpPath, raw, 0o600); err != nil {
return fmt.Errorf("token repository: write temp file failed: %w", err)
}
if err := os.Rename(tmpPath, filePath); err != nil {
_ = os.Remove(tmpPath)
return fmt.Errorf("token repository: rename failed: %w", err)
}
log.Debugf("token repository: updated token %s", token.ID)
return nil
}
// readTokenFile 从文件读取 token
func (r *FileTokenRepository) readTokenFile(path string) (*Token, error) {
data, err := os.ReadFile(path)
if err != nil {
return nil, err
}
var metadata map[string]any
if err := json.Unmarshal(data, &metadata); err != nil {
return nil, err
}
// 检查是否是 kiro token
tokenType, _ := metadata["type"].(string)
if tokenType != "kiro" {
return nil, nil
}
// 检查 auth_method (case-insensitive comparison to handle "IdC", "IDC", "idc", etc.)
authMethod, _ := metadata["auth_method"].(string)
authMethod = strings.ToLower(authMethod)
if authMethod != "idc" && authMethod != "builder-id" {
return nil, nil // 只处理 IDC 和 Builder ID token
}
token := &Token{
ID: filepath.Base(path),
AuthMethod: authMethod,
}
// 解析各字段
token.AccessToken, _ = metadata["access_token"].(string)
token.RefreshToken, _ = metadata["refresh_token"].(string)
token.ClientID, _ = metadata["client_id"].(string)
token.ClientSecret, _ = metadata["client_secret"].(string)
token.Region, _ = metadata["region"].(string)
token.StartURL, _ = metadata["start_url"].(string)
token.Provider, _ = metadata["provider"].(string)
// 解析时间字段
if expiresAtStr, ok := metadata["expires_at"].(string); ok && expiresAtStr != "" {
if t, err := time.Parse(time.RFC3339, expiresAtStr); err == nil {
token.ExpiresAt = t
}
}
if lastRefreshStr, ok := metadata["last_refresh"].(string); ok && lastRefreshStr != "" {
if t, err := time.Parse(time.RFC3339, lastRefreshStr); err == nil {
token.LastVerified = t
}
}
return token, nil
}
// ListKiroTokens 列出所有 Kiro token用于调试
func (r *FileTokenRepository) ListKiroTokens(ctx context.Context) ([]*Token, error) {
r.mu.RLock()
baseDir := r.baseDir
r.mu.RUnlock()
if baseDir == "" {
return nil, fmt.Errorf("token repository: base directory not configured")
}
var tokens []*Token
err := filepath.WalkDir(baseDir, func(path string, d fs.DirEntry, walkErr error) error {
if walkErr != nil {
return nil
}
if d.IsDir() {
return nil
}
if !strings.HasPrefix(d.Name(), "kiro-") || !strings.HasSuffix(d.Name(), ".json") {
return nil
}
token, err := r.readTokenFile(path)
if err != nil {
return nil
}
if token != nil {
tokens = append(tokens, token)
}
return nil
})
return tokens, err
}

View File

@@ -0,0 +1,236 @@
// Package kiro provides authentication functionality for AWS CodeWhisperer (Kiro) API.
// This file implements usage quota checking and monitoring.
package kiro
import (
"context"
"encoding/json"
"fmt"
"io"
"net/http"
"time"
"github.com/router-for-me/CLIProxyAPI/v6/internal/config"
"github.com/router-for-me/CLIProxyAPI/v6/internal/util"
)
// UsageQuotaResponse represents the API response structure for usage quota checking.
type UsageQuotaResponse struct {
UsageBreakdownList []UsageBreakdownExtended `json:"usageBreakdownList"`
SubscriptionInfo *SubscriptionInfo `json:"subscriptionInfo,omitempty"`
NextDateReset float64 `json:"nextDateReset,omitempty"`
}
// UsageBreakdownExtended represents detailed usage information for quota checking.
// Note: UsageBreakdown is already defined in codewhisperer_client.go
type UsageBreakdownExtended struct {
ResourceType string `json:"resourceType"`
UsageLimitWithPrecision float64 `json:"usageLimitWithPrecision"`
CurrentUsageWithPrecision float64 `json:"currentUsageWithPrecision"`
FreeTrialInfo *FreeTrialInfoExtended `json:"freeTrialInfo,omitempty"`
}
// FreeTrialInfoExtended represents free trial usage information.
type FreeTrialInfoExtended struct {
FreeTrialStatus string `json:"freeTrialStatus"`
UsageLimitWithPrecision float64 `json:"usageLimitWithPrecision"`
CurrentUsageWithPrecision float64 `json:"currentUsageWithPrecision"`
}
// QuotaStatus represents the quota status for a token.
type QuotaStatus struct {
TotalLimit float64
CurrentUsage float64
RemainingQuota float64
IsExhausted bool
ResourceType string
NextReset time.Time
}
// UsageChecker provides methods for checking token quota usage.
type UsageChecker struct {
httpClient *http.Client
}
// NewUsageChecker creates a new UsageChecker instance.
func NewUsageChecker(cfg *config.Config) *UsageChecker {
return &UsageChecker{
httpClient: util.SetProxy(&cfg.SDKConfig, &http.Client{Timeout: 30 * time.Second}),
}
}
// NewUsageCheckerWithClient creates a UsageChecker with a custom HTTP client.
func NewUsageCheckerWithClient(client *http.Client) *UsageChecker {
return &UsageChecker{
httpClient: client,
}
}
// CheckUsage retrieves usage limits for the given token.
func (c *UsageChecker) CheckUsage(ctx context.Context, tokenData *KiroTokenData) (*UsageQuotaResponse, error) {
if tokenData == nil {
return nil, fmt.Errorf("token data is nil")
}
if tokenData.AccessToken == "" {
return nil, fmt.Errorf("access token is empty")
}
queryParams := map[string]string{
"origin": "AI_EDITOR",
"profileArn": tokenData.ProfileArn,
"resourceType": "AGENTIC_REQUEST",
}
// Use endpoint from profileArn if available
endpoint := GetKiroAPIEndpointFromProfileArn(tokenData.ProfileArn)
url := buildURL(endpoint, pathGetUsageLimits, queryParams)
req, err := http.NewRequestWithContext(ctx, http.MethodGet, url, nil)
if err != nil {
return nil, fmt.Errorf("failed to create request: %w", err)
}
accountKey := GetAccountKey(tokenData.ClientID, tokenData.RefreshToken)
setRuntimeHeaders(req, tokenData.AccessToken, accountKey)
resp, err := c.httpClient.Do(req)
if err != nil {
return nil, fmt.Errorf("request failed: %w", err)
}
defer resp.Body.Close()
body, err := io.ReadAll(resp.Body)
if err != nil {
return nil, fmt.Errorf("failed to read response: %w", err)
}
if resp.StatusCode != http.StatusOK {
return nil, fmt.Errorf("API error (status %d): %s", resp.StatusCode, string(body))
}
var result UsageQuotaResponse
if err := json.Unmarshal(body, &result); err != nil {
return nil, fmt.Errorf("failed to parse usage response: %w", err)
}
return &result, nil
}
// CheckUsageByAccessToken retrieves usage limits using an access token and profile ARN directly.
func (c *UsageChecker) CheckUsageByAccessToken(ctx context.Context, accessToken, profileArn string) (*UsageQuotaResponse, error) {
tokenData := &KiroTokenData{
AccessToken: accessToken,
ProfileArn: profileArn,
}
return c.CheckUsage(ctx, tokenData)
}
// GetRemainingQuota calculates the remaining quota from usage limits.
func GetRemainingQuota(usage *UsageQuotaResponse) float64 {
if usage == nil || len(usage.UsageBreakdownList) == 0 {
return 0
}
var totalRemaining float64
for _, breakdown := range usage.UsageBreakdownList {
remaining := breakdown.UsageLimitWithPrecision - breakdown.CurrentUsageWithPrecision
if remaining > 0 {
totalRemaining += remaining
}
if breakdown.FreeTrialInfo != nil {
freeRemaining := breakdown.FreeTrialInfo.UsageLimitWithPrecision - breakdown.FreeTrialInfo.CurrentUsageWithPrecision
if freeRemaining > 0 {
totalRemaining += freeRemaining
}
}
}
return totalRemaining
}
// IsQuotaExhausted checks if the quota is exhausted based on usage limits.
func IsQuotaExhausted(usage *UsageQuotaResponse) bool {
if usage == nil || len(usage.UsageBreakdownList) == 0 {
return true
}
for _, breakdown := range usage.UsageBreakdownList {
if breakdown.CurrentUsageWithPrecision < breakdown.UsageLimitWithPrecision {
return false
}
if breakdown.FreeTrialInfo != nil {
if breakdown.FreeTrialInfo.CurrentUsageWithPrecision < breakdown.FreeTrialInfo.UsageLimitWithPrecision {
return false
}
}
}
return true
}
// GetQuotaStatus retrieves a comprehensive quota status for a token.
func (c *UsageChecker) GetQuotaStatus(ctx context.Context, tokenData *KiroTokenData) (*QuotaStatus, error) {
usage, err := c.CheckUsage(ctx, tokenData)
if err != nil {
return nil, err
}
status := &QuotaStatus{
IsExhausted: IsQuotaExhausted(usage),
}
if len(usage.UsageBreakdownList) > 0 {
breakdown := usage.UsageBreakdownList[0]
status.TotalLimit = breakdown.UsageLimitWithPrecision
status.CurrentUsage = breakdown.CurrentUsageWithPrecision
status.RemainingQuota = breakdown.UsageLimitWithPrecision - breakdown.CurrentUsageWithPrecision
status.ResourceType = breakdown.ResourceType
if breakdown.FreeTrialInfo != nil {
status.TotalLimit += breakdown.FreeTrialInfo.UsageLimitWithPrecision
status.CurrentUsage += breakdown.FreeTrialInfo.CurrentUsageWithPrecision
freeRemaining := breakdown.FreeTrialInfo.UsageLimitWithPrecision - breakdown.FreeTrialInfo.CurrentUsageWithPrecision
if freeRemaining > 0 {
status.RemainingQuota += freeRemaining
}
}
}
if usage.NextDateReset > 0 {
status.NextReset = time.Unix(int64(usage.NextDateReset/1000), 0)
}
return status, nil
}
// CalculateAvailableCount calculates the available request count based on usage limits.
func CalculateAvailableCount(usage *UsageQuotaResponse) float64 {
return GetRemainingQuota(usage)
}
// GetUsagePercentage calculates the usage percentage.
func GetUsagePercentage(usage *UsageQuotaResponse) float64 {
if usage == nil || len(usage.UsageBreakdownList) == 0 {
return 100.0
}
var totalLimit, totalUsage float64
for _, breakdown := range usage.UsageBreakdownList {
totalLimit += breakdown.UsageLimitWithPrecision
totalUsage += breakdown.CurrentUsageWithPrecision
if breakdown.FreeTrialInfo != nil {
totalLimit += breakdown.FreeTrialInfo.UsageLimitWithPrecision
totalUsage += breakdown.FreeTrialInfo.CurrentUsageWithPrecision
}
}
if totalLimit == 0 {
return 100.0
}
return (totalUsage / totalLimit) * 100
}

View File

@@ -6,14 +6,49 @@ import (
"fmt"
"os/exec"
"runtime"
"strings"
"sync"
pkgbrowser "github.com/pkg/browser"
log "github.com/sirupsen/logrus"
"github.com/skratchdot/open-golang/open"
)
// incognitoMode controls whether to open URLs in incognito/private mode.
// This is useful for OAuth flows where you want to use a different account.
var incognitoMode bool
// lastBrowserProcess stores the last opened browser process for cleanup
var lastBrowserProcess *exec.Cmd
var browserMutex sync.Mutex
// SetIncognitoMode enables or disables incognito/private browsing mode.
func SetIncognitoMode(enabled bool) {
incognitoMode = enabled
}
// IsIncognitoMode returns whether incognito mode is enabled.
func IsIncognitoMode() bool {
return incognitoMode
}
// CloseBrowser closes the last opened browser process.
func CloseBrowser() error {
browserMutex.Lock()
defer browserMutex.Unlock()
if lastBrowserProcess == nil || lastBrowserProcess.Process == nil {
return nil
}
err := lastBrowserProcess.Process.Kill()
lastBrowserProcess = nil
return err
}
// OpenURL opens the specified URL in the default web browser.
// It first attempts to use a platform-agnostic library and falls back to
// platform-specific commands if that fails.
// It uses the pkg/browser library which provides robust cross-platform support
// for Windows, macOS, and Linux.
// If incognito mode is enabled, it will open in a private/incognito window.
//
// Parameters:
// - url: The URL to open.
@@ -21,16 +56,22 @@ import (
// Returns:
// - An error if the URL cannot be opened, otherwise nil.
func OpenURL(url string) error {
fmt.Printf("Attempting to open URL in browser: %s\n", url)
log.Debugf("Opening URL in browser: %s (incognito=%v)", url, incognitoMode)
// Try using the open-golang library first
err := open.Run(url)
// If incognito mode is enabled, use platform-specific incognito commands
if incognitoMode {
log.Debug("Using incognito mode")
return openURLIncognito(url)
}
// Use pkg/browser for cross-platform support
err := pkgbrowser.OpenURL(url)
if err == nil {
log.Debug("Successfully opened URL using open-golang library")
log.Debug("Successfully opened URL using pkg/browser library")
return nil
}
log.Debugf("open-golang failed: %v, trying platform-specific commands", err)
log.Debugf("pkg/browser failed: %v, trying platform-specific commands", err)
// Fallback to platform-specific commands
return openURLPlatformSpecific(url)
@@ -78,18 +119,379 @@ func openURLPlatformSpecific(url string) error {
return nil
}
// openURLIncognito opens a URL in incognito/private browsing mode.
// It first tries to detect the default browser and use its incognito flag.
// Falls back to a chain of known browsers if detection fails.
//
// Parameters:
// - url: The URL to open.
//
// Returns:
// - An error if the URL cannot be opened, otherwise nil.
func openURLIncognito(url string) error {
// First, try to detect and use the default browser
if cmd := tryDefaultBrowserIncognito(url); cmd != nil {
log.Debugf("Using detected default browser: %s %v", cmd.Path, cmd.Args[1:])
if err := cmd.Start(); err == nil {
storeBrowserProcess(cmd)
log.Debug("Successfully opened URL in default browser's incognito mode")
return nil
}
log.Debugf("Failed to start default browser, trying fallback chain")
}
// Fallback to known browser chain
cmd := tryFallbackBrowsersIncognito(url)
if cmd == nil {
log.Warn("No browser with incognito support found, falling back to normal mode")
return openURLPlatformSpecific(url)
}
log.Debugf("Running incognito command: %s %v", cmd.Path, cmd.Args[1:])
err := cmd.Start()
if err != nil {
log.Warnf("Failed to open incognito browser: %v, falling back to normal mode", err)
return openURLPlatformSpecific(url)
}
storeBrowserProcess(cmd)
log.Debug("Successfully opened URL in incognito/private mode")
return nil
}
// storeBrowserProcess safely stores the browser process for later cleanup.
func storeBrowserProcess(cmd *exec.Cmd) {
browserMutex.Lock()
lastBrowserProcess = cmd
browserMutex.Unlock()
}
// tryDefaultBrowserIncognito attempts to detect the default browser and return
// an exec.Cmd configured with the appropriate incognito flag.
func tryDefaultBrowserIncognito(url string) *exec.Cmd {
switch runtime.GOOS {
case "darwin":
return tryDefaultBrowserMacOS(url)
case "windows":
return tryDefaultBrowserWindows(url)
case "linux":
return tryDefaultBrowserLinux(url)
}
return nil
}
// tryDefaultBrowserMacOS detects the default browser on macOS.
func tryDefaultBrowserMacOS(url string) *exec.Cmd {
// Try to get default browser from Launch Services
out, err := exec.Command("defaults", "read", "com.apple.LaunchServices/com.apple.launchservices.secure", "LSHandlers").Output()
if err != nil {
return nil
}
output := string(out)
var browserName string
// Parse the output to find the http/https handler
if containsBrowserID(output, "com.google.chrome") {
browserName = "chrome"
} else if containsBrowserID(output, "org.mozilla.firefox") {
browserName = "firefox"
} else if containsBrowserID(output, "com.apple.safari") {
browserName = "safari"
} else if containsBrowserID(output, "com.brave.browser") {
browserName = "brave"
} else if containsBrowserID(output, "com.microsoft.edgemac") {
browserName = "edge"
}
return createMacOSIncognitoCmd(browserName, url)
}
// containsBrowserID checks if the LaunchServices output contains a browser ID.
func containsBrowserID(output, bundleID string) bool {
return strings.Contains(output, bundleID)
}
// createMacOSIncognitoCmd creates the appropriate incognito command for macOS browsers.
func createMacOSIncognitoCmd(browserName, url string) *exec.Cmd {
switch browserName {
case "chrome":
// Try direct path first
chromePath := "/Applications/Google Chrome.app/Contents/MacOS/Google Chrome"
if _, err := exec.LookPath(chromePath); err == nil {
return exec.Command(chromePath, "--incognito", url)
}
return exec.Command("open", "-na", "Google Chrome", "--args", "--incognito", url)
case "firefox":
return exec.Command("open", "-na", "Firefox", "--args", "--private-window", url)
case "safari":
// Safari doesn't have CLI incognito, try AppleScript
return tryAppleScriptSafariPrivate(url)
case "brave":
return exec.Command("open", "-na", "Brave Browser", "--args", "--incognito", url)
case "edge":
return exec.Command("open", "-na", "Microsoft Edge", "--args", "--inprivate", url)
}
return nil
}
// tryAppleScriptSafariPrivate attempts to open Safari in private browsing mode using AppleScript.
func tryAppleScriptSafariPrivate(url string) *exec.Cmd {
// AppleScript to open a new private window in Safari
script := fmt.Sprintf(`
tell application "Safari"
activate
tell application "System Events"
keystroke "n" using {command down, shift down}
delay 0.5
end tell
set URL of document 1 to "%s"
end tell
`, url)
cmd := exec.Command("osascript", "-e", script)
// Test if this approach works by checking if Safari is available
if _, err := exec.LookPath("/Applications/Safari.app/Contents/MacOS/Safari"); err != nil {
log.Debug("Safari not found, AppleScript private window not available")
return nil
}
log.Debug("Attempting Safari private window via AppleScript")
return cmd
}
// tryDefaultBrowserWindows detects the default browser on Windows via registry.
func tryDefaultBrowserWindows(url string) *exec.Cmd {
// Query registry for default browser
out, err := exec.Command("reg", "query",
`HKEY_CURRENT_USER\Software\Microsoft\Windows\Shell\Associations\UrlAssociations\http\UserChoice`,
"/v", "ProgId").Output()
if err != nil {
return nil
}
output := string(out)
var browserName string
// Map ProgId to browser name
if strings.Contains(output, "ChromeHTML") {
browserName = "chrome"
} else if strings.Contains(output, "FirefoxURL") {
browserName = "firefox"
} else if strings.Contains(output, "MSEdgeHTM") {
browserName = "edge"
} else if strings.Contains(output, "BraveHTML") {
browserName = "brave"
}
return createWindowsIncognitoCmd(browserName, url)
}
// createWindowsIncognitoCmd creates the appropriate incognito command for Windows browsers.
func createWindowsIncognitoCmd(browserName, url string) *exec.Cmd {
switch browserName {
case "chrome":
paths := []string{
"chrome",
`C:\Program Files\Google\Chrome\Application\chrome.exe`,
`C:\Program Files (x86)\Google\Chrome\Application\chrome.exe`,
}
for _, p := range paths {
if _, err := exec.LookPath(p); err == nil {
return exec.Command(p, "--incognito", url)
}
}
case "firefox":
if path, err := exec.LookPath("firefox"); err == nil {
return exec.Command(path, "--private-window", url)
}
case "edge":
paths := []string{
"msedge",
`C:\Program Files (x86)\Microsoft\Edge\Application\msedge.exe`,
`C:\Program Files\Microsoft\Edge\Application\msedge.exe`,
}
for _, p := range paths {
if _, err := exec.LookPath(p); err == nil {
return exec.Command(p, "--inprivate", url)
}
}
case "brave":
paths := []string{
`C:\Program Files\BraveSoftware\Brave-Browser\Application\brave.exe`,
`C:\Program Files (x86)\BraveSoftware\Brave-Browser\Application\brave.exe`,
}
for _, p := range paths {
if _, err := exec.LookPath(p); err == nil {
return exec.Command(p, "--incognito", url)
}
}
}
return nil
}
// tryDefaultBrowserLinux detects the default browser on Linux using xdg-settings.
func tryDefaultBrowserLinux(url string) *exec.Cmd {
out, err := exec.Command("xdg-settings", "get", "default-web-browser").Output()
if err != nil {
return nil
}
desktop := string(out)
var browserName string
// Map .desktop file to browser name
if strings.Contains(desktop, "google-chrome") || strings.Contains(desktop, "chrome") {
browserName = "chrome"
} else if strings.Contains(desktop, "firefox") {
browserName = "firefox"
} else if strings.Contains(desktop, "chromium") {
browserName = "chromium"
} else if strings.Contains(desktop, "brave") {
browserName = "brave"
} else if strings.Contains(desktop, "microsoft-edge") || strings.Contains(desktop, "msedge") {
browserName = "edge"
}
return createLinuxIncognitoCmd(browserName, url)
}
// createLinuxIncognitoCmd creates the appropriate incognito command for Linux browsers.
func createLinuxIncognitoCmd(browserName, url string) *exec.Cmd {
switch browserName {
case "chrome":
paths := []string{"google-chrome", "google-chrome-stable"}
for _, p := range paths {
if path, err := exec.LookPath(p); err == nil {
return exec.Command(path, "--incognito", url)
}
}
case "firefox":
paths := []string{"firefox", "firefox-esr"}
for _, p := range paths {
if path, err := exec.LookPath(p); err == nil {
return exec.Command(path, "--private-window", url)
}
}
case "chromium":
paths := []string{"chromium", "chromium-browser"}
for _, p := range paths {
if path, err := exec.LookPath(p); err == nil {
return exec.Command(path, "--incognito", url)
}
}
case "brave":
if path, err := exec.LookPath("brave-browser"); err == nil {
return exec.Command(path, "--incognito", url)
}
case "edge":
if path, err := exec.LookPath("microsoft-edge"); err == nil {
return exec.Command(path, "--inprivate", url)
}
}
return nil
}
// tryFallbackBrowsersIncognito tries a chain of known browsers as fallback.
func tryFallbackBrowsersIncognito(url string) *exec.Cmd {
switch runtime.GOOS {
case "darwin":
return tryFallbackBrowsersMacOS(url)
case "windows":
return tryFallbackBrowsersWindows(url)
case "linux":
return tryFallbackBrowsersLinuxChain(url)
}
return nil
}
// tryFallbackBrowsersMacOS tries known browsers on macOS.
func tryFallbackBrowsersMacOS(url string) *exec.Cmd {
// Try Chrome
chromePath := "/Applications/Google Chrome.app/Contents/MacOS/Google Chrome"
if _, err := exec.LookPath(chromePath); err == nil {
return exec.Command(chromePath, "--incognito", url)
}
// Try Firefox
if _, err := exec.LookPath("/Applications/Firefox.app/Contents/MacOS/firefox"); err == nil {
return exec.Command("open", "-na", "Firefox", "--args", "--private-window", url)
}
// Try Brave
if _, err := exec.LookPath("/Applications/Brave Browser.app/Contents/MacOS/Brave Browser"); err == nil {
return exec.Command("open", "-na", "Brave Browser", "--args", "--incognito", url)
}
// Try Edge
if _, err := exec.LookPath("/Applications/Microsoft Edge.app/Contents/MacOS/Microsoft Edge"); err == nil {
return exec.Command("open", "-na", "Microsoft Edge", "--args", "--inprivate", url)
}
// Last resort: try Safari with AppleScript
if cmd := tryAppleScriptSafariPrivate(url); cmd != nil {
log.Info("Using Safari with AppleScript for private browsing (may require accessibility permissions)")
return cmd
}
return nil
}
// tryFallbackBrowsersWindows tries known browsers on Windows.
func tryFallbackBrowsersWindows(url string) *exec.Cmd {
// Chrome
chromePaths := []string{
"chrome",
`C:\Program Files\Google\Chrome\Application\chrome.exe`,
`C:\Program Files (x86)\Google\Chrome\Application\chrome.exe`,
}
for _, p := range chromePaths {
if _, err := exec.LookPath(p); err == nil {
return exec.Command(p, "--incognito", url)
}
}
// Firefox
if path, err := exec.LookPath("firefox"); err == nil {
return exec.Command(path, "--private-window", url)
}
// Edge (usually available on Windows 10+)
edgePaths := []string{
"msedge",
`C:\Program Files (x86)\Microsoft\Edge\Application\msedge.exe`,
`C:\Program Files\Microsoft\Edge\Application\msedge.exe`,
}
for _, p := range edgePaths {
if _, err := exec.LookPath(p); err == nil {
return exec.Command(p, "--inprivate", url)
}
}
return nil
}
// tryFallbackBrowsersLinuxChain tries known browsers on Linux.
func tryFallbackBrowsersLinuxChain(url string) *exec.Cmd {
type browserConfig struct {
name string
flag string
}
browsers := []browserConfig{
{"google-chrome", "--incognito"},
{"google-chrome-stable", "--incognito"},
{"chromium", "--incognito"},
{"chromium-browser", "--incognito"},
{"firefox", "--private-window"},
{"firefox-esr", "--private-window"},
{"brave-browser", "--incognito"},
{"microsoft-edge", "--inprivate"},
}
for _, b := range browsers {
if path, err := exec.LookPath(b.name); err == nil {
return exec.Command(path, b.flag, url)
}
}
return nil
}
// IsAvailable checks if the system has a command available to open a web browser.
// It verifies the presence of necessary commands for the current operating system.
//
// Returns:
// - true if a browser can be opened, false otherwise.
func IsAvailable() bool {
// First check if open-golang can work
testErr := open.Run("about:blank")
if testErr == nil {
return true
}
// Check platform-specific commands
switch runtime.GOOS {
case "darwin":

View File

@@ -6,7 +6,7 @@ import (
// newAuthManager creates a new authentication manager instance with all supported
// authenticators and a file-based token store. It initializes authenticators for
// Gemini, Codex, Claude, and Qwen providers.
// Gemini, Codex, Claude, Qwen, IFlow, Antigravity, and GitHub Copilot providers.
//
// Returns:
// - *sdkAuth.Manager: A configured authentication manager instance
@@ -20,6 +20,9 @@ func newAuthManager() *sdkAuth.Manager {
sdkAuth.NewIFlowAuthenticator(),
sdkAuth.NewAntigravityAuthenticator(),
sdkAuth.NewKimiAuthenticator(),
sdkAuth.NewKiroAuthenticator(),
sdkAuth.NewGitHubCopilotAuthenticator(),
sdkAuth.NewKiloAuthenticator(),
)
return manager
}

View File

@@ -0,0 +1,44 @@
package cmd
import (
"context"
"fmt"
"github.com/router-for-me/CLIProxyAPI/v6/internal/config"
sdkAuth "github.com/router-for-me/CLIProxyAPI/v6/sdk/auth"
log "github.com/sirupsen/logrus"
)
// DoGitHubCopilotLogin triggers the OAuth device flow for GitHub Copilot and saves tokens.
// It initiates the device flow authentication, displays the user code for the user to enter
// at GitHub's verification URL, and waits for authorization before saving the tokens.
//
// Parameters:
// - cfg: The application configuration containing proxy and auth directory settings
// - options: Login options including browser behavior settings
func DoGitHubCopilotLogin(cfg *config.Config, options *LoginOptions) {
if options == nil {
options = &LoginOptions{}
}
manager := newAuthManager()
authOpts := &sdkAuth.LoginOptions{
NoBrowser: options.NoBrowser,
Metadata: map[string]string{},
Prompt: options.Prompt,
}
record, savedPath, err := manager.Login(context.Background(), "github-copilot", cfg, authOpts)
if err != nil {
log.Errorf("GitHub Copilot authentication failed: %v", err)
return
}
if savedPath != "" {
fmt.Printf("Authentication saved to %s\n", savedPath)
}
if record != nil && record.Label != "" {
fmt.Printf("Authenticated as %s\n", record.Label)
}
fmt.Println("GitHub Copilot authentication successful!")
}

View File

@@ -0,0 +1,54 @@
package cmd
import (
"context"
"fmt"
"strings"
"github.com/router-for-me/CLIProxyAPI/v6/internal/config"
sdkAuth "github.com/router-for-me/CLIProxyAPI/v6/sdk/auth"
)
// DoKiloLogin handles the Kilo device flow using the shared authentication manager.
// It initiates the device-based authentication process for Kilo AI services and saves
// the authentication tokens to the configured auth directory.
//
// Parameters:
// - cfg: The application configuration
// - options: Login options including browser behavior and prompts
func DoKiloLogin(cfg *config.Config, options *LoginOptions) {
if options == nil {
options = &LoginOptions{}
}
manager := newAuthManager()
promptFn := options.Prompt
if promptFn == nil {
promptFn = func(prompt string) (string, error) {
fmt.Print(prompt)
var value string
fmt.Scanln(&value)
return strings.TrimSpace(value), nil
}
}
authOpts := &sdkAuth.LoginOptions{
NoBrowser: options.NoBrowser,
CallbackPort: options.CallbackPort,
Metadata: map[string]string{},
Prompt: promptFn,
}
_, savedPath, err := manager.Login(context.Background(), "kilo", cfg, authOpts)
if err != nil {
fmt.Printf("Kilo authentication failed: %v\n", err)
return
}
if savedPath != "" {
fmt.Printf("Authentication saved to %s\n", savedPath)
}
fmt.Println("Kilo authentication successful!")
}

257
internal/cmd/kiro_login.go Normal file
View File

@@ -0,0 +1,257 @@
package cmd
import (
"context"
"fmt"
"github.com/router-for-me/CLIProxyAPI/v6/internal/config"
sdkAuth "github.com/router-for-me/CLIProxyAPI/v6/sdk/auth"
log "github.com/sirupsen/logrus"
)
// DoKiroLogin triggers the Kiro authentication flow with Google OAuth.
// This is the default login method (same as --kiro-google-login).
//
// Parameters:
// - cfg: The application configuration
// - options: Login options including Prompt field
func DoKiroLogin(cfg *config.Config, options *LoginOptions) {
// Use Google login as default
DoKiroGoogleLogin(cfg, options)
}
// DoKiroGoogleLogin triggers Kiro authentication with Google OAuth.
// This uses a custom protocol handler (kiro://) to receive the callback.
//
// Parameters:
// - cfg: The application configuration
// - options: Login options including prompts
func DoKiroGoogleLogin(cfg *config.Config, options *LoginOptions) {
if options == nil {
options = &LoginOptions{}
}
// Note: Kiro defaults to incognito mode for multi-account support.
// Users can override with --no-incognito if they want to use existing browser sessions.
manager := newAuthManager()
// Use KiroAuthenticator with Google login
authenticator := sdkAuth.NewKiroAuthenticator()
record, err := authenticator.LoginWithGoogle(context.Background(), cfg, &sdkAuth.LoginOptions{
NoBrowser: options.NoBrowser,
Metadata: map[string]string{},
Prompt: options.Prompt,
})
if err != nil {
log.Errorf("Kiro Google authentication failed: %v", err)
fmt.Println("\nTroubleshooting:")
fmt.Println("1. Make sure the protocol handler is installed")
fmt.Println("2. Complete the Google login in the browser")
fmt.Println("3. If callback fails, try: --kiro-import (after logging in via Kiro IDE)")
return
}
// Save the auth record
savedPath, err := manager.SaveAuth(record, cfg)
if err != nil {
log.Errorf("Failed to save auth: %v", err)
return
}
if savedPath != "" {
fmt.Printf("Authentication saved to %s\n", savedPath)
}
if record != nil && record.Label != "" {
fmt.Printf("Authenticated as %s\n", record.Label)
}
fmt.Println("Kiro Google authentication successful!")
}
// DoKiroAWSLogin triggers Kiro authentication with AWS Builder ID.
// This uses the device code flow for AWS SSO OIDC authentication.
//
// Parameters:
// - cfg: The application configuration
// - options: Login options including prompts
func DoKiroAWSLogin(cfg *config.Config, options *LoginOptions) {
if options == nil {
options = &LoginOptions{}
}
// Note: Kiro defaults to incognito mode for multi-account support.
// Users can override with --no-incognito if they want to use existing browser sessions.
manager := newAuthManager()
// Use KiroAuthenticator with AWS Builder ID login (device code flow)
authenticator := sdkAuth.NewKiroAuthenticator()
record, err := authenticator.Login(context.Background(), cfg, &sdkAuth.LoginOptions{
NoBrowser: options.NoBrowser,
Metadata: map[string]string{},
Prompt: options.Prompt,
})
if err != nil {
log.Errorf("Kiro AWS authentication failed: %v", err)
fmt.Println("\nTroubleshooting:")
fmt.Println("1. Make sure you have an AWS Builder ID")
fmt.Println("2. Complete the authorization in the browser")
fmt.Println("3. If callback fails, try: --kiro-import (after logging in via Kiro IDE)")
return
}
// Save the auth record
savedPath, err := manager.SaveAuth(record, cfg)
if err != nil {
log.Errorf("Failed to save auth: %v", err)
return
}
if savedPath != "" {
fmt.Printf("Authentication saved to %s\n", savedPath)
}
if record != nil && record.Label != "" {
fmt.Printf("Authenticated as %s\n", record.Label)
}
fmt.Println("Kiro AWS authentication successful!")
}
// DoKiroAWSAuthCodeLogin triggers Kiro authentication with AWS Builder ID using authorization code flow.
// This provides a better UX than device code flow as it uses automatic browser callback.
//
// Parameters:
// - cfg: The application configuration
// - options: Login options including prompts
func DoKiroAWSAuthCodeLogin(cfg *config.Config, options *LoginOptions) {
if options == nil {
options = &LoginOptions{}
}
// Note: Kiro defaults to incognito mode for multi-account support.
// Users can override with --no-incognito if they want to use existing browser sessions.
manager := newAuthManager()
// Use KiroAuthenticator with AWS Builder ID login (authorization code flow)
authenticator := sdkAuth.NewKiroAuthenticator()
record, err := authenticator.LoginWithAuthCode(context.Background(), cfg, &sdkAuth.LoginOptions{
NoBrowser: options.NoBrowser,
Metadata: map[string]string{},
Prompt: options.Prompt,
})
if err != nil {
log.Errorf("Kiro AWS authentication (auth code) failed: %v", err)
fmt.Println("\nTroubleshooting:")
fmt.Println("1. Make sure you have an AWS Builder ID")
fmt.Println("2. Complete the authorization in the browser")
fmt.Println("3. If callback fails, try: --kiro-aws-login (device code flow)")
return
}
// Save the auth record
savedPath, err := manager.SaveAuth(record, cfg)
if err != nil {
log.Errorf("Failed to save auth: %v", err)
return
}
if savedPath != "" {
fmt.Printf("Authentication saved to %s\n", savedPath)
}
if record != nil && record.Label != "" {
fmt.Printf("Authenticated as %s\n", record.Label)
}
fmt.Println("Kiro AWS authentication successful!")
}
// DoKiroImport imports Kiro token from Kiro IDE's token file.
// This is useful for users who have already logged in via Kiro IDE
// and want to use the same credentials in CLI Proxy API.
//
// Parameters:
// - cfg: The application configuration
// - options: Login options (currently unused for import)
func DoKiroImport(cfg *config.Config, options *LoginOptions) {
if options == nil {
options = &LoginOptions{}
}
manager := newAuthManager()
// Use ImportFromKiroIDE instead of Login
authenticator := sdkAuth.NewKiroAuthenticator()
record, err := authenticator.ImportFromKiroIDE(context.Background(), cfg)
if err != nil {
log.Errorf("Kiro token import failed: %v", err)
fmt.Println("\nMake sure you have logged in to Kiro IDE first:")
fmt.Println("1. Open Kiro IDE")
fmt.Println("2. Click 'Sign in with Google' (or GitHub)")
fmt.Println("3. Complete the login process")
fmt.Println("4. Run this command again")
return
}
// Save the imported auth record
savedPath, err := manager.SaveAuth(record, cfg)
if err != nil {
log.Errorf("Failed to save auth: %v", err)
return
}
if savedPath != "" {
fmt.Printf("Authentication saved to %s\n", savedPath)
}
if record != nil && record.Label != "" {
fmt.Printf("Imported as %s\n", record.Label)
}
fmt.Println("Kiro token import successful!")
}
func DoKiroIDCLogin(cfg *config.Config, options *LoginOptions, startURL, region, flow string) {
if options == nil {
options = &LoginOptions{}
}
if startURL == "" {
log.Errorf("Kiro IDC login requires --kiro-idc-start-url")
fmt.Println("\nUsage: --kiro-idc-login --kiro-idc-start-url https://d-xxx.awsapps.com/start")
return
}
manager := newAuthManager()
authenticator := sdkAuth.NewKiroAuthenticator()
metadata := map[string]string{
"start-url": startURL,
"region": region,
"flow": flow,
}
record, err := authenticator.Login(context.Background(), cfg, &sdkAuth.LoginOptions{
NoBrowser: options.NoBrowser,
Metadata: metadata,
Prompt: options.Prompt,
})
if err != nil {
log.Errorf("Kiro IDC authentication failed: %v", err)
fmt.Println("\nTroubleshooting:")
fmt.Println("1. Make sure your IDC Start URL is correct")
fmt.Println("2. Complete the authorization in the browser")
fmt.Println("3. If auth code flow fails, try: --kiro-idc-flow device")
return
}
savedPath, err := manager.SaveAuth(record, cfg)
if err != nil {
log.Errorf("Failed to save auth: %v", err)
return
}
if savedPath != "" {
fmt.Printf("Authentication saved to %s\n", savedPath)
}
if record != nil && record.Label != "" {
fmt.Printf("Authenticated as %s\n", record.Label)
}
fmt.Println("Kiro IDC authentication successful!")
}

View File

@@ -87,6 +87,17 @@ type Config struct {
// GeminiKey defines Gemini API key configurations with optional routing overrides.
GeminiKey []GeminiKey `yaml:"gemini-api-key" json:"gemini-api-key"`
// KiroKey defines a list of Kiro (AWS CodeWhisperer) configurations.
KiroKey []KiroKey `yaml:"kiro" json:"kiro"`
// KiroFingerprint defines a global fingerprint configuration for all Kiro requests.
// When set, all Kiro requests will use this fixed fingerprint instead of random generation.
KiroFingerprint *KiroFingerprintConfig `yaml:"kiro-fingerprint,omitempty" json:"kiro-fingerprint,omitempty"`
// KiroPreferredEndpoint sets the global default preferred endpoint for all Kiro providers.
// Values: "ide" (default, CodeWhisperer) or "cli" (Amazon Q).
KiroPreferredEndpoint string `yaml:"kiro-preferred-endpoint" json:"kiro-preferred-endpoint"`
// Codex defines a list of Codex API key configurations as specified in the YAML configuration file.
CodexKey []CodexKey `yaml:"codex-api-key" json:"codex-api-key"`
@@ -108,11 +119,12 @@ type Config struct {
AmpCode AmpCode `yaml:"ampcode" json:"ampcode"`
// OAuthExcludedModels defines per-provider global model exclusions applied to OAuth/file-backed auth entries.
// Supported channels: gemini-cli, vertex, aistudio, antigravity, claude, codex, qwen, iflow, kiro, github-copilot.
OAuthExcludedModels map[string][]string `yaml:"oauth-excluded-models,omitempty" json:"oauth-excluded-models,omitempty"`
// OAuthModelAlias defines global model name aliases for OAuth/file-backed auth channels.
// These aliases affect both model listing and model routing for supported channels:
// gemini-cli, vertex, aistudio, antigravity, claude, codex, qwen, iflow.
// gemini-cli, vertex, aistudio, antigravity, claude, codex, qwen, iflow, kiro, github-copilot.
//
// NOTE: This does not apply to existing per-credential model alias features under:
// gemini-api-key, codex-api-key, claude-api-key, openai-compatibility, vertex-api-key, and ampcode.
@@ -121,6 +133,11 @@ type Config struct {
// Payload defines default and override rules for provider payload parameters.
Payload PayloadConfig `yaml:"payload" json:"payload"`
// IncognitoBrowser enables opening OAuth URLs in incognito/private browsing mode.
// This is useful when you want to login with a different account without logging out
// from your current session. Default: false.
IncognitoBrowser bool `yaml:"incognito-browser" json:"incognito-browser"`
legacyMigrationPending bool `yaml:"-" json:"-"`
}
@@ -450,6 +467,52 @@ type GeminiModel struct {
func (m GeminiModel) GetName() string { return m.Name }
func (m GeminiModel) GetAlias() string { return m.Alias }
// KiroKey represents the configuration for Kiro (AWS CodeWhisperer) authentication.
type KiroKey struct {
// TokenFile is the path to the Kiro token file (default: ~/.aws/sso/cache/kiro-auth-token.json)
TokenFile string `yaml:"token-file,omitempty" json:"token-file,omitempty"`
// AccessToken is the OAuth access token for direct configuration.
AccessToken string `yaml:"access-token,omitempty" json:"access-token,omitempty"`
// RefreshToken is the OAuth refresh token for token renewal.
RefreshToken string `yaml:"refresh-token,omitempty" json:"refresh-token,omitempty"`
// ProfileArn is the AWS CodeWhisperer profile ARN.
ProfileArn string `yaml:"profile-arn,omitempty" json:"profile-arn,omitempty"`
// Region is the AWS region (default: us-east-1).
Region string `yaml:"region,omitempty" json:"region,omitempty"`
// StartURL is the IAM Identity Center (IDC) start URL for SSO login.
StartURL string `yaml:"start-url,omitempty" json:"start-url,omitempty"`
// ProxyURL optionally overrides the global proxy for this configuration.
ProxyURL string `yaml:"proxy-url,omitempty" json:"proxy-url,omitempty"`
// AgentTaskType sets the Kiro API task type. Known values: "vibe", "dev", "chat".
// Leave empty to let API use defaults. Different values may inject different system prompts.
AgentTaskType string `yaml:"agent-task-type,omitempty" json:"agent-task-type,omitempty"`
// PreferredEndpoint sets the preferred Kiro API endpoint/quota.
// Values: "codewhisperer" (default, IDE quota) or "amazonq" (CLI quota).
PreferredEndpoint string `yaml:"preferred-endpoint,omitempty" json:"preferred-endpoint,omitempty"`
}
// KiroFingerprintConfig defines a global fingerprint configuration for Kiro requests.
// When configured, all Kiro requests will use this fixed fingerprint instead of random generation.
// Empty fields will fall back to random selection from built-in pools.
type KiroFingerprintConfig struct {
OIDCSDKVersion string `yaml:"oidc-sdk-version,omitempty" json:"oidc-sdk-version,omitempty"`
RuntimeSDKVersion string `yaml:"runtime-sdk-version,omitempty" json:"runtime-sdk-version,omitempty"`
StreamingSDKVersion string `yaml:"streaming-sdk-version,omitempty" json:"streaming-sdk-version,omitempty"`
OSType string `yaml:"os-type,omitempty" json:"os-type,omitempty"`
OSVersion string `yaml:"os-version,omitempty" json:"os-version,omitempty"`
NodeVersion string `yaml:"node-version,omitempty" json:"node-version,omitempty"`
KiroVersion string `yaml:"kiro-version,omitempty" json:"kiro-version,omitempty"`
KiroHash string `yaml:"kiro-hash,omitempty" json:"kiro-hash,omitempty"`
}
// OpenAICompatibility represents the configuration for OpenAI API compatibility
// with external providers, allowing model aliases to be routed through OpenAI API format.
type OpenAICompatibility struct {
@@ -556,6 +619,7 @@ func LoadConfigOptional(configFile string, optional bool) (*Config, error) {
cfg.Pprof.Addr = DefaultPprofAddr
cfg.AmpCode.RestrictManagementToLocalhost = false // Default to false: API key auth is sufficient
cfg.RemoteManagement.PanelGitHubRepository = DefaultPanelGitHubRepository
cfg.IncognitoBrowser = false // Default to normal browser (AWS uses incognito by force)
if err = yaml.Unmarshal(data, &cfg); err != nil {
if optional {
// In cloud deploy mode, if YAML parsing fails, return empty config instead of error.
@@ -628,6 +692,9 @@ func LoadConfigOptional(configFile string, optional bool) (*Config, error) {
// Sanitize Claude key headers
cfg.SanitizeClaudeKeys()
// Sanitize Kiro keys: trim whitespace from credential fields
cfg.SanitizeKiroKeys()
// Sanitize OpenAI compatibility providers: drop entries without base-url
cfg.SanitizeOpenAICompatibility()
@@ -717,14 +784,46 @@ func payloadRawString(value any) ([]byte, bool) {
// SanitizeOAuthModelAlias normalizes and deduplicates global OAuth model name aliases.
// It trims whitespace, normalizes channel keys to lower-case, drops empty entries,
// allows multiple aliases per upstream name, and ensures aliases are unique within each channel.
// It also injects default aliases for channels that have built-in defaults (e.g., kiro)
// when no user-configured aliases exist for those channels.
func (cfg *Config) SanitizeOAuthModelAlias() {
if cfg == nil || len(cfg.OAuthModelAlias) == 0 {
if cfg == nil {
return
}
// Inject channel defaults when the channel is absent in user config.
// Presence is checked case-insensitively and includes explicit nil/empty markers.
if cfg.OAuthModelAlias == nil {
cfg.OAuthModelAlias = make(map[string][]OAuthModelAlias)
}
hasChannel := func(channel string) bool {
for k := range cfg.OAuthModelAlias {
if strings.EqualFold(strings.TrimSpace(k), channel) {
return true
}
}
return false
}
if !hasChannel("kiro") {
cfg.OAuthModelAlias["kiro"] = defaultKiroAliases()
}
if !hasChannel("github-copilot") {
cfg.OAuthModelAlias["github-copilot"] = defaultGitHubCopilotAliases()
}
if len(cfg.OAuthModelAlias) == 0 {
return
}
out := make(map[string][]OAuthModelAlias, len(cfg.OAuthModelAlias))
for rawChannel, aliases := range cfg.OAuthModelAlias {
channel := strings.ToLower(strings.TrimSpace(rawChannel))
if channel == "" || len(aliases) == 0 {
if channel == "" {
continue
}
// Preserve channels that were explicitly set to empty/nil they act
// as "disabled" markers so default injection won't re-add them (#222).
if len(aliases) == 0 {
out[channel] = nil
continue
}
seenAlias := make(map[string]struct{}, len(aliases))
@@ -809,6 +908,23 @@ func (cfg *Config) SanitizeClaudeKeys() {
}
}
// SanitizeKiroKeys trims whitespace from Kiro credential fields.
func (cfg *Config) SanitizeKiroKeys() {
if cfg == nil || len(cfg.KiroKey) == 0 {
return
}
for i := range cfg.KiroKey {
entry := &cfg.KiroKey[i]
entry.TokenFile = strings.TrimSpace(entry.TokenFile)
entry.AccessToken = strings.TrimSpace(entry.AccessToken)
entry.RefreshToken = strings.TrimSpace(entry.RefreshToken)
entry.ProfileArn = strings.TrimSpace(entry.ProfileArn)
entry.Region = strings.TrimSpace(entry.Region)
entry.ProxyURL = strings.TrimSpace(entry.ProxyURL)
entry.PreferredEndpoint = strings.TrimSpace(entry.PreferredEndpoint)
}
}
// SanitizeGeminiKeys deduplicates and normalizes Gemini credentials.
func (cfg *Config) SanitizeGeminiKeys() {
if cfg == nil {

View File

@@ -20,6 +20,45 @@ var antigravityModelConversionTable = map[string]string{
"gemini-claude-opus-4-6-thinking": "claude-opus-4-6-thinking",
}
// defaultKiroAliases returns the default oauth-model-alias configuration
// for the kiro channel. Maps kiro-prefixed model names to standard Claude model
// names so that clients like Claude Code can use standard names directly.
func defaultKiroAliases() []OAuthModelAlias {
return []OAuthModelAlias{
// Sonnet 4.6
{Name: "kiro-claude-sonnet-4-6", Alias: "claude-sonnet-4-6", Fork: true},
// Sonnet 4.5
{Name: "kiro-claude-sonnet-4-5", Alias: "claude-sonnet-4-5-20250929", Fork: true},
{Name: "kiro-claude-sonnet-4-5", Alias: "claude-sonnet-4-5", Fork: true},
// Sonnet 4
{Name: "kiro-claude-sonnet-4", Alias: "claude-sonnet-4-20250514", Fork: true},
{Name: "kiro-claude-sonnet-4", Alias: "claude-sonnet-4", Fork: true},
// Opus 4.6
{Name: "kiro-claude-opus-4-6", Alias: "claude-opus-4-6", Fork: true},
// Opus 4.5
{Name: "kiro-claude-opus-4-5", Alias: "claude-opus-4-5-20251101", Fork: true},
{Name: "kiro-claude-opus-4-5", Alias: "claude-opus-4-5", Fork: true},
// Haiku 4.5
{Name: "kiro-claude-haiku-4-5", Alias: "claude-haiku-4-5-20251001", Fork: true},
{Name: "kiro-claude-haiku-4-5", Alias: "claude-haiku-4-5", Fork: true},
}
}
// defaultGitHubCopilotAliases returns default oauth-model-alias entries that
// expose Claude hyphen-style IDs for GitHub Copilot Claude models.
// This keeps compatibility with clients (e.g. Claude Code) that use
// Anthropic-style model IDs like "claude-opus-4-6".
func defaultGitHubCopilotAliases() []OAuthModelAlias {
return []OAuthModelAlias{
{Name: "claude-haiku-4.5", Alias: "claude-haiku-4-5", Fork: true},
{Name: "claude-opus-4.1", Alias: "claude-opus-4-1", Fork: true},
{Name: "claude-opus-4.5", Alias: "claude-opus-4-5", Fork: true},
{Name: "claude-opus-4.6", Alias: "claude-opus-4-6", Fork: true},
{Name: "claude-sonnet-4.5", Alias: "claude-sonnet-4-5", Fork: true},
{Name: "claude-sonnet-4.6", Alias: "claude-sonnet-4-6", Fork: true},
}
}
// defaultAntigravityAliases returns the default oauth-model-alias configuration
// for the antigravity channel when neither field exists.
func defaultAntigravityAliases() []OAuthModelAlias {

View File

@@ -54,3 +54,208 @@ func TestSanitizeOAuthModelAlias_AllowsMultipleAliasesForSameName(t *testing.T)
}
}
}
func TestSanitizeOAuthModelAlias_InjectsDefaultKiroAliases(t *testing.T) {
// When no kiro aliases are configured, defaults should be injected
cfg := &Config{
OAuthModelAlias: map[string][]OAuthModelAlias{
"codex": {
{Name: "gpt-5", Alias: "g5"},
},
},
}
cfg.SanitizeOAuthModelAlias()
kiroAliases := cfg.OAuthModelAlias["kiro"]
if len(kiroAliases) == 0 {
t.Fatal("expected default kiro aliases to be injected")
}
// Check that standard Claude model names are present
aliasSet := make(map[string]bool)
for _, a := range kiroAliases {
aliasSet[a.Alias] = true
}
expectedAliases := []string{
"claude-sonnet-4-5-20250929",
"claude-sonnet-4-5",
"claude-sonnet-4-20250514",
"claude-sonnet-4",
"claude-opus-4-6",
"claude-opus-4-5-20251101",
"claude-opus-4-5",
"claude-haiku-4-5-20251001",
"claude-haiku-4-5",
}
for _, expected := range expectedAliases {
if !aliasSet[expected] {
t.Fatalf("expected default kiro alias %q to be present", expected)
}
}
// All should have fork=true
for _, a := range kiroAliases {
if !a.Fork {
t.Fatalf("expected all default kiro aliases to have fork=true, got fork=false for %q", a.Alias)
}
}
// Codex aliases should still be preserved
if len(cfg.OAuthModelAlias["codex"]) != 1 {
t.Fatal("expected codex aliases to be preserved")
}
}
func TestSanitizeOAuthModelAlias_InjectsDefaultGitHubCopilotAliases(t *testing.T) {
cfg := &Config{
OAuthModelAlias: map[string][]OAuthModelAlias{
"codex": {
{Name: "gpt-5", Alias: "g5"},
},
},
}
cfg.SanitizeOAuthModelAlias()
copilotAliases := cfg.OAuthModelAlias["github-copilot"]
if len(copilotAliases) == 0 {
t.Fatal("expected default github-copilot aliases to be injected")
}
aliasSet := make(map[string]bool, len(copilotAliases))
for _, a := range copilotAliases {
aliasSet[a.Alias] = true
if !a.Fork {
t.Fatalf("expected all default github-copilot aliases to have fork=true, got fork=false for %q", a.Alias)
}
}
expectedAliases := []string{
"claude-haiku-4-5",
"claude-opus-4-1",
"claude-opus-4-5",
"claude-opus-4-6",
"claude-sonnet-4-5",
"claude-sonnet-4-6",
}
for _, expected := range expectedAliases {
if !aliasSet[expected] {
t.Fatalf("expected default github-copilot alias %q to be present", expected)
}
}
}
func TestSanitizeOAuthModelAlias_DoesNotOverrideUserKiroAliases(t *testing.T) {
// When user has configured kiro aliases, defaults should NOT be injected
cfg := &Config{
OAuthModelAlias: map[string][]OAuthModelAlias{
"kiro": {
{Name: "kiro-claude-sonnet-4", Alias: "my-custom-sonnet", Fork: true},
},
},
}
cfg.SanitizeOAuthModelAlias()
kiroAliases := cfg.OAuthModelAlias["kiro"]
if len(kiroAliases) != 1 {
t.Fatalf("expected 1 user-configured kiro alias, got %d", len(kiroAliases))
}
if kiroAliases[0].Alias != "my-custom-sonnet" {
t.Fatalf("expected user alias to be preserved, got %q", kiroAliases[0].Alias)
}
}
func TestSanitizeOAuthModelAlias_DoesNotOverrideUserGitHubCopilotAliases(t *testing.T) {
cfg := &Config{
OAuthModelAlias: map[string][]OAuthModelAlias{
"github-copilot": {
{Name: "claude-opus-4.6", Alias: "my-opus", Fork: true},
},
},
}
cfg.SanitizeOAuthModelAlias()
copilotAliases := cfg.OAuthModelAlias["github-copilot"]
if len(copilotAliases) != 1 {
t.Fatalf("expected 1 user-configured github-copilot alias, got %d", len(copilotAliases))
}
if copilotAliases[0].Alias != "my-opus" {
t.Fatalf("expected user alias to be preserved, got %q", copilotAliases[0].Alias)
}
}
func TestSanitizeOAuthModelAlias_DoesNotReinjectAfterExplicitDeletion(t *testing.T) {
// When user explicitly deletes kiro aliases (key exists with nil value),
// defaults should NOT be re-injected on subsequent sanitize calls (#222).
cfg := &Config{
OAuthModelAlias: map[string][]OAuthModelAlias{
"kiro": nil, // explicitly deleted
"codex": {{Name: "gpt-5", Alias: "g5"}},
},
}
cfg.SanitizeOAuthModelAlias()
kiroAliases := cfg.OAuthModelAlias["kiro"]
if len(kiroAliases) != 0 {
t.Fatalf("expected kiro aliases to remain empty after explicit deletion, got %d aliases", len(kiroAliases))
}
// The key itself must still be present to prevent re-injection on next reload
if _, exists := cfg.OAuthModelAlias["kiro"]; !exists {
t.Fatal("expected kiro key to be preserved as nil marker after sanitization")
}
// Other channels should be unaffected
if len(cfg.OAuthModelAlias["codex"]) != 1 {
t.Fatal("expected codex aliases to be preserved")
}
}
func TestSanitizeOAuthModelAlias_GitHubCopilotDoesNotReinjectAfterExplicitDeletion(t *testing.T) {
cfg := &Config{
OAuthModelAlias: map[string][]OAuthModelAlias{
"github-copilot": nil, // explicitly deleted
},
}
cfg.SanitizeOAuthModelAlias()
copilotAliases := cfg.OAuthModelAlias["github-copilot"]
if len(copilotAliases) != 0 {
t.Fatalf("expected github-copilot aliases to remain empty after explicit deletion, got %d aliases", len(copilotAliases))
}
if _, exists := cfg.OAuthModelAlias["github-copilot"]; !exists {
t.Fatal("expected github-copilot key to be preserved as nil marker after sanitization")
}
}
func TestSanitizeOAuthModelAlias_DoesNotReinjectAfterExplicitDeletionEmpty(t *testing.T) {
// Same as above but with empty slice instead of nil (PUT with empty body).
cfg := &Config{
OAuthModelAlias: map[string][]OAuthModelAlias{
"kiro": {}, // explicitly set to empty
},
}
cfg.SanitizeOAuthModelAlias()
if len(cfg.OAuthModelAlias["kiro"]) != 0 {
t.Fatalf("expected kiro aliases to remain empty, got %d aliases", len(cfg.OAuthModelAlias["kiro"]))
}
if _, exists := cfg.OAuthModelAlias["kiro"]; !exists {
t.Fatal("expected kiro key to be preserved")
}
}
func TestSanitizeOAuthModelAlias_InjectsDefaultKiroWhenEmpty(t *testing.T) {
// When OAuthModelAlias is nil, kiro defaults should still be injected
cfg := &Config{}
cfg.SanitizeOAuthModelAlias()
kiroAliases := cfg.OAuthModelAlias["kiro"]
if len(kiroAliases) == 0 {
t.Fatal("expected default kiro aliases to be injected when OAuthModelAlias is nil")
}
}

View File

@@ -24,4 +24,10 @@ const (
// Antigravity represents the Antigravity response format identifier.
Antigravity = "antigravity"
// Kiro represents the AWS CodeWhisperer (Kiro) provider identifier.
Kiro = "kiro"
// Kilo represents the Kilo AI provider identifier.
Kilo = "kilo"
)

View File

@@ -85,6 +85,7 @@ func (m *LogFormatter) Format(entry *log.Entry) ([]byte, error) {
func SetupBaseLogger() {
setupOnce.Do(func() {
log.SetOutput(os.Stdout)
log.SetLevel(log.InfoLevel)
log.SetReportCaller(true)
log.SetFormatter(&LogFormatter{})

View File

@@ -0,0 +1,21 @@
// Package registry provides model definitions for various AI service providers.
package registry
// GetKiloModels returns the Kilo model definitions
func GetKiloModels() []*ModelInfo {
return []*ModelInfo{
// --- Base Models ---
{
ID: "kilo/auto",
Object: "model",
Created: 1732752000,
OwnedBy: "kilo",
Type: "kilo",
DisplayName: "Kilo Auto",
Description: "Automatic model selection by Kilo",
ContextLength: 200000,
MaxCompletionTokens: 64000,
Thinking: &ThinkingSupport{Min: 1024, Max: 32000, ZeroAllowed: true, DynamicAllowed: true},
},
}
}

View File

@@ -0,0 +1,303 @@
// Package registry provides Kiro model conversion utilities.
// This file handles converting dynamic Kiro API model lists to the internal ModelInfo format,
// and merging with static metadata for thinking support and other capabilities.
package registry
import (
"strings"
"time"
)
// KiroAPIModel represents a model from Kiro API response.
// This is a local copy to avoid import cycles with the kiro package.
// The structure mirrors kiro.KiroModel for easy data conversion.
type KiroAPIModel struct {
// ModelID is the unique identifier for the model (e.g., "claude-sonnet-4.5")
ModelID string
// ModelName is the human-readable name
ModelName string
// Description is the model description
Description string
// RateMultiplier is the credit multiplier for this model
RateMultiplier float64
// RateUnit is the unit for rate calculation (e.g., "credit")
RateUnit string
// MaxInputTokens is the maximum input token limit
MaxInputTokens int
}
// DefaultKiroThinkingSupport defines the default thinking configuration for Kiro models.
// All Kiro models support thinking with the following budget range.
var DefaultKiroThinkingSupport = &ThinkingSupport{
Min: 1024, // Minimum thinking budget tokens
Max: 32000, // Maximum thinking budget tokens
ZeroAllowed: true, // Allow disabling thinking with 0
DynamicAllowed: true, // Allow dynamic thinking budget (-1)
}
// DefaultKiroContextLength is the default context window size for Kiro models.
const DefaultKiroContextLength = 200000
// DefaultKiroMaxCompletionTokens is the default max completion tokens for Kiro models.
const DefaultKiroMaxCompletionTokens = 64000
// ConvertKiroAPIModels converts Kiro API models to internal ModelInfo format.
// It performs the following transformations:
// - Normalizes model ID (e.g., claude-sonnet-4.5 → kiro-claude-sonnet-4-5)
// - Adds default thinking support metadata
// - Sets default context length and max completion tokens if not provided
//
// Parameters:
// - kiroModels: List of models from Kiro API response
//
// Returns:
// - []*ModelInfo: Converted model information list
func ConvertKiroAPIModels(kiroModels []*KiroAPIModel) []*ModelInfo {
if len(kiroModels) == 0 {
return nil
}
now := time.Now().Unix()
result := make([]*ModelInfo, 0, len(kiroModels))
for _, km := range kiroModels {
// Skip nil models
if km == nil {
continue
}
// Skip models without valid ID
if km.ModelID == "" {
continue
}
// Normalize the model ID to kiro-* format
normalizedID := normalizeKiroModelID(km.ModelID)
// Create ModelInfo with converted data
info := &ModelInfo{
ID: normalizedID,
Object: "model",
Created: now,
OwnedBy: "aws",
Type: "kiro",
DisplayName: generateKiroDisplayName(km.ModelName, normalizedID),
Description: km.Description,
// Use MaxInputTokens from API if available, otherwise use default
ContextLength: getContextLength(km.MaxInputTokens),
MaxCompletionTokens: DefaultKiroMaxCompletionTokens,
// All Kiro models support thinking
Thinking: cloneThinkingSupport(DefaultKiroThinkingSupport),
}
result = append(result, info)
}
return result
}
// GenerateAgenticVariants creates -agentic variants for each model.
// Agentic variants are optimized for coding agents with chunked writes.
//
// Parameters:
// - models: Base models to generate variants for
//
// Returns:
// - []*ModelInfo: Combined list of base models and their agentic variants
func GenerateAgenticVariants(models []*ModelInfo) []*ModelInfo {
if len(models) == 0 {
return nil
}
// Pre-allocate result with capacity for both base models and variants
result := make([]*ModelInfo, 0, len(models)*2)
for _, model := range models {
if model == nil {
continue
}
// Add the base model first
result = append(result, model)
// Skip if model already has -agentic suffix
if strings.HasSuffix(model.ID, "-agentic") {
continue
}
// Skip special models that shouldn't have agentic variants
if model.ID == "kiro-auto" {
continue
}
// Create agentic variant
agenticModel := &ModelInfo{
ID: model.ID + "-agentic",
Object: model.Object,
Created: model.Created,
OwnedBy: model.OwnedBy,
Type: model.Type,
DisplayName: model.DisplayName + " (Agentic)",
Description: generateAgenticDescription(model.Description),
ContextLength: model.ContextLength,
MaxCompletionTokens: model.MaxCompletionTokens,
Thinking: cloneThinkingSupport(model.Thinking),
}
result = append(result, agenticModel)
}
return result
}
// MergeWithStaticMetadata merges dynamic models with static metadata.
// Static metadata takes priority for any overlapping fields.
// This allows manual overrides for specific models while keeping dynamic discovery.
//
// Parameters:
// - dynamicModels: Models from Kiro API (converted to ModelInfo)
// - staticModels: Predefined model metadata (from GetKiroModels())
//
// Returns:
// - []*ModelInfo: Merged model list with static metadata taking priority
func MergeWithStaticMetadata(dynamicModels, staticModels []*ModelInfo) []*ModelInfo {
if len(dynamicModels) == 0 && len(staticModels) == 0 {
return nil
}
// Build a map of static models for quick lookup
staticMap := make(map[string]*ModelInfo, len(staticModels))
for _, sm := range staticModels {
if sm != nil && sm.ID != "" {
staticMap[sm.ID] = sm
}
}
// Build result, preferring static metadata where available
seenIDs := make(map[string]struct{})
result := make([]*ModelInfo, 0, len(dynamicModels)+len(staticModels))
// First, process dynamic models and merge with static if available
for _, dm := range dynamicModels {
if dm == nil || dm.ID == "" {
continue
}
// Skip duplicates
if _, seen := seenIDs[dm.ID]; seen {
continue
}
seenIDs[dm.ID] = struct{}{}
// Check if static metadata exists for this model
if sm, exists := staticMap[dm.ID]; exists {
// Static metadata takes priority - use static model
result = append(result, sm)
} else {
// No static metadata - use dynamic model
result = append(result, dm)
}
}
// Add any static models not in dynamic list
for _, sm := range staticModels {
if sm == nil || sm.ID == "" {
continue
}
if _, seen := seenIDs[sm.ID]; seen {
continue
}
seenIDs[sm.ID] = struct{}{}
result = append(result, sm)
}
return result
}
// normalizeKiroModelID converts Kiro API model IDs to internal format.
// Transformation rules:
// - Adds "kiro-" prefix if not present
// - Replaces dots with hyphens (e.g., 4.5 → 4-5)
// - Handles special cases like "auto" → "kiro-auto"
//
// Examples:
// - "claude-sonnet-4.5" → "kiro-claude-sonnet-4-5"
// - "claude-opus-4.5" → "kiro-claude-opus-4-5"
// - "auto" → "kiro-auto"
// - "kiro-claude-sonnet-4-5" → "kiro-claude-sonnet-4-5" (unchanged)
func normalizeKiroModelID(modelID string) string {
if modelID == "" {
return ""
}
// Trim whitespace
modelID = strings.TrimSpace(modelID)
// Replace dots with hyphens (e.g., 4.5 → 4-5)
normalized := strings.ReplaceAll(modelID, ".", "-")
// Add kiro- prefix if not present
if !strings.HasPrefix(normalized, "kiro-") {
normalized = "kiro-" + normalized
}
return normalized
}
// generateKiroDisplayName creates a human-readable display name.
// Uses the API-provided model name if available, otherwise generates from ID.
func generateKiroDisplayName(modelName, normalizedID string) string {
if modelName != "" {
return "Kiro " + modelName
}
// Generate from normalized ID by removing kiro- prefix and formatting
displayID := strings.TrimPrefix(normalizedID, "kiro-")
// Capitalize first letter of each word
words := strings.Split(displayID, "-")
for i, word := range words {
if len(word) > 0 {
words[i] = strings.ToUpper(word[:1]) + word[1:]
}
}
return "Kiro " + strings.Join(words, " ")
}
// generateAgenticDescription creates description for agentic variants.
func generateAgenticDescription(baseDescription string) string {
if baseDescription == "" {
return "Optimized for coding agents with chunked writes"
}
return baseDescription + " (Agentic mode: chunked writes)"
}
// getContextLength returns the context length, using default if not provided.
func getContextLength(maxInputTokens int) int {
if maxInputTokens > 0 {
return maxInputTokens
}
return DefaultKiroContextLength
}
// cloneThinkingSupport creates a deep copy of ThinkingSupport.
// Returns nil if input is nil.
func cloneThinkingSupport(ts *ThinkingSupport) *ThinkingSupport {
if ts == nil {
return nil
}
clone := &ThinkingSupport{
Min: ts.Min,
Max: ts.Max,
ZeroAllowed: ts.ZeroAllowed,
DynamicAllowed: ts.DynamicAllowed,
}
// Deep copy Levels slice if present
if len(ts.Levels) > 0 {
clone.Levels = make([]string, len(ts.Levels))
copy(clone.Levels, ts.Levels)
}
return clone
}

View File

@@ -20,6 +20,11 @@ import (
// - qwen
// - iflow
// - kimi
// - kiro
// - kilo
// - github-copilot
// - kiro
// - amazonq
// - antigravity (returns static overrides only)
func GetStaticModelDefinitionsByChannel(channel string) []*ModelInfo {
key := strings.ToLower(strings.TrimSpace(channel))
@@ -42,6 +47,14 @@ func GetStaticModelDefinitionsByChannel(channel string) []*ModelInfo {
return GetIFlowModels()
case "kimi":
return GetKimiModels()
case "github-copilot":
return GetGitHubCopilotModels()
case "kiro":
return GetKiroModels()
case "kilo":
return GetKiloModels()
case "amazonq":
return GetAmazonQModels()
case "antigravity":
cfg := GetAntigravityModelConfig()
if len(cfg) == 0 {
@@ -87,6 +100,10 @@ func LookupStaticModelInfo(modelID string) *ModelInfo {
GetQwenModels(),
GetIFlowModels(),
GetKimiModels(),
GetGitHubCopilotModels(),
GetKiroModels(),
GetKiloModels(),
GetAmazonQModels(),
}
for _, models := range allModels {
for _, m := range models {
@@ -107,3 +124,675 @@ func LookupStaticModelInfo(modelID string) *ModelInfo {
return nil
}
// GetGitHubCopilotModels returns the available models for GitHub Copilot.
// These models are available through the GitHub Copilot API at api.githubcopilot.com.
func GetGitHubCopilotModels() []*ModelInfo {
now := int64(1732752000) // 2024-11-27
gpt4oEntries := []struct {
ID string
DisplayName string
Description string
}{
{ID: "gpt-4o-2024-11-20", DisplayName: "GPT-4o (2024-11-20)", Description: "OpenAI GPT-4o 2024-11-20 via GitHub Copilot"},
{ID: "gpt-4o-2024-08-06", DisplayName: "GPT-4o (2024-08-06)", Description: "OpenAI GPT-4o 2024-08-06 via GitHub Copilot"},
{ID: "gpt-4o-2024-05-13", DisplayName: "GPT-4o (2024-05-13)", Description: "OpenAI GPT-4o 2024-05-13 via GitHub Copilot"},
{ID: "gpt-4o", DisplayName: "GPT-4o", Description: "OpenAI GPT-4o via GitHub Copilot"},
{ID: "gpt-4-o-preview", DisplayName: "GPT-4-o Preview", Description: "OpenAI GPT-4-o Preview via GitHub Copilot"},
}
models := []*ModelInfo{
{
ID: "gpt-4.1",
Object: "model",
Created: now,
OwnedBy: "github-copilot",
Type: "github-copilot",
DisplayName: "GPT-4.1",
Description: "OpenAI GPT-4.1 via GitHub Copilot",
ContextLength: 128000,
MaxCompletionTokens: 16384,
},
}
for _, entry := range gpt4oEntries {
models = append(models, &ModelInfo{
ID: entry.ID,
Object: "model",
Created: now,
OwnedBy: "github-copilot",
Type: "github-copilot",
DisplayName: entry.DisplayName,
Description: entry.Description,
ContextLength: 128000,
MaxCompletionTokens: 16384,
})
}
return append(models, []*ModelInfo{
{
ID: "gpt-5",
Object: "model",
Created: now,
OwnedBy: "github-copilot",
Type: "github-copilot",
DisplayName: "GPT-5",
Description: "OpenAI GPT-5 via GitHub Copilot",
ContextLength: 200000,
MaxCompletionTokens: 32768,
SupportedEndpoints: []string{"/chat/completions", "/responses"},
Thinking: &ThinkingSupport{Levels: []string{"low", "medium", "high"}},
},
{
ID: "gpt-5-mini",
Object: "model",
Created: now,
OwnedBy: "github-copilot",
Type: "github-copilot",
DisplayName: "GPT-5 Mini",
Description: "OpenAI GPT-5 Mini via GitHub Copilot",
ContextLength: 128000,
MaxCompletionTokens: 16384,
SupportedEndpoints: []string{"/chat/completions", "/responses"},
Thinking: &ThinkingSupport{Levels: []string{"low", "medium", "high"}},
},
{
ID: "gpt-5-codex",
Object: "model",
Created: now,
OwnedBy: "github-copilot",
Type: "github-copilot",
DisplayName: "GPT-5 Codex",
Description: "OpenAI GPT-5 Codex via GitHub Copilot",
ContextLength: 200000,
MaxCompletionTokens: 32768,
SupportedEndpoints: []string{"/responses"},
Thinking: &ThinkingSupport{Levels: []string{"low", "medium", "high"}},
},
{
ID: "gpt-5.1",
Object: "model",
Created: now,
OwnedBy: "github-copilot",
Type: "github-copilot",
DisplayName: "GPT-5.1",
Description: "OpenAI GPT-5.1 via GitHub Copilot",
ContextLength: 200000,
MaxCompletionTokens: 32768,
SupportedEndpoints: []string{"/chat/completions", "/responses"},
Thinking: &ThinkingSupport{Levels: []string{"none", "low", "medium", "high"}},
},
{
ID: "gpt-5.1-codex",
Object: "model",
Created: now,
OwnedBy: "github-copilot",
Type: "github-copilot",
DisplayName: "GPT-5.1 Codex",
Description: "OpenAI GPT-5.1 Codex via GitHub Copilot",
ContextLength: 200000,
MaxCompletionTokens: 32768,
SupportedEndpoints: []string{"/responses"},
Thinking: &ThinkingSupport{Levels: []string{"none", "low", "medium", "high"}},
},
{
ID: "gpt-5.1-codex-mini",
Object: "model",
Created: now,
OwnedBy: "github-copilot",
Type: "github-copilot",
DisplayName: "GPT-5.1 Codex Mini",
Description: "OpenAI GPT-5.1 Codex Mini via GitHub Copilot",
ContextLength: 128000,
MaxCompletionTokens: 16384,
SupportedEndpoints: []string{"/responses"},
Thinking: &ThinkingSupport{Levels: []string{"none", "low", "medium", "high"}},
},
{
ID: "gpt-5.1-codex-max",
Object: "model",
Created: now,
OwnedBy: "github-copilot",
Type: "github-copilot",
DisplayName: "GPT-5.1 Codex Max",
Description: "OpenAI GPT-5.1 Codex Max via GitHub Copilot",
ContextLength: 200000,
MaxCompletionTokens: 32768,
SupportedEndpoints: []string{"/responses"},
Thinking: &ThinkingSupport{Levels: []string{"none", "low", "medium", "high", "xhigh"}},
},
{
ID: "gpt-5.2",
Object: "model",
Created: now,
OwnedBy: "github-copilot",
Type: "github-copilot",
DisplayName: "GPT-5.2",
Description: "OpenAI GPT-5.2 via GitHub Copilot",
ContextLength: 200000,
MaxCompletionTokens: 32768,
SupportedEndpoints: []string{"/chat/completions", "/responses"},
Thinking: &ThinkingSupport{Levels: []string{"none", "low", "medium", "high", "xhigh"}},
},
{
ID: "gpt-5.2-codex",
Object: "model",
Created: now,
OwnedBy: "github-copilot",
Type: "github-copilot",
DisplayName: "GPT-5.2 Codex",
Description: "OpenAI GPT-5.2 Codex via GitHub Copilot",
ContextLength: 200000,
MaxCompletionTokens: 32768,
SupportedEndpoints: []string{"/responses"},
Thinking: &ThinkingSupport{Levels: []string{"none", "low", "medium", "high", "xhigh"}},
},
{
ID: "gpt-5.3-codex",
Object: "model",
Created: now,
OwnedBy: "github-copilot",
Type: "github-copilot",
DisplayName: "GPT-5.3 Codex",
Description: "OpenAI GPT-5.3 Codex via GitHub Copilot",
ContextLength: 200000,
MaxCompletionTokens: 32768,
SupportedEndpoints: []string{"/responses"},
Thinking: &ThinkingSupport{Levels: []string{"none", "low", "medium", "high", "xhigh"}},
},
{
ID: "claude-haiku-4.5",
Object: "model",
Created: now,
OwnedBy: "github-copilot",
Type: "github-copilot",
DisplayName: "Claude Haiku 4.5",
Description: "Anthropic Claude Haiku 4.5 via GitHub Copilot",
ContextLength: 200000,
MaxCompletionTokens: 64000,
SupportedEndpoints: []string{"/chat/completions"},
},
{
ID: "claude-opus-4.1",
Object: "model",
Created: now,
OwnedBy: "github-copilot",
Type: "github-copilot",
DisplayName: "Claude Opus 4.1",
Description: "Anthropic Claude Opus 4.1 via GitHub Copilot",
ContextLength: 200000,
MaxCompletionTokens: 32000,
SupportedEndpoints: []string{"/chat/completions"},
},
{
ID: "claude-opus-4.5",
Object: "model",
Created: now,
OwnedBy: "github-copilot",
Type: "github-copilot",
DisplayName: "Claude Opus 4.5",
Description: "Anthropic Claude Opus 4.5 via GitHub Copilot",
ContextLength: 200000,
MaxCompletionTokens: 64000,
SupportedEndpoints: []string{"/chat/completions"},
},
{
ID: "claude-opus-4.6",
Object: "model",
Created: now,
OwnedBy: "github-copilot",
Type: "github-copilot",
DisplayName: "Claude Opus 4.6",
Description: "Anthropic Claude Opus 4.6 via GitHub Copilot",
ContextLength: 200000,
MaxCompletionTokens: 64000,
SupportedEndpoints: []string{"/chat/completions"},
},
{
ID: "claude-sonnet-4",
Object: "model",
Created: now,
OwnedBy: "github-copilot",
Type: "github-copilot",
DisplayName: "Claude Sonnet 4",
Description: "Anthropic Claude Sonnet 4 via GitHub Copilot",
ContextLength: 200000,
MaxCompletionTokens: 64000,
SupportedEndpoints: []string{"/chat/completions"},
},
{
ID: "claude-sonnet-4.5",
Object: "model",
Created: now,
OwnedBy: "github-copilot",
Type: "github-copilot",
DisplayName: "Claude Sonnet 4.5",
Description: "Anthropic Claude Sonnet 4.5 via GitHub Copilot",
ContextLength: 200000,
MaxCompletionTokens: 64000,
SupportedEndpoints: []string{"/chat/completions"},
},
{
ID: "claude-sonnet-4.6",
Object: "model",
Created: now,
OwnedBy: "github-copilot",
Type: "github-copilot",
DisplayName: "Claude Sonnet 4.6",
Description: "Anthropic Claude Sonnet 4.6 via GitHub Copilot",
ContextLength: 200000,
MaxCompletionTokens: 64000,
SupportedEndpoints: []string{"/chat/completions"},
},
{
ID: "gemini-2.5-pro",
Object: "model",
Created: now,
OwnedBy: "github-copilot",
Type: "github-copilot",
DisplayName: "Gemini 2.5 Pro",
Description: "Google Gemini 2.5 Pro via GitHub Copilot",
ContextLength: 1048576,
MaxCompletionTokens: 65536,
},
{
ID: "gemini-3-pro-preview",
Object: "model",
Created: now,
OwnedBy: "github-copilot",
Type: "github-copilot",
DisplayName: "Gemini 3 Pro (Preview)",
Description: "Google Gemini 3 Pro Preview via GitHub Copilot",
ContextLength: 1048576,
MaxCompletionTokens: 65536,
},
{
ID: "gemini-3.1-pro-preview",
Object: "model",
Created: now,
OwnedBy: "github-copilot",
Type: "github-copilot",
DisplayName: "Gemini 3.1 Pro (Preview)",
Description: "Google Gemini 3.1 Pro Preview via GitHub Copilot",
ContextLength: 1048576,
MaxCompletionTokens: 65536,
},
{
ID: "gemini-3-flash-preview",
Object: "model",
Created: now,
OwnedBy: "github-copilot",
Type: "github-copilot",
DisplayName: "Gemini 3 Flash (Preview)",
Description: "Google Gemini 3 Flash Preview via GitHub Copilot",
ContextLength: 1048576,
MaxCompletionTokens: 65536,
},
{
ID: "grok-code-fast-1",
Object: "model",
Created: now,
OwnedBy: "github-copilot",
Type: "github-copilot",
DisplayName: "Grok Code Fast 1",
Description: "xAI Grok Code Fast 1 via GitHub Copilot",
ContextLength: 128000,
MaxCompletionTokens: 16384,
},
{
ID: "oswe-vscode-prime",
Object: "model",
Created: now,
OwnedBy: "github-copilot",
Type: "github-copilot",
DisplayName: "Raptor mini (Preview)",
Description: "Raptor mini via GitHub Copilot",
ContextLength: 128000,
MaxCompletionTokens: 16384,
SupportedEndpoints: []string{"/chat/completions", "/responses"},
},
}...)
}
// GetKiroModels returns the Kiro (AWS CodeWhisperer) model definitions
func GetKiroModels() []*ModelInfo {
return []*ModelInfo{
// --- Base Models ---
{
ID: "kiro-auto",
Object: "model",
Created: 1732752000,
OwnedBy: "aws",
Type: "kiro",
DisplayName: "Kiro Auto",
Description: "Automatic model selection by Kiro",
ContextLength: 200000,
MaxCompletionTokens: 64000,
Thinking: &ThinkingSupport{Min: 1024, Max: 32000, ZeroAllowed: true, DynamicAllowed: true},
},
{
ID: "kiro-claude-opus-4-6",
Object: "model",
Created: 1736899200, // 2025-01-15
OwnedBy: "aws",
Type: "kiro",
DisplayName: "Kiro Claude Opus 4.6",
Description: "Claude Opus 4.6 via Kiro (2.2x credit)",
ContextLength: 200000,
MaxCompletionTokens: 64000,
Thinking: &ThinkingSupport{Min: 1024, Max: 32000, ZeroAllowed: true, DynamicAllowed: true},
},
{
ID: "kiro-claude-sonnet-4-6",
Object: "model",
Created: 1739836800, // 2025-02-18
OwnedBy: "aws",
Type: "kiro",
DisplayName: "Kiro Claude Sonnet 4.6",
Description: "Claude Sonnet 4.6 via Kiro (1.3x credit)",
ContextLength: 200000,
MaxCompletionTokens: 64000,
Thinking: &ThinkingSupport{Min: 1024, Max: 32000, ZeroAllowed: true, DynamicAllowed: true},
},
{
ID: "kiro-claude-opus-4-5",
Object: "model",
Created: 1732752000,
OwnedBy: "aws",
Type: "kiro",
DisplayName: "Kiro Claude Opus 4.5",
Description: "Claude Opus 4.5 via Kiro (2.2x credit)",
ContextLength: 200000,
MaxCompletionTokens: 64000,
Thinking: &ThinkingSupport{Min: 1024, Max: 32000, ZeroAllowed: true, DynamicAllowed: true},
},
{
ID: "kiro-claude-sonnet-4-5",
Object: "model",
Created: 1732752000,
OwnedBy: "aws",
Type: "kiro",
DisplayName: "Kiro Claude Sonnet 4.5",
Description: "Claude Sonnet 4.5 via Kiro (1.3x credit)",
ContextLength: 200000,
MaxCompletionTokens: 64000,
Thinking: &ThinkingSupport{Min: 1024, Max: 32000, ZeroAllowed: true, DynamicAllowed: true},
},
{
ID: "kiro-claude-sonnet-4",
Object: "model",
Created: 1732752000,
OwnedBy: "aws",
Type: "kiro",
DisplayName: "Kiro Claude Sonnet 4",
Description: "Claude Sonnet 4 via Kiro (1.3x credit)",
ContextLength: 200000,
MaxCompletionTokens: 64000,
Thinking: &ThinkingSupport{Min: 1024, Max: 32000, ZeroAllowed: true, DynamicAllowed: true},
},
{
ID: "kiro-claude-haiku-4-5",
Object: "model",
Created: 1732752000,
OwnedBy: "aws",
Type: "kiro",
DisplayName: "Kiro Claude Haiku 4.5",
Description: "Claude Haiku 4.5 via Kiro (0.4x credit)",
ContextLength: 200000,
MaxCompletionTokens: 64000,
Thinking: &ThinkingSupport{Min: 1024, Max: 32000, ZeroAllowed: true, DynamicAllowed: true},
},
// --- 第三方模型 (通过 Kiro 接入) ---
{
ID: "kiro-deepseek-3-2",
Object: "model",
Created: 1732752000,
OwnedBy: "aws",
Type: "kiro",
DisplayName: "Kiro DeepSeek 3.2",
Description: "DeepSeek 3.2 via Kiro",
ContextLength: 128000,
MaxCompletionTokens: 32768,
Thinking: &ThinkingSupport{Min: 1024, Max: 32000, ZeroAllowed: true, DynamicAllowed: true},
},
{
ID: "kiro-minimax-m2-1",
Object: "model",
Created: 1732752000,
OwnedBy: "aws",
Type: "kiro",
DisplayName: "Kiro MiniMax M2.1",
Description: "MiniMax M2.1 via Kiro",
ContextLength: 200000,
MaxCompletionTokens: 64000,
Thinking: &ThinkingSupport{Min: 1024, Max: 32000, ZeroAllowed: true, DynamicAllowed: true},
},
{
ID: "kiro-qwen3-coder-next",
Object: "model",
Created: 1732752000,
OwnedBy: "aws",
Type: "kiro",
DisplayName: "Kiro Qwen3 Coder Next",
Description: "Qwen3 Coder Next via Kiro",
ContextLength: 128000,
MaxCompletionTokens: 32768,
Thinking: &ThinkingSupport{Min: 1024, Max: 32000, ZeroAllowed: true, DynamicAllowed: true},
},
{
ID: "kiro-gpt-4o",
Object: "model",
Created: 1732752000,
OwnedBy: "aws",
Type: "kiro",
DisplayName: "Kiro GPT-4o",
Description: "OpenAI GPT-4o via Kiro",
ContextLength: 128000,
MaxCompletionTokens: 16384,
},
{
ID: "kiro-gpt-4",
Object: "model",
Created: 1732752000,
OwnedBy: "aws",
Type: "kiro",
DisplayName: "Kiro GPT-4",
Description: "OpenAI GPT-4 via Kiro",
ContextLength: 128000,
MaxCompletionTokens: 8192,
},
{
ID: "kiro-gpt-4-turbo",
Object: "model",
Created: 1732752000,
OwnedBy: "aws",
Type: "kiro",
DisplayName: "Kiro GPT-4 Turbo",
Description: "OpenAI GPT-4 Turbo via Kiro",
ContextLength: 128000,
MaxCompletionTokens: 16384,
},
{
ID: "kiro-gpt-3-5-turbo",
Object: "model",
Created: 1732752000,
OwnedBy: "aws",
Type: "kiro",
DisplayName: "Kiro GPT-3.5 Turbo",
Description: "OpenAI GPT-3.5 Turbo via Kiro",
ContextLength: 16384,
MaxCompletionTokens: 4096,
},
// --- Agentic Variants (Optimized for coding agents with chunked writes) ---
{
ID: "kiro-claude-opus-4-6-agentic",
Object: "model",
Created: 1736899200, // 2025-01-15
OwnedBy: "aws",
Type: "kiro",
DisplayName: "Kiro Claude Opus 4.6 (Agentic)",
Description: "Claude Opus 4.6 optimized for coding agents (chunked writes)",
ContextLength: 200000,
MaxCompletionTokens: 64000,
Thinking: &ThinkingSupport{Min: 1024, Max: 32000, ZeroAllowed: true, DynamicAllowed: true},
},
{
ID: "kiro-claude-sonnet-4-6-agentic",
Object: "model",
Created: 1739836800, // 2025-02-18
OwnedBy: "aws",
Type: "kiro",
DisplayName: "Kiro Claude Sonnet 4.6 (Agentic)",
Description: "Claude Sonnet 4.6 optimized for coding agents (chunked writes)",
ContextLength: 200000,
MaxCompletionTokens: 64000,
Thinking: &ThinkingSupport{Min: 1024, Max: 32000, ZeroAllowed: true, DynamicAllowed: true},
},
{
ID: "kiro-claude-opus-4-5-agentic",
Object: "model",
Created: 1732752000,
OwnedBy: "aws",
Type: "kiro",
DisplayName: "Kiro Claude Opus 4.5 (Agentic)",
Description: "Claude Opus 4.5 optimized for coding agents (chunked writes)",
ContextLength: 200000,
MaxCompletionTokens: 64000,
Thinking: &ThinkingSupport{Min: 1024, Max: 32000, ZeroAllowed: true, DynamicAllowed: true},
},
{
ID: "kiro-claude-sonnet-4-5-agentic",
Object: "model",
Created: 1732752000,
OwnedBy: "aws",
Type: "kiro",
DisplayName: "Kiro Claude Sonnet 4.5 (Agentic)",
Description: "Claude Sonnet 4.5 optimized for coding agents (chunked writes)",
ContextLength: 200000,
MaxCompletionTokens: 64000,
Thinking: &ThinkingSupport{Min: 1024, Max: 32000, ZeroAllowed: true, DynamicAllowed: true},
},
{
ID: "kiro-claude-sonnet-4-agentic",
Object: "model",
Created: 1732752000,
OwnedBy: "aws",
Type: "kiro",
DisplayName: "Kiro Claude Sonnet 4 (Agentic)",
Description: "Claude Sonnet 4 optimized for coding agents (chunked writes)",
ContextLength: 200000,
MaxCompletionTokens: 64000,
Thinking: &ThinkingSupport{Min: 1024, Max: 32000, ZeroAllowed: true, DynamicAllowed: true},
},
{
ID: "kiro-claude-haiku-4-5-agentic",
Object: "model",
Created: 1732752000,
OwnedBy: "aws",
Type: "kiro",
DisplayName: "Kiro Claude Haiku 4.5 (Agentic)",
Description: "Claude Haiku 4.5 optimized for coding agents (chunked writes)",
ContextLength: 200000,
MaxCompletionTokens: 64000,
Thinking: &ThinkingSupport{Min: 1024, Max: 32000, ZeroAllowed: true, DynamicAllowed: true},
},
{
ID: "kiro-deepseek-3-2-agentic",
Object: "model",
Created: 1732752000,
OwnedBy: "aws",
Type: "kiro",
DisplayName: "Kiro DeepSeek 3.2 (Agentic)",
Description: "DeepSeek 3.2 optimized for coding agents (chunked writes)",
ContextLength: 128000,
MaxCompletionTokens: 32768,
Thinking: &ThinkingSupport{Min: 1024, Max: 32000, ZeroAllowed: true, DynamicAllowed: true},
},
{
ID: "kiro-minimax-m2-1-agentic",
Object: "model",
Created: 1732752000,
OwnedBy: "aws",
Type: "kiro",
DisplayName: "Kiro MiniMax M2.1 (Agentic)",
Description: "MiniMax M2.1 optimized for coding agents (chunked writes)",
ContextLength: 200000,
MaxCompletionTokens: 64000,
Thinking: &ThinkingSupport{Min: 1024, Max: 32000, ZeroAllowed: true, DynamicAllowed: true},
},
{
ID: "kiro-qwen3-coder-next-agentic",
Object: "model",
Created: 1732752000,
OwnedBy: "aws",
Type: "kiro",
DisplayName: "Kiro Qwen3 Coder Next (Agentic)",
Description: "Qwen3 Coder Next optimized for coding agents (chunked writes)",
ContextLength: 128000,
MaxCompletionTokens: 32768,
Thinking: &ThinkingSupport{Min: 1024, Max: 32000, ZeroAllowed: true, DynamicAllowed: true},
},
}
}
// GetAmazonQModels returns the Amazon Q (AWS CodeWhisperer) model definitions.
// These models use the same API as Kiro and share the same executor.
func GetAmazonQModels() []*ModelInfo {
return []*ModelInfo{
{
ID: "amazonq-auto",
Object: "model",
Created: 1732752000,
OwnedBy: "aws",
Type: "kiro", // Uses Kiro executor - same API
DisplayName: "Amazon Q Auto",
Description: "Automatic model selection by Amazon Q",
ContextLength: 200000,
MaxCompletionTokens: 64000,
},
{
ID: "amazonq-claude-opus-4.5",
Object: "model",
Created: 1732752000,
OwnedBy: "aws",
Type: "kiro",
DisplayName: "Amazon Q Claude Opus 4.5",
Description: "Claude Opus 4.5 via Amazon Q (2.2x credit)",
ContextLength: 200000,
MaxCompletionTokens: 64000,
},
{
ID: "amazonq-claude-sonnet-4.5",
Object: "model",
Created: 1732752000,
OwnedBy: "aws",
Type: "kiro",
DisplayName: "Amazon Q Claude Sonnet 4.5",
Description: "Claude Sonnet 4.5 via Amazon Q (1.3x credit)",
ContextLength: 200000,
MaxCompletionTokens: 64000,
},
{
ID: "amazonq-claude-sonnet-4",
Object: "model",
Created: 1732752000,
OwnedBy: "aws",
Type: "kiro",
DisplayName: "Amazon Q Claude Sonnet 4",
Description: "Claude Sonnet 4 via Amazon Q (1.3x credit)",
ContextLength: 200000,
MaxCompletionTokens: 64000,
},
{
ID: "amazonq-claude-haiku-4.5",
Object: "model",
Created: 1732752000,
OwnedBy: "aws",
Type: "kiro",
DisplayName: "Amazon Q Claude Haiku 4.5",
Description: "Claude Haiku 4.5 via Amazon Q (0.4x credit)",
ContextLength: 200000,
MaxCompletionTokens: 64000,
},
}
}

View File

@@ -51,6 +51,18 @@ func GetClaudeModels() []*ModelInfo {
MaxCompletionTokens: 128000,
Thinking: &ThinkingSupport{Min: 1024, Max: 128000, ZeroAllowed: true, DynamicAllowed: false},
},
{
ID: "claude-sonnet-4-6",
Object: "model",
Created: 1771286400, // 2026-02-17
OwnedBy: "anthropic",
Type: "claude",
DisplayName: "Claude 4.6 Sonnet",
Description: "Best combination of speed and intelligence",
ContextLength: 200000,
MaxCompletionTokens: 64000,
Thinking: &ThinkingSupport{Min: 1024, Max: 128000, ZeroAllowed: true, DynamicAllowed: false},
},
{
ID: "claude-opus-4-5-20251101",
Object: "model",
@@ -955,12 +967,12 @@ func GetAntigravityModelConfig() map[string]*AntigravityModelConfig {
"gemini-3.1-pro-high": {Thinking: &ThinkingSupport{Min: 128, Max: 32768, ZeroAllowed: false, DynamicAllowed: true, Levels: []string{"low", "high"}}},
"gemini-3.1-flash-image": {Thinking: &ThinkingSupport{Min: 128, Max: 32768, ZeroAllowed: false, DynamicAllowed: true, Levels: []string{"minimal", "high"}}},
"gemini-3-flash": {Thinking: &ThinkingSupport{Min: 128, Max: 32768, ZeroAllowed: false, DynamicAllowed: true, Levels: []string{"minimal", "low", "medium", "high"}}},
"claude-sonnet-4-5-thinking": {Thinking: &ThinkingSupport{Min: 1024, Max: 128000, ZeroAllowed: true, DynamicAllowed: true}, MaxCompletionTokens: 64000},
"claude-opus-4-5-thinking": {Thinking: &ThinkingSupport{Min: 1024, Max: 128000, ZeroAllowed: true, DynamicAllowed: true}, MaxCompletionTokens: 64000},
"claude-opus-4-6-thinking": {Thinking: &ThinkingSupport{Min: 1024, Max: 128000, ZeroAllowed: true, DynamicAllowed: true}, MaxCompletionTokens: 64000},
"claude-opus-4-6": {Thinking: &ThinkingSupport{Min: 1024, Max: 128000, ZeroAllowed: true, DynamicAllowed: true}, MaxCompletionTokens: 128000},
"claude-sonnet-4-5": {Thinking: &ThinkingSupport{Min: 1024, Max: 128000, ZeroAllowed: true, DynamicAllowed: true}, MaxCompletionTokens: 64000},
"claude-sonnet-4-6": {Thinking: &ThinkingSupport{Min: 1024, Max: 128000, ZeroAllowed: true, DynamicAllowed: true}, MaxCompletionTokens: 64000},
"claude-sonnet-4-5-thinking": {Thinking: &ThinkingSupport{Min: 1024, Max: 128000, ZeroAllowed: true, DynamicAllowed: true}, MaxCompletionTokens: 64000},
"claude-sonnet-4-6-thinking": {Thinking: &ThinkingSupport{Min: 1024, Max: 128000, ZeroAllowed: true, DynamicAllowed: true}, MaxCompletionTokens: 64000},
"gpt-oss-120b-medium": {},
"tab_flash_lite_preview": {},

View File

@@ -47,6 +47,8 @@ type ModelInfo struct {
MaxCompletionTokens int `json:"max_completion_tokens,omitempty"`
// SupportedParameters lists supported parameters
SupportedParameters []string `json:"supported_parameters,omitempty"`
// SupportedEndpoints lists supported API endpoints (e.g., "/chat/completions", "/responses").
SupportedEndpoints []string `json:"supported_endpoints,omitempty"`
// Thinking holds provider-specific reasoning/thinking budget capabilities.
// This is optional and currently used for Gemini thinking budget normalization.
@@ -499,6 +501,9 @@ func cloneModelInfo(model *ModelInfo) *ModelInfo {
if len(model.SupportedParameters) > 0 {
copyModel.SupportedParameters = append([]string(nil), model.SupportedParameters...)
}
if len(model.SupportedEndpoints) > 0 {
copyModel.SupportedEndpoints = append([]string(nil), model.SupportedEndpoints...)
}
return &copyModel
}
@@ -1023,9 +1028,13 @@ func (r *ModelRegistry) convertModelToMap(model *ModelInfo, handlerType string)
if len(model.SupportedParameters) > 0 {
result["supported_parameters"] = model.SupportedParameters
}
if len(model.SupportedEndpoints) > 0 {
result["supported_endpoints"] = model.SupportedEndpoints
}
return result
case "claude":
case "claude", "kiro", "antigravity":
// Claude, Kiro, and Antigravity all use Claude-compatible format for Claude Code client
result := map[string]any{
"id": model.ID,
"object": "model",
@@ -1040,6 +1049,19 @@ func (r *ModelRegistry) convertModelToMap(model *ModelInfo, handlerType string)
if model.DisplayName != "" {
result["display_name"] = model.DisplayName
}
// Add thinking support for Claude Code client
// Claude Code checks for "thinking" field (simple boolean) to enable tab toggle
// Also add "extended_thinking" for detailed budget info
if model.Thinking != nil {
result["thinking"] = true
result["extended_thinking"] = map[string]any{
"supported": true,
"min": model.Thinking.Min,
"max": model.Thinking.Max,
"zero_allowed": model.Thinking.ZeroAllowed,
"dynamic_allowed": model.Thinking.DynamicAllowed,
}
}
return result
case "gemini":

View File

@@ -1078,7 +1078,7 @@ func FetchAntigravityModels(ctx context.Context, auth *cliproxyauth.Auth, cfg *c
token, updatedAuth, errToken := exec.ensureAccessToken(ctx, auth)
if errToken != nil || token == "" {
return fallbackAntigravityPrimaryModels()
}
}
if updatedAuth != nil {
auth = updatedAuth
}

View File

@@ -29,6 +29,7 @@ func startCodexCacheCleanup() {
go func() {
ticker := time.NewTicker(codexCacheCleanupInterval)
defer ticker.Stop()
for range ticker.C {
purgeExpiredCodexCache()
}
@@ -38,8 +39,10 @@ func startCodexCacheCleanup() {
// purgeExpiredCodexCache removes entries that have expired.
func purgeExpiredCodexCache() {
now := time.Now()
codexCacheMu.Lock()
defer codexCacheMu.Unlock()
for key, cache := range codexCacheMap {
if cache.Expire.Before(now) {
delete(codexCacheMap, key)
@@ -66,3 +69,10 @@ func setCodexCache(key string, cache codexCache) {
codexCacheMap[key] = cache
codexCacheMu.Unlock()
}
// deleteCodexCache deletes a cache entry.
func deleteCodexCache(key string) {
codexCacheMu.Lock()
delete(codexCacheMap, key)
codexCacheMu.Unlock()
}

File diff suppressed because it is too large Load Diff

View File

@@ -0,0 +1,333 @@
package executor
import (
"net/http"
"strings"
"testing"
sdktranslator "github.com/router-for-me/CLIProxyAPI/v6/sdk/translator"
"github.com/tidwall/gjson"
)
func TestGitHubCopilotNormalizeModel_StripsSuffix(t *testing.T) {
t.Parallel()
tests := []struct {
name string
model string
wantModel string
}{
{
name: "suffix stripped",
model: "claude-opus-4.6(medium)",
wantModel: "claude-opus-4.6",
},
{
name: "no suffix unchanged",
model: "claude-opus-4.6",
wantModel: "claude-opus-4.6",
},
{
name: "different suffix stripped",
model: "gpt-4o(high)",
wantModel: "gpt-4o",
},
{
name: "numeric suffix stripped",
model: "gemini-2.5-pro(8192)",
wantModel: "gemini-2.5-pro",
},
}
e := &GitHubCopilotExecutor{}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
t.Parallel()
body := []byte(`{"model":"` + tt.model + `","messages":[]}`)
got := e.normalizeModel(tt.model, body)
gotModel := gjson.GetBytes(got, "model").String()
if gotModel != tt.wantModel {
t.Fatalf("normalizeModel() model = %q, want %q", gotModel, tt.wantModel)
}
})
}
}
func TestUseGitHubCopilotResponsesEndpoint_OpenAIResponseSource(t *testing.T) {
t.Parallel()
if !useGitHubCopilotResponsesEndpoint(sdktranslator.FromString("openai-response"), "claude-3-5-sonnet") {
t.Fatal("expected openai-response source to use /responses")
}
}
func TestUseGitHubCopilotResponsesEndpoint_CodexModel(t *testing.T) {
t.Parallel()
if !useGitHubCopilotResponsesEndpoint(sdktranslator.FromString("openai"), "gpt-5-codex") {
t.Fatal("expected codex model to use /responses")
}
}
func TestUseGitHubCopilotResponsesEndpoint_DefaultChat(t *testing.T) {
t.Parallel()
if useGitHubCopilotResponsesEndpoint(sdktranslator.FromString("openai"), "claude-3-5-sonnet") {
t.Fatal("expected default openai source with non-codex model to use /chat/completions")
}
}
func TestNormalizeGitHubCopilotChatTools_KeepFunctionOnly(t *testing.T) {
t.Parallel()
body := []byte(`{"tools":[{"type":"function","function":{"name":"ok"}},{"type":"code_interpreter"}],"tool_choice":"auto"}`)
got := normalizeGitHubCopilotChatTools(body)
tools := gjson.GetBytes(got, "tools").Array()
if len(tools) != 1 {
t.Fatalf("tools len = %d, want 1", len(tools))
}
if tools[0].Get("type").String() != "function" {
t.Fatalf("tool type = %q, want function", tools[0].Get("type").String())
}
}
func TestNormalizeGitHubCopilotChatTools_InvalidToolChoiceDowngradeToAuto(t *testing.T) {
t.Parallel()
body := []byte(`{"tools":[],"tool_choice":{"type":"function","function":{"name":"x"}}}`)
got := normalizeGitHubCopilotChatTools(body)
if gjson.GetBytes(got, "tool_choice").String() != "auto" {
t.Fatalf("tool_choice = %s, want auto", gjson.GetBytes(got, "tool_choice").Raw)
}
}
func TestNormalizeGitHubCopilotResponsesInput_MissingInputExtractedFromSystemAndMessages(t *testing.T) {
t.Parallel()
body := []byte(`{"system":"sys text","messages":[{"role":"user","content":"user text"},{"role":"assistant","content":[{"type":"text","text":"assistant text"}]}]}`)
got := normalizeGitHubCopilotResponsesInput(body)
in := gjson.GetBytes(got, "input")
if !in.IsArray() {
t.Fatalf("input type = %v, want array", in.Type)
}
raw := in.Raw
if !strings.Contains(raw, "sys text") || !strings.Contains(raw, "user text") || !strings.Contains(raw, "assistant text") {
t.Fatalf("input = %s, want structured array with all texts", raw)
}
if gjson.GetBytes(got, "messages").Exists() {
t.Fatal("messages should be removed after conversion")
}
if gjson.GetBytes(got, "system").Exists() {
t.Fatal("system should be removed after conversion")
}
}
func TestNormalizeGitHubCopilotResponsesInput_NonStringInputStringified(t *testing.T) {
t.Parallel()
body := []byte(`{"input":{"foo":"bar"}}`)
got := normalizeGitHubCopilotResponsesInput(body)
in := gjson.GetBytes(got, "input")
if in.Type != gjson.String {
t.Fatalf("input type = %v, want string", in.Type)
}
if !strings.Contains(in.String(), "foo") {
t.Fatalf("input = %q, want stringified object", in.String())
}
}
func TestNormalizeGitHubCopilotResponsesTools_FlattenFunctionTools(t *testing.T) {
t.Parallel()
body := []byte(`{"tools":[{"type":"function","function":{"name":"sum","description":"d","parameters":{"type":"object"}}},{"type":"web_search"}]}`)
got := normalizeGitHubCopilotResponsesTools(body)
tools := gjson.GetBytes(got, "tools").Array()
if len(tools) != 1 {
t.Fatalf("tools len = %d, want 1", len(tools))
}
if tools[0].Get("name").String() != "sum" {
t.Fatalf("tools[0].name = %q, want sum", tools[0].Get("name").String())
}
if !tools[0].Get("parameters").Exists() {
t.Fatal("expected parameters to be preserved")
}
}
func TestNormalizeGitHubCopilotResponsesTools_ClaudeFormatTools(t *testing.T) {
t.Parallel()
body := []byte(`{"tools":[{"name":"Bash","description":"Run commands","input_schema":{"type":"object","properties":{"command":{"type":"string"}},"required":["command"]}},{"name":"Read","description":"Read files","input_schema":{"type":"object","properties":{"path":{"type":"string"}}}}]}`)
got := normalizeGitHubCopilotResponsesTools(body)
tools := gjson.GetBytes(got, "tools").Array()
if len(tools) != 2 {
t.Fatalf("tools len = %d, want 2", len(tools))
}
if tools[0].Get("type").String() != "function" {
t.Fatalf("tools[0].type = %q, want function", tools[0].Get("type").String())
}
if tools[0].Get("name").String() != "Bash" {
t.Fatalf("tools[0].name = %q, want Bash", tools[0].Get("name").String())
}
if tools[0].Get("description").String() != "Run commands" {
t.Fatalf("tools[0].description = %q, want 'Run commands'", tools[0].Get("description").String())
}
if !tools[0].Get("parameters").Exists() {
t.Fatal("expected parameters to be set from input_schema")
}
if tools[0].Get("parameters.properties.command").Exists() != true {
t.Fatal("expected parameters.properties.command to exist")
}
if tools[1].Get("name").String() != "Read" {
t.Fatalf("tools[1].name = %q, want Read", tools[1].Get("name").String())
}
}
func TestNormalizeGitHubCopilotResponsesTools_FlattenToolChoiceFunctionObject(t *testing.T) {
t.Parallel()
body := []byte(`{"tool_choice":{"type":"function","function":{"name":"sum"}}}`)
got := normalizeGitHubCopilotResponsesTools(body)
if gjson.GetBytes(got, "tool_choice.type").String() != "function" {
t.Fatalf("tool_choice.type = %q, want function", gjson.GetBytes(got, "tool_choice.type").String())
}
if gjson.GetBytes(got, "tool_choice.name").String() != "sum" {
t.Fatalf("tool_choice.name = %q, want sum", gjson.GetBytes(got, "tool_choice.name").String())
}
}
func TestNormalizeGitHubCopilotResponsesTools_InvalidToolChoiceDowngradeToAuto(t *testing.T) {
t.Parallel()
body := []byte(`{"tool_choice":{"type":"function"}}`)
got := normalizeGitHubCopilotResponsesTools(body)
if gjson.GetBytes(got, "tool_choice").String() != "auto" {
t.Fatalf("tool_choice = %s, want auto", gjson.GetBytes(got, "tool_choice").Raw)
}
}
func TestTranslateGitHubCopilotResponsesNonStreamToClaude_TextMapping(t *testing.T) {
t.Parallel()
resp := []byte(`{"id":"resp_1","model":"gpt-5-codex","output":[{"type":"message","content":[{"type":"output_text","text":"hello"}]}],"usage":{"input_tokens":3,"output_tokens":5}}`)
out := translateGitHubCopilotResponsesNonStreamToClaude(resp)
if gjson.Get(out, "type").String() != "message" {
t.Fatalf("type = %q, want message", gjson.Get(out, "type").String())
}
if gjson.Get(out, "content.0.type").String() != "text" {
t.Fatalf("content.0.type = %q, want text", gjson.Get(out, "content.0.type").String())
}
if gjson.Get(out, "content.0.text").String() != "hello" {
t.Fatalf("content.0.text = %q, want hello", gjson.Get(out, "content.0.text").String())
}
}
func TestTranslateGitHubCopilotResponsesNonStreamToClaude_ToolUseMapping(t *testing.T) {
t.Parallel()
resp := []byte(`{"id":"resp_2","model":"gpt-5-codex","output":[{"type":"function_call","id":"fc_1","call_id":"call_1","name":"sum","arguments":"{\"a\":1}"}],"usage":{"input_tokens":1,"output_tokens":2}}`)
out := translateGitHubCopilotResponsesNonStreamToClaude(resp)
if gjson.Get(out, "content.0.type").String() != "tool_use" {
t.Fatalf("content.0.type = %q, want tool_use", gjson.Get(out, "content.0.type").String())
}
if gjson.Get(out, "content.0.name").String() != "sum" {
t.Fatalf("content.0.name = %q, want sum", gjson.Get(out, "content.0.name").String())
}
if gjson.Get(out, "stop_reason").String() != "tool_use" {
t.Fatalf("stop_reason = %q, want tool_use", gjson.Get(out, "stop_reason").String())
}
}
func TestTranslateGitHubCopilotResponsesStreamToClaude_TextLifecycle(t *testing.T) {
t.Parallel()
var param any
created := translateGitHubCopilotResponsesStreamToClaude([]byte(`data: {"type":"response.created","response":{"id":"resp_1","model":"gpt-5-codex"}}`), &param)
if len(created) == 0 || !strings.Contains(created[0], "message_start") {
t.Fatalf("created events = %#v, want message_start", created)
}
delta := translateGitHubCopilotResponsesStreamToClaude([]byte(`data: {"type":"response.output_text.delta","delta":"he"}`), &param)
joinedDelta := strings.Join(delta, "")
if !strings.Contains(joinedDelta, "content_block_start") || !strings.Contains(joinedDelta, "text_delta") {
t.Fatalf("delta events = %#v, want content_block_start + text_delta", delta)
}
completed := translateGitHubCopilotResponsesStreamToClaude([]byte(`data: {"type":"response.completed","response":{"usage":{"input_tokens":7,"output_tokens":9}}}`), &param)
joinedCompleted := strings.Join(completed, "")
if !strings.Contains(joinedCompleted, "message_delta") || !strings.Contains(joinedCompleted, "message_stop") {
t.Fatalf("completed events = %#v, want message_delta + message_stop", completed)
}
}
// --- Tests for X-Initiator detection logic (Problem L) ---
func TestApplyHeaders_XInitiator_UserOnly(t *testing.T) {
t.Parallel()
e := &GitHubCopilotExecutor{}
req, _ := http.NewRequest(http.MethodPost, "https://example.com", nil)
body := []byte(`{"messages":[{"role":"system","content":"sys"},{"role":"user","content":"hello"}]}`)
e.applyHeaders(req, "token", body)
if got := req.Header.Get("X-Initiator"); got != "user" {
t.Fatalf("X-Initiator = %q, want user", got)
}
}
func TestApplyHeaders_XInitiator_AgentWithAssistantAndUserToolResult(t *testing.T) {
t.Parallel()
e := &GitHubCopilotExecutor{}
req, _ := http.NewRequest(http.MethodPost, "https://example.com", nil)
// Claude Code typical flow: last message is user (tool result), but has assistant in history
body := []byte(`{"messages":[{"role":"user","content":"hello"},{"role":"assistant","content":"I will read the file"},{"role":"user","content":"tool result here"}]}`)
e.applyHeaders(req, "token", body)
if got := req.Header.Get("X-Initiator"); got != "agent" {
t.Fatalf("X-Initiator = %q, want agent (assistant exists in messages)", got)
}
}
func TestApplyHeaders_XInitiator_AgentWithToolRole(t *testing.T) {
t.Parallel()
e := &GitHubCopilotExecutor{}
req, _ := http.NewRequest(http.MethodPost, "https://example.com", nil)
body := []byte(`{"messages":[{"role":"user","content":"hello"},{"role":"tool","content":"result"}]}`)
e.applyHeaders(req, "token", body)
if got := req.Header.Get("X-Initiator"); got != "agent" {
t.Fatalf("X-Initiator = %q, want agent (tool role exists)", got)
}
}
// --- Tests for x-github-api-version header (Problem M) ---
func TestApplyHeaders_GitHubAPIVersion(t *testing.T) {
t.Parallel()
e := &GitHubCopilotExecutor{}
req, _ := http.NewRequest(http.MethodPost, "https://example.com", nil)
e.applyHeaders(req, "token", nil)
if got := req.Header.Get("X-Github-Api-Version"); got != "2025-04-01" {
t.Fatalf("X-Github-Api-Version = %q, want 2025-04-01", got)
}
}
// --- Tests for vision detection (Problem P) ---
func TestDetectVisionContent_WithImageURL(t *testing.T) {
t.Parallel()
body := []byte(`{"messages":[{"role":"user","content":[{"type":"text","text":"describe"},{"type":"image_url","image_url":{"url":"data:image/png;base64,abc"}}]}]}`)
if !detectVisionContent(body) {
t.Fatal("expected vision content to be detected")
}
}
func TestDetectVisionContent_WithImageType(t *testing.T) {
t.Parallel()
body := []byte(`{"messages":[{"role":"user","content":[{"type":"image","source":{"data":"abc","media_type":"image/png"}}]}]}`)
if !detectVisionContent(body) {
t.Fatal("expected image type to be detected")
}
}
func TestDetectVisionContent_NoVision(t *testing.T) {
t.Parallel()
body := []byte(`{"messages":[{"role":"user","content":[{"type":"text","text":"hello"}]}]}`)
if detectVisionContent(body) {
t.Fatal("expected no vision content")
}
}
func TestDetectVisionContent_NoMessages(t *testing.T) {
t.Parallel()
// After Responses API normalization, messages is removed — detection should return false
body := []byte(`{"input":[{"type":"message","role":"user","content":[{"type":"input_text","text":"hello"}]}]}`)
if detectVisionContent(body) {
t.Fatal("expected no vision content when messages field is absent")
}
}

View File

@@ -0,0 +1,460 @@
package executor
import (
"bufio"
"bytes"
"context"
"errors"
"fmt"
"io"
"net/http"
"strings"
"time"
"github.com/router-for-me/CLIProxyAPI/v6/internal/config"
"github.com/router-for-me/CLIProxyAPI/v6/internal/registry"
"github.com/router-for-me/CLIProxyAPI/v6/internal/thinking"
"github.com/router-for-me/CLIProxyAPI/v6/internal/util"
cliproxyauth "github.com/router-for-me/CLIProxyAPI/v6/sdk/cliproxy/auth"
cliproxyexecutor "github.com/router-for-me/CLIProxyAPI/v6/sdk/cliproxy/executor"
sdktranslator "github.com/router-for-me/CLIProxyAPI/v6/sdk/translator"
log "github.com/sirupsen/logrus"
"github.com/tidwall/gjson"
)
// KiloExecutor handles requests to Kilo API.
type KiloExecutor struct {
cfg *config.Config
}
// NewKiloExecutor creates a new Kilo executor instance.
func NewKiloExecutor(cfg *config.Config) *KiloExecutor {
return &KiloExecutor{cfg: cfg}
}
// Identifier returns the unique identifier for this executor.
func (e *KiloExecutor) Identifier() string { return "kilo" }
// PrepareRequest prepares the HTTP request before execution.
func (e *KiloExecutor) PrepareRequest(req *http.Request, auth *cliproxyauth.Auth) error {
if req == nil {
return nil
}
accessToken, _ := kiloCredentials(auth)
if strings.TrimSpace(accessToken) == "" {
return fmt.Errorf("kilo: missing access token")
}
req.Header.Set("Authorization", "Bearer "+accessToken)
var attrs map[string]string
if auth != nil {
attrs = auth.Attributes
}
util.ApplyCustomHeadersFromAttrs(req, attrs)
return nil
}
// HttpRequest executes a raw HTTP request.
func (e *KiloExecutor) HttpRequest(ctx context.Context, auth *cliproxyauth.Auth, req *http.Request) (*http.Response, error) {
if req == nil {
return nil, fmt.Errorf("kilo executor: request is nil")
}
if ctx == nil {
ctx = req.Context()
}
httpReq := req.WithContext(ctx)
if err := e.PrepareRequest(httpReq, auth); err != nil {
return nil, err
}
httpClient := newProxyAwareHTTPClient(ctx, e.cfg, auth, 0)
return httpClient.Do(httpReq)
}
// Execute performs a non-streaming request.
func (e *KiloExecutor) Execute(ctx context.Context, auth *cliproxyauth.Auth, req cliproxyexecutor.Request, opts cliproxyexecutor.Options) (resp cliproxyexecutor.Response, err error) {
baseModel := thinking.ParseSuffix(req.Model).ModelName
reporter := newUsageReporter(ctx, e.Identifier(), baseModel, auth)
defer reporter.trackFailure(ctx, &err)
accessToken, orgID := kiloCredentials(auth)
if accessToken == "" {
return resp, fmt.Errorf("kilo: missing access token")
}
from := opts.SourceFormat
to := sdktranslator.FromString("openai")
endpoint := "/api/openrouter/chat/completions"
originalPayloadSource := req.Payload
if len(opts.OriginalRequest) > 0 {
originalPayloadSource = opts.OriginalRequest
}
originalPayload := originalPayloadSource
originalTranslated := sdktranslator.TranslateRequest(from, to, baseModel, originalPayload, opts.Stream)
translated := sdktranslator.TranslateRequest(from, to, baseModel, req.Payload, opts.Stream)
requestedModel := payloadRequestedModel(opts, req.Model)
translated = applyPayloadConfigWithRoot(e.cfg, baseModel, to.String(), "", translated, originalTranslated, requestedModel)
translated, err = thinking.ApplyThinking(translated, req.Model, from.String(), to.String(), e.Identifier())
if err != nil {
return resp, err
}
url := "https://api.kilo.ai" + endpoint
httpReq, err := http.NewRequestWithContext(ctx, http.MethodPost, url, bytes.NewReader(translated))
if err != nil {
return resp, err
}
httpReq.Header.Set("Content-Type", "application/json")
httpReq.Header.Set("Authorization", "Bearer "+accessToken)
if orgID != "" {
httpReq.Header.Set("X-Kilocode-OrganizationID", orgID)
}
httpReq.Header.Set("User-Agent", "cli-proxy-kilo")
var attrs map[string]string
if auth != nil {
attrs = auth.Attributes
}
util.ApplyCustomHeadersFromAttrs(httpReq, attrs)
var authID, authLabel, authType, authValue string
if auth != nil {
authID = auth.ID
authLabel = auth.Label
authType, authValue = auth.AccountInfo()
}
recordAPIRequest(ctx, e.cfg, upstreamRequestLog{
URL: url,
Method: http.MethodPost,
Headers: httpReq.Header.Clone(),
Body: translated,
Provider: e.Identifier(),
AuthID: authID,
AuthLabel: authLabel,
AuthType: authType,
AuthValue: authValue,
})
httpClient := newProxyAwareHTTPClient(ctx, e.cfg, auth, 0)
httpResp, err := httpClient.Do(httpReq)
if err != nil {
recordAPIResponseError(ctx, e.cfg, err)
return resp, err
}
defer httpResp.Body.Close()
recordAPIResponseMetadata(ctx, e.cfg, httpResp.StatusCode, httpResp.Header.Clone())
if httpResp.StatusCode < 200 || httpResp.StatusCode >= 300 {
b, _ := io.ReadAll(httpResp.Body)
appendAPIResponseChunk(ctx, e.cfg, b)
err = statusErr{code: httpResp.StatusCode, msg: string(b)}
return resp, err
}
body, err := io.ReadAll(httpResp.Body)
if err != nil {
recordAPIResponseError(ctx, e.cfg, err)
return resp, err
}
appendAPIResponseChunk(ctx, e.cfg, body)
reporter.publish(ctx, parseOpenAIUsage(body))
reporter.ensurePublished(ctx)
var param any
out := sdktranslator.TranslateNonStream(ctx, to, from, req.Model, opts.OriginalRequest, translated, body, &param)
resp = cliproxyexecutor.Response{Payload: []byte(out)}
return resp, nil
}
// ExecuteStream performs a streaming request.
func (e *KiloExecutor) ExecuteStream(ctx context.Context, auth *cliproxyauth.Auth, req cliproxyexecutor.Request, opts cliproxyexecutor.Options) (_ *cliproxyexecutor.StreamResult, err error) {
baseModel := thinking.ParseSuffix(req.Model).ModelName
reporter := newUsageReporter(ctx, e.Identifier(), baseModel, auth)
defer reporter.trackFailure(ctx, &err)
accessToken, orgID := kiloCredentials(auth)
if accessToken == "" {
return nil, fmt.Errorf("kilo: missing access token")
}
from := opts.SourceFormat
to := sdktranslator.FromString("openai")
endpoint := "/api/openrouter/chat/completions"
originalPayloadSource := req.Payload
if len(opts.OriginalRequest) > 0 {
originalPayloadSource = opts.OriginalRequest
}
originalPayload := originalPayloadSource
originalTranslated := sdktranslator.TranslateRequest(from, to, baseModel, originalPayload, true)
translated := sdktranslator.TranslateRequest(from, to, baseModel, req.Payload, true)
requestedModel := payloadRequestedModel(opts, req.Model)
translated = applyPayloadConfigWithRoot(e.cfg, baseModel, to.String(), "", translated, originalTranslated, requestedModel)
translated, err = thinking.ApplyThinking(translated, req.Model, from.String(), to.String(), e.Identifier())
if err != nil {
return nil, err
}
url := "https://api.kilo.ai" + endpoint
httpReq, err := http.NewRequestWithContext(ctx, http.MethodPost, url, bytes.NewReader(translated))
if err != nil {
return nil, err
}
httpReq.Header.Set("Content-Type", "application/json")
httpReq.Header.Set("Authorization", "Bearer "+accessToken)
if orgID != "" {
httpReq.Header.Set("X-Kilocode-OrganizationID", orgID)
}
httpReq.Header.Set("User-Agent", "cli-proxy-kilo")
httpReq.Header.Set("Accept", "text/event-stream")
httpReq.Header.Set("Cache-Control", "no-cache")
var attrs map[string]string
if auth != nil {
attrs = auth.Attributes
}
util.ApplyCustomHeadersFromAttrs(httpReq, attrs)
var authID, authLabel, authType, authValue string
if auth != nil {
authID = auth.ID
authLabel = auth.Label
authType, authValue = auth.AccountInfo()
}
recordAPIRequest(ctx, e.cfg, upstreamRequestLog{
URL: url,
Method: http.MethodPost,
Headers: httpReq.Header.Clone(),
Body: translated,
Provider: e.Identifier(),
AuthID: authID,
AuthLabel: authLabel,
AuthType: authType,
AuthValue: authValue,
})
httpClient := newProxyAwareHTTPClient(ctx, e.cfg, auth, 0)
httpResp, err := httpClient.Do(httpReq)
if err != nil {
recordAPIResponseError(ctx, e.cfg, err)
return nil, err
}
recordAPIResponseMetadata(ctx, e.cfg, httpResp.StatusCode, httpResp.Header.Clone())
if httpResp.StatusCode < 200 || httpResp.StatusCode >= 300 {
b, _ := io.ReadAll(httpResp.Body)
appendAPIResponseChunk(ctx, e.cfg, b)
httpResp.Body.Close()
err = statusErr{code: httpResp.StatusCode, msg: string(b)}
return nil, err
}
out := make(chan cliproxyexecutor.StreamChunk)
go func() {
defer close(out)
defer httpResp.Body.Close()
scanner := bufio.NewScanner(httpResp.Body)
scanner.Buffer(nil, 52_428_800)
var param any
for scanner.Scan() {
line := scanner.Bytes()
appendAPIResponseChunk(ctx, e.cfg, line)
if detail, ok := parseOpenAIStreamUsage(line); ok {
reporter.publish(ctx, detail)
}
if len(line) == 0 {
continue
}
if !bytes.HasPrefix(line, []byte("data:")) {
continue
}
chunks := sdktranslator.TranslateStream(ctx, to, from, req.Model, opts.OriginalRequest, translated, bytes.Clone(line), &param)
for i := range chunks {
out <- cliproxyexecutor.StreamChunk{Payload: []byte(chunks[i])}
}
}
if errScan := scanner.Err(); errScan != nil {
recordAPIResponseError(ctx, e.cfg, errScan)
reporter.publishFailure(ctx)
out <- cliproxyexecutor.StreamChunk{Err: errScan}
}
reporter.ensurePublished(ctx)
}()
return &cliproxyexecutor.StreamResult{
Headers: httpResp.Header.Clone(),
Chunks: out,
}, nil
}
// Refresh validates the Kilo token.
func (e *KiloExecutor) Refresh(ctx context.Context, auth *cliproxyauth.Auth) (*cliproxyauth.Auth, error) {
if auth == nil {
return nil, fmt.Errorf("missing auth")
}
return auth, nil
}
// CountTokens returns the token count for the given request.
func (e *KiloExecutor) CountTokens(ctx context.Context, auth *cliproxyauth.Auth, req cliproxyexecutor.Request, opts cliproxyexecutor.Options) (cliproxyexecutor.Response, error) {
return cliproxyexecutor.Response{}, fmt.Errorf("kilo: count tokens not supported")
}
// kiloCredentials extracts access token and other info from auth.
func kiloCredentials(auth *cliproxyauth.Auth) (accessToken, orgID string) {
if auth == nil {
return "", ""
}
// Prefer kilocode specific keys, then fall back to generic keys.
// Check metadata first, then attributes.
if auth.Metadata != nil {
if token, ok := auth.Metadata["kilocodeToken"].(string); ok && token != "" {
accessToken = token
} else if token, ok := auth.Metadata["access_token"].(string); ok && token != "" {
accessToken = token
}
if org, ok := auth.Metadata["kilocodeOrganizationId"].(string); ok && org != "" {
orgID = org
} else if org, ok := auth.Metadata["organization_id"].(string); ok && org != "" {
orgID = org
}
}
if accessToken == "" && auth.Attributes != nil {
if token := auth.Attributes["kilocodeToken"]; token != "" {
accessToken = token
} else if token := auth.Attributes["access_token"]; token != "" {
accessToken = token
}
}
if orgID == "" && auth.Attributes != nil {
if org := auth.Attributes["kilocodeOrganizationId"]; org != "" {
orgID = org
} else if org := auth.Attributes["organization_id"]; org != "" {
orgID = org
}
}
return accessToken, orgID
}
// FetchKiloModels fetches models from Kilo API.
func FetchKiloModels(ctx context.Context, auth *cliproxyauth.Auth, cfg *config.Config) []*registry.ModelInfo {
accessToken, orgID := kiloCredentials(auth)
if accessToken == "" {
log.Infof("kilo: no access token found, skipping dynamic model fetch (using static kilo/auto)")
return registry.GetKiloModels()
}
log.Debugf("kilo: fetching dynamic models (orgID: %s)", orgID)
httpClient := newProxyAwareHTTPClient(ctx, cfg, auth, 0)
req, err := http.NewRequestWithContext(ctx, http.MethodGet, "https://api.kilo.ai/api/openrouter/models", nil)
if err != nil {
log.Warnf("kilo: failed to create model fetch request: %v", err)
return registry.GetKiloModels()
}
req.Header.Set("Authorization", "Bearer "+accessToken)
if orgID != "" {
req.Header.Set("X-Kilocode-OrganizationID", orgID)
}
req.Header.Set("User-Agent", "cli-proxy-kilo")
resp, err := httpClient.Do(req)
if err != nil {
if errors.Is(err, context.Canceled) || errors.Is(err, context.DeadlineExceeded) {
log.Warnf("kilo: fetch models canceled: %v", err)
} else {
log.Warnf("kilo: using static models (API fetch failed: %v)", err)
}
return registry.GetKiloModels()
}
defer resp.Body.Close()
body, err := io.ReadAll(resp.Body)
if err != nil {
log.Warnf("kilo: failed to read models response: %v", err)
return registry.GetKiloModels()
}
if resp.StatusCode != http.StatusOK {
log.Warnf("kilo: fetch models failed: status %d, body: %s", resp.StatusCode, string(body))
return registry.GetKiloModels()
}
result := gjson.GetBytes(body, "data")
if !result.Exists() {
// Try root if data field is missing
result = gjson.ParseBytes(body)
if !result.IsArray() {
log.Debugf("kilo: response body: %s", string(body))
log.Warn("kilo: invalid API response format (expected array or data field with array)")
return registry.GetKiloModels()
}
}
var dynamicModels []*registry.ModelInfo
now := time.Now().Unix()
count := 0
totalCount := 0
result.ForEach(func(key, value gjson.Result) bool {
totalCount++
id := value.Get("id").String()
pIdxResult := value.Get("preferredIndex")
preferredIndex := pIdxResult.Int()
// Filter models where preferredIndex > 0 (Kilo-curated models)
if preferredIndex <= 0 {
return true
}
// Check if it's free. We look for :free suffix, is_free flag, or zero pricing.
isFree := strings.HasSuffix(id, ":free") || id == "giga-potato" || value.Get("is_free").Bool()
if !isFree {
// Check pricing as fallback
promptPricing := value.Get("pricing.prompt").String()
if promptPricing == "0" || promptPricing == "0.0" {
isFree = true
}
}
if !isFree {
log.Debugf("kilo: skipping curated paid model: %s", id)
return true
}
log.Debugf("kilo: found curated model: %s (preferredIndex: %d)", id, preferredIndex)
dynamicModels = append(dynamicModels, &registry.ModelInfo{
ID: id,
DisplayName: value.Get("name").String(),
ContextLength: int(value.Get("context_length").Int()),
OwnedBy: "kilo",
Type: "kilo",
Object: "model",
Created: now,
})
count++
return true
})
log.Infof("kilo: fetched %d models from API, %d curated free (preferredIndex > 0)", totalCount, count)
if count == 0 && totalCount > 0 {
log.Warn("kilo: no curated free models found (check API response fields)")
}
staticModels := registry.GetKiloModels()
// Always include kilo/auto (first static model)
allModels := append(staticModels[:1], dynamicModels...)
return allModels
}

File diff suppressed because it is too large Load Diff

View File

@@ -0,0 +1,423 @@
package executor
import (
"fmt"
"testing"
kiroauth "github.com/router-for-me/CLIProxyAPI/v6/internal/auth/kiro"
cliproxyauth "github.com/router-for-me/CLIProxyAPI/v6/sdk/cliproxy/auth"
)
func TestBuildKiroEndpointConfigs(t *testing.T) {
tests := []struct {
name string
region string
expectedURL string
expectedOrigin string
expectedName string
}{
{
name: "Empty region - defaults to us-east-1",
region: "",
expectedURL: "https://q.us-east-1.amazonaws.com/generateAssistantResponse",
expectedOrigin: "AI_EDITOR",
expectedName: "AmazonQ",
},
{
name: "us-east-1",
region: "us-east-1",
expectedURL: "https://q.us-east-1.amazonaws.com/generateAssistantResponse",
expectedOrigin: "AI_EDITOR",
expectedName: "AmazonQ",
},
{
name: "ap-southeast-1",
region: "ap-southeast-1",
expectedURL: "https://q.ap-southeast-1.amazonaws.com/generateAssistantResponse",
expectedOrigin: "AI_EDITOR",
expectedName: "AmazonQ",
},
{
name: "eu-west-1",
region: "eu-west-1",
expectedURL: "https://q.eu-west-1.amazonaws.com/generateAssistantResponse",
expectedOrigin: "AI_EDITOR",
expectedName: "AmazonQ",
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
configs := buildKiroEndpointConfigs(tt.region)
if len(configs) != 2 {
t.Fatalf("expected 2 endpoint configs, got %d", len(configs))
}
// Check primary endpoint (AmazonQ)
primary := configs[0]
if primary.URL != tt.expectedURL {
t.Errorf("primary URL = %q, want %q", primary.URL, tt.expectedURL)
}
if primary.Origin != tt.expectedOrigin {
t.Errorf("primary Origin = %q, want %q", primary.Origin, tt.expectedOrigin)
}
if primary.Name != tt.expectedName {
t.Errorf("primary Name = %q, want %q", primary.Name, tt.expectedName)
}
if primary.AmzTarget != "" {
t.Errorf("primary AmzTarget should be empty, got %q", primary.AmzTarget)
}
// Check fallback endpoint (CodeWhisperer)
fallback := configs[1]
if fallback.Name != "CodeWhisperer" {
t.Errorf("fallback Name = %q, want %q", fallback.Name, "CodeWhisperer")
}
// CodeWhisperer fallback uses the same region as Q endpoint
expectedRegion := tt.region
if expectedRegion == "" {
expectedRegion = kiroDefaultRegion
}
expectedFallbackURL := fmt.Sprintf("https://codewhisperer.%s.amazonaws.com/generateAssistantResponse", expectedRegion)
if fallback.URL != expectedFallbackURL {
t.Errorf("fallback URL = %q, want %q", fallback.URL, expectedFallbackURL)
}
if fallback.AmzTarget == "" {
t.Error("fallback AmzTarget should NOT be empty")
}
})
}
}
func TestGetKiroEndpointConfigs_NilAuth(t *testing.T) {
configs := getKiroEndpointConfigs(nil)
if len(configs) != 2 {
t.Fatalf("expected 2 endpoint configs, got %d", len(configs))
}
// Should return default us-east-1 configs
if configs[0].Name != "AmazonQ" {
t.Errorf("first config Name = %q, want %q", configs[0].Name, "AmazonQ")
}
expectedURL := "https://q.us-east-1.amazonaws.com/generateAssistantResponse"
if configs[0].URL != expectedURL {
t.Errorf("first config URL = %q, want %q", configs[0].URL, expectedURL)
}
}
func TestGetKiroEndpointConfigs_WithRegionFromProfileArn(t *testing.T) {
auth := &cliproxyauth.Auth{
Metadata: map[string]any{
"profile_arn": "arn:aws:codewhisperer:ap-southeast-1:123456789012:profile/ABC",
},
}
configs := getKiroEndpointConfigs(auth)
if len(configs) != 2 {
t.Fatalf("expected 2 endpoint configs, got %d", len(configs))
}
expectedURL := "https://q.ap-southeast-1.amazonaws.com/generateAssistantResponse"
if configs[0].URL != expectedURL {
t.Errorf("primary URL = %q, want %q", configs[0].URL, expectedURL)
}
}
func TestGetKiroEndpointConfigs_WithApiRegionOverride(t *testing.T) {
auth := &cliproxyauth.Auth{
Metadata: map[string]any{
"api_region": "eu-central-1",
"profile_arn": "arn:aws:codewhisperer:us-east-1:123456789012:profile/ABC",
},
}
configs := getKiroEndpointConfigs(auth)
// api_region should take precedence over profile_arn
expectedURL := "https://q.eu-central-1.amazonaws.com/generateAssistantResponse"
if configs[0].URL != expectedURL {
t.Errorf("primary URL = %q, want %q", configs[0].URL, expectedURL)
}
}
func TestGetKiroEndpointConfigs_PreferredEndpoint(t *testing.T) {
tests := []struct {
name string
preference string
expectedFirstName string
}{
{
name: "Prefer codewhisperer",
preference: "codewhisperer",
expectedFirstName: "CodeWhisperer",
},
{
name: "Prefer ide (alias for codewhisperer)",
preference: "ide",
expectedFirstName: "CodeWhisperer",
},
{
name: "Prefer amazonq",
preference: "amazonq",
expectedFirstName: "AmazonQ",
},
{
name: "Prefer q (alias for amazonq)",
preference: "q",
expectedFirstName: "AmazonQ",
},
{
name: "Prefer cli (alias for amazonq)",
preference: "cli",
expectedFirstName: "AmazonQ",
},
{
name: "Unknown preference - no reordering",
preference: "unknown",
expectedFirstName: "AmazonQ",
},
{
name: "Empty preference - no reordering",
preference: "",
expectedFirstName: "AmazonQ",
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
auth := &cliproxyauth.Auth{
Metadata: map[string]any{
"preferred_endpoint": tt.preference,
},
}
configs := getKiroEndpointConfigs(auth)
if configs[0].Name != tt.expectedFirstName {
t.Errorf("first endpoint Name = %q, want %q", configs[0].Name, tt.expectedFirstName)
}
})
}
}
func TestGetKiroEndpointConfigs_PreferredEndpointFromAttributes(t *testing.T) {
// Test that preferred_endpoint can also come from Attributes
auth := &cliproxyauth.Auth{
Metadata: map[string]any{},
Attributes: map[string]string{"preferred_endpoint": "codewhisperer"},
}
configs := getKiroEndpointConfigs(auth)
if configs[0].Name != "CodeWhisperer" {
t.Errorf("first endpoint Name = %q, want %q", configs[0].Name, "CodeWhisperer")
}
}
func TestGetKiroEndpointConfigs_MetadataTakesPrecedenceOverAttributes(t *testing.T) {
auth := &cliproxyauth.Auth{
Metadata: map[string]any{"preferred_endpoint": "amazonq"},
Attributes: map[string]string{"preferred_endpoint": "codewhisperer"},
}
configs := getKiroEndpointConfigs(auth)
// Metadata should take precedence
if configs[0].Name != "AmazonQ" {
t.Errorf("first endpoint Name = %q, want %q", configs[0].Name, "AmazonQ")
}
}
func TestGetAuthValue(t *testing.T) {
tests := []struct {
name string
auth *cliproxyauth.Auth
key string
expected string
}{
{
name: "From metadata",
auth: &cliproxyauth.Auth{
Metadata: map[string]any{"test_key": "metadata_value"},
},
key: "test_key",
expected: "metadata_value",
},
{
name: "From attributes (fallback)",
auth: &cliproxyauth.Auth{
Attributes: map[string]string{"test_key": "attribute_value"},
},
key: "test_key",
expected: "attribute_value",
},
{
name: "Metadata takes precedence",
auth: &cliproxyauth.Auth{
Metadata: map[string]any{"test_key": "metadata_value"},
Attributes: map[string]string{"test_key": "attribute_value"},
},
key: "test_key",
expected: "metadata_value",
},
{
name: "Key not found",
auth: &cliproxyauth.Auth{
Metadata: map[string]any{"other_key": "value"},
Attributes: map[string]string{"another_key": "value"},
},
key: "test_key",
expected: "",
},
{
name: "Nil metadata",
auth: &cliproxyauth.Auth{
Attributes: map[string]string{"test_key": "attribute_value"},
},
key: "test_key",
expected: "attribute_value",
},
{
name: "Both nil",
auth: &cliproxyauth.Auth{},
key: "test_key",
expected: "",
},
{
name: "Value is trimmed and lowercased",
auth: &cliproxyauth.Auth{
Metadata: map[string]any{"test_key": " UPPER_VALUE "},
},
key: "test_key",
expected: "upper_value",
},
{
name: "Empty string value in metadata - falls back to attributes",
auth: &cliproxyauth.Auth{
Metadata: map[string]any{"test_key": ""},
Attributes: map[string]string{"test_key": "attribute_value"},
},
key: "test_key",
expected: "attribute_value",
},
{
name: "Non-string value in metadata - falls back to attributes",
auth: &cliproxyauth.Auth{
Metadata: map[string]any{"test_key": 123},
Attributes: map[string]string{"test_key": "attribute_value"},
},
key: "test_key",
expected: "attribute_value",
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
result := getAuthValue(tt.auth, tt.key)
if result != tt.expected {
t.Errorf("getAuthValue() = %q, want %q", result, tt.expected)
}
})
}
}
func TestGetAccountKey(t *testing.T) {
tests := []struct {
name string
auth *cliproxyauth.Auth
checkFn func(t *testing.T, result string)
}{
{
name: "From client_id",
auth: &cliproxyauth.Auth{
Metadata: map[string]any{
"client_id": "test-client-id-123",
"refresh_token": "test-refresh-token-456",
},
},
checkFn: func(t *testing.T, result string) {
expected := kiroauth.GetAccountKey("test-client-id-123", "test-refresh-token-456")
if result != expected {
t.Errorf("expected %s, got %s", expected, result)
}
},
},
{
name: "From refresh_token only",
auth: &cliproxyauth.Auth{
Metadata: map[string]any{
"refresh_token": "test-refresh-token-789",
},
},
checkFn: func(t *testing.T, result string) {
expected := kiroauth.GetAccountKey("", "test-refresh-token-789")
if result != expected {
t.Errorf("expected %s, got %s", expected, result)
}
},
},
{
name: "Nil auth",
auth: nil,
checkFn: func(t *testing.T, result string) {
if len(result) != 16 {
t.Errorf("expected 16 char key, got %d chars", len(result))
}
},
},
{
name: "Nil metadata",
auth: &cliproxyauth.Auth{},
checkFn: func(t *testing.T, result string) {
if len(result) != 16 {
t.Errorf("expected 16 char key, got %d chars", len(result))
}
},
},
{
name: "Empty metadata",
auth: &cliproxyauth.Auth{
Metadata: map[string]any{},
},
checkFn: func(t *testing.T, result string) {
if len(result) != 16 {
t.Errorf("expected 16 char key, got %d chars", len(result))
}
},
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
result := getAccountKey(tt.auth)
tt.checkFn(t, result)
})
}
}
func TestEndpointAliases(t *testing.T) {
// Verify all expected aliases are defined
expectedAliases := map[string]string{
"codewhisperer": "codewhisperer",
"ide": "codewhisperer",
"amazonq": "amazonq",
"q": "amazonq",
"cli": "amazonq",
}
for alias, target := range expectedAliases {
if actual, ok := endpointAliases[alias]; !ok {
t.Errorf("missing alias %q", alias)
} else if actual != target {
t.Errorf("alias %q = %q, want %q", alias, actual, target)
}
}
// Verify no unexpected aliases
if len(endpointAliases) != len(expectedAliases) {
t.Errorf("unexpected number of aliases: got %d, want %d", len(endpointAliases), len(expectedAliases))
}
}

View File

@@ -6,6 +6,7 @@ import (
"net/http"
"net/url"
"strings"
"sync"
"time"
"github.com/router-for-me/CLIProxyAPI/v6/internal/config"
@@ -14,11 +15,19 @@ import (
"golang.org/x/net/proxy"
)
// httpClientCache caches HTTP clients by proxy URL to enable connection reuse
var (
httpClientCache = make(map[string]*http.Client)
httpClientCacheMutex sync.RWMutex
)
// newProxyAwareHTTPClient creates an HTTP client with proper proxy configuration priority:
// 1. Use auth.ProxyURL if configured (highest priority)
// 2. Use cfg.ProxyURL if auth proxy is not configured
// 3. Use RoundTripper from context if neither are configured
//
// This function caches HTTP clients by proxy URL to enable TCP/TLS connection reuse.
//
// Parameters:
// - ctx: The context containing optional RoundTripper
// - cfg: The application configuration
@@ -28,11 +37,6 @@ import (
// Returns:
// - *http.Client: An HTTP client with configured proxy or transport
func newProxyAwareHTTPClient(ctx context.Context, cfg *config.Config, auth *cliproxyauth.Auth, timeout time.Duration) *http.Client {
httpClient := &http.Client{}
if timeout > 0 {
httpClient.Timeout = timeout
}
// Priority 1: Use auth.ProxyURL if configured
var proxyURL string
if auth != nil {
@@ -44,11 +48,39 @@ func newProxyAwareHTTPClient(ctx context.Context, cfg *config.Config, auth *clip
proxyURL = strings.TrimSpace(cfg.ProxyURL)
}
// Build cache key from proxy URL (empty string for no proxy)
cacheKey := proxyURL
// Check cache first
httpClientCacheMutex.RLock()
if cachedClient, ok := httpClientCache[cacheKey]; ok {
httpClientCacheMutex.RUnlock()
// Return a wrapper with the requested timeout but shared transport
if timeout > 0 {
return &http.Client{
Transport: cachedClient.Transport,
Timeout: timeout,
}
}
return cachedClient
}
httpClientCacheMutex.RUnlock()
// Create new client
httpClient := &http.Client{}
if timeout > 0 {
httpClient.Timeout = timeout
}
// If we have a proxy URL configured, set up the transport
if proxyURL != "" {
transport := buildProxyTransport(proxyURL)
if transport != nil {
httpClient.Transport = transport
// Cache the client
httpClientCacheMutex.Lock()
httpClientCache[cacheKey] = httpClient
httpClientCacheMutex.Unlock()
return httpClient
}
// If proxy setup failed, log and fall through to context RoundTripper
@@ -60,6 +92,13 @@ func newProxyAwareHTTPClient(ctx context.Context, cfg *config.Config, auth *clip
httpClient.Transport = rt
}
// Cache the client for no-proxy case
if proxyURL == "" {
httpClientCacheMutex.Lock()
httpClientCache[cacheKey] = httpClient
httpClientCacheMutex.Unlock()
}
return httpClient
}

View File

@@ -2,43 +2,109 @@ package executor
import (
"fmt"
"regexp"
"strconv"
"strings"
"sync"
"github.com/tidwall/gjson"
"github.com/tiktoken-go/tokenizer"
)
// tokenizerCache stores tokenizer instances to avoid repeated creation
var tokenizerCache sync.Map
// TokenizerWrapper wraps a tokenizer codec with an adjustment factor for models
// where tiktoken may not accurately estimate token counts (e.g., Claude models)
type TokenizerWrapper struct {
Codec tokenizer.Codec
AdjustmentFactor float64 // 1.0 means no adjustment, >1.0 means tiktoken underestimates
}
// Count returns the token count with adjustment factor applied
func (tw *TokenizerWrapper) Count(text string) (int, error) {
count, err := tw.Codec.Count(text)
if err != nil {
return 0, err
}
if tw.AdjustmentFactor != 1.0 && tw.AdjustmentFactor > 0 {
return int(float64(count) * tw.AdjustmentFactor), nil
}
return count, nil
}
// getTokenizer returns a cached tokenizer for the given model.
// This improves performance by avoiding repeated tokenizer creation.
func getTokenizer(model string) (*TokenizerWrapper, error) {
// Check cache first
if cached, ok := tokenizerCache.Load(model); ok {
return cached.(*TokenizerWrapper), nil
}
// Cache miss, create new tokenizer
wrapper, err := tokenizerForModel(model)
if err != nil {
return nil, err
}
// Store in cache (use LoadOrStore to handle race conditions)
actual, _ := tokenizerCache.LoadOrStore(model, wrapper)
return actual.(*TokenizerWrapper), nil
}
// tokenizerForModel returns a tokenizer codec suitable for an OpenAI-style model id.
func tokenizerForModel(model string) (tokenizer.Codec, error) {
// For Claude models, applies a 1.1 adjustment factor since tiktoken may underestimate.
func tokenizerForModel(model string) (*TokenizerWrapper, error) {
sanitized := strings.ToLower(strings.TrimSpace(model))
// Claude models use cl100k_base with 1.1 adjustment factor
// because tiktoken may underestimate Claude's actual token count
if strings.Contains(sanitized, "claude") || strings.HasPrefix(sanitized, "kiro-") || strings.HasPrefix(sanitized, "amazonq-") {
enc, err := tokenizer.Get(tokenizer.Cl100kBase)
if err != nil {
return nil, err
}
return &TokenizerWrapper{Codec: enc, AdjustmentFactor: 1.1}, nil
}
var enc tokenizer.Codec
var err error
switch {
case sanitized == "":
return tokenizer.Get(tokenizer.Cl100kBase)
case strings.HasPrefix(sanitized, "gpt-5"):
return tokenizer.ForModel(tokenizer.GPT5)
enc, err = tokenizer.Get(tokenizer.Cl100kBase)
case strings.HasPrefix(sanitized, "gpt-5.2"):
enc, err = tokenizer.ForModel(tokenizer.GPT5)
case strings.HasPrefix(sanitized, "gpt-5.1"):
return tokenizer.ForModel(tokenizer.GPT5)
enc, err = tokenizer.ForModel(tokenizer.GPT5)
case strings.HasPrefix(sanitized, "gpt-5"):
enc, err = tokenizer.ForModel(tokenizer.GPT5)
case strings.HasPrefix(sanitized, "gpt-4.1"):
return tokenizer.ForModel(tokenizer.GPT41)
enc, err = tokenizer.ForModel(tokenizer.GPT41)
case strings.HasPrefix(sanitized, "gpt-4o"):
return tokenizer.ForModel(tokenizer.GPT4o)
enc, err = tokenizer.ForModel(tokenizer.GPT4o)
case strings.HasPrefix(sanitized, "gpt-4"):
return tokenizer.ForModel(tokenizer.GPT4)
enc, err = tokenizer.ForModel(tokenizer.GPT4)
case strings.HasPrefix(sanitized, "gpt-3.5"), strings.HasPrefix(sanitized, "gpt-3"):
return tokenizer.ForModel(tokenizer.GPT35Turbo)
enc, err = tokenizer.ForModel(tokenizer.GPT35Turbo)
case strings.HasPrefix(sanitized, "o1"):
return tokenizer.ForModel(tokenizer.O1)
enc, err = tokenizer.ForModel(tokenizer.O1)
case strings.HasPrefix(sanitized, "o3"):
return tokenizer.ForModel(tokenizer.O3)
enc, err = tokenizer.ForModel(tokenizer.O3)
case strings.HasPrefix(sanitized, "o4"):
return tokenizer.ForModel(tokenizer.O4Mini)
enc, err = tokenizer.ForModel(tokenizer.O4Mini)
default:
return tokenizer.Get(tokenizer.O200kBase)
enc, err = tokenizer.Get(tokenizer.O200kBase)
}
if err != nil {
return nil, err
}
return &TokenizerWrapper{Codec: enc, AdjustmentFactor: 1.0}, nil
}
// countOpenAIChatTokens approximates prompt tokens for OpenAI chat completions payloads.
func countOpenAIChatTokens(enc tokenizer.Codec, payload []byte) (int64, error) {
func countOpenAIChatTokens(enc *TokenizerWrapper, payload []byte) (int64, error) {
if enc == nil {
return 0, fmt.Errorf("encoder is nil")
}
@@ -62,11 +128,206 @@ func countOpenAIChatTokens(enc tokenizer.Codec, payload []byte) (int64, error) {
return 0, nil
}
// Count text tokens
count, err := enc.Count(joined)
if err != nil {
return 0, err
}
return int64(count), nil
// Extract and add image tokens from placeholders
imageTokens := extractImageTokens(joined)
return int64(count) + int64(imageTokens), nil
}
// countClaudeChatTokens approximates prompt tokens for Claude API chat completions payloads.
// This handles Claude's message format with system, messages, and tools.
// Image tokens are estimated based on image dimensions when available.
func countClaudeChatTokens(enc *TokenizerWrapper, payload []byte) (int64, error) {
if enc == nil {
return 0, fmt.Errorf("encoder is nil")
}
if len(payload) == 0 {
return 0, nil
}
root := gjson.ParseBytes(payload)
segments := make([]string, 0, 32)
// Collect system prompt (can be string or array of content blocks)
collectClaudeSystem(root.Get("system"), &segments)
// Collect messages
collectClaudeMessages(root.Get("messages"), &segments)
// Collect tools
collectClaudeTools(root.Get("tools"), &segments)
joined := strings.TrimSpace(strings.Join(segments, "\n"))
if joined == "" {
return 0, nil
}
// Count text tokens
count, err := enc.Count(joined)
if err != nil {
return 0, err
}
// Extract and add image tokens from placeholders
imageTokens := extractImageTokens(joined)
return int64(count) + int64(imageTokens), nil
}
// imageTokenPattern matches [IMAGE:xxx tokens] format for extracting estimated image tokens
var imageTokenPattern = regexp.MustCompile(`\[IMAGE:(\d+) tokens\]`)
// extractImageTokens extracts image token estimates from placeholder text.
// Placeholders are in the format [IMAGE:xxx tokens] where xxx is the estimated token count.
func extractImageTokens(text string) int {
matches := imageTokenPattern.FindAllStringSubmatch(text, -1)
total := 0
for _, match := range matches {
if len(match) > 1 {
if tokens, err := strconv.Atoi(match[1]); err == nil {
total += tokens
}
}
}
return total
}
// estimateImageTokens calculates estimated tokens for an image based on dimensions.
// Based on Claude's image token calculation: tokens ≈ (width * height) / 750
// Minimum 85 tokens, maximum 1590 tokens (for 1568x1568 images).
func estimateImageTokens(width, height float64) int {
if width <= 0 || height <= 0 {
// No valid dimensions, use default estimate (medium-sized image)
return 1000
}
tokens := int(width * height / 750)
// Apply bounds
if tokens < 85 {
tokens = 85
}
if tokens > 1590 {
tokens = 1590
}
return tokens
}
// collectClaudeSystem extracts text from Claude's system field.
// System can be a string or an array of content blocks.
func collectClaudeSystem(system gjson.Result, segments *[]string) {
if !system.Exists() {
return
}
if system.Type == gjson.String {
addIfNotEmpty(segments, system.String())
return
}
if system.IsArray() {
system.ForEach(func(_, block gjson.Result) bool {
blockType := block.Get("type").String()
if blockType == "text" || blockType == "" {
addIfNotEmpty(segments, block.Get("text").String())
}
// Also handle plain string blocks
if block.Type == gjson.String {
addIfNotEmpty(segments, block.String())
}
return true
})
}
}
// collectClaudeMessages extracts text from Claude's messages array.
func collectClaudeMessages(messages gjson.Result, segments *[]string) {
if !messages.Exists() || !messages.IsArray() {
return
}
messages.ForEach(func(_, message gjson.Result) bool {
addIfNotEmpty(segments, message.Get("role").String())
collectClaudeContent(message.Get("content"), segments)
return true
})
}
// collectClaudeContent extracts text from Claude's content field.
// Content can be a string or an array of content blocks.
// For images, estimates token count based on dimensions when available.
func collectClaudeContent(content gjson.Result, segments *[]string) {
if !content.Exists() {
return
}
if content.Type == gjson.String {
addIfNotEmpty(segments, content.String())
return
}
if content.IsArray() {
content.ForEach(func(_, part gjson.Result) bool {
partType := part.Get("type").String()
switch partType {
case "text":
addIfNotEmpty(segments, part.Get("text").String())
case "image":
// Estimate image tokens based on dimensions if available
source := part.Get("source")
if source.Exists() {
width := source.Get("width").Float()
height := source.Get("height").Float()
if width > 0 && height > 0 {
tokens := estimateImageTokens(width, height)
addIfNotEmpty(segments, fmt.Sprintf("[IMAGE:%d tokens]", tokens))
} else {
// No dimensions available, use default estimate
addIfNotEmpty(segments, "[IMAGE:1000 tokens]")
}
} else {
// No source info, use default estimate
addIfNotEmpty(segments, "[IMAGE:1000 tokens]")
}
case "tool_use":
addIfNotEmpty(segments, part.Get("id").String())
addIfNotEmpty(segments, part.Get("name").String())
if input := part.Get("input"); input.Exists() {
addIfNotEmpty(segments, input.Raw)
}
case "tool_result":
addIfNotEmpty(segments, part.Get("tool_use_id").String())
collectClaudeContent(part.Get("content"), segments)
case "thinking":
addIfNotEmpty(segments, part.Get("thinking").String())
default:
// For unknown types, try to extract any text content
if part.Type == gjson.String {
addIfNotEmpty(segments, part.String())
} else if part.Type == gjson.JSON {
addIfNotEmpty(segments, part.Raw)
}
}
return true
})
}
}
// collectClaudeTools extracts text from Claude's tools array.
func collectClaudeTools(tools gjson.Result, segments *[]string) {
if !tools.Exists() || !tools.IsArray() {
return
}
tools.ForEach(func(_, tool gjson.Result) bool {
addIfNotEmpty(segments, tool.Get("name").String())
addIfNotEmpty(segments, tool.Get("description").String())
if inputSchema := tool.Get("input_schema"); inputSchema.Exists() {
addIfNotEmpty(segments, inputSchema.Raw)
}
return true
})
}
// buildOpenAIUsageJSON returns a minimal usage structure understood by downstream translators.

View File

@@ -252,6 +252,44 @@ func parseOpenAIStreamUsage(line []byte) (usage.Detail, bool) {
return detail, true
}
func parseOpenAIResponsesUsageDetail(usageNode gjson.Result) usage.Detail {
detail := usage.Detail{
InputTokens: usageNode.Get("input_tokens").Int(),
OutputTokens: usageNode.Get("output_tokens").Int(),
TotalTokens: usageNode.Get("total_tokens").Int(),
}
if detail.TotalTokens == 0 {
detail.TotalTokens = detail.InputTokens + detail.OutputTokens
}
if cached := usageNode.Get("input_tokens_details.cached_tokens"); cached.Exists() {
detail.CachedTokens = cached.Int()
}
if reasoning := usageNode.Get("output_tokens_details.reasoning_tokens"); reasoning.Exists() {
detail.ReasoningTokens = reasoning.Int()
}
return detail
}
func parseOpenAIResponsesUsage(data []byte) usage.Detail {
usageNode := gjson.ParseBytes(data).Get("usage")
if !usageNode.Exists() {
return usage.Detail{}
}
return parseOpenAIResponsesUsageDetail(usageNode)
}
func parseOpenAIResponsesStreamUsage(line []byte) (usage.Detail, bool) {
payload := jsonPayload(line)
if len(payload) == 0 || !gjson.ValidBytes(payload) {
return usage.Detail{}, false
}
usageNode := gjson.GetBytes(payload, "usage")
if !usageNode.Exists() {
return usage.Detail{}, false
}
return parseOpenAIResponsesUsageDetail(usageNode), true
}
func parseClaudeUsage(data []byte) usage.Detail {
usageNode := gjson.ParseBytes(data).Get("usage")
if !usageNode.Exists() {

View File

@@ -50,6 +50,10 @@ type ToolCallAccumulator struct {
// Returns:
// - []string: A slice of strings, each containing an OpenAI-compatible JSON response
func ConvertClaudeResponseToOpenAI(_ context.Context, modelName string, originalRequestRawJSON, requestRawJSON, rawJSON []byte, param *any) []string {
var localParam any
if param == nil {
param = &localParam
}
if *param == nil {
*param = &ConvertAnthropicResponseToOpenAIParams{
CreatedAt: 0,

View File

@@ -33,4 +33,7 @@ import (
_ "github.com/router-for-me/CLIProxyAPI/v6/internal/translator/antigravity/gemini"
_ "github.com/router-for-me/CLIProxyAPI/v6/internal/translator/antigravity/openai/chat-completions"
_ "github.com/router-for-me/CLIProxyAPI/v6/internal/translator/antigravity/openai/responses"
_ "github.com/router-for-me/CLIProxyAPI/v6/internal/translator/kiro/claude"
_ "github.com/router-for-me/CLIProxyAPI/v6/internal/translator/kiro/openai"
)

View File

@@ -0,0 +1,20 @@
// Package claude provides translation between Kiro and Claude formats.
package claude
import (
. "github.com/router-for-me/CLIProxyAPI/v6/internal/constant"
"github.com/router-for-me/CLIProxyAPI/v6/internal/interfaces"
"github.com/router-for-me/CLIProxyAPI/v6/internal/translator/translator"
)
func init() {
translator.Register(
Claude,
Kiro,
ConvertClaudeRequestToKiro,
interfaces.TranslateResponse{
Stream: ConvertKiroStreamToClaude,
NonStream: ConvertKiroNonStreamToClaude,
},
)
}

View File

@@ -0,0 +1,21 @@
// Package claude provides translation between Kiro and Claude formats.
// Since Kiro executor generates Claude-compatible SSE format internally (with event: prefix),
// translations are pass-through for streaming, but responses need proper formatting.
package claude
import (
"context"
)
// ConvertKiroStreamToClaude converts Kiro streaming response to Claude format.
// Kiro executor already generates complete SSE format with "event:" prefix,
// so this is a simple pass-through.
func ConvertKiroStreamToClaude(ctx context.Context, model string, originalRequest, request, rawResponse []byte, param *any) []string {
return []string{string(rawResponse)}
}
// ConvertKiroNonStreamToClaude converts Kiro non-streaming response to Claude format.
// The response is already in Claude format, so this is a pass-through.
func ConvertKiroNonStreamToClaude(ctx context.Context, model string, originalRequest, request, rawResponse []byte, param *any) string {
return string(rawResponse)
}

View File

@@ -0,0 +1,992 @@
// Package claude provides request translation functionality for Claude API to Kiro format.
// It handles parsing and transforming Claude API requests into the Kiro/Amazon Q API format,
// extracting model information, system instructions, message contents, and tool declarations.
package claude
import (
"encoding/json"
"fmt"
"net/http"
"strings"
"time"
"unicode/utf8"
"github.com/google/uuid"
kirocommon "github.com/router-for-me/CLIProxyAPI/v6/internal/translator/kiro/common"
log "github.com/sirupsen/logrus"
"github.com/tidwall/gjson"
)
// remoteWebSearchDescription is a minimal fallback for when dynamic fetch from MCP tools/list hasn't completed yet.
const remoteWebSearchDescription = "WebSearch looks up information outside the model's training data. Supports multiple queries to gather comprehensive information."
// Kiro API request structs - field order determines JSON key order
// KiroPayload is the top-level request structure for Kiro API
type KiroPayload struct {
ConversationState KiroConversationState `json:"conversationState"`
ProfileArn string `json:"profileArn,omitempty"`
InferenceConfig *KiroInferenceConfig `json:"inferenceConfig,omitempty"`
}
// KiroInferenceConfig contains inference parameters for the Kiro API.
type KiroInferenceConfig struct {
MaxTokens int `json:"maxTokens,omitempty"`
Temperature float64 `json:"temperature,omitempty"`
TopP float64 `json:"topP,omitempty"`
}
// KiroConversationState holds the conversation context
type KiroConversationState struct {
AgentContinuationID string `json:"agentContinuationId,omitempty"`
AgentTaskType string `json:"agentTaskType,omitempty"`
ChatTriggerType string `json:"chatTriggerType"` // Required: "MANUAL"
ConversationID string `json:"conversationId"`
CurrentMessage KiroCurrentMessage `json:"currentMessage"`
History []KiroHistoryMessage `json:"history,omitempty"`
}
// KiroCurrentMessage wraps the current user message
type KiroCurrentMessage struct {
UserInputMessage KiroUserInputMessage `json:"userInputMessage"`
}
// KiroHistoryMessage represents a message in the conversation history
type KiroHistoryMessage struct {
UserInputMessage *KiroUserInputMessage `json:"userInputMessage,omitempty"`
AssistantResponseMessage *KiroAssistantResponseMessage `json:"assistantResponseMessage,omitempty"`
}
// KiroImage represents an image in Kiro API format
type KiroImage struct {
Format string `json:"format"`
Source KiroImageSource `json:"source"`
}
// KiroImageSource contains the image data
type KiroImageSource struct {
Bytes string `json:"bytes"` // base64 encoded image data
}
// KiroUserInputMessage represents a user message
type KiroUserInputMessage struct {
Content string `json:"content"`
ModelID string `json:"modelId"`
Origin string `json:"origin"`
Images []KiroImage `json:"images,omitempty"`
UserInputMessageContext *KiroUserInputMessageContext `json:"userInputMessageContext,omitempty"`
}
// KiroUserInputMessageContext contains tool-related context
type KiroUserInputMessageContext struct {
ToolResults []KiroToolResult `json:"toolResults,omitempty"`
Tools []KiroToolWrapper `json:"tools,omitempty"`
}
// KiroToolResult represents a tool execution result
type KiroToolResult struct {
Content []KiroTextContent `json:"content"`
Status string `json:"status"`
ToolUseID string `json:"toolUseId"`
}
// KiroTextContent represents text content
type KiroTextContent struct {
Text string `json:"text"`
}
// KiroToolWrapper wraps a tool specification
type KiroToolWrapper struct {
ToolSpecification KiroToolSpecification `json:"toolSpecification"`
}
// KiroToolSpecification defines a tool's schema
type KiroToolSpecification struct {
Name string `json:"name"`
Description string `json:"description"`
InputSchema KiroInputSchema `json:"inputSchema"`
}
// KiroInputSchema wraps the JSON schema for tool input
type KiroInputSchema struct {
JSON interface{} `json:"json"`
}
// KiroAssistantResponseMessage represents an assistant message
type KiroAssistantResponseMessage struct {
Content string `json:"content"`
ToolUses []KiroToolUse `json:"toolUses,omitempty"`
}
// KiroToolUse represents a tool invocation by the assistant
type KiroToolUse struct {
ToolUseID string `json:"toolUseId"`
Name string `json:"name"`
Input map[string]interface{} `json:"input"`
IsTruncated bool `json:"-"` // Internal flag, not serialized
TruncationInfo *TruncationInfo `json:"-"` // Truncation details, not serialized
}
// ConvertClaudeRequestToKiro converts a Claude API request to Kiro format.
// This is the main entry point for request translation.
func ConvertClaudeRequestToKiro(modelName string, inputRawJSON []byte, stream bool) []byte {
// For Kiro, we pass through the Claude format since buildKiroPayload
// expects Claude format and does the conversion internally.
// The actual conversion happens in the executor when building the HTTP request.
return inputRawJSON
}
// BuildKiroPayload constructs the Kiro API request payload from Claude format.
// Supports tool calling - tools are passed via userInputMessageContext.
// origin parameter determines which quota to use: "CLI" for Amazon Q, "AI_EDITOR" for Kiro IDE.
// isAgentic parameter enables chunked write optimization prompt for -agentic model variants.
// isChatOnly parameter disables tool calling for -chat model variants (pure conversation mode).
// headers parameter allows checking Anthropic-Beta header for thinking mode detection.
// metadata parameter is kept for API compatibility but no longer used for thinking configuration.
// Supports thinking mode - when enabled, injects thinking tags into system prompt.
// Returns the payload and a boolean indicating whether thinking mode was injected.
func BuildKiroPayload(claudeBody []byte, modelID, profileArn, origin string, isAgentic, isChatOnly bool, headers http.Header, metadata map[string]any) ([]byte, bool) {
// Extract max_tokens for potential use in inferenceConfig
// Handle -1 as "use maximum" (Kiro max output is ~32000 tokens)
const kiroMaxOutputTokens = 32000
var maxTokens int64
if mt := gjson.GetBytes(claudeBody, "max_tokens"); mt.Exists() {
maxTokens = mt.Int()
if maxTokens == -1 {
maxTokens = kiroMaxOutputTokens
log.Debugf("kiro: max_tokens=-1 converted to %d", kiroMaxOutputTokens)
}
}
// Extract temperature if specified
var temperature float64
var hasTemperature bool
if temp := gjson.GetBytes(claudeBody, "temperature"); temp.Exists() {
temperature = temp.Float()
hasTemperature = true
}
// Extract top_p if specified
var topP float64
var hasTopP bool
if tp := gjson.GetBytes(claudeBody, "top_p"); tp.Exists() {
topP = tp.Float()
hasTopP = true
log.Debugf("kiro: extracted top_p: %.2f", topP)
}
// Normalize origin value for Kiro API compatibility
origin = normalizeOrigin(origin)
log.Debugf("kiro: normalized origin value: %s", origin)
messages := gjson.GetBytes(claudeBody, "messages")
// For chat-only mode, don't include tools
var tools gjson.Result
if !isChatOnly {
tools = gjson.GetBytes(claudeBody, "tools")
}
// Extract system prompt
systemPrompt := extractSystemPrompt(claudeBody)
// Check for thinking mode using the comprehensive IsThinkingEnabledWithHeaders function
// This supports Claude API format, OpenAI reasoning_effort, AMP/Cursor format, and Anthropic-Beta header
thinkingEnabled := IsThinkingEnabledWithHeaders(claudeBody, headers)
// Inject timestamp context
timestamp := time.Now().Format("2006-01-02 15:04:05 MST")
timestampContext := fmt.Sprintf("[Context: Current time is %s]", timestamp)
if systemPrompt != "" {
systemPrompt = timestampContext + "\n\n" + systemPrompt
} else {
systemPrompt = timestampContext
}
log.Debugf("kiro: injected timestamp context: %s", timestamp)
// Inject agentic optimization prompt for -agentic model variants
if isAgentic {
if systemPrompt != "" {
systemPrompt += "\n"
}
systemPrompt += kirocommon.KiroAgenticSystemPrompt
}
// Handle tool_choice parameter - Kiro doesn't support it natively, so we inject system prompt hints
// Claude tool_choice values: {"type": "auto/any/tool", "name": "..."}
toolChoiceHint := extractClaudeToolChoiceHint(claudeBody)
if toolChoiceHint != "" {
if systemPrompt != "" {
systemPrompt += "\n"
}
systemPrompt += toolChoiceHint
log.Debugf("kiro: injected tool_choice hint into system prompt")
}
// Convert Claude tools to Kiro format
kiroTools := convertClaudeToolsToKiro(tools)
// Thinking mode implementation:
// Kiro API supports official thinking/reasoning mode via <thinking_mode> tag.
// When set to "enabled", Kiro returns reasoning content as official reasoningContentEvent
// rather than inline <thinking> tags in assistantResponseEvent.
// We cap max_thinking_length to reserve space for tool outputs and prevent truncation.
if thinkingEnabled {
thinkingHint := `<thinking_mode>enabled</thinking_mode>
<max_thinking_length>16000</max_thinking_length>`
if systemPrompt != "" {
systemPrompt = thinkingHint + "\n\n" + systemPrompt
} else {
systemPrompt = thinkingHint
}
log.Infof("kiro: injected thinking prompt (official mode), has_tools: %v", len(kiroTools) > 0)
}
// Process messages and build history
history, currentUserMsg, currentToolResults := processMessages(messages, modelID, origin)
// Build content with system prompt.
// Keep thinking tags on subsequent turns so multi-turn Claude sessions
// continue to emit reasoning events.
if currentUserMsg != nil {
currentUserMsg.Content = buildFinalContent(currentUserMsg.Content, systemPrompt, currentToolResults)
// Deduplicate currentToolResults
currentToolResults = deduplicateToolResults(currentToolResults)
// Build userInputMessageContext with tools and tool results
if len(kiroTools) > 0 || len(currentToolResults) > 0 {
currentUserMsg.UserInputMessageContext = &KiroUserInputMessageContext{
Tools: kiroTools,
ToolResults: currentToolResults,
}
}
}
// Build payload
var currentMessage KiroCurrentMessage
if currentUserMsg != nil {
currentMessage = KiroCurrentMessage{UserInputMessage: *currentUserMsg}
} else {
fallbackContent := ""
if systemPrompt != "" {
fallbackContent = "--- SYSTEM PROMPT ---\n" + systemPrompt + "\n--- END SYSTEM PROMPT ---\n"
}
currentMessage = KiroCurrentMessage{UserInputMessage: KiroUserInputMessage{
Content: fallbackContent,
ModelID: modelID,
Origin: origin,
}}
}
// Build inferenceConfig if we have any inference parameters
// Note: Kiro API doesn't actually use max_tokens for thinking budget
var inferenceConfig *KiroInferenceConfig
if maxTokens > 0 || hasTemperature || hasTopP {
inferenceConfig = &KiroInferenceConfig{}
if maxTokens > 0 {
inferenceConfig.MaxTokens = int(maxTokens)
}
if hasTemperature {
inferenceConfig.Temperature = temperature
}
if hasTopP {
inferenceConfig.TopP = topP
}
}
// Session IDs: extract from messages[].additional_kwargs (LangChain format) or random
conversationID := extractMetadataFromMessages(messages, "conversationId")
continuationID := extractMetadataFromMessages(messages, "continuationId")
if conversationID == "" {
conversationID = uuid.New().String()
}
payload := KiroPayload{
ConversationState: KiroConversationState{
AgentTaskType: "vibe",
ChatTriggerType: "MANUAL",
ConversationID: conversationID,
CurrentMessage: currentMessage,
History: history,
},
ProfileArn: profileArn,
InferenceConfig: inferenceConfig,
}
// Only set AgentContinuationID if client provided
if continuationID != "" {
payload.ConversationState.AgentContinuationID = continuationID
}
result, err := json.Marshal(payload)
if err != nil {
log.Debugf("kiro: failed to marshal payload: %v", err)
return nil, false
}
return result, thinkingEnabled
}
// normalizeOrigin normalizes origin value for Kiro API compatibility
func normalizeOrigin(origin string) string {
switch origin {
case "KIRO_CLI":
return "CLI"
case "KIRO_AI_EDITOR":
return "AI_EDITOR"
case "AMAZON_Q":
return "CLI"
case "KIRO_IDE":
return "AI_EDITOR"
default:
return origin
}
}
// extractMetadataFromMessages extracts metadata from messages[].additional_kwargs (LangChain format).
// Searches from the last message backwards, returns empty string if not found.
func extractMetadataFromMessages(messages gjson.Result, key string) string {
arr := messages.Array()
for i := len(arr) - 1; i >= 0; i-- {
if val := arr[i].Get("additional_kwargs." + key); val.Exists() && val.String() != "" {
return val.String()
}
}
return ""
}
// extractSystemPrompt extracts system prompt from Claude request
func extractSystemPrompt(claudeBody []byte) string {
systemField := gjson.GetBytes(claudeBody, "system")
if systemField.IsArray() {
var sb strings.Builder
for _, block := range systemField.Array() {
if block.Get("type").String() == "text" {
sb.WriteString(block.Get("text").String())
} else if block.Type == gjson.String {
sb.WriteString(block.String())
}
}
return sb.String()
}
return systemField.String()
}
// checkThinkingMode checks if thinking mode is enabled in the Claude request
func checkThinkingMode(claudeBody []byte) (bool, int64) {
thinkingEnabled := false
var budgetTokens int64 = 24000
thinkingField := gjson.GetBytes(claudeBody, "thinking")
if thinkingField.Exists() {
thinkingType := thinkingField.Get("type").String()
if thinkingType == "enabled" {
thinkingEnabled = true
if bt := thinkingField.Get("budget_tokens"); bt.Exists() {
budgetTokens = bt.Int()
if budgetTokens <= 0 {
thinkingEnabled = false
log.Debugf("kiro: thinking mode disabled via budget_tokens <= 0")
}
}
if thinkingEnabled {
log.Debugf("kiro: thinking mode enabled via Claude API parameter, budget_tokens: %d", budgetTokens)
}
}
}
return thinkingEnabled, budgetTokens
}
// hasThinkingTagInBody checks if the request body already contains thinking configuration tags.
// This is used to prevent duplicate injection when client (e.g., AMP/Cursor) already includes thinking config.
func hasThinkingTagInBody(body []byte) bool {
bodyStr := string(body)
return strings.Contains(bodyStr, "<thinking_mode>") || strings.Contains(bodyStr, "<max_thinking_length>")
}
// IsThinkingEnabledFromHeader checks if thinking mode is enabled via Anthropic-Beta header.
// Claude CLI uses "Anthropic-Beta: interleaved-thinking-2025-05-14" to enable thinking.
func IsThinkingEnabledFromHeader(headers http.Header) bool {
if headers == nil {
return false
}
betaHeader := headers.Get("Anthropic-Beta")
if betaHeader == "" {
return false
}
// Check for interleaved-thinking beta feature
if strings.Contains(betaHeader, "interleaved-thinking") {
log.Debugf("kiro: thinking mode enabled via Anthropic-Beta header: %s", betaHeader)
return true
}
return false
}
// IsThinkingEnabled is a public wrapper to check if thinking mode is enabled.
// This is used by the executor to determine whether to parse <thinking> tags in responses.
// When thinking is NOT enabled in the request, <thinking> tags in responses should be
// treated as regular text content, not as thinking blocks.
//
// Supports multiple formats:
// - Claude API format: thinking.type = "enabled"
// - OpenAI format: reasoning_effort parameter
// - AMP/Cursor format: <thinking_mode>interleaved</thinking_mode> in system prompt
func IsThinkingEnabled(body []byte) bool {
return IsThinkingEnabledWithHeaders(body, nil)
}
// IsThinkingEnabledWithHeaders checks if thinking mode is enabled from body or headers.
// This is the comprehensive check that supports all thinking detection methods:
// - Claude API format: thinking.type = "enabled"
// - OpenAI format: reasoning_effort parameter
// - AMP/Cursor format: <thinking_mode>interleaved</thinking_mode> in system prompt
// - Anthropic-Beta header: interleaved-thinking-2025-05-14
func IsThinkingEnabledWithHeaders(body []byte, headers http.Header) bool {
// Check Anthropic-Beta header first (Claude Code uses this)
if IsThinkingEnabledFromHeader(headers) {
return true
}
// Check Claude API format first (thinking.type = "enabled")
enabled, _ := checkThinkingMode(body)
if enabled {
log.Debugf("kiro: IsThinkingEnabled returning true (Claude API format)")
return true
}
// Check OpenAI format: reasoning_effort parameter
// Valid values: "low", "medium", "high", "auto" (not "none")
reasoningEffort := gjson.GetBytes(body, "reasoning_effort")
if reasoningEffort.Exists() {
effort := reasoningEffort.String()
if effort != "" && effort != "none" {
log.Debugf("kiro: thinking mode enabled via OpenAI reasoning_effort: %s", effort)
return true
}
}
// Check AMP/Cursor format: <thinking_mode>interleaved</thinking_mode> in system prompt
// This is how AMP client passes thinking configuration
bodyStr := string(body)
if strings.Contains(bodyStr, "<thinking_mode>") && strings.Contains(bodyStr, "</thinking_mode>") {
// Extract thinking mode value
startTag := "<thinking_mode>"
endTag := "</thinking_mode>"
startIdx := strings.Index(bodyStr, startTag)
if startIdx >= 0 {
startIdx += len(startTag)
endIdx := strings.Index(bodyStr[startIdx:], endTag)
if endIdx >= 0 {
thinkingMode := bodyStr[startIdx : startIdx+endIdx]
if thinkingMode == "interleaved" || thinkingMode == "enabled" {
log.Debugf("kiro: thinking mode enabled via AMP/Cursor format: %s", thinkingMode)
return true
}
}
}
}
// Check OpenAI format: max_completion_tokens with reasoning (o1-style)
// Some clients use this to indicate reasoning mode
if gjson.GetBytes(body, "max_completion_tokens").Exists() {
// If max_completion_tokens is set, check if model name suggests reasoning
model := gjson.GetBytes(body, "model").String()
if strings.Contains(strings.ToLower(model), "thinking") ||
strings.Contains(strings.ToLower(model), "reason") {
log.Debugf("kiro: thinking mode enabled via model name hint: %s", model)
return true
}
}
// Check model name directly for thinking hints.
// This enables thinking variants even when clients don't send explicit thinking fields.
model := strings.TrimSpace(gjson.GetBytes(body, "model").String())
modelLower := strings.ToLower(model)
if strings.Contains(modelLower, "thinking") || strings.Contains(modelLower, "-reason") {
log.Debugf("kiro: thinking mode enabled via model name hint: %s", model)
return true
}
log.Debugf("kiro: IsThinkingEnabled returning false (no thinking mode detected)")
return false
}
// shortenToolNameIfNeeded shortens tool names that exceed 64 characters.
// MCP tools often have long names like "mcp__server-name__tool-name".
// This preserves the "mcp__" prefix and last segment when possible.
func shortenToolNameIfNeeded(name string) string {
const limit = 64
if len(name) <= limit {
return name
}
// For MCP tools, try to preserve prefix and last segment
if strings.HasPrefix(name, "mcp__") {
idx := strings.LastIndex(name, "__")
if idx > 0 {
cand := "mcp__" + name[idx+2:]
if len(cand) > limit {
return cand[:limit]
}
return cand
}
}
return name[:limit]
}
func ensureKiroInputSchema(parameters interface{}) interface{} {
if parameters != nil {
return parameters
}
return map[string]interface{}{
"type": "object",
"properties": map[string]interface{}{},
}
}
// convertClaudeToolsToKiro converts Claude tools to Kiro format
func convertClaudeToolsToKiro(tools gjson.Result) []KiroToolWrapper {
var kiroTools []KiroToolWrapper
if !tools.IsArray() {
return kiroTools
}
for _, tool := range tools.Array() {
name := tool.Get("name").String()
description := tool.Get("description").String()
inputSchemaResult := tool.Get("input_schema")
var inputSchema interface{}
if inputSchemaResult.Exists() && inputSchemaResult.Type != gjson.Null {
inputSchema = inputSchemaResult.Value()
}
inputSchema = ensureKiroInputSchema(inputSchema)
// Shorten tool name if it exceeds 64 characters (common with MCP tools)
originalName := name
name = shortenToolNameIfNeeded(name)
if name != originalName {
log.Debugf("kiro: shortened tool name from '%s' to '%s'", originalName, name)
}
// CRITICAL FIX: Kiro API requires non-empty description
if strings.TrimSpace(description) == "" {
description = fmt.Sprintf("Tool: %s", name)
log.Debugf("kiro: tool '%s' has empty description, using default: %s", name, description)
}
// Rename web_search → remote_web_search for Kiro API compatibility
if name == "web_search" {
name = "remote_web_search"
// Prefer dynamically fetched description, fall back to hardcoded constant
if cached := GetWebSearchDescription(); cached != "" {
description = cached
} else {
description = remoteWebSearchDescription
}
log.Debugf("kiro: renamed tool web_search → remote_web_search")
}
// Truncate long descriptions (individual tool limit)
if len(description) > kirocommon.KiroMaxToolDescLen {
truncLen := kirocommon.KiroMaxToolDescLen - 30
for truncLen > 0 && !utf8.RuneStart(description[truncLen]) {
truncLen--
}
description = description[:truncLen] + "... (description truncated)"
}
kiroTools = append(kiroTools, KiroToolWrapper{
ToolSpecification: KiroToolSpecification{
Name: name,
Description: description,
InputSchema: KiroInputSchema{JSON: inputSchema},
},
})
}
// Apply dynamic compression if total tools size exceeds threshold
// This prevents 500 errors when Claude Code sends too many tools
kiroTools = compressToolsIfNeeded(kiroTools)
return kiroTools
}
// processMessages processes Claude messages and builds Kiro history
func processMessages(messages gjson.Result, modelID, origin string) ([]KiroHistoryMessage, *KiroUserInputMessage, []KiroToolResult) {
var history []KiroHistoryMessage
var currentUserMsg *KiroUserInputMessage
var currentToolResults []KiroToolResult
// Merge adjacent messages with the same role
messagesArray := kirocommon.MergeAdjacentMessages(messages.Array())
// FIX: Kiro API requires history to start with a user message.
// Some clients (e.g., OpenClaw) send conversations starting with an assistant message,
// which is valid for the Claude API but causes "Improperly formed request" on Kiro.
// Prepend a placeholder user message so the history alternation is correct.
if len(messagesArray) > 0 && messagesArray[0].Get("role").String() == "assistant" {
placeholder := `{"role":"user","content":"."}`
messagesArray = append([]gjson.Result{gjson.Parse(placeholder)}, messagesArray...)
log.Infof("kiro: messages started with assistant role, prepended placeholder user message for Kiro API compatibility")
}
for i, msg := range messagesArray {
role := msg.Get("role").String()
isLastMessage := i == len(messagesArray)-1
if role == "user" {
userMsg, toolResults := BuildUserMessageStruct(msg, modelID, origin)
// CRITICAL: Kiro API requires content to be non-empty for ALL user messages
// This includes both history messages and the current message.
// When user message contains only tool_result (no text), content will be empty.
// This commonly happens in compaction requests from OpenCode.
if strings.TrimSpace(userMsg.Content) == "" {
if len(toolResults) > 0 {
userMsg.Content = kirocommon.DefaultUserContentWithToolResults
} else {
userMsg.Content = kirocommon.DefaultUserContent
}
log.Debugf("kiro: user content was empty, using default: %s", userMsg.Content)
}
if isLastMessage {
currentUserMsg = &userMsg
currentToolResults = toolResults
} else {
// For history messages, embed tool results in context
if len(toolResults) > 0 {
userMsg.UserInputMessageContext = &KiroUserInputMessageContext{
ToolResults: toolResults,
}
}
history = append(history, KiroHistoryMessage{
UserInputMessage: &userMsg,
})
}
} else if role == "assistant" {
assistantMsg := BuildAssistantMessageStruct(msg)
if isLastMessage {
history = append(history, KiroHistoryMessage{
AssistantResponseMessage: &assistantMsg,
})
// Create a "Continue" user message as currentMessage
currentUserMsg = &KiroUserInputMessage{
Content: "Continue",
ModelID: modelID,
Origin: origin,
}
} else {
history = append(history, KiroHistoryMessage{
AssistantResponseMessage: &assistantMsg,
})
}
}
}
// POST-PROCESSING: Remove orphaned tool_results that have no matching tool_use
// in any assistant message. This happens when Claude Code compaction truncates
// the conversation and removes the assistant message containing the tool_use,
// but keeps the user message with the corresponding tool_result.
// Without this fix, Kiro API returns "Improperly formed request".
validToolUseIDs := make(map[string]bool)
for _, h := range history {
if h.AssistantResponseMessage != nil {
for _, tu := range h.AssistantResponseMessage.ToolUses {
validToolUseIDs[tu.ToolUseID] = true
}
}
}
// Filter orphaned tool results from history user messages
for i, h := range history {
if h.UserInputMessage != nil && h.UserInputMessage.UserInputMessageContext != nil {
ctx := h.UserInputMessage.UserInputMessageContext
if len(ctx.ToolResults) > 0 {
filtered := make([]KiroToolResult, 0, len(ctx.ToolResults))
for _, tr := range ctx.ToolResults {
if validToolUseIDs[tr.ToolUseID] {
filtered = append(filtered, tr)
} else {
log.Debugf("kiro: dropping orphaned tool_result in history[%d]: toolUseId=%s (no matching tool_use)", i, tr.ToolUseID)
}
}
ctx.ToolResults = filtered
if len(ctx.ToolResults) == 0 && len(ctx.Tools) == 0 {
h.UserInputMessage.UserInputMessageContext = nil
}
}
}
}
// Filter orphaned tool results from current message
if len(currentToolResults) > 0 {
filtered := make([]KiroToolResult, 0, len(currentToolResults))
for _, tr := range currentToolResults {
if validToolUseIDs[tr.ToolUseID] {
filtered = append(filtered, tr)
} else {
log.Debugf("kiro: dropping orphaned tool_result in currentMessage: toolUseId=%s (no matching tool_use)", tr.ToolUseID)
}
}
if len(filtered) != len(currentToolResults) {
log.Infof("kiro: dropped %d orphaned tool_result(s) from currentMessage (compaction artifact)", len(currentToolResults)-len(filtered))
}
currentToolResults = filtered
}
return history, currentUserMsg, currentToolResults
}
// buildFinalContent builds the final content with system prompt
func buildFinalContent(content, systemPrompt string, toolResults []KiroToolResult) string {
var contentBuilder strings.Builder
if systemPrompt != "" {
contentBuilder.WriteString("--- SYSTEM PROMPT ---\n")
contentBuilder.WriteString(systemPrompt)
contentBuilder.WriteString("\n--- END SYSTEM PROMPT ---\n\n")
}
contentBuilder.WriteString(content)
finalContent := contentBuilder.String()
// CRITICAL: Kiro API requires content to be non-empty
if strings.TrimSpace(finalContent) == "" {
if len(toolResults) > 0 {
finalContent = "Tool results provided."
} else {
finalContent = "Continue"
}
log.Debugf("kiro: content was empty, using default: %s", finalContent)
}
return finalContent
}
// deduplicateToolResults removes duplicate tool results
func deduplicateToolResults(toolResults []KiroToolResult) []KiroToolResult {
if len(toolResults) == 0 {
return toolResults
}
seenIDs := make(map[string]bool)
unique := make([]KiroToolResult, 0, len(toolResults))
for _, tr := range toolResults {
if !seenIDs[tr.ToolUseID] {
seenIDs[tr.ToolUseID] = true
unique = append(unique, tr)
} else {
log.Debugf("kiro: skipping duplicate toolResult in currentMessage: %s", tr.ToolUseID)
}
}
return unique
}
// extractClaudeToolChoiceHint extracts tool_choice from Claude request and returns a system prompt hint.
// Claude tool_choice values:
// - {"type": "auto"}: Model decides (default, no hint needed)
// - {"type": "any"}: Must use at least one tool
// - {"type": "tool", "name": "..."}: Must use specific tool
func extractClaudeToolChoiceHint(claudeBody []byte) string {
toolChoice := gjson.GetBytes(claudeBody, "tool_choice")
if !toolChoice.Exists() {
return ""
}
toolChoiceType := toolChoice.Get("type").String()
switch toolChoiceType {
case "any":
return "[INSTRUCTION: You MUST use at least one of the available tools to respond. Do not respond with text only - always make a tool call.]"
case "tool":
toolName := toolChoice.Get("name").String()
if toolName != "" {
return fmt.Sprintf("[INSTRUCTION: You MUST use the tool named '%s' to respond. Do not use any other tool or respond with text only.]", toolName)
}
case "auto":
// Default behavior, no hint needed
return ""
}
return ""
}
// BuildUserMessageStruct builds a user message and extracts tool results
func BuildUserMessageStruct(msg gjson.Result, modelID, origin string) (KiroUserInputMessage, []KiroToolResult) {
content := msg.Get("content")
var contentBuilder strings.Builder
var toolResults []KiroToolResult
var images []KiroImage
// Track seen toolUseIds to deduplicate
seenToolUseIDs := make(map[string]bool)
if content.IsArray() {
for _, part := range content.Array() {
partType := part.Get("type").String()
switch partType {
case "text":
contentBuilder.WriteString(part.Get("text").String())
case "image":
mediaType := part.Get("source.media_type").String()
data := part.Get("source.data").String()
format := ""
if idx := strings.LastIndex(mediaType, "/"); idx != -1 {
format = mediaType[idx+1:]
}
if format != "" && data != "" {
images = append(images, KiroImage{
Format: format,
Source: KiroImageSource{
Bytes: data,
},
})
}
case "tool_result":
toolUseID := part.Get("tool_use_id").String()
// Skip duplicate toolUseIds
if seenToolUseIDs[toolUseID] {
log.Debugf("kiro: skipping duplicate tool_result with toolUseId: %s", toolUseID)
continue
}
seenToolUseIDs[toolUseID] = true
isError := part.Get("is_error").Bool()
resultContent := part.Get("content")
var textContents []KiroTextContent
// Check if this tool_result contains error from our SOFT_LIMIT_REACHED tool_use
// The client will return an error when trying to execute a tool with marker input
resultStr := resultContent.String()
isSoftLimitError := strings.Contains(resultStr, "SOFT_LIMIT_REACHED") ||
strings.Contains(resultStr, "_status") ||
strings.Contains(resultStr, "truncated") ||
strings.Contains(resultStr, "missing required") ||
strings.Contains(resultStr, "invalid input") ||
strings.Contains(resultStr, "Error writing file")
if isError && isSoftLimitError {
// Replace error content with SOFT_LIMIT_REACHED guidance
log.Infof("kiro: detected SOFT_LIMIT_REACHED in tool_result for %s, replacing with guidance", toolUseID)
softLimitMsg := `SOFT_LIMIT_REACHED
Your previous tool call was incomplete due to API output size limits.
The content was PARTIALLY transmitted but NOT executed.
REQUIRED ACTION:
1. Split your content into smaller chunks (max 300 lines per call)
2. For file writes: Create file with first chunk, then use append for remaining
3. Do NOT regenerate content you already attempted - continue from where you stopped
STATUS: This is NOT an error. Continue with smaller chunks.`
textContents = append(textContents, KiroTextContent{Text: softLimitMsg})
// Mark as SUCCESS so Claude doesn't treat it as a failure
isError = false
} else if resultContent.IsArray() {
for _, item := range resultContent.Array() {
if item.Get("type").String() == "text" {
textContents = append(textContents, KiroTextContent{Text: item.Get("text").String()})
} else if item.Type == gjson.String {
textContents = append(textContents, KiroTextContent{Text: item.String()})
}
}
} else if resultContent.Type == gjson.String {
textContents = append(textContents, KiroTextContent{Text: resultContent.String()})
}
if len(textContents) == 0 {
textContents = append(textContents, KiroTextContent{Text: "Tool use was cancelled by the user"})
}
status := "success"
if isError {
status = "error"
}
toolResults = append(toolResults, KiroToolResult{
ToolUseID: toolUseID,
Content: textContents,
Status: status,
})
}
}
} else {
contentBuilder.WriteString(content.String())
}
userMsg := KiroUserInputMessage{
Content: contentBuilder.String(),
ModelID: modelID,
Origin: origin,
}
if len(images) > 0 {
userMsg.Images = images
}
return userMsg, toolResults
}
// BuildAssistantMessageStruct builds an assistant message with tool uses
func BuildAssistantMessageStruct(msg gjson.Result) KiroAssistantResponseMessage {
content := msg.Get("content")
var contentBuilder strings.Builder
var toolUses []KiroToolUse
if content.IsArray() {
for _, part := range content.Array() {
partType := part.Get("type").String()
switch partType {
case "text":
contentBuilder.WriteString(part.Get("text").String())
case "tool_use":
toolUseID := part.Get("id").String()
toolName := part.Get("name").String()
toolInput := part.Get("input")
var inputMap map[string]interface{}
if toolInput.IsObject() {
inputMap = make(map[string]interface{})
toolInput.ForEach(func(key, value gjson.Result) bool {
inputMap[key.String()] = value.Value()
return true
})
}
// Rename web_search → remote_web_search to match convertClaudeToolsToKiro
if toolName == "web_search" {
toolName = "remote_web_search"
}
toolUses = append(toolUses, KiroToolUse{
ToolUseID: toolUseID,
Name: toolName,
Input: inputMap,
})
}
}
} else {
contentBuilder.WriteString(content.String())
}
// CRITICAL FIX: Kiro API requires non-empty content for assistant messages
// This can happen with compaction requests where assistant messages have only tool_use
// (no text content). Without this fix, Kiro API returns "Improperly formed request" error.
finalContent := contentBuilder.String()
if strings.TrimSpace(finalContent) == "" {
if len(toolUses) > 0 {
finalContent = kirocommon.DefaultAssistantContentWithTools
} else {
finalContent = kirocommon.DefaultAssistantContent
}
log.Debugf("kiro: assistant content was empty, using default: %s", finalContent)
}
return KiroAssistantResponseMessage{
Content: finalContent,
ToolUses: toolUses,
}
}

View File

@@ -0,0 +1,230 @@
// Package claude provides response translation functionality for Kiro API to Claude format.
// This package handles the conversion of Kiro API responses into Claude-compatible format,
// including support for thinking blocks and tool use.
package claude
import (
"crypto/sha256"
"encoding/base64"
"encoding/json"
"strings"
"github.com/google/uuid"
"github.com/router-for-me/CLIProxyAPI/v6/sdk/cliproxy/usage"
log "github.com/sirupsen/logrus"
kirocommon "github.com/router-for-me/CLIProxyAPI/v6/internal/translator/kiro/common"
)
// generateThinkingSignature generates a signature for thinking content.
// This is required by Claude API for thinking blocks in non-streaming responses.
// The signature is a base64-encoded hash of the thinking content.
func generateThinkingSignature(thinkingContent string) string {
if thinkingContent == "" {
return ""
}
// Generate a deterministic signature based on content hash
hash := sha256.Sum256([]byte(thinkingContent))
return base64.StdEncoding.EncodeToString(hash[:])
}
// Local references to kirocommon constants for thinking block parsing
var (
thinkingStartTag = kirocommon.ThinkingStartTag
thinkingEndTag = kirocommon.ThinkingEndTag
)
// BuildClaudeResponse constructs a Claude-compatible response.
// Supports tool_use blocks when tools are present in the response.
// Supports thinking blocks - parses <thinking> tags and converts to Claude thinking content blocks.
// stopReason is passed from upstream; fallback logic applied if empty.
func BuildClaudeResponse(content string, toolUses []KiroToolUse, model string, usageInfo usage.Detail, stopReason string) []byte {
var contentBlocks []map[string]interface{}
// Extract thinking blocks and text from content
if content != "" {
blocks := ExtractThinkingFromContent(content)
contentBlocks = append(contentBlocks, blocks...)
// Log if thinking blocks were extracted
for _, block := range blocks {
if block["type"] == "thinking" {
thinkingContent := block["thinking"].(string)
log.Infof("kiro: buildClaudeResponse extracted thinking block (len: %d)", len(thinkingContent))
}
}
}
// Add tool_use blocks - emit truncated tools with SOFT_LIMIT_REACHED marker
hasTruncatedTools := false
for _, toolUse := range toolUses {
if toolUse.IsTruncated && toolUse.TruncationInfo != nil {
// Emit tool_use with SOFT_LIMIT_REACHED marker input
hasTruncatedTools = true
log.Infof("kiro: buildClaudeResponse emitting truncated tool with SOFT_LIMIT_REACHED: %s (ID: %s)", toolUse.Name, toolUse.ToolUseID)
markerInput := map[string]interface{}{
"_status": "SOFT_LIMIT_REACHED",
"_message": "Tool output was truncated. Split content into smaller chunks (max 300 lines). Due to potential model hallucination, you MUST re-fetch the current working directory and generate the correct file_path.",
}
contentBlocks = append(contentBlocks, map[string]interface{}{
"type": "tool_use",
"id": toolUse.ToolUseID,
"name": toolUse.Name,
"input": markerInput,
})
} else {
// Normal tool use
contentBlocks = append(contentBlocks, map[string]interface{}{
"type": "tool_use",
"id": toolUse.ToolUseID,
"name": toolUse.Name,
"input": toolUse.Input,
})
}
}
// Log if we used SOFT_LIMIT_REACHED
if hasTruncatedTools {
log.Infof("kiro: buildClaudeResponse using SOFT_LIMIT_REACHED - keeping stop_reason=tool_use")
}
// Ensure at least one content block (Claude API requires non-empty content)
if len(contentBlocks) == 0 {
contentBlocks = append(contentBlocks, map[string]interface{}{
"type": "text",
"text": "",
})
}
// Use upstream stopReason; apply fallback logic if not provided
// SOFT_LIMIT_REACHED: Keep stop_reason = "tool_use" so Claude continues the loop
if stopReason == "" {
stopReason = "end_turn"
if len(toolUses) > 0 {
stopReason = "tool_use"
}
log.Debugf("kiro: buildClaudeResponse using fallback stop_reason: %s", stopReason)
}
// Log warning if response was truncated due to max_tokens
if stopReason == "max_tokens" {
log.Warnf("kiro: response truncated due to max_tokens limit (buildClaudeResponse)")
}
response := map[string]interface{}{
"id": "msg_" + uuid.New().String()[:24],
"type": "message",
"role": "assistant",
"model": model,
"content": contentBlocks,
"stop_reason": stopReason,
"usage": map[string]interface{}{
"input_tokens": usageInfo.InputTokens,
"output_tokens": usageInfo.OutputTokens,
},
}
result, _ := json.Marshal(response)
return result
}
// ExtractThinkingFromContent parses content to extract thinking blocks and text.
// Returns a list of content blocks in the order they appear in the content.
// Handles interleaved thinking and text blocks correctly.
func ExtractThinkingFromContent(content string) []map[string]interface{} {
var blocks []map[string]interface{}
if content == "" {
return blocks
}
// Check if content contains thinking tags at all
if !strings.Contains(content, thinkingStartTag) {
// No thinking tags, return as plain text
return []map[string]interface{}{
{
"type": "text",
"text": content,
},
}
}
log.Debugf("kiro: extractThinkingFromContent - found thinking tags in content (len: %d)", len(content))
remaining := content
for len(remaining) > 0 {
// Look for <thinking> tag
startIdx := strings.Index(remaining, thinkingStartTag)
if startIdx == -1 {
// No more thinking tags, add remaining as text
if strings.TrimSpace(remaining) != "" {
blocks = append(blocks, map[string]interface{}{
"type": "text",
"text": remaining,
})
}
break
}
// Add text before thinking tag (if any meaningful content)
if startIdx > 0 {
textBefore := remaining[:startIdx]
if strings.TrimSpace(textBefore) != "" {
blocks = append(blocks, map[string]interface{}{
"type": "text",
"text": textBefore,
})
}
}
// Move past the opening tag
remaining = remaining[startIdx+len(thinkingStartTag):]
// Find closing tag
endIdx := strings.Index(remaining, thinkingEndTag)
if endIdx == -1 {
// No closing tag found, treat rest as thinking content (incomplete response)
if strings.TrimSpace(remaining) != "" {
// Generate signature for thinking content (required by Claude API)
signature := generateThinkingSignature(remaining)
blocks = append(blocks, map[string]interface{}{
"type": "thinking",
"thinking": remaining,
"signature": signature,
})
log.Warnf("kiro: extractThinkingFromContent - missing closing </thinking> tag")
}
break
}
// Extract thinking content between tags
thinkContent := remaining[:endIdx]
if strings.TrimSpace(thinkContent) != "" {
// Generate signature for thinking content (required by Claude API)
signature := generateThinkingSignature(thinkContent)
blocks = append(blocks, map[string]interface{}{
"type": "thinking",
"thinking": thinkContent,
"signature": signature,
})
log.Debugf("kiro: extractThinkingFromContent - extracted thinking block (len: %d)", len(thinkContent))
}
// Move past the closing tag
remaining = remaining[endIdx+len(thinkingEndTag):]
}
// If no blocks were created (all whitespace), return empty text block
if len(blocks) == 0 {
blocks = append(blocks, map[string]interface{}{
"type": "text",
"text": "",
})
}
return blocks
}

View File

@@ -0,0 +1,306 @@
// Package claude provides streaming SSE event building for Claude format.
// This package handles the construction of Claude-compatible Server-Sent Events (SSE)
// for streaming responses from Kiro API.
package claude
import (
"encoding/json"
"github.com/google/uuid"
"github.com/router-for-me/CLIProxyAPI/v6/sdk/cliproxy/usage"
)
// BuildClaudeMessageStartEvent creates the message_start SSE event
func BuildClaudeMessageStartEvent(model string, inputTokens int64) []byte {
event := map[string]interface{}{
"type": "message_start",
"message": map[string]interface{}{
"id": "msg_" + uuid.New().String()[:24],
"type": "message",
"role": "assistant",
"content": []interface{}{},
"model": model,
"stop_reason": nil,
"stop_sequence": nil,
"usage": map[string]interface{}{"input_tokens": inputTokens, "output_tokens": 0},
},
}
result, _ := json.Marshal(event)
return []byte("event: message_start\ndata: " + string(result))
}
// BuildClaudeContentBlockStartEvent creates a content_block_start SSE event
func BuildClaudeContentBlockStartEvent(index int, blockType, toolUseID, toolName string) []byte {
var contentBlock map[string]interface{}
switch blockType {
case "tool_use":
contentBlock = map[string]interface{}{
"type": "tool_use",
"id": toolUseID,
"name": toolName,
"input": map[string]interface{}{},
}
case "thinking":
contentBlock = map[string]interface{}{
"type": "thinking",
"thinking": "",
}
default:
contentBlock = map[string]interface{}{
"type": "text",
"text": "",
}
}
event := map[string]interface{}{
"type": "content_block_start",
"index": index,
"content_block": contentBlock,
}
result, _ := json.Marshal(event)
return []byte("event: content_block_start\ndata: " + string(result))
}
// BuildClaudeStreamEvent creates a text_delta content_block_delta SSE event
func BuildClaudeStreamEvent(contentDelta string, index int) []byte {
event := map[string]interface{}{
"type": "content_block_delta",
"index": index,
"delta": map[string]interface{}{
"type": "text_delta",
"text": contentDelta,
},
}
result, _ := json.Marshal(event)
return []byte("event: content_block_delta\ndata: " + string(result))
}
// BuildClaudeInputJsonDeltaEvent creates an input_json_delta event for tool use streaming
func BuildClaudeInputJsonDeltaEvent(partialJSON string, index int) []byte {
event := map[string]interface{}{
"type": "content_block_delta",
"index": index,
"delta": map[string]interface{}{
"type": "input_json_delta",
"partial_json": partialJSON,
},
}
result, _ := json.Marshal(event)
return []byte("event: content_block_delta\ndata: " + string(result))
}
// BuildClaudeContentBlockStopEvent creates a content_block_stop SSE event
func BuildClaudeContentBlockStopEvent(index int) []byte {
event := map[string]interface{}{
"type": "content_block_stop",
"index": index,
}
result, _ := json.Marshal(event)
return []byte("event: content_block_stop\ndata: " + string(result))
}
// BuildClaudeThinkingBlockStopEvent creates a content_block_stop SSE event for thinking blocks.
func BuildClaudeThinkingBlockStopEvent(index int) []byte {
event := map[string]interface{}{
"type": "content_block_stop",
"index": index,
}
result, _ := json.Marshal(event)
return []byte("event: content_block_stop\ndata: " + string(result))
}
// BuildClaudeMessageDeltaEvent creates the message_delta event with stop_reason and usage
func BuildClaudeMessageDeltaEvent(stopReason string, usageInfo usage.Detail) []byte {
deltaEvent := map[string]interface{}{
"type": "message_delta",
"delta": map[string]interface{}{
"stop_reason": stopReason,
"stop_sequence": nil,
},
"usage": map[string]interface{}{
"input_tokens": usageInfo.InputTokens,
"output_tokens": usageInfo.OutputTokens,
},
}
deltaResult, _ := json.Marshal(deltaEvent)
return []byte("event: message_delta\ndata: " + string(deltaResult))
}
// BuildClaudeMessageStopOnlyEvent creates only the message_stop event
func BuildClaudeMessageStopOnlyEvent() []byte {
stopEvent := map[string]interface{}{
"type": "message_stop",
}
stopResult, _ := json.Marshal(stopEvent)
return []byte("event: message_stop\ndata: " + string(stopResult))
}
// BuildClaudePingEventWithUsage creates a ping event with embedded usage information.
// This is used for real-time usage estimation during streaming.
func BuildClaudePingEventWithUsage(inputTokens, outputTokens int64) []byte {
event := map[string]interface{}{
"type": "ping",
"usage": map[string]interface{}{
"input_tokens": inputTokens,
"output_tokens": outputTokens,
"total_tokens": inputTokens + outputTokens,
"estimated": true,
},
}
result, _ := json.Marshal(event)
return []byte("event: ping\ndata: " + string(result))
}
// BuildClaudeThinkingDeltaEvent creates a thinking_delta event for Claude API compatibility.
// This is used when streaming thinking content wrapped in <thinking> tags.
func BuildClaudeThinkingDeltaEvent(thinkingDelta string, index int) []byte {
event := map[string]interface{}{
"type": "content_block_delta",
"index": index,
"delta": map[string]interface{}{
"type": "thinking_delta",
"thinking": thinkingDelta,
},
}
result, _ := json.Marshal(event)
return []byte("event: content_block_delta\ndata: " + string(result))
}
// PendingTagSuffix detects if the buffer ends with a partial prefix of the given tag.
// Returns the length of the partial match (0 if no match).
// Based on amq2api implementation for handling cross-chunk tag boundaries.
func PendingTagSuffix(buffer, tag string) int {
if buffer == "" || tag == "" {
return 0
}
maxLen := len(buffer)
if maxLen > len(tag)-1 {
maxLen = len(tag) - 1
}
for length := maxLen; length > 0; length-- {
if len(buffer) >= length && buffer[len(buffer)-length:] == tag[:length] {
return length
}
}
return 0
}
// GenerateSearchIndicatorEvents generates ONLY the search indicator SSE events
// (server_tool_use + web_search_tool_result) without text summary or message termination.
// These events trigger Claude Code's search indicator UI.
// The caller is responsible for sending message_start before and message_delta/stop after.
func GenerateSearchIndicatorEvents(
query string,
toolUseID string,
searchResults *WebSearchResults,
startIndex int,
) [][]byte {
events := make([][]byte, 0, 5)
// 1. content_block_start (server_tool_use)
event1 := map[string]interface{}{
"type": "content_block_start",
"index": startIndex,
"content_block": map[string]interface{}{
"id": toolUseID,
"type": "server_tool_use",
"name": "web_search",
"input": map[string]interface{}{},
},
}
data1, _ := json.Marshal(event1)
events = append(events, []byte("event: content_block_start\ndata: "+string(data1)+"\n\n"))
// 2. content_block_delta (input_json_delta)
inputJSON, _ := json.Marshal(map[string]string{"query": query})
event2 := map[string]interface{}{
"type": "content_block_delta",
"index": startIndex,
"delta": map[string]interface{}{
"type": "input_json_delta",
"partial_json": string(inputJSON),
},
}
data2, _ := json.Marshal(event2)
events = append(events, []byte("event: content_block_delta\ndata: "+string(data2)+"\n\n"))
// 3. content_block_stop (server_tool_use)
event3 := map[string]interface{}{
"type": "content_block_stop",
"index": startIndex,
}
data3, _ := json.Marshal(event3)
events = append(events, []byte("event: content_block_stop\ndata: "+string(data3)+"\n\n"))
// 4. content_block_start (web_search_tool_result)
searchContent := make([]map[string]interface{}, 0)
if searchResults != nil {
for _, r := range searchResults.Results {
snippet := ""
if r.Snippet != nil {
snippet = *r.Snippet
}
searchContent = append(searchContent, map[string]interface{}{
"type": "web_search_result",
"title": r.Title,
"url": r.URL,
"encrypted_content": snippet,
"page_age": nil,
})
}
}
event4 := map[string]interface{}{
"type": "content_block_start",
"index": startIndex + 1,
"content_block": map[string]interface{}{
"type": "web_search_tool_result",
"tool_use_id": toolUseID,
"content": searchContent,
},
}
data4, _ := json.Marshal(event4)
events = append(events, []byte("event: content_block_start\ndata: "+string(data4)+"\n\n"))
// 5. content_block_stop (web_search_tool_result)
event5 := map[string]interface{}{
"type": "content_block_stop",
"index": startIndex + 1,
}
data5, _ := json.Marshal(event5)
events = append(events, []byte("event: content_block_stop\ndata: "+string(data5)+"\n\n"))
return events
}
// BuildFallbackTextEvents generates SSE events for a fallback text response
// when the Kiro API fails during the search loop. Uses BuildClaude*Event()
// functions to align with streamToChannel patterns.
// Returns raw SSE byte slices ready to be sent to the client channel.
func BuildFallbackTextEvents(contentBlockIndex int, query string, results *WebSearchResults) [][]byte {
summary := FormatSearchContextPrompt(query, results)
outputTokens := len(summary) / 4
if len(summary) > 0 && outputTokens == 0 {
outputTokens = 1
}
var events [][]byte
// content_block_start (text)
events = append(events, BuildClaudeContentBlockStartEvent(contentBlockIndex, "text", "", ""))
// content_block_delta (text_delta)
events = append(events, BuildClaudeStreamEvent(summary, contentBlockIndex))
// content_block_stop
events = append(events, BuildClaudeContentBlockStopEvent(contentBlockIndex))
// message_delta with end_turn
events = append(events, BuildClaudeMessageDeltaEvent("end_turn", usage.Detail{
OutputTokens: int64(outputTokens),
}))
// message_stop
events = append(events, BuildClaudeMessageStopOnlyEvent())
return events
}

View File

@@ -0,0 +1,350 @@
package claude
import (
"encoding/json"
"strings"
log "github.com/sirupsen/logrus"
)
// sseEvent represents a Server-Sent Event
type sseEvent struct {
Event string
Data interface{}
}
// ToSSEString converts the event to SSE wire format
func (e *sseEvent) ToSSEString() string {
dataBytes, _ := json.Marshal(e.Data)
return "event: " + e.Event + "\ndata: " + string(dataBytes) + "\n\n"
}
// AdjustStreamIndices adjusts content block indices in SSE event data by adding an offset.
// It also suppresses duplicate message_start events (returns shouldForward=false).
// This is used to combine search indicator events (indices 0,1) with Kiro model response events.
//
// The data parameter is a single SSE "data:" line payload (JSON).
// Returns: adjusted data, shouldForward (false = skip this event).
func AdjustStreamIndices(data []byte, offset int) ([]byte, bool) {
if len(data) == 0 {
return data, true
}
// Quick check: parse the JSON
var event map[string]interface{}
if err := json.Unmarshal(data, &event); err != nil {
// Not valid JSON, pass through
return data, true
}
eventType, _ := event["type"].(string)
// Suppress duplicate message_start events
if eventType == "message_start" {
return data, false
}
// Adjust index for content_block events
switch eventType {
case "content_block_start", "content_block_delta", "content_block_stop":
if idx, ok := event["index"].(float64); ok {
event["index"] = int(idx) + offset
adjusted, err := json.Marshal(event)
if err != nil {
return data, true
}
return adjusted, true
}
}
// Pass through all other events unchanged (message_delta, message_stop, ping, etc.)
return data, true
}
// AdjustSSEChunk processes a raw SSE chunk (potentially containing multiple "event:/data:" pairs)
// and adjusts content block indices. Suppresses duplicate message_start events.
// Returns the adjusted chunk and whether it should be forwarded.
func AdjustSSEChunk(chunk []byte, offset int) ([]byte, bool) {
chunkStr := string(chunk)
// Fast path: if no "data:" prefix, pass through
if !strings.Contains(chunkStr, "data: ") {
return chunk, true
}
var result strings.Builder
hasContent := false
lines := strings.Split(chunkStr, "\n")
for i := 0; i < len(lines); i++ {
line := lines[i]
if strings.HasPrefix(line, "data: ") {
dataPayload := strings.TrimPrefix(line, "data: ")
dataPayload = strings.TrimSpace(dataPayload)
if dataPayload == "[DONE]" {
result.WriteString(line + "\n")
hasContent = true
continue
}
adjusted, shouldForward := AdjustStreamIndices([]byte(dataPayload), offset)
if !shouldForward {
// Skip this event and its preceding "event:" line
// Also skip the trailing empty line
continue
}
result.WriteString("data: " + string(adjusted) + "\n")
hasContent = true
} else if strings.HasPrefix(line, "event: ") {
// Check if the next data line will be suppressed
if i+1 < len(lines) && strings.HasPrefix(lines[i+1], "data: ") {
dataPayload := strings.TrimPrefix(lines[i+1], "data: ")
dataPayload = strings.TrimSpace(dataPayload)
var event map[string]interface{}
if err := json.Unmarshal([]byte(dataPayload), &event); err == nil {
if eventType, ok := event["type"].(string); ok && eventType == "message_start" {
// Skip both the event: and data: lines
i++ // skip the data: line too
continue
}
}
}
result.WriteString(line + "\n")
hasContent = true
} else {
result.WriteString(line + "\n")
if strings.TrimSpace(line) != "" {
hasContent = true
}
}
}
if !hasContent {
return nil, false
}
return []byte(result.String()), true
}
// BufferedStreamResult contains the analysis of buffered SSE chunks from a Kiro API response.
type BufferedStreamResult struct {
// StopReason is the detected stop_reason from the stream (e.g., "end_turn", "tool_use")
StopReason string
// WebSearchQuery is the extracted query if the model requested another web_search
WebSearchQuery string
// WebSearchToolUseId is the tool_use ID from the model's response (needed for toolResults)
WebSearchToolUseId string
// HasWebSearchToolUse indicates whether the model requested web_search
HasWebSearchToolUse bool
// WebSearchToolUseIndex is the content_block index of the web_search tool_use
WebSearchToolUseIndex int
}
// AnalyzeBufferedStream scans buffered SSE chunks to detect stop_reason and web_search tool_use.
// This is used in the search loop to determine if the model wants another search round.
func AnalyzeBufferedStream(chunks [][]byte) BufferedStreamResult {
result := BufferedStreamResult{WebSearchToolUseIndex: -1}
// Track tool use state across chunks
var currentToolName string
var currentToolIndex int = -1
var toolInputBuilder strings.Builder
for _, chunk := range chunks {
chunkStr := string(chunk)
lines := strings.Split(chunkStr, "\n")
for _, line := range lines {
if !strings.HasPrefix(line, "data: ") {
continue
}
dataPayload := strings.TrimPrefix(line, "data: ")
dataPayload = strings.TrimSpace(dataPayload)
if dataPayload == "[DONE]" || dataPayload == "" {
continue
}
var event map[string]interface{}
if err := json.Unmarshal([]byte(dataPayload), &event); err != nil {
continue
}
eventType, _ := event["type"].(string)
switch eventType {
case "message_delta":
// Extract stop_reason from message_delta
if delta, ok := event["delta"].(map[string]interface{}); ok {
if sr, ok := delta["stop_reason"].(string); ok && sr != "" {
result.StopReason = sr
}
}
case "content_block_start":
// Detect tool_use content blocks
if cb, ok := event["content_block"].(map[string]interface{}); ok {
if cbType, ok := cb["type"].(string); ok && cbType == "tool_use" {
if name, ok := cb["name"].(string); ok {
currentToolName = strings.ToLower(name)
if idx, ok := event["index"].(float64); ok {
currentToolIndex = int(idx)
}
// Capture tool use ID for toolResults handshake
if id, ok := cb["id"].(string); ok {
result.WebSearchToolUseId = id
}
toolInputBuilder.Reset()
}
}
}
case "content_block_delta":
// Accumulate tool input JSON
if currentToolName != "" {
if delta, ok := event["delta"].(map[string]interface{}); ok {
if deltaType, ok := delta["type"].(string); ok && deltaType == "input_json_delta" {
if partial, ok := delta["partial_json"].(string); ok {
toolInputBuilder.WriteString(partial)
}
}
}
}
case "content_block_stop":
// Finalize tool use detection
if currentToolName == "web_search" || currentToolName == "websearch" || currentToolName == "remote_web_search" {
result.HasWebSearchToolUse = true
result.WebSearchToolUseIndex = currentToolIndex
// Extract query from accumulated input JSON
inputJSON := toolInputBuilder.String()
var input map[string]string
if err := json.Unmarshal([]byte(inputJSON), &input); err == nil {
if q, ok := input["query"]; ok {
result.WebSearchQuery = q
}
}
log.Debugf("kiro/websearch: detected web_search tool_use")
}
currentToolName = ""
currentToolIndex = -1
toolInputBuilder.Reset()
}
}
}
return result
}
// FilterChunksForClient processes buffered SSE chunks and removes web_search tool_use
// content blocks. This prevents the client from seeing "Tool use" prompts for web_search
// when the proxy is handling the search loop internally.
// Also suppresses message_start and message_delta/message_stop events since those
// are managed by the outer handleWebSearchStream.
func FilterChunksForClient(chunks [][]byte, wsToolIndex int, indexOffset int) [][]byte {
var filtered [][]byte
for _, chunk := range chunks {
chunkStr := string(chunk)
lines := strings.Split(chunkStr, "\n")
var resultBuilder strings.Builder
hasContent := false
for i := 0; i < len(lines); i++ {
line := lines[i]
if strings.HasPrefix(line, "data: ") {
dataPayload := strings.TrimPrefix(line, "data: ")
dataPayload = strings.TrimSpace(dataPayload)
if dataPayload == "[DONE]" {
// Skip [DONE] — the outer loop manages stream termination
continue
}
var event map[string]interface{}
if err := json.Unmarshal([]byte(dataPayload), &event); err != nil {
resultBuilder.WriteString(line + "\n")
hasContent = true
continue
}
eventType, _ := event["type"].(string)
// Skip message_start (outer loop sends its own)
if eventType == "message_start" {
continue
}
// Skip message_delta and message_stop (outer loop manages these)
if eventType == "message_delta" || eventType == "message_stop" {
continue
}
// Check if this event belongs to the web_search tool_use block
if wsToolIndex >= 0 {
if idx, ok := event["index"].(float64); ok && int(idx) == wsToolIndex {
// Skip events for the web_search tool_use block
continue
}
}
// Apply index offset for remaining events
if indexOffset > 0 {
switch eventType {
case "content_block_start", "content_block_delta", "content_block_stop":
if idx, ok := event["index"].(float64); ok {
event["index"] = int(idx) + indexOffset
adjusted, err := json.Marshal(event)
if err == nil {
resultBuilder.WriteString("data: " + string(adjusted) + "\n")
hasContent = true
continue
}
}
}
}
resultBuilder.WriteString(line + "\n")
hasContent = true
} else if strings.HasPrefix(line, "event: ") {
// Check if the next data line will be suppressed
if i+1 < len(lines) && strings.HasPrefix(lines[i+1], "data: ") {
nextData := strings.TrimPrefix(lines[i+1], "data: ")
nextData = strings.TrimSpace(nextData)
var nextEvent map[string]interface{}
if err := json.Unmarshal([]byte(nextData), &nextEvent); err == nil {
nextType, _ := nextEvent["type"].(string)
if nextType == "message_start" || nextType == "message_delta" || nextType == "message_stop" {
i++ // skip the data line
continue
}
if wsToolIndex >= 0 {
if idx, ok := nextEvent["index"].(float64); ok && int(idx) == wsToolIndex {
i++ // skip the data line
continue
}
}
}
}
resultBuilder.WriteString(line + "\n")
hasContent = true
} else {
resultBuilder.WriteString(line + "\n")
if strings.TrimSpace(line) != "" {
hasContent = true
}
}
}
if hasContent {
filtered = append(filtered, []byte(resultBuilder.String()))
}
}
return filtered
}

View File

@@ -0,0 +1,543 @@
// Package claude provides tool calling support for Kiro to Claude translation.
// This package handles parsing embedded tool calls, JSON repair, and deduplication.
package claude
import (
"encoding/json"
"regexp"
"strings"
"github.com/google/uuid"
kirocommon "github.com/router-for-me/CLIProxyAPI/v6/internal/translator/kiro/common"
log "github.com/sirupsen/logrus"
)
// ToolUseState tracks the state of an in-progress tool use during streaming.
type ToolUseState struct {
ToolUseID string
Name string
InputBuffer strings.Builder
IsComplete bool
TruncationInfo *TruncationInfo // Truncation detection result (set when complete)
}
// Pre-compiled regex patterns for performance
var (
// embeddedToolCallPattern matches [Called tool_name with args: {...}] format
embeddedToolCallPattern = regexp.MustCompile(`\[Called\s+([A-Za-z0-9_.-]+)\s+with\s+args:\s*`)
// trailingCommaPattern matches trailing commas before closing braces/brackets
trailingCommaPattern = regexp.MustCompile(`,\s*([}\]])`)
)
// ParseEmbeddedToolCalls extracts [Called tool_name with args: {...}] format from text.
// Kiro sometimes embeds tool calls in text content instead of using toolUseEvent.
// Returns the cleaned text (with tool calls removed) and extracted tool uses.
func ParseEmbeddedToolCalls(text string, processedIDs map[string]bool) (string, []KiroToolUse) {
if !strings.Contains(text, "[Called") {
return text, nil
}
var toolUses []KiroToolUse
cleanText := text
// Find all [Called markers
matches := embeddedToolCallPattern.FindAllStringSubmatchIndex(text, -1)
if len(matches) == 0 {
return text, nil
}
// Process matches in reverse order to maintain correct indices
for i := len(matches) - 1; i >= 0; i-- {
matchStart := matches[i][0]
toolNameStart := matches[i][2]
toolNameEnd := matches[i][3]
if toolNameStart < 0 || toolNameEnd < 0 {
continue
}
toolName := text[toolNameStart:toolNameEnd]
// Find the JSON object start (after "with args:")
jsonStart := matches[i][1]
if jsonStart >= len(text) {
continue
}
// Skip whitespace to find the opening brace
for jsonStart < len(text) && (text[jsonStart] == ' ' || text[jsonStart] == '\t') {
jsonStart++
}
if jsonStart >= len(text) || text[jsonStart] != '{' {
continue
}
// Find matching closing bracket
jsonEnd := findMatchingBracket(text, jsonStart)
if jsonEnd < 0 {
continue
}
// Extract JSON and find the closing bracket of [Called ...]
jsonStr := text[jsonStart : jsonEnd+1]
// Find the closing ] after the JSON
closingBracket := jsonEnd + 1
for closingBracket < len(text) && text[closingBracket] != ']' {
closingBracket++
}
if closingBracket >= len(text) {
continue
}
// End index of the full tool call (closing ']' inclusive)
matchEnd := closingBracket + 1
// Repair and parse JSON
repairedJSON := RepairJSON(jsonStr)
var inputMap map[string]interface{}
if err := json.Unmarshal([]byte(repairedJSON), &inputMap); err != nil {
log.Debugf("kiro: failed to parse embedded tool call JSON: %v, raw: %s", err, jsonStr)
continue
}
// Generate unique tool ID
toolUseID := "toolu_" + uuid.New().String()[:12]
// Check for duplicates using name+input as key
dedupeKey := toolName + ":" + repairedJSON
if processedIDs != nil {
if processedIDs[dedupeKey] {
log.Debugf("kiro: skipping duplicate embedded tool call: %s", toolName)
// Still remove from text even if duplicate
if matchStart >= 0 && matchEnd <= len(cleanText) && matchStart <= matchEnd {
cleanText = cleanText[:matchStart] + cleanText[matchEnd:]
}
continue
}
processedIDs[dedupeKey] = true
}
toolUses = append(toolUses, KiroToolUse{
ToolUseID: toolUseID,
Name: toolName,
Input: inputMap,
})
log.Infof("kiro: extracted embedded tool call: %s (ID: %s)", toolName, toolUseID)
// Remove from clean text (index-based removal to avoid deleting the wrong occurrence)
if matchStart >= 0 && matchEnd <= len(cleanText) && matchStart <= matchEnd {
cleanText = cleanText[:matchStart] + cleanText[matchEnd:]
}
}
return cleanText, toolUses
}
// findMatchingBracket finds the index of the closing brace/bracket that matches
// the opening one at startPos. Handles nested objects and strings correctly.
func findMatchingBracket(text string, startPos int) int {
if startPos >= len(text) {
return -1
}
openChar := text[startPos]
var closeChar byte
switch openChar {
case '{':
closeChar = '}'
case '[':
closeChar = ']'
default:
return -1
}
depth := 1
inString := false
escapeNext := false
for i := startPos + 1; i < len(text); i++ {
char := text[i]
if escapeNext {
escapeNext = false
continue
}
if char == '\\' && inString {
escapeNext = true
continue
}
if char == '"' {
inString = !inString
continue
}
if !inString {
if char == openChar {
depth++
} else if char == closeChar {
depth--
if depth == 0 {
return i
}
}
}
}
return -1
}
// RepairJSON attempts to fix common JSON issues that may occur in tool call arguments.
// Conservative repair strategy:
// 1. First try to parse JSON directly - if valid, return as-is
// 2. Only attempt repair if parsing fails
// 3. After repair, validate the result - if still invalid, return original
func RepairJSON(jsonString string) string {
// Handle empty or invalid input
if jsonString == "" {
return "{}"
}
str := strings.TrimSpace(jsonString)
if str == "" {
return "{}"
}
// CONSERVATIVE STRATEGY: First try to parse directly
var testParse interface{}
if err := json.Unmarshal([]byte(str), &testParse); err == nil {
log.Debugf("kiro: repairJSON - JSON is already valid, returning unchanged")
return str
}
log.Debugf("kiro: repairJSON - JSON parse failed, attempting repair")
originalStr := str
// First, escape unescaped newlines/tabs within JSON string values
str = escapeNewlinesInStrings(str)
// Remove trailing commas before closing braces/brackets
str = trailingCommaPattern.ReplaceAllString(str, "$1")
// Calculate bracket balance
braceCount := 0
bracketCount := 0
inString := false
escape := false
lastValidIndex := -1
for i := 0; i < len(str); i++ {
char := str[i]
if escape {
escape = false
continue
}
if char == '\\' {
escape = true
continue
}
if char == '"' {
inString = !inString
continue
}
if inString {
continue
}
switch char {
case '{':
braceCount++
case '}':
braceCount--
case '[':
bracketCount++
case ']':
bracketCount--
}
if braceCount >= 0 && bracketCount >= 0 {
lastValidIndex = i
}
}
// If brackets are unbalanced, try to repair
if braceCount > 0 || bracketCount > 0 {
if lastValidIndex > 0 && lastValidIndex < len(str)-1 {
truncated := str[:lastValidIndex+1]
// Recount brackets after truncation
braceCount = 0
bracketCount = 0
inString = false
escape = false
for i := 0; i < len(truncated); i++ {
char := truncated[i]
if escape {
escape = false
continue
}
if char == '\\' {
escape = true
continue
}
if char == '"' {
inString = !inString
continue
}
if inString {
continue
}
switch char {
case '{':
braceCount++
case '}':
braceCount--
case '[':
bracketCount++
case ']':
bracketCount--
}
}
str = truncated
}
// Add missing closing brackets
for braceCount > 0 {
str += "}"
braceCount--
}
for bracketCount > 0 {
str += "]"
bracketCount--
}
}
// Validate repaired JSON
if err := json.Unmarshal([]byte(str), &testParse); err != nil {
log.Warnf("kiro: repairJSON - repair failed to produce valid JSON, returning original")
return originalStr
}
log.Debugf("kiro: repairJSON - successfully repaired JSON")
return str
}
// escapeNewlinesInStrings escapes literal newlines, tabs, and other control characters
// that appear inside JSON string values.
func escapeNewlinesInStrings(raw string) string {
var result strings.Builder
result.Grow(len(raw) + 100)
inString := false
escaped := false
for i := 0; i < len(raw); i++ {
c := raw[i]
if escaped {
result.WriteByte(c)
escaped = false
continue
}
if c == '\\' && inString {
result.WriteByte(c)
escaped = true
continue
}
if c == '"' {
inString = !inString
result.WriteByte(c)
continue
}
if inString {
switch c {
case '\n':
result.WriteString("\\n")
case '\r':
result.WriteString("\\r")
case '\t':
result.WriteString("\\t")
default:
result.WriteByte(c)
}
} else {
result.WriteByte(c)
}
}
return result.String()
}
// ProcessToolUseEvent handles a toolUseEvent from the Kiro stream.
// It accumulates input fragments and emits tool_use blocks when complete.
// Returns events to emit and updated state.
func ProcessToolUseEvent(event map[string]interface{}, currentToolUse *ToolUseState, processedIDs map[string]bool) ([]KiroToolUse, *ToolUseState) {
var toolUses []KiroToolUse
// Extract from nested toolUseEvent or direct format
tu := event
if nested, ok := event["toolUseEvent"].(map[string]interface{}); ok {
tu = nested
}
toolUseID := kirocommon.GetString(tu, "toolUseId")
toolName := kirocommon.GetString(tu, "name")
isStop := false
if stop, ok := tu["stop"].(bool); ok {
isStop = stop
}
// Get input - can be string (fragment) or object (complete)
var inputFragment string
var inputMap map[string]interface{}
if inputRaw, ok := tu["input"]; ok {
switch v := inputRaw.(type) {
case string:
inputFragment = v
case map[string]interface{}:
inputMap = v
}
}
// New tool use starting
if toolUseID != "" && toolName != "" {
if currentToolUse != nil && currentToolUse.ToolUseID != toolUseID {
log.Warnf("kiro: interleaved tool use detected - new ID %s arrived while %s in progress, completing previous",
toolUseID, currentToolUse.ToolUseID)
if !processedIDs[currentToolUse.ToolUseID] {
incomplete := KiroToolUse{
ToolUseID: currentToolUse.ToolUseID,
Name: currentToolUse.Name,
}
if currentToolUse.InputBuffer.Len() > 0 {
raw := currentToolUse.InputBuffer.String()
repaired := RepairJSON(raw)
var input map[string]interface{}
if err := json.Unmarshal([]byte(repaired), &input); err != nil {
log.Warnf("kiro: failed to parse interleaved tool input: %v, raw: %s", err, raw)
input = make(map[string]interface{})
}
incomplete.Input = input
}
toolUses = append(toolUses, incomplete)
processedIDs[currentToolUse.ToolUseID] = true
}
currentToolUse = nil
}
if currentToolUse == nil {
if processedIDs != nil && processedIDs[toolUseID] {
log.Debugf("kiro: skipping duplicate toolUseEvent: %s", toolUseID)
return nil, nil
}
currentToolUse = &ToolUseState{
ToolUseID: toolUseID,
Name: toolName,
}
log.Infof("kiro: starting new tool use: %s (ID: %s)", toolName, toolUseID)
}
}
// Accumulate input fragments
if currentToolUse != nil && inputFragment != "" {
currentToolUse.InputBuffer.WriteString(inputFragment)
log.Debugf("kiro: accumulated input fragment, total length: %d", currentToolUse.InputBuffer.Len())
}
// If complete input object provided directly
if currentToolUse != nil && inputMap != nil {
inputBytes, _ := json.Marshal(inputMap)
currentToolUse.InputBuffer.Reset()
currentToolUse.InputBuffer.Write(inputBytes)
}
// Tool use complete
if isStop && currentToolUse != nil {
fullInput := currentToolUse.InputBuffer.String()
// Repair and parse the accumulated JSON
repairedJSON := RepairJSON(fullInput)
var finalInput map[string]interface{}
if err := json.Unmarshal([]byte(repairedJSON), &finalInput); err != nil {
log.Warnf("kiro: failed to parse accumulated tool input: %v, raw: %s", err, fullInput)
finalInput = make(map[string]interface{})
}
// Detect truncation for all tools
truncInfo := DetectTruncation(currentToolUse.Name, currentToolUse.ToolUseID, fullInput, finalInput)
if truncInfo.IsTruncated {
log.Warnf("kiro: TRUNCATION DETECTED for tool %s (ID: %s): type=%s, raw_size=%d bytes",
currentToolUse.Name, currentToolUse.ToolUseID, truncInfo.TruncationType, len(fullInput))
log.Warnf("kiro: truncation details: %s", truncInfo.ErrorMessage)
if len(truncInfo.ParsedFields) > 0 {
log.Infof("kiro: partial fields received: %v", truncInfo.ParsedFields)
}
// Store truncation info in the state for upstream handling
currentToolUse.TruncationInfo = &truncInfo
} else {
log.Infof("kiro: tool use %s input length: %d bytes (no truncation)", currentToolUse.Name, len(fullInput))
}
// Create the tool use with truncation info if applicable
toolUse := KiroToolUse{
ToolUseID: currentToolUse.ToolUseID,
Name: currentToolUse.Name,
Input: finalInput,
IsTruncated: truncInfo.IsTruncated,
TruncationInfo: nil, // Will be set below if truncated
}
if truncInfo.IsTruncated {
toolUse.TruncationInfo = &truncInfo
}
toolUses = append(toolUses, toolUse)
if processedIDs != nil {
processedIDs[currentToolUse.ToolUseID] = true
}
log.Infof("kiro: completed tool use: %s (ID: %s, truncated: %v)", currentToolUse.Name, currentToolUse.ToolUseID, truncInfo.IsTruncated)
return toolUses, nil
}
return toolUses, currentToolUse
}
// DeduplicateToolUses removes duplicate tool uses based on toolUseId and content.
func DeduplicateToolUses(toolUses []KiroToolUse) []KiroToolUse {
seenIDs := make(map[string]bool)
seenContent := make(map[string]bool)
var unique []KiroToolUse
for _, tu := range toolUses {
if seenIDs[tu.ToolUseID] {
log.Debugf("kiro: removing ID-duplicate tool use: %s (name: %s)", tu.ToolUseID, tu.Name)
continue
}
inputJSON, _ := json.Marshal(tu.Input)
contentKey := tu.Name + ":" + string(inputJSON)
if seenContent[contentKey] {
log.Debugf("kiro: removing content-duplicate tool use: %s (id: %s)", tu.Name, tu.ToolUseID)
continue
}
seenIDs[tu.ToolUseID] = true
seenContent[contentKey] = true
unique = append(unique, tu)
}
return unique
}

View File

@@ -0,0 +1,495 @@
// Package claude provides web search functionality for Kiro translator.
// This file implements detection, MCP request/response types, and pure data
// transformation utilities for web search. SSE event generation, stream analysis,
// and HTTP I/O logic reside in the executor package (kiro_executor.go).
package claude
import (
"encoding/json"
"fmt"
"strings"
"sync/atomic"
"time"
"github.com/google/uuid"
log "github.com/sirupsen/logrus"
"github.com/tidwall/gjson"
"github.com/tidwall/sjson"
)
// cachedToolDescription stores the dynamically-fetched web_search tool description.
// Written by the executor via SetWebSearchDescription, read by the translator
// when building the remote_web_search tool for Kiro API requests.
var cachedToolDescription atomic.Value // stores string
// GetWebSearchDescription returns the cached web_search tool description,
// or empty string if not yet fetched. Lock-free via atomic.Value.
func GetWebSearchDescription() string {
if v := cachedToolDescription.Load(); v != nil {
return v.(string)
}
return ""
}
// SetWebSearchDescription stores the dynamically-fetched web_search tool description.
// Called by the executor after fetching from MCP tools/list.
func SetWebSearchDescription(desc string) {
cachedToolDescription.Store(desc)
}
// McpRequest represents a JSON-RPC 2.0 request to Kiro MCP API
type McpRequest struct {
ID string `json:"id"`
JSONRPC string `json:"jsonrpc"`
Method string `json:"method"`
Params McpParams `json:"params"`
}
// McpParams represents MCP request parameters
type McpParams struct {
Name string `json:"name"`
Arguments McpArguments `json:"arguments"`
}
// McpArgumentsMeta represents the _meta field in MCP arguments
type McpArgumentsMeta struct {
IsValid bool `json:"_isValid"`
ActivePath []string `json:"_activePath"`
CompletedPaths [][]string `json:"_completedPaths"`
}
// McpArguments represents MCP request arguments
type McpArguments struct {
Query string `json:"query"`
Meta *McpArgumentsMeta `json:"_meta,omitempty"`
}
// McpResponse represents a JSON-RPC 2.0 response from Kiro MCP API
type McpResponse struct {
Error *McpError `json:"error,omitempty"`
ID string `json:"id"`
JSONRPC string `json:"jsonrpc"`
Result *McpResult `json:"result,omitempty"`
}
// McpError represents an MCP error
type McpError struct {
Code *int `json:"code,omitempty"`
Message *string `json:"message,omitempty"`
}
// McpResult represents MCP result
type McpResult struct {
Content []McpContent `json:"content"`
IsError bool `json:"isError"`
}
// McpContent represents MCP content item
type McpContent struct {
ContentType string `json:"type"`
Text string `json:"text"`
}
// WebSearchResults represents parsed search results
type WebSearchResults struct {
Results []WebSearchResult `json:"results"`
TotalResults *int `json:"totalResults,omitempty"`
Query *string `json:"query,omitempty"`
Error *string `json:"error,omitempty"`
}
// WebSearchResult represents a single search result
type WebSearchResult struct {
Title string `json:"title"`
URL string `json:"url"`
Snippet *string `json:"snippet,omitempty"`
PublishedDate *int64 `json:"publishedDate,omitempty"`
ID *string `json:"id,omitempty"`
Domain *string `json:"domain,omitempty"`
MaxVerbatimWordLimit *int `json:"maxVerbatimWordLimit,omitempty"`
PublicDomain *bool `json:"publicDomain,omitempty"`
}
// isWebSearchTool checks if a tool name or type indicates a web_search tool.
func isWebSearchTool(name, toolType string) bool {
return name == "web_search" ||
strings.HasPrefix(toolType, "web_search") ||
toolType == "web_search_20250305"
}
// HasWebSearchTool checks if the request contains ONLY a web_search tool.
// Returns true only if tools array has exactly one tool named "web_search".
// Only intercept pure web_search requests (single-tool array).
func HasWebSearchTool(body []byte) bool {
tools := gjson.GetBytes(body, "tools")
if !tools.IsArray() {
return false
}
toolsArray := tools.Array()
if len(toolsArray) != 1 {
return false
}
// Check if the single tool is web_search
tool := toolsArray[0]
// Check both name and type fields for web_search detection
name := strings.ToLower(tool.Get("name").String())
toolType := strings.ToLower(tool.Get("type").String())
return isWebSearchTool(name, toolType)
}
// ExtractSearchQuery extracts the search query from the request.
// Reads messages[0].content and removes "Perform a web search for the query: " prefix.
func ExtractSearchQuery(body []byte) string {
messages := gjson.GetBytes(body, "messages")
if !messages.IsArray() || len(messages.Array()) == 0 {
return ""
}
firstMsg := messages.Array()[0]
content := firstMsg.Get("content")
var text string
if content.IsArray() {
// Array format: [{"type": "text", "text": "..."}]
for _, block := range content.Array() {
if block.Get("type").String() == "text" {
text = block.Get("text").String()
break
}
}
} else {
// String format
text = content.String()
}
// Remove prefix "Perform a web search for the query: "
const prefix = "Perform a web search for the query: "
if strings.HasPrefix(text, prefix) {
text = text[len(prefix):]
}
return strings.TrimSpace(text)
}
// generateRandomID8 generates an 8-character random lowercase alphanumeric string
func generateRandomID8() string {
u := uuid.New()
return strings.ToLower(strings.ReplaceAll(u.String(), "-", "")[:8])
}
// CreateMcpRequest creates an MCP request for web search.
// Returns (toolUseID, McpRequest)
// ID format: web_search_tooluse_{22 random}_{timestamp_millis}_{8 random}
func CreateMcpRequest(query string) (string, *McpRequest) {
random22 := GenerateToolUseID()
timestamp := time.Now().UnixMilli()
random8 := generateRandomID8()
requestID := fmt.Sprintf("web_search_tooluse_%s_%d_%s", random22, timestamp, random8)
// tool_use_id format: srvtoolu_{32 hex chars}
toolUseID := "srvtoolu_" + strings.ReplaceAll(uuid.New().String(), "-", "")[:32]
request := &McpRequest{
ID: requestID,
JSONRPC: "2.0",
Method: "tools/call",
Params: McpParams{
Name: "web_search",
Arguments: McpArguments{
Query: query,
Meta: &McpArgumentsMeta{
IsValid: true,
ActivePath: []string{"query"},
CompletedPaths: [][]string{{"query"}},
},
},
},
}
return toolUseID, request
}
// GenerateToolUseID generates a Kiro-style tool use ID (base62-like UUID)
func GenerateToolUseID() string {
return strings.ReplaceAll(uuid.New().String(), "-", "")[:22]
}
// ReplaceWebSearchToolDescription replaces the web_search tool description with
// a minimal version that allows re-search without the restrictive "do not search
// non-coding topics" instruction from the original Kiro tools/list response.
// This keeps the tool available so the model can request additional searches.
func ReplaceWebSearchToolDescription(body []byte) ([]byte, error) {
tools := gjson.GetBytes(body, "tools")
if !tools.IsArray() {
return body, nil
}
var updated []json.RawMessage
for _, tool := range tools.Array() {
name := strings.ToLower(tool.Get("name").String())
toolType := strings.ToLower(tool.Get("type").String())
if isWebSearchTool(name, toolType) {
// Replace with a minimal web_search tool definition
minimalTool := map[string]interface{}{
"name": "web_search",
"description": "Search the web for information. Use this when the previous search results are insufficient or when you need additional information on a different aspect of the query. Provide a refined or different search query.",
"input_schema": map[string]interface{}{
"type": "object",
"properties": map[string]interface{}{
"query": map[string]interface{}{
"type": "string",
"description": "The search query to execute",
},
},
"required": []string{"query"},
"additionalProperties": false,
},
}
minimalJSON, err := json.Marshal(minimalTool)
if err != nil {
return body, fmt.Errorf("failed to marshal minimal tool: %w", err)
}
updated = append(updated, json.RawMessage(minimalJSON))
} else {
updated = append(updated, json.RawMessage(tool.Raw))
}
}
updatedJSON, err := json.Marshal(updated)
if err != nil {
return body, fmt.Errorf("failed to marshal updated tools: %w", err)
}
result, err := sjson.SetRawBytes(body, "tools", updatedJSON)
if err != nil {
return body, fmt.Errorf("failed to set updated tools: %w", err)
}
return result, nil
}
// FormatSearchContextPrompt formats search results as a structured text block
// for injection into the system prompt.
func FormatSearchContextPrompt(query string, results *WebSearchResults) string {
var sb strings.Builder
sb.WriteString(fmt.Sprintf("[Web Search Results for \"%s\"]\n", query))
if results != nil && len(results.Results) > 0 {
for i, r := range results.Results {
sb.WriteString(fmt.Sprintf("%d. %s - %s\n", i+1, r.Title, r.URL))
if r.Snippet != nil && *r.Snippet != "" {
snippet := *r.Snippet
if len(snippet) > 500 {
snippet = snippet[:500] + "..."
}
sb.WriteString(fmt.Sprintf(" %s\n", snippet))
}
}
} else {
sb.WriteString("No results found.\n")
}
sb.WriteString("[End Web Search Results]")
return sb.String()
}
// FormatToolResultText formats search results as JSON text for the toolResults content field.
// This matches the format observed in Kiro IDE HAR captures.
func FormatToolResultText(results *WebSearchResults) string {
if results == nil || len(results.Results) == 0 {
return "No search results found."
}
text := fmt.Sprintf("Found %d search result(s):\n\n", len(results.Results))
resultJSON, err := json.MarshalIndent(results.Results, "", " ")
if err != nil {
return text + "Error formatting results."
}
return text + string(resultJSON)
}
// InjectToolResultsClaude modifies a Claude-format JSON payload to append
// tool_use (assistant) and tool_result (user) messages to the messages array.
// BuildKiroPayload correctly translates:
// - assistant tool_use → KiroAssistantResponseMessage.toolUses
// - user tool_result → KiroUserInputMessageContext.toolResults
//
// This produces the exact same GAR request format as the Kiro IDE (HAR captures).
// IMPORTANT: The web_search tool must remain in the "tools" array for this to work.
// Use ReplaceWebSearchToolDescription to keep the tool available with a minimal description.
func InjectToolResultsClaude(claudePayload []byte, toolUseId, query string, results *WebSearchResults) ([]byte, error) {
var payload map[string]interface{}
if err := json.Unmarshal(claudePayload, &payload); err != nil {
return claudePayload, fmt.Errorf("failed to parse claude payload: %w", err)
}
messages, _ := payload["messages"].([]interface{})
// 1. Append assistant message with tool_use (matches HAR: assistantResponseMessage.toolUses)
assistantMsg := map[string]interface{}{
"role": "assistant",
"content": []interface{}{
map[string]interface{}{
"type": "tool_use",
"id": toolUseId,
"name": "web_search",
"input": map[string]interface{}{"query": query},
},
},
}
messages = append(messages, assistantMsg)
// 2. Append user message with tool_result + search behavior instructions.
// NOTE: We embed search instructions HERE (not in system prompt) because
// BuildKiroPayload clears the system prompt when len(history) > 0,
// which is always true after injecting assistant + user messages.
now := time.Now()
searchGuidance := fmt.Sprintf(`<search_guidance>
Current date: %s (%s)
IMPORTANT: Evaluate the search results above carefully. If the results are:
- Mostly spam, SEO junk, or unrelated websites
- Missing actual information about the query topic
- Outdated or not matching the requested time frame
Then you MUST use the web_search tool again with a refined query. Try:
- Rephrasing in English for better coverage
- Using more specific keywords
- Adding date context
Do NOT apologize for bad results without first attempting a re-search.
</search_guidance>`, now.Format("January 2, 2006"), now.Format("Monday"))
userMsg := map[string]interface{}{
"role": "user",
"content": []interface{}{
map[string]interface{}{
"type": "tool_result",
"tool_use_id": toolUseId,
"content": FormatToolResultText(results),
},
map[string]interface{}{
"type": "text",
"text": searchGuidance,
},
},
}
messages = append(messages, userMsg)
payload["messages"] = messages
result, err := json.Marshal(payload)
if err != nil {
return claudePayload, fmt.Errorf("failed to marshal updated payload: %w", err)
}
log.Infof("kiro/websearch: injected tool_use+tool_result (toolUseId=%s, messages=%d)",
toolUseId, len(messages))
return result, nil
}
// InjectSearchIndicatorsInResponse prepends server_tool_use + web_search_tool_result
// content blocks into a non-streaming Claude JSON response. Claude Code counts
// server_tool_use blocks to display "Did X searches in Ys".
//
// Input response: {"content": [{"type":"text","text":"..."}], ...}
// Output response: {"content": [{"type":"server_tool_use",...}, {"type":"web_search_tool_result",...}, {"type":"text","text":"..."}], ...}
func InjectSearchIndicatorsInResponse(responsePayload []byte, searches []SearchIndicator) ([]byte, error) {
if len(searches) == 0 {
return responsePayload, nil
}
var resp map[string]interface{}
if err := json.Unmarshal(responsePayload, &resp); err != nil {
return responsePayload, fmt.Errorf("failed to parse response: %w", err)
}
existingContent, _ := resp["content"].([]interface{})
// Build new content: search indicators first, then existing content
newContent := make([]interface{}, 0, len(searches)*2+len(existingContent))
for _, s := range searches {
// server_tool_use block
newContent = append(newContent, map[string]interface{}{
"type": "server_tool_use",
"id": s.ToolUseID,
"name": "web_search",
"input": map[string]interface{}{"query": s.Query},
})
// web_search_tool_result block
searchContent := make([]map[string]interface{}, 0)
if s.Results != nil {
for _, r := range s.Results.Results {
snippet := ""
if r.Snippet != nil {
snippet = *r.Snippet
}
searchContent = append(searchContent, map[string]interface{}{
"type": "web_search_result",
"title": r.Title,
"url": r.URL,
"encrypted_content": snippet,
"page_age": nil,
})
}
}
newContent = append(newContent, map[string]interface{}{
"type": "web_search_tool_result",
"tool_use_id": s.ToolUseID,
"content": searchContent,
})
}
// Append existing content blocks
newContent = append(newContent, existingContent...)
resp["content"] = newContent
result, err := json.Marshal(resp)
if err != nil {
return responsePayload, fmt.Errorf("failed to marshal response: %w", err)
}
log.Infof("kiro/websearch: injected %d search indicator(s) into non-stream response", len(searches))
return result, nil
}
// SearchIndicator holds the data for one search operation to inject into a response.
type SearchIndicator struct {
ToolUseID string
Query string
Results *WebSearchResults
}
// BuildMcpEndpoint constructs the MCP endpoint URL for the given AWS region.
// Centralizes the URL pattern used by both handleWebSearch and handleWebSearchStream.
func BuildMcpEndpoint(region string) string {
return fmt.Sprintf("https://q.%s.amazonaws.com/mcp", region)
}
// ParseSearchResults extracts WebSearchResults from MCP response
func ParseSearchResults(response *McpResponse) *WebSearchResults {
if response == nil || response.Result == nil || len(response.Result.Content) == 0 {
return nil
}
content := response.Result.Content[0]
if content.ContentType != "text" {
return nil
}
var results WebSearchResults
if err := json.Unmarshal([]byte(content.Text), &results); err != nil {
log.Warnf("kiro/websearch: failed to parse search results: %v", err)
return nil
}
return &results
}

View File

@@ -0,0 +1,191 @@
// Package claude provides tool compression functionality for Kiro translator.
// This file implements dynamic tool compression to reduce tool payload size
// when it exceeds the target threshold, preventing 500 errors from Kiro API.
package claude
import (
"encoding/json"
"unicode/utf8"
kirocommon "github.com/router-for-me/CLIProxyAPI/v6/internal/translator/kiro/common"
log "github.com/sirupsen/logrus"
)
// calculateToolsSize calculates the JSON serialized size of the tools list.
// Returns the size in bytes.
func calculateToolsSize(tools []KiroToolWrapper) int {
if len(tools) == 0 {
return 0
}
data, err := json.Marshal(tools)
if err != nil {
log.Warnf("kiro: failed to marshal tools for size calculation: %v", err)
return 0
}
return len(data)
}
// simplifyInputSchema simplifies the input_schema by keeping only essential fields:
// type, enum, required. Recursively processes nested properties.
func simplifyInputSchema(schema interface{}) interface{} {
if schema == nil {
return nil
}
schemaMap, ok := schema.(map[string]interface{})
if !ok {
return schema
}
simplified := make(map[string]interface{})
// Keep essential fields
if t, ok := schemaMap["type"]; ok {
simplified["type"] = t
}
if enum, ok := schemaMap["enum"]; ok {
simplified["enum"] = enum
}
if required, ok := schemaMap["required"]; ok {
simplified["required"] = required
}
// Recursively process properties
if properties, ok := schemaMap["properties"].(map[string]interface{}); ok {
simplifiedProps := make(map[string]interface{})
for key, value := range properties {
simplifiedProps[key] = simplifyInputSchema(value)
}
simplified["properties"] = simplifiedProps
}
// Process items for array types
if items, ok := schemaMap["items"]; ok {
simplified["items"] = simplifyInputSchema(items)
}
// Process additionalProperties if present
if additionalProps, ok := schemaMap["additionalProperties"]; ok {
simplified["additionalProperties"] = simplifyInputSchema(additionalProps)
}
// Process anyOf, oneOf, allOf
for _, key := range []string{"anyOf", "oneOf", "allOf"} {
if arr, ok := schemaMap[key].([]interface{}); ok {
simplifiedArr := make([]interface{}, len(arr))
for i, item := range arr {
simplifiedArr[i] = simplifyInputSchema(item)
}
simplified[key] = simplifiedArr
}
}
return simplified
}
// compressToolDescription compresses a description to the target length.
// Ensures the result is at least MinToolDescriptionLength characters.
// Uses UTF-8 safe truncation.
func compressToolDescription(description string, targetLength int) string {
if targetLength < kirocommon.MinToolDescriptionLength {
targetLength = kirocommon.MinToolDescriptionLength
}
if len(description) <= targetLength {
return description
}
// Find a safe truncation point (UTF-8 boundary)
truncLen := targetLength - 3 // Leave room for "..."
// Ensure we don't cut in the middle of a UTF-8 character
for truncLen > 0 && !utf8.RuneStart(description[truncLen]) {
truncLen--
}
if truncLen <= 0 {
return description[:kirocommon.MinToolDescriptionLength]
}
return description[:truncLen] + "..."
}
// compressToolsIfNeeded compresses tools if their total size exceeds the target threshold.
// Compression strategy:
// 1. First, check if compression is needed (size > ToolCompressionTargetSize)
// 2. Step 1: Simplify input_schema (keep only type/enum/required)
// 3. Step 2: Proportionally compress descriptions (minimum MinToolDescriptionLength chars)
// Returns the compressed tools list.
func compressToolsIfNeeded(tools []KiroToolWrapper) []KiroToolWrapper {
if len(tools) == 0 {
return tools
}
originalSize := calculateToolsSize(tools)
if originalSize <= kirocommon.ToolCompressionTargetSize {
log.Debugf("kiro: tools size %d bytes is within target %d bytes, no compression needed",
originalSize, kirocommon.ToolCompressionTargetSize)
return tools
}
log.Infof("kiro: tools size %d bytes exceeds target %d bytes, starting compression",
originalSize, kirocommon.ToolCompressionTargetSize)
// Create a copy of tools to avoid modifying the original
compressedTools := make([]KiroToolWrapper, len(tools))
for i, tool := range tools {
compressedTools[i] = KiroToolWrapper{
ToolSpecification: KiroToolSpecification{
Name: tool.ToolSpecification.Name,
Description: tool.ToolSpecification.Description,
InputSchema: KiroInputSchema{JSON: tool.ToolSpecification.InputSchema.JSON},
},
}
}
// Step 1: Simplify input_schema
for i := range compressedTools {
compressedTools[i].ToolSpecification.InputSchema.JSON =
simplifyInputSchema(compressedTools[i].ToolSpecification.InputSchema.JSON)
}
sizeAfterSchemaSimplification := calculateToolsSize(compressedTools)
log.Debugf("kiro: size after schema simplification: %d bytes (reduced by %d bytes)",
sizeAfterSchemaSimplification, originalSize-sizeAfterSchemaSimplification)
// Check if we're within target after schema simplification
if sizeAfterSchemaSimplification <= kirocommon.ToolCompressionTargetSize {
log.Infof("kiro: compression complete after schema simplification, final size: %d bytes",
sizeAfterSchemaSimplification)
return compressedTools
}
// Step 2: Compress descriptions proportionally
sizeToReduce := float64(sizeAfterSchemaSimplification - kirocommon.ToolCompressionTargetSize)
var totalDescLen float64
for _, tool := range compressedTools {
totalDescLen += float64(len(tool.ToolSpecification.Description))
}
if totalDescLen > 0 {
// Assume size reduction comes primarily from descriptions.
keepRatio := 1.0 - (sizeToReduce / totalDescLen)
if keepRatio > 1.0 {
keepRatio = 1.0
} else if keepRatio < 0 {
keepRatio = 0
}
for i := range compressedTools {
desc := compressedTools[i].ToolSpecification.Description
targetLen := int(float64(len(desc)) * keepRatio)
compressedTools[i].ToolSpecification.Description = compressToolDescription(desc, targetLen)
}
}
finalSize := calculateToolsSize(compressedTools)
log.Infof("kiro: compression complete, original: %d bytes, final: %d bytes (%.1f%% reduction)",
originalSize, finalSize, float64(originalSize-finalSize)/float64(originalSize)*100)
return compressedTools
}

View File

@@ -0,0 +1,532 @@
// Package claude provides truncation detection for Kiro tool call responses.
// When Kiro API reaches its output token limit, tool call JSON may be truncated,
// resulting in incomplete or unparseable input parameters.
package claude
import (
"encoding/json"
"strings"
log "github.com/sirupsen/logrus"
)
// TruncationInfo contains details about detected truncation in a tool use event.
type TruncationInfo struct {
IsTruncated bool // Whether truncation was detected
TruncationType string // Type of truncation detected
ToolName string // Name of the truncated tool
ToolUseID string // ID of the truncated tool use
RawInput string // The raw (possibly truncated) input string
ParsedFields map[string]string // Fields that were successfully parsed before truncation
ErrorMessage string // Human-readable error message
}
// TruncationType constants for different truncation scenarios
const (
TruncationTypeNone = "" // No truncation detected
TruncationTypeEmptyInput = "empty_input" // No input data received at all
TruncationTypeInvalidJSON = "invalid_json" // JSON is syntactically invalid (truncated mid-value)
TruncationTypeMissingFields = "missing_fields" // JSON parsed but critical fields are missing
TruncationTypeIncompleteString = "incomplete_string" // String value was cut off mid-content
)
// KnownWriteTools lists tool names that typically write content and have a "content" field.
// These tools are checked for content field truncation specifically.
var KnownWriteTools = map[string]bool{
"Write": true,
"write_to_file": true,
"fsWrite": true,
"create_file": true,
"edit_file": true,
"apply_diff": true,
"str_replace_editor": true,
"insert": true,
}
// KnownCommandTools lists tool names that execute commands.
var KnownCommandTools = map[string]bool{
"Bash": true,
"execute": true,
"run_command": true,
"shell": true,
"terminal": true,
"execute_python": true,
}
// RequiredFieldsByTool maps tool names to their required field groups.
// Each outer element is a required group; each inner slice lists alternative field names (OR logic).
// A group is satisfied when ANY one of its alternatives exists in the parsed input.
// All groups must be satisfied for the tool input to be considered valid.
//
// Example:
// {{"cmd", "command"}} means the tool needs EITHER "cmd" OR "command".
// {{"file_path"}, {"content"}} means the tool needs BOTH "file_path" AND "content".
var RequiredFieldsByTool = map[string][][]string{
"Write": {{"file_path"}, {"content"}},
"write_to_file": {{"path"}, {"content"}},
"fsWrite": {{"path"}, {"content"}},
"create_file": {{"path"}, {"content"}},
"edit_file": {{"path"}},
"apply_diff": {{"path"}, {"diff"}},
"str_replace_editor": {{"path"}, {"old_str"}, {"new_str"}},
"Bash": {{"cmd", "command"}},
"execute": {{"command"}},
"run_command": {{"command"}},
}
// DetectTruncation checks if the tool use input appears to be truncated.
// It returns detailed information about the truncation status and type.
func DetectTruncation(toolName, toolUseID, rawInput string, parsedInput map[string]interface{}) TruncationInfo {
info := TruncationInfo{
ToolName: toolName,
ToolUseID: toolUseID,
RawInput: rawInput,
ParsedFields: make(map[string]string),
}
// Scenario 1: Empty input buffer - no data received at all
if strings.TrimSpace(rawInput) == "" {
info.IsTruncated = true
info.TruncationType = TruncationTypeEmptyInput
info.ErrorMessage = "Tool input was completely empty - API response may have been truncated before tool parameters were transmitted"
log.Warnf("kiro: truncation detected [%s] for tool %s (ID: %s): empty input buffer",
info.TruncationType, toolName, toolUseID)
return info
}
// Scenario 2: JSON parse failure - syntactically invalid JSON
if parsedInput == nil || len(parsedInput) == 0 {
// Check if the raw input looks like truncated JSON
if looksLikeTruncatedJSON(rawInput) {
info.IsTruncated = true
info.TruncationType = TruncationTypeInvalidJSON
info.ParsedFields = extractPartialFields(rawInput)
info.ErrorMessage = buildTruncationErrorMessage(toolName, info.TruncationType, info.ParsedFields, rawInput)
log.Warnf("kiro: truncation detected [%s] for tool %s (ID: %s): JSON parse failed, raw length=%d bytes",
info.TruncationType, toolName, toolUseID, len(rawInput))
return info
}
}
// Scenario 3: JSON parsed but critical fields are missing
if parsedInput != nil {
requiredGroups, hasRequirements := RequiredFieldsByTool[toolName]
if hasRequirements {
missingFields := findMissingRequiredFields(parsedInput, requiredGroups)
if len(missingFields) > 0 {
info.IsTruncated = true
info.TruncationType = TruncationTypeMissingFields
info.ParsedFields = extractParsedFieldNames(parsedInput)
info.ErrorMessage = buildMissingFieldsErrorMessage(toolName, missingFields, info.ParsedFields)
log.Warnf("kiro: truncation detected [%s] for tool %s (ID: %s): missing required fields: %v",
info.TruncationType, toolName, toolUseID, missingFields)
return info
}
}
// Scenario 4: Check for incomplete string values (very short content for write tools)
if isWriteTool(toolName) {
if contentTruncation := detectContentTruncation(parsedInput, rawInput); contentTruncation != "" {
info.IsTruncated = true
info.TruncationType = TruncationTypeIncompleteString
info.ParsedFields = extractParsedFieldNames(parsedInput)
info.ErrorMessage = contentTruncation
log.Warnf("kiro: truncation detected [%s] for tool %s (ID: %s): %s",
info.TruncationType, toolName, toolUseID, contentTruncation)
return info
}
}
}
// No truncation detected
info.IsTruncated = false
info.TruncationType = TruncationTypeNone
return info
}
// looksLikeTruncatedJSON checks if the raw string appears to be truncated JSON.
func looksLikeTruncatedJSON(raw string) bool {
trimmed := strings.TrimSpace(raw)
if trimmed == "" {
return false
}
// Must start with { to be considered JSON
if !strings.HasPrefix(trimmed, "{") {
return false
}
// Count brackets to detect imbalance
openBraces := strings.Count(trimmed, "{")
closeBraces := strings.Count(trimmed, "}")
openBrackets := strings.Count(trimmed, "[")
closeBrackets := strings.Count(trimmed, "]")
// Bracket imbalance suggests truncation
if openBraces > closeBraces || openBrackets > closeBrackets {
return true
}
// Check for obvious truncation patterns
// - Ends with a quote but no closing brace
// - Ends with a colon (mid key-value)
// - Ends with a comma (mid object/array)
lastChar := trimmed[len(trimmed)-1]
if lastChar != '}' && lastChar != ']' {
// Check if it's not a complete simple value
if lastChar == '"' || lastChar == ':' || lastChar == ',' {
return true
}
}
// Check for unclosed strings (odd number of unescaped quotes)
inString := false
escaped := false
for i := 0; i < len(trimmed); i++ {
c := trimmed[i]
if escaped {
escaped = false
continue
}
if c == '\\' {
escaped = true
continue
}
if c == '"' {
inString = !inString
}
}
if inString {
return true // Unclosed string
}
return false
}
// extractPartialFields attempts to extract any field names from malformed JSON.
// This helps provide context about what was received before truncation.
func extractPartialFields(raw string) map[string]string {
fields := make(map[string]string)
// Simple pattern matching for "key": "value" or "key": value patterns
// This works even with truncated JSON
trimmed := strings.TrimSpace(raw)
if !strings.HasPrefix(trimmed, "{") {
return fields
}
// Remove opening brace
content := strings.TrimPrefix(trimmed, "{")
// Split by comma (rough parsing)
parts := strings.Split(content, ",")
for _, part := range parts {
part = strings.TrimSpace(part)
if colonIdx := strings.Index(part, ":"); colonIdx > 0 {
key := strings.TrimSpace(part[:colonIdx])
key = strings.Trim(key, `"`)
value := strings.TrimSpace(part[colonIdx+1:])
// Truncate long values for display
if len(value) > 50 {
value = value[:50] + "..."
}
fields[key] = value
}
}
return fields
}
// extractParsedFieldNames returns the field names from a successfully parsed map.
func extractParsedFieldNames(parsed map[string]interface{}) map[string]string {
fields := make(map[string]string)
for key, val := range parsed {
switch v := val.(type) {
case string:
if len(v) > 50 {
fields[key] = v[:50] + "..."
} else {
fields[key] = v
}
case nil:
fields[key] = "<null>"
default:
// For complex types, just indicate presence
fields[key] = "<present>"
}
}
return fields
}
// findMissingRequiredFields checks which required field groups are unsatisfied.
// Each group is a slice of alternative field names; the group is satisfied when ANY alternative exists.
// Returns the list of unsatisfied groups (represented by their alternatives joined with "/").
func findMissingRequiredFields(parsed map[string]interface{}, requiredGroups [][]string) []string {
var missing []string
for _, group := range requiredGroups {
satisfied := false
for _, field := range group {
if _, exists := parsed[field]; exists {
satisfied = true
break
}
}
if !satisfied {
missing = append(missing, strings.Join(group, "/"))
}
}
return missing
}
// isWriteTool checks if the tool is a known write/file operation tool.
func isWriteTool(toolName string) bool {
return KnownWriteTools[toolName]
}
// detectContentTruncation checks if the content field appears truncated for write tools.
func detectContentTruncation(parsed map[string]interface{}, rawInput string) string {
// Check for content field
content, hasContent := parsed["content"]
if !hasContent {
return ""
}
contentStr, isString := content.(string)
if !isString {
return ""
}
// Heuristic: if raw input is very large but content is suspiciously short,
// it might indicate truncation during JSON repair
if len(rawInput) > 1000 && len(contentStr) < 100 {
return "content field appears suspiciously short compared to raw input size"
}
// Check for code blocks that appear to be cut off
if strings.Contains(contentStr, "```") {
openFences := strings.Count(contentStr, "```")
if openFences%2 != 0 {
return "content contains unclosed code fence (```) suggesting truncation"
}
}
return ""
}
// buildTruncationErrorMessage creates a human-readable error message for truncation.
func buildTruncationErrorMessage(toolName, truncationType string, parsedFields map[string]string, rawInput string) string {
var sb strings.Builder
sb.WriteString("Tool input was truncated by the API. ")
switch truncationType {
case TruncationTypeEmptyInput:
sb.WriteString("No input data was received.")
case TruncationTypeInvalidJSON:
sb.WriteString("JSON was cut off mid-transmission. ")
if len(parsedFields) > 0 {
sb.WriteString("Partial fields received: ")
first := true
for k := range parsedFields {
if !first {
sb.WriteString(", ")
}
sb.WriteString(k)
first = false
}
}
case TruncationTypeMissingFields:
sb.WriteString("Required fields are missing from the input.")
case TruncationTypeIncompleteString:
sb.WriteString("Content appears to be shortened or incomplete.")
}
sb.WriteString(" Received ")
sb.WriteString(string(rune(len(rawInput))))
sb.WriteString(" bytes. Please retry with smaller content chunks.")
return sb.String()
}
// buildMissingFieldsErrorMessage creates an error message for missing required fields.
func buildMissingFieldsErrorMessage(toolName string, missingFields []string, parsedFields map[string]string) string {
var sb strings.Builder
sb.WriteString("Tool '")
sb.WriteString(toolName)
sb.WriteString("' is missing required fields: ")
sb.WriteString(strings.Join(missingFields, ", "))
sb.WriteString(". Fields received: ")
first := true
for k := range parsedFields {
if !first {
sb.WriteString(", ")
}
sb.WriteString(k)
first = false
}
sb.WriteString(". This usually indicates the API response was truncated.")
return sb.String()
}
// IsTruncated is a convenience function to check if a tool use appears truncated.
func IsTruncated(toolName, rawInput string, parsedInput map[string]interface{}) bool {
info := DetectTruncation(toolName, "", rawInput, parsedInput)
return info.IsTruncated
}
// GetTruncationSummary returns a short summary string for logging.
func GetTruncationSummary(info TruncationInfo) string {
if !info.IsTruncated {
return ""
}
result, _ := json.Marshal(map[string]interface{}{
"tool": info.ToolName,
"type": info.TruncationType,
"parsed_fields": info.ParsedFields,
"raw_input_size": len(info.RawInput),
})
return string(result)
}
// SoftFailureMessage contains the message structure for a truncation soft failure.
// This is returned to Claude as a tool_result to guide retry behavior.
type SoftFailureMessage struct {
Status string // "incomplete" - not an error, just incomplete
Reason string // Why the tool call was incomplete
Guidance []string // Step-by-step retry instructions
Context string // Any context about what was received
MaxLineHint int // Suggested maximum lines per chunk
}
// BuildSoftFailureMessage creates a structured message for Claude when truncation is detected.
// This follows the "soft failure" pattern:
// - For Claude: Clear explanation of what happened and how to fix
// - For User: Hidden or minimized (appears as normal processing)
//
// Key principle: "Conclusion First"
// 1. First state what happened (incomplete)
// 2. Then explain how to fix (chunked approach)
// 3. Provide specific guidance (line limits)
func BuildSoftFailureMessage(info TruncationInfo) SoftFailureMessage {
msg := SoftFailureMessage{
Status: "incomplete",
MaxLineHint: 300, // Conservative default
}
// Build reason based on truncation type
switch info.TruncationType {
case TruncationTypeEmptyInput:
msg.Reason = "Your tool call was too large and the input was completely lost during transmission."
msg.MaxLineHint = 200
case TruncationTypeInvalidJSON:
msg.Reason = "Your tool call was truncated mid-transmission, resulting in incomplete JSON."
msg.MaxLineHint = 250
case TruncationTypeMissingFields:
msg.Reason = "Your tool call was partially received but critical fields were cut off."
msg.MaxLineHint = 300
case TruncationTypeIncompleteString:
msg.Reason = "Your tool call content was truncated - the full content did not arrive."
msg.MaxLineHint = 350
default:
msg.Reason = "Your tool call was truncated by the API due to output size limits."
}
// Build context from parsed fields
if len(info.ParsedFields) > 0 {
var parts []string
for k, v := range info.ParsedFields {
if len(v) > 30 {
v = v[:30] + "..."
}
parts = append(parts, k+"="+v)
}
msg.Context = "Received partial data: " + strings.Join(parts, ", ")
}
// Build retry guidance - CRITICAL: Conclusion first approach
msg.Guidance = []string{
"CONCLUSION: Split your output into smaller chunks and retry.",
"",
"REQUIRED APPROACH:",
"1. For file writes: Write in chunks of ~" + formatInt(msg.MaxLineHint) + " lines maximum",
"2. For new files: First create with initial chunk, then append remaining sections",
"3. For edits: Make surgical, targeted changes - avoid rewriting entire files",
"",
"EXAMPLE (writing a 600-line file):",
" - Step 1: Write lines 1-300 (create file)",
" - Step 2: Append lines 301-600 (extend file)",
"",
"DO NOT attempt to write the full content again in a single call.",
"The API has a hard output limit that cannot be bypassed.",
}
return msg
}
// formatInt converts an integer to string (helper to avoid strconv import)
func formatInt(n int) string {
if n == 0 {
return "0"
}
result := ""
for n > 0 {
result = string(rune('0'+n%10)) + result
n /= 10
}
return result
}
// BuildSoftFailureToolResult creates a tool_result content for Claude.
// This is what Claude will see when a tool call is truncated.
// Returns a string that should be used as the tool_result content.
func BuildSoftFailureToolResult(info TruncationInfo) string {
msg := BuildSoftFailureMessage(info)
var sb strings.Builder
sb.WriteString("TOOL_CALL_INCOMPLETE\n")
sb.WriteString("status: ")
sb.WriteString(msg.Status)
sb.WriteString("\n")
sb.WriteString("reason: ")
sb.WriteString(msg.Reason)
sb.WriteString("\n")
if msg.Context != "" {
sb.WriteString("context: ")
sb.WriteString(msg.Context)
sb.WriteString("\n")
}
sb.WriteString("\n")
for _, line := range msg.Guidance {
if line != "" {
sb.WriteString(line)
sb.WriteString("\n")
}
}
return sb.String()
}
// CreateTruncationToolResult creates a KiroToolUse that represents a soft failure.
// Instead of returning the truncated tool_use, we return a tool with a special
// error result that guides Claude to retry with smaller chunks.
//
// This is the key mechanism for "soft failure":
// - stop_reason remains "tool_use" so Claude continues
// - The tool_result content explains the issue and how to fix it
// - Claude will read this and adjust its approach
func CreateTruncationToolResult(info TruncationInfo) KiroToolUse {
// We create a pseudo tool_use that represents the failed attempt
// The executor will convert this to a tool_result with the guidance message
return KiroToolUse{
ToolUseID: info.ToolUseID,
Name: info.ToolName,
Input: nil, // No input since it was truncated
IsTruncated: true,
TruncationInfo: &info,
}
}

View File

@@ -0,0 +1,103 @@
// Package common provides shared constants and utilities for Kiro translator.
package common
const (
// KiroMaxToolDescLen is the maximum description length for Kiro API tools.
// Kiro API limit is 10240 bytes, leave room for "..."
KiroMaxToolDescLen = 10237
// ToolCompressionTargetSize is the target total size for compressed tools (20KB).
// If tools exceed this size, compression will be applied.
ToolCompressionTargetSize = 20 * 1024 // 20KB
// MinToolDescriptionLength is the minimum description length after compression.
// Descriptions will not be shortened below this length.
MinToolDescriptionLength = 50
// ThinkingStartTag is the start tag for thinking blocks in responses.
ThinkingStartTag = "<thinking>"
// ThinkingEndTag is the end tag for thinking blocks in responses.
ThinkingEndTag = "</thinking>"
// CodeFenceMarker is the markdown code fence marker.
CodeFenceMarker = "```"
// AltCodeFenceMarker is the alternative markdown code fence marker.
AltCodeFenceMarker = "~~~"
// InlineCodeMarker is the markdown inline code marker (backtick).
InlineCodeMarker = "`"
// DefaultAssistantContentWithTools is the fallback content for assistant messages
// that have tool_use but no text content. Kiro API requires non-empty content.
// IMPORTANT: Use a minimal neutral string that the model won't mimic in responses.
// Previously "I'll help you with that." which caused the model to parrot it back.
DefaultAssistantContentWithTools = "."
// DefaultAssistantContent is the fallback content for assistant messages
// that have no content at all. Kiro API requires non-empty content.
// IMPORTANT: Use a minimal neutral string that the model won't mimic in responses.
// Previously "I understand." which could leak into model behavior.
DefaultAssistantContent = "."
// DefaultUserContentWithToolResults is the fallback content for user messages
// that have only tool_result (no text). Kiro API requires non-empty content.
DefaultUserContentWithToolResults = "Tool results provided."
// DefaultUserContent is the fallback content for user messages
// that have no content at all. Kiro API requires non-empty content.
DefaultUserContent = "Continue"
// KiroAgenticSystemPrompt is injected only for -agentic models to prevent timeouts on large writes.
// AWS Kiro API has a 2-3 minute timeout for large file write operations.
KiroAgenticSystemPrompt = `
# CRITICAL: CHUNKED WRITE PROTOCOL (MANDATORY)
You MUST follow these rules for ALL file operations. Violation causes server timeouts and task failure.
## ABSOLUTE LIMITS
- **MAXIMUM 350 LINES** per single write/edit operation - NO EXCEPTIONS
- **RECOMMENDED 300 LINES** or less for optimal performance
- **NEVER** write entire files in one operation if >300 lines
## MANDATORY CHUNKED WRITE STRATEGY
### For NEW FILES (>300 lines total):
1. FIRST: Write initial chunk (first 250-300 lines) using write_to_file/fsWrite
2. THEN: Append remaining content in 250-300 line chunks using file append operations
3. REPEAT: Continue appending until complete
### For EDITING EXISTING FILES:
1. Use surgical edits (apply_diff/targeted edits) - change ONLY what's needed
2. NEVER rewrite entire files - use incremental modifications
3. Split large refactors into multiple small, focused edits
### For LARGE CODE GENERATION:
1. Generate in logical sections (imports, types, functions separately)
2. Write each section as a separate operation
3. Use append operations for subsequent sections
## EXAMPLES OF CORRECT BEHAVIOR
✅ CORRECT: Writing a 600-line file
- Operation 1: Write lines 1-300 (initial file creation)
- Operation 2: Append lines 301-600
✅ CORRECT: Editing multiple functions
- Operation 1: Edit function A
- Operation 2: Edit function B
- Operation 3: Edit function C
❌ WRONG: Writing 500 lines in single operation → TIMEOUT
❌ WRONG: Rewriting entire file to change 5 lines → TIMEOUT
❌ WRONG: Generating massive code blocks without chunking → TIMEOUT
## WHY THIS MATTERS
- Server has 2-3 minute timeout for operations
- Large writes exceed timeout and FAIL completely
- Chunked writes are FASTER and more RELIABLE
- Failed writes waste time and require retry
REMEMBER: When in doubt, write LESS per operation. Multiple small operations > one large operation.`
)

View File

@@ -0,0 +1,160 @@
// Package common provides shared utilities for Kiro translators.
package common
import (
"encoding/json"
"github.com/tidwall/gjson"
)
// MergeAdjacentMessages merges adjacent messages with the same role.
// This reduces API call complexity and improves compatibility.
// Based on AIClient-2-API implementation.
// NOTE: Tool messages are NOT merged because each has a unique tool_call_id that must be preserved.
func MergeAdjacentMessages(messages []gjson.Result) []gjson.Result {
if len(messages) <= 1 {
return messages
}
var merged []gjson.Result
for _, msg := range messages {
if len(merged) == 0 {
merged = append(merged, msg)
continue
}
lastMsg := merged[len(merged)-1]
currentRole := msg.Get("role").String()
lastRole := lastMsg.Get("role").String()
// Don't merge tool messages - each has a unique tool_call_id
if currentRole == "tool" || lastRole == "tool" {
merged = append(merged, msg)
continue
}
if currentRole == lastRole {
// Merge content from current message into last message
mergedContent := mergeMessageContent(lastMsg, msg)
var mergedToolCalls []interface{}
if currentRole == "assistant" {
// Preserve assistant tool_calls when adjacent assistant messages are merged.
mergedToolCalls = mergeToolCalls(lastMsg.Get("tool_calls"), msg.Get("tool_calls"))
}
// Create a new merged message JSON.
mergedMsg := createMergedMessage(lastRole, mergedContent, mergedToolCalls)
merged[len(merged)-1] = gjson.Parse(mergedMsg)
} else {
merged = append(merged, msg)
}
}
return merged
}
// mergeMessageContent merges the content of two messages with the same role.
// Handles both string content and array content (with text, tool_use, tool_result blocks).
func mergeMessageContent(msg1, msg2 gjson.Result) string {
content1 := msg1.Get("content")
content2 := msg2.Get("content")
// Extract content blocks from both messages
var blocks1, blocks2 []map[string]interface{}
if content1.IsArray() {
for _, block := range content1.Array() {
blocks1 = append(blocks1, blockToMap(block))
}
} else if content1.Type == gjson.String {
blocks1 = append(blocks1, map[string]interface{}{
"type": "text",
"text": content1.String(),
})
}
if content2.IsArray() {
for _, block := range content2.Array() {
blocks2 = append(blocks2, blockToMap(block))
}
} else if content2.Type == gjson.String {
blocks2 = append(blocks2, map[string]interface{}{
"type": "text",
"text": content2.String(),
})
}
// Merge text blocks if both end/start with text
if len(blocks1) > 0 && len(blocks2) > 0 {
if blocks1[len(blocks1)-1]["type"] == "text" && blocks2[0]["type"] == "text" {
// Merge the last text block of msg1 with the first text block of msg2
text1 := blocks1[len(blocks1)-1]["text"].(string)
text2 := blocks2[0]["text"].(string)
blocks1[len(blocks1)-1]["text"] = text1 + "\n" + text2
blocks2 = blocks2[1:] // Remove the merged block from blocks2
}
}
// Combine all blocks
allBlocks := append(blocks1, blocks2...)
// Convert to JSON
result, _ := json.Marshal(allBlocks)
return string(result)
}
// blockToMap converts a gjson.Result block to a map[string]interface{}
func blockToMap(block gjson.Result) map[string]interface{} {
result := make(map[string]interface{})
block.ForEach(func(key, value gjson.Result) bool {
if value.IsObject() {
result[key.String()] = blockToMap(value)
} else if value.IsArray() {
var arr []interface{}
for _, item := range value.Array() {
if item.IsObject() {
arr = append(arr, blockToMap(item))
} else {
arr = append(arr, item.Value())
}
}
result[key.String()] = arr
} else {
result[key.String()] = value.Value()
}
return true
})
return result
}
// createMergedMessage creates a JSON string for a merged message.
// toolCalls is optional and only emitted for assistant role.
func createMergedMessage(role string, content string, toolCalls []interface{}) string {
msg := map[string]interface{}{
"role": role,
"content": json.RawMessage(content),
}
if role == "assistant" && len(toolCalls) > 0 {
msg["tool_calls"] = toolCalls
}
result, _ := json.Marshal(msg)
return string(result)
}
// mergeToolCalls combines tool_calls from two assistant messages while preserving order.
func mergeToolCalls(tc1, tc2 gjson.Result) []interface{} {
var merged []interface{}
if tc1.IsArray() {
for _, tc := range tc1.Array() {
merged = append(merged, tc.Value())
}
}
if tc2.IsArray() {
for _, tc := range tc2.Array() {
merged = append(merged, tc.Value())
}
}
return merged
}

View File

@@ -0,0 +1,106 @@
package common
import (
"strings"
"testing"
"github.com/tidwall/gjson"
)
func parseMessages(t *testing.T, raw string) []gjson.Result {
t.Helper()
parsed := gjson.Parse(raw)
if !parsed.IsArray() {
t.Fatalf("expected JSON array, got: %s", raw)
}
return parsed.Array()
}
func TestMergeAdjacentMessages_AssistantMergePreservesToolCalls(t *testing.T) {
messages := parseMessages(t, `[
{"role":"assistant","content":"part1"},
{
"role":"assistant",
"content":"part2",
"tool_calls":[
{
"id":"call_1",
"type":"function",
"function":{"name":"Read","arguments":"{}"}
}
]
},
{"role":"tool","tool_call_id":"call_1","content":"ok"}
]`)
merged := MergeAdjacentMessages(messages)
if len(merged) != 2 {
t.Fatalf("expected 2 messages after merge, got %d", len(merged))
}
assistant := merged[0]
if assistant.Get("role").String() != "assistant" {
t.Fatalf("expected first message role assistant, got %q", assistant.Get("role").String())
}
toolCalls := assistant.Get("tool_calls")
if !toolCalls.IsArray() || len(toolCalls.Array()) != 1 {
t.Fatalf("expected assistant.tool_calls length 1, got: %s", toolCalls.Raw)
}
if toolCalls.Array()[0].Get("id").String() != "call_1" {
t.Fatalf("expected tool call id call_1, got %q", toolCalls.Array()[0].Get("id").String())
}
contentRaw := assistant.Get("content").Raw
if !strings.Contains(contentRaw, "part1") || !strings.Contains(contentRaw, "part2") {
t.Fatalf("expected merged content to contain both parts, got: %s", contentRaw)
}
if merged[1].Get("role").String() != "tool" {
t.Fatalf("expected second message role tool, got %q", merged[1].Get("role").String())
}
}
func TestMergeAdjacentMessages_AssistantMergeCombinesMultipleToolCalls(t *testing.T) {
messages := parseMessages(t, `[
{
"role":"assistant",
"content":"first",
"tool_calls":[
{"id":"call_1","type":"function","function":{"name":"Read","arguments":"{}"}}
]
},
{
"role":"assistant",
"content":"second",
"tool_calls":[
{"id":"call_2","type":"function","function":{"name":"Write","arguments":"{}"}}
]
}
]`)
merged := MergeAdjacentMessages(messages)
if len(merged) != 1 {
t.Fatalf("expected 1 message after merge, got %d", len(merged))
}
toolCalls := merged[0].Get("tool_calls").Array()
if len(toolCalls) != 2 {
t.Fatalf("expected 2 merged tool calls, got %d", len(toolCalls))
}
if toolCalls[0].Get("id").String() != "call_1" || toolCalls[1].Get("id").String() != "call_2" {
t.Fatalf("unexpected merged tool call ids: %q, %q", toolCalls[0].Get("id").String(), toolCalls[1].Get("id").String())
}
}
func TestMergeAdjacentMessages_ToolMessagesRemainUnmerged(t *testing.T) {
messages := parseMessages(t, `[
{"role":"tool","tool_call_id":"call_1","content":"r1"},
{"role":"tool","tool_call_id":"call_2","content":"r2"}
]`)
merged := MergeAdjacentMessages(messages)
if len(merged) != 2 {
t.Fatalf("expected tool messages to remain separate, got %d", len(merged))
}
}

View File

@@ -0,0 +1,16 @@
// Package common provides shared constants and utilities for Kiro translator.
package common
// GetString safely extracts a string from a map.
// Returns empty string if the key doesn't exist or the value is not a string.
func GetString(m map[string]interface{}, key string) string {
if v, ok := m[key].(string); ok {
return v
}
return ""
}
// GetStringValue is an alias for GetString for backward compatibility.
func GetStringValue(m map[string]interface{}, key string) string {
return GetString(m, key)
}

View File

@@ -0,0 +1,20 @@
// Package openai provides translation between OpenAI Chat Completions and Kiro formats.
package openai
import (
. "github.com/router-for-me/CLIProxyAPI/v6/internal/constant"
"github.com/router-for-me/CLIProxyAPI/v6/internal/interfaces"
"github.com/router-for-me/CLIProxyAPI/v6/internal/translator/translator"
)
func init() {
translator.Register(
OpenAI, // source format
Kiro, // target format
ConvertOpenAIRequestToKiro,
interfaces.TranslateResponse{
Stream: ConvertKiroStreamToOpenAI,
NonStream: ConvertKiroNonStreamToOpenAI,
},
)
}

Some files were not shown because too many files have changed in this diff Show More